Skip to content

Commit

Permalink
Insert & Upsert support functions
Browse files Browse the repository at this point in the history
Signed-off-by: junjie.jiang <[email protected]>
  • Loading branch information
junjiejiangjjj committed Oct 11, 2024
1 parent 247588f commit 7950316
Show file tree
Hide file tree
Showing 7 changed files with 188 additions and 19 deletions.
13 changes: 13 additions & 0 deletions internal/proxy/task_insert.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/allocator"
"github.com/milvus-io/milvus/internal/util/function"
"github.com/milvus-io/milvus/pkg/log"
"github.com/milvus-io/milvus/pkg/metrics"
"github.com/milvus-io/milvus/pkg/mq/msgstream"
Expand Down Expand Up @@ -132,6 +133,18 @@ func (it *insertTask) PreExecute(ctx context.Context) error {
}
it.schema = schema.CollectionSchema

// Calculate embedding fields
exec, err := function.NewFunctionExecutor(schema.CollectionSchema)
if err != nil {
return err
}

if !exec.Empty() {
if err := exec.ProcessInsert(it.insertMsg); err != nil {
return err
}
}

rowNums := uint32(it.insertMsg.NRows())
// set insertTask.rowIDs
var rowIDBegin UniqueID
Expand Down
128 changes: 128 additions & 0 deletions internal/proxy/task_insert_test.go
Original file line number Diff line number Diff line change
@@ -1,15 +1,25 @@
package proxy

import (
"io"
"fmt"
"context"
"testing"
"net/http"
"net/http/httptest"
"encoding/json"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"

"github.com/milvus-io/milvus/internal/models"
"github.com/milvus-io/milvus/internal/allocator"
"github.com/milvus-io/milvus/internal/mocks"
"github.com/milvus-io/milvus/internal/util/function"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/msgpb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/proto/rootcoordpb"
"github.com/milvus-io/milvus/pkg/mq/msgstream"
"github.com/milvus-io/milvus/pkg/util/merr"
"github.com/milvus-io/milvus/pkg/util/paramtable"
Expand Down Expand Up @@ -308,3 +318,121 @@ func TestMaxInsertSize(t *testing.T) {
assert.ErrorIs(t, err, merr.ErrParameterTooLarge)
})
}

func TestInsertTask_Function(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request){
var req models.EmbeddingRequest
body, _ := io.ReadAll(r.Body)
defer r.Body.Close()
json.Unmarshal(body, &req)

var res models.EmbeddingResponse
res.Object = "list"
res.Model = "text-embedding-3-small"
for i := 0; i < len(req.Input); i++ {
res.Data = append(res.Data, models.EmbeddingData{
Object: "embedding",
Embedding: make([]float32, req.Dimensions),
Index: i,
})
}

res.Usage = models.Usage{
PromptTokens: 1,
TotalTokens: 100,
}
w.WriteHeader(http.StatusOK)
data, _ := json.Marshal(res)
w.Write(data)
}))
defer ts.Close()
data := []*schemapb.FieldData{}
f := schemapb.FieldData{
Type: schemapb.DataType_VarChar,
FieldId: 101,
FieldName: "text",
IsDynamic: false,
Field: &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{
Data: &schemapb.ScalarField_StringData{
StringData: &schemapb.StringArray{
Data: []string{"sentence", "sentence"},
},
},
},
},
}
data = append(data, &f)
collectionName := "TestInsertTask_function"
schema := &schemapb.CollectionSchema{
Name: collectionName,
Description: "TestInsertTask_function",
AutoID: true,
Fields: []*schemapb.FieldSchema{
{FieldID: 100, Name: "id", DataType: schemapb.DataType_Int64, IsPrimaryKey: true, AutoID: true},
{FieldID: 101, Name: "text", DataType: schemapb.DataType_VarChar,
TypeParams: []*commonpb.KeyValuePair{
{Key: "max_length", Value: "200"},
}},
{FieldID: 102, Name: "vector", DataType: schemapb.DataType_FloatVector,
TypeParams: []*commonpb.KeyValuePair{
{Key: "dim", Value: "4"},
}},
},
Functions: []*schemapb.FunctionSchema{
{
Name: "test_function",
Type: schemapb.FunctionType_OpenAIEmbedding,
InputFieldIds: []int64{101},
OutputFieldIds: []int64{102},
Params: []*commonpb.KeyValuePair{
{Key: function.ModelNameParamKey, Value: "text-embedding-ada-002"},
{Key: function.OpenaiApiKeyParamKey, Value: "mock"},
{Key: function.OpenaiEmbeddingUrlParamKey, Value: ts.URL},
{Key: function.DimParamKey, Value: "4"},
},
},
},
}

