From a104c84d1552411b1ffe4146285c587c1d647e72 Mon Sep 17 00:00:00 2001 From: fern-api <115122769+fern-api[bot]@users.noreply.github.com> Date: Thu, 21 Mar 2024 21:10:22 +0000 Subject: [PATCH] SDK regeneration --- client/client.go | 21 +- core/request_option.go | 2 +- datasets.go | 4 +- errors.go | 47 +++ finetuning.go | 96 +++++ finetuning/client/client.go | 701 ++++++++++++++++++++++++++++++++++++ finetuning/doc.go | 4 + finetuning/types.go | 665 ++++++++++++++++++++++++++++++++++ types.go | 80 ++-- 9 files changed, 1587 insertions(+), 33 deletions(-) create mode 100644 finetuning.go create mode 100644 finetuning/client/client.go create mode 100644 finetuning/doc.go create mode 100644 finetuning/types.go diff --git a/client/client.go b/client/client.go index 70dd563..f8e7fc6 100644 --- a/client/client.go +++ b/client/client.go @@ -12,6 +12,7 @@ import ( core "github.com/cohere-ai/cohere-go/v2/core" datasets "github.com/cohere-ai/cohere-go/v2/datasets" embedjobs "github.com/cohere-ai/cohere-go/v2/embedjobs" + finetuningclient "github.com/cohere-ai/cohere-go/v2/finetuning/client" models "github.com/cohere-ai/cohere-go/v2/models" option "github.com/cohere-ai/cohere-go/v2/option" io "io" @@ -27,6 +28,7 @@ type Client struct { Datasets *datasets.Client Connectors *connectors.Client Models *models.Client + Finetuning *finetuningclient.Client } func NewClient(opts ...option.RequestOption) *Client { @@ -44,6 +46,7 @@ func NewClient(opts ...option.RequestOption) *Client { Datasets: datasets.NewClient(opts...), Connectors: connectors.NewClient(opts...), Models: models.NewClient(opts...), + Finetuning: finetuningclient.NewClient(opts...), } } @@ -159,7 +162,11 @@ func (c *Client) Chat( return response, nil } -// This endpoint generates realistic text conditioned on a given input. +// > 🚧 Warning +// > +// > This API is marked as "Legacy" and is no longer maintained. Follow the [migration guide](/docs/migrating-from-cogenerate-to-cochat) to start using the Chat API. +// +// Generates realistic text conditioned on a given input. func (c *Client) GenerateStream( ctx context.Context, request *v2.GenerateStreamRequest, @@ -226,7 +233,11 @@ func (c *Client) GenerateStream( ) } -// This endpoint generates realistic text conditioned on a given input. +// > 🚧 Warning +// > +// > This API is marked as "Legacy" and is no longer maintained. Follow the [migration guide](/docs/migrating-from-cogenerate-to-cochat) to start using the Chat API. +// +// Generates realistic text conditioned on a given input. func (c *Client) Generate( ctx context.Context, request *v2.GenerateRequest, @@ -501,7 +512,11 @@ func (c *Client) Classify( return response, nil } -// This endpoint generates a summary in English for a given text. +// > 🚧 Warning +// > +// > This API is marked as "Legacy" and is no longer maintained. Follow the [migration guide](/docs/migrating-from-cogenerate-to-cochat) to start using the Chat API. +// +// Generates a summary in English for a given text. func (c *Client) Summarize( ctx context.Context, request *v2.SummarizeRequest, diff --git a/core/request_option.go b/core/request_option.go index 9379a37..9586474 100644 --- a/core/request_option.go +++ b/core/request_option.go @@ -56,7 +56,7 @@ func (r *RequestOptions) cloneHeader() http.Header { headers := r.HTTPHeader.Clone() headers.Set("X-Fern-Language", "Go") headers.Set("X-Fern-SDK-Name", "github.com/cohere-ai/cohere-go/v2") - headers.Set("X-Fern-SDK-Version", "v2.6.0") + headers.Set("X-Fern-SDK-Version", "v2.7.0") return headers } diff --git a/datasets.go b/datasets.go index 4b00c96..0241639 100644 --- a/datasets.go +++ b/datasets.go @@ -36,9 +36,9 @@ type DatasetsListRequest struct { // optional filter after a date After *time.Time `json:"-" url:"after,omitempty"` // optional limit to number of results - Limit *string `json:"-" url:"limit,omitempty"` + Limit *float64 `json:"-" url:"limit,omitempty"` // optional offset to start of results - Offset *string `json:"-" url:"offset,omitempty"` + Offset *float64 `json:"-" url:"offset,omitempty"` } type DatasetsCreateResponse struct { diff --git a/errors.go b/errors.go index f825312..2679cdb 100644 --- a/errors.go +++ b/errors.go @@ -5,6 +5,7 @@ package api import ( json "encoding/json" core "github.com/cohere-ai/cohere-go/v2/core" + finetuning "github.com/cohere-ai/cohere-go/v2/finetuning" ) type BadRequestError struct { @@ -99,6 +100,29 @@ func (n *NotFoundError) Unwrap() error { return n.APIError } +type ServiceUnavailableError struct { + *core.APIError + Body *finetuning.Error +} + +func (s *ServiceUnavailableError) UnmarshalJSON(data []byte) error { + var body *finetuning.Error + if err := json.Unmarshal(data, &body); err != nil { + return err + } + s.StatusCode = 503 + s.Body = body + return nil +} + +func (s *ServiceUnavailableError) MarshalJSON() ([]byte, error) { + return json.Marshal(s.Body) +} + +func (s *ServiceUnavailableError) Unwrap() error { + return s.APIError +} + type TooManyRequestsError struct { *core.APIError Body interface{} @@ -121,3 +145,26 @@ func (t *TooManyRequestsError) MarshalJSON() ([]byte, error) { func (t *TooManyRequestsError) Unwrap() error { return t.APIError } + +type UnauthorizedError struct { + *core.APIError + Body *finetuning.Error +} + +func (u *UnauthorizedError) UnmarshalJSON(data []byte) error { + var body *finetuning.Error + if err := json.Unmarshal(data, &body); err != nil { + return err + } + u.StatusCode = 401 + u.Body = body + return nil +} + +func (u *UnauthorizedError) MarshalJSON() ([]byte, error) { + return json.Marshal(u.Body) +} + +func (u *UnauthorizedError) Unwrap() error { + return u.APIError +} diff --git a/finetuning.go b/finetuning.go new file mode 100644 index 0000000..45d9215 --- /dev/null +++ b/finetuning.go @@ -0,0 +1,96 @@ +// This file was auto-generated by Fern from our API Definition. + +package api + +import ( + json "encoding/json" + core "github.com/cohere-ai/cohere-go/v2/core" + finetuning "github.com/cohere-ai/cohere-go/v2/finetuning" + time "time" +) + +type FinetuningListEventsRequest struct { + // Maximum number of results to be returned by the server. If 0, defaults to 50. + PageSize *int `json:"-" url:"page_size,omitempty"` + // Request a specific page of the list results. + PageToken *string `json:"-" url:"page_token,omitempty"` + // Comma separated list of fields. For example: "created_at,name". The default + // sorting order is ascending. To specify descending order for a field, append + // " desc" to the field name. For example: "created_at desc,name". + // + // Supported sorting fields: + // + // - created_at (default) + OrderBy *string `json:"-" url:"order_by,omitempty"` +} + +type FinetuningListFinetunedModelsRequest struct { + // Maximum number of results to be returned by the server. If 0, defaults to 50. + PageSize *int `json:"-" url:"page_size,omitempty"` + // Request a specific page of the list results. + PageToken *string `json:"-" url:"page_token,omitempty"` + // Comma separated list of fields. For example: "created_at,name". The default + // sorting order is ascending. To specify descending order for a field, append + // " desc" to the field name. For example: "created_at desc,name". + // + // Supported sorting fields: + // + // - created_at (default) + OrderBy *string `json:"-" url:"order_by,omitempty"` +} + +type FinetuningListTrainingStepMetricsRequest struct { + // Maximum number of results to be returned by the server. If 0, defaults to 50. + PageSize *int `json:"-" url:"page_size,omitempty"` + // Request a specific page of the list results. + PageToken *string `json:"-" url:"page_token,omitempty"` +} + +type FinetuningUpdateFinetunedModelRequest struct { + // FinetunedModel name (e.g. `foobar`). + Name string `json:"name" url:"name"` + // User ID of the creator. + CreatorId *string `json:"creator_id,omitempty" url:"creator_id,omitempty"` + // Organization ID. + OrganizationId *string `json:"organization_id,omitempty" url:"organization_id,omitempty"` + // FinetunedModel settings such as dataset, hyperparameters... + Settings *finetuning.Settings `json:"settings,omitempty" url:"settings,omitempty"` + // Current stage in the life-cycle of the fine-tuned model. + Status *finetuning.Status `json:"status,omitempty" url:"status,omitempty"` + // Creation timestamp. + CreatedAt *time.Time `json:"created_at,omitempty" url:"created_at,omitempty"` + // Latest update timestamp. + UpdatedAt *time.Time `json:"updated_at,omitempty" url:"updated_at,omitempty"` + // Timestamp for the completed fine-tuning. + CompletedAt *time.Time `json:"completed_at,omitempty" url:"completed_at,omitempty"` + // Timestamp for the latest request to this fine-tuned model. + LastUsed *time.Time `json:"last_used,omitempty" url:"last_used,omitempty"` +} + +func (f *FinetuningUpdateFinetunedModelRequest) UnmarshalJSON(data []byte) error { + type unmarshaler FinetuningUpdateFinetunedModelRequest + var body unmarshaler + if err := json.Unmarshal(data, &body); err != nil { + return err + } + *f = FinetuningUpdateFinetunedModelRequest(body) + return nil +} + +func (f *FinetuningUpdateFinetunedModelRequest) MarshalJSON() ([]byte, error) { + type embed FinetuningUpdateFinetunedModelRequest + var marshaler = struct { + embed + CreatedAt *core.DateTime `json:"created_at,omitempty"` + UpdatedAt *core.DateTime `json:"updated_at,omitempty"` + CompletedAt *core.DateTime `json:"completed_at,omitempty"` + LastUsed *core.DateTime `json:"last_used,omitempty"` + }{ + embed: embed(*f), + CreatedAt: core.NewOptionalDateTime(f.CreatedAt), + UpdatedAt: core.NewOptionalDateTime(f.UpdatedAt), + CompletedAt: core.NewOptionalDateTime(f.CompletedAt), + LastUsed: core.NewOptionalDateTime(f.LastUsed), + } + return json.Marshal(marshaler) +} diff --git a/finetuning/client/client.go b/finetuning/client/client.go new file mode 100644 index 0000000..8f026c4 --- /dev/null +++ b/finetuning/client/client.go @@ -0,0 +1,701 @@ +// This file was auto-generated by Fern from our API Definition. + +package client + +import ( + bytes "bytes" + context "context" + json "encoding/json" + errors "errors" + fmt "fmt" + v2 "github.com/cohere-ai/cohere-go/v2" + core "github.com/cohere-ai/cohere-go/v2/core" + finetuning "github.com/cohere-ai/cohere-go/v2/finetuning" + option "github.com/cohere-ai/cohere-go/v2/option" + io "io" + http "net/http" +) + +type Client struct { + baseURL string + caller *core.Caller + header http.Header +} + +func NewClient(opts ...option.RequestOption) *Client { + options := core.NewRequestOptions(opts...) + return &Client{ + baseURL: options.BaseURL, + caller: core.NewCaller( + &core.CallerParams{ + Client: options.HTTPClient, + MaxAttempts: options.MaxAttempts, + }, + ), + header: options.ToHeader(), + } +} + +func (c *Client) ListFinetunedModels( + ctx context.Context, + request *v2.FinetuningListFinetunedModelsRequest, + opts ...option.RequestOption, +) (*finetuning.ListFinetunedModelsResponse, error) { + options := core.NewRequestOptions(opts...) + + baseURL := "https://api.cohere.ai/v1" + if c.baseURL != "" { + baseURL = c.baseURL + } + if options.BaseURL != "" { + baseURL = options.BaseURL + } + endpointURL := baseURL + "/" + "finetuning/finetuned-models" + + queryParams, err := core.QueryValues(request) + if err != nil { + return nil, err + } + if len(queryParams) > 0 { + endpointURL += "?" + queryParams.Encode() + } + + headers := core.MergeHeaders(c.header.Clone(), options.ToHeader()) + + errorDecoder := func(statusCode int, body io.Reader) error { + raw, err := io.ReadAll(body) + if err != nil { + return err + } + apiError := core.NewAPIError(statusCode, errors.New(string(raw))) + decoder := json.NewDecoder(bytes.NewReader(raw)) + switch statusCode { + case 400: + value := new(v2.BadRequestError) + value.APIError = apiError + if err := decoder.Decode(value); err != nil { + return apiError + } + return value + case 401: + value := new(v2.UnauthorizedError) + value.APIError = apiError + if err := decoder.Decode(value); err != nil { + return apiError + } + return value + case 403: + value := new(v2.ForbiddenError) + value.APIError = apiError + if err := decoder.Decode(value); err != nil { + return apiError + } + return value + case 404: + value := new(v2.NotFoundError) + value.APIError = apiError + if err := decoder.Decode(value); err != nil { + return apiError + } + return value + case 500: + value := new(v2.InternalServerError) + value.APIError = apiError + if err := decoder.Decode(value); err != nil { + return apiError + } + return value + case 503: + value := new(v2.ServiceUnavailableError) + value.APIError = apiError + if err := decoder.Decode(value); err != nil { + return apiError + } + return value + } + return apiError + } + + var response *finetuning.ListFinetunedModelsResponse + if err := c.caller.Call( + ctx, + &core.CallParams{ + URL: endpointURL, + Method: http.MethodGet, + MaxAttempts: options.MaxAttempts, + Headers: headers, + Client: options.HTTPClient, + Response: &response, + ErrorDecoder: errorDecoder, + }, + ); err != nil { + return nil, err + } + return response, nil +} + +func (c *Client) CreateFinetunedModel( + ctx context.Context, + request *finetuning.FinetunedModel, + opts ...option.RequestOption, +) (*finetuning.CreateFinetunedModelResponse, error) { + options := core.NewRequestOptions(opts...) + + baseURL := "https://api.cohere.ai/v1" + if c.baseURL != "" { + baseURL = c.baseURL + } + if options.BaseURL != "" { + baseURL = options.BaseURL + } + endpointURL := baseURL + "/" + "finetuning/finetuned-models" + + headers := core.MergeHeaders(c.header.Clone(), options.ToHeader()) + + errorDecoder := func(statusCode int, body io.Reader) error { + raw, err := io.ReadAll(body) + if err != nil { + return err + } + apiError := core.NewAPIError(statusCode, errors.New(string(raw))) + decoder := json.NewDecoder(bytes.NewReader(raw)) + switch statusCode { + case 400: + value := new(v2.BadRequestError) + value.APIError = apiError + if err := decoder.Decode(value); err != nil { + return apiError + } + return value + case 401: + value := new(v2.UnauthorizedError) + value.APIError = apiError + if err := decoder.Decode(value); err != nil { + return apiError + } + return value + case 403: + value := new(v2.ForbiddenError) + value.APIError = apiError + if err := decoder.Decode(value); err != nil { + return apiError + } + return value + case 404: + value := new(v2.NotFoundError) + value.APIError = apiError + if err := decoder.Decode(value); err != nil { + return apiError + } + return value + case 500: + value := new(v2.InternalServerError) + value.APIError = apiError + if err := decoder.Decode(value); err != nil { + return apiError + } + return value + case 503: + value := new(v2.ServiceUnavailableError) + value.APIError = apiError + if err := decoder.Decode(value); err != nil { + return apiError + } + return value + } + return apiError + } + + var response *finetuning.CreateFinetunedModelResponse + if err := c.caller.Call( + ctx, + &core.CallParams{ + URL: endpointURL, + Method: http.MethodPost, + MaxAttempts: options.MaxAttempts, + Headers: headers, + Client: options.HTTPClient, + Request: request, + Response: &response, + ErrorDecoder: errorDecoder, + }, + ); err != nil { + return nil, err + } + return response, nil +} + +func (c *Client) GetFinetunedModel( + ctx context.Context, + // The fine-tuned model ID. + id string, + opts ...option.RequestOption, +) (*finetuning.GetFinetunedModelResponse, error) { + options := core.NewRequestOptions(opts...) + + baseURL := "https://api.cohere.ai/v1" + if c.baseURL != "" { + baseURL = c.baseURL + } + if options.BaseURL != "" { + baseURL = options.BaseURL + } + endpointURL := fmt.Sprintf(baseURL+"/"+"finetuning/finetuned-models/%v", id) + + headers := core.MergeHeaders(c.header.Clone(), options.ToHeader()) + + errorDecoder := func(statusCode int, body io.Reader) error { + raw, err := io.ReadAll(body) + if err != nil { + return err + } + apiError := core.NewAPIError(statusCode, errors.New(string(raw))) + decoder := json.NewDecoder(bytes.NewReader(raw)) + switch statusCode { + case 400: + value := new(v2.BadRequestError) + value.APIError = apiError + if err := decoder.Decode(value); err != nil { + return apiError + } + return value + case 401: + value := new(v2.UnauthorizedError) + value.APIError = apiError + if err := decoder.Decode(value); err != nil { + return apiError + } + return value + case 403: + value := new(v2.ForbiddenError) + value.APIError = apiError + if err := decoder.Decode(value); err != nil { + return apiError + } + return value + case 404: + value := new(v2.NotFoundError) + value.APIError = apiError + if err := decoder.Decode(value); err != nil { + return apiError + } + return value + case 500: + value := new(v2.InternalServerError) + value.APIError = apiError + if err := decoder.Decode(value); err != nil { + return apiError + } + return value + case 503: + value := new(v2.ServiceUnavailableError) + value.APIError = apiError + if err := decoder.Decode(value); err != nil { + return apiError + } + return value + } + return apiError + } + + var response *finetuning.GetFinetunedModelResponse + if err := c.caller.Call( + ctx, + &core.CallParams{ + URL: endpointURL, + Method: http.MethodGet, + MaxAttempts: options.MaxAttempts, + Headers: headers, + Client: options.HTTPClient, + Response: &response, + ErrorDecoder: errorDecoder, + }, + ); err != nil { + return nil, err + } + return response, nil +} + +func (c *Client) DeleteFinetunedModel( + ctx context.Context, + // The fine-tuned model ID. + id string, + opts ...option.RequestOption, +) (finetuning.DeleteFinetunedModelResponse, error) { + options := core.NewRequestOptions(opts...) + + baseURL := "https://api.cohere.ai/v1" + if c.baseURL != "" { + baseURL = c.baseURL + } + if options.BaseURL != "" { + baseURL = options.BaseURL + } + endpointURL := fmt.Sprintf(baseURL+"/"+"finetuning/finetuned-models/%v", id) + + headers := core.MergeHeaders(c.header.Clone(), options.ToHeader()) + + errorDecoder := func(statusCode int, body io.Reader) error { + raw, err := io.ReadAll(body) + if err != nil { + return err + } + apiError := core.NewAPIError(statusCode, errors.New(string(raw))) + decoder := json.NewDecoder(bytes.NewReader(raw)) + switch statusCode { + case 400: + value := new(v2.BadRequestError) + value.APIError = apiError + if err := decoder.Decode(value); err != nil { + return apiError + } + return value + case 401: + value := new(v2.UnauthorizedError) + value.APIError = apiError + if err := decoder.Decode(value); err != nil { + return apiError + } + return value + case 403: + value := new(v2.ForbiddenError) + value.APIError = apiError + if err := decoder.Decode(value); err != nil { + return apiError + } + return value + case 404: + value := new(v2.NotFoundError) + value.APIError = apiError + if err := decoder.Decode(value); err != nil { + return apiError + } + return value + case 500: + value := new(v2.InternalServerError) + value.APIError = apiError + if err := decoder.Decode(value); err != nil { + return apiError + } + return value + case 503: + value := new(v2.ServiceUnavailableError) + value.APIError = apiError + if err := decoder.Decode(value); err != nil { + return apiError + } + return value + } + return apiError + } + + var response finetuning.DeleteFinetunedModelResponse + if err := c.caller.Call( + ctx, + &core.CallParams{ + URL: endpointURL, + Method: http.MethodDelete, + MaxAttempts: options.MaxAttempts, + Headers: headers, + Client: options.HTTPClient, + Response: &response, + ErrorDecoder: errorDecoder, + }, + ); err != nil { + return nil, err + } + return response, nil +} + +func (c *Client) UpdateFinetunedModel( + ctx context.Context, + // FinetunedModel ID. + id string, + request *v2.FinetuningUpdateFinetunedModelRequest, + opts ...option.RequestOption, +) (*finetuning.UpdateFinetunedModelResponse, error) { + options := core.NewRequestOptions(opts...) + + baseURL := "https://api.cohere.ai/v1" + if c.baseURL != "" { + baseURL = c.baseURL + } + if options.BaseURL != "" { + baseURL = options.BaseURL + } + endpointURL := fmt.Sprintf(baseURL+"/"+"finetuning/finetuned-models/%v", id) + + headers := core.MergeHeaders(c.header.Clone(), options.ToHeader()) + + errorDecoder := func(statusCode int, body io.Reader) error { + raw, err := io.ReadAll(body) + if err != nil { + return err + } + apiError := core.NewAPIError(statusCode, errors.New(string(raw))) + decoder := json.NewDecoder(bytes.NewReader(raw)) + switch statusCode { + case 400: + value := new(v2.BadRequestError) + value.APIError = apiError + if err := decoder.Decode(value); err != nil { + return apiError + } + return value + case 401: + value := new(v2.UnauthorizedError) + value.APIError = apiError + if err := decoder.Decode(value); err != nil { + return apiError + } + return value + case 403: + value := new(v2.ForbiddenError) + value.APIError = apiError + if err := decoder.Decode(value); err != nil { + return apiError + } + return value + case 404: + value := new(v2.NotFoundError) + value.APIError = apiError + if err := decoder.Decode(value); err != nil { + return apiError + } + return value + case 500: + value := new(v2.InternalServerError) + value.APIError = apiError + if err := decoder.Decode(value); err != nil { + return apiError + } + return value + case 503: + value := new(v2.ServiceUnavailableError) + value.APIError = apiError + if err := decoder.Decode(value); err != nil { + return apiError + } + return value + } + return apiError + } + + var response *finetuning.UpdateFinetunedModelResponse + if err := c.caller.Call( + ctx, + &core.CallParams{ + URL: endpointURL, + Method: http.MethodPatch, + MaxAttempts: options.MaxAttempts, + Headers: headers, + Client: options.HTTPClient, + Request: request, + Response: &response, + ErrorDecoder: errorDecoder, + }, + ); err != nil { + return nil, err + } + return response, nil +} + +func (c *Client) ListEvents( + ctx context.Context, + // The parent fine-tuned model ID. + finetunedModelId string, + request *v2.FinetuningListEventsRequest, + opts ...option.RequestOption, +) (*finetuning.ListEventsResponse, error) { + options := core.NewRequestOptions(opts...) + + baseURL := "https://api.cohere.ai/v1" + if c.baseURL != "" { + baseURL = c.baseURL + } + if options.BaseURL != "" { + baseURL = options.BaseURL + } + endpointURL := fmt.Sprintf(baseURL+"/"+"finetuning/finetuned-models/%v/events", finetunedModelId) + + queryParams, err := core.QueryValues(request) + if err != nil { + return nil, err + } + if len(queryParams) > 0 { + endpointURL += "?" + queryParams.Encode() + } + + headers := core.MergeHeaders(c.header.Clone(), options.ToHeader()) + + errorDecoder := func(statusCode int, body io.Reader) error { + raw, err := io.ReadAll(body) + if err != nil { + return err + } + apiError := core.NewAPIError(statusCode, errors.New(string(raw))) + decoder := json.NewDecoder(bytes.NewReader(raw)) + switch statusCode { + case 400: + value := new(v2.BadRequestError) + value.APIError = apiError + if err := decoder.Decode(value); err != nil { + return apiError + } + return value + case 401: + value := new(v2.UnauthorizedError) + value.APIError = apiError + if err := decoder.Decode(value); err != nil { + return apiError + } + return value + case 403: + value := new(v2.ForbiddenError) + value.APIError = apiError + if err := decoder.Decode(value); err != nil { + return apiError + } + return value + case 404: + value := new(v2.NotFoundError) + value.APIError = apiError + if err := decoder.Decode(value); err != nil { + return apiError + } + return value + case 500: + value := new(v2.InternalServerError) + value.APIError = apiError + if err := decoder.Decode(value); err != nil { + return apiError + } + return value + case 503: + value := new(v2.ServiceUnavailableError) + value.APIError = apiError + if err := decoder.Decode(value); err != nil { + return apiError + } + return value + } + return apiError + } + + var response *finetuning.ListEventsResponse + if err := c.caller.Call( + ctx, + &core.CallParams{ + URL: endpointURL, + Method: http.MethodGet, + MaxAttempts: options.MaxAttempts, + Headers: headers, + Client: options.HTTPClient, + Response: &response, + ErrorDecoder: errorDecoder, + }, + ); err != nil { + return nil, err + } + return response, nil +} + +func (c *Client) ListTrainingStepMetrics( + ctx context.Context, + // The parent fine-tuned model ID. + finetunedModelId string, + request *v2.FinetuningListTrainingStepMetricsRequest, + opts ...option.RequestOption, +) (*finetuning.ListTrainingStepMetricsResponse, error) { + options := core.NewRequestOptions(opts...) + + baseURL := "https://api.cohere.ai/v1" + if c.baseURL != "" { + baseURL = c.baseURL + } + if options.BaseURL != "" { + baseURL = options.BaseURL + } + endpointURL := fmt.Sprintf(baseURL+"/"+"finetuning/finetuned-models/%v/metrics", finetunedModelId) + + queryParams, err := core.QueryValues(request) + if err != nil { + return nil, err + } + if len(queryParams) > 0 { + endpointURL += "?" + queryParams.Encode() + } + + headers := core.MergeHeaders(c.header.Clone(), options.ToHeader()) + + errorDecoder := func(statusCode int, body io.Reader) error { + raw, err := io.ReadAll(body) + if err != nil { + return err + } + apiError := core.NewAPIError(statusCode, errors.New(string(raw))) + decoder := json.NewDecoder(bytes.NewReader(raw)) + switch statusCode { + case 400: + value := new(v2.BadRequestError) + value.APIError = apiError + if err := decoder.Decode(value); err != nil { + return apiError + } + return value + case 401: + value := new(v2.UnauthorizedError) + value.APIError = apiError + if err := decoder.Decode(value); err != nil { + return apiError + } + return value + case 403: + value := new(v2.ForbiddenError) + value.APIError = apiError + if err := decoder.Decode(value); err != nil { + return apiError + } + return value + case 404: + value := new(v2.NotFoundError) + value.APIError = apiError + if err := decoder.Decode(value); err != nil { + return apiError + } + return value + case 500: + value := new(v2.InternalServerError) + value.APIError = apiError + if err := decoder.Decode(value); err != nil { + return apiError + } + return value + case 503: + value := new(v2.ServiceUnavailableError) + value.APIError = apiError + if err := decoder.Decode(value); err != nil { + return apiError + } + return value + } + return apiError + } + + var response *finetuning.ListTrainingStepMetricsResponse + if err := c.caller.Call( + ctx, + &core.CallParams{ + URL: endpointURL, + Method: http.MethodGet, + MaxAttempts: options.MaxAttempts, + Headers: headers, + Client: options.HTTPClient, + Response: &response, + ErrorDecoder: errorDecoder, + }, + ); err != nil { + return nil, err + } + return response, nil +} diff --git a/finetuning/doc.go b/finetuning/doc.go new file mode 100644 index 0000000..c2fef42 --- /dev/null +++ b/finetuning/doc.go @@ -0,0 +1,4 @@ +// This file was auto-generated by Fern from our API Definition. + +// Finetuning API (Beta) +package finetuning diff --git a/finetuning/types.go b/finetuning/types.go new file mode 100644 index 0000000..725c633 --- /dev/null +++ b/finetuning/types.go @@ -0,0 +1,665 @@ +// This file was auto-generated by Fern from our API Definition. + +package finetuning + +import ( + json "encoding/json" + fmt "fmt" + core "github.com/cohere-ai/cohere-go/v2/core" + time "time" +) + +// The base model used for fine-tuning. +type BaseModel struct { + // The name of the base model. + Name *string `json:"name,omitempty" url:"name,omitempty"` + // read-only. The version of the base model. + Version *string `json:"version,omitempty" url:"version,omitempty"` + // The type of the base model. + BaseType BaseType `json:"base_type,omitempty" url:"base_type,omitempty"` + // The fine-tuning strategy. + Strategy *Strategy `json:"strategy,omitempty" url:"strategy,omitempty"` + + _rawJSON json.RawMessage +} + +func (b *BaseModel) UnmarshalJSON(data []byte) error { + type unmarshaler BaseModel + var value unmarshaler + if err := json.Unmarshal(data, &value); err != nil { + return err + } + *b = BaseModel(value) + b._rawJSON = json.RawMessage(data) + return nil +} + +func (b *BaseModel) String() string { + if len(b._rawJSON) > 0 { + if value, err := core.StringifyJSON(b._rawJSON); err == nil { + return value + } + } + if value, err := core.StringifyJSON(b); err == nil { + return value + } + return fmt.Sprintf("%#v", b) +} + +// The possible types of fine-tuned models. +// +// - BASE_TYPE_UNSPECIFIED: Unspecified model. +// - BASE_TYPE_GENERATIVE: Generative model. +// - BASE_TYPE_CLASSIFICATION: Classification model. +// - BASE_TYPE_RERANK: Rerank model. +// - BASE_TYPE_CHAT: Chat model. +type BaseType string + +const ( + BaseTypeBaseTypeUnspecified BaseType = "BASE_TYPE_UNSPECIFIED" + BaseTypeBaseTypeGenerative BaseType = "BASE_TYPE_GENERATIVE" + BaseTypeBaseTypeClassification BaseType = "BASE_TYPE_CLASSIFICATION" + BaseTypeBaseTypeRerank BaseType = "BASE_TYPE_RERANK" + BaseTypeBaseTypeChat BaseType = "BASE_TYPE_CHAT" +) + +func NewBaseTypeFromString(s string) (BaseType, error) { + switch s { + case "BASE_TYPE_UNSPECIFIED": + return BaseTypeBaseTypeUnspecified, nil + case "BASE_TYPE_GENERATIVE": + return BaseTypeBaseTypeGenerative, nil + case "BASE_TYPE_CLASSIFICATION": + return BaseTypeBaseTypeClassification, nil + case "BASE_TYPE_RERANK": + return BaseTypeBaseTypeRerank, nil + case "BASE_TYPE_CHAT": + return BaseTypeBaseTypeChat, nil + } + var t BaseType + return "", fmt.Errorf("%s is not a valid %T", s, t) +} + +func (b BaseType) Ptr() *BaseType { + return &b +} + +// Response to request to create a fine-tuned model. +type CreateFinetunedModelResponse struct { + // Information about the fine-tuned model. + FinetunedModel *FinetunedModel `json:"finetuned_model,omitempty" url:"finetuned_model,omitempty"` + + _rawJSON json.RawMessage +} + +func (c *CreateFinetunedModelResponse) UnmarshalJSON(data []byte) error { + type unmarshaler CreateFinetunedModelResponse + var value unmarshaler + if err := json.Unmarshal(data, &value); err != nil { + return err + } + *c = CreateFinetunedModelResponse(value) + c._rawJSON = json.RawMessage(data) + return nil +} + +func (c *CreateFinetunedModelResponse) String() string { + if len(c._rawJSON) > 0 { + if value, err := core.StringifyJSON(c._rawJSON); err == nil { + return value + } + } + if value, err := core.StringifyJSON(c); err == nil { + return value + } + return fmt.Sprintf("%#v", c) +} + +// Response to request to delete a fine-tuned model. +type DeleteFinetunedModelResponse = map[string]interface{} + +// Error is the response for any unsuccessful event. +type Error struct { + // A developer-facing error message. + Message *string `json:"message,omitempty" url:"message,omitempty"` + + _rawJSON json.RawMessage +} + +func (e *Error) UnmarshalJSON(data []byte) error { + type unmarshaler Error + var value unmarshaler + if err := json.Unmarshal(data, &value); err != nil { + return err + } + *e = Error(value) + e._rawJSON = json.RawMessage(data) + return nil +} + +func (e *Error) String() string { + if len(e._rawJSON) > 0 { + if value, err := core.StringifyJSON(e._rawJSON); err == nil { + return value + } + } + if value, err := core.StringifyJSON(e); err == nil { + return value + } + return fmt.Sprintf("%#v", e) +} + +// A change in status of a fine-tuned model. +type Event struct { + // ID of the user who initiated the event. Empty if initiated by the system. + UserId *string `json:"user_id,omitempty" url:"user_id,omitempty"` + // Status of the fine-tuned model. + Status *Status `json:"status,omitempty" url:"status,omitempty"` + // Timestamp when the event happened. + CreatedAt *time.Time `json:"created_at,omitempty" url:"created_at,omitempty"` + + _rawJSON json.RawMessage +} + +func (e *Event) UnmarshalJSON(data []byte) error { + type embed Event + var unmarshaler = struct { + embed + CreatedAt *core.DateTime `json:"created_at,omitempty"` + }{ + embed: embed(*e), + } + if err := json.Unmarshal(data, &unmarshaler); err != nil { + return err + } + *e = Event(unmarshaler.embed) + e.CreatedAt = unmarshaler.CreatedAt.TimePtr() + e._rawJSON = json.RawMessage(data) + return nil +} + +func (e *Event) MarshalJSON() ([]byte, error) { + type embed Event + var marshaler = struct { + embed + CreatedAt *core.DateTime `json:"created_at,omitempty"` + }{ + embed: embed(*e), + CreatedAt: core.NewOptionalDateTime(e.CreatedAt), + } + return json.Marshal(marshaler) +} + +func (e *Event) String() string { + if len(e._rawJSON) > 0 { + if value, err := core.StringifyJSON(e._rawJSON); err == nil { + return value + } + } + if value, err := core.StringifyJSON(e); err == nil { + return value + } + return fmt.Sprintf("%#v", e) +} + +// This resource represents a fine-tuned model. +type FinetunedModel struct { + // read-only. FinetunedModel ID. + Id *string `json:"id,omitempty" url:"id,omitempty"` + // FinetunedModel name (e.g. `foobar`). + Name string `json:"name" url:"name"` + // read-only. User ID of the creator. + CreatorId *string `json:"creator_id,omitempty" url:"creator_id,omitempty"` + // read-only. Organization ID. + OrganizationId *string `json:"organization_id,omitempty" url:"organization_id,omitempty"` + // FinetunedModel settings such as dataset, hyperparameters... + Settings *Settings `json:"settings,omitempty" url:"settings,omitempty"` + // read-only. Current stage in the life-cycle of the fine-tuned model. + Status *Status `json:"status,omitempty" url:"status,omitempty"` + // read-only. Creation timestamp. + CreatedAt *time.Time `json:"created_at,omitempty" url:"created_at,omitempty"` + // read-only. Latest update timestamp. + UpdatedAt *time.Time `json:"updated_at,omitempty" url:"updated_at,omitempty"` + // read-only. Timestamp for the completed fine-tuning. + CompletedAt *time.Time `json:"completed_at,omitempty" url:"completed_at,omitempty"` + // read-only. Timestamp for the latest request to this fine-tuned model. + LastUsed *time.Time `json:"last_used,omitempty" url:"last_used,omitempty"` + + _rawJSON json.RawMessage +} + +func (f *FinetunedModel) UnmarshalJSON(data []byte) error { + type embed FinetunedModel + var unmarshaler = struct { + embed + CreatedAt *core.DateTime `json:"created_at,omitempty"` + UpdatedAt *core.DateTime `json:"updated_at,omitempty"` + CompletedAt *core.DateTime `json:"completed_at,omitempty"` + LastUsed *core.DateTime `json:"last_used,omitempty"` + }{ + embed: embed(*f), + } + if err := json.Unmarshal(data, &unmarshaler); err != nil { + return err + } + *f = FinetunedModel(unmarshaler.embed) + f.CreatedAt = unmarshaler.CreatedAt.TimePtr() + f.UpdatedAt = unmarshaler.UpdatedAt.TimePtr() + f.CompletedAt = unmarshaler.CompletedAt.TimePtr() + f.LastUsed = unmarshaler.LastUsed.TimePtr() + f._rawJSON = json.RawMessage(data) + return nil +} + +func (f *FinetunedModel) MarshalJSON() ([]byte, error) { + type embed FinetunedModel + var marshaler = struct { + embed + CreatedAt *core.DateTime `json:"created_at,omitempty"` + UpdatedAt *core.DateTime `json:"updated_at,omitempty"` + CompletedAt *core.DateTime `json:"completed_at,omitempty"` + LastUsed *core.DateTime `json:"last_used,omitempty"` + }{ + embed: embed(*f), + CreatedAt: core.NewOptionalDateTime(f.CreatedAt), + UpdatedAt: core.NewOptionalDateTime(f.UpdatedAt), + CompletedAt: core.NewOptionalDateTime(f.CompletedAt), + LastUsed: core.NewOptionalDateTime(f.LastUsed), + } + return json.Marshal(marshaler) +} + +func (f *FinetunedModel) String() string { + if len(f._rawJSON) > 0 { + if value, err := core.StringifyJSON(f._rawJSON); err == nil { + return value + } + } + if value, err := core.StringifyJSON(f); err == nil { + return value + } + return fmt.Sprintf("%#v", f) +} + +// Response to a request to get a fine-tuned model. +type GetFinetunedModelResponse struct { + // Information about the fine-tuned model. + FinetunedModel *FinetunedModel `json:"finetuned_model,omitempty" url:"finetuned_model,omitempty"` + + _rawJSON json.RawMessage +} + +func (g *GetFinetunedModelResponse) UnmarshalJSON(data []byte) error { + type unmarshaler GetFinetunedModelResponse + var value unmarshaler + if err := json.Unmarshal(data, &value); err != nil { + return err + } + *g = GetFinetunedModelResponse(value) + g._rawJSON = json.RawMessage(data) + return nil +} + +func (g *GetFinetunedModelResponse) String() string { + if len(g._rawJSON) > 0 { + if value, err := core.StringifyJSON(g._rawJSON); err == nil { + return value + } + } + if value, err := core.StringifyJSON(g); err == nil { + return value + } + return fmt.Sprintf("%#v", g) +} + +// The fine-tuning hyperparameters. +type Hyperparameters struct { + // Stops training if the loss metric does not improve beyond the value of + // `early_stopping_threshold` after this many times of evaluation. + EarlyStoppingPatience *int `json:"early_stopping_patience,omitempty" url:"early_stopping_patience,omitempty"` + // How much the loss must improve to prevent early stopping. + EarlyStoppingThreshold *float64 `json:"early_stopping_threshold,omitempty" url:"early_stopping_threshold,omitempty"` + // The batch size is the number of training examples included in a single + // training pass. + TrainBatchSize *int `json:"train_batch_size,omitempty" url:"train_batch_size,omitempty"` + // The number of epochs to train for. + TrainEpochs *int `json:"train_epochs,omitempty" url:"train_epochs,omitempty"` + // The learning rate to be used during training. + LearningRate *float64 `json:"learning_rate,omitempty" url:"learning_rate,omitempty"` + + _rawJSON json.RawMessage +} + +func (h *Hyperparameters) UnmarshalJSON(data []byte) error { + type unmarshaler Hyperparameters + var value unmarshaler + if err := json.Unmarshal(data, &value); err != nil { + return err + } + *h = Hyperparameters(value) + h._rawJSON = json.RawMessage(data) + return nil +} + +func (h *Hyperparameters) String() string { + if len(h._rawJSON) > 0 { + if value, err := core.StringifyJSON(h._rawJSON); err == nil { + return value + } + } + if value, err := core.StringifyJSON(h); err == nil { + return value + } + return fmt.Sprintf("%#v", h) +} + +// Response to a request to list events of a fine-tuned model. +type ListEventsResponse struct { + // List of events for the fine-tuned model. + Events []*Event `json:"events,omitempty" url:"events,omitempty"` + // Pagination token to retrieve the next page of results. If the value is "", + // it means no further results for the request. + NextPageToken *string `json:"next_page_token,omitempty" url:"next_page_token,omitempty"` + // Total count of results. + TotalSize *int `json:"total_size,omitempty" url:"total_size,omitempty"` + + _rawJSON json.RawMessage +} + +func (l *ListEventsResponse) UnmarshalJSON(data []byte) error { + type unmarshaler ListEventsResponse + var value unmarshaler + if err := json.Unmarshal(data, &value); err != nil { + return err + } + *l = ListEventsResponse(value) + l._rawJSON = json.RawMessage(data) + return nil +} + +func (l *ListEventsResponse) String() string { + if len(l._rawJSON) > 0 { + if value, err := core.StringifyJSON(l._rawJSON); err == nil { + return value + } + } + if value, err := core.StringifyJSON(l); err == nil { + return value + } + return fmt.Sprintf("%#v", l) +} + +// Response to a request to list fine-tuned models. +type ListFinetunedModelsResponse struct { + // List of fine-tuned models matching the request. + FinetunedModels []*FinetunedModel `json:"finetuned_models,omitempty" url:"finetuned_models,omitempty"` + // Pagination token to retrieve the next page of results. If the value is "", + // it means no further results for the request. + NextPageToken *string `json:"next_page_token,omitempty" url:"next_page_token,omitempty"` + // Total count of results. + TotalSize *int `json:"total_size,omitempty" url:"total_size,omitempty"` + + _rawJSON json.RawMessage +} + +func (l *ListFinetunedModelsResponse) UnmarshalJSON(data []byte) error { + type unmarshaler ListFinetunedModelsResponse + var value unmarshaler + if err := json.Unmarshal(data, &value); err != nil { + return err + } + *l = ListFinetunedModelsResponse(value) + l._rawJSON = json.RawMessage(data) + return nil +} + +func (l *ListFinetunedModelsResponse) String() string { + if len(l._rawJSON) > 0 { + if value, err := core.StringifyJSON(l._rawJSON); err == nil { + return value + } + } + if value, err := core.StringifyJSON(l); err == nil { + return value + } + return fmt.Sprintf("%#v", l) +} + +// Response to a request to list training-step metrics of a fine-tuned model. +type ListTrainingStepMetricsResponse struct { + // The metrics for each step the evaluation was run on. + StepMetrics []*TrainingStepMetrics `json:"step_metrics,omitempty" url:"step_metrics,omitempty"` + // Pagination token to retrieve the next page of results. If the value is "", + // it means no further results for the request. + NextPageToken *string `json:"next_page_token,omitempty" url:"next_page_token,omitempty"` + + _rawJSON json.RawMessage +} + +func (l *ListTrainingStepMetricsResponse) UnmarshalJSON(data []byte) error { + type unmarshaler ListTrainingStepMetricsResponse + var value unmarshaler + if err := json.Unmarshal(data, &value); err != nil { + return err + } + *l = ListTrainingStepMetricsResponse(value) + l._rawJSON = json.RawMessage(data) + return nil +} + +func (l *ListTrainingStepMetricsResponse) String() string { + if len(l._rawJSON) > 0 { + if value, err := core.StringifyJSON(l._rawJSON); err == nil { + return value + } + } + if value, err := core.StringifyJSON(l); err == nil { + return value + } + return fmt.Sprintf("%#v", l) +} + +// The configuration used for fine-tuning. +type Settings struct { + // The base model to fine-tune. + BaseModel *BaseModel `json:"base_model,omitempty" url:"base_model,omitempty"` + // The data used for training and evaluating the fine-tuned model. + DatasetId string `json:"dataset_id" url:"dataset_id"` + // Fine-tuning hyper-parameters. + Hyperparameters *Hyperparameters `json:"hyperparameters,omitempty" url:"hyperparameters,omitempty"` + // read-only. Whether the model is single-label or multi-label (only for classification). + MultiLabel *bool `json:"multi_label,omitempty" url:"multi_label,omitempty"` + + _rawJSON json.RawMessage +} + +func (s *Settings) UnmarshalJSON(data []byte) error { + type unmarshaler Settings + var value unmarshaler + if err := json.Unmarshal(data, &value); err != nil { + return err + } + *s = Settings(value) + s._rawJSON = json.RawMessage(data) + return nil +} + +func (s *Settings) String() string { + if len(s._rawJSON) > 0 { + if value, err := core.StringifyJSON(s._rawJSON); err == nil { + return value + } + } + if value, err := core.StringifyJSON(s); err == nil { + return value + } + return fmt.Sprintf("%#v", s) +} + +// The possible stages of a fine-tuned model life-cycle. +// +// - STATUS_UNSPECIFIED: Unspecified status. +// - STATUS_FINETUNING: The fine-tuned model is being fine-tuned. +// - STATUS_DEPLOYING_API: The fine-tuned model is being deployed. +// - STATUS_READY: The fine-tuned model is ready to receive requests. +// - STATUS_FAILED: The fine-tuned model failed. +// - STATUS_DELETED: The fine-tuned model was deleted. +// - STATUS_TEMPORARILY_OFFLINE: The fine-tuned model is temporarily unavailable. +// - STATUS_PAUSED: The fine-tuned model is paused (Vanilla only). +// - STATUS_QUEUED: The fine-tuned model is queued for training. +type Status string + +const ( + StatusStatusUnspecified Status = "STATUS_UNSPECIFIED" + StatusStatusFinetuning Status = "STATUS_FINETUNING" + StatusStatusDeployingApi Status = "STATUS_DEPLOYING_API" + StatusStatusReady Status = "STATUS_READY" + StatusStatusFailed Status = "STATUS_FAILED" + StatusStatusDeleted Status = "STATUS_DELETED" + StatusStatusTemporarilyOffline Status = "STATUS_TEMPORARILY_OFFLINE" + StatusStatusPaused Status = "STATUS_PAUSED" + StatusStatusQueued Status = "STATUS_QUEUED" +) + +func NewStatusFromString(s string) (Status, error) { + switch s { + case "STATUS_UNSPECIFIED": + return StatusStatusUnspecified, nil + case "STATUS_FINETUNING": + return StatusStatusFinetuning, nil + case "STATUS_DEPLOYING_API": + return StatusStatusDeployingApi, nil + case "STATUS_READY": + return StatusStatusReady, nil + case "STATUS_FAILED": + return StatusStatusFailed, nil + case "STATUS_DELETED": + return StatusStatusDeleted, nil + case "STATUS_TEMPORARILY_OFFLINE": + return StatusStatusTemporarilyOffline, nil + case "STATUS_PAUSED": + return StatusStatusPaused, nil + case "STATUS_QUEUED": + return StatusStatusQueued, nil + } + var t Status + return "", fmt.Errorf("%s is not a valid %T", s, t) +} + +func (s Status) Ptr() *Status { + return &s +} + +// The possible strategy used to serve a fine-tuned models. +// +// - STRATEGY_UNSPECIFIED: Unspecified strategy. +// - STRATEGY_VANILLA: Serve the fine-tuned model on a dedicated GPU. +// - STRATEGY_TFEW: Serve the fine-tuned model on a shared GPU. +type Strategy string + +const ( + StrategyStrategyUnspecified Strategy = "STRATEGY_UNSPECIFIED" + StrategyStrategyVanilla Strategy = "STRATEGY_VANILLA" + StrategyStrategyTfew Strategy = "STRATEGY_TFEW" +) + +func NewStrategyFromString(s string) (Strategy, error) { + switch s { + case "STRATEGY_UNSPECIFIED": + return StrategyStrategyUnspecified, nil + case "STRATEGY_VANILLA": + return StrategyStrategyVanilla, nil + case "STRATEGY_TFEW": + return StrategyStrategyTfew, nil + } + var t Strategy + return "", fmt.Errorf("%s is not a valid %T", s, t) +} + +func (s Strategy) Ptr() *Strategy { + return &s +} + +// The evaluation metrics at a given step of the training of a fine-tuned model. +type TrainingStepMetrics struct { + // Creation timestamp. + CreatedAt *time.Time `json:"created_at,omitempty" url:"created_at,omitempty"` + // Step number. + StepNumber *int `json:"step_number,omitempty" url:"step_number,omitempty"` + // Map of names and values for each evaluation metrics. + Metrics map[string]float64 `json:"metrics,omitempty" url:"metrics,omitempty"` + + _rawJSON json.RawMessage +} + +func (t *TrainingStepMetrics) UnmarshalJSON(data []byte) error { + type embed TrainingStepMetrics + var unmarshaler = struct { + embed + CreatedAt *core.DateTime `json:"created_at,omitempty"` + }{ + embed: embed(*t), + } + if err := json.Unmarshal(data, &unmarshaler); err != nil { + return err + } + *t = TrainingStepMetrics(unmarshaler.embed) + t.CreatedAt = unmarshaler.CreatedAt.TimePtr() + t._rawJSON = json.RawMessage(data) + return nil +} + +func (t *TrainingStepMetrics) MarshalJSON() ([]byte, error) { + type embed TrainingStepMetrics + var marshaler = struct { + embed + CreatedAt *core.DateTime `json:"created_at,omitempty"` + }{ + embed: embed(*t), + CreatedAt: core.NewOptionalDateTime(t.CreatedAt), + } + return json.Marshal(marshaler) +} + +func (t *TrainingStepMetrics) String() string { + if len(t._rawJSON) > 0 { + if value, err := core.StringifyJSON(t._rawJSON); err == nil { + return value + } + } + if value, err := core.StringifyJSON(t); err == nil { + return value + } + return fmt.Sprintf("%#v", t) +} + +// Response to a request to update a fine-tuned model. +type UpdateFinetunedModelResponse struct { + // Information about the fine-tuned model. + FinetunedModel *FinetunedModel `json:"finetuned_model,omitempty" url:"finetuned_model,omitempty"` + + _rawJSON json.RawMessage +} + +func (u *UpdateFinetunedModelResponse) UnmarshalJSON(data []byte) error { + type unmarshaler UpdateFinetunedModelResponse + var value unmarshaler + if err := json.Unmarshal(data, &value); err != nil { + return err + } + *u = UpdateFinetunedModelResponse(value) + u._rawJSON = json.RawMessage(data) + return nil +} + +func (u *UpdateFinetunedModelResponse) String() string { + if len(u._rawJSON) > 0 { + if value, err := core.StringifyJSON(u._rawJSON); err == nil { + return value + } + } + if value, err := core.StringifyJSON(u); err == nil { + return value + } + return fmt.Sprintf("%#v", u) +} diff --git a/types.go b/types.go index f164ab2..9f00b55 100644 --- a/types.go +++ b/types.go @@ -16,9 +16,15 @@ type ChatRequest struct { // // The name of a compatible [Cohere model](https://docs.cohere.com/docs/models) or the ID of a [fine-tuned](https://docs.cohere.com/docs/chat-fine-tuning) model. Model *string `json:"model,omitempty" url:"model,omitempty"` - // When specified, the default Cohere preamble will be replaced with the provided one. Preambles are a part of the prompt used to adjust the model's overall behavior and conversation style. + // When specified, the default Cohere preamble will be replaced with the provided one. Preambles are a part of the prompt used to adjust the model's overall behavior and conversation style, and use the `SYSTEM` role. + // + // The `SYSTEM` role is also used for the contents of the optional `chat_history=` parameter. When used with the `chat_history=` parameter it adds content throughout a conversation. Conversely, when used with the `preamble=` parameter it adds content at the start of the conversation only. Preamble *string `json:"preamble,omitempty" url:"preamble,omitempty"` - // A list of previous messages between the user and the model, meant to give the model conversational context for responding to the user's `message`. + // A list of previous messages between the user and the model, giving the model conversational context for responding to the user's `message`. + // + // Each item represents a single message in the chat history, excluding the current user turn. It has two properties: `role` and `message`. The `role` identifies the sender (`CHATBOT`, `SYSTEM`, or `USER`), while the `message` contains the text content. + // + // The chat_history parameter should not be used for `SYSTEM` messages in most cases. Instead, to add a `SYSTEM` role message at the beginning of a conversation, the `preamble` parameter should be used. ChatHistory []*ChatMessage `json:"chat_history,omitempty" url:"chat_history,omitempty"` // An alternative to `chat_history`. // @@ -76,6 +82,10 @@ type ChatRequest struct { // Ensures that only the most likely tokens, with total probability mass of `p`, are considered for generation at each step. If both `k` and `p` are enabled, `p` acts after `k`. // Defaults to `0.75`. min value of `0.01`, max value of `0.99`. P *float64 `json:"p,omitempty" url:"p,omitempty"` + // If specified, the backend will make a best effort to sample tokens deterministically, such that repeated requests with the same seed and parameters should return the same result. However, determinism cannot be totally guaranteed. + Seed *float64 `json:"seed,omitempty" url:"seed,omitempty"` + // A list of up to 5 strings that the model will use to stop generation. If the model generates a string that matches any of the strings in the list, it will stop generating tokens and return the generated text up to that point not including the stop sequence. + StopSequences []string `json:"stop_sequences,omitempty" url:"stop_sequences,omitempty"` // Defaults to `0.0`, min value of `0.0`, max value of `1.0`. // // Used to reduce repetitiveness of generated tokens. The higher the value, the stronger a penalty is applied to previously present tokens, proportional to how many times they have already appeared in the prompt or prior generation. @@ -88,21 +98,21 @@ type ChatRequest struct { RawPrompting *bool `json:"raw_prompting,omitempty" url:"raw_prompting,omitempty"` // A list of available tools (functions) that the model may suggest invoking before producing a text response. // - // When `tools` is passed, The `text` field in the response will be `""` and the `tool_calls` field in the response will be populated with a list of tool calls that need to be made. If no calls need to be made - // the `tool_calls` array will be empty. + // When `tools` is passed (without `tool_results`), the `text` field in the response will be `""` and the `tool_calls` field in the response will be populated with a list of tool calls that need to be made. If no calls need to be made, the `tool_calls` array will be empty. Tools []*Tool `json:"tools,omitempty" url:"tools,omitempty"` - // A list of results from invoking tools. Results are used to generate text and will be referenced in citations. When using `tool_results`, `tools` must be passed as well. + // A list of results from invoking tools recommended by the model in the previous chat turn. Results are used to produce a text response and will be referenced in citations. When using `tool_results`, `tools` must be passed as well. // Each tool_result contains information about how it was invoked, as well as a list of outputs in the form of dictionaries. // + // **Note**: `outputs` must be a list of objects. If your tool returns a single object (eg `{"status": 200}`), make sure to wrap it in a list. // ``` // tool_results = [ // // { // "call": { - // "name": , - // "parameters": { - // : - // } + // "name": , + // "parameters": { + // : + // } // }, // "outputs": [{ // : @@ -112,6 +122,7 @@ type ChatRequest struct { // // ] // ``` + // **Note**: Chat calls with `tool_results` should not be included in the Chat history to avoid duplication of the message text. ToolResults []*ChatRequestToolResultsItem `json:"tool_results,omitempty" url:"tool_results,omitempty"` stream bool } @@ -150,9 +161,15 @@ type ChatStreamRequest struct { // // The name of a compatible [Cohere model](https://docs.cohere.com/docs/models) or the ID of a [fine-tuned](https://docs.cohere.com/docs/chat-fine-tuning) model. Model *string `json:"model,omitempty" url:"model,omitempty"` - // When specified, the default Cohere preamble will be replaced with the provided one. Preambles are a part of the prompt used to adjust the model's overall behavior and conversation style. + // When specified, the default Cohere preamble will be replaced with the provided one. Preambles are a part of the prompt used to adjust the model's overall behavior and conversation style, and use the `SYSTEM` role. + // + // The `SYSTEM` role is also used for the contents of the optional `chat_history=` parameter. When used with the `chat_history=` parameter it adds content throughout a conversation. Conversely, when used with the `preamble=` parameter it adds content at the start of the conversation only. Preamble *string `json:"preamble,omitempty" url:"preamble,omitempty"` - // A list of previous messages between the user and the model, meant to give the model conversational context for responding to the user's `message`. + // A list of previous messages between the user and the model, giving the model conversational context for responding to the user's `message`. + // + // Each item represents a single message in the chat history, excluding the current user turn. It has two properties: `role` and `message`. The `role` identifies the sender (`CHATBOT`, `SYSTEM`, or `USER`), while the `message` contains the text content. + // + // The chat_history parameter should not be used for `SYSTEM` messages in most cases. Instead, to add a `SYSTEM` role message at the beginning of a conversation, the `preamble` parameter should be used. ChatHistory []*ChatMessage `json:"chat_history,omitempty" url:"chat_history,omitempty"` // An alternative to `chat_history`. // @@ -210,6 +227,10 @@ type ChatStreamRequest struct { // Ensures that only the most likely tokens, with total probability mass of `p`, are considered for generation at each step. If both `k` and `p` are enabled, `p` acts after `k`. // Defaults to `0.75`. min value of `0.01`, max value of `0.99`. P *float64 `json:"p,omitempty" url:"p,omitempty"` + // If specified, the backend will make a best effort to sample tokens deterministically, such that repeated requests with the same seed and parameters should return the same result. However, determinism cannot be totally guaranteed. + Seed *float64 `json:"seed,omitempty" url:"seed,omitempty"` + // A list of up to 5 strings that the model will use to stop generation. If the model generates a string that matches any of the strings in the list, it will stop generating tokens and return the generated text up to that point not including the stop sequence. + StopSequences []string `json:"stop_sequences,omitempty" url:"stop_sequences,omitempty"` // Defaults to `0.0`, min value of `0.0`, max value of `1.0`. // // Used to reduce repetitiveness of generated tokens. The higher the value, the stronger a penalty is applied to previously present tokens, proportional to how many times they have already appeared in the prompt or prior generation. @@ -222,21 +243,21 @@ type ChatStreamRequest struct { RawPrompting *bool `json:"raw_prompting,omitempty" url:"raw_prompting,omitempty"` // A list of available tools (functions) that the model may suggest invoking before producing a text response. // - // When `tools` is passed, The `text` field in the response will be `""` and the `tool_calls` field in the response will be populated with a list of tool calls that need to be made. If no calls need to be made - // the `tool_calls` array will be empty. + // When `tools` is passed (without `tool_results`), the `text` field in the response will be `""` and the `tool_calls` field in the response will be populated with a list of tool calls that need to be made. If no calls need to be made, the `tool_calls` array will be empty. Tools []*Tool `json:"tools,omitempty" url:"tools,omitempty"` - // A list of results from invoking tools. Results are used to generate text and will be referenced in citations. When using `tool_results`, `tools` must be passed as well. + // A list of results from invoking tools recommended by the model in the previous chat turn. Results are used to produce a text response and will be referenced in citations. When using `tool_results`, `tools` must be passed as well. // Each tool_result contains information about how it was invoked, as well as a list of outputs in the form of dictionaries. // + // **Note**: `outputs` must be a list of objects. If your tool returns a single object (eg `{"status": 200}`), make sure to wrap it in a list. // ``` // tool_results = [ // // { // "call": { - // "name": , - // "parameters": { - // : - // } + // "name": , + // "parameters": { + // : + // } // }, // "outputs": [{ // : @@ -246,6 +267,7 @@ type ChatStreamRequest struct { // // ] // ``` + // **Note**: Chat calls with `tool_results` should not be included in the Chat history to avoid duplication of the message text. ToolResults []*ChatStreamRequestToolResultsItem `json:"tool_results,omitempty" url:"tool_results,omitempty"` stream bool } @@ -361,6 +383,8 @@ type GenerateRequest struct { // A non-negative float that tunes the degree of randomness in generation. Lower temperatures mean less random generations. See [Temperature](/temperature-wiki) for more details. // Defaults to `0.75`, min value of `0.0`, max value of `5.0`. Temperature *float64 `json:"temperature,omitempty" url:"temperature,omitempty"` + // If specified, the backend will make a best effort to sample tokens deterministically, such that repeated requests with the same seed and parameters should return the same result. However, determinsim cannot be totally guaranteed. + Seed *float64 `json:"seed,omitempty" url:"seed,omitempty"` // Identifier of a custom preset. A preset is a combination of parameters, such as prompt, temperature etc. You can create presets in the [playground](https://dashboard.cohere.ai/playground/generate). // When a preset is specified, the `prompt` parameter becomes optional, and any included parameters will override the preset's parameters. Preset *string `json:"preset,omitempty" url:"preset,omitempty"` @@ -446,6 +470,8 @@ type GenerateStreamRequest struct { // A non-negative float that tunes the degree of randomness in generation. Lower temperatures mean less random generations. See [Temperature](/temperature-wiki) for more details. // Defaults to `0.75`, min value of `0.0`, max value of `5.0`. Temperature *float64 `json:"temperature,omitempty" url:"temperature,omitempty"` + // If specified, the backend will make a best effort to sample tokens deterministically, such that repeated requests with the same seed and parameters should return the same result. However, determinsim cannot be totally guaranteed. + Seed *float64 `json:"seed,omitempty" url:"seed,omitempty"` // Identifier of a custom preset. A preset is a combination of parameters, such as prompt, temperature etc. You can create presets in the [playground](https://dashboard.cohere.ai/playground/generate). // When a preset is specified, the `prompt` parameter becomes optional, and any included parameters will override the preset's parameters. Preset *string `json:"preset,omitempty" url:"preset,omitempty"` @@ -824,16 +850,14 @@ func (c *ChatDataMetrics) String() string { // passed to the model. type ChatDocument = map[string]string -// A single message in a chat history. Contains the role of the sender, the text contents of the message. +// Represents a single message in the chat history, excluding the current user turn. It has two properties: `role` and `message`. The `role` identifies the sender (`CHATBOT`, `SYSTEM`, or `USER`), while the `message` contains the text content. +// +// The chat_history parameter should not be used for `SYSTEM` messages in most cases. Instead, to add a `SYSTEM` role message at the beginning of a conversation, the `preamble` parameter should be used. type ChatMessage struct { - // One of CHATBOT|USER to identify who the message is coming from. + // One of `CHATBOT`, `SYSTEM`, or `USER` to identify who the message is coming from. Role ChatMessageRole `json:"role,omitempty" url:"role,omitempty"` // Contents of the chat message. Message string `json:"message" url:"message"` - // Unique identifier for the generated reply. Useful for submitting feedback. - GenerationId *string `json:"generation_id,omitempty" url:"generation_id,omitempty"` - // Unique identifier for the response. - ResponseId *string `json:"response_id,omitempty" url:"response_id,omitempty"` _rawJSON json.RawMessage } @@ -861,11 +885,12 @@ func (c *ChatMessage) String() string { return fmt.Sprintf("%#v", c) } -// One of CHATBOT|USER to identify who the message is coming from. +// One of `CHATBOT`, `SYSTEM`, or `USER` to identify who the message is coming from. type ChatMessageRole string const ( ChatMessageRoleChatbot ChatMessageRole = "CHATBOT" + ChatMessageRoleSystem ChatMessageRole = "SYSTEM" ChatMessageRoleUser ChatMessageRole = "USER" ) @@ -873,6 +898,8 @@ func NewChatMessageRoleFromString(s string) (ChatMessageRole, error) { switch s { case "CHATBOT": return ChatMessageRoleChatbot, nil + case "SYSTEM": + return ChatMessageRoleSystem, nil case "USER": return ChatMessageRoleUser, nil } @@ -4161,8 +4188,7 @@ type ToolCall struct { // Name of the tool to call. Name string `json:"name" url:"name"` // The name and value of the parameters to use when invoking a tool. - Parameters map[string]interface{} `json:"parameters,omitempty" url:"parameters,omitempty"` - GenerationId string `json:"generation_id" url:"generation_id"` + Parameters map[string]interface{} `json:"parameters,omitempty" url:"parameters,omitempty"` _rawJSON json.RawMessage }