diff --git a/src/transformers/utils/chat_template_utils.py b/src/transformers/utils/chat_template_utils.py index af217eb844e5f5..8d9fad6ef855a1 100644 --- a/src/transformers/utils/chat_template_utils.py +++ b/src/transformers/utils/chat_template_utils.py @@ -239,13 +239,24 @@ def _parse_type_hint(hint): return_dict["nullable"] = True return return_dict elif origin is tuple: - raise ValueError( - "This helper does not parse Tuple types, as they are usually used to indicate that " - "each position is associated with a specific type, and this requires JSON schemas " - "that are not supported by most templates. We recommend " - "either using List instead for arguments where this is appropriate, or " - "splitting arguments with Tuple types into multiple arguments that take single inputs." - ) + if not get_args(hint): + return {"type": "array"} + if len(get_args(hint)) == 1: + raise ValueError( + "Tuple type hints should only be used when the argument has a fixed length and each " + f"element has a specific type. The hint {hint} indicates a Tuple of length 1. " + "This should be replaced with an unwrapped type hint instead like " + f"{get_args(hint)[0]}. Alternatively, if the " + "function can actually take a tuple with multiple elements, please either indicate " + f"each element type (e.g. Tuple[{get_args(hint)[0]}, {get_args(hint)[0]}]), " + f"or if the input can be variable length, use List[{get_args(hint)[0]}] instead." + ) + if ... in get_args(hint): + raise ValueError( + "'...' is not supported in Tuple type hints. Use List[] types for variable-length" + " inputs instead." + ) + return {"type": "array", "prefixItems": [_parse_type_hint(t) for t in get_args(hint)]} elif origin is dict: # The JSON equivalent to a dict is 'object', which mandates that all keys are strings # However, we can specify the type of the dict values with "additionalProperties" diff --git a/tests/utils/test_chat_template_utils.py b/tests/utils/test_chat_template_utils.py index b48385ae13b174..65ad7b2a292254 100644 --- a/tests/utils/test_chat_template_utils.py +++ b/tests/utils/test_chat_template_utils.py @@ -1,5 +1,5 @@ import unittest -from typing import List, Optional, Union +from typing import List, Optional, Tuple, Union from transformers.utils import get_json_schema @@ -234,9 +234,10 @@ def fn(x: int) -> int: "description": "Test function", "parameters": { "type": "object", - "properties": {"x": {"type": "integer", "description": "The input"}, "return": {"type": "integer"}}, + "properties": {"x": {"type": "integer", "description": "The input"}}, "required": ["x"], }, + "return": {"type": "integer"}, } self.assertEqual(schema, expected_schema) @@ -254,6 +255,33 @@ def fn(x: int) -> int: """ return x + schema = get_json_schema(fn) + expected_schema = { + "name": "fn", + "description": "Test function", + "parameters": { + "type": "object", + "properties": {"x": {"type": "integer", "description": "The input"}}, + "required": ["x"], + }, + "return": {"type": "integer", "description": "The output"}, + } + self.assertEqual(schema, expected_schema) + + def test_tuple(self): + def fn(x: Tuple[int, str]): + """ + Test function + + Args: + x: The input + + + Returns: + The output + """ + return x + schema = get_json_schema(fn) expected_schema = { "name": "fn", @@ -261,10 +289,49 @@ def fn(x: int) -> int: "parameters": { "type": "object", "properties": { - "x": {"type": "integer", "description": "The input"}, - "return": {"type": "integer", "description": "The output"}, + "x": { + "type": "array", + "prefixItems": [{"type": "integer"}, {"type": "string"}], + "description": "The input", + } }, "required": ["x"], }, } self.assertEqual(schema, expected_schema) + + def test_single_element_tuple_fails(self): + def fn(x: Tuple[int]): + """ + Test function + + Args: + x: The input + + + Returns: + The output + """ + return x + + # Single-element tuples should just be the type itself, or List[type] for variable-length inputs + with self.assertRaises(ValueError): + get_json_schema(fn) + + def test_ellipsis_type_fails(self): + def fn(x: Tuple[int, ...]): + """ + Test function + + Args: + x: The input + + + Returns: + The output + """ + return x + + # Variable length inputs should be specified with List[type], not Tuple[type, ...] + with self.assertRaises(ValueError): + get_json_schema(fn)