diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 802b3efb..ba9b94b3 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -1,6 +1,12 @@ Brewtils Changelog ================== +3.20.0 +------ +TBD + +- Expanded Auto Generation to support Literal Type Hinting, if python version >= 3.8 + 3.19.0 ------ 10/20/2023 diff --git a/brewtils/decorators.py b/brewtils/decorators.py index a0f7734a..0082faf0 100644 --- a/brewtils/decorators.py +++ b/brewtils/decorators.py @@ -19,6 +19,9 @@ else: from inspect import signature, Parameter as InspectParameter # noqa + if sys.version_info.major == 3 and sys.version_info.minor >= 8: + from typing import get_args + __all__ = [ "client", "command", @@ -563,22 +566,48 @@ def _parameter_docstring(method, parameter): return None +def _choices_type_hint(method, cmd_parameter): + if sys.version_info.major == 3 and sys.version_info.minor >= 8: + for _, arg in enumerate(signature(method).parameters.values()): + if arg.name == cmd_parameter: + if str(arg.annotation).startswith("typing.Literal"): + arg_choices = list() + for arg_choice in get_args(arg.annotation): + arg_choices.append(arg_choice) + return arg_choices + + return None + + def _parameter_type_hint(method, cmd_parameter): for _, arg in enumerate(signature(method).parameters.values()): if arg.name == cmd_parameter: - if str(arg.annotation) in ["<class 'str'>"]: + type_hint_class = str(arg.annotation) + if type_hint_class.startswith("typing.Literal"): + if sys.version_info.major == 3 and sys.version_info.minor >= 8: + choice_types = None + for arg_choice in get_args(arg.annotation): + if choice_types is None: + choice_types = type(arg_choice) + elif type(arg_choice) is not choice_types: + choice_types = None + break + + if choice_types is not None: + type_hint_class = str(choice_types) + if type_hint_class in ["<class 'str'>"]: return "String" - if str(arg.annotation) in ["<class 'int'>"]: + if type_hint_class in ["<class 'int'>"]: return "Integer" - if str(arg.annotation) in ["<class 'float'>"]: + if type_hint_class in ["<class 'float'>"]: return "Float" - if str(arg.annotation) in ["<class 'bool'>"]: + if type_hint_class in ["<class 'bool'>"]: return "Boolean" - if str(arg.annotation) in ["<class 'object'>", "<class 'dict'>"]: + if type_hint_class in ["<class 'object'>", "<class 'dict'>"]: return "Dictionary" - if str(arg.annotation).lower() in ["<class 'datetime'>"]: + if type_hint_class.lower() in ["<class 'datetime'>"]: return "DateTime" - if str(arg.annotation) in ["<class 'bytes'>"]: + if type_hint_class in ["<class 'bytes'>"]: return "Bytes" if hasattr(method, "func_doc"): @@ -720,6 +749,9 @@ def _initialize_parameter( if param.type is None and method is not None: param.type = _parameter_type_hint(method, param.key) + if param.choices is None and method is not None: + param.choices = _choices_type_hint(method, param.key) + # Type and type info # Type info is where type specific information goes. For now, this is specific # to file types. See #289 for more details. diff --git a/brewtils/rest/client.py b/brewtils/rest/client.py index ffe14fac..808d02b2 100644 --- a/brewtils/rest/client.py +++ b/brewtils/rest/client.py @@ -227,7 +227,7 @@ def can_connect(self, **kwargs): try: self.session.get(self.config_url, **kwargs) except requests.exceptions.ConnectionError as ex: - if type(ex) == requests.exceptions.ConnectionError: + if type(ex) is requests.exceptions.ConnectionError: return False raise @@ -831,7 +831,7 @@ def post_chunked_file(self, fd, file_params, current_position=0): data = fd.read(file_params["chunk_size"]) if not data: break - if type(data) != bytes: + if type(data) is not bytes: data = bytes(data, "utf-8") data = b64encode(data) chunk_result = self.session.post( diff --git a/test/decorators_test.py b/test/decorators_test.py index 583cb086..0f7936d5 100644 --- a/test/decorators_test.py +++ b/test/decorators_test.py @@ -6,6 +6,9 @@ import pytest from mock import Mock +if sys.version_info.major == 3 and sys.version_info.minor >= 8: + from typing import Literal + import brewtils.decorators from brewtils.decorators import ( _format_type, @@ -208,6 +211,54 @@ def cmd(foo: int) -> dict: assert bg_cmd.output_type == "JSON" + def test_type_hints_choices_any(self): + if sys.version_info.major == 3 and sys.version_info.minor >= 8: + + @command + def cmd(foo: Literal["a", 2] = "a") -> dict: + return foo + + bg_cmd = _parse_method(cmd) + + assert len(bg_cmd.parameters) == 1 + assert bg_cmd.parameters[0].key == "foo" + assert bg_cmd.parameters[0].type == "Any" + assert bg_cmd.parameters[0].choices.value == ["a", 2] + assert bg_cmd.parameters[0].default == "a" + assert bg_cmd.parameters[0].optional is True + + def test_type_hints_choices_string(self): + if sys.version_info.major == 3 and sys.version_info.minor >= 8: + + @command + def cmd(foo: Literal["a", "b"] = "a") -> dict: + return foo + + bg_cmd = _parse_method(cmd) + + assert len(bg_cmd.parameters) == 1 + assert bg_cmd.parameters[0].key == "foo" + assert bg_cmd.parameters[0].type == "String" + assert bg_cmd.parameters[0].choices.value == ["a", "b"] + assert bg_cmd.parameters[0].default == "a" + assert bg_cmd.parameters[0].optional is True + + def test_type_hints_choices_integer(self): + if sys.version_info.major == 3 and sys.version_info.minor >= 8: + + @command + def cmd(foo: Literal[1, 2] = 1) -> dict: + return foo + + bg_cmd = _parse_method(cmd) + + assert len(bg_cmd.parameters) == 1 + assert bg_cmd.parameters[0].key == "foo" + assert bg_cmd.parameters[0].type == "Integer" + assert bg_cmd.parameters[0].choices.value == [1, 2] + assert bg_cmd.parameters[0].default == 1 + assert bg_cmd.parameters[0].optional is True + class TestDocString(object): def test_cmd_description(self): @command