Skip to content

Commit

Permalink
Polish code
Browse files Browse the repository at this point in the history
Signed-off-by: junjie.jiang <[email protected]>
  • Loading branch information
junjiejiangjjj committed Dec 30, 2024
1 parent b482b3c commit 19d79b1
Show file tree
Hide file tree
Showing 24 changed files with 223 additions and 225 deletions.
11 changes: 6 additions & 5 deletions internal/models/ali/ali_dashscope_text_embedding.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ func (eb *ByIndex) Len() int { return len(eb.resp.Output.Embeddings) }
func (eb *ByIndex) Swap(i, j int) {
eb.resp.Output.Embeddings[i], eb.resp.Output.Embeddings[j] = eb.resp.Output.Embeddings[j], eb.resp.Output.Embeddings[i]
}

func (eb *ByIndex) Less(i, j int) bool {
return eb.resp.Output.Embeddings[i].TextIndex < eb.resp.Output.Embeddings[j].TextIndex
}
Expand Down Expand Up @@ -116,23 +117,23 @@ func (c *AliDashScopeEmbedding) Check() error {
return nil
}

func (c *AliDashScopeEmbedding) Embedding(modelName string, texts []string, dim int, text_type string, output_type string, timeoutSec time.Duration) (*EmbeddingResponse, error) {
func (c *AliDashScopeEmbedding) Embedding(modelName string, texts []string, dim int, textType string, outputType string, timeoutSec int64) (*EmbeddingResponse, error) {
var r EmbeddingRequest
r.Model = modelName
r.Input = Input{texts}
r.Parameters.Dimension = dim
r.Parameters.TextType = text_type
r.Parameters.OutputType = output_type
r.Parameters.TextType = textType
r.Parameters.OutputType = outputType
data, err := json.Marshal(r)
if err != nil {
return nil, err
}

if timeoutSec <= 0 {
timeoutSec = 30
timeoutSec = utils.DefaultTimeout
}

ctx, cancel := context.WithTimeout(context.Background(), timeoutSec*time.Second)
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(timeoutSec)*time.Second)
defer cancel()
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.url, bytes.NewBuffer(data))
if err != nil {
Expand Down
87 changes: 37 additions & 50 deletions internal/models/openai/openai_embedding.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ type EmbedddingError struct {

type OpenAIEmbeddingInterface interface {
Check() error
Embedding(modelName string, texts []string, dim int, user string, timeoutSec time.Duration) (*EmbeddingResponse, error)
Embedding(modelName string, texts []string, dim int, user string, timeoutSec int64) (*EmbeddingResponse, error)
}

type openAIBase struct {
Expand Down Expand Up @@ -135,41 +135,31 @@ func (c *openAIBase) genReq(modelName string, texts []string, dim int, user stri
return &r
}

type OpenAIEmbeddingClient struct {
openAIBase
}

func NewOpenAIEmbeddingClient(apiKey string, url string) *OpenAIEmbeddingClient {
return &OpenAIEmbeddingClient{
openAIBase{
apiKey: apiKey,
url: url,
},
}
}

func (c *OpenAIEmbeddingClient) Embedding(modelName string, texts []string, dim int, user string, timeoutSec time.Duration) (*EmbeddingResponse, error) {
func (c *openAIBase) embedding(url string, headers map[string]string, modelName string, texts []string, dim int, user string, timeoutSec int64) (*EmbeddingResponse, error) {
r := c.genReq(modelName, texts, dim, user)
data, err := json.Marshal(r)
if err != nil {
return nil, err
}

if timeoutSec <= 0 {
timeoutSec = 30
timeoutSec = utils.DefaultTimeout
}
ctx, cancel := context.WithTimeout(context.Background(), timeoutSec*time.Second)

ctx, cancel := context.WithTimeout(context.Background(), time.Duration(timeoutSec)*time.Second)
defer cancel()
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.url, bytes.NewBuffer(data))
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewBuffer(data))
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.apiKey))
for key, value := range headers {
req.Header.Set(key, value)
}
body, err := utils.RetrySend(req, 3)
if err != nil {
return nil, err
}

var res EmbeddingResponse
err = json.Unmarshal(body, &res)
if err != nil {
Expand All @@ -179,6 +169,27 @@ func (c *OpenAIEmbeddingClient) Embedding(modelName string, texts []string, dim
return &res, err
}

type OpenAIEmbeddingClient struct {
openAIBase
}

func NewOpenAIEmbeddingClient(apiKey string, url string) *OpenAIEmbeddingClient {
return &OpenAIEmbeddingClient{
openAIBase{
apiKey: apiKey,
url: url,
},
}
}

func (c *OpenAIEmbeddingClient) Embedding(modelName string, texts []string, dim int, user string, timeoutSec int64) (*EmbeddingResponse, error) {
headers := map[string]string{
"Content-Type": "application/json",
"Authorization": fmt.Sprintf("Bearer %s", c.apiKey),
}
return c.embedding(c.url, headers, modelName, texts, dim, user, timeoutSec)
}

type AzureOpenAIEmbeddingClient struct {
openAIBase
apiVersion string
Expand All @@ -194,17 +205,7 @@ func NewAzureOpenAIEmbeddingClient(apiKey string, url string) *AzureOpenAIEmbedd
}
}

func (c *AzureOpenAIEmbeddingClient) Embedding(modelName string, texts []string, dim int, user string, timeoutSec time.Duration) (*EmbeddingResponse, error) {
r := c.genReq(modelName, texts, dim, user)
data, err := json.Marshal(r)
if err != nil {
return nil, err
}

if timeoutSec <= 0 {
timeoutSec = 30
}

func (c *AzureOpenAIEmbeddingClient) Embedding(modelName string, texts []string, dim int, user string, timeoutSec int64) (*EmbeddingResponse, error) {
base, err := url.Parse(c.url)
if err != nil {
return nil, err
Expand All @@ -214,25 +215,11 @@ func (c *AzureOpenAIEmbeddingClient) Embedding(modelName string, texts []string,
params := url.Values{}
params.Add("api-version", c.apiVersion)
base.RawQuery = params.Encode()
url := base.String()

ctx, cancel := context.WithTimeout(context.Background(), timeoutSec*time.Second)
defer cancel()
req, err := http.NewRequestWithContext(ctx, http.MethodPost, base.String(), bytes.NewBuffer(data))
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("api-key", c.apiKey)
body, err := utils.RetrySend(req, 3)
if err != nil {
return nil, err
}

var res EmbeddingResponse
err = json.Unmarshal(body, &res)
if err != nil {
return nil, err
headers := map[string]string{
"Content-Type": "application/json",
"api-key": c.apiKey,
}
sort.Sort(&ByIndex{&res})
return &res, err
return c.embedding(url, headers, modelName, texts, dim, user, timeoutSec)
}
4 changes: 3 additions & 1 deletion internal/models/utils/embedding_util.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ import (
"net/http"
)

const DefaultTimeout int64 = 30

func send(req *http.Request) ([]byte, error) {
resp, err := http.DefaultClient.Do(req)
if err != nil {
Expand All @@ -34,7 +36,7 @@ func send(req *http.Request) ([]byte, error) {
return nil, err
}

if resp.StatusCode != 200 {
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf(string(body))
}
return body, nil
Expand Down
10 changes: 5 additions & 5 deletions internal/models/vertexai/vertexai_text_embedding.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@ import (
"net/http"
"time"

"github.com/milvus-io/milvus/internal/models/utils"

"golang.org/x/oauth2/google"

"github.com/milvus-io/milvus/internal/models/utils"
)

type Instance struct {
Expand Down Expand Up @@ -114,7 +114,7 @@ func (c *VertexAIEmbedding) getAccessToken() (string, error) {
return token.AccessToken, nil
}

func (c *VertexAIEmbedding) Embedding(modelName string, texts []string, dim int64, taskType string, timeoutSec time.Duration) (*EmbeddingResponse, error) {
func (c *VertexAIEmbedding) Embedding(modelName string, texts []string, dim int64, taskType string, timeoutSec int64) (*EmbeddingResponse, error) {
var r EmbeddingRequest
for _, text := range texts {
r.Instances = append(r.Instances, Instance{TaskType: taskType, Content: text})
Expand All @@ -129,10 +129,10 @@ func (c *VertexAIEmbedding) Embedding(modelName string, texts []string, dim int6
}

if timeoutSec <= 0 {
timeoutSec = 30
timeoutSec = utils.DefaultTimeout
}

ctx, cancel := context.WithTimeout(context.Background(), timeoutSec*time.Second)
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(timeoutSec)*time.Second)
defer cancel()
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.url, bytes.NewBuffer(data))
if err != nil {
Expand Down
6 changes: 3 additions & 3 deletions internal/models/vertexai/vertexai_text_embedding_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ import (
)

func TestEmbeddingClientCheck(t *testing.T) {
mockJsonKey := []byte{1, 2, 3}
mockJSONKey := []byte{1, 2, 3}
{
c := NewVertexAIEmbedding("mock_url", []byte{}, "mock_scopes", "")
err := c.Check()
Expand All @@ -36,14 +36,14 @@ func TestEmbeddingClientCheck(t *testing.T) {
}

{
c := NewVertexAIEmbedding("", mockJsonKey, "", "")
c := NewVertexAIEmbedding("", mockJSONKey, "", "")
err := c.Check()
assert.True(t, err != nil)
fmt.Println(err)
}

{
c := NewVertexAIEmbedding("mock_url", mockJsonKey, "mock_scopes", "")
c := NewVertexAIEmbedding("mock_url", mockJSONKey, "mock_scopes", "")
err := c.Check()
assert.True(t, err == nil)
}
Expand Down
13 changes: 9 additions & 4 deletions internal/proxy/task_insert_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -340,14 +340,19 @@ func TestInsertTask_Function(t *testing.T) {
AutoID: true,
Fields: []*schemapb.FieldSchema{
{FieldID: 100, Name: "id", DataType: schemapb.DataType_Int64, IsPrimaryKey: true, AutoID: true},
{FieldID: 101, Name: "text", DataType: schemapb.DataType_VarChar,
{
FieldID: 101, Name: "text", DataType: schemapb.DataType_VarChar,
TypeParams: []*commonpb.KeyValuePair{
{Key: "max_length", Value: "200"},
}},
{FieldID: 102, Name: "vector", DataType: schemapb.DataType_FloatVector,
},
},
{
FieldID: 102, Name: "vector", DataType: schemapb.DataType_FloatVector,
TypeParams: []*commonpb.KeyValuePair{
{Key: "dim", Value: "4"},
}, IsFunctionOutput: true},
},
IsFunctionOutput: true,
},
},
Functions: []*schemapb.FunctionSchema{
{
Expand Down
18 changes: 12 additions & 6 deletions internal/proxy/task_search_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -487,18 +487,24 @@ func TestSearchTask_WithFunctions(t *testing.T) {
AutoID: true,
Fields: []*schemapb.FieldSchema{
{FieldID: 100, Name: "id", DataType: schemapb.DataType_Int64, IsPrimaryKey: true, AutoID: true},
{FieldID: 101, Name: "text", DataType: schemapb.DataType_VarChar,
{
FieldID: 101, Name: "text", DataType: schemapb.DataType_VarChar,
TypeParams: []*commonpb.KeyValuePair{
{Key: "max_length", Value: "200"},
}},
{FieldID: 102, Name: "vector1", DataType: schemapb.DataType_FloatVector,
},
},
{
FieldID: 102, Name: "vector1", DataType: schemapb.DataType_FloatVector,
TypeParams: []*commonpb.KeyValuePair{
{Key: "dim", Value: "4"},
}},
{FieldID: 103, Name: "vector2", DataType: schemapb.DataType_FloatVector,
},
},
{
FieldID: 103, Name: "vector2", DataType: schemapb.DataType_FloatVector,
TypeParams: []*commonpb.KeyValuePair{
{Key: "dim", Value: "4"},
}},
},
},
},
Functions: []*schemapb.FunctionSchema{
{
Expand Down
13 changes: 9 additions & 4 deletions internal/proxy/task_upsert_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -409,14 +409,19 @@ func TestUpsertTask_Function(t *testing.T) {
AutoID: true,
Fields: []*schemapb.FieldSchema{
{FieldID: 100, Name: "id", DataType: schemapb.DataType_Int64, IsPrimaryKey: true, AutoID: true},
{FieldID: 101, Name: "text", DataType: schemapb.DataType_VarChar,
{
FieldID: 101, Name: "text", DataType: schemapb.DataType_VarChar,
TypeParams: []*commonpb.KeyValuePair{
{Key: "max_length", Value: "200"},
}},
{FieldID: 102, Name: "vector", DataType: schemapb.DataType_FloatVector,
},
},
{
FieldID: 102, Name: "vector", DataType: schemapb.DataType_FloatVector,
TypeParams: []*commonpb.KeyValuePair{
{Key: "dim", Value: "4"},
}, IsFunctionOutput: true},
},
IsFunctionOutput: true,
},
},
Functions: []*schemapb.FunctionSchema{
{
Expand Down
20 changes: 7 additions & 13 deletions internal/util/function/ali_embedding_provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,7 @@ package function
import (
"fmt"
"os"
"strconv"
"strings"
"time"

"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/models/ali"
Expand All @@ -39,15 +37,15 @@ type AliEmbeddingProvider struct {
outputType string

maxBatch int
timeoutSec int
timeoutSec int64
}

func createAliEmbeddingClient(apiKey string, url string) (*ali.AliDashScopeEmbedding, error) {
if apiKey == "" {
apiKey = os.Getenv(dashscopeApiKey)
apiKey = os.Getenv(dashscopeAKEnvStr)
}
if apiKey == "" {
return nil, fmt.Errorf("Missing credentials. Please pass `api_key`, or configure the %s environment variable in the Milvus service.", dashscopeApiKey)
return nil, fmt.Errorf("Missing credentials. Please pass `api_key`, or configure the %s environment variable in the Milvus service.", dashscopeAKEnvStr)
}

if url == "" {
Expand All @@ -70,17 +68,13 @@ func NewAliDashScopeEmbeddingProvider(fieldSchema *schemapb.FieldSchema, functio
case modelNameParamKey:
modelName = param.Value
case dimParamKey:
dim, err = strconv.ParseInt(param.Value, 10, 64)
dim, err = parseAndCheckFieldDim(param.Value, fieldDim, fieldSchema.Name)
if err != nil {
return nil, fmt.Errorf("dim [%s] is not int", param.Value)
}

if dim != 0 && dim != fieldDim {
return nil, fmt.Errorf("Field %s's dim is [%d], but embeding's dim is [%d]", functionSchema.Name, fieldDim, dim)
return nil, err
}
case apiKeyParamKey:
apiKey = param.Value
case embeddingUrlParamKey:
case embeddingURLParamKey:
url = param.Value
default:
}
Expand Down Expand Up @@ -139,7 +133,7 @@ func (provider *AliEmbeddingProvider) CallEmbedding(texts []string, batchLimit b
if end > numRows {
end = numRows
}
resp, err := provider.client.Embedding(provider.modelName, texts[i:end], int(provider.embedDimParam), textType, provider.outputType, time.Duration(provider.timeoutSec))
resp, err := provider.client.Embedding(provider.modelName, texts[i:end], int(provider.embedDimParam), textType, provider.outputType, provider.timeoutSec)
if err != nil {
return nil, err
}
Expand Down
Loading

0 comments on commit 19d79b1

Please sign in to comment.