-
Notifications
You must be signed in to change notification settings - Fork 6
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
12 changed files
with
398 additions
and
160 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
package main | ||
|
||
import ( | ||
"context" | ||
"fmt" | ||
"log" | ||
"os" | ||
|
||
huggingface "github.com/hupe1980/go-huggingface" | ||
) | ||
|
||
func main() { | ||
ic := huggingface.NewInferenceClient(os.Getenv("HUGGINGFACEHUB_API_TOKEN")) | ||
|
||
res, err := ic.Conversational(context.Background(), &huggingface.ConversationalRequest{ | ||
Inputs: huggingface.ConverstationalInputs{ | ||
PastUserInputs: []string{ | ||
"Which movie is the best ?", | ||
"Can you explain why ?", | ||
}, | ||
GeneratedResponses: []string{ | ||
"It's Die Hard for sure.", | ||
"It's the best movie ever.", | ||
}, | ||
Text: "Can you explain why ?", | ||
}, | ||
}) | ||
if err != nil { | ||
log.Fatal(err) | ||
} | ||
|
||
fmt.Println(res.GeneratedText) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
package main | ||
|
||
import ( | ||
"context" | ||
"fmt" | ||
"log" | ||
"os" | ||
|
||
huggingface "github.com/hupe1980/go-huggingface" | ||
) | ||
|
||
func main() { | ||
ic := huggingface.NewInferenceClient(os.Getenv("HUGGINGFACEHUB_API_TOKEN")) | ||
|
||
res, err := ic.FeatureExtraction(context.Background(), &huggingface.FeatureExtractionRequest{ | ||
Inputs: []string{"Hello World"}, | ||
}) | ||
if err != nil { | ||
log.Fatal(err) | ||
} | ||
|
||
fmt.Println(res[0][0]) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,99 @@ | ||
package gohuggingface | ||
|
||
import ( | ||
"context" | ||
"encoding/json" | ||
"errors" | ||
) | ||
|
||
// Used with ConversationalRequest | ||
type ConversationalParameters struct { | ||
// (Default: None). Integer to define the minimum length in tokens of the output summary. | ||
MinLength *int `json:"min_length,omitempty"` | ||
|
||
// (Default: None). Integer to define the maximum length in tokens of the output summary. | ||
MaxLength *int `json:"max_length,omitempty"` | ||
|
||
// (Default: None). Integer to define the top tokens considered within the sample operation to create | ||
// new text. | ||
TopK *int `json:"top_k,omitempty"` | ||
|
||
// (Default: None). Float to define the tokens that are within the sample` operation of text generation. | ||
// Add tokens in the sample for more probable to least probable until the sum of the probabilities is | ||
// greater than top_p. | ||
TopP *float64 `json:"top_p,omitempty"` | ||
|
||
// (Default: 1.0). Float (0.0-100.0). The temperature of the sampling operation. 1 means regular sampling, | ||
// 0 mens top_k=1, 100.0 is getting closer to uniform probability. | ||
Temperature *float64 `json:"temperature,omitempty"` | ||
|
||
// (Default: None). Float (0.0-100.0). The more a token is used within generation the more it is penalized | ||
// to not be picked in successive generation passes. | ||
RepetitionPenalty *float64 `json:"repetitionpenalty,omitempty"` | ||
|
||
// (Default: None). Float (0-120.0). The amount of time in seconds that the query should take maximum. | ||
// Network can cause some overhead so it will be a soft limit. | ||
MaxTime *float64 `json:"maxtime,omitempty"` | ||
} | ||
|
||
// Used with ConversationalRequest | ||
type ConverstationalInputs struct { | ||
// (Required) The last input from the user in the conversation. | ||
Text string `json:"text"` | ||
|
||
// A list of strings corresponding to the earlier replies from the model. | ||
GeneratedResponses []string `json:"generated_responses,omitempty"` | ||
|
||
// A list of strings corresponding to the earlier replies from the user. | ||
// Should be of the same length of GeneratedResponses. | ||
PastUserInputs []string `json:"past_user_inputs,omitempty"` | ||
} | ||
|
||
// Request structure for the conversational endpoint | ||
type ConversationalRequest struct { | ||
// (Required) | ||
Inputs ConverstationalInputs `json:"inputs,omitempty"` | ||
|
||
Parameters ConversationalParameters `json:"parameters,omitempty"` | ||
Options Options `json:"options,omitempty"` | ||
Model string `json:"-"` | ||
} | ||
|
||
// Used with ConversationalResponse | ||
type Conversation struct { | ||
// The last outputs from the model in the conversation, after the model has run. | ||
GeneratedResponses []string `json:"generated_responses,omitempty"` | ||
|
||
// The last inputs from the user in the conversation, after the model has run. | ||
PastUserInputs []string `json:"past_user_inputs,omitempty"` | ||
} | ||
|
||
// Response structure for the conversational endpoint | ||
type ConversationalResponse struct { | ||
// The answer of the model | ||
GeneratedText string `json:"generated_text,omitempty"` | ||
|
||
// A facility dictionary to send back for the next input (with the new user input addition). | ||
Conversation Conversation `json:"conversation,omitempty"` | ||
} | ||
|
||
// Conversational performs conversational AI using the specified model. | ||
// It sends a POST request to the Hugging Face inference endpoint with the provided conversational inputs. | ||
// The response contains the generated conversational response or an error if the request fails. | ||
func (ic *InferenceClient) Conversational(ctx context.Context, req *ConversationalRequest) (*ConversationalResponse, error) { | ||
if len(req.Inputs.Text) == 0 { | ||
return nil, errors.New("text is required") | ||
} | ||
|
||
body, err := ic.post(ctx, req.Model, "conversational", req) | ||
if err != nil { | ||
return nil, err | ||
} | ||
|
||
conversationalResponse := ConversationalResponse{} | ||
if err := json.Unmarshal(body, &conversationalResponse); err != nil { | ||
return nil, err | ||
} | ||
|
||
return &conversationalResponse, nil | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
package gohuggingface | ||
|
||
import ( | ||
"context" | ||
"encoding/json" | ||
"errors" | ||
"fmt" | ||
) | ||
|
||
// Request structure for the feature extraction endpoint | ||
type FeatureExtractionRequest struct { | ||
// String to get the features from | ||
Inputs []string `json:"inputs"` | ||
Options Options `json:"options,omitempty"` | ||
Model string `json:"-"` | ||
} | ||
|
||
// Response structure for the feature extraction endpoint | ||
type FeatureExtractionResponse [][][][]float64 | ||
|
||
// FeatureExtraction performs feature extraction using the specified model. | ||
// It sends a POST request to the Hugging Face inference endpoint with the provided input data. | ||
// The response contains the extracted features or an error if the request fails. | ||
func (ic *InferenceClient) FeatureExtraction(ctx context.Context, req *FeatureExtractionRequest) (FeatureExtractionResponse, error) { | ||
if len(req.Inputs) == 0 { | ||
return nil, errors.New("inputs are required") | ||
} | ||
|
||
body, err := ic.post(ctx, req.Model, "feature-extraction", req) | ||
if err != nil { | ||
return nil, err | ||
} | ||
|
||
fmt.Println(string(body)) | ||
|
||
featureExtractionResponse := FeatureExtractionResponse{} | ||
if err := json.Unmarshal(body, &featureExtractionResponse); err != nil { | ||
return nil, err | ||
} | ||
|
||
return featureExtractionResponse, nil | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.