diff --git a/internal/models/ali/ali_dashscope_text_embedding.go b/internal/models/ali/ali_dashscope_text_embedding.go index 329451577f07f..bdf5ff7d866b1 100644 --- a/internal/models/ali/ali_dashscope_text_embedding.go +++ b/internal/models/ali/ali_dashscope_text_embedding.go @@ -83,6 +83,7 @@ func (eb *ByIndex) Len() int { return len(eb.resp.Output.Embeddings) } func (eb *ByIndex) Swap(i, j int) { eb.resp.Output.Embeddings[i], eb.resp.Output.Embeddings[j] = eb.resp.Output.Embeddings[j], eb.resp.Output.Embeddings[i] } + func (eb *ByIndex) Less(i, j int) bool { return eb.resp.Output.Embeddings[i].TextIndex < eb.resp.Output.Embeddings[j].TextIndex } @@ -116,23 +117,23 @@ func (c *AliDashScopeEmbedding) Check() error { return nil } -func (c *AliDashScopeEmbedding) Embedding(modelName string, texts []string, dim int, text_type string, output_type string, timeoutSec time.Duration) (*EmbeddingResponse, error) { +func (c *AliDashScopeEmbedding) Embedding(modelName string, texts []string, dim int, textType string, outputType string, timeoutSec int64) (*EmbeddingResponse, error) { var r EmbeddingRequest r.Model = modelName r.Input = Input{texts} r.Parameters.Dimension = dim - r.Parameters.TextType = text_type - r.Parameters.OutputType = output_type + r.Parameters.TextType = textType + r.Parameters.OutputType = outputType data, err := json.Marshal(r) if err != nil { return nil, err } if timeoutSec <= 0 { - timeoutSec = 30 + timeoutSec = utils.DefaultTimeout } - ctx, cancel := context.WithTimeout(context.Background(), timeoutSec*time.Second) + ctx, cancel := context.WithTimeout(context.Background(), time.Duration(timeoutSec)*time.Second) defer cancel() req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.url, bytes.NewBuffer(data)) if err != nil { diff --git a/internal/models/openai/openai_embedding.go b/internal/models/openai/openai_embedding.go index bb6f88be0cd18..00cc478e62bbb 100644 --- a/internal/models/openai/openai_embedding.go +++ b/internal/models/openai/openai_embedding.go @@ -102,7 +102,7 @@ type EmbedddingError struct { type OpenAIEmbeddingInterface interface { Check() error - Embedding(modelName string, texts []string, dim int, user string, timeoutSec time.Duration) (*EmbeddingResponse, error) + Embedding(modelName string, texts []string, dim int, user string, timeoutSec int64) (*EmbeddingResponse, error) } type openAIBase struct { @@ -135,20 +135,7 @@ func (c *openAIBase) genReq(modelName string, texts []string, dim int, user stri return &r } -type OpenAIEmbeddingClient struct { - openAIBase -} - -func NewOpenAIEmbeddingClient(apiKey string, url string) *OpenAIEmbeddingClient { - return &OpenAIEmbeddingClient{ - openAIBase{ - apiKey: apiKey, - url: url, - }, - } -} - -func (c *OpenAIEmbeddingClient) Embedding(modelName string, texts []string, dim int, user string, timeoutSec time.Duration) (*EmbeddingResponse, error) { +func (c *openAIBase) embedding(url string, headers map[string]string, modelName string, texts []string, dim int, user string, timeoutSec int64) (*EmbeddingResponse, error) { r := c.genReq(modelName, texts, dim, user) data, err := json.Marshal(r) if err != nil { @@ -156,20 +143,23 @@ func (c *OpenAIEmbeddingClient) Embedding(modelName string, texts []string, dim } if timeoutSec <= 0 { - timeoutSec = 30 + timeoutSec = utils.DefaultTimeout } - ctx, cancel := context.WithTimeout(context.Background(), timeoutSec*time.Second) + + ctx, cancel := context.WithTimeout(context.Background(), time.Duration(timeoutSec)*time.Second) defer cancel() - req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.url, bytes.NewBuffer(data)) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewBuffer(data)) if err != nil { return nil, err } - req.Header.Set("Content-Type", "application/json") - req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.apiKey)) + for key, value := range headers { + req.Header.Set(key, value) + } body, err := utils.RetrySend(req, 3) if err != nil { return nil, err } + var res EmbeddingResponse err = json.Unmarshal(body, &res) if err != nil { @@ -179,6 +169,27 @@ func (c *OpenAIEmbeddingClient) Embedding(modelName string, texts []string, dim return &res, err } +type OpenAIEmbeddingClient struct { + openAIBase +} + +func NewOpenAIEmbeddingClient(apiKey string, url string) *OpenAIEmbeddingClient { + return &OpenAIEmbeddingClient{ + openAIBase{ + apiKey: apiKey, + url: url, + }, + } +} + +func (c *OpenAIEmbeddingClient) Embedding(modelName string, texts []string, dim int, user string, timeoutSec int64) (*EmbeddingResponse, error) { + headers := map[string]string{ + "Content-Type": "application/json", + "Authorization": fmt.Sprintf("Bearer %s", c.apiKey), + } + return c.embedding(c.url, headers, modelName, texts, dim, user, timeoutSec) +} + type AzureOpenAIEmbeddingClient struct { openAIBase apiVersion string @@ -194,17 +205,7 @@ func NewAzureOpenAIEmbeddingClient(apiKey string, url string) *AzureOpenAIEmbedd } } -func (c *AzureOpenAIEmbeddingClient) Embedding(modelName string, texts []string, dim int, user string, timeoutSec time.Duration) (*EmbeddingResponse, error) { - r := c.genReq(modelName, texts, dim, user) - data, err := json.Marshal(r) - if err != nil { - return nil, err - } - - if timeoutSec <= 0 { - timeoutSec = 30 - } - +func (c *AzureOpenAIEmbeddingClient) Embedding(modelName string, texts []string, dim int, user string, timeoutSec int64) (*EmbeddingResponse, error) { base, err := url.Parse(c.url) if err != nil { return nil, err @@ -214,25 +215,11 @@ func (c *AzureOpenAIEmbeddingClient) Embedding(modelName string, texts []string, params := url.Values{} params.Add("api-version", c.apiVersion) base.RawQuery = params.Encode() + url := base.String() - ctx, cancel := context.WithTimeout(context.Background(), timeoutSec*time.Second) - defer cancel() - req, err := http.NewRequestWithContext(ctx, http.MethodPost, base.String(), bytes.NewBuffer(data)) - if err != nil { - return nil, err - } - req.Header.Set("Content-Type", "application/json") - req.Header.Set("api-key", c.apiKey) - body, err := utils.RetrySend(req, 3) - if err != nil { - return nil, err - } - - var res EmbeddingResponse - err = json.Unmarshal(body, &res) - if err != nil { - return nil, err + headers := map[string]string{ + "Content-Type": "application/json", + "api-key": c.apiKey, } - sort.Sort(&ByIndex{&res}) - return &res, err + return c.embedding(url, headers, modelName, texts, dim, user, timeoutSec) } diff --git a/internal/models/utils/embedding_util.go b/internal/models/utils/embedding_util.go index 1d6e7d916cab2..b3aaccdd0cd20 100644 --- a/internal/models/utils/embedding_util.go +++ b/internal/models/utils/embedding_util.go @@ -22,6 +22,8 @@ import ( "net/http" ) +const DefaultTimeout int64 = 30 + func send(req *http.Request) ([]byte, error) { resp, err := http.DefaultClient.Do(req) if err != nil { @@ -34,7 +36,7 @@ func send(req *http.Request) ([]byte, error) { return nil, err } - if resp.StatusCode != 200 { + if resp.StatusCode != http.StatusOK { return nil, fmt.Errorf(string(body)) } return body, nil diff --git a/internal/models/vertexai/vertexai_text_embedding.go b/internal/models/vertexai/vertexai_text_embedding.go index 3842824616214..8fe2bad23ce67 100644 --- a/internal/models/vertexai/vertexai_text_embedding.go +++ b/internal/models/vertexai/vertexai_text_embedding.go @@ -24,9 +24,9 @@ import ( "net/http" "time" - "github.com/milvus-io/milvus/internal/models/utils" - "golang.org/x/oauth2/google" + + "github.com/milvus-io/milvus/internal/models/utils" ) type Instance struct { @@ -114,7 +114,7 @@ func (c *VertexAIEmbedding) getAccessToken() (string, error) { return token.AccessToken, nil } -func (c *VertexAIEmbedding) Embedding(modelName string, texts []string, dim int64, taskType string, timeoutSec time.Duration) (*EmbeddingResponse, error) { +func (c *VertexAIEmbedding) Embedding(modelName string, texts []string, dim int64, taskType string, timeoutSec int64) (*EmbeddingResponse, error) { var r EmbeddingRequest for _, text := range texts { r.Instances = append(r.Instances, Instance{TaskType: taskType, Content: text}) @@ -129,10 +129,10 @@ func (c *VertexAIEmbedding) Embedding(modelName string, texts []string, dim int6 } if timeoutSec <= 0 { - timeoutSec = 30 + timeoutSec = utils.DefaultTimeout } - ctx, cancel := context.WithTimeout(context.Background(), timeoutSec*time.Second) + ctx, cancel := context.WithTimeout(context.Background(), time.Duration(timeoutSec)*time.Second) defer cancel() req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.url, bytes.NewBuffer(data)) if err != nil { diff --git a/internal/models/vertexai/vertexai_text_embedding_test.go b/internal/models/vertexai/vertexai_text_embedding_test.go index f138d659a3ea4..83b26ac4d4634 100644 --- a/internal/models/vertexai/vertexai_text_embedding_test.go +++ b/internal/models/vertexai/vertexai_text_embedding_test.go @@ -27,7 +27,7 @@ import ( ) func TestEmbeddingClientCheck(t *testing.T) { - mockJsonKey := []byte{1, 2, 3} + mockJSONKey := []byte{1, 2, 3} { c := NewVertexAIEmbedding("mock_url", []byte{}, "mock_scopes", "") err := c.Check() @@ -36,14 +36,14 @@ func TestEmbeddingClientCheck(t *testing.T) { } { - c := NewVertexAIEmbedding("", mockJsonKey, "", "") + c := NewVertexAIEmbedding("", mockJSONKey, "", "") err := c.Check() assert.True(t, err != nil) fmt.Println(err) } { - c := NewVertexAIEmbedding("mock_url", mockJsonKey, "mock_scopes", "") + c := NewVertexAIEmbedding("mock_url", mockJSONKey, "mock_scopes", "") err := c.Check() assert.True(t, err == nil) } diff --git a/internal/proxy/task_insert_test.go b/internal/proxy/task_insert_test.go index 2586ddf37afca..006d383be3146 100644 --- a/internal/proxy/task_insert_test.go +++ b/internal/proxy/task_insert_test.go @@ -340,14 +340,19 @@ func TestInsertTask_Function(t *testing.T) { AutoID: true, Fields: []*schemapb.FieldSchema{ {FieldID: 100, Name: "id", DataType: schemapb.DataType_Int64, IsPrimaryKey: true, AutoID: true}, - {FieldID: 101, Name: "text", DataType: schemapb.DataType_VarChar, + { + FieldID: 101, Name: "text", DataType: schemapb.DataType_VarChar, TypeParams: []*commonpb.KeyValuePair{ {Key: "max_length", Value: "200"}, - }}, - {FieldID: 102, Name: "vector", DataType: schemapb.DataType_FloatVector, + }, + }, + { + FieldID: 102, Name: "vector", DataType: schemapb.DataType_FloatVector, TypeParams: []*commonpb.KeyValuePair{ {Key: "dim", Value: "4"}, - }, IsFunctionOutput: true}, + }, + IsFunctionOutput: true, + }, }, Functions: []*schemapb.FunctionSchema{ { diff --git a/internal/proxy/task_search_test.go b/internal/proxy/task_search_test.go index 4af9d69a82a67..b255027cf7316 100644 --- a/internal/proxy/task_search_test.go +++ b/internal/proxy/task_search_test.go @@ -487,18 +487,24 @@ func TestSearchTask_WithFunctions(t *testing.T) { AutoID: true, Fields: []*schemapb.FieldSchema{ {FieldID: 100, Name: "id", DataType: schemapb.DataType_Int64, IsPrimaryKey: true, AutoID: true}, - {FieldID: 101, Name: "text", DataType: schemapb.DataType_VarChar, + { + FieldID: 101, Name: "text", DataType: schemapb.DataType_VarChar, TypeParams: []*commonpb.KeyValuePair{ {Key: "max_length", Value: "200"}, - }}, - {FieldID: 102, Name: "vector1", DataType: schemapb.DataType_FloatVector, + }, + }, + { + FieldID: 102, Name: "vector1", DataType: schemapb.DataType_FloatVector, TypeParams: []*commonpb.KeyValuePair{ {Key: "dim", Value: "4"}, - }}, - {FieldID: 103, Name: "vector2", DataType: schemapb.DataType_FloatVector, + }, + }, + { + FieldID: 103, Name: "vector2", DataType: schemapb.DataType_FloatVector, TypeParams: []*commonpb.KeyValuePair{ {Key: "dim", Value: "4"}, - }}, + }, + }, }, Functions: []*schemapb.FunctionSchema{ { diff --git a/internal/proxy/task_upsert_test.go b/internal/proxy/task_upsert_test.go index 348ee64313d31..da0b3595cc45e 100644 --- a/internal/proxy/task_upsert_test.go +++ b/internal/proxy/task_upsert_test.go @@ -409,14 +409,19 @@ func TestUpsertTask_Function(t *testing.T) { AutoID: true, Fields: []*schemapb.FieldSchema{ {FieldID: 100, Name: "id", DataType: schemapb.DataType_Int64, IsPrimaryKey: true, AutoID: true}, - {FieldID: 101, Name: "text", DataType: schemapb.DataType_VarChar, + { + FieldID: 101, Name: "text", DataType: schemapb.DataType_VarChar, TypeParams: []*commonpb.KeyValuePair{ {Key: "max_length", Value: "200"}, - }}, - {FieldID: 102, Name: "vector", DataType: schemapb.DataType_FloatVector, + }, + }, + { + FieldID: 102, Name: "vector", DataType: schemapb.DataType_FloatVector, TypeParams: []*commonpb.KeyValuePair{ {Key: "dim", Value: "4"}, - }, IsFunctionOutput: true}, + }, + IsFunctionOutput: true, + }, }, Functions: []*schemapb.FunctionSchema{ { diff --git a/internal/util/function/ali_embedding_provider.go b/internal/util/function/ali_embedding_provider.go index 2def771d86971..bdb34a2aeb097 100644 --- a/internal/util/function/ali_embedding_provider.go +++ b/internal/util/function/ali_embedding_provider.go @@ -21,9 +21,7 @@ package function import ( "fmt" "os" - "strconv" "strings" - "time" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus/internal/models/ali" @@ -39,15 +37,15 @@ type AliEmbeddingProvider struct { outputType string maxBatch int - timeoutSec int + timeoutSec int64 } func createAliEmbeddingClient(apiKey string, url string) (*ali.AliDashScopeEmbedding, error) { if apiKey == "" { - apiKey = os.Getenv(dashscopeApiKey) + apiKey = os.Getenv(dashscopeAKEnvStr) } if apiKey == "" { - return nil, fmt.Errorf("Missing credentials. Please pass `api_key`, or configure the %s environment variable in the Milvus service.", dashscopeApiKey) + return nil, fmt.Errorf("Missing credentials. Please pass `api_key`, or configure the %s environment variable in the Milvus service.", dashscopeAKEnvStr) } if url == "" { @@ -70,17 +68,13 @@ func NewAliDashScopeEmbeddingProvider(fieldSchema *schemapb.FieldSchema, functio case modelNameParamKey: modelName = param.Value case dimParamKey: - dim, err = strconv.ParseInt(param.Value, 10, 64) + dim, err = parseAndCheckFieldDim(param.Value, fieldDim, fieldSchema.Name) if err != nil { - return nil, fmt.Errorf("dim [%s] is not int", param.Value) - } - - if dim != 0 && dim != fieldDim { - return nil, fmt.Errorf("Field %s's dim is [%d], but embeding's dim is [%d]", functionSchema.Name, fieldDim, dim) + return nil, err } case apiKeyParamKey: apiKey = param.Value - case embeddingUrlParamKey: + case embeddingURLParamKey: url = param.Value default: } @@ -139,7 +133,7 @@ func (provider *AliEmbeddingProvider) CallEmbedding(texts []string, batchLimit b if end > numRows { end = numRows } - resp, err := provider.client.Embedding(provider.modelName, texts[i:end], int(provider.embedDimParam), textType, provider.outputType, time.Duration(provider.timeoutSec)) + resp, err := provider.client.Embedding(provider.modelName, texts[i:end], int(provider.embedDimParam), textType, provider.outputType, provider.timeoutSec) if err != nil { return nil, err } diff --git a/internal/util/function/alitext_embedding_provider_test.go b/internal/util/function/alitext_embedding_provider_test.go index 73d8613f20a4d..6e48be8ba7aee 100644 --- a/internal/util/function/alitext_embedding_provider_test.go +++ b/internal/util/function/alitext_embedding_provider_test.go @@ -29,7 +29,6 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" - "github.com/milvus-io/milvus/internal/models/ali" ) @@ -49,10 +48,12 @@ func (s *AliTextEmbeddingProviderSuite) SetupTest() { Fields: []*schemapb.FieldSchema{ {FieldID: 100, Name: "int64", DataType: schemapb.DataType_Int64}, {FieldID: 101, Name: "text", DataType: schemapb.DataType_VarChar}, - {FieldID: 102, Name: "vector", DataType: schemapb.DataType_FloatVector, + { + FieldID: 102, Name: "vector", DataType: schemapb.DataType_FloatVector, TypeParams: []*commonpb.KeyValuePair{ {Key: "dim", Value: "4"}, - }}, + }, + }, }, } s.providers = []string{AliDashScopeProvider} @@ -69,7 +70,7 @@ func createAliProvider(url string, schema *schemapb.FieldSchema, providerName st Params: []*commonpb.KeyValuePair{ {Key: modelNameParamKey, Value: TextEmbeddingV3}, {Key: apiKeyParamKey, Value: "mock"}, - {Key: embeddingUrlParamKey, Value: url}, + {Key: embeddingURLParamKey, Value: url}, {Key: dimParamKey, Value: "4"}, }, } @@ -101,7 +102,6 @@ func (s *AliTextEmbeddingProviderSuite) TestEmbedding() { ret, _ := provder.CallEmbedding(data, false, SearchMode) s.Equal([][]float32{{0.0, 0.1, 0.2, 0.3}, {1.0, 1.1, 1.2, 1.3}, {2.0, 2.1, 2.2, 2.3}}, ret) } - } } @@ -134,7 +134,6 @@ func (s *AliTextEmbeddingProviderSuite) TestEmbeddingDimNotMatch() { data := []string{"sentence", "sentence"} _, err2 := provder.CallEmbedding(data, false, InsertMode) s.Error(err2) - } } @@ -163,6 +162,5 @@ func (s *AliTextEmbeddingProviderSuite) TestEmbeddingNubmerNotMatch() { data := []string{"sentence", "sentence2"} _, err2 := provder.CallEmbedding(data, false, InsertMode) s.Error(err2) - } } diff --git a/internal/util/function/bedrock_embedding_provider.go b/internal/util/function/bedrock_embedding_provider.go index f9d4d184e4c8e..e5cbfd16d53b1 100644 --- a/internal/util/function/bedrock_embedding_provider.go +++ b/internal/util/function/bedrock_embedding_provider.go @@ -23,16 +23,15 @@ import ( "encoding/json" "fmt" "os" - "strconv" "strings" - "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" - "github.com/milvus-io/milvus/pkg/util/typeutil" - "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/config" "github.com/aws/aws-sdk-go-v2/credentials" "github.com/aws/aws-sdk-go-v2/service/bedrockruntime" + + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus/pkg/util/typeutil" ) type BedrockClient interface { @@ -60,10 +59,10 @@ func createBedRockEmbeddingClient(awsAccessKeyId string, awsSecretAccessKey stri } if awsSecretAccessKey == "" { - awsSecretAccessKey = os.Getenv(bedrockSecretAccessKey) + awsSecretAccessKey = os.Getenv(bedrockSAKEnvStr) } if awsSecretAccessKey == "" { - return nil, fmt.Errorf("Missing credentials. Please pass `aws_secret_access_key`, or configure the %s environment variable in the Milvus service.", bedrockSecretAccessKey) + return nil, fmt.Errorf("Missing credentials. Please pass `aws_secret_access_key`, or configure the %s environment variable in the Milvus service.", bedrockSAKEnvStr) } if region == "" { return nil, fmt.Errorf("Missing region. Please pass `region` param.") @@ -94,17 +93,13 @@ func NewBedrockEmbeddingProvider(fieldSchema *schemapb.FieldSchema, functionSche case modelNameParamKey: modelName = param.Value case dimParamKey: - dim, err = strconv.ParseInt(param.Value, 10, 64) + dim, err = parseAndCheckFieldDim(param.Value, fieldDim, fieldSchema.Name) if err != nil { - return nil, fmt.Errorf("dim [%s] is not int", param.Value) - } - - if dim != 0 && dim != fieldDim { - return nil, fmt.Errorf("Field %s's dim is [%d], but embeding's dim is [%d]", functionSchema.Name, fieldDim, dim) + return nil, err } - case awsAccessKeyIdParamKey: + case awsAKIdParamKey: awsAccessKeyId = param.Value - case awsSecretAccessKeyParamKey: + case awsSAKParamKey: awsSecretAccessKey = param.Value case regionParamKey: region = param.Value @@ -178,7 +173,6 @@ func (provider *BedrockEmbeddingProvider) CallEmbedding(texts []string, batchLim ModelId: aws.String(provider.modelName), ContentType: aws.String("application/json"), }) - if err != nil { return nil, err } diff --git a/internal/util/function/bedrock_text_embedding_provider_test.go b/internal/util/function/bedrock_text_embedding_provider_test.go index eb26aa03ef3f7..ed8a6ba58a600 100644 --- a/internal/util/function/bedrock_text_embedding_provider_test.go +++ b/internal/util/function/bedrock_text_embedding_provider_test.go @@ -44,10 +44,12 @@ func (s *BedrockTextEmbeddingProviderSuite) SetupTest() { Fields: []*schemapb.FieldSchema{ {FieldID: 100, Name: "int64", DataType: schemapb.DataType_Int64}, {FieldID: 101, Name: "text", DataType: schemapb.DataType_VarChar}, - {FieldID: 102, Name: "vector", DataType: schemapb.DataType_FloatVector, + { + FieldID: 102, Name: "vector", DataType: schemapb.DataType_FloatVector, TypeParams: []*commonpb.KeyValuePair{ {Key: "dim", Value: "4"}, - }}, + }, + }, }, } s.providers = []string{BedrockProvider} @@ -92,7 +94,6 @@ func (s *BedrockTextEmbeddingProviderSuite) TestEmbedding() { ret, _ := provder.CallEmbedding(data, false, SearchMode) s.Equal([][]float32{{0.0, 0.1, 0.2, 0.3}, {0.0, 0.1, 0.2, 0.3}, {0.0, 0.1, 0.2, 0.3}}, ret) } - } } @@ -105,6 +106,5 @@ func (s *BedrockTextEmbeddingProviderSuite) TestEmbeddingDimNotMatch() { data := []string{"sentence", "sentence"} _, err2 := provder.CallEmbedding(data, false, InsertMode) s.Error(err2) - } } diff --git a/internal/util/function/common.go b/internal/util/function/common.go index 56da30e5ed42f..7bfce2659d9e4 100644 --- a/internal/util/function/common.go +++ b/internal/util/function/common.go @@ -18,6 +18,11 @@ package function +import ( + "fmt" + "strconv" +) + const ( InsertMode string = "Insert" SearchMode string = "Search" @@ -27,7 +32,7 @@ const ( const ( modelNameParamKey string = "model_name" dimParamKey string = "dim" - embeddingUrlParamKey string = "url" + embeddingURLParamKey string = "url" apiKeyParamKey string = "api_key" ) @@ -37,7 +42,7 @@ const ( TextEmbeddingV2 string = "text-embedding-v2" TextEmbeddingV3 string = "text-embedding-v3" - dashscopeApiKey string = "MILVUS_DASHSCOPE_API_KEY" + dashscopeAKEnvStr string = "MILVUS_DASHSCOPE_API_KEY" ) // openai/azure text embedding @@ -47,10 +52,10 @@ const ( TextEmbedding3Small string = "text-embedding-3-small" TextEmbedding3Large string = "text-embedding-3-large" - openaiApiKey string = "MILVUSAI_OPENAI_API_KEY" + openaiAKEnvStr string = "MILVUSAI_OPENAI_API_KEY" - azureOpenaiApiKey string = "MILVUSAI_AZURE_OPENAI_API_KEY" - azureOpenaiEndpoint string = "MILVUSAI_AZURE_OPENAI_ENDPOINT" + azureOpenaiAKEnvStr string = "MILVUSAI_AZURE_OPENAI_API_KEY" + azureOpenaiResourceName string = "MILVUSAI_AZURE_OPENAI_RESOURCE_NAME" userParamKey string = "user" ) @@ -59,13 +64,13 @@ const ( const ( BedRockTitanTextEmbeddingsV2 string = "amazon.titan-embed-text-v2:0" - awsAccessKeyIdParamKey string = "aws_access_key_id" - awsSecretAccessKeyParamKey string = "aws_secret_access_key" + awsAKIdParamKey string = "aws_access_key_id" + awsSAKParamKey string = "aws_secret_access_key" regionParamKey string = "regin" normalizeParamKey string = "normalize" - bedrockAccessKeyId string = "MILVUSAI_BEDROCK_ACCESS_KEY_ID" - bedrockSecretAccessKey string = "MILVUSAI_BEDROCK_SECRET_ACCESS_KEY" + bedrockAccessKeyId string = "MILVUSAI_BEDROCK_ACCESS_KEY_ID" + bedrockSAKEnvStr string = "MILVUSAI_BEDROCK_SECRET_ACCESS_KEY" ) // vertexAI @@ -80,3 +85,15 @@ const ( vertexServiceAccountJSONEnv string = "MILVUSAI_GOOGLE_APPLICATION_CREDENTIALS" ) + +func parseAndCheckFieldDim(dimStr string, fieldDim int64, fieldName string) (int64, error) { + dim, err := strconv.ParseInt(dimStr, 10, 64) + if err != nil { + return 0, fmt.Errorf("dim [%s] is not int", dimStr) + } + + if dim != 0 && dim != fieldDim { + return 0, fmt.Errorf("Field %s's dim is [%d], but embedding's dim is [%d]", fieldName, fieldDim, dim) + } + return dim, nil +} diff --git a/internal/util/function/function_base.go b/internal/util/function/function_base.go index fa3253bf3c7bd..aabcfdf5c0ea2 100644 --- a/internal/util/function/function_base.go +++ b/internal/util/function/function_base.go @@ -29,10 +29,10 @@ type FunctionBase struct { outputFields []*schemapb.FieldSchema } -func NewFunctionBase(coll *schemapb.CollectionSchema, f_schema *schemapb.FunctionSchema) (*FunctionBase, error) { +func NewFunctionBase(coll *schemapb.CollectionSchema, fSchema *schemapb.FunctionSchema) (*FunctionBase, error) { var base FunctionBase - base.schema = f_schema - for _, fieldName := range f_schema.GetOutputFieldNames() { + base.schema = fSchema + for _, fieldName := range fSchema.GetOutputFieldNames() { for _, field := range coll.GetFields() { if field.GetName() == fieldName { base.outputFields = append(base.outputFields, field) @@ -41,9 +41,9 @@ func NewFunctionBase(coll *schemapb.CollectionSchema, f_schema *schemapb.Functio } } - if len(base.outputFields) != len(f_schema.GetOutputFieldNames()) { + if len(base.outputFields) != len(fSchema.GetOutputFieldNames()) { return &base, fmt.Errorf("The collection [%s]'s information is wrong, function [%s]'s outputs does not match the schema", - coll.Name, f_schema.Name) + coll.Name, fSchema.Name) } return &base, nil } diff --git a/internal/util/function/function_executor.go b/internal/util/function/function_executor.go index 6f2469cca9173..d2cc7f61d8345 100644 --- a/internal/util/function/function_executor.go +++ b/internal/util/function/function_executor.go @@ -62,8 +62,8 @@ func createFunction(coll *schemapb.CollectionSchema, schema *schemapb.FunctionSc } func CheckFunctions(schema *schemapb.CollectionSchema) error { - for _, f_schema := range schema.Functions { - if _, err := createFunction(schema, f_schema); err != nil { + for _, fSchema := range schema.Functions { + if _, err := createFunction(schema, fSchema); err != nil { return err } } @@ -77,13 +77,13 @@ func NewFunctionExecutor(schema *schemapb.CollectionSchema) (*FunctionExecutor, executor := &FunctionExecutor{ runners: make(map[int64]Runner), } - for _, f_schema := range schema.Functions { - if runner, err := createFunction(schema, f_schema); err != nil { + for _, fSchema := range schema.Functions { + runner, err := createFunction(schema, fSchema) + if err != nil { return nil, err - } else { - if runner != nil { - executor.runners[f_schema.GetOutputFieldIds()[0]] = runner - } + } + if runner != nil { + executor.runners[fSchema.GetOutputFieldIds()[0]] = runner } } return executor, nil @@ -193,15 +193,14 @@ func (executor *FunctionExecutor) prcessAdvanceSearch(req *internalpb.SearchRequ return fmt.Errorf("Nq [%d] > function [%s]'s max batch [%d]", sub.Nq, runner.GetSchema().Name, runner.MaxBatch()) } wg.Add(1) - go func(runner Runner, idx int64) { + go func(runner Runner, idx int64, placeholderGroup []byte) { defer wg.Done() - if newHolder, err := executor.processSingleSearch(runner, sub.GetPlaceholderGroup()); err != nil { + if newHolder, err := executor.processSingleSearch(runner, placeholderGroup); err != nil { errChan <- err } else { outputs <- map[int64][]byte{idx: newHolder} } - - }(runner, int64(idx)) + }(runner, int64(idx), sub.GetPlaceholderGroup()) } } wg.Wait() @@ -222,9 +221,8 @@ func (executor *FunctionExecutor) prcessAdvanceSearch(req *internalpb.SearchRequ func (executor *FunctionExecutor) ProcessSearch(req *internalpb.SearchRequest) error { if !req.IsAdvanced { return executor.prcessSearch(req) - } else { - return executor.prcessAdvanceSearch(req) } + return executor.prcessAdvanceSearch(req) } func (executor *FunctionExecutor) processSingleBulkInsert(runner Runner, data *storage.InsertData) (map[storage.FieldID]storage.FieldData, error) { diff --git a/internal/util/function/function_executor_test.go b/internal/util/function/function_executor_test.go index a791760351fb7..a38360d343660 100644 --- a/internal/util/function/function_executor_test.go +++ b/internal/util/function/function_executor_test.go @@ -49,16 +49,20 @@ func (s *FunctionExecutorSuite) creataSchema(url string) *schemapb.CollectionSch Fields: []*schemapb.FieldSchema{ {FieldID: 100, Name: "int64", DataType: schemapb.DataType_Int64}, {FieldID: 101, Name: "text", DataType: schemapb.DataType_VarChar}, - {FieldID: 102, Name: "vector", DataType: schemapb.DataType_FloatVector, + { + FieldID: 102, Name: "vector", DataType: schemapb.DataType_FloatVector, TypeParams: []*commonpb.KeyValuePair{ {Key: "dim", Value: "4"}, }, IsFunctionOutput: true, }, - {FieldID: 103, Name: "vector2", DataType: schemapb.DataType_FloatVector, + { + FieldID: 103, Name: "vector2", DataType: schemapb.DataType_FloatVector, TypeParams: []*commonpb.KeyValuePair{ {Key: "dim", Value: "8"}, - }, IsFunctionOutput: true}, + }, + IsFunctionOutput: true, + }, }, Functions: []*schemapb.FunctionSchema{ { @@ -72,7 +76,7 @@ func (s *FunctionExecutorSuite) creataSchema(url string) *schemapb.CollectionSch {Key: Provider, Value: OpenAIProvider}, {Key: modelNameParamKey, Value: "text-embedding-ada-002"}, {Key: apiKeyParamKey, Value: "mock"}, - {Key: embeddingUrlParamKey, Value: url}, + {Key: embeddingURLParamKey, Value: url}, {Key: dimParamKey, Value: "4"}, }, }, @@ -87,17 +91,15 @@ func (s *FunctionExecutorSuite) creataSchema(url string) *schemapb.CollectionSch {Key: Provider, Value: OpenAIProvider}, {Key: modelNameParamKey, Value: "text-embedding-ada-002"}, {Key: apiKeyParamKey, Value: "mock"}, - {Key: embeddingUrlParamKey, Value: url}, + {Key: embeddingURLParamKey, Value: url}, {Key: dimParamKey, Value: "8"}, }, }, }, } - } func (s *FunctionExecutorSuite) createMsg(texts []string) *msgstream.InsertMsg { - data := []*schemapb.FieldData{} f := schemapb.FieldData{ Type: schemapb.DataType_VarChar, @@ -173,7 +175,6 @@ func (s *FunctionExecutorSuite) TestErrorEmbedding() { w.WriteHeader(http.StatusOK) data, _ := json.Marshal(res) w.Write(data) - })) defer ts.Close() schema := s.creataSchema(ts.URL) diff --git a/internal/util/function/function_util.go b/internal/util/function/function_util.go index bd0265336baa7..240e13615b4f7 100644 --- a/internal/util/function/function_util.go +++ b/internal/util/function/function_util.go @@ -26,8 +26,8 @@ import ( func HasFunctions(functions []*schemapb.FunctionSchema, outputIDs []int64) bool { // Determine whether the column corresponding to outputIDs contains functions, except bm25 function, // if outputIDs is empty, check all cols - for _, f_schema := range functions { - switch f_schema.GetType() { + for _, fSchema := range functions { + switch fSchema.GetType() { case schemapb.FunctionType_BM25: case schemapb.FunctionType_Unknown: default: @@ -35,7 +35,7 @@ func HasFunctions(functions []*schemapb.FunctionSchema, outputIDs []int64) bool return true } else { for _, id := range outputIDs { - if f_schema.GetOutputFieldIds()[0] == id { + if fSchema.GetOutputFieldIds()[0] == id { return true } } @@ -47,14 +47,14 @@ func HasFunctions(functions []*schemapb.FunctionSchema, outputIDs []int64) bool func GetOutputIDFunctionsMap(functions []*schemapb.FunctionSchema) (map[int64]*schemapb.FunctionSchema, error) { outputIdMap := map[int64]*schemapb.FunctionSchema{} - for _, f_schema := range functions { - switch f_schema.GetType() { + for _, fSchema := range functions { + switch fSchema.GetType() { case schemapb.FunctionType_BM25: default: - if len(f_schema.OutputFieldIds) != 1 { - return nil, merr.WrapErrParameterInvalidMsg("Function [%s]'s outputs err, only supports one outputs", f_schema.Name) + if len(fSchema.OutputFieldIds) != 1 { + return nil, merr.WrapErrParameterInvalidMsg("Function [%s]'s outputs err, only supports one outputs", fSchema.Name) } - outputIdMap[f_schema.OutputFieldIds[0]] = f_schema + outputIdMap[fSchema.OutputFieldIds[0]] = fSchema } } return outputIdMap, nil diff --git a/internal/util/function/mock_embedding_service.go b/internal/util/function/mock_embedding_service.go index 4cb181a7a0c4f..789e2d03a2b5a 100644 --- a/internal/util/function/mock_embedding_service.go +++ b/internal/util/function/mock_embedding_service.go @@ -26,6 +26,7 @@ import ( "net/http/httptest" "github.com/aws/aws-sdk-go-v2/service/bedrockruntime" + "github.com/milvus-io/milvus/internal/models/ali" "github.com/milvus-io/milvus/internal/models/openai" "github.com/milvus-io/milvus/internal/models/vertexai" @@ -69,7 +70,6 @@ func CreateOpenAIEmbeddingServer() *httptest.Server { w.WriteHeader(http.StatusOK) data, _ := json.Marshal(res) w.Write(data) - })) return ts } @@ -129,7 +129,6 @@ func CreateVertexAIEmbeddingServer() *httptest.Server { w.WriteHeader(http.StatusOK) data, _ := json.Marshal(res) w.Write(data) - })) return ts } diff --git a/internal/util/function/openai_embedding_provider.go b/internal/util/function/openai_embedding_provider.go index 32cfb945509f7..6cc4ed69b354b 100644 --- a/internal/util/function/openai_embedding_provider.go +++ b/internal/util/function/openai_embedding_provider.go @@ -21,9 +21,7 @@ package function import ( "fmt" "os" - "strconv" "strings" - "time" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus/internal/models/openai" @@ -39,15 +37,15 @@ type OpenAIEmbeddingProvider struct { user string maxBatch int - timeoutSec int + timeoutSec int64 } func createOpenAIEmbeddingClient(apiKey string, url string) (*openai.OpenAIEmbeddingClient, error) { if apiKey == "" { - apiKey = os.Getenv(openaiApiKey) + apiKey = os.Getenv(openaiAKEnvStr) } if apiKey == "" { - return nil, fmt.Errorf("Missing credentials. Please pass `api_key`, or configure the %s environment variable in the Milvus service.", openaiApiKey) + return nil, fmt.Errorf("Missing credentials. Please pass `api_key`, or configure the %s environment variable in the Milvus service.", openaiAKEnvStr) } if url == "" { @@ -60,17 +58,19 @@ func createOpenAIEmbeddingClient(apiKey string, url string) (*openai.OpenAIEmbed func createAzureOpenAIEmbeddingClient(apiKey string, url string) (*openai.AzureOpenAIEmbeddingClient, error) { if apiKey == "" { - apiKey = os.Getenv(azureOpenaiApiKey) + apiKey = os.Getenv(azureOpenaiAKEnvStr) } if apiKey == "" { - return nil, fmt.Errorf("Missing credentials. Please pass `api_key`, or configure the %s environment variable in the Milvus service", azureOpenaiApiKey) + return nil, fmt.Errorf("Missing credentials. Please pass `api_key`, or configure the %s environment variable in the Milvus service", azureOpenaiAKEnvStr) } if url == "" { - url = os.Getenv(azureOpenaiEndpoint) + if resourceName := os.Getenv(azureOpenaiResourceName); resourceName != "" { + url = fmt.Sprintf("https://%s.openai.azure.com", resourceName) + } } if url == "" { - return nil, fmt.Errorf("Must provide `url` arguments or configure the %s environment variable in the Milvus service", azureOpenaiEndpoint) + return nil, fmt.Errorf("Must configure the %s environment variable in the Milvus service", azureOpenaiResourceName) } c := openai.NewAzureOpenAIEmbeddingClient(apiKey, url) return c, nil @@ -89,19 +89,15 @@ func newOpenAIEmbeddingProvider(fieldSchema *schemapb.FieldSchema, functionSchem case modelNameParamKey: modelName = param.Value case dimParamKey: - dim, err = strconv.ParseInt(param.Value, 10, 64) + dim, err = parseAndCheckFieldDim(param.Value, fieldDim, fieldSchema.Name) if err != nil { - return nil, fmt.Errorf("dim [%s] is not int", param.Value) - } - - if dim != 0 && dim != fieldDim { - return nil, fmt.Errorf("Field %s's dim is [%d], but embeding's dim is [%d]", fieldSchema.Name, fieldDim, dim) + return nil, err } case userParamKey: user = param.Value case apiKeyParamKey: apiKey = param.Value - case embeddingUrlParamKey: + case embeddingURLParamKey: url = param.Value default: } @@ -165,7 +161,7 @@ func (provider *OpenAIEmbeddingProvider) CallEmbedding(texts []string, batchLimi if end > numRows { end = numRows } - resp, err := provider.client.Embedding(provider.modelName, texts[i:end], int(provider.embedDimParam), provider.user, time.Duration(provider.timeoutSec)) + resp, err := provider.client.Embedding(provider.modelName, texts[i:end], int(provider.embedDimParam), provider.user, provider.timeoutSec) if err != nil { return nil, err } diff --git a/internal/util/function/openai_text_embedding_provider_test.go b/internal/util/function/openai_text_embedding_provider_test.go index 395ecf06cdc9d..89a20101c9d35 100644 --- a/internal/util/function/openai_text_embedding_provider_test.go +++ b/internal/util/function/openai_text_embedding_provider_test.go @@ -29,7 +29,6 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" - "github.com/milvus-io/milvus/internal/models/openai" ) @@ -49,10 +48,12 @@ func (s *OpenAITextEmbeddingProviderSuite) SetupTest() { Fields: []*schemapb.FieldSchema{ {FieldID: 100, Name: "int64", DataType: schemapb.DataType_Int64}, {FieldID: 101, Name: "text", DataType: schemapb.DataType_VarChar}, - {FieldID: 102, Name: "vector", DataType: schemapb.DataType_FloatVector, + { + FieldID: 102, Name: "vector", DataType: schemapb.DataType_FloatVector, TypeParams: []*commonpb.KeyValuePair{ {Key: "dim", Value: "4"}, - }}, + }, + }, }, } s.providers = []string{OpenAIProvider, AzureOpenAIProvider} @@ -70,7 +71,7 @@ func createOpenAIProvider(url string, schema *schemapb.FieldSchema, providerName {Key: modelNameParamKey, Value: "text-embedding-ada-002"}, {Key: apiKeyParamKey, Value: "mock"}, {Key: dimParamKey, Value: "4"}, - {Key: embeddingUrlParamKey, Value: url}, + {Key: embeddingURLParamKey, Value: url}, }, } switch providerName { @@ -103,7 +104,6 @@ func (s *OpenAITextEmbeddingProviderSuite) TestEmbedding() { ret, _ := provder.CallEmbedding(data, false, SearchMode) s.Equal([][]float32{{0.0, 0.1, 0.2, 0.3}, {1.0, 1.1, 1.2, 1.3}, {2.0, 2.1, 2.2, 2.3}}, ret) } - } } @@ -141,7 +141,6 @@ func (s *OpenAITextEmbeddingProviderSuite) TestEmbeddingDimNotMatch() { data := []string{"sentence", "sentence"} _, err2 := provder.CallEmbedding(data, false, InsertMode) s.Error(err2) - } } @@ -174,6 +173,5 @@ func (s *OpenAITextEmbeddingProviderSuite) TestEmbeddingNubmerNotMatch() { data := []string{"sentence", "sentence2"} _, err2 := provder.CallEmbedding(data, false, InsertMode) s.Error(err2) - } } diff --git a/internal/util/function/text_embedding_function.go b/internal/util/function/text_embedding_function.go index d359514d6cfbc..030679df812fb 100644 --- a/internal/util/function/text_embedding_function.go +++ b/internal/util/function/text_embedding_function.go @@ -26,7 +26,6 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus/internal/storage" "github.com/milvus-io/milvus/pkg/util/funcutil" - // "github.com/milvus-io/milvus/pkg/util/typeutil" ) const ( @@ -134,7 +133,6 @@ func NewTextEmbeddingFunction(coll *schemapb.CollectionSchema, functionSchema *s default: return nil, fmt.Errorf("Unsupported embedding service provider: [%s] , list of supported [%s, %s, %s, %s]", provider, OpenAIProvider, AzureOpenAIProvider, AliDashScopeProvider, BedrockProvider) } - } func (runner *TextEmebddingFunction) MaxBatch() int { @@ -147,7 +145,7 @@ func (runner *TextEmebddingFunction) ProcessInsert(inputs []*schemapb.FieldData) } if inputs[0].Type != schemapb.DataType_VarChar { - return nil, fmt.Errorf("Text embedding only supports varchar field, the input is not varchar") + return nil, fmt.Errorf("Text embedding only supports varchar field as input field, but got %s", schemapb.DataType_name[int32(inputs[0].Type)]) } texts := inputs[0].GetScalars().GetStringData().GetData() @@ -193,11 +191,11 @@ func (runner *TextEmebddingFunction) ProcessSearch(placeholderGroup *commonpb.Pl func (runner *TextEmebddingFunction) ProcessBulkInsert(inputs []storage.FieldData) (map[storage.FieldID]storage.FieldData, error) { if len(inputs) != 1 { - return nil, fmt.Errorf("OpenAIEmbedding function only receives one input, bug got [%d]", len(inputs)) + return nil, fmt.Errorf("TextEmbedding function only receives one input, bug got [%d]", len(inputs)) } if inputs[0].GetDataType() != schemapb.DataType_VarChar { - return nil, fmt.Errorf("OpenAIEmbedding only supports varchar field, the input is not varchar") + return nil, fmt.Errorf(" only supports varchar field, the input is not varchar") } texts, ok := inputs[0].GetDataRows().([]string) diff --git a/internal/util/function/text_embedding_function_test.go b/internal/util/function/text_embedding_function_test.go index ce0bfc86dbf51..4fd8f1409e332 100644 --- a/internal/util/function/text_embedding_function_test.go +++ b/internal/util/function/text_embedding_function_test.go @@ -42,10 +42,12 @@ func (s *TextEmbeddingFunctionSuite) SetupTest() { Fields: []*schemapb.FieldSchema{ {FieldID: 100, Name: "int64", DataType: schemapb.DataType_Int64}, {FieldID: 101, Name: "text", DataType: schemapb.DataType_VarChar}, - {FieldID: 102, Name: "vector", DataType: schemapb.DataType_FloatVector, + { + FieldID: 102, Name: "vector", DataType: schemapb.DataType_FloatVector, TypeParams: []*commonpb.KeyValuePair{ {Key: "dim", Value: "4"}, - }}, + }, + }, }, } } @@ -74,7 +76,6 @@ func (s *TextEmbeddingFunctionSuite) TestOpenAIEmbedding() { ts := CreateOpenAIEmbeddingServer() defer ts.Close() { - runner, err := NewTextEmbeddingFunction(s.schema, &schemapb.FunctionSchema{ Name: "test", Type: schemapb.FunctionType_Unknown, @@ -87,7 +88,7 @@ func (s *TextEmbeddingFunctionSuite) TestOpenAIEmbedding() { {Key: modelNameParamKey, Value: "text-embedding-ada-002"}, {Key: dimParamKey, Value: "4"}, {Key: apiKeyParamKey, Value: "mock"}, - {Key: embeddingUrlParamKey, Value: ts.URL}, + {Key: embeddingURLParamKey, Value: ts.URL}, }, }) s.NoError(err) @@ -106,9 +107,7 @@ func (s *TextEmbeddingFunctionSuite) TestOpenAIEmbedding() { s.Equal([]float32{0.0, 0.1, 0.2, 0.3, 1.0, 1.1, 1.2, 1.3, 2.0, 2.1, 2.2, 2.3}, ret[0].GetVectors().GetFloatVector().Data) } } - { - runner, err := NewTextEmbeddingFunction(s.schema, &schemapb.FunctionSchema{ Name: "test", Type: schemapb.FunctionType_Unknown, @@ -121,7 +120,7 @@ func (s *TextEmbeddingFunctionSuite) TestOpenAIEmbedding() { {Key: modelNameParamKey, Value: "text-embedding-ada-002"}, {Key: dimParamKey, Value: "4"}, {Key: apiKeyParamKey, Value: "mock"}, - {Key: embeddingUrlParamKey, Value: ts.URL}, + {Key: embeddingURLParamKey, Value: ts.URL}, }, }) s.NoError(err) @@ -158,7 +157,7 @@ func (s *TextEmbeddingFunctionSuite) TestAliEmbedding() { {Key: modelNameParamKey, Value: TextEmbeddingV3}, {Key: dimParamKey, Value: "4"}, {Key: apiKeyParamKey, Value: "mock"}, - {Key: embeddingUrlParamKey, Value: ts.URL}, + {Key: embeddingURLParamKey, Value: ts.URL}, }, }) s.NoError(err) @@ -176,7 +175,6 @@ func (s *TextEmbeddingFunctionSuite) TestAliEmbedding() { ret, _ := runner.ProcessInsert(data) s.Equal([]float32{0.0, 0.1, 0.2, 0.3, 1.0, 1.1, 1.2, 1.3, 2.0, 2.1, 2.2, 2.3}, ret[0].GetVectors().GetFloatVector().Data) } - } func (s *TextEmbeddingFunctionSuite) TestRunnerParamsErr() { @@ -187,10 +185,12 @@ func (s *TextEmbeddingFunctionSuite) TestRunnerParamsErr() { Fields: []*schemapb.FieldSchema{ {FieldID: 100, Name: "int64", DataType: schemapb.DataType_Int64}, {FieldID: 101, Name: "text", DataType: schemapb.DataType_VarChar}, - {FieldID: 102, Name: "vector", DataType: schemapb.DataType_BFloat16Vector, + { + FieldID: 102, Name: "vector", DataType: schemapb.DataType_BFloat16Vector, TypeParams: []*commonpb.KeyValuePair{ {Key: "dim", Value: "4"}, - }}, + }, + }, }, } @@ -206,7 +206,7 @@ func (s *TextEmbeddingFunctionSuite) TestRunnerParamsErr() { {Key: modelNameParamKey, Value: "text-embedding-ada-002"}, {Key: dimParamKey, Value: "4"}, {Key: apiKeyParamKey, Value: "mock"}, - {Key: embeddingUrlParamKey, Value: "mock"}, + {Key: embeddingURLParamKey, Value: "mock"}, }, }) s.Error(err) @@ -219,14 +219,18 @@ func (s *TextEmbeddingFunctionSuite) TestRunnerParamsErr() { Fields: []*schemapb.FieldSchema{ {FieldID: 100, Name: "int64", DataType: schemapb.DataType_Int64}, {FieldID: 101, Name: "text", DataType: schemapb.DataType_VarChar}, - {FieldID: 102, Name: "vector", DataType: schemapb.DataType_FloatVector, + { + FieldID: 102, Name: "vector", DataType: schemapb.DataType_FloatVector, TypeParams: []*commonpb.KeyValuePair{ {Key: "dim", Value: "4"}, - }}, - {FieldID: 103, Name: "vector2", DataType: schemapb.DataType_FloatVector, + }, + }, + { + FieldID: 103, Name: "vector2", DataType: schemapb.DataType_FloatVector, TypeParams: []*commonpb.KeyValuePair{ {Key: "dim", Value: "4"}, - }}, + }, + }, }, } _, err := NewTextEmbeddingFunction(schema, &schemapb.FunctionSchema{ @@ -241,7 +245,7 @@ func (s *TextEmbeddingFunctionSuite) TestRunnerParamsErr() { {Key: modelNameParamKey, Value: "text-embedding-ada-002"}, {Key: dimParamKey, Value: "4"}, {Key: apiKeyParamKey, Value: "mock"}, - {Key: embeddingUrlParamKey, Value: "mock"}, + {Key: embeddingURLParamKey, Value: "mock"}, }, }) s.Error(err) @@ -261,7 +265,7 @@ func (s *TextEmbeddingFunctionSuite) TestRunnerParamsErr() { {Key: modelNameParamKey, Value: "text-embedding-ada-002"}, {Key: dimParamKey, Value: "4"}, {Key: apiKeyParamKey, Value: "mock"}, - {Key: embeddingUrlParamKey, Value: "mock"}, + {Key: embeddingURLParamKey, Value: "mock"}, }, }) s.Error(err) @@ -281,7 +285,7 @@ func (s *TextEmbeddingFunctionSuite) TestRunnerParamsErr() { {Key: modelNameParamKey, Value: "text-embedding-ada-004"}, {Key: dimParamKey, Value: "4"}, {Key: apiKeyParamKey, Value: "mock"}, - {Key: embeddingUrlParamKey, Value: "mock"}, + {Key: embeddingURLParamKey, Value: "mock"}, }, }) s.Error(err) diff --git a/internal/util/function/vertexai_embedding_provider.go b/internal/util/function/vertexai_embedding_provider.go index 1d9c997571dcf..b1e3017ef93b0 100644 --- a/internal/util/function/vertexai_embedding_provider.go +++ b/internal/util/function/vertexai_embedding_provider.go @@ -21,10 +21,8 @@ package function import ( "fmt" "os" - "strconv" "strings" "sync" - "time" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus/internal/models/vertexai" @@ -77,7 +75,7 @@ type VertextAIEmbeddingProvider struct { task string maxBatch int - timeoutSec int + timeoutSec int64 } func createVertextAIEmbeddingClient(url string) (*vertexai.VertexAIEmbedding, error) { @@ -102,13 +100,9 @@ func NewVertextAIEmbeddingProvider(fieldSchema *schemapb.FieldSchema, functionSc case modelNameParamKey: modelName = param.Value case dimParamKey: - dim, err = strconv.ParseInt(param.Value, 10, 64) + dim, err = parseAndCheckFieldDim(param.Value, fieldDim, fieldSchema.Name) if err != nil { - return nil, fmt.Errorf("dim [%s] is not int", param.Value) - } - - if dim != 0 && dim != fieldDim { - return nil, fmt.Errorf("Field %s's dim is [%d], but embeding's dim is [%d]", functionSchema.Name, fieldDim, dim) + return nil, err } case locationParamKey: location = param.Value @@ -202,7 +196,7 @@ func (provider *VertextAIEmbeddingProvider) CallEmbedding(texts []string, batchL if end > numRows { end = numRows } - resp, err := provider.client.Embedding(provider.modelName, texts[i:end], provider.embedDimParam, taskType, time.Duration(provider.timeoutSec)) + resp, err := provider.client.Embedding(provider.modelName, texts[i:end], provider.embedDimParam, taskType, provider.timeoutSec) if err != nil { return nil, err } diff --git a/internal/util/function/vertexai_embedding_provider_test.go b/internal/util/function/vertexai_embedding_provider_test.go index 2c18b133cc974..10a9093d69634 100644 --- a/internal/util/function/vertexai_embedding_provider_test.go +++ b/internal/util/function/vertexai_embedding_provider_test.go @@ -47,10 +47,12 @@ func (s *VertextAITextEmbeddingProviderSuite) SetupTest() { Fields: []*schemapb.FieldSchema{ {FieldID: 100, Name: "int64", DataType: schemapb.DataType_Int64}, {FieldID: 101, Name: "text", DataType: schemapb.DataType_VarChar}, - {FieldID: 102, Name: "vector", DataType: schemapb.DataType_FloatVector, + { + FieldID: 102, Name: "vector", DataType: schemapb.DataType_FloatVector, TypeParams: []*commonpb.KeyValuePair{ {Key: "dim", Value: "4"}, - }}, + }, + }, }, } } @@ -68,7 +70,7 @@ func createVertextAIProvider(url string, schema *schemapb.FieldSchema) (TextEmbe {Key: locationParamKey, Value: "mock_local"}, {Key: projectIDParamKey, Value: "mock_id"}, {Key: taskTypeParamKey, Value: vertexAICodeRetrival}, - {Key: embeddingUrlParamKey, Value: url}, + {Key: embeddingURLParamKey, Value: url}, {Key: dimParamKey, Value: "4"}, }, } @@ -95,7 +97,6 @@ func (s *VertextAITextEmbeddingProviderSuite) TestEmbedding() { ret, _ := provder.CallEmbedding(data, false, SearchMode) s.Equal([][]float32{{0.0, 0.1, 0.2, 0.3}, {1.0, 1.1, 1.2, 1.3}, {2.0, 2.1, 2.2, 2.3}}, ret) } - } func (s *VertextAITextEmbeddingProviderSuite) TestEmbeddingDimNotMatch() {