From a3898113c290a9b4c998dc98a17d51979a3ec0c9 Mon Sep 17 00:00:00 2001 From: macie Date: Tue, 23 Jan 2024 08:28:23 +0100 Subject: [PATCH] feat: Enable Alpaca prompt format --- llama/prompt.go | 16 ++++++++++++++++ llama/prompt_test.go | 4 ++++ 2 files changed, 20 insertions(+) diff --git a/llama/prompt.go b/llama/prompt.go index b79579c..37c4ede 100644 --- a/llama/prompt.go +++ b/llama/prompt.go @@ -10,6 +10,22 @@ var promptFormats = map[string]func(Prompt) string{ "": func(p Prompt) string { return fmt.Sprintf("%s\n%s", p.System, strings.Join(p.userPrompt, "\n")) }, + "alpaca": func(p Prompt) string { + systemPrompt := "" + if p.System != "" { + systemPrompt = fmt.Sprintf("%s\n\n", p.System) + } + + userPrompt := "" + for i := range p.userPrompt { + userPrompt += fmt.Sprintf("### Instruction:\n%s\n\n", p.userPrompt[i]) + } + if userPrompt == "" { + userPrompt = "### Instruction:\n\n" + } + + return fmt.Sprintf("%s%s### Response:\n", systemPrompt, userPrompt) + }, "chatml": func(p Prompt) string { systemPrompt := fmt.Sprintf("<|im_start|>system\n%s<|im_end|>\n", p.System) userPrompt := "" diff --git a/llama/prompt_test.go b/llama/prompt_test.go index c463f39..ded2c0e 100644 --- a/llama/prompt_test.go +++ b/llama/prompt_test.go @@ -10,6 +10,8 @@ func TestPromptString(t *testing.T) { prompt Prompt want string }{ + {Prompt{Format: "alpaca"}, "### Instruction:\n\n### Response:\n"}, + {Prompt{Format: "Alpaca", System: "You are a helpful assistant."}, "You are a helpful assistant.\n\n### Instruction:\n\n### Response:\n"}, {Prompt{Format: "chatml"}, "<|im_start|>system\n<|im_end|>\n<|im_start|>user\n<|im_end|>\n<|im_start|>assistant\n"}, {Prompt{Format: "ChatML", System: "You are a helpful assistant."}, "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n<|im_end|>\n<|im_start|>assistant\n"}, {Prompt{Format: "openchat"}, "GPT4 Correct User: <|end_of_turn|>GPT4 Correct Assistant: "}, @@ -36,6 +38,8 @@ func TestPromptAdd(t *testing.T) { userPrompts []string want string }{ + {"Alpaca", []string{}, "### Instruction:\n\n### Response:\n"}, + {"alpaca", []string{"How are you?"}, "### Instruction:\nHow are you?\n\n### Response:\n"}, {"ChatML", []string{}, "<|im_start|>system\n<|im_end|>\n<|im_start|>user\n<|im_end|>\n<|im_start|>assistant\n"}, {"chatml", []string{"How are you?"}, "<|im_start|>system\n<|im_end|>\n<|im_start|>user\nHow are you?<|im_end|>\n<|im_start|>assistant\n"}, {"OpenChat", []string{}, "GPT4 Correct User: <|end_of_turn|>GPT4 Correct Assistant: "},