Skip to content

Commit

Permalink
Add embedding runner
Browse files Browse the repository at this point in the history
Signed-off-by: junjie.jiang <[email protected]>
  • Loading branch information
junjiejiangjjj committed Oct 8, 2024
1 parent 5117b40 commit 7f76022
Show file tree
Hide file tree
Showing 6 changed files with 694 additions and 67 deletions.
70 changes: 40 additions & 30 deletions internal/models/openai_embedding.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,10 @@ import (
"fmt"
"io"
"net/http"
"sort"
"time"
)

const (
TextEmbeddingAda002 string = "text-embedding-ada-002"
TextEmbedding3Small string = "text-embedding-3-small"
TextEmbedding3Large string = "text-embedding-3-large"
)


type EmbeddingRequest struct {
// ID of the model to use.
Expand Down Expand Up @@ -84,6 +79,16 @@ type EmbeddingResponse struct {
Usage Usage `json:"usage"`
}


type ByIndex struct {
resp *EmbeddingResponse
}

func (eb *ByIndex) Len() int { return len(eb.resp.Data) }
func (eb *ByIndex) Swap(i, j int) { eb.resp.Data[i], eb.resp.Data[j] = eb.resp.Data[j], eb.resp.Data[i] }
func (eb *ByIndex) Less(i, j int) bool { return eb.resp.Data[i].Index < eb.resp.Data[j].Index }


type ErrorInfo struct {
Code string `json:"code"`
Message string `json:"message"`
Expand All @@ -96,27 +101,28 @@ type EmbedddingError struct {
}

type OpenAIEmbeddingClient struct {
api_key string
uri string
model_name string
apiKey string
url string
}

func (c *OpenAIEmbeddingClient) Check() error {
if c.model_name != TextEmbeddingAda002 && c.model_name != TextEmbedding3Small && c.model_name != TextEmbedding3Large {
return fmt.Errorf("Unsupported model: %s, only support [%s, %s, %s]",
c.model_name, TextEmbeddingAda002, TextEmbedding3Small, TextEmbedding3Large)
}

if c.api_key == "" {
if c.apiKey == "" {
return fmt.Errorf("OpenAI api key is empty")
}

if c.uri == "" {
return fmt.Errorf("OpenAI embedding uri is empty")
if c.url == "" {
return fmt.Errorf("OpenAI embedding url is empty")
}
return nil
}

func NewOpenAIEmbeddingClient(apiKey string, url string) OpenAIEmbeddingClient{
return OpenAIEmbeddingClient{
apiKey: apiKey,
url: url,
}
}


func (c *OpenAIEmbeddingClient) send(client *http.Client, req *http.Request, res *EmbeddingResponse) error {
// call openai
Expand All @@ -143,9 +149,9 @@ func (c *OpenAIEmbeddingClient) send(client *http.Client, req *http.Request, res
return nil
}

func (c *OpenAIEmbeddingClient) sendWithRetry(client *http.Client, req *http.Request,res *EmbeddingResponse, max_retries int) error {
func (c *OpenAIEmbeddingClient) sendWithRetry(client *http.Client, req *http.Request,res *EmbeddingResponse, maxRetries int) error {
var err error
for i := 0; i < max_retries; i++ {
for i := 0; i < maxRetries; i++ {
err = c.send(client, req, res)
if err == nil {
return nil
Expand All @@ -154,9 +160,9 @@ func (c *OpenAIEmbeddingClient) sendWithRetry(client *http.Client, req *http.Req
return err
}

func (c *OpenAIEmbeddingClient) Embedding(texts []string, dim int, user string, timeout_sec time.Duration) (EmbeddingResponse, error) {
func (c *OpenAIEmbeddingClient) Embedding(modelName string, texts []string, dim int, user string, timeoutSec time.Duration) (*EmbeddingResponse, error) {
var r EmbeddingRequest
r.Model = c.model_name
r.Model = modelName
r.Input = texts
r.EncodingFormat = "float"
if user != "" {
Expand All @@ -166,27 +172,31 @@ func (c *OpenAIEmbeddingClient) Embedding(texts []string, dim int, user string,
r.Dimensions = dim
}

var res EmbeddingResponse
data, err := json.Marshal(r)
if err != nil {
return res, err
return nil, err
}

// call openai
if timeout_sec <= 0 {
timeout_sec = 30
if timeoutSec <= 0 {
timeoutSec = 30
}
client := &http.Client{
Timeout: timeout_sec * time.Second,
Timeout: timeoutSec * time.Second,
}
req, err := http.NewRequest("POST" , c.uri, bytes.NewBuffer(data))
req, err := http.NewRequest("POST" , c.url, bytes.NewBuffer(data))
if err != nil {
return res, err
return nil, err
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("api-key", c.api_key)
req.Header.Set("api-key", c.apiKey)

var res EmbeddingResponse
err = c.sendWithRetry(client, req, &res, 3)
return res, err
if err != nil {
return nil, err
}
sort.Sort(&ByIndex{&res})
return &res, err

}
79 changes: 42 additions & 37 deletions internal/models/openai_embedding_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,41 +17,34 @@
package models

import (
// "bytes"
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"testing"
"time"
"sync/atomic"

"github.com/stretchr/testify/assert"
)

func TestEmbeddingClientCheck(t *testing.T) {
{
c := OpenAIEmbeddingClient{"mock_key", "mock_uri", "unknow_model"}
c := OpenAIEmbeddingClient{"", "mock_uri"}
err := c.Check();
assert.True(t, err != nil)
fmt.Println(err)
}

{
c := OpenAIEmbeddingClient{"", "mock_uri", TextEmbeddingAda002}
c := OpenAIEmbeddingClient{"mock_key", ""}
err := c.Check();
assert.True(t, err != nil)
fmt.Println(err)
}

{
c := OpenAIEmbeddingClient{"mock_key", "", TextEmbedding3Small}
err := c.Check();
assert.True(t, err != nil)
fmt.Println(err)
}

{
c := OpenAIEmbeddingClient{"mock_key", "mock_uri", TextEmbedding3Small}
c := OpenAIEmbeddingClient{"mock_key", "mock_uri"}
err := c.Check();
assert.True(t, err == nil)
}
Expand All @@ -61,7 +54,7 @@ func TestEmbeddingClientCheck(t *testing.T) {
func TestEmbeddingOK(t *testing.T) {
var res EmbeddingResponse
res.Object = "list"
res.Model = TextEmbedding3Small
res.Model = "text-embedding-3-small"
res.Data = []EmbeddingData{
{
Object: "embedding",
Expand All @@ -84,37 +77,47 @@ func TestEmbeddingOK(t *testing.T) {
url := ts.URL

{
c := OpenAIEmbeddingClient{"mock_key", url, TextEmbedding3Small}
c := OpenAIEmbeddingClient{"mock_key", url}
err := c.Check();
assert.True(t, err == nil)
ret, err := c.Embedding([]string{"sentence"}, 0, "", 0)
ret, err := c.Embedding("text-embedding-3-small", []string{"sentence"}, 0, "", 0)
assert.True(t, err == nil)
assert.Equal(t, ret, res)
assert.Equal(t, ret, &res)
}
}


func TestEmbeddingRetry(t *testing.T) {
var res EmbeddingResponse
res.Object = "list"
res.Model = TextEmbedding3Small
res.Model = "text-embedding-3-small"
res.Data = []EmbeddingData{
{
Object: "embedding",
Embedding: []float32{1.1, 2.2, 3.2, 4.5},
Index: 2,
},
{
Object: "embedding",
Embedding: []float32{1.1, 2.2, 3.3, 4.4},
Index: 0,
},
{
Object: "embedding",
Embedding: []float32{1.1, 2.2, 3.2, 4.3},
Index: 1,
},
}
res.Usage = Usage{
PromptTokens: 1,
TotalTokens: 100,
}

var count = 0
var count int32 = 0

ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if count < 2 {
count += 1
if atomic.LoadInt32(&count) < 2 {
atomic.AddInt32(&count, 1)
w.WriteHeader(http.StatusUnauthorized)
} else {
w.WriteHeader(http.StatusOK)
Expand All @@ -127,59 +130,61 @@ func TestEmbeddingRetry(t *testing.T) {
url := ts.URL

{
c := OpenAIEmbeddingClient{"mock_key", url, TextEmbedding3Small}
c := OpenAIEmbeddingClient{"mock_key", url}
err := c.Check();
assert.True(t, err == nil)
ret, err := c.Embedding([]string{"sentence"}, 0, "", 0)
ret, err := c.Embedding("text-embedding-3-small", []string{"sentence"}, 0, "", 0)
assert.True(t, err == nil)
assert.Equal(t, ret, res)
assert.Equal(t, count, 2)
assert.Equal(t, ret.Usage, res.Usage)
assert.Equal(t, ret.Object, res.Object)
assert.Equal(t, ret.Model, res.Model)
assert.Equal(t, ret.Data[0], res.Data[1])
assert.Equal(t, ret.Data[1], res.Data[2])
assert.Equal(t, ret.Data[2], res.Data[0])
assert.Equal(t, atomic.LoadInt32(&count), int32(2))
}
}


func TestEmbeddingFailed(t *testing.T) {
var count = 0

var count int32 = 0
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
count += 1
atomic.AddInt32(&count, 1)
w.WriteHeader(http.StatusUnauthorized)
}))

defer ts.Close()
url := ts.URL

{
c := OpenAIEmbeddingClient{"mock_key", url, TextEmbedding3Small}
c := OpenAIEmbeddingClient{"mock_key", url}
err := c.Check();
assert.True(t, err == nil)
_, err = c.Embedding([]string{"sentence"}, 0, "", 0)
_, err = c.Embedding("text-embedding-3-small", []string{"sentence"}, 0, "", 0)
assert.True(t, err != nil)
assert.Equal(t, count, 3)
assert.Equal(t, atomic.LoadInt32(&count), int32(3))
}
}

func TestTimeout(t *testing.T) {
var st = "Doing"

var st int32 = 0
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
time.Sleep(3 * time.Second)
st = "Done"
atomic.AddInt32(&st, 1)
w.WriteHeader(http.StatusUnauthorized)

}))

defer ts.Close()
url := ts.URL

{
c := OpenAIEmbeddingClient{"mock_key", url, TextEmbedding3Small}
c := OpenAIEmbeddingClient{"mock_key", url}
err := c.Check();
assert.True(t, err == nil)
_, err = c.Embedding([]string{"sentence"}, 0, "", 1)
_, err = c.Embedding("text-embedding-3-small", []string{"sentence"}, 0, "", 1)
assert.True(t, err != nil)
assert.Equal(t, st, "Doing")
assert.Equal(t, atomic.LoadInt32(&st), int32(0))
time.Sleep(3 * time.Second)
assert.Equal(t, st, "Done")
assert.Equal(t, atomic.LoadInt32(&st), int32(1))
}
}
1 change: 1 addition & 0 deletions internal/util/function/function.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ type FunctionRunner interface {
GetOutputFields() []*schemapb.FieldSchema
}


func NewFunctionRunner(coll *schemapb.CollectionSchema, schema *schemapb.FunctionSchema) (FunctionRunner, error) {
switch schema.GetType() {
case schemapb.FunctionType_BM25:
Expand Down
Loading

0 comments on commit 7f76022

Please sign in to comment.