Skip to content

Commit

Permalink
feat: Introduce ChatML prompt format
Browse files Browse the repository at this point in the history
Models are taugh with specially formatted prompts. Using the same
format for inference gives the best results.
  • Loading branch information
macie committed Jan 1, 2024
1 parent 66d231e commit 893bb83
Show file tree
Hide file tree
Showing 7 changed files with 148 additions and 14 deletions.
34 changes: 28 additions & 6 deletions cmd/boludo/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ func Version() string {
type AppConfig struct {
Options llama.Options
ServerPath string
Prompt string
Prompt llama.Prompt
Timeout time.Duration
Verbose bool
ExitMessage string
Expand Down Expand Up @@ -91,10 +91,13 @@ func NewAppConfig(cliArgs []string) (AppConfig, error) {
options.Update(configFile.Options(configArgs.ConfigId))
options.Update(configArgs.Options())

prompt := configArgs.Prompt
if spec, ok := configFile[configArgs.ConfigId]; ok {
prompt = fmt.Sprintf("%s %s", spec.InitialPrompt, configArgs.Prompt)
prompt := configFile.Prompt(configArgs.ConfigId)
initialPrompt := configFile.InitialPrompt(configArgs.ConfigId)
userPrompt := configArgs.Prompt
if initialPrompt != "" {
userPrompt = fmt.Sprintf("%s %s", initialPrompt, userPrompt)
}
prompt.Add(userPrompt)

return AppConfig{
Prompt: prompt,
Expand Down Expand Up @@ -226,7 +229,7 @@ func (c *ConfigFile) UnmarshalTOML(data interface{}) error {
defaultSpec := ModelSpec{
Model: "",
InitialPrompt: "",
Format: llama.DefaultOptions.Format,
Format: "",
Creativity: llama.DefaultOptions.Temp,
Cutoff: llama.DefaultOptions.MinP,
}
Expand All @@ -238,6 +241,8 @@ func (c *ConfigFile) UnmarshalTOML(data interface{}) error {
defaultSpec.Creativity = (float32)(v.(float64))
case "cutoff":
defaultSpec.Cutoff = (float32)(v.(float64))
case "format":
defaultSpec.Format = v.(string)
case "initial-prompt":
defaultSpec.InitialPrompt = v.(string)
}
Expand All @@ -255,11 +260,28 @@ func (c *ConfigFile) Options(configId string) llama.Options {
if spec, ok := (*c)[configId]; ok {
return llama.Options{
ModelPath: spec.Model,
Format: spec.Format,
Temp: spec.Creativity,
MinP: spec.Cutoff,
}
}

return llama.DefaultOptions
}

// Prompt returns the llama.Prompt based on the ConfigFile.
func (c *ConfigFile) Prompt(configId string) llama.Prompt {
if spec, ok := (*c)[configId]; ok {
return llama.Prompt{
Format: spec.Format,
}
}
return llama.Prompt{}
}

// InitialPrompt returns the initial prompt specified in the ConfigFile.
func (c *ConfigFile) InitialPrompt(configId string) string {
if spec, ok := (*c)[configId]; ok {
return spec.InitialPrompt
}
return ""
}
4 changes: 2 additions & 2 deletions cmd/boludo/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ func TestConfigArgsOptions(t *testing.T) {
want llama.Options
}{
{ConfigArgs{}, llama.DefaultOptions},
{ConfigArgs{ModelPath: "model.gguf"}, llama.Options{ModelPath: "model.gguf", Format: "", Temp: 1, MinP: 0}},
{ConfigArgs{ModelPath: "model.gguf"}, llama.Options{ModelPath: "model.gguf", Temp: 1, MinP: 0}},
}
for _, tc := range testcases {
tc := tc
Expand Down Expand Up @@ -177,7 +177,7 @@ func TestConfigFileOptions(t *testing.T) {
file ConfigFile
want llama.Options
}{
{"chat", ConfigFile{"edit": ModelSpec{Model: "editmodel.gguf"}, "chat": ModelSpec{Model: "chatmodel.gguf", Format: "", Creativity: 0.3, Cutoff: 2}}, llama.Options{ModelPath: "chatmodel.gguf", Format: "", Temp: 0.3, MinP: 2}},
{"chat", ConfigFile{"edit": ModelSpec{Model: "editmodel.gguf"}, "chat": ModelSpec{Model: "chatmodel.gguf", Format: "", Creativity: 0.3, Cutoff: 2}}, llama.Options{ModelPath: "chatmodel.gguf", Temp: 0.3, MinP: 2}},
{"invalid", ConfigFile{}, llama.DefaultOptions},
}
for _, tc := range testcases {
Expand Down
4 changes: 2 additions & 2 deletions llama/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,12 +49,12 @@ type Client struct {
}

// Complete returns a channel with completion results for given string.
func (c *Client) Complete(ctx context.Context, s string) (chan string, error) {
func (c *Client) Complete(ctx context.Context, p Prompt) (chan string, error) {
if c.Options == nil {
c.Options = &DefaultOptions
}
req := completionRequest{
Prompt: s,
Prompt: p.String(),
Temp: c.Options.Temp,
TopK: 40,
MinP: c.Options.MinP,
Expand Down
5 changes: 4 additions & 1 deletion llama/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,11 @@ func TestComplete(t *testing.T) {
}
defer server.Close()

prompt := Prompt{}
prompt.Add("Once upon a time")

client := Client{}
c, err := client.Complete(context.TODO(), "Once upon a time")
c, err := client.Complete(context.TODO(), prompt)
if err != nil {
t.Fatalf("client.Complete() returns error: %v", err)
}
Expand Down
5 changes: 2 additions & 3 deletions llama/llama.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ func Serve(ctx context.Context, modelPath string) error {
}

// Complete returns a channel with completion results for given string.
func Complete(ctx context.Context, s string) (chan string, error) {
return defaultClient.Complete(ctx, s)
func Complete(ctx context.Context, p Prompt) (chan string, error) {
return defaultClient.Complete(ctx, p)
}

// Close releases all resources used by LLM server.
Expand All @@ -45,7 +45,6 @@ func Close() error {
// Options represent parameters for interacting with LLaMA model.
type Options struct {
ModelPath string
Format string
Temp float32
MinP float32
Seed uint
Expand Down
52 changes: 52 additions & 0 deletions llama/prompt.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
// Copyright (C) 2023 Maciej Żok
//
// SPDX-License-Identifier: GPL-3.0-or-later

package llama

import (
"fmt"
"strings"
)

// supported prompt formats
var promptFormats = map[string]func(Prompt) string{
"": func(p Prompt) string {
return fmt.Sprintf("%s\n%s", p.System, strings.Join(p.userPrompt, "\n"))
},
"chatml": func(p Prompt) string {
systemPrompt := fmt.Sprintf("<|im_start|>system\n%s<|im_end|>\n", p.System)
userPrompt := ""
for i := range p.userPrompt {
userPrompt += fmt.Sprintf("<|im_start|>user\n%s<|im_end|>\n", p.userPrompt[i])
}
if userPrompt == "" {
userPrompt = "<|im_start|>user\n<|im_end|>\n"
}

return fmt.Sprintf("%s%s<|im_start|>assistant", systemPrompt, userPrompt)
},
}

// Prompt represents prompt for the LLM.
type Prompt struct {
Format string
System string
userPrompt []string
}

// String returns prompt string in format specified by Format.
// If Format is not specified, returns prompt in default format.
func (p *Prompt) String() string {
formatFunc, ok := promptFormats[strings.ToLower(p.Format)]
if !ok {
formatFunc = promptFormats[""]
}

return formatFunc(*p)
}

// Add adds user prompt to the prompt.
func (p *Prompt) Add(userPrompt string) {
p.userPrompt = append(p.userPrompt, userPrompt)
}
58 changes: 58 additions & 0 deletions llama/prompt_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
// Copyright (C) 2023 Maciej Żok
//
// SPDX-License-Identifier: GPL-3.0-or-later

package llama

import (
"strings"
"testing"
)

func TestPromptString(t *testing.T) {
testcases := []struct {
prompt Prompt
want string
}{
{Prompt{Format: "chatml"}, "<|im_start|>system\n<|im_end|>\n<|im_start|>user\n<|im_end|>\n<|im_start|>assistant"},
{Prompt{Format: "ChatML", System: "You are a helpful assistant."}, "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n<|im_end|>\n<|im_start|>assistant"},
}

for _, tc := range testcases {
tc := tc
t.Run(tc.prompt.String(), func(t *testing.T) {
t.Parallel()
got := tc.prompt.String()
if got != tc.want {
t.Fatalf("Prompt.String() = %v, want %v", got, tc.want)
}
})
}
}

func TestPromptAdd(t *testing.T) {
testcases := []struct {
userPrompts []string
want string
}{
{[]string{}, "<|im_start|>system\n<|im_end|>\n<|im_start|>user\n<|im_end|>\n<|im_start|>assistant"},
{[]string{"How are you?"}, "<|im_start|>system\n<|im_end|>\n<|im_start|>user\nHow are you?<|im_end|>\n<|im_start|>assistant"},
}

for _, tc := range testcases {
tc := tc
t.Run(strings.Join(tc.userPrompts, "; "), func(t *testing.T) {
t.Parallel()
prompt := Prompt{
Format: "ChatML",
}
for i := range tc.userPrompts {
prompt.Add(tc.userPrompts[i])
}
got := prompt.String()
if got != tc.want {
t.Fatalf("Prompt.String() = %v, want %v", got, tc.want)
}
})
}
}

0 comments on commit 893bb83

Please sign in to comment.