diff --git a/service/aiproxy/controller/channel-test.go b/service/aiproxy/controller/channel-test.go index c3505010191..45c8ee54356 100644 --- a/service/aiproxy/controller/channel-test.go +++ b/service/aiproxy/controller/channel-test.go @@ -101,7 +101,7 @@ func testChannel(channel *model.Channel, request *relaymodel.GeneralOpenAIReques return nil, err } if resp != nil && resp.StatusCode != http.StatusOK { - err := controller.RelayErrorHandler(resp) + err := controller.RelayErrorHandler(resp, meta.Mode) return &err.Error, errors.New(err.Error.Message) } usage, respErr := adaptor.DoResponse(c, resp, meta) diff --git a/service/aiproxy/controller/relay.go b/service/aiproxy/controller/relay.go index f4db66e65fe..892861f1062 100644 --- a/service/aiproxy/controller/relay.go +++ b/service/aiproxy/controller/relay.go @@ -34,6 +34,8 @@ func relayHelper(c *gin.Context, relayMode int) *model.ErrorWithStatusCode { fallthrough case relaymode.AudioTranscription: err = controller.RelayAudioHelper(c, relayMode) + case relaymode.Rerank: + err = controller.RerankHelper(c) default: err = controller.RelayTextHelper(c) } diff --git a/service/aiproxy/middleware/distributor.go b/service/aiproxy/middleware/distributor.go index 13c4f4aea82..826f1c34568 100644 --- a/service/aiproxy/middleware/distributor.go +++ b/service/aiproxy/middleware/distributor.go @@ -22,6 +22,10 @@ func Distribute(c *gin.Context) { return } requestModel := c.GetString(ctxkey.RequestModel) + if requestModel == "" { + abortWithMessage(c, http.StatusBadRequest, "no model provided") + return + } var channel *model.Channel channelID, ok := c.Get(ctxkey.SpecificChannelID) if ok { diff --git a/service/aiproxy/middleware/utils.go b/service/aiproxy/middleware/utils.go index cf7fcc4b906..91eedb46c03 100644 --- a/service/aiproxy/middleware/utils.go +++ b/service/aiproxy/middleware/utils.go @@ -26,8 +26,6 @@ func getRequestModel(c *gin.Context) (string, error) { switch { case strings.HasPrefix(path, "/v1/moderations"): return "text-moderation-stable", nil - case strings.HasSuffix(path, "embeddings"): - return c.Param("model"), nil case strings.HasPrefix(path, "/v1/images/generations"): return "dall-e-2", nil case strings.HasPrefix(path, "/v1/audio/transcriptions"), strings.HasPrefix(path, "/v1/audio/translations"): diff --git a/service/aiproxy/relay/adaptor/cohere/adaptor.go b/service/aiproxy/relay/adaptor/cohere/adaptor.go index 525b39ab387..6815c6bb098 100644 --- a/service/aiproxy/relay/adaptor/cohere/adaptor.go +++ b/service/aiproxy/relay/adaptor/cohere/adaptor.go @@ -7,8 +7,10 @@ import ( "github.com/gin-gonic/gin" "github.com/labring/sealos/service/aiproxy/relay/adaptor" + "github.com/labring/sealos/service/aiproxy/relay/adaptor/openai" "github.com/labring/sealos/service/aiproxy/relay/meta" "github.com/labring/sealos/service/aiproxy/relay/model" + "github.com/labring/sealos/service/aiproxy/relay/relaymode" ) type Adaptor struct{} @@ -53,7 +55,12 @@ func (a *Adaptor) ConvertTTSRequest(*model.TextToSpeechRequest) (any, error) { func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) { if meta.IsStream { err, usage = StreamHandler(c, resp) - } else { + return + } + switch meta.Mode { + case relaymode.Rerank: + err, usage = openai.RerankHandler(c, resp, meta.PromptTokens, meta) + default: err, usage = Handler(c, resp, meta.PromptTokens, meta.ActualModelName) } return diff --git a/service/aiproxy/relay/adaptor/openai/adaptor.go b/service/aiproxy/relay/adaptor/openai/adaptor.go index 48d0250707f..a5cab79b89b 100644 --- a/service/aiproxy/relay/adaptor/openai/adaptor.go +++ b/service/aiproxy/relay/adaptor/openai/adaptor.go @@ -185,17 +185,19 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Met usage.PromptTokens = meta.PromptTokens usage.CompletionTokens = usage.TotalTokens - meta.PromptTokens } - } else { - switch meta.Mode { - case relaymode.ImagesGenerations: - err, _ = ImageHandler(c, resp) - case relaymode.AudioTranscription: - err, usage = STTHandler(c, resp, meta, a.responseFormat) - case relaymode.AudioSpeech: - err, usage = TTSHandler(c, resp, meta) - default: - err, usage = Handler(c, resp, meta.PromptTokens, meta.ActualModelName) - } + return + } + switch meta.Mode { + case relaymode.ImagesGenerations: + err, _ = ImageHandler(c, resp) + case relaymode.AudioTranscription: + err, usage = STTHandler(c, resp, meta, a.responseFormat) + case relaymode.AudioSpeech: + err, usage = TTSHandler(c, resp, meta) + case relaymode.Rerank: + err, usage = RerankHandler(c, resp, meta.PromptTokens, meta) + default: + err, usage = Handler(c, resp, meta.PromptTokens, meta.ActualModelName) } return } diff --git a/service/aiproxy/relay/adaptor/openai/main.go b/service/aiproxy/relay/adaptor/openai/main.go index d25412d3bb6..a90d8ff8e99 100644 --- a/service/aiproxy/relay/adaptor/openai/main.go +++ b/service/aiproxy/relay/adaptor/openai/main.go @@ -97,12 +97,13 @@ func StreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*model.E } func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName string) (*model.ErrorWithStatusCode, *model.Usage) { - var textResponse SlimTextResponse + defer resp.Body.Close() + responseBody, err := io.ReadAll(resp.Body) - _ = resp.Body.Close() if err != nil { return ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil } + var textResponse SlimTextResponse err = json.Unmarshal(responseBody, &textResponse) if err != nil { return ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil @@ -126,18 +127,49 @@ func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName st } } - resp.Body = io.NopCloser(bytes.NewBuffer(responseBody)) - defer resp.Body.Close() - for k, v := range resp.Header { c.Writer.Header().Set(k, v[0]) } c.Writer.WriteHeader(resp.StatusCode) - _, _ = io.Copy(c.Writer, resp.Body) + _, _ = c.Writer.Write(responseBody) return nil, &textResponse.Usage } +func RerankHandler(c *gin.Context, resp *http.Response, promptTokens int, _ *meta.Meta) (*model.ErrorWithStatusCode, *model.Usage) { + defer resp.Body.Close() + + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + return ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil + } + var rerankResponse SlimRerankResponse + err = json.Unmarshal(responseBody, &rerankResponse) + if err != nil { + return ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil + } + + c.Writer.WriteHeader(resp.StatusCode) + + _, _ = c.Writer.Write(responseBody) + + if rerankResponse.Meta.Tokens == nil { + return nil, &model.Usage{ + PromptTokens: promptTokens, + CompletionTokens: 0, + TotalTokens: promptTokens, + } + } + if rerankResponse.Meta.Tokens.InputTokens <= 0 { + rerankResponse.Meta.Tokens.InputTokens = promptTokens + } + return nil, &model.Usage{ + PromptTokens: rerankResponse.Meta.Tokens.InputTokens, + CompletionTokens: rerankResponse.Meta.Tokens.OutputTokens, + TotalTokens: rerankResponse.Meta.Tokens.InputTokens + rerankResponse.Meta.Tokens.OutputTokens, + } +} + func TTSHandler(c *gin.Context, resp *http.Response, meta *meta.Meta) (*model.ErrorWithStatusCode, *model.Usage) { defer resp.Body.Close() diff --git a/service/aiproxy/relay/adaptor/openai/model.go b/service/aiproxy/relay/adaptor/openai/model.go index 9bb9b1fe2a2..fe8123b68d1 100644 --- a/service/aiproxy/relay/adaptor/openai/model.go +++ b/service/aiproxy/relay/adaptor/openai/model.go @@ -73,6 +73,10 @@ type SlimTextResponse struct { model.Usage `json:"usage"` } +type SlimRerankResponse struct { + Meta model.RerankMeta `json:"meta"` +} + type TextResponseChoice struct { FinishReason string `json:"finish_reason"` model.Message `json:"message"` diff --git a/service/aiproxy/relay/controller/audio.go b/service/aiproxy/relay/controller/audio.go index ab455a2cbfd..fb1be0ad0ab 100644 --- a/service/aiproxy/relay/controller/audio.go +++ b/service/aiproxy/relay/controller/audio.go @@ -110,7 +110,7 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus } if isErrorHappened(meta, resp) { - err := RelayErrorHandler(resp) + err := RelayErrorHandler(resp, meta.Mode) ConsumeWaitGroup.Add(1) go postConsumeAmount(context.Background(), &ConsumeWaitGroup, diff --git a/service/aiproxy/relay/controller/error.go b/service/aiproxy/relay/controller/error.go index 1b268c424f1..7f6433e2644 100644 --- a/service/aiproxy/relay/controller/error.go +++ b/service/aiproxy/relay/controller/error.go @@ -2,13 +2,17 @@ package controller import ( "fmt" + "io" "net/http" "strconv" + "strings" json "github.com/json-iterator/go" "github.com/labring/sealos/service/aiproxy/common/config" + "github.com/labring/sealos/service/aiproxy/common/conv" "github.com/labring/sealos/service/aiproxy/common/logger" "github.com/labring/sealos/service/aiproxy/relay/model" + "github.com/labring/sealos/service/aiproxy/relay/relaymode" ) type GeneralErrorResponse struct { @@ -52,7 +56,7 @@ func (e GeneralErrorResponse) ToMessage() string { return "" } -func RelayErrorHandler(resp *http.Response) *model.ErrorWithStatusCode { +func RelayErrorHandler(resp *http.Response, relayMode int) *model.ErrorWithStatusCode { if resp == nil { return &model.ErrorWithStatusCode{ StatusCode: 500, @@ -63,7 +67,49 @@ func RelayErrorHandler(resp *http.Response) *model.ErrorWithStatusCode { }, } } + switch relayMode { + case relaymode.Rerank: + return RerankErrorHandler(resp) + default: + return RelayDefaultErrorHanlder(resp) + } +} + +func RerankErrorHandler(resp *http.Response) *model.ErrorWithStatusCode { defer resp.Body.Close() + respBody, err := io.ReadAll(resp.Body) + if err != nil { + return &model.ErrorWithStatusCode{ + StatusCode: resp.StatusCode, + Error: model.Error{ + Message: err.Error(), + Type: "upstream_error", + Code: "bad_response", + }, + } + } + trimmedRespBody := strings.Trim(conv.BytesToString(respBody), "\"") + return &model.ErrorWithStatusCode{ + StatusCode: resp.StatusCode, + Error: model.Error{ + Message: trimmedRespBody, + Type: "upstream_error", + Code: "bad_response", + }, + } +} + +func RelayDefaultErrorHanlder(resp *http.Response) *model.ErrorWithStatusCode { + defer resp.Body.Close() + respBody, err := io.ReadAll(resp.Body) + if err != nil { + return &model.ErrorWithStatusCode{ + StatusCode: resp.StatusCode, + Error: model.Error{ + Message: err.Error(), + }, + } + } ErrorWithStatusCode := &model.ErrorWithStatusCode{ StatusCode: resp.StatusCode, @@ -75,7 +121,7 @@ func RelayErrorHandler(resp *http.Response) *model.ErrorWithStatusCode { }, } var errResponse GeneralErrorResponse - err := json.NewDecoder(resp.Body).Decode(&errResponse) + err = json.Unmarshal(respBody, &errResponse) if err != nil { return ErrorWithStatusCode } diff --git a/service/aiproxy/relay/controller/helper.go b/service/aiproxy/relay/controller/helper.go index 16f8f8adc1a..084a940997b 100644 --- a/service/aiproxy/relay/controller/helper.go +++ b/service/aiproxy/relay/controller/helper.go @@ -54,20 +54,29 @@ func getPromptTokens(textRequest *relaymodel.GeneralOpenAIRequest, relayMode int return 0 } -func getPreConsumedAmount(textRequest *relaymodel.GeneralOpenAIRequest, promptTokens int, price float64) float64 { - preConsumedTokens := int64(promptTokens) - if textRequest.MaxTokens != 0 { - preConsumedTokens += int64(textRequest.MaxTokens) +type PreCheckGroupBalanceReq struct { + PromptTokens int + MaxTokens int + Price float64 +} + +func getPreConsumedAmount(req *PreCheckGroupBalanceReq) float64 { + if req.Price == 0 || (req.PromptTokens == 0 && req.MaxTokens == 0) { + return 0 + } + preConsumedTokens := int64(req.PromptTokens) + if req.MaxTokens != 0 { + preConsumedTokens += int64(req.MaxTokens) } return decimal. NewFromInt(preConsumedTokens). - Mul(decimal.NewFromFloat(price)). + Mul(decimal.NewFromFloat(req.Price)). Div(decimal.NewFromInt(billingprice.PriceUnit)). InexactFloat64() } -func preCheckGroupBalance(ctx context.Context, textRequest *relaymodel.GeneralOpenAIRequest, promptTokens int, price float64, meta *meta.Meta) (bool, balance.PostGroupConsumer, error) { - preConsumedAmount := getPreConsumedAmount(textRequest, promptTokens, price) +func preCheckGroupBalance(ctx context.Context, req *PreCheckGroupBalanceReq, meta *meta.Meta) (bool, balance.PostGroupConsumer, error) { + preConsumedAmount := getPreConsumedAmount(req) groupRemainBalance, postGroupConsumer, err := balance.Default.GetGroupRemainBalance(ctx, meta.Group) if err != nil { diff --git a/service/aiproxy/relay/controller/image.go b/service/aiproxy/relay/controller/image.go index 3d9a13b066f..86c947a99e6 100644 --- a/service/aiproxy/relay/controller/image.go +++ b/service/aiproxy/relay/controller/image.go @@ -175,7 +175,7 @@ func RelayImageHelper(c *gin.Context, _ int) *relaymodel.ErrorWithStatusCode { } if isErrorHappened(meta, resp) { - err := RelayErrorHandler(resp) + err := RelayErrorHandler(resp, meta.Mode) ConsumeWaitGroup.Add(1) go postConsumeAmount(context.Background(), &ConsumeWaitGroup, diff --git a/service/aiproxy/relay/controller/rerank.go b/service/aiproxy/relay/controller/rerank.go new file mode 100644 index 00000000000..3654b9ef900 --- /dev/null +++ b/service/aiproxy/relay/controller/rerank.go @@ -0,0 +1,162 @@ +package controller + +import ( + "bytes" + "context" + "errors" + "fmt" + "io" + "net/http" + "strings" + + "github.com/gin-gonic/gin" + json "github.com/json-iterator/go" + "github.com/labring/sealos/service/aiproxy/common" + "github.com/labring/sealos/service/aiproxy/common/logger" + "github.com/labring/sealos/service/aiproxy/relay" + "github.com/labring/sealos/service/aiproxy/relay/adaptor" + "github.com/labring/sealos/service/aiproxy/relay/adaptor/openai" + "github.com/labring/sealos/service/aiproxy/relay/meta" + relaymodel "github.com/labring/sealos/service/aiproxy/relay/model" + billingprice "github.com/labring/sealos/service/aiproxy/relay/price" + "github.com/labring/sealos/service/aiproxy/relay/relaymode" +) + +func RerankHelper(c *gin.Context) *relaymodel.ErrorWithStatusCode { + ctx := c.Request.Context() + meta := meta.GetByContext(c) + rerankRequest, err := getRerankRequest(c) + if err != nil { + logger.Errorf(ctx, "get rerank request failed: %s", err.Error()) + return openai.ErrorWrapper(err, "invalid_rerank_request", http.StatusBadRequest) + } + + meta.OriginModelName = rerankRequest.Model + rerankRequest.Model, _ = getMappedModelName(rerankRequest.Model, meta.ModelMapping) + meta.ActualModelName = rerankRequest.Model + + price, ok := billingprice.GetModelPrice(meta.OriginModelName, meta.ActualModelName, meta.ChannelType) + if !ok { + return openai.ErrorWrapper(fmt.Errorf("model price not found: %s", meta.OriginModelName), "model_price_not_found", http.StatusInternalServerError) + } + completionPrice, ok := billingprice.GetCompletionPrice(meta.OriginModelName, meta.ActualModelName, meta.ChannelType) + if !ok { + return openai.ErrorWrapper(fmt.Errorf("completion price not found: %s", meta.OriginModelName), "completion_price_not_found", http.StatusInternalServerError) + } + + meta.PromptTokens = rerankPromptTokens(rerankRequest) + + ok, postGroupConsumer, err := preCheckGroupBalance(ctx, &PreCheckGroupBalanceReq{ + PromptTokens: meta.PromptTokens, + Price: price, + }, meta) + if err != nil { + logger.Errorf(ctx, "get group (%s) balance failed: %s", meta.Group, err) + return openai.ErrorWrapper( + fmt.Errorf("get group (%s) balance failed", meta.Group), + "get_group_quota_failed", + http.StatusInternalServerError, + ) + } + if !ok { + return openai.ErrorWrapper(errors.New("group balance is not enough"), "insufficient_group_balance", http.StatusForbidden) + } + + adaptor := relay.GetAdaptor(meta.APIType) + if adaptor == nil { + return openai.ErrorWrapper(fmt.Errorf("invalid api type: %d", meta.APIType), "invalid_api_type", http.StatusBadRequest) + } + + requestBody, err := getRerankRequestBody(c, meta, rerankRequest, adaptor) + if err != nil { + logger.Errorf(ctx, "get rerank request body failed: %s", err.Error()) + return openai.ErrorWrapper(err, "invalid_rerank_request", http.StatusBadRequest) + } + + resp, err := adaptor.DoRequest(c, meta, requestBody) + if err != nil { + logger.Errorf(ctx, "do rerank request failed: %s", err.Error()) + ConsumeWaitGroup.Add(1) + go postConsumeAmount(context.Background(), + &ConsumeWaitGroup, + postGroupConsumer, + http.StatusInternalServerError, + c.Request.URL.Path, + nil, meta, price, completionPrice, err.Error(), + ) + return openai.ErrorWrapper(err, "do_request_failed", http.StatusInternalServerError) + } + + if isErrorHappened(meta, resp) { + err := RelayErrorHandler(resp, relaymode.Rerank) + ConsumeWaitGroup.Add(1) + go postConsumeAmount(context.Background(), + &ConsumeWaitGroup, + postGroupConsumer, + resp.StatusCode, + c.Request.URL.Path, + nil, + meta, + price, + completionPrice, + err.String(), + ) + return err + } + + usage, respErr := adaptor.DoResponse(c, resp, meta) + if respErr != nil { + logger.Errorf(ctx, "do rerank response failed: %+v", respErr) + ConsumeWaitGroup.Add(1) + go postConsumeAmount(context.Background(), + &ConsumeWaitGroup, + postGroupConsumer, + http.StatusInternalServerError, + c.Request.URL.Path, + usage, meta, price, completionPrice, respErr.String(), + ) + return respErr + } + + ConsumeWaitGroup.Add(1) + go postConsumeAmount(context.Background(), + &ConsumeWaitGroup, + postGroupConsumer, + http.StatusOK, + c.Request.URL.Path, + usage, meta, price, completionPrice, "", + ) + + return nil +} + +func getRerankRequest(c *gin.Context) (*relaymodel.RerankRequest, error) { + rerankRequest := &relaymodel.RerankRequest{} + err := common.UnmarshalBodyReusable(c, rerankRequest) + if err != nil { + return nil, err + } + if rerankRequest.Model == "" { + return nil, errors.New("model parameter must be provided") + } + if rerankRequest.Query == "" { + return nil, errors.New("query must not be empty") + } + if len(rerankRequest.Documents) == 0 { + return nil, errors.New("document list must not be empty") + } + + return rerankRequest, nil +} + +func getRerankRequestBody(_ *gin.Context, _ *meta.Meta, textRequest *relaymodel.RerankRequest, _ adaptor.Adaptor) (io.Reader, error) { + jsonData, err := json.Marshal(textRequest) + if err != nil { + return nil, err + } + return bytes.NewReader(jsonData), nil +} + +func rerankPromptTokens(rerankRequest *relaymodel.RerankRequest) int { + return len(rerankRequest.Query) + len(strings.Join(rerankRequest.Documents, "")) +} diff --git a/service/aiproxy/relay/controller/text.go b/service/aiproxy/relay/controller/text.go index 98c70e07be9..b510e05950b 100644 --- a/service/aiproxy/relay/controller/text.go +++ b/service/aiproxy/relay/controller/text.go @@ -46,7 +46,11 @@ func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode { // pre-consume balance promptTokens := getPromptTokens(textRequest, meta.Mode) meta.PromptTokens = promptTokens - ok, postGroupConsumer, err := preCheckGroupBalance(ctx, textRequest, promptTokens, price, meta) + ok, postGroupConsumer, err := preCheckGroupBalance(ctx, &PreCheckGroupBalanceReq{ + PromptTokens: promptTokens, + MaxTokens: textRequest.MaxTokens, + Price: price, + }, meta) if err != nil { logger.Errorf(ctx, "get group (%s) balance failed: %s", meta.Group, err) return openai.ErrorWrapper( @@ -89,7 +93,7 @@ func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode { } if isErrorHappened(meta, resp) { - err := RelayErrorHandler(resp) + err := RelayErrorHandler(resp, meta.Mode) ConsumeWaitGroup.Add(1) go postConsumeAmount(context.Background(), &ConsumeWaitGroup, diff --git a/service/aiproxy/relay/model/rerank.go b/service/aiproxy/relay/model/rerank.go new file mode 100644 index 00000000000..a58b3cab225 --- /dev/null +++ b/service/aiproxy/relay/model/rerank.go @@ -0,0 +1,37 @@ +package model + +type RerankRequest struct { + TopN *int `json:"top_n,omitempty"` + MaxChunksPerDoc *int `json:"max_chunks_per_doc,omitempty"` + ReturnDocuments *bool `json:"return_documents,omitempty"` + OverlapTokens *int `json:"overlap_tokens,omitempty"` + Model string `json:"model"` + Query string `json:"query"` + Documents []string `json:"documents"` +} + +type Document struct { + Text string `json:"text"` +} + +type RerankResult struct { + Document Document `json:"document"` + Index int `json:"index"` + RelevanceScore float64 `json:"relevance_score"` +} + +type RerankMetaTokens struct { + InputTokens int `json:"input_tokens"` + OutputTokens int `json:"output_tokens"` +} + +type RerankMeta struct { + Tokens *RerankMetaTokens `json:"tokens"` + Model string `json:"model"` +} + +type RerankResponse struct { + Meta RerankMeta `json:"meta"` + ID string `json:"id"` + Result []RerankResult `json:"result"` +} diff --git a/service/aiproxy/relay/relaymode/define.go b/service/aiproxy/relay/relaymode/define.go index 96d094382ca..88b2086a4c1 100644 --- a/service/aiproxy/relay/relaymode/define.go +++ b/service/aiproxy/relay/relaymode/define.go @@ -11,4 +11,5 @@ const ( AudioSpeech AudioTranscription AudioTranslation + Rerank ) diff --git a/service/aiproxy/relay/relaymode/helper.go b/service/aiproxy/relay/relaymode/helper.go index 5dc188b3f47..7a83ec53f73 100644 --- a/service/aiproxy/relay/relaymode/helper.go +++ b/service/aiproxy/relay/relaymode/helper.go @@ -22,6 +22,8 @@ func GetByPath(path string) int { return AudioTranscription case strings.HasPrefix(path, "/v1/audio/translations"): return AudioTranslation + case strings.HasPrefix(path, "/v1/rerank"): + return Rerank default: return Unknown } diff --git a/service/aiproxy/router/relay.go b/service/aiproxy/router/relay.go index f1ff1c85e7f..20886eb9ddf 100644 --- a/service/aiproxy/router/relay.go +++ b/service/aiproxy/router/relay.go @@ -37,6 +37,7 @@ func SetRelayRouter(router *gin.Engine) { relayV1Router.POST("/audio/transcriptions", controller.Relay) relayV1Router.POST("/audio/translations", controller.Relay) relayV1Router.POST("/audio/speech", controller.Relay) + relayV1Router.POST("/rerank", controller.Relay) relayV1Router.GET("/files", controller.RelayNotImplemented) relayV1Router.POST("/files", controller.RelayNotImplemented) relayV1Router.DELETE("/files/:id", controller.RelayNotImplemented)