From 893bb830c6832725b5ec4cb0298904960948fcd6 Mon Sep 17 00:00:00 2001 From: macie Date: Mon, 1 Jan 2024 15:58:42 +0100 Subject: [PATCH] feat: Introduce ChatML prompt format Models are taugh with specially formatted prompts. Using the same format for inference gives the best results. --- cmd/boludo/config.go | 34 +++++++++++++++++++---- cmd/boludo/config_test.go | 4 +-- llama/client.go | 4 +-- llama/client_test.go | 5 +++- llama/llama.go | 5 ++-- llama/prompt.go | 52 +++++++++++++++++++++++++++++++++++ llama/prompt_test.go | 58 +++++++++++++++++++++++++++++++++++++++ 7 files changed, 148 insertions(+), 14 deletions(-) create mode 100644 llama/prompt.go create mode 100644 llama/prompt_test.go diff --git a/cmd/boludo/config.go b/cmd/boludo/config.go index 63c5de7..7543cbb 100644 --- a/cmd/boludo/config.go +++ b/cmd/boludo/config.go @@ -48,7 +48,7 @@ func Version() string { type AppConfig struct { Options llama.Options ServerPath string - Prompt string + Prompt llama.Prompt Timeout time.Duration Verbose bool ExitMessage string @@ -91,10 +91,13 @@ func NewAppConfig(cliArgs []string) (AppConfig, error) { options.Update(configFile.Options(configArgs.ConfigId)) options.Update(configArgs.Options()) - prompt := configArgs.Prompt - if spec, ok := configFile[configArgs.ConfigId]; ok { - prompt = fmt.Sprintf("%s %s", spec.InitialPrompt, configArgs.Prompt) + prompt := configFile.Prompt(configArgs.ConfigId) + initialPrompt := configFile.InitialPrompt(configArgs.ConfigId) + userPrompt := configArgs.Prompt + if initialPrompt != "" { + userPrompt = fmt.Sprintf("%s %s", initialPrompt, userPrompt) } + prompt.Add(userPrompt) return AppConfig{ Prompt: prompt, @@ -226,7 +229,7 @@ func (c *ConfigFile) UnmarshalTOML(data interface{}) error { defaultSpec := ModelSpec{ Model: "", InitialPrompt: "", - Format: llama.DefaultOptions.Format, + Format: "", Creativity: llama.DefaultOptions.Temp, Cutoff: llama.DefaultOptions.MinP, } @@ -238,6 +241,8 @@ func (c *ConfigFile) UnmarshalTOML(data interface{}) error { defaultSpec.Creativity = (float32)(v.(float64)) case "cutoff": defaultSpec.Cutoff = (float32)(v.(float64)) + case "format": + defaultSpec.Format = v.(string) case "initial-prompt": defaultSpec.InitialPrompt = v.(string) } @@ -255,7 +260,6 @@ func (c *ConfigFile) Options(configId string) llama.Options { if spec, ok := (*c)[configId]; ok { return llama.Options{ ModelPath: spec.Model, - Format: spec.Format, Temp: spec.Creativity, MinP: spec.Cutoff, } @@ -263,3 +267,21 @@ func (c *ConfigFile) Options(configId string) llama.Options { return llama.DefaultOptions } + +// Prompt returns the llama.Prompt based on the ConfigFile. +func (c *ConfigFile) Prompt(configId string) llama.Prompt { + if spec, ok := (*c)[configId]; ok { + return llama.Prompt{ + Format: spec.Format, + } + } + return llama.Prompt{} +} + +// InitialPrompt returns the initial prompt specified in the ConfigFile. +func (c *ConfigFile) InitialPrompt(configId string) string { + if spec, ok := (*c)[configId]; ok { + return spec.InitialPrompt + } + return "" +} diff --git a/cmd/boludo/config_test.go b/cmd/boludo/config_test.go index 94373f0..d2edda0 100644 --- a/cmd/boludo/config_test.go +++ b/cmd/boludo/config_test.go @@ -77,7 +77,7 @@ func TestConfigArgsOptions(t *testing.T) { want llama.Options }{ {ConfigArgs{}, llama.DefaultOptions}, - {ConfigArgs{ModelPath: "model.gguf"}, llama.Options{ModelPath: "model.gguf", Format: "", Temp: 1, MinP: 0}}, + {ConfigArgs{ModelPath: "model.gguf"}, llama.Options{ModelPath: "model.gguf", Temp: 1, MinP: 0}}, } for _, tc := range testcases { tc := tc @@ -177,7 +177,7 @@ func TestConfigFileOptions(t *testing.T) { file ConfigFile want llama.Options }{ - {"chat", ConfigFile{"edit": ModelSpec{Model: "editmodel.gguf"}, "chat": ModelSpec{Model: "chatmodel.gguf", Format: "", Creativity: 0.3, Cutoff: 2}}, llama.Options{ModelPath: "chatmodel.gguf", Format: "", Temp: 0.3, MinP: 2}}, + {"chat", ConfigFile{"edit": ModelSpec{Model: "editmodel.gguf"}, "chat": ModelSpec{Model: "chatmodel.gguf", Format: "", Creativity: 0.3, Cutoff: 2}}, llama.Options{ModelPath: "chatmodel.gguf", Temp: 0.3, MinP: 2}}, {"invalid", ConfigFile{}, llama.DefaultOptions}, } for _, tc := range testcases { diff --git a/llama/client.go b/llama/client.go index a59a3b6..5bab9b3 100644 --- a/llama/client.go +++ b/llama/client.go @@ -49,12 +49,12 @@ type Client struct { } // Complete returns a channel with completion results for given string. -func (c *Client) Complete(ctx context.Context, s string) (chan string, error) { +func (c *Client) Complete(ctx context.Context, p Prompt) (chan string, error) { if c.Options == nil { c.Options = &DefaultOptions } req := completionRequest{ - Prompt: s, + Prompt: p.String(), Temp: c.Options.Temp, TopK: 40, MinP: c.Options.MinP, diff --git a/llama/client_test.go b/llama/client_test.go index 23cbac2..7bb89e2 100644 --- a/llama/client_test.go +++ b/llama/client_test.go @@ -18,8 +18,11 @@ func TestComplete(t *testing.T) { } defer server.Close() + prompt := Prompt{} + prompt.Add("Once upon a time") + client := Client{} - c, err := client.Complete(context.TODO(), "Once upon a time") + c, err := client.Complete(context.TODO(), prompt) if err != nil { t.Fatalf("client.Complete() returns error: %v", err) } diff --git a/llama/llama.go b/llama/llama.go index 8ada8f7..fc16c58 100644 --- a/llama/llama.go +++ b/llama/llama.go @@ -33,8 +33,8 @@ func Serve(ctx context.Context, modelPath string) error { } // Complete returns a channel with completion results for given string. -func Complete(ctx context.Context, s string) (chan string, error) { - return defaultClient.Complete(ctx, s) +func Complete(ctx context.Context, p Prompt) (chan string, error) { + return defaultClient.Complete(ctx, p) } // Close releases all resources used by LLM server. @@ -45,7 +45,6 @@ func Close() error { // Options represent parameters for interacting with LLaMA model. type Options struct { ModelPath string - Format string Temp float32 MinP float32 Seed uint diff --git a/llama/prompt.go b/llama/prompt.go new file mode 100644 index 0000000..87ffdb6 --- /dev/null +++ b/llama/prompt.go @@ -0,0 +1,52 @@ +// Copyright (C) 2023 Maciej Żok +// +// SPDX-License-Identifier: GPL-3.0-or-later + +package llama + +import ( + "fmt" + "strings" +) + +// supported prompt formats +var promptFormats = map[string]func(Prompt) string{ + "": func(p Prompt) string { + return fmt.Sprintf("%s\n%s", p.System, strings.Join(p.userPrompt, "\n")) + }, + "chatml": func(p Prompt) string { + systemPrompt := fmt.Sprintf("<|im_start|>system\n%s<|im_end|>\n", p.System) + userPrompt := "" + for i := range p.userPrompt { + userPrompt += fmt.Sprintf("<|im_start|>user\n%s<|im_end|>\n", p.userPrompt[i]) + } + if userPrompt == "" { + userPrompt = "<|im_start|>user\n<|im_end|>\n" + } + + return fmt.Sprintf("%s%s<|im_start|>assistant", systemPrompt, userPrompt) + }, +} + +// Prompt represents prompt for the LLM. +type Prompt struct { + Format string + System string + userPrompt []string +} + +// String returns prompt string in format specified by Format. +// If Format is not specified, returns prompt in default format. +func (p *Prompt) String() string { + formatFunc, ok := promptFormats[strings.ToLower(p.Format)] + if !ok { + formatFunc = promptFormats[""] + } + + return formatFunc(*p) +} + +// Add adds user prompt to the prompt. +func (p *Prompt) Add(userPrompt string) { + p.userPrompt = append(p.userPrompt, userPrompt) +} diff --git a/llama/prompt_test.go b/llama/prompt_test.go new file mode 100644 index 0000000..89ee1d0 --- /dev/null +++ b/llama/prompt_test.go @@ -0,0 +1,58 @@ +// Copyright (C) 2023 Maciej Żok +// +// SPDX-License-Identifier: GPL-3.0-or-later + +package llama + +import ( + "strings" + "testing" +) + +func TestPromptString(t *testing.T) { + testcases := []struct { + prompt Prompt + want string + }{ + {Prompt{Format: "chatml"}, "<|im_start|>system\n<|im_end|>\n<|im_start|>user\n<|im_end|>\n<|im_start|>assistant"}, + {Prompt{Format: "ChatML", System: "You are a helpful assistant."}, "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n<|im_end|>\n<|im_start|>assistant"}, + } + + for _, tc := range testcases { + tc := tc + t.Run(tc.prompt.String(), func(t *testing.T) { + t.Parallel() + got := tc.prompt.String() + if got != tc.want { + t.Fatalf("Prompt.String() = %v, want %v", got, tc.want) + } + }) + } +} + +func TestPromptAdd(t *testing.T) { + testcases := []struct { + userPrompts []string + want string + }{ + {[]string{}, "<|im_start|>system\n<|im_end|>\n<|im_start|>user\n<|im_end|>\n<|im_start|>assistant"}, + {[]string{"How are you?"}, "<|im_start|>system\n<|im_end|>\n<|im_start|>user\nHow are you?<|im_end|>\n<|im_start|>assistant"}, + } + + for _, tc := range testcases { + tc := tc + t.Run(strings.Join(tc.userPrompts, "; "), func(t *testing.T) { + t.Parallel() + prompt := Prompt{ + Format: "ChatML", + } + for i := range tc.userPrompts { + prompt.Add(tc.userPrompts[i]) + } + got := prompt.String() + if got != tc.want { + t.Fatalf("Prompt.String() = %v, want %v", got, tc.want) + } + }) + } +}