From 38243dd796f5615bab9662bb32e7c298b4aa4a64 Mon Sep 17 00:00:00 2001 From: "junjie.jiang" Date: Sun, 29 Sep 2024 20:13:19 +0800 Subject: [PATCH] Add function executor Signed-off-by: junjie.jiang --- internal/util/function/function_base.go | 12 +- internal/util/function/function_executor.go | 120 ++++++++++++ .../util/function/function_executor_test.go | 171 ++++++++++++++++++ .../function/openai_embedding_function.go | 32 ++-- .../openai_embedding_function_test.go | 12 +- 5 files changed, 318 insertions(+), 29 deletions(-) create mode 100644 internal/util/function/function_executor.go create mode 100644 internal/util/function/function_executor_test.go diff --git a/internal/util/function/function_base.go b/internal/util/function/function_base.go index 54cd55d18d496..9727e8948040c 100644 --- a/internal/util/function/function_base.go +++ b/internal/util/function/function_base.go @@ -25,24 +25,14 @@ import ( ) -type RunnerMode int - -const ( - InsertMode RunnerMode = iota - SearchMode -) - - type FunctionBase struct { schema *schemapb.FunctionSchema outputFields []*schemapb.FieldSchema - mode RunnerMode } -func NewBase(coll *schemapb.CollectionSchema, schema *schemapb.FunctionSchema, mode RunnerMode) (*FunctionBase, error) { +func NewBase(coll *schemapb.CollectionSchema, schema *schemapb.FunctionSchema) (*FunctionBase, error) { var base FunctionBase base.schema = schema - base.mode = mode for _, field_id := range schema.GetOutputFieldIds() { for _, field := range coll.GetFields() { if field.GetFieldID() == field_id { diff --git a/internal/util/function/function_executor.go b/internal/util/function/function_executor.go new file mode 100644 index 0000000000000..8f72a7538daa0 --- /dev/null +++ b/internal/util/function/function_executor.go @@ -0,0 +1,120 @@ +/* + * # Licensed to the LF AI & Data foundation under one + * # or more contributor license agreements. See the NOTICE file + * # distributed with this work for additional information + * # regarding copyright ownership. The ASF licenses this file + * # to you under the Apache License, Version 2.0 (the + * # "License"); you may not use this file except in compliance + * # with the License. You may obtain a copy of the License at + * # + * # http://www.apache.org/licenses/LICENSE-2.0 + * # + * # Unless required by applicable law or agreed to in writing, software + * # distributed under the License is distributed on an "AS IS" BASIS, + * # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * # See the License for the specific language governing permissions and + * # limitations under the License. + */ + + +package function + + +import ( + "fmt" + "sync" + + "github.com/milvus-io/milvus/pkg/mq/msgstream" + "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" +) + + +type Runner interface { + GetSchema() *schemapb.FunctionSchema + GetOutputFields() []*schemapb.FieldSchema + + MaxBatch() int + ProcessInsert(inputs []*schemapb.FieldData) ([]*schemapb.FieldData, error) +} + + +type FunctionExecutor struct { + runners []Runner +} + + +func newFunctionExecutor(schema *schemapb.CollectionSchema) (*FunctionExecutor, error) { + executor := new(FunctionExecutor) + for _, f_schema := range schema.Functions { + switch f_schema.GetType() { + case schemapb.FunctionType_BM25: + case schemapb.FunctionType_OpenAIEmbedding: + f, err := NewOpenAIEmbeddingFunction(schema, f_schema) + if err != nil { + return nil, err + } + executor.runners = append(executor.runners, f) + default: + return nil, fmt.Errorf("unknown functionRunner type %s", f_schema.GetType().String()) + } + } + return executor, nil +} + +func (executor *FunctionExecutor)processSingeFunction(idx int, msg *msgstream.InsertMsg) ([]*schemapb.FieldData, error){ + runner := executor.runners[idx] + inputs := make([]*schemapb.FieldData, 0, len(runner.GetSchema().InputFieldIds)) + for _, id := range runner.GetSchema().InputFieldIds { + for _, field := range msg.FieldsData{ + if field.FieldId == id { + inputs = append(inputs, field) + } + } + } + + if len(inputs) != len(runner.GetSchema().InputFieldIds) { + return nil, fmt.Errorf("Input field not found") + } + + outputs, err := runner.ProcessInsert(inputs) + if err != nil { + return nil, err + } + return outputs, nil +} + +func (executor *FunctionExecutor)ProcessInsert(msg *msgstream.InsertMsg) error{ + numRows := msg.NumRows + for _, runner := range executor.runners { + if numRows > uint64(runner.MaxBatch()) { + return fmt.Errorf("numRows [%d] > function [%s]'s max batch [%d]", numRows, runner.GetSchema().Name, runner.MaxBatch()) + } + } + + outputs := make(chan []*schemapb.FieldData, len(executor.runners)) + var wg sync.WaitGroup + for idx, _ := range executor.runners { + wg.Add(1) + go func(index int) { + defer wg.Done() + data, err := executor.processSingeFunction(index, msg) + if err != nil { + outputs <- nil + } + outputs <- data + }(idx) + } + + wg.Wait() + close(outputs) + for output := range outputs { + msg.FieldsData = append(msg.FieldsData, output...) + } + return nil +} + + +func (executor *FunctionExecutor)ProcessSearch(msg *milvuspb.SearchRequest) error{ + return nil +} diff --git a/internal/util/function/function_executor_test.go b/internal/util/function/function_executor_test.go new file mode 100644 index 0000000000000..9d1ca8b6833ac --- /dev/null +++ b/internal/util/function/function_executor_test.go @@ -0,0 +1,171 @@ +/* + * # Licensed to the LF AI & Data foundation under one + * # or more contributor license agreements. See the NOTICE file + * # distributed with this work for additional information + * # regarding copyright ownership. The ASF licenses this file + * # to you under the Apache License, Version 2.0 (the + * # "License"); you may not use this file except in compliance + * # with the License. You may obtain a copy of the License at + * # + * # http://www.apache.org/licenses/LICENSE-2.0 + * # + * # Unless required by applicable law or agreed to in writing, software + * # distributed under the License is distributed on an "AS IS" BASIS, + * # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * # See the License for the specific language governing permissions and + * # limitations under the License. + */ + + +package function + + +import ( + "io" + "fmt" + "testing" + "net/http" + "net/http/httptest" + "encoding/json" + + "github.com/stretchr/testify/suite" + + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" + "github.com/milvus-io/milvus/internal/models" + "github.com/milvus-io/milvus/pkg/mq/msgstream" + "github.com/milvus-io/milvus-proto/go-api/v2/msgpb" +) + + +func TestFunctionExecutor(t *testing.T) { + suite.Run(t, new(FunctionExecutorSuite)) +} + +type FunctionExecutorSuite struct { + suite.Suite +} + + +func (s *OpenAIEmbeddingFunctionSuite) creataSchema(url string) *schemapb.CollectionSchema{ + return &schemapb.CollectionSchema{ + Name: "test", + Fields: []*schemapb.FieldSchema{ + {FieldID: 100, Name: "int64", DataType: schemapb.DataType_Int64}, + {FieldID: 101, Name: "text", DataType: schemapb.DataType_VarChar}, + {FieldID: 102, Name: "vector", DataType: schemapb.DataType_FloatVector, + TypeParams: []*commonpb.KeyValuePair{ + {Key: "dim", Value: "4"}, + }}, + {FieldID: 103, Name: "vector2", DataType: schemapb.DataType_FloatVector, + TypeParams: []*commonpb.KeyValuePair{ + {Key: "dim", Value: "4"}, + }}, + }, + Functions: []*schemapb.FunctionSchema{ + { + Name: "test", + Type: schemapb.FunctionType_OpenAIEmbedding, + InputFieldIds: []int64{101}, + OutputFieldIds: []int64{102}, + Params: []*commonpb.KeyValuePair{ + {Key: ModelNameParamKey, Value: "text-embedding-ada-002"}, + {Key: OpenaiApiKeyParamKey, Value: "mock"}, + {Key: OpenaiEmbeddingUrlParamKey, Value: url}, + }, + }, + { + Name: "test", + Type: schemapb.FunctionType_OpenAIEmbedding, + InputFieldIds: []int64{101}, + OutputFieldIds: []int64{103}, + Params: []*commonpb.KeyValuePair{ + {Key: ModelNameParamKey, Value: "text-embedding-ada-002"}, + {Key: OpenaiApiKeyParamKey, Value: "mock"}, + {Key: OpenaiEmbeddingUrlParamKey, Value: url}, + }, + }, + }, + } + +} + +func (s *OpenAIEmbeddingFunctionSuite)createMsg(texts []string) *msgstream.InsertMsg{ + + data := []*schemapb.FieldData{} + f := schemapb.FieldData{ + Type: schemapb.DataType_VarChar, + FieldId: 101, + IsDynamic: false, + Field: &schemapb.FieldData_Scalars{ + Scalars: &schemapb.ScalarField{ + Data: &schemapb.ScalarField_StringData{ + StringData: &schemapb.StringArray{ + Data: texts, + }, + }, + }, + }, + } + data = append(data, &f) + + msg := msgstream.InsertMsg{ + InsertRequest: &msgpb.InsertRequest{ + FieldsData: data, + }, + } + return &msg +} + + +func (s *OpenAIEmbeddingFunctionSuite)createEmbedding(texts []string, dim int) [][]float32 { + embeddings := make([][]float32, 0) + for i := 0; i < len(texts); i++ { + f := float32(i) + emb := make([]float32, 0) + for j := 0; j < dim; j++ { + emb = append(emb, f + float32(j) * 0.1) + } + embeddings = append(embeddings, emb) + } + return embeddings +} + +func (s *OpenAIEmbeddingFunctionSuite) TestExecutor() { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var req models.EmbeddingRequest + body, _ := io.ReadAll(r.Body) + defer r.Body.Close() + json.Unmarshal(body, &req) + + var res models.EmbeddingResponse + res.Object = "list" + res.Model = "text-embedding-3-small" + embs := s.createEmbedding(req.Input, 4) + for i := 0; i < len(req.Input); i++ { + res.Data = append(res.Data, models.EmbeddingData{ + Object: "embedding", + Embedding: embs[i], + Index: i, + }) + } + + res.Usage = models.Usage{ + PromptTokens: 1, + TotalTokens: 100, + } + w.WriteHeader(http.StatusOK) + data, _ := json.Marshal(res) + w.Write(data) + + })) + + defer ts.Close() + schema := s.creataSchema(ts.URL) + exec, err := newFunctionExecutor(schema) + s.NoError(err) + msg := s.createMsg([]string{"sentence", "sentence"}) + exec.ProcessInsert(msg) + fmt.Println(msg) + +} diff --git a/internal/util/function/openai_embedding_function.go b/internal/util/function/openai_embedding_function.go index 10182cf9fa7cc..a9ae3712135aa 100644 --- a/internal/util/function/openai_embedding_function.go +++ b/internal/util/function/openai_embedding_function.go @@ -39,7 +39,6 @@ const ( const ( maxBatch = 128 timeoutSec = 60 - maxRowNum = 60 * maxBatch ) const ( @@ -52,7 +51,7 @@ const ( type OpenAIEmbeddingFunction struct { - base *FunctionBase + FunctionBase fieldDim int64 client *models.OpenAIEmbeddingClient @@ -79,12 +78,12 @@ func createOpenAIEmbeddingClient(apiKey string, url string) (*models.OpenAIEmbed return &c, nil } -func NewOpenAIEmbeddingFunction(coll *schemapb.CollectionSchema, schema *schemapb.FunctionSchema, mode RunnerMode) (*OpenAIEmbeddingFunction, error) { +func NewOpenAIEmbeddingFunction(coll *schemapb.CollectionSchema, schema *schemapb.FunctionSchema) (*OpenAIEmbeddingFunction, error) { if len(schema.GetOutputFieldIds()) != 1 { return nil, fmt.Errorf("OpenAIEmbedding function should only have one output field, but now %d", len(schema.GetOutputFieldIds())) } - base, err := NewBase(coll, schema, mode) + base, err := NewBase(coll, schema) if err != nil { return nil, err } @@ -131,7 +130,7 @@ func NewOpenAIEmbeddingFunction(coll *schemapb.CollectionSchema, schema *schemap } runner := OpenAIEmbeddingFunction{ - base: base, + FunctionBase: *base, client: c, fieldDim: fieldDim, modelName: modelName, @@ -146,7 +145,16 @@ func NewOpenAIEmbeddingFunction(coll *schemapb.CollectionSchema, schema *schemap return &runner, nil } -func (runner *OpenAIEmbeddingFunction) Run(inputs []*schemapb.FieldData) ([]*schemapb.FieldData, error) { +func (runner *OpenAIEmbeddingFunction)MaxBatch() int { + return 5 * maxBatch +} + + +func (runner *OpenAIEmbeddingFunction) ProcessInsert(inputs []*schemapb.FieldData) ([]*schemapb.FieldData, error) { + return runner.Run(inputs) +} + +func (runner *OpenAIEmbeddingFunction) Run( inputs []*schemapb.FieldData) ([]*schemapb.FieldData, error) { if len(inputs) != 1 { return nil, fmt.Errorf("OpenAIEmbedding function only receives one input, bug got [%d]", len(inputs)) } @@ -161,15 +169,15 @@ func (runner *OpenAIEmbeddingFunction) Run(inputs []*schemapb.FieldData) ([]*sch } numRows := len(texts) - if numRows > maxRowNum { - return nil, fmt.Errorf("OpenAI embedding supports up to [%d] pieces of data at a time, got [%d]", maxRowNum, numRows) + if numRows > runner.MaxBatch() { + return nil, fmt.Errorf("OpenAI embedding supports up to [%d] pieces of data at a time, got [%d]", runner.MaxBatch(), numRows) } var output_field schemapb.FieldData - output_field.FieldId = runner.base.outputFields[0].FieldID - output_field.FieldName = runner.base.outputFields[0].Name - output_field.Type = runner.base.outputFields[0].DataType - output_field.IsDynamic = runner.base.outputFields[0].IsDynamic + output_field.FieldId = runner.outputFields[0].FieldID + output_field.FieldName = runner.outputFields[0].Name + output_field.Type = runner.outputFields[0].DataType + output_field.IsDynamic = runner.outputFields[0].IsDynamic data := make([]float32, 0, numRows * int(runner.fieldDim)) for i := 0; i < numRows; i += maxBatch { end := i + maxBatch diff --git a/internal/util/function/openai_embedding_function_test.go b/internal/util/function/openai_embedding_function_test.go index 68420cbeddbeb..49f20941db2bf 100644 --- a/internal/util/function/openai_embedding_function_test.go +++ b/internal/util/function/openai_embedding_function_test.go @@ -103,7 +103,7 @@ func createRunner(url string, schema *schemapb.CollectionSchema) (*OpenAIEmbeddi {Key: OpenaiApiKeyParamKey, Value: "mock"}, {Key: OpenaiEmbeddingUrlParamKey, Value: url}, }, - }, InsertMode) + }) } func (s *OpenAIEmbeddingFunctionSuite) TestEmbedding() { @@ -248,7 +248,7 @@ func (s *OpenAIEmbeddingFunctionSuite) TestRunnerParamsErr() { {Key: OpenaiApiKeyParamKey, Value: "mock"}, {Key: OpenaiEmbeddingUrlParamKey, Value: "mock"}, }, - }, InsertMode) + }) s.Error(err) fmt.Println(err.Error()) } @@ -281,7 +281,7 @@ func (s *OpenAIEmbeddingFunctionSuite) TestRunnerParamsErr() { {Key: OpenaiApiKeyParamKey, Value: "mock"}, {Key: OpenaiEmbeddingUrlParamKey, Value: "mock"}, }, - }, InsertMode) + }) s.Error(err) fmt.Println(err.Error()) } @@ -299,7 +299,7 @@ func (s *OpenAIEmbeddingFunctionSuite) TestRunnerParamsErr() { {Key: OpenaiApiKeyParamKey, Value: "mock"}, {Key: OpenaiEmbeddingUrlParamKey, Value: "mock"}, }, - }, InsertMode) + }) s.Error(err) fmt.Println(err.Error()) } @@ -317,7 +317,7 @@ func (s *OpenAIEmbeddingFunctionSuite) TestRunnerParamsErr() { {Key: OpenaiApiKeyParamKey, Value: "mock"}, {Key: OpenaiEmbeddingUrlParamKey, Value: "mock"}, }, - }, InsertMode) + }) s.Error(err) fmt.Println(err.Error()) } @@ -332,7 +332,7 @@ func (s *OpenAIEmbeddingFunctionSuite) TestRunnerParamsErr() { Params: []*commonpb.KeyValuePair{ {Key: ModelNameParamKey, Value: "text-embedding-ada-003"}, }, - }, InsertMode) + }) s.Error(err) fmt.Println(err.Error()) }