From 5faad6cc1eb92a71b719105dbd6862b2585fa279 Mon Sep 17 00:00:00 2001 From: TheBurchLog <5104941+TheBurchLog@users.noreply.github.com> Date: Thu, 19 Oct 2023 18:39:21 +0000 Subject: [PATCH] Updated Unit Tests --- brewtils/decorators.py | 72 +++++++++++++++++++++------------------ test/decorators_test.py | 75 +++++++++++++++++++++++++++++++++++++++-- 2 files changed, 112 insertions(+), 35 deletions(-) diff --git a/brewtils/decorators.py b/brewtils/decorators.py index 81c926e1..6335dd56 100644 --- a/brewtils/decorators.py +++ b/brewtils/decorators.py @@ -79,7 +79,7 @@ def command( description=None, # type: Optional[str] parameters=None, # type: Optional[List[Parameter]] command_type="ACTION", # type: str - output_type="STRING", # type: str + output_type=None, # type: str schema=None, # type: Optional[Union[dict, str]] form=None, # type: Optional[Union[dict, list, str]] template=None, # type: Optional[str] @@ -120,6 +120,8 @@ def echo_json(self, message): """ if _wrapped is None: + if output_type is None: + output_type = "STRING" if form is not None: _deprecate( "Use of form with @command is now deprecated and will eventually be removed" @@ -149,6 +151,11 @@ def echo_json(self, message): metadata=metadata, ) + if output_type is None: + if str(inspect.signature(_wrapped)._return_annotation) in ["", ""]: + output_type = "JSON" + else: + output_type = "STRING" new_command = Command( description=description, parameters=parameters, @@ -461,8 +468,8 @@ def _initialize_command(method): cmd.name = _method_name(method) cmd.description = cmd.description or _method_docstring(method) - if cmd.output_type is None and str(inspect.signature(method)._return_annotation) in ["", ""]: - cmd.output_type = "JSON" + # if str(inspect.signature(method)._return_annotation) in ["", ""]: + # cmd.output_type = "JSON" try: base_dir = os.path.dirname(inspect.getfile(method)) @@ -538,13 +545,14 @@ def _parameter_docstring(method, parameter): else: docstring = method.__doc__ - delimiters = [":", "--"] - for line in docstring.expandtabs().split("\n"): - line = line.strip() - for delimiter in delimiters: - if delimiter in line: - if line.startswith(parameter + " ") or line.startswith(parameter + delimiter): - return line.split(delimiter)[1].strip() + if docstring: + delimiters = [":", "--"] + for line in docstring.expandtabs().split("\n"): + line = line.strip() + for delimiter in delimiters: + if delimiter in line: + if line.startswith(parameter + " ") or line.startswith(parameter + delimiter): + return line.split(delimiter)[1].strip() return None @@ -570,27 +578,27 @@ def _parameter_type_hint(method, cmd_parameter): docstring = method.func_doc else: docstring = method.__doc__ - - for line in docstring.expandtabs().split("\n"): - line = line.strip() - - if line.startswith(cmd_parameter + " ") and line.find(")") > line.find("("): - docType = line.split("(")[1].split(")")[0] - - if docType in ["str"]: - return "String" - if docType in ["int"]: - return "Integer" - if docType in ["float"]: - return "Float" - if docType in ["bool"]: - return "Boolean" - if docType in ["obj", "object", "dict"]: - return "Dictionary" - if docType.lower() in ["datetime"]: - return "DateTime" - if docType in ["bytes"]: - return "Bytes" + if docstring: + for line in docstring.expandtabs().split("\n"): + line = line.strip() + + if line.startswith(cmd_parameter + " ") and line.find(")") > line.find("("): + docType = line.split("(")[1].split(")")[0] + + if docType in ["str"]: + return "String" + if docType in ["int"]: + return "Integer" + if docType in ["float"]: + return "Float" + if docType in ["bool"]: + return "Boolean" + if docType in ["obj", "object", "dict"]: + return "Dictionary" + if docType.lower() in ["datetime"]: + return "DateTime" + if docType in ["bytes"]: + return "Bytes" return None def _sig_info(arg): @@ -871,7 +879,7 @@ def _signature_parameters(cmd, method): if arg.name not in cmd.parameter_keys(): cmd.parameters.append( _initialize_parameter( - key=arg.name, default=sig_default, optional=sig_optional, type=sig_type, method=method + key=arg.name, default=sig_default, optional=sig_optional, method=method ) ) diff --git a/test/decorators_test.py b/test/decorators_test.py index 9d250526..1ce7cb8e 100644 --- a/test/decorators_test.py +++ b/test/decorators_test.py @@ -180,6 +180,75 @@ def cmd(self, foo): assert_parameter_equal(c.parameters[0], Parameter(**basic_param)) + class TestParametersExtract(object): + """Test that Type Hints and Doc Strings parse""" + + class TestTypeHint(object): + """Type Hint arguments""" + + def test_type_hints_parameter(self): + @command + def cmd(foo:int): + return foo + bg_cmd = _parse_method(cmd) + + assert len(bg_cmd.parameters) == 1 + assert bg_cmd.parameters[0].key == "foo" + assert bg_cmd.parameters[0].type == "Integer" + assert bg_cmd.parameters[0].default is None + assert bg_cmd.parameters[0].optional is False + + def test_type_hints_output(self): + @command + def cmd(foo:int) -> dict: + return foo + bg_cmd = _parse_method(cmd) + + assert bg_cmd.output_type == "JSON" + + class TestDocString(object): + + def test_cmd_description(self): + @command + def cmd(foo): + """Default Command Description + """ + return foo + bg_cmd = _parse_method(cmd) + + assert bg_cmd.description == "Default Command Description" + + def test_param_description(self): + @command + def cmd(foo): + """Default Command Description + + Args: + foo : Parameter Description + """ + return foo + bg_cmd = _parse_method(cmd) + + assert len(bg_cmd.parameters) == 1 + assert bg_cmd.parameters[0].key == "foo" + assert bg_cmd.parameters[0].description == "Parameter Description" + + def test_param_type(self): + @command + def cmd(foo): + """Default Command Description + + Args: + foo (int): Parameter Description + """ + return foo + bg_cmd = _parse_method(cmd) + + assert len(bg_cmd.parameters) == 1 + assert bg_cmd.parameters[0].key == "foo" + assert bg_cmd.parameters[0].type == "Integer" + + class TestParameterReconciliation(object): """Test that the parameters line up correctly""" @@ -962,7 +1031,7 @@ def test_parameter(self, init_mock, param): assert len(res) == 1 assert res[0] == init_mock.return_value - init_mock.assert_called_once_with(param=param) + init_mock.assert_called_once_with(param=param, method=None) def test_deprecated_model(self, init_mock, nested_1): with warnings.catch_warnings(record=True) as w: @@ -972,7 +1041,7 @@ def test_deprecated_model(self, init_mock, nested_1): assert len(res) == 1 assert res[0] == init_mock.return_value - init_mock.assert_called_once_with(param=nested_1.parameters[0]) + init_mock.assert_called_once_with(param=nested_1.parameters[0], method=None) assert issubclass(w[0].category, DeprecationWarning) assert "model class objects" in str(w[0].message) @@ -982,7 +1051,7 @@ def test_dict(self, init_mock, basic_param): assert len(res) == 1 assert res[0] == init_mock.return_value - init_mock.assert_called_once_with(**basic_param) + init_mock.assert_called_once_with(**basic_param, method=None) def test_unknown_type(self): with pytest.raises(PluginParamError):