diff --git a/chat.go b/chat.go index c74f979..52ae329 100644 --- a/chat.go +++ b/chat.go @@ -44,7 +44,7 @@ const ( type ChatCompletionRequest struct { // (Required) - // ID of the model to use. Currently, only gpt-3.5-turbo and gpt-3.5-turbo-0301 are supported. + // ID of the model to use. Model ChatGPTModel `json:"model"` // Required @@ -157,16 +157,22 @@ func validate(req *ChatCompletionRequest) error { return chatgpt_errors.ErrNoMessages } + isAllowed := false + allowedModels := []ChatGPTModel{ GPT35Turbo, GPT35Turbo0301, GPT35Turbo0613, GPT35Turbo16k, GPT35Turbo16k0613, GPT4, GPT4_0314, GPT4_0613, GPT4_32k, GPT4_32k_0314, GPT4_32k_0613, } for _, model := range allowedModels { - if req.Model != model { - return chatgpt_errors.ErrInvalidModel + if req.Model == model { + isAllowed = true } } + if !isAllowed { + return chatgpt_errors.ErrInvalidModel + } + for _, message := range req.Messages { if message.Role != ChatGPTModelRoleUser && message.Role != ChatGPTModelRoleSystem && message.Role != ChatGPTModelRoleAssistant { return chatgpt_errors.ErrInvalidRole diff --git a/go.mod b/go.mod index d8bca31..194e77e 100644 --- a/go.mod +++ b/go.mod @@ -2,10 +2,10 @@ module github.com/ayush6624/go-chatgpt go 1.20 +require github.com/stretchr/testify v1.8.2 + require ( github.com/davecgh/go-spew v1.1.1 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect - github.com/stretchr/objx v0.5.0 // indirect - github.com/stretchr/testify v1.8.2 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index a788369..6a56e69 100644 --- a/go.sum +++ b/go.sum @@ -5,12 +5,12 @@ github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZb github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= -github.com/stretchr/objx v0.5.0 h1:1zr/of2m5FGMsad5YfcqgdqdWrIhu+EBEJRhR1U7z/c= github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= github.com/stretchr/testify v1.8.2 h1:+h33VjcLVPDHtOdpUCuF+7gSuG3yGIftsP1YvFihtJ8= github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= diff --git a/utils/errors.go b/utils/errors.go index 59ca836..85a4d09 100644 --- a/utils/errors.go +++ b/utils/errors.go @@ -7,7 +7,7 @@ var ( ErrAPIKeyRequired = errors.New("API Key is required") // ErrInvalidModel is returned when the model is invalid - ErrInvalidModel = errors.New("invalid model. Only `gpt-3.5-turbo` and `gpt-3.5-turbo-0301` are supported") + ErrInvalidModel = errors.New("invalid model") // ErrNoMessages is returned when no messages are provided ErrNoMessages = errors.New("no messages provided")