Skip to content

Commit

Permalink
feat: Enable Zephyr prompt format
Browse files Browse the repository at this point in the history
  • Loading branch information
macie committed Jan 23, 2024
1 parent af9a73d commit 5eba0e2
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 6 deletions.
12 changes: 12 additions & 0 deletions llama/prompt.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,18 @@ var promptFormats = map[string]func(Prompt) string{

return fmt.Sprintf("%s%sGPT4 Correct Assistant: ", systemPrompt, userPrompt)
},
"zephyr": func(p Prompt) string {
systemPrompt := fmt.Sprintf("<|system|>\n%s</s>\n", p.System)
userPrompt := ""
for i := range p.userPrompt {
userPrompt += fmt.Sprintf("<|user|>\n%s</s>\n", p.userPrompt[i])
}
if userPrompt == "" {
userPrompt = "<|user|>\n</s>\n"
}

return fmt.Sprintf("%s%s<|assistant|>\n", systemPrompt, userPrompt)
},
}

// Prompt represents prompt for the LLM.
Expand Down
17 changes: 11 additions & 6 deletions llama/prompt_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ func TestPromptString(t *testing.T) {
{Prompt{Format: "ChatML", System: "You are a helpful assistant."}, "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n<|im_end|>\n<|im_start|>assistant\n"},
{Prompt{Format: "openchat"}, "GPT4 Correct User: <|end_of_turn|>GPT4 Correct Assistant: "},
{Prompt{Format: "openchat", System: "You are a helpful assistant."}, "You are a helpful assistant.<|end_of_turn|>GPT4 Correct User: <|end_of_turn|>GPT4 Correct Assistant: "},
{Prompt{Format: "Zephyr"}, "<|system|>\n</s>\n<|user|>\n</s>\n<|assistant|>\n"},
{Prompt{Format: "Zephyr", System: "You are a helpful assistant."}, "<|system|>\nYou are a helpful assistant.</s>\n<|user|>\n</s>\n<|assistant|>\n"},
}

for _, tc := range testcases {
Expand All @@ -30,26 +32,29 @@ func TestPromptString(t *testing.T) {

func TestPromptAdd(t *testing.T) {
testcases := []struct {
format string
userPrompts []string
want string
}{
{[]string{}, "<|im_start|>system\n<|im_end|>\n<|im_start|>user\n<|im_end|>\n<|im_start|>assistant\n"},
{[]string{"How are you?"}, "<|im_start|>system\n<|im_end|>\n<|im_start|>user\nHow are you?<|im_end|>\n<|im_start|>assistant\n"},
{"ChatML", []string{}, "<|im_start|>system\n<|im_end|>\n<|im_start|>user\n<|im_end|>\n<|im_start|>assistant\n"},
{"chatml", []string{"How are you?"}, "<|im_start|>system\n<|im_end|>\n<|im_start|>user\nHow are you?<|im_end|>\n<|im_start|>assistant\n"},
{"OpenChat", []string{}, "GPT4 Correct User: <|end_of_turn|>GPT4 Correct Assistant: "},
{"openchat", []string{"How are you?"}, "GPT4 Correct User: How are you?<|end_of_turn|>GPT4 Correct Assistant: "},
{"Zephyr", []string{}, "<|system|>\n</s>\n<|user|>\n</s>\n<|assistant|>\n"},
{"zephyr", []string{"How are you?"}, "<|system|>\n</s>\n<|user|>\nHow are you?</s>\n<|assistant|>\n"},
}

for _, tc := range testcases {
tc := tc
t.Run(strings.Join(tc.userPrompts, "; "), func(t *testing.T) {
t.Parallel()
prompt := Prompt{
Format: "ChatML",
}
prompt := Prompt{Format: tc.format}
for i := range tc.userPrompts {
prompt.Add(tc.userPrompts[i])
}
got := prompt.String()
if got != tc.want {
t.Fatalf("Prompt.String() = %v, want %v", got, tc.want)
t.Fatalf("Prompt{Format: %v}.String() = %v, want %v", tc.format, got, tc.want)
}
})
}
Expand Down

0 comments on commit 5eba0e2

Please sign in to comment.