diff --git a/requirements/test.txt b/requirements/test.txt index 893c49f0..a2c7b4d6 100644 --- a/requirements/test.txt +++ b/requirements/test.txt @@ -1,2 +1,4 @@ -r install.txt +enum34 +marshmallow_enum unittest2 diff --git a/tests/marshmallow_serializers.py b/tests/marshmallow_serializers.py new file mode 100644 index 00000000..c1b8101f --- /dev/null +++ b/tests/marshmallow_serializers.py @@ -0,0 +1,52 @@ +# serializers.py - custom serializers for unit tests +# using the marshmallow de/serialization library +# +# .._Marshmallow: http://marshmallow.readthedocs.io/en/latest/index.html +# +# Copyright 2018 Kiptoo Magutt . +# Copyright 2012, 2013, 2014, 2015, 2016 Jeffrey Finkelstein +# and contributors. +# +# This file is part of Flask-Restless. +# +# Flask-Restless is distributed under both the GNU Affero General Public +# License version 3 and under the 3-clause BSD license. For more +# information, see LICENSE.AGPL and LICENSE.BSD. +"""Helper functions for unit tests.""" + +from flask_restless import DefaultSerializer +from flask_restless import DefaultDeserializer + + +class MarshmallowSerializer(DefaultSerializer): + """ + Base class for models that need custom serializers + using the marshmallow library + + See + :class:`TestUpdating.TestSupport.AddressSchema` + for example usage + + """ + schema_class = None + + def serialize(self, instance, only=None): + schema = self.schema_class(only=only) + return schema.dump(instance).data + + def serialize_many(self, instances, only=None): + schema = self.schema_class(many=True, only=only) + return schema.dump(instances).data + + +class MarshmallowDeserializer(DefaultDeserializer): + + schema_class = None + + def deserialize(self, document): + schema = self.schema_class() + return schema.load(document).data + + def deserialize_many(self, document): + schema = self.schema_class(many=True) + return schema.load(document).data diff --git a/tests/test_updating.py b/tests/test_updating.py index 542bba22..be9efb82 100644 --- a/tests/test_updating.py +++ b/tests/test_updating.py @@ -22,6 +22,7 @@ from datetime import datetime from unittest2 import skip +from enum import IntEnum try: from flask_sqlalchemy import SQLAlchemy @@ -37,11 +38,17 @@ from sqlalchemy import Integer from sqlalchemy import Time from sqlalchemy import Unicode +from sqlalchemy import Enum +from sqlalchemy import String from sqlalchemy.ext.associationproxy import association_proxy from sqlalchemy.ext.hybrid import hybrid_property from sqlalchemy.orm import backref from sqlalchemy.orm import relationship +from marshmallow import post_load +from marshmallow_jsonapi import Schema, fields +from marshmallow_enum import EnumField + from flask_restless import APIManager from flask_restless import JSONAPI_MIMETYPE from flask_restless import ProcessingException @@ -56,6 +63,12 @@ from .helpers import ManagerTestBase from .helpers import raise_s_exception as raise_exception +from .marshmallow_serializers import MarshmallowSerializer +from .marshmallow_serializers import MarshmallowDeserializer + +# from flask_restless import DefaultSerializer +# from flask_restless import DefaultDeserializer + class TestUpdating(ManagerTestBase): """Tests for updating resources.""" @@ -63,8 +76,13 @@ class TestUpdating(ManagerTestBase): def setUp(self): """Creates the database, the :class:`~flask.Flask` object, the :class:`~flask_restless.manager.APIManager` for that application, and - creates the ReSTful API endpoints for the :class:`TestSupport.Person` - and :class:`TestSupport.Article` models. + creates the ReSTful API endpoints for the :class:`TestSupport.Person`, + :class:`TestSupport.Article` and :class:`TestSupport.Address` models. + + For custom serialization, it also creates + :class:`TestSupport.AddressSchema`, + :class:`TestSupport.AddressSerializer` and + :class:`TestSupport.AddressDeserializer` """ super(TestUpdating, self).setUp() @@ -87,6 +105,37 @@ class Person(self.Base): def foo(self): return u'foo' + class AddressType(IntEnum): + HOME, OFFICE, UNKNOWN = range(3) + + class Address(self.Base): + __tablename__ = 'address' + id = Column(Integer, primary_key=True) + address_str = Column(String(30)) + address_type = Column(Enum(AddressType), default=AddressType.UNKNOWN) + # field to force model serialization, based on 'changes on update' + time_updated = Column(DateTime, onupdate=datetime.utcnow) + + class AddressSchema(Schema): + id = fields.Integer(dump_only=True) + address_str = fields.Str() + address_type = EnumField(AddressType) + + class Meta: + type_ = 'address' + model = Address + strict = True + + @post_load + def make_object(self, data): + return self.Meta.model(**data) + + class AddressSerializer(MarshmallowSerializer): + schema_class = AddressSchema + + class AddressDeserializer(MarshmallowDeserializer): + schema_class = AddressSchema + # This example comes from the SQLAlchemy documentation. # # The SQLAlchemy documentation is licensed under the MIT license. @@ -122,11 +171,18 @@ class Tag(self.Base): self.Article = Article self.Interval = Interval self.Person = Person + self.Address = Address + self.AddressType = AddressType self.Tag = Tag self.Base.metadata.create_all() + # deserializer_class = DefaultDeserializer + # deserializer_ = deserializer_class(session=self.session, model=Person,allow_client_generated_ids=True) self.manager.create_api(Article, methods=['PATCH']) self.manager.create_api(Interval, methods=['PATCH']) self.manager.create_api(Person, methods=['PATCH']) + self.manager.create_api(Address, methods=['PATCH'], + serializer_class=AddressSerializer, + deserializer_class=AddressDeserializer) def test_wrong_content_type(self): """Tests that if a client specifies only :http:header:`Accept` @@ -257,6 +313,25 @@ def test_deserializing_datetime(self): assert response.status_code == 204 assert person.birth_datetime == now + def test_deserializing_enum_field(self): + """Test for deserializing a JSON representation of an Enum field.""" + address = self.Address(id=1) + self.session.add(address) + self.session.commit() + data = { + 'data': { + 'type': 'address', + 'id': '1', + 'attributes': { + 'address_type': "HOME" + } + } + } + data = dumps(data) + response = self.app.patch('/api/address/1', data=data) + assert response.status_code == 200 # changes on update were forced + assert address.address_type == self.AddressType.HOME + def test_correct_content_type(self): """Tests that the server responds with :http:status:`201` if the request has the correct JSON API content type.