diff --git a/_examples/table_question_answering/main.go b/_examples/table_question_answering/main.go new file mode 100644 index 0000000..756965d --- /dev/null +++ b/_examples/table_question_answering/main.go @@ -0,0 +1,33 @@ +package main + +import ( + "context" + "fmt" + "log" + "os" + + huggingface "github.com/hupe1980/go-huggingface" +) + +func main() { + ic := huggingface.NewInferenceClient(os.Getenv("HUGGINGFACEHUB_API_TOKEN")) + + res, err := ic.TableQuestionAnswering(context.Background(), &huggingface.TableQuestionAnsweringRequest{ + Inputs: huggingface.TableQuestionAnsweringInputs{ + Query: "How many stars does the transformers repository have?", + Table: map[string][]string{ + "Repository": {"Transformers", "Datasets", "Tokenizers"}, + "Stars": {"36542", "4512", "3934"}, + "Contributors": {"651", "77", "34"}, + }, + }, + }) + if err != nil { + log.Fatal(err) + } + + fmt.Println("Answer:", res.Answer) + fmt.Println("Coordinates:", res.Coordinates) + fmt.Println("Cells:", res.Cells) + fmt.Println("Aggregator:", res.Aggregator) +} diff --git a/huggingface.go b/huggingface.go index 43206b3..4179563 100644 --- a/huggingface.go +++ b/huggingface.go @@ -12,6 +12,7 @@ import ( ) var ( + // recommendedModels stores the recommended models for each task. recommendedModels map[string]string ) @@ -170,6 +171,31 @@ func (ic *InferenceClient) QuestionAnswering(ctx context.Context, req *QuestionA return questionAnsweringResponse, nil } +// TableQuestionAnswering performs table-based question answering using the specified model. +// It sends a POST request to the Hugging Face inference endpoint with the provided inputs. +// The response contains the answer or an error if the request fails. +func (ic *InferenceClient) TableQuestionAnswering(ctx context.Context, req *TableQuestionAnsweringRequest) (*TableQuestionAnsweringResponse, error) { + if req.Inputs.Query == "" { + return nil, errors.New("query is required") + } + + if req.Inputs.Table == nil { + return nil, errors.New("table is required") + } + + body, err := ic.post(ctx, req.Model, "table-question-answering", req) + if err != nil { + return nil, err + } + + tablequestionAnsweringResponse := TableQuestionAnsweringResponse{} + if err := json.Unmarshal(body, &tablequestionAnsweringResponse); err != nil { + return nil, err + } + + return &tablequestionAnsweringResponse, nil +} + // FillMask performs masked language modeling using the specified model. // It sends a POST request to the Hugging Face inference endpoint with the provided inputs. // The response contains the generated text with the masked tokens filled or an error if the request fails. diff --git a/table_question_answering.go b/table_question_answering.go new file mode 100644 index 0000000..aff30fb --- /dev/null +++ b/table_question_answering.go @@ -0,0 +1,33 @@ +package gohuggingface + +// Request structure for table question answering model +type TableQuestionAnsweringRequest struct { + Inputs TableQuestionAnsweringInputs `json:"inputs"` + Options Options `json:"options,omitempty"` + Model string `json:"-"` +} + +type TableQuestionAnsweringInputs struct { + // (Required) The query in plain text that you want to ask the table + Query string `json:"query"` + + // (Required) A table of data represented as a dict of list where entries + // are headers and the lists are all the values, all lists must + // have the same size. + Table map[string][]string `json:"table"` +} + +// Response structure for table question answering model +type TableQuestionAnsweringResponse struct { + // The plaintext answer + Answer string `json:"answer,omitempty"` + + // A list of coordinates of the cells references in the answer + Coordinates [][]int `json:"coordinates,omitempty"` + + // A list of coordinates of the cells contents + Cells []string `json:"cells,omitempty"` + + // The aggregator used to get the answer + Aggregator string `json:"aggregator,omitempty"` +}