diff --git a/core/request_option.go b/core/request_option.go index 9586474..b020c18 100644 --- a/core/request_option.go +++ b/core/request_option.go @@ -56,7 +56,7 @@ func (r *RequestOptions) cloneHeader() http.Header { headers := r.HTTPHeader.Clone() headers.Set("X-Fern-Language", "Go") headers.Set("X-Fern-SDK-Name", "github.com/cohere-ai/cohere-go/v2") - headers.Set("X-Fern-SDK-Version", "v2.7.0") + headers.Set("X-Fern-SDK-Version", "v2.7.1") return headers } diff --git a/datasets.go b/datasets.go index 0241639..8795778 100644 --- a/datasets.go +++ b/datasets.go @@ -12,7 +12,7 @@ import ( type DatasetsCreateRequest struct { // The name of the uploaded dataset. Name string `json:"-" url:"name"` - // The dataset type, which is used to validate the data. + // The dataset type, which is used to validate the data. Valid types are `embed-input`, `reranker-finetune-input`, `prompt-completion-finetune-input`, `single-label-classification-finetune-input`, `chat-finetune-input`, and `multi-label-classification-finetune-input`. Type DatasetType `json:"-" url:"type,omitempty"` // Indicates if the original file should be stored. KeepOriginalFile *bool `json:"-" url:"keep_original_file,omitempty"` diff --git a/embed_jobs.go b/embed_jobs.go index ace88b4..a734ffb 100644 --- a/embed_jobs.go +++ b/embed_jobs.go @@ -21,6 +21,14 @@ type CreateEmbedJobRequest struct { InputType EmbedInputType `json:"input_type,omitempty" url:"input_type,omitempty"` // The name of the embed job. Name *string `json:"name,omitempty" url:"name,omitempty"` + // Specifies the types of embeddings you want to get back. Not required and default is None, which returns the Embed Floats response type. Can be one or more of the following types. + // + // * `"float"`: Use this when you want to get back the default float embeddings. Valid for all models. + // * `"int8"`: Use this when you want to get back signed int8 embeddings. Valid for only v3 models. + // * `"uint8"`: Use this when you want to get back unsigned int8 embeddings. Valid for only v3 models. + // * `"binary"`: Use this when you want to get back signed binary embeddings. Valid for only v3 models. + // * `"ubinary"`: Use this when you want to get back unsigned binary embeddings. Valid for only v3 models. + EmbeddingTypes []EmbeddingType `json:"embedding_types,omitempty" url:"embedding_types,omitempty"` // One of `START|END` to specify how the API will handle inputs longer than the maximum token length. // // Passing `START` will discard the start of the input. `END` will discard the end of the input. In both cases, input is discarded until the remaining input is exactly the maximum input token length for the model. diff --git a/models/client.go b/models/client.go index da9e42f..5c7af91 100644 --- a/models/client.go +++ b/models/client.go @@ -7,6 +7,7 @@ import ( context "context" json "encoding/json" errors "errors" + fmt "fmt" v2 "github.com/cohere-ai/cohere-go/v2" core "github.com/cohere-ai/cohere-go/v2/core" option "github.com/cohere-ai/cohere-go/v2/option" @@ -34,6 +35,76 @@ func NewClient(opts ...option.RequestOption) *Client { } } +// Returns the details of a model, provided its name. +func (c *Client) Get( + ctx context.Context, + model string, + opts ...option.RequestOption, +) (*v2.GetModelResponse, error) { + options := core.NewRequestOptions(opts...) + + baseURL := "https://api.cohere.ai/v1" + if c.baseURL != "" { + baseURL = c.baseURL + } + if options.BaseURL != "" { + baseURL = options.BaseURL + } + endpointURL := fmt.Sprintf(baseURL+"/"+"models/%v", model) + + headers := core.MergeHeaders(c.header.Clone(), options.ToHeader()) + + errorDecoder := func(statusCode int, body io.Reader) error { + raw, err := io.ReadAll(body) + if err != nil { + return err + } + apiError := core.NewAPIError(statusCode, errors.New(string(raw))) + decoder := json.NewDecoder(bytes.NewReader(raw)) + switch statusCode { + case 400: + value := new(v2.BadRequestError) + value.APIError = apiError + if err := decoder.Decode(value); err != nil { + return apiError + } + return value + case 429: + value := new(v2.TooManyRequestsError) + value.APIError = apiError + if err := decoder.Decode(value); err != nil { + return apiError + } + return value + case 500: + value := new(v2.InternalServerError) + value.APIError = apiError + if err := decoder.Decode(value); err != nil { + return apiError + } + return value + } + return apiError + } + + var response *v2.GetModelResponse + if err := c.caller.Call( + ctx, + &core.CallParams{ + URL: endpointURL, + Method: http.MethodGet, + MaxAttempts: options.MaxAttempts, + Headers: headers, + Client: options.HTTPClient, + Response: &response, + ErrorDecoder: errorDecoder, + }, + ); err != nil { + return nil, err + } + return response, nil +} + // Returns a list of models available for use. The list contains models from Cohere as well as your fine-tuned models. func (c *Client) List( ctx context.Context, diff --git a/types.go b/types.go index 9f00b55..977bd04 100644 --- a/types.go +++ b/types.go @@ -12,7 +12,7 @@ import ( type ChatRequest struct { // Text input for the model to respond to. Message string `json:"message" url:"message"` - // Defaults to `command`. + // Defaults to `command-r`. // // The name of a compatible [Cohere model](https://docs.cohere.com/docs/models) or the ID of a [fine-tuned](https://docs.cohere.com/docs/chat-fine-tuning) model. Model *string `json:"model,omitempty" url:"model,omitempty"` @@ -157,7 +157,7 @@ func (c *ChatRequest) MarshalJSON() ([]byte, error) { type ChatStreamRequest struct { // Text input for the model to respond to. Message string `json:"message" url:"message"` - // Defaults to `command`. + // Defaults to `command-r`. // // The name of a compatible [Cohere model](https://docs.cohere.com/docs/models) or the ID of a [fine-tuned](https://docs.cohere.com/docs/chat-fine-tuning) model. Model *string `json:"model,omitempty" url:"model,omitempty"` @@ -321,7 +321,7 @@ type DetokenizeRequest struct { // The list of tokens to be detokenized. Tokens []int `json:"tokens,omitempty" url:"tokens,omitempty"` // An optional parameter to provide the model name. This will ensure that the detokenization is done by the tokenizer used by that model. - Model *string `json:"model,omitempty" url:"model,omitempty"` + Model string `json:"model" url:"model"` } type EmbedRequest struct { @@ -350,7 +350,7 @@ type EmbedRequest struct { // * `"uint8"`: Use this when you want to get back unsigned int8 embeddings. Valid for only v3 models. // * `"binary"`: Use this when you want to get back signed binary embeddings. Valid for only v3 models. // * `"ubinary"`: Use this when you want to get back unsigned binary embeddings. Valid for only v3 models. - EmbeddingTypes []EmbedRequestEmbeddingTypesItem `json:"embedding_types,omitempty" url:"embedding_types,omitempty"` + EmbeddingTypes []EmbeddingType `json:"embedding_types,omitempty" url:"embedding_types,omitempty"` // One of `NONE|START|END` to specify how the API will handle inputs longer than the maximum token length. // // Passing `START` will discard the start of the input. `END` will discard the end of the input. In both cases, input is discarded until the remaining input is exactly the maximum input token length for the model. @@ -575,7 +575,7 @@ type TokenizeRequest struct { // The string to be tokenized, the minimum text length is 1 character, and the maximum text length is 65536 characters. Text string `json:"text" url:"text"` // An optional parameter to provide the model name. This will ensure that the tokenization uses the tokenizer used by that model. - Model *string `json:"model,omitempty" url:"model,omitempty"` + Model string `json:"model" url:"model"` } type ApiMeta struct { @@ -971,38 +971,6 @@ func (c *ChatRequestConnectorsSearchOptions) String() string { return fmt.Sprintf("%#v", c) } -// (internal) Overrides specified parts of the default Chat or RAG preamble. It is recommended that these options only be used in specific scenarios where the defaults are not adequate. -type ChatRequestPromptOverride struct { - Preamble interface{} `json:"preamble,omitempty" url:"preamble,omitempty"` - TaskDescription interface{} `json:"task_description,omitempty" url:"task_description,omitempty"` - StyleGuide interface{} `json:"style_guide,omitempty" url:"style_guide,omitempty"` - - _rawJSON json.RawMessage -} - -func (c *ChatRequestPromptOverride) UnmarshalJSON(data []byte) error { - type unmarshaler ChatRequestPromptOverride - var value unmarshaler - if err := json.Unmarshal(data, &value); err != nil { - return err - } - *c = ChatRequestPromptOverride(value) - c._rawJSON = json.RawMessage(data) - return nil -} - -func (c *ChatRequestPromptOverride) String() string { - if len(c._rawJSON) > 0 { - if value, err := core.StringifyJSON(c._rawJSON); err == nil { - return value - } - } - if value, err := core.StringifyJSON(c); err == nil { - return value - } - return fmt.Sprintf("%#v", c) -} - // Defaults to `AUTO` when `connectors` are specified and `OFF` in all other cases. // // Dictates how the prompt will be constructed. @@ -1136,6 +1104,10 @@ type ChatSearchResult struct { Connector *ChatSearchResultConnector `json:"connector,omitempty" url:"connector,omitempty"` // Identifiers of documents found by this search query. DocumentIds []string `json:"document_ids,omitempty" url:"document_ids,omitempty"` + // An error message if the search failed. + ErrorMessage *string `json:"error_message,omitempty" url:"error_message,omitempty"` + // Whether a chat request should continue or not if the request to this connector fails. + ContinueOnFailure *bool `json:"continue_on_failure,omitempty" url:"continue_on_failure,omitempty"` _rawJSON json.RawMessage } @@ -1385,38 +1357,6 @@ func (c *ChatStreamRequestConnectorsSearchOptions) String() string { return fmt.Sprintf("%#v", c) } -// (internal) Overrides specified parts of the default Chat or RAG preamble. It is recommended that these options only be used in specific scenarios where the defaults are not adequate. -type ChatStreamRequestPromptOverride struct { - Preamble interface{} `json:"preamble,omitempty" url:"preamble,omitempty"` - TaskDescription interface{} `json:"task_description,omitempty" url:"task_description,omitempty"` - StyleGuide interface{} `json:"style_guide,omitempty" url:"style_guide,omitempty"` - - _rawJSON json.RawMessage -} - -func (c *ChatStreamRequestPromptOverride) UnmarshalJSON(data []byte) error { - type unmarshaler ChatStreamRequestPromptOverride - var value unmarshaler - if err := json.Unmarshal(data, &value); err != nil { - return err - } - *c = ChatStreamRequestPromptOverride(value) - c._rawJSON = json.RawMessage(data) - return nil -} - -func (c *ChatStreamRequestPromptOverride) String() string { - if len(c._rawJSON) > 0 { - if value, err := core.StringifyJSON(c._rawJSON); err == nil { - return value - } - } - if value, err := core.StringifyJSON(c); err == nil { - return value - } - return fmt.Sprintf("%#v", c) -} - // Defaults to `AUTO` when `connectors` are specified and `OFF` in all other cases. // // Dictates how the prompt will be constructed. @@ -2574,37 +2514,6 @@ func (e EmbedJobTruncate) Ptr() *EmbedJobTruncate { return &e } -type EmbedRequestEmbeddingTypesItem string - -const ( - EmbedRequestEmbeddingTypesItemFloat EmbedRequestEmbeddingTypesItem = "float" - EmbedRequestEmbeddingTypesItemInt8 EmbedRequestEmbeddingTypesItem = "int8" - EmbedRequestEmbeddingTypesItemUint8 EmbedRequestEmbeddingTypesItem = "uint8" - EmbedRequestEmbeddingTypesItemBinary EmbedRequestEmbeddingTypesItem = "binary" - EmbedRequestEmbeddingTypesItemUbinary EmbedRequestEmbeddingTypesItem = "ubinary" -) - -func NewEmbedRequestEmbeddingTypesItemFromString(s string) (EmbedRequestEmbeddingTypesItem, error) { - switch s { - case "float": - return EmbedRequestEmbeddingTypesItemFloat, nil - case "int8": - return EmbedRequestEmbeddingTypesItemInt8, nil - case "uint8": - return EmbedRequestEmbeddingTypesItemUint8, nil - case "binary": - return EmbedRequestEmbeddingTypesItemBinary, nil - case "ubinary": - return EmbedRequestEmbeddingTypesItemUbinary, nil - } - var t EmbedRequestEmbeddingTypesItem - return "", fmt.Errorf("%s is not a valid %T", s, t) -} - -func (e EmbedRequestEmbeddingTypesItem) Ptr() *EmbedRequestEmbeddingTypesItem { - return &e -} - // One of `NONE|START|END` to specify how the API will handle inputs longer than the maximum token length. // // Passing `START` will discard the start of the input. `END` will discard the end of the input. In both cases, input is discarded until the remaining input is exactly the maximum input token length for the model. @@ -2705,6 +2614,37 @@ func (e *EmbedResponse) Accept(visitor EmbedResponseVisitor) error { return fmt.Errorf("type %T does not define a non-empty union type", e) } +type EmbeddingType string + +const ( + EmbeddingTypeFloat EmbeddingType = "float" + EmbeddingTypeInt8 EmbeddingType = "int8" + EmbeddingTypeUint8 EmbeddingType = "uint8" + EmbeddingTypeBinary EmbeddingType = "binary" + EmbeddingTypeUbinary EmbeddingType = "ubinary" +) + +func NewEmbeddingTypeFromString(s string) (EmbeddingType, error) { + switch s { + case "float": + return EmbeddingTypeFloat, nil + case "int8": + return EmbeddingTypeInt8, nil + case "uint8": + return EmbeddingTypeUint8, nil + case "binary": + return EmbeddingTypeBinary, nil + case "ubinary": + return EmbeddingTypeUbinary, nil + } + var t EmbeddingType + return "", fmt.Errorf("%s is not a valid %T", s, t) +} + +func (e EmbeddingType) Ptr() *EmbeddingType { + return &e +} + type FinetuneDatasetMetrics struct { // The number of tokens of valid examples that can be used for training. TrainableTokenCount *string `json:"trainableTokenCount,omitempty" url:"trainableTokenCount,omitempty"` @@ -3210,6 +3150,47 @@ func (g *GetConnectorResponse) String() string { return fmt.Sprintf("%#v", g) } +// Contains information about the model and which API endpoints it can be used with. +type GetModelResponse struct { + // Specify this name in the `model` parameter of API requests to use your chosen model. + Name *string `json:"name,omitempty" url:"name,omitempty"` + // The API endpoints that the model is compatible with. + Endpoints []CompatibleEndpoint `json:"endpoints,omitempty" url:"endpoints,omitempty"` + // Whether the model has been fine-tuned or not. + Finetuned *bool `json:"finetuned,omitempty" url:"finetuned,omitempty"` + // The maximum number of tokens that the model can process in a single request. Note that not all of these tokens are always available due to special tokens and preambles that Cohere has added by default. + ContextLength *float64 `json:"context_length,omitempty" url:"context_length,omitempty"` + // The name of the tokenizer used for the model. + Tokenizer *string `json:"tokenizer,omitempty" url:"tokenizer,omitempty"` + // Public URL to the tokenizer's configuration file. + TokenizerUrl *string `json:"tokenizer_url,omitempty" url:"tokenizer_url,omitempty"` + + _rawJSON json.RawMessage +} + +func (g *GetModelResponse) UnmarshalJSON(data []byte) error { + type unmarshaler GetModelResponse + var value unmarshaler + if err := json.Unmarshal(data, &value); err != nil { + return err + } + *g = GetModelResponse(value) + g._rawJSON = json.RawMessage(data) + return nil +} + +func (g *GetModelResponse) String() string { + if len(g._rawJSON) > 0 { + if value, err := core.StringifyJSON(g._rawJSON); err == nil { + return value + } + } + if value, err := core.StringifyJSON(g); err == nil { + return value + } + return fmt.Sprintf("%#v", g) +} + type LabelMetric struct { // Total number of examples for this label TotalExamples *string `json:"totalExamples,omitempty" url:"totalExamples,omitempty"` @@ -3305,7 +3286,7 @@ func (l *ListEmbedJobResponse) String() string { } type ListModelsResponse struct { - Models []*Model `json:"models,omitempty" url:"models,omitempty"` + Models []*GetModelResponse `json:"models,omitempty" url:"models,omitempty"` // A token to retrieve the next page of results. Provide in the page_token parameter of the next request. NextPageToken *string `json:"next_page_token,omitempty" url:"next_page_token,omitempty"` @@ -3364,47 +3345,6 @@ func (m *Metrics) String() string { return fmt.Sprintf("%#v", m) } -// Contains information about the model and which API endpoints it can be used with. -type Model struct { - // Specify this name in the `model` parameter of API requests to use your chosen model. - Name *string `json:"name,omitempty" url:"name,omitempty"` - // The API endpoints that the model is compatible with. - Endpoints []CompatibleEndpoint `json:"endpoints,omitempty" url:"endpoints,omitempty"` - // Whether the model has been fine-tuned or not. - Finetuned *bool `json:"finetuned,omitempty" url:"finetuned,omitempty"` - // The maximum number of tokens that the model can process in a single request. Note that not all of these tokens are always available due to special tokens and preambles that Cohere has added by default. - ContextLength *float64 `json:"context_length,omitempty" url:"context_length,omitempty"` - // The name of the tokenizer used for the model. - Tokenizer *string `json:"tokenizer,omitempty" url:"tokenizer,omitempty"` - // Public URL to the tokenizer's configuration file. - TokenizerUrl *string `json:"tokenizer_url,omitempty" url:"tokenizer_url,omitempty"` - - _rawJSON json.RawMessage -} - -func (m *Model) UnmarshalJSON(data []byte) error { - type unmarshaler Model - var value unmarshaler - if err := json.Unmarshal(data, &value); err != nil { - return err - } - *m = Model(value) - m._rawJSON = json.RawMessage(data) - return nil -} - -func (m *Model) String() string { - if len(m._rawJSON) > 0 { - if value, err := core.StringifyJSON(m._rawJSON); err == nil { - return value - } - } - if value, err := core.StringifyJSON(m); err == nil { - return value - } - return fmt.Sprintf("%#v", m) -} - type NonStreamedChatResponse struct { // Contents of the reply generated by the model. Text string `json:"text" url:"text"` @@ -4218,7 +4158,7 @@ func (t *ToolCall) String() string { type ToolParameterDefinitionsValue struct { // The description of the parameter. - Description string `json:"description" url:"description"` + Description *string `json:"description,omitempty" url:"description,omitempty"` // The type of the parameter. Must be a valid Python type. Type string `json:"type" url:"type"` // Denotes whether the parameter is always present (required) or not. Defaults to not required.