diff --git a/brewtils/schemas.py b/brewtils/schemas.py index 101f0678..a30a07fe 100644 --- a/brewtils/schemas.py +++ b/brewtils/schemas.py @@ -2,7 +2,7 @@ from functools import partial -from marshmallow import Schema, fields, post_load, pre_load +from marshmallow import Schema, fields, post_load, pre_load, utils from marshmallow_polyfield import PolyField __all__ = [ @@ -89,6 +89,34 @@ def __init__(self, type_field="payload_type", allowed_types=None, **kwargs): ) +class DateTime(fields.DateTime): + """Class that adds methods for (de)serializing DateTime fields as an epoch + + This is required for going from Mongo Model objects to Marshmallow model Objects + """ + + def __init__(self, format="epoch", **kwargs): + self.SERIALIZATION_FUNCS["epoch"] = self.to_epoch + self.DESERIALIZATION_FUNCS["epoch"] = self.from_epoch + super(DateTime, self).__init__(format=format, **kwargs) + + @staticmethod + def to_epoch(value): + # If already in epoch form just return it + if isinstance(value, int): + return value + + return utils.timestamp_ms(value) + + @staticmethod + def from_epoch(value): + # If already in datetime form just return it + if isinstance(value, datetime.datetime): + return value + + return utils.from_timestamp_ms(value) + + class BaseSchema(Schema): @post_load @@ -203,7 +231,7 @@ class FileSchema(BaseSchema): owner = fields.Raw(allow_none=True) job = fields.Nested("JobSchema", allow_none=True) request = fields.Nested("RequestSchema", allow_none=True) - updated_at = fields.DateTime(allow_none=True, format="timestamp_ms") + updated_at = DateTime(allow_none=True, format="epoch") file_name = fields.Str(allow_none=True) file_size = fields.Int(allow_none=False) chunks = fields.Dict(allow_none=True) @@ -283,9 +311,9 @@ class RequestSchema(RequestTemplateSchema): hidden = fields.Boolean(allow_none=True) status = fields.Str(allow_none=True) error_class = fields.Str(allow_none=True) - created_at = fields.DateTime(allow_none=True, format="timestamp_ms") - updated_at = fields.DateTime(allow_none=True, format="timestamp_ms") - status_updated_at = fields.DateTime(allow_none=True, format="timestamp_ms") + created_at = DateTime(allow_none=True, format="epoch") + updated_at = DateTime(allow_none=True, format="epoch") + status_updated_at = DateTime(allow_none=True, format="epoch") has_parent = fields.Bool(allow_none=True) requester = fields.String(allow_none=True) source_garden = fields.String(allow_none=True) @@ -293,12 +321,12 @@ class RequestSchema(RequestTemplateSchema): class StatusHistorySchema(BaseSchema): - heartbeat = fields.DateTime(allow_none=True, format="timestamp_ms") + heartbeat = DateTime(allow_none=True, format="epoch") status = fields.Str(allow_none=True) class StatusInfoSchema(BaseSchema): - heartbeat = fields.DateTime(allow_none=True, format="timestamp_ms") + heartbeat = DateTime(allow_none=True, format="epoch") history = fields.List(fields.Nested(lambda: StatusHistorySchema()), allow_none=True) @@ -351,7 +379,7 @@ class EventSchema(BaseSchema): namespace = fields.Str(allow_none=True) garden = fields.Str(allow_none=True) metadata = fields.Dict(allow_none=True) - timestamp = fields.DateTime(allow_none=True, format="timestamp_ms") + timestamp = DateTime(allow_none=True, format="epoch") payload_type = fields.Str(allow_none=True) payload = ModelField(allow_none=True, type_field="payload_type") @@ -373,13 +401,13 @@ class QueueSchema(BaseSchema): class UserTokenSchema(BaseSchema): id = fields.Str(allow_none=True) uuid = fields.Str(allow_none=True) - issued_at = fields.DateTime(allow_none=True, format="timestamp_ms") - expires_at = fields.DateTime(allow_none=True, format="timestamp_ms") + issued_at = DateTime(allow_none=True, format="epoch") + expires_at = DateTime(allow_none=True, format="epoch") username = fields.Str(allow_none=True) class DateTriggerSchema(BaseSchema): - run_date = fields.DateTime(allow_none=True, format="timestamp_ms") + run_date = DateTime(allow_none=True, format="epoch") timezone = fields.Str(allow_none=True) @@ -389,8 +417,8 @@ class IntervalTriggerSchema(BaseSchema): hours = fields.Int(allow_none=True) minutes = fields.Int(allow_none=True) seconds = fields.Int(allow_none=True) - start_date = fields.DateTime(allow_none=True, format="timestamp_ms") - end_date = fields.DateTime(allow_none=True, format="timestamp_ms") + start_date = DateTime(allow_none=True, format="epoch") + end_date = DateTime(allow_none=True, format="epoch") timezone = fields.Str(allow_none=True) jitter = fields.Int(allow_none=True) reschedule_on_finish = fields.Bool(allow_none=True) @@ -405,8 +433,8 @@ class CronTriggerSchema(BaseSchema): hour = fields.Str(allow_none=True) minute = fields.Str(allow_none=True) second = fields.Str(allow_none=True) - start_date = fields.DateTime(allow_none=True, format="timestamp_ms") - end_date = fields.DateTime(allow_none=True, format="timestamp_ms") + start_date = DateTime(allow_none=True, format="epoch") + end_date = DateTime(allow_none=True, format="epoch") timezone = fields.Str(allow_none=True) jitter = fields.Int(allow_none=True) @@ -465,7 +493,7 @@ class JobSchema(BaseSchema): request_template = fields.Nested("RequestTemplateSchema", allow_none=True) misfire_grace_time = fields.Int(allow_none=True) coalesce = fields.Bool(allow_none=True) - next_run_time = fields.DateTime(allow_none=True, format="timestamp_ms") + next_run_time = DateTime(allow_none=True, format="epoch") success_count = fields.Int(allow_none=True) error_count = fields.Int(allow_none=True) canceled_count = fields.Int(allow_none=True) @@ -578,7 +606,7 @@ class TopicSchema(BaseSchema): class ReplicationSchema(BaseSchema): id = fields.Str(allow_none=True) replication_id = fields.Str(allow_none=True) - expires_at = fields.DateTime(allow_none=True, format="timestamp_ms") + expires_at = DateTime(allow_none=True, format="epoch") class UserSchema(BaseSchema):