diff --git a/.env b/.env index e066cc2b2b350..d17408a1a817d 100644 --- a/.env +++ b/.env @@ -1,8 +1,8 @@ IMAGE_REPO=milvusdb IMAGE_ARCH=amd64 OS_NAME=ubuntu20.04 -DATE_VERSION=20231011-11b5213 -LATEST_DATE_VERSION=20231011-11b5213 +DATE_VERSION=20231024-4faba61 +LATEST_DATE_VERSION=20231024-4faba61 GPU_DATE_VERSION=20230822-a64488a LATEST_GPU_DATE_VERSION=20230317-a1c7b0c MINIO_ADDRESS=minio:9000 diff --git a/.github/workflows/mac.yaml b/.github/workflows/mac.yaml index 1bc9e73fffd04..f3d7f50b286f0 100644 --- a/.github/workflows/mac.yaml +++ b/.github/workflows/mac.yaml @@ -28,7 +28,7 @@ jobs: mac: name: Code Checker MacOS 12 runs-on: macos-12 - timeout-minutes: 300 + timeout-minutes: 300 steps: - name: Checkout uses: actions/checkout@v2 @@ -73,7 +73,7 @@ jobs: fi ls -alh /var/tmp/ccache brew install libomp ninja openblas ccache pkg-config - pip3 install conan==1.58.0 + pip3 install conan==1.61.0 if [[ ! -d "/usr/local/opt/llvm" ]]; then ln -s /usr/local/opt/llvm@14 /usr/local/opt/llvm fi diff --git a/DEVELOPMENT.md b/DEVELOPMENT.md index 8606a20d5c7da..4a582e76eeca8 100644 --- a/DEVELOPMENT.md +++ b/DEVELOPMENT.md @@ -118,7 +118,7 @@ Milvus uses Conan to manage third-party dependencies for c++. Install Conan ```shell -pip install conan==1.58.0 +pip install conan==1.61.0 ``` Note: Conan version 2.x is not currently supported, please use version 1.58. @@ -288,7 +288,7 @@ start the cluster on your host machine ```shell $ ./build/builder.sh make install // build milvus -$ ./build/build_image.sh // build milvus lastest docker image +$ ./build/build_image.sh // build milvus latest docker image $ docker images // check if milvus latest image is ready REPOSITORY TAG IMAGE ID CREATED SIZE milvusdb/milvus latest 63c62ff7c1b7 52 minutes ago 570MB diff --git a/README.md b/README.md index 707e0b419e5c8..bf305c3ac22ea 100644 --- a/README.md +++ b/README.md @@ -169,7 +169,7 @@ Contributions to Milvus are welcome from everyone. See [Guidelines for Contribut ### All contributors
-
+
@@ -290,6 +290,7 @@ Contributions to Milvus are welcome from everyone. See [Guidelines for Contribut + diff --git a/README_CN.md b/README_CN.md index 3cf9a188af826..ef9b4b6ff02aa 100644 --- a/README_CN.md +++ b/README_CN.md @@ -154,7 +154,7 @@ Milvus [训练营](https://github.com/milvus-io/bootcamp)能够帮助你了解 ### All contributors
-
+
@@ -275,6 +275,7 @@ Milvus [训练营](https://github.com/milvus-io/bootcamp)能够帮助你了解 + diff --git a/build/docker/builder/cpu/amazonlinux2023/Dockerfile b/build/docker/builder/cpu/amazonlinux2023/Dockerfile index c9f3f15cecdb2..c349e17f2cceb 100644 --- a/build/docker/builder/cpu/amazonlinux2023/Dockerfile +++ b/build/docker/builder/cpu/amazonlinux2023/Dockerfile @@ -18,7 +18,7 @@ RUN yum install -y wget g++ gcc gdb libatomic libstdc++-static git make zip unzi pkg-config libuuid-devel libaio perl-IPC-Cmd && \ rm -rf /var/cache/yum/* -RUN pip3 install conan==1.58.0 +RUN pip3 install conan==1.61.0 RUN echo "target arch $TARGETARCH" RUN wget -qO- "https://cmake.org/files/v3.24/cmake-3.24.4-linux-`uname -m`.tar.gz" | tar --strip-components=1 -xz -C /usr/local diff --git a/build/docker/builder/cpu/ubuntu20.04/Dockerfile b/build/docker/builder/cpu/ubuntu20.04/Dockerfile index 3e95294dfd5f2..8f9af87d88ca7 100644 --- a/build/docker/builder/cpu/ubuntu20.04/Dockerfile +++ b/build/docker/builder/cpu/ubuntu20.04/Dockerfile @@ -22,7 +22,7 @@ RUN apt-get update && apt-get install -y --no-install-recommends wget curl ca-ce apt-get remove --purge -y && \ rm -rf /var/lib/apt/lists/* -RUN pip3 install conan==1.58.0 +RUN pip3 install conan==1.61.0 RUN echo "target arch $TARGETARCH" RUN wget -qO- "https://cmake.org/files/v3.24/cmake-3.24.4-linux-`uname -m`.tar.gz" | tar --strip-components=1 -xz -C /usr/local diff --git a/build/docker/builder/gpu/ubuntu20.04/Dockerfile b/build/docker/builder/gpu/ubuntu20.04/Dockerfile index fa2303f3f26b2..dc0febf4e0a11 100644 --- a/build/docker/builder/gpu/ubuntu20.04/Dockerfile +++ b/build/docker/builder/gpu/ubuntu20.04/Dockerfile @@ -11,7 +11,7 @@ FROM nvidia/cuda:11.8.0-devel-ubuntu20.04 -RUN apt-get update && apt-get install -y --no-install-recommends wget curl ca-certificates gnupg2 && \ +RUN apt-get update && apt-get install -y --no-install-recommends wget curl ca-certificates gnupg2 ninja-build && \ wget -qO- "https://cmake.org/files/v3.24/cmake-3.24.0-linux-x86_64.tar.gz" | tar --strip-components=1 -xz -C /usr/local && \ apt-get update && apt-get install -y --no-install-recommends \ g++ gcc gfortran git make ccache libssl-dev zlib1g-dev zip unzip \ @@ -20,7 +20,7 @@ RUN apt-get update && apt-get install -y --no-install-recommends wget curl ca-ce apt-get remove --purge -y && \ rm -rf /var/lib/apt/lists/* -RUN pip3 install conan==1.58.0 +RUN pip3 install conan==1.61.0 # Instal openblas # RUN wget https://github.com/xianyi/OpenBLAS/archive/v0.3.21.tar.gz && \ diff --git a/ci/jenkins/PublishImages.groovy b/ci/jenkins/PublishImages.groovy index 2139fe5834b78..0d49ac0afed0b 100644 --- a/ci/jenkins/PublishImages.groovy +++ b/ci/jenkins/PublishImages.groovy @@ -27,7 +27,7 @@ pipeline { } stages { - stage('Generat Image Tag') { + stage('Generate Image Tag') { steps { script { def date = sh(returnStdout: true, script: 'date +%Y%m%d').trim() diff --git a/configs/advanced/etcd.yaml b/configs/advanced/etcd.yaml index e2d3e727f8068..79d005fb99fe7 100644 --- a/configs/advanced/etcd.yaml +++ b/configs/advanced/etcd.yaml @@ -15,7 +15,7 @@ # limitations under the License. # This is the configuration file for the etcd server. -# Only standalone users with embeded etcd should change this file, others could just keep this file As Is. +# Only standalone users with embedded etcd should change this file, others could just keep this file As Is. # All the etcd client should be added to milvus.yaml if necessary # Human-readable name for this member. diff --git a/configs/milvus.yaml b/configs/milvus.yaml index adea48aee559e..c180ed2d012fe 100644 --- a/configs/milvus.yaml +++ b/configs/milvus.yaml @@ -111,7 +111,7 @@ mq: pulsar: address: localhost # Address of pulsar port: 6650 # Port of Pulsar - webport: 80 # Web port of pulsar, if you connect direcly without proxy, should use 8080 + webport: 80 # Web port of pulsar, if you connect directly without proxy, should use 8080 maxMessageSize: 5242880 # 5 * 1024 * 1024 Bytes, Maximum size of each message in pulsar. tenant: public namespace: default @@ -346,7 +346,7 @@ dataCoord: balanceInterval: 360 #The interval for the channelBalancer on datacoord to check balance status segment: maxSize: 512 # Maximum size of a segment in MB - diskSegmentMaxSize: 2048 # Maximun size of a segment in MB for collection which has Disk index + diskSegmentMaxSize: 2048 # Maximum size of a segment in MB for collection which has Disk index sealProportion: 0.23 # The time of the assignment expiration in ms # Warning! this parameter is an expert variable and closely related to data integrity. Without specific @@ -443,7 +443,7 @@ grpc: serverMaxSendSize: 536870912 serverMaxRecvSize: 536870912 client: - compressionEnabled: true + compressionEnabled: false dialTimeout: 200 keepAliveTime: 10000 keepAliveTimeout: 20000 diff --git a/docs/design_docs/20210604-datanode_flowgraph_recovery_design.md b/docs/design_docs/20210604-datanode_flowgraph_recovery_design.md index 9f89ec50629c3..7c4f28e655025 100644 --- a/docs/design_docs/20210604-datanode_flowgraph_recovery_design.md +++ b/docs/design_docs/20210604-datanode_flowgraph_recovery_design.md @@ -74,7 +74,7 @@ Supposing we have segments `s1, s2, s3`, corresponding positions `p1, p2, p3` const filter_threshold = recovery_time // mp means msgPack for mp := seeking(p1) { - if mp.position.endtime < filter_threshod { + if mp.position.endtime < filter_threshold { if mp.position < p3 { filter s3 } diff --git a/docs/design_docs/20211217-milvus_create_collection.md b/docs/design_docs/20211217-milvus_create_collection.md index c98f03a666faf..44f621f95364d 100644 --- a/docs/design_docs/20211217-milvus_create_collection.md +++ b/docs/design_docs/20211217-milvus_create_collection.md @@ -86,7 +86,7 @@ type createCollectionTask struct { } ``` - - `PostExecute`, `CreateCollectonTask` does nothing at this phase, and return directly. + - `PostExecute`, `CreateCollectionTask` does nothing at this phase, and return directly. 4. `RootCoord` would wrap the `CreateCollection` request into `CreateCollectionReqTask`, and then call function `executeTask`. `executeTask` would return until the `context` is done or `CreateCollectionReqTask.Execute` is returned. @@ -104,7 +104,7 @@ type CreateCollectionReqTask struct { } ``` -5. `CreateCollectionReqTask.Execute` would alloc `CollecitonID` and default `PartitionID`, and set `Virtual Channel` and `Physical Channel`, which are used by `MsgStream`, then write the `Collection`'s meta into `metaTable` +5. `CreateCollectionReqTask.Execute` would alloc `CollectionID` and default `PartitionID`, and set `Virtual Channel` and `Physical Channel`, which are used by `MsgStream`, then write the `Collection`'s meta into `metaTable` 6. After `Collection`'s meta written into `metaTable`, `Milvus` would consider this collection has been created successfully. diff --git a/docs/design_docs/20220105-proxy.md b/docs/design_docs/20220105-proxy.md index 20af39b16b688..c447a5a1652cb 100644 --- a/docs/design_docs/20220105-proxy.md +++ b/docs/design_docs/20220105-proxy.md @@ -127,7 +127,7 @@ future work. For DqRequest, request and result data are written to the stream. The request data will be written to DqRequestChannel, and the result data will be written to DqResultChannel. Proxy will write the request of the collection into the -DqRequestChannel, and the DqReqeustChannel will be jointly subscribed by a group of query nodes. When all query nodes +DqRequestChannel, and the DqRequestChannel will be jointly subscribed by a group of query nodes. When all query nodes receive the DqRequest, they will write the query results into the DqResultChannel corresponding to the collection. As the consumer of the DqResultChannel, Proxy is responsible for collecting the query results and aggregating them, The result is then returned to the client. diff --git a/docs/design_docs/20220105-query_boolean_expr.md b/docs/design_docs/20220105-query_boolean_expr.md index 656d27a1a1894..5e030f925bd5f 100644 --- a/docs/design_docs/20220105-query_boolean_expr.md +++ b/docs/design_docs/20220105-query_boolean_expr.md @@ -31,7 +31,7 @@ ConstantExpr := | UnaryArithOp ConstantExpr Constant := - INTERGER + INTEGER | FLOAT_NUMBER UnaryArithOp := @@ -64,7 +64,7 @@ CmpOp := | "==" | "!=" -INTERGER := 整数 +INTEGER := 整数 FLOAT_NUM := 浮点数 IDENTIFIER := 列名 ``` diff --git a/docs/design_docs/20230918-datanode_remove_datacoord_dependency.md b/docs/design_docs/20230918-datanode_remove_datacoord_dependency.md index 338aaa986ecbc..ec87e140e2630 100644 --- a/docs/design_docs/20230918-datanode_remove_datacoord_dependency.md +++ b/docs/design_docs/20230918-datanode_remove_datacoord_dependency.md @@ -61,7 +61,7 @@ The rules system shall follow is: {% note %} -**Note:** Segments meta shall be updated *BEFORE* changing the channel checkpoint in case of datanode crashing during the prodedure. Under this premise, reconsuming from the old checkpoint shall recover all the data and duplidated entires will be discarded by segment checkpoints. +**Note:** Segments meta shall be updated *BEFORE* changing the channel checkpoint in case of datanode crashing during the prodedure. Under this premise, reconsuming from the old checkpoint shall recover all the data and duplidated entries will be discarded by segment checkpoints. {% endnote %} @@ -78,7 +78,7 @@ The winning option is to: **Note:** `Datacoord` reloads from metastore periodically. Optimization 1: reload channel checkpoint first, then reload segment meta if newly read revision is greater than in-memory one. -Optimization 2: After `L0 segemnt` is implemented, datacoord shall refresh growing segments only. +Optimization 2: After `L0 segment` is implemented, datacoord shall refresh growing segments only. {% endnote %} diff --git a/docs/design_docs/segcore/segment_growing.md b/docs/design_docs/segcore/segment_growing.md index c3f8ad7da4027..74b3011bf40a3 100644 --- a/docs/design_docs/segcore/segment_growing.md +++ b/docs/design_docs/segcore/segment_growing.md @@ -2,13 +2,13 @@ Growing segment has the following additional interfaces: -1. `PreInsert(size) -> reseveredOffset`: serial interface, which reserves space for future insertion and returns the `reseveredOffset`. +1. `PreInsert(size) -> reservedOffset`: serial interface, which reserves space for future insertion and returns the `reservedOffset`. -2. `Insert(reseveredOffset, size, ...Data...)`: write `...Data...` into range `[reseveredOffset, reseveredOffset + size)`. This interface is allowed to be called concurrently. +2. `Insert(reservedOffset, size, ...Data...)`: write `...Data...` into range `[reservedOffset, reservedOffset + size)`. This interface is allowed to be called concurrently. 1. `...Data...` contains row_ids, timestamps two system attributes, and other columns 2. data columns can be stored either row-based or column-based. - 3. `PreDelete & Delete(reseveredOffset, row_ids, timestamps)` is a delete interface similar to insert interface. + 3. `PreDelete & Delete(reservedOffset, row_ids, timestamps)` is a delete interface similar to insert interface. Growing segment stores data in the form of chunk. The number of rows in each chunk is restricted by configs. diff --git a/docs/developer_guides/appendix_a_basic_components.md b/docs/developer_guides/appendix_a_basic_components.md index cc4b42b708846..62eee6999d36a 100644 --- a/docs/developer_guides/appendix_a_basic_components.md +++ b/docs/developer_guides/appendix_a_basic_components.md @@ -107,7 +107,7 @@ type Session struct { } // NewSession is a helper to build Session object. -// ServerID, ServerName, Address, Exclusive will be assigned after registeration. +// ServerID, ServerName, Address, Exclusive will be assigned after registration. // metaRoot is a path in etcd to save session information. // etcdEndpoints is to init etcdCli when NewSession func NewSession(ctx context.Context, metaRoot string, etcdEndpoints []string) *Session {} diff --git a/docs/developer_guides/chap04_message_stream.md b/docs/developer_guides/chap04_message_stream.md index 3a6e0004bd62f..e2822477a2e6c 100644 --- a/docs/developer_guides/chap04_message_stream.md +++ b/docs/developer_guides/chap04_message_stream.md @@ -7,7 +7,7 @@ ```go type Client interface { CreateChannels(req CreateChannelRequest) (CreateChannelResponse, error) - DestoryChannels(req DestoryChannelRequest) error + DestroyChannels(req DestroyChannelRequest) error DescribeChannels(req DescribeChannelRequest) (DescribeChannelResponse, error) } ``` @@ -32,10 +32,10 @@ type CreateChannelResponse struct { } ``` -- _DestoryChannels_ +- _DestroyChannels_ ```go -type DestoryChannelRequest struct { +type DestroyChannelRequest struct { ChannelNames []string } ``` diff --git a/docs/developer_guides/chap05_proxy.md b/docs/developer_guides/chap05_proxy.md index 15240c1606e4f..8bff965ea4338 100644 --- a/docs/developer_guides/chap05_proxy.md +++ b/docs/developer_guides/chap05_proxy.md @@ -105,7 +105,7 @@ type MilvusService interface { CreatePartition(ctx context.Context, request *milvuspb.CreatePartitionRequest) (*commonpb.Status, error) DropPartition(ctx context.Context, request *milvuspb.DropPartitionRequest) (*commonpb.Status, error) HasPartition(ctx context.Context, request *milvuspb.HasPartitionRequest) (*milvuspb.BoolResponse, error) - LoadPartitions(ctx context.Context, request *milvuspb.LoadPartitonRequest) (*commonpb.Status, error) + LoadPartitions(ctx context.Context, request *milvuspb.LoadPartitionRequest) (*commonpb.Status, error) ReleasePartitions(ctx context.Context, request *milvuspb.ReleasePartitionRequest) (*commonpb.Status, error) GetPartitionStatistics(ctx context.Context, request *milvuspb.PartitionStatsRequest) (*milvuspb.PartitionStatsResponse, error) ShowPartitions(ctx context.Context, request *milvuspb.ShowPartitionRequest) (*milvuspb.ShowPartitionResponse, error) @@ -225,7 +225,7 @@ type CollectionSchema struct { Fields []*FieldSchema } -type LoadPartitonRequest struct { +type LoadPartitionRequest struct { Base *commonpb.MsgBase DbID UniqueID CollectionID UniqueID diff --git a/docs/developer_guides/chap07_query_coordinator.md b/docs/developer_guides/chap07_query_coordinator.md index 067b2d534bca5..2bfba980be175 100644 --- a/docs/developer_guides/chap07_query_coordinator.md +++ b/docs/developer_guides/chap07_query_coordinator.md @@ -134,7 +134,7 @@ type PartitionStatesResponse struct { - _LoadPartitions_ ```go -type LoadPartitonRequest struct { +type LoadPartitionRequest struct { Base *commonpb.MsgBase DbID UniqueID CollectionID UniqueID diff --git a/docs/user_guides/tls_proxy.md b/docs/user_guides/tls_proxy.md index a359dff917010..63251f7ca92b1 100644 --- a/docs/user_guides/tls_proxy.md +++ b/docs/user_guides/tls_proxy.md @@ -78,7 +78,7 @@ certs = $dir/certs # Where the issued certs are kept crl_dir = $dir/crl # Where the issued crl are kept database = $dir/index.txt # database index file. #unique_subject = no # Set to 'no' to allow creation of - # several ctificates with same subject. + # several certificates with same subject. new_certs_dir = $dir/newcerts # default place for new certs. certificate = $dir/cacert.pem # The CA certificate @@ -89,7 +89,7 @@ crl = $dir/crl.pem # The current CRL private_key = $dir/private/cakey.pem# The private key RANDFILE = $dir/private/.rand # private random number file -x509_extensions = usr_cert # The extentions to add to the cert +x509_extensions = usr_cert # The extensions to add to the cert # Comment out the following two lines for the "traditional" # (and highly broken) format. @@ -141,7 +141,7 @@ default_bits = 2048 default_keyfile = privkey.pem distinguished_name = req_distinguished_name attributes = req_attributes -x509_extensions = v3_ca # The extentions to add to the self signed cert +x509_extensions = v3_ca # The extensions to add to the self signed cert # Passwords for private keys if not present they will be prompted for # input_password = secret diff --git a/go.mod b/go.mod index 3501b30ecb576..0ff6ffcee9c39 100644 --- a/go.mod +++ b/go.mod @@ -23,7 +23,7 @@ require ( github.com/grpc-ecosystem/go-grpc-middleware v1.3.0 github.com/klauspost/compress v1.16.7 github.com/mgutz/ansi v0.0.0-20200706080929-d51e80ef957d - github.com/milvus-io/milvus-proto/go-api/v2 v2.3.2-0.20231019101159-a0a6f5e7eff8 + github.com/milvus-io/milvus-proto/go-api/v2 v2.3.2 github.com/milvus-io/milvus/pkg v0.0.1 github.com/minio/minio-go/v7 v7.0.61 github.com/prometheus/client_golang v1.14.0 diff --git a/go.sum b/go.sum index 8e5e5d28b0e25..dcaa85408c912 100644 --- a/go.sum +++ b/go.sum @@ -581,8 +581,8 @@ github.com/microcosm-cc/bluemonday v1.0.2/go.mod h1:iVP4YcDBq+n/5fb23BhYFvIMq/le github.com/miekg/dns v1.0.14/go.mod h1:W1PPwlIAgtquWBMBEV9nkV9Cazfe8ScdGz/Lj7v3Nrg= github.com/milvus-io/gorocksdb v0.0.0-20220624081344-8c5f4212846b h1:TfeY0NxYxZzUfIfYe5qYDBzt4ZYRqzUjTR6CvUzjat8= github.com/milvus-io/gorocksdb v0.0.0-20220624081344-8c5f4212846b/go.mod h1:iwW+9cWfIzzDseEBCCeDSN5SD16Tidvy8cwQ7ZY8Qj4= -github.com/milvus-io/milvus-proto/go-api/v2 v2.3.2-0.20231019101159-a0a6f5e7eff8 h1:GoGErEOhdWjwSfQilXso3eINqb11yEBDLtoBMNdlve0= -github.com/milvus-io/milvus-proto/go-api/v2 v2.3.2-0.20231019101159-a0a6f5e7eff8/go.mod h1:1OIl0v5PQeNxIJhCvY+K55CBUOYDZevw9g9380u1Wek= +github.com/milvus-io/milvus-proto/go-api/v2 v2.3.2 h1:tBcKiEUcX6i3MaFYvMJO1F7R6fIoeLFkg1kSGE1Tvpk= +github.com/milvus-io/milvus-proto/go-api/v2 v2.3.2/go.mod h1:1OIl0v5PQeNxIJhCvY+K55CBUOYDZevw9g9380u1Wek= github.com/milvus-io/milvus-storage/go v0.0.0-20231017063757-b4720fe2ec8f h1:NGI0nposnZ2S0K1W/pEWRNzDkaLraaI10aMgKCMnVYs= github.com/milvus-io/milvus-storage/go v0.0.0-20231017063757-b4720fe2ec8f/go.mod h1:GPETMcTZq1gLY1WA6Na5kiNAKnq8SEMMiVKUZrM3sho= github.com/milvus-io/pulsar-client-go v0.6.10 h1:eqpJjU+/QX0iIhEo3nhOqMNXL+TyInAs1IAHZCrCM/A= diff --git a/internal/core/conanfile.py b/internal/core/conanfile.py index 9bcea02577b93..e44130cc05dc3 100644 --- a/internal/core/conanfile.py +++ b/internal/core/conanfile.py @@ -14,18 +14,6 @@ class MilvusConan(ConanFile): "lzo/2.10", "arrow/12.0.1", "openssl/3.1.2", - "s2n/1.3.31@milvus/dev", - "aws-c-common/0.8.2@milvus/dev", - "aws-c-compression/0.2.15@milvus/dev", - "aws-c-sdkutils/0.1.3@milvus/dev", - "aws-checksums/0.1.13@milvus/dev", - "aws-c-cal/0.5.20@milvus/dev", - "aws-c-io/0.10.20@milvus/dev", - "aws-c-http/0.6.13@milvus/dev", - "aws-c-auth/0.6.11@milvus/dev", - "aws-c-event-stream/0.2.7@milvus/dev", - "aws-c-s3/0.1.37@milvus/dev", - "aws-crt-cpp/0.17.23@milvus/dev", "aws-sdk-cpp/1.9.234", "googleapis/cci.20221108", "benchmark/1.7.0", diff --git a/internal/datacoord/channel_manager.go b/internal/datacoord/channel_manager.go index 0d77425c08b97..9feb380db6c6b 100644 --- a/internal/datacoord/channel_manager.go +++ b/internal/datacoord/channel_manager.go @@ -251,7 +251,7 @@ func (c *ChannelManager) unwatchDroppedChannels() { nodeChannels := c.store.GetChannels() for _, nodeChannel := range nodeChannels { for _, ch := range nodeChannel.Channels { - if !c.h.CheckShouldDropChannel(ch.Name, ch.CollectionID) { + if !c.h.CheckShouldDropChannel(ch.Name) { continue } err := c.remove(nodeChannel.NodeID, ch) @@ -788,7 +788,7 @@ func (c *ChannelManager) Reassign(originNodeID UniqueID, channelName string) err c.mu.RUnlock() reallocates := &NodeChannelInfo{originNodeID, []*channel{ch}} - isDropped := c.isMarkedDrop(channelName, ch.CollectionID) + isDropped := c.isMarkedDrop(channelName) c.mu.Lock() defer c.mu.Unlock() @@ -843,7 +843,7 @@ func (c *ChannelManager) CleanupAndReassign(nodeID UniqueID, channelName string) } reallocates := &NodeChannelInfo{nodeID, []*channel{chToCleanUp}} - isDropped := c.isMarkedDrop(channelName, chToCleanUp.CollectionID) + isDropped := c.isMarkedDrop(channelName) c.mu.Lock() defer c.mu.Unlock() @@ -910,8 +910,8 @@ func (c *ChannelManager) getNodeIDByChannelName(chName string) (bool, UniqueID) return false, 0 } -func (c *ChannelManager) isMarkedDrop(channelName string, collectionID UniqueID) bool { - return c.h.CheckShouldDropChannel(channelName, collectionID) +func (c *ChannelManager) isMarkedDrop(channelName string) bool { + return c.h.CheckShouldDropChannel(channelName) } func getReleaseOp(nodeID UniqueID, ch *channel) ChannelOpSet { diff --git a/internal/datacoord/channel_manager_test.go b/internal/datacoord/channel_manager_test.go index 14b14d3a421dd..cf7eaf12b68eb 100644 --- a/internal/datacoord/channel_manager_test.go +++ b/internal/datacoord/channel_manager_test.go @@ -581,7 +581,7 @@ func TestChannelManager(t *testing.T) { collectionID := UniqueID(5) handler := NewNMockHandler(t) handler.EXPECT(). - CheckShouldDropChannel(mock.Anything, mock.Anything). + CheckShouldDropChannel(mock.Anything). Return(true) handler.EXPECT().FinishDropChannel(mock.Anything).Return(nil) chManager, err := NewChannelManager(watchkv, handler) @@ -603,8 +603,8 @@ func TestChannelManager(t *testing.T) { var err error handler := NewNMockHandler(t) handler.EXPECT(). - CheckShouldDropChannel(mock.Anything, mock.Anything). - Run(func(channel string, collectionID int64) { + CheckShouldDropChannel(mock.Anything). + Run(func(channel string) { channels, err := chManager.store.Delete(1) assert.NoError(t, err) assert.Equal(t, 1, len(channels)) @@ -628,8 +628,8 @@ func TestChannelManager(t *testing.T) { var err error handler := NewNMockHandler(t) handler.EXPECT(). - CheckShouldDropChannel(mock.Anything, mock.Anything). - Run(func(channel string, collectionID int64) { + CheckShouldDropChannel(mock.Anything). + Run(func(channel string) { channels, err := chManager.store.Delete(1) assert.NoError(t, err) assert.Equal(t, 1, len(channels)) @@ -659,7 +659,7 @@ func TestChannelManager(t *testing.T) { t.Run("test CleanupAndReassign with dropped channel", func(t *testing.T) { handler := NewNMockHandler(t) handler.EXPECT(). - CheckShouldDropChannel(mock.Anything, mock.Anything). + CheckShouldDropChannel(mock.Anything). Return(true) handler.EXPECT().FinishDropChannel(mock.Anything).Return(nil) chManager, err := NewChannelManager(watchkv, handler) diff --git a/internal/datacoord/handler.go b/internal/datacoord/handler.go index 4db996cda0ace..d9f5183977064 100644 --- a/internal/datacoord/handler.go +++ b/internal/datacoord/handler.go @@ -39,8 +39,8 @@ type Handler interface { GetQueryVChanPositions(ch *channel, partitionIDs ...UniqueID) *datapb.VchannelInfo // GetDataVChanPositions gets the information recovery needed of a channel for DataNode GetDataVChanPositions(ch *channel, partitionID UniqueID) *datapb.VchannelInfo - CheckShouldDropChannel(channel string, collectionID UniqueID) bool - FinishDropChannel(channel string) error + CheckShouldDropChannel(ch string) bool + FinishDropChannel(ch string) error GetCollection(ctx context.Context, collectionID UniqueID) (*collectionInfo, error) } @@ -403,20 +403,8 @@ func (h *ServerHandler) GetCollection(ctx context.Context, collectionID UniqueID } // CheckShouldDropChannel returns whether specified channel is marked to be removed -func (h *ServerHandler) CheckShouldDropChannel(channel string, collectionID UniqueID) bool { - if h.s.meta.catalog.ShouldDropChannel(h.s.ctx, channel) { - return true - } - // collectionID parse from channelName - has, err := h.HasCollection(h.s.ctx, collectionID) - if err != nil { - log.Info("datacoord ServerHandler CheckShouldDropChannel hasCollection failed", zap.Error(err)) - return false - } - log.Info("datacoord ServerHandler CheckShouldDropChannel hasCollection", zap.Bool("shouldDropChannel", !has), - zap.String("channel", channel)) - - return !has +func (h *ServerHandler) CheckShouldDropChannel(channel string) bool { + return h.s.meta.catalog.ShouldDropChannel(h.s.ctx, channel) } // FinishDropChannel cleans up the remove flag for channels diff --git a/internal/datacoord/mock_handler.go b/internal/datacoord/mock_handler.go index 9481afbb5f3ab..125fc7002a843 100644 --- a/internal/datacoord/mock_handler.go +++ b/internal/datacoord/mock_handler.go @@ -22,13 +22,13 @@ func (_m *NMockHandler) EXPECT() *NMockHandler_Expecter { return &NMockHandler_Expecter{mock: &_m.Mock} } -// CheckShouldDropChannel provides a mock function with given fields: channel, collectionID -func (_m *NMockHandler) CheckShouldDropChannel(channel string, collectionID int64) bool { - ret := _m.Called(channel, collectionID) +// CheckShouldDropChannel provides a mock function with given fields: ch +func (_m *NMockHandler) CheckShouldDropChannel(ch string) bool { + ret := _m.Called(ch) var r0 bool - if rf, ok := ret.Get(0).(func(string, int64) bool); ok { - r0 = rf(channel, collectionID) + if rf, ok := ret.Get(0).(func(string) bool); ok { + r0 = rf(ch) } else { r0 = ret.Get(0).(bool) } @@ -42,15 +42,14 @@ type NMockHandler_CheckShouldDropChannel_Call struct { } // CheckShouldDropChannel is a helper method to define mock.On call -// - channel string -// - collectionID int64 -func (_e *NMockHandler_Expecter) CheckShouldDropChannel(channel interface{}, collectionID interface{}) *NMockHandler_CheckShouldDropChannel_Call { - return &NMockHandler_CheckShouldDropChannel_Call{Call: _e.mock.On("CheckShouldDropChannel", channel, collectionID)} +// - ch string +func (_e *NMockHandler_Expecter) CheckShouldDropChannel(ch interface{}) *NMockHandler_CheckShouldDropChannel_Call { + return &NMockHandler_CheckShouldDropChannel_Call{Call: _e.mock.On("CheckShouldDropChannel", ch)} } -func (_c *NMockHandler_CheckShouldDropChannel_Call) Run(run func(channel string, collectionID int64)) *NMockHandler_CheckShouldDropChannel_Call { +func (_c *NMockHandler_CheckShouldDropChannel_Call) Run(run func(ch string)) *NMockHandler_CheckShouldDropChannel_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(string), args[1].(int64)) + run(args[0].(string)) }) return _c } @@ -60,18 +59,18 @@ func (_c *NMockHandler_CheckShouldDropChannel_Call) Return(_a0 bool) *NMockHandl return _c } -func (_c *NMockHandler_CheckShouldDropChannel_Call) RunAndReturn(run func(string, int64) bool) *NMockHandler_CheckShouldDropChannel_Call { +func (_c *NMockHandler_CheckShouldDropChannel_Call) RunAndReturn(run func(string) bool) *NMockHandler_CheckShouldDropChannel_Call { _c.Call.Return(run) return _c } -// FinishDropChannel provides a mock function with given fields: channel -func (_m *NMockHandler) FinishDropChannel(channel string) error { - ret := _m.Called(channel) +// FinishDropChannel provides a mock function with given fields: ch +func (_m *NMockHandler) FinishDropChannel(ch string) error { + ret := _m.Called(ch) var r0 error if rf, ok := ret.Get(0).(func(string) error); ok { - r0 = rf(channel) + r0 = rf(ch) } else { r0 = ret.Error(0) } @@ -85,12 +84,12 @@ type NMockHandler_FinishDropChannel_Call struct { } // FinishDropChannel is a helper method to define mock.On call -// - channel string -func (_e *NMockHandler_Expecter) FinishDropChannel(channel interface{}) *NMockHandler_FinishDropChannel_Call { - return &NMockHandler_FinishDropChannel_Call{Call: _e.mock.On("FinishDropChannel", channel)} +// - ch string +func (_e *NMockHandler_Expecter) FinishDropChannel(ch interface{}) *NMockHandler_FinishDropChannel_Call { + return &NMockHandler_FinishDropChannel_Call{Call: _e.mock.On("FinishDropChannel", ch)} } -func (_c *NMockHandler_FinishDropChannel_Call) Run(run func(channel string)) *NMockHandler_FinishDropChannel_Call { +func (_c *NMockHandler_FinishDropChannel_Call) Run(run func(ch string)) *NMockHandler_FinishDropChannel_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(string)) }) diff --git a/internal/datacoord/mock_test.go b/internal/datacoord/mock_test.go index b477b7db6efe8..260d5d940b497 100644 --- a/internal/datacoord/mock_test.go +++ b/internal/datacoord/mock_test.go @@ -784,7 +784,7 @@ func (h *mockHandler) GetDataVChanPositions(channel *channel, partitionID Unique } } -func (h *mockHandler) CheckShouldDropChannel(channel string, collectionID UniqueID) bool { +func (h *mockHandler) CheckShouldDropChannel(channel string) bool { return false } diff --git a/internal/datacoord/server.go b/internal/datacoord/server.go index 7cde094158f49..f990c4f47eec5 100644 --- a/internal/datacoord/server.go +++ b/internal/datacoord/server.go @@ -26,7 +26,7 @@ import ( "syscall" "time" - "github.com/blang/semver/v4" + semver "github.com/blang/semver/v4" "github.com/cockroachdb/errors" "github.com/tikv/client-go/v2/txnkv" clientv3 "go.etcd.io/etcd/client/v3" diff --git a/internal/datacoord/server_test.go b/internal/datacoord/server_test.go index 233fc45cb3cec..0f8ae9b882292 100644 --- a/internal/datacoord/server_test.go +++ b/internal/datacoord/server_test.go @@ -1905,24 +1905,6 @@ func TestGetChannelSeekPosition(t *testing.T) { } } -func TestDescribeCollection(t *testing.T) { - t.Run("TestNotExistCollections", func(t *testing.T) { - svr := newTestServer(t, nil) - defer closeTestServer(t, svr) - has, err := svr.handler.(*ServerHandler).HasCollection(context.TODO(), -1) - assert.NoError(t, err) - assert.False(t, has) - }) - - t.Run("TestExistCollections", func(t *testing.T) { - svr := newTestServer(t, nil) - defer closeTestServer(t, svr) - has, err := svr.handler.(*ServerHandler).HasCollection(context.TODO(), 1314) - assert.NoError(t, err) - assert.True(t, has) - }) -} - func TestGetDataVChanPositions(t *testing.T) { svr := newTestServer(t, nil) defer closeTestServer(t, svr) @@ -2463,12 +2445,7 @@ func TestShouldDropChannel(t *testing.T) { Count: 1, }, nil) - var crt rootCoordCreatorFunc = func(ctx context.Context, metaRoot string, etcdClient *clientv3.Client) (types.RootCoordClient, error) { - return myRoot, nil - } - - opt := WithRootCoordCreator(crt) - svr := newTestServer(t, nil, opt) + svr := newTestServer(t, nil) defer closeTestServer(t, svr) schema := newTestSchema() svr.meta.AddCollection(&collectionInfo{ @@ -2492,52 +2469,14 @@ func TestShouldDropChannel(t *testing.T) { }, }) - t.Run("channel name not in kv, collection not exist", func(t *testing.T) { - // myRoot.code = commonpb.ErrorCode_CollectionNotExists - myRoot.EXPECT().DescribeCollection(mock.Anything, mock.Anything). - Return(&milvuspb.DescribeCollectionResponse{ - Status: merr.Status(merr.WrapErrCollectionNotFound(-1)), - CollectionID: -1, - }, nil).Once() - assert.True(t, svr.handler.CheckShouldDropChannel("ch99", -1)) - }) - - t.Run("channel name not in kv, collection exist", func(t *testing.T) { - myRoot.EXPECT().DescribeCollection(mock.Anything, mock.Anything). - Return(&milvuspb.DescribeCollectionResponse{ - Status: merr.Success(), - CollectionID: 0, - }, nil).Once() - assert.False(t, svr.handler.CheckShouldDropChannel("ch99", 0)) + t.Run("channel name not in kv ", func(t *testing.T) { + assert.False(t, svr.handler.CheckShouldDropChannel("ch99")) }) - t.Run("collection name in kv, collection exist", func(t *testing.T) { - myRoot.EXPECT().DescribeCollection(mock.Anything, mock.Anything). - Return(&milvuspb.DescribeCollectionResponse{ - Status: merr.Success(), - CollectionID: 0, - }, nil).Once() - assert.False(t, svr.handler.CheckShouldDropChannel("ch1", 0)) - }) - - t.Run("collection name in kv, collection not exist", func(t *testing.T) { - myRoot.EXPECT().DescribeCollection(mock.Anything, mock.Anything). - Return(&milvuspb.DescribeCollectionResponse{ - Status: merr.Status(merr.WrapErrCollectionNotFound(-1)), - CollectionID: -1, - }, nil).Once() - assert.True(t, svr.handler.CheckShouldDropChannel("ch1", -1)) - }) - - t.Run("channel in remove flag, collection exist", func(t *testing.T) { + t.Run("channel in remove flag", func(t *testing.T) { err := svr.meta.catalog.MarkChannelDeleted(context.TODO(), "ch1") require.NoError(t, err) - myRoot.EXPECT().DescribeCollection(mock.Anything, mock.Anything). - Return(&milvuspb.DescribeCollectionResponse{ - Status: merr.Success(), - CollectionID: 0, - }, nil).Once() - assert.True(t, svr.handler.CheckShouldDropChannel("ch1", 0)) + assert.True(t, svr.handler.CheckShouldDropChannel("ch1")) }) } diff --git a/internal/datanode/flow_graph_time_tick_node.go b/internal/datanode/flow_graph_time_tick_node.go index e81a671f98484..daadd344d54a0 100644 --- a/internal/datanode/flow_graph_time_tick_node.go +++ b/internal/datanode/flow_graph_time_tick_node.go @@ -52,6 +52,8 @@ type ttNode struct { updateCPLock sync.Mutex notifyChannel chan checkPoint closeChannel chan struct{} + closeOnce sync.Once + closeWg sync.WaitGroup } type checkPoint struct { @@ -76,13 +78,19 @@ func (ttn *ttNode) IsValidInMsg(in []Msg) bool { return true } +func (ttn *ttNode) Close() { + ttn.closeOnce.Do(func() { + close(ttn.closeChannel) + ttn.closeWg.Wait() + }) +} + // Operate handles input messages, implementing flowgraph.Node func (ttn *ttNode) Operate(in []Msg) []Msg { fgMsg := in[0].(*flowGraphMsg) curTs, _ := tsoutil.ParseTS(fgMsg.timeRange.timestampMax) if fgMsg.IsCloseMsg() { if len(fgMsg.endPositions) > 0 { - close(ttn.closeChannel) channelPos := ttn.channel.getChannelCheckpoint(fgMsg.endPositions[0]) log.Info("flowgraph is closing, force update channel CP", zap.Time("cpTs", tsoutil.PhysicalTime(channelPos.GetTimestamp())), @@ -151,13 +159,17 @@ func newTTNode(config *nodeConfig, broker broker.Broker) (*ttNode, error) { broker: broker, notifyChannel: make(chan checkPoint, 1), closeChannel: make(chan struct{}), + closeWg: sync.WaitGroup{}, } // check point updater + tt.closeWg.Add(1) go func() { + defer tt.closeWg.Done() for { select { case <-tt.closeChannel: + log.Info("ttNode updater exited", zap.String("channel", tt.vChannelName)) return case cp := <-tt.notifyChannel: tt.updateChannelCP(cp.pos, cp.curTs) diff --git a/internal/distributed/datacoord/service.go b/internal/distributed/datacoord/service.go index 0d798d46b6109..8d2e29973ceb0 100644 --- a/internal/distributed/datacoord/service.go +++ b/internal/distributed/datacoord/service.go @@ -42,6 +42,7 @@ import ( "github.com/milvus-io/milvus/internal/proto/internalpb" "github.com/milvus-io/milvus/internal/types" "github.com/milvus-io/milvus/internal/util/dependency" + _ "github.com/milvus-io/milvus/internal/util/grpcclient" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/tracer" "github.com/milvus-io/milvus/pkg/util" diff --git a/internal/distributed/datanode/service.go b/internal/distributed/datanode/service.go index 3d3dba2537d70..0a8b29c727d2c 100644 --- a/internal/distributed/datanode/service.go +++ b/internal/distributed/datanode/service.go @@ -43,6 +43,7 @@ import ( "github.com/milvus-io/milvus/internal/types" "github.com/milvus-io/milvus/internal/util/componentutil" "github.com/milvus-io/milvus/internal/util/dependency" + _ "github.com/milvus-io/milvus/internal/util/grpcclient" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/tracer" "github.com/milvus-io/milvus/pkg/util/etcd" diff --git a/internal/distributed/indexnode/service.go b/internal/distributed/indexnode/service.go index 4b969a5dfb726..10b3b8ac02a09 100644 --- a/internal/distributed/indexnode/service.go +++ b/internal/distributed/indexnode/service.go @@ -40,6 +40,7 @@ import ( "github.com/milvus-io/milvus/internal/proto/internalpb" "github.com/milvus-io/milvus/internal/types" "github.com/milvus-io/milvus/internal/util/dependency" + _ "github.com/milvus-io/milvus/internal/util/grpcclient" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/tracer" "github.com/milvus-io/milvus/pkg/util/etcd" diff --git a/internal/distributed/proxy/service.go b/internal/distributed/proxy/service.go index a1545fcdb707f..0b0df80abda44 100644 --- a/internal/distributed/proxy/service.go +++ b/internal/distributed/proxy/service.go @@ -61,6 +61,7 @@ import ( "github.com/milvus-io/milvus/internal/types" "github.com/milvus-io/milvus/internal/util/componentutil" "github.com/milvus-io/milvus/internal/util/dependency" + _ "github.com/milvus-io/milvus/internal/util/grpcclient" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/tracer" "github.com/milvus-io/milvus/pkg/util" diff --git a/internal/distributed/querycoord/service.go b/internal/distributed/querycoord/service.go index 4281358f12312..80716c1bd0970 100644 --- a/internal/distributed/querycoord/service.go +++ b/internal/distributed/querycoord/service.go @@ -43,6 +43,7 @@ import ( "github.com/milvus-io/milvus/internal/types" "github.com/milvus-io/milvus/internal/util/componentutil" "github.com/milvus-io/milvus/internal/util/dependency" + _ "github.com/milvus-io/milvus/internal/util/grpcclient" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/tracer" "github.com/milvus-io/milvus/pkg/util" diff --git a/internal/distributed/querynode/service.go b/internal/distributed/querynode/service.go index 0233b3a98b1ab..a94c68d221943 100644 --- a/internal/distributed/querynode/service.go +++ b/internal/distributed/querynode/service.go @@ -40,6 +40,7 @@ import ( qn "github.com/milvus-io/milvus/internal/querynodev2" "github.com/milvus-io/milvus/internal/types" "github.com/milvus-io/milvus/internal/util/dependency" + _ "github.com/milvus-io/milvus/internal/util/grpcclient" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/tracer" "github.com/milvus-io/milvus/pkg/util/etcd" diff --git a/internal/distributed/rootcoord/service.go b/internal/distributed/rootcoord/service.go index e7be73f655c70..bb994ad4cac8c 100644 --- a/internal/distributed/rootcoord/service.go +++ b/internal/distributed/rootcoord/service.go @@ -43,6 +43,7 @@ import ( "github.com/milvus-io/milvus/internal/rootcoord" "github.com/milvus-io/milvus/internal/types" "github.com/milvus-io/milvus/internal/util/dependency" + _ "github.com/milvus-io/milvus/internal/util/grpcclient" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/tracer" "github.com/milvus-io/milvus/pkg/util" diff --git a/internal/proxy/task_delete.go b/internal/proxy/task_delete.go index 75485c1ffde5a..56bc2589d5cde 100644 --- a/internal/proxy/task_delete.go +++ b/internal/proxy/task_delete.go @@ -262,17 +262,21 @@ func (dt *deleteTask) PostExecute(ctx context.Context) error { func (dt *deleteTask) getStreamingQueryAndDelteFunc(stream msgstream.MsgStream, plan *planpb.PlanNode) executeFunc { return func(ctx context.Context, nodeID int64, qn types.QueryNodeClient, channelIDs ...string) error { - // outputField := translateOutputFields(, dt.schema, true) - partationIDs := []int64{} if dt.partitionID != common.InvalidFieldID { partationIDs = append(partationIDs, dt.partitionID) } + log := log.Ctx(ctx).With( + zap.Int64("collectionID", dt.collectionID), + zap.Int64s("partationIDs", partationIDs), + zap.Strings("channels", channelIDs), + zap.Int64("nodeID", nodeID)) // set plan _, outputFieldIDs := translatePkOutputFields(dt.schema) outputFieldIDs = append(outputFieldIDs, common.TimeStampField) plan.OutputFieldIds = outputFieldIDs + log.Debug("start query for delete") serializedPlan, err := proto.Marshal(plan) if err != nil { @@ -300,9 +304,10 @@ func (dt *deleteTask) getStreamingQueryAndDelteFunc(stream msgstream.MsgStream, Scope: querypb.DataScope_All, } + rc := timerecord.NewTimeRecorder("QueryStreamDelete") client, err := qn.QueryStream(ctx, queryReq) if err != nil { - log.Warn("query for delete return error", zap.Error(err)) + log.Warn("query stream for delete create failed", zap.Error(err)) return err } @@ -310,6 +315,7 @@ func (dt *deleteTask) getStreamingQueryAndDelteFunc(stream msgstream.MsgStream, result, err := client.Recv() if err != nil { if err == io.EOF { + log.Debug("query stream for delete finished", zap.Int64("msgID", dt.msgID), zap.Duration("duration", rc.ElapseSpan())) return nil } return err @@ -317,11 +323,13 @@ func (dt *deleteTask) getStreamingQueryAndDelteFunc(stream msgstream.MsgStream, err = merr.Error(result.GetStatus()) if err != nil { + log.Warn("query stream for delete get error status", zap.Int64("msgID", dt.msgID), zap.Error(err)) return err } err = dt.produce(ctx, stream, result.GetIds()) if err != nil { + log.Warn("query stream for delete produce result failed", zap.Int64("msgID", dt.msgID), zap.Error(err)) return err } } @@ -350,7 +358,10 @@ func (dt *deleteTask) simpleDelete(ctx context.Context, termExp *planpb.Expr_Ter log.Info("Failed to get primary keys from expr", zap.Error(err)) return err } - log.Debug("get primary keys from expr", zap.Int64("len of primary keys", numRow)) + log.Debug("get primary keys from expr", + zap.Int64("len of primary keys", numRow), + zap.Int64("collectionID", dt.collectionID), + zap.Int64("partationID", dt.partitionID)) err = dt.produce(ctx, stream, primaryKeys) if err != nil { return err diff --git a/internal/querycoordv2/balance/rowcount_based_balancer.go b/internal/querycoordv2/balance/rowcount_based_balancer.go index 171f0fe21b3b4..31b59fd500192 100644 --- a/internal/querycoordv2/balance/rowcount_based_balancer.go +++ b/internal/querycoordv2/balance/rowcount_based_balancer.go @@ -97,8 +97,8 @@ func (b *RowCountBasedBalancer) BalanceReplica(replica *meta.Replica) ([]Segment segments := b.dist.SegmentDistManager.GetByCollectionAndNode(replica.GetCollectionID(), nid) // Only balance segments in targets segments = lo.Filter(segments, func(segment *meta.Segment, _ int) bool { - return b.targetMgr.GetHistoricalSegment(segment.GetCollectionID(), segment.GetID(), meta.CurrentTarget) != nil && - b.targetMgr.GetHistoricalSegment(segment.GetCollectionID(), segment.GetID(), meta.NextTarget) != nil + return b.targetMgr.GetSealedSegment(segment.GetCollectionID(), segment.GetID(), meta.CurrentTarget) != nil && + b.targetMgr.GetSealedSegment(segment.GetCollectionID(), segment.GetID(), meta.NextTarget) != nil }) if isStopping, err := b.nodeManager.IsStoppingNode(nid); err != nil { diff --git a/internal/querycoordv2/balance/score_based_balancer.go b/internal/querycoordv2/balance/score_based_balancer.go index 64bbd471620d2..98f7f1405c71e 100644 --- a/internal/querycoordv2/balance/score_based_balancer.go +++ b/internal/querycoordv2/balance/score_based_balancer.go @@ -122,7 +122,7 @@ func (b *ScoreBasedBalancer) BalanceReplica(replica *meta.Replica) ([]SegmentAss segments := b.dist.SegmentDistManager.GetByCollectionAndNode(replica.GetCollectionID(), nid) // Only balance segments in targets segments = lo.Filter(segments, func(segment *meta.Segment, _ int) bool { - return b.targetMgr.GetHistoricalSegment(segment.GetCollectionID(), segment.GetID(), meta.CurrentTarget) != nil + return b.targetMgr.GetSealedSegment(segment.GetCollectionID(), segment.GetID(), meta.CurrentTarget) != nil }) if isStopping, err := b.nodeManager.IsStoppingNode(nid); err != nil { diff --git a/internal/querycoordv2/checkers/index_checker.go b/internal/querycoordv2/checkers/index_checker.go index 48471f928b2e4..943582332bc15 100644 --- a/internal/querycoordv2/checkers/index_checker.go +++ b/internal/querycoordv2/checkers/index_checker.go @@ -82,7 +82,7 @@ func (c *IndexChecker) checkReplica(ctx context.Context, collection *meta.Collec ) var tasks []task.Task - segments := c.getHistoricalSegmentsDist(replica) + segments := c.getSealedSegmentsDist(replica) idSegments := make(map[int64]*meta.Segment) targets := make(map[int64][]int64) // segmentID => FieldID @@ -133,7 +133,7 @@ func (c *IndexChecker) checkSegment(ctx context.Context, segment *meta.Segment, return result } -func (c *IndexChecker) getHistoricalSegmentsDist(replica *meta.Replica) []*meta.Segment { +func (c *IndexChecker) getSealedSegmentsDist(replica *meta.Replica) []*meta.Segment { var ret []*meta.Segment for _, node := range replica.GetNodes() { ret = append(ret, c.dist.SegmentDistManager.GetByCollectionAndNode(replica.CollectionID, node)...) diff --git a/internal/querycoordv2/checkers/segment_checker.go b/internal/querycoordv2/checkers/segment_checker.go index 56fc61e9aa692..c971886e14335 100644 --- a/internal/querycoordv2/checkers/segment_checker.go +++ b/internal/querycoordv2/checkers/segment_checker.go @@ -111,7 +111,7 @@ func (c *SegmentChecker) checkReplica(ctx context.Context, replica *meta.Replica } // compare with targets to find the lack and redundancy of segments - lacks, redundancies := c.getHistoricalSegmentDiff(replica.GetCollectionID(), replica.GetID()) + lacks, redundancies := c.getSealedSegmentDiff(replica.GetCollectionID(), replica.GetID()) tasks := c.createSegmentLoadTasks(ctx, lacks, replica) task.SetReason("lacks of segment", tasks...) ret = append(ret, tasks...) @@ -122,14 +122,14 @@ func (c *SegmentChecker) checkReplica(ctx context.Context, replica *meta.Replica ret = append(ret, tasks...) // compare inner dists to find repeated loaded segments - redundancies = c.findRepeatedHistoricalSegments(replica.GetID()) + redundancies = c.findRepeatedSealedSegments(replica.GetID()) redundancies = c.filterExistedOnLeader(replica, redundancies) tasks = c.createSegmentReduceTasks(ctx, redundancies, replica.GetID(), querypb.DataScope_Historical) task.SetReason("redundancies of segment", tasks...) ret = append(ret, tasks...) // compare with target to find the lack and redundancy of segments - _, redundancies = c.getStreamingSegmentDiff(replica.GetCollectionID(), replica.GetID()) + _, redundancies = c.getGrowingSegmentDiff(replica.GetCollectionID(), replica.GetID()) tasks = c.createSegmentReduceTasks(ctx, redundancies, replica.GetID(), querypb.DataScope_Streaming) task.SetReason("streaming segment not exists in target", tasks...) ret = append(ret, tasks...) @@ -137,8 +137,8 @@ func (c *SegmentChecker) checkReplica(ctx context.Context, replica *meta.Replica return ret } -// GetStreamingSegmentDiff get streaming segment diff between leader view and target -func (c *SegmentChecker) getStreamingSegmentDiff(collectionID int64, +// GetGrowingSegmentDiff get streaming segment diff between leader view and target +func (c *SegmentChecker) getGrowingSegmentDiff(collectionID int64, replicaID int64, ) (toLoad []*datapb.SegmentInfo, toRelease []*meta.Segment) { replica := c.meta.Get(replicaID) @@ -171,8 +171,8 @@ func (c *SegmentChecker) getStreamingSegmentDiff(collectionID int64, continue } - nextTargetSegmentIDs := c.targetMgr.GetStreamingSegmentsByCollection(collectionID, meta.NextTarget) - currentTargetSegmentIDs := c.targetMgr.GetStreamingSegmentsByCollection(collectionID, meta.CurrentTarget) + nextTargetSegmentIDs := c.targetMgr.GetGrowingSegmentsByCollection(collectionID, meta.NextTarget) + currentTargetSegmentIDs := c.targetMgr.GetGrowingSegmentsByCollection(collectionID, meta.CurrentTarget) currentTargetChannelMap := c.targetMgr.GetDmChannelsByCollection(collectionID, meta.CurrentTarget) // get segment which exist on leader view, but not on current target and next target @@ -196,8 +196,8 @@ func (c *SegmentChecker) getStreamingSegmentDiff(collectionID int64, return } -// GetHistoricalSegmentDiff get historical segment diff between target and dist -func (c *SegmentChecker) getHistoricalSegmentDiff( +// GetSealedSegmentDiff get historical segment diff between target and dist +func (c *SegmentChecker) getSealedSegmentDiff( collectionID int64, replicaID int64, ) (toLoad []*datapb.SegmentInfo, toRelease []*meta.Segment) { @@ -206,7 +206,7 @@ func (c *SegmentChecker) getHistoricalSegmentDiff( log.Info("replica does not exist, skip it") return } - dist := c.getHistoricalSegmentsDist(replica) + dist := c.getSealedSegmentsDist(replica) sort.Slice(dist, func(i, j int) bool { return dist[i].Version < dist[j].Version }) @@ -215,8 +215,8 @@ func (c *SegmentChecker) getHistoricalSegmentDiff( distMap[s.GetID()] = s.Node } - nextTargetMap := c.targetMgr.GetHistoricalSegmentsByCollection(collectionID, meta.NextTarget) - currentTargetMap := c.targetMgr.GetHistoricalSegmentsByCollection(collectionID, meta.CurrentTarget) + nextTargetMap := c.targetMgr.GetSealedSegmentsByCollection(collectionID, meta.NextTarget) + currentTargetMap := c.targetMgr.GetSealedSegmentsByCollection(collectionID, meta.CurrentTarget) // Segment which exist on next target, but not on dist for segmentID, segment := range nextTargetMap { @@ -256,7 +256,7 @@ func (c *SegmentChecker) getHistoricalSegmentDiff( return } -func (c *SegmentChecker) getHistoricalSegmentsDist(replica *meta.Replica) []*meta.Segment { +func (c *SegmentChecker) getSealedSegmentsDist(replica *meta.Replica) []*meta.Segment { ret := make([]*meta.Segment, 0) for _, node := range replica.GetNodes() { ret = append(ret, c.dist.SegmentDistManager.GetByCollectionAndNode(replica.CollectionID, node)...) @@ -264,14 +264,14 @@ func (c *SegmentChecker) getHistoricalSegmentsDist(replica *meta.Replica) []*met return ret } -func (c *SegmentChecker) findRepeatedHistoricalSegments(replicaID int64) []*meta.Segment { +func (c *SegmentChecker) findRepeatedSealedSegments(replicaID int64) []*meta.Segment { segments := make([]*meta.Segment, 0) replica := c.meta.Get(replicaID) if replica == nil { log.Info("replica does not exist, skip it") return segments } - dist := c.getHistoricalSegmentsDist(replica) + dist := c.getSealedSegmentsDist(replica) versions := make(map[int64]*meta.Segment) for _, s := range dist { maxVer, ok := versions[s.GetID()] diff --git a/internal/querycoordv2/dist/dist_handler.go b/internal/querycoordv2/dist/dist_handler.go index 60cf6487f3202..66fc0139a8ccc 100644 --- a/internal/querycoordv2/dist/dist_handler.go +++ b/internal/querycoordv2/dist/dist_handler.go @@ -114,10 +114,10 @@ func (dh *distHandler) updateSegmentsDistribution(resp *querypb.GetDataDistribut updates := make([]*meta.Segment, 0, len(resp.GetSegments())) for _, s := range resp.GetSegments() { // for collection which is already loaded - segmentInfo := dh.target.GetHistoricalSegment(s.GetCollection(), s.GetID(), meta.CurrentTarget) + segmentInfo := dh.target.GetSealedSegment(s.GetCollection(), s.GetID(), meta.CurrentTarget) if segmentInfo == nil { // for collection which is loading - segmentInfo = dh.target.GetHistoricalSegment(s.GetCollection(), s.GetID(), meta.NextTarget) + segmentInfo = dh.target.GetSealedSegment(s.GetCollection(), s.GetID(), meta.NextTarget) } var segment *meta.Segment if segmentInfo == nil { diff --git a/internal/querycoordv2/handlers.go b/internal/querycoordv2/handlers.go index 92e90373d5a93..56bbec48b0931 100644 --- a/internal/querycoordv2/handlers.go +++ b/internal/querycoordv2/handlers.go @@ -62,7 +62,7 @@ func (s *Server) checkAnyReplicaAvailable(collectionID int64) bool { func (s *Server) getCollectionSegmentInfo(collection int64) []*querypb.SegmentInfo { segments := s.dist.SegmentDistManager.GetByCollection(collection) - currentTargetSegmentsMap := s.targetMgr.GetHistoricalSegmentsByCollection(collection, meta.CurrentTarget) + currentTargetSegmentsMap := s.targetMgr.GetSealedSegmentsByCollection(collection, meta.CurrentTarget) infos := make(map[int64]*querypb.SegmentInfo) for _, segment := range segments { if _, existCurrentTarget := currentTargetSegmentsMap[segment.GetID()]; !existCurrentTarget { @@ -109,7 +109,7 @@ func (s *Server) balanceSegments(ctx context.Context, req *querypb.LoadBalanceRe // Only balance segments in targets segments := s.dist.SegmentDistManager.GetByCollectionAndNode(req.GetCollectionID(), srcNode) segments = lo.Filter(segments, func(segment *meta.Segment, _ int) bool { - return s.targetMgr.GetHistoricalSegment(segment.GetCollectionID(), segment.GetID(), meta.CurrentTarget) != nil + return s.targetMgr.GetSealedSegment(segment.GetCollectionID(), segment.GetID(), meta.CurrentTarget) != nil }) allSegments := make(map[int64]*meta.Segment) for _, segment := range segments { diff --git a/internal/querycoordv2/job/job_test.go b/internal/querycoordv2/job/job_test.go index 145a8144441e2..8f230cb8616d1 100644 --- a/internal/querycoordv2/job/job_test.go +++ b/internal/querycoordv2/job/job_test.go @@ -1484,7 +1484,7 @@ func (suite *JobSuite) assertCollectionLoaded(collection int64) { } for _, segments := range suite.segments[collection] { for _, segment := range segments { - suite.NotNil(suite.targetMgr.GetHistoricalSegment(collection, segment, meta.CurrentTarget)) + suite.NotNil(suite.targetMgr.GetSealedSegment(collection, segment, meta.CurrentTarget)) } } } @@ -1501,7 +1501,7 @@ func (suite *JobSuite) assertPartitionLoaded(collection int64, partitionIDs ...i } suite.NotNil(suite.meta.GetPartition(partitionID)) for _, segment := range segments { - suite.NotNil(suite.targetMgr.GetHistoricalSegment(collection, segment, meta.CurrentTarget)) + suite.NotNil(suite.targetMgr.GetSealedSegment(collection, segment, meta.CurrentTarget)) } } } @@ -1514,7 +1514,7 @@ func (suite *JobSuite) assertCollectionReleased(collection int64) { } for _, partitions := range suite.segments[collection] { for _, segment := range partitions { - suite.Nil(suite.targetMgr.GetHistoricalSegment(collection, segment, meta.CurrentTarget)) + suite.Nil(suite.targetMgr.GetSealedSegment(collection, segment, meta.CurrentTarget)) } } } @@ -1524,7 +1524,7 @@ func (suite *JobSuite) assertPartitionReleased(collection int64, partitionIDs .. suite.Nil(suite.meta.GetPartition(partition)) segments := suite.segments[collection][partition] for _, segment := range segments { - suite.Nil(suite.targetMgr.GetHistoricalSegment(collection, segment, meta.CurrentTarget)) + suite.Nil(suite.targetMgr.GetSealedSegment(collection, segment, meta.CurrentTarget)) } } } diff --git a/internal/querycoordv2/meta/target_manager.go b/internal/querycoordv2/meta/target_manager.go index c45c23a2f3f52..e6772ede82c0f 100644 --- a/internal/querycoordv2/meta/target_manager.go +++ b/internal/querycoordv2/meta/target_manager.go @@ -324,7 +324,7 @@ func (mgr *TargetManager) getTarget(scope TargetScope) *target { return mgr.next } -func (mgr *TargetManager) GetStreamingSegmentsByCollection(collectionID int64, +func (mgr *TargetManager) GetGrowingSegmentsByCollection(collectionID int64, scope TargetScope, ) typeutil.UniqueSet { mgr.rwMutex.RLock() @@ -345,7 +345,7 @@ func (mgr *TargetManager) GetStreamingSegmentsByCollection(collectionID int64, return segments } -func (mgr *TargetManager) GetStreamingSegmentsByChannel(collectionID int64, +func (mgr *TargetManager) GetGrowingSegmentsByChannel(collectionID int64, channelName string, scope TargetScope, ) typeutil.UniqueSet { @@ -369,7 +369,7 @@ func (mgr *TargetManager) GetStreamingSegmentsByChannel(collectionID int64, return segments } -func (mgr *TargetManager) GetHistoricalSegmentsByCollection(collectionID int64, +func (mgr *TargetManager) GetSealedSegmentsByCollection(collectionID int64, scope TargetScope, ) map[int64]*datapb.SegmentInfo { mgr.rwMutex.RLock() @@ -384,7 +384,7 @@ func (mgr *TargetManager) GetHistoricalSegmentsByCollection(collectionID int64, return collectionTarget.GetAllSegments() } -func (mgr *TargetManager) GetHistoricalSegmentsByChannel(collectionID int64, +func (mgr *TargetManager) GetSealedSegmentsByChannel(collectionID int64, channelName string, scope TargetScope, ) map[int64]*datapb.SegmentInfo { @@ -430,7 +430,7 @@ func (mgr *TargetManager) GetDroppedSegmentsByChannel(collectionID int64, return channel.GetDroppedSegmentIds() } -func (mgr *TargetManager) GetHistoricalSegmentsByPartition(collectionID int64, +func (mgr *TargetManager) GetSealedSegmentsByPartition(collectionID int64, partitionID int64, scope TargetScope, ) map[int64]*datapb.SegmentInfo { mgr.rwMutex.RLock() @@ -479,7 +479,7 @@ func (mgr *TargetManager) GetDmChannel(collectionID int64, channel string, scope return collectionTarget.GetAllDmChannels()[channel] } -func (mgr *TargetManager) GetHistoricalSegment(collectionID int64, id int64, scope TargetScope) *datapb.SegmentInfo { +func (mgr *TargetManager) GetSealedSegment(collectionID int64, id int64, scope TargetScope) *datapb.SegmentInfo { mgr.rwMutex.RLock() defer mgr.rwMutex.RUnlock() targetMap := mgr.getTarget(scope) diff --git a/internal/querycoordv2/meta/target_manager_test.go b/internal/querycoordv2/meta/target_manager_test.go index 41443d3cb64ce..cdc1a8c1479a0 100644 --- a/internal/querycoordv2/meta/target_manager_test.go +++ b/internal/querycoordv2/meta/target_manager_test.go @@ -162,24 +162,24 @@ func (suite *TargetManagerSuite) TearDownSuite() { func (suite *TargetManagerSuite) TestUpdateCurrentTarget() { collectionID := int64(1000) suite.assertSegments(suite.getAllSegment(collectionID, suite.partitions[collectionID]), - suite.mgr.GetHistoricalSegmentsByCollection(collectionID, NextTarget)) + suite.mgr.GetSealedSegmentsByCollection(collectionID, NextTarget)) suite.assertChannels(suite.channels[collectionID], suite.mgr.GetDmChannelsByCollection(collectionID, NextTarget)) - suite.assertSegments([]int64{}, suite.mgr.GetHistoricalSegmentsByCollection(collectionID, CurrentTarget)) + suite.assertSegments([]int64{}, suite.mgr.GetSealedSegmentsByCollection(collectionID, CurrentTarget)) suite.assertChannels([]string{}, suite.mgr.GetDmChannelsByCollection(collectionID, CurrentTarget)) suite.mgr.UpdateCollectionCurrentTarget(collectionID) - suite.assertSegments([]int64{}, suite.mgr.GetHistoricalSegmentsByCollection(collectionID, NextTarget)) + suite.assertSegments([]int64{}, suite.mgr.GetSealedSegmentsByCollection(collectionID, NextTarget)) suite.assertChannels([]string{}, suite.mgr.GetDmChannelsByCollection(collectionID, NextTarget)) suite.assertSegments(suite.getAllSegment(collectionID, suite.partitions[collectionID]), - suite.mgr.GetHistoricalSegmentsByCollection(collectionID, CurrentTarget)) + suite.mgr.GetSealedSegmentsByCollection(collectionID, CurrentTarget)) suite.assertChannels(suite.channels[collectionID], suite.mgr.GetDmChannelsByCollection(collectionID, CurrentTarget)) } func (suite *TargetManagerSuite) TestUpdateNextTarget() { collectionID := int64(1003) - suite.assertSegments([]int64{}, suite.mgr.GetHistoricalSegmentsByCollection(collectionID, NextTarget)) + suite.assertSegments([]int64{}, suite.mgr.GetSealedSegmentsByCollection(collectionID, NextTarget)) suite.assertChannels([]string{}, suite.mgr.GetDmChannelsByCollection(collectionID, NextTarget)) - suite.assertSegments([]int64{}, suite.mgr.GetHistoricalSegmentsByCollection(collectionID, CurrentTarget)) + suite.assertSegments([]int64{}, suite.mgr.GetSealedSegmentsByCollection(collectionID, CurrentTarget)) suite.assertChannels([]string{}, suite.mgr.GetDmChannelsByCollection(collectionID, CurrentTarget)) suite.meta.PutCollection(&Collection{ @@ -232,9 +232,9 @@ func (suite *TargetManagerSuite) TestUpdateNextTarget() { suite.broker.EXPECT().GetRecoveryInfoV2(mock.Anything, collectionID).Return(nextTargetChannels, nextTargetSegments, nil) suite.mgr.UpdateCollectionNextTarget(collectionID) - suite.assertSegments([]int64{11, 12}, suite.mgr.GetHistoricalSegmentsByCollection(collectionID, NextTarget)) + suite.assertSegments([]int64{11, 12}, suite.mgr.GetSealedSegmentsByCollection(collectionID, NextTarget)) suite.assertChannels([]string{"channel-1", "channel-2"}, suite.mgr.GetDmChannelsByCollection(collectionID, NextTarget)) - suite.assertSegments([]int64{}, suite.mgr.GetHistoricalSegmentsByCollection(collectionID, CurrentTarget)) + suite.assertSegments([]int64{}, suite.mgr.GetSealedSegmentsByCollection(collectionID, CurrentTarget)) suite.assertChannels([]string{}, suite.mgr.GetDmChannelsByCollection(collectionID, CurrentTarget)) suite.broker.ExpectedCalls = nil @@ -259,42 +259,42 @@ func (suite *TargetManagerSuite) TestUpdateNextTarget() { func (suite *TargetManagerSuite) TestRemovePartition() { collectionID := int64(1000) - suite.assertSegments(suite.getAllSegment(collectionID, suite.partitions[collectionID]), suite.mgr.GetHistoricalSegmentsByCollection(collectionID, NextTarget)) + suite.assertSegments(suite.getAllSegment(collectionID, suite.partitions[collectionID]), suite.mgr.GetSealedSegmentsByCollection(collectionID, NextTarget)) suite.assertChannels(suite.channels[collectionID], suite.mgr.GetDmChannelsByCollection(collectionID, NextTarget)) - suite.assertSegments([]int64{}, suite.mgr.GetHistoricalSegmentsByCollection(collectionID, CurrentTarget)) + suite.assertSegments([]int64{}, suite.mgr.GetSealedSegmentsByCollection(collectionID, CurrentTarget)) suite.assertChannels([]string{}, suite.mgr.GetDmChannelsByCollection(collectionID, CurrentTarget)) suite.mgr.RemovePartition(collectionID, 100) - suite.assertSegments([]int64{3, 4}, suite.mgr.GetHistoricalSegmentsByCollection(collectionID, NextTarget)) + suite.assertSegments([]int64{3, 4}, suite.mgr.GetSealedSegmentsByCollection(collectionID, NextTarget)) suite.assertChannels(suite.channels[collectionID], suite.mgr.GetDmChannelsByCollection(collectionID, NextTarget)) - suite.assertSegments([]int64{}, suite.mgr.GetHistoricalSegmentsByCollection(collectionID, CurrentTarget)) + suite.assertSegments([]int64{}, suite.mgr.GetSealedSegmentsByCollection(collectionID, CurrentTarget)) suite.assertChannels([]string{}, suite.mgr.GetDmChannelsByCollection(collectionID, CurrentTarget)) } func (suite *TargetManagerSuite) TestRemoveCollection() { collectionID := int64(1000) - suite.assertSegments(suite.getAllSegment(collectionID, suite.partitions[collectionID]), suite.mgr.GetHistoricalSegmentsByCollection(collectionID, NextTarget)) + suite.assertSegments(suite.getAllSegment(collectionID, suite.partitions[collectionID]), suite.mgr.GetSealedSegmentsByCollection(collectionID, NextTarget)) suite.assertChannels(suite.channels[collectionID], suite.mgr.GetDmChannelsByCollection(collectionID, NextTarget)) - suite.assertSegments([]int64{}, suite.mgr.GetHistoricalSegmentsByCollection(collectionID, CurrentTarget)) + suite.assertSegments([]int64{}, suite.mgr.GetSealedSegmentsByCollection(collectionID, CurrentTarget)) suite.assertChannels([]string{}, suite.mgr.GetDmChannelsByCollection(collectionID, CurrentTarget)) suite.mgr.RemoveCollection(collectionID) - suite.assertSegments([]int64{}, suite.mgr.GetHistoricalSegmentsByCollection(collectionID, NextTarget)) + suite.assertSegments([]int64{}, suite.mgr.GetSealedSegmentsByCollection(collectionID, NextTarget)) suite.assertChannels([]string{}, suite.mgr.GetDmChannelsByCollection(collectionID, NextTarget)) - suite.assertSegments([]int64{}, suite.mgr.GetHistoricalSegmentsByCollection(collectionID, CurrentTarget)) + suite.assertSegments([]int64{}, suite.mgr.GetSealedSegmentsByCollection(collectionID, CurrentTarget)) suite.assertChannels([]string{}, suite.mgr.GetDmChannelsByCollection(collectionID, CurrentTarget)) collectionID = int64(1001) suite.mgr.UpdateCollectionCurrentTarget(collectionID) - suite.assertSegments([]int64{}, suite.mgr.GetHistoricalSegmentsByCollection(collectionID, NextTarget)) + suite.assertSegments([]int64{}, suite.mgr.GetSealedSegmentsByCollection(collectionID, NextTarget)) suite.assertChannels([]string{}, suite.mgr.GetDmChannelsByCollection(collectionID, NextTarget)) - suite.assertSegments(suite.getAllSegment(collectionID, suite.partitions[collectionID]), suite.mgr.GetHistoricalSegmentsByCollection(collectionID, CurrentTarget)) + suite.assertSegments(suite.getAllSegment(collectionID, suite.partitions[collectionID]), suite.mgr.GetSealedSegmentsByCollection(collectionID, CurrentTarget)) suite.assertChannels(suite.channels[collectionID], suite.mgr.GetDmChannelsByCollection(collectionID, CurrentTarget)) suite.mgr.RemoveCollection(collectionID) - suite.assertSegments([]int64{}, suite.mgr.GetHistoricalSegmentsByCollection(collectionID, NextTarget)) + suite.assertSegments([]int64{}, suite.mgr.GetSealedSegmentsByCollection(collectionID, NextTarget)) suite.assertChannels([]string{}, suite.mgr.GetDmChannelsByCollection(collectionID, NextTarget)) - suite.assertSegments([]int64{}, suite.mgr.GetHistoricalSegmentsByCollection(collectionID, CurrentTarget)) + suite.assertSegments([]int64{}, suite.mgr.GetSealedSegmentsByCollection(collectionID, CurrentTarget)) suite.assertChannels([]string{}, suite.mgr.GetDmChannelsByCollection(collectionID, CurrentTarget)) } @@ -360,9 +360,9 @@ func (suite *TargetManagerSuite) TestGetCollectionTargetVersion() { func (suite *TargetManagerSuite) TestGetSegmentByChannel() { collectionID := int64(1003) - suite.assertSegments([]int64{}, suite.mgr.GetHistoricalSegmentsByCollection(collectionID, NextTarget)) + suite.assertSegments([]int64{}, suite.mgr.GetSealedSegmentsByCollection(collectionID, NextTarget)) suite.assertChannels([]string{}, suite.mgr.GetDmChannelsByCollection(collectionID, NextTarget)) - suite.assertSegments([]int64{}, suite.mgr.GetHistoricalSegmentsByCollection(collectionID, CurrentTarget)) + suite.assertSegments([]int64{}, suite.mgr.GetSealedSegmentsByCollection(collectionID, CurrentTarget)) suite.assertChannels([]string{}, suite.mgr.GetDmChannelsByCollection(collectionID, CurrentTarget)) suite.meta.PutCollection(&Collection{ @@ -407,11 +407,11 @@ func (suite *TargetManagerSuite) TestGetSegmentByChannel() { suite.broker.EXPECT().GetRecoveryInfoV2(mock.Anything, collectionID).Return(nextTargetChannels, nextTargetSegments, nil) suite.mgr.UpdateCollectionNextTarget(collectionID) - suite.Len(suite.mgr.GetHistoricalSegmentsByCollection(collectionID, NextTarget), 2) - suite.Len(suite.mgr.GetHistoricalSegmentsByChannel(collectionID, "channel-1", NextTarget), 1) - suite.Len(suite.mgr.GetHistoricalSegmentsByChannel(collectionID, "channel-2", NextTarget), 1) - suite.Len(suite.mgr.GetStreamingSegmentsByChannel(collectionID, "channel-1", NextTarget), 4) - suite.Len(suite.mgr.GetStreamingSegmentsByChannel(collectionID, "channel-2", NextTarget), 1) + suite.Len(suite.mgr.GetSealedSegmentsByCollection(collectionID, NextTarget), 2) + suite.Len(suite.mgr.GetSealedSegmentsByChannel(collectionID, "channel-1", NextTarget), 1) + suite.Len(suite.mgr.GetSealedSegmentsByChannel(collectionID, "channel-2", NextTarget), 1) + suite.Len(suite.mgr.GetGrowingSegmentsByChannel(collectionID, "channel-1", NextTarget), 4) + suite.Len(suite.mgr.GetGrowingSegmentsByChannel(collectionID, "channel-2", NextTarget), 1) suite.Len(suite.mgr.GetDroppedSegmentsByChannel(collectionID, "channel-1", NextTarget), 3) } diff --git a/internal/querycoordv2/observers/collection_observer.go b/internal/querycoordv2/observers/collection_observer.go index 685074749690d..a1a8825a056ec 100644 --- a/internal/querycoordv2/observers/collection_observer.go +++ b/internal/querycoordv2/observers/collection_observer.go @@ -182,7 +182,7 @@ func (ob *CollectionObserver) observePartitionLoadStatus(ctx context.Context, pa zap.Int64("partitionID", partition.GetPartitionID()), ) - segmentTargets := ob.targetMgr.GetHistoricalSegmentsByPartition(partition.GetCollectionID(), partition.GetPartitionID(), meta.NextTarget) + segmentTargets := ob.targetMgr.GetSealedSegmentsByPartition(partition.GetCollectionID(), partition.GetPartitionID(), meta.NextTarget) channelTargets := ob.targetMgr.GetDmChannelsByCollection(partition.GetCollectionID(), meta.NextTarget) targetNum := len(segmentTargets) + len(channelTargets) @@ -226,7 +226,15 @@ func (ob *CollectionObserver) observePartitionLoadStatus(ctx context.Context, pa } ob.partitionLoadedCount[partition.GetPartitionID()] = loadedCount - if loadPercentage == 100 && ob.targetObserver.Check(ctx, partition.GetCollectionID()) && ob.leaderObserver.CheckTargetVersion(ctx, partition.GetCollectionID()) { + if loadPercentage == 100 { + if !ob.targetObserver.Check(ctx, partition.GetCollectionID()) { + log.Warn("failed to manual check current target, skip update load status") + return + } + if !ob.leaderObserver.CheckTargetVersion(ctx, partition.GetCollectionID()) { + log.Warn("failed to manual check leader target version ,skip update load status") + return + } delete(ob.partitionLoadedCount, partition.GetPartitionID()) } collectionPercentage, err := ob.meta.CollectionManager.UpdateLoadPercent(partition.PartitionID, loadPercentage) diff --git a/internal/querycoordv2/observers/collection_observer_test.go b/internal/querycoordv2/observers/collection_observer_test.go index 9a7b14744d29a..e0c98a592be8e 100644 --- a/internal/querycoordv2/observers/collection_observer_test.go +++ b/internal/querycoordv2/observers/collection_observer_test.go @@ -328,7 +328,7 @@ func (suite *CollectionObserverSuite) isCollectionLoaded(collection int64) bool status := suite.meta.CalculateLoadStatus(collection) replicas := suite.meta.ReplicaManager.GetByCollection(collection) channels := suite.targetMgr.GetDmChannelsByCollection(collection, meta.CurrentTarget) - segments := suite.targetMgr.GetHistoricalSegmentsByCollection(collection, meta.CurrentTarget) + segments := suite.targetMgr.GetSealedSegmentsByCollection(collection, meta.CurrentTarget) return exist && percentage == 100 && @@ -347,7 +347,7 @@ func (suite *CollectionObserverSuite) isPartitionLoaded(partitionID int64) bool percentage := suite.meta.GetPartitionLoadPercentage(partitionID) status := partition.GetStatus() channels := suite.targetMgr.GetDmChannelsByCollection(collection, meta.CurrentTarget) - segments := suite.targetMgr.GetHistoricalSegmentsByPartition(collection, partitionID, meta.CurrentTarget) + segments := suite.targetMgr.GetSealedSegmentsByPartition(collection, partitionID, meta.CurrentTarget) expectedSegments := lo.Filter(suite.segments[collection], func(seg *datapb.SegmentInfo, _ int) bool { return seg.PartitionID == partitionID }) @@ -361,7 +361,7 @@ func (suite *CollectionObserverSuite) isCollectionTimeout(collection int64) bool exist := suite.meta.Exist(collection) replicas := suite.meta.ReplicaManager.GetByCollection(collection) channels := suite.targetMgr.GetDmChannelsByCollection(collection, meta.CurrentTarget) - segments := suite.targetMgr.GetHistoricalSegmentsByCollection(collection, meta.CurrentTarget) + segments := suite.targetMgr.GetSealedSegmentsByCollection(collection, meta.CurrentTarget) return !(exist || len(replicas) > 0 || len(channels) > 0 || @@ -370,7 +370,7 @@ func (suite *CollectionObserverSuite) isCollectionTimeout(collection int64) bool func (suite *CollectionObserverSuite) isPartitionTimeout(collection int64, partitionID int64) bool { partition := suite.meta.GetPartition(partitionID) - segments := suite.targetMgr.GetHistoricalSegmentsByPartition(collection, partitionID, meta.CurrentTarget) + segments := suite.targetMgr.GetSealedSegmentsByPartition(collection, partitionID, meta.CurrentTarget) return partition == nil && len(segments) == 0 } diff --git a/internal/querycoordv2/observers/leader_observer.go b/internal/querycoordv2/observers/leader_observer.go index 0e01477fbdfa9..5778a29766f6f 100644 --- a/internal/querycoordv2/observers/leader_observer.go +++ b/internal/querycoordv2/observers/leader_observer.go @@ -41,14 +41,15 @@ const ( // LeaderObserver is to sync the distribution with leader type LeaderObserver struct { - wg sync.WaitGroup - cancel context.CancelFunc - dist *meta.DistributionManager - meta *meta.Meta - target *meta.TargetManager - broker meta.Broker - cluster session.Cluster - manualCheck chan checkRequest + wg sync.WaitGroup + cancel context.CancelFunc + dist *meta.DistributionManager + meta *meta.Meta + target *meta.TargetManager + broker meta.Broker + cluster session.Cluster + + dispatcher *taskDispatcher[int64] stopOnce sync.Once } @@ -57,27 +58,12 @@ func (o *LeaderObserver) Start() { ctx, cancel := context.WithCancel(context.Background()) o.cancel = cancel + o.dispatcher.Start() + o.wg.Add(1) go func() { defer o.wg.Done() - ticker := time.NewTicker(interval) - defer ticker.Stop() - for { - select { - case <-ctx.Done(): - log.Info("stop leader observer") - return - - case req := <-o.manualCheck: - log.Info("triggering manual check") - ret := o.observeCollection(ctx, req.CollectionID) - req.Notifier <- ret - log.Info("manual check done", zap.Bool("result", ret)) - - case <-ticker.C: - o.observe(ctx) - } - } + o.schedule(ctx) }() } @@ -87,9 +73,26 @@ func (o *LeaderObserver) Stop() { o.cancel() } o.wg.Wait() + + o.dispatcher.Stop() }) } +func (o *LeaderObserver) schedule(ctx context.Context) { + ticker := time.NewTicker(interval) + defer ticker.Stop() + for { + select { + case <-ctx.Done(): + log.Info("stop leader observer") + return + + case <-ticker.C: + o.observe(ctx) + } + } +} + func (o *LeaderObserver) observe(ctx context.Context) { o.observeSegmentsDist(ctx) } @@ -105,14 +108,13 @@ func (o *LeaderObserver) observeSegmentsDist(ctx context.Context) { collectionIDs := o.meta.CollectionManager.GetAll() for _, cid := range collectionIDs { if o.readyToObserve(cid) { - o.observeCollection(ctx, cid) + o.dispatcher.AddTask(cid) } } } -func (o *LeaderObserver) observeCollection(ctx context.Context, collection int64) bool { +func (o *LeaderObserver) observeCollection(ctx context.Context, collection int64) { replicas := o.meta.ReplicaManager.GetByCollection(collection) - result := true for _, replica := range replicas { leaders := o.dist.ChannelDistManager.GetShardLeadersByReplica(replica) for ch, leaderID := range leaders { @@ -128,29 +130,42 @@ func (o *LeaderObserver) observeCollection(ctx context.Context, collection int64 if updateVersionAction != nil { actions = append(actions, updateVersionAction) } - success := o.sync(ctx, replica.GetID(), leaderView, actions) - if !success { - result = false - } + o.sync(ctx, replica.GetID(), leaderView, actions) } } - return result } -func (ob *LeaderObserver) CheckTargetVersion(ctx context.Context, collectionID int64) bool { - notifier := make(chan bool) - select { - case ob.manualCheck <- checkRequest{CollectionID: collectionID, Notifier: notifier}: - case <-ctx.Done(): +func (o *LeaderObserver) CheckTargetVersion(ctx context.Context, collectionID int64) bool { + // if not ready to observer, skip add task + if !o.readyToObserve(collectionID) { return false } - select { - case result := <-notifier: - return result - case <-ctx.Done(): - return false + result := o.checkCollectionLeaderVersionIsCurrent(ctx, collectionID) + if !result { + o.dispatcher.AddTask(collectionID) } + + return result +} + +func (o *LeaderObserver) checkCollectionLeaderVersionIsCurrent(ctx context.Context, collectionID int64) bool { + replicas := o.meta.ReplicaManager.GetByCollection(collectionID) + for _, replica := range replicas { + leaders := o.dist.ChannelDistManager.GetShardLeadersByReplica(replica) + for ch, leaderID := range leaders { + leaderView := o.dist.LeaderViewManager.GetLeaderShardView(leaderID, ch) + if leaderView == nil { + return false + } + + action := o.checkNeedUpdateTargetVersion(ctx, leaderView) + if action != nil { + return false + } + } + } + return true } func (o *LeaderObserver) checkNeedUpdateTargetVersion(ctx context.Context, leaderView *meta.LeaderView) *querypb.SyncAction { @@ -169,8 +184,8 @@ func (o *LeaderObserver) checkNeedUpdateTargetVersion(ctx context.Context, leade zap.Int64("newVersion", targetVersion), ) - sealedSegments := o.target.GetHistoricalSegmentsByChannel(leaderView.CollectionID, leaderView.Channel, meta.CurrentTarget) - growingSegments := o.target.GetStreamingSegmentsByChannel(leaderView.CollectionID, leaderView.Channel, meta.CurrentTarget) + sealedSegments := o.target.GetSealedSegmentsByChannel(leaderView.CollectionID, leaderView.Channel, meta.CurrentTarget) + growingSegments := o.target.GetGrowingSegmentsByChannel(leaderView.CollectionID, leaderView.Channel, meta.CurrentTarget) droppedSegments := o.target.GetDroppedSegmentsByChannel(leaderView.CollectionID, leaderView.Channel, meta.CurrentTarget) return &querypb.SyncAction{ @@ -187,9 +202,9 @@ func (o *LeaderObserver) findNeedLoadedSegments(leaderView *meta.LeaderView, dis dists = utils.FindMaxVersionSegments(dists) for _, s := range dists { version, ok := leaderView.Segments[s.GetID()] - currentTarget := o.target.GetHistoricalSegment(s.CollectionID, s.GetID(), meta.CurrentTarget) + currentTarget := o.target.GetSealedSegment(s.CollectionID, s.GetID(), meta.CurrentTarget) existInCurrentTarget := currentTarget != nil - existInNextTarget := o.target.GetHistoricalSegment(s.CollectionID, s.GetID(), meta.NextTarget) != nil + existInNextTarget := o.target.GetSealedSegment(s.CollectionID, s.GetID(), meta.NextTarget) != nil if !existInCurrentTarget && !existInNextTarget { continue @@ -231,8 +246,8 @@ func (o *LeaderObserver) findNeedRemovedSegments(leaderView *meta.LeaderView, di } for sid, s := range leaderView.Segments { _, ok := distMap[sid] - existInCurrentTarget := o.target.GetHistoricalSegment(leaderView.CollectionID, sid, meta.CurrentTarget) != nil - existInNextTarget := o.target.GetHistoricalSegment(leaderView.CollectionID, sid, meta.NextTarget) != nil + existInCurrentTarget := o.target.GetSealedSegment(leaderView.CollectionID, sid, meta.CurrentTarget) != nil + existInNextTarget := o.target.GetSealedSegment(leaderView.CollectionID, sid, meta.NextTarget) != nil if ok || existInCurrentTarget || existInNextTarget { continue } @@ -312,12 +327,16 @@ func NewLeaderObserver( broker meta.Broker, cluster session.Cluster, ) *LeaderObserver { - return &LeaderObserver{ - dist: dist, - meta: meta, - target: targetMgr, - broker: broker, - cluster: cluster, - manualCheck: make(chan checkRequest, 10), + ob := &LeaderObserver{ + dist: dist, + meta: meta, + target: targetMgr, + broker: broker, + cluster: cluster, } + + dispatcher := newTaskDispatcher[int64](ob.observeCollection) + ob.dispatcher = dispatcher + + return ob } diff --git a/internal/querycoordv2/observers/leader_observer_test.go b/internal/querycoordv2/observers/leader_observer_test.go index c2f1f771d2ece..3a2738f4ff1aa 100644 --- a/internal/querycoordv2/observers/leader_observer_test.go +++ b/internal/querycoordv2/observers/leader_observer_test.go @@ -591,44 +591,6 @@ func (suite *LeaderObserverTestSuite) TestSyncTargetVersion() { suite.Len(action.SealedInTarget, 1) } -func (suite *LeaderObserverTestSuite) TestCheckTargetVersion() { - collectionID := int64(1001) - observer := suite.observer - - suite.Run("check_channel_blocked", func() { - oldCh := observer.manualCheck - defer func() { - observer.manualCheck = oldCh - }() - - // zero-length channel - observer.manualCheck = make(chan checkRequest) - - ctx, cancel := context.WithCancel(context.Background()) - // cancel context, make test return fast - cancel() - - result := observer.CheckTargetVersion(ctx, collectionID) - suite.False(result) - }) - - suite.Run("check_return_ctx_timeout", func() { - oldCh := observer.manualCheck - defer func() { - observer.manualCheck = oldCh - }() - - // make channel length = 1, task received - observer.manualCheck = make(chan checkRequest, 1) - - ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*200) - defer cancel() - - result := observer.CheckTargetVersion(ctx, collectionID) - suite.False(result) - }) -} - func TestLeaderObserverSuite(t *testing.T) { suite.Run(t, new(LeaderObserverTestSuite)) } diff --git a/internal/querycoordv2/observers/target_observer.go b/internal/querycoordv2/observers/target_observer.go index d1f210fed083e..cc609163d2652 100644 --- a/internal/querycoordv2/observers/target_observer.go +++ b/internal/querycoordv2/observers/target_observer.go @@ -51,36 +51,48 @@ type TargetObserver struct { distMgr *meta.DistributionManager broker meta.Broker - initChan chan initRequest - manualCheck chan checkRequest - nextTargetLastUpdate map[int64]time.Time + initChan chan initRequest + manualCheck chan checkRequest + // nextTargetLastUpdate map[int64]time.Time + nextTargetLastUpdate *typeutil.ConcurrentMap[int64, time.Time] updateChan chan targetUpdateRequest mut sync.Mutex // Guard readyNotifiers readyNotifiers map[int64][]chan struct{} // CollectionID -> Notifiers + dispatcher *taskDispatcher[int64] + stopOnce sync.Once } func NewTargetObserver(meta *meta.Meta, targetMgr *meta.TargetManager, distMgr *meta.DistributionManager, broker meta.Broker) *TargetObserver { - return &TargetObserver{ + result := &TargetObserver{ meta: meta, targetMgr: targetMgr, distMgr: distMgr, broker: broker, manualCheck: make(chan checkRequest, 10), - nextTargetLastUpdate: make(map[int64]time.Time), + nextTargetLastUpdate: typeutil.NewConcurrentMap[int64, time.Time](), updateChan: make(chan targetUpdateRequest), readyNotifiers: make(map[int64][]chan struct{}), initChan: make(chan initRequest), } + + dispatcher := newTaskDispatcher(result.check) + result.dispatcher = dispatcher + return result } func (ob *TargetObserver) Start() { ctx, cancel := context.WithCancel(context.Background()) ob.cancel = cancel + ob.dispatcher.Start() + ob.wg.Add(1) - go ob.schedule(ctx) + go func() { + defer ob.wg.Done() + ob.schedule(ctx) + }() // after target observer start, update target for all collection ob.initChan <- initRequest{} @@ -92,11 +104,12 @@ func (ob *TargetObserver) Stop() { ob.cancel() } ob.wg.Wait() + + ob.dispatcher.Stop() }) } func (ob *TargetObserver) schedule(ctx context.Context) { - defer ob.wg.Done() log.Info("Start update next target loop") ticker := time.NewTicker(params.Params.QueryCoordCfg.UpdateNextTargetInterval.GetAsDuration(time.Second)) @@ -111,16 +124,11 @@ func (ob *TargetObserver) schedule(ctx context.Context) { for _, collectionID := range ob.meta.GetAll() { ob.init(collectionID) } + log.Info("target observer init done") case <-ticker.C: ob.clean() - for _, collectionID := range ob.meta.GetAll() { - ob.check(collectionID) - } - - case req := <-ob.manualCheck: - ob.check(req.CollectionID) - req.Notifier <- ob.targetMgr.IsCurrentTargetExist(req.CollectionID) + ob.dispatcher.AddTask(ob.meta.GetAll()...) case req := <-ob.updateChan: err := ob.updateNextTarget(req.CollectionID) @@ -137,26 +145,17 @@ func (ob *TargetObserver) schedule(ctx context.Context) { } } -// Check checks whether the next target is ready, -// and updates the current target if it is, -// returns true if current target is not nil +// Check whether provided collection is has current target. +// If not, submit a async task into dispatcher. func (ob *TargetObserver) Check(ctx context.Context, collectionID int64) bool { - notifier := make(chan bool) - select { - case ob.manualCheck <- checkRequest{CollectionID: collectionID, Notifier: notifier}: - case <-ctx.Done(): - return false - } - - select { - case result := <-notifier: - return result - case <-ctx.Done(): - return false + result := ob.targetMgr.IsCurrentTargetExist(collectionID) + if !result { + ob.dispatcher.AddTask(collectionID) } + return result } -func (ob *TargetObserver) check(collectionID int64) { +func (ob *TargetObserver) check(ctx context.Context, collectionID int64) { if !ob.meta.Exist(collectionID) { ob.ReleaseCollection(collectionID) ob.targetMgr.RemoveCollection(collectionID) @@ -215,11 +214,12 @@ func (ob *TargetObserver) ReleaseCollection(collectionID int64) { func (ob *TargetObserver) clean() { collectionSet := typeutil.NewUniqueSet(ob.meta.GetAll()...) // for collection which has been removed from target, try to clear nextTargetLastUpdate - for collection := range ob.nextTargetLastUpdate { - if !collectionSet.Contain(collection) { - delete(ob.nextTargetLastUpdate, collection) + ob.nextTargetLastUpdate.Range(func(collectionID int64, _ time.Time) bool { + if !collectionSet.Contain(collectionID) { + ob.nextTargetLastUpdate.Remove(collectionID) } - } + return true + }) ob.mut.Lock() defer ob.mut.Unlock() @@ -238,7 +238,11 @@ func (ob *TargetObserver) shouldUpdateNextTarget(collectionID int64) bool { } func (ob *TargetObserver) isNextTargetExpired(collectionID int64) bool { - return time.Since(ob.nextTargetLastUpdate[collectionID]) > params.Params.QueryCoordCfg.NextTargetSurviveTime.GetAsDuration(time.Second) + lastUpdated, has := ob.nextTargetLastUpdate.Get(collectionID) + if !has { + return true + } + return time.Since(lastUpdated) > params.Params.QueryCoordCfg.NextTargetSurviveTime.GetAsDuration(time.Second) } func (ob *TargetObserver) updateNextTarget(collectionID int64) error { @@ -256,7 +260,7 @@ func (ob *TargetObserver) updateNextTarget(collectionID int64) error { } func (ob *TargetObserver) updateNextTargetTimestamp(collectionID int64) { - ob.nextTargetLastUpdate[collectionID] = time.Now() + ob.nextTargetLastUpdate.Insert(collectionID, time.Now()) } func (ob *TargetObserver) shouldUpdateCurrentTarget(collectionID int64) bool { @@ -279,8 +283,8 @@ func (ob *TargetObserver) shouldUpdateCurrentTarget(collectionID int64) bool { } // and last check historical segment - historicalSegments := ob.targetMgr.GetHistoricalSegmentsByCollection(collectionID, meta.NextTarget) - for _, segment := range historicalSegments { + SealedSegments := ob.targetMgr.GetSealedSegmentsByCollection(collectionID, meta.NextTarget) + for _, segment := range SealedSegments { group := utils.GroupNodesByReplica(ob.meta.ReplicaManager, collectionID, ob.distMgr.LeaderViewManager.GetSealedSegmentDist(segment.GetID())) diff --git a/internal/querycoordv2/observers/target_observer_test.go b/internal/querycoordv2/observers/target_observer_test.go index 2fee0098b0bb0..b5509f6ec9a7b 100644 --- a/internal/querycoordv2/observers/target_observer_test.go +++ b/internal/querycoordv2/observers/target_observer_test.go @@ -126,7 +126,7 @@ func (suite *TargetObserverSuite) SetupTest() { func (suite *TargetObserverSuite) TestTriggerUpdateTarget() { suite.Eventually(func() bool { - return len(suite.targetMgr.GetHistoricalSegmentsByCollection(suite.collectionID, meta.NextTarget)) == 2 && + return len(suite.targetMgr.GetSealedSegmentsByCollection(suite.collectionID, meta.NextTarget)) == 2 && len(suite.targetMgr.GetDmChannelsByCollection(suite.collectionID, meta.NextTarget)) == 2 }, 5*time.Second, 1*time.Second) @@ -168,7 +168,7 @@ func (suite *TargetObserverSuite) TestTriggerUpdateTarget() { GetRecoveryInfoV2(mock.Anything, mock.Anything). Return(suite.nextTargetChannels, suite.nextTargetSegments, nil) suite.Eventually(func() bool { - return len(suite.targetMgr.GetHistoricalSegmentsByCollection(suite.collectionID, meta.NextTarget)) == 3 && + return len(suite.targetMgr.GetSealedSegmentsByCollection(suite.collectionID, meta.NextTarget)) == 3 && len(suite.targetMgr.GetDmChannelsByCollection(suite.collectionID, meta.NextTarget)) == 2 }, 7*time.Second, 1*time.Second) suite.broker.AssertExpectations(suite.T()) @@ -206,7 +206,7 @@ func (suite *TargetObserverSuite) TestTriggerUpdateTarget() { default: } return isReady && - len(suite.targetMgr.GetHistoricalSegmentsByCollection(suite.collectionID, meta.CurrentTarget)) == 3 && + len(suite.targetMgr.GetSealedSegmentsByCollection(suite.collectionID, meta.CurrentTarget)) == 3 && len(suite.targetMgr.GetDmChannelsByCollection(suite.collectionID, meta.CurrentTarget)) == 2 }, 7*time.Second, 1*time.Second) } @@ -273,41 +273,10 @@ func (suite *TargetObserverCheckSuite) SetupTest() { suite.NoError(err) } -func (suite *TargetObserverCheckSuite) TestCheckCtxDone() { - observer := suite.observer - - suite.Run("check_channel_blocked", func() { - oldCh := observer.manualCheck - defer func() { - observer.manualCheck = oldCh - }() - - // zero-length channel - observer.manualCheck = make(chan checkRequest) - - ctx, cancel := context.WithCancel(context.Background()) - // cancel context, make test return fast - cancel() - - result := observer.Check(ctx, suite.collectionID) - suite.False(result) - }) - - suite.Run("check_return_ctx_timeout", func() { - oldCh := observer.manualCheck - defer func() { - observer.manualCheck = oldCh - }() - - // make channel length = 1, task received - observer.manualCheck = make(chan checkRequest, 1) - - ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*200) - defer cancel() - - result := observer.Check(ctx, suite.collectionID) - suite.False(result) - }) +func (s *TargetObserverCheckSuite) TestCheck() { + r := s.observer.Check(context.Background(), s.collectionID) + s.False(r) + s.True(s.observer.dispatcher.tasks.Contain(s.collectionID)) } func TestTargetObserver(t *testing.T) { diff --git a/internal/querycoordv2/observers/task_dispatcher.go b/internal/querycoordv2/observers/task_dispatcher.go new file mode 100644 index 0000000000000..cfede74304cf3 --- /dev/null +++ b/internal/querycoordv2/observers/task_dispatcher.go @@ -0,0 +1,104 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package observers + +import ( + "context" + "sync" + + "github.com/milvus-io/milvus/pkg/util/conc" + "github.com/milvus-io/milvus/pkg/util/paramtable" + "github.com/milvus-io/milvus/pkg/util/typeutil" +) + +// taskDispatcher is the utility to provide task dedup and dispatch feature +type taskDispatcher[K comparable] struct { + tasks *typeutil.ConcurrentSet[K] + pool *conc.Pool[any] + notifyCh chan struct{} + taskRunner task[K] + wg sync.WaitGroup + cancel context.CancelFunc + stopOnce sync.Once +} + +type task[K comparable] func(context.Context, K) + +func newTaskDispatcher[K comparable](runner task[K]) *taskDispatcher[K] { + return &taskDispatcher[K]{ + tasks: typeutil.NewConcurrentSet[K](), + pool: conc.NewPool[any](paramtable.Get().QueryCoordCfg.ObserverTaskParallel.GetAsInt()), + notifyCh: make(chan struct{}, 1), + taskRunner: runner, + } +} + +func (d *taskDispatcher[K]) Start() { + ctx, cancel := context.WithCancel(context.Background()) + d.cancel = cancel + + d.wg.Add(1) + go func() { + defer d.wg.Done() + d.schedule(ctx) + }() +} + +func (d *taskDispatcher[K]) Stop() { + d.stopOnce.Do(func() { + if d.cancel != nil { + d.cancel() + } + d.wg.Wait() + }) +} + +func (d *taskDispatcher[K]) AddTask(keys ...K) { + var added bool + for _, key := range keys { + added = d.tasks.Insert(key) || added + } + if added { + d.notify() + } +} + +func (d *taskDispatcher[K]) notify() { + select { + case d.notifyCh <- struct{}{}: + default: + } +} + +func (d *taskDispatcher[K]) schedule(ctx context.Context) { + for { + select { + case <-ctx.Done(): + return + case <-d.notifyCh: + d.tasks.Range(func(k K) bool { + d.tasks.Insert(k) + d.pool.Submit(func() (any, error) { + d.taskRunner(ctx, k) + d.tasks.Remove(k) + return struct{}{}, nil + }) + return true + }) + } + } +} diff --git a/internal/querycoordv2/server.go b/internal/querycoordv2/server.go index 1d48f39718ff0..8ef648ff01ce6 100644 --- a/internal/querycoordv2/server.go +++ b/internal/querycoordv2/server.go @@ -408,6 +408,10 @@ func (s *Server) startQueryCoord() error { for _, node := range sessions { s.nodeMgr.Add(session.NewNodeInfo(node.ServerID, node.Address)) s.taskScheduler.AddExecutor(node.ServerID) + + if node.Stopping { + s.nodeMgr.Stopping(node.ServerID) + } } s.checkReplicas() for _, node := range sessions { diff --git a/internal/querycoordv2/server_test.go b/internal/querycoordv2/server_test.go index b3dad1843397b..5a0ddecda755b 100644 --- a/internal/querycoordv2/server_test.go +++ b/internal/querycoordv2/server_test.go @@ -167,6 +167,10 @@ func (suite *ServerSuite) TestRecover() { err := suite.server.Stop() suite.NoError(err) + // stopping querynode + downNode := suite.nodes[0] + downNode.Stopping() + suite.server, err = suite.newQueryCoord() suite.NoError(err) suite.hackServer() @@ -176,6 +180,8 @@ func (suite *ServerSuite) TestRecover() { for _, collection := range suite.collections { suite.True(suite.server.meta.Exist(collection)) } + + suite.True(suite.server.nodeMgr.IsStoppingNode(suite.nodes[0].ID)) } func (suite *ServerSuite) TestNodeUp() { diff --git a/internal/querycoordv2/services.go b/internal/querycoordv2/services.go index 2e1c52389dab6..c6adaf8973cf2 100644 --- a/internal/querycoordv2/services.go +++ b/internal/querycoordv2/services.go @@ -876,7 +876,7 @@ func (s *Server) GetShardLeaders(ctx context.Context, req *querypb.GetShardLeade return resp, nil } - currentTargets := s.targetMgr.GetHistoricalSegmentsByCollection(req.GetCollectionID(), meta.CurrentTarget) + currentTargets := s.targetMgr.GetSealedSegmentsByCollection(req.GetCollectionID(), meta.CurrentTarget) for _, channel := range channels { log := log.With(zap.String("channel", channel.GetChannelName())) diff --git a/internal/querycoordv2/services_test.go b/internal/querycoordv2/services_test.go index ed7c18a786c93..0417ca7a7ba32 100644 --- a/internal/querycoordv2/services_test.go +++ b/internal/querycoordv2/services_test.go @@ -1659,7 +1659,7 @@ func (suite *ServiceSuite) assertLoaded(collection int64) { } for _, partitions := range suite.segments[collection] { for _, segment := range partitions { - suite.NotNil(suite.targetMgr.GetHistoricalSegment(collection, segment, meta.NextTarget)) + suite.NotNil(suite.targetMgr.GetSealedSegment(collection, segment, meta.NextTarget)) } } } @@ -1675,7 +1675,7 @@ func (suite *ServiceSuite) assertPartitionLoaded(collection int64, partitions .. continue } for _, segment := range segments { - suite.NotNil(suite.targetMgr.GetHistoricalSegment(collection, segment, meta.CurrentTarget)) + suite.NotNil(suite.targetMgr.GetSealedSegment(collection, segment, meta.CurrentTarget)) } } } @@ -1687,8 +1687,8 @@ func (suite *ServiceSuite) assertReleased(collection int64) { } for _, partitions := range suite.segments[collection] { for _, segment := range partitions { - suite.Nil(suite.targetMgr.GetHistoricalSegment(collection, segment, meta.CurrentTarget)) - suite.Nil(suite.targetMgr.GetHistoricalSegment(collection, segment, meta.NextTarget)) + suite.Nil(suite.targetMgr.GetSealedSegment(collection, segment, meta.CurrentTarget)) + suite.Nil(suite.targetMgr.GetSealedSegment(collection, segment, meta.NextTarget)) } } } diff --git a/internal/querycoordv2/task/scheduler.go b/internal/querycoordv2/task/scheduler.go index 412270a10f321..37bb22627e963 100644 --- a/internal/querycoordv2/task/scheduler.go +++ b/internal/querycoordv2/task/scheduler.go @@ -502,7 +502,7 @@ func (scheduler *taskScheduler) GetNodeSegmentCntDelta(nodeID int64) int { continue } segmentAction := action.(*SegmentAction) - segment := scheduler.targetMgr.GetHistoricalSegment(task.CollectionID(), segmentAction.SegmentID(), meta.NextTarget) + segment := scheduler.targetMgr.GetSealedSegment(task.CollectionID(), segmentAction.SegmentID(), meta.NextTarget) if action.Type() == ActionTypeGrow { delta += int(segment.GetNumOfRows()) } else { @@ -586,9 +586,9 @@ func (scheduler *taskScheduler) isRelated(task Task, node int64) bool { taskType := GetTaskType(task) var segment *datapb.SegmentInfo if taskType == TaskTypeMove || taskType == TaskTypeUpdate { - segment = scheduler.targetMgr.GetHistoricalSegment(task.CollectionID(), task.SegmentID(), meta.CurrentTarget) + segment = scheduler.targetMgr.GetSealedSegment(task.CollectionID(), task.SegmentID(), meta.CurrentTarget) } else { - segment = scheduler.targetMgr.GetHistoricalSegment(task.CollectionID(), task.SegmentID(), meta.NextTarget) + segment = scheduler.targetMgr.GetSealedSegment(task.CollectionID(), task.SegmentID(), meta.NextTarget) } if segment == nil { continue @@ -779,9 +779,9 @@ func (scheduler *taskScheduler) checkSegmentTaskStale(task *SegmentTask) error { taskType := GetTaskType(task) var segment *datapb.SegmentInfo if taskType == TaskTypeMove || taskType == TaskTypeUpdate { - segment = scheduler.targetMgr.GetHistoricalSegment(task.CollectionID(), task.SegmentID(), meta.CurrentTarget) + segment = scheduler.targetMgr.GetSealedSegment(task.CollectionID(), task.SegmentID(), meta.CurrentTarget) } else { - segment = scheduler.targetMgr.GetHistoricalSegment(task.CollectionID(), task.SegmentID(), meta.NextTarget) + segment = scheduler.targetMgr.GetSealedSegment(task.CollectionID(), task.SegmentID(), meta.NextTarget) } if segment == nil { log.Warn("task stale due to the segment to load not exists in targets", diff --git a/internal/querynodev2/delegator/delegator.go b/internal/querynodev2/delegator/delegator.go index 2796fe70269c8..b93ac04ef4db0 100644 --- a/internal/querynodev2/delegator/delegator.go +++ b/internal/querynodev2/delegator/delegator.go @@ -280,7 +280,7 @@ func (sd *shardDelegator) QueryStream(ctx context.Context, req *querypb.QueryReq growing = []SegmentEntry{} } - log.Info("query segments...", + log.Info("query stream segments...", zap.Int("sealedNum", len(sealed)), zap.Int("growingNum", len(growing)), ) diff --git a/internal/querynodev2/handlers.go b/internal/querynodev2/handlers.go index 02bb7e027d1c1..55c731310c7ae 100644 --- a/internal/querynodev2/handlers.go +++ b/internal/querynodev2/handlers.go @@ -264,6 +264,13 @@ func (node *QueryNode) queryChannelStream(ctx context.Context, req *querypb.Quer } func (node *QueryNode) queryStreamSegments(ctx context.Context, req *querypb.QueryRequest, srv streamrpc.QueryStreamServer) error { + log.Debug("received query stream request", + zap.Int64s("outputFields", req.GetReq().GetOutputFieldsId()), + zap.Int64s("segmentIDs", req.GetSegmentIDs()), + zap.Uint64("guaranteeTimestamp", req.GetReq().GetGuaranteeTimestamp()), + zap.Uint64("mvccTimestamp", req.GetReq().GetMvccTimestamp()), + ) + collection := node.manager.Collection.Get(req.Req.GetCollectionID()) if collection == nil { return merr.WrapErrCollectionNotFound(req.Req.GetCollectionID()) diff --git a/pkg/common/version.go b/pkg/common/version.go index 731b001ec938b..643b943b0bf57 100644 --- a/pkg/common/version.go +++ b/pkg/common/version.go @@ -6,5 +6,5 @@ import semver "github.com/blang/semver/v4" var Version semver.Version func init() { - Version, _ = semver.Parse("2.3.0") + Version, _ = semver.Parse("2.3.2") } diff --git a/pkg/mq/msgstream/mock_msgstream_factory.go b/pkg/mq/msgstream/mock_msgstream_factory.go index 1d0c6ff129ed2..6f5f4e7b868f7 100644 --- a/pkg/mq/msgstream/mock_msgstream_factory.go +++ b/pkg/mq/msgstream/mock_msgstream_factory.go @@ -53,7 +53,7 @@ type MockFactory_NewMsgStream_Call struct { } // NewMsgStream is a helper method to define mock.On call -// - ctx context.Context +// - ctx context.Context func (_e *MockFactory_Expecter) NewMsgStream(ctx interface{}) *MockFactory_NewMsgStream_Call { return &MockFactory_NewMsgStream_Call{Call: _e.mock.On("NewMsgStream", ctx)} } @@ -97,7 +97,7 @@ type MockFactory_NewMsgStreamDisposer_Call struct { } // NewMsgStreamDisposer is a helper method to define mock.On call -// - ctx context.Context +// - ctx context.Context func (_e *MockFactory_Expecter) NewMsgStreamDisposer(ctx interface{}) *MockFactory_NewMsgStreamDisposer_Call { return &MockFactory_NewMsgStreamDisposer_Call{Call: _e.mock.On("NewMsgStreamDisposer", ctx)} } @@ -151,7 +151,7 @@ type MockFactory_NewTtMsgStream_Call struct { } // NewTtMsgStream is a helper method to define mock.On call -// - ctx context.Context +// - ctx context.Context func (_e *MockFactory_Expecter) NewTtMsgStream(ctx interface{}) *MockFactory_NewTtMsgStream_Call { return &MockFactory_NewTtMsgStream_Call{Call: _e.mock.On("NewTtMsgStream", ctx)} } diff --git a/pkg/util/paramtable/component_param.go b/pkg/util/paramtable/component_param.go index 61f3e105928fb..6fed7efad5f4d 100644 --- a/pkg/util/paramtable/component_param.go +++ b/pkg/util/paramtable/component_param.go @@ -1212,6 +1212,7 @@ type queryCoordConfig struct { CheckHealthRPCTimeout ParamItem `refreshable:"true"` BrokerTimeout ParamItem `refreshable:"false"` CollectionRecoverTimesLimit ParamItem `refreshable:"true"` + ObserverTaskParallel ParamItem `refreshable:"false"` } func (p *queryCoordConfig) init(base *BaseTable) { @@ -1523,6 +1524,16 @@ func (p *queryCoordConfig) init(base *BaseTable) { Export: true, } p.CollectionRecoverTimesLimit.Init(base.mgr) + + p.ObserverTaskParallel = ParamItem{ + Key: "queryCoord.observerTaskParallel", + Version: "2.3.2", + DefaultValue: "16", + PanicIfEmpty: true, + Doc: "the parallel observer dispatcher task number", + Export: true, + } + p.ObserverTaskParallel.Init(base.mgr) } // ///////////////////////////////////////////////////////////////////////////// diff --git a/pkg/util/paramtable/grpc_param.go b/pkg/util/paramtable/grpc_param.go index 6264e4a71b9f7..bd4a91b629de0 100644 --- a/pkg/util/paramtable/grpc_param.go +++ b/pkg/util/paramtable/grpc_param.go @@ -46,7 +46,7 @@ const ( DefaultMaxAttempts = 10 DefaultInitialBackoff float64 = 0.2 DefaultMaxBackoff float64 = 10 - DefaultCompressionEnabled bool = true + DefaultCompressionEnabled bool = false ProxyInternalPort = 19529 ProxyExternalPort = 19530 diff --git a/scripts/install_deps.sh b/scripts/install_deps.sh index c5b0e3f8a7fd6..1e208f18504c6 100755 --- a/scripts/install_deps.sh +++ b/scripts/install_deps.sh @@ -24,7 +24,7 @@ function install_linux_deps() { clang-format-10 clang-tidy-10 lcov libtool m4 autoconf automake python3 python3-pip \ pkg-config uuid-dev libaio-dev libgoogle-perftools-dev - sudo pip3 install conan==1.58.0 + sudo pip3 install conan==1.61.0 elif [[ -x "$(command -v yum)" ]]; then # for CentOS devtoolset-11 sudo yum install -y epel-release centos-release-scl-rh @@ -35,7 +35,7 @@ function install_linux_deps() { libaio libuuid-devel zip unzip \ ccache lcov libtool m4 autoconf automake - sudo pip3 install conan==1.58.0 + sudo pip3 install conan==1.61.0 echo "source scl_source enable devtoolset-11" | sudo tee -a /etc/profile.d/devtoolset-11.sh echo "source scl_source enable llvm-toolset-11.0" | sudo tee -a /etc/profile.d/llvm-toolset-11.sh echo "export CLANG_TOOLS_PATH=/opt/rh/llvm-toolset-11.0/root/usr/bin" | sudo tee -a /etc/profile.d/llvm-toolset-11.sh @@ -49,7 +49,7 @@ function install_linux_deps() { if [ ! $cmake_version ] || [ `expr $cmake_version \>= 3.24` -eq 0 ]; then echo "cmake version $cmake_version is less than 3.24, wait to installing ..." wget -qO- "https://cmake.org/files/v3.24/cmake-3.24.0-linux-x86_64.tar.gz" | sudo tar --strip-components=1 -xz -C /usr/local - else + else echo "cmake version is $cmake_version" fi } @@ -60,7 +60,7 @@ function install_mac_deps() { export PATH="/usr/local/opt/grep/libexec/gnubin:$PATH" brew update && brew upgrade && brew cleanup - pip3 install conan==1.58.0 + pip3 install conan==1.61.0 if [[ $(arch) == 'arm64' ]]; then brew install openssl diff --git a/scripts/install_deps_msys.sh b/scripts/install_deps_msys.sh index 670ab81e8d8df..793612286789d 100644 --- a/scripts/install_deps_msys.sh +++ b/scripts/install_deps_msys.sh @@ -21,9 +21,9 @@ pacmanInstall() mingw-w64-x86_64-python2 \ mingw-w64-x86_64-python-pip \ mingw-w64-x86_64-diffutils \ - mingw-w64-x86_64-go + mingw-w64-x86_64-go - pip3 install conan==1.58.0 + pip3 install conan==1.61.0 } updateKey() diff --git a/tests/python_client/chaos/checker.py b/tests/python_client/chaos/checker.py index b4571468575ab..641bf74e8f683 100644 --- a/tests/python_client/chaos/checker.py +++ b/tests/python_client/chaos/checker.py @@ -158,7 +158,7 @@ def get_stage_success_rate(self): data_before_chaos = group[group['start_time'] < chaos_start_time].agg( {'success_rate': 'mean', 'failed_count': 'sum', 'success_count': 'sum'}) data_during_chaos = group[ - (group['start_time'] >= chaos_start_time) & (group['start_time'] <= recovery_time)].agg( + (group['start_time'] >= chaos_start_time) & (group['start_time'] <= chaos_end_time)].agg( {'success_rate': 'mean', 'failed_count': 'sum', 'success_count': 'sum'}) data_after_chaos = group[group['start_time'] > recovery_time].agg( {'success_rate': 'mean', 'failed_count': 'sum', 'success_count': 'sum'}) diff --git a/tests/python_client/common/common_func.py b/tests/python_client/common/common_func.py index 966932784f87c..23dc77a89c1a7 100644 --- a/tests/python_client/common/common_func.py +++ b/tests/python_client/common/common_func.py @@ -96,6 +96,14 @@ def gen_json_field(name=ct.default_json_field_name, description=ct.default_desc, return json_field +def gen_array_field(name=ct.default_array_field_name, element_type=DataType.INT64, max_capacity=ct.default_max_capacity, + description=ct.default_desc, is_primary=False, **kwargs): + array_field, _ = ApiFieldSchemaWrapper().init_field_schema(name=name, dtype=DataType.ARRAY, + element_type=element_type, max_capacity=max_capacity, + description=description, is_primary=is_primary, **kwargs) + return array_field + + def gen_int8_field(name=ct.default_int8_field_name, description=ct.default_desc, is_primary=False, **kwargs): int8_field, _ = ApiFieldSchemaWrapper().init_field_schema(name=name, dtype=DataType.INT8, description=description, is_primary=is_primary, **kwargs) @@ -170,6 +178,34 @@ def gen_default_collection_schema(description=ct.default_desc, primary_field=ct. return schema +def gen_array_collection_schema(description=ct.default_desc, primary_field=ct.default_int64_field_name, auto_id=False, + dim=ct.default_dim, enable_dynamic_field=False, max_capacity=ct.default_max_capacity, + max_length=100, with_json=False, **kwargs): + if enable_dynamic_field: + if primary_field is ct.default_int64_field_name: + fields = [gen_int64_field(), gen_float_vec_field(dim=dim)] + elif primary_field is ct.default_string_field_name: + fields = [gen_string_field(), gen_float_vec_field(dim=dim)] + else: + log.error("Primary key only support int or varchar") + assert False + else: + fields = [gen_int64_field(), gen_float_vec_field(dim=dim), gen_json_field(), + gen_array_field(name=ct.default_int32_array_field_name, element_type=DataType.INT32, + max_capacity=max_capacity), + gen_array_field(name=ct.default_float_array_field_name, element_type=DataType.FLOAT, + max_capacity=max_capacity), + gen_array_field(name=ct.default_string_array_field_name, element_type=DataType.VARCHAR, + max_capacity=max_capacity, max_length=max_length)] + if with_json is False: + fields.remove(gen_json_field()) + + schema, _ = ApiCollectionSchemaWrapper().init_collection_schema(fields=fields, description=description, + primary_field=primary_field, auto_id=auto_id, + enable_dynamic_field=enable_dynamic_field, **kwargs) + return schema + + def gen_bulk_insert_collection_schema(description=ct.default_desc, primary_field=ct.default_int64_field_name, with_varchar_field=True, auto_id=False, dim=ct.default_dim, enable_dynamic_field=False, with_json=False): if enable_dynamic_field: @@ -359,6 +395,33 @@ def gen_default_data_for_upsert(nb=ct.default_nb, dim=ct.default_dim, start=0, s return df, float_values +def gen_array_dataframe_data(nb=ct.default_nb, dim=ct.default_dim, start=0, + array_length=ct.default_max_capacity, with_json=False, random_primary_key=False): + if not random_primary_key: + int_values = pd.Series(data=[i for i in range(start, start + nb)]) + else: + int_values = pd.Series(data=random.sample(range(start, start + nb), nb)) + float_vec_values = gen_vectors(nb, dim) + json_values = [{"number": i, "float": i * 1.0} for i in range(start, start + nb)] + + int32_values = pd.Series(data=[[np.int32(j) for j in range(i, i + array_length)] for i in range(start, start + nb)]) + float_values = pd.Series(data=[[np.float32(j) for j in range(i, i + array_length)] for i in range(start, start + nb)]) + string_values = pd.Series(data=[[str(j) for j in range(i, i + array_length)] for i in range(start, start + nb)]) + + df = pd.DataFrame({ + ct.default_int64_field_name: int_values, + ct.default_float_vec_field_name: float_vec_values, + ct.default_json_field_name: json_values, + ct.default_int32_array_field_name: int32_values, + ct.default_float_array_field_name: float_values, + ct.default_string_array_field_name: string_values, + }) + if with_json is False: + df.drop(ct.default_json_field_name, axis=1, inplace=True) + + return df + + def gen_dataframe_multi_vec_fields(vec_fields, nb=ct.default_nb): """ gen dataframe data for fields: int64, float, float_vec and vec_fields @@ -683,6 +746,25 @@ def gen_data_by_type(field, nb=None, start=None): if nb is None: return [random.random() for i in range(dim)] return [[random.random() for i in range(dim)] for _ in range(nb)] + if data_type == DataType.ARRAY: + max_capacity = field.params['max_capacity'] + element_type = field.element_type + if element_type == DataType.INT32: + if nb is None: + return [random.randint(-2147483648, 2147483647) for _ in range(max_capacity)] + return [[random.randint(-2147483648, 2147483647) for _ in range(max_capacity)] for _ in range(nb)] + if element_type == DataType.FLOAT: + if nb is None: + return [np.float32(random.random()) for _ in range(max_capacity)] + return [[np.float32(random.random()) for _ in range(max_capacity)] for _ in range(nb)] + if element_type == DataType.VARCHAR: + max_length = field.params['max_length'] + max_length = min(20, max_length - 1) + length = random.randint(0, max_length) + if nb is None: + return ["".join([chr(random.randint(97, 122)) for _ in range(length)]) for _ in range(max_capacity)] + return [["".join([chr(random.randint(97, 122)) for _ in range(length)]) for _ in range(max_capacity)] for _ in range(nb)] + return None @@ -986,6 +1068,21 @@ def gen_json_field_expressions(): return expressions +def gen_array_field_expressions(): + expressions = [ + "int32_array[0] > 0", + "0 <= int32_array[0] < 400 or 1000 > float_array[1] >= 500", + "int32_array[1] not in [1, 2, 3]", + "int32_array[1] in [1, 2, 3] and string_array[1] != '2'", + "int32_array == [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]", + "int32_array[1] + 1 == 3 && int32_array[0] - 1 != 1", + "int32_array[1] % 100 == 0 && string_array[1] in ['1', '2']", + "int32_array[1] in [300/2, -10*30+800, (200-100)*2] " + "or (float_array[1] <= -4**5/2 || 100 <= int32_array[1] < 200)" + ] + return expressions + + def gen_field_compare_expressions(fields1=None, fields2=None): if fields1 is None: fields1 = ["int64_1"] @@ -1218,25 +1315,31 @@ def index_to_dict(index): def assert_json_contains(expr, list_data): + opposite = False + if expr.startswith("not"): + opposite = True + expr = expr.split("not ", 1)[1] result_ids = [] expr_prefix = expr.split('(', 1)[0] exp_ids = eval(expr.split(', ', 1)[1].split(')', 1)[0]) - if expr_prefix in ["json_contains", "JSON_CONTAINS"]: + if expr_prefix in ["json_contains", "JSON_CONTAINS", "array_contains", "ARRAY_CONTAINS"]: for i in range(len(list_data)): if exp_ids in list_data[i]: result_ids.append(i) - elif expr_prefix in ["json_contains_all", "JSON_CONTAINS_ALL"]: + elif expr_prefix in ["json_contains_all", "JSON_CONTAINS_ALL", "array_contains_all", "ARRAY_CONTAINS_ALL"]: for i in range(len(list_data)): set_list_data = set(tuple(element) if isinstance(element, list) else element for element in list_data[i]) if set(exp_ids).issubset(set_list_data): result_ids.append(i) - elif expr_prefix in ["json_contains_any", "JSON_CONTAINS_ANY"]: + elif expr_prefix in ["json_contains_any", "JSON_CONTAINS_ANY", "array_contains_any", "ARRAY_CONTAINS_ANY"]: for i in range(len(list_data)): set_list_data = set(tuple(element) if isinstance(element, list) else element for element in list_data[i]) if set(exp_ids) & set_list_data: result_ids.append(i) else: log.warning("unknown expr: %s" % expr) + if opposite: + result_ids = [i for i in range(len(list_data)) if i not in result_ids] return result_ids diff --git a/tests/python_client/common/common_type.py b/tests/python_client/common/common_type.py index 7db80d6d84837..bea7dae0a3805 100644 --- a/tests/python_client/common/common_type.py +++ b/tests/python_client/common/common_type.py @@ -8,6 +8,7 @@ default_dim = 128 default_nb = 2000 default_nb_medium = 5000 +default_max_capacity = 100 default_top_k = 10 default_nq = 2 default_limit = 10 @@ -38,6 +39,10 @@ default_double_field_name = "double" default_string_field_name = "varchar" default_json_field_name = "json_field" +default_array_field_name = "int_array" +default_int32_array_field_name = "int32_array" +default_float_array_field_name = "float_array" +default_string_array_field_name = "string_array" default_float_vec_field_name = "float_vector" another_float_vec_field_name = "float_vector1" default_binary_vec_field_name = "binary_vector" diff --git a/tests/python_client/requirements.txt b/tests/python_client/requirements.txt index b24114dd006de..da6e9bbcfa562 100644 --- a/tests/python_client/requirements.txt +++ b/tests/python_client/requirements.txt @@ -12,7 +12,7 @@ allure-pytest==2.7.0 pytest-print==0.2.1 pytest-level==0.1.1 pytest-xdist==2.5.0 -pymilvus==2.3.1.post1.dev8 +pymilvus==2.3.1.post1.dev18 pytest-rerunfailures==9.1.1 git+https://github.com/Projectplace/pytest-tags ndg-httpsclient diff --git a/tests/python_client/testcases/test_collection.py b/tests/python_client/testcases/test_collection.py index 5c24ae847b281..3d6330e0320d9 100644 --- a/tests/python_client/testcases/test_collection.py +++ b/tests/python_client/testcases/test_collection.py @@ -4,6 +4,7 @@ import pandas as pd import pytest +from pymilvus import DataType from base.client_base import TestcaseBase from common import common_func as cf from common import common_type as ct @@ -3060,23 +3061,22 @@ def test_collection_describe(self): collection_w.create_index(ct.default_float_vec_field_name, index_params=ct.default_flat_index) description = \ {'collection_name': c_name, 'auto_id': False, 'num_shards': ct.default_shards_num, 'description': '', - 'fields': [{'field_id': 100, 'name': 'int64', 'description': '', 'type': 5, - 'params': {}, 'is_primary': True, 'element_type': 0,}, + 'fields': [{'field_id': 100, 'name': 'int64', 'description': '', 'type': 5, 'params': {}, + 'is_primary': True, 'element_type': 0, "auto_id": False, "is_partition_key": False, + "is_dynamic": False}, {'field_id': 101, 'name': 'float', 'description': '', 'type': 10, 'params': {}, - 'element_type': 0,}, + 'element_type': 0}, {'field_id': 102, 'name': 'varchar', 'description': '', 'type': 21, - 'params': {'max_length': 65535}, 'element_type': 0,}, + 'params': {'max_length': 65535}, 'element_type': 0}, {'field_id': 103, 'name': 'json_field', 'description': '', 'type': 23, 'params': {}, - 'element_type': 0,}, + 'element_type': 0}, {'field_id': 104, 'name': 'float_vector', 'description': '', 'type': 101, 'params': {'dim': 128}, 'element_type': 0}], - 'aliases': [], 'consistency_level': 0, 'properties': [], 'num_partitions': 1} + 'aliases': [], 'consistency_level': 0, 'properties': [], 'num_partitions': 1, + "enable_dynamic_field": False} res = collection_w.describe()[0] del res['collection_id'] log.info(res) - assert description['fields'] == res['fields'], description['aliases'] == res['aliases'] - del description['fields'], res['fields'], description['aliases'], res['aliases'] - del description['properties'], res['properties'] assert description == res @@ -3820,7 +3820,7 @@ def test_collection_string_field_is_primary_and_auto_id(self): class TestCollectionJSON(TestcaseBase): """ ****************************************************************** - The following cases are used to test about string + The following cases are used to test about json ****************************************************************** """ @pytest.mark.tags(CaseLabel.L1) @@ -3895,3 +3895,189 @@ def test_collection_multiple_json_fields_supported_primary_key(self, primary_fie self.collection_wrap.init_collection(name=c_name, schema=schema, check_task=CheckTasks.check_collection_property, check_items={exp_name: c_name, exp_schema: schema}) + + +class TestCollectionARRAY(TestcaseBase): + """ + ****************************************************************** + The following cases are used to test about array + ****************************************************************** + """ + + @pytest.mark.tags(CaseLabel.L2) + def test_collection_array_field_element_type_not_exist(self): + """ + target: test create collection with ARRAY field without element type + method: create collection with one array field without element type + expected: Raise exception + """ + int_field = cf.gen_int64_field(is_primary=True) + vec_field = cf.gen_float_vec_field() + array_field = cf.gen_array_field(element_type=None) + array_schema = cf.gen_collection_schema([int_field, vec_field, array_field]) + self.init_collection_wrap(schema=array_schema, check_task=CheckTasks.err_res, + check_items={ct.err_code: 65535, ct.err_msg: "element data type None is not valid"}) + + @pytest.mark.tags(CaseLabel.L2) + # @pytest.mark.skip("issue #27522") + @pytest.mark.parametrize("element_type", [1001, 'a', [], (), {1}, DataType.BINARY_VECTOR, + DataType.FLOAT_VECTOR, DataType.JSON, DataType.ARRAY]) + def test_collection_array_field_element_type_invalid(self, element_type): + """ + target: Create a field with invalid element_type + method: Create a field with invalid element_type + 1. Type not in DataType: 1, 'a', ... + 2. Type in DataType: binary_vector, float_vector, json_field, array_field + expected: Raise exception + """ + int_field = cf.gen_int64_field(is_primary=True) + vec_field = cf.gen_float_vec_field() + array_field = cf.gen_array_field(element_type=element_type) + array_schema = cf.gen_collection_schema([int_field, vec_field, array_field]) + error = {ct.err_code: 65535, ct.err_msg: "element data type None is not valid"} + if element_type in ['a', {1}]: + error = {ct.err_code: 1, ct.err_msg: "Unexpected error"} + self.init_collection_wrap(schema=array_schema, check_task=CheckTasks.err_res, check_items=error) + + @pytest.mark.tags(CaseLabel.L2) + def test_collection_array_field_no_capacity(self): + """ + target: Create a field without giving max_capacity + method: Create a field without giving max_capacity + expected: Raise exception + """ + int_field = cf.gen_int64_field(is_primary=True) + vec_field = cf.gen_float_vec_field() + array_field = cf.gen_array_field(max_capacity=None) + array_schema = cf.gen_collection_schema([int_field, vec_field, array_field]) + self.init_collection_wrap(schema=array_schema, check_task=CheckTasks.err_res, + check_items={ct.err_code: 65535, + ct.err_msg: "the value of max_capacity must be an integer"}) + + @pytest.mark.tags(CaseLabel.L2) + @pytest.mark.parametrize("max_capacity", [[], 'a', (), -1, 4097]) + def test_collection_array_field_invalid_capacity(self, max_capacity): + """ + target: Create a field with invalid max_capacity + method: Create a field with invalid max_capacity + 1. Type invalid: [], 'a', () + 2. Value invalid: <0, >max_capacity(4096) + expected: Raise exception + """ + int_field = cf.gen_int64_field(is_primary=True) + vec_field = cf.gen_float_vec_field() + array_field = cf.gen_array_field(max_capacity=max_capacity) + array_schema = cf.gen_collection_schema([int_field, vec_field, array_field]) + self.init_collection_wrap(schema=array_schema, check_task=CheckTasks.err_res, + check_items={ct.err_code: 65535, + ct.err_msg: "the maximum capacity specified for a " + "Array should be in (0, 4096]"}) + + @pytest.mark.tags(CaseLabel.L2) + def test_collection_string_array_without_max_length(self): + """ + target: Create string array without giving max length + method: Create string array without giving max length + expected: Raise exception + """ + int_field = cf.gen_int64_field(is_primary=True) + vec_field = cf.gen_float_vec_field() + array_field = cf.gen_array_field(element_type=DataType.VARCHAR) + array_schema = cf.gen_collection_schema([int_field, vec_field, array_field]) + self.init_collection_wrap(schema=array_schema, check_task=CheckTasks.err_res, + check_items={ct.err_code: 65535, + ct.err_msg: "type param(max_length) should be specified for " + "varChar field of collection"}) + + @pytest.mark.tags(CaseLabel.L2) + @pytest.mark.parametrize("max_length", [[], 'a', (), -1, 65536]) + def test_collection_string_array_max_length_invalid(self, max_length): + """ + target: Create string array with invalid max length + method: Create string array with invalid max length + 1. Type invalid: [], 'a', () + 2. Value invalid: <0, >max_length(65535) + expected: Raise exception + """ + int_field = cf.gen_int64_field(is_primary=True) + vec_field = cf.gen_float_vec_field() + array_field = cf.gen_array_field(element_type=DataType.VARCHAR, max_length=max_length) + array_schema = cf.gen_collection_schema([int_field, vec_field, array_field]) + self.init_collection_wrap(schema=array_schema, check_task=CheckTasks.err_res, + check_items={ct.err_code: 65535, + ct.err_msg: "the maximum length specified for a VarChar " + "should be in (0, 65535]"}) + + @pytest.mark.tags(CaseLabel.L2) + def test_collection_array_field_all_datatype(self): + """ + target: test create collection with ARRAY field all data type + method: 1. Create field respectively: int8, int16, int32, int64, varchar, bool, float, double + 2. Insert data respectively: int8, int16, int32, int64, varchar, bool, float, double + expected: Raise exception + """ + # Create field respectively + nb = ct.default_nb + pk_field = cf.gen_int64_field(is_primary=True) + vec_field = cf.gen_float_vec_field() + int8_array = cf.gen_array_field(name="int8_array", element_type=DataType.INT8, max_capacity=nb) + int16_array = cf.gen_array_field(name="int16_array", element_type=DataType.INT16, max_capacity=nb) + int32_array = cf.gen_array_field(name="int32_array", element_type=DataType.INT32, max_capacity=nb) + int64_array = cf.gen_array_field(name="int64_array", element_type=DataType.INT64, max_capacity=nb) + bool_array = cf.gen_array_field(name="bool_array", element_type=DataType.BOOL, max_capacity=nb) + float_array = cf.gen_array_field(name="float_array", element_type=DataType.FLOAT, max_capacity=nb) + double_array = cf.gen_array_field(name="double_array", element_type=DataType.DOUBLE, max_capacity=nb) + string_array = cf.gen_array_field(name="string_array", element_type=DataType.VARCHAR, max_capacity=nb, + max_length=100) + array_schema = cf.gen_collection_schema([pk_field, vec_field, int8_array, int16_array, int32_array, + int64_array, bool_array, float_array, double_array, string_array]) + collection_w = self.init_collection_wrap(schema=array_schema, + check_task=CheckTasks.check_collection_property, + check_items={exp_schema: array_schema}) + + # check array in collection.describe() + res = collection_w.describe()[0] + log.info(res) + fields = [ + {"field_id": 100, "name": "int64", "description": "", "type": 5, "params": {}, + "element_type": 0, "is_primary": True}, + {"field_id": 101, "name": "float_vector", "description": "", "type": 101, + "params": {"dim": ct.default_dim}, "element_type": 0}, + {"field_id": 102, "name": "int8_array", "description": "", "type": 22, + "params": {"max_capacity": "2000"}, "element_type": 2}, + {"field_id": 103, "name": "int16_array", "description": "", "type": 22, + "params": {"max_capacity": "2000"}, "element_type": 3}, + {"field_id": 104, "name": "int32_array", "description": "", "type": 22, + "params": {"max_capacity": "2000"}, "element_type": 4}, + {"field_id": 105, "name": "int64_array", "description": "", "type": 22, + "params": {"max_capacity": "2000"}, "element_type": 5}, + {"field_id": 106, "name": "bool_array", "description": "", "type": 22, + "params": {"max_capacity": "2000"}, "element_type": 1}, + {"field_id": 107, "name": "float_array", "description": "", "type": 22, + "params": {"max_capacity": "2000"}, "element_type": 10}, + {"field_id": 108, "name": "double_array", "description": "", "type": 22, + "params": {"max_capacity": "2000"}, "element_type": 11}, + {"field_id": 109, "name": "string_array", "description": "", "type": 22, + "params": {"max_length": "100", "max_capacity": "2000"}, "element_type": 21} + ] + assert res["fields"] == fields + + # Insert data respectively + nb = 10 + pk_values = [i for i in range(nb)] + float_vec = cf.gen_vectors(nb, ct.default_dim) + int8_values = [[numpy.int8(j) for j in range(nb)] for i in range(nb)] + int16_values = [[numpy.int16(j) for j in range(nb)] for i in range(nb)] + int32_values = [[numpy.int32(j) for j in range(nb)] for i in range(nb)] + int64_values = [[numpy.int64(j) for j in range(nb)] for i in range(nb)] + bool_values = [[numpy.bool_(j) for j in range(nb)] for i in range(nb)] + float_values = [[numpy.float32(j) for j in range(nb)] for i in range(nb)] + double_values = [[numpy.double(j) for j in range(nb)] for i in range(nb)] + string_values = [[str(j) for j in range(nb)] for i in range(nb)] + data = [pk_values, float_vec, int8_values, int16_values, int32_values, int64_values, + bool_values, float_values, double_values, string_values] + collection_w.insert(data) + + # check insert successfully + collection_w.flush() + collection_w.num_entities == nb diff --git a/tests/python_client/testcases/test_delete.py b/tests/python_client/testcases/test_delete.py index 7bfaab9abbf47..19fdc4556667b 100644 --- a/tests/python_client/testcases/test_delete.py +++ b/tests/python_client/testcases/test_delete.py @@ -1,6 +1,7 @@ import random import time import pandas as pd +import numpy as np import pytest from base.client_base import TestcaseBase @@ -1892,6 +1893,52 @@ def test_delete_normal_expressions(self, expression, enable_dynamic_field): # query to check collection_w.query(f"int64 in {filter_ids}", check_task=CheckTasks.check_query_empty) + @pytest.mark.tags(CaseLabel.L2) + @pytest.mark.parametrize("expression", cf.gen_array_field_expressions()) + @pytest.mark.parametrize("enable_dynamic_field", [True, False]) + def test_delete_array_expressions(self, expression, enable_dynamic_field): + """ + target: test delete entities using normal expression + method: delete using normal expression + expected: delete successfully + """ + # 1. create a collection + nb = ct.default_nb + schema = cf.gen_array_collection_schema() + collection_w = self.init_collection_wrap(schema=schema, enable_dynamic_field=enable_dynamic_field) + + # 2. insert data + array_length = 100 + data = [] + for i in range(nb): + arr = {ct.default_int64_field_name: i, + ct.default_float_vec_field_name: cf.gen_vectors(1, ct.default_dim)[0], + ct.default_int32_array_field_name: [np.int32(i) for i in range(array_length)], + ct.default_float_array_field_name: [np.float32(i) for i in range(array_length)], + ct.default_string_array_field_name: [str(i) for i in range(array_length)]} + data.append(arr) + collection_w.insert(data) + collection_w.flush() + + # 3. filter result with expression in collection + expression = expression.replace("&&", "and").replace("||", "or") + filter_ids = [] + for i in range(nb): + int32_array = data[i][ct.default_int32_array_field_name] + float_array = data[i][ct.default_float_array_field_name] + string_array = data[i][ct.default_string_array_field_name] + if not expression or eval(expression): + filter_ids.append(i) + + # 4. delete by array expression + collection_w.create_index(ct.default_float_vec_field_name, ct.default_flat_index) + collection_w.load() + res = collection_w.delete(expression)[0] + assert res.delete_count == len(filter_ids) + + # 5. query to check + collection_w.query(expression, check_task=CheckTasks.check_query_empty) + @pytest.mark.tags(CaseLabel.L1) @pytest.mark.parametrize("field_name", ["varchar", "json_field['string']", "NewStr"]) @pytest.mark.parametrize("like", ["like", "LIKE"]) @@ -1981,7 +2028,7 @@ def test_delete_expr_json_contains_base(self, expr_prefix, field_name, enable_dy collection_w = self.init_collection_general(prefix, False, enable_dynamic_field=enable_dynamic_field)[0] # insert - listMix = [[i, i + 2] for i in range(ct.default_nb)] # only int + listMix = [[i, i + 2] for i in range(ct.default_nb)] # only int if enable_dynamic_field: data = cf.gen_default_rows_data() for i in range(ct.default_nb): diff --git a/tests/python_client/testcases/test_high_level_api.py b/tests/python_client/testcases/test_high_level_api.py index d5f40b908e4fc..3d7c215cc8b3a 100644 --- a/tests/python_client/testcases/test_high_level_api.py +++ b/tests/python_client/testcases/test_high_level_api.py @@ -42,6 +42,8 @@ default_float_field_name = ct.default_float_field_name default_bool_field_name = ct.default_bool_field_name default_string_field_name = ct.default_string_field_name +default_int32_array_field_name = ct.default_int32_array_field_name +default_string_array_field_name = ct.default_string_array_field_name class TestHighLevelApi(TestcaseBase): @@ -195,6 +197,41 @@ def test_high_level_search_query_default(self): "primary_field": default_primary_key_field_name}) client_w.drop_collection(client, collection_name) + @pytest.mark.tags(CaseLabel.L1) + def test_high_level_array_insert_search(self): + """ + target: test search (high level api) normal case + method: create connection, collection, insert and search + expected: search/query successfully + """ + client = self._connect(enable_high_level_api=True) + collection_name = cf.gen_unique_str(prefix) + # 1. create collection + client_w.create_collection(client, collection_name, default_dim) + collections = client_w.list_collections(client)[0] + assert collection_name in collections + # 2. insert + rng = np.random.default_rng(seed=19530) + rows = [{ + default_primary_key_field_name: i, + default_vector_field_name: list(rng.random((1, default_dim))[0]), + default_float_field_name: i * 1.0, + default_int32_array_field_name: [i, i+1, i+2], + default_string_array_field_name: [str(i), str(i + 1), str(i + 2)] + } for i in range(default_nb)] + client_w.insert(client, collection_name, rows) + client_w.flush(client, collection_name) + assert client_w.num_entities(client, collection_name)[0] == default_nb + # 3. search + vectors_to_search = rng.random((1, default_dim)) + insert_ids = [i for i in range(default_nb)] + client_w.search(client, collection_name, vectors_to_search, + check_task=CheckTasks.check_search_results, + check_items={"enable_high_level_api": True, + "nq": len(vectors_to_search), + "ids": insert_ids, + "limit": default_limit}) + @pytest.mark.tags(CaseLabel.L2) @pytest.mark.skip(reason="issue 25110") def test_high_level_search_query_string(self): diff --git a/tests/python_client/testcases/test_index.py b/tests/python_client/testcases/test_index.py index d5d080a8d8553..2723beb3b6f3a 100644 --- a/tests/python_client/testcases/test_index.py +++ b/tests/python_client/testcases/test_index.py @@ -247,6 +247,21 @@ def test_index_create_on_scalar_field(self): ct.err_msg: f"there is no vector index on collection: {collection_w.name}, " f"please create index firstly"}) + @pytest.mark.tags(CaseLabel.L2) + def test_index_create_on_array_field(self): + """ + target: Test create index on array field + method: create index on array field + expected: raise exception + """ + schema = cf.gen_array_collection_schema() + collection_w = self.init_collection_wrap(schema=schema) + error = {ct.err_code: 1100, + ct.err_msg: "create index on json field is not supported: expected=supported field, " + "actual=create index on Array field: invalid parameter"} + collection_w.create_index(ct.default_string_array_field_name, {}, + check_task=CheckTasks.err_res, check_items=error) + @pytest.mark.tags(CaseLabel.L1) def test_index_collection_empty(self): """ diff --git a/tests/python_client/testcases/test_insert.py b/tests/python_client/testcases/test_insert.py index 518efb88c2eb2..8563417b8af4c 100644 --- a/tests/python_client/testcases/test_insert.py +++ b/tests/python_client/testcases/test_insert.py @@ -58,8 +58,7 @@ def test_insert_dataframe_data(self): df = cf.gen_default_dataframe_data(ct.default_nb) mutation_res, _ = collection_w.insert(data=df) assert mutation_res.insert_count == ct.default_nb - assert mutation_res.primary_keys == df[ct.default_int64_field_name].values.tolist( - ) + assert mutation_res.primary_keys == df[ct.default_int64_field_name].values.tolist() assert collection_w.num_entities == ct.default_nb @pytest.mark.tags(CaseLabel.L0) @@ -204,8 +203,7 @@ def test_insert_binary_dataframe(self): df, _ = cf.gen_default_binary_dataframe_data(ct.default_nb) mutation_res, _ = collection_w.insert(data=df) assert mutation_res.insert_count == ct.default_nb - assert mutation_res.primary_keys == df[ct.default_int64_field_name].values.tolist( - ) + assert mutation_res.primary_keys == df[ct.default_int64_field_name].values.tolist() assert collection_w.num_entities == ct.default_nb @pytest.mark.tags(CaseLabel.L0) @@ -407,17 +405,18 @@ def test_insert_dataframe_order_inconsistent_schema(self): collection_w = self.init_collection_wrap(name=c_name) nb = 10 int_values = pd.Series(data=[i for i in range(nb)]) - float_values = pd.Series(data=[float(i) - for i in range(nb)], dtype="float32") + float_values = pd.Series(data=[float(i) for i in range(nb)], dtype="float32") float_vec_values = cf.gen_vectors(nb, ct.default_dim) df = pd.DataFrame({ ct.default_float_field_name: float_values, ct.default_float_vec_field_name: float_vec_values, ct.default_int64_field_name: int_values }) - error = {ct.err_code: 1, ct.err_msg: "The fields don't match with schema fields"} - collection_w.insert( - data=df, check_task=CheckTasks.err_res, check_items=error) + error = {ct.err_code: 1, + ct.err_msg: "The fields don't match with schema fields, expected: ['int64', 'float', " + "'varchar', 'json_field', 'float_vector'], got ['float', 'float_vector', " + "'int64']"} + collection_w.insert(data=df, check_task=CheckTasks.err_res, check_items=error) @pytest.mark.tags(CaseLabel.L2) def test_insert_inconsistent_data(self): @@ -2221,3 +2220,172 @@ def test_upsert_tuple_using_default_value(self, default_value): data = (int_values, default_value, string_values, vectors) collection_w.upsert(data, check_task=CheckTasks.err_res, check_items={ct.err_code: 1, ct.err_msg: "Field varchar don't match in entities[0]"}) + + +class TestInsertArray(TestcaseBase): + """ Test case of Insert array """ + + @pytest.mark.tags(CaseLabel.L1) + @pytest.mark.parametrize("auto_id", [True, False]) + def test_insert_array_dataframe(self, auto_id): + """ + target: test insert DataFrame data + method: Insert data in the form of dataframe + expected: assert num entities + """ + schema = cf.gen_array_collection_schema(auto_id=auto_id) + collection_w = self.init_collection_wrap(schema=schema) + data = cf.gen_array_dataframe_data() + if auto_id: + data = data.drop(ct.default_int64_field_name, axis=1) + collection_w.insert(data=data) + collection_w.flush() + assert collection_w.num_entities == ct.default_nb + + @pytest.mark.tags(CaseLabel.L1) + @pytest.mark.parametrize("auto_id", [True, False]) + def test_insert_array_list(self, auto_id): + """ + target: test insert list data + method: Insert data in the form of a list + expected: assert num entities + """ + schema = cf.gen_array_collection_schema(auto_id=auto_id) + collection_w = self.init_collection_wrap(schema=schema) + + nb = ct.default_nb + arr_len = ct.default_max_capacity + pk_values = [i for i in range(nb)] + float_vec = cf.gen_vectors(nb, ct.default_dim) + int32_values = [[np.int32(j) for j in range(i, i+arr_len)] for i in range(nb)] + float_values = [[np.float32(j) for j in range(i, i+arr_len)] for i in range(nb)] + string_values = [[str(j) for j in range(i, i+arr_len)] for i in range(nb)] + + data = [pk_values, float_vec, int32_values, float_values, string_values] + if auto_id: + del data[0] + # log.info(data[0][1]) + collection_w.insert(data=data) + assert collection_w.num_entities == nb + + @pytest.mark.tags(CaseLabel.L1) + def test_insert_array_rows(self): + """ + target: test insert row data + method: Insert data in the form of rows + expected: assert num entities + """ + schema = cf.gen_array_collection_schema() + collection_w = self.init_collection_wrap(schema=schema) + + data = cf.get_row_data_by_schema(schema=schema) + collection_w.insert(data=data) + assert collection_w.num_entities == ct.default_nb + + collection_w.upsert(data[:2]) + + @pytest.mark.tags(CaseLabel.L2) + def test_insert_array_empty_list(self): + """ + target: test insert DataFrame data + method: Insert data with the length of array = 0 + expected: assert num entities + """ + nb = ct.default_nb + schema = cf.gen_array_collection_schema() + collection_w = self.init_collection_wrap(schema=schema) + data = cf.gen_array_dataframe_data() + data[ct.default_int32_array_field_name] = [[] for _ in range(nb)] + collection_w.insert(data=data) + assert collection_w.num_entities == ct.default_nb + + @pytest.mark.tags(CaseLabel.L2) + def test_insert_array_length_differ(self): + """ + target: test insert row data + method: Insert data with every row's array length differ + expected: assert num entities + """ + nb = ct.default_nb + schema = cf.gen_array_collection_schema() + collection_w = self.init_collection_wrap(schema=schema) + array = [] + for i in range(nb): + arr_len1 = random.randint(0, ct.default_max_capacity) + arr_len2 = random.randint(0, ct.default_max_capacity) + arr = { + ct.default_int64_field_name: i, + ct.default_float_vec_field_name: [random.random() for _ in range(ct.default_dim)], + ct.default_int32_array_field_name: [np.int32(j) for j in range(arr_len1)], + ct.default_float_array_field_name: [np.float32(j) for j in range(arr_len2)], + ct.default_string_array_field_name: [str(j) for j in range(ct.default_max_capacity)], + } + array.append(arr) + + collection_w.insert(array) + assert collection_w.num_entities == nb + + data = cf.get_row_data_by_schema(nb=2, schema=schema) + collection_w.upsert(data) + + @pytest.mark.tags(CaseLabel.L2) + def test_insert_array_length_invalid(self): + """ + target: Insert actual array length > max_capacity + method: Insert actual array length > max_capacity + expected: raise error + """ + # init collection + schema = cf.gen_array_collection_schema() + collection_w = self.init_collection_wrap(schema=schema) + # Insert actual array length > max_capacity + arr_len = ct.default_max_capacity + 1 + data = cf.get_row_data_by_schema(schema=schema) + data[1][ct.default_float_array_field_name] = [np.float32(i) for i in range(arr_len)] + err_msg = (f"the length (101) of 1th array exceeds max capacity ({ct.default_max_capacity}): " + f"expected=valid length array, actual=array length exceeds max capacity: invalid parameter") + collection_w.insert(data=data, check_task=CheckTasks.err_res, + check_items={ct.err_code: 1100, ct.err_msg: err_msg}) + + @pytest.mark.tags(CaseLabel.L2) + def test_insert_array_type_invalid(self): + """ + target: Insert array type invalid + method: 1. Insert string values to an int array + 2. upsert float values to a string array + expected: raise error + """ + # init collection + arr_len = 10 + schema = cf.gen_array_collection_schema() + collection_w = self.init_collection_wrap(schema=schema) + data = cf.get_row_data_by_schema(schema=schema) + + # 1. Insert string values to an int array + data[1][ct.default_int32_array_field_name] = [str(i) for i in range(arr_len)] + err_msg = "The data in the same column must be of the same type." + collection_w.insert(data=data, check_task=CheckTasks.err_res, + check_items={ct.err_code: 1, ct.err_msg: err_msg}) + + # 2. upsert float values to a string array + data = cf.get_row_data_by_schema(schema=schema) + data[1][ct.default_string_array_field_name] = [np.float32(i) for i in range(arr_len)] + collection_w.upsert(data=data, check_task=CheckTasks.err_res, + check_items={ct.err_code: 1, ct.err_msg: err_msg}) + + @pytest.mark.tags(CaseLabel.L2) + def test_insert_array_mixed_value(self): + """ + target: Insert array consisting of mixed values + method: Insert array consisting of mixed values + expected: raise error + """ + # init collection + schema = cf.gen_array_collection_schema() + collection_w = self.init_collection_wrap(schema=schema) + # Insert array consisting of mixed values + data = cf.get_row_data_by_schema(schema=schema) + data[1][ct.default_string_array_field_name] = ["a", 1, [2.0, 3.0], False] + collection_w.insert(data=data, check_task=CheckTasks.err_res, + check_items={ct.err_code: 1, + ct.err_msg: "The data in the same column must be of the same type."}) diff --git a/tests/python_client/testcases/test_query.py b/tests/python_client/testcases/test_query.py index c5a7bc70646e6..ac4a3d658bbcf 100644 --- a/tests/python_client/testcases/test_query.py +++ b/tests/python_client/testcases/test_query.py @@ -573,7 +573,8 @@ def test_query_expr_non_constant_array_term(self): collection_w.query(term_expr, check_task=CheckTasks.err_res, check_items=error) @pytest.mark.tags(CaseLabel.L1) - @pytest.mark.parametrize("expr_prefix", ["json_contains", "JSON_CONTAINS"]) + @pytest.mark.parametrize("expr_prefix", ["json_contains", "JSON_CONTAINS", + "array_contains", "ARRAY_CONTAINS"]) def test_query_expr_json_contains(self, enable_dynamic_field, expr_prefix): """ target: test query with expression using json_contains @@ -581,8 +582,7 @@ def test_query_expr_json_contains(self, enable_dynamic_field, expr_prefix): expected: succeed """ # 1. initialize with data - collection_w = self.init_collection_general( - prefix, enable_dynamic_field=enable_dynamic_field)[0] + collection_w = self.init_collection_general(prefix, enable_dynamic_field=enable_dynamic_field)[0] # 2. insert data array = cf.gen_default_rows_data() @@ -608,8 +608,7 @@ def test_query_expr_list_json_contains(self, expr_prefix): expected: succeed """ # 1. initialize with data - collection_w = self.init_collection_general( - prefix, enable_dynamic_field=True)[0] + collection_w = self.init_collection_general(prefix, enable_dynamic_field=True)[0] # 2. insert data limit = ct.default_nb // 4 @@ -656,7 +655,8 @@ def test_query_expr_json_contains_combined_with_normal(self, enable_dynamic_fiel assert len(res) == limit // 2 @pytest.mark.tags(CaseLabel.L1) - @pytest.mark.parametrize("expr_prefix", ["json_contains_all", "JSON_CONTAINS_ALL"]) + @pytest.mark.parametrize("expr_prefix", ["json_contains_all", "JSON_CONTAINS_ALL", + "array_contains_all", "ARRAY_CONTAINS_ALL"]) def test_query_expr_all_datatype_json_contains_all(self, enable_dynamic_field, expr_prefix): """ target: test query with expression using json_contains @@ -865,7 +865,8 @@ def test_query_expr_all_datatype_json_contains_any(self, enable_dynamic_field, e assert len(res) == ct.default_nb // 2 @pytest.mark.tags(CaseLabel.L2) - @pytest.mark.parametrize("expr_prefix", ["json_contains_any", "JSON_CONTAINS_ANY"]) + @pytest.mark.parametrize("expr_prefix", ["json_contains_any", "JSON_CONTAINS_ANY", + "array_contains_any", "ARRAY_CONTAINS_ANY"]) def test_query_expr_list_all_datatype_json_contains_any(self, expr_prefix): """ target: test query with expression using json_contains_any @@ -1018,6 +1019,72 @@ def test_query_expr_json_contains_pagination(self, enable_dynamic_field, expr_pr res = collection_w.query(expression, limit=limit, offset=offset)[0] assert len(res) == limit - offset + @pytest.mark.tags(CaseLabel.L2) + @pytest.mark.parametrize("array_length", ["ARRAY_LENGTH", "array_length"]) + @pytest.mark.parametrize("op", ["==", "!="]) + def test_query_expr_array_length(self, array_length, op, enable_dynamic_field): + """ + target: test query with expression using array_length + method: query with expression using array_length + array_length only support == , != + expected: succeed + """ + # 1. create a collection + schema = cf.gen_array_collection_schema() + collection_w = self.init_collection_wrap(schema=schema, enable_dynamic_field=enable_dynamic_field) + + # 2. insert data + data = cf.gen_array_dataframe_data() + length = [] + for i in range(ct.default_nb): + ran_int = random.randint(50, 53) + length.append(ran_int) + + data[ct.default_float_array_field_name] = \ + [[np.float32(j) for j in range(length[i])] for i in range(ct.default_nb)] + collection_w.insert(data) + + # 3. load and query + collection_w.create_index(ct.default_float_vec_field_name, ct.default_flat_index) + collection_w.load() + expression = f"{array_length}({ct.default_float_array_field_name}) {op} 51" + res = collection_w.query(expression)[0] + + # 4. check + expression = expression.replace(f"{array_length}(float_array)", "array_length") + filter_ids = [] + for i in range(ct.default_nb): + array_length = length[i] + if not expression or eval(expression): + filter_ids.append(i) + assert len(res) == len(filter_ids) + + @pytest.mark.tags(CaseLabel.L2) + @pytest.mark.parametrize("op", [">", "<=", "+ 1 =="]) + def test_query_expr_invalid_array_length(self, op): + """ + target: test query with expression using array_length + method: query with expression using array_length + array_length only support == , != + expected: raise error + """ + # 1. create a collection + schema = cf.gen_array_collection_schema() + collection_w = self.init_collection_wrap(schema=schema) + + # 2. insert data + data = cf.gen_array_dataframe_data() + collection_w.insert(data) + + # 3. load and query + collection_w.create_index(ct.default_float_vec_field_name, ct.default_flat_index) + collection_w.load() + expression = f"array_length({ct.default_float_array_field_name}) {op} 51" + collection_w.query(expression, check_task=CheckTasks.err_res, + check_items={ct.err_code: 65535, + ct.err_msg: "cannot parse expression: %s, error %s " + "is not supported" % (expression, op)}) + @pytest.mark.tags(CaseLabel.L1) def test_query_expr_empty_without_limit(self): """ diff --git a/tests/python_client/testcases/test_search.py b/tests/python_client/testcases/test_search.py index a676e552310cf..f996056309678 100644 --- a/tests/python_client/testcases/test_search.py +++ b/tests/python_client/testcases/test_search.py @@ -1,3 +1,4 @@ +import numpy as np from pymilvus.orm.types import CONSISTENCY_STRONG, CONSISTENCY_BOUNDED, CONSISTENCY_SESSION, CONSISTENCY_EVENTUALLY from common.constants import * from utils.util_pymilvus import * @@ -550,6 +551,63 @@ def test_search_with_expression_invalid_like(self, expression): "err_msg": "failed to create query plan: cannot parse " "expression: %s" % expression}) + @pytest.mark.tags(CaseLabel.L1) + def test_search_with_expression_invalid_array_one(self): + """ + target: test search with invalid array expressions + method: test search with invalid array expressions: + the order of array > the length of array + expected: searched successfully with correct limit(topK) + """ + # 1. create a collection + nb = ct.default_nb + schema = cf.gen_array_collection_schema() + collection_w = self.init_collection_wrap(schema=schema) + data = cf.get_row_data_by_schema(schema=schema) + data[1][ct.default_int32_array_field_name] = [1] + collection_w.insert(data) + collection_w.create_index("float_vector", ct.default_index) + collection_w.load() + + # 2. search + expression = "int32_array[101] > 0" + msg = ("failed to search: attempt #0: failed to search/query delegator 1 for channel " + "by-dev-rootcoord-dml_: fail to Search, QueryNode ID=1, reason=worker(1) query" + " failed: UnknownError: Assert \")index >= 0 && index < length_\" at /go/src/" + "github.com/milvus-io/milvus/internal/core/src/common/Array.h:454 => index out" + " of range, index=101, length=100: attempt #1: no available shard delegator " + "found: service unavailable") + collection_w.search(vectors[:default_nq], default_search_field, + default_search_params, nb, expression, + check_task=CheckTasks.err_res, + check_items={ct.err_code: 65538, + ct.err_msg: msg}) + + @pytest.mark.tags(CaseLabel.L1) + def test_search_with_expression_invalid_array_two(self): + """ + target: test search with invalid array expressions + method: test search with invalid array expressions + expected: searched successfully with correct limit(topK) + """ + # 1. create a collection + nb = ct.default_nb + schema = cf.gen_array_collection_schema() + collection_w = self.init_collection_wrap(schema=schema) + data = cf.get_row_data_by_schema(schema=schema) + collection_w.insert(data) + collection_w.create_index("float_vector", ct.default_index) + collection_w.load() + + # 2. search + expression = "int32_array[0] - 1 < 1" + error = {ct.err_code: 65535, + ct.err_msg: f"failed to create query plan: cannot parse expression: {expression}, " + f"error: LessThan is not supported in execution backend"} + collection_w.search(vectors[:default_nq], default_search_field, + default_search_params, nb, expression, + check_task=CheckTasks.err_res, check_items=error) + @pytest.mark.tags(CaseLabel.L2) def test_search_partition_invalid_type(self, get_invalid_partition): """ @@ -1465,8 +1523,7 @@ def test_search_with_dup_primary_key(self, dim, auto_id, _async, dup_times): insert_res, _ = collection_w.insert(insert_data[0]) insert_ids.extend(insert_res.primary_keys) # search - vectors = [[random.random() for _ in range(dim)] - for _ in range(default_nq)] + vectors = [[random.random() for _ in range(dim)] for _ in range(default_nq)] search_res, _ = collection_w.search(vectors[:nq], default_search_field, default_search_params, default_limit, default_search_exp, _async=_async, @@ -3019,6 +3076,57 @@ def test_search_with_expression_bool(self, dim, auto_id, _async, bool_type, enab ids = hits.ids assert set(ids).issubset(filter_ids_set) + @pytest.mark.tags(CaseLabel.L2) + @pytest.mark.parametrize("expression", cf.gen_array_field_expressions()) + def test_search_with_expression_array(self, expression, _async, enable_dynamic_field): + """ + target: test search with different expressions + method: test search with different expressions + expected: searched successfully with correct limit(topK) + """ + # 1. create a collection + nb = ct.default_nb + schema = cf.gen_array_collection_schema() + collection_w = self.init_collection_wrap(schema=schema, enable_dynamic_field=enable_dynamic_field) + + # 2. insert data + array_length = 10 + data = [] + for i in range(nb): + arr = {ct.default_int64_field_name: i, + ct.default_float_vec_field_name: cf.gen_vectors(1, ct.default_dim)[0], + ct.default_int32_array_field_name: [np.int32(i) for i in range(array_length)], + ct.default_float_array_field_name: [np.float32(i) for i in range(array_length)], + ct.default_string_array_field_name: [str(i) for i in range(array_length)]} + data.append(arr) + collection_w.insert(data) + + # 3. filter result with expression in collection + expression = expression.replace("&&", "and").replace("||", "or") + filter_ids = [] + for i in range(nb): + int32_array = data[i][ct.default_int32_array_field_name] + float_array = data[i][ct.default_float_array_field_name] + string_array = data[i][ct.default_string_array_field_name] + if not expression or eval(expression): + filter_ids.append(i) + + # 4. create index + collection_w.create_index("float_vector", ct.default_index) + collection_w.load() + + # 5. search with expression + log.info("test_search_with_expression: searching with expression: %s" % expression) + search_res, _ = collection_w.search(vectors[:default_nq], default_search_field, + default_search_params, nb, expression, _async=_async) + if _async: + search_res.done() + search_res = search_res.result() + + for hits in search_res: + ids = hits.ids + assert set(ids) == set(filter_ids) + @pytest.mark.tags(CaseLabel.L2) @pytest.mark.xfail(reason="issue 24514") @pytest.mark.parametrize("expression", cf.gen_normal_expressions_field(default_float_field_name)) @@ -3111,9 +3219,8 @@ def test_search_expression_all_data_type(self, nb, nq, dim, auto_id, _async, ena if _async: res.done() res = res.result() - assert len(res[0][0].entity._row_data) != 0 assert (default_int64_field_name and default_float_field_name and default_bool_field_name) \ - in res[0][0].entity._row_data + in res[0][0].fields @pytest.mark.tags(CaseLabel.L1) @pytest.mark.parametrize("field", ct.all_scalar_data_types[:3]) @@ -3373,7 +3480,7 @@ def test_search_output_field_vector_after_binary_index(self, metrics, index): output_fields=[binary_field_name])[0] # 4. check the result vectors should be equal to the inserted - assert res[0][0].entity.binary_vector == [data[binary_field_name][res[0][0].id]] + assert res[0][0].entity.binary_vector == data[binary_field_name][res[0][0].id] @pytest.mark.tags(CaseLabel.L2) @pytest.mark.parametrize("metrics", ct.structure_metrics) @@ -3406,7 +3513,7 @@ def test_search_output_field_vector_after_structure_metrics(self, metrics, index output_fields=[binary_field_name])[0] # 4. check the result vectors should be equal to the inserted - assert res[0][0].entity.binary_vector == [data[binary_field_name][res[0][0].id]] + assert res[0][0].entity.binary_vector == data[binary_field_name][res[0][0].id] @pytest.mark.tags(CaseLabel.L2) @pytest.mark.parametrize("dim", [32, 128, 768]) @@ -7332,7 +7439,7 @@ def test_range_search_with_output_field(self, auto_id, _async, enable_dynamic_fi res.done() res = res.result() assert len(res[0][0].entity._row_data) != 0 - assert default_int64_field_name in res[0][0].entity._row_data + assert default_int64_field_name in res[0][0].fields @pytest.mark.tags(CaseLabel.L2) def test_range_search_concurrent_multi_threads(self, nb, nq, dim, auto_id, _async): @@ -9028,6 +9135,144 @@ def test_search_expression_json_contains_combined_with_normal(self, enable_dynam check_items={"nq": default_nq, "limit": limit // 2}) + @pytest.mark.tags(CaseLabel.L1) + @pytest.mark.parametrize("expr_prefix", ["array_contains", "ARRAY_CONTAINS"]) + def test_search_expr_array_contains(self, expr_prefix): + """ + target: test query with expression using json_contains + method: query with expression using json_contains + expected: succeed + """ + # 1. create a collection + schema = cf.gen_array_collection_schema() + collection_w = self.init_collection_wrap(schema=schema) + + # 2. insert data + string_field_value = [[str(j) for j in range(i, i+3)] for i in range(ct.default_nb)] + data = cf.gen_array_dataframe_data() + data[ct.default_string_array_field_name] = string_field_value + collection_w.insert(data) + collection_w.create_index(ct.default_float_vec_field_name, {}) + + # 3. search + collection_w.load() + expression = f"{expr_prefix}({ct.default_string_array_field_name}, '1000')" + res = collection_w.search(vectors[:default_nq], default_search_field, {}, + limit=ct.default_nb, expr=expression)[0] + exp_ids = cf.assert_json_contains(expression, string_field_value) + assert set(res[0].ids) == set(exp_ids) + + @pytest.mark.tags(CaseLabel.L1) + @pytest.mark.parametrize("expr_prefix", ["array_contains", "ARRAY_CONTAINS"]) + def test_search_expr_not_array_contains(self, expr_prefix): + """ + target: test query with expression using json_contains + method: query with expression using json_contains + expected: succeed + """ + # 1. create a collection + schema = cf.gen_array_collection_schema() + collection_w = self.init_collection_wrap(schema=schema) + + # 2. insert data + string_field_value = [[str(j) for j in range(i, i + 3)] for i in range(ct.default_nb)] + data = cf.gen_array_dataframe_data() + data[ct.default_string_array_field_name] = string_field_value + collection_w.insert(data) + collection_w.create_index(ct.default_float_vec_field_name, {}) + + # 3. search + collection_w.load() + expression = f"not {expr_prefix}({ct.default_string_array_field_name}, '1000')" + res = collection_w.search(vectors[:default_nq], default_search_field, {}, + limit=ct.default_nb, expr=expression)[0] + exp_ids = cf.assert_json_contains(expression, string_field_value) + assert set(res[0].ids) == set(exp_ids) + + @pytest.mark.tags(CaseLabel.L1) + @pytest.mark.parametrize("expr_prefix", ["array_contains_all", "ARRAY_CONTAINS_ALL"]) + def test_search_expr_array_contains_all(self, expr_prefix): + """ + target: test query with expression using json_contains + method: query with expression using json_contains + expected: succeed + """ + # 1. create a collection + schema = cf.gen_array_collection_schema() + collection_w = self.init_collection_wrap(schema=schema) + + # 2. insert data + string_field_value = [[str(j) for j in range(i, i + 3)] for i in range(ct.default_nb)] + data = cf.gen_array_dataframe_data() + data[ct.default_string_array_field_name] = string_field_value + collection_w.insert(data) + collection_w.create_index(ct.default_float_vec_field_name, {}) + + # 3. search + collection_w.load() + expression = f"{expr_prefix}({ct.default_string_array_field_name}, ['1000'])" + res = collection_w.search(vectors[:default_nq], default_search_field, {}, + limit=ct.default_nb, expr=expression)[0] + exp_ids = cf.assert_json_contains(expression, string_field_value) + assert set(res[0].ids) == set(exp_ids) + + @pytest.mark.tags(CaseLabel.L1) + @pytest.mark.parametrize("expr_prefix", ["array_contains_any", "ARRAY_CONTAINS_ANY", + "not array_contains_any", "not ARRAY_CONTAINS_ANY"]) + def test_search_expr_array_contains_any(self, expr_prefix): + """ + target: test query with expression using json_contains + method: query with expression using json_contains + expected: succeed + """ + # 1. create a collection + schema = cf.gen_array_collection_schema() + collection_w = self.init_collection_wrap(schema=schema) + + # 2. insert data + string_field_value = [[str(j) for j in range(i, i + 3)] for i in range(ct.default_nb)] + data = cf.gen_array_dataframe_data() + data[ct.default_string_array_field_name] = string_field_value + collection_w.insert(data) + collection_w.create_index(ct.default_float_vec_field_name, {}) + + # 3. search + collection_w.load() + expression = f"{expr_prefix}({ct.default_string_array_field_name}, ['1000'])" + res = collection_w.search(vectors[:default_nq], default_search_field, {}, + limit=ct.default_nb, expr=expression)[0] + exp_ids = cf.assert_json_contains(expression, string_field_value) + assert set(res[0].ids) == set(exp_ids) + + @pytest.mark.tags(CaseLabel.L2) + @pytest.mark.parametrize("expr_prefix", ["array_contains_all", "ARRAY_CONTAINS_ALL", + "array_contains_any", "ARRAY_CONTAINS_ANY"]) + def test_search_expr_array_contains_invalid(self, expr_prefix): + """ + target: test query with expression using json_contains + method: query with expression using json_contains(a, b) b not list + expected: report error + """ + # 1. create a collection + schema = cf.gen_array_collection_schema() + collection_w = self.init_collection_wrap(schema=schema) + + # 2. insert data + data = cf.gen_array_dataframe_data() + collection_w.insert(data) + collection_w.create_index(ct.default_float_vec_field_name, {}) + + # 3. search + collection_w.load() + expression = f"{expr_prefix}({ct.default_string_array_field_name}, '1000')" + collection_w.search(vectors[:default_nq], default_search_field, {}, + limit=ct.default_nb, expr=expression, + check_task=CheckTasks.err_res, + check_items={ct.err_code: 65535, + ct.err_msg: "failed to create query plan: cannot parse " + "expression: %s, error: contains_any operation " + "element must be an array" % expression}) + class TestSearchIterator(TestcaseBase): """ Test case of search iterator """