Skip to content

Commit

Permalink
fix: get channel
Browse files Browse the repository at this point in the history
  • Loading branch information
zijiren233 committed Nov 15, 2024
1 parent 6fcad0b commit 8247de0
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 23 deletions.
21 changes: 8 additions & 13 deletions service/aiproxy/middleware/distributor.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@ package middleware
import (
"fmt"
"net/http"
"slices"
"strconv"

"github.com/gin-gonic/gin"
"github.com/labring/sealos/service/aiproxy/common/config"
"github.com/labring/sealos/service/aiproxy/common/ctxkey"
"github.com/labring/sealos/service/aiproxy/common/logger"
"github.com/labring/sealos/service/aiproxy/model"
"github.com/labring/sealos/service/aiproxy/relay/channeltype"
)
Expand All @@ -22,34 +22,29 @@ func Distribute(c *gin.Context) {
abortWithMessage(c, http.StatusServiceUnavailable, "服务暂停中")
return
}
group := c.GetString(ctxkey.Group)
requestModel := c.GetString(ctxkey.RequestModel)
var channel *model.Channel
channelID, ok := c.Get(ctxkey.SpecificChannelID)
if ok {
id, err := strconv.Atoi(channelID.(string))
if err != nil {
abortWithMessage(c, http.StatusBadRequest, "无效的渠道 Id")
abortWithMessage(c, http.StatusBadRequest, "无效的渠道 ID")
return
}
channel, err = model.GetChannelByID(id, false)
if err != nil {
abortWithMessage(c, http.StatusBadRequest, "无效的渠道 Id")
channel, ok = model.CacheGetChannelByID(id)
if !ok {
abortWithMessage(c, http.StatusBadRequest, "无效的渠道 ID")
return
}
if channel.Status != model.ChannelStatusEnabled {
abortWithMessage(c, http.StatusForbidden, "该渠道已被禁用")
if !slices.Contains(channel.Models, requestModel) {
abortWithMessage(c, http.StatusServiceUnavailable, fmt.Sprintf("渠道 %s 不支持模型 %s", channel.Name, requestModel))
return
}
} else {
var err error
channel, err = model.CacheGetRandomSatisfiedChannel(requestModel)
if err != nil {
message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", group, requestModel)
if channel != nil {
logger.SysError(fmt.Sprintf("渠道不存在:%d", channel.ID))
message = "数据库一致性已被破坏,请联系管理员"
}
message := fmt.Sprintf("%s 不可用", requestModel)
abortWithMessage(c, http.StatusServiceUnavailable, message)
return
}
Expand Down
17 changes: 13 additions & 4 deletions service/aiproxy/model/cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -271,10 +271,11 @@ func CacheGetGroup(id string) (*GroupCache, error) {
}

var (
model2channels map[string][]*Channel
allModels []string
type2Models map[int][]string
channelSyncLock sync.RWMutex
model2channels map[string][]*Channel
allModels []string
type2Models map[int][]string
channelID2channel map[int]*Channel
channelSyncLock sync.RWMutex
)

func CacheGetAllModels() []string {
Expand Down Expand Up @@ -344,6 +345,7 @@ func InitChannelCache() {
model2channels = newModel2channels
allModels = models
type2Models = newType2Models
channelID2channel = newChannelID2channel
channelSyncLock.Unlock()
logger.SysDebug("channels synced from database")
}
Expand Down Expand Up @@ -388,3 +390,10 @@ func CacheGetRandomSatisfiedChannel(model string) (*Channel, error) {

return channels[rand.Intn(len(channels))], nil
}

func CacheGetChannelByID(id int) (*Channel, bool) {
channelSyncLock.RLock()
channel, ok := channelID2channel[id]
channelSyncLock.RUnlock()
return channel, ok
}
6 changes: 3 additions & 3 deletions service/aiproxy/relay/adaptor/anthropic/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,9 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC
}

response, meta := StreamResponseClaude2OpenAI(&claudeResponse)
if response == nil {
continue
}
if meta != nil {
usage.PromptTokens += meta.Usage.InputTokens
usage.CompletionTokens += meta.Usage.OutputTokens
Expand All @@ -312,9 +315,6 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC
}
}
}
if response == nil {
continue
}

response.ID = id
response.Model = modelName
Expand Down
6 changes: 3 additions & 3 deletions service/aiproxy/relay/adaptor/aws/claude/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,9 @@ func StreamHandler(c *gin.Context, awsCli *bedrockruntime.Client) (*relaymodel.E
}

response, meta := anthropic.StreamResponseClaude2OpenAI(&claudeResp)
if response == nil {
return true
}
if meta != nil {
usage.PromptTokens += meta.Usage.InputTokens
usage.CompletionTokens += meta.Usage.OutputTokens
Expand All @@ -175,9 +178,6 @@ func StreamHandler(c *gin.Context, awsCli *bedrockruntime.Client) (*relaymodel.E
}
}
}
if response == nil {
return true
}
response.ID = id
response.Model = c.GetString(ctxkey.OriginalModel)
response.Created = createdTime
Expand Down

0 comments on commit 8247de0

Please sign in to comment.