Skip to content

Commit

Permalink
Add table question answering
Browse files Browse the repository at this point in the history
  • Loading branch information
hupe1980 committed Jun 26, 2023
1 parent 6a94b33 commit e68b990
Show file tree
Hide file tree
Showing 3 changed files with 92 additions and 0 deletions.
33 changes: 33 additions & 0 deletions _examples/table_question_answering/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
package main

import (
"context"
"fmt"
"log"
"os"

huggingface "github.com/hupe1980/go-huggingface"
)

func main() {
ic := huggingface.NewInferenceClient(os.Getenv("HUGGINGFACEHUB_API_TOKEN"))

res, err := ic.TableQuestionAnswering(context.Background(), &huggingface.TableQuestionAnsweringRequest{
Inputs: huggingface.TableQuestionAnsweringInputs{
Query: "How many stars does the transformers repository have?",
Table: map[string][]string{
"Repository": {"Transformers", "Datasets", "Tokenizers"},
"Stars": {"36542", "4512", "3934"},
"Contributors": {"651", "77", "34"},
},
},
})
if err != nil {
log.Fatal(err)
}

fmt.Println("Answer:", res.Answer)
fmt.Println("Coordinates:", res.Coordinates)
fmt.Println("Cells:", res.Cells)
fmt.Println("Aggregator:", res.Aggregator)
}
26 changes: 26 additions & 0 deletions huggingface.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
)

var (
// recommendedModels stores the recommended models for each task.
recommendedModels map[string]string
)

Expand Down Expand Up @@ -170,6 +171,31 @@ func (ic *InferenceClient) QuestionAnswering(ctx context.Context, req *QuestionA
return questionAnsweringResponse, nil
}

// TableQuestionAnswering performs table-based question answering using the specified model.
// It sends a POST request to the Hugging Face inference endpoint with the provided inputs.
// The response contains the answer or an error if the request fails.
func (ic *InferenceClient) TableQuestionAnswering(ctx context.Context, req *TableQuestionAnsweringRequest) (*TableQuestionAnsweringResponse, error) {
if req.Inputs.Query == "" {
return nil, errors.New("query is required")
}

if req.Inputs.Table == nil {
return nil, errors.New("table is required")
}

body, err := ic.post(ctx, req.Model, "table-question-answering", req)
if err != nil {
return nil, err
}

tablequestionAnsweringResponse := TableQuestionAnsweringResponse{}
if err := json.Unmarshal(body, &tablequestionAnsweringResponse); err != nil {
return nil, err
}

return &tablequestionAnsweringResponse, nil
}

// FillMask performs masked language modeling using the specified model.
// It sends a POST request to the Hugging Face inference endpoint with the provided inputs.
// The response contains the generated text with the masked tokens filled or an error if the request fails.
Expand Down
33 changes: 33 additions & 0 deletions table_question_answering.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
package gohuggingface

// Request structure for table question answering model
type TableQuestionAnsweringRequest struct {
Inputs TableQuestionAnsweringInputs `json:"inputs"`
Options Options `json:"options,omitempty"`
Model string `json:"-"`
}

type TableQuestionAnsweringInputs struct {
// (Required) The query in plain text that you want to ask the table
Query string `json:"query"`

// (Required) A table of data represented as a dict of list where entries
// are headers and the lists are all the values, all lists must
// have the same size.
Table map[string][]string `json:"table"`
}

// Response structure for table question answering model
type TableQuestionAnsweringResponse struct {
// The plaintext answer
Answer string `json:"answer,omitempty"`

// A list of coordinates of the cells references in the answer
Coordinates [][]int `json:"coordinates,omitempty"`

// A list of coordinates of the cells contents
Cells []string `json:"cells,omitempty"`

// The aggregator used to get the answer
Aggregator string `json:"aggregator,omitempty"`
}

0 comments on commit e68b990

Please sign in to comment.