diff --git a/.fernignore b/.fernignore index 9f170d2..c19d702 100644 --- a/.fernignore +++ b/.fernignore @@ -3,4 +3,4 @@ README.md banner.png LICENSE .github/workflows/e2e.yml -tests/** \ No newline at end of file +tests/ \ No newline at end of file diff --git a/client/client.go b/client/client.go index 9ff0f5b..70dd563 100644 --- a/client/client.go +++ b/client/client.go @@ -12,6 +12,7 @@ import ( core "github.com/cohere-ai/cohere-go/v2/core" datasets "github.com/cohere-ai/cohere-go/v2/datasets" embedjobs "github.com/cohere-ai/cohere-go/v2/embedjobs" + models "github.com/cohere-ai/cohere-go/v2/models" option "github.com/cohere-ai/cohere-go/v2/option" io "io" http "net/http" @@ -25,6 +26,7 @@ type Client struct { EmbedJobs *embedjobs.Client Datasets *datasets.Client Connectors *connectors.Client + Models *models.Client } func NewClient(opts ...option.RequestOption) *Client { @@ -41,12 +43,12 @@ func NewClient(opts ...option.RequestOption) *Client { EmbedJobs: embedjobs.NewClient(opts...), Datasets: datasets.NewClient(opts...), Connectors: connectors.NewClient(opts...), + Models: models.NewClient(opts...), } } -// The `chat` endpoint allows users to have conversations with a Large Language Model (LLM) from Cohere. Users can send messages as part of a persisted conversation using the `conversation_id` parameter, or they can pass in their own conversation history using the `chat_history` parameter. -// -// The endpoint features additional parameters such as [connectors](https://docs.cohere.com/docs/connectors) and `documents` that enable conversations enriched by external knowledge. We call this ["Retrieval Augmented Generation"](https://docs.cohere.com/docs/retrieval-augmented-generation-rag), or "RAG". For a full breakdown of the Chat API endpoint, document and connector modes, and streaming (with code samples), see [this guide](https://docs.cohere.com/docs/cochat-beta). +// Generates a text response to a user message. +// To learn how to use Chat with Streaming and RAG follow [this guide](https://docs.cohere.com/docs/cochat-beta#various-ways-of-using-the-chat-endpoint). func (c *Client) ChatStream( ctx context.Context, request *v2.ChatStreamRequest, @@ -54,34 +56,53 @@ func (c *Client) ChatStream( ) (*core.Stream[v2.StreamedChatResponse], error) { options := core.NewRequestOptions(opts...) - baseURL := "https://api.cohere.ai" + baseURL := "https://api.cohere.ai/v1" if c.baseURL != "" { baseURL = c.baseURL } if options.BaseURL != "" { baseURL = options.BaseURL } - endpointURL := baseURL + "/" + "v1/chat" + endpointURL := baseURL + "/" + "chat" headers := core.MergeHeaders(c.header.Clone(), options.ToHeader()) + errorDecoder := func(statusCode int, body io.Reader) error { + raw, err := io.ReadAll(body) + if err != nil { + return err + } + apiError := core.NewAPIError(statusCode, errors.New(string(raw))) + decoder := json.NewDecoder(bytes.NewReader(raw)) + switch statusCode { + case 429: + value := new(v2.TooManyRequestsError) + value.APIError = apiError + if err := decoder.Decode(value); err != nil { + return apiError + } + return value + } + return apiError + } + streamer := core.NewStreamer[v2.StreamedChatResponse](c.caller) return streamer.Stream( ctx, &core.StreamParams{ - URL: endpointURL, - Method: http.MethodPost, - MaxAttempts: options.MaxAttempts, - Headers: headers, - Client: options.HTTPClient, - Request: request, + URL: endpointURL, + Method: http.MethodPost, + MaxAttempts: options.MaxAttempts, + Headers: headers, + Client: options.HTTPClient, + Request: request, + ErrorDecoder: errorDecoder, }, ) } -// The `chat` endpoint allows users to have conversations with a Large Language Model (LLM) from Cohere. Users can send messages as part of a persisted conversation using the `conversation_id` parameter, or they can pass in their own conversation history using the `chat_history` parameter. -// -// The endpoint features additional parameters such as [connectors](https://docs.cohere.com/docs/connectors) and `documents` that enable conversations enriched by external knowledge. We call this ["Retrieval Augmented Generation"](https://docs.cohere.com/docs/retrieval-augmented-generation-rag), or "RAG". For a full breakdown of the Chat API endpoint, document and connector modes, and streaming (with code samples), see [this guide](https://docs.cohere.com/docs/cochat-beta). +// Generates a text response to a user message. +// To learn how to use Chat with Streaming and RAG follow [this guide](https://docs.cohere.com/docs/cochat-beta#various-ways-of-using-the-chat-endpoint). func (c *Client) Chat( ctx context.Context, request *v2.ChatRequest, @@ -89,28 +110,48 @@ func (c *Client) Chat( ) (*v2.NonStreamedChatResponse, error) { options := core.NewRequestOptions(opts...) - baseURL := "https://api.cohere.ai" + baseURL := "https://api.cohere.ai/v1" if c.baseURL != "" { baseURL = c.baseURL } if options.BaseURL != "" { baseURL = options.BaseURL } - endpointURL := baseURL + "/" + "v1/chat" + endpointURL := baseURL + "/" + "chat" headers := core.MergeHeaders(c.header.Clone(), options.ToHeader()) + errorDecoder := func(statusCode int, body io.Reader) error { + raw, err := io.ReadAll(body) + if err != nil { + return err + } + apiError := core.NewAPIError(statusCode, errors.New(string(raw))) + decoder := json.NewDecoder(bytes.NewReader(raw)) + switch statusCode { + case 429: + value := new(v2.TooManyRequestsError) + value.APIError = apiError + if err := decoder.Decode(value); err != nil { + return apiError + } + return value + } + return apiError + } + var response *v2.NonStreamedChatResponse if err := c.caller.Call( ctx, &core.CallParams{ - URL: endpointURL, - Method: http.MethodPost, - MaxAttempts: options.MaxAttempts, - Headers: headers, - Client: options.HTTPClient, - Request: request, - Response: &response, + URL: endpointURL, + Method: http.MethodPost, + MaxAttempts: options.MaxAttempts, + Headers: headers, + Client: options.HTTPClient, + Request: request, + Response: &response, + ErrorDecoder: errorDecoder, }, ); err != nil { return nil, err @@ -126,14 +167,14 @@ func (c *Client) GenerateStream( ) (*core.Stream[v2.GenerateStreamedResponse], error) { options := core.NewRequestOptions(opts...) - baseURL := "https://api.cohere.ai" + baseURL := "https://api.cohere.ai/v1" if c.baseURL != "" { baseURL = c.baseURL } if options.BaseURL != "" { baseURL = options.BaseURL } - endpointURL := baseURL + "/" + "v1/generate" + endpointURL := baseURL + "/" + "generate" headers := core.MergeHeaders(c.header.Clone(), options.ToHeader()) @@ -152,6 +193,13 @@ func (c *Client) GenerateStream( return apiError } return value + case 429: + value := new(v2.TooManyRequestsError) + value.APIError = apiError + if err := decoder.Decode(value); err != nil { + return apiError + } + return value case 500: value := new(v2.InternalServerError) value.APIError = apiError @@ -186,14 +234,14 @@ func (c *Client) Generate( ) (*v2.Generation, error) { options := core.NewRequestOptions(opts...) - baseURL := "https://api.cohere.ai" + baseURL := "https://api.cohere.ai/v1" if c.baseURL != "" { baseURL = c.baseURL } if options.BaseURL != "" { baseURL = options.BaseURL } - endpointURL := baseURL + "/" + "v1/generate" + endpointURL := baseURL + "/" + "generate" headers := core.MergeHeaders(c.header.Clone(), options.ToHeader()) @@ -212,6 +260,13 @@ func (c *Client) Generate( return apiError } return value + case 429: + value := new(v2.TooManyRequestsError) + value.APIError = apiError + if err := decoder.Decode(value); err != nil { + return apiError + } + return value case 500: value := new(v2.InternalServerError) value.APIError = apiError @@ -254,14 +309,14 @@ func (c *Client) Embed( ) (*v2.EmbedResponse, error) { options := core.NewRequestOptions(opts...) - baseURL := "https://api.cohere.ai" + baseURL := "https://api.cohere.ai/v1" if c.baseURL != "" { baseURL = c.baseURL } if options.BaseURL != "" { baseURL = options.BaseURL } - endpointURL := baseURL + "/" + "v1/embed" + endpointURL := baseURL + "/" + "embed" headers := core.MergeHeaders(c.header.Clone(), options.ToHeader()) @@ -280,6 +335,13 @@ func (c *Client) Embed( return apiError } return value + case 429: + value := new(v2.TooManyRequestsError) + value.APIError = apiError + if err := decoder.Decode(value); err != nil { + return apiError + } + return value case 500: value := new(v2.InternalServerError) value.APIError = apiError @@ -318,28 +380,48 @@ func (c *Client) Rerank( ) (*v2.RerankResponse, error) { options := core.NewRequestOptions(opts...) - baseURL := "https://api.cohere.ai" + baseURL := "https://api.cohere.ai/v1" if c.baseURL != "" { baseURL = c.baseURL } if options.BaseURL != "" { baseURL = options.BaseURL } - endpointURL := baseURL + "/" + "v1/rerank" + endpointURL := baseURL + "/" + "rerank" headers := core.MergeHeaders(c.header.Clone(), options.ToHeader()) + errorDecoder := func(statusCode int, body io.Reader) error { + raw, err := io.ReadAll(body) + if err != nil { + return err + } + apiError := core.NewAPIError(statusCode, errors.New(string(raw))) + decoder := json.NewDecoder(bytes.NewReader(raw)) + switch statusCode { + case 429: + value := new(v2.TooManyRequestsError) + value.APIError = apiError + if err := decoder.Decode(value); err != nil { + return apiError + } + return value + } + return apiError + } + var response *v2.RerankResponse if err := c.caller.Call( ctx, &core.CallParams{ - URL: endpointURL, - Method: http.MethodPost, - MaxAttempts: options.MaxAttempts, - Headers: headers, - Client: options.HTTPClient, - Request: request, - Response: &response, + URL: endpointURL, + Method: http.MethodPost, + MaxAttempts: options.MaxAttempts, + Headers: headers, + Client: options.HTTPClient, + Request: request, + Response: &response, + ErrorDecoder: errorDecoder, }, ); err != nil { return nil, err @@ -356,14 +438,14 @@ func (c *Client) Classify( ) (*v2.ClassifyResponse, error) { options := core.NewRequestOptions(opts...) - baseURL := "https://api.cohere.ai" + baseURL := "https://api.cohere.ai/v1" if c.baseURL != "" { baseURL = c.baseURL } if options.BaseURL != "" { baseURL = options.BaseURL } - endpointURL := baseURL + "/" + "v1/classify" + endpointURL := baseURL + "/" + "classify" headers := core.MergeHeaders(c.header.Clone(), options.ToHeader()) @@ -382,6 +464,13 @@ func (c *Client) Classify( return apiError } return value + case 429: + value := new(v2.TooManyRequestsError) + value.APIError = apiError + if err := decoder.Decode(value); err != nil { + return apiError + } + return value case 500: value := new(v2.InternalServerError) value.APIError = apiError @@ -420,28 +509,48 @@ func (c *Client) Summarize( ) (*v2.SummarizeResponse, error) { options := core.NewRequestOptions(opts...) - baseURL := "https://api.cohere.ai" + baseURL := "https://api.cohere.ai/v1" if c.baseURL != "" { baseURL = c.baseURL } if options.BaseURL != "" { baseURL = options.BaseURL } - endpointURL := baseURL + "/" + "v1/summarize" + endpointURL := baseURL + "/" + "summarize" headers := core.MergeHeaders(c.header.Clone(), options.ToHeader()) + errorDecoder := func(statusCode int, body io.Reader) error { + raw, err := io.ReadAll(body) + if err != nil { + return err + } + apiError := core.NewAPIError(statusCode, errors.New(string(raw))) + decoder := json.NewDecoder(bytes.NewReader(raw)) + switch statusCode { + case 429: + value := new(v2.TooManyRequestsError) + value.APIError = apiError + if err := decoder.Decode(value); err != nil { + return apiError + } + return value + } + return apiError + } + var response *v2.SummarizeResponse if err := c.caller.Call( ctx, &core.CallParams{ - URL: endpointURL, - Method: http.MethodPost, - MaxAttempts: options.MaxAttempts, - Headers: headers, - Client: options.HTTPClient, - Request: request, - Response: &response, + URL: endpointURL, + Method: http.MethodPost, + MaxAttempts: options.MaxAttempts, + Headers: headers, + Client: options.HTTPClient, + Request: request, + Response: &response, + ErrorDecoder: errorDecoder, }, ); err != nil { return nil, err @@ -457,14 +566,14 @@ func (c *Client) Tokenize( ) (*v2.TokenizeResponse, error) { options := core.NewRequestOptions(opts...) - baseURL := "https://api.cohere.ai" + baseURL := "https://api.cohere.ai/v1" if c.baseURL != "" { baseURL = c.baseURL } if options.BaseURL != "" { baseURL = options.BaseURL } - endpointURL := baseURL + "/" + "v1/tokenize" + endpointURL := baseURL + "/" + "tokenize" headers := core.MergeHeaders(c.header.Clone(), options.ToHeader()) @@ -483,6 +592,13 @@ func (c *Client) Tokenize( return apiError } return value + case 429: + value := new(v2.TooManyRequestsError) + value.APIError = apiError + if err := decoder.Decode(value); err != nil { + return apiError + } + return value case 500: value := new(v2.InternalServerError) value.APIError = apiError @@ -521,28 +637,48 @@ func (c *Client) Detokenize( ) (*v2.DetokenizeResponse, error) { options := core.NewRequestOptions(opts...) - baseURL := "https://api.cohere.ai" + baseURL := "https://api.cohere.ai/v1" if c.baseURL != "" { baseURL = c.baseURL } if options.BaseURL != "" { baseURL = options.BaseURL } - endpointURL := baseURL + "/" + "v1/detokenize" + endpointURL := baseURL + "/" + "detokenize" headers := core.MergeHeaders(c.header.Clone(), options.ToHeader()) + errorDecoder := func(statusCode int, body io.Reader) error { + raw, err := io.ReadAll(body) + if err != nil { + return err + } + apiError := core.NewAPIError(statusCode, errors.New(string(raw))) + decoder := json.NewDecoder(bytes.NewReader(raw)) + switch statusCode { + case 429: + value := new(v2.TooManyRequestsError) + value.APIError = apiError + if err := decoder.Decode(value); err != nil { + return apiError + } + return value + } + return apiError + } + var response *v2.DetokenizeResponse if err := c.caller.Call( ctx, &core.CallParams{ - URL: endpointURL, - Method: http.MethodPost, - MaxAttempts: options.MaxAttempts, - Headers: headers, - Client: options.HTTPClient, - Request: request, - Response: &response, + URL: endpointURL, + Method: http.MethodPost, + MaxAttempts: options.MaxAttempts, + Headers: headers, + Client: options.HTTPClient, + Request: request, + Response: &response, + ErrorDecoder: errorDecoder, }, ); err != nil { return nil, err diff --git a/connectors/client.go b/connectors/client.go index 724402d..135dfd9 100644 --- a/connectors/client.go +++ b/connectors/client.go @@ -43,14 +43,14 @@ func (c *Client) List( ) (*v2.ListConnectorsResponse, error) { options := core.NewRequestOptions(opts...) - baseURL := "https://api.cohere.ai" + baseURL := "https://api.cohere.ai/v1" if c.baseURL != "" { baseURL = c.baseURL } if options.BaseURL != "" { baseURL = options.BaseURL } - endpointURL := baseURL + "/" + "v1/connectors" + endpointURL := baseURL + "/" + "connectors" queryParams, err := core.QueryValues(request) if err != nil { @@ -77,6 +77,13 @@ func (c *Client) List( return apiError } return value + case 429: + value := new(v2.TooManyRequestsError) + value.APIError = apiError + if err := decoder.Decode(value); err != nil { + return apiError + } + return value case 500: value := new(v2.InternalServerError) value.APIError = apiError @@ -114,14 +121,14 @@ func (c *Client) Create( ) (*v2.CreateConnectorResponse, error) { options := core.NewRequestOptions(opts...) - baseURL := "https://api.cohere.ai" + baseURL := "https://api.cohere.ai/v1" if c.baseURL != "" { baseURL = c.baseURL } if options.BaseURL != "" { baseURL = options.BaseURL } - endpointURL := baseURL + "/" + "v1/connectors" + endpointURL := baseURL + "/" + "connectors" headers := core.MergeHeaders(c.header.Clone(), options.ToHeader()) @@ -147,6 +154,13 @@ func (c *Client) Create( return apiError } return value + case 429: + value := new(v2.TooManyRequestsError) + value.APIError = apiError + if err := decoder.Decode(value); err != nil { + return apiError + } + return value case 500: value := new(v2.InternalServerError) value.APIError = apiError @@ -186,14 +200,14 @@ func (c *Client) Get( ) (*v2.GetConnectorResponse, error) { options := core.NewRequestOptions(opts...) - baseURL := "https://api.cohere.ai" + baseURL := "https://api.cohere.ai/v1" if c.baseURL != "" { baseURL = c.baseURL } if options.BaseURL != "" { baseURL = options.BaseURL } - endpointURL := fmt.Sprintf(baseURL+"/"+"v1/connectors/%v", id) + endpointURL := fmt.Sprintf(baseURL+"/"+"connectors/%v", id) headers := core.MergeHeaders(c.header.Clone(), options.ToHeader()) @@ -219,6 +233,13 @@ func (c *Client) Get( return apiError } return value + case 429: + value := new(v2.TooManyRequestsError) + value.APIError = apiError + if err := decoder.Decode(value); err != nil { + return apiError + } + return value case 500: value := new(v2.InternalServerError) value.APIError = apiError @@ -257,14 +278,14 @@ func (c *Client) Delete( ) (v2.DeleteConnectorResponse, error) { options := core.NewRequestOptions(opts...) - baseURL := "https://api.cohere.ai" + baseURL := "https://api.cohere.ai/v1" if c.baseURL != "" { baseURL = c.baseURL } if options.BaseURL != "" { baseURL = options.BaseURL } - endpointURL := fmt.Sprintf(baseURL+"/"+"v1/connectors/%v", id) + endpointURL := fmt.Sprintf(baseURL+"/"+"connectors/%v", id) headers := core.MergeHeaders(c.header.Clone(), options.ToHeader()) @@ -297,6 +318,13 @@ func (c *Client) Delete( return apiError } return value + case 429: + value := new(v2.TooManyRequestsError) + value.APIError = apiError + if err := decoder.Decode(value); err != nil { + return apiError + } + return value case 500: value := new(v2.InternalServerError) value.APIError = apiError @@ -336,14 +364,14 @@ func (c *Client) Update( ) (*v2.UpdateConnectorResponse, error) { options := core.NewRequestOptions(opts...) - baseURL := "https://api.cohere.ai" + baseURL := "https://api.cohere.ai/v1" if c.baseURL != "" { baseURL = c.baseURL } if options.BaseURL != "" { baseURL = options.BaseURL } - endpointURL := fmt.Sprintf(baseURL+"/"+"v1/connectors/%v", id) + endpointURL := fmt.Sprintf(baseURL+"/"+"connectors/%v", id) headers := core.MergeHeaders(c.header.Clone(), options.ToHeader()) @@ -376,6 +404,13 @@ func (c *Client) Update( return apiError } return value + case 429: + value := new(v2.TooManyRequestsError) + value.APIError = apiError + if err := decoder.Decode(value); err != nil { + return apiError + } + return value case 500: value := new(v2.InternalServerError) value.APIError = apiError @@ -416,14 +451,14 @@ func (c *Client) OAuthAuthorize( ) (*v2.OAuthAuthorizeResponse, error) { options := core.NewRequestOptions(opts...) - baseURL := "https://api.cohere.ai" + baseURL := "https://api.cohere.ai/v1" if c.baseURL != "" { baseURL = c.baseURL } if options.BaseURL != "" { baseURL = options.BaseURL } - endpointURL := fmt.Sprintf(baseURL+"/"+"v1/connectors/%v/oauth/authorize", id) + endpointURL := fmt.Sprintf(baseURL+"/"+"connectors/%v/oauth/authorize", id) queryParams, err := core.QueryValues(request) if err != nil { @@ -457,6 +492,13 @@ func (c *Client) OAuthAuthorize( return apiError } return value + case 429: + value := new(v2.TooManyRequestsError) + value.APIError = apiError + if err := decoder.Decode(value); err != nil { + return apiError + } + return value case 500: value := new(v2.InternalServerError) value.APIError = apiError diff --git a/core/query_test.go b/core/query_test.go index 130720f..4f0d392 100644 --- a/core/query_test.go +++ b/core/query_test.go @@ -120,15 +120,27 @@ func TestQueryValues(t *testing.T) { t.Run("date", func(t *testing.T) { type example struct { - DateTime time.Time `json:"dateTime" url:"dateTime" format:"date"` + Date time.Time `json:"date" url:"date" format:"date"` } values, err := QueryValues( &example{ - DateTime: time.Date(1994, 3, 16, 12, 34, 56, 0, time.UTC), + Date: time.Date(1994, 3, 16, 12, 34, 56, 0, time.UTC), }, ) require.NoError(t, err) - assert.Equal(t, "dateTime=1994-03-16", values.Encode()) + assert.Equal(t, "date=1994-03-16", values.Encode()) + }) + + t.Run("optional time", func(t *testing.T) { + type example struct { + Date *time.Time `json:"date,omitempty" url:"date,omitempty" format:"date"` + } + + values, err := QueryValues( + &example{}, + ) + require.NoError(t, err) + assert.Empty(t, values.Encode()) }) } diff --git a/core/request_option.go b/core/request_option.go index 55389a4..9379a37 100644 --- a/core/request_option.go +++ b/core/request_option.go @@ -56,7 +56,7 @@ func (r *RequestOptions) cloneHeader() http.Header { headers := r.HTTPHeader.Clone() headers.Set("X-Fern-Language", "Go") headers.Set("X-Fern-SDK-Name", "github.com/cohere-ai/cohere-go/v2") - headers.Set("X-Fern-SDK-Version", "v2.5.2") + headers.Set("X-Fern-SDK-Version", "v2.6.0") return headers } diff --git a/core/time.go b/core/time.go new file mode 100644 index 0000000..d009ab3 --- /dev/null +++ b/core/time.go @@ -0,0 +1,137 @@ +package core + +import ( + "encoding/json" + "time" +) + +const dateFormat = "2006-01-02" + +// DateTime wraps time.Time and adapts its JSON representation +// to conform to a RFC3339 date (e.g. 2006-01-02). +// +// Ref: https://ijmacd.github.io/rfc3339-iso8601 +type Date struct { + t *time.Time +} + +// NewDate returns a new *Date. If the given time.Time +// is nil, nil will be returned. +func NewDate(t time.Time) *Date { + return &Date{t: &t} +} + +// NewOptionalDate returns a new *Date. If the given time.Time +// is nil, nil will be returned. +func NewOptionalDate(t *time.Time) *Date { + if t == nil { + return nil + } + return &Date{t: t} +} + +// Time returns the Date's underlying time, if any. If the +// date is nil, the zero value is returned. +func (d *Date) Time() time.Time { + if d == nil || d.t == nil { + return time.Time{} + } + return *d.t +} + +// TimePtr returns a pointer to the Date's underlying time.Time, if any. +func (d *Date) TimePtr() *time.Time { + if d == nil || d.t == nil { + return nil + } + if d.t.IsZero() { + return nil + } + return d.t +} + +func (d *Date) MarshalJSON() ([]byte, error) { + if d == nil || d.t == nil { + return nil, nil + } + return json.Marshal(d.t.Format(dateFormat)) +} + +func (d *Date) UnmarshalJSON(data []byte) error { + var raw string + if err := json.Unmarshal(data, &raw); err != nil { + return err + } + + parsedTime, err := time.Parse(dateFormat, raw) + if err != nil { + return err + } + + *d = Date{t: &parsedTime} + return nil +} + +// DateTime wraps time.Time and adapts its JSON representation +// to conform to a RFC3339 date-time (e.g. 2017-07-21T17:32:28Z). +// +// Ref: https://ijmacd.github.io/rfc3339-iso8601 +type DateTime struct { + t *time.Time +} + +// NewDateTime returns a new *DateTime. +func NewDateTime(t time.Time) *DateTime { + return &DateTime{t: &t} +} + +// NewOptionalDateTime returns a new *DateTime. If the given time.Time +// is nil, nil will be returned. +func NewOptionalDateTime(t *time.Time) *DateTime { + if t == nil { + return nil + } + return &DateTime{t: t} +} + +// Time returns the DateTime's underlying time, if any. If the +// date-time is nil, the zero value is returned. +func (d *DateTime) Time() time.Time { + if d == nil || d.t == nil { + return time.Time{} + } + return *d.t +} + +// TimePtr returns a pointer to the DateTime's underlying time.Time, if any. +func (d *DateTime) TimePtr() *time.Time { + if d == nil || d.t == nil { + return nil + } + if d.t.IsZero() { + return nil + } + return d.t +} + +func (d *DateTime) MarshalJSON() ([]byte, error) { + if d == nil || d.t == nil { + return nil, nil + } + return json.Marshal(d.t.Format(time.RFC3339)) +} + +func (d *DateTime) UnmarshalJSON(data []byte) error { + var raw string + if err := json.Unmarshal(data, &raw); err != nil { + return err + } + + parsedTime, err := time.Parse(time.RFC3339, raw) + if err != nil { + return err + } + + *d = DateTime{t: &parsedTime} + return nil +} diff --git a/datasets.go b/datasets.go index f881330..4b00c96 100644 --- a/datasets.go +++ b/datasets.go @@ -11,9 +11,9 @@ import ( type DatasetsCreateRequest struct { // The name of the uploaded dataset. - Name *string `json:"-" url:"name,omitempty"` + Name string `json:"-" url:"name"` // The dataset type, which is used to validate the data. - Type *DatasetType `json:"-" url:"type,omitempty"` + Type DatasetType `json:"-" url:"type,omitempty"` // Indicates if the original file should be stored. KeepOriginalFile *bool `json:"-" url:"keep_original_file,omitempty"` // Indicates whether rows with malformed input should be dropped (instead of failing the validation check). Dropped rows will be returned in the warnings field. diff --git a/datasets/client.go b/datasets/client.go index 9d26e51..c1f3639 100644 --- a/datasets/client.go +++ b/datasets/client.go @@ -5,6 +5,8 @@ package datasets import ( bytes "bytes" context "context" + json "encoding/json" + errors "errors" fmt "fmt" v2 "github.com/cohere-ai/cohere-go/v2" core "github.com/cohere-ai/cohere-go/v2/core" @@ -42,14 +44,14 @@ func (c *Client) List( ) (*v2.DatasetsListResponse, error) { options := core.NewRequestOptions(opts...) - baseURL := "https://api.cohere.ai" + baseURL := "https://api.cohere.ai/v1" if c.baseURL != "" { baseURL = c.baseURL } if options.BaseURL != "" { baseURL = options.BaseURL } - endpointURL := baseURL + "/" + "v1/datasets" + endpointURL := baseURL + "/" + "datasets" queryParams, err := core.QueryValues(request) if err != nil { @@ -61,16 +63,36 @@ func (c *Client) List( headers := core.MergeHeaders(c.header.Clone(), options.ToHeader()) + errorDecoder := func(statusCode int, body io.Reader) error { + raw, err := io.ReadAll(body) + if err != nil { + return err + } + apiError := core.NewAPIError(statusCode, errors.New(string(raw))) + decoder := json.NewDecoder(bytes.NewReader(raw)) + switch statusCode { + case 429: + value := new(v2.TooManyRequestsError) + value.APIError = apiError + if err := decoder.Decode(value); err != nil { + return apiError + } + return value + } + return apiError + } + var response *v2.DatasetsListResponse if err := c.caller.Call( ctx, &core.CallParams{ - URL: endpointURL, - Method: http.MethodGet, - MaxAttempts: options.MaxAttempts, - Headers: headers, - Client: options.HTTPClient, - Response: &response, + URL: endpointURL, + Method: http.MethodGet, + MaxAttempts: options.MaxAttempts, + Headers: headers, + Client: options.HTTPClient, + Response: &response, + ErrorDecoder: errorDecoder, }, ); err != nil { return nil, err @@ -88,14 +110,14 @@ func (c *Client) Create( ) (*v2.DatasetsCreateResponse, error) { options := core.NewRequestOptions(opts...) - baseURL := "https://api.cohere.ai" + baseURL := "https://api.cohere.ai/v1" if c.baseURL != "" { baseURL = c.baseURL } if options.BaseURL != "" { baseURL = options.BaseURL } - endpointURL := baseURL + "/" + "v1/datasets" + endpointURL := baseURL + "/" + "datasets" queryParams, err := core.QueryValues(request) if err != nil { @@ -107,6 +129,25 @@ func (c *Client) Create( headers := core.MergeHeaders(c.header.Clone(), options.ToHeader()) + errorDecoder := func(statusCode int, body io.Reader) error { + raw, err := io.ReadAll(body) + if err != nil { + return err + } + apiError := core.NewAPIError(statusCode, errors.New(string(raw))) + decoder := json.NewDecoder(bytes.NewReader(raw)) + switch statusCode { + case 429: + value := new(v2.TooManyRequestsError) + value.APIError = apiError + if err := decoder.Decode(value); err != nil { + return apiError + } + return value + } + return apiError + } + var response *v2.DatasetsCreateResponse requestBuffer := bytes.NewBuffer(nil) writer := multipart.NewWriter(requestBuffer) @@ -121,16 +162,18 @@ func (c *Client) Create( if _, err := io.Copy(dataPart, data); err != nil { return nil, err } - evalDataFilename := "evalData_filename" - if named, ok := evalData.(interface{ Name() string }); ok { - evalDataFilename = named.Name() - } - evalDataPart, err := writer.CreateFormFile("eval_data", evalDataFilename) - if err != nil { - return nil, err - } - if _, err := io.Copy(evalDataPart, evalData); err != nil { - return nil, err + if evalData != nil { + evalDataFilename := "evalData_filename" + if named, ok := evalData.(interface{ Name() string }); ok { + evalDataFilename = named.Name() + } + evalDataPart, err := writer.CreateFormFile("eval_data", evalDataFilename) + if err != nil { + return nil, err + } + if _, err := io.Copy(evalDataPart, evalData); err != nil { + return nil, err + } } if err := writer.Close(); err != nil { return nil, err @@ -140,13 +183,14 @@ func (c *Client) Create( if err := c.caller.Call( ctx, &core.CallParams{ - URL: endpointURL, - Method: http.MethodPost, - MaxAttempts: options.MaxAttempts, - Headers: headers, - Client: options.HTTPClient, - Request: requestBuffer, - Response: &response, + URL: endpointURL, + Method: http.MethodPost, + MaxAttempts: options.MaxAttempts, + Headers: headers, + Client: options.HTTPClient, + Request: requestBuffer, + Response: &response, + ErrorDecoder: errorDecoder, }, ); err != nil { return nil, err @@ -161,27 +205,47 @@ func (c *Client) GetUsage( ) (*v2.DatasetsGetUsageResponse, error) { options := core.NewRequestOptions(opts...) - baseURL := "https://api.cohere.ai" + baseURL := "https://api.cohere.ai/v1" if c.baseURL != "" { baseURL = c.baseURL } if options.BaseURL != "" { baseURL = options.BaseURL } - endpointURL := baseURL + "/" + "v1/datasets/usage" + endpointURL := baseURL + "/" + "datasets/usage" headers := core.MergeHeaders(c.header.Clone(), options.ToHeader()) + errorDecoder := func(statusCode int, body io.Reader) error { + raw, err := io.ReadAll(body) + if err != nil { + return err + } + apiError := core.NewAPIError(statusCode, errors.New(string(raw))) + decoder := json.NewDecoder(bytes.NewReader(raw)) + switch statusCode { + case 429: + value := new(v2.TooManyRequestsError) + value.APIError = apiError + if err := decoder.Decode(value); err != nil { + return apiError + } + return value + } + return apiError + } + var response *v2.DatasetsGetUsageResponse if err := c.caller.Call( ctx, &core.CallParams{ - URL: endpointURL, - Method: http.MethodGet, - MaxAttempts: options.MaxAttempts, - Headers: headers, - Client: options.HTTPClient, - Response: &response, + URL: endpointURL, + Method: http.MethodGet, + MaxAttempts: options.MaxAttempts, + Headers: headers, + Client: options.HTTPClient, + Response: &response, + ErrorDecoder: errorDecoder, }, ); err != nil { return nil, err @@ -197,27 +261,47 @@ func (c *Client) Get( ) (*v2.DatasetsGetResponse, error) { options := core.NewRequestOptions(opts...) - baseURL := "https://api.cohere.ai" + baseURL := "https://api.cohere.ai/v1" if c.baseURL != "" { baseURL = c.baseURL } if options.BaseURL != "" { baseURL = options.BaseURL } - endpointURL := fmt.Sprintf(baseURL+"/"+"v1/datasets/%v", id) + endpointURL := fmt.Sprintf(baseURL+"/"+"datasets/%v", id) headers := core.MergeHeaders(c.header.Clone(), options.ToHeader()) + errorDecoder := func(statusCode int, body io.Reader) error { + raw, err := io.ReadAll(body) + if err != nil { + return err + } + apiError := core.NewAPIError(statusCode, errors.New(string(raw))) + decoder := json.NewDecoder(bytes.NewReader(raw)) + switch statusCode { + case 429: + value := new(v2.TooManyRequestsError) + value.APIError = apiError + if err := decoder.Decode(value); err != nil { + return apiError + } + return value + } + return apiError + } + var response *v2.DatasetsGetResponse if err := c.caller.Call( ctx, &core.CallParams{ - URL: endpointURL, - Method: http.MethodGet, - MaxAttempts: options.MaxAttempts, - Headers: headers, - Client: options.HTTPClient, - Response: &response, + URL: endpointURL, + Method: http.MethodGet, + MaxAttempts: options.MaxAttempts, + Headers: headers, + Client: options.HTTPClient, + Response: &response, + ErrorDecoder: errorDecoder, }, ); err != nil { return nil, err @@ -233,27 +317,47 @@ func (c *Client) Delete( ) (map[string]interface{}, error) { options := core.NewRequestOptions(opts...) - baseURL := "https://api.cohere.ai" + baseURL := "https://api.cohere.ai/v1" if c.baseURL != "" { baseURL = c.baseURL } if options.BaseURL != "" { baseURL = options.BaseURL } - endpointURL := fmt.Sprintf(baseURL+"/"+"v1/datasets/%v", id) + endpointURL := fmt.Sprintf(baseURL+"/"+"datasets/%v", id) headers := core.MergeHeaders(c.header.Clone(), options.ToHeader()) + errorDecoder := func(statusCode int, body io.Reader) error { + raw, err := io.ReadAll(body) + if err != nil { + return err + } + apiError := core.NewAPIError(statusCode, errors.New(string(raw))) + decoder := json.NewDecoder(bytes.NewReader(raw)) + switch statusCode { + case 429: + value := new(v2.TooManyRequestsError) + value.APIError = apiError + if err := decoder.Decode(value); err != nil { + return apiError + } + return value + } + return apiError + } + var response map[string]interface{} if err := c.caller.Call( ctx, &core.CallParams{ - URL: endpointURL, - Method: http.MethodDelete, - MaxAttempts: options.MaxAttempts, - Headers: headers, - Client: options.HTTPClient, - Response: &response, + URL: endpointURL, + Method: http.MethodDelete, + MaxAttempts: options.MaxAttempts, + Headers: headers, + Client: options.HTTPClient, + Response: &response, + ErrorDecoder: errorDecoder, }, ); err != nil { return nil, err diff --git a/embedjobs/client.go b/embedjobs/client.go index be9f992..3179e97 100644 --- a/embedjobs/client.go +++ b/embedjobs/client.go @@ -42,14 +42,14 @@ func (c *Client) List( ) (*v2.ListEmbedJobResponse, error) { options := core.NewRequestOptions(opts...) - baseURL := "https://api.cohere.ai" + baseURL := "https://api.cohere.ai/v1" if c.baseURL != "" { baseURL = c.baseURL } if options.BaseURL != "" { baseURL = options.BaseURL } - endpointURL := baseURL + "/" + "v1/embed-jobs" + endpointURL := baseURL + "/" + "embed-jobs" headers := core.MergeHeaders(c.header.Clone(), options.ToHeader()) @@ -68,6 +68,13 @@ func (c *Client) List( return apiError } return value + case 429: + value := new(v2.TooManyRequestsError) + value.APIError = apiError + if err := decoder.Decode(value); err != nil { + return apiError + } + return value case 500: value := new(v2.InternalServerError) value.APIError = apiError @@ -105,14 +112,14 @@ func (c *Client) Create( ) (*v2.CreateEmbedJobResponse, error) { options := core.NewRequestOptions(opts...) - baseURL := "https://api.cohere.ai" + baseURL := "https://api.cohere.ai/v1" if c.baseURL != "" { baseURL = c.baseURL } if options.BaseURL != "" { baseURL = options.BaseURL } - endpointURL := baseURL + "/" + "v1/embed-jobs" + endpointURL := baseURL + "/" + "embed-jobs" headers := core.MergeHeaders(c.header.Clone(), options.ToHeader()) @@ -131,6 +138,13 @@ func (c *Client) Create( return apiError } return value + case 429: + value := new(v2.TooManyRequestsError) + value.APIError = apiError + if err := decoder.Decode(value); err != nil { + return apiError + } + return value case 500: value := new(v2.InternalServerError) value.APIError = apiError @@ -170,14 +184,14 @@ func (c *Client) Get( ) (*v2.EmbedJob, error) { options := core.NewRequestOptions(opts...) - baseURL := "https://api.cohere.ai" + baseURL := "https://api.cohere.ai/v1" if c.baseURL != "" { baseURL = c.baseURL } if options.BaseURL != "" { baseURL = options.BaseURL } - endpointURL := fmt.Sprintf(baseURL+"/"+"v1/embed-jobs/%v", id) + endpointURL := fmt.Sprintf(baseURL+"/"+"embed-jobs/%v", id) headers := core.MergeHeaders(c.header.Clone(), options.ToHeader()) @@ -203,6 +217,13 @@ func (c *Client) Get( return apiError } return value + case 429: + value := new(v2.TooManyRequestsError) + value.APIError = apiError + if err := decoder.Decode(value); err != nil { + return apiError + } + return value case 500: value := new(v2.InternalServerError) value.APIError = apiError @@ -241,14 +262,14 @@ func (c *Client) Cancel( ) error { options := core.NewRequestOptions(opts...) - baseURL := "https://api.cohere.ai" + baseURL := "https://api.cohere.ai/v1" if c.baseURL != "" { baseURL = c.baseURL } if options.BaseURL != "" { baseURL = options.BaseURL } - endpointURL := fmt.Sprintf(baseURL+"/"+"v1/embed-jobs/%v/cancel", id) + endpointURL := fmt.Sprintf(baseURL+"/"+"embed-jobs/%v/cancel", id) headers := core.MergeHeaders(c.header.Clone(), options.ToHeader()) @@ -274,6 +295,13 @@ func (c *Client) Cancel( return apiError } return value + case 429: + value := new(v2.TooManyRequestsError) + value.APIError = apiError + if err := decoder.Decode(value); err != nil { + return apiError + } + return value case 500: value := new(v2.InternalServerError) value.APIError = apiError diff --git a/environments.go b/environments.go index 1e42718..56a405b 100644 --- a/environments.go +++ b/environments.go @@ -9,5 +9,5 @@ package api var Environments = struct { Production string }{ - Production: "https://api.cohere.ai", + Production: "https://api.cohere.ai/v1", } diff --git a/errors.go b/errors.go index a49084c..f825312 100644 --- a/errors.go +++ b/errors.go @@ -98,3 +98,26 @@ func (n *NotFoundError) MarshalJSON() ([]byte, error) { func (n *NotFoundError) Unwrap() error { return n.APIError } + +type TooManyRequestsError struct { + *core.APIError + Body interface{} +} + +func (t *TooManyRequestsError) UnmarshalJSON(data []byte) error { + var body interface{} + if err := json.Unmarshal(data, &body); err != nil { + return err + } + t.StatusCode = 429 + t.Body = body + return nil +} + +func (t *TooManyRequestsError) MarshalJSON() ([]byte, error) { + return json.Marshal(t.Body) +} + +func (t *TooManyRequestsError) Unwrap() error { + return t.APIError +} diff --git a/models.go b/models.go new file mode 100644 index 0000000..4917b47 --- /dev/null +++ b/models.go @@ -0,0 +1,13 @@ +// This file was auto-generated by Fern from our API Definition. + +package api + +type ModelsListRequest struct { + // Maximum number of models to include in a page + // Defaults to `20`, min value of `1`, max value of `1000`. + PageSize *float64 `json:"-" url:"page_size,omitempty"` + // Page token provided in the `next_page_token` field of a previous response. + PageToken *string `json:"-" url:"page_token,omitempty"` + // When provided, filters the list of models to only those that are compatible with the specified endpoint. + Endpoint *CompatibleEndpoint `json:"-" url:"endpoint,omitempty"` +} diff --git a/models/client.go b/models/client.go new file mode 100644 index 0000000..da9e42f --- /dev/null +++ b/models/client.go @@ -0,0 +1,99 @@ +// This file was auto-generated by Fern from our API Definition. + +package models + +import ( + bytes "bytes" + context "context" + json "encoding/json" + errors "errors" + v2 "github.com/cohere-ai/cohere-go/v2" + core "github.com/cohere-ai/cohere-go/v2/core" + option "github.com/cohere-ai/cohere-go/v2/option" + io "io" + http "net/http" +) + +type Client struct { + baseURL string + caller *core.Caller + header http.Header +} + +func NewClient(opts ...option.RequestOption) *Client { + options := core.NewRequestOptions(opts...) + return &Client{ + baseURL: options.BaseURL, + caller: core.NewCaller( + &core.CallerParams{ + Client: options.HTTPClient, + MaxAttempts: options.MaxAttempts, + }, + ), + header: options.ToHeader(), + } +} + +// Returns a list of models available for use. The list contains models from Cohere as well as your fine-tuned models. +func (c *Client) List( + ctx context.Context, + request *v2.ModelsListRequest, + opts ...option.RequestOption, +) (*v2.ListModelsResponse, error) { + options := core.NewRequestOptions(opts...) + + baseURL := "https://api.cohere.ai/v1" + if c.baseURL != "" { + baseURL = c.baseURL + } + if options.BaseURL != "" { + baseURL = options.BaseURL + } + endpointURL := baseURL + "/" + "models" + + queryParams, err := core.QueryValues(request) + if err != nil { + return nil, err + } + if len(queryParams) > 0 { + endpointURL += "?" + queryParams.Encode() + } + + headers := core.MergeHeaders(c.header.Clone(), options.ToHeader()) + + errorDecoder := func(statusCode int, body io.Reader) error { + raw, err := io.ReadAll(body) + if err != nil { + return err + } + apiError := core.NewAPIError(statusCode, errors.New(string(raw))) + decoder := json.NewDecoder(bytes.NewReader(raw)) + switch statusCode { + case 429: + value := new(v2.TooManyRequestsError) + value.APIError = apiError + if err := decoder.Decode(value); err != nil { + return apiError + } + return value + } + return apiError + } + + var response *v2.ListModelsResponse + if err := c.caller.Call( + ctx, + &core.CallParams{ + URL: endpointURL, + Method: http.MethodGet, + MaxAttempts: options.MaxAttempts, + Headers: headers, + Client: options.HTTPClient, + Response: &response, + ErrorDecoder: errorDecoder, + }, + ); err != nil { + return nil, err + } + return response, nil +} diff --git a/pointer.go b/pointer.go index 82fb917..faaf462 100644 --- a/pointer.go +++ b/pointer.go @@ -1,6 +1,10 @@ package api -import "time" +import ( + "time" + + "github.com/google/uuid" +) // Bool returns a pointer to the given bool value. func Bool(b bool) *bool { @@ -97,7 +101,32 @@ func Uintptr(u uintptr) *uintptr { return &u } +// UUID returns a pointer to the given uuid.UUID value. +func UUID(u uuid.UUID) *uuid.UUID { + return &u +} + // Time returns a pointer to the given time.Time value. func Time(t time.Time) *time.Time { return &t } + +// MustParseDate attempts to parse the given string as a +// date time.Time, and panics upon failure. +func MustParseDate(date string) time.Time { + t, err := time.Parse("2006-01-02", date) + if err != nil { + panic(err) + } + return t +} + +// MustParseDateTime attempts to parse the given string as a +// datetime time.Time, and panics upon failure. +func MustParseDateTime(datetime string) time.Time { + t, err := time.Parse(time.RFC3339, datetime) + if err != nil { + panic(err) + } + return t +} diff --git a/tests/sdk_test.go b/tests/sdk_test.go index b2a8faf..592dd17 100644 --- a/tests/sdk_test.go +++ b/tests/sdk_test.go @@ -26,6 +26,10 @@ func strPointer(s string) *string { return &s } +func boolPointer(s bool) *bool { + return &s +} + func TestNewClient(t *testing.T) { client := client.NewClient(client.WithToken(os.Getenv("COHERE_API_KEY"))) @@ -186,10 +190,10 @@ func TestNewClient(t *testing.T) { &cohere.RerankRequest{ Query: "What is the capital of the United States?", Documents: []*cohere.RerankRequestDocumentsItem{ - cohere.NewRerankRequestDocumentsItemFromString("Carson City is the capital city of the American state of Nevada."), - cohere.NewRerankRequestDocumentsItemFromString("The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean. Its capital is Saipan."), - cohere.NewRerankRequestDocumentsItemFromString("Washington, D.C. (also known as simply Washington or D.C., and officially as the District of Columbia) is the capital of the United States. It is a federal district."), - cohere.NewRerankRequestDocumentsItemFromString("Capital punishment (the death penalty) has existed in the United States since beforethe United States was a country. As of 2017, capital punishment is legal in 30 of the 50 states."), + {String: "Carson City is the capital city of the American state of Nevada."}, + {String: "The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean. Its capital is Saipan."}, + {String: "Washington, D.C. (also known as simply Washington or D.C., and officially as the District of Columbia) is the capital of the United States. It is a federal district."}, + {String: "Capital punishment (the death penalty) has existed in the United States since beforethe United States was a country. As of 2017, capital punishment is legal in 30 of the 50 states."}, }, }) @@ -218,8 +222,8 @@ func TestNewClient(t *testing.T) { &MyReader{Reader: strings.NewReader(`{"text": "The quick brown fox jumps over the lazy dog"}`), name: "test.jsonl"}, &MyReader{Reader: strings.NewReader(""), name: "a.jsonl"}, &cohere.DatasetsCreateRequest{ - Name: strPointer("prompt-completion-dataset"), - Type: cohere.DatasetTypeEmbedResult.Ptr(), + Name: "prompt-completion-dataset", + Type: cohere.DatasetTypeEmbedResult, }) require.NoError(t, err) @@ -345,4 +349,84 @@ func TestNewClient(t *testing.T) { require.NoError(t, err) print(delete) }) + + t.Run("TestTool", func(t *testing.T) { + tools := []*cohere.Tool{ + { + Name: "sales_database", + Description: "Connects to a database about sales volumes", + ParameterDefinitions: map[string]*cohere.ToolParameterDefinitionsValue{ + "day": { + Description: "Retrieves sales data from this day, formatted as YYYY-MM-DD.", + Type: "str", + Required: boolPointer(true), + }, + }, + }, + } + + toolsResponse, err := client.Chat( + context.TODO(), + &cohere.ChatRequest{ + Message: "How good were the sales on September 29?", + Tools: tools, + Preamble: strPointer(` + ## Task Description + You help people answer their questions and other requests interactively. You will be asked a very wide array of requests on all kinds of topics. You will be equipped with a wide range of search engines or similar tools to help you, which you use to research your answer. You should focus on serving the user's needs as best you can, which will be wide-ranging. + + ## Style Guide + Unless the user asks for a different style of answer, you should answer in full sentences, using proper grammar and spelling. + `), + }) + + require.NoError(t, err) + require.NotNil(t, toolsResponse.ToolCalls) + require.Len(t, toolsResponse.ToolCalls, 1) + require.Equal(t, toolsResponse.ToolCalls[0].Name, "sales_database") + require.Equal(t, toolsResponse.ToolCalls[0].Parameters["day"], "2023-09-29") + + print(toolsResponse) + + localTools := map[string]func(string) *[]map[string]interface{}{ + "sales_database": func(day string) *[]map[string]interface{} { + return &[]map[string]interface{}{ + { + "numberOfSales": 120, + "totalRevenue": 48500, + "averageSaleValue": 404.17, + "date": "2023-09-29", + }, + } + }, + } + + toolResults := make([]*cohere.ChatRequestToolResultsItem, 0) + + for _, toolCall := range toolsResponse.ToolCalls { + result := localTools[toolCall.Name](toolCall.Parameters["day"].(string)) + toolResult := &cohere.ChatRequestToolResultsItem{ + Call: toolCall, + Outputs: *result, + } + toolResults = append(toolResults, toolResult) + } + + citedResponse, err := client.Chat( + context.TODO(), + &cohere.ChatRequest{ + Message: "How good were the sales on September 29?", + Tools: tools, + ToolResults: toolResults, + Model: strPointer("command-nightly"), + }) + + require.NoError(t, err) + + require.Equal(t, citedResponse.Documents[0]["averageSaleValue"], "404.17") + require.Equal(t, citedResponse.Documents[0]["date"], "2023-09-29") + require.Equal(t, citedResponse.Documents[0]["numberOfSales"], "120") + require.Equal(t, citedResponse.Documents[0]["totalRevenue"], "48500") + + }) + } diff --git a/types.go b/types.go index 5c723b7..f164ab2 100644 --- a/types.go +++ b/types.go @@ -10,28 +10,27 @@ import ( ) type ChatRequest struct { - // Accepts a string. - // The chat message from the user to the model. + // Text input for the model to respond to. Message string `json:"message" url:"message"` // Defaults to `command`. // - // The identifier of the model, which can be one of the existing Cohere models or the full ID for a [fine-tuned custom model](https://docs.cohere.com/docs/chat-fine-tuning). - // - // Compatible Cohere models are `command` and `command-light` as well as the experimental `command-nightly` and `command-light-nightly` variants. Read more about [Cohere models](https://docs.cohere.com/docs/models). + // The name of a compatible [Cohere model](https://docs.cohere.com/docs/models) or the ID of a [fine-tuned](https://docs.cohere.com/docs/chat-fine-tuning) model. Model *string `json:"model,omitempty" url:"model,omitempty"` - // When specified, the default Cohere preamble will be replaced with the provided one. - PreambleOverride *string `json:"preamble_override,omitempty" url:"preamble_override,omitempty"` + // When specified, the default Cohere preamble will be replaced with the provided one. Preambles are a part of the prompt used to adjust the model's overall behavior and conversation style. + Preamble *string `json:"preamble,omitempty" url:"preamble,omitempty"` // A list of previous messages between the user and the model, meant to give the model conversational context for responding to the user's `message`. ChatHistory []*ChatMessage `json:"chat_history,omitempty" url:"chat_history,omitempty"` - // An alternative to `chat_history`. Previous conversations can be resumed by providing the conversation's identifier. The contents of `message` and the model's response will be stored as part of this conversation. + // An alternative to `chat_history`. // - // If a conversation with this id does not already exist, a new conversation will be created. + // Providing a `conversation_id` creates or resumes a persisted conversation with the specified ID. The ID can be any non empty string. ConversationId *string `json:"conversation_id,omitempty" url:"conversation_id,omitempty"` // Defaults to `AUTO` when `connectors` are specified and `OFF` in all other cases. // // Dictates how the prompt will be constructed. // - // With `prompt_truncation` set to "AUTO", some elements from `chat_history` and `documents` will be dropped in an attempt to construct a prompt that fits within the model's context length limit. + // With `prompt_truncation` set to "AUTO", some elements from `chat_history` and `documents` will be dropped in an attempt to construct a prompt that fits within the model's context length limit. During this process the order of the documents and chat history will be changed and ranked by relevance. + // + // With `prompt_truncation` set to "AUTO_PRESERVE_ORDER", some elements from `chat_history` and `documents` will be dropped in an attempt to construct a prompt that fits within the model's context length limit. During this process the order of the documents and chat history will be preserved as they are inputted into the API. // // With `prompt_truncation` set to "OFF", no elements will be dropped. If the sum of the inputs exceeds the model's context length limit, a `TooManyTokens` error will be returned. PromptTruncation *ChatRequestPromptTruncation `json:"prompt_truncation,omitempty" url:"prompt_truncation,omitempty"` @@ -43,12 +42,26 @@ type ChatRequest struct { // // When `true`, the response will only contain a list of generated search queries, but no search will take place, and no reply from the model to the user's `message` will be generated. SearchQueriesOnly *bool `json:"search_queries_only,omitempty" url:"search_queries_only,omitempty"` - // A list of relevant documents that the model can use to enrich its reply. See ['Document Mode'](https://docs.cohere.com/docs/retrieval-augmented-generation-rag#document-mode) in the guide for more information. - Documents []ChatDocument `json:"documents,omitempty" url:"documents,omitempty"` - // Defaults to `"accurate"`. + // A list of relevant documents that the model can cite to generate a more accurate reply. Each document is a string-string dictionary. + // + // Example: + // `[ + // + // { "title": "Tall penguins", "text": "Emperor penguins are the tallest." }, + // { "title": "Penguin habitats", "text": "Emperor penguins only live in Antarctica." }, + // + // ]` + // + // Keys and values from each document will be serialized to a string and passed to the model. The resulting generation will include citations that reference some of these documents. + // + // Some suggested keys are "text", "author", and "date". For better generation quality, it is recommended to keep the total word count of the strings in the dictionary to under 300 words. // - // Dictates the approach taken to generating citations as part of the RAG flow by allowing the user to specify whether they want `"accurate"` results or `"fast"` results. - CitationQuality *ChatRequestCitationQuality `json:"citation_quality,omitempty" url:"citation_quality,omitempty"` + // An `id` field (string) can be optionally supplied to identify the document in the citations. This field will not be passed to the model. + // + // An `_excludes` field (array of strings) can be optionally supplied to omit some key-value pairs from being shown to the model. The omitted fields will still show up in the citation object. The "_excludes" field will not be passed to the model. + // + // See ['Document Mode'](https://docs.cohere.com/docs/retrieval-augmented-generation-rag#document-mode) in the guide for more information. + Documents []ChatDocument `json:"documents,omitempty" url:"documents,omitempty"` // Defaults to `0.3`. // // A non-negative float that tunes the degree of randomness in generation. Lower temperatures mean less random generations, and higher temperatures mean more random generations. @@ -63,11 +76,44 @@ type ChatRequest struct { // Ensures that only the most likely tokens, with total probability mass of `p`, are considered for generation at each step. If both `k` and `p` are enabled, `p` acts after `k`. // Defaults to `0.75`. min value of `0.01`, max value of `0.99`. P *float64 `json:"p,omitempty" url:"p,omitempty"` + // Defaults to `0.0`, min value of `0.0`, max value of `1.0`. + // // Used to reduce repetitiveness of generated tokens. The higher the value, the stronger a penalty is applied to previously present tokens, proportional to how many times they have already appeared in the prompt or prior generation. FrequencyPenalty *float64 `json:"frequency_penalty,omitempty" url:"frequency_penalty,omitempty"` - // Defaults to `0.0`, min value of `0.0`, max value of `1.0`. Can be used to reduce repetitiveness of generated tokens. Similar to `frequency_penalty`, except that this penalty is applied equally to all tokens that have already appeared, regardless of their exact frequencies. + // Defaults to `0.0`, min value of `0.0`, max value of `1.0`. + // + // Used to reduce repetitiveness of generated tokens. Similar to `frequency_penalty`, except that this penalty is applied equally to all tokens that have already appeared, regardless of their exact frequencies. PresencePenalty *float64 `json:"presence_penalty,omitempty" url:"presence_penalty,omitempty"` - stream bool + // When enabled, the user's prompt will be sent to the model without any pre-processing. + RawPrompting *bool `json:"raw_prompting,omitempty" url:"raw_prompting,omitempty"` + // A list of available tools (functions) that the model may suggest invoking before producing a text response. + // + // When `tools` is passed, The `text` field in the response will be `""` and the `tool_calls` field in the response will be populated with a list of tool calls that need to be made. If no calls need to be made + // the `tool_calls` array will be empty. + Tools []*Tool `json:"tools,omitempty" url:"tools,omitempty"` + // A list of results from invoking tools. Results are used to generate text and will be referenced in citations. When using `tool_results`, `tools` must be passed as well. + // Each tool_result contains information about how it was invoked, as well as a list of outputs in the form of dictionaries. + // + // ``` + // tool_results = [ + // + // { + // "call": { + // "name": , + // "parameters": { + // : + // } + // }, + // "outputs": [{ + // : + // }] + // }, + // ... + // + // ] + // ``` + ToolResults []*ChatRequestToolResultsItem `json:"tool_results,omitempty" url:"tool_results,omitempty"` + stream bool } func (c *ChatRequest) Stream() bool { @@ -98,28 +144,27 @@ func (c *ChatRequest) MarshalJSON() ([]byte, error) { } type ChatStreamRequest struct { - // Accepts a string. - // The chat message from the user to the model. + // Text input for the model to respond to. Message string `json:"message" url:"message"` // Defaults to `command`. // - // The identifier of the model, which can be one of the existing Cohere models or the full ID for a [fine-tuned custom model](https://docs.cohere.com/docs/chat-fine-tuning). - // - // Compatible Cohere models are `command` and `command-light` as well as the experimental `command-nightly` and `command-light-nightly` variants. Read more about [Cohere models](https://docs.cohere.com/docs/models). + // The name of a compatible [Cohere model](https://docs.cohere.com/docs/models) or the ID of a [fine-tuned](https://docs.cohere.com/docs/chat-fine-tuning) model. Model *string `json:"model,omitempty" url:"model,omitempty"` - // When specified, the default Cohere preamble will be replaced with the provided one. - PreambleOverride *string `json:"preamble_override,omitempty" url:"preamble_override,omitempty"` + // When specified, the default Cohere preamble will be replaced with the provided one. Preambles are a part of the prompt used to adjust the model's overall behavior and conversation style. + Preamble *string `json:"preamble,omitempty" url:"preamble,omitempty"` // A list of previous messages between the user and the model, meant to give the model conversational context for responding to the user's `message`. ChatHistory []*ChatMessage `json:"chat_history,omitempty" url:"chat_history,omitempty"` - // An alternative to `chat_history`. Previous conversations can be resumed by providing the conversation's identifier. The contents of `message` and the model's response will be stored as part of this conversation. + // An alternative to `chat_history`. // - // If a conversation with this id does not already exist, a new conversation will be created. + // Providing a `conversation_id` creates or resumes a persisted conversation with the specified ID. The ID can be any non empty string. ConversationId *string `json:"conversation_id,omitempty" url:"conversation_id,omitempty"` // Defaults to `AUTO` when `connectors` are specified and `OFF` in all other cases. // // Dictates how the prompt will be constructed. // - // With `prompt_truncation` set to "AUTO", some elements from `chat_history` and `documents` will be dropped in an attempt to construct a prompt that fits within the model's context length limit. + // With `prompt_truncation` set to "AUTO", some elements from `chat_history` and `documents` will be dropped in an attempt to construct a prompt that fits within the model's context length limit. During this process the order of the documents and chat history will be changed and ranked by relevance. + // + // With `prompt_truncation` set to "AUTO_PRESERVE_ORDER", some elements from `chat_history` and `documents` will be dropped in an attempt to construct a prompt that fits within the model's context length limit. During this process the order of the documents and chat history will be preserved as they are inputted into the API. // // With `prompt_truncation` set to "OFF", no elements will be dropped. If the sum of the inputs exceeds the model's context length limit, a `TooManyTokens` error will be returned. PromptTruncation *ChatStreamRequestPromptTruncation `json:"prompt_truncation,omitempty" url:"prompt_truncation,omitempty"` @@ -131,12 +176,26 @@ type ChatStreamRequest struct { // // When `true`, the response will only contain a list of generated search queries, but no search will take place, and no reply from the model to the user's `message` will be generated. SearchQueriesOnly *bool `json:"search_queries_only,omitempty" url:"search_queries_only,omitempty"` - // A list of relevant documents that the model can use to enrich its reply. See ['Document Mode'](https://docs.cohere.com/docs/retrieval-augmented-generation-rag#document-mode) in the guide for more information. - Documents []ChatDocument `json:"documents,omitempty" url:"documents,omitempty"` - // Defaults to `"accurate"`. + // A list of relevant documents that the model can cite to generate a more accurate reply. Each document is a string-string dictionary. // - // Dictates the approach taken to generating citations as part of the RAG flow by allowing the user to specify whether they want `"accurate"` results or `"fast"` results. - CitationQuality *ChatStreamRequestCitationQuality `json:"citation_quality,omitempty" url:"citation_quality,omitempty"` + // Example: + // `[ + // + // { "title": "Tall penguins", "text": "Emperor penguins are the tallest." }, + // { "title": "Penguin habitats", "text": "Emperor penguins only live in Antarctica." }, + // + // ]` + // + // Keys and values from each document will be serialized to a string and passed to the model. The resulting generation will include citations that reference some of these documents. + // + // Some suggested keys are "text", "author", and "date". For better generation quality, it is recommended to keep the total word count of the strings in the dictionary to under 300 words. + // + // An `id` field (string) can be optionally supplied to identify the document in the citations. This field will not be passed to the model. + // + // An `_excludes` field (array of strings) can be optionally supplied to omit some key-value pairs from being shown to the model. The omitted fields will still show up in the citation object. The "_excludes" field will not be passed to the model. + // + // See ['Document Mode'](https://docs.cohere.com/docs/retrieval-augmented-generation-rag#document-mode) in the guide for more information. + Documents []ChatDocument `json:"documents,omitempty" url:"documents,omitempty"` // Defaults to `0.3`. // // A non-negative float that tunes the degree of randomness in generation. Lower temperatures mean less random generations, and higher temperatures mean more random generations. @@ -151,11 +210,44 @@ type ChatStreamRequest struct { // Ensures that only the most likely tokens, with total probability mass of `p`, are considered for generation at each step. If both `k` and `p` are enabled, `p` acts after `k`. // Defaults to `0.75`. min value of `0.01`, max value of `0.99`. P *float64 `json:"p,omitempty" url:"p,omitempty"` + // Defaults to `0.0`, min value of `0.0`, max value of `1.0`. + // // Used to reduce repetitiveness of generated tokens. The higher the value, the stronger a penalty is applied to previously present tokens, proportional to how many times they have already appeared in the prompt or prior generation. FrequencyPenalty *float64 `json:"frequency_penalty,omitempty" url:"frequency_penalty,omitempty"` - // Defaults to `0.0`, min value of `0.0`, max value of `1.0`. Can be used to reduce repetitiveness of generated tokens. Similar to `frequency_penalty`, except that this penalty is applied equally to all tokens that have already appeared, regardless of their exact frequencies. + // Defaults to `0.0`, min value of `0.0`, max value of `1.0`. + // + // Used to reduce repetitiveness of generated tokens. Similar to `frequency_penalty`, except that this penalty is applied equally to all tokens that have already appeared, regardless of their exact frequencies. PresencePenalty *float64 `json:"presence_penalty,omitempty" url:"presence_penalty,omitempty"` - stream bool + // When enabled, the user's prompt will be sent to the model without any pre-processing. + RawPrompting *bool `json:"raw_prompting,omitempty" url:"raw_prompting,omitempty"` + // A list of available tools (functions) that the model may suggest invoking before producing a text response. + // + // When `tools` is passed, The `text` field in the response will be `""` and the `tool_calls` field in the response will be populated with a list of tool calls that need to be made. If no calls need to be made + // the `tool_calls` array will be empty. + Tools []*Tool `json:"tools,omitempty" url:"tools,omitempty"` + // A list of results from invoking tools. Results are used to generate text and will be referenced in citations. When using `tool_results`, `tools` must be passed as well. + // Each tool_result contains information about how it was invoked, as well as a list of outputs in the form of dictionaries. + // + // ``` + // tool_results = [ + // + // { + // "call": { + // "name": , + // "parameters": { + // : + // } + // }, + // "outputs": [{ + // : + // }] + // }, + // ... + // + // ] + // ``` + ToolResults []*ChatStreamRequestToolResultsItem `json:"tool_results,omitempty" url:"tool_results,omitempty"` + stream bool } func (c *ChatStreamRequest) Stream() bool { @@ -298,12 +390,6 @@ type GenerateRequest struct { // // If `ALL` is selected, the token likelihoods will be provided both for the prompt and the generated text. ReturnLikelihoods *GenerateRequestReturnLikelihoods `json:"return_likelihoods,omitempty" url:"return_likelihoods,omitempty"` - // Certain models support the `logit_bias` parameter. - // - // Used to prevent the model from generating unwanted tokens or to incentivize it to include desired tokens. The format is `{token_id: bias}` where bias is a float between -10 and 10. Tokens can be obtained from text using [Tokenize](/reference/tokenize). - // - // For example, if the value `{'11': -10}` is provided, the model will be very unlikely to include the token 11 (`"\n"`, the newline character) anywhere in the generated text. In contrast `{'11': 10}` will result in generations that nearly only contain that token. Values between -10 and 10 will proportionally affect the likelihood of the token appearing in the generated text. - LogitBias map[string]float64 `json:"logit_bias,omitempty" url:"logit_bias,omitempty"` // When enabled, the user's prompt will be sent to the model without any pre-processing. RawPrompting *bool `json:"raw_prompting,omitempty" url:"raw_prompting,omitempty"` stream bool @@ -389,12 +475,6 @@ type GenerateStreamRequest struct { // // If `ALL` is selected, the token likelihoods will be provided both for the prompt and the generated text. ReturnLikelihoods *GenerateStreamRequestReturnLikelihoods `json:"return_likelihoods,omitempty" url:"return_likelihoods,omitempty"` - // Certain models support the `logit_bias` parameter. - // - // Used to prevent the model from generating unwanted tokens or to incentivize it to include desired tokens. The format is `{token_id: bias}` where bias is a float between -10 and 10. Tokens can be obtained from text using [Tokenize](/reference/tokenize). - // - // For example, if the value `{'11': -10}` is provided, the model will be very unlikely to include the token 11 (`"\n"`, the newline character) anywhere in the generated text. In contrast `{'11': 10}` will result in generations that nearly only contain that token. Values between -10 and 10 will proportionally affect the likelihood of the token appearing in the generated text. - LogitBias map[string]float64 `json:"logit_bias,omitempty" url:"logit_bias,omitempty"` // When enabled, the user's prompt will be sent to the model without any pre-processing. RawPrompting *bool `json:"raw_prompting,omitempty" url:"raw_prompting,omitempty"` stream bool @@ -667,21 +747,15 @@ func (c *ChatCitationGenerationEvent) String() string { type ChatConnector struct { // The identifier of the connector. Id string `json:"id" url:"id"` - // An optional override to set the token that Cohere passes to the connector in the Authorization header. + // When specified, this user access token will be passed to the connector in the Authorization header instead of the Cohere generated one. UserAccessToken *string `json:"user_access_token,omitempty" url:"user_access_token,omitempty"` - // An optional override to set whether or not the request continues if this connector fails. + // Defaults to `false`. + // + // When `true`, the request will continue if this connector returned an error. ContinueOnFailure *bool `json:"continue_on_failure,omitempty" url:"continue_on_failure,omitempty"` // Provides the connector with different settings at request time. The key/value pairs of this object are specific to each connector. // - // The supported options are: - // - // **web-search** - // - // **site** - The web search results will be restricted to this domain (and TLD) when specified. Only a single domain is specified, and subdomains are also accepted. - // Examples: - // - // - `{"options": {"site": "cohere.com"}}` would restrict the results to all subdomains at cohere.com - // - `{"options": {"site": "txt.cohere.com"}}` would restrict the results to `txt.cohere.com` + // For example, the connector `web-search` supports the `site` option, which limits search results to the specified domain. Options map[string]interface{} `json:"options,omitempty" url:"options,omitempty"` _rawJSON json.RawMessage @@ -710,17 +784,56 @@ func (c *ChatConnector) String() string { return fmt.Sprintf("%#v", c) } +type ChatDataMetrics struct { + // The sum of all turns of valid train examples. + NumTrainTurns *string `json:"numTrainTurns,omitempty" url:"numTrainTurns,omitempty"` + // The sum of all turns of valid eval examples. + NumEvalTurns *string `json:"numEvalTurns,omitempty" url:"numEvalTurns,omitempty"` + // The preamble of this dataset. + Preamble *string `json:"preamble,omitempty" url:"preamble,omitempty"` + + _rawJSON json.RawMessage +} + +func (c *ChatDataMetrics) UnmarshalJSON(data []byte) error { + type unmarshaler ChatDataMetrics + var value unmarshaler + if err := json.Unmarshal(data, &value); err != nil { + return err + } + *c = ChatDataMetrics(value) + c._rawJSON = json.RawMessage(data) + return nil +} + +func (c *ChatDataMetrics) String() string { + if len(c._rawJSON) > 0 { + if value, err := core.StringifyJSON(c._rawJSON); err == nil { + return value + } + } + if value, err := core.StringifyJSON(c); err == nil { + return value + } + return fmt.Sprintf("%#v", c) +} + // Relevant information that could be used by the model to generate a more accurate reply. // The contents of each document are generally short (under 300 words), and are passed in the form of a // dictionary of strings. Some suggested keys are "text", "author", "date". Both the key name and the value will be // passed to the model. type ChatDocument = map[string]string -// A single message in a chat history. Contains the role of the sender, the text contents of the message, and optionally a username. +// A single message in a chat history. Contains the role of the sender, the text contents of the message. type ChatMessage struct { - Role ChatMessageRole `json:"role,omitempty" url:"role,omitempty"` - Message string `json:"message" url:"message"` - UserName *string `json:"user_name,omitempty" url:"user_name,omitempty"` + // One of CHATBOT|USER to identify who the message is coming from. + Role ChatMessageRole `json:"role,omitempty" url:"role,omitempty"` + // Contents of the chat message. + Message string `json:"message" url:"message"` + // Unique identifier for the generated reply. Useful for submitting feedback. + GenerationId *string `json:"generation_id,omitempty" url:"generation_id,omitempty"` + // Unique identifier for the response. + ResponseId *string `json:"response_id,omitempty" url:"response_id,omitempty"` _rawJSON json.RawMessage } @@ -748,6 +861,7 @@ func (c *ChatMessage) String() string { return fmt.Sprintf("%#v", c) } +// One of CHATBOT|USER to identify who the message is coming from. type ChatMessageRole string const ( @@ -795,6 +909,41 @@ func (c ChatRequestCitationQuality) Ptr() *ChatRequestCitationQuality { return &c } +// (internal) Sets inference and model options for RAG search query and tool use generations. Defaults are used when options are not specified here, meaning that other parameters outside of connectors_search_options are ignored (such as model= or temperature=). +type ChatRequestConnectorsSearchOptions struct { + Model interface{} `json:"model,omitempty" url:"model,omitempty"` + Temperature interface{} `json:"temperature,omitempty" url:"temperature,omitempty"` + MaxTokens interface{} `json:"max_tokens,omitempty" url:"max_tokens,omitempty"` + Preamble interface{} `json:"preamble,omitempty" url:"preamble,omitempty"` + // If specified, the backend will make a best effort to sample tokens deterministically, such that repeated requests with the same seed and parameters should return the same result. However, determinsim cannot be totally guaranteed. + Seed *float64 `json:"seed,omitempty" url:"seed,omitempty"` + + _rawJSON json.RawMessage +} + +func (c *ChatRequestConnectorsSearchOptions) UnmarshalJSON(data []byte) error { + type unmarshaler ChatRequestConnectorsSearchOptions + var value unmarshaler + if err := json.Unmarshal(data, &value); err != nil { + return err + } + *c = ChatRequestConnectorsSearchOptions(value) + c._rawJSON = json.RawMessage(data) + return nil +} + +func (c *ChatRequestConnectorsSearchOptions) String() string { + if len(c._rawJSON) > 0 { + if value, err := core.StringifyJSON(c._rawJSON); err == nil { + return value + } + } + if value, err := core.StringifyJSON(c); err == nil { + return value + } + return fmt.Sprintf("%#v", c) +} + // (internal) Overrides specified parts of the default Chat or RAG preamble. It is recommended that these options only be used in specific scenarios where the defaults are not adequate. type ChatRequestPromptOverride struct { Preamble interface{} `json:"preamble,omitempty" url:"preamble,omitempty"` @@ -831,14 +980,17 @@ func (c *ChatRequestPromptOverride) String() string { // // Dictates how the prompt will be constructed. // -// With `prompt_truncation` set to "AUTO", some elements from `chat_history` and `documents` will be dropped in an attempt to construct a prompt that fits within the model's context length limit. +// With `prompt_truncation` set to "AUTO", some elements from `chat_history` and `documents` will be dropped in an attempt to construct a prompt that fits within the model's context length limit. During this process the order of the documents and chat history will be changed and ranked by relevance. +// +// With `prompt_truncation` set to "AUTO_PRESERVE_ORDER", some elements from `chat_history` and `documents` will be dropped in an attempt to construct a prompt that fits within the model's context length limit. During this process the order of the documents and chat history will be preserved as they are inputted into the API. // // With `prompt_truncation` set to "OFF", no elements will be dropped. If the sum of the inputs exceeds the model's context length limit, a `TooManyTokens` error will be returned. type ChatRequestPromptTruncation string const ( - ChatRequestPromptTruncationOff ChatRequestPromptTruncation = "OFF" - ChatRequestPromptTruncationAuto ChatRequestPromptTruncation = "AUTO" + ChatRequestPromptTruncationOff ChatRequestPromptTruncation = "OFF" + ChatRequestPromptTruncationAuto ChatRequestPromptTruncation = "AUTO" + ChatRequestPromptTruncationAutoPreserveOrder ChatRequestPromptTruncation = "AUTO_PRESERVE_ORDER" ) func NewChatRequestPromptTruncationFromString(s string) (ChatRequestPromptTruncation, error) { @@ -847,6 +999,8 @@ func NewChatRequestPromptTruncationFromString(s string) (ChatRequestPromptTrunca return ChatRequestPromptTruncationOff, nil case "AUTO": return ChatRequestPromptTruncationAuto, nil + case "AUTO_PRESERVE_ORDER": + return ChatRequestPromptTruncationAutoPreserveOrder, nil } var t ChatRequestPromptTruncation return "", fmt.Errorf("%s is not a valid %T", s, t) @@ -856,28 +1010,25 @@ func (c ChatRequestPromptTruncation) Ptr() *ChatRequestPromptTruncation { return &c } -// (internal) Sets inference and model options for RAG search query and tool use generations. Defaults are used when options are not specified here, meaning that other parameters outside of search_options are ignored (such as model= or temperature=). -type ChatRequestSearchOptions struct { - Model interface{} `json:"model,omitempty" url:"model,omitempty"` - Temperature interface{} `json:"temperature,omitempty" url:"temperature,omitempty"` - MaxTokens interface{} `json:"max_tokens,omitempty" url:"max_tokens,omitempty"` - Preamble interface{} `json:"preamble,omitempty" url:"preamble,omitempty"` +type ChatRequestToolResultsItem struct { + Call *ToolCall `json:"call,omitempty" url:"call,omitempty"` + Outputs []map[string]interface{} `json:"outputs,omitempty" url:"outputs,omitempty"` _rawJSON json.RawMessage } -func (c *ChatRequestSearchOptions) UnmarshalJSON(data []byte) error { - type unmarshaler ChatRequestSearchOptions +func (c *ChatRequestToolResultsItem) UnmarshalJSON(data []byte) error { + type unmarshaler ChatRequestToolResultsItem var value unmarshaler if err := json.Unmarshal(data, &value); err != nil { return err } - *c = ChatRequestSearchOptions(value) + *c = ChatRequestToolResultsItem(value) c._rawJSON = json.RawMessage(data) return nil } -func (c *ChatRequestSearchOptions) String() string { +func (c *ChatRequestToolResultsItem) String() string { if len(c._rawJSON) > 0 { if value, err := core.StringifyJSON(c._rawJSON); err == nil { return value @@ -1056,7 +1207,7 @@ type ChatStreamEndEvent struct { // - `ERROR_TOXIC` - the model generated a reply that was deemed toxic FinishReason ChatStreamEndEventFinishReason `json:"finish_reason,omitempty" url:"finish_reason,omitempty"` // The consolidated response from the model. Contains the generated reply and all the other information streamed back in the previous events. - Response *ChatStreamEndEventResponse `json:"response,omitempty" url:"response,omitempty"` + Response *NonStreamedChatResponse `json:"response,omitempty" url:"response,omitempty"` _rawJSON json.RawMessage } @@ -1120,64 +1271,6 @@ func (c ChatStreamEndEventFinishReason) Ptr() *ChatStreamEndEventFinishReason { return &c } -// The consolidated response from the model. Contains the generated reply and all the other information streamed back in the previous events. -type ChatStreamEndEventResponse struct { - typeName string - NonStreamedChatResponse *NonStreamedChatResponse - SearchQueriesOnlyResponse *SearchQueriesOnlyResponse -} - -func NewChatStreamEndEventResponseFromNonStreamedChatResponse(value *NonStreamedChatResponse) *ChatStreamEndEventResponse { - return &ChatStreamEndEventResponse{typeName: "nonStreamedChatResponse", NonStreamedChatResponse: value} -} - -func NewChatStreamEndEventResponseFromSearchQueriesOnlyResponse(value *SearchQueriesOnlyResponse) *ChatStreamEndEventResponse { - return &ChatStreamEndEventResponse{typeName: "searchQueriesOnlyResponse", SearchQueriesOnlyResponse: value} -} - -func (c *ChatStreamEndEventResponse) UnmarshalJSON(data []byte) error { - valueNonStreamedChatResponse := new(NonStreamedChatResponse) - if err := json.Unmarshal(data, &valueNonStreamedChatResponse); err == nil { - c.typeName = "nonStreamedChatResponse" - c.NonStreamedChatResponse = valueNonStreamedChatResponse - return nil - } - valueSearchQueriesOnlyResponse := new(SearchQueriesOnlyResponse) - if err := json.Unmarshal(data, &valueSearchQueriesOnlyResponse); err == nil { - c.typeName = "searchQueriesOnlyResponse" - c.SearchQueriesOnlyResponse = valueSearchQueriesOnlyResponse - return nil - } - return fmt.Errorf("%s cannot be deserialized as a %T", data, c) -} - -func (c ChatStreamEndEventResponse) MarshalJSON() ([]byte, error) { - switch c.typeName { - default: - return nil, fmt.Errorf("invalid type %s in %T", c.typeName, c) - case "nonStreamedChatResponse": - return json.Marshal(c.NonStreamedChatResponse) - case "searchQueriesOnlyResponse": - return json.Marshal(c.SearchQueriesOnlyResponse) - } -} - -type ChatStreamEndEventResponseVisitor interface { - VisitNonStreamedChatResponse(*NonStreamedChatResponse) error - VisitSearchQueriesOnlyResponse(*SearchQueriesOnlyResponse) error -} - -func (c *ChatStreamEndEventResponse) Accept(visitor ChatStreamEndEventResponseVisitor) error { - switch c.typeName { - default: - return fmt.Errorf("invalid type %s in %T", c.typeName, c) - case "nonStreamedChatResponse": - return visitor.VisitNonStreamedChatResponse(c.NonStreamedChatResponse) - case "searchQueriesOnlyResponse": - return visitor.VisitSearchQueriesOnlyResponse(c.SearchQueriesOnlyResponse) - } -} - type ChatStreamEvent struct { _rawJSON json.RawMessage } @@ -1230,6 +1323,41 @@ func (c ChatStreamRequestCitationQuality) Ptr() *ChatStreamRequestCitationQualit return &c } +// (internal) Sets inference and model options for RAG search query and tool use generations. Defaults are used when options are not specified here, meaning that other parameters outside of connectors_search_options are ignored (such as model= or temperature=). +type ChatStreamRequestConnectorsSearchOptions struct { + Model interface{} `json:"model,omitempty" url:"model,omitempty"` + Temperature interface{} `json:"temperature,omitempty" url:"temperature,omitempty"` + MaxTokens interface{} `json:"max_tokens,omitempty" url:"max_tokens,omitempty"` + Preamble interface{} `json:"preamble,omitempty" url:"preamble,omitempty"` + // If specified, the backend will make a best effort to sample tokens deterministically, such that repeated requests with the same seed and parameters should return the same result. However, determinsim cannot be totally guaranteed. + Seed *float64 `json:"seed,omitempty" url:"seed,omitempty"` + + _rawJSON json.RawMessage +} + +func (c *ChatStreamRequestConnectorsSearchOptions) UnmarshalJSON(data []byte) error { + type unmarshaler ChatStreamRequestConnectorsSearchOptions + var value unmarshaler + if err := json.Unmarshal(data, &value); err != nil { + return err + } + *c = ChatStreamRequestConnectorsSearchOptions(value) + c._rawJSON = json.RawMessage(data) + return nil +} + +func (c *ChatStreamRequestConnectorsSearchOptions) String() string { + if len(c._rawJSON) > 0 { + if value, err := core.StringifyJSON(c._rawJSON); err == nil { + return value + } + } + if value, err := core.StringifyJSON(c); err == nil { + return value + } + return fmt.Sprintf("%#v", c) +} + // (internal) Overrides specified parts of the default Chat or RAG preamble. It is recommended that these options only be used in specific scenarios where the defaults are not adequate. type ChatStreamRequestPromptOverride struct { Preamble interface{} `json:"preamble,omitempty" url:"preamble,omitempty"` @@ -1266,14 +1394,17 @@ func (c *ChatStreamRequestPromptOverride) String() string { // // Dictates how the prompt will be constructed. // -// With `prompt_truncation` set to "AUTO", some elements from `chat_history` and `documents` will be dropped in an attempt to construct a prompt that fits within the model's context length limit. +// With `prompt_truncation` set to "AUTO", some elements from `chat_history` and `documents` will be dropped in an attempt to construct a prompt that fits within the model's context length limit. During this process the order of the documents and chat history will be changed and ranked by relevance. +// +// With `prompt_truncation` set to "AUTO_PRESERVE_ORDER", some elements from `chat_history` and `documents` will be dropped in an attempt to construct a prompt that fits within the model's context length limit. During this process the order of the documents and chat history will be preserved as they are inputted into the API. // // With `prompt_truncation` set to "OFF", no elements will be dropped. If the sum of the inputs exceeds the model's context length limit, a `TooManyTokens` error will be returned. type ChatStreamRequestPromptTruncation string const ( - ChatStreamRequestPromptTruncationOff ChatStreamRequestPromptTruncation = "OFF" - ChatStreamRequestPromptTruncationAuto ChatStreamRequestPromptTruncation = "AUTO" + ChatStreamRequestPromptTruncationOff ChatStreamRequestPromptTruncation = "OFF" + ChatStreamRequestPromptTruncationAuto ChatStreamRequestPromptTruncation = "AUTO" + ChatStreamRequestPromptTruncationAutoPreserveOrder ChatStreamRequestPromptTruncation = "AUTO_PRESERVE_ORDER" ) func NewChatStreamRequestPromptTruncationFromString(s string) (ChatStreamRequestPromptTruncation, error) { @@ -1282,6 +1413,8 @@ func NewChatStreamRequestPromptTruncationFromString(s string) (ChatStreamRequest return ChatStreamRequestPromptTruncationOff, nil case "AUTO": return ChatStreamRequestPromptTruncationAuto, nil + case "AUTO_PRESERVE_ORDER": + return ChatStreamRequestPromptTruncationAutoPreserveOrder, nil } var t ChatStreamRequestPromptTruncation return "", fmt.Errorf("%s is not a valid %T", s, t) @@ -1291,28 +1424,25 @@ func (c ChatStreamRequestPromptTruncation) Ptr() *ChatStreamRequestPromptTruncat return &c } -// (internal) Sets inference and model options for RAG search query and tool use generations. Defaults are used when options are not specified here, meaning that other parameters outside of search_options are ignored (such as model= or temperature=). -type ChatStreamRequestSearchOptions struct { - Model interface{} `json:"model,omitempty" url:"model,omitempty"` - Temperature interface{} `json:"temperature,omitempty" url:"temperature,omitempty"` - MaxTokens interface{} `json:"max_tokens,omitempty" url:"max_tokens,omitempty"` - Preamble interface{} `json:"preamble,omitempty" url:"preamble,omitempty"` +type ChatStreamRequestToolResultsItem struct { + Call *ToolCall `json:"call,omitempty" url:"call,omitempty"` + Outputs []map[string]interface{} `json:"outputs,omitempty" url:"outputs,omitempty"` _rawJSON json.RawMessage } -func (c *ChatStreamRequestSearchOptions) UnmarshalJSON(data []byte) error { - type unmarshaler ChatStreamRequestSearchOptions +func (c *ChatStreamRequestToolResultsItem) UnmarshalJSON(data []byte) error { + type unmarshaler ChatStreamRequestToolResultsItem var value unmarshaler if err := json.Unmarshal(data, &value); err != nil { return err } - *c = ChatStreamRequestSearchOptions(value) + *c = ChatStreamRequestToolResultsItem(value) c._rawJSON = json.RawMessage(data) return nil } -func (c *ChatStreamRequestSearchOptions) String() string { +func (c *ChatStreamRequestToolResultsItem) String() string { if len(c._rawJSON) > 0 { if value, err := core.StringifyJSON(c._rawJSON); err == nil { return value @@ -1384,6 +1514,64 @@ func (c *ChatTextGenerationEvent) String() string { return fmt.Sprintf("%#v", c) } +type ChatToolCallsGenerationEvent struct { + ToolCalls []*ToolCall `json:"tool_calls,omitempty" url:"tool_calls,omitempty"` + + _rawJSON json.RawMessage +} + +func (c *ChatToolCallsGenerationEvent) UnmarshalJSON(data []byte) error { + type unmarshaler ChatToolCallsGenerationEvent + var value unmarshaler + if err := json.Unmarshal(data, &value); err != nil { + return err + } + *c = ChatToolCallsGenerationEvent(value) + c._rawJSON = json.RawMessage(data) + return nil +} + +func (c *ChatToolCallsGenerationEvent) String() string { + if len(c._rawJSON) > 0 { + if value, err := core.StringifyJSON(c._rawJSON); err == nil { + return value + } + } + if value, err := core.StringifyJSON(c); err == nil { + return value + } + return fmt.Sprintf("%#v", c) +} + +type ClassifyDataMetrics struct { + LabelMetrics []*LabelMetric `json:"labelMetrics,omitempty" url:"labelMetrics,omitempty"` + + _rawJSON json.RawMessage +} + +func (c *ClassifyDataMetrics) UnmarshalJSON(data []byte) error { + type unmarshaler ClassifyDataMetrics + var value unmarshaler + if err := json.Unmarshal(data, &value); err != nil { + return err + } + *c = ClassifyDataMetrics(value) + c._rawJSON = json.RawMessage(data) + return nil +} + +func (c *ClassifyDataMetrics) String() string { + if len(c._rawJSON) > 0 { + if value, err := core.StringifyJSON(c._rawJSON); err == nil { + return value + } + } + if value, err := core.StringifyJSON(c); err == nil { + return value + } + return fmt.Sprintf("%#v", c) +} + type ClassifyExample struct { Text *string `json:"text,omitempty" url:"text,omitempty"` Label *string `json:"label,omitempty" url:"label,omitempty"` @@ -1568,6 +1756,44 @@ func (c *ClassifyResponseClassificationsItemLabelsValue) String() string { return fmt.Sprintf("%#v", c) } +// One of the Cohere API endpoints that the model can be used with. +type CompatibleEndpoint string + +const ( + CompatibleEndpointChat CompatibleEndpoint = "chat" + CompatibleEndpointEmbed CompatibleEndpoint = "embed" + CompatibleEndpointClassify CompatibleEndpoint = "classify" + CompatibleEndpointSummarize CompatibleEndpoint = "summarize" + CompatibleEndpointRerank CompatibleEndpoint = "rerank" + CompatibleEndpointRate CompatibleEndpoint = "rate" + CompatibleEndpointGenerate CompatibleEndpoint = "generate" +) + +func NewCompatibleEndpointFromString(s string) (CompatibleEndpoint, error) { + switch s { + case "chat": + return CompatibleEndpointChat, nil + case "embed": + return CompatibleEndpointEmbed, nil + case "classify": + return CompatibleEndpointClassify, nil + case "summarize": + return CompatibleEndpointSummarize, nil + case "rerank": + return CompatibleEndpointRerank, nil + case "rate": + return CompatibleEndpointRate, nil + case "generate": + return CompatibleEndpointGenerate, nil + } + var t CompatibleEndpoint + return "", fmt.Errorf("%s is not a valid %T", s, t) +} + +func (c CompatibleEndpoint) Ptr() *CompatibleEndpoint { + return &c +} + // A connector allows you to integrate data sources with the '/chat' endpoint to create grounded generations with citations to the data source. // documents to help answer users. type Connector struct { @@ -1604,16 +1830,38 @@ type Connector struct { } func (c *Connector) UnmarshalJSON(data []byte) error { - type unmarshaler Connector - var value unmarshaler - if err := json.Unmarshal(data, &value); err != nil { + type embed Connector + var unmarshaler = struct { + embed + CreatedAt *core.DateTime `json:"created_at"` + UpdatedAt *core.DateTime `json:"updated_at"` + }{ + embed: embed(*c), + } + if err := json.Unmarshal(data, &unmarshaler); err != nil { return err } - *c = Connector(value) + *c = Connector(unmarshaler.embed) + c.CreatedAt = unmarshaler.CreatedAt.Time() + c.UpdatedAt = unmarshaler.UpdatedAt.Time() c._rawJSON = json.RawMessage(data) return nil } +func (c *Connector) MarshalJSON() ([]byte, error) { + type embed Connector + var marshaler = struct { + embed + CreatedAt *core.DateTime `json:"created_at"` + UpdatedAt *core.DateTime `json:"updated_at"` + }{ + embed: embed(*c), + CreatedAt: core.NewDateTime(c.CreatedAt), + UpdatedAt: core.NewDateTime(c.UpdatedAt), + } + return json.Marshal(marshaler) +} + func (c *Connector) String() string { if len(c._rawJSON) > 0 { if value, err := core.StringifyJSON(c._rawJSON); err == nil { @@ -1842,16 +2090,38 @@ type Dataset struct { } func (d *Dataset) UnmarshalJSON(data []byte) error { - type unmarshaler Dataset - var value unmarshaler - if err := json.Unmarshal(data, &value); err != nil { + type embed Dataset + var unmarshaler = struct { + embed + CreatedAt *core.DateTime `json:"created_at"` + UpdatedAt *core.DateTime `json:"updated_at"` + }{ + embed: embed(*d), + } + if err := json.Unmarshal(data, &unmarshaler); err != nil { return err } - *d = Dataset(value) + *d = Dataset(unmarshaler.embed) + d.CreatedAt = unmarshaler.CreatedAt.Time() + d.UpdatedAt = unmarshaler.UpdatedAt.Time() d._rawJSON = json.RawMessage(data) return nil } +func (d *Dataset) MarshalJSON() ([]byte, error) { + type embed Dataset + var marshaler = struct { + embed + CreatedAt *core.DateTime `json:"created_at"` + UpdatedAt *core.DateTime `json:"updated_at"` + }{ + embed: embed(*d), + CreatedAt: core.NewDateTime(d.CreatedAt), + UpdatedAt: core.NewDateTime(d.UpdatedAt), + } + return json.Marshal(marshaler) +} + func (d *Dataset) String() string { if len(d._rawJSON) > 0 { if value, err := core.StringifyJSON(d._rawJSON); err == nil { @@ -2182,16 +2452,34 @@ type EmbedJob struct { } func (e *EmbedJob) UnmarshalJSON(data []byte) error { - type unmarshaler EmbedJob - var value unmarshaler - if err := json.Unmarshal(data, &value); err != nil { + type embed EmbedJob + var unmarshaler = struct { + embed + CreatedAt *core.DateTime `json:"created_at"` + }{ + embed: embed(*e), + } + if err := json.Unmarshal(data, &unmarshaler); err != nil { return err } - *e = EmbedJob(value) + *e = EmbedJob(unmarshaler.embed) + e.CreatedAt = unmarshaler.CreatedAt.Time() e._rawJSON = json.RawMessage(data) return nil } +func (e *EmbedJob) MarshalJSON() ([]byte, error) { + type embed EmbedJob + var marshaler = struct { + embed + CreatedAt *core.DateTime `json:"created_at"` + }{ + embed: embed(*e), + CreatedAt: core.NewDateTime(e.CreatedAt), + } + return json.Marshal(marshaler) +} + func (e *EmbedJob) String() string { if len(e._rawJSON) > 0 { if value, err := core.StringifyJSON(e._rawJSON); err == nil { @@ -2326,14 +2614,6 @@ type EmbedResponse struct { EmbeddingsByType *EmbedByTypeResponse } -func NewEmbedResponseFromEmbeddingsFloats(value *EmbedFloatsResponse) *EmbedResponse { - return &EmbedResponse{ResponseType: "embeddings_floats", EmbeddingsFloats: value} -} - -func NewEmbedResponseFromEmbeddingsByType(value *EmbedByTypeResponse) *EmbedResponse { - return &EmbedResponse{ResponseType: "embeddings_by_type", EmbeddingsByType: value} -} - func (e *EmbedResponse) UnmarshalJSON(data []byte) error { var unmarshaler struct { ResponseType string `json:"response_type"` @@ -2360,28 +2640,27 @@ func (e *EmbedResponse) UnmarshalJSON(data []byte) error { } func (e EmbedResponse) MarshalJSON() ([]byte, error) { - switch e.ResponseType { - default: - return nil, fmt.Errorf("invalid type %s in %T", e.ResponseType, e) - case "embeddings_floats": + if e.EmbeddingsFloats != nil { var marshaler = struct { ResponseType string `json:"response_type"` *EmbedFloatsResponse }{ - ResponseType: e.ResponseType, + ResponseType: "embeddings_floats", EmbedFloatsResponse: e.EmbeddingsFloats, } return json.Marshal(marshaler) - case "embeddings_by_type": + } + if e.EmbeddingsByType != nil { var marshaler = struct { ResponseType string `json:"response_type"` *EmbedByTypeResponse }{ - ResponseType: e.ResponseType, + ResponseType: "embeddings_by_type", EmbedByTypeResponse: e.EmbeddingsByType, } return json.Marshal(marshaler) } + return nil, fmt.Errorf("type %T does not define a non-empty union type", e) } type EmbedResponseVisitor interface { @@ -2390,14 +2669,53 @@ type EmbedResponseVisitor interface { } func (e *EmbedResponse) Accept(visitor EmbedResponseVisitor) error { - switch e.ResponseType { - default: - return fmt.Errorf("invalid type %s in %T", e.ResponseType, e) - case "embeddings_floats": + if e.EmbeddingsFloats != nil { return visitor.VisitEmbeddingsFloats(e.EmbeddingsFloats) - case "embeddings_by_type": + } + if e.EmbeddingsByType != nil { return visitor.VisitEmbeddingsByType(e.EmbeddingsByType) } + return fmt.Errorf("type %T does not define a non-empty union type", e) +} + +type FinetuneDatasetMetrics struct { + // The number of tokens of valid examples that can be used for training. + TrainableTokenCount *string `json:"trainableTokenCount,omitempty" url:"trainableTokenCount,omitempty"` + // The overall number of examples. + TotalExamples *string `json:"totalExamples,omitempty" url:"totalExamples,omitempty"` + // The number of training examples. + TrainExamples *string `json:"trainExamples,omitempty" url:"trainExamples,omitempty"` + // The size in bytes of all training examples. + TrainSizeBytes *string `json:"trainSizeBytes,omitempty" url:"trainSizeBytes,omitempty"` + // Number of evaluation examples. + EvalExamples *string `json:"evalExamples,omitempty" url:"evalExamples,omitempty"` + // The size in bytes of all eval examples. + EvalSizeBytes *string `json:"evalSizeBytes,omitempty" url:"evalSizeBytes,omitempty"` + + _rawJSON json.RawMessage +} + +func (f *FinetuneDatasetMetrics) UnmarshalJSON(data []byte) error { + type unmarshaler FinetuneDatasetMetrics + var value unmarshaler + if err := json.Unmarshal(data, &value); err != nil { + return err + } + *f = FinetuneDatasetMetrics(value) + f._rawJSON = json.RawMessage(data) + return nil +} + +func (f *FinetuneDatasetMetrics) String() string { + if len(f._rawJSON) > 0 { + if value, err := core.StringifyJSON(f._rawJSON); err == nil { + return value + } + } + if value, err := core.StringifyJSON(f); err == nil { + return value + } + return fmt.Sprintf("%#v", f) } type FinishReason string @@ -2718,18 +3036,6 @@ type GenerateStreamedResponse struct { StreamError *GenerateStreamError } -func NewGenerateStreamedResponseFromTextGeneration(value *GenerateStreamText) *GenerateStreamedResponse { - return &GenerateStreamedResponse{EventType: "text-generation", TextGeneration: value} -} - -func NewGenerateStreamedResponseFromStreamEnd(value *GenerateStreamEnd) *GenerateStreamedResponse { - return &GenerateStreamedResponse{EventType: "stream-end", StreamEnd: value} -} - -func NewGenerateStreamedResponseFromStreamError(value *GenerateStreamError) *GenerateStreamedResponse { - return &GenerateStreamedResponse{EventType: "stream-error", StreamError: value} -} - func (g *GenerateStreamedResponse) UnmarshalJSON(data []byte) error { var unmarshaler struct { EventType string `json:"event_type"` @@ -2762,37 +3068,37 @@ func (g *GenerateStreamedResponse) UnmarshalJSON(data []byte) error { } func (g GenerateStreamedResponse) MarshalJSON() ([]byte, error) { - switch g.EventType { - default: - return nil, fmt.Errorf("invalid type %s in %T", g.EventType, g) - case "text-generation": + if g.TextGeneration != nil { var marshaler = struct { EventType string `json:"event_type"` *GenerateStreamText }{ - EventType: g.EventType, + EventType: "text-generation", GenerateStreamText: g.TextGeneration, } return json.Marshal(marshaler) - case "stream-end": + } + if g.StreamEnd != nil { var marshaler = struct { EventType string `json:"event_type"` *GenerateStreamEnd }{ - EventType: g.EventType, + EventType: "stream-end", GenerateStreamEnd: g.StreamEnd, } return json.Marshal(marshaler) - case "stream-error": + } + if g.StreamError != nil { var marshaler = struct { EventType string `json:"event_type"` *GenerateStreamError }{ - EventType: g.EventType, + EventType: "stream-error", GenerateStreamError: g.StreamError, } return json.Marshal(marshaler) } + return nil, fmt.Errorf("type %T does not define a non-empty union type", g) } type GenerateStreamedResponseVisitor interface { @@ -2802,16 +3108,16 @@ type GenerateStreamedResponseVisitor interface { } func (g *GenerateStreamedResponse) Accept(visitor GenerateStreamedResponseVisitor) error { - switch g.EventType { - default: - return fmt.Errorf("invalid type %s in %T", g.EventType, g) - case "text-generation": + if g.TextGeneration != nil { return visitor.VisitTextGeneration(g.TextGeneration) - case "stream-end": + } + if g.StreamEnd != nil { return visitor.VisitStreamEnd(g.StreamEnd) - case "stream-error": + } + if g.StreamError != nil { return visitor.VisitStreamError(g.StreamError) } + return fmt.Errorf("type %T does not define a non-empty union type", g) } type Generation struct { @@ -2877,6 +3183,40 @@ func (g *GetConnectorResponse) String() string { return fmt.Sprintf("%#v", g) } +type LabelMetric struct { + // Total number of examples for this label + TotalExamples *string `json:"totalExamples,omitempty" url:"totalExamples,omitempty"` + // value of the label + Label *string `json:"label,omitempty" url:"label,omitempty"` + // samples for this label + Samples []string `json:"samples,omitempty" url:"samples,omitempty"` + + _rawJSON json.RawMessage +} + +func (l *LabelMetric) UnmarshalJSON(data []byte) error { + type unmarshaler LabelMetric + var value unmarshaler + if err := json.Unmarshal(data, &value); err != nil { + return err + } + *l = LabelMetric(value) + l._rawJSON = json.RawMessage(data) + return nil +} + +func (l *LabelMetric) String() string { + if len(l._rawJSON) > 0 { + if value, err := core.StringifyJSON(l._rawJSON); err == nil { + return value + } + } + if value, err := core.StringifyJSON(l); err == nil { + return value + } + return fmt.Sprintf("%#v", l) +} + type ListConnectorsResponse struct { Connectors []*Connector `json:"connectors,omitempty" url:"connectors,omitempty"` // Total number of connectors. @@ -2937,19 +3277,126 @@ func (l *ListEmbedJobResponse) String() string { return fmt.Sprintf("%#v", l) } +type ListModelsResponse struct { + Models []*Model `json:"models,omitempty" url:"models,omitempty"` + // A token to retrieve the next page of results. Provide in the page_token parameter of the next request. + NextPageToken *string `json:"next_page_token,omitempty" url:"next_page_token,omitempty"` + + _rawJSON json.RawMessage +} + +func (l *ListModelsResponse) UnmarshalJSON(data []byte) error { + type unmarshaler ListModelsResponse + var value unmarshaler + if err := json.Unmarshal(data, &value); err != nil { + return err + } + *l = ListModelsResponse(value) + l._rawJSON = json.RawMessage(data) + return nil +} + +func (l *ListModelsResponse) String() string { + if len(l._rawJSON) > 0 { + if value, err := core.StringifyJSON(l._rawJSON); err == nil { + return value + } + } + if value, err := core.StringifyJSON(l); err == nil { + return value + } + return fmt.Sprintf("%#v", l) +} + +type Metrics struct { + FinetuneDatasetMetrics *FinetuneDatasetMetrics `json:"finetune_dataset_metrics,omitempty" url:"finetune_dataset_metrics,omitempty"` + + _rawJSON json.RawMessage +} + +func (m *Metrics) UnmarshalJSON(data []byte) error { + type unmarshaler Metrics + var value unmarshaler + if err := json.Unmarshal(data, &value); err != nil { + return err + } + *m = Metrics(value) + m._rawJSON = json.RawMessage(data) + return nil +} + +func (m *Metrics) String() string { + if len(m._rawJSON) > 0 { + if value, err := core.StringifyJSON(m._rawJSON); err == nil { + return value + } + } + if value, err := core.StringifyJSON(m); err == nil { + return value + } + return fmt.Sprintf("%#v", m) +} + +// Contains information about the model and which API endpoints it can be used with. +type Model struct { + // Specify this name in the `model` parameter of API requests to use your chosen model. + Name *string `json:"name,omitempty" url:"name,omitempty"` + // The API endpoints that the model is compatible with. + Endpoints []CompatibleEndpoint `json:"endpoints,omitempty" url:"endpoints,omitempty"` + // Whether the model has been fine-tuned or not. + Finetuned *bool `json:"finetuned,omitempty" url:"finetuned,omitempty"` + // The maximum number of tokens that the model can process in a single request. Note that not all of these tokens are always available due to special tokens and preambles that Cohere has added by default. + ContextLength *float64 `json:"context_length,omitempty" url:"context_length,omitempty"` + // The name of the tokenizer used for the model. + Tokenizer *string `json:"tokenizer,omitempty" url:"tokenizer,omitempty"` + // Public URL to the tokenizer's configuration file. + TokenizerUrl *string `json:"tokenizer_url,omitempty" url:"tokenizer_url,omitempty"` + + _rawJSON json.RawMessage +} + +func (m *Model) UnmarshalJSON(data []byte) error { + type unmarshaler Model + var value unmarshaler + if err := json.Unmarshal(data, &value); err != nil { + return err + } + *m = Model(value) + m._rawJSON = json.RawMessage(data) + return nil +} + +func (m *Model) String() string { + if len(m._rawJSON) > 0 { + if value, err := core.StringifyJSON(m._rawJSON); err == nil { + return value + } + } + if value, err := core.StringifyJSON(m); err == nil { + return value + } + return fmt.Sprintf("%#v", m) +} + type NonStreamedChatResponse struct { // Contents of the reply generated by the model. Text string `json:"text" url:"text"` // Unique identifier for the generated reply. Useful for submitting feedback. - GenerationId string `json:"generation_id" url:"generation_id"` + GenerationId *string `json:"generation_id,omitempty" url:"generation_id,omitempty"` // Inline citations for the generated reply. Citations []*ChatCitation `json:"citations,omitempty" url:"citations,omitempty"` // Documents seen by the model when generating the reply. Documents []ChatDocument `json:"documents,omitempty" url:"documents,omitempty"` + // Denotes that a search for documents is required during the RAG flow. + IsSearchRequired *bool `json:"is_search_required,omitempty" url:"is_search_required,omitempty"` // Generated search queries, meant to be used as part of the RAG flow. SearchQueries []*ChatSearchQuery `json:"search_queries,omitempty" url:"search_queries,omitempty"` // Documents retrieved from each of the conducted searches. SearchResults []*ChatSearchResult `json:"search_results,omitempty" url:"search_results,omitempty"` + FinishReason *FinishReason `json:"finish_reason,omitempty" url:"finish_reason,omitempty"` + ToolCalls []*ToolCall `json:"tool_calls,omitempty" url:"tool_calls,omitempty"` + // A list of previous messages between the user and the model, meant to give the model conversational context for responding to the user's `message`. + ChatHistory []*ChatMessage `json:"chat_history,omitempty" url:"chat_history,omitempty"` _rawJSON json.RawMessage } @@ -3038,29 +3485,18 @@ func (p *ParseInfo) String() string { } type RerankRequestDocumentsItem struct { - typeName string String string RerankRequestDocumentsItemText *RerankRequestDocumentsItemText } -func NewRerankRequestDocumentsItemFromString(value string) *RerankRequestDocumentsItem { - return &RerankRequestDocumentsItem{typeName: "string", String: value} -} - -func NewRerankRequestDocumentsItemFromRerankRequestDocumentsItemText(value *RerankRequestDocumentsItemText) *RerankRequestDocumentsItem { - return &RerankRequestDocumentsItem{typeName: "rerankRequestDocumentsItemText", RerankRequestDocumentsItemText: value} -} - func (r *RerankRequestDocumentsItem) UnmarshalJSON(data []byte) error { var valueString string if err := json.Unmarshal(data, &valueString); err == nil { - r.typeName = "string" r.String = valueString return nil } valueRerankRequestDocumentsItemText := new(RerankRequestDocumentsItemText) if err := json.Unmarshal(data, &valueRerankRequestDocumentsItemText); err == nil { - r.typeName = "rerankRequestDocumentsItemText" r.RerankRequestDocumentsItemText = valueRerankRequestDocumentsItemText return nil } @@ -3068,14 +3504,13 @@ func (r *RerankRequestDocumentsItem) UnmarshalJSON(data []byte) error { } func (r RerankRequestDocumentsItem) MarshalJSON() ([]byte, error) { - switch r.typeName { - default: - return nil, fmt.Errorf("invalid type %s in %T", r.typeName, r) - case "string": + if r.String != "" { return json.Marshal(r.String) - case "rerankRequestDocumentsItemText": + } + if r.RerankRequestDocumentsItemText != nil { return json.Marshal(r.RerankRequestDocumentsItemText) } + return nil, fmt.Errorf("type %T does not include a non-empty union type", r) } type RerankRequestDocumentsItemVisitor interface { @@ -3084,14 +3519,13 @@ type RerankRequestDocumentsItemVisitor interface { } func (r *RerankRequestDocumentsItem) Accept(visitor RerankRequestDocumentsItemVisitor) error { - switch r.typeName { - default: - return fmt.Errorf("invalid type %s in %T", r.typeName, r) - case "string": + if r.String != "" { return visitor.VisitString(r.String) - case "rerankRequestDocumentsItemText": + } + if r.RerankRequestDocumentsItemText != nil { return visitor.VisitRerankRequestDocumentsItemText(r.RerankRequestDocumentsItemText) } + return fmt.Errorf("type %T does not include a non-empty union type", r) } type RerankRequestDocumentsItemText struct { @@ -3221,34 +3655,44 @@ func (r *RerankResponseResultsItemDocument) String() string { return fmt.Sprintf("%#v", r) } -type SearchQueriesOnlyResponse struct { - // Generated search queries, meant to be used as part of the RAG flow. - SearchQueries []*ChatSearchQuery `json:"search_queries,omitempty" url:"search_queries,omitempty"` +type RerankerDataMetrics struct { + // The number of training queries. + NumTrainQueries *string `json:"numTrainQueries,omitempty" url:"numTrainQueries,omitempty"` + // The sum of all relevant passages of valid training examples. + NumTrainRelevantPassages *string `json:"numTrainRelevantPassages,omitempty" url:"numTrainRelevantPassages,omitempty"` + // The sum of all hard negatives of valid training examples. + NumTrainHardNegatives *string `json:"numTrainHardNegatives,omitempty" url:"numTrainHardNegatives,omitempty"` + // The number of evaluation queries. + NumEvalQueries *string `json:"numEvalQueries,omitempty" url:"numEvalQueries,omitempty"` + // The sum of all relevant passages of valid eval examples. + NumEvalRelevantPassages *string `json:"numEvalRelevantPassages,omitempty" url:"numEvalRelevantPassages,omitempty"` + // The sum of all hard negatives of valid eval examples. + NumEvalHardNegatives *string `json:"numEvalHardNegatives,omitempty" url:"numEvalHardNegatives,omitempty"` _rawJSON json.RawMessage } -func (s *SearchQueriesOnlyResponse) UnmarshalJSON(data []byte) error { - type unmarshaler SearchQueriesOnlyResponse +func (r *RerankerDataMetrics) UnmarshalJSON(data []byte) error { + type unmarshaler RerankerDataMetrics var value unmarshaler if err := json.Unmarshal(data, &value); err != nil { return err } - *s = SearchQueriesOnlyResponse(value) - s._rawJSON = json.RawMessage(data) + *r = RerankerDataMetrics(value) + r._rawJSON = json.RawMessage(data) return nil } -func (s *SearchQueriesOnlyResponse) String() string { - if len(s._rawJSON) > 0 { - if value, err := core.StringifyJSON(s._rawJSON); err == nil { +func (r *RerankerDataMetrics) String() string { + if len(r._rawJSON) > 0 { + if value, err := core.StringifyJSON(r._rawJSON); err == nil { return value } } - if value, err := core.StringifyJSON(s); err == nil { + if value, err := core.StringifyJSON(r); err == nil { return value } - return fmt.Sprintf("%#v", s) + return fmt.Sprintf("%#v", r) } type SingleGeneration struct { @@ -3358,33 +3802,10 @@ type StreamedChatResponse struct { SearchResults *ChatSearchResultsEvent TextGeneration *ChatTextGenerationEvent CitationGeneration *ChatCitationGenerationEvent + ToolCallsGeneration *ChatToolCallsGenerationEvent StreamEnd *ChatStreamEndEvent } -func NewStreamedChatResponseFromStreamStart(value *ChatStreamStartEvent) *StreamedChatResponse { - return &StreamedChatResponse{EventType: "stream-start", StreamStart: value} -} - -func NewStreamedChatResponseFromSearchQueriesGeneration(value *ChatSearchQueriesGenerationEvent) *StreamedChatResponse { - return &StreamedChatResponse{EventType: "search-queries-generation", SearchQueriesGeneration: value} -} - -func NewStreamedChatResponseFromSearchResults(value *ChatSearchResultsEvent) *StreamedChatResponse { - return &StreamedChatResponse{EventType: "search-results", SearchResults: value} -} - -func NewStreamedChatResponseFromTextGeneration(value *ChatTextGenerationEvent) *StreamedChatResponse { - return &StreamedChatResponse{EventType: "text-generation", TextGeneration: value} -} - -func NewStreamedChatResponseFromCitationGeneration(value *ChatCitationGenerationEvent) *StreamedChatResponse { - return &StreamedChatResponse{EventType: "citation-generation", CitationGeneration: value} -} - -func NewStreamedChatResponseFromStreamEnd(value *ChatStreamEndEvent) *StreamedChatResponse { - return &StreamedChatResponse{EventType: "stream-end", StreamEnd: value} -} - func (s *StreamedChatResponse) UnmarshalJSON(data []byte) error { var unmarshaler struct { EventType string `json:"event_type"` @@ -3424,6 +3845,12 @@ func (s *StreamedChatResponse) UnmarshalJSON(data []byte) error { return err } s.CitationGeneration = value + case "tool-calls-generation": + value := new(ChatToolCallsGenerationEvent) + if err := json.Unmarshal(data, &value); err != nil { + return err + } + s.ToolCallsGeneration = value case "stream-end": value := new(ChatStreamEndEvent) if err := json.Unmarshal(data, &value); err != nil { @@ -3435,64 +3862,77 @@ func (s *StreamedChatResponse) UnmarshalJSON(data []byte) error { } func (s StreamedChatResponse) MarshalJSON() ([]byte, error) { - switch s.EventType { - default: - return nil, fmt.Errorf("invalid type %s in %T", s.EventType, s) - case "stream-start": + if s.StreamStart != nil { var marshaler = struct { EventType string `json:"event_type"` *ChatStreamStartEvent }{ - EventType: s.EventType, + EventType: "stream-start", ChatStreamStartEvent: s.StreamStart, } return json.Marshal(marshaler) - case "search-queries-generation": + } + if s.SearchQueriesGeneration != nil { var marshaler = struct { EventType string `json:"event_type"` *ChatSearchQueriesGenerationEvent }{ - EventType: s.EventType, + EventType: "search-queries-generation", ChatSearchQueriesGenerationEvent: s.SearchQueriesGeneration, } return json.Marshal(marshaler) - case "search-results": + } + if s.SearchResults != nil { var marshaler = struct { EventType string `json:"event_type"` *ChatSearchResultsEvent }{ - EventType: s.EventType, + EventType: "search-results", ChatSearchResultsEvent: s.SearchResults, } return json.Marshal(marshaler) - case "text-generation": + } + if s.TextGeneration != nil { var marshaler = struct { EventType string `json:"event_type"` *ChatTextGenerationEvent }{ - EventType: s.EventType, + EventType: "text-generation", ChatTextGenerationEvent: s.TextGeneration, } return json.Marshal(marshaler) - case "citation-generation": + } + if s.CitationGeneration != nil { var marshaler = struct { EventType string `json:"event_type"` *ChatCitationGenerationEvent }{ - EventType: s.EventType, + EventType: "citation-generation", ChatCitationGenerationEvent: s.CitationGeneration, } return json.Marshal(marshaler) - case "stream-end": + } + if s.ToolCallsGeneration != nil { + var marshaler = struct { + EventType string `json:"event_type"` + *ChatToolCallsGenerationEvent + }{ + EventType: "tool-calls-generation", + ChatToolCallsGenerationEvent: s.ToolCallsGeneration, + } + return json.Marshal(marshaler) + } + if s.StreamEnd != nil { var marshaler = struct { EventType string `json:"event_type"` *ChatStreamEndEvent }{ - EventType: s.EventType, + EventType: "stream-end", ChatStreamEndEvent: s.StreamEnd, } return json.Marshal(marshaler) } + return nil, fmt.Errorf("type %T does not define a non-empty union type", s) } type StreamedChatResponseVisitor interface { @@ -3501,26 +3941,33 @@ type StreamedChatResponseVisitor interface { VisitSearchResults(*ChatSearchResultsEvent) error VisitTextGeneration(*ChatTextGenerationEvent) error VisitCitationGeneration(*ChatCitationGenerationEvent) error + VisitToolCallsGeneration(*ChatToolCallsGenerationEvent) error VisitStreamEnd(*ChatStreamEndEvent) error } func (s *StreamedChatResponse) Accept(visitor StreamedChatResponseVisitor) error { - switch s.EventType { - default: - return fmt.Errorf("invalid type %s in %T", s.EventType, s) - case "stream-start": + if s.StreamStart != nil { return visitor.VisitStreamStart(s.StreamStart) - case "search-queries-generation": + } + if s.SearchQueriesGeneration != nil { return visitor.VisitSearchQueriesGeneration(s.SearchQueriesGeneration) - case "search-results": + } + if s.SearchResults != nil { return visitor.VisitSearchResults(s.SearchResults) - case "text-generation": + } + if s.TextGeneration != nil { return visitor.VisitTextGeneration(s.TextGeneration) - case "citation-generation": + } + if s.CitationGeneration != nil { return visitor.VisitCitationGeneration(s.CitationGeneration) - case "stream-end": + } + if s.ToolCallsGeneration != nil { + return visitor.VisitToolCallsGeneration(s.ToolCallsGeneration) + } + if s.StreamEnd != nil { return visitor.VisitStreamEnd(s.StreamEnd) } + return fmt.Errorf("type %T does not define a non-empty union type", s) } // One of `low`, `medium`, `high`, or `auto`, defaults to `auto`. Controls how close to the original text the summary is. `high` extractiveness summaries will lean towards reusing sentences verbatim, while `low` extractiveness summaries will tend to paraphrase more. If `auto` is selected, the best option will be picked based on the input text. @@ -3663,6 +4110,120 @@ func (t *TokenizeResponse) String() string { return fmt.Sprintf("%#v", t) } +type Tool struct { + // The name of the tool to be called. Valid names contain only the characters `a-z`, `A-Z`, `0-9`, `_` and must not begin with a digit. + Name string `json:"name" url:"name"` + // The description of what the tool does, the model uses the description to choose when and how to call the function. + Description string `json:"description" url:"description"` + // The input parameters of the tool. Accepts a dictionary where the key is the name of the parameter and the value is the parameter spec. Valid parameter names contain only the characters `a-z`, `A-Z`, `0-9`, `_` and must not begin with a digit. + // + // ``` + // + // { + // "my_param": { + // "description": , + // "type": , // any python data type, such as 'str', 'bool' + // "required": + // } + // } + // + // ``` + ParameterDefinitions map[string]*ToolParameterDefinitionsValue `json:"parameter_definitions,omitempty" url:"parameter_definitions,omitempty"` + + _rawJSON json.RawMessage +} + +func (t *Tool) UnmarshalJSON(data []byte) error { + type unmarshaler Tool + var value unmarshaler + if err := json.Unmarshal(data, &value); err != nil { + return err + } + *t = Tool(value) + t._rawJSON = json.RawMessage(data) + return nil +} + +func (t *Tool) String() string { + if len(t._rawJSON) > 0 { + if value, err := core.StringifyJSON(t._rawJSON); err == nil { + return value + } + } + if value, err := core.StringifyJSON(t); err == nil { + return value + } + return fmt.Sprintf("%#v", t) +} + +// Contains the tool calls generated by the model. Use it to invoke your tools. +type ToolCall struct { + // Name of the tool to call. + Name string `json:"name" url:"name"` + // The name and value of the parameters to use when invoking a tool. + Parameters map[string]interface{} `json:"parameters,omitempty" url:"parameters,omitempty"` + GenerationId string `json:"generation_id" url:"generation_id"` + + _rawJSON json.RawMessage +} + +func (t *ToolCall) UnmarshalJSON(data []byte) error { + type unmarshaler ToolCall + var value unmarshaler + if err := json.Unmarshal(data, &value); err != nil { + return err + } + *t = ToolCall(value) + t._rawJSON = json.RawMessage(data) + return nil +} + +func (t *ToolCall) String() string { + if len(t._rawJSON) > 0 { + if value, err := core.StringifyJSON(t._rawJSON); err == nil { + return value + } + } + if value, err := core.StringifyJSON(t); err == nil { + return value + } + return fmt.Sprintf("%#v", t) +} + +type ToolParameterDefinitionsValue struct { + // The description of the parameter. + Description string `json:"description" url:"description"` + // The type of the parameter. Must be a valid Python type. + Type string `json:"type" url:"type"` + // Denotes whether the parameter is always present (required) or not. Defaults to not required. + Required *bool `json:"required,omitempty" url:"required,omitempty"` + + _rawJSON json.RawMessage +} + +func (t *ToolParameterDefinitionsValue) UnmarshalJSON(data []byte) error { + type unmarshaler ToolParameterDefinitionsValue + var value unmarshaler + if err := json.Unmarshal(data, &value); err != nil { + return err + } + *t = ToolParameterDefinitionsValue(value) + t._rawJSON = json.RawMessage(data) + return nil +} + +func (t *ToolParameterDefinitionsValue) String() string { + if len(t._rawJSON) > 0 { + if value, err := core.StringifyJSON(t._rawJSON); err == nil { + return value + } + } + if value, err := core.StringifyJSON(t); err == nil { + return value + } + return fmt.Sprintf("%#v", t) +} + type UpdateConnectorResponse struct { Connector *Connector `json:"connector,omitempty" url:"connector,omitempty"`