Skip to content

Commit

Permalink
fix: aiproxy graceful shutdown
Browse files Browse the repository at this point in the history
  • Loading branch information
zijiren233 committed Nov 22, 2024
1 parent b829895 commit 9a49350
Show file tree
Hide file tree
Showing 6 changed files with 48 additions and 13 deletions.
20 changes: 18 additions & 2 deletions service/aiproxy/common/balance/sealos.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ const (
appType = "LLM-TOKEN"
sealosRequester = "sealos-admin"
sealosGroupBalanceKey = "sealos:balance:%s"
getBalanceRetry = 3
)

var (
Expand All @@ -34,7 +35,7 @@ var (
minConsumeAmount = decimal.NewFromInt(1)
jwtToken string
sealosRedisCacheEnable = env.Bool("BALANCE_SEALOS_REDIS_CACHE_ENABLE", true)
sealosCacheExpire = 15 * time.Second
sealosCacheExpire = 3 * time.Minute
)

type Sealos struct {
Expand Down Expand Up @@ -139,8 +140,23 @@ func cacheDecreaseGroupBalance(ctx context.Context, group string, amount int64)
return decreaseGroupBalanceScript.Run(ctx, common.RDB, []string{fmt.Sprintf(sealosGroupBalanceKey, group)}, amount).Err()
}

// GroupBalance interface implementation
func (s *Sealos) GetGroupRemainBalance(ctx context.Context, group string) (float64, PostGroupConsumer, error) {
var errs []error
for i := 0; ; i++ {
balance, consumer, err := s.getGroupRemainBalance(ctx, group)
if err == nil {
return balance, consumer, nil
}
errs = append(errs, err)
if i == getBalanceRetry-1 {
return 0, nil, errors.Join(errs...)
}
time.Sleep(time.Second)
}
}

// GroupBalance interface implementation
func (s *Sealos) getGroupRemainBalance(ctx context.Context, group string) (float64, PostGroupConsumer, error) {
if cache, err := cacheGetGroupBalance(ctx, group); err == nil && cache.UserUID != "" {
return decimal.NewFromInt(cache.Balance).Div(decimalBalancePrecision).InexactFloat64(),
newSealosPostGroupConsumer(s.accountURL, group, cache.UserUID, cache.Balance), nil
Expand Down
5 changes: 4 additions & 1 deletion service/aiproxy/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (
"github.com/labring/sealos/service/aiproxy/controller"
"github.com/labring/sealos/service/aiproxy/middleware"
"github.com/labring/sealos/service/aiproxy/model"
relaycontroller "github.com/labring/sealos/service/aiproxy/relay/controller"
"github.com/labring/sealos/service/aiproxy/router"
)

Expand Down Expand Up @@ -111,11 +112,13 @@ func main() {
<-quit
logger.SysLog("shutting down server...")

ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
defer cancel()
if err := srv.Shutdown(ctx); err != nil {
logger.SysError("server forced to shutdown: " + err.Error())
}

relaycontroller.ConsumeWaitGroup.Wait()

logger.SysLog("server exiting")
}
9 changes: 7 additions & 2 deletions service/aiproxy/relay/controller/audio.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (
"github.com/labring/sealos/service/aiproxy/common"
"github.com/labring/sealos/service/aiproxy/common/balance"
"github.com/labring/sealos/service/aiproxy/common/ctxkey"
"github.com/labring/sealos/service/aiproxy/common/helper"
"github.com/labring/sealos/service/aiproxy/relay"
"github.com/labring/sealos/service/aiproxy/relay/adaptor/openai"
"github.com/labring/sealos/service/aiproxy/relay/meta"
Expand Down Expand Up @@ -94,9 +95,12 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
return openai.ErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
}

consumeCtx := context.WithValue(context.Background(), helper.RequestIDKey, c.Value(helper.RequestIDKey))

if resp.StatusCode != http.StatusOK {
err := RelayErrorHandler(resp)
go postConsumeAmount(context.Background(), postGroupConsumer, resp.StatusCode, c.Request.URL.Path, &relaymodel.Usage{
ConsumeWaitGroup.Add(1)
go postConsumeAmount(consumeCtx, &ConsumeWaitGroup, postGroupConsumer, resp.StatusCode, c.Request.URL.Path, &relaymodel.Usage{
PromptTokens: 0,
CompletionTokens: 0,
}, meta, price, completionPrice, err.Message)
Expand All @@ -108,7 +112,8 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
return respErr
}

go postConsumeAmount(context.Background(), postGroupConsumer, resp.StatusCode, c.Request.URL.Path, usage, meta, price, completionPrice, "")
ConsumeWaitGroup.Add(1)
go postConsumeAmount(consumeCtx, &ConsumeWaitGroup, postGroupConsumer, resp.StatusCode, c.Request.URL.Path, usage, meta, price, completionPrice, "")

return nil
}
6 changes: 5 additions & 1 deletion service/aiproxy/relay/controller/helper.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"net/http"
"strings"
"sync"

"github.com/gin-gonic/gin"
"github.com/labring/sealos/service/aiproxy/common"
Expand All @@ -20,6 +21,8 @@ import (
"github.com/shopspring/decimal"
)

var ConsumeWaitGroup sync.WaitGroup

func getAndValidateTextRequest(c *gin.Context, relayMode int) (*relaymodel.GeneralOpenAIRequest, error) {
textRequest := &relaymodel.GeneralOpenAIRequest{}
err := common.UnmarshalBodyReusable(c, textRequest)
Expand Down Expand Up @@ -76,7 +79,8 @@ func preCheckGroupBalance(ctx context.Context, textRequest *relaymodel.GeneralOp
return true, postGroupConsumer, nil
}

func postConsumeAmount(ctx context.Context, postGroupConsumer balance.PostGroupConsumer, code int, endpoint string, usage *relaymodel.Usage, meta *meta.Meta, price, completionPrice float64, content string) {
func postConsumeAmount(ctx context.Context, consumeWaitGroup *sync.WaitGroup, postGroupConsumer balance.PostGroupConsumer, code int, endpoint string, usage *relaymodel.Usage, meta *meta.Meta, price, completionPrice float64, content string) {
defer consumeWaitGroup.Done()
if usage == nil {
err := model.BatchRecordConsume(ctx, meta.Group, code, meta.ChannelID, 0, 0, meta.OriginModelName, meta.TokenID, meta.TokenName, 0, price, completionPrice, endpoint, content)
if err != nil {
Expand Down
10 changes: 6 additions & 4 deletions service/aiproxy/relay/controller/image.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"github.com/gin-gonic/gin"
"github.com/labring/sealos/service/aiproxy/common"
"github.com/labring/sealos/service/aiproxy/common/balance"
"github.com/labring/sealos/service/aiproxy/common/helper"
"github.com/labring/sealos/service/aiproxy/common/logger"
"github.com/labring/sealos/service/aiproxy/model"
"github.com/labring/sealos/service/aiproxy/relay"
Expand Down Expand Up @@ -159,13 +160,14 @@ func RelayImageHelper(c *gin.Context, _ int) *relaymodel.ErrorWithStatusCode {
return openai.ErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
}

defer func(ctx context.Context) {
defer func() {
if resp == nil || resp.StatusCode != http.StatusOK {
_ = model.RecordConsumeLog(ctx, meta.Group, resp.StatusCode, meta.ChannelID, imageRequest.N, 0, imageRequest.Model, meta.TokenID, meta.TokenName, 0, imageCostPrice, 0, c.Request.URL.Path, imageRequest.Size)
return
}

_amount, err := postGroupConsumer.PostGroupConsume(ctx, meta.TokenName, amount)
consumeCtx := context.WithValue(context.Background(), helper.RequestIDKey, ctx.Value(helper.RequestIDKey))
_amount, err := postGroupConsumer.PostGroupConsume(consumeCtx, meta.TokenName, amount)
if err != nil {
logger.Error(ctx, "error consuming token remain balance: "+err.Error())
err = model.CreateConsumeError(meta.Group, meta.TokenName, imageRequest.Model, err.Error(), amount, meta.TokenID)
Expand All @@ -175,11 +177,11 @@ func RelayImageHelper(c *gin.Context, _ int) *relaymodel.ErrorWithStatusCode {
} else {
amount = _amount
}
err = model.BatchRecordConsume(ctx, meta.Group, resp.StatusCode, meta.ChannelID, imageRequest.N, 0, imageRequest.Model, meta.TokenID, meta.TokenName, amount, imageCostPrice, 0, c.Request.URL.Path, imageRequest.Size)
err = model.BatchRecordConsume(consumeCtx, meta.Group, resp.StatusCode, meta.ChannelID, imageRequest.N, 0, imageRequest.Model, meta.TokenID, meta.TokenName, amount, imageCostPrice, 0, c.Request.URL.Path, imageRequest.Size)
if err != nil {
logger.Error(ctx, "failed to record consume log: "+err.Error())
}
}(c.Request.Context())
}()

// do response
_, respErr := adaptor.DoResponse(c, resp, meta)
Expand Down
11 changes: 8 additions & 3 deletions service/aiproxy/relay/controller/text.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (

"github.com/gin-gonic/gin"
json "github.com/json-iterator/go"
"github.com/labring/sealos/service/aiproxy/common/helper"
"github.com/labring/sealos/service/aiproxy/common/logger"
"github.com/labring/sealos/service/aiproxy/relay"
"github.com/labring/sealos/service/aiproxy/relay/adaptor"
Expand Down Expand Up @@ -74,21 +75,25 @@ func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode {
logger.Errorf(ctx, "DoRequest failed: %s", err.Error())
return openai.ErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
}
consumeCtx := context.WithValue(context.Background(), helper.RequestIDKey, ctx.Value(helper.RequestIDKey))
if isErrorHappened(meta, resp) {
err := RelayErrorHandler(resp)
go postConsumeAmount(context.Background(), postGroupConsume, resp.StatusCode, c.Request.URL.Path, nil, meta, price, completionPrice, err.Error.Message)
ConsumeWaitGroup.Add(1)
go postConsumeAmount(consumeCtx, &ConsumeWaitGroup, postGroupConsume, resp.StatusCode, c.Request.URL.Path, nil, meta, price, completionPrice, err.Error.Message)
return err
}

// do response
usage, respErr := adaptor.DoResponse(c, resp, meta)
if respErr != nil {
logger.Errorf(ctx, "respErr is not nil: %+v", respErr)
go postConsumeAmount(context.Background(), postGroupConsume, respErr.StatusCode, c.Request.URL.Path, usage, meta, price, completionPrice, respErr.Error.Message)
ConsumeWaitGroup.Add(1)
go postConsumeAmount(consumeCtx, &ConsumeWaitGroup, postGroupConsume, respErr.StatusCode, c.Request.URL.Path, usage, meta, price, completionPrice, respErr.Error.Message)
return respErr
}
// post-consume amount
go postConsumeAmount(context.Background(), postGroupConsume, resp.StatusCode, c.Request.URL.Path, usage, meta, price, completionPrice, "")
ConsumeWaitGroup.Add(1)
go postConsumeAmount(consumeCtx, &ConsumeWaitGroup, postGroupConsume, resp.StatusCode, c.Request.URL.Path, usage, meta, price, completionPrice, "")
return nil
}

Expand Down

0 comments on commit 9a49350

Please sign in to comment.