diff --git a/_examples/text_classification/main.go b/_examples/text_classification/main.go new file mode 100644 index 0000000..a91be43 --- /dev/null +++ b/_examples/text_classification/main.go @@ -0,0 +1,24 @@ +package main + +import ( + "context" + "fmt" + "log" + "os" + + huggingface "github.com/hupe1980/go-huggingface" +) + +func main() { + ic := huggingface.NewInferenceClient(os.Getenv("HUGGINGFACEHUB_API_TOKEN")) + + res, err := ic.TextClassification(context.Background(), &huggingface.TextClassificationRequest{ + Inputs: "The answer to the universe is 42", + //Model: "deepset/deberta-v3-base-injection", // overwrite recommended model + }) + if err != nil { + log.Fatal(err) + } + + fmt.Println(res[0]) +} diff --git a/text_classification.go b/text_classification.go new file mode 100644 index 0000000..5c3568e --- /dev/null +++ b/text_classification.go @@ -0,0 +1,45 @@ +package gohuggingface + +import ( + "context" + "encoding/json" + "errors" +) + +// TextClassificationRequest represents a request for text classification. +type TextClassificationRequest struct { + // Inputs is the string to be generated from. + Inputs string `json:"inputs"` + // Options represents optional settings for the classification. + Options Options `json:"options,omitempty"` + // Model is the name of the model to use for classification. + Model string `json:"-"` +} + +// TextClassificationResponse represents a response for text classification. +type TextClassificationResponse [][]struct { + // Label is the label for the class (model-specific). + Label string `json:"label,omitempty"` + // Score is a float that represents how likely it is that the text belongs to this class. + Score float32 `json:"score,omitempty"` +} + +// TextClassification performs text classification using the provided request. +func (ic *InferenceClient) TextClassification(ctx context.Context, req *TextClassificationRequest) (TextClassificationResponse, error) { + // Check if inputs are provided. + if len(req.Inputs) == 0 { + return nil, errors.New("inputs are required") + } + + body, err := ic.post(ctx, req.Model, "text-classification", req) + if err != nil { + return nil, err + } + + textClassificationResponse := TextClassificationResponse{} + if err := json.Unmarshal(body, &textClassificationResponse); err != nil { + return nil, err + } + + return textClassificationResponse, nil +}