From 7f7602247167d020eba7fc3f6858b0226995be89 Mon Sep 17 00:00:00 2001 From: "junjie.jiang" Date: Mon, 23 Sep 2024 20:57:11 +0800 Subject: [PATCH] Add embedding runner Signed-off-by: junjie.jiang --- internal/models/openai_embedding.go | 70 ++-- internal/models/openai_embedding_test.go | 79 ++-- internal/util/function/function.go | 1 + internal/util/function/function_base.go | 67 ++++ .../function/openai_embedding_function.go | 205 +++++++++++ .../openai_embedding_function_test.go | 339 ++++++++++++++++++ 6 files changed, 694 insertions(+), 67 deletions(-) create mode 100644 internal/util/function/function_base.go create mode 100644 internal/util/function/openai_embedding_function.go create mode 100644 internal/util/function/openai_embedding_function_test.go diff --git a/internal/models/openai_embedding.go b/internal/models/openai_embedding.go index dbb568377c648..70e8e2508a14c 100644 --- a/internal/models/openai_embedding.go +++ b/internal/models/openai_embedding.go @@ -22,15 +22,10 @@ import ( "fmt" "io" "net/http" + "sort" "time" ) -const ( - TextEmbeddingAda002 string = "text-embedding-ada-002" - TextEmbedding3Small string = "text-embedding-3-small" - TextEmbedding3Large string = "text-embedding-3-large" -) - type EmbeddingRequest struct { // ID of the model to use. @@ -84,6 +79,16 @@ type EmbeddingResponse struct { Usage Usage `json:"usage"` } + +type ByIndex struct { + resp *EmbeddingResponse +} + +func (eb *ByIndex) Len() int { return len(eb.resp.Data) } +func (eb *ByIndex) Swap(i, j int) { eb.resp.Data[i], eb.resp.Data[j] = eb.resp.Data[j], eb.resp.Data[i] } +func (eb *ByIndex) Less(i, j int) bool { return eb.resp.Data[i].Index < eb.resp.Data[j].Index } + + type ErrorInfo struct { Code string `json:"code"` Message string `json:"message"` @@ -96,27 +101,28 @@ type EmbedddingError struct { } type OpenAIEmbeddingClient struct { - api_key string - uri string - model_name string + apiKey string + url string } func (c *OpenAIEmbeddingClient) Check() error { - if c.model_name != TextEmbeddingAda002 && c.model_name != TextEmbedding3Small && c.model_name != TextEmbedding3Large { - return fmt.Errorf("Unsupported model: %s, only support [%s, %s, %s]", - c.model_name, TextEmbeddingAda002, TextEmbedding3Small, TextEmbedding3Large) - } - - if c.api_key == "" { + if c.apiKey == "" { return fmt.Errorf("OpenAI api key is empty") } - if c.uri == "" { - return fmt.Errorf("OpenAI embedding uri is empty") + if c.url == "" { + return fmt.Errorf("OpenAI embedding url is empty") } return nil } +func NewOpenAIEmbeddingClient(apiKey string, url string) OpenAIEmbeddingClient{ + return OpenAIEmbeddingClient{ + apiKey: apiKey, + url: url, + } +} + func (c *OpenAIEmbeddingClient) send(client *http.Client, req *http.Request, res *EmbeddingResponse) error { // call openai @@ -143,9 +149,9 @@ func (c *OpenAIEmbeddingClient) send(client *http.Client, req *http.Request, res return nil } -func (c *OpenAIEmbeddingClient) sendWithRetry(client *http.Client, req *http.Request,res *EmbeddingResponse, max_retries int) error { +func (c *OpenAIEmbeddingClient) sendWithRetry(client *http.Client, req *http.Request,res *EmbeddingResponse, maxRetries int) error { var err error - for i := 0; i < max_retries; i++ { + for i := 0; i < maxRetries; i++ { err = c.send(client, req, res) if err == nil { return nil @@ -154,9 +160,9 @@ func (c *OpenAIEmbeddingClient) sendWithRetry(client *http.Client, req *http.Req return err } -func (c *OpenAIEmbeddingClient) Embedding(texts []string, dim int, user string, timeout_sec time.Duration) (EmbeddingResponse, error) { +func (c *OpenAIEmbeddingClient) Embedding(modelName string, texts []string, dim int, user string, timeoutSec time.Duration) (*EmbeddingResponse, error) { var r EmbeddingRequest - r.Model = c.model_name + r.Model = modelName r.Input = texts r.EncodingFormat = "float" if user != "" { @@ -166,27 +172,31 @@ func (c *OpenAIEmbeddingClient) Embedding(texts []string, dim int, user string, r.Dimensions = dim } - var res EmbeddingResponse data, err := json.Marshal(r) if err != nil { - return res, err + return nil, err } // call openai - if timeout_sec <= 0 { - timeout_sec = 30 + if timeoutSec <= 0 { + timeoutSec = 30 } client := &http.Client{ - Timeout: timeout_sec * time.Second, + Timeout: timeoutSec * time.Second, } - req, err := http.NewRequest("POST" , c.uri, bytes.NewBuffer(data)) + req, err := http.NewRequest("POST" , c.url, bytes.NewBuffer(data)) if err != nil { - return res, err + return nil, err } req.Header.Set("Content-Type", "application/json") - req.Header.Set("api-key", c.api_key) + req.Header.Set("api-key", c.apiKey) + var res EmbeddingResponse err = c.sendWithRetry(client, req, &res, 3) - return res, err + if err != nil { + return nil, err + } + sort.Sort(&ByIndex{&res}) + return &res, err } diff --git a/internal/models/openai_embedding_test.go b/internal/models/openai_embedding_test.go index 788e95e84f1cd..eb31b9c23dffc 100644 --- a/internal/models/openai_embedding_test.go +++ b/internal/models/openai_embedding_test.go @@ -17,41 +17,34 @@ package models import ( - // "bytes" "encoding/json" "fmt" "net/http" "net/http/httptest" "testing" "time" + "sync/atomic" "github.com/stretchr/testify/assert" ) func TestEmbeddingClientCheck(t *testing.T) { { - c := OpenAIEmbeddingClient{"mock_key", "mock_uri", "unknow_model"} + c := OpenAIEmbeddingClient{"", "mock_uri"} err := c.Check(); assert.True(t, err != nil) fmt.Println(err) } { - c := OpenAIEmbeddingClient{"", "mock_uri", TextEmbeddingAda002} + c := OpenAIEmbeddingClient{"mock_key", ""} err := c.Check(); assert.True(t, err != nil) fmt.Println(err) } { - c := OpenAIEmbeddingClient{"mock_key", "", TextEmbedding3Small} - err := c.Check(); - assert.True(t, err != nil) - fmt.Println(err) - } - - { - c := OpenAIEmbeddingClient{"mock_key", "mock_uri", TextEmbedding3Small} + c := OpenAIEmbeddingClient{"mock_key", "mock_uri"} err := c.Check(); assert.True(t, err == nil) } @@ -61,7 +54,7 @@ func TestEmbeddingClientCheck(t *testing.T) { func TestEmbeddingOK(t *testing.T) { var res EmbeddingResponse res.Object = "list" - res.Model = TextEmbedding3Small + res.Model = "text-embedding-3-small" res.Data = []EmbeddingData{ { Object: "embedding", @@ -84,12 +77,12 @@ func TestEmbeddingOK(t *testing.T) { url := ts.URL { - c := OpenAIEmbeddingClient{"mock_key", url, TextEmbedding3Small} + c := OpenAIEmbeddingClient{"mock_key", url} err := c.Check(); assert.True(t, err == nil) - ret, err := c.Embedding([]string{"sentence"}, 0, "", 0) + ret, err := c.Embedding("text-embedding-3-small", []string{"sentence"}, 0, "", 0) assert.True(t, err == nil) - assert.Equal(t, ret, res) + assert.Equal(t, ret, &res) } } @@ -97,24 +90,34 @@ func TestEmbeddingOK(t *testing.T) { func TestEmbeddingRetry(t *testing.T) { var res EmbeddingResponse res.Object = "list" - res.Model = TextEmbedding3Small + res.Model = "text-embedding-3-small" res.Data = []EmbeddingData{ + { + Object: "embedding", + Embedding: []float32{1.1, 2.2, 3.2, 4.5}, + Index: 2, + }, { Object: "embedding", Embedding: []float32{1.1, 2.2, 3.3, 4.4}, Index: 0, }, + { + Object: "embedding", + Embedding: []float32{1.1, 2.2, 3.2, 4.3}, + Index: 1, + }, } res.Usage = Usage{ PromptTokens: 1, TotalTokens: 100, } - var count = 0 + var count int32 = 0 ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if count < 2 { - count += 1 + if atomic.LoadInt32(&count) < 2 { + atomic.AddInt32(&count, 1) w.WriteHeader(http.StatusUnauthorized) } else { w.WriteHeader(http.StatusOK) @@ -127,22 +130,26 @@ func TestEmbeddingRetry(t *testing.T) { url := ts.URL { - c := OpenAIEmbeddingClient{"mock_key", url, TextEmbedding3Small} + c := OpenAIEmbeddingClient{"mock_key", url} err := c.Check(); assert.True(t, err == nil) - ret, err := c.Embedding([]string{"sentence"}, 0, "", 0) + ret, err := c.Embedding("text-embedding-3-small", []string{"sentence"}, 0, "", 0) assert.True(t, err == nil) - assert.Equal(t, ret, res) - assert.Equal(t, count, 2) + assert.Equal(t, ret.Usage, res.Usage) + assert.Equal(t, ret.Object, res.Object) + assert.Equal(t, ret.Model, res.Model) + assert.Equal(t, ret.Data[0], res.Data[1]) + assert.Equal(t, ret.Data[1], res.Data[2]) + assert.Equal(t, ret.Data[2], res.Data[0]) + assert.Equal(t, atomic.LoadInt32(&count), int32(2)) } } func TestEmbeddingFailed(t *testing.T) { - var count = 0 - + var count int32 = 0 ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - count += 1 + atomic.AddInt32(&count, 1) w.WriteHeader(http.StatusUnauthorized) })) @@ -150,36 +157,34 @@ func TestEmbeddingFailed(t *testing.T) { url := ts.URL { - c := OpenAIEmbeddingClient{"mock_key", url, TextEmbedding3Small} + c := OpenAIEmbeddingClient{"mock_key", url} err := c.Check(); assert.True(t, err == nil) - _, err = c.Embedding([]string{"sentence"}, 0, "", 0) + _, err = c.Embedding("text-embedding-3-small", []string{"sentence"}, 0, "", 0) assert.True(t, err != nil) - assert.Equal(t, count, 3) + assert.Equal(t, atomic.LoadInt32(&count), int32(3)) } } func TestTimeout(t *testing.T) { - var st = "Doing" - + var st int32 = 0 ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { time.Sleep(3 * time.Second) - st = "Done" + atomic.AddInt32(&st, 1) w.WriteHeader(http.StatusUnauthorized) - })) defer ts.Close() url := ts.URL { - c := OpenAIEmbeddingClient{"mock_key", url, TextEmbedding3Small} + c := OpenAIEmbeddingClient{"mock_key", url} err := c.Check(); assert.True(t, err == nil) - _, err = c.Embedding([]string{"sentence"}, 0, "", 1) + _, err = c.Embedding("text-embedding-3-small", []string{"sentence"}, 0, "", 1) assert.True(t, err != nil) - assert.Equal(t, st, "Doing") + assert.Equal(t, atomic.LoadInt32(&st), int32(0)) time.Sleep(3 * time.Second) - assert.Equal(t, st, "Done") + assert.Equal(t, atomic.LoadInt32(&st), int32(1)) } } diff --git a/internal/util/function/function.go b/internal/util/function/function.go index a9056af41298d..9eeaa110c3d03 100644 --- a/internal/util/function/function.go +++ b/internal/util/function/function.go @@ -31,6 +31,7 @@ type FunctionRunner interface { GetOutputFields() []*schemapb.FieldSchema } + func NewFunctionRunner(coll *schemapb.CollectionSchema, schema *schemapb.FunctionSchema) (FunctionRunner, error) { switch schema.GetType() { case schemapb.FunctionType_BM25: diff --git a/internal/util/function/function_base.go b/internal/util/function/function_base.go new file mode 100644 index 0000000000000..54cd55d18d496 --- /dev/null +++ b/internal/util/function/function_base.go @@ -0,0 +1,67 @@ +/* + * # Licensed to the LF AI & Data foundation under one + * # or more contributor license agreements. See the NOTICE file + * # distributed with this work for additional information + * # regarding copyright ownership. The ASF licenses this file + * # to you under the Apache License, Version 2.0 (the + * # "License"); you may not use this file except in compliance + * # with the License. You may obtain a copy of the License at + * # + * # http://www.apache.org/licenses/LICENSE-2.0 + * # + * # Unless required by applicable law or agreed to in writing, software + * # distributed under the License is distributed on an "AS IS" BASIS, + * # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * # See the License for the specific language governing permissions and + * # limitations under the License. + */ + +package function + +import ( + "fmt" + + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" +) + + +type RunnerMode int + +const ( + InsertMode RunnerMode = iota + SearchMode +) + + +type FunctionBase struct { + schema *schemapb.FunctionSchema + outputFields []*schemapb.FieldSchema + mode RunnerMode +} + +func NewBase(coll *schemapb.CollectionSchema, schema *schemapb.FunctionSchema, mode RunnerMode) (*FunctionBase, error) { + var base FunctionBase + base.schema = schema + base.mode = mode + for _, field_id := range schema.GetOutputFieldIds() { + for _, field := range coll.GetFields() { + if field.GetFieldID() == field_id { + base.outputFields = append(base.outputFields, field) + break + } + } + } + + if len(base.outputFields) != len(schema.GetOutputFieldIds()) { + return &base, fmt.Errorf("Collection [%s]'s function [%s]'s outputs mismatch schema", coll.Name, schema.Name) + } + return &base, nil +} + +func (base *FunctionBase) GetSchema() *schemapb.FunctionSchema { + return base.schema +} + +func (base *FunctionBase) GetOutputFields() []*schemapb.FieldSchema { + return base.outputFields +} diff --git a/internal/util/function/openai_embedding_function.go b/internal/util/function/openai_embedding_function.go new file mode 100644 index 0000000000000..10182cf9fa7cc --- /dev/null +++ b/internal/util/function/openai_embedding_function.go @@ -0,0 +1,205 @@ +/* + * # Licensed to the LF AI & Data foundation under one + * # or more contributor license agreements. See the NOTICE file + * # distributed with this work for additional information + * # regarding copyright ownership. The ASF licenses this file + * # to you under the Apache License, Version 2.0 (the + * # "License"); you may not use this file except in compliance + * # with the License. You may obtain a copy of the License at + * # + * # http://www.apache.org/licenses/LICENSE-2.0 + * # + * # Unless required by applicable law or agreed to in writing, software + * # distributed under the License is distributed on an "AS IS" BASIS, + * # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * # See the License for the specific language governing permissions and + * # limitations under the License. + */ + +package function + +import ( + "fmt" + "os" + "strconv" + "strings" + + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus/internal/models" + "github.com/milvus-io/milvus/pkg/util/typeutil" +) + + +const ( + TextEmbeddingAda002 string = "text-embedding-ada-002" + TextEmbedding3Small string = "text-embedding-3-small" + TextEmbedding3Large string = "text-embedding-3-large" +) + +const ( + maxBatch = 128 + timeoutSec = 60 + maxRowNum = 60 * maxBatch +) + +const ( + ModelNameParamKey string = "model_name" + DimParamKey string = "dim" + UserParamKey string = "user" + OpenaiEmbeddingUrlParamKey string = "embedding_url" + OpenaiApiKeyParamKey string = "api_key" +) + + +type OpenAIEmbeddingFunction struct { + base *FunctionBase + fieldDim int64 + + client *models.OpenAIEmbeddingClient + modelName string + embedDimParam int64 + user string +} + +func createOpenAIEmbeddingClient(apiKey string, url string) (*models.OpenAIEmbeddingClient, error) { + if apiKey == "" { + apiKey = os.Getenv("OPENAI_API_KEY") + } + if apiKey == "" { + return nil, fmt.Errorf("The apiKey configuration was not found in the environment variables") + } + + if url == "" { + url = os.Getenv("OPENAI_EMBEDDING_URL") + } + if url == "" { + url = "https://api.openai.com/v1/embeddings" + } + c := models.NewOpenAIEmbeddingClient(apiKey, url) + return &c, nil +} + +func NewOpenAIEmbeddingFunction(coll *schemapb.CollectionSchema, schema *schemapb.FunctionSchema, mode RunnerMode) (*OpenAIEmbeddingFunction, error) { + if len(schema.GetOutputFieldIds()) != 1 { + return nil, fmt.Errorf("OpenAIEmbedding function should only have one output field, but now %d", len(schema.GetOutputFieldIds())) + } + + base, err := NewBase(coll, schema, mode) + if err != nil { + return nil, err + } + + if base.outputFields[0].DataType != schemapb.DataType_FloatVector { + return nil, fmt.Errorf("Output field not match, openai embedding needs [%s], got [%s]", + schemapb.DataType_name[int32(schemapb.DataType_FloatVector)], + schemapb.DataType_name[int32(base.outputFields[0].DataType)]) + } + + fieldDim, err := typeutil.GetDim(base.outputFields[0]) + if err != nil { + return nil, err + } + var apiKey, url, modelName, user string + var dim int64 + + for _, param := range schema.Params { + switch strings.ToLower(param.Key) { + case ModelNameParamKey: + modelName = param.Value + case DimParamKey: + dim, err := strconv.ParseInt(param.Value, 10, 64) + if err != nil { + return nil, fmt.Errorf("dim [%s] is not int", param.Value) + } + + if dim != 0 && dim != fieldDim { + return nil, fmt.Errorf("Dim in field's schema is [%d], but embeding dim is [%d]", fieldDim, dim) + } + case UserParamKey: + user = param.Value + case OpenaiApiKeyParamKey: + apiKey = param.Value + case OpenaiEmbeddingUrlParamKey: + url = param.Value + default: + } + } + + c, err := createOpenAIEmbeddingClient(apiKey, url) + if err != nil { + return nil, err + } + + runner := OpenAIEmbeddingFunction{ + base: base, + client: c, + fieldDim: fieldDim, + modelName: modelName, + user: user, + embedDimParam: dim, + } + + if runner.modelName != TextEmbeddingAda002 && runner.modelName != TextEmbedding3Small && runner.modelName != TextEmbedding3Large { + return nil, fmt.Errorf("Unsupported model: %s, only support [%s, %s, %s]", + runner.modelName, TextEmbeddingAda002, TextEmbedding3Small, TextEmbedding3Large) + } + return &runner, nil +} + +func (runner *OpenAIEmbeddingFunction) Run(inputs []*schemapb.FieldData) ([]*schemapb.FieldData, error) { + if len(inputs) != 1 { + return nil, fmt.Errorf("OpenAIEmbedding function only receives one input, bug got [%d]", len(inputs)) + } + + if inputs[0].Type != schemapb.DataType_VarChar { + return nil, fmt.Errorf("OpenAIEmbedding only supports varchar field, the input is not varchar") + } + + texts := inputs[0].GetScalars().GetStringData().GetData() + if texts == nil { + return nil, fmt.Errorf("Input texts is empty") + } + + numRows := len(texts) + if numRows > maxRowNum { + return nil, fmt.Errorf("OpenAI embedding supports up to [%d] pieces of data at a time, got [%d]", maxRowNum, numRows) + } + + var output_field schemapb.FieldData + output_field.FieldId = runner.base.outputFields[0].FieldID + output_field.FieldName = runner.base.outputFields[0].Name + output_field.Type = runner.base.outputFields[0].DataType + output_field.IsDynamic = runner.base.outputFields[0].IsDynamic + data := make([]float32, 0, numRows * int(runner.fieldDim)) + for i := 0; i < numRows; i += maxBatch { + end := i + maxBatch + if end > numRows { + end = numRows + } + resp, err := runner.client.Embedding(runner.modelName, texts[i:end], int(runner.embedDimParam), runner.user, timeoutSec) + if err != nil { + return nil, err + } + if end - i != len(resp.Data) { + return nil, fmt.Errorf("The texts number is [%d], but got embedding number [%d]", end - i, len(resp.Data)) + } + for _, item := range resp.Data { + if len(item.Embedding) != int(runner.fieldDim) { + return nil, fmt.Errorf("Dim in field's schema is [%d], but embeding dim is [%d]", + runner.fieldDim, len(resp.Data[0].Embedding)) + } + data = append(data, item.Embedding...) + } + } + output_field.Field = &schemapb.FieldData_Vectors{ + Vectors: &schemapb.VectorField{ + Data: &schemapb.VectorField_FloatVector{ + FloatVector: &schemapb.FloatArray{ + Data: data, + }, + }, + Dim: runner.fieldDim, + }, + } + return []*schemapb.FieldData{&output_field}, nil +} diff --git a/internal/util/function/openai_embedding_function_test.go b/internal/util/function/openai_embedding_function_test.go new file mode 100644 index 0000000000000..68420cbeddbeb --- /dev/null +++ b/internal/util/function/openai_embedding_function_test.go @@ -0,0 +1,339 @@ +/* + * # Licensed to the LF AI & Data foundation under one + * # or more contributor license agreements. See the NOTICE file + * # distributed with this work for additional information + * # regarding copyright ownership. The ASF licenses this file + * # to you under the Apache License, Version 2.0 (the + * # "License"); you may not use this file except in compliance + * # with the License. You may obtain a copy of the License at + * # + * # http://www.apache.org/licenses/LICENSE-2.0 + * # + * # Unless required by applicable law or agreed to in writing, software + * # distributed under the License is distributed on an "AS IS" BASIS, + * # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * # See the License for the specific language governing permissions and + * # limitations under the License. + */ + + +package function + +import ( + "io" + "fmt" + "testing" + "net/http" + "net/http/httptest" + "encoding/json" + + "github.com/stretchr/testify/suite" + + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" + + "github.com/milvus-io/milvus/internal/models" +) + + +func TestOpenAIEmbeddingFunction(t *testing.T) { + suite.Run(t, new(OpenAIEmbeddingFunctionSuite)) +} + +type OpenAIEmbeddingFunctionSuite struct { + suite.Suite + schema *schemapb.CollectionSchema +} + +func (s *OpenAIEmbeddingFunctionSuite) SetupTest() { + s.schema = &schemapb.CollectionSchema{ + Name: "test", + Fields: []*schemapb.FieldSchema{ + {FieldID: 100, Name: "int64", DataType: schemapb.DataType_Int64}, + {FieldID: 101, Name: "text", DataType: schemapb.DataType_VarChar}, + {FieldID: 102, Name: "vector", DataType: schemapb.DataType_FloatVector, + TypeParams: []*commonpb.KeyValuePair{ + {Key: "dim", Value: "4"}, + }}, + }, + } +} + +func createData(texts []string) []*schemapb.FieldData{ + data := []*schemapb.FieldData{} + f := schemapb.FieldData{ + Type: schemapb.DataType_VarChar, + FieldId: 101, + IsDynamic: false, + Field: &schemapb.FieldData_Scalars{ + Scalars: &schemapb.ScalarField{ + Data: &schemapb.ScalarField_StringData{ + StringData: &schemapb.StringArray{ + Data: texts, + }, + }, + }, + }, + } + data = append(data, &f) + return data +} + +func createEmbedding(texts []string, dim int) [][]float32 { + embeddings := make([][]float32, 0) + for i := 0; i < len(texts); i++ { + f := float32(i) + emb := make([]float32, 0) + for j := 0; j < dim; j++ { + emb = append(emb, f + float32(j) * 0.1) + } + embeddings = append(embeddings, emb) + } + return embeddings +} + +func createRunner(url string, schema *schemapb.CollectionSchema) (*OpenAIEmbeddingFunction, error) { + return NewOpenAIEmbeddingFunction(schema, &schemapb.FunctionSchema{ + Name: "test", + Type: schemapb.FunctionType_Unknown, + InputFieldIds: []int64{101}, + OutputFieldIds: []int64{102}, + Params: []*commonpb.KeyValuePair{ + {Key: ModelNameParamKey, Value: "text-embedding-ada-002"}, + {Key: OpenaiApiKeyParamKey, Value: "mock"}, + {Key: OpenaiEmbeddingUrlParamKey, Value: url}, + }, + }, InsertMode) +} + +func (s *OpenAIEmbeddingFunctionSuite) TestEmbedding() { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var req models.EmbeddingRequest + body, _ := io.ReadAll(r.Body) + defer r.Body.Close() + json.Unmarshal(body, &req) + + var res models.EmbeddingResponse + res.Object = "list" + res.Model = "text-embedding-3-small" + embs := createEmbedding(req.Input, 4) + for i := 0; i < len(req.Input); i++ { + res.Data = append(res.Data, models.EmbeddingData{ + Object: "embedding", + Embedding: embs[i], + Index: i, + }) + } + + res.Usage = models.Usage{ + PromptTokens: 1, + TotalTokens: 100, + } + w.WriteHeader(http.StatusOK) + data, _ := json.Marshal(res) + w.Write(data) + + })) + + defer ts.Close() + runner, err := createRunner(ts.URL, s.schema) + s.NoError(err) + { + data := createData([]string{"sentence"}) + ret, err2 := runner.Run(data) + s.NoError(err2) + s.Equal(1, len(ret)) + s.Equal(int64(4), ret[0].GetVectors().Dim) + s.Equal([]float32{0.0, 0.1, 0.2, 0.3}, ret[0].GetVectors().GetFloatVector().Data) + } + { + data := createData([]string{"sentence 1", "sentence 2", "sentence 3"}) + ret, _ := runner.Run(data) + s.Equal([]float32{0.0, 0.1, 0.2, 0.3, 1.0, 1.1, 1.2, 1.3, 2.0, 2.1, 2.2, 2.3}, ret[0].GetVectors().GetFloatVector().Data) + } +} + +func (s *OpenAIEmbeddingFunctionSuite) TestEmbeddingDimNotMatch() { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var res models.EmbeddingResponse + res.Object = "list" + res.Model = "text-embedding-3-small" + res.Data = append(res.Data, models.EmbeddingData{ + Object: "embedding", + Embedding: []float32{1.0, 1.0, 1.0, 1.0}, + Index: 0, + }) + + res.Data = append(res.Data, models.EmbeddingData{ + Object: "embedding", + Embedding: []float32{1.0, 1.0}, + Index: 1, + }) + res.Usage = models.Usage{ + PromptTokens: 1, + TotalTokens: 100, + } + w.WriteHeader(http.StatusOK) + data, _ := json.Marshal(res) + w.Write(data) + })) + + defer ts.Close() + runner, err := createRunner(ts.URL, s.schema) + s.NoError(err) + + // embedding dim not match + data := createData([]string{"sentence", "sentence"}) + _, err2 := runner.Run(data) + s.Error(err2) + fmt.Println(err2.Error()) + // s.NoError(err2) +} + +func (s *OpenAIEmbeddingFunctionSuite) TestEmbeddingNubmerNotMatch() { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var res models.EmbeddingResponse + res.Object = "list" + res.Model = "text-embedding-3-small" + res.Data = append(res.Data, models.EmbeddingData{ + Object: "embedding", + Embedding: []float32{1.0, 1.0, 1.0, 1.0}, + Index: 0, + }) + res.Usage = models.Usage{ + PromptTokens: 1, + TotalTokens: 100, + } + w.WriteHeader(http.StatusOK) + data, _ := json.Marshal(res) + w.Write(data) + })) + + defer ts.Close() + runner, err := createRunner(ts.URL, s.schema) + + s.NoError(err) + + // embedding dim not match + data := createData([]string{"sentence", "sentence2"}) + _, err2 := runner.Run(data) + s.Error(err2) + fmt.Println(err2.Error()) + // s.NoError(err2) +} + +func (s *OpenAIEmbeddingFunctionSuite) TestRunnerParamsErr() { + // outputfield datatype mismatch + { + schema := &schemapb.CollectionSchema{ + Name: "test", + Fields: []*schemapb.FieldSchema{ + {FieldID: 100, Name: "int64", DataType: schemapb.DataType_Int64}, + {FieldID: 101, Name: "text", DataType: schemapb.DataType_VarChar}, + {FieldID: 102, Name: "vector", DataType: schemapb.DataType_BFloat16Vector, + TypeParams: []*commonpb.KeyValuePair{ + {Key: "dim", Value: "4"}, + }}, + }, + } + + _, err := NewOpenAIEmbeddingFunction(schema, &schemapb.FunctionSchema{ + Name: "test", + Type: schemapb.FunctionType_Unknown, + InputFieldIds: []int64{101}, + OutputFieldIds: []int64{102}, + Params: []*commonpb.KeyValuePair{ + {Key: ModelNameParamKey, Value: "text-embedding-ada-002"}, + {Key: DimParamKey, Value: "4"}, + {Key: OpenaiApiKeyParamKey, Value: "mock"}, + {Key: OpenaiEmbeddingUrlParamKey, Value: "mock"}, + }, + }, InsertMode) + s.Error(err) + fmt.Println(err.Error()) + } + + // outputfield number mismatc + { + schema := &schemapb.CollectionSchema{ + Name: "test", + Fields: []*schemapb.FieldSchema{ + {FieldID: 100, Name: "int64", DataType: schemapb.DataType_Int64}, + {FieldID: 101, Name: "text", DataType: schemapb.DataType_VarChar}, + {FieldID: 102, Name: "vector", DataType: schemapb.DataType_FloatVector, + TypeParams: []*commonpb.KeyValuePair{ + {Key: "dim", Value: "4"}, + }}, + {FieldID: 103, Name: "vector", DataType: schemapb.DataType_FloatVector, + TypeParams: []*commonpb.KeyValuePair{ + {Key: "dim", Value: "4"}, + }}, + }, + } + _, err := NewOpenAIEmbeddingFunction(schema, &schemapb.FunctionSchema{ + Name: "test", + Type: schemapb.FunctionType_Unknown, + InputFieldIds: []int64{101}, + OutputFieldIds: []int64{102, 103}, + Params: []*commonpb.KeyValuePair{ + {Key: ModelNameParamKey, Value: "text-embedding-ada-002"}, + {Key: DimParamKey, Value: "4"}, + {Key: OpenaiApiKeyParamKey, Value: "mock"}, + {Key: OpenaiEmbeddingUrlParamKey, Value: "mock"}, + }, + }, InsertMode) + s.Error(err) + fmt.Println(err.Error()) + } + + // outputfield miss + { + _, err := NewOpenAIEmbeddingFunction(s.schema, &schemapb.FunctionSchema{ + Name: "test", + Type: schemapb.FunctionType_Unknown, + InputFieldIds: []int64{101}, + OutputFieldIds: []int64{103}, + Params: []*commonpb.KeyValuePair{ + {Key: ModelNameParamKey, Value: "text-embedding-ada-002"}, + {Key: DimParamKey, Value: "4"}, + {Key: OpenaiApiKeyParamKey, Value: "mock"}, + {Key: OpenaiEmbeddingUrlParamKey, Value: "mock"}, + }, + }, InsertMode) + s.Error(err) + fmt.Println(err.Error()) + } + + // error model name + { + _, err := NewOpenAIEmbeddingFunction(s.schema, &schemapb.FunctionSchema{ + Name: "test", + Type: schemapb.FunctionType_Unknown, + InputFieldIds: []int64{101}, + OutputFieldIds: []int64{102}, + Params: []*commonpb.KeyValuePair{ + {Key: ModelNameParamKey, Value: "text-embedding-ada-004"}, + {Key: DimParamKey, Value: "4"}, + {Key: OpenaiApiKeyParamKey, Value: "mock"}, + {Key: OpenaiEmbeddingUrlParamKey, Value: "mock"}, + }, + }, InsertMode) + s.Error(err) + fmt.Println(err.Error()) + } + + // no openai api key + { + _, err := NewOpenAIEmbeddingFunction(s.schema, &schemapb.FunctionSchema{ + Name: "test", + Type: schemapb.FunctionType_Unknown, + InputFieldIds: []int64{101}, + OutputFieldIds: []int64{102}, + Params: []*commonpb.KeyValuePair{ + {Key: ModelNameParamKey, Value: "text-embedding-ada-003"}, + }, + }, InsertMode) + s.Error(err) + fmt.Println(err.Error()) + } +}