diff --git a/src/transformers/tokenization_utils_base.py b/src/transformers/tokenization_utils_base.py index 796e2caf347958..3b319888bd0721 100644 --- a/src/transformers/tokenization_utils_base.py +++ b/src/transformers/tokenization_utils_base.py @@ -1604,6 +1604,7 @@ def __init__(self, **kwargs): # Stores a Jinja template that formats chat histories into tokenizable strings self.chat_template = kwargs.pop("chat_template", None) + self.inverse_chat_template = kwargs.pop("inverse_chat_template", None) if isinstance(self.chat_template, (list, tuple)): # Chat templates are stored as lists of dicts with fixed key names, # we reconstruct that into a single dict while loading them. @@ -1881,6 +1882,33 @@ def apply_chat_template( else: return rendered + def apply_inverse_chat_template(self, formatted_chat, inverse_chat_template=None): + """ + This method performs the inverse of the apply_chat_template method. In other words, it converts a formatted + string back into the arguments to apply_chat_template that would have created it: The chat history as a list + of dicts, and optionally other inputs like tools. + + Note that inverting a chat template is not trivial, and so we require a separate Jinja inverse template + to be defined. This template takes a string as input, and yield the formatted args in JSON format. + + Args: + formatted_chat (`str`): The formatted chat history to convert back into arguments. + inverse_chat_template (`str`, *optional*): A Jinja template to use for this conversion. If not provided, the + tokenizer's inverse_chat_template attribute will be used. + """ + + if inverse_chat_template is None: + if self.inverse_chat_template is not None: + inverse_chat_template = self.inverse_chat_template + else: + raise ValueError("No inverse chat template set, cannot use apply_inverse_chat_template!") + + # Compilation function uses a cache to avoid recompiling the same template + compiled_template = self._compile_jinja_template(inverse_chat_template) + + return compiled_template.render(formatted_chat) + + @lru_cache def _compile_jinja_template(self, chat_template): try: