From 1944f3514b1d2b9804a940bd38b5ead17d904f21 Mon Sep 17 00:00:00 2001 From: fern-api <115122769+fern-api[bot]@users.noreply.github.com> Date: Wed, 7 Feb 2024 19:58:19 +0000 Subject: [PATCH 1/2] SDK regeneration --- client/client.go | 68 +++----- connectors/client.go | 24 +-- core/client_option.go | 2 +- datasets/client.go | 20 +-- embedjobs/client.go | 16 +- environments.go | 2 +- types.go | 389 ++++++++++++++++++++++++++++++++---------- 7 files changed, 352 insertions(+), 169 deletions(-) diff --git a/client/client.go b/client/client.go index d2541b3..425b6f4 100644 --- a/client/client.go +++ b/client/client.go @@ -21,9 +21,9 @@ type Client struct { caller *core.Caller header http.Header + EmbedJobs *embedjobs.Client Datasets *datasets.Client Connectors *connectors.Client - EmbedJobs *embedjobs.Client } func NewClient(opts ...core.ClientOption) *Client { @@ -35,9 +35,9 @@ func NewClient(opts ...core.ClientOption) *Client { baseURL: options.BaseURL, caller: core.NewCaller(options.HTTPClient), header: options.ToHeader(), + EmbedJobs: embedjobs.NewClient(opts...), Datasets: datasets.NewClient(opts...), Connectors: connectors.NewClient(opts...), - EmbedJobs: embedjobs.NewClient(opts...), } } @@ -45,11 +45,11 @@ func NewClient(opts ...core.ClientOption) *Client { // // The endpoint features additional parameters such as [connectors](https://docs.cohere.com/docs/connectors) and `documents` that enable conversations enriched by external knowledge. We call this ["Retrieval Augmented Generation"](https://docs.cohere.com/docs/retrieval-augmented-generation-rag), or "RAG". For a full breakdown of the Chat API endpoint, document and connector modes, and streaming (with code samples), see [this guide](https://docs.cohere.com/docs/cochat-beta). func (c *Client) ChatStream(ctx context.Context, request *v2.ChatStreamRequest) (*core.Stream[v2.StreamedChatResponse], error) { - baseURL := "https://api.cohere.ai/v1" + baseURL := "https://api.cohere.ai" if c.baseURL != "" { baseURL = c.baseURL } - endpointURL := baseURL + "/" + "chat" + endpointURL := baseURL + "/" + "v1/chat" streamer := core.NewStreamer[v2.StreamedChatResponse](c.caller) return streamer.Stream( @@ -67,11 +67,11 @@ func (c *Client) ChatStream(ctx context.Context, request *v2.ChatStreamRequest) // // The endpoint features additional parameters such as [connectors](https://docs.cohere.com/docs/connectors) and `documents` that enable conversations enriched by external knowledge. We call this ["Retrieval Augmented Generation"](https://docs.cohere.com/docs/retrieval-augmented-generation-rag), or "RAG". For a full breakdown of the Chat API endpoint, document and connector modes, and streaming (with code samples), see [this guide](https://docs.cohere.com/docs/cochat-beta). func (c *Client) Chat(ctx context.Context, request *v2.ChatRequest) (*v2.NonStreamedChatResponse, error) { - baseURL := "https://api.cohere.ai/v1" + baseURL := "https://api.cohere.ai" if c.baseURL != "" { baseURL = c.baseURL } - endpointURL := baseURL + "/" + "chat" + endpointURL := baseURL + "/" + "v1/chat" var response *v2.NonStreamedChatResponse if err := c.caller.Call( @@ -91,11 +91,11 @@ func (c *Client) Chat(ctx context.Context, request *v2.ChatRequest) (*v2.NonStre // This endpoint generates realistic text conditioned on a given input. func (c *Client) GenerateStream(ctx context.Context, request *v2.GenerateStreamRequest) (*core.Stream[v2.GenerateStreamedResponse], error) { - baseURL := "https://api.cohere.ai/v1" + baseURL := "https://api.cohere.ai" if c.baseURL != "" { baseURL = c.baseURL } - endpointURL := baseURL + "/" + "generate" + endpointURL := baseURL + "/" + "v1/generate" errorDecoder := func(statusCode int, body io.Reader) error { raw, err := io.ReadAll(body) @@ -138,11 +138,11 @@ func (c *Client) GenerateStream(ctx context.Context, request *v2.GenerateStreamR // This endpoint generates realistic text conditioned on a given input. func (c *Client) Generate(ctx context.Context, request *v2.GenerateRequest) (*v2.Generation, error) { - baseURL := "https://api.cohere.ai/v1" + baseURL := "https://api.cohere.ai" if c.baseURL != "" { baseURL = c.baseURL } - endpointURL := baseURL + "/" + "generate" + endpointURL := baseURL + "/" + "v1/generate" errorDecoder := func(statusCode int, body io.Reader) error { raw, err := io.ReadAll(body) @@ -193,11 +193,11 @@ func (c *Client) Generate(ctx context.Context, request *v2.GenerateRequest) (*v2 // // If you want to learn more how to use the embedding model, have a look at the [Semantic Search Guide](/docs/semantic-search). func (c *Client) Embed(ctx context.Context, request *v2.EmbedRequest) (*v2.EmbedResponse, error) { - baseURL := "https://api.cohere.ai/v1" + baseURL := "https://api.cohere.ai" if c.baseURL != "" { baseURL = c.baseURL } - endpointURL := baseURL + "/" + "embed" + endpointURL := baseURL + "/" + "v1/embed" errorDecoder := func(statusCode int, body io.Reader) error { raw, err := io.ReadAll(body) @@ -244,11 +244,11 @@ func (c *Client) Embed(ctx context.Context, request *v2.EmbedRequest) (*v2.Embed // This endpoint takes in a query and a list of texts and produces an ordered array with each text assigned a relevance score. func (c *Client) Rerank(ctx context.Context, request *v2.RerankRequest) (*v2.RerankResponse, error) { - baseURL := "https://api.cohere.ai/v1" + baseURL := "https://api.cohere.ai" if c.baseURL != "" { baseURL = c.baseURL } - endpointURL := baseURL + "/" + "rerank" + endpointURL := baseURL + "/" + "v1/rerank" var response *v2.RerankResponse if err := c.caller.Call( @@ -269,11 +269,11 @@ func (c *Client) Rerank(ctx context.Context, request *v2.RerankRequest) (*v2.Rer // This endpoint makes a prediction about which label fits the specified text inputs best. To make a prediction, Classify uses the provided `examples` of text + label pairs as a reference. // Note: [Fine-tuned models](https://docs.cohere.com/docs/classify-fine-tuning) trained on classification examples don't require the `examples` parameter to be passed in explicitly. func (c *Client) Classify(ctx context.Context, request *v2.ClassifyRequest) (*v2.ClassifyResponse, error) { - baseURL := "https://api.cohere.ai/v1" + baseURL := "https://api.cohere.ai" if c.baseURL != "" { baseURL = c.baseURL } - endpointURL := baseURL + "/" + "classify" + endpointURL := baseURL + "/" + "v1/classify" errorDecoder := func(statusCode int, body io.Reader) error { raw, err := io.ReadAll(body) @@ -318,37 +318,13 @@ func (c *Client) Classify(ctx context.Context, request *v2.ClassifyRequest) (*v2 return response, nil } -// This endpoint identifies which language each of the provided texts is written in. -func (c *Client) DetectLanguage(ctx context.Context, request *v2.DetectLanguageRequest) (*v2.DetectLanguageResponse, error) { - baseURL := "https://api.cohere.ai/v1" - if c.baseURL != "" { - baseURL = c.baseURL - } - endpointURL := baseURL + "/" + "detect-language" - - var response *v2.DetectLanguageResponse - if err := c.caller.Call( - ctx, - &core.CallParams{ - URL: endpointURL, - Method: http.MethodPost, - Headers: c.header, - Request: request, - Response: &response, - }, - ); err != nil { - return nil, err - } - return response, nil -} - // This endpoint generates a summary in English for a given text. func (c *Client) Summarize(ctx context.Context, request *v2.SummarizeRequest) (*v2.SummarizeResponse, error) { - baseURL := "https://api.cohere.ai/v1" + baseURL := "https://api.cohere.ai" if c.baseURL != "" { baseURL = c.baseURL } - endpointURL := baseURL + "/" + "summarize" + endpointURL := baseURL + "/" + "v1/summarize" var response *v2.SummarizeResponse if err := c.caller.Call( @@ -368,11 +344,11 @@ func (c *Client) Summarize(ctx context.Context, request *v2.SummarizeRequest) (* // This endpoint splits input text into smaller units called tokens using byte-pair encoding (BPE). To learn more about tokenization and byte pair encoding, see the tokens page. func (c *Client) Tokenize(ctx context.Context, request *v2.TokenizeRequest) (*v2.TokenizeResponse, error) { - baseURL := "https://api.cohere.ai/v1" + baseURL := "https://api.cohere.ai" if c.baseURL != "" { baseURL = c.baseURL } - endpointURL := baseURL + "/" + "tokenize" + endpointURL := baseURL + "/" + "v1/tokenize" errorDecoder := func(statusCode int, body io.Reader) error { raw, err := io.ReadAll(body) @@ -419,11 +395,11 @@ func (c *Client) Tokenize(ctx context.Context, request *v2.TokenizeRequest) (*v2 // This endpoint takes tokens using byte-pair encoding and returns their text representation. To learn more about tokenization and byte pair encoding, see the tokens page. func (c *Client) Detokenize(ctx context.Context, request *v2.DetokenizeRequest) (*v2.DetokenizeResponse, error) { - baseURL := "https://api.cohere.ai/v1" + baseURL := "https://api.cohere.ai" if c.baseURL != "" { baseURL = c.baseURL } - endpointURL := baseURL + "/" + "detokenize" + endpointURL := baseURL + "/" + "v1/detokenize" var response *v2.DetokenizeResponse if err := c.caller.Call( diff --git a/connectors/client.go b/connectors/client.go index 6cd4249..ed857c7 100644 --- a/connectors/client.go +++ b/connectors/client.go @@ -35,11 +35,11 @@ func NewClient(opts ...core.ClientOption) *Client { // Returns a list of connectors ordered by descending creation date (newer first). See ['Managing your Connector'](https://docs.cohere.com/docs/managing-your-connector) for more information. func (c *Client) List(ctx context.Context, request *v2.ConnectorsListRequest) (*v2.ListConnectorsResponse, error) { - baseURL := "https://api.cohere.ai/v1" + baseURL := "https://api.cohere.ai" if c.baseURL != "" { baseURL = c.baseURL } - endpointURL := baseURL + "/" + "connectors" + endpointURL := baseURL + "/" + "v1/connectors" queryParams := make(url.Values) if request.Limit != nil { @@ -96,11 +96,11 @@ func (c *Client) List(ctx context.Context, request *v2.ConnectorsListRequest) (* // Creates a new connector. The connector is tested during registration and will cancel registration when the test is unsuccessful. See ['Creating and Deploying a Connector'](https://docs.cohere.com/docs/creating-and-deploying-a-connector) for more information. func (c *Client) Create(ctx context.Context, request *v2.CreateConnectorRequest) (*v2.CreateConnectorResponse, error) { - baseURL := "https://api.cohere.ai/v1" + baseURL := "https://api.cohere.ai" if c.baseURL != "" { baseURL = c.baseURL } - endpointURL := baseURL + "/" + "connectors" + endpointURL := baseURL + "/" + "v1/connectors" errorDecoder := func(statusCode int, body io.Reader) error { raw, err := io.ReadAll(body) @@ -156,11 +156,11 @@ func (c *Client) Create(ctx context.Context, request *v2.CreateConnectorRequest) // // The ID of the connector to retrieve. func (c *Client) Get(ctx context.Context, id string) (*v2.GetConnectorResponse, error) { - baseURL := "https://api.cohere.ai/v1" + baseURL := "https://api.cohere.ai" if c.baseURL != "" { baseURL = c.baseURL } - endpointURL := fmt.Sprintf(baseURL+"/"+"connectors/%v", id) + endpointURL := fmt.Sprintf(baseURL+"/"+"v1/connectors/%v", id) errorDecoder := func(statusCode int, body io.Reader) error { raw, err := io.ReadAll(body) @@ -215,11 +215,11 @@ func (c *Client) Get(ctx context.Context, id string) (*v2.GetConnectorResponse, // // The ID of the connector to delete. func (c *Client) Delete(ctx context.Context, id string) (v2.DeleteConnectorResponse, error) { - baseURL := "https://api.cohere.ai/v1" + baseURL := "https://api.cohere.ai" if c.baseURL != "" { baseURL = c.baseURL } - endpointURL := fmt.Sprintf(baseURL+"/"+"connectors/%v", id) + endpointURL := fmt.Sprintf(baseURL+"/"+"v1/connectors/%v", id) errorDecoder := func(statusCode int, body io.Reader) error { raw, err := io.ReadAll(body) @@ -281,11 +281,11 @@ func (c *Client) Delete(ctx context.Context, id string) (v2.DeleteConnectorRespo // // The ID of the connector to update. func (c *Client) Update(ctx context.Context, id string, request *v2.UpdateConnectorRequest) (*v2.UpdateConnectorResponse, error) { - baseURL := "https://api.cohere.ai/v1" + baseURL := "https://api.cohere.ai" if c.baseURL != "" { baseURL = c.baseURL } - endpointURL := fmt.Sprintf(baseURL+"/"+"connectors/%v", id) + endpointURL := fmt.Sprintf(baseURL+"/"+"v1/connectors/%v", id) errorDecoder := func(statusCode int, body io.Reader) error { raw, err := io.ReadAll(body) @@ -348,11 +348,11 @@ func (c *Client) Update(ctx context.Context, id string, request *v2.UpdateConnec // // The ID of the connector to authorize. func (c *Client) OAuthAuthorize(ctx context.Context, id string, request *v2.ConnectorsOAuthAuthorizeRequest) (*v2.OAuthAuthorizeResponse, error) { - baseURL := "https://api.cohere.ai/v1" + baseURL := "https://api.cohere.ai" if c.baseURL != "" { baseURL = c.baseURL } - endpointURL := fmt.Sprintf(baseURL+"/"+"connectors/%v/oauth/authorize", id) + endpointURL := fmt.Sprintf(baseURL+"/"+"v1/connectors/%v/oauth/authorize", id) queryParams := make(url.Values) if request.AfterTokenRedirect != nil { diff --git a/core/client_option.go b/core/client_option.go index 6a25bed..ae44a05 100644 --- a/core/client_option.go +++ b/core/client_option.go @@ -48,6 +48,6 @@ func (c *ClientOptions) cloneHeader() http.Header { headers := c.HTTPHeader.Clone() headers.Set("X-Fern-Language", "Go") headers.Set("X-Fern-SDK-Name", "github.com/cohere-ai/cohere-go/v2") - headers.Set("X-Fern-SDK-Version", "v2.5.1") + headers.Set("X-Fern-SDK-Version", "v2.5.2") return headers } diff --git a/datasets/client.go b/datasets/client.go index 9aedae0..67fe3a8 100644 --- a/datasets/client.go +++ b/datasets/client.go @@ -35,11 +35,11 @@ func NewClient(opts ...core.ClientOption) *Client { // List datasets that have been created. func (c *Client) List(ctx context.Context, request *v2.DatasetsListRequest) (*v2.DatasetsListResponse, error) { - baseURL := "https://api.cohere.ai/v1" + baseURL := "https://api.cohere.ai" if c.baseURL != "" { baseURL = c.baseURL } - endpointURL := baseURL + "/" + "datasets" + endpointURL := baseURL + "/" + "v1/datasets" queryParams := make(url.Values) if request.DatasetType != nil { @@ -78,11 +78,11 @@ func (c *Client) List(ctx context.Context, request *v2.DatasetsListRequest) (*v2 // Create a dataset by uploading a file. See ['Dataset Creation'](https://docs.cohere.com/docs/datasets#dataset-creation) for more information. func (c *Client) Create(ctx context.Context, data io.Reader, evalData io.Reader, request *v2.DatasetsCreateRequest) (*v2.DatasetsCreateResponse, error) { - baseURL := "https://api.cohere.ai/v1" + baseURL := "https://api.cohere.ai" if c.baseURL != "" { baseURL = c.baseURL } - endpointURL := baseURL + "/" + "datasets" + endpointURL := baseURL + "/" + "v1/datasets" queryParams := make(url.Values) if request.Name != nil { @@ -160,11 +160,11 @@ func (c *Client) Create(ctx context.Context, data io.Reader, evalData io.Reader, // View the dataset storage usage for your Organization. Each Organization can have up to 10GB of storage across all their users. func (c *Client) GetUsage(ctx context.Context) (*v2.DatasetsGetUsageResponse, error) { - baseURL := "https://api.cohere.ai/v1" + baseURL := "https://api.cohere.ai" if c.baseURL != "" { baseURL = c.baseURL } - endpointURL := baseURL + "/" + "datasets/usage" + endpointURL := baseURL + "/" + "v1/datasets/usage" var response *v2.DatasetsGetUsageResponse if err := c.caller.Call( @@ -183,11 +183,11 @@ func (c *Client) GetUsage(ctx context.Context) (*v2.DatasetsGetUsageResponse, er // Retrieve a dataset by ID. See ['Datasets'](https://docs.cohere.com/docs/datasets) for more information. func (c *Client) Get(ctx context.Context, id string) (*v2.DatasetsGetResponse, error) { - baseURL := "https://api.cohere.ai/v1" + baseURL := "https://api.cohere.ai" if c.baseURL != "" { baseURL = c.baseURL } - endpointURL := fmt.Sprintf(baseURL+"/"+"datasets/%v", id) + endpointURL := fmt.Sprintf(baseURL+"/"+"v1/datasets/%v", id) var response *v2.DatasetsGetResponse if err := c.caller.Call( @@ -206,11 +206,11 @@ func (c *Client) Get(ctx context.Context, id string) (*v2.DatasetsGetResponse, e // Delete a dataset by ID. Datasets are automatically deleted after 30 days, but they can also be deleted manually. func (c *Client) Delete(ctx context.Context, id string) (map[string]interface{}, error) { - baseURL := "https://api.cohere.ai/v1" + baseURL := "https://api.cohere.ai" if c.baseURL != "" { baseURL = c.baseURL } - endpointURL := fmt.Sprintf(baseURL+"/"+"datasets/%v", id) + endpointURL := fmt.Sprintf(baseURL+"/"+"v1/datasets/%v", id) var response map[string]interface{} if err := c.caller.Call( diff --git a/embedjobs/client.go b/embedjobs/client.go index 0ddd31b..de5ccf3 100644 --- a/embedjobs/client.go +++ b/embedjobs/client.go @@ -34,11 +34,11 @@ func NewClient(opts ...core.ClientOption) *Client { // The list embed job endpoint allows users to view all embed jobs history for that specific user. func (c *Client) List(ctx context.Context) (*v2.ListEmbedJobResponse, error) { - baseURL := "https://api.cohere.ai/v1" + baseURL := "https://api.cohere.ai" if c.baseURL != "" { baseURL = c.baseURL } - endpointURL := baseURL + "/" + "embed-jobs" + endpointURL := baseURL + "/" + "v1/embed-jobs" errorDecoder := func(statusCode int, body io.Reader) error { raw, err := io.ReadAll(body) @@ -84,11 +84,11 @@ func (c *Client) List(ctx context.Context) (*v2.ListEmbedJobResponse, error) { // This API launches an async Embed job for a [Dataset](https://docs.cohere.com/docs/datasets) of type `embed-input`. The result of a completed embed job is new Dataset of type `embed-output`, which contains the original text entries and the corresponding embeddings. func (c *Client) Create(ctx context.Context, request *v2.CreateEmbedJobRequest) (*v2.CreateEmbedJobResponse, error) { - baseURL := "https://api.cohere.ai/v1" + baseURL := "https://api.cohere.ai" if c.baseURL != "" { baseURL = c.baseURL } - endpointURL := baseURL + "/" + "embed-jobs" + endpointURL := baseURL + "/" + "v1/embed-jobs" errorDecoder := func(statusCode int, body io.Reader) error { raw, err := io.ReadAll(body) @@ -137,11 +137,11 @@ func (c *Client) Create(ctx context.Context, request *v2.CreateEmbedJobRequest) // // The ID of the embed job to retrieve. func (c *Client) Get(ctx context.Context, id string) (*v2.EmbedJob, error) { - baseURL := "https://api.cohere.ai/v1" + baseURL := "https://api.cohere.ai" if c.baseURL != "" { baseURL = c.baseURL } - endpointURL := fmt.Sprintf(baseURL+"/"+"embed-jobs/%v", id) + endpointURL := fmt.Sprintf(baseURL+"/"+"v1/embed-jobs/%v", id) errorDecoder := func(statusCode int, body io.Reader) error { raw, err := io.ReadAll(body) @@ -196,11 +196,11 @@ func (c *Client) Get(ctx context.Context, id string) (*v2.EmbedJob, error) { // // The ID of the embed job to cancel. func (c *Client) Cancel(ctx context.Context, id string) error { - baseURL := "https://api.cohere.ai/v1" + baseURL := "https://api.cohere.ai" if c.baseURL != "" { baseURL = c.baseURL } - endpointURL := fmt.Sprintf(baseURL+"/"+"embed-jobs/%v/cancel", id) + endpointURL := fmt.Sprintf(baseURL+"/"+"v1/embed-jobs/%v/cancel", id) errorDecoder := func(statusCode int, body io.Reader) error { raw, err := io.ReadAll(body) diff --git a/environments.go b/environments.go index c2ad074..bb280a3 100644 --- a/environments.go +++ b/environments.go @@ -9,5 +9,5 @@ package api var Environments = struct { Production string }{ - Production: "https://api.cohere.ai/v1", + Production: "https://api.cohere.ai", } diff --git a/types.go b/types.go index 6f08540..3d58a8b 100644 --- a/types.go +++ b/types.go @@ -52,8 +52,22 @@ type ChatRequest struct { // Defaults to `0.3`. // // A non-negative float that tunes the degree of randomness in generation. Lower temperatures mean less random generations, and higher temperatures mean more random generations. + // + // Randomness can be further maximized by increasing the value of the `p` parameter. Temperature *float64 `json:"temperature,omitempty"` - stream bool + // The maximum number of tokens the model will generate as part of the response. Note: Setting a low value may result in incomplete generations. + MaxTokens *int `json:"max_tokens,omitempty"` + // Ensures only the top `k` most likely tokens are considered for generation at each step. + // Defaults to `0`, min value of `0`, max value of `500`. + K *int `json:"k,omitempty"` + // Ensures that only the most likely tokens, with total probability mass of `p`, are considered for generation at each step. If both `k` and `p` are enabled, `p` acts after `k`. + // Defaults to `0.75`. min value of `0.01`, max value of `0.99`. + P *float64 `json:"p,omitempty"` + // Used to reduce repetitiveness of generated tokens. The higher the value, the stronger a penalty is applied to previously present tokens, proportional to how many times they have already appeared in the prompt or prior generation. + FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"` + // Defaults to `0.0`, min value of `0.0`, max value of `1.0`. Can be used to reduce repetitiveness of generated tokens. Similar to `frequency_penalty`, except that this penalty is applied equally to all tokens that have already appeared, regardless of their exact frequencies. + PresencePenalty *float64 `json:"presence_penalty,omitempty"` + stream bool } func (c *ChatRequest) Stream() bool { @@ -126,8 +140,22 @@ type ChatStreamRequest struct { // Defaults to `0.3`. // // A non-negative float that tunes the degree of randomness in generation. Lower temperatures mean less random generations, and higher temperatures mean more random generations. + // + // Randomness can be further maximized by increasing the value of the `p` parameter. Temperature *float64 `json:"temperature,omitempty"` - stream bool + // The maximum number of tokens the model will generate as part of the response. Note: Setting a low value may result in incomplete generations. + MaxTokens *int `json:"max_tokens,omitempty"` + // Ensures only the top `k` most likely tokens are considered for generation at each step. + // Defaults to `0`, min value of `0`, max value of `500`. + K *int `json:"k,omitempty"` + // Ensures that only the most likely tokens, with total probability mass of `p`, are considered for generation at each step. If both `k` and `p` are enabled, `p` acts after `k`. + // Defaults to `0.75`. min value of `0.01`, max value of `0.99`. + P *float64 `json:"p,omitempty"` + // Used to reduce repetitiveness of generated tokens. The higher the value, the stronger a penalty is applied to previously present tokens, proportional to how many times they have already appeared in the prompt or prior generation. + FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"` + // Defaults to `0.0`, min value of `0.0`, max value of `1.0`. Can be used to reduce repetitiveness of generated tokens. Similar to `frequency_penalty`, except that this penalty is applied equally to all tokens that have already appeared, regardless of their exact frequencies. + PresencePenalty *float64 `json:"presence_penalty,omitempty"` + stream bool } func (c *ChatStreamRequest) Stream() bool { @@ -164,7 +192,7 @@ type ClassifyRequest struct { Inputs []string `json:"inputs,omitempty"` // An array of examples to provide context to the model. Each example is a text string and its associated label/class. Each unique label requires at least 2 examples associated with it; the maximum number of examples is 2500, and each example has a maximum length of 512 tokens. The values should be structured as `{text: "...",label: "..."}`. // Note: [Fine-tuned Models](https://docs.cohere.com/docs/classify-fine-tuning) trained on classification examples don't require the `examples` parameter to be passed in explicitly. - Examples []*ClassifyRequestExamplesItem `json:"examples,omitempty"` + Examples []*ClassifyExample `json:"examples,omitempty"` // The identifier of the model. Currently available models are `embed-multilingual-v2.0`, `embed-english-light-v2.0`, and `embed-english-v2.0` (default). Smaller "light" models are faster, while larger models will perform better. [Fine-tuned models](https://docs.cohere.com/docs/fine-tuning) can also be supplied with their full ID. Model *string `json:"model,omitempty"` // The ID of a custom playground preset. You can create presets in the [playground](https://dashboard.cohere.ai/playground/classify?model=large). If you use a preset, all other parameters become optional, and any included parameters will override the preset's parameters. @@ -175,13 +203,6 @@ type ClassifyRequest struct { Truncate *ClassifyRequestTruncate `json:"truncate,omitempty"` } -type DetectLanguageRequest struct { - // List of strings to run the detection on. - Texts []string `json:"texts,omitempty"` - // The identifier of the model to generate with. - Model *string `json:"model,omitempty"` -} - type DetokenizeRequest struct { // The list of tokens to be detokenized. Tokens []int `json:"tokens,omitempty"` @@ -215,7 +236,7 @@ type EmbedRequest struct { // * `"uint8"`: Use this when you want to get back unsigned int8 embeddings. Valid for only v3 models. // * `"binary"`: Use this when you want to get back signed binary embeddings. Valid for only v3 models. // * `"ubinary"`: Use this when you want to get back unsigned binary embeddings. Valid for only v3 models. - EmbeddingTypes []string `json:"embedding_types,omitempty"` + EmbeddingTypes []EmbedRequestEmbeddingTypesItem `json:"embedding_types,omitempty"` // One of `NONE|START|END` to specify how the API will handle inputs longer than the maximum token length. // // Passing `START` will discard the start of the input. `END` will discard the end of the input. In both cases, input is discarded until the remaining input is exactly the maximum input token length for the model. @@ -261,9 +282,15 @@ type GenerateRequest struct { // Ensures that only the most likely tokens, with total probability mass of `p`, are considered for generation at each step. If both `k` and `p` are enabled, `p` acts after `k`. // Defaults to `0.75`. min value of `0.01`, max value of `0.99`. P *float64 `json:"p,omitempty"` - // Used to reduce repetitiveness of generated tokens. The higher the value, the stronger a penalty is applied to previously present tokens, proportional to how many times they have already appeared in the prompt or prior generation.' + // Used to reduce repetitiveness of generated tokens. The higher the value, the stronger a penalty is applied to previously present tokens, proportional to how many times they have already appeared in the prompt or prior generation. + // + // Using `frequency_penalty` in combination with `presence_penalty` is not supported on newer models. FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"` - // Defaults to `0.0`, min value of `0.0`, max value of `1.0`. Can be used to reduce repetitiveness of generated tokens. Similar to `frequency_penalty`, except that this penalty is applied equally to all tokens that have already appeared, regardless of their exact frequencies. + // Defaults to `0.0`, min value of `0.0`, max value of `1.0`. + // + // Can be used to reduce repetitiveness of generated tokens. Similar to `frequency_penalty`, except that this penalty is applied equally to all tokens that have already appeared, regardless of their exact frequencies. + // + // Using `frequency_penalty` in combination with `presence_penalty` is not supported on newer models. PresencePenalty *float64 `json:"presence_penalty,omitempty"` // One of `GENERATION|ALL|NONE` to specify how and if the token likelihoods are returned with the response. Defaults to `NONE`. // @@ -271,13 +298,15 @@ type GenerateRequest struct { // // If `ALL` is selected, the token likelihoods will be provided both for the prompt and the generated text. ReturnLikelihoods *GenerateRequestReturnLikelihoods `json:"return_likelihoods,omitempty"` + // Certain models support the `logit_bias` parameter. + // // Used to prevent the model from generating unwanted tokens or to incentivize it to include desired tokens. The format is `{token_id: bias}` where bias is a float between -10 and 10. Tokens can be obtained from text using [Tokenize](/reference/tokenize). // // For example, if the value `{'11': -10}` is provided, the model will be very unlikely to include the token 11 (`"\n"`, the newline character) anywhere in the generated text. In contrast `{'11': 10}` will result in generations that nearly only contain that token. Values between -10 and 10 will proportionally affect the likelihood of the token appearing in the generated text. - // - // Note: logit bias may not be supported for all custom models. LogitBias map[string]float64 `json:"logit_bias,omitempty"` - stream bool + // When enabled, the user's prompt will be sent to the model without any pre-processing. + RawPrompting *bool `json:"raw_prompting,omitempty"` + stream bool } func (g *GenerateRequest) Stream() bool { @@ -344,9 +373,15 @@ type GenerateStreamRequest struct { // Ensures that only the most likely tokens, with total probability mass of `p`, are considered for generation at each step. If both `k` and `p` are enabled, `p` acts after `k`. // Defaults to `0.75`. min value of `0.01`, max value of `0.99`. P *float64 `json:"p,omitempty"` - // Used to reduce repetitiveness of generated tokens. The higher the value, the stronger a penalty is applied to previously present tokens, proportional to how many times they have already appeared in the prompt or prior generation.' + // Used to reduce repetitiveness of generated tokens. The higher the value, the stronger a penalty is applied to previously present tokens, proportional to how many times they have already appeared in the prompt or prior generation. + // + // Using `frequency_penalty` in combination with `presence_penalty` is not supported on newer models. FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"` - // Defaults to `0.0`, min value of `0.0`, max value of `1.0`. Can be used to reduce repetitiveness of generated tokens. Similar to `frequency_penalty`, except that this penalty is applied equally to all tokens that have already appeared, regardless of their exact frequencies. + // Defaults to `0.0`, min value of `0.0`, max value of `1.0`. + // + // Can be used to reduce repetitiveness of generated tokens. Similar to `frequency_penalty`, except that this penalty is applied equally to all tokens that have already appeared, regardless of their exact frequencies. + // + // Using `frequency_penalty` in combination with `presence_penalty` is not supported on newer models. PresencePenalty *float64 `json:"presence_penalty,omitempty"` // One of `GENERATION|ALL|NONE` to specify how and if the token likelihoods are returned with the response. Defaults to `NONE`. // @@ -354,13 +389,15 @@ type GenerateStreamRequest struct { // // If `ALL` is selected, the token likelihoods will be provided both for the prompt and the generated text. ReturnLikelihoods *GenerateStreamRequestReturnLikelihoods `json:"return_likelihoods,omitempty"` + // Certain models support the `logit_bias` parameter. + // // Used to prevent the model from generating unwanted tokens or to incentivize it to include desired tokens. The format is `{token_id: bias}` where bias is a float between -10 and 10. Tokens can be obtained from text using [Tokenize](/reference/tokenize). // // For example, if the value `{'11': -10}` is provided, the model will be very unlikely to include the token 11 (`"\n"`, the newline character) anywhere in the generated text. In contrast `{'11': 10}` will result in generations that nearly only contain that token. Values between -10 and 10 will proportionally affect the likelihood of the token appearing in the generated text. - // - // Note: logit bias may not be supported for all custom models. LogitBias map[string]float64 `json:"logit_bias,omitempty"` - stream bool + // When enabled, the user's prompt will be sent to the model without any pre-processing. + RawPrompting *bool `json:"raw_prompting,omitempty"` + stream bool } func (g *GenerateStreamRequest) Stream() bool { @@ -628,7 +665,7 @@ func (c *ChatCitationGenerationEvent) String() string { // The connector used for fetching documents. type ChatConnector struct { - // The identifier of the connector. Currently only 'web-search' is supported. + // The identifier of the connector. Id string `json:"id"` // An optional override to set the token that Cohere passes to the connector in the Authorization header. UserAccessToken *string `json:"user_access_token,omitempty"` @@ -758,6 +795,38 @@ func (c ChatRequestCitationQuality) Ptr() *ChatRequestCitationQuality { return &c } +// (internal) Overrides specified parts of the default Chat or RAG preamble. It is recommended that these options only be used in specific scenarios where the defaults are not adequate. +type ChatRequestPromptOverride struct { + Preamble interface{} `json:"preamble,omitempty"` + TaskDescription interface{} `json:"task_description,omitempty"` + StyleGuide interface{} `json:"style_guide,omitempty"` + + _rawJSON json.RawMessage +} + +func (c *ChatRequestPromptOverride) UnmarshalJSON(data []byte) error { + type unmarshaler ChatRequestPromptOverride + var value unmarshaler + if err := json.Unmarshal(data, &value); err != nil { + return err + } + *c = ChatRequestPromptOverride(value) + c._rawJSON = json.RawMessage(data) + return nil +} + +func (c *ChatRequestPromptOverride) String() string { + if len(c._rawJSON) > 0 { + if value, err := core.StringifyJSON(c._rawJSON); err == nil { + return value + } + } + if value, err := core.StringifyJSON(c); err == nil { + return value + } + return fmt.Sprintf("%#v", c) +} + // Defaults to `AUTO` when `connectors` are specified and `OFF` in all other cases. // // Dictates how the prompt will be constructed. @@ -787,6 +856,39 @@ func (c ChatRequestPromptTruncation) Ptr() *ChatRequestPromptTruncation { return &c } +// (internal) Sets inference and model options for RAG search query and tool use generations. Defaults are used when options are not specified here, meaning that other parameters outside of search_options are ignored (such as model= or temperature=). +type ChatRequestSearchOptions struct { + Model interface{} `json:"model,omitempty"` + Temperature interface{} `json:"temperature,omitempty"` + MaxTokens interface{} `json:"max_tokens,omitempty"` + Preamble interface{} `json:"preamble,omitempty"` + + _rawJSON json.RawMessage +} + +func (c *ChatRequestSearchOptions) UnmarshalJSON(data []byte) error { + type unmarshaler ChatRequestSearchOptions + var value unmarshaler + if err := json.Unmarshal(data, &value); err != nil { + return err + } + *c = ChatRequestSearchOptions(value) + c._rawJSON = json.RawMessage(data) + return nil +} + +func (c *ChatRequestSearchOptions) String() string { + if len(c._rawJSON) > 0 { + if value, err := core.StringifyJSON(c._rawJSON); err == nil { + return value + } + } + if value, err := core.StringifyJSON(c); err == nil { + return value + } + return fmt.Sprintf("%#v", c) +} + type ChatSearchQueriesGenerationEvent struct { // Generated search queries, meant to be used as part of the RAG flow. SearchQueries []*ChatSearchQuery `json:"search_queries,omitempty"` @@ -853,7 +955,7 @@ func (c *ChatSearchQuery) String() string { type ChatSearchResult struct { SearchQuery *ChatSearchQuery `json:"search_query,omitempty"` // The connector from which this result comes from. - Connector *ChatConnector `json:"connector,omitempty"` + Connector *ChatSearchResultConnector `json:"connector,omitempty"` // Identifiers of documents found by this search query. DocumentIds []string `json:"document_ids,omitempty"` @@ -883,6 +985,37 @@ func (c *ChatSearchResult) String() string { return fmt.Sprintf("%#v", c) } +// The connector used for fetching documents. +type ChatSearchResultConnector struct { + // The identifier of the connector. + Id string `json:"id"` + + _rawJSON json.RawMessage +} + +func (c *ChatSearchResultConnector) UnmarshalJSON(data []byte) error { + type unmarshaler ChatSearchResultConnector + var value unmarshaler + if err := json.Unmarshal(data, &value); err != nil { + return err + } + *c = ChatSearchResultConnector(value) + c._rawJSON = json.RawMessage(data) + return nil +} + +func (c *ChatSearchResultConnector) String() string { + if len(c._rawJSON) > 0 { + if value, err := core.StringifyJSON(c._rawJSON); err == nil { + return value + } + } + if value, err := core.StringifyJSON(c); err == nil { + return value + } + return fmt.Sprintf("%#v", c) +} + type ChatSearchResultsEvent struct { // Conducted searches and the ids of documents retrieved from each of them. SearchResults []*ChatSearchResult `json:"search_results,omitempty"` @@ -1097,6 +1230,38 @@ func (c ChatStreamRequestCitationQuality) Ptr() *ChatStreamRequestCitationQualit return &c } +// (internal) Overrides specified parts of the default Chat or RAG preamble. It is recommended that these options only be used in specific scenarios where the defaults are not adequate. +type ChatStreamRequestPromptOverride struct { + Preamble interface{} `json:"preamble,omitempty"` + TaskDescription interface{} `json:"task_description,omitempty"` + StyleGuide interface{} `json:"style_guide,omitempty"` + + _rawJSON json.RawMessage +} + +func (c *ChatStreamRequestPromptOverride) UnmarshalJSON(data []byte) error { + type unmarshaler ChatStreamRequestPromptOverride + var value unmarshaler + if err := json.Unmarshal(data, &value); err != nil { + return err + } + *c = ChatStreamRequestPromptOverride(value) + c._rawJSON = json.RawMessage(data) + return nil +} + +func (c *ChatStreamRequestPromptOverride) String() string { + if len(c._rawJSON) > 0 { + if value, err := core.StringifyJSON(c._rawJSON); err == nil { + return value + } + } + if value, err := core.StringifyJSON(c); err == nil { + return value + } + return fmt.Sprintf("%#v", c) +} + // Defaults to `AUTO` when `connectors` are specified and `OFF` in all other cases. // // Dictates how the prompt will be constructed. @@ -1126,6 +1291,39 @@ func (c ChatStreamRequestPromptTruncation) Ptr() *ChatStreamRequestPromptTruncat return &c } +// (internal) Sets inference and model options for RAG search query and tool use generations. Defaults are used when options are not specified here, meaning that other parameters outside of search_options are ignored (such as model= or temperature=). +type ChatStreamRequestSearchOptions struct { + Model interface{} `json:"model,omitempty"` + Temperature interface{} `json:"temperature,omitempty"` + MaxTokens interface{} `json:"max_tokens,omitempty"` + Preamble interface{} `json:"preamble,omitempty"` + + _rawJSON json.RawMessage +} + +func (c *ChatStreamRequestSearchOptions) UnmarshalJSON(data []byte) error { + type unmarshaler ChatStreamRequestSearchOptions + var value unmarshaler + if err := json.Unmarshal(data, &value); err != nil { + return err + } + *c = ChatStreamRequestSearchOptions(value) + c._rawJSON = json.RawMessage(data) + return nil +} + +func (c *ChatStreamRequestSearchOptions) String() string { + if len(c._rawJSON) > 0 { + if value, err := core.StringifyJSON(c._rawJSON); err == nil { + return value + } + } + if value, err := core.StringifyJSON(c); err == nil { + return value + } + return fmt.Sprintf("%#v", c) +} + type ChatStreamStartEvent struct { // Unique identifier for the generated reply. Useful for submitting feedback. GenerationId string `json:"generation_id"` @@ -1186,25 +1384,25 @@ func (c *ChatTextGenerationEvent) String() string { return fmt.Sprintf("%#v", c) } -type ClassifyRequestExamplesItem struct { +type ClassifyExample struct { Text *string `json:"text,omitempty"` Label *string `json:"label,omitempty"` _rawJSON json.RawMessage } -func (c *ClassifyRequestExamplesItem) UnmarshalJSON(data []byte) error { - type unmarshaler ClassifyRequestExamplesItem +func (c *ClassifyExample) UnmarshalJSON(data []byte) error { + type unmarshaler ClassifyExample var value unmarshaler if err := json.Unmarshal(data, &value); err != nil { return err } - *c = ClassifyRequestExamplesItem(value) + *c = ClassifyExample(value) c._rawJSON = json.RawMessage(data) return nil } -func (c *ClassifyRequestExamplesItem) String() string { +func (c *ClassifyExample) String() string { if len(c._rawJSON) > 0 { if value, err := core.StringifyJSON(c._rawJSON); err == nil { return value @@ -1452,6 +1650,10 @@ func (c ConnectorAuthStatus) Ptr() *ConnectorAuthStatus { } type ConnectorOAuth struct { + // The OAuth 2.0 client ID. This field is encrypted at rest. + ClientId *string `json:"client_id,omitempty"` + // The OAuth 2.0 client Secret. This field is encrypted at rest and never returned in a response. + ClientSecret *string `json:"client_secret,omitempty"` // The OAuth 2.0 /authorize endpoint to use when users authorize the connector. AuthorizeUrl string `json:"authorize_url"` // The OAuth 2.0 /token endpoint to use when users authorize the connector. @@ -1622,7 +1824,9 @@ type Dataset struct { // The creation date CreatedAt time.Time `json:"created_at"` // The last update date - UpdatedAt time.Time `json:"updated_at"` + UpdatedAt time.Time `json:"updated_at"` + DatasetType DatasetType `json:"dataset_type,omitempty"` + ValidationStatus DatasetValidationStatus `json:"validation_status,omitempty"` // Errors found during validation ValidationError *string `json:"validation_error,omitempty"` // the avro schema of the dataset @@ -1783,67 +1987,6 @@ func (d DatasetValidationStatus) Ptr() *DatasetValidationStatus { type DeleteConnectorResponse = map[string]interface{} -type DetectLanguageResponse struct { - // List of languages, one per input text - Results []*DetectLanguageResponseResultsItem `json:"results,omitempty"` - Meta *ApiMeta `json:"meta,omitempty"` - - _rawJSON json.RawMessage -} - -func (d *DetectLanguageResponse) UnmarshalJSON(data []byte) error { - type unmarshaler DetectLanguageResponse - var value unmarshaler - if err := json.Unmarshal(data, &value); err != nil { - return err - } - *d = DetectLanguageResponse(value) - d._rawJSON = json.RawMessage(data) - return nil -} - -func (d *DetectLanguageResponse) String() string { - if len(d._rawJSON) > 0 { - if value, err := core.StringifyJSON(d._rawJSON); err == nil { - return value - } - } - if value, err := core.StringifyJSON(d); err == nil { - return value - } - return fmt.Sprintf("%#v", d) -} - -type DetectLanguageResponseResultsItem struct { - LanguageName *string `json:"language_name,omitempty"` - LanguageCode *string `json:"language_code,omitempty"` - - _rawJSON json.RawMessage -} - -func (d *DetectLanguageResponseResultsItem) UnmarshalJSON(data []byte) error { - type unmarshaler DetectLanguageResponseResultsItem - var value unmarshaler - if err := json.Unmarshal(data, &value); err != nil { - return err - } - *d = DetectLanguageResponseResultsItem(value) - d._rawJSON = json.RawMessage(data) - return nil -} - -func (d *DetectLanguageResponseResultsItem) String() string { - if len(d._rawJSON) > 0 { - if value, err := core.StringifyJSON(d._rawJSON); err == nil { - return value - } - } - if value, err := core.StringifyJSON(d); err == nil { - return value - } - return fmt.Sprintf("%#v", d) -} - type DetokenizeResponse struct { // A string representing the list of tokens. Text string `json:"text"` @@ -2116,6 +2259,37 @@ func (e EmbedJobTruncate) Ptr() *EmbedJobTruncate { return &e } +type EmbedRequestEmbeddingTypesItem string + +const ( + EmbedRequestEmbeddingTypesItemFloat EmbedRequestEmbeddingTypesItem = "float" + EmbedRequestEmbeddingTypesItemInt8 EmbedRequestEmbeddingTypesItem = "int8" + EmbedRequestEmbeddingTypesItemUint8 EmbedRequestEmbeddingTypesItem = "uint8" + EmbedRequestEmbeddingTypesItemBinary EmbedRequestEmbeddingTypesItem = "binary" + EmbedRequestEmbeddingTypesItemUbinary EmbedRequestEmbeddingTypesItem = "ubinary" +) + +func NewEmbedRequestEmbeddingTypesItemFromString(s string) (EmbedRequestEmbeddingTypesItem, error) { + switch s { + case "float": + return EmbedRequestEmbeddingTypesItemFloat, nil + case "int8": + return EmbedRequestEmbeddingTypesItemInt8, nil + case "uint8": + return EmbedRequestEmbeddingTypesItemUint8, nil + case "binary": + return EmbedRequestEmbeddingTypesItemBinary, nil + case "ubinary": + return EmbedRequestEmbeddingTypesItemUbinary, nil + } + var t EmbedRequestEmbeddingTypesItem + return "", fmt.Errorf("%s is not a valid %T", s, t) +} + +func (e EmbedRequestEmbeddingTypesItem) Ptr() *EmbedRequestEmbeddingTypesItem { + return &e +} + // One of `NONE|START|END` to specify how the API will handle inputs longer than the maximum token length. // // Passing `START` will discard the start of the input. `END` will discard the end of the input. In both cases, input is discarded until the remaining input is exactly the maximum input token length for the model. @@ -2705,6 +2879,8 @@ func (g *GetConnectorResponse) String() string { type ListConnectorsResponse struct { Connectors []*Connector `json:"connectors,omitempty"` + // Total number of connectors. + TotalCount *float64 `json:"total_count,omitempty"` _rawJSON json.RawMessage } @@ -2831,6 +3007,36 @@ func (o *OAuthAuthorizeResponse) String() string { return fmt.Sprintf("%#v", o) } +type ParseInfo struct { + Separator *string `json:"separator,omitempty"` + Delimiter *string `json:"delimiter,omitempty"` + + _rawJSON json.RawMessage +} + +func (p *ParseInfo) UnmarshalJSON(data []byte) error { + type unmarshaler ParseInfo + var value unmarshaler + if err := json.Unmarshal(data, &value); err != nil { + return err + } + *p = ParseInfo(value) + p._rawJSON = json.RawMessage(data) + return nil +} + +func (p *ParseInfo) String() string { + if len(p._rawJSON) > 0 { + if value, err := core.StringifyJSON(p._rawJSON); err == nil { + return value + } + } + if value, err := core.StringifyJSON(p); err == nil { + return value + } + return fmt.Sprintf("%#v", p) +} + type RerankRequestDocumentsItem struct { typeName string String string @@ -3085,7 +3291,8 @@ type SingleGenerationInStream struct { // Full text of the generation. Text string `json:"text"` // Refers to the nth generation. Only present when `num_generations` is greater than zero. - Index *int `json:"index,omitempty"` + Index *int `json:"index,omitempty"` + FinishReason FinishReason `json:"finish_reason,omitempty"` _rawJSON json.RawMessage } From dc5dacc12d2e7c5f219235af2489055a8f25ee02 Mon Sep 17 00:00:00 2001 From: Billy Trend Date: Wed, 7 Feb 2024 13:56:03 -0600 Subject: [PATCH 2/2] Add e2e tests --- .fernignore | 3 +- .github/workflows/e2e.yml | 25 +++ tests/sdk_test.go | 348 ++++++++++++++++++++++++++++++++++++++ 3 files changed, 375 insertions(+), 1 deletion(-) create mode 100644 .github/workflows/e2e.yml create mode 100644 tests/sdk_test.go diff --git a/.fernignore b/.fernignore index a1ebaab..8e010eb 100644 --- a/.fernignore +++ b/.fernignore @@ -1,4 +1,5 @@ # Specify files that shouldn't be modified by Fern README.md banner.png -LICENSE \ No newline at end of file +LICENSE +.github/workflows/e2e.yml \ No newline at end of file diff --git a/.github/workflows/e2e.yml b/.github/workflows/e2e.yml new file mode 100644 index 0000000..51477be --- /dev/null +++ b/.github/workflows/e2e.yml @@ -0,0 +1,25 @@ +name: CI +on: + pull_request: {} +jobs: + build: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v2 + - name: Setup Go 1.x.x + uses: actions/setup-go@v4 + with: + go-version: 1.x.x + - name: Install testing dependencies here so we dont have to edit the go.mod file + run: | + go get . + go get golang.org/x/tools/go/pointer@v0.1.0-deprecated + go get golang.org/x/sys@v0.8.0 + go get golang.org/x/tools@v0.9.2-0.20230531220058-a260315e300a + - name: Build + run: go build -v ./... + - name: Test with the Go CLI + run: go test -v ./... + env: + COHERE_API_KEY: ${{ secrets.COHERE_API_KEY }} + diff --git a/tests/sdk_test.go b/tests/sdk_test.go new file mode 100644 index 0000000..b2a8faf --- /dev/null +++ b/tests/sdk_test.go @@ -0,0 +1,348 @@ +package tests + +import ( + "context" + "errors" + "io" + "os" + "strings" + "testing" + + cohere "github.com/cohere-ai/cohere-go/v2" + client "github.com/cohere-ai/cohere-go/v2/client" + "github.com/stretchr/testify/require" +) + +type MyReader struct { + io.Reader + name string +} + +func (m *MyReader) Name() string { + return m.name +} + +func strPointer(s string) *string { + return &s +} + +func TestNewClient(t *testing.T) { + client := client.NewClient(client.WithToken(os.Getenv("COHERE_API_KEY"))) + + t.Run("TestGenerate", func(t *testing.T) { + prediction, err := client.Generate( + context.TODO(), + &cohere.GenerateRequest{ + Prompt: "count with me!", + }, + ) + + require.NoError(t, err) + print(prediction) + }) + + t.Run("TestGenerateStream", func(t *testing.T) { + stream, err := client.GenerateStream( + context.TODO(), + &cohere.GenerateStreamRequest{ + Prompt: "Cohere is", + }, + ) + + require.NoError(t, err) + + // Make sure to close the stream when you're done reading. + // This is easily handled with defer. + defer stream.Close() + + for { + message, err := stream.Recv() + + if errors.Is(err, io.EOF) { + // An io.EOF error means the server is done sending messages + // and should be treated as a success. + break + } + + if message.TextGeneration != nil { + print(message.TextGeneration.Text) + } + } + }) + + // Test Chat + t.Run("TestChat", func(t *testing.T) { + chat, err := client.Chat( + context.TODO(), + &cohere.ChatRequest{ + Message: "2", + }, + ) + + require.NoError(t, err) + print(chat) + }) + + // Test ChatStream + t.Run("TestChatStream", func(t *testing.T) { + stream, err := client.ChatStream( + context.TODO(), + &cohere.ChatStreamRequest{ + Message: "Cohere is", + }, + ) + + require.NoError(t, err) + + // Make sure to close the stream when you're done reading. + // This is easily handled with defer. + defer stream.Close() + + for { + message, err := stream.Recv() + + if errors.Is(err, io.EOF) { + // An io.EOF error means the server is done sending messages + // and should be treated as a success. + break + } + + if message.TextGeneration != nil { + print(message.TextGeneration.Text) + } + } + }) + + t.Run("TestClassify", func(t *testing.T) { + classify, err := client.Classify( + context.TODO(), + &cohere.ClassifyRequest{ + Examples: []*cohere.ClassifyExample{ + { + Text: strPointer("orange"), + Label: strPointer("fruit"), + }, + { + Text: strPointer("pear"), + Label: strPointer("fruit"), + }, + { + Text: strPointer("lettuce"), + Label: strPointer("vegetable"), + }, + { + Text: strPointer("cauliflower"), + Label: strPointer("vegetable"), + }, + }, + Inputs: []string{"Abiu"}, + }, + ) + + require.NoError(t, err) + print(classify) + }) + + t.Run("TestTokenizeDetokenize", func(t *testing.T) { + str := "token mctoken face" + + tokenise, err := client.Tokenize( + context.TODO(), + &cohere.TokenizeRequest{ + Text: str, + Model: strPointer("base"), + }, + ) + + require.NoError(t, err) + print(tokenise) + + detokenise, err := client.Detokenize( + context.TODO(), + &cohere.DetokenizeRequest{ + Tokens: tokenise.Tokens, + }) + + require.NoError(t, err) + print(detokenise) + + require.Equal(t, str, detokenise.Text) + }) + + t.Run("TestSummarize", func(t *testing.T) { + summarise, err := client.Summarize( + context.TODO(), + &cohere.SummarizeRequest{ + Text: "the quick brown fox jumped over the lazy dog and then the dog jumped over the fox the quick brown fox jumped over the lazy dog the quick brown fox jumped over the lazy dog the quick brown fox jumped over the lazy dog the quick brown fox jumped over the lazy dog", + }) + + require.NoError(t, err) + print(summarise) + }) + + t.Run("TestRerank", func(t *testing.T) { + rerank, err := client.Rerank( + context.TODO(), + &cohere.RerankRequest{ + Query: "What is the capital of the United States?", + Documents: []*cohere.RerankRequestDocumentsItem{ + cohere.NewRerankRequestDocumentsItemFromString("Carson City is the capital city of the American state of Nevada."), + cohere.NewRerankRequestDocumentsItemFromString("The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean. Its capital is Saipan."), + cohere.NewRerankRequestDocumentsItemFromString("Washington, D.C. (also known as simply Washington or D.C., and officially as the District of Columbia) is the capital of the United States. It is a federal district."), + cohere.NewRerankRequestDocumentsItemFromString("Capital punishment (the death penalty) has existed in the United States since beforethe United States was a country. As of 2017, capital punishment is legal in 30 of the 50 states."), + }, + }) + + require.NoError(t, err) + print(rerank) + }) + + t.Run("TestEmbed", func(t *testing.T) { + embed, err := client.Embed( + context.TODO(), + &cohere.EmbedRequest{ + Texts: []string{"hello", "goodbye"}, + Model: strPointer("embed-english-v3.0"), + InputType: cohere.EmbedInputTypeSearchDocument.Ptr(), + }) + + require.NoError(t, err) + print(embed) + }) + + t.Run("TestCreateDataset", func(t *testing.T) { + t.Skip("While we have issues with dataset upload") + + dataset, err := client.Datasets.Create( + context.TODO(), + &MyReader{Reader: strings.NewReader(`{"text": "The quick brown fox jumps over the lazy dog"}`), name: "test.jsonl"}, + &MyReader{Reader: strings.NewReader(""), name: "a.jsonl"}, + &cohere.DatasetsCreateRequest{ + Name: strPointer("prompt-completion-dataset"), + Type: cohere.DatasetTypeEmbedResult.Ptr(), + }) + + require.NoError(t, err) + print(dataset) + }) + + t.Run("TestListDatasets", func(t *testing.T) { + datasets, err := client.Datasets.List( + context.TODO(), + &cohere.DatasetsListRequest{}) + + require.NoError(t, err) + print(datasets) + }) + + t.Run("TestGetDatasetUsage", func(t *testing.T) { + t.Skip("While we have issues with dataset upload") + dataset_usage, err := client.Datasets.GetUsage(context.TODO()) + + require.NoError(t, err) + print(dataset_usage) + }) + + t.Run("TestGetDataset", func(t *testing.T) { + t.Skip("While we have issues with dataset upload") + dataset, err := client.Datasets.Get(context.TODO(), "id") + + require.NoError(t, err) + print(dataset) + }) + + t.Run("TestUpdateDataset", func(t *testing.T) { + t.Skip("While we have issues with dataset upload") + _, err := client.Datasets.Delete(context.TODO(), "id") + require.NoError(t, err) + }) + + t.Run("TestCreateEmbedJob", func(t *testing.T) { + t.Skip("While we have issues with dataset upload") + job, err := client.EmbedJobs.Create( + context.TODO(), + &cohere.CreateEmbedJobRequest{ + DatasetId: "id", + InputType: cohere.EmbedInputTypeSearchDocument, + }) + + require.NoError(t, err) + print(job) + }) + + t.Run("TestListEmbedJobs", func(t *testing.T) { + embed_jobs, err := client.EmbedJobs.List(context.TODO()) + + require.NoError(t, err) + print(embed_jobs) + }) + + t.Run("TestGetEmbedJob", func(t *testing.T) { + t.Skip("While we have issues with dataset upload") + embed_job, err := client.EmbedJobs.Get(context.TODO(), "id") + + require.NoError(t, err) + print(embed_job) + }) + + t.Run("TestCancelEmbedJob", func(t *testing.T) { + t.Skip("While we have issues with dataset upload") + err := client.EmbedJobs.Cancel(context.TODO(), "id") + + require.NoError(t, err) + }) + + t.Run("TestConnectorCRUD", func(t *testing.T) { + connector, err := client.Connectors.Create( + context.TODO(), + &cohere.CreateConnectorRequest{ + Name: "Example connector", + Url: "https://dummy-connector-o5btz7ucgq-uc.a.run.app/search", + ServiceAuth: &cohere.CreateConnectorServiceAuth{ + Token: "dummy-connector-token", + Type: "bearer", + }, + }) + + require.NoError(t, err) + print(connector) + + updated_connector, err := client.Connectors.Update( + context.TODO(), + connector.Connector.Id, + &cohere.UpdateConnectorRequest{ + Name: strPointer("Example connector renamed"), + }) + + require.NoError(t, err) + print(updated_connector) + + my_connector, err := client.Connectors.Get(context.TODO(), connector.Connector.Id) + + require.NoError(t, err) + print(my_connector) + + connectors, err := client.Connectors.List( + context.TODO(), + &cohere.ConnectorsListRequest{}) + + require.NoError(t, err) + print(connectors) + + oauth, err := client.Connectors.OAuthAuthorize( + context.TODO(), + connector.Connector.Id, + &cohere.ConnectorsOAuthAuthorizeRequest{ + AfterTokenRedirect: strPointer("https://test.com"), + }) + + // find a way to test this + require.Error(t, err) + print(oauth) + + delete, err := client.Connectors.Delete(context.TODO(), connector.Connector.Id) + + require.NoError(t, err) + print(delete) + }) +}