ctx, cancel := context.WithCancel(context.Background())
defer cancel()
rc := mocks.NewMockRootCoordClient(t)
rc.EXPECT().AllocID(mock.Anything, mock.Anything).Return(&rootcoordpb.AllocIDResponse{
Status: merr.Status(nil),
ID: 11198,
Count: 10,
}, nil)
idAllocator, err := allocator.NewIDAllocator(ctx, rc, 0)
idAllocator.Start()
defer idAllocator.Close()
assert.NoError(t, err)
task := insertTask{
ctx: context.Background(),
insertMsg: &BaseInsertTask{
InsertRequest: &msgpb.InsertRequest{
CollectionName: collectionName,
DbName: "hooooooo",
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_Insert,
},
Version: msgpb.InsertDataVersion_ColumnBased,
FieldsData: data,
NumRows: 2,
},
},
schema: schema,
idAllocator: idAllocator,
}

info := newSchemaInfo(schema)
cache := NewMockCache(t)
cache.On("GetCollectionSchema",
mock.Anything, // context.Context
mock.AnythingOfType("string"),
mock.AnythingOfType("string"),
).Return(info, nil)
globalMetaCache = cache
err = task.PreExecute(ctx)
assert.NoError(t, err)
}
15 changes: 15 additions & 0 deletions internal/proxy/task_upsert.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import (
"github.com/milvus-io/milvus-proto/go-api/v2/msgpb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/allocator"
"github.com/milvus-io/milvus/internal/util/function"
"github.com/milvus-io/milvus/pkg/common"
"github.com/milvus-io/milvus/pkg/log"
"github.com/milvus-io/milvus/pkg/metrics"
Expand Down Expand Up @@ -152,6 +153,20 @@ func (it *upsertTask) insertPreExecute(ctx context.Context) error {
return err
}

// Calculate embedding fields
{
exec, err := function.NewFunctionExecutor(it.schema.CollectionSchema)
if err != nil {
return err
}

if !exec.Empty() {
if err := exec.ProcessInsert(it.upsertMsg.InsertMsg); err != nil {
return err
}
}
}

rowNums := uint32(it.upsertMsg.InsertMsg.NRows())
// set upsertTask.insertRequest.rowIDs
tr := timerecord.NewTimeRecorder("applyPK")
Expand Down
18 changes: 10 additions & 8 deletions internal/util/function/function_executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,8 @@ package function
import (
"fmt"
"sync"

"github.com/milvus-io/milvus/internal/proto/internalpb"
"github.com/milvus-io/milvus/pkg/mq/msgstream"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
)

Expand All @@ -35,15 +34,15 @@ type Runner interface {
GetOutputFields() []*schemapb.FieldSchema

MaxBatch() int
ProcessInsert(inputs []*schemapb.FieldData) ([]*schemapb.FieldData, error)
ProcessInsert(inputs []*schemapb.FieldData) ([]*schemapb.FieldData, error)
ProcessSearch(placeholderGroups [][]byte) ([][]byte, error)
}


type FunctionExecutor struct {
runners []Runner
}

func newFunctionExecutor(schema *schemapb.CollectionSchema) (*FunctionExecutor, error) {
func NewFunctionExecutor(schema *schemapb.CollectionSchema) (*FunctionExecutor, error) {
executor := new(FunctionExecutor)
for _, f_schema := range schema.Functions {
switch f_schema.GetType() {
Expand All @@ -61,6 +60,10 @@ func newFunctionExecutor(schema *schemapb.CollectionSchema) (*FunctionExecutor,
return executor, nil
}

func (executor *FunctionExecutor)Empty() bool {
return len(executor.runners) == 0
}

func (executor *FunctionExecutor)processSingleFunction(idx int, msg *msgstream.InsertMsg) ([]*schemapb.FieldData, error) {
runner := executor.runners[idx]
inputs := make([]*schemapb.FieldData, 0, len(runner.GetSchema().InputFieldIds))
Expand Down Expand Up @@ -119,7 +122,6 @@ func (executor *FunctionExecutor)ProcessInsert(msg *msgstream.InsertMsg) error {
return nil
}


func (executor *FunctionExecutor)ProcessSearch(msg *milvuspb.SearchRequest) error {
return nil
func (executor *FunctionExecutor)ProcessSearch(req *internalpb.SearchRequest) (interface{}, error) {
return nil, nil
}
6 changes: 3 additions & 3 deletions internal/util/function/function_executor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ func (s *FunctionExecutorSuite) TestExecutor() {

defer ts.Close()
schema := s.creataSchema(ts.URL)
exec, err := newFunctionExecutor(schema)
exec, err := NewFunctionExecutor(schema)
s.NoError(err)
msg := s.createMsg([]string{"sentence", "sentence"})
exec.ProcessInsert(msg)
Expand Down Expand Up @@ -198,7 +198,7 @@ func (s *FunctionExecutorSuite) TestErrorEmbedding() {
}))
defer ts.Close()
schema := s.creataSchema(ts.URL)
exec, err := newFunctionExecutor(schema)
exec, err := NewFunctionExecutor(schema)
s.NoError(err)
msg := s.createMsg([]string{"sentence", "sentence"})
err = exec.ProcessInsert(msg)
Expand All @@ -208,6 +208,6 @@ func (s *FunctionExecutorSuite) TestErrorEmbedding() {
func (s *FunctionExecutorSuite) TestErrorSchema() {
schema := s.creataSchema("http://localhost")
schema.Functions[0].Type = schemapb.FunctionType_Unknown
_, err := newFunctionExecutor(schema)
_, err := NewFunctionExecutor(schema)
s.Error(err)
}
19 changes: 15 additions & 4 deletions internal/util/function/openai_embedding_function.go
Original file line number Diff line number Diff line change
Expand Up @@ -151,10 +151,6 @@ func (runner *OpenAIEmbeddingFunction)MaxBatch() int {


func (runner *OpenAIEmbeddingFunction) ProcessInsert(inputs []*schemapb.FieldData) ([]*schemapb.FieldData, error) {
return runner.Run(inputs)
}

func (runner *OpenAIEmbeddingFunction) Run( inputs []*schemapb.FieldData) ([]*schemapb.FieldData, error) {
if len(inputs) != 1 {
return nil, fmt.Errorf("OpenAIEmbedding function only receives one input, bug got [%d]", len(inputs))
}
Expand Down Expand Up @@ -211,3 +207,18 @@ func (runner *OpenAIEmbeddingFunction) Run( inputs []*schemapb.FieldData) ([]*sc
}
return []*schemapb.FieldData{&output_field}, nil
}

func (runner *OpenAIEmbeddingFunction)ProcessSearch(placeholderGroups [][]byte) ([][]byte, error){
if len(placeholderGroups) != 1 {
return nil, fmt.Errorf("OpenAIEmbedding function only receives one input, bug got [%d]", len(placeholderGroups))
}

// get tests from placeholderGroups

// texts := []string{}

// calc embedding

//to placeholderGroups
return nil, nil
}
8 changes: 4 additions & 4 deletions internal/util/function/openai_embedding_function_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -139,15 +139,15 @@ func (s *OpenAIEmbeddingFunctionSuite) TestEmbedding() {
s.NoError(err)
{
data := createData([]string{"sentence"})
ret, err2 := runner.Run(data)
ret, err2 := runner.ProcessInsert(data)
s.NoError(err2)
s.Equal(1, len(ret))
s.Equal(int64(4), ret[0].GetVectors().Dim)
s.Equal([]float32{0.0, 0.1, 0.2, 0.3}, ret[0].GetVectors().GetFloatVector().Data)
}
{
data := createData([]string{"sentence 1", "sentence 2", "sentence 3"})
ret, _ := runner.Run(data)
ret, _ := runner.ProcessInsert(data)
s.Equal([]float32{0.0, 0.1, 0.2, 0.3, 1.0, 1.1, 1.2, 1.3, 2.0, 2.1, 2.2, 2.3}, ret[0].GetVectors().GetFloatVector().Data)
}
}
Expand Down Expand Up @@ -183,7 +183,7 @@ func (s *OpenAIEmbeddingFunctionSuite) TestEmbeddingDimNotMatch() {

// embedding dim not match
data := createData([]string{"sentence", "sentence"})
_, err2 := runner.Run(data)
_, err2 := runner.ProcessInsert(data)
s.Error(err2)
}

Expand Down Expand Up @@ -213,7 +213,7 @@ func (s *OpenAIEmbeddingFunctionSuite) TestEmbeddingNubmerNotMatch() {

// embedding dim not match
data := createData([]string{"sentence", "sentence2"})
_, err2 := runner.Run(data)
_, err2 := runner.ProcessInsert(data)
s.Error(err2)
}

Expand Down

0 comments on commit 7950316

Please sign in to comment.