Skip to content

Commit

Permalink
Add sentence-similarity
Browse files Browse the repository at this point in the history
  • Loading branch information
hupe1980 committed Jan 7, 2024
1 parent 0560532 commit 15f0a87
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 0 deletions.
27 changes: 27 additions & 0 deletions _examples/sentence_similarity/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
package main

import (
"context"
"fmt"
"log"
"os"

huggingface "github.com/hupe1980/go-huggingface"
)

func main() {
ic := huggingface.NewInferenceClient(os.Getenv("HUGGINGFACEHUB_API_TOKEN"))

res, err := ic.SentenceSimilarity(context.Background(), &huggingface.SentenceSimilarityRequest{
Inputs: huggingface.SentenceSimilarityInputs{
SourceSentence: "That is a happy person",
Sentences: []string{"That is a happy dog", "That is a very happy person", "Today is a sunny day"},
},
Model: "sentence-transformers/all-MiniLM-L6-v2",
})
if err != nil {
log.Fatal(err)
}

fmt.Println(res)
}
45 changes: 45 additions & 0 deletions sentence_similarity.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
package gohuggingface

import (
"context"
"encoding/json"
"errors"
)

// SentenceSimilarityInputs represents the inputs for sentence similarity computation.
type SentenceSimilarityInputs struct {
SourceSentence string `json:"source_sentence"`
Sentences []string `json:"sentences"`
}

// SentenceSimilarityRequest represents a request for sentence similarity computation.
type SentenceSimilarityRequest struct {
Inputs SentenceSimilarityInputs `json:"inputs"`
Options Options `json:"options,omitempty"`
Model string `json:"-"`
}

// SentenceSimilarityResponse represents the response for a sentence similarity computation request.
type SentenceSimilarityResponse []float32

// SentenceSimilarity sends a sentence similarity computation request to the InferenceClient
// and returns the sentence similarity response.
func (ic *InferenceClient) SentenceSimilarity(ctx context.Context, req *SentenceSimilarityRequest) (SentenceSimilarityResponse, error) {
if len(req.Inputs.SourceSentence) == 0 || len(req.Inputs.Sentences) == 0 {
return nil, errors.New("sourceSentence and sentences are required")
}

body, err := ic.post(ctx, req.Model, "sentence-similarity", req)

if err != nil {
return nil, err
}

sentenceSimilarityResponse := SentenceSimilarityResponse{}

if err := json.Unmarshal(body, &sentenceSimilarityResponse); err != nil {
return nil, err
}

return sentenceSimilarityResponse, nil
}

0 comments on commit 15f0a87

Please sign in to comment.