forked from AkariAsai/OpenScholar
-
Notifications
You must be signed in to change notification settings - Fork 0
/
_prompt_templates.py
213 lines (176 loc) · 6.78 KB
/
_prompt_templates.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from functools import partial
from typing import Dict, List, Protocol, Tuple
from torchtune.data import Message, Role
class PromptTemplateInterface(Protocol):
"""
Interface for prompt templates. Each prompt template can include structured
text for system, user, and assistant roles that are prepended or appended to
the message content.
"""
# Template should map role to a tuple containing the tag to prepend to the text
# and tag to append to the text. Leave as empty strings to not prepend or append
template: Dict[Role, Tuple[str, str]]
def __call__(
self,
messages: List[Message],
) -> List[Message]:
"""
Format each role's message(s) according to the prompt template
Args:
messages (List[Message]): a single conversation, structured as a list
of :class:`~torchtune.data.Message` objects
Returns:
The formatted list of messages
"""
pass
class PromptTemplate(PromptTemplateInterface):
"""
Quickly define a custom prompt template by passing in a dictionary mapping role to
the prepend and append tags. For example, to achieve the following prompt
template::
System: {content}\\n
User: {content}\\n
Assistant: {content}\\n
Tool: {content}\\n
You need to pass in a tuple for each role, where ``PREPEND_TAG`` is the string
added before the text content and ``APPEND_TAG`` is the string added after::
template = {role: (PREPEND_TAG, APPEND_TAG)}
Thus, the template would be defined as follows::
template = {
"system": ("System: ", "\\n"),
"user": ("User: ", "\\n"),
"assistant": ("Assistant: ", "\\n"),
"ipython": ("Tool: ", "\\n"),
}
Once instantiated, you must call the prompt template on a list of messages. It
will return the same list of messages updated with the template.
Note:
Any tags prepended/appended to the assistant message will be included
in the loss calculation. All other prepend/append tags for other roles
(system, user, ipython) are, in most cases, not included in loss. Consider using
the append tags for user messages for tags that need to come before the
assistant message but should not be included in loss. For more custom masking
and prompt templating, you can create your own class based off the
:class:`~torchtune.data.PromptTemplate` interface.
Args:
template (Dict[Role, Tuple[str, str]]): a dictionary mapping role to the
prepend and append tags
"""
def __init__(
self,
template: Dict[Role, Tuple[str, str]],
):
self.template = template
def __call__(self, messages: List[Message]) -> List[Message]:
"""
Format each role's message(s) according to the prompt template by prepending
and appending the defined tags.
Args:
messages (List[Message]): list of messages to apply the template to
Returns:
List[Message]: The formatted list of messages
"""
formatted_dialogue = []
for message in messages:
if message.role in self.template:
prepend_tag = self.template[message.role][0]
append_tag = self.template[message.role][1]
content = (
[{"type": "text", "content": prepend_tag}]
+ message.content
+ [{"type": "text", "content": append_tag}]
)
else:
content = message.content
formatted_dialogue.append(
Message(
role=message.role,
content=content,
masked=message.masked,
ipython=message.ipython,
eot=message.eot,
),
)
return formatted_dialogue
class ChatMLTemplate(PromptTemplateInterface):
"""
OpenAI's `Chat Markup Language
<https://github.com/MicrosoftDocs/azure-docs/blob/772c14eeabfa0c0c561d5c2d34ef19341f528b7b/articles/ai-services/openai/how-to/chat-markup-language.md>`_
used by their chat models.
It is the default chat template used by Hugging Face models.
.. code-block:: text
<|im_start|>system
Provide some context and/or instructions to the model.<|im_end|>
<|im_start|>user
The user’s message goes here<|im_end|>
<|im_start|>assistant
The assistant’s response goes here<|im_end|>
"""
template = {
"system": ("<|im_start|>system\n", "<|im_end|>\n"),
"user": ("<|im_start|>user\n", "<|im_end|>\n"),
"assistant": ("<|im_start|>assistant\n", "<|im_end|>"),
"ipython": ("", ""),
}
def __call__(
self,
messages: List[Message],
) -> List[Message]:
"""
Format user, assistant, and system messages with appropriate tags.
Args:
messages (List[Message]): a single conversation, structured as a list
of `Message` objects
Returns:
The formatted list of messages
"""
formatted_dialogue = []
for message in messages:
content = (
[{"type": "text", "content": self.template[message.role][0]}]
+ message.content
+ [{"type": "text", "content": self.template[message.role][1]}]
)
formatted_dialogue.append(
Message(
role=message.role,
content=content,
masked=message.masked,
ipython=message.ipython,
eot=message.eot,
),
)
return formatted_dialogue
GrammarErrorCorrectionTemplate = partial(
PromptTemplate,
template={
"user": ("Correct this to standard English: ", "\n---\nCorrected: "),
},
)
GrammarErrorCorrectionTemplate.__doc__ = """
A prompt template for grammar error correction tasks::
Correct this to standard English: {user_message}
---
Corrected: {assistant_message}
Please see :class:`~torchtune.data.PromptTemplate` for full API arguments.
"""
SummarizeTemplate = partial(
PromptTemplate,
template={
"user": ("Summarize this dialogue:\n", "\n---\nSummary:\n"),
},
)
SummarizeTemplate.__doc__ = """
A prompt template for summarization tasks::
Summarize this dialogue:
{user_message}
---
Summary:
{assistant_message}
Please see :class:`~torchtune.data.PromptTemplate` for full API arguments.
"""