diff --git a/brewtils/schema_parser.py b/brewtils/schema_parser.py index 86e7a22c..6853a700 100644 --- a/brewtils/schema_parser.py +++ b/brewtils/schema_parser.py @@ -370,9 +370,9 @@ def parse_job_ids(cls, job_id_json, from_string=False, **kwargs): schema = brewtils.schemas.JobExportInputSchema(**kwargs) if from_string: - return schema.loads(job_id_json).data + return schema.loads(job_id_json) else: - return schema.load(job_id_json).data + return schema.load(job_id_json) @classmethod def parse_garden(cls, garden, from_string=False, **kwargs): @@ -561,7 +561,7 @@ def parse( schema.context["models"] = cls._models - return schema.loads(data).data if from_string else schema.load(data).data + return schema.loads(data) if from_string else schema.load(data) # Serialization methods @classmethod @@ -1162,7 +1162,7 @@ def serialize( schema = getattr(brewtils.schemas, schema_name)(**kwargs) - return schema.dumps(model).data if to_string else schema.dump(model).data + return schema.dumps(model) if to_string else schema.dump(model) # Explicitly force to_string to False so only original call returns a string multiple = [ diff --git a/brewtils/schemas.py b/brewtils/schemas.py index 9df24335..cdd24c71 100644 --- a/brewtils/schemas.py +++ b/brewtils/schemas.py @@ -2,8 +2,6 @@ from functools import partial -import marshmallow -import simplejson from marshmallow import Schema, fields, post_load, pre_load from marshmallow_polyfield import PolyField @@ -92,18 +90,15 @@ def __init__(self, type_field="payload_type", allowed_types=None, **kwargs): class BaseSchema(Schema): - class Meta: - version_nums = marshmallow.__version__.split(".") - if int(version_nums[0]) <= 2 and int(version_nums[1]) < 17: # pragma: no cover - json_module = simplejson - else: - render_module = simplejson - - def __init__(self, strict=True, **kwargs): - super(BaseSchema, self).__init__(strict=strict, **kwargs) + # class Meta: + # version_nums = marshmallow.__version__.split(".") + # if int(version_nums[0]) <= 2 and int(version_nums[1]) < 17: # pragma: no cover + # json_module = simplejson + # else: + # render_module = simplejson @post_load - def make_object(self, data): + def make_object(self, data, **_): try: model_class = self.context["models"][self.__class__.__name__] except KeyError: @@ -123,8 +118,8 @@ def get_attribute_names(cls): class ChoicesSchema(BaseSchema): type = fields.Str(allow_none=True) display = fields.Str(allow_none=True) - value = fields.Raw(allow_none=True, many=True) - strict = fields.Bool(allow_none=True, default=False) + value = fields.List(fields.Raw, allow_none=True) + strict = fields.Bool(allow_none=True, dump_default=False) details = fields.Dict(allow_none=True) @@ -136,8 +131,8 @@ class ParameterSchema(BaseSchema): optional = fields.Bool(allow_none=True) default = fields.Raw(allow_none=True) description = fields.Str(allow_none=True) - choices = fields.Nested("ChoicesSchema", allow_none=True, many=False) - parameters = fields.Nested("self", many=True, allow_none=True) + choices = fields.Nested(lambda: ChoicesSchema, allow_none=True) + parameters = fields.List(fields.Nested(lambda: ParameterSchema), allow_none=True) nullable = fields.Bool(allow_none=True) maximum = fields.Int(allow_none=True) minimum = fields.Int(allow_none=True) @@ -149,7 +144,7 @@ class ParameterSchema(BaseSchema): class CommandSchema(BaseSchema): name = fields.Str(allow_none=True) description = fields.Str(allow_none=True) - parameters = fields.Nested("ParameterSchema", many=True) + parameters = fields.List(fields.Nested(lambda: ParameterSchema()), allow_none=True) command_type = fields.Str(allow_none=True) output_type = fields.Str(allow_none=True) schema = fields.Dict(allow_none=True) @@ -168,7 +163,7 @@ class InstanceSchema(BaseSchema): name = fields.Str(allow_none=True) description = fields.Str(allow_none=True) status = fields.Str(allow_none=True) - status_info = fields.Nested("StatusInfoSchema", allow_none=True) + status_info = fields.Nested(lambda: StatusInfoSchema(), allow_none=True) queue_type = fields.Str(allow_none=True) queue_info = fields.Dict(allow_none=True) icon_name = fields.Str(allow_none=True) @@ -182,8 +177,8 @@ class SystemSchema(BaseSchema): version = fields.Str(allow_none=True) max_instances = fields.Integer(allow_none=True) icon_name = fields.Str(allow_none=True) - instances = fields.Nested("InstanceSchema", many=True, allow_none=True) - commands = fields.Nested("CommandSchema", many=True, allow_none=True) + instances = fields.List(fields.Nested(lambda: InstanceSchema()), allow_none=True) + commands = fields.List(fields.Nested(lambda: CommandSchema()), allow_none=True) display_name = fields.Str(allow_none=True) metadata = fields.Dict(allow_none=True) namespace = fields.Str(allow_none=True) @@ -214,9 +209,7 @@ class FileSchema(BaseSchema): owner = fields.Raw(allow_none=True) job = fields.Nested("JobSchema", allow_none=True) request = fields.Nested("RequestSchema", allow_none=True) - updated_at = fields.DateTime( - allow_none=True, format="timestamp", example="1500065932000" - ) + updated_at = fields.DateTime(allow_none=True, format="timestamp_ms") file_name = fields.Str(allow_none=True) file_size = fields.Int(allow_none=False) chunks = fields.Dict(allow_none=True) @@ -277,23 +270,28 @@ class RequestTemplateSchema(BaseSchema): class RequestSchema(RequestTemplateSchema): id = fields.Str(allow_none=True) is_event = fields.Bool(allow_none=True) - parent = fields.Nested("self", exclude=("children",), allow_none=True) - children = fields.Nested( - "self", exclude=("parent", "children"), many=True, default=None, allow_none=True + parent = fields.Nested( + lambda: RequestSchema(exclude=("children",)), allow_none=True + ) + children = fields.List( + fields.Nested( + lambda: RequestSchema( + exclude=( + "parent", + "children", + ) + ) + ), + dump_default=None, + allow_none=True, ) output = fields.Str(allow_none=True) hidden = fields.Boolean(allow_none=True) status = fields.Str(allow_none=True) error_class = fields.Str(allow_none=True) - created_at = fields.DateTime( - allow_none=True, format="timestamp", example="1500065932000" - ) - updated_at = fields.DateTime( - allow_none=True, format="timestamp", example="1500065932000" - ) - status_updated_at = fields.DateTime( - allow_none=True, format="timestamp", example="1500065932000" - ) + created_at = fields.DateTime(allow_none=True, format="timestamp_ms") + updated_at = fields.DateTime(allow_none=True, format="timestamp_ms") + status_updated_at = fields.DateTime(allow_none=True, format="timestamp_ms") has_parent = fields.Bool(allow_none=True) requester = fields.String(allow_none=True) source_garden = fields.String(allow_none=True) @@ -301,17 +299,13 @@ class RequestSchema(RequestTemplateSchema): class StatusHistorySchema(BaseSchema): - heartbeat = fields.DateTime( - allow_none=True, format="timestamp", example="1500065932000" - ) + heartbeat = fields.DateTime(allow_none=True, format="timestamp_ms") status = fields.Str(allow_none=True) class StatusInfoSchema(BaseSchema): - heartbeat = fields.DateTime( - allow_none=True, format="timestamp", example="1500065932000" - ) - history = fields.Nested("StatusHistorySchema", many=True, allow_none=True) + heartbeat = fields.DateTime(allow_none=True, format="timestamp_ms") + history = fields.List(fields.Nested(lambda: StatusHistorySchema()), allow_none=True) class PatchSchema(BaseSchema): @@ -320,7 +314,7 @@ class PatchSchema(BaseSchema): value = fields.Raw(allow_none=True) @pre_load(pass_many=True) - def unwrap_envelope(self, data, many): + def unwrap_envelope(self, data, many, **_): """Helper function for parsing the different patch formats. This exists because previously multiple patches serialized like:: @@ -363,9 +357,7 @@ class EventSchema(BaseSchema): namespace = fields.Str(allow_none=True) garden = fields.Str(allow_none=True) metadata = fields.Dict(allow_none=True) - timestamp = fields.DateTime( - allow_none=True, format="timestamp", example="1500065932000" - ) + timestamp = fields.DateTime(allow_none=True, format="timestamp_ms") payload_type = fields.Str(allow_none=True) payload = ModelField(allow_none=True, type_field="payload_type") @@ -387,19 +379,13 @@ class QueueSchema(BaseSchema): class UserTokenSchema(BaseSchema): id = fields.Str(allow_none=True) uuid = fields.Str(allow_none=True) - issued_at = fields.DateTime( - allow_none=True, format="timestamp", example="1500065932000" - ) - expires_at = fields.DateTime( - allow_none=True, format="timestamp", example="1500065932000" - ) + issued_at = fields.DateTime(allow_none=True, format="timestamp_ms") + expires_at = fields.DateTime(allow_none=True, format="timestamp_ms") username = fields.Str(allow_none=True) class DateTriggerSchema(BaseSchema): - run_date = fields.DateTime( - allow_none=True, format="timestamp", example="1500065932000" - ) + run_date = fields.DateTime(allow_none=True, format="timestamp_ms") timezone = fields.Str(allow_none=True) @@ -409,12 +395,8 @@ class IntervalTriggerSchema(BaseSchema): hours = fields.Int(allow_none=True) minutes = fields.Int(allow_none=True) seconds = fields.Int(allow_none=True) - start_date = fields.DateTime( - allow_none=True, format="timestamp", example="1500065932000" - ) - end_date = fields.DateTime( - allow_none=True, format="timestamp", example="1500065932000" - ) + start_date = fields.DateTime(allow_none=True, format="timestamp_ms") + end_date = fields.DateTime(allow_none=True, format="timestamp_ms") timezone = fields.Str(allow_none=True) jitter = fields.Int(allow_none=True) reschedule_on_finish = fields.Bool(allow_none=True) @@ -429,12 +411,8 @@ class CronTriggerSchema(BaseSchema): hour = fields.Str(allow_none=True) minute = fields.Str(allow_none=True) second = fields.Str(allow_none=True) - start_date = fields.DateTime( - allow_none=True, format="timestamp", example="1500065932000" - ) - end_date = fields.DateTime( - allow_none=True, format="timestamp", example="1500065932000" - ) + start_date = fields.DateTime(allow_none=True, format="timestamp_ms") + end_date = fields.DateTime(allow_none=True, format="timestamp_ms") timezone = fields.Str(allow_none=True) jitter = fields.Int(allow_none=True) @@ -452,7 +430,7 @@ class FileTriggerSchema(BaseSchema): class ConnectionSchema(BaseSchema): api = fields.Str(allow_none=True) status = fields.Str(allow_none=True) - status_info = fields.Nested("StatusInfoSchema", allow_none=True) + status_info = fields.Nested(lambda: StatusInfoSchema(), allow_none=True) config = fields.Dict(allow_none=True) @@ -460,20 +438,20 @@ class GardenSchema(BaseSchema): id = fields.Str(allow_none=True) name = fields.Str(allow_none=True) status = fields.Str(allow_none=True) - status_info = fields.Nested("StatusInfoSchema", allow_none=True) + status_info = fields.Nested(lambda: StatusInfoSchema(), allow_none=True) connection_type = fields.Str(allow_none=True) - receiving_connections = fields.Nested( - "ConnectionSchema", many=True, allow_none=True + receiving_connections = fields.List( + fields.Nested(lambda: ConnectionSchema()), allow_none=True ) - publishing_connections = fields.Nested( - "ConnectionSchema", many=True, allow_none=True + publishing_connections = fields.List( + fields.Nested(lambda: ConnectionSchema()), allow_none=True ) namespaces = fields.List(fields.Str(), allow_none=True) - systems = fields.Nested("SystemSchema", many=True, allow_none=True) + systems = fields.List(fields.Nested(lambda: SystemSchema()), allow_none=True) has_parent = fields.Bool(allow_none=True) parent = fields.Str(allow_none=True) - children = fields.Nested( - "self", exclude=("parent"), many=True, default=None, allow_none=True + children = fields.List( + fields.Nested(lambda: GardenSchema(exclude=("parent",))), allow_none=True ) metadata = fields.Dict(allow_none=True) default_user = fields.Str(allow_none=True) @@ -493,9 +471,7 @@ class JobSchema(BaseSchema): request_template = fields.Nested("RequestTemplateSchema", allow_none=True) misfire_grace_time = fields.Int(allow_none=True) coalesce = fields.Bool(allow_none=True) - next_run_time = fields.DateTime( - allow_none=True, format="timestamp", example="1500065932000" - ) + next_run_time = fields.DateTime(allow_none=True, format="timestamp_ms") success_count = fields.Int(allow_none=True) error_count = fields.Int(allow_none=True) canceled_count = fields.Int(allow_none=True) @@ -608,9 +584,7 @@ class TopicSchema(BaseSchema): class ReplicationSchema(BaseSchema): id = fields.Str(allow_none=True) replication_id = fields.Str(allow_none=True) - expires_at = fields.DateTime( - allow_none=True, format="timestamp", example="1500065932000" - ) + expires_at = fields.DateTime(allow_none=True, format="timestamp_ms") class UserSchema(BaseSchema): diff --git a/setup.py b/setup.py index bb425e10..20292885 100644 --- a/setup.py +++ b/setup.py @@ -32,7 +32,7 @@ def find_version(): install_requires=[ "appdirs<2", "lark-parser<1", - "marshmallow<4", + "marshmallow<4,>=3.3", "marshmallow-polyfield<6", "packaging", "pika<=1.4,>=1.0.1", diff --git a/test/schema_parser_test.py b/test/schema_parser_test.py index d49652ee..28e200b9 100644 --- a/test/schema_parser_test.py +++ b/test/schema_parser_test.py @@ -60,11 +60,11 @@ def test_error(self, data, kwargs, error): with pytest.raises(error): SchemaParser.parse_system(data, **kwargs) - def test_non_strict_failure(self, system_dict): - system_dict["name"] = 1234 - value = SchemaParser.parse_system(system_dict, from_string=False, strict=False) - assert value.get("name") is None - assert value["version"] == system_dict["version"] + # def test_non_strict_failure(self, system_dict): + # system_dict["name"] = 1234 + # value = SchemaParser.parse_system(system_dict, from_string=False, strict=False) + # assert value.get("name") is None + # assert value["version"] == system_dict["version"] def test_no_modify(self, system_dict): system_copy = copy.deepcopy(system_dict) diff --git a/test/schema_test.py b/test/schema_test.py index da25bd61..b95f0a6b 100644 --- a/test/schema_test.py +++ b/test/schema_test.py @@ -6,17 +6,15 @@ from pytest_lazyfixture import lazy_fixture from brewtils.models import System +from brewtils.schema_parser import SchemaParser from brewtils.schemas import ( BaseSchema, - DateTime, SystemSchema, _deserialize_model, _serialize_model, model_schema_map, ) -from brewtils.schema_parser import SchemaParser - class TestSchemas(object): def test_make_object(self): @@ -36,29 +34,6 @@ def test_get_attributes(self): class TestFields(object): - @pytest.mark.parametrize( - "dt,localtime,expected", - [ - (lazy_fixture("ts_dt"), False, lazy_fixture("ts_epoch")), - (lazy_fixture("ts_dt"), True, lazy_fixture("ts_epoch")), - (lazy_fixture("ts_dt_eastern"), False, lazy_fixture("ts_epoch_eastern")), - (lazy_fixture("ts_dt_eastern"), True, lazy_fixture("ts_epoch")), - (lazy_fixture("ts_epoch"), False, lazy_fixture("ts_epoch")), - (lazy_fixture("ts_epoch"), True, lazy_fixture("ts_epoch")), - ], - ) - def test_to_epoch(self, dt, localtime, expected): - assert DateTime.to_epoch(dt, localtime) == expected - - @pytest.mark.parametrize( - "epoch,expected", - [ - (lazy_fixture("ts_epoch"), lazy_fixture("ts_dt")), - (lazy_fixture("ts_dt"), lazy_fixture("ts_dt")), - ], - ) - def test_from_epoch(self, epoch, expected): - assert DateTime.from_epoch(epoch) == expected def test_modelfield_serialize_invalid_type(self): with pytest.raises(TypeError):