From a03544e5541c67b3cfcf930f89a8e06f650fa1dd Mon Sep 17 00:00:00 2001 From: Rose Davidson Date: Thu, 26 Jul 2018 16:55:05 -0700 Subject: [PATCH 1/6] Import protorpc from google/protorpc repo. --- endpoints/internal/__init__.py | 15 + endpoints/internal/protorpc/__init__.py | 21 + endpoints/internal/protorpc/definition.py | 290 +++ .../internal/protorpc/definition_test.py | 657 +++++ endpoints/internal/protorpc/descriptor.py | 712 ++++++ .../internal/protorpc/descriptor_test.py | 649 +++++ endpoints/internal/protorpc/end2end_test.py | 148 ++ .../protorpc/experimental/__init__.py | 20 + .../protorpc/experimental/parser/protobuf.g | 159 ++ .../experimental/parser/protobuf_lexer.g | 153 ++ .../protorpc/experimental/parser/pyprotobuf.g | 45 + .../protorpc/experimental/parser/test.proto | 27 + endpoints/internal/protorpc/generate.py | 128 + endpoints/internal/protorpc/generate_proto.py | 127 + .../internal/protorpc/generate_proto_test.py | 197 ++ .../internal/protorpc/generate_python.py | 218 ++ .../internal/protorpc/generate_python_test.py | 362 +++ endpoints/internal/protorpc/generate_test.py | 152 ++ endpoints/internal/protorpc/google_imports.py | 15 + endpoints/internal/protorpc/message_types.py | 119 + .../internal/protorpc/message_types_test.py | 88 + endpoints/internal/protorpc/messages.py | 1949 +++++++++++++++ endpoints/internal/protorpc/messages_test.py | 2109 +++++++++++++++++ .../internal/protorpc/non_sdk_imports.py | 21 + endpoints/internal/protorpc/protobuf.py | 359 +++ endpoints/internal/protorpc/protobuf_test.py | 299 +++ endpoints/internal/protorpc/protojson.py | 363 +++ endpoints/internal/protorpc/protojson_test.py | 565 +++++ .../internal/protorpc/protorpc_test.proto | 83 + .../internal/protorpc/protorpc_test_pb2.py | 405 ++++ endpoints/internal/protorpc/protourlencode.py | 563 +++++ .../internal/protorpc/protourlencode_test.py | 369 +++ endpoints/internal/protorpc/registry.py | 240 ++ endpoints/internal/protorpc/registry_test.py | 124 + endpoints/internal/protorpc/remote.py | 1248 ++++++++++ endpoints/internal/protorpc/remote_test.py | 933 ++++++++ endpoints/internal/protorpc/static/base.html | 57 + endpoints/internal/protorpc/static/forms.html | 31 + endpoints/internal/protorpc/static/forms.js | 685 ++++++ .../protorpc/static/jquery-1.4.2.min.js | 154 ++ .../protorpc/static/jquery.json-2.2.min.js | 31 + .../internal/protorpc/static/methods.html | 37 + endpoints/internal/protorpc/test_util.py | 671 ++++++ endpoints/internal/protorpc/transport.py | 412 ++++ endpoints/internal/protorpc/transport_test.py | 493 ++++ endpoints/internal/protorpc/util.py | 494 ++++ endpoints/internal/protorpc/util_test.py | 394 +++ .../internal/protorpc/webapp/__init__.py | 18 + endpoints/internal/protorpc/webapp/forms.py | 163 ++ .../internal/protorpc/webapp/forms_test.py | 159 ++ .../protorpc/webapp/google_imports.py | 25 + .../protorpc/webapp/service_handlers.py | 834 +++++++ .../protorpc/webapp/service_handlers_test.py | 1332 +++++++++++ .../internal/protorpc/webapp_test_util.py | 411 ++++ endpoints/internal/protorpc/wsgi/__init__.py | 16 + endpoints/internal/protorpc/wsgi/service.py | 267 +++ .../internal/protorpc/wsgi/service_test.py | 205 ++ endpoints/internal/protorpc/wsgi/util.py | 180 ++ endpoints/internal/protorpc/wsgi/util_test.py | 295 +++ 59 files changed, 21296 insertions(+) create mode 100644 endpoints/internal/__init__.py create mode 100644 endpoints/internal/protorpc/__init__.py create mode 100644 endpoints/internal/protorpc/definition.py create mode 100644 endpoints/internal/protorpc/definition_test.py create mode 100644 endpoints/internal/protorpc/descriptor.py create mode 100644 endpoints/internal/protorpc/descriptor_test.py create mode 100644 endpoints/internal/protorpc/end2end_test.py create mode 100644 endpoints/internal/protorpc/experimental/__init__.py create mode 100644 endpoints/internal/protorpc/experimental/parser/protobuf.g create mode 100644 endpoints/internal/protorpc/experimental/parser/protobuf_lexer.g create mode 100644 endpoints/internal/protorpc/experimental/parser/pyprotobuf.g create mode 100644 endpoints/internal/protorpc/experimental/parser/test.proto create mode 100644 endpoints/internal/protorpc/generate.py create mode 100644 endpoints/internal/protorpc/generate_proto.py create mode 100644 endpoints/internal/protorpc/generate_proto_test.py create mode 100644 endpoints/internal/protorpc/generate_python.py create mode 100644 endpoints/internal/protorpc/generate_python_test.py create mode 100644 endpoints/internal/protorpc/generate_test.py create mode 100644 endpoints/internal/protorpc/google_imports.py create mode 100644 endpoints/internal/protorpc/message_types.py create mode 100644 endpoints/internal/protorpc/message_types_test.py create mode 100644 endpoints/internal/protorpc/messages.py create mode 100644 endpoints/internal/protorpc/messages_test.py create mode 100644 endpoints/internal/protorpc/non_sdk_imports.py create mode 100644 endpoints/internal/protorpc/protobuf.py create mode 100644 endpoints/internal/protorpc/protobuf_test.py create mode 100644 endpoints/internal/protorpc/protojson.py create mode 100644 endpoints/internal/protorpc/protojson_test.py create mode 100644 endpoints/internal/protorpc/protorpc_test.proto create mode 100644 endpoints/internal/protorpc/protorpc_test_pb2.py create mode 100644 endpoints/internal/protorpc/protourlencode.py create mode 100644 endpoints/internal/protorpc/protourlencode_test.py create mode 100644 endpoints/internal/protorpc/registry.py create mode 100644 endpoints/internal/protorpc/registry_test.py create mode 100644 endpoints/internal/protorpc/remote.py create mode 100644 endpoints/internal/protorpc/remote_test.py create mode 100644 endpoints/internal/protorpc/static/base.html create mode 100644 endpoints/internal/protorpc/static/forms.html create mode 100644 endpoints/internal/protorpc/static/forms.js create mode 100644 endpoints/internal/protorpc/static/jquery-1.4.2.min.js create mode 100644 endpoints/internal/protorpc/static/jquery.json-2.2.min.js create mode 100644 endpoints/internal/protorpc/static/methods.html create mode 100644 endpoints/internal/protorpc/test_util.py create mode 100644 endpoints/internal/protorpc/transport.py create mode 100644 endpoints/internal/protorpc/transport_test.py create mode 100644 endpoints/internal/protorpc/util.py create mode 100644 endpoints/internal/protorpc/util_test.py create mode 100644 endpoints/internal/protorpc/webapp/__init__.py create mode 100644 endpoints/internal/protorpc/webapp/forms.py create mode 100644 endpoints/internal/protorpc/webapp/forms_test.py create mode 100644 endpoints/internal/protorpc/webapp/google_imports.py create mode 100644 endpoints/internal/protorpc/webapp/service_handlers.py create mode 100644 endpoints/internal/protorpc/webapp/service_handlers_test.py create mode 100644 endpoints/internal/protorpc/webapp_test_util.py create mode 100644 endpoints/internal/protorpc/wsgi/__init__.py create mode 100644 endpoints/internal/protorpc/wsgi/service.py create mode 100644 endpoints/internal/protorpc/wsgi/service_test.py create mode 100644 endpoints/internal/protorpc/wsgi/util.py create mode 100644 endpoints/internal/protorpc/wsgi/util_test.py diff --git a/endpoints/internal/__init__.py b/endpoints/internal/__init__.py new file mode 100644 index 0000000..7bbe865 --- /dev/null +++ b/endpoints/internal/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2018 Google Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Embedded libraries.""" diff --git a/endpoints/internal/protorpc/__init__.py b/endpoints/internal/protorpc/__init__.py new file mode 100644 index 0000000..9005262 --- /dev/null +++ b/endpoints/internal/protorpc/__init__.py @@ -0,0 +1,21 @@ +#!/usr/bin/env python +# +# Copyright 2011 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""Main module for ProtoRPC package.""" + +__author__ = 'rafek@google.com (Rafe Kaplan)' +__version__ = '1.0' diff --git a/endpoints/internal/protorpc/definition.py b/endpoints/internal/protorpc/definition.py new file mode 100644 index 0000000..46ee167 --- /dev/null +++ b/endpoints/internal/protorpc/definition.py @@ -0,0 +1,290 @@ +#!/usr/bin/env python +# +# Copyright 2010 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""Stub library.""" +import six + +__author__ = 'rafek@google.com (Rafe Kaplan)' + +import sys +import types + +from . import descriptor +from . import message_types +from . import messages +from . import protobuf +from . import remote +from . import util + +__all__ = [ + 'define_enum', + 'define_field', + 'define_file', + 'define_message', + 'define_service', + 'import_file', + 'import_file_set', +] + + +# Map variant back to message field classes. +def _build_variant_map(): + """Map variants to fields. + + Returns: + Dictionary mapping field variant to its associated field type. + """ + result = {} + for name in dir(messages): + value = getattr(messages, name) + if isinstance(value, type) and issubclass(value, messages.Field): + for variant in getattr(value, 'VARIANTS', []): + result[variant] = value + return result + +_VARIANT_MAP = _build_variant_map() + +_MESSAGE_TYPE_MAP = { + message_types.DateTimeMessage.definition_name(): message_types.DateTimeField, +} + + +def _get_or_define_module(full_name, modules): + """Helper method for defining new modules. + + Args: + full_name: Fully qualified name of module to create or return. + modules: Dictionary of all modules. Defaults to sys.modules. + + Returns: + Named module if found in 'modules', else creates new module and inserts in + 'modules'. Will also construct parent modules if necessary. + """ + module = modules.get(full_name) + if not module: + module = types.ModuleType(full_name) + modules[full_name] = module + + split_name = full_name.rsplit('.', 1) + if len(split_name) > 1: + parent_module_name, sub_module_name = split_name + parent_module = _get_or_define_module(parent_module_name, modules) + setattr(parent_module, sub_module_name, module) + + return module + + +def define_enum(enum_descriptor, module_name): + """Define Enum class from descriptor. + + Args: + enum_descriptor: EnumDescriptor to build Enum class from. + module_name: Module name to give new descriptor class. + + Returns: + New messages.Enum sub-class as described by enum_descriptor. + """ + enum_values = enum_descriptor.values or [] + + class_dict = dict((value.name, value.number) for value in enum_values) + class_dict['__module__'] = module_name + return type(str(enum_descriptor.name), (messages.Enum,), class_dict) + + +def define_field(field_descriptor): + """Define Field instance from descriptor. + + Args: + field_descriptor: FieldDescriptor class to build field instance from. + + Returns: + New field instance as described by enum_descriptor. + """ + field_class = _VARIANT_MAP[field_descriptor.variant] + params = {'number': field_descriptor.number, + 'variant': field_descriptor.variant, + } + + if field_descriptor.label == descriptor.FieldDescriptor.Label.REQUIRED: + params['required'] = True + elif field_descriptor.label == descriptor.FieldDescriptor.Label.REPEATED: + params['repeated'] = True + + message_type_field = _MESSAGE_TYPE_MAP.get(field_descriptor.type_name) + if message_type_field: + return message_type_field(**params) + elif field_class in (messages.EnumField, messages.MessageField): + return field_class(field_descriptor.type_name, **params) + else: + if field_descriptor.default_value: + value = field_descriptor.default_value + try: + value = descriptor._DEFAULT_FROM_STRING_MAP[field_class](value) + except (TypeError, ValueError, KeyError): + pass # Let the value pass to the constructor. + params['default'] = value + return field_class(**params) + + +def define_message(message_descriptor, module_name): + """Define Message class from descriptor. + + Args: + message_descriptor: MessageDescriptor to describe message class from. + module_name: Module name to give to new descriptor class. + + Returns: + New messages.Message sub-class as described by message_descriptor. + """ + class_dict = {'__module__': module_name} + + for enum in message_descriptor.enum_types or []: + enum_instance = define_enum(enum, module_name) + class_dict[enum.name] = enum_instance + + # TODO(rafek): support nested messages when supported by descriptor. + + for field in message_descriptor.fields or []: + field_instance = define_field(field) + class_dict[field.name] = field_instance + + class_name = message_descriptor.name.encode('utf-8') + return type(class_name, (messages.Message,), class_dict) + + +def define_service(service_descriptor, module): + """Define a new service proxy. + + Args: + service_descriptor: ServiceDescriptor class that describes the service. + module: Module to add service to. Request and response types are found + relative to this module. + + Returns: + Service class proxy capable of communicating with a remote server. + """ + class_dict = {'__module__': module.__name__} + class_name = service_descriptor.name.encode('utf-8') + + for method_descriptor in service_descriptor.methods or []: + request_definition = messages.find_definition( + method_descriptor.request_type, module) + response_definition = messages.find_definition( + method_descriptor.response_type, module) + + method_name = method_descriptor.name.encode('utf-8') + def remote_method(self, request): + """Actual service method.""" + raise NotImplementedError('Method is not implemented') + remote_method.__name__ = method_name + remote_method_decorator = remote.method(request_definition, + response_definition) + + class_dict[method_name] = remote_method_decorator(remote_method) + + service_class = type(class_name, (remote.Service,), class_dict) + return service_class + + +def define_file(file_descriptor, module=None): + """Define module from FileDescriptor. + + Args: + file_descriptor: FileDescriptor instance to describe module from. + module: Module to add contained objects to. Module name overrides value + in file_descriptor.package. Definitions are added to existing + module if provided. + + Returns: + If no module provided, will create a new module with its name set to the + file descriptor's package. If a module is provided, returns the same + module. + """ + if module is None: + module = types.ModuleType(file_descriptor.package) + + for enum_descriptor in file_descriptor.enum_types or []: + enum_class = define_enum(enum_descriptor, module.__name__) + setattr(module, enum_descriptor.name, enum_class) + + for message_descriptor in file_descriptor.message_types or []: + message_class = define_message(message_descriptor, module.__name__) + setattr(module, message_descriptor.name, message_class) + + for service_descriptor in file_descriptor.service_types or []: + service_class = define_service(service_descriptor, module) + setattr(module, service_descriptor.name, service_class) + + return module + + +@util.positional(1) +def import_file(file_descriptor, modules=None): + """Import FileDescriptor in to module space. + + This is like define_file except that a new module and any required parent + modules are created and added to the modules parameter or sys.modules if not + provided. + + Args: + file_descriptor: FileDescriptor instance to describe module from. + modules: Dictionary of modules to update. Modules and their parents that + do not exist will be created. If an existing module is found that + matches file_descriptor.package, that module is updated with the + FileDescriptor contents. + + Returns: + Module found in modules, else a new module. + """ + if not file_descriptor.package: + raise ValueError('File descriptor must have package name') + + if modules is None: + modules = sys.modules + + module = _get_or_define_module(file_descriptor.package.encode('utf-8'), + modules) + + return define_file(file_descriptor, module) + + +@util.positional(1) +def import_file_set(file_set, modules=None, _open=open): + """Import FileSet in to module space. + + Args: + file_set: If string, open file and read serialized FileSet. Otherwise, + a FileSet instance to import definitions from. + modules: Dictionary of modules to update. Modules and their parents that + do not exist will be created. If an existing module is found that + matches file_descriptor.package, that module is updated with the + FileDescriptor contents. + _open: Used for dependency injection during tests. + """ + if isinstance(file_set, six.string_types): + encoded_file = _open(file_set, 'rb') + try: + encoded_file_set = encoded_file.read() + finally: + encoded_file.close() + + file_set = protobuf.decode_message(descriptor.FileSet, encoded_file_set) + + for file_descriptor in file_set.files: + # Do not reload built in protorpc classes. + if not file_descriptor.package.startswith('protorpc.'): + import_file(file_descriptor, modules=modules) diff --git a/endpoints/internal/protorpc/definition_test.py b/endpoints/internal/protorpc/definition_test.py new file mode 100644 index 0000000..992220e --- /dev/null +++ b/endpoints/internal/protorpc/definition_test.py @@ -0,0 +1,657 @@ +#!/usr/bin/env python +# +# Copyright 2010 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""Tests for protorpc.stub.""" + +__author__ = 'rafek@google.com (Rafe Kaplan)' + + +import StringIO +import sys +import types +import unittest + +from protorpc import definition +from protorpc import descriptor +from protorpc import message_types +from protorpc import messages +from protorpc import protobuf +from protorpc import remote +from protorpc import test_util + +import mox + + +class ModuleInterfaceTest(test_util.ModuleInterfaceTest, + test_util.TestCase): + + MODULE = definition + + +class DefineEnumTest(test_util.TestCase): + """Test for define_enum.""" + + def testDefineEnum_Empty(self): + """Test defining an empty enum.""" + enum_descriptor = descriptor.EnumDescriptor() + enum_descriptor.name = 'Empty' + + enum_class = definition.define_enum(enum_descriptor, 'whatever') + + self.assertEquals('Empty', enum_class.__name__) + self.assertEquals('whatever', enum_class.__module__) + + self.assertEquals(enum_descriptor, descriptor.describe_enum(enum_class)) + + def testDefineEnum(self): + """Test defining an enum.""" + red = descriptor.EnumValueDescriptor() + green = descriptor.EnumValueDescriptor() + blue = descriptor.EnumValueDescriptor() + + red.name = 'RED' + red.number = 1 + green.name = 'GREEN' + green.number = 2 + blue.name = 'BLUE' + blue.number = 3 + + enum_descriptor = descriptor.EnumDescriptor() + enum_descriptor.name = 'Colors' + enum_descriptor.values = [red, green, blue] + + enum_class = definition.define_enum(enum_descriptor, 'whatever') + + self.assertEquals('Colors', enum_class.__name__) + self.assertEquals('whatever', enum_class.__module__) + + self.assertEquals(enum_descriptor, descriptor.describe_enum(enum_class)) + + +class DefineFieldTest(test_util.TestCase): + """Test for define_field.""" + + def testDefineField_Optional(self): + """Test defining an optional field instance from a method descriptor.""" + field_descriptor = descriptor.FieldDescriptor() + + field_descriptor.name = 'a_field' + field_descriptor.number = 1 + field_descriptor.variant = descriptor.FieldDescriptor.Variant.INT32 + field_descriptor.label = descriptor.FieldDescriptor.Label.OPTIONAL + + field = definition.define_field(field_descriptor) + + # Name will not be set from the original descriptor. + self.assertFalse(hasattr(field, 'name')) + + self.assertTrue(isinstance(field, messages.IntegerField)) + self.assertEquals(1, field.number) + self.assertEquals(descriptor.FieldDescriptor.Variant.INT32, field.variant) + self.assertFalse(field.required) + self.assertFalse(field.repeated) + + def testDefineField_Required(self): + """Test defining a required field instance from a method descriptor.""" + field_descriptor = descriptor.FieldDescriptor() + + field_descriptor.name = 'a_field' + field_descriptor.number = 1 + field_descriptor.variant = descriptor.FieldDescriptor.Variant.STRING + field_descriptor.label = descriptor.FieldDescriptor.Label.REQUIRED + + field = definition.define_field(field_descriptor) + + # Name will not be set from the original descriptor. + self.assertFalse(hasattr(field, 'name')) + + self.assertTrue(isinstance(field, messages.StringField)) + self.assertEquals(1, field.number) + self.assertEquals(descriptor.FieldDescriptor.Variant.STRING, field.variant) + self.assertTrue(field.required) + self.assertFalse(field.repeated) + + def testDefineField_Repeated(self): + """Test defining a repeated field instance from a method descriptor.""" + field_descriptor = descriptor.FieldDescriptor() + + field_descriptor.name = 'a_field' + field_descriptor.number = 1 + field_descriptor.variant = descriptor.FieldDescriptor.Variant.DOUBLE + field_descriptor.label = descriptor.FieldDescriptor.Label.REPEATED + + field = definition.define_field(field_descriptor) + + # Name will not be set from the original descriptor. + self.assertFalse(hasattr(field, 'name')) + + self.assertTrue(isinstance(field, messages.FloatField)) + self.assertEquals(1, field.number) + self.assertEquals(descriptor.FieldDescriptor.Variant.DOUBLE, field.variant) + self.assertFalse(field.required) + self.assertTrue(field.repeated) + + def testDefineField_Message(self): + """Test defining a message field.""" + field_descriptor = descriptor.FieldDescriptor() + + field_descriptor.name = 'a_field' + field_descriptor.number = 1 + field_descriptor.variant = descriptor.FieldDescriptor.Variant.MESSAGE + field_descriptor.type_name = 'something.yet.to.be.Defined' + field_descriptor.label = descriptor.FieldDescriptor.Label.REPEATED + + field = definition.define_field(field_descriptor) + + # Name will not be set from the original descriptor. + self.assertFalse(hasattr(field, 'name')) + + self.assertTrue(isinstance(field, messages.MessageField)) + self.assertEquals(1, field.number) + self.assertEquals(descriptor.FieldDescriptor.Variant.MESSAGE, field.variant) + self.assertFalse(field.required) + self.assertTrue(field.repeated) + self.assertRaisesWithRegexpMatch(messages.DefinitionNotFoundError, + 'Could not find definition for ' + 'something.yet.to.be.Defined', + getattr, field, 'type') + + def testDefineField_DateTime(self): + """Test defining a date time field.""" + field_descriptor = descriptor.FieldDescriptor() + + field_descriptor.name = 'a_timestamp' + field_descriptor.number = 1 + field_descriptor.variant = descriptor.FieldDescriptor.Variant.MESSAGE + field_descriptor.type_name = 'protorpc.message_types.DateTimeMessage' + field_descriptor.label = descriptor.FieldDescriptor.Label.REPEATED + + field = definition.define_field(field_descriptor) + + # Name will not be set from the original descriptor. + self.assertFalse(hasattr(field, 'name')) + + self.assertTrue(isinstance(field, message_types.DateTimeField)) + self.assertEquals(1, field.number) + self.assertEquals(descriptor.FieldDescriptor.Variant.MESSAGE, field.variant) + self.assertFalse(field.required) + self.assertTrue(field.repeated) + + def testDefineField_Enum(self): + """Test defining an enum field.""" + field_descriptor = descriptor.FieldDescriptor() + + field_descriptor.name = 'a_field' + field_descriptor.number = 1 + field_descriptor.variant = descriptor.FieldDescriptor.Variant.ENUM + field_descriptor.type_name = 'something.yet.to.be.Defined' + field_descriptor.label = descriptor.FieldDescriptor.Label.REPEATED + + field = definition.define_field(field_descriptor) + + # Name will not be set from the original descriptor. + self.assertFalse(hasattr(field, 'name')) + + self.assertTrue(isinstance(field, messages.EnumField)) + self.assertEquals(1, field.number) + self.assertEquals(descriptor.FieldDescriptor.Variant.ENUM, field.variant) + self.assertFalse(field.required) + self.assertTrue(field.repeated) + self.assertRaisesWithRegexpMatch(messages.DefinitionNotFoundError, + 'Could not find definition for ' + 'something.yet.to.be.Defined', + getattr, field, 'type') + + def testDefineField_Default_Bool(self): + """Test defining a default value for a bool.""" + field_descriptor = descriptor.FieldDescriptor() + + field_descriptor.name = 'a_field' + field_descriptor.number = 1 + field_descriptor.variant = descriptor.FieldDescriptor.Variant.BOOL + field_descriptor.default_value = u'true' + + field = definition.define_field(field_descriptor) + + # Name will not be set from the original descriptor. + self.assertFalse(hasattr(field, 'name')) + + self.assertTrue(isinstance(field, messages.BooleanField)) + self.assertEquals(1, field.number) + self.assertEquals(descriptor.FieldDescriptor.Variant.BOOL, field.variant) + self.assertFalse(field.required) + self.assertFalse(field.repeated) + self.assertEqual(field.default, True) + + field_descriptor.default_value = u'false' + + field = definition.define_field(field_descriptor) + + self.assertEqual(field.default, False) + + def testDefineField_Default_Float(self): + """Test defining a default value for a float.""" + field_descriptor = descriptor.FieldDescriptor() + + field_descriptor.name = 'a_field' + field_descriptor.number = 1 + field_descriptor.variant = descriptor.FieldDescriptor.Variant.FLOAT + field_descriptor.default_value = u'34.567' + + field = definition.define_field(field_descriptor) + + # Name will not be set from the original descriptor. + self.assertFalse(hasattr(field, 'name')) + + self.assertTrue(isinstance(field, messages.FloatField)) + self.assertEquals(1, field.number) + self.assertEquals(descriptor.FieldDescriptor.Variant.FLOAT, field.variant) + self.assertFalse(field.required) + self.assertFalse(field.repeated) + self.assertEqual(field.default, 34.567) + + def testDefineField_Default_Int(self): + """Test defining a default value for an int.""" + field_descriptor = descriptor.FieldDescriptor() + + field_descriptor.name = 'a_field' + field_descriptor.number = 1 + field_descriptor.variant = descriptor.FieldDescriptor.Variant.INT64 + field_descriptor.default_value = u'34' + + field = definition.define_field(field_descriptor) + + # Name will not be set from the original descriptor. + self.assertFalse(hasattr(field, 'name')) + + self.assertTrue(isinstance(field, messages.IntegerField)) + self.assertEquals(1, field.number) + self.assertEquals(descriptor.FieldDescriptor.Variant.INT64, field.variant) + self.assertFalse(field.required) + self.assertFalse(field.repeated) + self.assertEqual(field.default, 34) + + def testDefineField_Default_Str(self): + """Test defining a default value for a str.""" + field_descriptor = descriptor.FieldDescriptor() + + field_descriptor.name = 'a_field' + field_descriptor.number = 1 + field_descriptor.variant = descriptor.FieldDescriptor.Variant.STRING + field_descriptor.default_value = u'Test' + + field = definition.define_field(field_descriptor) + + # Name will not be set from the original descriptor. + self.assertFalse(hasattr(field, 'name')) + + self.assertTrue(isinstance(field, messages.StringField)) + self.assertEquals(1, field.number) + self.assertEquals(descriptor.FieldDescriptor.Variant.STRING, field.variant) + self.assertFalse(field.required) + self.assertFalse(field.repeated) + self.assertEqual(field.default, u'Test') + + def testDefineField_Default_Invalid(self): + """Test defining a default value that is not valid.""" + field_descriptor = descriptor.FieldDescriptor() + + field_descriptor.name = 'a_field' + field_descriptor.number = 1 + field_descriptor.variant = descriptor.FieldDescriptor.Variant.INT64 + field_descriptor.default_value = u'Test' + + # Verify that the string is passed to the Constructor. + mock = mox.Mox() + mock.StubOutWithMock(messages.IntegerField, '__init__') + messages.IntegerField.__init__( + default=u'Test', + number=1, + variant=messages.Variant.INT64 + ).AndRaise(messages.InvalidDefaultError) + + mock.ReplayAll() + self.assertRaises(messages.InvalidDefaultError, + definition.define_field, field_descriptor) + mock.VerifyAll() + + mock.ResetAll() + mock.UnsetStubs() + + +class DefineMessageTest(test_util.TestCase): + """Test for define_message.""" + + def testDefineMessageEmpty(self): + """Test definition a message with no fields or enums.""" + + class AMessage(messages.Message): + pass + + message_descriptor = descriptor.describe_message(AMessage) + + message_class = definition.define_message(message_descriptor, '__main__') + + self.assertEquals('AMessage', message_class.__name__) + self.assertEquals('__main__', message_class.__module__) + + self.assertEquals(message_descriptor, + descriptor.describe_message(message_class)) + + def testDefineMessageEnumOnly(self): + """Test definition a message with only enums.""" + + class AMessage(messages.Message): + class NestedEnum(messages.Enum): + pass + + message_descriptor = descriptor.describe_message(AMessage) + + message_class = definition.define_message(message_descriptor, '__main__') + + self.assertEquals('AMessage', message_class.__name__) + self.assertEquals('__main__', message_class.__module__) + + self.assertEquals(message_descriptor, + descriptor.describe_message(message_class)) + + def testDefineMessageFieldsOnly(self): + """Test definition a message with only fields.""" + + class AMessage(messages.Message): + + field1 = messages.IntegerField(1) + field2 = messages.StringField(2) + + message_descriptor = descriptor.describe_message(AMessage) + + message_class = definition.define_message(message_descriptor, '__main__') + + self.assertEquals('AMessage', message_class.__name__) + self.assertEquals('__main__', message_class.__module__) + + self.assertEquals(message_descriptor, + descriptor.describe_message(message_class)) + + def testDefineMessage(self): + """Test defining Message class from descriptor.""" + + class AMessage(messages.Message): + class NestedEnum(messages.Enum): + pass + + field1 = messages.IntegerField(1) + field2 = messages.StringField(2) + + message_descriptor = descriptor.describe_message(AMessage) + + message_class = definition.define_message(message_descriptor, '__main__') + + self.assertEquals('AMessage', message_class.__name__) + self.assertEquals('__main__', message_class.__module__) + + self.assertEquals(message_descriptor, + descriptor.describe_message(message_class)) + + +class DefineServiceTest(test_util.TestCase): + """Test service proxy definition.""" + + def setUp(self): + """Set up mock and request classes.""" + self.module = types.ModuleType('stocks') + + class GetQuoteRequest(messages.Message): + __module__ = 'stocks' + + symbols = messages.StringField(1, repeated=True) + + class GetQuoteResponse(messages.Message): + __module__ = 'stocks' + + prices = messages.IntegerField(1, repeated=True) + + self.module.GetQuoteRequest = GetQuoteRequest + self.module.GetQuoteResponse = GetQuoteResponse + + def testDefineService(self): + """Test service definition from descriptor.""" + method_descriptor = descriptor.MethodDescriptor() + method_descriptor.name = 'get_quote' + method_descriptor.request_type = 'GetQuoteRequest' + method_descriptor.response_type = 'GetQuoteResponse' + + service_descriptor = descriptor.ServiceDescriptor() + service_descriptor.name = 'Stocks' + service_descriptor.methods = [method_descriptor] + + StockService = definition.define_service(service_descriptor, self.module) + + self.assertTrue(issubclass(StockService, remote.Service)) + self.assertTrue(issubclass(StockService.Stub, remote.StubBase)) + + request = self.module.GetQuoteRequest() + service = StockService() + self.assertRaises(NotImplementedError, + service.get_quote, request) + + self.assertEquals(self.module.GetQuoteRequest, + service.get_quote.remote.request_type) + self.assertEquals(self.module.GetQuoteResponse, + service.get_quote.remote.response_type) + + +class ModuleTest(test_util.TestCase): + """Test for module creation and importation functions.""" + + def MakeFileDescriptor(self, package): + """Helper method to construct FileDescriptors. + + Creates FileDescriptor with a MessageDescriptor and an EnumDescriptor. + + Args: + package: Package name to give new file descriptors. + + Returns: + New FileDescriptor instance. + """ + enum_descriptor = descriptor.EnumDescriptor() + enum_descriptor.name = u'MyEnum' + + message_descriptor = descriptor.MessageDescriptor() + message_descriptor.name = u'MyMessage' + + service_descriptor = descriptor.ServiceDescriptor() + service_descriptor.name = u'MyService' + + file_descriptor = descriptor.FileDescriptor() + file_descriptor.package = package + file_descriptor.enum_types = [enum_descriptor] + file_descriptor.message_types = [message_descriptor] + file_descriptor.service_types = [service_descriptor] + + return file_descriptor + + def testDefineModule(self): + """Test define_module function.""" + file_descriptor = self.MakeFileDescriptor('my.package') + + module = definition.define_file(file_descriptor) + + self.assertEquals('my.package', module.__name__) + self.assertEquals('my.package', module.MyEnum.__module__) + self.assertEquals('my.package', module.MyMessage.__module__) + self.assertEquals('my.package', module.MyService.__module__) + + self.assertEquals(file_descriptor, descriptor.describe_file(module)) + + def testDefineModule_ReuseModule(self): + """Test updating module with additional definitions.""" + file_descriptor = self.MakeFileDescriptor('my.package') + + module = types.ModuleType('override') + self.assertEquals(module, definition.define_file(file_descriptor, module)) + + self.assertEquals('override', module.MyEnum.__module__) + self.assertEquals('override', module.MyMessage.__module__) + self.assertEquals('override', module.MyService.__module__) + + # One thing is different between original descriptor and new. + file_descriptor.package = 'override' + self.assertEquals(file_descriptor, descriptor.describe_file(module)) + + def testImportFile(self): + """Test importing FileDescriptor in to module space.""" + modules = {} + file_descriptor = self.MakeFileDescriptor('standalone') + definition.import_file(file_descriptor, modules=modules) + self.assertEquals(file_descriptor, + descriptor.describe_file(modules['standalone'])) + + def testImportFile_InToExisting(self): + """Test importing FileDescriptor in to existing module.""" + module = types.ModuleType('standalone') + modules = {'standalone': module} + file_descriptor = self.MakeFileDescriptor('standalone') + definition.import_file(file_descriptor, modules=modules) + self.assertEquals(module, modules['standalone']) + self.assertEquals(file_descriptor, + descriptor.describe_file(modules['standalone'])) + + def testImportFile_InToGlobalModules(self): + """Test importing FileDescriptor in to global modules.""" + original_modules = sys.modules + try: + sys.modules = dict(sys.modules) + if 'standalone' in sys.modules: + del sys.modules['standalone'] + file_descriptor = self.MakeFileDescriptor('standalone') + definition.import_file(file_descriptor) + self.assertEquals(file_descriptor, + descriptor.describe_file(sys.modules['standalone'])) + finally: + sys.modules = original_modules + + def testImportFile_Nested(self): + """Test importing FileDescriptor in to existing nested module.""" + modules = {} + file_descriptor = self.MakeFileDescriptor('root.nested') + definition.import_file(file_descriptor, modules=modules) + self.assertEquals(modules['root'].nested, modules['root.nested']) + self.assertEquals(file_descriptor, + descriptor.describe_file(modules['root.nested'])) + + def testImportFile_NoPackage(self): + """Test importing FileDescriptor with no package.""" + file_descriptor = self.MakeFileDescriptor('does not matter') + file_descriptor.reset('package') + self.assertRaisesWithRegexpMatch(ValueError, + 'File descriptor must have package name', + definition.import_file, + file_descriptor) + + def testImportFileSet(self): + """Test importing a whole file set.""" + file_set = descriptor.FileSet() + file_set.files = [self.MakeFileDescriptor(u'standalone'), + self.MakeFileDescriptor(u'root.nested'), + self.MakeFileDescriptor(u'root.nested.nested'), + ] + + root = types.ModuleType('root') + nested = types.ModuleType('root.nested') + root.nested = nested + modules = { + 'root': root, + 'root.nested': nested, + } + + definition.import_file_set(file_set, modules=modules) + + self.assertEquals(root, modules['root']) + self.assertEquals(nested, modules['root.nested']) + self.assertEquals(nested.nested, modules['root.nested.nested']) + + self.assertEquals(file_set, + descriptor.describe_file_set( + [modules['standalone'], + modules['root.nested'], + modules['root.nested.nested'], + ])) + + def testImportFileSetFromFile(self): + """Test importing a whole file set from a file.""" + file_set = descriptor.FileSet() + file_set.files = [self.MakeFileDescriptor(u'standalone'), + self.MakeFileDescriptor(u'root.nested'), + self.MakeFileDescriptor(u'root.nested.nested'), + ] + + stream = StringIO.StringIO(protobuf.encode_message(file_set)) + + self.mox = mox.Mox() + opener = self.mox.CreateMockAnything() + opener('my-file.dat', 'rb').AndReturn(stream) + + self.mox.ReplayAll() + + modules = {} + definition.import_file_set('my-file.dat', modules=modules, _open=opener) + + self.assertEquals(file_set, + descriptor.describe_file_set( + [modules['standalone'], + modules['root.nested'], + modules['root.nested.nested'], + ])) + + def testImportBuiltInProtorpcClasses(self): + """Test that built in Protorpc classes are skipped.""" + file_set = descriptor.FileSet() + file_set.files = [self.MakeFileDescriptor(u'standalone'), + self.MakeFileDescriptor(u'root.nested'), + self.MakeFileDescriptor(u'root.nested.nested'), + descriptor.describe_file(descriptor), + ] + + root = types.ModuleType('root') + nested = types.ModuleType('root.nested') + root.nested = nested + modules = { + 'root': root, + 'root.nested': nested, + 'protorpc.descriptor': descriptor, + } + + definition.import_file_set(file_set, modules=modules) + + self.assertEquals(root, modules['root']) + self.assertEquals(nested, modules['root.nested']) + self.assertEquals(nested.nested, modules['root.nested.nested']) + self.assertEquals(descriptor, modules['protorpc.descriptor']) + + self.assertEquals(file_set, + descriptor.describe_file_set( + [modules['standalone'], + modules['root.nested'], + modules['root.nested.nested'], + modules['protorpc.descriptor'], + ])) + + +if __name__ == '__main__': + unittest.main() diff --git a/endpoints/internal/protorpc/descriptor.py b/endpoints/internal/protorpc/descriptor.py new file mode 100644 index 0000000..5f9e2e7 --- /dev/null +++ b/endpoints/internal/protorpc/descriptor.py @@ -0,0 +1,712 @@ +#!/usr/bin/env python +# +# Copyright 2010 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""Services descriptor definitions. + +Contains message definitions and functions for converting +service classes into transmittable message format. + +Describing an Enum instance, Enum class, Field class or Message class will +generate an appropriate descriptor object that describes that class. +This message can itself be used to transmit information to clients wishing +to know the description of an enum value, enum, field or message without +needing to download the source code. This format is also compatible with +other, non-Python languages. + +The descriptors are modeled to be binary compatible with: + + http://code.google.com/p/protobuf/source/browse/trunk/src/google/protobuf/descriptor.proto + +NOTE: The names of types and fields are not always the same between these +descriptors and the ones defined in descriptor.proto. This was done in order +to make source code files that use these descriptors easier to read. For +example, it is not necessary to prefix TYPE to all the values in +FieldDescriptor.Variant as is done in descriptor.proto FieldDescriptorProto.Type. + +Example: + + class Pixel(messages.Message): + + x = messages.IntegerField(1, required=True) + y = messages.IntegerField(2, required=True) + + color = messages.BytesField(3) + + # Describe Pixel class using message descriptor. + fields = [] + + field = FieldDescriptor() + field.name = 'x' + field.number = 1 + field.label = FieldDescriptor.Label.REQUIRED + field.variant = FieldDescriptor.Variant.INT64 + fields.append(field) + + field = FieldDescriptor() + field.name = 'y' + field.number = 2 + field.label = FieldDescriptor.Label.REQUIRED + field.variant = FieldDescriptor.Variant.INT64 + fields.append(field) + + field = FieldDescriptor() + field.name = 'color' + field.number = 3 + field.label = FieldDescriptor.Label.OPTIONAL + field.variant = FieldDescriptor.Variant.BYTES + fields.append(field) + + message = MessageDescriptor() + message.name = 'Pixel' + message.fields = fields + + # Describing is the equivalent of building the above message. + message == describe_message(Pixel) + +Public Classes: + EnumValueDescriptor: Describes Enum values. + EnumDescriptor: Describes Enum classes. + FieldDescriptor: Describes field instances. + FileDescriptor: Describes a single 'file' unit. + FileSet: Describes a collection of file descriptors. + MessageDescriptor: Describes Message classes. + MethodDescriptor: Describes a method of a service. + ServiceDescriptor: Describes a services. + +Public Functions: + describe_enum_value: Describe an individual enum-value. + describe_enum: Describe an Enum class. + describe_field: Describe a Field definition. + describe_file: Describe a 'file' unit from a Python module or object. + describe_file_set: Describe a file set from a list of modules or objects. + describe_message: Describe a Message definition. + describe_method: Describe a Method definition. + describe_service: Describe a Service definition. +""" +import six + +__author__ = 'rafek@google.com (Rafe Kaplan)' + +import codecs +import types + +from . import messages +from . import util + + +__all__ = ['EnumDescriptor', + 'EnumValueDescriptor', + 'FieldDescriptor', + 'MessageDescriptor', + 'MethodDescriptor', + 'FileDescriptor', + 'FileSet', + 'ServiceDescriptor', + 'DescriptorLibrary', + + 'describe_enum', + 'describe_enum_value', + 'describe_field', + 'describe_message', + 'describe_method', + 'describe_file', + 'describe_file_set', + 'describe_service', + 'describe', + 'import_descriptor_loader', + ] + + +# NOTE: MessageField is missing because message fields cannot have +# a default value at this time. +# TODO(rafek): Support default message values. +# +# Map to functions that convert default values of fields of a given type +# to a string. The function must return a value that is compatible with +# FieldDescriptor.default_value and therefore a unicode string. +_DEFAULT_TO_STRING_MAP = { + messages.IntegerField: six.text_type, + messages.FloatField: six.text_type, + messages.BooleanField: lambda value: value and u'true' or u'false', + messages.BytesField: lambda value: codecs.escape_encode(value)[0], + messages.StringField: lambda value: value, + messages.EnumField: lambda value: six.text_type(value.number), +} + +_DEFAULT_FROM_STRING_MAP = { + messages.IntegerField: int, + messages.FloatField: float, + messages.BooleanField: lambda value: value == u'true', + messages.BytesField: lambda value: codecs.escape_decode(value)[0], + messages.StringField: lambda value: value, + messages.EnumField: int, +} + + +class EnumValueDescriptor(messages.Message): + """Enum value descriptor. + + Fields: + name: Name of enumeration value. + number: Number of enumeration value. + """ + + # TODO(rafek): Why are these listed as optional in descriptor.proto. + # Harmonize? + name = messages.StringField(1, required=True) + number = messages.IntegerField(2, + required=True, + variant=messages.Variant.INT32) + + +class EnumDescriptor(messages.Message): + """Enum class descriptor. + + Fields: + name: Name of Enum without any qualification. + values: Values defined by Enum class. + """ + + name = messages.StringField(1) + values = messages.MessageField(EnumValueDescriptor, 2, repeated=True) + + +class FieldDescriptor(messages.Message): + """Field definition descriptor. + + Enums: + Variant: Wire format hint sub-types for field. + Label: Values for optional, required and repeated fields. + + Fields: + name: Name of field. + number: Number of field. + variant: Variant of field. + type_name: Type name for message and enum fields. + default_value: String representation of default value. + """ + + Variant = messages.Variant + + class Label(messages.Enum): + """Field label.""" + + OPTIONAL = 1 + REQUIRED = 2 + REPEATED = 3 + + name = messages.StringField(1, required=True) + number = messages.IntegerField(3, + required=True, + variant=messages.Variant.INT32) + label = messages.EnumField(Label, 4, default=Label.OPTIONAL) + variant = messages.EnumField(Variant, 5) + type_name = messages.StringField(6) + + # For numeric types, contains the original text representation of the value. + # For booleans, "true" or "false". + # For strings, contains the default text contents (not escaped in any way). + # For bytes, contains the C escaped value. All bytes < 128 are that are + # traditionally considered unprintable are also escaped. + default_value = messages.StringField(7) + + +class MessageDescriptor(messages.Message): + """Message definition descriptor. + + Fields: + name: Name of Message without any qualification. + fields: Fields defined for message. + message_types: Nested Message classes defined on message. + enum_types: Nested Enum classes defined on message. + """ + + name = messages.StringField(1) + fields = messages.MessageField(FieldDescriptor, 2, repeated=True) + + message_types = messages.MessageField( + 'protorpc.descriptor.MessageDescriptor', 3, repeated=True) + enum_types = messages.MessageField(EnumDescriptor, 4, repeated=True) + + +class MethodDescriptor(messages.Message): + """Service method definition descriptor. + + Fields: + name: Name of service method. + request_type: Fully qualified or relative name of request message type. + response_type: Fully qualified or relative name of response message type. + """ + + name = messages.StringField(1) + + request_type = messages.StringField(2) + response_type = messages.StringField(3) + + +class ServiceDescriptor(messages.Message): + """Service definition descriptor. + + Fields: + name: Name of Service without any qualification. + methods: Remote methods of Service. + """ + + name = messages.StringField(1) + + methods = messages.MessageField(MethodDescriptor, 2, repeated=True) + + +class FileDescriptor(messages.Message): + """Description of file containing protobuf definitions. + + Fields: + package: Fully qualified name of package that definitions belong to. + message_types: Message definitions contained in file. + enum_types: Enum definitions contained in file. + service_types: Service definitions contained in file. + """ + + package = messages.StringField(2) + + # TODO(rafek): Add dependency field + + message_types = messages.MessageField(MessageDescriptor, 4, repeated=True) + enum_types = messages.MessageField(EnumDescriptor, 5, repeated=True) + service_types = messages.MessageField(ServiceDescriptor, 6, repeated=True) + + +class FileSet(messages.Message): + """A collection of FileDescriptors. + + Fields: + files: Files in file-set. + """ + + files = messages.MessageField(FileDescriptor, 1, repeated=True) + + +def describe_enum_value(enum_value): + """Build descriptor for Enum instance. + + Args: + enum_value: Enum value to provide descriptor for. + + Returns: + Initialized EnumValueDescriptor instance describing the Enum instance. + """ + enum_value_descriptor = EnumValueDescriptor() + enum_value_descriptor.name = six.text_type(enum_value.name) + enum_value_descriptor.number = enum_value.number + return enum_value_descriptor + + +def describe_enum(enum_definition): + """Build descriptor for Enum class. + + Args: + enum_definition: Enum class to provide descriptor for. + + Returns: + Initialized EnumDescriptor instance describing the Enum class. + """ + enum_descriptor = EnumDescriptor() + enum_descriptor.name = enum_definition.definition_name().split('.')[-1] + + values = [] + for number in enum_definition.numbers(): + value = enum_definition.lookup_by_number(number) + values.append(describe_enum_value(value)) + + if values: + enum_descriptor.values = values + + return enum_descriptor + + +def describe_field(field_definition): + """Build descriptor for Field instance. + + Args: + field_definition: Field instance to provide descriptor for. + + Returns: + Initialized FieldDescriptor instance describing the Field instance. + """ + field_descriptor = FieldDescriptor() + field_descriptor.name = field_definition.name + field_descriptor.number = field_definition.number + field_descriptor.variant = field_definition.variant + + if isinstance(field_definition, messages.EnumField): + field_descriptor.type_name = field_definition.type.definition_name() + + if isinstance(field_definition, messages.MessageField): + field_descriptor.type_name = field_definition.message_type.definition_name() + + if field_definition.default is not None: + field_descriptor.default_value = _DEFAULT_TO_STRING_MAP[ + type(field_definition)](field_definition.default) + + # Set label. + if field_definition.repeated: + field_descriptor.label = FieldDescriptor.Label.REPEATED + elif field_definition.required: + field_descriptor.label = FieldDescriptor.Label.REQUIRED + else: + field_descriptor.label = FieldDescriptor.Label.OPTIONAL + + return field_descriptor + + +def describe_message(message_definition): + """Build descriptor for Message class. + + Args: + message_definition: Message class to provide descriptor for. + + Returns: + Initialized MessageDescriptor instance describing the Message class. + """ + message_descriptor = MessageDescriptor() + message_descriptor.name = message_definition.definition_name().split('.')[-1] + + fields = sorted(message_definition.all_fields(), + key=lambda v: v.number) + if fields: + message_descriptor.fields = [describe_field(field) for field in fields] + + try: + nested_messages = message_definition.__messages__ + except AttributeError: + pass + else: + message_descriptors = [] + for name in nested_messages: + value = getattr(message_definition, name) + message_descriptors.append(describe_message(value)) + + message_descriptor.message_types = message_descriptors + + try: + nested_enums = message_definition.__enums__ + except AttributeError: + pass + else: + enum_descriptors = [] + for name in nested_enums: + value = getattr(message_definition, name) + enum_descriptors.append(describe_enum(value)) + + message_descriptor.enum_types = enum_descriptors + + return message_descriptor + + +def describe_method(method): + """Build descriptor for service method. + + Args: + method: Remote service method to describe. + + Returns: + Initialized MethodDescriptor instance describing the service method. + """ + method_info = method.remote + descriptor = MethodDescriptor() + descriptor.name = method_info.method.__name__ + descriptor.request_type = method_info.request_type.definition_name() + descriptor.response_type = method_info.response_type.definition_name() + + return descriptor + + +def describe_service(service_class): + """Build descriptor for service. + + Args: + service_class: Service class to describe. + + Returns: + Initialized ServiceDescriptor instance describing the service. + """ + descriptor = ServiceDescriptor() + descriptor.name = service_class.__name__ + methods = [] + remote_methods = service_class.all_remote_methods() + for name in sorted(remote_methods.keys()): + if name == 'get_descriptor': + continue + + method = remote_methods[name] + methods.append(describe_method(method)) + if methods: + descriptor.methods = methods + + return descriptor + + +def describe_file(module): + """Build a file from a specified Python module. + + Args: + module: Python module to describe. + + Returns: + Initialized FileDescriptor instance describing the module. + """ + # May not import remote at top of file because remote depends on this + # file + # TODO(rafek): Straighten out this dependency. Possibly move these functions + # from descriptor to their own module. + from . import remote + + descriptor = FileDescriptor() + descriptor.package = util.get_package_for_module(module) + + if not descriptor.package: + descriptor.package = None + + message_descriptors = [] + enum_descriptors = [] + service_descriptors = [] + + # Need to iterate over all top level attributes of the module looking for + # message, enum and service definitions. Each definition must be itself + # described. + for name in sorted(dir(module)): + value = getattr(module, name) + + if isinstance(value, type): + if issubclass(value, messages.Message): + message_descriptors.append(describe_message(value)) + + elif issubclass(value, messages.Enum): + enum_descriptors.append(describe_enum(value)) + + elif issubclass(value, remote.Service): + service_descriptors.append(describe_service(value)) + + if message_descriptors: + descriptor.message_types = message_descriptors + + if enum_descriptors: + descriptor.enum_types = enum_descriptors + + if service_descriptors: + descriptor.service_types = service_descriptors + + return descriptor + + +def describe_file_set(modules): + """Build a file set from a specified Python modules. + + Args: + modules: Iterable of Python module to describe. + + Returns: + Initialized FileSet instance describing the modules. + """ + descriptor = FileSet() + file_descriptors = [] + for module in modules: + file_descriptors.append(describe_file(module)) + + if file_descriptors: + descriptor.files = file_descriptors + + return descriptor + + +def describe(value): + """Describe any value as a descriptor. + + Helper function for describing any object with an appropriate descriptor + object. + + Args: + value: Value to describe as a descriptor. + + Returns: + Descriptor message class if object is describable as a descriptor, else + None. + """ + from . import remote + if isinstance(value, types.ModuleType): + return describe_file(value) + elif callable(value) and hasattr(value, 'remote'): + return describe_method(value) + elif isinstance(value, messages.Field): + return describe_field(value) + elif isinstance(value, messages.Enum): + return describe_enum_value(value) + elif isinstance(value, type): + if issubclass(value, messages.Message): + return describe_message(value) + elif issubclass(value, messages.Enum): + return describe_enum(value) + elif issubclass(value, remote.Service): + return describe_service(value) + return None + + +@util.positional(1) +def import_descriptor_loader(definition_name, importer=__import__): + """Find objects by importing modules as needed. + + A definition loader is a function that resolves a definition name to a + descriptor. + + The import finder resolves definitions to their names by importing modules + when necessary. + + Args: + definition_name: Name of definition to find. + importer: Import function used for importing new modules. + + Returns: + Appropriate descriptor for any describable type located by name. + + Raises: + DefinitionNotFoundError when a name does not refer to either a definition + or a module. + """ + # Attempt to import descriptor as a module. + if definition_name.startswith('.'): + definition_name = definition_name[1:] + if not definition_name.startswith('.'): + leaf = definition_name.split('.')[-1] + if definition_name: + try: + module = importer(definition_name, '', '', [leaf]) + except ImportError: + pass + else: + return describe(module) + + try: + # Attempt to use messages.find_definition to find item. + return describe(messages.find_definition(definition_name, + importer=__import__)) + except messages.DefinitionNotFoundError as err: + # There are things that find_definition will not find, but if the parent + # is loaded, its children can be searched for a match. + split_name = definition_name.rsplit('.', 1) + if len(split_name) > 1: + parent, child = split_name + try: + parent_definition = import_descriptor_loader(parent, importer=importer) + except messages.DefinitionNotFoundError: + # Fall through to original error. + pass + else: + # Check the parent definition for a matching descriptor. + if isinstance(parent_definition, FileDescriptor): + search_list = parent_definition.service_types or [] + elif isinstance(parent_definition, ServiceDescriptor): + search_list = parent_definition.methods or [] + elif isinstance(parent_definition, EnumDescriptor): + search_list = parent_definition.values or [] + elif isinstance(parent_definition, MessageDescriptor): + search_list = parent_definition.fields or [] + else: + search_list = [] + + for definition in search_list: + if definition.name == child: + return definition + + # Still didn't find. Reraise original exception. + raise err + + +class DescriptorLibrary(object): + """A descriptor library is an object that contains known definitions. + + A descriptor library contains a cache of descriptor objects mapped by + definition name. It contains all types of descriptors except for + file sets. + + When a definition name is requested that the library does not know about + it can be provided with a descriptor loader which attempt to resolve the + missing descriptor. + """ + + @util.positional(1) + def __init__(self, + descriptors=None, + descriptor_loader=import_descriptor_loader): + """Constructor. + + Args: + descriptors: A dictionary or dictionary-like object that can be used + to store and cache descriptors by definition name. + definition_loader: A function used for resolving missing descriptors. + The function takes a definition name as its parameter and returns + an appropriate descriptor. It may raise DefinitionNotFoundError. + """ + self.__descriptor_loader = descriptor_loader + self.__descriptors = descriptors or {} + + def lookup_descriptor(self, definition_name): + """Lookup descriptor by name. + + Get descriptor from library by name. If descriptor is not found will + attempt to find via descriptor loader if provided. + + Args: + definition_name: Definition name to find. + + Returns: + Descriptor that describes definition name. + + Raises: + DefinitionNotFoundError if not descriptor exists for definition name. + """ + try: + return self.__descriptors[definition_name] + except KeyError: + pass + + if self.__descriptor_loader: + definition = self.__descriptor_loader(definition_name) + self.__descriptors[definition_name] = definition + return definition + else: + raise messages.DefinitionNotFoundError( + 'Could not find definition for %s' % definition_name) + + def lookup_package(self, definition_name): + """Determines the package name for any definition. + + Determine the package that any definition name belongs to. May check + parent for package name and will resolve missing descriptors if provided + descriptor loader. + + Args: + definition_name: Definition name to find package for. + """ + while True: + descriptor = self.lookup_descriptor(definition_name) + if isinstance(descriptor, FileDescriptor): + return descriptor.package + else: + index = definition_name.rfind('.') + if index < 0: + return None + definition_name = definition_name[:index] diff --git a/endpoints/internal/protorpc/descriptor_test.py b/endpoints/internal/protorpc/descriptor_test.py new file mode 100644 index 0000000..5047e8e --- /dev/null +++ b/endpoints/internal/protorpc/descriptor_test.py @@ -0,0 +1,649 @@ +#!/usr/bin/env python +# +# Copyright 2010 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""Tests for protorpc.descriptor.""" + +__author__ = 'rafek@google.com (Rafe Kaplan)' + + +import types +import unittest + +from protorpc import descriptor +from protorpc import message_types +from protorpc import messages +from protorpc import registry +from protorpc import remote +from protorpc import test_util + + +RUSSIA = u'\u0420\u043e\u0441\u0441\u0438\u044f' + + +class ModuleInterfaceTest(test_util.ModuleInterfaceTest, + test_util.TestCase): + + MODULE = descriptor + + +class DescribeEnumValueTest(test_util.TestCase): + + def testDescribe(self): + class MyEnum(messages.Enum): + MY_NAME = 10 + + expected = descriptor.EnumValueDescriptor() + expected.name = 'MY_NAME' + expected.number = 10 + + described = descriptor.describe_enum_value(MyEnum.MY_NAME) + described.check_initialized() + self.assertEquals(expected, described) + + +class DescribeEnumTest(test_util.TestCase): + + def testEmptyEnum(self): + class EmptyEnum(messages.Enum): + pass + + expected = descriptor.EnumDescriptor() + expected.name = 'EmptyEnum' + + described = descriptor.describe_enum(EmptyEnum) + described.check_initialized() + self.assertEquals(expected, described) + + def testNestedEnum(self): + class MyScope(messages.Message): + class NestedEnum(messages.Enum): + pass + + expected = descriptor.EnumDescriptor() + expected.name = 'NestedEnum' + + described = descriptor.describe_enum(MyScope.NestedEnum) + described.check_initialized() + self.assertEquals(expected, described) + + def testEnumWithItems(self): + class EnumWithItems(messages.Enum): + A = 3 + B = 1 + C = 2 + + expected = descriptor.EnumDescriptor() + expected.name = 'EnumWithItems' + + a = descriptor.EnumValueDescriptor() + a.name = 'A' + a.number = 3 + + b = descriptor.EnumValueDescriptor() + b.name = 'B' + b.number = 1 + + c = descriptor.EnumValueDescriptor() + c.name = 'C' + c.number = 2 + + expected.values = [b, c, a] + + described = descriptor.describe_enum(EnumWithItems) + described.check_initialized() + self.assertEquals(expected, described) + + +class DescribeFieldTest(test_util.TestCase): + + def testLabel(self): + for repeated, required, expected_label in ( + (True, False, descriptor.FieldDescriptor.Label.REPEATED), + (False, True, descriptor.FieldDescriptor.Label.REQUIRED), + (False, False, descriptor.FieldDescriptor.Label.OPTIONAL)): + field = messages.IntegerField(10, required=required, repeated=repeated) + field.name = 'a_field' + + expected = descriptor.FieldDescriptor() + expected.name = 'a_field' + expected.number = 10 + expected.label = expected_label + expected.variant = descriptor.FieldDescriptor.Variant.INT64 + + described = descriptor.describe_field(field) + described.check_initialized() + self.assertEquals(expected, described) + + def testDefault(self): + for field_class, default, expected_default in ( + (messages.IntegerField, 200, '200'), + (messages.FloatField, 1.5, '1.5'), + (messages.FloatField, 1e6, '1000000.0'), + (messages.BooleanField, True, 'true'), + (messages.BooleanField, False, 'false'), + (messages.BytesField, 'ab\xF1', 'ab\\xf1'), + (messages.StringField, RUSSIA, RUSSIA), + ): + field = field_class(10, default=default) + field.name = u'a_field' + + expected = descriptor.FieldDescriptor() + expected.name = u'a_field' + expected.number = 10 + expected.label = descriptor.FieldDescriptor.Label.OPTIONAL + expected.variant = field_class.DEFAULT_VARIANT + expected.default_value = expected_default + + described = descriptor.describe_field(field) + described.check_initialized() + self.assertEquals(expected, described) + + def testDefault_EnumField(self): + class MyEnum(messages.Enum): + + VAL = 1 + + module_name = test_util.get_module_name(MyEnum) + field = messages.EnumField(MyEnum, 10, default=MyEnum.VAL) + field.name = 'a_field' + + expected = descriptor.FieldDescriptor() + expected.name = 'a_field' + expected.number = 10 + expected.label = descriptor.FieldDescriptor.Label.OPTIONAL + expected.variant = messages.EnumField.DEFAULT_VARIANT + expected.type_name = '%s.MyEnum' % module_name + expected.default_value = '1' + + described = descriptor.describe_field(field) + self.assertEquals(expected, described) + + def testMessageField(self): + field = messages.MessageField(descriptor.FieldDescriptor, 10) + field.name = 'a_field' + + expected = descriptor.FieldDescriptor() + expected.name = 'a_field' + expected.number = 10 + expected.label = descriptor.FieldDescriptor.Label.OPTIONAL + expected.variant = messages.MessageField.DEFAULT_VARIANT + expected.type_name = ('protorpc.descriptor.FieldDescriptor') + + described = descriptor.describe_field(field) + described.check_initialized() + self.assertEquals(expected, described) + + def testDateTimeField(self): + field = message_types.DateTimeField(20) + field.name = 'a_timestamp' + + expected = descriptor.FieldDescriptor() + expected.name = 'a_timestamp' + expected.number = 20 + expected.label = descriptor.FieldDescriptor.Label.OPTIONAL + expected.variant = messages.MessageField.DEFAULT_VARIANT + expected.type_name = ('protorpc.message_types.DateTimeMessage') + + described = descriptor.describe_field(field) + described.check_initialized() + self.assertEquals(expected, described) + + +class DescribeMessageTest(test_util.TestCase): + + def testEmptyDefinition(self): + class MyMessage(messages.Message): + pass + + expected = descriptor.MessageDescriptor() + expected.name = 'MyMessage' + + described = descriptor.describe_message(MyMessage) + described.check_initialized() + self.assertEquals(expected, described) + + def testDefinitionWithFields(self): + class MessageWithFields(messages.Message): + field1 = messages.IntegerField(10) + field2 = messages.StringField(30) + field3 = messages.IntegerField(20) + + expected = descriptor.MessageDescriptor() + expected.name = 'MessageWithFields' + + expected.fields = [ + descriptor.describe_field(MessageWithFields.field_by_name('field1')), + descriptor.describe_field(MessageWithFields.field_by_name('field3')), + descriptor.describe_field(MessageWithFields.field_by_name('field2')), + ] + + described = descriptor.describe_message(MessageWithFields) + described.check_initialized() + self.assertEquals(expected, described) + + def testNestedEnum(self): + class MessageWithEnum(messages.Message): + class Mood(messages.Enum): + GOOD = 1 + BAD = 2 + UGLY = 3 + + class Music(messages.Enum): + CLASSIC = 1 + JAZZ = 2 + BLUES = 3 + + expected = descriptor.MessageDescriptor() + expected.name = 'MessageWithEnum' + + expected.enum_types = [descriptor.describe_enum(MessageWithEnum.Mood), + descriptor.describe_enum(MessageWithEnum.Music)] + + described = descriptor.describe_message(MessageWithEnum) + described.check_initialized() + self.assertEquals(expected, described) + + def testNestedMessage(self): + class MessageWithMessage(messages.Message): + class Nesty(messages.Message): + pass + + expected = descriptor.MessageDescriptor() + expected.name = 'MessageWithMessage' + + expected.message_types = [ + descriptor.describe_message(MessageWithMessage.Nesty)] + + described = descriptor.describe_message(MessageWithMessage) + described.check_initialized() + self.assertEquals(expected, described) + + +class DescribeMethodTest(test_util.TestCase): + """Test describing remote methods.""" + + def testDescribe(self): + class Request(messages.Message): + pass + + class Response(messages.Message): + pass + + @remote.method(Request, Response) + def remote_method(request): + pass + + module_name = test_util.get_module_name(DescribeMethodTest) + expected = descriptor.MethodDescriptor() + expected.name = 'remote_method' + expected.request_type = '%s.Request' % module_name + expected.response_type = '%s.Response' % module_name + + described = descriptor.describe_method(remote_method) + described.check_initialized() + self.assertEquals(expected, described) + + +class DescribeServiceTest(test_util.TestCase): + """Test describing service classes.""" + + def testDescribe(self): + class Request1(messages.Message): + pass + + class Response1(messages.Message): + pass + + class Request2(messages.Message): + pass + + class Response2(messages.Message): + pass + + class MyService(remote.Service): + + @remote.method(Request1, Response1) + def method1(self, request): + pass + + @remote.method(Request2, Response2) + def method2(self, request): + pass + + expected = descriptor.ServiceDescriptor() + expected.name = 'MyService' + expected.methods = [] + + expected.methods.append(descriptor.describe_method(MyService.method1)) + expected.methods.append(descriptor.describe_method(MyService.method2)) + + described = descriptor.describe_service(MyService) + described.check_initialized() + self.assertEquals(expected, described) + + +class DescribeFileTest(test_util.TestCase): + """Test describing modules.""" + + def LoadModule(self, module_name, source): + result = {'__name__': module_name, + 'messages': messages, + 'remote': remote, + } + exec(source, result) + + module = types.ModuleType(module_name) + for name, value in result.items(): + setattr(module, name, value) + + return module + + def testEmptyModule(self): + """Test describing an empty file.""" + module = types.ModuleType('my.package.name') + + expected = descriptor.FileDescriptor() + expected.package = 'my.package.name' + + described = descriptor.describe_file(module) + described.check_initialized() + self.assertEquals(expected, described) + + def testNoPackageName(self): + """Test describing a module with no module name.""" + module = types.ModuleType('') + + expected = descriptor.FileDescriptor() + + described = descriptor.describe_file(module) + described.check_initialized() + self.assertEquals(expected, described) + + def testPackageName(self): + """Test using the 'package' module attribute.""" + module = types.ModuleType('my.module.name') + module.package = 'my.package.name' + + expected = descriptor.FileDescriptor() + expected.package = 'my.package.name' + + described = descriptor.describe_file(module) + described.check_initialized() + self.assertEquals(expected, described) + + def testMain(self): + """Test using the 'package' module attribute.""" + module = types.ModuleType('__main__') + module.__file__ = '/blim/blam/bloom/my_package.py' + + expected = descriptor.FileDescriptor() + expected.package = 'my_package' + + described = descriptor.describe_file(module) + described.check_initialized() + self.assertEquals(expected, described) + + def testMessages(self): + """Test that messages are described.""" + module = self.LoadModule('my.package', + 'class Message1(messages.Message): pass\n' + 'class Message2(messages.Message): pass\n') + + message1 = descriptor.MessageDescriptor() + message1.name = 'Message1' + + message2 = descriptor.MessageDescriptor() + message2.name = 'Message2' + + expected = descriptor.FileDescriptor() + expected.package = 'my.package' + expected.message_types = [message1, message2] + + described = descriptor.describe_file(module) + described.check_initialized() + self.assertEquals(expected, described) + + def testEnums(self): + """Test that enums are described.""" + module = self.LoadModule('my.package', + 'class Enum1(messages.Enum): pass\n' + 'class Enum2(messages.Enum): pass\n') + + enum1 = descriptor.EnumDescriptor() + enum1.name = 'Enum1' + + enum2 = descriptor.EnumDescriptor() + enum2.name = 'Enum2' + + expected = descriptor.FileDescriptor() + expected.package = 'my.package' + expected.enum_types = [enum1, enum2] + + described = descriptor.describe_file(module) + described.check_initialized() + self.assertEquals(expected, described) + + def testServices(self): + """Test that services are described.""" + module = self.LoadModule('my.package', + 'class Service1(remote.Service): pass\n' + 'class Service2(remote.Service): pass\n') + + service1 = descriptor.ServiceDescriptor() + service1.name = 'Service1' + + service2 = descriptor.ServiceDescriptor() + service2.name = 'Service2' + + expected = descriptor.FileDescriptor() + expected.package = 'my.package' + expected.service_types = [service1, service2] + + described = descriptor.describe_file(module) + described.check_initialized() + self.assertEquals(expected, described) + + +class DescribeFileSetTest(test_util.TestCase): + """Test describing multiple modules.""" + + def testNoModules(self): + """Test what happens when no modules provided.""" + described = descriptor.describe_file_set([]) + described.check_initialized() + # The described FileSet.files will be None. + self.assertEquals(descriptor.FileSet(), described) + + def testWithModules(self): + """Test what happens when no modules provided.""" + modules = [types.ModuleType('package1'), types.ModuleType('package1')] + + file1 = descriptor.FileDescriptor() + file1.package = 'package1' + file2 = descriptor.FileDescriptor() + file2.package = 'package2' + + expected = descriptor.FileSet() + expected.files = [file1, file1] + + described = descriptor.describe_file_set(modules) + described.check_initialized() + self.assertEquals(expected, described) + + +class DescribeTest(test_util.TestCase): + + def testModule(self): + self.assertEquals(descriptor.describe_file(test_util), + descriptor.describe(test_util)) + + def testMethod(self): + class Param(messages.Message): + pass + + class Service(remote.Service): + + @remote.method(Param, Param) + def fn(self): + return Param() + + self.assertEquals(descriptor.describe_method(Service.fn), + descriptor.describe(Service.fn)) + + def testField(self): + self.assertEquals( + descriptor.describe_field(test_util.NestedMessage.a_value), + descriptor.describe(test_util.NestedMessage.a_value)) + + def testEnumValue(self): + self.assertEquals( + descriptor.describe_enum_value( + test_util.OptionalMessage.SimpleEnum.VAL1), + descriptor.describe(test_util.OptionalMessage.SimpleEnum.VAL1)) + + def testMessage(self): + self.assertEquals(descriptor.describe_message(test_util.NestedMessage), + descriptor.describe(test_util.NestedMessage)) + + def testEnum(self): + self.assertEquals( + descriptor.describe_enum(test_util.OptionalMessage.SimpleEnum), + descriptor.describe(test_util.OptionalMessage.SimpleEnum)) + + def testService(self): + class Service(remote.Service): + pass + + self.assertEquals(descriptor.describe_service(Service), + descriptor.describe(Service)) + + def testService(self): + class Service(remote.Service): + pass + + self.assertEquals(descriptor.describe_service(Service), + descriptor.describe(Service)) + + def testUndescribable(self): + class NonService(object): + + def fn(self): + pass + + for value in (NonService, + NonService.fn, + 1, + 'string', + 1.2, + None): + self.assertEquals(None, descriptor.describe(value)) + + +class ModuleFinderTest(test_util.TestCase): + + def testFindModule(self): + self.assertEquals(descriptor.describe_file(registry), + descriptor.import_descriptor_loader('protorpc.registry')) + + def testFindMessage(self): + self.assertEquals( + descriptor.describe_message(descriptor.FileSet), + descriptor.import_descriptor_loader('protorpc.descriptor.FileSet')) + + def testFindField(self): + self.assertEquals( + descriptor.describe_field(descriptor.FileSet.files), + descriptor.import_descriptor_loader('protorpc.descriptor.FileSet.files')) + + def testFindEnumValue(self): + self.assertEquals( + descriptor.describe_enum_value(test_util.OptionalMessage.SimpleEnum.VAL1), + descriptor.import_descriptor_loader( + 'protorpc.test_util.OptionalMessage.SimpleEnum.VAL1')) + + def testFindMethod(self): + self.assertEquals( + descriptor.describe_method(registry.RegistryService.services), + descriptor.import_descriptor_loader( + 'protorpc.registry.RegistryService.services')) + + def testFindService(self): + self.assertEquals( + descriptor.describe_service(registry.RegistryService), + descriptor.import_descriptor_loader('protorpc.registry.RegistryService')) + + def testFindWithAbsoluteName(self): + self.assertEquals( + descriptor.describe_service(registry.RegistryService), + descriptor.import_descriptor_loader('.protorpc.registry.RegistryService')) + + def testFindWrongThings(self): + for name in ('a', 'protorpc.registry.RegistryService.__init__', '', ): + self.assertRaisesWithRegexpMatch( + messages.DefinitionNotFoundError, + 'Could not find definition for %s' % name, + descriptor.import_descriptor_loader, name) + + +class DescriptorLibraryTest(test_util.TestCase): + + def setUp(self): + self.packageless = descriptor.MessageDescriptor() + self.packageless.name = 'Packageless' + self.library = descriptor.DescriptorLibrary( + descriptors={ + 'not.real.Packageless': self.packageless, + 'Packageless': self.packageless, + }) + + def testLookupPackage(self): + self.assertEquals('csv', self.library.lookup_package('csv')) + self.assertEquals('protorpc', self.library.lookup_package('protorpc')) + self.assertEquals('protorpc.registry', + self.library.lookup_package('protorpc.registry')) + self.assertEquals('protorpc.registry', + self.library.lookup_package('.protorpc.registry')) + self.assertEquals( + 'protorpc.registry', + self.library.lookup_package('protorpc.registry.RegistryService')) + self.assertEquals( + 'protorpc.registry', + self.library.lookup_package( + 'protorpc.registry.RegistryService.services')) + + def testLookupNonPackages(self): + for name in ('', 'a', 'protorpc.descriptor.DescriptorLibrary'): + self.assertRaisesWithRegexpMatch( + messages.DefinitionNotFoundError, + 'Could not find definition for %s' % name, + self.library.lookup_package, name) + + def testNoPackage(self): + self.assertRaisesWithRegexpMatch( + messages.DefinitionNotFoundError, + 'Could not find definition for not.real', + self.library.lookup_package, 'not.real.Packageless') + + self.assertEquals(None, self.library.lookup_package('Packageless')) + + +def main(): + unittest.main() + + +if __name__ == '__main__': + main() diff --git a/endpoints/internal/protorpc/end2end_test.py b/endpoints/internal/protorpc/end2end_test.py new file mode 100644 index 0000000..c3e0141 --- /dev/null +++ b/endpoints/internal/protorpc/end2end_test.py @@ -0,0 +1,148 @@ +#!/usr/bin/env python +# +# Copyright 2011 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""End to end tests for ProtoRPC.""" + +__author__ = 'rafek@google.com (Rafe Kaplan)' + + +import unittest + +from protorpc import protojson +from protorpc import remote +from protorpc import test_util +from protorpc import util +from protorpc import webapp_test_util + +package = 'test_package' + + +class EndToEndTest(webapp_test_util.EndToEndTestBase): + + def testSimpleRequest(self): + self.assertEquals(test_util.OptionalMessage(string_value='+blar'), + self.stub.optional_message(string_value='blar')) + + def testSimpleRequestComplexContentType(self): + response = self.DoRawRequest( + 'optional_message', + content='{"string_value": "blar"}', + content_type='application/json; charset=utf-8') + headers = response.headers + self.assertEquals(200, response.code) + self.assertEquals('{"string_value": "+blar"}', response.read()) + self.assertEquals('application/json', headers['content-type']) + + def testInitParameter(self): + self.assertEquals(test_util.OptionalMessage(string_value='uninitialized'), + self.stub.init_parameter()) + self.assertEquals(test_util.OptionalMessage(string_value='initialized'), + self.other_stub.init_parameter()) + + def testMissingContentType(self): + code, content, headers = self.RawRequestError( + 'optional_message', + content='{"string_value": "blar"}', + content_type='') + self.assertEquals(400, code) + self.assertEquals(util.pad_string('Bad Request'), content) + self.assertEquals('text/plain; charset=utf-8', headers['content-type']) + + def testWrongPath(self): + self.assertRaisesWithRegexpMatch(remote.ServerError, + 'HTTP Error 404: Not Found', + self.bad_path_stub.optional_message) + + def testUnsupportedContentType(self): + code, content, headers = self.RawRequestError( + 'optional_message', + content='{"string_value": "blar"}', + content_type='image/png') + self.assertEquals(415, code) + self.assertEquals(util.pad_string('Unsupported Media Type'), content) + self.assertEquals(headers['content-type'], 'text/plain; charset=utf-8') + + def testUnsupportedHttpMethod(self): + code, content, headers = self.RawRequestError('optional_message') + self.assertEquals(405, code) + self.assertEquals( + util.pad_string('/my/service.optional_message is a ProtoRPC method.\n\n' + 'Service protorpc.webapp_test_util.TestService\n\n' + 'More about ProtoRPC: ' + 'http://code.google.com/p/google-protorpc\n'), + content) + self.assertEquals(headers['content-type'], 'text/plain; charset=utf-8') + + def testMethodNotFound(self): + self.assertRaisesWithRegexpMatch(remote.MethodNotFoundError, + 'Unrecognized RPC method: does_not_exist', + self.mismatched_stub.does_not_exist) + + def testBadMessageError(self): + code, content, headers = self.RawRequestError('nested_message', + content='{}') + self.assertEquals(400, code) + + expected_content = protojson.encode_message(remote.RpcStatus( + state=remote.RpcState.REQUEST_ERROR, + error_message=('Error parsing ProtoRPC request ' + '(Unable to parse request content: ' + 'Message NestedMessage is missing ' + 'required field a_value)'))) + self.assertEquals(util.pad_string(expected_content), content) + self.assertEquals(headers['content-type'], 'application/json') + + def testApplicationError(self): + try: + self.stub.raise_application_error() + except remote.ApplicationError as err: + self.assertEquals('This is an application error', unicode(err)) + self.assertEquals('ERROR_NAME', err.error_name) + else: + self.fail('Expected application error') + + def testRpcError(self): + try: + self.stub.raise_rpc_error() + except remote.ServerError as err: + self.assertEquals('Internal Server Error', unicode(err)) + else: + self.fail('Expected server error') + + def testUnexpectedError(self): + try: + self.stub.raise_unexpected_error() + except remote.ServerError as err: + self.assertEquals('Internal Server Error', unicode(err)) + else: + self.fail('Expected server error') + + def testBadResponse(self): + try: + self.stub.return_bad_message() + except remote.ServerError as err: + self.assertEquals('Internal Server Error', unicode(err)) + else: + self.fail('Expected server error') + + +def main(): + unittest.main() + + +if __name__ == '__main__': + main() diff --git a/endpoints/internal/protorpc/experimental/__init__.py b/endpoints/internal/protorpc/experimental/__init__.py new file mode 100644 index 0000000..419fff2 --- /dev/null +++ b/endpoints/internal/protorpc/experimental/__init__.py @@ -0,0 +1,20 @@ +#!/usr/bin/env python +# +# Copyright 2011 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""Main module for ProtoRPC package.""" + +__author__ = 'rafek@google.com (Rafe Kaplan)' diff --git a/endpoints/internal/protorpc/experimental/parser/protobuf.g b/endpoints/internal/protorpc/experimental/parser/protobuf.g new file mode 100644 index 0000000..8115be5 --- /dev/null +++ b/endpoints/internal/protorpc/experimental/parser/protobuf.g @@ -0,0 +1,159 @@ +/* !/usr/bin/env python + * + * Copyright 2011 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +parser grammar protobuf; + +scalar_value + : STRING + | FLOAT + | INT + | BOOL + ; + +id + : ID + | PACKAGE + | SERVICE + | MESSAGE + | ENUM + | DATA_TYPE + | EXTENSIONS + ; + +user_option_id + : '(' name_root='.'? qualified_name ')' + -> ^(USER_OPTION_ID $name_root? qualified_name) + ; + +option_id + : (id | user_option_id) ('.'! (id | user_option_id))* + ; + +option + : option_id '=' (scalar_value | id) + -> ^(OPTION ^(OPTION_ID option_id) scalar_value? id?) + ; + +decl_options + : '[' option (',' option)* ']' + -> ^(OPTIONS option*) + ; + +qualified_name + : id ('.'! id)* + ; + +field_decl + : qualified_name id '=' INT decl_options? ';' + -> ^(FIELD_TYPE qualified_name) id INT decl_options? + | GROUP id '=' INT '{' message_def '}' + -> ^(FIELD_TYPE GROUP) id INT ^(GROUP_MESSAGE message_def) + ; + +field + : LABEL field_decl + -> ^(FIELD LABEL field_decl) + ; + +enum_decl + : id '=' INT decl_options? ';' + -> ^(ENUM_DECL id INT decl_options?) + ; + +enum_def + : ENUM id '{' (def_option | enum_decl | ';')* '}' + -> ^(ENUM id + ^(OPTIONS def_option*) + ^(ENUM_DECLS enum_decl*)) + ; + +extensions + : EXTENSIONS start=INT (TO (end=INT | end=MAX))? ';' -> ^(EXTENSION_RANGE $start $end) + ; + +message_def + : ( field + | enum_def + | message + | extension + | extensions + | def_option + | ';' + )* -> + ^(FIELDS field*) + ^(MESSAGES message*) + ^(ENUMS enum_def*) + ^(EXTENSIONS extensions*) + ^(OPTIONS def_option*) + ; + +message + : MESSAGE^ id '{'! message_def '}'! + ; + +method_options + : '{'! (def_option | ';'!)+ '}'! + ; + +method_def + : RPC id '(' qualified_name ')' + RETURNS '(' qualified_name ')' (method_options | ';') + ; + +service_defs + : (def_option | method_def | ';')+ + ; + +service + : SERVICE id '{' service_defs? '}' + ; + +extension + : EXTEND qualified_name '{' message_def '}' + ; + +import_line + : IMPORT! STRING ';'! + ; + +package_decl + : PACKAGE^ qualified_name ';'! + ; + +def_option + : OPTION option ';' -> option + ; + +proto_file + : ( package_decl + | import_line + | message + | enum_def + | service + | extension + | def_option + | ';' + )* + -> ^(PROTO_FILE package_decl* + ^(IMPORTS import_line*) + ^(MESSAGES message*) + ^(ENUMS enum_def*) + ^(SERVICES service*) + ^(EXTENSIONS extension*) + ^(OPTIONS def_option*) + ) + ; diff --git a/endpoints/internal/protorpc/experimental/parser/protobuf_lexer.g b/endpoints/internal/protorpc/experimental/parser/protobuf_lexer.g new file mode 100644 index 0000000..be789b5 --- /dev/null +++ b/endpoints/internal/protorpc/experimental/parser/protobuf_lexer.g @@ -0,0 +1,153 @@ +/* !/usr/bin/env python + * + * Copyright 2011 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +lexer grammar protobuf_lexer; + +tokens { + // Imaginary tree nodes. + ENUMS; + ENUM_DECL; + ENUM_DECLS; + EXTENSION_RANGE; + FIELD; + FIELDS; + FIELD_TYPE; + GROUP_MESSAGE; + IMPORTS; + MESSAGES; + NAME_ROOT; + OPTIONS; + OPTION_ID; + PROTO_FILE; + SERVICES; + USER_OPTION_ID; +} + +// Basic keyword tokens. +ENUM : 'enum'; +MESSAGE : 'message'; +IMPORT : 'import'; +OPTION : 'option'; +PACKAGE : 'package'; +RPC : 'rpc'; +SERVICE : 'service'; +RETURNS : 'returns'; +EXTEND : 'extend'; +EXTENSIONS : 'extensions'; +TO : 'to'; +GROUP : 'group'; +MAX : 'max'; + +COMMENT + : '//' ~('\n'|'\r')* '\r'? '\n' {$channel=HIDDEN;} + | '/*' ( options {greedy=false;} : . )* '*/' {$channel=HIDDEN;} + ; + +WS + : ( ' ' + | '\t' + | '\r' + | '\n' + ) {$channel=HIDDEN;} + ; + +DATA_TYPE + : 'double' + | 'float' + | 'int32' + | 'int64' + | 'uint32' + | 'uint64' + | 'sint32' + | 'sint64' + | 'fixed32' + | 'fixed64' + | 'sfixed32' + | 'sfixed64' + | 'bool' + | 'string' + | 'bytes' + ; + +LABEL + : 'required' + | 'optional' + | 'repeated' + ; + +BOOL + : 'true' + | 'false' + ; + +ID + : ('a'..'z'|'A'..'Z'|'_') ('a'..'z'|'A'..'Z'|'0'..'9'|'_')* + ; + +INT + : '-'? ('0'..'9'+ | '0x' ('a'..'f'|'A'..'F'|'0'..'9')+ | 'inf') + | 'nan' + ; + +FLOAT + : '-'? ('0'..'9')+ '.' ('0'..'9')* EXPONENT? + | '-'? '.' ('0'..'9')+ EXPONENT? + | '-'? ('0'..'9')+ EXPONENT + ; + +STRING + : '"' ( STRING_INNARDS )* '"'; + +fragment +STRING_INNARDS + : ESC_SEQ + | ~('\\'|'"') + ; + +fragment +EXPONENT + : ('e'|'E') ('+'|'-')? ('0'..'9')+ + ; + +fragment +HEX_DIGIT + : ('0'..'9'|'a'..'f'|'A'..'F') + ; + +fragment +ESC_SEQ + : '\\' ('a'|'b'|'t'|'n'|'f'|'r'|'v'|'\"'|'\''|'\\') + | UNICODE_ESC + | OCTAL_ESC + | HEX_ESC + ; + +fragment +OCTAL_ESC + : '\\' ('0'..'3') ('0'..'7') ('0'..'7') + | '\\' ('0'..'7') ('0'..'7') + | '\\' ('0'..'7') + ; + +fragment +HEX_ESC + : '\\x' HEX_DIGIT HEX_DIGIT + ; + +fragment +UNICODE_ESC + : '\\' 'u' HEX_DIGIT HEX_DIGIT HEX_DIGIT HEX_DIGIT + ; diff --git a/endpoints/internal/protorpc/experimental/parser/pyprotobuf.g b/endpoints/internal/protorpc/experimental/parser/pyprotobuf.g new file mode 100644 index 0000000..534e1f8 --- /dev/null +++ b/endpoints/internal/protorpc/experimental/parser/pyprotobuf.g @@ -0,0 +1,45 @@ +/* !/usr/bin/env python + * + * Copyright 2011 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +grammar pyprotobuf; + +options { +// language=Python; + output = AST; + ASTLabelType = CommonTree; +} + +import protobuf_lexer, protobuf; + +// For reasons I do not understand the HIDDEN elements from the imported +// with their channel intact. + +COMMENT + : '//' ~('\n'|'\r')* '\r'? '\n' {$channel=HIDDEN;} + | '/*' ( options {greedy=false;} : . )* '*/' {$channel=HIDDEN;} + ; + +WS : ( ' ' + | '\t' + | '\r' + | '\n' + ) {$channel=HIDDEN;} + ; + +py_proto_file + : proto_file EOF^ + ; diff --git a/endpoints/internal/protorpc/experimental/parser/test.proto b/endpoints/internal/protorpc/experimental/parser/test.proto new file mode 100644 index 0000000..438e1e6 --- /dev/null +++ b/endpoints/internal/protorpc/experimental/parser/test.proto @@ -0,0 +1,27 @@ +/* !/usr/bin/env python + * + * Copyright 2011 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package a.b.c; + +import "abc.def"; +import "from/here"; + +message MyMessage { + required int64 thing = 1 [a="b"]; + optional group whatever = 2 { + repeated int64 thing = 1; + } +} diff --git a/endpoints/internal/protorpc/generate.py b/endpoints/internal/protorpc/generate.py new file mode 100644 index 0000000..9a2630b --- /dev/null +++ b/endpoints/internal/protorpc/generate.py @@ -0,0 +1,128 @@ +#!/usr/bin/env python +# +# Copyright 2010 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +__author__ = 'rafek@google.com (Rafe Kaplan)' + +import contextlib + +from . import messages +from . import util + +__all__ = ['IndentationError', + 'IndentWriter', + ] + + +class IndentationError(messages.Error): + """Raised when end_indent is called too many times.""" + + +class IndentWriter(object): + """Utility class to make it easy to write formatted indented text. + + IndentWriter delegates to a file-like object and is able to keep track of the + level of indentation. Each call to write_line will write a line terminated + by a new line proceeded by a number of spaces indicated by the current level + of indentation. + + IndexWriter overloads the << operator to make line writing operations clearer. + + The indent method returns a context manager that can be used by the Python + with statement that makes generating python code easier to use. For example: + + index_writer << 'def factorial(n):' + with index_writer.indent(): + index_writer << 'if n <= 1:' + with index_writer.indent(): + index_writer << 'return 1' + index_writer << 'else:' + with index_writer.indent(): + index_writer << 'return factorial(n - 1)' + + This would generate: + + def factorial(n): + if n <= 1: + return 1 + else: + return factorial(n - 1) + """ + + @util.positional(2) + def __init__(self, output, indent_space=2): + """Constructor. + + Args: + output: File-like object to wrap. + indent_space: Number of spaces each level of indentation will be. + """ + # Private attributes: + # + # __output: The wrapped file-like object. + # __indent_space: String to append for each level of indentation. + # __indentation: The current full indentation string. + self.__output = output + self.__indent_space = indent_space * ' ' + self.__indentation = 0 + + @property + def indent_level(self): + """Current level of indentation for IndentWriter.""" + return self.__indentation + + def write_line(self, line): + """Write line to wrapped file-like object using correct indentation. + + The line is written with the current level of indentation printed before it + and terminated by a new line. + + Args: + line: Line to write to wrapped file-like object. + """ + if line != '': + self.__output.write(self.__indentation * self.__indent_space) + self.__output.write(line) + self.__output.write('\n') + + def begin_indent(self): + """Begin a level of indentation.""" + self.__indentation += 1 + + def end_indent(self): + """Undo the most recent level of indentation. + + Raises: + IndentationError when called with no indentation levels. + """ + if not self.__indentation: + raise IndentationError('Unable to un-indent further') + self.__indentation -= 1 + + @contextlib.contextmanager + def indent(self): + """Create indentation level compatible with the Python 'with' keyword.""" + self.begin_indent() + yield + self.end_indent() + + def __lshift__(self, line): + """Syntactic sugar for write_line method. + + Args: + line: Line to write to wrapped file-like object. + """ + self.write_line(line) diff --git a/endpoints/internal/protorpc/generate_proto.py b/endpoints/internal/protorpc/generate_proto.py new file mode 100644 index 0000000..8e4b19e --- /dev/null +++ b/endpoints/internal/protorpc/generate_proto.py @@ -0,0 +1,127 @@ +#!/usr/bin/env python +# +# Copyright 2010 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import with_statement + +__author__ = 'rafek@google.com (Rafe Kaplan)' + +import logging + +from . import descriptor +from . import generate +from . import messages +from . import util + + +__all__ = ['format_proto_file'] + + +@util.positional(2) +def format_proto_file(file_descriptor, output, indent_space=2): + out = generate.IndentWriter(output, indent_space=indent_space) + + if file_descriptor.package: + out << 'package %s;' % file_descriptor.package + + def write_enums(enum_descriptors): + """Write nested and non-nested Enum types. + + Args: + enum_descriptors: List of EnumDescriptor objects from which to generate + enums. + """ + # Write enums. + for enum in enum_descriptors or []: + out << '' + out << '' + out << 'enum %s {' % enum.name + out << '' + + with out.indent(): + if enum.values: + for enum_value in enum.values: + out << '%s = %s;' % (enum_value.name, enum_value.number) + + out << '}' + + write_enums(file_descriptor.enum_types) + + def write_fields(field_descriptors): + """Write fields for Message types. + + Args: + field_descriptors: List of FieldDescriptor objects from which to generate + fields. + """ + for field in field_descriptors or []: + default_format = '' + if field.default_value is not None: + if field.label == descriptor.FieldDescriptor.Label.REPEATED: + logging.warning('Default value for repeated field %s is not being ' + 'written to proto file' % field.name) + else: + # Convert default value to string. + if field.variant == messages.Variant.MESSAGE: + logging.warning( + 'Message field %s should not have default values' % field.name) + default = None + elif field.variant == messages.Variant.STRING: + default = repr(field.default_value.encode('utf-8')) + elif field.variant == messages.Variant.BYTES: + default = repr(field.default_value) + else: + default = str(field.default_value) + + if default is not None: + default_format = ' [default=%s]' % default + + if field.variant in (messages.Variant.MESSAGE, messages.Variant.ENUM): + field_type = field.type_name + else: + field_type = str(field.variant).lower() + + out << '%s %s %s = %s%s;' % (str(field.label).lower(), + field_type, + field.name, + field.number, + default_format) + + def write_messages(message_descriptors): + """Write nested and non-nested Message types. + + Args: + message_descriptors: List of MessageDescriptor objects from which to + generate messages. + """ + for message in message_descriptors or []: + out << '' + out << '' + out << 'message %s {' % message.name + + with out.indent(): + if message.enum_types: + write_enums(message.enum_types) + + if message.message_types: + write_messages(message.message_types) + + if message.fields: + write_fields(message.fields) + + out << '}' + + write_messages(file_descriptor.message_types) diff --git a/endpoints/internal/protorpc/generate_proto_test.py b/endpoints/internal/protorpc/generate_proto_test.py new file mode 100644 index 0000000..43469b5 --- /dev/null +++ b/endpoints/internal/protorpc/generate_proto_test.py @@ -0,0 +1,197 @@ +#!/usr/bin/env python +# +# Copyright 2010 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""Tests for protorpc.generate_proto_test.""" + + +import os +import shutil +import cStringIO +import sys +import tempfile +import unittest + +from protorpc import descriptor +from protorpc import generate_proto +from protorpc import test_util +from protorpc import util + + +class ModuleInterfaceTest(test_util.ModuleInterfaceTest, + test_util.TestCase): + + MODULE = generate_proto + + +class FormatProtoFileTest(test_util.TestCase): + + def setUp(self): + self.file_descriptor = descriptor.FileDescriptor() + self.output = cStringIO.StringIO() + + @property + def result(self): + return self.output.getvalue() + + def MakeMessage(self, name='MyMessage', fields=[]): + message = descriptor.MessageDescriptor() + message.name = name + message.fields = fields + + messages_list = getattr(self.file_descriptor, 'fields', []) + messages_list.append(message) + self.file_descriptor.message_types = messages_list + + def testBlankPackage(self): + self.file_descriptor.package = None + generate_proto.format_proto_file(self.file_descriptor, self.output) + self.assertEquals('', self.result) + + def testEmptyPackage(self): + self.file_descriptor.package = 'my_package' + generate_proto.format_proto_file(self.file_descriptor, self.output) + self.assertEquals('package my_package;\n', self.result) + + def testSingleField(self): + field = descriptor.FieldDescriptor() + field.name = 'integer_field' + field.number = 1 + field.label = descriptor.FieldDescriptor.Label.OPTIONAL + field.variant = descriptor.FieldDescriptor.Variant.INT64 + + self.MakeMessage(fields=[field]) + + generate_proto.format_proto_file(self.file_descriptor, self.output) + self.assertEquals('\n\n' + 'message MyMessage {\n' + ' optional int64 integer_field = 1;\n' + '}\n', + self.result) + + def testSingleFieldWithDefault(self): + field = descriptor.FieldDescriptor() + field.name = 'integer_field' + field.number = 1 + field.label = descriptor.FieldDescriptor.Label.OPTIONAL + field.variant = descriptor.FieldDescriptor.Variant.INT64 + field.default_value = '10' + + self.MakeMessage(fields=[field]) + + generate_proto.format_proto_file(self.file_descriptor, self.output) + self.assertEquals('\n\n' + 'message MyMessage {\n' + ' optional int64 integer_field = 1 [default=10];\n' + '}\n', + self.result) + + def testRepeatedFieldWithDefault(self): + field = descriptor.FieldDescriptor() + field.name = 'integer_field' + field.number = 1 + field.label = descriptor.FieldDescriptor.Label.REPEATED + field.variant = descriptor.FieldDescriptor.Variant.INT64 + field.default_value = '[10, 20]' + + self.MakeMessage(fields=[field]) + + generate_proto.format_proto_file(self.file_descriptor, self.output) + self.assertEquals('\n\n' + 'message MyMessage {\n' + ' repeated int64 integer_field = 1;\n' + '}\n', + self.result) + + def testSingleFieldWithDefaultString(self): + field = descriptor.FieldDescriptor() + field.name = 'string_field' + field.number = 1 + field.label = descriptor.FieldDescriptor.Label.OPTIONAL + field.variant = descriptor.FieldDescriptor.Variant.STRING + field.default_value = 'hello' + + self.MakeMessage(fields=[field]) + + generate_proto.format_proto_file(self.file_descriptor, self.output) + self.assertEquals('\n\n' + 'message MyMessage {\n' + " optional string string_field = 1 [default='hello'];\n" + '}\n', + self.result) + + def testSingleFieldWithDefaultEmptyString(self): + field = descriptor.FieldDescriptor() + field.name = 'string_field' + field.number = 1 + field.label = descriptor.FieldDescriptor.Label.OPTIONAL + field.variant = descriptor.FieldDescriptor.Variant.STRING + field.default_value = '' + + self.MakeMessage(fields=[field]) + + generate_proto.format_proto_file(self.file_descriptor, self.output) + self.assertEquals('\n\n' + 'message MyMessage {\n' + " optional string string_field = 1 [default=''];\n" + '}\n', + self.result) + + def testSingleFieldWithDefaultMessage(self): + field = descriptor.FieldDescriptor() + field.name = 'message_field' + field.number = 1 + field.label = descriptor.FieldDescriptor.Label.OPTIONAL + field.variant = descriptor.FieldDescriptor.Variant.MESSAGE + field.type_name = 'MyNestedMessage' + field.default_value = 'not valid' + + self.MakeMessage(fields=[field]) + + generate_proto.format_proto_file(self.file_descriptor, self.output) + self.assertEquals('\n\n' + 'message MyMessage {\n' + " optional MyNestedMessage message_field = 1;\n" + '}\n', + self.result) + + def testSingleFieldWithDefaultEnum(self): + field = descriptor.FieldDescriptor() + field.name = 'enum_field' + field.number = 1 + field.label = descriptor.FieldDescriptor.Label.OPTIONAL + field.variant = descriptor.FieldDescriptor.Variant.ENUM + field.type_name = 'my_package.MyEnum' + field.default_value = '17' + + self.MakeMessage(fields=[field]) + + generate_proto.format_proto_file(self.file_descriptor, self.output) + self.assertEquals('\n\n' + 'message MyMessage {\n' + " optional my_package.MyEnum enum_field = 1 " + "[default=17];\n" + '}\n', + self.result) + + +def main(): + unittest.main() + + +if __name__ == '__main__': + main() + diff --git a/endpoints/internal/protorpc/generate_python.py b/endpoints/internal/protorpc/generate_python.py new file mode 100644 index 0000000..5234e05 --- /dev/null +++ b/endpoints/internal/protorpc/generate_python.py @@ -0,0 +1,218 @@ +#!/usr/bin/env python +# +# Copyright 2010 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import with_statement + +__author__ = 'rafek@google.com (Rafe Kaplan)' + +from . import descriptor +from . import generate +from . import message_types +from . import messages +from . import util + + +__all__ = ['format_python_file'] + +_MESSAGE_FIELD_MAP = { + message_types.DateTimeMessage.definition_name(): message_types.DateTimeField, +} + + +def _write_enums(enum_descriptors, out): + """Write nested and non-nested Enum types. + + Args: + enum_descriptors: List of EnumDescriptor objects from which to generate + enums. + out: Indent writer used for generating text. + """ + # Write enums. + for enum in enum_descriptors or []: + out << '' + out << '' + out << 'class %s(messages.Enum):' % enum.name + out << '' + + with out.indent(): + if not enum.values: + out << 'pass' + else: + for enum_value in enum.values: + out << '%s = %s' % (enum_value.name, enum_value.number) + + +def _write_fields(field_descriptors, out): + """Write fields for Message types. + + Args: + field_descriptors: List of FieldDescriptor objects from which to generate + fields. + out: Indent writer used for generating text. + """ + out << '' + for field in field_descriptors or []: + type_format = '' + label_format = '' + + message_field = _MESSAGE_FIELD_MAP.get(field.type_name) + if message_field: + module = 'message_types' + field_type = message_field + else: + module = 'messages' + field_type = messages.Field.lookup_field_type_by_variant(field.variant) + + if field_type in (messages.EnumField, messages.MessageField): + type_format = '\'%s\', ' % field.type_name + + if field.label == descriptor.FieldDescriptor.Label.REQUIRED: + label_format = ', required=True' + + elif field.label == descriptor.FieldDescriptor.Label.REPEATED: + label_format = ', repeated=True' + + if field_type.DEFAULT_VARIANT != field.variant: + variant_format = ', variant=messages.Variant.%s' % field.variant + else: + variant_format = '' + + if field.default_value: + if field_type in [messages.BytesField, + messages.StringField, + ]: + default_value = repr(field.default_value) + elif field_type is messages.EnumField: + try: + default_value = str(int(field.default_value)) + except ValueError: + default_value = repr(field.default_value) + else: + default_value = field.default_value + + default_format = ', default=%s' % (default_value,) + else: + default_format = '' + + out << '%s = %s.%s(%s%s%s%s%s)' % (field.name, + module, + field_type.__name__, + type_format, + field.number, + label_format, + variant_format, + default_format) + + +def _write_messages(message_descriptors, out): + """Write nested and non-nested Message types. + + Args: + message_descriptors: List of MessageDescriptor objects from which to + generate messages. + out: Indent writer used for generating text. + """ + for message in message_descriptors or []: + out << '' + out << '' + out << 'class %s(messages.Message):' % message.name + + with out.indent(): + if not (message.enum_types or message.message_types or message.fields): + out << '' + out << 'pass' + else: + _write_enums(message.enum_types, out) + _write_messages(message.message_types, out) + _write_fields(message.fields, out) + + +def _write_methods(method_descriptors, out): + """Write methods of Service types. + + All service method implementations raise NotImplementedError. + + Args: + method_descriptors: List of MethodDescriptor objects from which to + generate methods. + out: Indent writer used for generating text. + """ + for method in method_descriptors: + out << '' + out << "@remote.method('%s', '%s')" % (method.request_type, + method.response_type) + out << 'def %s(self, request):' % (method.name,) + with out.indent(): + out << ('raise NotImplementedError' + "('Method %s is not implemented')" % (method.name)) + + +def _write_services(service_descriptors, out): + """Write Service types. + + Args: + service_descriptors: List of ServiceDescriptor instances from which to + generate services. + out: Indent writer used for generating text. + """ + for service in service_descriptors or []: + out << '' + out << '' + out << 'class %s(remote.Service):' % service.name + + with out.indent(): + if service.methods: + _write_methods(service.methods, out) + else: + out << '' + out << 'pass' + + +@util.positional(2) +def format_python_file(file_descriptor, output, indent_space=2): + """Format FileDescriptor object as a single Python module. + + Services generated by this function will raise NotImplementedError. + + All Python classes generated by this function use delayed binding for all + message fields, enum fields and method parameter types. For example a + service method might be generated like so: + + class MyService(remote.Service): + + @remote.method('my_package.MyRequestType', 'my_package.MyResponseType') + def my_method(self, request): + raise NotImplementedError('Method my_method is not implemented') + + Args: + file_descriptor: FileDescriptor instance to format as python module. + output: File-like object to write module source code to. + indent_space: Number of spaces for each level of Python indentation. + """ + out = generate.IndentWriter(output, indent_space=indent_space) + + out << 'from protorpc import message_types' + out << 'from protorpc import messages' + if file_descriptor.service_types: + out << 'from protorpc import remote' + + if file_descriptor.package: + out << "package = '%s'" % file_descriptor.package + + _write_enums(file_descriptor.enum_types, out) + _write_messages(file_descriptor.message_types, out) + _write_services(file_descriptor.service_types, out) diff --git a/endpoints/internal/protorpc/generate_python_test.py b/endpoints/internal/protorpc/generate_python_test.py new file mode 100644 index 0000000..21a05cc --- /dev/null +++ b/endpoints/internal/protorpc/generate_python_test.py @@ -0,0 +1,362 @@ +#!/usr/bin/env python +# +# Copyright 2010 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""Tests for protorpc.generate_python_test.""" + +__author__ = 'rafek@google.com (Rafe Kaplan)' + + +import os +import shutil +import sys +import tempfile +import unittest + +from protorpc import descriptor +from protorpc import generate_python +from protorpc import test_util +from protorpc import util + + +class ModuleInterfaceTest(test_util.ModuleInterfaceTest, + test_util.TestCase): + + MODULE = generate_python + + +class FormatPythonFileTest(test_util.TestCase): + + def setUp(self): + self.original_path = list(sys.path) + self.original_modules = dict(sys.modules) + sys.path = list(sys.path) + self.file_descriptor = descriptor.FileDescriptor() + + # Create temporary directory and add to Python path so that generated + # Python code can be easily parsed, imported and executed. + self.temp_dir = tempfile.mkdtemp() + sys.path.append(self.temp_dir) + + def tearDown(self): + # Reset path. + sys.path[:] = [] + sys.path.extend(self.original_path) + + # Reset modules. + sys.modules.clear() + sys.modules.update(self.original_modules) + + # Remove temporary directory. + try: + shutil.rmtree(self.temp_dir) + except IOError: + pass + + def DoPythonTest(self, file_descriptor): + """Execute python test based on a FileDescriptor object. + + The full test of the Python code generation is to generate a Python source + code file, import the module and regenerate the FileDescriptor from it. + If the generated FileDescriptor is the same as the original, it means that + the generated source code correctly implements the actual FileDescriptor. + """ + file_name = os.path.join(self.temp_dir, + '%s.py' % (file_descriptor.package or 'blank',)) + source_file = open(file_name, 'wt') + try: + generate_python.format_python_file(file_descriptor, source_file) + finally: + source_file.close() + + module_to_import = file_descriptor.package or 'blank' + module = __import__(module_to_import) + + if not file_descriptor.package: + self.assertFalse(hasattr(module, 'package')) + module.package = '' # Create package name so that comparison will work. + + reloaded_descriptor = descriptor.describe_file(module) + + # Need to sort both message_types fields because document order is never + # Ensured. + # TODO(rafek): Ensure document order. + if reloaded_descriptor.message_types: + reloaded_descriptor.message_types = sorted( + reloaded_descriptor.message_types, key=lambda v: v.name) + + if file_descriptor.message_types: + file_descriptor.message_types = sorted( + file_descriptor.message_types, key=lambda v: v.name) + + self.assertEquals(file_descriptor, reloaded_descriptor) + + @util.positional(2) + def DoMessageTest(self, + field_descriptors, + message_types=None, + enum_types=None): + """Execute message generation test based on FieldDescriptor objects. + + Args: + field_descriptor: List of FieldDescriptor object to generate and test. + message_types: List of other MessageDescriptor objects that the new + Message class depends on. + enum_types: List of EnumDescriptor objects that the new Message class + depends on. + """ + file_descriptor = descriptor.FileDescriptor() + file_descriptor.package = 'my_package' + + message_descriptor = descriptor.MessageDescriptor() + message_descriptor.name = 'MyMessage' + + message_descriptor.fields = list(field_descriptors) + + file_descriptor.message_types = message_types or [] + file_descriptor.message_types.append(message_descriptor) + + if enum_types: + file_descriptor.enum_types = list(enum_types) + + self.DoPythonTest(file_descriptor) + + def testBlankPackage(self): + self.DoPythonTest(descriptor.FileDescriptor()) + + def testEmptyPackage(self): + file_descriptor = descriptor.FileDescriptor() + file_descriptor.package = 'mypackage' + self.DoPythonTest(file_descriptor) + + def testSingleField(self): + field = descriptor.FieldDescriptor() + field.name = 'integer_field' + field.number = 1 + field.label = descriptor.FieldDescriptor.Label.OPTIONAL + field.variant = descriptor.FieldDescriptor.Variant.INT64 + + self.DoMessageTest([field]) + + def testMessageField_InternalReference(self): + other_message = descriptor.MessageDescriptor() + other_message.name = 'OtherMessage' + + field = descriptor.FieldDescriptor() + field.name = 'message_field' + field.number = 1 + field.label = descriptor.FieldDescriptor.Label.OPTIONAL + field.variant = descriptor.FieldDescriptor.Variant.MESSAGE + field.type_name = 'my_package.OtherMessage' + + self.DoMessageTest([field], message_types=[other_message]) + + def testMessageField_ExternalReference(self): + field = descriptor.FieldDescriptor() + field.name = 'message_field' + field.number = 1 + field.label = descriptor.FieldDescriptor.Label.OPTIONAL + field.variant = descriptor.FieldDescriptor.Variant.MESSAGE + field.type_name = 'protorpc.registry.GetFileSetResponse' + + self.DoMessageTest([field]) + + def testEnumField_InternalReference(self): + enum = descriptor.EnumDescriptor() + enum.name = 'Color' + + field = descriptor.FieldDescriptor() + field.name = 'color' + field.number = 1 + field.label = descriptor.FieldDescriptor.Label.OPTIONAL + field.variant = descriptor.FieldDescriptor.Variant.ENUM + field.type_name = 'my_package.Color' + + self.DoMessageTest([field], enum_types=[enum]) + + def testEnumField_ExternalReference(self): + field = descriptor.FieldDescriptor() + field.name = 'color' + field.number = 1 + field.label = descriptor.FieldDescriptor.Label.OPTIONAL + field.variant = descriptor.FieldDescriptor.Variant.ENUM + field.type_name = 'protorpc.descriptor.FieldDescriptor.Label' + + self.DoMessageTest([field]) + + def testDateTimeField(self): + field = descriptor.FieldDescriptor() + field.name = 'timestamp' + field.number = 1 + field.label = descriptor.FieldDescriptor.Label.OPTIONAL + field.variant = descriptor.FieldDescriptor.Variant.MESSAGE + field.type_name = 'protorpc.message_types.DateTimeMessage' + + self.DoMessageTest([field]) + + def testNonDefaultVariant(self): + field = descriptor.FieldDescriptor() + field.name = 'integer_field' + field.number = 1 + field.label = descriptor.FieldDescriptor.Label.OPTIONAL + field.variant = descriptor.FieldDescriptor.Variant.UINT64 + + self.DoMessageTest([field]) + + def testRequiredField(self): + field = descriptor.FieldDescriptor() + field.name = 'integer_field' + field.number = 1 + field.label = descriptor.FieldDescriptor.Label.REQUIRED + field.variant = descriptor.FieldDescriptor.Variant.INT64 + + self.DoMessageTest([field]) + + def testRepeatedField(self): + field = descriptor.FieldDescriptor() + field.name = 'integer_field' + field.number = 1 + field.label = descriptor.FieldDescriptor.Label.REPEATED + field.variant = descriptor.FieldDescriptor.Variant.INT64 + + self.DoMessageTest([field]) + + def testIntegerDefaultValue(self): + field = descriptor.FieldDescriptor() + field.name = 'integer_field' + field.number = 1 + field.label = descriptor.FieldDescriptor.Label.OPTIONAL + field.variant = descriptor.FieldDescriptor.Variant.INT64 + field.default_value = '10' + + self.DoMessageTest([field]) + + def testFloatDefaultValue(self): + field = descriptor.FieldDescriptor() + field.name = 'float_field' + field.number = 1 + field.label = descriptor.FieldDescriptor.Label.OPTIONAL + field.variant = descriptor.FieldDescriptor.Variant.DOUBLE + field.default_value = '10.1' + + self.DoMessageTest([field]) + + def testStringDefaultValue(self): + field = descriptor.FieldDescriptor() + field.name = 'string_field' + field.number = 1 + field.label = descriptor.FieldDescriptor.Label.OPTIONAL + field.variant = descriptor.FieldDescriptor.Variant.STRING + field.default_value = u'a nice lovely string\'s "string"' + + self.DoMessageTest([field]) + + def testEnumDefaultValue(self): + field = descriptor.FieldDescriptor() + field.name = 'label' + field.number = 1 + field.label = descriptor.FieldDescriptor.Label.OPTIONAL + field.variant = descriptor.FieldDescriptor.Variant.ENUM + field.type_name = 'protorpc.descriptor.FieldDescriptor.Label' + field.default_value = '2' + + self.DoMessageTest([field]) + + def testMultiFields(self): + field1 = descriptor.FieldDescriptor() + field1.name = 'integer_field' + field1.number = 1 + field1.label = descriptor.FieldDescriptor.Label.OPTIONAL + field1.variant = descriptor.FieldDescriptor.Variant.INT64 + + field2 = descriptor.FieldDescriptor() + field2.name = 'string_field' + field2.number = 2 + field2.label = descriptor.FieldDescriptor.Label.OPTIONAL + field2.variant = descriptor.FieldDescriptor.Variant.STRING + + field3 = descriptor.FieldDescriptor() + field3.name = 'unsigned_integer_field' + field3.number = 3 + field3.label = descriptor.FieldDescriptor.Label.OPTIONAL + field3.variant = descriptor.FieldDescriptor.Variant.UINT64 + + self.DoMessageTest([field1, field2, field3]) + + def testNestedMessage(self): + message = descriptor.MessageDescriptor() + message.name = 'OuterMessage' + + inner_message = descriptor.MessageDescriptor() + inner_message.name = 'InnerMessage' + + inner_inner_message = descriptor.MessageDescriptor() + inner_inner_message.name = 'InnerInnerMessage' + + inner_message.message_types = [inner_inner_message] + + message.message_types = [inner_message] + + file_descriptor = descriptor.FileDescriptor() + file_descriptor.message_types = [message] + + self.DoPythonTest(file_descriptor) + + def testNestedEnum(self): + message = descriptor.MessageDescriptor() + message.name = 'OuterMessage' + + inner_enum = descriptor.EnumDescriptor() + inner_enum.name = 'InnerEnum' + + message.enum_types = [inner_enum] + + file_descriptor = descriptor.FileDescriptor() + file_descriptor.message_types = [message] + + self.DoPythonTest(file_descriptor) + + def testService(self): + service = descriptor.ServiceDescriptor() + service.name = 'TheService' + + method1 = descriptor.MethodDescriptor() + method1.name = 'method1' + method1.request_type = 'protorpc.descriptor.FileDescriptor' + method1.response_type = 'protorpc.descriptor.MethodDescriptor' + + service.methods = [method1] + + file_descriptor = descriptor.FileDescriptor() + file_descriptor.service_types = [service] + + self.DoPythonTest(file_descriptor) + + # Test to make sure that implementation methods raise an exception. + import blank + service_instance = blank.TheService() + self.assertRaisesWithRegexpMatch(NotImplementedError, + 'Method method1 is not implemented', + service_instance.method1, + descriptor.FileDescriptor()) + + +def main(): + unittest.main() + + +if __name__ == '__main__': + main() diff --git a/endpoints/internal/protorpc/generate_test.py b/endpoints/internal/protorpc/generate_test.py new file mode 100644 index 0000000..7b9893a --- /dev/null +++ b/endpoints/internal/protorpc/generate_test.py @@ -0,0 +1,152 @@ +#!/usr/bin/env python +# +# Copyright 2010 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""Tests for protorpc.generate.""" + +__author__ = 'rafek@google.com (Rafe Kaplan)' + +import operator + +import cStringIO +import sys +import unittest + +from protorpc import generate +from protorpc import test_util + + +class ModuleInterfaceTest(test_util.ModuleInterfaceTest, + test_util.TestCase): + + MODULE = generate + + +class IndentWriterTest(test_util.TestCase): + + def setUp(self): + self.out = cStringIO.StringIO() + self.indent_writer = generate.IndentWriter(self.out) + + def testWriteLine(self): + self.indent_writer.write_line('This is a line') + self.indent_writer.write_line('This is another line') + + self.assertEquals('This is a line\n' + 'This is another line\n', + self.out.getvalue()) + + def testLeftShift(self): + self.run_count = 0 + def mock_write_line(line): + self.run_count += 1 + self.assertEquals('same as calling write_line', line) + + self.indent_writer.write_line = mock_write_line + self.indent_writer << 'same as calling write_line' + self.assertEquals(1, self.run_count) + + def testIndentation(self): + self.indent_writer << 'indent 0' + self.indent_writer.begin_indent() + self.indent_writer << 'indent 1' + self.indent_writer.begin_indent() + self.indent_writer << 'indent 2' + self.indent_writer.end_indent() + self.indent_writer << 'end 2' + self.indent_writer.end_indent() + self.indent_writer << 'end 1' + self.assertRaises(generate.IndentationError, + self.indent_writer.end_indent) + + self.assertEquals('indent 0\n' + ' indent 1\n' + ' indent 2\n' + ' end 2\n' + 'end 1\n', + self.out.getvalue()) + + def testBlankLine(self): + self.indent_writer << '' + self.indent_writer.begin_indent() + self.indent_writer << '' + self.assertEquals('\n\n', self.out.getvalue()) + + def testNoneInvalid(self): + self.assertRaises( + TypeError, operator.lshift, self.indent_writer, None) + + def testAltIndentation(self): + self.indent_writer = generate.IndentWriter(self.out, indent_space=3) + self.indent_writer << 'indent 0' + self.assertEquals(0, self.indent_writer.indent_level) + self.indent_writer.begin_indent() + self.indent_writer << 'indent 1' + self.assertEquals(1, self.indent_writer.indent_level) + self.indent_writer.begin_indent() + self.indent_writer << 'indent 2' + self.assertEquals(2, self.indent_writer.indent_level) + self.indent_writer.end_indent() + self.indent_writer << 'end 2' + self.assertEquals(1, self.indent_writer.indent_level) + self.indent_writer.end_indent() + self.indent_writer << 'end 1' + self.assertEquals(0, self.indent_writer.indent_level) + self.assertRaises(generate.IndentationError, + self.indent_writer.end_indent) + self.assertEquals(0, self.indent_writer.indent_level) + + self.assertEquals('indent 0\n' + ' indent 1\n' + ' indent 2\n' + ' end 2\n' + 'end 1\n', + self.out.getvalue()) + + def testIndent(self): + self.indent_writer << 'indent 0' + self.assertEquals(0, self.indent_writer.indent_level) + + def indent1(): + self.indent_writer << 'indent 1' + self.assertEquals(1, self.indent_writer.indent_level) + + def indent2(): + self.indent_writer << 'indent 2' + self.assertEquals(2, self.indent_writer.indent_level) + test_util.do_with(self.indent_writer.indent(), indent2) + + self.assertEquals(1, self.indent_writer.indent_level) + self.indent_writer << 'end 2' + test_util.do_with(self.indent_writer.indent(), indent1) + + self.assertEquals(0, self.indent_writer.indent_level) + self.indent_writer << 'end 1' + + self.assertEquals('indent 0\n' + ' indent 1\n' + ' indent 2\n' + ' end 2\n' + 'end 1\n', + self.out.getvalue()) + + +def main(): + unittest.main() + + +if __name__ == '__main__': + main() diff --git a/endpoints/internal/protorpc/google_imports.py b/endpoints/internal/protorpc/google_imports.py new file mode 100644 index 0000000..7ed0f40 --- /dev/null +++ b/endpoints/internal/protorpc/google_imports.py @@ -0,0 +1,15 @@ +"""Dynamically decide from where to import other SDK modules. + +All other protorpc code should import other SDK modules from +this module. If necessary, add new imports here (in both places). +""" + +__author__ = 'yey@google.com (Ye Yuan)' + +# pylint: disable=g-import-not-at-top +# pylint: disable=unused-import + +try: + from google.net.proto import ProtocolBuffer +except ImportError: + pass diff --git a/endpoints/internal/protorpc/message_types.py b/endpoints/internal/protorpc/message_types.py new file mode 100644 index 0000000..f707d52 --- /dev/null +++ b/endpoints/internal/protorpc/message_types.py @@ -0,0 +1,119 @@ +#!/usr/bin/env python +# +# Copyright 2010 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""Simple protocol message types. + +Includes new message and field types that are outside what is defined by the +protocol buffers standard. +""" + +__author__ = 'rafek@google.com (Rafe Kaplan)' + +import datetime + +from . import messages +from . import util + +__all__ = [ + 'DateTimeField', + 'DateTimeMessage', + 'VoidMessage', +] + +class VoidMessage(messages.Message): + """Empty message.""" + + +class DateTimeMessage(messages.Message): + """Message to store/transmit a DateTime. + + Fields: + milliseconds: Milliseconds since Jan 1st 1970 local time. + time_zone_offset: Optional time zone offset, in minutes from UTC. + """ + milliseconds = messages.IntegerField(1, required=True) + time_zone_offset = messages.IntegerField(2) + + +class DateTimeField(messages.MessageField): + """Field definition for datetime values. + + Stores a python datetime object as a field. If time zone information is + included in the datetime object, it will be included in + the encoded data when this is encoded/decoded. + """ + + type = datetime.datetime + + message_type = DateTimeMessage + + @util.positional(3) + def __init__(self, + number, + **kwargs): + super(DateTimeField, self).__init__(self.message_type, + number, + **kwargs) + + def value_from_message(self, message): + """Convert DateTimeMessage to a datetime. + + Args: + A DateTimeMessage instance. + + Returns: + A datetime instance. + """ + message = super(DateTimeField, self).value_from_message(message) + if message.time_zone_offset is None: + return datetime.datetime.utcfromtimestamp(message.milliseconds / 1000.0) + + # Need to subtract the time zone offset, because when we call + # datetime.fromtimestamp, it will add the time zone offset to the + # value we pass. + milliseconds = (message.milliseconds - + 60000 * message.time_zone_offset) + + timezone = util.TimeZoneOffset(message.time_zone_offset) + return datetime.datetime.fromtimestamp(milliseconds / 1000.0, + tz=timezone) + + def value_to_message(self, value): + value = super(DateTimeField, self).value_to_message(value) + # First, determine the delta from the epoch, so we can fill in + # DateTimeMessage's milliseconds field. + if value.tzinfo is None: + time_zone_offset = 0 + local_epoch = datetime.datetime.utcfromtimestamp(0) + else: + time_zone_offset = util.total_seconds(value.tzinfo.utcoffset(value)) + # Determine Jan 1, 1970 local time. + local_epoch = datetime.datetime.fromtimestamp(-time_zone_offset, + tz=value.tzinfo) + delta = value - local_epoch + + # Create and fill in the DateTimeMessage, including time zone if + # one was specified. + message = DateTimeMessage() + message.milliseconds = int(util.total_seconds(delta) * 1000) + if value.tzinfo is not None: + utc_offset = value.tzinfo.utcoffset(value) + if utc_offset is not None: + message.time_zone_offset = int( + util.total_seconds(value.tzinfo.utcoffset(value)) / 60) + + return message diff --git a/endpoints/internal/protorpc/message_types_test.py b/endpoints/internal/protorpc/message_types_test.py new file mode 100644 index 0000000..b061cdf --- /dev/null +++ b/endpoints/internal/protorpc/message_types_test.py @@ -0,0 +1,88 @@ +#!/usr/bin/env python +# +# Copyright 2013 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""Tests for protorpc.message_types.""" + +__author__ = 'rafek@google.com (Rafe Kaplan)' + + +import datetime + +import unittest + +from protorpc import message_types +from protorpc import messages +from protorpc import test_util +from protorpc import util + + +class ModuleInterfaceTest(test_util.ModuleInterfaceTest, + test_util.TestCase): + + MODULE = message_types + + +class DateTimeFieldTest(test_util.TestCase): + + def testValueToMessage(self): + field = message_types.DateTimeField(1) + message = field.value_to_message(datetime.datetime(2033, 2, 4, 11, 22, 10)) + self.assertEqual(message_types.DateTimeMessage(milliseconds=1991128930000), + message) + + def testValueToMessageBadValue(self): + field = message_types.DateTimeField(1) + self.assertRaisesWithRegexpMatch( + messages.EncodeError, + 'Expected type datetime, got int: 20', + field.value_to_message, 20) + + def testValueToMessageWithTimeZone(self): + time_zone = util.TimeZoneOffset(60 * 10) + field = message_types.DateTimeField(1) + message = field.value_to_message( + datetime.datetime(2033, 2, 4, 11, 22, 10, tzinfo=time_zone)) + self.assertEqual(message_types.DateTimeMessage(milliseconds=1991128930000, + time_zone_offset=600), + message) + + def testValueFromMessage(self): + message = message_types.DateTimeMessage(milliseconds=1991128000000) + field = message_types.DateTimeField(1) + timestamp = field.value_from_message(message) + self.assertEqual(datetime.datetime(2033, 2, 4, 11, 6, 40), + timestamp) + + def testValueFromMessageBadValue(self): + field = message_types.DateTimeField(1) + self.assertRaisesWithRegexpMatch( + messages.DecodeError, + 'Expected type DateTimeMessage, got VoidMessage: ', + field.value_from_message, message_types.VoidMessage()) + + def testValueFromMessageWithTimeZone(self): + message = message_types.DateTimeMessage(milliseconds=1991128000000, + time_zone_offset=300) + field = message_types.DateTimeField(1) + timestamp = field.value_from_message(message) + time_zone = util.TimeZoneOffset(60 * 5) + self.assertEqual(datetime.datetime(2033, 2, 4, 11, 6, 40, tzinfo=time_zone), + timestamp) + + +if __name__ == '__main__': + unittest.main() diff --git a/endpoints/internal/protorpc/messages.py b/endpoints/internal/protorpc/messages.py new file mode 100644 index 0000000..024039a --- /dev/null +++ b/endpoints/internal/protorpc/messages.py @@ -0,0 +1,1949 @@ +#!/usr/bin/env python +# +# Copyright 2010 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""Stand-alone implementation of in memory protocol messages. + +Public Classes: + Enum: Represents an enumerated type. + Variant: Hint for wire format to determine how to serialize. + Message: Base class for user defined messages. + IntegerField: Field for integer values. + FloatField: Field for float values. + BooleanField: Field for boolean values. + BytesField: Field for binary string values. + StringField: Field for UTF-8 string values. + MessageField: Field for other message type values. + EnumField: Field for enumerated type values. + +Public Exceptions (indentation indications class hierarchy): + EnumDefinitionError: Raised when enumeration is incorrectly defined. + FieldDefinitionError: Raised when field is incorrectly defined. + InvalidVariantError: Raised when variant is not compatible with field type. + InvalidDefaultError: Raised when default is not compatiable with field. + InvalidNumberError: Raised when field number is out of range or reserved. + MessageDefinitionError: Raised when message is incorrectly defined. + DuplicateNumberError: Raised when field has duplicate number with another. + ValidationError: Raised when a message or field is not valid. + DefinitionNotFoundError: Raised when definition not found. +""" +import six + +__author__ = 'rafek@google.com (Rafe Kaplan)' + + +import types +import weakref + +from . import util + +__all__ = ['MAX_ENUM_VALUE', + 'MAX_FIELD_NUMBER', + 'FIRST_RESERVED_FIELD_NUMBER', + 'LAST_RESERVED_FIELD_NUMBER', + + 'Enum', + 'Field', + 'FieldList', + 'Variant', + 'Message', + 'IntegerField', + 'FloatField', + 'BooleanField', + 'BytesField', + 'StringField', + 'MessageField', + 'EnumField', + 'find_definition', + + 'Error', + 'DecodeError', + 'EncodeError', + 'EnumDefinitionError', + 'FieldDefinitionError', + 'InvalidVariantError', + 'InvalidDefaultError', + 'InvalidNumberError', + 'MessageDefinitionError', + 'DuplicateNumberError', + 'ValidationError', + 'DefinitionNotFoundError', + ] + + +# TODO(rafek): Add extended module test to ensure all exceptions +# in services extends Error. +Error = util.Error + + +class EnumDefinitionError(Error): + """Enumeration definition error.""" + + +class FieldDefinitionError(Error): + """Field definition error.""" + + +class InvalidVariantError(FieldDefinitionError): + """Invalid variant provided to field.""" + + +class InvalidDefaultError(FieldDefinitionError): + """Invalid default provided to field.""" + + +class InvalidNumberError(FieldDefinitionError): + """Invalid number provided to field.""" + + +class MessageDefinitionError(Error): + """Message definition error.""" + + +class DuplicateNumberError(Error): + """Duplicate number assigned to field.""" + + +class DefinitionNotFoundError(Error): + """Raised when definition is not found.""" + + +class DecodeError(Error): + """Error found decoding message from encoded form.""" + + +class EncodeError(Error): + """Error found when encoding message.""" + + +class ValidationError(Error): + """Invalid value for message error.""" + + def __str__(self): + """Prints string with field name if present on exception.""" + message = Error.__str__(self) + try: + field_name = self.field_name + except AttributeError: + return message + else: + return message + + +# Attributes that are reserved by a class definition that +# may not be used by either Enum or Message class definitions. +_RESERVED_ATTRIBUTE_NAMES = frozenset( + ['__module__', '__doc__', '__qualname__']) + +_POST_INIT_FIELD_ATTRIBUTE_NAMES = frozenset( + ['name', + '_message_definition', + '_MessageField__type', + '_EnumField__type', + '_EnumField__resolved_default']) + +_POST_INIT_ATTRIBUTE_NAMES = frozenset( + ['_message_definition']) + +# Maximum enumeration value as defined by the protocol buffers standard. +# All enum values must be less than or equal to this value. +MAX_ENUM_VALUE = (2 ** 29) - 1 + +# Maximum field number as defined by the protocol buffers standard. +# All field numbers must be less than or equal to this value. +MAX_FIELD_NUMBER = (2 ** 29) - 1 + +# Field numbers between 19000 and 19999 inclusive are reserved by the +# protobuf protocol and may not be used by fields. +FIRST_RESERVED_FIELD_NUMBER = 19000 +LAST_RESERVED_FIELD_NUMBER = 19999 + + +class _DefinitionClass(type): + """Base meta-class used for definition meta-classes. + + The Enum and Message definition classes share some basic functionality. + Both of these classes may be contained by a Message definition. After + initialization, neither class may have attributes changed + except for the protected _message_definition attribute, and that attribute + may change only once. + """ + + __initialized = False + + def __init__(cls, name, bases, dct): + """Constructor.""" + type.__init__(cls, name, bases, dct) + # Base classes may never be initialized. + if cls.__bases__ != (object,): + cls.__initialized = True + + def message_definition(cls): + """Get outer Message definition that contains this definition. + + Returns: + Containing Message definition if definition is contained within one, + else None. + """ + try: + return cls._message_definition() + except AttributeError: + return None + + def __setattr__(cls, name, value): + """Overridden so that cannot set variables on definition classes after init. + + Setting attributes on a class must work during the period of initialization + to set the enumation value class variables and build the name/number maps. + Once __init__ has set the __initialized flag to True prohibits setting any + more values on the class. The class is in effect frozen. + + Args: + name: Name of value to set. + value: Value to set. + """ + if cls.__initialized and name not in _POST_INIT_ATTRIBUTE_NAMES: + raise AttributeError('May not change values: %s' % name) + else: + type.__setattr__(cls, name, value) + + def __delattr__(cls, name): + """Overridden so that cannot delete varaibles on definition classes.""" + raise TypeError('May not delete attributes on definition class') + + def definition_name(cls): + """Helper method for creating definition name. + + Names will be generated to include the classes package name, scope (if the + class is nested in another definition) and class name. + + By default, the package name for a definition is derived from its module + name. However, this value can be overriden by placing a 'package' attribute + in the module that contains the definition class. For example: + + package = 'some.alternate.package' + + class MyMessage(Message): + ... + + >>> MyMessage.definition_name() + some.alternate.package.MyMessage + + Returns: + Dot-separated fully qualified name of definition. + """ + outer_definition_name = cls.outer_definition_name() + if outer_definition_name is None: + return six.text_type(cls.__name__) + else: + return u'%s.%s' % (outer_definition_name, cls.__name__) + + def outer_definition_name(cls): + """Helper method for creating outer definition name. + + Returns: + If definition is nested, will return the outer definitions name, else the + package name. + """ + outer_definition = cls.message_definition() + if not outer_definition: + return util.get_package_for_module(cls.__module__) + else: + return outer_definition.definition_name() + + def definition_package(cls): + """Helper method for creating creating the package of a definition. + + Returns: + Name of package that definition belongs to. + """ + outer_definition = cls.message_definition() + if not outer_definition: + return util.get_package_for_module(cls.__module__) + else: + return outer_definition.definition_package() + + +class _EnumClass(_DefinitionClass): + """Meta-class used for defining the Enum base class. + + Meta-class enables very specific behavior for any defined Enum + class. All attributes defined on an Enum sub-class must be integers. + Each attribute defined on an Enum sub-class is translated + into an instance of that sub-class, with the name of the attribute + as its name, and the number provided as its value. It also ensures + that only one level of Enum class hierarchy is possible. In other + words it is not possible to delcare sub-classes of sub-classes of + Enum. + + This class also defines some functions in order to restrict the + behavior of the Enum class and its sub-classes. It is not possible + to change the behavior of the Enum class in later classes since + any new classes may be defined with only integer values, and no methods. + """ + + def __init__(cls, name, bases, dct): + # Can only define one level of sub-classes below Enum. + if not (bases == (object,) or bases == (Enum,)): + raise EnumDefinitionError('Enum type %s may only inherit from Enum' % + (name,)) + + cls.__by_number = {} + cls.__by_name = {} + + # Enum base class does not need to be initialized or locked. + if bases != (object,): + # Replace integer with number. + for attribute, value in dct.items(): + + # Module will be in every enum class. + if attribute in _RESERVED_ATTRIBUTE_NAMES: + continue + + # Reject anything that is not an int. + if not isinstance(value, six.integer_types): + raise EnumDefinitionError( + 'May only use integers in Enum definitions. Found: %s = %s' % + (attribute, value)) + + # Protocol buffer standard recommends non-negative values. + # Reject negative values. + if value < 0: + raise EnumDefinitionError( + 'Must use non-negative enum values. Found: %s = %d' % + (attribute, value)) + + if value > MAX_ENUM_VALUE: + raise EnumDefinitionError( + 'Must use enum values less than or equal %d. Found: %s = %d' % + (MAX_ENUM_VALUE, attribute, value)) + + if value in cls.__by_number: + raise EnumDefinitionError( + 'Value for %s = %d is already defined: %s' % + (attribute, value, cls.__by_number[value].name)) + + # Create enum instance and list in new Enum type. + instance = object.__new__(cls) + cls.__init__(instance, attribute, value) + cls.__by_name[instance.name] = instance + cls.__by_number[instance.number] = instance + setattr(cls, attribute, instance) + + _DefinitionClass.__init__(cls, name, bases, dct) + + def __iter__(cls): + """Iterate over all values of enum. + + Yields: + Enumeration instances of the Enum class in arbitrary order. + """ + return iter(cls.__by_number.values()) + + def names(cls): + """Get all names for Enum. + + Returns: + An iterator for names of the enumeration in arbitrary order. + """ + return cls.__by_name.keys() + + def numbers(cls): + """Get all numbers for Enum. + + Returns: + An iterator for all numbers of the enumeration in arbitrary order. + """ + return cls.__by_number.keys() + + def lookup_by_name(cls, name): + """Look up Enum by name. + + Args: + name: Name of enum to find. + + Returns: + Enum sub-class instance of that value. + """ + return cls.__by_name[name] + + def lookup_by_number(cls, number): + """Look up Enum by number. + + Args: + number: Number of enum to find. + + Returns: + Enum sub-class instance of that value. + """ + return cls.__by_number[number] + + def __len__(cls): + return len(cls.__by_name) + + +class Enum(six.with_metaclass(_EnumClass, object)): + """Base class for all enumerated types.""" + + __slots__ = set(('name', 'number')) + + def __new__(cls, index): + """Acts as look-up routine after class is initialized. + + The purpose of overriding __new__ is to provide a way to treat + Enum subclasses as casting types, similar to how the int type + functions. A program can pass a string or an integer and this + method with "convert" that value in to an appropriate Enum instance. + + Args: + index: Name or number to look up. During initialization + this is always the name of the new enum value. + + Raises: + TypeError: When an inappropriate index value is passed provided. + """ + # If is enum type of this class, return it. + if isinstance(index, cls): + return index + + # If number, look up by number. + if isinstance(index, six.integer_types): + try: + return cls.lookup_by_number(index) + except KeyError: + pass + + # If name, look up by name. + if isinstance(index, six.string_types): + try: + return cls.lookup_by_name(index) + except KeyError: + pass + + raise TypeError('No such value for %s in Enum %s' % + (index, cls.__name__)) + + def __init__(self, name, number=None): + """Initialize new Enum instance. + + Since this should only be called during class initialization any + calls that happen after the class is frozen raises an exception. + """ + # Immediately return if __init__ was called after _Enum.__init__(). + # It means that casting operator version of the class constructor + # is being used. + if getattr(type(self), '_DefinitionClass__initialized'): + return + object.__setattr__(self, 'name', name) + object.__setattr__(self, 'number', number) + + def __setattr__(self, name, value): + raise TypeError('May not change enum values') + + def __str__(self): + return self.name + + def __int__(self): + return self.number + + def __repr__(self): + return '%s(%s, %d)' % (type(self).__name__, self.name, self.number) + + def __reduce__(self): + """Enable pickling. + + Returns: + A 2-tuple containing the class and __new__ args to be used for restoring + a pickled instance. + """ + return self.__class__, (self.number,) + + def __cmp__(self, other): + """Order is by number.""" + if isinstance(other, type(self)): + return cmp(self.number, other.number) + return NotImplemented + + def __lt__(self, other): + """Order is by number.""" + if isinstance(other, type(self)): + return self.number < other.number + return NotImplemented + + def __le__(self, other): + """Order is by number.""" + if isinstance(other, type(self)): + return self.number <= other.number + return NotImplemented + + def __eq__(self, other): + """Order is by number.""" + if isinstance(other, type(self)): + return self.number == other.number + return NotImplemented + + def __ne__(self, other): + """Order is by number.""" + if isinstance(other, type(self)): + return self.number != other.number + return NotImplemented + + def __ge__(self, other): + """Order is by number.""" + if isinstance(other, type(self)): + return self.number >= other.number + return NotImplemented + + def __gt__(self, other): + """Order is by number.""" + if isinstance(other, type(self)): + return self.number > other.number + return NotImplemented + + def __hash__(self): + """Hash by number.""" + return hash(self.number) + + @classmethod + def to_dict(cls): + """Make dictionary version of enumerated class. + + Dictionary created this way can be used with def_num. + + Returns: + A dict (name) -> number + """ + return dict((item.name, item.number) for item in iter(cls)) + + @staticmethod + def def_enum(dct, name): + """Define enum class from dictionary. + + Args: + dct: Dictionary of enumerated values for type. + name: Name of enum. + """ + return type(name, (Enum,), dct) + + +# TODO(rafek): Determine to what degree this enumeration should be compatible +# with FieldDescriptor.Type in: +# +# http://code.google.com/p/protobuf/source/browse/trunk/src/google/protobuf/descriptor.proto +class Variant(Enum): + """Wire format variant. + + Used by the 'protobuf' wire format to determine how to transmit + a single piece of data. May be used by other formats. + + See: http://code.google.com/apis/protocolbuffers/docs/encoding.html + + Values: + DOUBLE: 64-bit floating point number. + FLOAT: 32-bit floating point number. + INT64: 64-bit signed integer. + UINT64: 64-bit unsigned integer. + INT32: 32-bit signed integer. + BOOL: Boolean value (True or False). + STRING: String of UTF-8 encoded text. + MESSAGE: Embedded message as byte string. + BYTES: String of 8-bit bytes. + UINT32: 32-bit unsigned integer. + ENUM: Enum value as integer. + SINT32: 32-bit signed integer. Uses "zig-zag" encoding. + SINT64: 64-bit signed integer. Uses "zig-zag" encoding. + """ + DOUBLE = 1 + FLOAT = 2 + INT64 = 3 + UINT64 = 4 + INT32 = 5 + BOOL = 8 + STRING = 9 + MESSAGE = 11 + BYTES = 12 + UINT32 = 13 + ENUM = 14 + SINT32 = 17 + SINT64 = 18 + + +class _MessageClass(_DefinitionClass): + """Meta-class used for defining the Message base class. + + For more details about Message classes, see the Message class docstring. + Information contained there may help understanding this class. + + Meta-class enables very specific behavior for any defined Message + class. All attributes defined on an Message sub-class must be field + instances, Enum class definitions or other Message class definitions. Each + field attribute defined on an Message sub-class is added to the set of + field definitions and the attribute is translated in to a slot. It also + ensures that only one level of Message class hierarchy is possible. In other + words it is not possible to declare sub-classes of sub-classes of + Message. + + This class also defines some functions in order to restrict the + behavior of the Message class and its sub-classes. It is not possible + to change the behavior of the Message class in later classes since + any new classes may be defined with only field, Enums and Messages, and + no methods. + """ + + def __new__(cls, name, bases, dct): + """Create new Message class instance. + + The __new__ method of the _MessageClass type is overridden so as to + allow the translation of Field instances to slots. + """ + by_number = {} + by_name = {} + + variant_map = {} + + if bases != (object,): + # Can only define one level of sub-classes below Message. + if bases != (Message,): + raise MessageDefinitionError( + 'Message types may only inherit from Message') + + enums = [] + messages = [] + # Must not use iteritems because this loop will change the state of dct. + for key, field in dct.items(): + + if key in _RESERVED_ATTRIBUTE_NAMES: + continue + + if isinstance(field, type) and issubclass(field, Enum): + enums.append(key) + continue + + if (isinstance(field, type) and + issubclass(field, Message) and + field is not Message): + messages.append(key) + continue + + # Reject anything that is not a field. + if type(field) is Field or not isinstance(field, Field): + raise MessageDefinitionError( + 'May only use fields in message definitions. Found: %s = %s' % + (key, field)) + + if field.number in by_number: + raise DuplicateNumberError( + 'Field with number %d declared more than once in %s' % + (field.number, name)) + + field.name = key + + # Place in name and number maps. + by_name[key] = field + by_number[field.number] = field + + # Add enums if any exist. + if enums: + dct['__enums__'] = sorted(enums) + + # Add messages if any exist. + if messages: + dct['__messages__'] = sorted(messages) + + dct['_Message__by_number'] = by_number + dct['_Message__by_name'] = by_name + + return _DefinitionClass.__new__(cls, name, bases, dct) + + def __init__(cls, name, bases, dct): + """Initializer required to assign references to new class.""" + if bases != (object,): + for value in dct.values(): + if isinstance(value, _DefinitionClass) and not value is Message: + value._message_definition = weakref.ref(cls) + + for field in cls.all_fields(): + field._message_definition = weakref.ref(cls) + + _DefinitionClass.__init__(cls, name, bases, dct) + + +class Message(six.with_metaclass(_MessageClass, object)): + """Base class for user defined message objects. + + Used to define messages for efficient transmission across network or + process space. Messages are defined using the field classes (IntegerField, + FloatField, EnumField, etc.). + + Messages are more restricted than normal classes in that they may only + contain field attributes and other Message and Enum definitions. These + restrictions are in place because the structure of the Message class is + intentended to itself be transmitted across network or process space and + used directly by clients or even other servers. As such methods and + non-field attributes could not be transmitted with the structural information + causing discrepancies between different languages and implementations. + + Initialization and validation: + + A Message object is considered to be initialized if it has all required + fields and any nested messages are also initialized. + + Calling 'check_initialized' will raise a ValidationException if it is not + initialized; 'is_initialized' returns a boolean value indicating if it is + valid. + + Validation automatically occurs when Message objects are created + and populated. Validation that a given value will be compatible with + a field that it is assigned to can be done through the Field instances + validate() method. The validate method used on a message will check that + all values of a message and its sub-messages are valid. Assingning an + invalid value to a field will raise a ValidationException. + + Example: + + # Trade type. + class TradeType(Enum): + BUY = 1 + SELL = 2 + SHORT = 3 + CALL = 4 + + class Lot(Message): + price = IntegerField(1, required=True) + quantity = IntegerField(2, required=True) + + class Order(Message): + symbol = StringField(1, required=True) + total_quantity = IntegerField(2, required=True) + trade_type = EnumField(TradeType, 3, required=True) + lots = MessageField(Lot, 4, repeated=True) + limit = IntegerField(5) + + order = Order(symbol='GOOG', + total_quantity=10, + trade_type=TradeType.BUY) + + lot1 = Lot(price=304, + quantity=7) + + lot2 = Lot(price = 305, + quantity=3) + + order.lots = [lot1, lot2] + + # Now object is initialized! + order.check_initialized() + """ + + def __init__(self, **kwargs): + """Initialize internal messages state. + + Args: + A message can be initialized via the constructor by passing in keyword + arguments corresponding to fields. For example: + + class Date(Message): + day = IntegerField(1) + month = IntegerField(2) + year = IntegerField(3) + + Invoking: + + date = Date(day=6, month=6, year=1911) + + is the same as doing: + + date = Date() + date.day = 6 + date.month = 6 + date.year = 1911 + """ + # Tag being an essential implementation detail must be private. + self.__tags = {} + self.__unrecognized_fields = {} + + assigned = set() + for name, value in kwargs.items(): + setattr(self, name, value) + assigned.add(name) + + # initialize repeated fields. + for field in self.all_fields(): + if field.repeated and field.name not in assigned: + setattr(self, field.name, []) + + + def check_initialized(self): + """Check class for initialization status. + + Check that all required fields are initialized + + Raises: + ValidationError: If message is not initialized. + """ + for name, field in self.__by_name.items(): + value = getattr(self, name) + if value is None: + if field.required: + raise ValidationError("Message %s is missing required field %s" % + (type(self).__name__, name)) + else: + try: + if (isinstance(field, MessageField) and + issubclass(field.message_type, Message)): + if field.repeated: + for item in value: + item_message_value = field.value_to_message(item) + item_message_value.check_initialized() + else: + message_value = field.value_to_message(value) + message_value.check_initialized() + except ValidationError as err: + if not hasattr(err, 'message_name'): + err.message_name = type(self).__name__ + raise + + def is_initialized(self): + """Get initialization status. + + Returns: + True if message is valid, else False. + """ + try: + self.check_initialized() + except ValidationError: + return False + else: + return True + + @classmethod + def all_fields(cls): + """Get all field definition objects. + + Ordering is arbitrary. + + Returns: + Iterator over all values in arbitrary order. + """ + return cls.__by_name.values() + + @classmethod + def field_by_name(cls, name): + """Get field by name. + + Returns: + Field object associated with name. + + Raises: + KeyError if no field found by that name. + """ + return cls.__by_name[name] + + @classmethod + def field_by_number(cls, number): + """Get field by number. + + Returns: + Field object associated with number. + + Raises: + KeyError if no field found by that number. + """ + return cls.__by_number[number] + + def get_assigned_value(self, name): + """Get the assigned value of an attribute. + + Get the underlying value of an attribute. If value has not been set, will + not return the default for the field. + + Args: + name: Name of attribute to get. + + Returns: + Value of attribute, None if it has not been set. + """ + message_type = type(self) + try: + field = message_type.field_by_name(name) + except KeyError: + raise AttributeError('Message %s has no field %s' % ( + message_type.__name__, name)) + return self.__tags.get(field.number) + + def reset(self, name): + """Reset assigned value for field. + + Resetting a field will return it to its default value or None. + + Args: + name: Name of field to reset. + """ + message_type = type(self) + try: + field = message_type.field_by_name(name) + except KeyError: + if name not in message_type.__by_name: + raise AttributeError('Message %s has no field %s' % ( + message_type.__name__, name)) + if field.repeated: + self.__tags[field.number] = FieldList(field, []) + else: + self.__tags.pop(field.number, None) + + def all_unrecognized_fields(self): + """Get the names of all unrecognized fields in this message.""" + return list(self.__unrecognized_fields.keys()) + + def get_unrecognized_field_info(self, key, value_default=None, + variant_default=None): + """Get the value and variant of an unknown field in this message. + + Args: + key: The name or number of the field to retrieve. + value_default: Value to be returned if the key isn't found. + variant_default: Value to be returned as variant if the key isn't + found. + + Returns: + (value, variant), where value and variant are whatever was passed + to set_unrecognized_field. + """ + value, variant = self.__unrecognized_fields.get(key, (value_default, + variant_default)) + return value, variant + + def set_unrecognized_field(self, key, value, variant): + """Set an unrecognized field, used when decoding a message. + + Args: + key: The name or number used to refer to this unknown value. + value: The value of the field. + variant: Type information needed to interpret the value or re-encode it. + + Raises: + TypeError: If the variant is not an instance of messages.Variant. + """ + if not isinstance(variant, Variant): + raise TypeError('Variant type %s is not valid.' % variant) + self.__unrecognized_fields[key] = value, variant + + def __setattr__(self, name, value): + """Change set behavior for messages. + + Messages may only be assigned values that are fields. + + Does not try to validate field when set. + + Args: + name: Name of field to assign to. + value: Value to assign to field. + + Raises: + AttributeError when trying to assign value that is not a field. + """ + if name in self.__by_name or name.startswith('_Message__'): + object.__setattr__(self, name, value) + else: + raise AttributeError("May not assign arbitrary value %s " + "to message %s" % (name, type(self).__name__)) + + def __repr__(self): + """Make string representation of message. + + Example: + + class MyMessage(messages.Message): + integer_value = messages.IntegerField(1) + string_value = messages.StringField(2) + + my_message = MyMessage() + my_message.integer_value = 42 + my_message.string_value = u'A string' + + print my_message + >>> + + Returns: + String representation of message, including the values + of all fields and repr of all sub-messages. + """ + body = ['<', type(self).__name__] + for field in sorted(self.all_fields(), + key=lambda f: f.number): + attribute = field.name + value = self.get_assigned_value(field.name) + if value is not None: + body.append('\n %s: %s' % (attribute, repr(value))) + body.append('>') + return ''.join(body) + + def __eq__(self, other): + """Equality operator. + + Does field by field comparison with other message. For + equality, must be same type and values of all fields must be + equal. + + Messages not required to be initialized for comparison. + + Does not attempt to determine equality for values that have + default values that are not set. In other words: + + class HasDefault(Message): + + attr1 = StringField(1, default='default value') + + message1 = HasDefault() + message2 = HasDefault() + message2.attr1 = 'default value' + + message1 != message2 + + Does not compare unknown values. + + Args: + other: Other message to compare with. + """ + # TODO(rafek): Implement "equivalent" which does comparisons + # taking default values in to consideration. + if self is other: + return True + + if type(self) is not type(other): + return False + + return self.__tags == other.__tags + + def __ne__(self, other): + """Not equals operator. + + Does field by field comparison with other message. For + non-equality, must be different type or any value of a field must be + non-equal to the same field in the other instance. + + Messages not required to be initialized for comparison. + + Args: + other: Other message to compare with. + """ + return not self.__eq__(other) + + +class FieldList(list): + """List implementation that validates field values. + + This list implementation overrides all methods that add values in to a list + in order to validate those new elements. Attempting to add or set list + values that are not of the correct type will raise ValidationError. + """ + + def __init__(self, field_instance, sequence): + """Constructor. + + Args: + field_instance: Instance of field that validates the list. + sequence: List or tuple to construct list from. + """ + if not field_instance.repeated: + raise FieldDefinitionError('FieldList may only accept repeated fields') + self.__field = field_instance + self.__field.validate(sequence) + list.__init__(self, sequence) + + def __getstate__(self): + """Enable pickling. + + The assigned field instance can't be pickled if it belongs to a Message + definition (message_definition uses a weakref), so the Message class and + field number are returned in that case. + + Returns: + A 3-tuple containing: + - The field instance, or None if it belongs to a Message class. + - The Message class that the field instance belongs to, or None. + - The field instance number of the Message class it belongs to, or None. + """ + message_class = self.__field.message_definition() + if message_class is None: + return self.__field, None, None + else: + return None, message_class, self.__field.number + + def __setstate__(self, state): + """Enable unpickling. + + Args: + state: A 3-tuple containing: + - The field instance, or None if it belongs to a Message class. + - The Message class that the field instance belongs to, or None. + - The field instance number of the Message class it belongs to, or None. + """ + field_instance, message_class, number = state + if field_instance is None: + self.__field = message_class.field_by_number(number) + else: + self.__field = field_instance + + @property + def field(self): + """Field that validates list.""" + return self.__field + + def __setslice__(self, i, j, sequence): + """Validate slice assignment to list.""" + self.__field.validate(sequence) + list.__setslice__(self, i, j, sequence) + + def __setitem__(self, index, value): + """Validate item assignment to list.""" + if isinstance(index, slice): + self.__field.validate(value) + else: + self.__field.validate_element(value) + list.__setitem__(self, index, value) + + def append(self, value): + """Validate item appending to list.""" + self.__field.validate_element(value) + return list.append(self, value) + + def extend(self, sequence): + """Validate extension of list.""" + self.__field.validate(sequence) + return list.extend(self, sequence) + + def insert(self, index, value): + """Validate item insertion to list.""" + self.__field.validate_element(value) + return list.insert(self, index, value) + + +class _FieldMeta(type): + + def __init__(cls, name, bases, dct): + getattr(cls, '_Field__variant_to_type').update( + (variant, cls) for variant in dct.get('VARIANTS', [])) + type.__init__(cls, name, bases, dct) + + +# TODO(rafek): Prevent additional field subclasses. +class Field(six.with_metaclass(_FieldMeta, object)): + + __initialized = False + __variant_to_type = {} + + @util.positional(2) + def __init__(self, + number, + required=False, + repeated=False, + variant=None, + default=None): + """Constructor. + + The required and repeated parameters are mutually exclusive. Setting both + to True will raise a FieldDefinitionError. + + Sub-class Attributes: + Each sub-class of Field must define the following: + VARIANTS: Set of variant types accepted by that field. + DEFAULT_VARIANT: Default variant type if not specified in constructor. + + Args: + number: Number of field. Must be unique per message class. + required: Whether or not field is required. Mutually exclusive with + 'repeated'. + repeated: Whether or not field is repeated. Mutually exclusive with + 'required'. + variant: Wire-format variant hint. + default: Default value for field if not found in stream. + + Raises: + InvalidVariantError when invalid variant for field is provided. + InvalidDefaultError when invalid default for field is provided. + FieldDefinitionError when invalid number provided or mutually exclusive + fields are used. + InvalidNumberError when the field number is out of range or reserved. + """ + if not isinstance(number, int) or not 1 <= number <= MAX_FIELD_NUMBER: + raise InvalidNumberError('Invalid number for field: %s\n' + 'Number must be 1 or greater and %d or less' % + (number, MAX_FIELD_NUMBER)) + + if FIRST_RESERVED_FIELD_NUMBER <= number <= LAST_RESERVED_FIELD_NUMBER: + raise InvalidNumberError('Tag number %d is a reserved number.\n' + 'Numbers %d to %d are reserved' % + (number, FIRST_RESERVED_FIELD_NUMBER, + LAST_RESERVED_FIELD_NUMBER)) + + if repeated and required: + raise FieldDefinitionError('Cannot set both repeated and required') + + if variant is None: + variant = self.DEFAULT_VARIANT + + if repeated and default is not None: + raise FieldDefinitionError('Repeated fields may not have defaults') + + if variant not in self.VARIANTS: + raise InvalidVariantError( + 'Invalid variant: %s\nValid variants for %s are %r' % + (variant, type(self).__name__, sorted(self.VARIANTS))) + + self.number = number + self.required = required + self.repeated = repeated + self.variant = variant + + if default is not None: + try: + self.validate_default(default) + except ValidationError as err: + try: + name = self.name + except AttributeError: + # For when raising error before name initialization. + raise InvalidDefaultError('Invalid default value for %s: %r: %s' % + (self.__class__.__name__, default, err)) + else: + raise InvalidDefaultError('Invalid default value for field %s: ' + '%r: %s' % (name, default, err)) + + self.__default = default + self.__initialized = True + + def __setattr__(self, name, value): + """Setter overidden to prevent assignment to fields after creation. + + Args: + name: Name of attribute to set. + value: Value to assign. + """ + # Special case post-init names. They need to be set after constructor. + if name in _POST_INIT_FIELD_ATTRIBUTE_NAMES: + object.__setattr__(self, name, value) + return + + # All other attributes must be set before __initialized. + if not self.__initialized: + # Not initialized yet, allow assignment. + object.__setattr__(self, name, value) + else: + raise AttributeError('Field objects are read-only') + + def __set__(self, message_instance, value): + """Set value on message. + + Args: + message_instance: Message instance to set value on. + value: Value to set on message. + """ + # Reaches in to message instance directly to assign to private tags. + if value is None: + if self.repeated: + raise ValidationError( + 'May not assign None to repeated field %s' % self.name) + else: + message_instance._Message__tags.pop(self.number, None) + else: + if self.repeated: + value = FieldList(self, value) + else: + value = self.validate(value) + message_instance._Message__tags[self.number] = value + + def __get__(self, message_instance, message_class): + if message_instance is None: + return self + + result = message_instance._Message__tags.get(self.number) + if result is None: + return self.default + else: + return result + + def validate_element(self, value): + """Validate single element of field. + + This is different from validate in that it is used on individual + values of repeated fields. + + Args: + value: Value to validate. + + Returns: + The value casted in the expectes type. + + Raises: + ValidationError if value is not expected type. + """ + if not isinstance(value, self.type): + # Authorize in values as float + if isinstance(value, six.integer_types) and self.type == float: + return float(value) + + if value is None: + if self.required: + raise ValidationError('Required field is missing') + else: + try: + name = self.name + except AttributeError: + raise ValidationError('Expected type %s for %s, ' + 'found %s (type %s)' % + (self.type, self.__class__.__name__, + value, type(value))) + else: + raise ValidationError('Expected type %s for field %s, ' + 'found %s (type %s)' % + (self.type, name, value, type(value))) + return value + + def __validate(self, value, validate_element): + """Internal validation function. + + Validate an internal value using a function to validate individual elements. + + Args: + value: Value to validate. + validate_element: Function to use to validate individual elements. + + Raises: + ValidationError if value is not expected type. + """ + if not self.repeated: + return validate_element(value) + else: + # Must be a list or tuple, may not be a string. + if isinstance(value, (list, tuple)): + result = [] + for element in value: + if element is None: + try: + name = self.name + except AttributeError: + raise ValidationError('Repeated values for %s ' + 'may not be None' % self.__class__.__name__) + else: + raise ValidationError('Repeated values for field %s ' + 'may not be None' % name) + result.append(validate_element(element)) + return result + elif value is not None: + try: + name = self.name + except AttributeError: + raise ValidationError('%s is repeated. Found: %s' % ( + self.__class__.__name__, value)) + else: + raise ValidationError('Field %s is repeated. Found: %s' % (name, + value)) + return value + + def validate(self, value): + """Validate value assigned to field. + + Args: + value: Value to validate. + + Returns: + the value eventually casted in the correct type. + + Raises: + ValidationError if value is not expected type. + """ + return self.__validate(value, self.validate_element) + + def validate_default_element(self, value): + """Validate value as assigned to field default field. + + Some fields may allow for delayed resolution of default types necessary + in the case of circular definition references. In this case, the default + value might be a place holder that is resolved when needed after all the + message classes are defined. + + Args: + value: Default value to validate. + + Returns: + the value eventually casted in the correct type. + + Raises: + ValidationError if value is not expected type. + """ + return self.validate_element(value) + + def validate_default(self, value): + """Validate default value assigned to field. + + Args: + value: Value to validate. + + Returns: + the value eventually casted in the correct type. + + Raises: + ValidationError if value is not expected type. + """ + return self.__validate(value, self.validate_default_element) + + def message_definition(self): + """Get Message definition that contains this Field definition. + + Returns: + Containing Message definition for Field. Will return None if for + some reason Field is defined outside of a Message class. + """ + try: + return self._message_definition() + except AttributeError: + return None + + @property + def default(self): + """Get default value for field.""" + return self.__default + + @classmethod + def lookup_field_type_by_variant(cls, variant): + return cls.__variant_to_type[variant] + + +class IntegerField(Field): + """Field definition for integer values.""" + + VARIANTS = frozenset([Variant.INT32, + Variant.INT64, + Variant.UINT32, + Variant.UINT64, + Variant.SINT32, + Variant.SINT64, + ]) + + DEFAULT_VARIANT = Variant.INT64 + + type = six.integer_types + + +class FloatField(Field): + """Field definition for float values.""" + + VARIANTS = frozenset([Variant.FLOAT, + Variant.DOUBLE, + ]) + + DEFAULT_VARIANT = Variant.DOUBLE + + type = float + + +class BooleanField(Field): + """Field definition for boolean values.""" + + VARIANTS = frozenset([Variant.BOOL]) + + DEFAULT_VARIANT = Variant.BOOL + + type = bool + + +class BytesField(Field): + """Field definition for byte string values.""" + + VARIANTS = frozenset([Variant.BYTES]) + + DEFAULT_VARIANT = Variant.BYTES + + type = bytes + + +class StringField(Field): + """Field definition for unicode string values.""" + + VARIANTS = frozenset([Variant.STRING]) + + DEFAULT_VARIANT = Variant.STRING + + type = six.text_type + + def validate_element(self, value): + """Validate StringField allowing for str and unicode. + + Raises: + ValidationError if a str value is not 7-bit ascii. + """ + # If value is str is it considered valid. Satisfies "required=True". + if isinstance(value, bytes): + try: + six.text_type(value, 'ascii') + except UnicodeDecodeError as err: + try: + name = self.name + except AttributeError: + validation_error = ValidationError( + 'Field encountered non-ASCII string %r: %s' % (value, + err)) + else: + validation_error = ValidationError( + 'Field %s encountered non-ASCII string %r: %s' % (self.name, + value, + err)) + validation_error.field_name = self.name + raise validation_error + else: + return super(StringField, self).validate_element(value) + + +class MessageField(Field): + """Field definition for sub-message values. + + Message fields contain instance of other messages. Instances stored + on messages stored on message fields are considered to be owned by + the containing message instance and should not be shared between + owning instances. + + Message fields must be defined to reference a single type of message. + Normally message field are defined by passing the referenced message + class in to the constructor. + + It is possible to define a message field for a type that does not yet + exist by passing the name of the message in to the constructor instead + of a message class. Resolution of the actual type of the message is + deferred until it is needed, for example, during message verification. + Names provided to the constructor must refer to a class within the same + python module as the class that is using it. Names refer to messages + relative to the containing messages scope. For example, the two fields + of OuterMessage refer to the same message type: + + class Outer(Message): + + inner_relative = MessageField('Inner', 1) + inner_absolute = MessageField('Outer.Inner', 2) + + class Inner(Message): + ... + + When resolving an actual type, MessageField will traverse the entire + scope of nested messages to match a message name. This makes it easy + for siblings to reference siblings: + + class Outer(Message): + + class Inner(Message): + + sibling = MessageField('Sibling', 1) + + class Sibling(Message): + ... + """ + + VARIANTS = frozenset([Variant.MESSAGE]) + + DEFAULT_VARIANT = Variant.MESSAGE + + @util.positional(3) + def __init__(self, + message_type, + number, + required=False, + repeated=False, + variant=None): + """Constructor. + + Args: + message_type: Message type for field. Must be subclass of Message. + number: Number of field. Must be unique per message class. + required: Whether or not field is required. Mutually exclusive to + 'repeated'. + repeated: Whether or not field is repeated. Mutually exclusive to + 'required'. + variant: Wire-format variant hint. + + Raises: + FieldDefinitionError when invalid message_type is provided. + """ + valid_type = (isinstance(message_type, six.string_types) or + (message_type is not Message and + isinstance(message_type, type) and + issubclass(message_type, Message))) + + if not valid_type: + raise FieldDefinitionError('Invalid message class: %s' % message_type) + + if isinstance(message_type, six.string_types): + self.__type_name = message_type + self.__type = None + else: + self.__type = message_type + + super(MessageField, self).__init__(number, + required=required, + repeated=repeated, + variant=variant) + + def __set__(self, message_instance, value): + """Set value on message. + + Args: + message_instance: Message instance to set value on. + value: Value to set on message. + """ + message_type = self.type + if isinstance(message_type, type) and issubclass(message_type, Message): + if self.repeated: + if value and isinstance(value, (list, tuple)): + value = [(message_type(**v) if isinstance(v, dict) else v) + for v in value] + elif isinstance(value, dict): + value = message_type(**value) + super(MessageField, self).__set__(message_instance, value) + + @property + def type(self): + """Message type used for field.""" + if self.__type is None: + message_type = find_definition(self.__type_name, self.message_definition()) + if not (message_type is not Message and + isinstance(message_type, type) and + issubclass(message_type, Message)): + raise FieldDefinitionError('Invalid message class: %s' % message_type) + self.__type = message_type + return self.__type + + @property + def message_type(self): + """Underlying message type used for serialization. + + Will always be a sub-class of Message. This is different from type + which represents the python value that message_type is mapped to for + use by the user. + """ + return self.type + + def value_from_message(self, message): + """Convert a message to a value instance. + + Used by deserializers to convert from underlying messages to + value of expected user type. + + Args: + message: A message instance of type self.message_type. + + Returns: + Value of self.message_type. + """ + if not isinstance(message, self.message_type): + raise DecodeError('Expected type %s, got %s: %r' % + (self.message_type.__name__, + type(message).__name__, + message)) + return message + + def value_to_message(self, value): + """Convert a value instance to a message. + + Used by serializers to convert Python user types to underlying + messages for transmission. + + Args: + value: A value of type self.type. + + Returns: + An instance of type self.message_type. + """ + if not isinstance(value, self.type): + raise EncodeError('Expected type %s, got %s: %r' % + (self.type.__name__, + type(value).__name__, + value)) + return value + + +class EnumField(Field): + """Field definition for enum values. + + Enum fields may have default values that are delayed until the associated enum + type is resolved. This is necessary to support certain circular references. + + For example: + + class Message1(Message): + + class Color(Enum): + + RED = 1 + GREEN = 2 + BLUE = 3 + + # This field default value will be validated when default is accessed. + animal = EnumField('Message2.Animal', 1, default='HORSE') + + class Message2(Message): + + class Animal(Enum): + + DOG = 1 + CAT = 2 + HORSE = 3 + + # This fields default value will be validated right away since Color is + # already fully resolved. + color = EnumField(Message1.Color, 1, default='RED') + """ + + VARIANTS = frozenset([Variant.ENUM]) + + DEFAULT_VARIANT = Variant.ENUM + + def __init__(self, enum_type, number, **kwargs): + """Constructor. + + Args: + enum_type: Enum type for field. Must be subclass of Enum. + number: Number of field. Must be unique per message class. + required: Whether or not field is required. Mutually exclusive to + 'repeated'. + repeated: Whether or not field is repeated. Mutually exclusive to + 'required'. + variant: Wire-format variant hint. + default: Default value for field if not found in stream. + + Raises: + FieldDefinitionError when invalid enum_type is provided. + """ + valid_type = (isinstance(enum_type, six.string_types) or + (enum_type is not Enum and + isinstance(enum_type, type) and + issubclass(enum_type, Enum))) + + if not valid_type: + raise FieldDefinitionError('Invalid enum type: %s' % enum_type) + + if isinstance(enum_type, six.string_types): + self.__type_name = enum_type + self.__type = None + else: + self.__type = enum_type + + super(EnumField, self).__init__(number, **kwargs) + + def validate_default_element(self, value): + """Validate default element of Enum field. + + Enum fields allow for delayed resolution of default values when the type + of the field has not been resolved. The default value of a field may be + a string or an integer. If the Enum type of the field has been resolved, + the default value is validated against that type. + + Args: + value: Value to validate. + + Raises: + ValidationError if value is not expected message type. + """ + if isinstance(value, (six.string_types, six.integer_types)): + # Validation of the value does not happen for delayed resolution + # enumerated types. Ignore if type is not yet resolved. + if self.__type: + self.__type(value) + return + + return super(EnumField, self).validate_default_element(value) + + @property + def type(self): + """Enum type used for field.""" + if self.__type is None: + found_type = find_definition(self.__type_name, self.message_definition()) + if not (found_type is not Enum and + isinstance(found_type, type) and + issubclass(found_type, Enum)): + raise FieldDefinitionError('Invalid enum type: %s' % found_type) + + self.__type = found_type + return self.__type + + @property + def default(self): + """Default for enum field. + + Will cause resolution of Enum type and unresolved default value. + """ + try: + return self.__resolved_default + except AttributeError: + resolved_default = super(EnumField, self).default + if isinstance(resolved_default, (six.string_types, six.integer_types)): + resolved_default = self.type(resolved_default) + self.__resolved_default = resolved_default + return self.__resolved_default + + +@util.positional(2) +def find_definition(name, relative_to=None, importer=__import__): + """Find definition by name in module-space. + + The find algorthm will look for definitions by name relative to a message + definition or by fully qualfied name. If no definition is found relative + to the relative_to parameter it will do the same search against the container + of relative_to. If relative_to is a nested Message, it will search its + message_definition(). If that message has no message_definition() it will + search its module. If relative_to is a module, it will attempt to look for + the containing module and search relative to it. If the module is a top-level + module, it will look for the a message using a fully qualified name. If + no message is found then, the search fails and DefinitionNotFoundError is + raised. + + For example, when looking for any definition 'foo.bar.ADefinition' relative to + an actual message definition abc.xyz.SomeMessage: + + find_definition('foo.bar.ADefinition', SomeMessage) + + It is like looking for the following fully qualified names: + + abc.xyz.SomeMessage. foo.bar.ADefinition + abc.xyz. foo.bar.ADefinition + abc. foo.bar.ADefinition + foo.bar.ADefinition + + When resolving the name relative to Message definitions and modules, the + algorithm searches any Messages or sub-modules found in its path. + Non-Message values are not searched. + + A name that begins with '.' is considered to be a fully qualified name. The + name is always searched for from the topmost package. For example, assume + two message types: + + abc.xyz.SomeMessage + xyz.SomeMessage + + Searching for '.xyz.SomeMessage' relative to 'abc' will resolve to + 'xyz.SomeMessage' and not 'abc.xyz.SomeMessage'. For this kind of name, + the relative_to parameter is effectively ignored and always set to None. + + For more information about package name resolution, please see: + + http://code.google.com/apis/protocolbuffers/docs/proto.html#packages + + Args: + name: Name of definition to find. May be fully qualified or relative name. + relative_to: Search for definition relative to message definition or module. + None will cause a fully qualified name search. + importer: Import function to use for resolving modules. + + Returns: + Enum or Message class definition associated with name. + + Raises: + DefinitionNotFoundError if no definition is found in any search path. + """ + # Check parameters. + if not (relative_to is None or + isinstance(relative_to, types.ModuleType) or + isinstance(relative_to, type) and issubclass(relative_to, Message)): + raise TypeError('relative_to must be None, Message definition or module. ' + 'Found: %s' % relative_to) + + name_path = name.split('.') + + # Handle absolute path reference. + if not name_path[0]: + relative_to = None + name_path = name_path[1:] + + def search_path(): + """Performs a single iteration searching the path from relative_to. + + This is the function that searches up the path from a relative object. + + fully.qualified.object . relative.or.nested.Definition + ----------------------------> + ^ + | + this part of search --+ + + Returns: + Message or Enum at the end of name_path, else None. + """ + next = relative_to + for node in name_path: + # Look for attribute first. + attribute = getattr(next, node, None) + + if attribute is not None: + next = attribute + else: + # If module, look for sub-module. + if next is None or isinstance(next, types.ModuleType): + if next is None: + module_name = node + else: + module_name = '%s.%s' % (next.__name__, node) + + try: + fromitem = module_name.split('.')[-1] + next = importer(module_name, '', '', [str(fromitem)]) + except ImportError: + return None + else: + return None + + if (not isinstance(next, types.ModuleType) and + not (isinstance(next, type) and + issubclass(next, (Message, Enum)))): + return None + + return next + + while True: + found = search_path() + if isinstance(found, type) and issubclass(found, (Enum, Message)): + return found + else: + # Find next relative_to to search against. + # + # fully.qualified.object . relative.or.nested.Definition + # <--------------------- + # ^ + # | + # does this part of search + if relative_to is None: + # Fully qualified search was done. Nothing found. Fail. + raise DefinitionNotFoundError('Could not find definition for %s' + % (name,)) + else: + if isinstance(relative_to, types.ModuleType): + # Find parent module. + module_path = relative_to.__name__.split('.')[:-1] + if not module_path: + relative_to = None + else: + # Should not raise ImportError. If it does... weird and + # unexepected. Propagate. + relative_to = importer( + '.'.join(module_path), '', '', [module_path[-1]]) + elif (isinstance(relative_to, type) and + issubclass(relative_to, Message)): + parent = relative_to.message_definition() + if parent is None: + last_module_name = relative_to.__module__.split('.')[-1] + relative_to = importer( + relative_to.__module__, '', '', [last_module_name]) + else: + relative_to = parent diff --git a/endpoints/internal/protorpc/messages_test.py b/endpoints/internal/protorpc/messages_test.py new file mode 100644 index 0000000..9460b31 --- /dev/null +++ b/endpoints/internal/protorpc/messages_test.py @@ -0,0 +1,2109 @@ +#!/usr/bin/env python +# +# Copyright 2010 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""Tests for protorpc.messages.""" +import six + +__author__ = 'rafek@google.com (Rafe Kaplan)' + + +import pickle +import re +import sys +import types +import unittest + +from protorpc import descriptor +from protorpc import message_types +from protorpc import messages +from protorpc import test_util + + +class ModuleInterfaceTest(test_util.ModuleInterfaceTest, + test_util.TestCase): + + MODULE = messages + + +class ValidationErrorTest(test_util.TestCase): + + def testStr_NoFieldName(self): + """Test string version of ValidationError when no name provided.""" + self.assertEquals('Validation error', + str(messages.ValidationError('Validation error'))) + + def testStr_FieldName(self): + """Test string version of ValidationError when no name provided.""" + validation_error = messages.ValidationError('Validation error') + validation_error.field_name = 'a_field' + self.assertEquals('Validation error', str(validation_error)) + + +class EnumTest(test_util.TestCase): + + def setUp(self): + """Set up tests.""" + # Redefine Color class in case so that changes to it (an error) in one test + # does not affect other tests. + global Color + class Color(messages.Enum): + RED = 20 + ORANGE = 2 + YELLOW = 40 + GREEN = 4 + BLUE = 50 + INDIGO = 5 + VIOLET = 80 + + def testNames(self): + """Test that names iterates over enum names.""" + self.assertEquals( + set(['BLUE', 'GREEN', 'INDIGO', 'ORANGE', 'RED', 'VIOLET', 'YELLOW']), + set(Color.names())) + + def testNumbers(self): + """Tests that numbers iterates of enum numbers.""" + self.assertEquals(set([2, 4, 5, 20, 40, 50, 80]), set(Color.numbers())) + + def testIterate(self): + """Test that __iter__ iterates over all enum values.""" + self.assertEquals(set(Color), + set([Color.RED, + Color.ORANGE, + Color.YELLOW, + Color.GREEN, + Color.BLUE, + Color.INDIGO, + Color.VIOLET])) + + def testNaturalOrder(self): + """Test that natural order enumeration is in numeric order.""" + self.assertEquals([Color.ORANGE, + Color.GREEN, + Color.INDIGO, + Color.RED, + Color.YELLOW, + Color.BLUE, + Color.VIOLET], + sorted(Color)) + + def testByName(self): + """Test look-up by name.""" + self.assertEquals(Color.RED, Color.lookup_by_name('RED')) + self.assertRaises(KeyError, Color.lookup_by_name, 20) + self.assertRaises(KeyError, Color.lookup_by_name, Color.RED) + + def testByNumber(self): + """Test look-up by number.""" + self.assertRaises(KeyError, Color.lookup_by_number, 'RED') + self.assertEquals(Color.RED, Color.lookup_by_number(20)) + self.assertRaises(KeyError, Color.lookup_by_number, Color.RED) + + def testConstructor(self): + """Test that constructor look-up by name or number.""" + self.assertEquals(Color.RED, Color('RED')) + self.assertEquals(Color.RED, Color(u'RED')) + self.assertEquals(Color.RED, Color(20)) + if six.PY2: + self.assertEquals(Color.RED, Color(long(20))) + self.assertEquals(Color.RED, Color(Color.RED)) + self.assertRaises(TypeError, Color, 'Not exists') + self.assertRaises(TypeError, Color, 'Red') + self.assertRaises(TypeError, Color, 100) + self.assertRaises(TypeError, Color, 10.0) + + def testLen(self): + """Test that len function works to count enums.""" + self.assertEquals(7, len(Color)) + + def testNoSubclasses(self): + """Test that it is not possible to sub-class enum classes.""" + def declare_subclass(): + class MoreColor(Color): + pass + self.assertRaises(messages.EnumDefinitionError, + declare_subclass) + + def testClassNotMutable(self): + """Test that enum classes themselves are not mutable.""" + self.assertRaises(AttributeError, + setattr, + Color, + 'something_new', + 10) + + def testInstancesMutable(self): + """Test that enum instances are not mutable.""" + self.assertRaises(TypeError, + setattr, + Color.RED, + 'something_new', + 10) + + def testDefEnum(self): + """Test def_enum works by building enum class from dict.""" + WeekDay = messages.Enum.def_enum({'Monday': 1, + 'Tuesday': 2, + 'Wednesday': 3, + 'Thursday': 4, + 'Friday': 6, + 'Saturday': 7, + 'Sunday': 8}, + 'WeekDay') + self.assertEquals('Wednesday', WeekDay(3).name) + self.assertEquals(6, WeekDay('Friday').number) + self.assertEquals(WeekDay.Sunday, WeekDay('Sunday')) + + def testNonInt(self): + """Test that non-integer values rejection by enum def.""" + self.assertRaises(messages.EnumDefinitionError, + messages.Enum.def_enum, + {'Bad': '1'}, + 'BadEnum') + + def testNegativeInt(self): + """Test that negative numbers rejection by enum def.""" + self.assertRaises(messages.EnumDefinitionError, + messages.Enum.def_enum, + {'Bad': -1}, + 'BadEnum') + + def testLowerBound(self): + """Test that zero is accepted by enum def.""" + class NotImportant(messages.Enum): + """Testing for value zero""" + VALUE = 0 + + self.assertEquals(0, int(NotImportant.VALUE)) + + def testTooLargeInt(self): + """Test that numbers too large are rejected.""" + self.assertRaises(messages.EnumDefinitionError, + messages.Enum.def_enum, + {'Bad': (2 ** 29)}, + 'BadEnum') + + def testRepeatedInt(self): + """Test duplicated numbers are forbidden.""" + self.assertRaises(messages.EnumDefinitionError, + messages.Enum.def_enum, + {'Ok': 1, 'Repeated': 1}, + 'BadEnum') + + def testStr(self): + """Test converting to string.""" + self.assertEquals('RED', str(Color.RED)) + self.assertEquals('ORANGE', str(Color.ORANGE)) + + def testInt(self): + """Test converting to int.""" + self.assertEquals(20, int(Color.RED)) + self.assertEquals(2, int(Color.ORANGE)) + + def testRepr(self): + """Test enum representation.""" + self.assertEquals('Color(RED, 20)', repr(Color.RED)) + self.assertEquals('Color(YELLOW, 40)', repr(Color.YELLOW)) + + def testDocstring(self): + """Test that docstring is supported ok.""" + class NotImportant(messages.Enum): + """I have a docstring.""" + + VALUE1 = 1 + + self.assertEquals('I have a docstring.', NotImportant.__doc__) + + def testDeleteEnumValue(self): + """Test that enum values cannot be deleted.""" + self.assertRaises(TypeError, delattr, Color, 'RED') + + def testEnumName(self): + """Test enum name.""" + module_name = test_util.get_module_name(EnumTest) + self.assertEquals('%s.Color' % module_name, Color.definition_name()) + self.assertEquals(module_name, Color.outer_definition_name()) + self.assertEquals(module_name, Color.definition_package()) + + def testDefinitionName_OverrideModule(self): + """Test enum module is overriden by module package name.""" + global package + try: + package = 'my.package' + self.assertEquals('my.package.Color', Color.definition_name()) + self.assertEquals('my.package', Color.outer_definition_name()) + self.assertEquals('my.package', Color.definition_package()) + finally: + del package + + def testDefinitionName_NoModule(self): + """Test what happens when there is no module for enum.""" + class Enum1(messages.Enum): + pass + + original_modules = sys.modules + sys.modules = dict(sys.modules) + try: + del sys.modules[__name__] + self.assertEquals('Enum1', Enum1.definition_name()) + self.assertEquals(None, Enum1.outer_definition_name()) + self.assertEquals(None, Enum1.definition_package()) + self.assertEquals(six.text_type, type(Enum1.definition_name())) + finally: + sys.modules = original_modules + + def testDefinitionName_Nested(self): + """Test nested Enum names.""" + class MyMessage(messages.Message): + + class NestedEnum(messages.Enum): + + pass + + class NestedMessage(messages.Message): + + class NestedEnum(messages.Enum): + + pass + + module_name = test_util.get_module_name(EnumTest) + self.assertEquals('%s.MyMessage.NestedEnum' % module_name, + MyMessage.NestedEnum.definition_name()) + self.assertEquals('%s.MyMessage' % module_name, + MyMessage.NestedEnum.outer_definition_name()) + self.assertEquals(module_name, + MyMessage.NestedEnum.definition_package()) + + self.assertEquals('%s.MyMessage.NestedMessage.NestedEnum' % module_name, + MyMessage.NestedMessage.NestedEnum.definition_name()) + self.assertEquals( + '%s.MyMessage.NestedMessage' % module_name, + MyMessage.NestedMessage.NestedEnum.outer_definition_name()) + self.assertEquals(module_name, + MyMessage.NestedMessage.NestedEnum.definition_package()) + + def testMessageDefinition(self): + """Test that enumeration knows its enclosing message definition.""" + class OuterEnum(messages.Enum): + pass + + self.assertEquals(None, OuterEnum.message_definition()) + + class OuterMessage(messages.Message): + + class InnerEnum(messages.Enum): + pass + + self.assertEquals(OuterMessage, OuterMessage.InnerEnum.message_definition()) + + def testComparison(self): + """Test comparing various enums to different types.""" + class Enum1(messages.Enum): + VAL1 = 1 + VAL2 = 2 + + class Enum2(messages.Enum): + VAL1 = 1 + + self.assertEquals(Enum1.VAL1, Enum1.VAL1) + self.assertNotEquals(Enum1.VAL1, Enum1.VAL2) + self.assertNotEquals(Enum1.VAL1, Enum2.VAL1) + self.assertNotEquals(Enum1.VAL1, 'VAL1') + self.assertNotEquals(Enum1.VAL1, 1) + self.assertNotEquals(Enum1.VAL1, 2) + self.assertNotEquals(Enum1.VAL1, None) + self.assertNotEquals(Enum1.VAL1, Enum2.VAL1) + + self.assertTrue(Enum1.VAL1 < Enum1.VAL2) + self.assertTrue(Enum1.VAL2 > Enum1.VAL1) + + self.assertNotEquals(1, Enum2.VAL1) + + def testPickle(self): + """Testing pickling and unpickling of Enum instances.""" + colors = list(Color) + unpickled = pickle.loads(pickle.dumps(colors)) + self.assertEquals(colors, unpickled) + # Unpickling shouldn't create new enum instances. + for i, color in enumerate(colors): + self.assertTrue(color is unpickled[i]) + + +class FieldListTest(test_util.TestCase): + + def setUp(self): + self.integer_field = messages.IntegerField(1, repeated=True) + + def testConstructor(self): + self.assertEquals([1, 2, 3], + messages.FieldList(self.integer_field, [1, 2, 3])) + self.assertEquals([1, 2, 3], + messages.FieldList(self.integer_field, (1, 2, 3))) + self.assertEquals([], messages.FieldList(self.integer_field, [])) + + def testNone(self): + self.assertRaises(TypeError, messages.FieldList, self.integer_field, None) + + def testDoNotAutoConvertString(self): + string_field = messages.StringField(1, repeated=True) + self.assertRaises(messages.ValidationError, + messages.FieldList, string_field, 'abc') + + def testConstructorCopies(self): + a_list = [1, 3, 6] + field_list = messages.FieldList(self.integer_field, a_list) + self.assertFalse(a_list is field_list) + self.assertFalse(field_list is + messages.FieldList(self.integer_field, field_list)) + + def testNonRepeatedField(self): + self.assertRaisesWithRegexpMatch( + messages.FieldDefinitionError, + 'FieldList may only accept repeated fields', + messages.FieldList, + messages.IntegerField(1), + []) + + def testConstructor_InvalidValues(self): + self.assertRaisesWithRegexpMatch( + messages.ValidationError, + re.escape("Expected type %r " + "for IntegerField, found 1 (type %r)" + % (six.integer_types, str)), + messages.FieldList, self.integer_field, ["1", "2", "3"]) + + def testConstructor_Scalars(self): + self.assertRaisesWithRegexpMatch( + messages.ValidationError, + "IntegerField is repeated. Found: 3", + messages.FieldList, self.integer_field, 3) + + self.assertRaisesWithRegexpMatch( + messages.ValidationError, + "IntegerField is repeated. Found: <(list[_]?|sequence)iterator object", + messages.FieldList, self.integer_field, iter([1, 2, 3])) + + def testSetSlice(self): + field_list = messages.FieldList(self.integer_field, [1, 2, 3, 4, 5]) + field_list[1:3] = [10, 20] + self.assertEquals([1, 10, 20, 4, 5], field_list) + + def testSetSlice_InvalidValues(self): + field_list = messages.FieldList(self.integer_field, [1, 2, 3, 4, 5]) + + def setslice(): + field_list[1:3] = ['10', '20'] + + msg_re = re.escape("Expected type %r " + "for IntegerField, found 10 (type %r)" + % (six.integer_types, str)) + self.assertRaisesWithRegexpMatch( + messages.ValidationError, + msg_re, + setslice) + + def testSetItem(self): + field_list = messages.FieldList(self.integer_field, [2]) + field_list[0] = 10 + self.assertEquals([10], field_list) + + def testSetItem_InvalidValues(self): + field_list = messages.FieldList(self.integer_field, [2]) + + def setitem(): + field_list[0] = '10' + self.assertRaisesWithRegexpMatch( + messages.ValidationError, + re.escape("Expected type %r " + "for IntegerField, found 10 (type %r)" + % (six.integer_types, str)), + setitem) + + def testAppend(self): + field_list = messages.FieldList(self.integer_field, [2]) + field_list.append(10) + self.assertEquals([2, 10], field_list) + + def testAppend_InvalidValues(self): + field_list = messages.FieldList(self.integer_field, [2]) + field_list.name = 'a_field' + + def append(): + field_list.append('10') + self.assertRaisesWithRegexpMatch( + messages.ValidationError, + re.escape("Expected type %r " + "for IntegerField, found 10 (type %r)" + % (six.integer_types, str)), + append) + + def testExtend(self): + field_list = messages.FieldList(self.integer_field, [2]) + field_list.extend([10]) + self.assertEquals([2, 10], field_list) + + def testExtend_InvalidValues(self): + field_list = messages.FieldList(self.integer_field, [2]) + + def extend(): + field_list.extend(['10']) + self.assertRaisesWithRegexpMatch( + messages.ValidationError, + re.escape("Expected type %r " + "for IntegerField, found 10 (type %r)" + % (six.integer_types, str)), + extend) + + def testInsert(self): + field_list = messages.FieldList(self.integer_field, [2, 3]) + field_list.insert(1, 10) + self.assertEquals([2, 10, 3], field_list) + + def testInsert_InvalidValues(self): + field_list = messages.FieldList(self.integer_field, [2, 3]) + + def insert(): + field_list.insert(1, '10') + self.assertRaisesWithRegexpMatch( + messages.ValidationError, + re.escape("Expected type %r " + "for IntegerField, found 10 (type %r)" + % (six.integer_types, str)), + insert) + + def testPickle(self): + """Testing pickling and unpickling of disconnected FieldList instances.""" + field_list = messages.FieldList(self.integer_field, [1, 2, 3, 4, 5]) + unpickled = pickle.loads(pickle.dumps(field_list)) + self.assertEquals(field_list, unpickled) + self.assertIsInstance(unpickled.field, messages.IntegerField) + self.assertEquals(1, unpickled.field.number) + self.assertTrue(unpickled.field.repeated) + + +class FieldTest(test_util.TestCase): + + def ActionOnAllFieldClasses(self, action): + """Test all field classes except Message and Enum. + + Message and Enum require separate tests. + + Args: + action: Callable that takes the field class as a parameter. + """ + for field_class in (messages.IntegerField, + messages.FloatField, + messages.BooleanField, + messages.BytesField, + messages.StringField, + ): + action(field_class) + + def testNumberAttribute(self): + """Test setting the number attribute.""" + def action(field_class): + # Check range. + self.assertRaises(messages.InvalidNumberError, + field_class, + 0) + self.assertRaises(messages.InvalidNumberError, + field_class, + -1) + self.assertRaises(messages.InvalidNumberError, + field_class, + messages.MAX_FIELD_NUMBER + 1) + + # Check reserved. + self.assertRaises(messages.InvalidNumberError, + field_class, + messages.FIRST_RESERVED_FIELD_NUMBER) + self.assertRaises(messages.InvalidNumberError, + field_class, + messages.LAST_RESERVED_FIELD_NUMBER) + self.assertRaises(messages.InvalidNumberError, + field_class, + '1') + + # This one should work. + field_class(number=1) + self.ActionOnAllFieldClasses(action) + + def testRequiredAndRepeated(self): + """Test setting the required and repeated fields.""" + def action(field_class): + field_class(1, required=True) + field_class(1, repeated=True) + self.assertRaises(messages.FieldDefinitionError, + field_class, + 1, + required=True, + repeated=True) + self.ActionOnAllFieldClasses(action) + + def testInvalidVariant(self): + """Test field with invalid variants.""" + def action(field_class): + if field_class is not message_types.DateTimeField: + self.assertRaises(messages.InvalidVariantError, + field_class, + 1, + variant=messages.Variant.ENUM) + self.ActionOnAllFieldClasses(action) + + def testDefaultVariant(self): + """Test that default variant is used when not set.""" + def action(field_class): + field = field_class(1) + self.assertEquals(field_class.DEFAULT_VARIANT, field.variant) + + self.ActionOnAllFieldClasses(action) + + def testAlternateVariant(self): + """Test that default variant is used when not set.""" + field = messages.IntegerField(1, variant=messages.Variant.UINT32) + self.assertEquals(messages.Variant.UINT32, field.variant) + + def testDefaultFields_Single(self): + """Test default field is correct type (single).""" + defaults = {messages.IntegerField: 10, + messages.FloatField: 1.5, + messages.BooleanField: False, + messages.BytesField: b'abc', + messages.StringField: u'abc', + } + + def action(field_class): + field_class(1, default=defaults[field_class]) + self.ActionOnAllFieldClasses(action) + + # Run defaults test again checking for str/unicode compatiblity. + defaults[messages.StringField] = 'abc' + self.ActionOnAllFieldClasses(action) + + def testStringField_BadUnicodeInDefault(self): + """Test binary values in string field.""" + self.assertRaisesWithRegexpMatch( + messages.InvalidDefaultError, + r"Invalid default value for StringField:.*: " + r"Field encountered non-ASCII string .*: " + r"'ascii' codec can't decode byte 0x89 in position 0: " + r"ordinal not in range", + messages.StringField, 1, default=b'\x89') + + def testDefaultFields_InvalidSingle(self): + """Test default field is correct type (invalid single).""" + def action(field_class): + self.assertRaises(messages.InvalidDefaultError, + field_class, + 1, + default=object()) + self.ActionOnAllFieldClasses(action) + + def testDefaultFields_InvalidRepeated(self): + """Test default field does not accept defaults.""" + self.assertRaisesWithRegexpMatch( + messages.FieldDefinitionError, + 'Repeated fields may not have defaults', + messages.StringField, 1, repeated=True, default=[1, 2, 3]) + + def testDefaultFields_None(self): + """Test none is always acceptable.""" + def action(field_class): + field_class(1, default=None) + field_class(1, required=True, default=None) + field_class(1, repeated=True, default=None) + self.ActionOnAllFieldClasses(action) + + def testDefaultFields_Enum(self): + """Test the default for enum fields.""" + class Symbol(messages.Enum): + + ALPHA = 1 + BETA = 2 + GAMMA = 3 + + field = messages.EnumField(Symbol, 1, default=Symbol.ALPHA) + + self.assertEquals(Symbol.ALPHA, field.default) + + def testDefaultFields_EnumStringDelayedResolution(self): + """Test that enum fields resolve default strings.""" + field = messages.EnumField('protorpc.descriptor.FieldDescriptor.Label', + 1, + default='OPTIONAL') + + self.assertEquals(descriptor.FieldDescriptor.Label.OPTIONAL, field.default) + + def testDefaultFields_EnumIntDelayedResolution(self): + """Test that enum fields resolve default integers.""" + field = messages.EnumField('protorpc.descriptor.FieldDescriptor.Label', + 1, + default=2) + + self.assertEquals(descriptor.FieldDescriptor.Label.REQUIRED, field.default) + + def testDefaultFields_EnumOkIfTypeKnown(self): + """Test that enum fields accept valid default values when type is known.""" + field = messages.EnumField(descriptor.FieldDescriptor.Label, + 1, + default='REPEATED') + + self.assertEquals(descriptor.FieldDescriptor.Label.REPEATED, field.default) + + def testDefaultFields_EnumForceCheckIfTypeKnown(self): + """Test that enum fields validate default values if type is known.""" + self.assertRaisesWithRegexpMatch(TypeError, + 'No such value for NOT_A_LABEL in ' + 'Enum Label', + messages.EnumField, + descriptor.FieldDescriptor.Label, + 1, + default='NOT_A_LABEL') + + def testDefaultFields_EnumInvalidDelayedResolution(self): + """Test that enum fields raise errors upon delayed resolution error.""" + field = messages.EnumField('protorpc.descriptor.FieldDescriptor.Label', + 1, + default=200) + + self.assertRaisesWithRegexpMatch(TypeError, + 'No such value for 200 in Enum Label', + getattr, + field, + 'default') + + def testValidate_Valid(self): + """Test validation of valid values.""" + values = {messages.IntegerField: 10, + messages.FloatField: 1.5, + messages.BooleanField: False, + messages.BytesField: b'abc', + messages.StringField: u'abc', + } + def action(field_class): + # Optional. + field = field_class(1) + field.validate(values[field_class]) + + # Required. + field = field_class(1, required=True) + field.validate(values[field_class]) + + # Repeated. + field = field_class(1, repeated=True) + field.validate([]) + field.validate(()) + field.validate([values[field_class]]) + field.validate((values[field_class],)) + + # Right value, but not repeated. + self.assertRaises(messages.ValidationError, + field.validate, + values[field_class]) + self.assertRaises(messages.ValidationError, + field.validate, + values[field_class]) + + self.ActionOnAllFieldClasses(action) + + def testValidate_Invalid(self): + """Test validation of valid values.""" + values = {messages.IntegerField: "10", + messages.FloatField: "blah", + messages.BooleanField: 0, + messages.BytesField: 10.20, + messages.StringField: 42, + } + def action(field_class): + # Optional. + field = field_class(1) + self.assertRaises(messages.ValidationError, + field.validate, + values[field_class]) + + # Required. + field = field_class(1, required=True) + self.assertRaises(messages.ValidationError, + field.validate, + values[field_class]) + + # Repeated. + field = field_class(1, repeated=True) + self.assertRaises(messages.ValidationError, + field.validate, + [values[field_class]]) + self.assertRaises(messages.ValidationError, + field.validate, + (values[field_class],)) + self.ActionOnAllFieldClasses(action) + + def testValidate_None(self): + """Test that None is valid for non-required fields.""" + def action(field_class): + # Optional. + field = field_class(1) + field.validate(None) + + # Required. + field = field_class(1, required=True) + self.assertRaisesWithRegexpMatch(messages.ValidationError, + 'Required field is missing', + field.validate, + None) + + # Repeated. + field = field_class(1, repeated=True) + field.validate(None) + self.assertRaisesWithRegexpMatch(messages.ValidationError, + 'Repeated values for %s may ' + 'not be None' % field_class.__name__, + field.validate, + [None]) + self.assertRaises(messages.ValidationError, + field.validate, + (None,)) + self.ActionOnAllFieldClasses(action) + + def testValidateElement(self): + """Test validation of valid values.""" + values = {messages.IntegerField: (10, -1, 0), + messages.FloatField: (1.5, -1.5, 3), # for json it is all a number + messages.BooleanField: (True, False), + messages.BytesField: (b'abc',), + messages.StringField: (u'abc',), + } + def action(field_class): + # Optional. + field = field_class(1) + for value in values[field_class]: + field.validate_element(value) + + # Required. + field = field_class(1, required=True) + for value in values[field_class]: + field.validate_element(value) + + # Repeated. + field = field_class(1, repeated=True) + self.assertRaises(messages.ValidationError, + field.validate_element, + []) + self.assertRaises(messages.ValidationError, + field.validate_element, + ()) + for value in values[field_class]: + field.validate_element(value) + + # Right value, but repeated. + self.assertRaises(messages.ValidationError, + field.validate_element, + list(values[field_class])) # testing list + self.assertRaises(messages.ValidationError, + field.validate_element, + values[field_class]) # testing tuple + self.ActionOnAllFieldClasses(action) + + def testValidateCastingElement(self): + field = messages.FloatField(1) + self.assertEquals(type(field.validate_element(12)), float) + self.assertEquals(type(field.validate_element(12.0)), float) + field = messages.IntegerField(1) + self.assertEquals(type(field.validate_element(12)), int) + self.assertRaises(messages.ValidationError, + field.validate_element, + 12.0) # should fail from float to int + + def testReadOnly(self): + """Test that objects are all read-only.""" + def action(field_class): + field = field_class(10) + self.assertRaises(AttributeError, + setattr, + field, + 'number', + 20) + self.assertRaises(AttributeError, + setattr, + field, + 'anything_else', + 'whatever') + self.ActionOnAllFieldClasses(action) + + def testMessageField(self): + """Test the construction of message fields.""" + self.assertRaises(messages.FieldDefinitionError, + messages.MessageField, + str, + 10) + + self.assertRaises(messages.FieldDefinitionError, + messages.MessageField, + messages.Message, + 10) + + class MyMessage(messages.Message): + pass + + field = messages.MessageField(MyMessage, 10) + self.assertEquals(MyMessage, field.type) + + def testMessageField_ForwardReference(self): + """Test the construction of forward reference message fields.""" + global MyMessage + global ForwardMessage + try: + class MyMessage(messages.Message): + + self_reference = messages.MessageField('MyMessage', 1) + forward = messages.MessageField('ForwardMessage', 2) + nested = messages.MessageField('ForwardMessage.NestedMessage', 3) + inner = messages.MessageField('Inner', 4) + + class Inner(messages.Message): + + sibling = messages.MessageField('Sibling', 1) + + class Sibling(messages.Message): + + pass + + class ForwardMessage(messages.Message): + + class NestedMessage(messages.Message): + + pass + + self.assertEquals(MyMessage, + MyMessage.field_by_name('self_reference').type) + + self.assertEquals(ForwardMessage, + MyMessage.field_by_name('forward').type) + + self.assertEquals(ForwardMessage.NestedMessage, + MyMessage.field_by_name('nested').type) + + self.assertEquals(MyMessage.Inner, + MyMessage.field_by_name('inner').type) + + self.assertEquals(MyMessage.Sibling, + MyMessage.Inner.field_by_name('sibling').type) + finally: + try: + del MyMessage + del ForwardMessage + except: + pass + + def testMessageField_WrongType(self): + """Test that forward referencing the wrong type raises an error.""" + global AnEnum + try: + class AnEnum(messages.Enum): + pass + + class AnotherMessage(messages.Message): + + a_field = messages.MessageField('AnEnum', 1) + + self.assertRaises(messages.FieldDefinitionError, + getattr, + AnotherMessage.field_by_name('a_field'), + 'type') + finally: + del AnEnum + + def testMessageFieldValidate(self): + """Test validation on message field.""" + class MyMessage(messages.Message): + pass + + class AnotherMessage(messages.Message): + pass + + field = messages.MessageField(MyMessage, 10) + field.validate(MyMessage()) + + self.assertRaises(messages.ValidationError, + field.validate, + AnotherMessage()) + + def testMessageFieldMessageType(self): + """Test message_type property.""" + class MyMessage(messages.Message): + pass + + class HasMessage(messages.Message): + field = messages.MessageField(MyMessage, 1) + + self.assertEqual(HasMessage.field.type, HasMessage.field.message_type) + + def testMessageFieldValueFromMessage(self): + class MyMessage(messages.Message): + pass + + class HasMessage(messages.Message): + field = messages.MessageField(MyMessage, 1) + + instance = MyMessage() + + self.assertTrue(instance is HasMessage.field.value_from_message(instance)) + + def testMessageFieldValueFromMessageWrongType(self): + class MyMessage(messages.Message): + pass + + class HasMessage(messages.Message): + field = messages.MessageField(MyMessage, 1) + + self.assertRaisesWithRegexpMatch( + messages.DecodeError, + 'Expected type MyMessage, got int: 10', + HasMessage.field.value_from_message, 10) + + def testMessageFieldValueToMessage(self): + class MyMessage(messages.Message): + pass + + class HasMessage(messages.Message): + field = messages.MessageField(MyMessage, 1) + + instance = MyMessage() + + self.assertTrue(instance is HasMessage.field.value_to_message(instance)) + + def testMessageFieldValueToMessageWrongType(self): + class MyMessage(messages.Message): + pass + + class MyOtherMessage(messages.Message): + pass + + class HasMessage(messages.Message): + field = messages.MessageField(MyMessage, 1) + + instance = MyOtherMessage() + + self.assertRaisesWithRegexpMatch( + messages.EncodeError, + 'Expected type MyMessage, got MyOtherMessage: ', + HasMessage.field.value_to_message, instance) + + def testIntegerField_AllowLong(self): + """Test that the integer field allows for longs.""" + if six.PY2: + messages.IntegerField(10, default=long(10)) + + def testMessageFieldValidate_Initialized(self): + """Test validation on message field.""" + class MyMessage(messages.Message): + field1 = messages.IntegerField(1, required=True) + + field = messages.MessageField(MyMessage, 10) + + # Will validate messages where is_initialized() is False. + message = MyMessage() + field.validate(message) + message.field1 = 20 + field.validate(message) + + def testEnumField(self): + """Test the construction of enum fields.""" + self.assertRaises(messages.FieldDefinitionError, + messages.EnumField, + str, + 10) + + self.assertRaises(messages.FieldDefinitionError, + messages.EnumField, + messages.Enum, + 10) + + class Color(messages.Enum): + RED = 1 + GREEN = 2 + BLUE = 3 + + field = messages.EnumField(Color, 10) + self.assertEquals(Color, field.type) + + class Another(messages.Enum): + VALUE = 1 + + self.assertRaises(messages.InvalidDefaultError, + messages.EnumField, + Color, + 10, + default=Another.VALUE) + + def testEnumField_ForwardReference(self): + """Test the construction of forward reference enum fields.""" + global MyMessage + global ForwardEnum + global ForwardMessage + try: + class MyMessage(messages.Message): + + forward = messages.EnumField('ForwardEnum', 1) + nested = messages.EnumField('ForwardMessage.NestedEnum', 2) + inner = messages.EnumField('Inner', 3) + + class Inner(messages.Enum): + pass + + class ForwardEnum(messages.Enum): + pass + + class ForwardMessage(messages.Message): + + class NestedEnum(messages.Enum): + pass + + self.assertEquals(ForwardEnum, + MyMessage.field_by_name('forward').type) + + self.assertEquals(ForwardMessage.NestedEnum, + MyMessage.field_by_name('nested').type) + + self.assertEquals(MyMessage.Inner, + MyMessage.field_by_name('inner').type) + finally: + try: + del MyMessage + del ForwardEnum + del ForwardMessage + except: + pass + + def testEnumField_WrongType(self): + """Test that forward referencing the wrong type raises an error.""" + global AMessage + try: + class AMessage(messages.Message): + pass + + class AnotherMessage(messages.Message): + + a_field = messages.EnumField('AMessage', 1) + + self.assertRaises(messages.FieldDefinitionError, + getattr, + AnotherMessage.field_by_name('a_field'), + 'type') + finally: + del AMessage + + def testMessageDefinition(self): + """Test that message definition is set on fields.""" + class MyMessage(messages.Message): + + my_field = messages.StringField(1) + + self.assertEquals(MyMessage, + MyMessage.field_by_name('my_field').message_definition()) + + def testNoneAssignment(self): + """Test that assigning None does not change comparison.""" + class MyMessage(messages.Message): + + my_field = messages.StringField(1) + + m1 = MyMessage() + m2 = MyMessage() + m2.my_field = None + self.assertEquals(m1, m2) + + def testNonAsciiStr(self): + """Test validation fails for non-ascii StringField values.""" + class Thing(messages.Message): + string_field = messages.StringField(2) + + thing = Thing() + self.assertRaisesWithRegexpMatch( + messages.ValidationError, + 'Field string_field encountered non-ASCII string', + setattr, thing, 'string_field', test_util.BINARY) + + +class MessageTest(test_util.TestCase): + """Tests for message class.""" + + def CreateMessageClass(self): + """Creates a simple message class with 3 fields. + + Fields are defined in alphabetical order but with conflicting numeric + order. + """ + class ComplexMessage(messages.Message): + a3 = messages.IntegerField(3) + b1 = messages.StringField(1) + c2 = messages.StringField(2) + + return ComplexMessage + + def testSameNumbers(self): + """Test that cannot assign two fields with same numbers.""" + + def action(): + class BadMessage(messages.Message): + f1 = messages.IntegerField(1) + f2 = messages.IntegerField(1) + self.assertRaises(messages.DuplicateNumberError, + action) + + def testStrictAssignment(self): + """Tests that cannot assign to unknown or non-reserved attributes.""" + class SimpleMessage(messages.Message): + field = messages.IntegerField(1) + + simple_message = SimpleMessage() + self.assertRaises(AttributeError, + setattr, + simple_message, + 'does_not_exist', + 10) + + def testListAssignmentDoesNotCopy(self): + class SimpleMessage(messages.Message): + repeated = messages.IntegerField(1, repeated=True) + + message = SimpleMessage() + original = message.repeated + message.repeated = [] + self.assertFalse(original is message.repeated) + + def testValidate_Optional(self): + """Tests validation of optional fields.""" + class SimpleMessage(messages.Message): + non_required = messages.IntegerField(1) + + simple_message = SimpleMessage() + simple_message.check_initialized() + simple_message.non_required = 10 + simple_message.check_initialized() + + def testValidate_Required(self): + """Tests validation of required fields.""" + class SimpleMessage(messages.Message): + required = messages.IntegerField(1, required=True) + + simple_message = SimpleMessage() + self.assertRaises(messages.ValidationError, + simple_message.check_initialized) + simple_message.required = 10 + simple_message.check_initialized() + + def testValidate_Repeated(self): + """Tests validation of repeated fields.""" + class SimpleMessage(messages.Message): + repeated = messages.IntegerField(1, repeated=True) + + simple_message = SimpleMessage() + + # Check valid values. + for valid_value in [], [10], [10, 20], (), (10,), (10, 20): + simple_message.repeated = valid_value + simple_message.check_initialized() + + # Check cleared. + simple_message.repeated = [] + simple_message.check_initialized() + + # Check invalid values. + for invalid_value in 10, ['10', '20'], [None], (None,): + self.assertRaises(messages.ValidationError, + setattr, simple_message, 'repeated', invalid_value) + + def testIsInitialized(self): + """Tests is_initialized.""" + class SimpleMessage(messages.Message): + required = messages.IntegerField(1, required=True) + + simple_message = SimpleMessage() + self.assertFalse(simple_message.is_initialized()) + + simple_message.required = 10 + + self.assertTrue(simple_message.is_initialized()) + + def testIsInitializedNestedField(self): + """Tests is_initialized for nested fields.""" + class SimpleMessage(messages.Message): + required = messages.IntegerField(1, required=True) + + class NestedMessage(messages.Message): + simple = messages.MessageField(SimpleMessage, 1) + + simple_message = SimpleMessage() + self.assertFalse(simple_message.is_initialized()) + nested_message = NestedMessage(simple=simple_message) + self.assertFalse(nested_message.is_initialized()) + + simple_message.required = 10 + + self.assertTrue(simple_message.is_initialized()) + self.assertTrue(nested_message.is_initialized()) + + def testInitializeNestedFieldFromDict(self): + """Tests initializing nested fields from dict.""" + class SimpleMessage(messages.Message): + required = messages.IntegerField(1, required=True) + + class NestedMessage(messages.Message): + simple = messages.MessageField(SimpleMessage, 1) + + class RepeatedMessage(messages.Message): + simple = messages.MessageField(SimpleMessage, 1, repeated=True) + + nested_message1 = NestedMessage(simple={'required': 10}) + self.assertTrue(nested_message1.is_initialized()) + self.assertTrue(nested_message1.simple.is_initialized()) + + nested_message2 = NestedMessage() + nested_message2.simple = {'required': 10} + self.assertTrue(nested_message2.is_initialized()) + self.assertTrue(nested_message2.simple.is_initialized()) + + repeated_values = [{}, {'required': 10}, SimpleMessage(required=20)] + + repeated_message1 = RepeatedMessage(simple=repeated_values) + self.assertEquals(3, len(repeated_message1.simple)) + self.assertFalse(repeated_message1.is_initialized()) + + repeated_message1.simple[0].required = 0 + self.assertTrue(repeated_message1.is_initialized()) + + repeated_message2 = RepeatedMessage() + repeated_message2.simple = repeated_values + self.assertEquals(3, len(repeated_message2.simple)) + self.assertFalse(repeated_message2.is_initialized()) + + repeated_message2.simple[0].required = 0 + self.assertTrue(repeated_message2.is_initialized()) + + def testNestedMethodsNotAllowed(self): + """Test that method definitions on Message classes are not allowed.""" + def action(): + class WithMethods(messages.Message): + def not_allowed(self): + pass + + self.assertRaises(messages.MessageDefinitionError, + action) + + def testNestedAttributesNotAllowed(self): + """Test that attribute assignment on Message classes are not allowed.""" + def int_attribute(): + class WithMethods(messages.Message): + not_allowed = 1 + + def string_attribute(): + class WithMethods(messages.Message): + not_allowed = 'not allowed' + + def enum_attribute(): + class WithMethods(messages.Message): + not_allowed = Color.RED + + for action in (int_attribute, string_attribute, enum_attribute): + self.assertRaises(messages.MessageDefinitionError, + action) + + def testNameIsSetOnFields(self): + """Make sure name is set on fields after Message class init.""" + class HasNamedFields(messages.Message): + field = messages.StringField(1) + + self.assertEquals('field', HasNamedFields.field_by_number(1).name) + + def testSubclassingMessageDisallowed(self): + """Not permitted to create sub-classes of message classes.""" + class SuperClass(messages.Message): + pass + + def action(): + class SubClass(SuperClass): + pass + + self.assertRaises(messages.MessageDefinitionError, + action) + + def testAllFields(self): + """Test all_fields method.""" + ComplexMessage = self.CreateMessageClass() + fields = list(ComplexMessage.all_fields()) + + # Order does not matter, so sort now. + fields = sorted(fields, key=lambda f: f.name) + + self.assertEquals(3, len(fields)) + self.assertEquals('a3', fields[0].name) + self.assertEquals('b1', fields[1].name) + self.assertEquals('c2', fields[2].name) + + def testFieldByName(self): + """Test getting field by name.""" + ComplexMessage = self.CreateMessageClass() + + self.assertEquals(3, ComplexMessage.field_by_name('a3').number) + self.assertEquals(1, ComplexMessage.field_by_name('b1').number) + self.assertEquals(2, ComplexMessage.field_by_name('c2').number) + + self.assertRaises(KeyError, + ComplexMessage.field_by_name, + 'unknown') + + def testFieldByNumber(self): + """Test getting field by number.""" + ComplexMessage = self.CreateMessageClass() + + self.assertEquals('a3', ComplexMessage.field_by_number(3).name) + self.assertEquals('b1', ComplexMessage.field_by_number(1).name) + self.assertEquals('c2', ComplexMessage.field_by_number(2).name) + + self.assertRaises(KeyError, + ComplexMessage.field_by_number, + 4) + + def testGetAssignedValue(self): + """Test getting the assigned value of a field.""" + class SomeMessage(messages.Message): + a_value = messages.StringField(1, default=u'a default') + + message = SomeMessage() + self.assertEquals(None, message.get_assigned_value('a_value')) + + message.a_value = u'a string' + self.assertEquals(u'a string', message.get_assigned_value('a_value')) + + message.a_value = u'a default' + self.assertEquals(u'a default', message.get_assigned_value('a_value')) + + self.assertRaisesWithRegexpMatch( + AttributeError, + 'Message SomeMessage has no field no_such_field', + message.get_assigned_value, + 'no_such_field') + + def testReset(self): + """Test resetting a field value.""" + class SomeMessage(messages.Message): + a_value = messages.StringField(1, default=u'a default') + repeated = messages.IntegerField(2, repeated=True) + + message = SomeMessage() + + self.assertRaises(AttributeError, message.reset, 'unknown') + + self.assertEquals(u'a default', message.a_value) + message.reset('a_value') + self.assertEquals(u'a default', message.a_value) + + message.a_value = u'a new value' + self.assertEquals(u'a new value', message.a_value) + message.reset('a_value') + self.assertEquals(u'a default', message.a_value) + + message.repeated = [1, 2, 3] + self.assertEquals([1, 2, 3], message.repeated) + saved = message.repeated + message.reset('repeated') + self.assertEquals([], message.repeated) + self.assertIsInstance(message.repeated, messages.FieldList) + self.assertEquals([1, 2, 3], saved) + + def testAllowNestedEnums(self): + """Test allowing nested enums in a message definition.""" + class Trade(messages.Message): + class Duration(messages.Enum): + GTC = 1 + DAY = 2 + + class Currency(messages.Enum): + USD = 1 + GBP = 2 + INR = 3 + + # Sorted by name order seems to be the only feasible option. + self.assertEquals(['Currency', 'Duration'], Trade.__enums__) + + # Message definition will now be set on Enumerated objects. + self.assertEquals(Trade, Trade.Duration.message_definition()) + + def testAllowNestedMessages(self): + """Test allowing nested messages in a message definition.""" + class Trade(messages.Message): + class Lot(messages.Message): + pass + + class Agent(messages.Message): + pass + + # Sorted by name order seems to be the only feasible option. + self.assertEquals(['Agent', 'Lot'], Trade.__messages__) + self.assertEquals(Trade, Trade.Agent.message_definition()) + self.assertEquals(Trade, Trade.Lot.message_definition()) + + # But not Message itself. + def action(): + class Trade(messages.Message): + NiceTry = messages.Message + self.assertRaises(messages.MessageDefinitionError, action) + + def testDisallowClassAssignments(self): + """Test setting class attributes may not happen.""" + class MyMessage(messages.Message): + pass + + self.assertRaises(AttributeError, + setattr, + MyMessage, + 'x', + 'do not assign') + + def testEquality(self): + """Test message class equality.""" + # Comparison against enums must work. + class MyEnum(messages.Enum): + val1 = 1 + val2 = 2 + + # Comparisons against nested messages must work. + class AnotherMessage(messages.Message): + string = messages.StringField(1) + + class MyMessage(messages.Message): + field1 = messages.IntegerField(1) + field2 = messages.EnumField(MyEnum, 2) + field3 = messages.MessageField(AnotherMessage, 3) + + message1 = MyMessage() + + self.assertNotEquals('hi', message1) + self.assertNotEquals(AnotherMessage(), message1) + self.assertEquals(message1, message1) + + message2 = MyMessage() + + self.assertEquals(message1, message2) + + message1.field1 = 10 + self.assertNotEquals(message1, message2) + + message2.field1 = 20 + self.assertNotEquals(message1, message2) + + message2.field1 = 10 + self.assertEquals(message1, message2) + + message1.field2 = MyEnum.val1 + self.assertNotEquals(message1, message2) + + message2.field2 = MyEnum.val2 + self.assertNotEquals(message1, message2) + + message2.field2 = MyEnum.val1 + self.assertEquals(message1, message2) + + message1.field3 = AnotherMessage() + message1.field3.string = u'value1' + self.assertNotEquals(message1, message2) + + message2.field3 = AnotherMessage() + message2.field3.string = u'value2' + self.assertNotEquals(message1, message2) + + message2.field3.string = u'value1' + self.assertEquals(message1, message2) + + def testEqualityWithUnknowns(self): + """Test message class equality with unknown fields.""" + + class MyMessage(messages.Message): + field1 = messages.IntegerField(1) + + message1 = MyMessage() + message2 = MyMessage() + self.assertEquals(message1, message2) + message1.set_unrecognized_field('unknown1', 'value1', + messages.Variant.STRING) + self.assertEquals(message1, message2) + + message1.set_unrecognized_field('unknown2', ['asdf', 3], + messages.Variant.STRING) + message1.set_unrecognized_field('unknown3', 4.7, + messages.Variant.DOUBLE) + self.assertEquals(message1, message2) + + def testUnrecognizedFieldInvalidVariant(self): + class MyMessage(messages.Message): + field1 = messages.IntegerField(1) + + message1 = MyMessage() + self.assertRaises(TypeError, message1.set_unrecognized_field, 'unknown4', + {'unhandled': 'type'}, None) + self.assertRaises(TypeError, message1.set_unrecognized_field, 'unknown4', + {'unhandled': 'type'}, 123) + + def testRepr(self): + """Test represtation of Message object.""" + class MyMessage(messages.Message): + integer_value = messages.IntegerField(1) + string_value = messages.StringField(2) + unassigned = messages.StringField(3) + unassigned_with_default = messages.StringField(4, default=u'a default') + + my_message = MyMessage() + my_message.integer_value = 42 + my_message.string_value = u'A string' + + pat = re.compile(r"") + self.assertTrue(pat.match(repr(my_message)) is not None) + + def testValidation(self): + """Test validation of message values.""" + # Test optional. + class SubMessage(messages.Message): + pass + + class Message(messages.Message): + val = messages.MessageField(SubMessage, 1) + + message = Message() + + message_field = messages.MessageField(Message, 1) + message_field.validate(message) + message.val = SubMessage() + message_field.validate(message) + self.assertRaises(messages.ValidationError, + setattr, message, 'val', [SubMessage()]) + + # Test required. + class Message(messages.Message): + val = messages.MessageField(SubMessage, 1, required=True) + + message = Message() + + message_field = messages.MessageField(Message, 1) + message_field.validate(message) + message.val = SubMessage() + message_field.validate(message) + self.assertRaises(messages.ValidationError, + setattr, message, 'val', [SubMessage()]) + + # Test repeated. + class Message(messages.Message): + val = messages.MessageField(SubMessage, 1, repeated=True) + + message = Message() + + message_field = messages.MessageField(Message, 1) + message_field.validate(message) + self.assertRaisesWithRegexpMatch( + messages.ValidationError, + "Field val is repeated. Found: ", + setattr, message, 'val', SubMessage()) + message.val = [SubMessage()] + message_field.validate(message) + + def testDefinitionName(self): + """Test message name.""" + class MyMessage(messages.Message): + pass + + module_name = test_util.get_module_name(FieldTest) + self.assertEquals('%s.MyMessage' % module_name, + MyMessage.definition_name()) + self.assertEquals(module_name, MyMessage.outer_definition_name()) + self.assertEquals(module_name, MyMessage.definition_package()) + + self.assertEquals(six.text_type, type(MyMessage.definition_name())) + self.assertEquals(six.text_type, type(MyMessage.outer_definition_name())) + self.assertEquals(six.text_type, type(MyMessage.definition_package())) + + def testDefinitionName_OverrideModule(self): + """Test message module is overriden by module package name.""" + class MyMessage(messages.Message): + pass + + global package + package = 'my.package' + + try: + self.assertEquals('my.package.MyMessage', MyMessage.definition_name()) + self.assertEquals('my.package', MyMessage.outer_definition_name()) + self.assertEquals('my.package', MyMessage.definition_package()) + + self.assertEquals(six.text_type, type(MyMessage.definition_name())) + self.assertEquals(six.text_type, type(MyMessage.outer_definition_name())) + self.assertEquals(six.text_type, type(MyMessage.definition_package())) + finally: + del package + + def testDefinitionName_NoModule(self): + """Test what happens when there is no module for message.""" + class MyMessage(messages.Message): + pass + + original_modules = sys.modules + sys.modules = dict(sys.modules) + try: + del sys.modules[__name__] + self.assertEquals('MyMessage', MyMessage.definition_name()) + self.assertEquals(None, MyMessage.outer_definition_name()) + self.assertEquals(None, MyMessage.definition_package()) + + self.assertEquals(six.text_type, type(MyMessage.definition_name())) + finally: + sys.modules = original_modules + + def testDefinitionName_Nested(self): + """Test nested message names.""" + class MyMessage(messages.Message): + + class NestedMessage(messages.Message): + + class NestedMessage(messages.Message): + + pass + + module_name = test_util.get_module_name(MessageTest) + self.assertEquals('%s.MyMessage.NestedMessage' % module_name, + MyMessage.NestedMessage.definition_name()) + self.assertEquals('%s.MyMessage' % module_name, + MyMessage.NestedMessage.outer_definition_name()) + self.assertEquals(module_name, + MyMessage.NestedMessage.definition_package()) + + self.assertEquals('%s.MyMessage.NestedMessage.NestedMessage' % module_name, + MyMessage.NestedMessage.NestedMessage.definition_name()) + self.assertEquals( + '%s.MyMessage.NestedMessage' % module_name, + MyMessage.NestedMessage.NestedMessage.outer_definition_name()) + self.assertEquals( + module_name, + MyMessage.NestedMessage.NestedMessage.definition_package()) + + + def testMessageDefinition(self): + """Test that enumeration knows its enclosing message definition.""" + class OuterMessage(messages.Message): + + class InnerMessage(messages.Message): + pass + + self.assertEquals(None, OuterMessage.message_definition()) + self.assertEquals(OuterMessage, + OuterMessage.InnerMessage.message_definition()) + + def testConstructorKwargs(self): + """Test kwargs via constructor.""" + class SomeMessage(messages.Message): + name = messages.StringField(1) + number = messages.IntegerField(2) + + expected = SomeMessage() + expected.name = 'my name' + expected.number = 200 + self.assertEquals(expected, SomeMessage(name='my name', number=200)) + + def testConstructorNotAField(self): + """Test kwargs via constructor with wrong names.""" + class SomeMessage(messages.Message): + pass + + self.assertRaisesWithRegexpMatch( + AttributeError, + 'May not assign arbitrary value does_not_exist to message SomeMessage', + SomeMessage, + does_not_exist=10) + + def testGetUnsetRepeatedValue(self): + class SomeMessage(messages.Message): + repeated = messages.IntegerField(1, repeated=True) + + instance = SomeMessage() + self.assertEquals([], instance.repeated) + self.assertTrue(isinstance(instance.repeated, messages.FieldList)) + + def testCompareAutoInitializedRepeatedFields(self): + class SomeMessage(messages.Message): + repeated = messages.IntegerField(1, repeated=True) + + message1 = SomeMessage(repeated=[]) + message2 = SomeMessage() + self.assertEquals(message1, message2) + + def testUnknownValues(self): + """Test message class equality with unknown fields.""" + class MyMessage(messages.Message): + field1 = messages.IntegerField(1) + + message = MyMessage() + self.assertEquals([], message.all_unrecognized_fields()) + self.assertEquals((None, None), + message.get_unrecognized_field_info('doesntexist')) + self.assertEquals((None, None), + message.get_unrecognized_field_info( + 'doesntexist', None, None)) + self.assertEquals(('defaultvalue', 'defaultwire'), + message.get_unrecognized_field_info( + 'doesntexist', 'defaultvalue', 'defaultwire')) + self.assertEquals((3, None), + message.get_unrecognized_field_info( + 'doesntexist', value_default=3)) + + message.set_unrecognized_field('exists', 9.5, messages.Variant.DOUBLE) + self.assertEquals(1, len(message.all_unrecognized_fields())) + self.assertTrue('exists' in message.all_unrecognized_fields()) + self.assertEquals((9.5, messages.Variant.DOUBLE), + message.get_unrecognized_field_info('exists')) + self.assertEquals((9.5, messages.Variant.DOUBLE), + message.get_unrecognized_field_info('exists', 'type', + 1234)) + self.assertEquals((1234, None), + message.get_unrecognized_field_info('doesntexist', 1234)) + + message.set_unrecognized_field('another', 'value', messages.Variant.STRING) + self.assertEquals(2, len(message.all_unrecognized_fields())) + self.assertTrue('exists' in message.all_unrecognized_fields()) + self.assertTrue('another' in message.all_unrecognized_fields()) + self.assertEquals((9.5, messages.Variant.DOUBLE), + message.get_unrecognized_field_info('exists')) + self.assertEquals(('value', messages.Variant.STRING), + message.get_unrecognized_field_info('another')) + + message.set_unrecognized_field('typetest1', ['list', 0, ('test',)], + messages.Variant.STRING) + self.assertEquals((['list', 0, ('test',)], messages.Variant.STRING), + message.get_unrecognized_field_info('typetest1')) + message.set_unrecognized_field('typetest2', '', messages.Variant.STRING) + self.assertEquals(('', messages.Variant.STRING), + message.get_unrecognized_field_info('typetest2')) + + def testPickle(self): + """Testing pickling and unpickling of Message instances.""" + global MyEnum + global AnotherMessage + global MyMessage + + class MyEnum(messages.Enum): + val1 = 1 + val2 = 2 + + class AnotherMessage(messages.Message): + string = messages.StringField(1, repeated=True) + + class MyMessage(messages.Message): + field1 = messages.IntegerField(1) + field2 = messages.EnumField(MyEnum, 2) + field3 = messages.MessageField(AnotherMessage, 3) + + message = MyMessage(field1=1, field2=MyEnum.val2, + field3=AnotherMessage(string=['a', 'b', 'c'])) + message.set_unrecognized_field('exists', 'value', messages.Variant.STRING) + message.set_unrecognized_field('repeated', ['list', 0, ('test',)], + messages.Variant.STRING) + unpickled = pickle.loads(pickle.dumps(message)) + self.assertEquals(message, unpickled) + self.assertTrue(AnotherMessage.string is unpickled.field3.string.field) + self.assertTrue('exists' in message.all_unrecognized_fields()) + self.assertEquals(('value', messages.Variant.STRING), + message.get_unrecognized_field_info('exists')) + self.assertEquals((['list', 0, ('test',)], messages.Variant.STRING), + message.get_unrecognized_field_info('repeated')) + + +class FindDefinitionTest(test_util.TestCase): + """Test finding definitions relative to various definitions and modules.""" + + def setUp(self): + """Set up module-space. Starts off empty.""" + self.modules = {} + + def DefineModule(self, name): + """Define a module and its parents in module space. + + Modules that are already defined in self.modules are not re-created. + + Args: + name: Fully qualified name of modules to create. + + Returns: + Deepest nested module. For example: + + DefineModule('a.b.c') # Returns c. + """ + name_path = name.split('.') + full_path = [] + for node in name_path: + full_path.append(node) + full_name = '.'.join(full_path) + self.modules.setdefault(full_name, types.ModuleType(full_name)) + return self.modules[name] + + def DefineMessage(self, module, name, children={}, add_to_module=True): + """Define a new Message class in the context of a module. + + Used for easily describing complex Message hierarchy. Message is defined + including all child definitions. + + Args: + module: Fully qualified name of module to place Message class in. + name: Name of Message to define within module. + children: Define any level of nesting of children definitions. To define + a message, map the name to another dictionary. The dictionary can + itself contain additional definitions, and so on. To map to an Enum, + define the Enum class separately and map it by name. + add_to_module: If True, new Message class is added to module. If False, + new Message is not added. + """ + # Make sure module exists. + module_instance = self.DefineModule(module) + + # Recursively define all child messages. + for attribute, value in children.items(): + if isinstance(value, dict): + children[attribute] = self.DefineMessage( + module, attribute, value, False) + + # Override default __module__ variable. + children['__module__'] = module + + # Instantiate and possibly add to module. + message_class = type(name, (messages.Message,), dict(children)) + if add_to_module: + setattr(module_instance, name, message_class) + return message_class + + def Importer(self, module, globals='', locals='', fromlist=None): + """Importer function. + + Acts like __import__. Only loads modules from self.modules. Does not + try to load real modules defined elsewhere. Does not try to handle relative + imports. + + Args: + module: Fully qualified name of module to load from self.modules. + """ + if fromlist is None: + module = module.split('.')[0] + try: + return self.modules[module] + except KeyError: + raise ImportError() + + def testNoSuchModule(self): + """Test searching for definitions that do no exist.""" + self.assertRaises(messages.DefinitionNotFoundError, + messages.find_definition, + 'does.not.exist', + importer=self.Importer) + + def testRefersToModule(self): + """Test that referring to a module does not return that module.""" + self.DefineModule('i.am.a.module') + self.assertRaises(messages.DefinitionNotFoundError, + messages.find_definition, + 'i.am.a.module', + importer=self.Importer) + + def testNoDefinition(self): + """Test not finding a definition in an existing module.""" + self.DefineModule('i.am.a.module') + self.assertRaises(messages.DefinitionNotFoundError, + messages.find_definition, + 'i.am.a.module.MyMessage', + importer=self.Importer) + + def testNotADefinition(self): + """Test trying to fetch something that is not a definition.""" + module = self.DefineModule('i.am.a.module') + setattr(module, 'A', 'a string') + self.assertRaises(messages.DefinitionNotFoundError, + messages.find_definition, + 'i.am.a.module.A', + importer=self.Importer) + + def testGlobalFind(self): + """Test finding definitions from fully qualified module names.""" + A = self.DefineMessage('a.b.c', 'A', {}) + self.assertEquals(A, messages.find_definition('a.b.c.A', + importer=self.Importer)) + B = self.DefineMessage('a.b.c', 'B', {'C':{}}) + self.assertEquals(B.C, messages.find_definition('a.b.c.B.C', + importer=self.Importer)) + + def testRelativeToModule(self): + """Test finding definitions relative to modules.""" + # Define modules. + a = self.DefineModule('a') + b = self.DefineModule('a.b') + c = self.DefineModule('a.b.c') + + # Define messages. + A = self.DefineMessage('a', 'A') + B = self.DefineMessage('a.b', 'B') + C = self.DefineMessage('a.b.c', 'C') + D = self.DefineMessage('a.b.d', 'D') + + # Find A, B, C and D relative to a. + self.assertEquals(A, messages.find_definition( + 'A', a, importer=self.Importer)) + self.assertEquals(B, messages.find_definition( + 'b.B', a, importer=self.Importer)) + self.assertEquals(C, messages.find_definition( + 'b.c.C', a, importer=self.Importer)) + self.assertEquals(D, messages.find_definition( + 'b.d.D', a, importer=self.Importer)) + + # Find A, B, C and D relative to b. + self.assertEquals(A, messages.find_definition( + 'A', b, importer=self.Importer)) + self.assertEquals(B, messages.find_definition( + 'B', b, importer=self.Importer)) + self.assertEquals(C, messages.find_definition( + 'c.C', b, importer=self.Importer)) + self.assertEquals(D, messages.find_definition( + 'd.D', b, importer=self.Importer)) + + # Find A, B, C and D relative to c. Module d is the same case as c. + self.assertEquals(A, messages.find_definition( + 'A', c, importer=self.Importer)) + self.assertEquals(B, messages.find_definition( + 'B', c, importer=self.Importer)) + self.assertEquals(C, messages.find_definition( + 'C', c, importer=self.Importer)) + self.assertEquals(D, messages.find_definition( + 'd.D', c, importer=self.Importer)) + + def testRelativeToMessages(self): + """Test finding definitions relative to Message definitions.""" + A = self.DefineMessage('a.b', 'A', {'B': {'C': {}, 'D': {}}}) + B = A.B + C = A.B.C + D = A.B.D + + # Find relative to A. + self.assertEquals(A, messages.find_definition( + 'A', A, importer=self.Importer)) + self.assertEquals(B, messages.find_definition( + 'B', A, importer=self.Importer)) + self.assertEquals(C, messages.find_definition( + 'B.C', A, importer=self.Importer)) + self.assertEquals(D, messages.find_definition( + 'B.D', A, importer=self.Importer)) + + # Find relative to B. + self.assertEquals(A, messages.find_definition( + 'A', B, importer=self.Importer)) + self.assertEquals(B, messages.find_definition( + 'B', B, importer=self.Importer)) + self.assertEquals(C, messages.find_definition( + 'C', B, importer=self.Importer)) + self.assertEquals(D, messages.find_definition( + 'D', B, importer=self.Importer)) + + # Find relative to C. + self.assertEquals(A, messages.find_definition( + 'A', C, importer=self.Importer)) + self.assertEquals(B, messages.find_definition( + 'B', C, importer=self.Importer)) + self.assertEquals(C, messages.find_definition( + 'C', C, importer=self.Importer)) + self.assertEquals(D, messages.find_definition( + 'D', C, importer=self.Importer)) + + # Find relative to C searching from c. + self.assertEquals(A, messages.find_definition( + 'b.A', C, importer=self.Importer)) + self.assertEquals(B, messages.find_definition( + 'b.A.B', C, importer=self.Importer)) + self.assertEquals(C, messages.find_definition( + 'b.A.B.C', C, importer=self.Importer)) + self.assertEquals(D, messages.find_definition( + 'b.A.B.D', C, importer=self.Importer)) + + def testAbsoluteReference(self): + """Test finding absolute definition names.""" + # Define modules. + a = self.DefineModule('a') + b = self.DefineModule('a.a') + + # Define messages. + aA = self.DefineMessage('a', 'A') + aaA = self.DefineMessage('a.a', 'A') + + # Always find a.A. + self.assertEquals(aA, messages.find_definition('.a.A', None, + importer=self.Importer)) + self.assertEquals(aA, messages.find_definition('.a.A', a, + importer=self.Importer)) + self.assertEquals(aA, messages.find_definition('.a.A', aA, + importer=self.Importer)) + self.assertEquals(aA, messages.find_definition('.a.A', aaA, + importer=self.Importer)) + + def testFindEnum(self): + """Test that Enums are found.""" + class Color(messages.Enum): + pass + A = self.DefineMessage('a', 'A', {'Color': Color}) + + self.assertEquals( + Color, + messages.find_definition('Color', A, importer=self.Importer)) + + def testFalseScope(self): + """Test that Message definitions nested in strange objects are hidden.""" + global X + class X(object): + class A(messages.Message): + pass + + self.assertRaises(TypeError, messages.find_definition, 'A', X) + self.assertRaises(messages.DefinitionNotFoundError, + messages.find_definition, + 'X.A', sys.modules[__name__]) + + def testSearchAttributeFirst(self): + """Make sure not faked out by module, but continues searching.""" + A = self.DefineMessage('a', 'A') + module_A = self.DefineModule('a.A') + + self.assertEquals(A, messages.find_definition( + 'a.A', None, importer=self.Importer)) + + +class FindDefinitionUnicodeTests(test_util.TestCase): + + # TODO(craigcitro): Fix this test and re-enable it. + def notatestUnicodeString(self): + """Test using unicode names.""" + from protorpc import registry + self.assertEquals('ServiceMapping', + messages.find_definition( + u'protorpc.registry.ServiceMapping', + None).__name__) + + +def main(): + unittest.main() + + +if __name__ == '__main__': + main() diff --git a/endpoints/internal/protorpc/non_sdk_imports.py b/endpoints/internal/protorpc/non_sdk_imports.py new file mode 100644 index 0000000..5b971ec --- /dev/null +++ b/endpoints/internal/protorpc/non_sdk_imports.py @@ -0,0 +1,21 @@ +"""Dynamically decide from where to import other non SDK Google modules. + +All other protorpc code should import other non SDK modules from +this module. If necessary, add new imports here (in both places). +""" + +__author__ = 'yey@google.com (Ye Yuan)' + +# pylint: disable=g-import-not-at-top +# pylint: disable=unused-import + +try: + from google.protobuf import descriptor + normal_environment = True +except ImportError: + normal_environment = False + +if normal_environment: + from google.protobuf import descriptor_pb2 + from google.protobuf import message + from google.protobuf import reflection diff --git a/endpoints/internal/protorpc/protobuf.py b/endpoints/internal/protorpc/protobuf.py new file mode 100644 index 0000000..18d0074 --- /dev/null +++ b/endpoints/internal/protorpc/protobuf.py @@ -0,0 +1,359 @@ +#!/usr/bin/env python +# +# Copyright 2010 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""Protocol buffer support for message types. + +For more details about protocol buffer encoding and decoding please see: + + http://code.google.com/apis/protocolbuffers/docs/encoding.html + +Public Exceptions: + DecodeError: Raised when a decode error occurs from incorrect protobuf format. + +Public Functions: + encode_message: Encodes a message in to a protocol buffer string. + decode_message: Decode from a protocol buffer string to a message. +""" +import six + +__author__ = 'rafek@google.com (Rafe Kaplan)' + + +import array + +from . import message_types +from . import messages +from . import util +from .google_imports import ProtocolBuffer + + +__all__ = ['ALTERNATIVE_CONTENT_TYPES', + 'CONTENT_TYPE', + 'encode_message', + 'decode_message', + ] + +CONTENT_TYPE = 'application/octet-stream' + +ALTERNATIVE_CONTENT_TYPES = ['application/x-google-protobuf'] + + +class _Encoder(ProtocolBuffer.Encoder): + """Extension of protocol buffer encoder. + + Original protocol buffer encoder does not have complete set of methods + for handling required encoding. This class adds them. + """ + + # TODO(rafek): Implement the missing encoding types. + def no_encoding(self, value): + """No encoding available for type. + + Args: + value: Value to encode. + + Raises: + NotImplementedError at all times. + """ + raise NotImplementedError() + + def encode_enum(self, value): + """Encode an enum value. + + Args: + value: Enum to encode. + """ + self.putVarInt32(value.number) + + def encode_message(self, value): + """Encode a Message in to an embedded message. + + Args: + value: Message instance to encode. + """ + self.putPrefixedString(encode_message(value)) + + + def encode_unicode_string(self, value): + """Helper to properly pb encode unicode strings to UTF-8. + + Args: + value: String value to encode. + """ + if isinstance(value, six.text_type): + value = value.encode('utf-8') + self.putPrefixedString(value) + + +class _Decoder(ProtocolBuffer.Decoder): + """Extension of protocol buffer decoder. + + Original protocol buffer decoder does not have complete set of methods + for handling required decoding. This class adds them. + """ + + # TODO(rafek): Implement the missing encoding types. + def no_decoding(self): + """No decoding available for type. + + Raises: + NotImplementedError at all times. + """ + raise NotImplementedError() + + def decode_string(self): + """Decode a unicode string. + + Returns: + Next value in stream as a unicode string. + """ + return self.getPrefixedString().decode('UTF-8') + + def decode_boolean(self): + """Decode a boolean value. + + Returns: + Next value in stream as a boolean. + """ + return bool(self.getBoolean()) + + +# Number of bits used to describe a protocol buffer bits used for the variant. +_WIRE_TYPE_BITS = 3 +_WIRE_TYPE_MASK = 7 + + +# Maps variant to underlying wire type. Many variants map to same type. +_VARIANT_TO_WIRE_TYPE = { + messages.Variant.DOUBLE: _Encoder.DOUBLE, + messages.Variant.FLOAT: _Encoder.FLOAT, + messages.Variant.INT64: _Encoder.NUMERIC, + messages.Variant.UINT64: _Encoder.NUMERIC, + messages.Variant.INT32: _Encoder.NUMERIC, + messages.Variant.BOOL: _Encoder.NUMERIC, + messages.Variant.STRING: _Encoder.STRING, + messages.Variant.MESSAGE: _Encoder.STRING, + messages.Variant.BYTES: _Encoder.STRING, + messages.Variant.UINT32: _Encoder.NUMERIC, + messages.Variant.ENUM: _Encoder.NUMERIC, + messages.Variant.SINT32: _Encoder.NUMERIC, + messages.Variant.SINT64: _Encoder.NUMERIC, +} + + +# Maps variant to encoder method. +_VARIANT_TO_ENCODER_MAP = { + messages.Variant.DOUBLE: _Encoder.putDouble, + messages.Variant.FLOAT: _Encoder.putFloat, + messages.Variant.INT64: _Encoder.putVarInt64, + messages.Variant.UINT64: _Encoder.putVarUint64, + messages.Variant.INT32: _Encoder.putVarInt32, + messages.Variant.BOOL: _Encoder.putBoolean, + messages.Variant.STRING: _Encoder.encode_unicode_string, + messages.Variant.MESSAGE: _Encoder.encode_message, + messages.Variant.BYTES: _Encoder.encode_unicode_string, + messages.Variant.UINT32: _Encoder.no_encoding, + messages.Variant.ENUM: _Encoder.encode_enum, + messages.Variant.SINT32: _Encoder.no_encoding, + messages.Variant.SINT64: _Encoder.no_encoding, +} + + +# Basic wire format decoders. Used for reading unknown values. +_WIRE_TYPE_TO_DECODER_MAP = { + _Encoder.NUMERIC: _Decoder.getVarInt64, + _Encoder.DOUBLE: _Decoder.getDouble, + _Encoder.STRING: _Decoder.getPrefixedString, + _Encoder.FLOAT: _Decoder.getFloat, +} + + +# Map wire type to variant. Used to find a variant for unknown values. +_WIRE_TYPE_TO_VARIANT_MAP = { + _Encoder.NUMERIC: messages.Variant.INT64, + _Encoder.DOUBLE: messages.Variant.DOUBLE, + _Encoder.STRING: messages.Variant.STRING, + _Encoder.FLOAT: messages.Variant.FLOAT, +} + + +# Wire type to name mapping for error messages. +_WIRE_TYPE_NAME = { + _Encoder.NUMERIC: 'NUMERIC', + _Encoder.DOUBLE: 'DOUBLE', + _Encoder.STRING: 'STRING', + _Encoder.FLOAT: 'FLOAT', +} + + +# Maps variant to decoder method. +_VARIANT_TO_DECODER_MAP = { + messages.Variant.DOUBLE: _Decoder.getDouble, + messages.Variant.FLOAT: _Decoder.getFloat, + messages.Variant.INT64: _Decoder.getVarInt64, + messages.Variant.UINT64: _Decoder.getVarUint64, + messages.Variant.INT32: _Decoder.getVarInt32, + messages.Variant.BOOL: _Decoder.decode_boolean, + messages.Variant.STRING: _Decoder.decode_string, + messages.Variant.MESSAGE: _Decoder.getPrefixedString, + messages.Variant.BYTES: _Decoder.getPrefixedString, + messages.Variant.UINT32: _Decoder.no_decoding, + messages.Variant.ENUM: _Decoder.getVarInt32, + messages.Variant.SINT32: _Decoder.no_decoding, + messages.Variant.SINT64: _Decoder.no_decoding, +} + + +def encode_message(message): + """Encode Message instance to protocol buffer. + + Args: + Message instance to encode in to protocol buffer. + + Returns: + String encoding of Message instance in protocol buffer format. + + Raises: + messages.ValidationError if message is not initialized. + """ + message.check_initialized() + encoder = _Encoder() + + # Get all fields, from the known fields we parsed and the unknown fields + # we saved. Note which ones were known, so we can process them differently. + all_fields = [(field.number, field) for field in message.all_fields()] + all_fields.extend((key, None) + for key in message.all_unrecognized_fields() + if isinstance(key, six.integer_types)) + all_fields.sort() + for field_num, field in all_fields: + if field: + # Known field. + value = message.get_assigned_value(field.name) + if value is None: + continue + variant = field.variant + repeated = field.repeated + else: + # Unrecognized field. + value, variant = message.get_unrecognized_field_info(field_num) + if not isinstance(variant, messages.Variant): + continue + repeated = isinstance(value, (list, tuple)) + + tag = ((field_num << _WIRE_TYPE_BITS) | _VARIANT_TO_WIRE_TYPE[variant]) + + # Write value to wire. + if repeated: + values = value + else: + values = [value] + for next in values: + encoder.putVarInt32(tag) + if isinstance(field, messages.MessageField): + next = field.value_to_message(next) + field_encoder = _VARIANT_TO_ENCODER_MAP[variant] + field_encoder(encoder, next) + + return encoder.buffer().tostring() + + +def decode_message(message_type, encoded_message): + """Decode protocol buffer to Message instance. + + Args: + message_type: Message type to decode data to. + encoded_message: Encoded version of message as string. + + Returns: + Decoded instance of message_type. + + Raises: + DecodeError if an error occurs during decoding, such as incompatible + wire format for a field. + messages.ValidationError if merged message is not initialized. + """ + message = message_type() + message_array = array.array('B') + message_array.fromstring(encoded_message) + try: + decoder = _Decoder(message_array, 0, len(message_array)) + + while decoder.avail() > 0: + # Decode tag and variant information. + encoded_tag = decoder.getVarInt32() + tag = encoded_tag >> _WIRE_TYPE_BITS + wire_type = encoded_tag & _WIRE_TYPE_MASK + try: + found_wire_type_decoder = _WIRE_TYPE_TO_DECODER_MAP[wire_type] + except: + raise messages.DecodeError('No such wire type %d' % wire_type) + + if tag < 1: + raise messages.DecodeError('Invalid tag value %d' % tag) + + try: + field = message.field_by_number(tag) + except KeyError: + # Unexpected tags are ok. + field = None + wire_type_decoder = found_wire_type_decoder + else: + expected_wire_type = _VARIANT_TO_WIRE_TYPE[field.variant] + if expected_wire_type != wire_type: + raise messages.DecodeError('Expected wire type %s but found %s' % ( + _WIRE_TYPE_NAME[expected_wire_type], + _WIRE_TYPE_NAME[wire_type])) + + wire_type_decoder = _VARIANT_TO_DECODER_MAP[field.variant] + + value = wire_type_decoder(decoder) + + # Save unknown fields and skip additional processing. + if not field: + # When saving this, save it under the tag number (which should + # be unique), and set the variant and value so we know how to + # interpret the value later. + variant = _WIRE_TYPE_TO_VARIANT_MAP.get(wire_type) + if variant: + message.set_unrecognized_field(tag, value, variant) + continue + + # Special case Enum and Message types. + if isinstance(field, messages.EnumField): + try: + value = field.type(value) + except TypeError: + raise messages.DecodeError('Invalid enum value %s' % value) + elif isinstance(field, messages.MessageField): + value = decode_message(field.message_type, value) + value = field.value_from_message(value) + + # Merge value in to message. + if field.repeated: + values = getattr(message, field.name) + if values is None: + setattr(message, field.name, [value]) + else: + values.append(value) + else: + setattr(message, field.name, value) + except ProtocolBuffer.ProtocolBufferDecodeError as err: + raise messages.DecodeError('Decoding error: %s' % str(err)) + + message.check_initialized() + return message diff --git a/endpoints/internal/protorpc/protobuf_test.py b/endpoints/internal/protorpc/protobuf_test.py new file mode 100644 index 0000000..9a65824 --- /dev/null +++ b/endpoints/internal/protorpc/protobuf_test.py @@ -0,0 +1,299 @@ +#!/usr/bin/env python +# +# Copyright 2010 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""Tests for protorpc.protobuf.""" + +__author__ = 'rafek@google.com (Rafe Kaplan)' + + +import datetime +import unittest + +from protorpc import message_types +from protorpc import messages +from protorpc import protobuf +from protorpc import protorpc_test_pb2 +from protorpc import test_util +from protorpc import util + +# TODO: Add DateTimeFields to protorpc_test.proto when definition.py +# supports date time fields. +class HasDateTimeMessage(messages.Message): + value = message_types.DateTimeField(1) + +class NestedDateTimeMessage(messages.Message): + value = messages.MessageField(message_types.DateTimeMessage, 1) + + +class ModuleInterfaceTest(test_util.ModuleInterfaceTest, + test_util.TestCase): + + MODULE = protobuf + + +class EncodeMessageTest(test_util.TestCase, + test_util.ProtoConformanceTestBase): + """Test message to protocol buffer encoding.""" + + PROTOLIB = protobuf + + def assertErrorIs(self, exception, message, function, *params, **kwargs): + try: + function(*params, **kwargs) + self.fail('Expected to raise exception %s but did not.' % exception) + except exception as err: + self.assertEquals(message, str(err)) + + @property + def encoded_partial(self): + proto = protorpc_test_pb2.OptionalMessage() + proto.double_value = 1.23 + proto.int64_value = -100000000000 + proto.int32_value = 1020 + proto.string_value = u'a string' + proto.enum_value = protorpc_test_pb2.OptionalMessage.VAL2 + + return proto.SerializeToString() + + @property + def encoded_full(self): + proto = protorpc_test_pb2.OptionalMessage() + proto.double_value = 1.23 + proto.float_value = -2.5 + proto.int64_value = -100000000000 + proto.uint64_value = 102020202020 + proto.int32_value = 1020 + proto.bool_value = True + proto.string_value = u'a string\u044f' + proto.bytes_value = b'a bytes\xff\xfe' + proto.enum_value = protorpc_test_pb2.OptionalMessage.VAL2 + + return proto.SerializeToString() + + @property + def encoded_repeated(self): + proto = protorpc_test_pb2.RepeatedMessage() + proto.double_value.append(1.23) + proto.double_value.append(2.3) + proto.float_value.append(-2.5) + proto.float_value.append(0.5) + proto.int64_value.append(-100000000000) + proto.int64_value.append(20) + proto.uint64_value.append(102020202020) + proto.uint64_value.append(10) + proto.int32_value.append(1020) + proto.int32_value.append(718) + proto.bool_value.append(True) + proto.bool_value.append(False) + proto.string_value.append(u'a string\u044f') + proto.string_value.append(u'another string') + proto.bytes_value.append(b'a bytes\xff\xfe') + proto.bytes_value.append(b'another bytes') + proto.enum_value.append(protorpc_test_pb2.RepeatedMessage.VAL2) + proto.enum_value.append(protorpc_test_pb2.RepeatedMessage.VAL1) + + return proto.SerializeToString() + + @property + def encoded_nested(self): + proto = protorpc_test_pb2.HasNestedMessage() + proto.nested.a_value = 'a string' + + return proto.SerializeToString() + + @property + def encoded_repeated_nested(self): + proto = protorpc_test_pb2.HasNestedMessage() + proto.repeated_nested.add().a_value = 'a string' + proto.repeated_nested.add().a_value = 'another string' + + return proto.SerializeToString() + + unexpected_tag_message = ( + chr((15 << protobuf._WIRE_TYPE_BITS) | protobuf._Encoder.NUMERIC) + + chr(5)) + + @property + def encoded_default_assigned(self): + proto = protorpc_test_pb2.HasDefault() + proto.a_value = test_util.HasDefault.a_value.default + return proto.SerializeToString() + + @property + def encoded_nested_empty(self): + proto = protorpc_test_pb2.HasOptionalNestedMessage() + proto.nested.Clear() + return proto.SerializeToString() + + @property + def encoded_repeated_nested_empty(self): + proto = protorpc_test_pb2.HasOptionalNestedMessage() + proto.repeated_nested.add() + proto.repeated_nested.add() + return proto.SerializeToString() + + @property + def encoded_extend_message(self): + proto = protorpc_test_pb2.RepeatedMessage() + proto.add_int64_value(400) + proto.add_int64_value(50) + proto.add_int64_value(6000) + return proto.SerializeToString() + + @property + def encoded_string_types(self): + proto = protorpc_test_pb2.OptionalMessage() + proto.string_value = u'Latin' + return proto.SerializeToString() + + @property + def encoded_invalid_enum(self): + encoder = protobuf._Encoder() + field_num = test_util.OptionalMessage.enum_value.number + tag = (field_num << protobuf._WIRE_TYPE_BITS) | encoder.NUMERIC + encoder.putVarInt32(tag) + encoder.putVarInt32(1000) + return encoder.buffer().tostring() + + def testDecodeWrongWireFormat(self): + """Test what happens when wrong wire format found in protobuf.""" + class ExpectedProto(messages.Message): + value = messages.StringField(1) + + class WrongVariant(messages.Message): + value = messages.IntegerField(1) + + original = WrongVariant() + original.value = 10 + self.assertErrorIs(messages.DecodeError, + 'Expected wire type STRING but found NUMERIC', + protobuf.decode_message, + ExpectedProto, + protobuf.encode_message(original)) + + def testDecodeBadWireType(self): + """Test what happens when non-existant wire type found in protobuf.""" + # Message has tag 1, type 3 which does not exist. + bad_wire_type_message = chr((1 << protobuf._WIRE_TYPE_BITS) | 3) + + self.assertErrorIs(messages.DecodeError, + 'No such wire type 3', + protobuf.decode_message, + test_util.OptionalMessage, + bad_wire_type_message) + + def testUnexpectedTagBelowOne(self): + """Test that completely invalid tags generate an error.""" + # Message has tag 0, type NUMERIC. + invalid_tag_message = chr(protobuf._Encoder.NUMERIC) + + self.assertErrorIs(messages.DecodeError, + 'Invalid tag value 0', + protobuf.decode_message, + test_util.OptionalMessage, + invalid_tag_message) + + def testProtocolBufferDecodeError(self): + """Test what happens when there a ProtocolBufferDecodeError. + + This is what happens when the underlying ProtocolBuffer library raises + it's own decode error. + """ + # Message has tag 1, type DOUBLE, missing value. + truncated_message = ( + chr((1 << protobuf._WIRE_TYPE_BITS) | protobuf._Encoder.DOUBLE)) + + self.assertErrorIs(messages.DecodeError, + 'Decoding error: truncated', + protobuf.decode_message, + test_util.OptionalMessage, + truncated_message) + + def testProtobufUnrecognizedField(self): + """Test that unrecognized fields are serialized and can be accessed.""" + decoded = protobuf.decode_message(test_util.OptionalMessage, + self.unexpected_tag_message) + self.assertEquals(1, len(decoded.all_unrecognized_fields())) + self.assertEquals(15, decoded.all_unrecognized_fields()[0]) + self.assertEquals((5, messages.Variant.INT64), + decoded.get_unrecognized_field_info(15)) + + def testUnrecognizedFieldWrongFormat(self): + """Test that unrecognized fields in the wrong format are skipped.""" + + class SimpleMessage(messages.Message): + value = messages.IntegerField(1) + + message = SimpleMessage(value=3) + message.set_unrecognized_field('from_json', 'test', messages.Variant.STRING) + + encoded = protobuf.encode_message(message) + expected = ( + chr((1 << protobuf._WIRE_TYPE_BITS) | protobuf._Encoder.NUMERIC) + + chr(3)) + self.assertEquals(encoded, expected) + + def testProtobufDecodeDateTimeMessage(self): + """Test what happens when decoding a DateTimeMessage.""" + + nested = NestedDateTimeMessage() + nested.value = message_types.DateTimeMessage(milliseconds=2500) + value = protobuf.decode_message(HasDateTimeMessage, + protobuf.encode_message(nested)).value + self.assertEqual(datetime.datetime(1970, 1, 1, 0, 0, 2, 500000), value) + + def testProtobufDecodeDateTimeMessageWithTimeZone(self): + """Test what happens when decoding a DateTimeMessage with a time zone.""" + nested = NestedDateTimeMessage() + nested.value = message_types.DateTimeMessage(milliseconds=12345678, + time_zone_offset=60) + value = protobuf.decode_message(HasDateTimeMessage, + protobuf.encode_message(nested)).value + self.assertEqual(datetime.datetime(1970, 1, 1, 3, 25, 45, 678000, + tzinfo=util.TimeZoneOffset(60)), + value) + + def testProtobufEncodeDateTimeMessage(self): + """Test what happens when encoding a DateTimeField.""" + mine = HasDateTimeMessage(value=datetime.datetime(1970, 1, 1)) + nested = NestedDateTimeMessage() + nested.value = message_types.DateTimeMessage(milliseconds=0) + + my_encoded = protobuf.encode_message(mine) + encoded = protobuf.encode_message(nested) + self.assertEquals(my_encoded, encoded) + + def testProtobufEncodeDateTimeMessageWithTimeZone(self): + """Test what happens when encoding a DateTimeField with a time zone.""" + for tz_offset in (30, -30, 8 * 60, 0): + mine = HasDateTimeMessage(value=datetime.datetime( + 1970, 1, 1, tzinfo=util.TimeZoneOffset(tz_offset))) + nested = NestedDateTimeMessage() + nested.value = message_types.DateTimeMessage( + milliseconds=0, time_zone_offset=tz_offset) + + my_encoded = protobuf.encode_message(mine) + encoded = protobuf.encode_message(nested) + self.assertEquals(my_encoded, encoded) + + +def main(): + unittest.main() + + +if __name__ == '__main__': + main() diff --git a/endpoints/internal/protorpc/protojson.py b/endpoints/internal/protorpc/protojson.py new file mode 100644 index 0000000..8e2c94e --- /dev/null +++ b/endpoints/internal/protorpc/protojson.py @@ -0,0 +1,363 @@ +#!/usr/bin/env python +# +# Copyright 2010 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""JSON support for message types. + +Public classes: + MessageJSONEncoder: JSON encoder for message objects. + +Public functions: + encode_message: Encodes a message in to a JSON string. + decode_message: Merge from a JSON string in to a message. +""" +import six + +__author__ = 'rafek@google.com (Rafe Kaplan)' + +import base64 +import binascii +import logging + +from . import message_types +from . import messages +from . import util + +__all__ = [ + 'ALTERNATIVE_CONTENT_TYPES', + 'CONTENT_TYPE', + 'MessageJSONEncoder', + 'encode_message', + 'decode_message', + 'ProtoJson', +] + + +def _load_json_module(): + """Try to load a valid json module. + + There are more than one json modules that might be installed. They are + mostly compatible with one another but some versions may be different. + This function attempts to load various json modules in a preferred order. + It does a basic check to guess if a loaded version of json is compatible. + + Returns: + Compatible json module. + + Raises: + ImportError if there are no json modules or the loaded json module is + not compatible with ProtoRPC. + """ + first_import_error = None + for module_name in ['json', + 'simplejson']: + try: + module = __import__(module_name, {}, {}, 'json') + if not hasattr(module, 'JSONEncoder'): + message = ('json library "%s" is not compatible with ProtoRPC' % + module_name) + logging.warning(message) + raise ImportError(message) + else: + return module + except ImportError as err: + if not first_import_error: + first_import_error = err + + logging.error('Must use valid json library (Python 2.6 json or simplejson)') + raise first_import_error +json = _load_json_module() + + +# TODO: Rename this to MessageJsonEncoder. +class MessageJSONEncoder(json.JSONEncoder): + """Message JSON encoder class. + + Extension of JSONEncoder that can build JSON from a message object. + """ + + def __init__(self, protojson_protocol=None, **kwargs): + """Constructor. + + Args: + protojson_protocol: ProtoJson instance. + """ + super(MessageJSONEncoder, self).__init__(**kwargs) + self.__protojson_protocol = protojson_protocol or ProtoJson.get_default() + + def default(self, value): + """Return dictionary instance from a message object. + + Args: + value: Value to get dictionary for. If not encodable, will + call superclasses default method. + """ + if isinstance(value, messages.Enum): + return str(value) + + if six.PY3 and isinstance(value, bytes): + return value.decode('utf8') + + if isinstance(value, messages.Message): + result = {} + for field in value.all_fields(): + item = value.get_assigned_value(field.name) + if item not in (None, [], ()): + result[field.name] = self.__protojson_protocol.encode_field( + field, item) + # Handle unrecognized fields, so they're included when a message is + # decoded then encoded. + for unknown_key in value.all_unrecognized_fields(): + unrecognized_field, _ = value.get_unrecognized_field_info(unknown_key) + result[unknown_key] = unrecognized_field + return result + else: + return super(MessageJSONEncoder, self).default(value) + + +class ProtoJson(object): + """ProtoRPC JSON implementation class. + + Implementation of JSON based protocol used for serializing and deserializing + message objects. Instances of remote.ProtocolConfig constructor or used with + remote.Protocols.add_protocol. See the remote.py module for more details. + """ + + CONTENT_TYPE = 'application/json' + ALTERNATIVE_CONTENT_TYPES = [ + 'application/x-javascript', + 'text/javascript', + 'text/x-javascript', + 'text/x-json', + 'text/json', + ] + + def encode_field(self, field, value): + """Encode a python field value to a JSON value. + + Args: + field: A ProtoRPC field instance. + value: A python value supported by field. + + Returns: + A JSON serializable value appropriate for field. + """ + if isinstance(field, messages.BytesField): + if field.repeated: + value = [base64.b64encode(byte) for byte in value] + else: + value = base64.b64encode(value) + elif isinstance(field, message_types.DateTimeField): + # DateTimeField stores its data as a RFC 3339 compliant string. + if field.repeated: + value = [i.isoformat() for i in value] + else: + value = value.isoformat() + return value + + def encode_message(self, message): + """Encode Message instance to JSON string. + + Args: + Message instance to encode in to JSON string. + + Returns: + String encoding of Message instance in protocol JSON format. + + Raises: + messages.ValidationError if message is not initialized. + """ + message.check_initialized() + + return json.dumps(message, cls=MessageJSONEncoder, protojson_protocol=self) + + def decode_message(self, message_type, encoded_message): + """Merge JSON structure to Message instance. + + Args: + message_type: Message to decode data to. + encoded_message: JSON encoded version of message. + + Returns: + Decoded instance of message_type. + + Raises: + ValueError: If encoded_message is not valid JSON. + messages.ValidationError if merged message is not initialized. + """ + dictionary = json.loads(encoded_message) if encoded_message.strip() else {} + message = self.__decode_dictionary(message_type, dictionary) + message.check_initialized() + return message + + def __find_variant(self, value): + """Find the messages.Variant type that describes this value. + + Args: + value: The value whose variant type is being determined. + + Returns: + The messages.Variant value that best describes value's type, or None if + it's a type we don't know how to handle. + """ + if isinstance(value, bool): + return messages.Variant.BOOL + elif isinstance(value, six.integer_types): + return messages.Variant.INT64 + elif isinstance(value, float): + return messages.Variant.DOUBLE + elif isinstance(value, six.string_types): + return messages.Variant.STRING + elif isinstance(value, (list, tuple)): + # Find the most specific variant that covers all elements. + variant_priority = [None, messages.Variant.INT64, messages.Variant.DOUBLE, + messages.Variant.STRING] + chosen_priority = 0 + for v in value: + variant = self.__find_variant(v) + try: + priority = variant_priority.index(variant) + except IndexError: + priority = -1 + if priority > chosen_priority: + chosen_priority = priority + return variant_priority[chosen_priority] + # Unrecognized type. + return None + + def __decode_dictionary(self, message_type, dictionary): + """Merge dictionary in to message. + + Args: + message: Message to merge dictionary in to. + dictionary: Dictionary to extract information from. Dictionary + is as parsed from JSON. Nested objects will also be dictionaries. + """ + message = message_type() + for key, value in six.iteritems(dictionary): + if value is None: + try: + message.reset(key) + except AttributeError: + pass # This is an unrecognized field, skip it. + continue + + try: + field = message.field_by_name(key) + except KeyError: + # Save unknown values. + variant = self.__find_variant(value) + if variant: + if key.isdigit(): + key = int(key) + message.set_unrecognized_field(key, value, variant) + else: + logging.warning('No variant found for unrecognized field: %s', key) + continue + + # Normalize values in to a list. + if isinstance(value, list): + if not value: + continue + else: + value = [value] + + valid_value = [] + for item in value: + valid_value.append(self.decode_field(field, item)) + + if field.repeated: + existing_value = getattr(message, field.name) + setattr(message, field.name, valid_value) + else: + setattr(message, field.name, valid_value[-1]) + return message + + def decode_field(self, field, value): + """Decode a JSON value to a python value. + + Args: + field: A ProtoRPC field instance. + value: A serialized JSON value. + + Return: + A Python value compatible with field. + """ + if isinstance(field, messages.EnumField): + try: + return field.type(value) + except TypeError: + raise messages.DecodeError('Invalid enum value "%s"' % (value or '')) + + elif isinstance(field, messages.BytesField): + try: + return base64.b64decode(value) + except (binascii.Error, TypeError) as err: + raise messages.DecodeError('Base64 decoding error: %s' % err) + + elif isinstance(field, message_types.DateTimeField): + try: + return util.decode_datetime(value) + except ValueError as err: + raise messages.DecodeError(err) + + elif (isinstance(field, messages.MessageField) and + issubclass(field.type, messages.Message)): + return self.__decode_dictionary(field.type, value) + + elif (isinstance(field, messages.FloatField) and + isinstance(value, (six.integer_types, six.string_types))): + try: + return float(value) + except: + pass + + elif (isinstance(field, messages.IntegerField) and + isinstance(value, six.string_types)): + try: + return int(value) + except: + pass + + return value + + @staticmethod + def get_default(): + """Get default instanceof ProtoJson.""" + try: + return ProtoJson.__default + except AttributeError: + ProtoJson.__default = ProtoJson() + return ProtoJson.__default + + @staticmethod + def set_default(protocol): + """Set the default instance of ProtoJson. + + Args: + protocol: A ProtoJson instance. + """ + if not isinstance(protocol, ProtoJson): + raise TypeError('Expected protocol of type ProtoJson') + ProtoJson.__default = protocol + +CONTENT_TYPE = ProtoJson.CONTENT_TYPE + +ALTERNATIVE_CONTENT_TYPES = ProtoJson.ALTERNATIVE_CONTENT_TYPES + +encode_message = ProtoJson.get_default().encode_message + +decode_message = ProtoJson.get_default().decode_message diff --git a/endpoints/internal/protorpc/protojson_test.py b/endpoints/internal/protorpc/protojson_test.py new file mode 100644 index 0000000..b71f93f --- /dev/null +++ b/endpoints/internal/protorpc/protojson_test.py @@ -0,0 +1,565 @@ +#!/usr/bin/env python +# +# Copyright 2010 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""Tests for protorpc.protojson.""" + +__author__ = 'rafek@google.com (Rafe Kaplan)' + + +import datetime +import imp +import sys +import unittest + +from protorpc import message_types +from protorpc import messages +from protorpc import protojson +from protorpc import test_util + +try: + import json +except ImportError: + import simplejson as json + + +class CustomField(messages.MessageField): + """Custom MessageField class.""" + + type = int + message_type = message_types.VoidMessage + + def __init__(self, number, **kwargs): + super(CustomField, self).__init__(self.message_type, number, **kwargs) + + def value_to_message(self, value): + return self.message_type() + + +class MyMessage(messages.Message): + """Test message containing various types.""" + + class Color(messages.Enum): + + RED = 1 + GREEN = 2 + BLUE = 3 + + class Nested(messages.Message): + + nested_value = messages.StringField(1) + + a_string = messages.StringField(2) + an_integer = messages.IntegerField(3) + a_float = messages.FloatField(4) + a_boolean = messages.BooleanField(5) + an_enum = messages.EnumField(Color, 6) + a_nested = messages.MessageField(Nested, 7) + a_repeated = messages.IntegerField(8, repeated=True) + a_repeated_float = messages.FloatField(9, repeated=True) + a_datetime = message_types.DateTimeField(10) + a_repeated_datetime = message_types.DateTimeField(11, repeated=True) + a_custom = CustomField(12) + a_repeated_custom = CustomField(13, repeated=True) + + +class ModuleInterfaceTest(test_util.ModuleInterfaceTest, + test_util.TestCase): + + MODULE = protojson + + +# TODO(rafek): Convert this test to the compliance test in test_util. +class ProtojsonTest(test_util.TestCase, + test_util.ProtoConformanceTestBase): + """Test JSON encoding and decoding.""" + + PROTOLIB = protojson + + def CompareEncoded(self, expected_encoded, actual_encoded): + """JSON encoding will be laundered to remove string differences.""" + self.assertEquals(json.loads(expected_encoded), + json.loads(actual_encoded)) + + encoded_empty_message = '{}' + + encoded_partial = """{ + "double_value": 1.23, + "int64_value": -100000000000, + "int32_value": 1020, + "string_value": "a string", + "enum_value": "VAL2" + } + """ + + encoded_full = """{ + "double_value": 1.23, + "float_value": -2.5, + "int64_value": -100000000000, + "uint64_value": 102020202020, + "int32_value": 1020, + "bool_value": true, + "string_value": "a string\u044f", + "bytes_value": "YSBieXRlc//+", + "enum_value": "VAL2" + } + """ + + encoded_repeated = """{ + "double_value": [1.23, 2.3], + "float_value": [-2.5, 0.5], + "int64_value": [-100000000000, 20], + "uint64_value": [102020202020, 10], + "int32_value": [1020, 718], + "bool_value": [true, false], + "string_value": ["a string\u044f", "another string"], + "bytes_value": ["YSBieXRlc//+", "YW5vdGhlciBieXRlcw=="], + "enum_value": ["VAL2", "VAL1"] + } + """ + + encoded_nested = """{ + "nested": { + "a_value": "a string" + } + } + """ + + encoded_repeated_nested = """{ + "repeated_nested": [{"a_value": "a string"}, + {"a_value": "another string"}] + } + """ + + unexpected_tag_message = '{"unknown": "value"}' + + encoded_default_assigned = '{"a_value": "a default"}' + + encoded_nested_empty = '{"nested": {}}' + + encoded_repeated_nested_empty = '{"repeated_nested": [{}, {}]}' + + encoded_extend_message = '{"int64_value": [400, 50, 6000]}' + + encoded_string_types = '{"string_value": "Latin"}' + + encoded_invalid_enum = '{"enum_value": "undefined"}' + + def testConvertIntegerToFloat(self): + """Test that integers passed in to float fields are converted. + + This is necessary because JSON outputs integers for numbers with 0 decimals. + """ + message = protojson.decode_message(MyMessage, '{"a_float": 10}') + + self.assertTrue(isinstance(message.a_float, float)) + self.assertEquals(10.0, message.a_float) + + def testConvertStringToNumbers(self): + """Test that strings passed to integer fields are converted.""" + message = protojson.decode_message(MyMessage, + """{"an_integer": "10", + "a_float": "3.5", + "a_repeated": ["1", "2"], + "a_repeated_float": ["1.5", "2", 10] + }""") + + self.assertEquals(MyMessage(an_integer=10, + a_float=3.5, + a_repeated=[1, 2], + a_repeated_float=[1.5, 2.0, 10.0]), + message) + + def testWrongTypeAssignment(self): + """Test when wrong type is assigned to a field.""" + self.assertRaises(messages.ValidationError, + protojson.decode_message, + MyMessage, '{"a_string": 10}') + self.assertRaises(messages.ValidationError, + protojson.decode_message, + MyMessage, '{"an_integer": 10.2}') + self.assertRaises(messages.ValidationError, + protojson.decode_message, + MyMessage, '{"an_integer": "10.2"}') + + def testNumericEnumeration(self): + """Test that numbers work for enum values.""" + message = protojson.decode_message(MyMessage, '{"an_enum": 2}') + + expected_message = MyMessage() + expected_message.an_enum = MyMessage.Color.GREEN + + self.assertEquals(expected_message, message) + + def testNumericEnumerationNegativeTest(self): + """Test with an invalid number for the enum value.""" + self.assertRaisesRegexp( + messages.DecodeError, + 'Invalid enum value "89"', + protojson.decode_message, + MyMessage, + '{"an_enum": 89}') + + def testAlphaEnumeration(self): + """Test that alpha enum values work.""" + message = protojson.decode_message(MyMessage, '{"an_enum": "RED"}') + + expected_message = MyMessage() + expected_message.an_enum = MyMessage.Color.RED + + self.assertEquals(expected_message, message) + + def testAlphaEnumerationNegativeTest(self): + """The alpha enum value is invalid.""" + self.assertRaisesRegexp( + messages.DecodeError, + 'Invalid enum value "IAMINVALID"', + protojson.decode_message, + MyMessage, + '{"an_enum": "IAMINVALID"}') + + def testEnumerationNegativeTestWithEmptyString(self): + """The enum value is an empty string.""" + self.assertRaisesRegexp( + messages.DecodeError, + 'Invalid enum value ""', + protojson.decode_message, + MyMessage, + '{"an_enum": ""}') + + def testNullValues(self): + """Test that null values overwrite existing values.""" + self.assertEquals(MyMessage(), + protojson.decode_message(MyMessage, + ('{"an_integer": null,' + ' "a_nested": null,' + ' "an_enum": null' + '}'))) + + def testEmptyList(self): + """Test that empty lists are ignored.""" + self.assertEquals(MyMessage(), + protojson.decode_message(MyMessage, + '{"a_repeated": []}')) + + def testNotJSON(self): + """Test error when string is not valid JSON.""" + self.assertRaises(ValueError, + protojson.decode_message, MyMessage, '{this is not json}') + + def testDoNotEncodeStrangeObjects(self): + """Test trying to encode a strange object. + + The main purpose of this test is to complete coverage. It ensures that + the default behavior of the JSON encoder is preserved when someone tries to + serialized an unexpected type. + """ + class BogusObject(object): + + def check_initialized(self): + pass + + self.assertRaises(TypeError, + protojson.encode_message, + BogusObject()) + + def testMergeEmptyString(self): + """Test merging the empty or space only string.""" + message = protojson.decode_message(test_util.OptionalMessage, '') + self.assertEquals(test_util.OptionalMessage(), message) + + message = protojson.decode_message(test_util.OptionalMessage, ' ') + self.assertEquals(test_util.OptionalMessage(), message) + + def testMeregeInvalidEmptyMessage(self): + self.assertRaisesWithRegexpMatch(messages.ValidationError, + 'Message NestedMessage is missing ' + 'required field a_value', + self.PROTOLIB.decode_message, + test_util.NestedMessage, + '') + + def testProtojsonUnrecognizedFieldName(self): + """Test that unrecognized fields are saved and can be accessed.""" + decoded = protojson.decode_message(MyMessage, + ('{"an_integer": 1, "unknown_val": 2}')) + self.assertEquals(decoded.an_integer, 1) + self.assertEquals(1, len(decoded.all_unrecognized_fields())) + self.assertEquals('unknown_val', decoded.all_unrecognized_fields()[0]) + self.assertEquals((2, messages.Variant.INT64), + decoded.get_unrecognized_field_info('unknown_val')) + + def testProtojsonUnrecognizedFieldNumber(self): + """Test that unrecognized fields are saved and can be accessed.""" + decoded = protojson.decode_message( + MyMessage, + '{"an_integer": 1, "1001": "unknown", "-123": "negative", ' + '"456_mixed": 2}') + self.assertEquals(decoded.an_integer, 1) + self.assertEquals(3, len(decoded.all_unrecognized_fields())) + self.assertTrue(1001 in decoded.all_unrecognized_fields()) + self.assertEquals(('unknown', messages.Variant.STRING), + decoded.get_unrecognized_field_info(1001)) + self.assertTrue('-123' in decoded.all_unrecognized_fields()) + self.assertEquals(('negative', messages.Variant.STRING), + decoded.get_unrecognized_field_info('-123')) + self.assertTrue('456_mixed' in decoded.all_unrecognized_fields()) + self.assertEquals((2, messages.Variant.INT64), + decoded.get_unrecognized_field_info('456_mixed')) + + def testProtojsonUnrecognizedNull(self): + """Test that unrecognized fields that are None are skipped.""" + decoded = protojson.decode_message( + MyMessage, + '{"an_integer": 1, "unrecognized_null": null}') + self.assertEquals(decoded.an_integer, 1) + self.assertEquals(decoded.all_unrecognized_fields(), []) + + def testUnrecognizedFieldVariants(self): + """Test that unrecognized fields are mapped to the right variants.""" + for encoded, expected_variant in ( + ('{"an_integer": 1, "unknown_val": 2}', messages.Variant.INT64), + ('{"an_integer": 1, "unknown_val": 2.0}', messages.Variant.DOUBLE), + ('{"an_integer": 1, "unknown_val": "string value"}', + messages.Variant.STRING), + ('{"an_integer": 1, "unknown_val": [1, 2, 3]}', messages.Variant.INT64), + ('{"an_integer": 1, "unknown_val": [1, 2.0, 3]}', + messages.Variant.DOUBLE), + ('{"an_integer": 1, "unknown_val": [1, "foo", 3]}', + messages.Variant.STRING), + ('{"an_integer": 1, "unknown_val": true}', messages.Variant.BOOL)): + decoded = protojson.decode_message(MyMessage, encoded) + self.assertEquals(decoded.an_integer, 1) + self.assertEquals(1, len(decoded.all_unrecognized_fields())) + self.assertEquals('unknown_val', decoded.all_unrecognized_fields()[0]) + _, decoded_variant = decoded.get_unrecognized_field_info('unknown_val') + self.assertEquals(expected_variant, decoded_variant) + + def testDecodeDateTime(self): + for datetime_string, datetime_vals in ( + ('2012-09-30T15:31:50.262', (2012, 9, 30, 15, 31, 50, 262000)), + ('2012-09-30T15:31:50', (2012, 9, 30, 15, 31, 50, 0))): + message = protojson.decode_message( + MyMessage, '{"a_datetime": "%s"}' % datetime_string) + expected_message = MyMessage( + a_datetime=datetime.datetime(*datetime_vals)) + + self.assertEquals(expected_message, message) + + def testDecodeInvalidDateTime(self): + self.assertRaises(messages.DecodeError, protojson.decode_message, + MyMessage, '{"a_datetime": "invalid"}') + + def testEncodeDateTime(self): + for datetime_string, datetime_vals in ( + ('2012-09-30T15:31:50.262000', (2012, 9, 30, 15, 31, 50, 262000)), + ('2012-09-30T15:31:50.262123', (2012, 9, 30, 15, 31, 50, 262123)), + ('2012-09-30T15:31:50', (2012, 9, 30, 15, 31, 50, 0))): + decoded_message = protojson.encode_message( + MyMessage(a_datetime=datetime.datetime(*datetime_vals))) + expected_decoding = '{"a_datetime": "%s"}' % datetime_string + self.CompareEncoded(expected_decoding, decoded_message) + + def testDecodeRepeatedDateTime(self): + message = protojson.decode_message( + MyMessage, + '{"a_repeated_datetime": ["2012-09-30T15:31:50.262", ' + '"2010-01-21T09:52:00", "2000-01-01T01:00:59.999999"]}') + expected_message = MyMessage( + a_repeated_datetime=[ + datetime.datetime(2012, 9, 30, 15, 31, 50, 262000), + datetime.datetime(2010, 1, 21, 9, 52), + datetime.datetime(2000, 1, 1, 1, 0, 59, 999999)]) + + self.assertEquals(expected_message, message) + + def testDecodeCustom(self): + message = protojson.decode_message(MyMessage, '{"a_custom": 1}') + self.assertEquals(MyMessage(a_custom=1), message) + + def testDecodeInvalidCustom(self): + self.assertRaises(messages.ValidationError, protojson.decode_message, + MyMessage, '{"a_custom": "invalid"}') + + def testEncodeCustom(self): + decoded_message = protojson.encode_message(MyMessage(a_custom=1)) + self.CompareEncoded('{"a_custom": 1}', decoded_message) + + def testDecodeRepeatedCustom(self): + message = protojson.decode_message( + MyMessage, '{"a_repeated_custom": [1, 2, 3]}') + self.assertEquals(MyMessage(a_repeated_custom=[1, 2, 3]), message) + + def testDecodeBadBase64BytesField(self): + """Test decoding improperly encoded base64 bytes value.""" + self.assertRaisesWithRegexpMatch( + messages.DecodeError, + 'Base64 decoding error: Incorrect padding', + protojson.decode_message, + test_util.OptionalMessage, + '{"bytes_value": "abcdefghijklmnopq"}') + + +class CustomProtoJson(protojson.ProtoJson): + + def encode_field(self, field, value): + return '{encoded}' + value + + def decode_field(self, field, value): + return '{decoded}' + value + + +class CustomProtoJsonTest(test_util.TestCase): + """Tests for serialization overriding functionality.""" + + def setUp(self): + self.protojson = CustomProtoJson() + + def testEncode(self): + self.assertEqual(u'{"a_string": "{encoded}xyz"}', + self.protojson.encode_message(MyMessage(a_string=u'xyz'))) + + def testDecode(self): + self.assertEqual( + MyMessage(a_string=u'{decoded}xyz'), + self.protojson.decode_message(MyMessage, u'{"a_string": "xyz"}')) + + def testDecodeEmptyMessage(self): + self.assertEqual( + MyMessage(a_string=u'{decoded}'), + self.protojson.decode_message(MyMessage, u'{"a_string": ""}')) + + def testDefault(self): + self.assertTrue(protojson.ProtoJson.get_default(), + protojson.ProtoJson.get_default()) + + instance = CustomProtoJson() + protojson.ProtoJson.set_default(instance) + self.assertTrue(instance is protojson.ProtoJson.get_default()) + + +class InvalidJsonModule(object): + pass + + +class ValidJsonModule(object): + class JSONEncoder(object): + pass + + +class TestJsonDependencyLoading(test_util.TestCase): + """Test loading various implementations of json.""" + + def get_import(self): + """Get __import__ method. + + Returns: + The current __import__ method. + """ + if isinstance(__builtins__, dict): + return __builtins__['__import__'] + else: + return __builtins__.__import__ + + def set_import(self, new_import): + """Set __import__ method. + + Args: + new_import: Function to replace __import__. + """ + if isinstance(__builtins__, dict): + __builtins__['__import__'] = new_import + else: + __builtins__.__import__ = new_import + + def setUp(self): + """Save original import function.""" + self.simplejson = sys.modules.pop('simplejson', None) + self.json = sys.modules.pop('json', None) + self.original_import = self.get_import() + def block_all_jsons(name, *args, **kwargs): + if 'json' in name: + if name in sys.modules: + module = sys.modules[name] + module.name = name + return module + raise ImportError('Unable to find %s' % name) + else: + return self.original_import(name, *args, **kwargs) + self.set_import(block_all_jsons) + + def tearDown(self): + """Restore original import functions and any loaded modules.""" + + def reset_module(name, module): + if module: + sys.modules[name] = module + else: + sys.modules.pop(name, None) + reset_module('simplejson', self.simplejson) + reset_module('json', self.json) + imp.reload(protojson) + + def testLoadProtojsonWithValidJsonModule(self): + """Test loading protojson module with a valid json dependency.""" + sys.modules['json'] = ValidJsonModule + + # This will cause protojson to reload with the default json module + # instead of simplejson. + imp.reload(protojson) + self.assertEquals('json', protojson.json.name) + + def testLoadProtojsonWithSimplejsonModule(self): + """Test loading protojson module with simplejson dependency.""" + sys.modules['simplejson'] = ValidJsonModule + + # This will cause protojson to reload with the default json module + # instead of simplejson. + imp.reload(protojson) + self.assertEquals('simplejson', protojson.json.name) + + def testLoadProtojsonWithInvalidJsonModule(self): + """Loading protojson module with an invalid json defaults to simplejson.""" + sys.modules['json'] = InvalidJsonModule + sys.modules['simplejson'] = ValidJsonModule + + # Ignore bad module and default back to simplejson. + imp.reload(protojson) + self.assertEquals('simplejson', protojson.json.name) + + def testLoadProtojsonWithInvalidJsonModuleAndNoSimplejson(self): + """Loading protojson module with invalid json and no simplejson.""" + sys.modules['json'] = InvalidJsonModule + + # Bad module without simplejson back raises errors. + self.assertRaisesWithRegexpMatch( + ImportError, + 'json library "json" is not compatible with ProtoRPC', + imp.reload, + protojson) + + def testLoadProtojsonWithNoJsonModules(self): + """Loading protojson module with invalid json and no simplejson.""" + # No json modules raise the first exception. + self.assertRaisesWithRegexpMatch( + ImportError, + 'Unable to find json', + imp.reload, + protojson) + + +if __name__ == '__main__': + unittest.main() diff --git a/endpoints/internal/protorpc/protorpc_test.proto b/endpoints/internal/protorpc/protorpc_test.proto new file mode 100644 index 0000000..50d76e0 --- /dev/null +++ b/endpoints/internal/protorpc/protorpc_test.proto @@ -0,0 +1,83 @@ +// Copyright 2010 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package protorpc; + +// Message used to nest inside another message. +message NestedMessage { + required string a_value = 1; +} + +// Message that contains nested messages. +message HasNestedMessage { + optional NestedMessage nested = 1; + repeated NestedMessage repeated_nested = 2; +} + +message HasDefault { + optional string a_value = 1 [default="a default"]; +} + +// Message that contains all variants as optional fields. +message OptionalMessage { + enum SimpleEnum { + VAL1 = 1; + VAL2 = 2; + } + + optional double double_value = 1; + optional float float_value = 2; + optional int64 int64_value = 3; + optional uint64 uint64_value = 4; + optional int32 int32_value = 5; + optional bool bool_value = 6; + optional string string_value = 7; + optional bytes bytes_value = 8; + optional SimpleEnum enum_value = 10; + + // TODO(rafek): Add support for these variants. + // optional uint32 uint32_value = 9; + // optional sint32 sint32_value = 11; + // optional sint64 sint64_value = 12; +} + +// Message that contains all variants as repeated fields. +message RepeatedMessage { + enum SimpleEnum { + VAL1 = 1; + VAL2 = 2; + } + + repeated double double_value = 1; + repeated float float_value = 2; + repeated int64 int64_value = 3; + repeated uint64 uint64_value = 4; + repeated int32 int32_value = 5; + repeated bool bool_value = 6; + repeated string string_value = 7; + repeated bytes bytes_value = 8; + repeated SimpleEnum enum_value = 10; + + // TODO(rafek): Add support for these variants. + // repeated uint32 uint32_value = 9; + // repeated sint32 sint32_value = 11; + // repeated sint64 sint64_value = 12; +} + +// Message that has nested message with all optional fields. +message HasOptionalNestedMessage { + optional OptionalMessage nested = 1; + repeated OptionalMessage repeated_nested = 2; +} diff --git a/endpoints/internal/protorpc/protorpc_test_pb2.py b/endpoints/internal/protorpc/protorpc_test_pb2.py new file mode 100644 index 0000000..1dc3852 --- /dev/null +++ b/endpoints/internal/protorpc/protorpc_test_pb2.py @@ -0,0 +1,405 @@ +# Generated by the protocol buffer compiler. DO NOT EDIT (except the imports)! + +# Replace auto generated imports with .non_sdk_imports manually! +# Do the replacement and copy this comment everytime! +from .non_sdk_imports import descriptor +from .non_sdk_imports import message +from .non_sdk_imports import reflection +from .non_sdk_imports import descriptor_pb2 +import six + +# @@protoc_insertion_point(imports) + + + +DESCRIPTOR = descriptor.FileDescriptor( + name='protorpc_test.proto', + package='protorpc', + serialized_pb='\n\x13protorpc_test.proto\x12\x08protorpc\" \n\rNestedMessage\x12\x0f\n\x07\x61_value\x18\x01 \x02(\t\"m\n\x10HasNestedMessage\x12\'\n\x06nested\x18\x01 \x01(\x0b\x32\x17.protorpc.NestedMessage\x12\x30\n\x0frepeated_nested\x18\x02 \x03(\x0b\x32\x17.protorpc.NestedMessage\"(\n\nHasDefault\x12\x1a\n\x07\x61_value\x18\x01 \x01(\t:\ta default\"\x97\x02\n\x0fOptionalMessage\x12\x14\n\x0c\x64ouble_value\x18\x01 \x01(\x01\x12\x13\n\x0b\x66loat_value\x18\x02 \x01(\x02\x12\x13\n\x0bint64_value\x18\x03 \x01(\x03\x12\x14\n\x0cuint64_value\x18\x04 \x01(\x04\x12\x13\n\x0bint32_value\x18\x05 \x01(\x05\x12\x12\n\nbool_value\x18\x06 \x01(\x08\x12\x14\n\x0cstring_value\x18\x07 \x01(\t\x12\x13\n\x0b\x62ytes_value\x18\x08 \x01(\x0c\x12\x38\n\nenum_value\x18\n \x01(\x0e\x32$.protorpc.OptionalMessage.SimpleEnum\" \n\nSimpleEnum\x12\x08\n\x04VAL1\x10\x01\x12\x08\n\x04VAL2\x10\x02\"\x97\x02\n\x0fRepeatedMessage\x12\x14\n\x0c\x64ouble_value\x18\x01 \x03(\x01\x12\x13\n\x0b\x66loat_value\x18\x02 \x03(\x02\x12\x13\n\x0bint64_value\x18\x03 \x03(\x03\x12\x14\n\x0cuint64_value\x18\x04 \x03(\x04\x12\x13\n\x0bint32_value\x18\x05 \x03(\x05\x12\x12\n\nbool_value\x18\x06 \x03(\x08\x12\x14\n\x0cstring_value\x18\x07 \x03(\t\x12\x13\n\x0b\x62ytes_value\x18\x08 \x03(\x0c\x12\x38\n\nenum_value\x18\n \x03(\x0e\x32$.protorpc.RepeatedMessage.SimpleEnum\" \n\nSimpleEnum\x12\x08\n\x04VAL1\x10\x01\x12\x08\n\x04VAL2\x10\x02\"y\n\x18HasOptionalNestedMessage\x12)\n\x06nested\x18\x01 \x01(\x0b\x32\x19.protorpc.OptionalMessage\x12\x32\n\x0frepeated_nested\x18\x02 \x03(\x0b\x32\x19.protorpc.OptionalMessage') + + + +_OPTIONALMESSAGE_SIMPLEENUM = descriptor.EnumDescriptor( + name='SimpleEnum', + full_name='protorpc.OptionalMessage.SimpleEnum', + filename=None, + file=DESCRIPTOR, + values=[ + descriptor.EnumValueDescriptor( + name='VAL1', index=0, number=1, + options=None, + type=None), + descriptor.EnumValueDescriptor( + name='VAL2', index=1, number=2, + options=None, + type=None), + ], + containing_type=None, + options=None, + serialized_start=468, + serialized_end=500, +) + +_REPEATEDMESSAGE_SIMPLEENUM = descriptor.EnumDescriptor( + name='SimpleEnum', + full_name='protorpc.RepeatedMessage.SimpleEnum', + filename=None, + file=DESCRIPTOR, + values=[ + descriptor.EnumValueDescriptor( + name='VAL1', index=0, number=1, + options=None, + type=None), + descriptor.EnumValueDescriptor( + name='VAL2', index=1, number=2, + options=None, + type=None), + ], + containing_type=None, + options=None, + serialized_start=468, + serialized_end=500, +) + + +_NESTEDMESSAGE = descriptor.Descriptor( + name='NestedMessage', + full_name='protorpc.NestedMessage', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + descriptor.FieldDescriptor( + name='a_value', full_name='protorpc.NestedMessage.a_value', index=0, + number=1, type=9, cpp_type=9, label=2, + has_default_value=False, default_value=six.text_type("", "utf-8"), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + options=None, + is_extendable=False, + extension_ranges=[], + serialized_start=33, + serialized_end=65, +) + + +_HASNESTEDMESSAGE = descriptor.Descriptor( + name='HasNestedMessage', + full_name='protorpc.HasNestedMessage', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + descriptor.FieldDescriptor( + name='nested', full_name='protorpc.HasNestedMessage.nested', index=0, + number=1, type=11, cpp_type=10, label=1, + has_default_value=False, default_value=None, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None), + descriptor.FieldDescriptor( + name='repeated_nested', full_name='protorpc.HasNestedMessage.repeated_nested', index=1, + number=2, type=11, cpp_type=10, label=3, + has_default_value=False, default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + options=None, + is_extendable=False, + extension_ranges=[], + serialized_start=67, + serialized_end=176, +) + + +_HASDEFAULT = descriptor.Descriptor( + name='HasDefault', + full_name='protorpc.HasDefault', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + descriptor.FieldDescriptor( + name='a_value', full_name='protorpc.HasDefault.a_value', index=0, + number=1, type=9, cpp_type=9, label=1, + has_default_value=True, default_value=six.text_type("a default", "utf-8"), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + options=None, + is_extendable=False, + extension_ranges=[], + serialized_start=178, + serialized_end=218, +) + + +_OPTIONALMESSAGE = descriptor.Descriptor( + name='OptionalMessage', + full_name='protorpc.OptionalMessage', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + descriptor.FieldDescriptor( + name='double_value', full_name='protorpc.OptionalMessage.double_value', index=0, + number=1, type=1, cpp_type=5, label=1, + has_default_value=False, default_value=0, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None), + descriptor.FieldDescriptor( + name='float_value', full_name='protorpc.OptionalMessage.float_value', index=1, + number=2, type=2, cpp_type=6, label=1, + has_default_value=False, default_value=0, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None), + descriptor.FieldDescriptor( + name='int64_value', full_name='protorpc.OptionalMessage.int64_value', index=2, + number=3, type=3, cpp_type=2, label=1, + has_default_value=False, default_value=0, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None), + descriptor.FieldDescriptor( + name='uint64_value', full_name='protorpc.OptionalMessage.uint64_value', index=3, + number=4, type=4, cpp_type=4, label=1, + has_default_value=False, default_value=0, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None), + descriptor.FieldDescriptor( + name='int32_value', full_name='protorpc.OptionalMessage.int32_value', index=4, + number=5, type=5, cpp_type=1, label=1, + has_default_value=False, default_value=0, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None), + descriptor.FieldDescriptor( + name='bool_value', full_name='protorpc.OptionalMessage.bool_value', index=5, + number=6, type=8, cpp_type=7, label=1, + has_default_value=False, default_value=False, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None), + descriptor.FieldDescriptor( + name='string_value', full_name='protorpc.OptionalMessage.string_value', index=6, + number=7, type=9, cpp_type=9, label=1, + has_default_value=False, default_value=six.text_type("", "utf-8"), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None), + descriptor.FieldDescriptor( + name='bytes_value', full_name='protorpc.OptionalMessage.bytes_value', index=7, + number=8, type=12, cpp_type=9, label=1, + has_default_value=False, default_value="", + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None), + descriptor.FieldDescriptor( + name='enum_value', full_name='protorpc.OptionalMessage.enum_value', index=8, + number=10, type=14, cpp_type=8, label=1, + has_default_value=False, default_value=1, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + _OPTIONALMESSAGE_SIMPLEENUM, + ], + options=None, + is_extendable=False, + extension_ranges=[], + serialized_start=221, + serialized_end=500, +) + + +_REPEATEDMESSAGE = descriptor.Descriptor( + name='RepeatedMessage', + full_name='protorpc.RepeatedMessage', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + descriptor.FieldDescriptor( + name='double_value', full_name='protorpc.RepeatedMessage.double_value', index=0, + number=1, type=1, cpp_type=5, label=3, + has_default_value=False, default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None), + descriptor.FieldDescriptor( + name='float_value', full_name='protorpc.RepeatedMessage.float_value', index=1, + number=2, type=2, cpp_type=6, label=3, + has_default_value=False, default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None), + descriptor.FieldDescriptor( + name='int64_value', full_name='protorpc.RepeatedMessage.int64_value', index=2, + number=3, type=3, cpp_type=2, label=3, + has_default_value=False, default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None), + descriptor.FieldDescriptor( + name='uint64_value', full_name='protorpc.RepeatedMessage.uint64_value', index=3, + number=4, type=4, cpp_type=4, label=3, + has_default_value=False, default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None), + descriptor.FieldDescriptor( + name='int32_value', full_name='protorpc.RepeatedMessage.int32_value', index=4, + number=5, type=5, cpp_type=1, label=3, + has_default_value=False, default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None), + descriptor.FieldDescriptor( + name='bool_value', full_name='protorpc.RepeatedMessage.bool_value', index=5, + number=6, type=8, cpp_type=7, label=3, + has_default_value=False, default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None), + descriptor.FieldDescriptor( + name='string_value', full_name='protorpc.RepeatedMessage.string_value', index=6, + number=7, type=9, cpp_type=9, label=3, + has_default_value=False, default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None), + descriptor.FieldDescriptor( + name='bytes_value', full_name='protorpc.RepeatedMessage.bytes_value', index=7, + number=8, type=12, cpp_type=9, label=3, + has_default_value=False, default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None), + descriptor.FieldDescriptor( + name='enum_value', full_name='protorpc.RepeatedMessage.enum_value', index=8, + number=10, type=14, cpp_type=8, label=3, + has_default_value=False, default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + _REPEATEDMESSAGE_SIMPLEENUM, + ], + options=None, + is_extendable=False, + extension_ranges=[], + serialized_start=503, + serialized_end=782, +) + + +_HASOPTIONALNESTEDMESSAGE = descriptor.Descriptor( + name='HasOptionalNestedMessage', + full_name='protorpc.HasOptionalNestedMessage', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + descriptor.FieldDescriptor( + name='nested', full_name='protorpc.HasOptionalNestedMessage.nested', index=0, + number=1, type=11, cpp_type=10, label=1, + has_default_value=False, default_value=None, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None), + descriptor.FieldDescriptor( + name='repeated_nested', full_name='protorpc.HasOptionalNestedMessage.repeated_nested', index=1, + number=2, type=11, cpp_type=10, label=3, + has_default_value=False, default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + options=None, + is_extendable=False, + extension_ranges=[], + serialized_start=784, + serialized_end=905, +) + +_HASNESTEDMESSAGE.fields_by_name['nested'].message_type = _NESTEDMESSAGE +_HASNESTEDMESSAGE.fields_by_name['repeated_nested'].message_type = _NESTEDMESSAGE +_OPTIONALMESSAGE.fields_by_name['enum_value'].enum_type = _OPTIONALMESSAGE_SIMPLEENUM +_OPTIONALMESSAGE_SIMPLEENUM.containing_type = _OPTIONALMESSAGE; +_REPEATEDMESSAGE.fields_by_name['enum_value'].enum_type = _REPEATEDMESSAGE_SIMPLEENUM +_REPEATEDMESSAGE_SIMPLEENUM.containing_type = _REPEATEDMESSAGE; +_HASOPTIONALNESTEDMESSAGE.fields_by_name['nested'].message_type = _OPTIONALMESSAGE +_HASOPTIONALNESTEDMESSAGE.fields_by_name['repeated_nested'].message_type = _OPTIONALMESSAGE +DESCRIPTOR.message_types_by_name['NestedMessage'] = _NESTEDMESSAGE +DESCRIPTOR.message_types_by_name['HasNestedMessage'] = _HASNESTEDMESSAGE +DESCRIPTOR.message_types_by_name['HasDefault'] = _HASDEFAULT +DESCRIPTOR.message_types_by_name['OptionalMessage'] = _OPTIONALMESSAGE +DESCRIPTOR.message_types_by_name['RepeatedMessage'] = _REPEATEDMESSAGE +DESCRIPTOR.message_types_by_name['HasOptionalNestedMessage'] = _HASOPTIONALNESTEDMESSAGE + +class NestedMessage(six.with_metaclass(reflection.GeneratedProtocolMessageType, message.Message)): + DESCRIPTOR = _NESTEDMESSAGE + + # @@protoc_insertion_point(class_scope:protorpc.NestedMessage) + +class HasNestedMessage(six.with_metaclass(reflection.GeneratedProtocolMessageType, message.Message)): + DESCRIPTOR = _HASNESTEDMESSAGE + + # @@protoc_insertion_point(class_scope:protorpc.HasNestedMessage) + +class HasDefault(six.with_metaclass(reflection.GeneratedProtocolMessageType, message.Message)): + DESCRIPTOR = _HASDEFAULT + + # @@protoc_insertion_point(class_scope:protorpc.HasDefault) + +class OptionalMessage(six.with_metaclass(reflection.GeneratedProtocolMessageType, message.Message)): + DESCRIPTOR = _OPTIONALMESSAGE + + # @@protoc_insertion_point(class_scope:protorpc.OptionalMessage) + +class RepeatedMessage(six.with_metaclass(reflection.GeneratedProtocolMessageType, message.Message)): + DESCRIPTOR = _REPEATEDMESSAGE + + # @@protoc_insertion_point(class_scope:protorpc.RepeatedMessage) + +class HasOptionalNestedMessage(six.with_metaclass(reflection.GeneratedProtocolMessageType, message.Message)): + DESCRIPTOR = _HASOPTIONALNESTEDMESSAGE + + # @@protoc_insertion_point(class_scope:protorpc.HasOptionalNestedMessage) + +# @@protoc_insertion_point(module_scope) diff --git a/endpoints/internal/protorpc/protourlencode.py b/endpoints/internal/protorpc/protourlencode.py new file mode 100644 index 0000000..9f6059e --- /dev/null +++ b/endpoints/internal/protorpc/protourlencode.py @@ -0,0 +1,563 @@ +#!/usr/bin/env python +# +# Copyright 2010 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""URL encoding support for messages types. + +Protocol support for URL encoded form parameters. + +Nested Fields: + Nested fields are repesented by dot separated names. For example, consider + the following messages: + + class WebPage(Message): + + title = StringField(1) + tags = StringField(2, repeated=True) + + class WebSite(Message): + + name = StringField(1) + home = MessageField(WebPage, 2) + pages = MessageField(WebPage, 3, repeated=True) + + And consider the object: + + page = WebPage() + page.title = 'Welcome to NewSite 2010' + + site = WebSite() + site.name = 'NewSite 2010' + site.home = page + + The URL encoded representation of this constellation of objects is. + + name=NewSite+2010&home.title=Welcome+to+NewSite+2010 + + An object that exists but does not have any state can be represented with + a reference to its name alone with no value assigned to it. For example: + + page = WebSite() + page.name = 'My Empty Site' + page.home = WebPage() + + is represented as: + + name=My+Empty+Site&home= + + This represents a site with an empty uninitialized home page. + +Repeated Fields: + Repeated fields are represented by the name of and the index of each value + separated by a dash. For example, consider the following message: + + home = Page() + home.title = 'Nome' + + news = Page() + news.title = 'News' + news.tags = ['news', 'articles'] + + instance = WebSite() + instance.name = 'Super fun site' + instance.pages = [home, news, preferences] + + An instance of this message can be represented as: + + name=Super+fun+site&page-0.title=Home&pages-1.title=News&... + pages-1.tags-0=new&pages-1.tags-1=articles + +Helper classes: + + URLEncodedRequestBuilder: Used for encapsulating the logic used for building + a request message from a URL encoded RPC. +""" +import six + +__author__ = 'rafek@google.com (Rafe Kaplan)' + +import cgi +import re +import urllib + +from . import message_types +from . import messages +from . import util + +__all__ = ['CONTENT_TYPE', + 'URLEncodedRequestBuilder', + 'encode_message', + 'decode_message', + ] + +CONTENT_TYPE = 'application/x-www-form-urlencoded' + +_FIELD_NAME_REGEX = re.compile(r'^([a-zA-Z_][a-zA-Z_0-9]*)(?:-([0-9]+))?$') + + +class URLEncodedRequestBuilder(object): + """Helper that encapsulates the logic used for building URL encoded messages. + + This helper is used to map query parameters from a URL encoded RPC to a + message instance. + """ + + @util.positional(2) + def __init__(self, message, prefix=''): + """Constructor. + + Args: + message: Message instance to build from parameters. + prefix: Prefix expected at the start of valid parameters. + """ + self.__parameter_prefix = prefix + + # The empty tuple indicates the root message, which has no path. + # __messages is a full cache that makes it very easy to look up message + # instances by their paths. See make_path for details about what a path + # is. + self.__messages = {(): message} + + # This is a cache that stores paths which have been checked for + # correctness. Correctness means that an index is present for repeated + # fields on the path and absent for non-repeated fields. The cache is + # also used to check that indexes are added in the right order so that + # dicontiguous ranges of indexes are ignored. + self.__checked_indexes = set([()]) + + def make_path(self, parameter_name): + """Parse a parameter name and build a full path to a message value. + + The path of a method is a tuple of 2-tuples describing the names and + indexes within repeated fields from the root message (the message being + constructed by the builder) to an arbitrarily nested message within it. + + Each 2-tuple node of a path (name, index) is: + name: The name of the field that refers to the message instance. + index: The index within a repeated field that refers to the message + instance, None if not a repeated field. + + For example, consider: + + class VeryInner(messages.Message): + ... + + class Inner(messages.Message): + + very_inner = messages.MessageField(VeryInner, 1, repeated=True) + + class Outer(messages.Message): + + inner = messages.MessageField(Inner, 1) + + If this builder is building an instance of Outer, that instance is + referred to in the URL encoded parameters without a path. Therefore + its path is (). + + The child 'inner' is referred to by its path (('inner', None)). + + The first child of repeated field 'very_inner' on the Inner instance + is referred to by (('inner', None), ('very_inner', 0)). + + Examples: + # Correct reference to model where nation is a Message, district is + # repeated Message and county is any not repeated field type. + >>> make_path('nation.district-2.county') + (('nation', None), ('district', 2), ('county', None)) + + # Field is not part of model. + >>> make_path('nation.made_up_field') + None + + # nation field is not repeated and index provided. + >>> make_path('nation-1') + None + + # district field is repeated and no index provided. + >>> make_path('nation.district') + None + + Args: + parameter_name: Name of query parameter as passed in from the request. + in order to make a path, this parameter_name must point to a valid + field within the message structure. Nodes of the path that refer to + repeated fields must be indexed with a number, non repeated nodes must + not have an index. + + Returns: + Parsed version of the parameter_name as a tuple of tuples: + attribute: Name of attribute associated with path. + index: Postitive integer index when it is a repeated field, else None. + Will return None if the parameter_name does not have the right prefix, + does not point to a field within the message structure, does not have + an index if it is a repeated field or has an index but is not a repeated + field. + """ + if parameter_name.startswith(self.__parameter_prefix): + parameter_name = parameter_name[len(self.__parameter_prefix):] + else: + return None + + path = [] + name = [] + message_type = type(self.__messages[()]) # Get root message. + + for item in parameter_name.split('.'): + # This will catch sub_message.real_message_field.not_real_field + if not message_type: + return None + + item_match = _FIELD_NAME_REGEX.match(item) + if not item_match: + return None + attribute = item_match.group(1) + index = item_match.group(2) + if index: + index = int(index) + + try: + field = message_type.field_by_name(attribute) + except KeyError: + return None + + if field.repeated != (index is not None): + return None + + if isinstance(field, messages.MessageField): + message_type = field.message_type + else: + message_type = None + + # Path is valid so far. Append node and continue. + path.append((attribute, index)) + + return tuple(path) + + def __check_index(self, parent_path, name, index): + """Check correct index use and value relative to a given path. + + Check that for a given path the index is present for repeated fields + and that it is in range for the existing list that it will be inserted + in to or appended to. + + Args: + parent_path: Path to check against name and index. + name: Name of field to check for existance. + index: Index to check. If field is repeated, should be a number within + range of the length of the field, or point to the next item for + appending. + """ + # Don't worry about non-repeated fields. + # It's also ok if index is 0 because that means next insert will append. + if not index: + return True + + parent = self.__messages.get(parent_path, None) + value_list = getattr(parent, name, None) + # If the list does not exist then the index should be 0. Since it is + # not, path is not valid. + if not value_list: + return False + + # The index must either point to an element of the list or to the tail. + return len(value_list) >= index + + def __check_indexes(self, path): + """Check that all indexes are valid and in the right order. + + This method must iterate over the path and check that all references + to indexes point to an existing message or to the end of the list, meaning + the next value should be appended to the repeated field. + + Args: + path: Path to check indexes for. Tuple of 2-tuples (name, index). See + make_path for more information. + + Returns: + True if all the indexes of the path are within range, else False. + """ + if path in self.__checked_indexes: + return True + + # Start with the root message. + parent_path = () + + for name, index in path: + next_path = parent_path + ((name, index),) + # First look in the checked indexes cache. + if next_path not in self.__checked_indexes: + if not self.__check_index(parent_path, name, index): + return False + self.__checked_indexes.add(next_path) + + parent_path = next_path + + return True + + def __get_or_create_path(self, path): + """Get a message from the messages cache or create it and add it. + + This method will also create any parent messages based on the path. + + When a new instance of a given message is created, it is stored in + __message by its path. + + Args: + path: Path of message to get. Path must be valid, in other words + __check_index(path) returns true. Tuple of 2-tuples (name, index). + See make_path for more information. + + Returns: + Message instance if the field being pointed to by the path is a + message, else will return None for non-message fields. + """ + message = self.__messages.get(path, None) + if message: + return message + + parent_path = () + parent = self.__messages[()] # Get the root object + + for name, index in path: + field = parent.field_by_name(name) + next_path = parent_path + ((name, index),) + next_message = self.__messages.get(next_path, None) + if next_message is None: + next_message = field.message_type() + self.__messages[next_path] = next_message + if not field.repeated: + setattr(parent, field.name, next_message) + else: + list_value = getattr(parent, field.name, None) + if list_value is None: + setattr(parent, field.name, [next_message]) + else: + list_value.append(next_message) + + parent_path = next_path + parent = next_message + + return parent + + def add_parameter(self, parameter, values): + """Add a single parameter. + + Adds a single parameter and its value to the request message. + + Args: + parameter: Query string parameter to map to request. + values: List of values to assign to request message. + + Returns: + True if parameter was valid and added to the message, else False. + + Raises: + DecodeError if the parameter refers to a valid field, and the values + parameter does not have one and only one value. Non-valid query + parameters may have multiple values and should not cause an error. + """ + path = self.make_path(parameter) + + if not path: + return False + + # Must check that all indexes of all items in the path are correct before + # instantiating any of them. For example, consider: + # + # class Repeated(object): + # ... + # + # class Inner(object): + # + # repeated = messages.MessageField(Repeated, 1, repeated=True) + # + # class Outer(object): + # + # inner = messages.MessageField(Inner, 1) + # + # instance = Outer() + # builder = URLEncodedRequestBuilder(instance) + # builder.add_parameter('inner.repeated') + # + # assert not hasattr(instance, 'inner') + # + # The check is done relative to the instance of Outer pass in to the + # constructor of the builder. This instance is not referred to at all + # because all names are assumed to be relative to it. + # + # The 'repeated' part of the path is not correct because it is missing an + # index. Because it is missing an index, it should not create an instance + # of Repeated. In this case add_parameter will return False and have no + # side effects. + # + # A correct path that would cause a new Inner instance to be inserted at + # instance.inner and a new Repeated instance to be appended to the + # instance.inner.repeated list would be 'inner.repeated-0'. + if not self.__check_indexes(path): + return False + + # Ok to build objects. + parent_path = path[:-1] + parent = self.__get_or_create_path(parent_path) + name, index = path[-1] + field = parent.field_by_name(name) + + if len(values) != 1: + raise messages.DecodeError( + 'Found repeated values for field %s.' % field.name) + + value = values[0] + + if isinstance(field, messages.IntegerField): + converted_value = int(value) + elif isinstance(field, message_types.DateTimeField): + try: + converted_value = util.decode_datetime(value) + except ValueError as e: + raise messages.DecodeError(e) + elif isinstance(field, messages.MessageField): + # Just make sure it's instantiated. Assignment to field or + # appending to list is done in __get_or_create_path. + self.__get_or_create_path(path) + return True + elif isinstance(field, messages.StringField): + converted_value = value.decode('utf-8') + elif isinstance(field, messages.BooleanField): + converted_value = value.lower() == 'true' and True or False + else: + try: + converted_value = field.type(value) + except TypeError: + raise messages.DecodeError('Invalid enum value "%s"' % value) + + if field.repeated: + value_list = getattr(parent, field.name, None) + if value_list is None: + setattr(parent, field.name, [converted_value]) + else: + if index == len(value_list): + value_list.append(converted_value) + else: + # Index should never be above len(value_list) because it was + # verified during the index check above. + value_list[index] = converted_value + else: + setattr(parent, field.name, converted_value) + + return True + + +@util.positional(1) +def encode_message(message, prefix=''): + """Encode Message instance to url-encoded string. + + Args: + message: Message instance to encode in to url-encoded string. + prefix: Prefix to append to field names of contained values. + + Returns: + String encoding of Message in URL encoded format. + + Raises: + messages.ValidationError if message is not initialized. + """ + message.check_initialized() + + parameters = [] + def build_message(parent, prefix): + """Recursively build parameter list for URL response. + + Args: + parent: Message to build parameters for. + prefix: Prefix to append to field names of contained values. + + Returns: + True if some value of parent was added to the parameters list, + else False, meaning the object contained no values. + """ + has_any_values = False + for field in sorted(parent.all_fields(), key=lambda f: f.number): + next_value = parent.get_assigned_value(field.name) + if next_value is None: + continue + + # Found a value. Ultimate return value should be True. + has_any_values = True + + # Normalize all values in to a list. + if not field.repeated: + next_value = [next_value] + + for index, item in enumerate(next_value): + # Create a name with an index if it is a repeated field. + if field.repeated: + field_name = '%s%s-%s' % (prefix, field.name, index) + else: + field_name = prefix + field.name + + if isinstance(field, message_types.DateTimeField): + # DateTimeField stores its data as a RFC 3339 compliant string. + parameters.append((field_name, item.isoformat())) + elif isinstance(field, messages.MessageField): + # Message fields must be recursed in to in order to construct + # their component parameter values. + if not build_message(item, field_name + '.'): + # The nested message is empty. Append an empty value to + # represent it. + parameters.append((field_name, '')) + elif isinstance(field, messages.BooleanField): + parameters.append((field_name, item and 'true' or 'false')) + else: + if isinstance(item, six.text_type): + item = item.encode('utf-8') + parameters.append((field_name, str(item))) + + return has_any_values + + build_message(message, prefix) + + # Also add any unrecognized values from the decoded string. + for key in message.all_unrecognized_fields(): + values, _ = message.get_unrecognized_field_info(key) + if not isinstance(values, (list, tuple)): + values = (values,) + for value in values: + parameters.append((key, value)) + + return urllib.urlencode(parameters) + + +def decode_message(message_type, encoded_message, **kwargs): + """Decode urlencoded content to message. + + Args: + message_type: Message instance to merge URL encoded content into. + encoded_message: URL encoded message. + prefix: Prefix to append to field names of contained values. + + Returns: + Decoded instance of message_type. + """ + message = message_type() + builder = URLEncodedRequestBuilder(message, **kwargs) + arguments = cgi.parse_qs(encoded_message, keep_blank_values=True) + for argument, values in sorted(six.iteritems(arguments)): + added = builder.add_parameter(argument, values) + # Save off any unknown values, so they're still accessible. + if not added: + message.set_unrecognized_field(argument, values, messages.Variant.STRING) + message.check_initialized() + return message diff --git a/endpoints/internal/protorpc/protourlencode_test.py b/endpoints/internal/protorpc/protourlencode_test.py new file mode 100644 index 0000000..0121896 --- /dev/null +++ b/endpoints/internal/protorpc/protourlencode_test.py @@ -0,0 +1,369 @@ +#!/usr/bin/env python +# +# Copyright 2010 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""Tests for protorpc.protourlencode.""" + +__author__ = 'rafek@google.com (Rafe Kaplan)' + + +import cgi +import logging +import unittest +import urllib + +from protorpc import message_types +from protorpc import messages +from protorpc import protourlencode +from protorpc import test_util + + +class ModuleInterfaceTest(test_util.ModuleInterfaceTest, + test_util.TestCase): + + MODULE = protourlencode + + + +class SuperMessage(messages.Message): + """A test message with a nested message field.""" + + sub_message = messages.MessageField(test_util.OptionalMessage, 1) + sub_messages = messages.MessageField(test_util.OptionalMessage, + 2, + repeated=True) + + +class SuperSuperMessage(messages.Message): + """A test message with two levels of nested.""" + + sub_message = messages.MessageField(SuperMessage, 1) + sub_messages = messages.MessageField(SuperMessage, 2, repeated=True) + + +class URLEncodedRequestBuilderTest(test_util.TestCase): + """Test the URL Encoded request builder.""" + + def testMakePath(self): + builder = protourlencode.URLEncodedRequestBuilder(SuperSuperMessage(), + prefix='pre.') + + self.assertEquals(None, builder.make_path('')) + self.assertEquals(None, builder.make_path('no_such_field')) + self.assertEquals(None, builder.make_path('pre.no_such_field')) + + # Missing prefix. + self.assertEquals(None, builder.make_path('sub_message')) + + # Valid parameters. + self.assertEquals((('sub_message', None),), + builder.make_path('pre.sub_message')) + self.assertEquals((('sub_message', None), ('sub_messages', 1)), + builder.make_path('pre.sub_message.sub_messages-1')) + self.assertEquals( + (('sub_message', None), + ('sub_messages', 1), + ('int64_value', None)), + builder.make_path('pre.sub_message.sub_messages-1.int64_value')) + + # Missing index. + self.assertEquals( + None, + builder.make_path('pre.sub_message.sub_messages.integer_field')) + + # Has unexpected index. + self.assertEquals( + None, + builder.make_path('pre.sub_message.sub_message-1.integer_field')) + + def testAddParameter_SimpleAttributes(self): + message = test_util.OptionalMessage() + builder = protourlencode.URLEncodedRequestBuilder(message, prefix='pre.') + + self.assertTrue(builder.add_parameter('pre.int64_value', ['10'])) + self.assertTrue(builder.add_parameter('pre.string_value', ['a string'])) + self.assertTrue(builder.add_parameter('pre.enum_value', ['VAL1'])) + self.assertEquals(10, message.int64_value) + self.assertEquals('a string', message.string_value) + self.assertEquals(test_util.OptionalMessage.SimpleEnum.VAL1, + message.enum_value) + + def testAddParameter_InvalidAttributes(self): + message = SuperSuperMessage() + builder = protourlencode.URLEncodedRequestBuilder(message, prefix='pre.') + + def assert_empty(): + self.assertEquals(None, getattr(message, 'sub_message')) + self.assertEquals([], getattr(message, 'sub_messages')) + + self.assertFalse(builder.add_parameter('pre.nothing', ['x'])) + assert_empty() + + self.assertFalse(builder.add_parameter('pre.sub_messages', ['x'])) + self.assertFalse(builder.add_parameter('pre.sub_messages-1.nothing', ['x'])) + assert_empty() + + def testAddParameter_NestedAttributes(self): + message = SuperSuperMessage() + builder = protourlencode.URLEncodedRequestBuilder(message, prefix='pre.') + + # Set an empty message fields. + self.assertTrue(builder.add_parameter('pre.sub_message', [''])) + self.assertTrue(isinstance(message.sub_message, SuperMessage)) + + # Add a basic attribute. + self.assertTrue(builder.add_parameter( + 'pre.sub_message.sub_message.int64_value', ['10'])) + self.assertTrue(builder.add_parameter( + 'pre.sub_message.sub_message.string_value', ['hello'])) + + self.assertTrue(10, message.sub_message.sub_message.int64_value) + self.assertTrue('hello', message.sub_message.sub_message.string_value) + + + def testAddParameter_NestedMessages(self): + message = SuperSuperMessage() + builder = protourlencode.URLEncodedRequestBuilder(message, prefix='pre.') + + # Add a repeated empty message. + self.assertTrue(builder.add_parameter( + 'pre.sub_message.sub_messages-0', [''])) + sub_message = message.sub_message.sub_messages[0] + self.assertTrue(1, len(message.sub_message.sub_messages)) + self.assertTrue(isinstance(sub_message, + test_util.OptionalMessage)) + self.assertEquals(None, getattr(sub_message, 'int64_value')) + self.assertEquals(None, getattr(sub_message, 'string_value')) + self.assertEquals(None, getattr(sub_message, 'enum_value')) + + # Add a repeated message with value. + self.assertTrue(builder.add_parameter( + 'pre.sub_message.sub_messages-1.int64_value', ['10'])) + self.assertTrue(2, len(message.sub_message.sub_messages)) + self.assertTrue(10, message.sub_message.sub_messages[1].int64_value) + + # Add another value to the same nested message. + self.assertTrue(builder.add_parameter( + 'pre.sub_message.sub_messages-1.string_value', ['a string'])) + self.assertTrue(2, len(message.sub_message.sub_messages)) + self.assertEquals(10, message.sub_message.sub_messages[1].int64_value) + self.assertEquals('a string', + message.sub_message.sub_messages[1].string_value) + + def testAddParameter_RepeatedValues(self): + message = test_util.RepeatedMessage() + builder = protourlencode.URLEncodedRequestBuilder(message, prefix='pre.') + + self.assertTrue(builder.add_parameter('pre.int64_value-0', ['20'])) + self.assertTrue(builder.add_parameter('pre.int64_value-1', ['30'])) + self.assertEquals([20, 30], message.int64_value) + + self.assertTrue(builder.add_parameter('pre.string_value-0', ['hi'])) + self.assertTrue(builder.add_parameter('pre.string_value-1', ['lo'])) + self.assertTrue(builder.add_parameter('pre.string_value-1', ['dups overwrite'])) + self.assertEquals(['hi', 'dups overwrite'], message.string_value) + + def testAddParameter_InvalidValuesMayRepeat(self): + message = test_util.OptionalMessage() + builder = protourlencode.URLEncodedRequestBuilder(message, prefix='pre.') + + self.assertFalse(builder.add_parameter('nothing', [1, 2, 3])) + + def testAddParameter_RepeatedParameters(self): + message = test_util.OptionalMessage() + builder = protourlencode.URLEncodedRequestBuilder(message, prefix='pre.') + + self.assertRaises(messages.DecodeError, + builder.add_parameter, + 'pre.int64_value', + [1, 2, 3]) + self.assertRaises(messages.DecodeError, + builder.add_parameter, + 'pre.int64_value', + []) + + def testAddParameter_UnexpectedNestedValue(self): + """Test getting a nested value on a non-message sub-field.""" + message = test_util.HasNestedMessage() + builder = protourlencode.URLEncodedRequestBuilder(message, 'pre.') + + self.assertFalse(builder.add_parameter('pre.nested.a_value.whatever', + ['1'])) + + def testInvalidFieldFormat(self): + message = test_util.OptionalMessage() + builder = protourlencode.URLEncodedRequestBuilder(message, prefix='pre.') + + self.assertFalse(builder.add_parameter('pre.illegal%20', ['1'])) + + def testAddParameter_UnexpectedNestedValue(self): + """Test getting a nested value on a non-message sub-field + + There is an odd corner case where if trying to insert a repeated value + on an nested repeated message that would normally succeed in being created + should fail. This case can only be tested when the first message of the + nested messages already exists. + + Another case is trying to access an indexed value nested within a + non-message field. + """ + class HasRepeated(messages.Message): + + values = messages.IntegerField(1, repeated=True) + + class HasNestedRepeated(messages.Message): + + nested = messages.MessageField(HasRepeated, 1, repeated=True) + + + message = HasNestedRepeated() + builder = protourlencode.URLEncodedRequestBuilder(message, prefix='pre.') + + self.assertTrue(builder.add_parameter('pre.nested-0.values-0', ['1'])) + # Try to create an indexed value on a non-message field. + self.assertFalse(builder.add_parameter('pre.nested-0.values-0.unknown-0', + ['1'])) + # Try to create an out of range indexed field on an otherwise valid + # repeated message field. + self.assertFalse(builder.add_parameter('pre.nested-1.values-1', ['1'])) + + +class ProtourlencodeConformanceTest(test_util.TestCase, + test_util.ProtoConformanceTestBase): + + PROTOLIB = protourlencode + + encoded_partial = urllib.urlencode([('double_value', 1.23), + ('int64_value', -100000000000), + ('int32_value', 1020), + ('string_value', u'a string'), + ('enum_value', 'VAL2'), + ]) + + encoded_full = urllib.urlencode([('double_value', 1.23), + ('float_value', -2.5), + ('int64_value', -100000000000), + ('uint64_value', 102020202020), + ('int32_value', 1020), + ('bool_value', 'true'), + ('string_value', + u'a string\u044f'.encode('utf-8')), + ('bytes_value', b'a bytes\xff\xfe'), + ('enum_value', 'VAL2'), + ]) + + encoded_repeated = urllib.urlencode([('double_value-0', 1.23), + ('double_value-1', 2.3), + ('float_value-0', -2.5), + ('float_value-1', 0.5), + ('int64_value-0', -100000000000), + ('int64_value-1', 20), + ('uint64_value-0', 102020202020), + ('uint64_value-1', 10), + ('int32_value-0', 1020), + ('int32_value-1', 718), + ('bool_value-0', 'true'), + ('bool_value-1', 'false'), + ('string_value-0', + u'a string\u044f'.encode('utf-8')), + ('string_value-1', + u'another string'.encode('utf-8')), + ('bytes_value-0', b'a bytes\xff\xfe'), + ('bytes_value-1', b'another bytes'), + ('enum_value-0', 'VAL2'), + ('enum_value-1', 'VAL1'), + ]) + + encoded_nested = urllib.urlencode([('nested.a_value', 'a string'), + ]) + + encoded_repeated_nested = urllib.urlencode( + [('repeated_nested-0.a_value', 'a string'), + ('repeated_nested-1.a_value', 'another string'), + ]) + + unexpected_tag_message = 'unexpected=whatever' + + encoded_default_assigned = urllib.urlencode([('a_value', 'a default'), + ]) + + encoded_nested_empty = urllib.urlencode([('nested', '')]) + + encoded_repeated_nested_empty = urllib.urlencode([('repeated_nested-0', ''), + ('repeated_nested-1', '')]) + + encoded_extend_message = urllib.urlencode([('int64_value-0', 400), + ('int64_value-1', 50), + ('int64_value-2', 6000)]) + + encoded_string_types = urllib.urlencode( + [('string_value', 'Latin')]) + + encoded_invalid_enum = urllib.urlencode([('enum_value', 'undefined')]) + + def testParameterPrefix(self): + """Test using the 'prefix' parameter to encode_message.""" + class MyMessage(messages.Message): + number = messages.IntegerField(1) + names = messages.StringField(2, repeated=True) + + message = MyMessage() + message.number = 10 + message.names = [u'Fred', u'Lisa'] + + encoded_message = protourlencode.encode_message(message, prefix='prefix-') + self.assertEquals({'prefix-number': ['10'], + 'prefix-names-0': ['Fred'], + 'prefix-names-1': ['Lisa'], + }, + cgi.parse_qs(encoded_message)) + + self.assertEquals(message, protourlencode.decode_message(MyMessage, + encoded_message, + prefix='prefix-')) + + def testProtourlencodeUnrecognizedField(self): + """Test that unrecognized fields are saved and can be accessed.""" + + class MyMessage(messages.Message): + number = messages.IntegerField(1) + + decoded = protourlencode.decode_message(MyMessage, + self.unexpected_tag_message) + self.assertEquals(1, len(decoded.all_unrecognized_fields())) + self.assertEquals('unexpected', decoded.all_unrecognized_fields()[0]) + # Unknown values set to a list of however many values had that name. + self.assertEquals((['whatever'], messages.Variant.STRING), + decoded.get_unrecognized_field_info('unexpected')) + + repeated_unknown = urllib.urlencode([('repeated', 400), + ('repeated', 'test'), + ('repeated', '123.456')]) + decoded2 = protourlencode.decode_message(MyMessage, repeated_unknown) + self.assertEquals((['400', 'test', '123.456'], messages.Variant.STRING), + decoded2.get_unrecognized_field_info('repeated')) + + def testDecodeInvalidDateTime(self): + + class MyMessage(messages.Message): + a_datetime = message_types.DateTimeField(1) + + self.assertRaises(messages.DecodeError, protourlencode.decode_message, + MyMessage, 'a_datetime=invalid') + + +if __name__ == '__main__': + unittest.main() diff --git a/endpoints/internal/protorpc/registry.py b/endpoints/internal/protorpc/registry.py new file mode 100644 index 0000000..23ba876 --- /dev/null +++ b/endpoints/internal/protorpc/registry.py @@ -0,0 +1,240 @@ +#!/usr/bin/env python +# +# Copyright 2010 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""Service regsitry for service discovery. + +The registry service can be deployed on a server in order to provide a +central place where remote clients can discover available. + +On the server side, each service is registered by their name which is unique +to the registry. Typically this name provides enough information to identify +the service and locate it within a server. For example, for an HTTP based +registry the name is the URL path on the host where the service is invocable. + +The registry is also able to resolve the full descriptor.FileSet necessary to +describe the service and all required data-types (messages and enums). + +A configured registry is itself a remote service and should reference itself. +""" + +import sys + +from . import descriptor +from . import messages +from . import remote +from . import util + + +__all__ = [ + 'ServiceMapping', + 'ServicesResponse', + 'GetFileSetRequest', + 'GetFileSetResponse', + 'RegistryService', +] + + +class ServiceMapping(messages.Message): + """Description of registered service. + + Fields: + name: Name of service. On HTTP based services this will be the + URL path used for invocation. + definition: Fully qualified name of the service definition. Useful + for clients that can look up service definitions based on an existing + repository of definitions. + """ + + name = messages.StringField(1, required=True) + definition = messages.StringField(2, required=True) + + +class ServicesResponse(messages.Message): + """Response containing all registered services. + + May also contain complete descriptor file-set for all services known by the + registry. + + Fields: + services: Service mappings for all registered services in registry. + file_set: Descriptor file-set describing all services, messages and enum + types needed for use with all requested services if asked for in the + request. + """ + + services = messages.MessageField(ServiceMapping, 1, repeated=True) + + +class GetFileSetRequest(messages.Message): + """Request for service descriptor file-set. + + Request to retrieve file sets for specific services. + + Fields: + names: Names of services to retrieve file-set for. + """ + + names = messages.StringField(1, repeated=True) + + +class GetFileSetResponse(messages.Message): + """Descriptor file-set for all names in GetFileSetRequest. + + Fields: + file_set: Descriptor file-set containing all descriptors for services, + messages and enum types needed for listed names in request. + """ + + file_set = messages.MessageField(descriptor.FileSet, 1, required=True) + + +class RegistryService(remote.Service): + """Registry service. + + Maps names to services and is able to describe all descriptor file-sets + necessary to use contined services. + + On an HTTP based server, the name is the URL path to the service. + """ + + @util.positional(2) + def __init__(self, registry, modules=None): + """Constructor. + + Args: + registry: Map of name to service class. This map is not copied and may + be modified after the reigstry service has been configured. + modules: Module dict to draw descriptors from. Defaults to sys.modules. + """ + # Private Attributes: + # __registry: Map of name to service class. Refers to same instance as + # registry parameter. + # __modules: Mapping of module name to module. + # __definition_to_modules: Mapping of definition types to set of modules + # that they refer to. This cache is used to make repeated look-ups + # faster and to prevent circular references from causing endless loops. + + self.__registry = registry + if modules is None: + modules = sys.modules + self.__modules = modules + # This cache will only last for a single request. + self.__definition_to_modules = {} + + def __find_modules_for_message(self, message_type): + """Find modules referred to by a message type. + + Determines the entire list of modules ultimately referred to by message_type + by iterating over all of its message and enum fields. Includes modules + referred to fields within its referred messages. + + Args: + message_type: Message type to find all referring modules for. + + Returns: + Set of modules referred to by message_type by traversing all its + message and enum fields. + """ + # TODO(rafek): Maybe this should be a method on Message and Service? + def get_dependencies(message_type, seen=None): + """Get all dependency definitions of a message type. + + This function works by collecting the types of all enumeration and message + fields defined within the message type. When encountering a message + field, it will recursivly find all of the associated message's + dependencies. It will terminate on circular dependencies by keeping track + of what definitions it already via the seen set. + + Args: + message_type: Message type to get dependencies for. + seen: Set of definitions that have already been visited. + + Returns: + All dependency message and enumerated types associated with this message + including the message itself. + """ + if seen is None: + seen = set() + seen.add(message_type) + + for field in message_type.all_fields(): + if isinstance(field, messages.MessageField): + if field.message_type not in seen: + get_dependencies(field.message_type, seen) + elif isinstance(field, messages.EnumField): + seen.add(field.type) + + return seen + + found_modules = self.__definition_to_modules.setdefault(message_type, set()) + if not found_modules: + dependencies = get_dependencies(message_type) + found_modules.update(self.__modules[definition.__module__] + for definition in dependencies) + + return found_modules + + def __describe_file_set(self, names): + """Get file-set for named services. + + Args: + names: List of names to get file-set for. + + Returns: + descriptor.FileSet containing all the descriptors for all modules + ultimately referred to by all service types request by names parameter. + """ + service_modules = set() + if names: + for service in (self.__registry[name] for name in names): + found_modules = self.__definition_to_modules.setdefault(service, set()) + if not found_modules: + found_modules.add(self.__modules[service.__module__]) + for method_name in service.all_remote_methods(): + method = getattr(service, method_name) + for message_type in (method.remote.request_type, + method.remote.response_type): + found_modules.update( + self.__find_modules_for_message(message_type)) + service_modules.update(found_modules) + + return descriptor.describe_file_set(service_modules) + + @property + def registry(self): + """Get service registry associated with this service instance.""" + return self.__registry + + @remote.method(response_type=ServicesResponse) + def services(self, request): + """Get all registered services.""" + response = ServicesResponse() + response.services = [] + for name, service_class in self.__registry.items(): + mapping = ServiceMapping() + mapping.name = name.decode('utf-8') + mapping.definition = service_class.definition_name().decode('utf-8') + response.services.append(mapping) + + return response + + @remote.method(GetFileSetRequest, GetFileSetResponse) + def get_file_set(self, request): + """Get file-set for registered servies.""" + response = GetFileSetResponse() + response.file_set = self.__describe_file_set(request.names) + return response diff --git a/endpoints/internal/protorpc/registry_test.py b/endpoints/internal/protorpc/registry_test.py new file mode 100644 index 0000000..ec30a3f --- /dev/null +++ b/endpoints/internal/protorpc/registry_test.py @@ -0,0 +1,124 @@ +#!/usr/bin/env python +# +# Copyright 2010 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""Tests for protorpc.message.""" + +__author__ = 'rafek@google.com (Rafe Kaplan)' + + +import sys +import unittest + +from protorpc import descriptor +from protorpc import message_types +from protorpc import messages +from protorpc import registry +from protorpc import remote +from protorpc import test_util + + +class ModuleInterfaceTest(test_util.ModuleInterfaceTest, + test_util.TestCase): + + MODULE = registry + + +class MyService1(remote.Service): + """Test service that refers to messages in another module.""" + + @remote.method(test_util.NestedMessage, test_util.NestedMessage) + def a_method(self, request): + pass + + +class MyService2(remote.Service): + """Test service that does not refer to messages in another module.""" + + +class RegistryServiceTest(test_util.TestCase): + + def setUp(self): + self.registry = { + 'my-service1': MyService1, + 'my-service2': MyService2, + } + + self.modules = { + __name__: sys.modules[__name__], + test_util.__name__: test_util, + } + + self.registry_service = registry.RegistryService(self.registry, + modules=self.modules) + + def CheckServiceMappings(self, mappings): + module_name = test_util.get_module_name(RegistryServiceTest) + service1_mapping = registry.ServiceMapping() + service1_mapping.name = 'my-service1' + service1_mapping.definition = '%s.MyService1' % module_name + + service2_mapping = registry.ServiceMapping() + service2_mapping.name = 'my-service2' + service2_mapping.definition = '%s.MyService2' % module_name + + self.assertIterEqual(mappings, [service1_mapping, service2_mapping]) + + def testServices(self): + response = self.registry_service.services(message_types.VoidMessage()) + + self.CheckServiceMappings(response.services) + + def testGetFileSet_All(self): + request = registry.GetFileSetRequest() + request.names = ['my-service1', 'my-service2'] + response = self.registry_service.get_file_set(request) + + expected_file_set = descriptor.describe_file_set(list(self.modules.values())) + self.assertIterEqual(expected_file_set.files, response.file_set.files) + + def testGetFileSet_None(self): + request = registry.GetFileSetRequest() + response = self.registry_service.get_file_set(request) + + self.assertEquals(descriptor.FileSet(), + response.file_set) + + def testGetFileSet_ReferenceOtherModules(self): + request = registry.GetFileSetRequest() + request.names = ['my-service1'] + response = self.registry_service.get_file_set(request) + + # Will suck in and describe the test_util module. + expected_file_set = descriptor.describe_file_set(list(self.modules.values())) + self.assertIterEqual(expected_file_set.files, response.file_set.files) + + def testGetFileSet_DoNotReferenceOtherModules(self): + request = registry.GetFileSetRequest() + request.names = ['my-service2'] + response = self.registry_service.get_file_set(request) + + # Service does not reference test_util, so will only describe this module. + expected_file_set = descriptor.describe_file_set([self.modules[__name__]]) + self.assertIterEqual(expected_file_set.files, response.file_set.files) + + +def main(): + unittest.main() + + +if __name__ == '__main__': + main() diff --git a/endpoints/internal/protorpc/remote.py b/endpoints/internal/protorpc/remote.py new file mode 100644 index 0000000..61fe6c8 --- /dev/null +++ b/endpoints/internal/protorpc/remote.py @@ -0,0 +1,1248 @@ +#!/usr/bin/env python +# +# Copyright 2010 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""Remote service library. + +This module contains classes that are useful for building remote services that +conform to a standard request and response model. To conform to this model +a service must be like the following class: + + # Each service instance only handles a single request and is then discarded. + # Make these objects light weight. + class Service(object): + + # It must be possible to construct service objects without any parameters. + # If your constructor needs extra information you should provide a + # no-argument factory function to create service instances. + def __init__(self): + ... + + # Each remote method must use the 'method' decorator, passing the request + # and response message types. The remote method itself must take a single + # parameter which is an instance of RequestMessage and return an instance + # of ResponseMessage. + @method(RequestMessage, ResponseMessage) + def remote_method(self, request): + # Return an instance of ResponseMessage. + + # A service object may optionally implement an 'initialize_request_state' + # method that takes as a parameter a single instance of a RequestState. If + # a service does not implement this method it will not receive the request + # state. + def initialize_request_state(self, state): + ... + +The 'Service' class is provided as a convenient base class that provides the +above functionality. It implements all required and optional methods for a +service. It also has convenience methods for creating factory functions that +can pass persistent global state to a new service instance. + +The 'method' decorator is used to declare which methods of a class are +meant to service RPCs. While this decorator is not responsible for handling +actual remote method invocations, such as handling sockets, handling various +RPC protocols and checking messages for correctness, it does attach information +to methods that responsible classes can examine and ensure the correctness +of the RPC. + +When the method decorator is used on a method, the wrapper method will have a +'remote' property associated with it. The 'remote' property contains the +request_type and response_type expected by the methods implementation. + +On its own, the method decorator does not provide any support for subclassing +remote methods. In order to extend a service, one would need to redecorate +the sub-classes methods. For example: + + class MyService(Service): + + @method(DoSomethingRequest, DoSomethingResponse) + def do_stuff(self, request): + ... implement do_stuff ... + + class MyBetterService(MyService): + + @method(DoSomethingRequest, DoSomethingResponse) + def do_stuff(self, request): + response = super(MyBetterService, self).do_stuff.remote.method(request) + ... do stuff with response ... + return response + +A Service subclass also has a Stub class that can be used with a transport for +making RPCs. When a stub is created, it is capable of doing both synchronous +and asynchronous RPCs if the underlying transport supports it. To make a stub +using an HTTP transport do: + + my_service = MyService.Stub(HttpTransport('')) + +For synchronous calls, just call the expected methods on the service stub: + + request = DoSomethingRequest() + ... + response = my_service.do_something(request) + +Each stub instance has an async object that can be used for initiating +asynchronous RPCs if the underlying protocol transport supports it. To +make an asynchronous call, do: + + rpc = my_service.async.do_something(request) + response = rpc.get_response() +""" + +from __future__ import with_statement +import six + +__author__ = 'rafek@google.com (Rafe Kaplan)' + +import functools +import logging +import sys +import threading +from wsgiref import headers as wsgi_headers + +from . import message_types +from . import messages +from . import protobuf +from . import protojson +from . import util + + +__all__ = [ + 'ApplicationError', + 'MethodNotFoundError', + 'NetworkError', + 'RequestError', + 'RpcError', + 'ServerError', + 'ServiceConfigurationError', + 'ServiceDefinitionError', + + 'HttpRequestState', + 'ProtocolConfig', + 'Protocols', + 'RequestState', + 'RpcState', + 'RpcStatus', + 'Service', + 'StubBase', + 'check_rpc_status', + 'get_remote_method_info', + 'is_error_status', + 'method', + 'remote', +] + + +class ServiceDefinitionError(messages.Error): + """Raised when a service is improperly defined.""" + + +class ServiceConfigurationError(messages.Error): + """Raised when a service is incorrectly configured.""" + + +# TODO: Use error_name to map to specific exception message types. +class RpcStatus(messages.Message): + """Status of on-going or complete RPC. + + Fields: + state: State of RPC. + error_name: Error name set by application. Only set when + status is APPLICATION_ERROR. For use by application to transmit + specific reason for error. + error_message: Error message associated with status. + """ + + class State(messages.Enum): + """Enumeration of possible RPC states. + + Values: + OK: Completed successfully. + RUNNING: Still running, not complete. + REQUEST_ERROR: Request was malformed or incomplete. + SERVER_ERROR: Server experienced an unexpected error. + NETWORK_ERROR: An error occured on the network. + APPLICATION_ERROR: The application is indicating an error. + When in this state, RPC should also set application_error. + """ + OK = 0 + RUNNING = 1 + + REQUEST_ERROR = 2 + SERVER_ERROR = 3 + NETWORK_ERROR = 4 + APPLICATION_ERROR = 5 + METHOD_NOT_FOUND_ERROR = 6 + + state = messages.EnumField(State, 1, required=True) + error_message = messages.StringField(2) + error_name = messages.StringField(3) + + +RpcState = RpcStatus.State + + +class RpcError(messages.Error): + """Base class for RPC errors. + + Each sub-class of RpcError is associated with an error value from RpcState + and has an attribute STATE that refers to that value. + """ + + def __init__(self, message, cause=None): + super(RpcError, self).__init__(message) + self.cause = cause + + @classmethod + def from_state(cls, state): + """Get error class from RpcState. + + Args: + state: RpcState value. Can be enum value itself, string or int. + + Returns: + Exception class mapped to value if state is an error. Returns None + if state is OK or RUNNING. + """ + return _RPC_STATE_TO_ERROR.get(RpcState(state)) + + +class RequestError(RpcError): + """Raised when wrong request objects received during method invocation.""" + + STATE = RpcState.REQUEST_ERROR + + +class MethodNotFoundError(RequestError): + """Raised when unknown method requested by RPC.""" + + STATE = RpcState.METHOD_NOT_FOUND_ERROR + + +class NetworkError(RpcError): + """Raised when network error occurs during RPC.""" + + STATE = RpcState.NETWORK_ERROR + + +class ServerError(RpcError): + """Unexpected error occured on server.""" + + STATE = RpcState.SERVER_ERROR + + +class ApplicationError(RpcError): + """Raised for application specific errors. + + Attributes: + error_name: Application specific error name for exception. + """ + + STATE = RpcState.APPLICATION_ERROR + + def __init__(self, message, error_name=None): + """Constructor. + + Args: + message: Application specific error message. + error_name: Application specific error name. Must be None, string + or unicode string. + """ + super(ApplicationError, self).__init__(message) + self.error_name = error_name + + def __str__(self): + return self.args[0] or '' + + def __repr__(self): + if self.error_name is None: + error_format = '' + else: + error_format = ', %r' % self.error_name + return '%s(%r%s)' % (type(self).__name__, self.args[0], error_format) + + +_RPC_STATE_TO_ERROR = { + RpcState.REQUEST_ERROR: RequestError, + RpcState.NETWORK_ERROR: NetworkError, + RpcState.SERVER_ERROR: ServerError, + RpcState.APPLICATION_ERROR: ApplicationError, + RpcState.METHOD_NOT_FOUND_ERROR: MethodNotFoundError, +} + +class _RemoteMethodInfo(object): + """Object for encapsulating remote method information. + + An instance of this method is associated with the 'remote' attribute + of the methods 'invoke_remote_method' instance. + + Instances of this class are created by the remote decorator and should not + be created directly. + """ + + def __init__(self, + method, + request_type, + response_type): + """Constructor. + + Args: + method: The method which implements the remote method. This is a + function that will act as an instance method of a class definition + that is decorated by '@method'. It must always take 'self' as its + first parameter. + request_type: Expected request type for the remote method. + response_type: Expected response type for the remote method. + """ + self.__method = method + self.__request_type = request_type + self.__response_type = response_type + + @property + def method(self): + """Original undecorated method.""" + return self.__method + + @property + def request_type(self): + """Expected request type for remote method.""" + if isinstance(self.__request_type, six.string_types): + self.__request_type = messages.find_definition( + self.__request_type, + relative_to=sys.modules[self.__method.__module__]) + return self.__request_type + + @property + def response_type(self): + """Expected response type for remote method.""" + if isinstance(self.__response_type, six.string_types): + self.__response_type = messages.find_definition( + self.__response_type, + relative_to=sys.modules[self.__method.__module__]) + return self.__response_type + + +def method(request_type=message_types.VoidMessage, + response_type=message_types.VoidMessage): + """Method decorator for creating remote methods. + + Args: + request_type: Message type of expected request. + response_type: Message type of expected response. + + Returns: + 'remote_method_wrapper' function. + + Raises: + TypeError: if the request_type or response_type parameters are not + proper subclasses of messages.Message. + """ + if (not isinstance(request_type, six.string_types) and + (not isinstance(request_type, type) or + not issubclass(request_type, messages.Message) or + request_type is messages.Message)): + raise TypeError( + 'Must provide message class for request-type. Found %s', + request_type) + + if (not isinstance(response_type, six.string_types) and + (not isinstance(response_type, type) or + not issubclass(response_type, messages.Message) or + response_type is messages.Message)): + raise TypeError( + 'Must provide message class for response-type. Found %s', + response_type) + + def remote_method_wrapper(method): + """Decorator used to wrap method. + + Args: + method: Original method being wrapped. + + Returns: + 'invoke_remote_method' function responsible for actual invocation. + This invocation function instance is assigned an attribute 'remote' + which contains information about the remote method: + request_type: Expected request type for remote method. + response_type: Response type returned from remote method. + + Raises: + TypeError: If request_type or response_type is not a subclass of Message + or is the Message class itself. + """ + + @functools.wraps(method) + def invoke_remote_method(service_instance, request): + """Function used to replace original method. + + Invoke wrapped remote method. Checks to ensure that request and + response objects are the correct types. + + Does not check whether messages are initialized. + + Args: + service_instance: The service object whose method is being invoked. + This is passed to 'self' during the invocation of the original + method. + request: Request message. + + Returns: + Results of calling wrapped remote method. + + Raises: + RequestError: Request object is not of the correct type. + ServerError: Response object is not of the correct type. + """ + if not isinstance(request, remote_method_info.request_type): + raise RequestError('Method %s.%s expected request type %s, ' + 'received %s' % + (type(service_instance).__name__, + method.__name__, + remote_method_info.request_type, + type(request))) + response = method(service_instance, request) + if not isinstance(response, remote_method_info.response_type): + raise ServerError('Method %s.%s expected response type %s, ' + 'sent %s' % + (type(service_instance).__name__, + method.__name__, + remote_method_info.response_type, + type(response))) + return response + + remote_method_info = _RemoteMethodInfo(method, + request_type, + response_type) + + invoke_remote_method.remote = remote_method_info + return invoke_remote_method + + return remote_method_wrapper + + +def remote(request_type, response_type): + """Temporary backward compatibility alias for method.""" + logging.warning('The remote decorator has been renamed method. It will be ' + 'removed in very soon from future versions of ProtoRPC.') + return method(request_type, response_type) + + +def get_remote_method_info(method): + """Get remote method info object from remote method. + + Returns: + Remote method info object if method is a remote method, else None. + """ + if not callable(method): + return None + + try: + method_info = method.remote + except AttributeError: + return None + + if not isinstance(method_info, _RemoteMethodInfo): + return None + + return method_info + + +class StubBase(object): + """Base class for client side service stubs. + + The remote method stubs are created by the _ServiceClass meta-class + when a Service class is first created. The resulting stub will + extend both this class and the service class it handles communications for. + + Assume that there is a service: + + class NewContactRequest(messages.Message): + + name = messages.StringField(1, required=True) + phone = messages.StringField(2) + email = messages.StringField(3) + + class NewContactResponse(message.Message): + + contact_id = messages.StringField(1) + + class AccountService(remote.Service): + + @remote.method(NewContactRequest, NewContactResponse): + def new_contact(self, request): + ... implementation ... + + A stub of this service can be called in two ways. The first is to pass in a + correctly initialized NewContactRequest message: + + request = NewContactRequest() + request.name = 'Bob Somebody' + request.phone = '+1 415 555 1234' + + response = account_service_stub.new_contact(request) + + The second way is to pass in keyword parameters that correspond with the root + request message type: + + account_service_stub.new_contact(name='Bob Somebody', + phone='+1 415 555 1234') + + The second form will create a request message of the appropriate type. + """ + + def __init__(self, transport): + """Constructor. + + Args: + transport: Underlying transport to communicate with remote service. + """ + self.__transport = transport + + @property + def transport(self): + """Transport used to communicate with remote service.""" + return self.__transport + + +class _ServiceClass(type): + """Meta-class for service class.""" + + def __new_async_method(cls, remote): + """Create asynchronous method for Async handler. + + Args: + remote: RemoteInfo to create method for. + """ + def async_method(self, *args, **kwargs): + """Asynchronous remote method. + + Args: + self: Instance of StubBase.Async subclass. + + Stub methods either take a single positional argument when a full + request message is passed in, or keyword arguments, but not both. + + See docstring for StubBase for more information on how to use remote + stub methods. + + Returns: + Rpc instance used to represent asynchronous RPC. + """ + if args and kwargs: + raise TypeError('May not provide both args and kwargs') + + if not args: + # Construct request object from arguments. + request = remote.request_type() + for name, value in six.iteritems(kwargs): + setattr(request, name, value) + else: + # First argument is request object. + request = args[0] + + return self.transport.send_rpc(remote, request) + + async_method.__name__ = remote.method.__name__ + async_method = util.positional(2)(async_method) + async_method.remote = remote + return async_method + + def __new_sync_method(cls, async_method): + """Create synchronous method for stub. + + Args: + async_method: asynchronous method to delegate calls to. + """ + def sync_method(self, *args, **kwargs): + """Synchronous remote method. + + Args: + self: Instance of StubBase.Async subclass. + args: Tuple (request,): + request: Request object. + kwargs: Field values for request. Must be empty if request object + is provided. + + Returns: + Response message from synchronized RPC. + """ + return async_method(self.async, *args, **kwargs).response + sync_method.__name__ = async_method.__name__ + sync_method.remote = async_method.remote + return sync_method + + def __create_async_methods(cls, remote_methods): + """Construct a dictionary of asynchronous methods based on remote methods. + + Args: + remote_methods: Dictionary of methods with associated RemoteInfo objects. + + Returns: + Dictionary of asynchronous methods with assocaited RemoteInfo objects. + Results added to AsyncStub subclass. + """ + async_methods = {} + for method_name, method in remote_methods.items(): + async_methods[method_name] = cls.__new_async_method(method.remote) + return async_methods + + def __create_sync_methods(cls, async_methods): + """Construct a dictionary of synchronous methods based on remote methods. + + Args: + async_methods: Dictionary of async methods to delegate calls to. + + Returns: + Dictionary of synchronous methods with assocaited RemoteInfo objects. + Results added to Stub subclass. + """ + sync_methods = {} + for method_name, async_method in async_methods.items(): + sync_methods[method_name] = cls.__new_sync_method(async_method) + return sync_methods + + def __new__(cls, name, bases, dct): + """Instantiate new service class instance.""" + if StubBase not in bases: + # Collect existing remote methods. + base_methods = {} + for base in bases: + try: + remote_methods = base.__remote_methods + except AttributeError: + pass + else: + base_methods.update(remote_methods) + + # Set this class private attribute so that base_methods do not have + # to be recacluated in __init__. + dct['_ServiceClass__base_methods'] = base_methods + + for attribute, value in dct.items(): + base_method = base_methods.get(attribute, None) + if base_method: + if not callable(value): + raise ServiceDefinitionError( + 'Must override %s in %s with a method.' % ( + attribute, name)) + + if get_remote_method_info(value): + raise ServiceDefinitionError( + 'Do not use method decorator when overloading remote method %s ' + 'on service %s.' % + (attribute, name)) + + base_remote_method_info = get_remote_method_info(base_method) + remote_decorator = method( + base_remote_method_info.request_type, + base_remote_method_info.response_type) + new_remote_method = remote_decorator(value) + dct[attribute] = new_remote_method + + return type.__new__(cls, name, bases, dct) + + def __init__(cls, name, bases, dct): + """Create uninitialized state on new class.""" + type.__init__(cls, name, bases, dct) + + # Only service implementation classes should have remote methods and stub + # sub classes created. Stub implementations have their own methods passed + # in to the type constructor. + if StubBase not in bases: + # Create list of remote methods. + cls.__remote_methods = dict(cls.__base_methods) + + for attribute, value in dct.items(): + value = getattr(cls, attribute) + remote_method_info = get_remote_method_info(value) + if remote_method_info: + cls.__remote_methods[attribute] = value + + # Build asynchronous stub class. + stub_attributes = {'Service': cls} + async_methods = cls.__create_async_methods(cls.__remote_methods) + stub_attributes.update(async_methods) + async_class = type('AsyncStub', (StubBase, cls), stub_attributes) + cls.AsyncStub = async_class + + # Constructor for synchronous stub class. + def __init__(self, transport): + """Constructor. + + Args: + transport: Underlying transport to communicate with remote service. + """ + super(cls.Stub, self).__init__(transport) + self.async = cls.AsyncStub(transport) + + # Build synchronous stub class. + stub_attributes = {'Service': cls, + '__init__': __init__} + stub_attributes.update(cls.__create_sync_methods(async_methods)) + + cls.Stub = type('Stub', (StubBase, cls), stub_attributes) + + @staticmethod + def all_remote_methods(cls): + """Get all remote methods of service. + + Returns: + Dict from method name to unbound method. + """ + return dict(cls.__remote_methods) + + +class RequestState(object): + """Request state information. + + Properties: + remote_host: Remote host name where request originated. + remote_address: IP address where request originated. + server_host: Host of server within which service resides. + server_port: Post which service has recevied request from. + """ + + @util.positional(1) + def __init__(self, + remote_host=None, + remote_address=None, + server_host=None, + server_port=None): + """Constructor. + + Args: + remote_host: Assigned to property. + remote_address: Assigned to property. + server_host: Assigned to property. + server_port: Assigned to property. + """ + self.__remote_host = remote_host + self.__remote_address = remote_address + self.__server_host = server_host + self.__server_port = server_port + + @property + def remote_host(self): + return self.__remote_host + + @property + def remote_address(self): + return self.__remote_address + + @property + def server_host(self): + return self.__server_host + + @property + def server_port(self): + return self.__server_port + + def _repr_items(self): + for name in ['remote_host', + 'remote_address', + 'server_host', + 'server_port']: + yield name, getattr(self, name) + + def __repr__(self): + """String representation of state.""" + state = [self.__class__.__name__] + for name, value in self._repr_items(): + if value: + state.append('%s=%r' % (name, value)) + + return '<%s>' % (' '.join(state),) + + +class HttpRequestState(RequestState): + """HTTP request state information. + + NOTE: Does not attempt to represent certain types of information from the + request such as the query string as query strings are not permitted in + ProtoRPC URLs unless required by the underlying message format. + + Properties: + headers: wsgiref.headers.Headers instance of HTTP request headers. + http_method: HTTP method as a string. + service_path: Path on HTTP service where service is mounted. This path + will not include the remote method name. + """ + + @util.positional(1) + def __init__(self, + http_method=None, + service_path=None, + headers=None, + **kwargs): + """Constructor. + + Args: + Same as RequestState, including: + http_method: Assigned to property. + service_path: Assigned to property. + headers: HTTP request headers. If instance of Headers, assigned to + property without copying. If dict, will convert to name value pairs + for use with Headers constructor. Otherwise, passed as parameters to + Headers constructor. + """ + super(HttpRequestState, self).__init__(**kwargs) + + self.__http_method = http_method + self.__service_path = service_path + + # Initialize headers. + if isinstance(headers, dict): + header_list = [] + for key, value in sorted(headers.items()): + if not isinstance(value, list): + value = [value] + for item in value: + header_list.append((key, item)) + headers = header_list + self.__headers = wsgi_headers.Headers(headers or []) + + @property + def http_method(self): + return self.__http_method + + @property + def service_path(self): + return self.__service_path + + @property + def headers(self): + return self.__headers + + def _repr_items(self): + for item in super(HttpRequestState, self)._repr_items(): + yield item + + for name in ['http_method', 'service_path']: + yield name, getattr(self, name) + + yield 'headers', list(self.headers.items()) + + +class Service(six.with_metaclass(_ServiceClass, object)): + """Service base class. + + Base class used for defining remote services. Contains reflection functions, + useful helpers and built-in remote methods. + + Services are expected to be constructed via either a constructor or factory + which takes no parameters. However, it might be required that some state or + configuration is passed in to a service across multiple requests. + + To do this, define parameters to the constructor of the service and use + the 'new_factory' class method to build a constructor that will transmit + parameters to the constructor. For example: + + class MyService(Service): + + def __init__(self, configuration, state): + self.configuration = configuration + self.state = state + + configuration = MyServiceConfiguration() + global_state = MyServiceState() + + my_service_factory = MyService.new_factory(configuration, + state=global_state) + + The contract with any service handler is that a new service object is created + to handle each user request, and that the construction does not take any + parameters. The factory satisfies this condition: + + new_instance = my_service_factory() + assert new_instance.state is global_state + + Attributes: + request_state: RequestState set via initialize_request_state. + """ + + __request_state = None + + @classmethod + def all_remote_methods(cls): + """Get all remote methods for service class. + + Built-in methods do not appear in the dictionary of remote methods. + + Returns: + Dictionary mapping method name to remote method. + """ + return _ServiceClass.all_remote_methods(cls) + + @classmethod + def new_factory(cls, *args, **kwargs): + """Create factory for service. + + Useful for passing configuration or state objects to the service. Accepts + arbitrary parameters and keywords, however, underlying service must accept + also accept not other parameters in its constructor. + + Args: + args: Args to pass to service constructor. + kwargs: Keyword arguments to pass to service constructor. + + Returns: + Factory function that will create a new instance and forward args and + keywords to the constructor. + """ + + def service_factory(): + return cls(*args, **kwargs) + + # Update docstring so that it is easier to debug. + full_class_name = '%s.%s' % (cls.__module__, cls.__name__) + service_factory.__doc__ = ( + 'Creates new instances of service %s.\n\n' + 'Returns:\n' + ' New instance of %s.' + % (cls.__name__, full_class_name)) + + # Update name so that it is easier to debug the factory function. + service_factory.__name__ = '%s_service_factory' % cls.__name__ + + service_factory.service_class = cls + + return service_factory + + def initialize_request_state(self, request_state): + """Save request state for use in remote method. + + Args: + request_state: RequestState instance. + """ + self.__request_state = request_state + + @classmethod + def definition_name(cls): + """Get definition name for Service class. + + Package name is determined by the global 'package' attribute in the + module that contains the Service definition. If no 'package' attribute + is available, uses module name. If no module is found, just uses class + name as name. + + Returns: + Fully qualified service name. + """ + try: + return cls.__definition_name + except AttributeError: + outer_definition_name = cls.outer_definition_name() + if outer_definition_name is None: + cls.__definition_name = cls.__name__ + else: + cls.__definition_name = '%s.%s' % (outer_definition_name, cls.__name__) + + return cls.__definition_name + + @classmethod + def outer_definition_name(cls): + """Get outer definition name. + + Returns: + Package for service. Services are never nested inside other definitions. + """ + return cls.definition_package() + + @classmethod + def definition_package(cls): + """Get package for service. + + Returns: + Package name for service. + """ + try: + return cls.__definition_package + except AttributeError: + cls.__definition_package = util.get_package_for_module(cls.__module__) + + return cls.__definition_package + + @property + def request_state(self): + """Request state associated with this Service instance.""" + return self.__request_state + + +def is_error_status(status): + """Function that determines whether the RPC status is an error. + + Args: + status: Initialized RpcStatus message to check for errors. + """ + status.check_initialized() + return RpcError.from_state(status.state) is not None + + +def check_rpc_status(status): + """Function converts an error status to a raised exception. + + Args: + status: Initialized RpcStatus message to check for errors. + + Raises: + RpcError according to state set on status, if it is an error state. + """ + status.check_initialized() + error_class = RpcError.from_state(status.state) + if error_class is not None: + if error_class is ApplicationError: + raise error_class(status.error_message, status.error_name) + else: + raise error_class(status.error_message) + + +class ProtocolConfig(object): + """Configuration for single protocol mapping. + + A read-only protocol configuration provides a given protocol implementation + with a name and a set of content-types that it recognizes. + + Properties: + protocol: The protocol implementation for configuration (usually a module, + for example, protojson, protobuf, etc.). This is an object that has the + following attributes: + CONTENT_TYPE: Used as the default content-type if default_content_type + is not set. + ALTERNATIVE_CONTENT_TYPES (optional): A list of alternative + content-types to the default that indicate the same protocol. + encode_message: Function that matches the signature of + ProtocolConfig.encode_message. Used for encoding a ProtoRPC message. + decode_message: Function that matches the signature of + ProtocolConfig.decode_message. Used for decoding a ProtoRPC message. + name: Name of protocol configuration. + default_content_type: The default content type for the protocol. Overrides + CONTENT_TYPE defined on protocol. + alternative_content_types: A list of alternative content-types supported + by the protocol. Must not contain the default content-type, nor + duplicates. Overrides ALTERNATIVE_CONTENT_TYPE defined on protocol. + content_types: A list of all content-types supported by configuration. + Combination of default content-type and alternatives. + """ + + def __init__(self, + protocol, + name, + default_content_type=None, + alternative_content_types=None): + """Constructor. + + Args: + protocol: The protocol implementation for configuration. + name: The name of the protocol configuration. + default_content_type: The default content-type for protocol. If none + provided it will check protocol.CONTENT_TYPE. + alternative_content_types: A list of content-types. If none provided, + it will check protocol.ALTERNATIVE_CONTENT_TYPES. If that attribute + does not exist, will be an empty tuple. + + Raises: + ServiceConfigurationError if there are any duplicate content-types. + """ + self.__protocol = protocol + self.__name = name + self.__default_content_type = (default_content_type or + protocol.CONTENT_TYPE).lower() + if alternative_content_types is None: + alternative_content_types = getattr(protocol, + 'ALTERNATIVE_CONTENT_TYPES', + ()) + self.__alternative_content_types = tuple( + content_type.lower() for content_type in alternative_content_types) + self.__content_types = ( + (self.__default_content_type,) + self.__alternative_content_types) + + # Detect duplicate content types in definition. + previous_type = None + for content_type in sorted(self.content_types): + if content_type == previous_type: + raise ServiceConfigurationError( + 'Duplicate content-type %s' % content_type) + previous_type = content_type + + @property + def protocol(self): + return self.__protocol + + @property + def name(self): + return self.__name + + @property + def default_content_type(self): + return self.__default_content_type + + @property + def alternate_content_types(self): + return self.__alternative_content_types + + @property + def content_types(self): + return self.__content_types + + def encode_message(self, message): + """Encode message. + + Args: + message: Message instance to encode. + + Returns: + String encoding of Message instance encoded in protocol's format. + """ + return self.__protocol.encode_message(message) + + def decode_message(self, message_type, encoded_message): + """Decode buffer to Message instance. + + Args: + message_type: Message type to decode data to. + encoded_message: Encoded version of message as string. + + Returns: + Decoded instance of message_type. + """ + return self.__protocol.decode_message(message_type, encoded_message) + + +class Protocols(object): + """Collection of protocol configurations. + + Used to describe a complete set of content-type mappings for multiple + protocol configurations. + + Properties: + names: Sorted list of the names of registered protocols. + content_types: Sorted list of supported content-types. + """ + + __default_protocols = None + __lock = threading.Lock() + + def __init__(self): + """Constructor.""" + self.__by_name = {} + self.__by_content_type = {} + + def add_protocol_config(self, config): + """Add a protocol configuration to protocol mapping. + + Args: + config: A ProtocolConfig. + + Raises: + ServiceConfigurationError if protocol.name is already registered + or any of it's content-types are already registered. + """ + if config.name in self.__by_name: + raise ServiceConfigurationError( + 'Protocol name %r is already in use' % config.name) + for content_type in config.content_types: + if content_type in self.__by_content_type: + raise ServiceConfigurationError( + 'Content type %r is already in use' % content_type) + + self.__by_name[config.name] = config + self.__by_content_type.update((t, config) for t in config.content_types) + + def add_protocol(self, *args, **kwargs): + """Add a protocol configuration from basic parameters. + + Simple helper method that creates and registeres a ProtocolConfig instance. + """ + self.add_protocol_config(ProtocolConfig(*args, **kwargs)) + + @property + def names(self): + return tuple(sorted(self.__by_name)) + + @property + def content_types(self): + return tuple(sorted(self.__by_content_type)) + + def lookup_by_name(self, name): + """Look up a ProtocolConfig by name. + + Args: + name: Name of protocol to look for. + + Returns: + ProtocolConfig associated with name. + + Raises: + KeyError if there is no protocol for name. + """ + return self.__by_name[name.lower()] + + def lookup_by_content_type(self, content_type): + """Look up a ProtocolConfig by content-type. + + Args: + content_type: Content-type to find protocol configuration for. + + Returns: + ProtocolConfig associated with content-type. + + Raises: + KeyError if there is no protocol for content-type. + """ + return self.__by_content_type[content_type.lower()] + + @classmethod + def new_default(cls): + """Create default protocols configuration. + + Returns: + New Protocols instance configured for protobuf and protorpc. + """ + protocols = cls() + protocols.add_protocol(protobuf, 'protobuf') + protocols.add_protocol(protojson.ProtoJson.get_default(), 'protojson') + return protocols + + @classmethod + def get_default(cls): + """Get the global default Protocols instance. + + Returns: + Current global default Protocols instance. + """ + default_protocols = cls.__default_protocols + if default_protocols is None: + with cls.__lock: + default_protocols = cls.__default_protocols + if default_protocols is None: + default_protocols = cls.new_default() + cls.__default_protocols = default_protocols + return default_protocols + + @classmethod + def set_default(cls, protocols): + """Set the global default Protocols instance. + + Args: + protocols: A Protocols instance. + + Raises: + TypeError: If protocols is not an instance of Protocols. + """ + if not isinstance(protocols, Protocols): + raise TypeError( + 'Expected value of type "Protocols", found %r' % protocols) + with cls.__lock: + cls.__default_protocols = protocols diff --git a/endpoints/internal/protorpc/remote_test.py b/endpoints/internal/protorpc/remote_test.py new file mode 100644 index 0000000..155dcb8 --- /dev/null +++ b/endpoints/internal/protorpc/remote_test.py @@ -0,0 +1,933 @@ +#!/usr/bin/env python +# +# Copyright 2010 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""Tests for protorpc.remote.""" + +__author__ = 'rafek@google.com (Rafe Kaplan)' + + +import sys +import types +import unittest +from wsgiref import headers + +from protorpc import descriptor +from protorpc import message_types +from protorpc import messages +from protorpc import protobuf +from protorpc import protojson +from protorpc import remote +from protorpc import test_util +from protorpc import transport + +import mox + + +class ModuleInterfaceTest(test_util.ModuleInterfaceTest, + test_util.TestCase): + + MODULE = remote + + +class Request(messages.Message): + """Test request message.""" + + value = messages.StringField(1) + + +class Response(messages.Message): + """Test response message.""" + + value = messages.StringField(1) + + +class MyService(remote.Service): + + @remote.method(Request, Response) + def remote_method(self, request): + response = Response() + response.value = request.value + return response + + +class SimpleRequest(messages.Message): + """Simple request message type used for tests.""" + + param1 = messages.StringField(1) + param2 = messages.StringField(2) + + +class SimpleResponse(messages.Message): + """Simple response message type used for tests.""" + + +class BasicService(remote.Service): + """A basic service with decorated remote method.""" + + def __init__(self): + self.request_ids = [] + + @remote.method(SimpleRequest, SimpleResponse) + def remote_method(self, request): + """BasicService remote_method docstring.""" + self.request_ids.append(id(request)) + return SimpleResponse() + + +class RpcErrorTest(test_util.TestCase): + + def testFromStatus(self): + for state in remote.RpcState: + exception = remote.RpcError.from_state + self.assertEquals(remote.ServerError, + remote.RpcError.from_state('SERVER_ERROR')) + + +class ApplicationErrorTest(test_util.TestCase): + + def testErrorCode(self): + self.assertEquals('blam', + remote.ApplicationError('an error', 'blam').error_name) + + def testStr(self): + self.assertEquals('an error', str(remote.ApplicationError('an error', 1))) + + def testRepr(self): + self.assertEquals("ApplicationError('an error', 1)", + repr(remote.ApplicationError('an error', 1))) + + self.assertEquals("ApplicationError('an error')", + repr(remote.ApplicationError('an error'))) + + +class MethodTest(test_util.TestCase): + """Test remote method decorator.""" + + def testMethod(self): + """Test use of remote decorator.""" + self.assertEquals(SimpleRequest, + BasicService.remote_method.remote.request_type) + self.assertEquals(SimpleResponse, + BasicService.remote_method.remote.response_type) + self.assertTrue(isinstance(BasicService.remote_method.remote.method, + types.FunctionType)) + + def testMethodMessageResolution(self): + """Test use of remote decorator to resolve message types by name.""" + class OtherService(remote.Service): + + @remote.method('SimpleRequest', 'SimpleResponse') + def remote_method(self, request): + pass + + self.assertEquals(SimpleRequest, + OtherService.remote_method.remote.request_type) + self.assertEquals(SimpleResponse, + OtherService.remote_method.remote.response_type) + + def testMethodMessageResolution_NotFound(self): + """Test failure to find message types.""" + class OtherService(remote.Service): + + @remote.method('NoSuchRequest', 'NoSuchResponse') + def remote_method(self, request): + pass + + self.assertRaisesWithRegexpMatch( + messages.DefinitionNotFoundError, + 'Could not find definition for NoSuchRequest', + getattr, + OtherService.remote_method.remote, + 'request_type') + + self.assertRaisesWithRegexpMatch( + messages.DefinitionNotFoundError, + 'Could not find definition for NoSuchResponse', + getattr, + OtherService.remote_method.remote, + 'response_type') + + def testInvocation(self): + """Test that invocation passes request through properly.""" + service = BasicService() + request = SimpleRequest() + self.assertEquals(SimpleResponse(), service.remote_method(request)) + self.assertEquals([id(request)], service.request_ids) + + def testInvocation_WrongRequestType(self): + """Wrong request type passed to remote method.""" + service = BasicService() + + self.assertRaises(remote.RequestError, + service.remote_method, + 'wrong') + + self.assertRaises(remote.RequestError, + service.remote_method, + None) + + self.assertRaises(remote.RequestError, + service.remote_method, + SimpleResponse()) + + def testInvocation_WrongResponseType(self): + """Wrong response type returned from remote method.""" + + class AnotherService(object): + + @remote.method(SimpleRequest, SimpleResponse) + def remote_method(self, unused_request): + return self.return_this + + service = AnotherService() + + service.return_this = 'wrong' + self.assertRaises(remote.ServerError, + service.remote_method, + SimpleRequest()) + service.return_this = None + self.assertRaises(remote.ServerError, + service.remote_method, + SimpleRequest()) + service.return_this = SimpleRequest() + self.assertRaises(remote.ServerError, + service.remote_method, + SimpleRequest()) + + def testBadRequestType(self): + """Test bad request types used in remote definition.""" + + for request_type in (None, 1020, messages.Message, str): + + def declare(): + class BadService(object): + + @remote.method(request_type, SimpleResponse) + def remote_method(self, request): + pass + + self.assertRaises(TypeError, declare) + + def testBadResponseType(self): + """Test bad response types used in remote definition.""" + + for response_type in (None, 1020, messages.Message, str): + + def declare(): + class BadService(object): + + @remote.method(SimpleRequest, response_type) + def remote_method(self, request): + pass + + self.assertRaises(TypeError, declare) + + def testDocString(self): + """Test that the docstring comes from the original method.""" + service = BasicService() + self.assertEquals('BasicService remote_method docstring.', + service.remote_method.__doc__) + + +class GetRemoteMethodTest(test_util.TestCase): + """Test for is_remote_method.""" + + def testGetRemoteMethod(self): + """Test valid remote method detection.""" + + class Service(object): + + @remote.method(Request, Response) + def remote_method(self, request): + pass + + self.assertEquals(Service.remote_method.remote, + remote.get_remote_method_info(Service.remote_method)) + self.assertTrue(Service.remote_method.remote, + remote.get_remote_method_info(Service().remote_method)) + + def testGetNotRemoteMethod(self): + """Test positive result on a remote method.""" + + class NotService(object): + + def not_remote_method(self, request): + pass + + def fn(self): + pass + + class NotReallyRemote(object): + """Test negative result on many bad values for remote methods.""" + + def not_really(self, request): + pass + + not_really.remote = 'something else' + + for not_remote in [NotService.not_remote_method, + NotService().not_remote_method, + NotReallyRemote.not_really, + NotReallyRemote().not_really, + None, + 1, + 'a string', + fn]: + self.assertEquals(None, remote.get_remote_method_info(not_remote)) + + +class RequestStateTest(test_util.TestCase): + """Test request state.""" + + STATE_CLASS = remote.RequestState + + def testConstructor(self): + """Test constructor.""" + state = self.STATE_CLASS(remote_host='remote-host', + remote_address='remote-address', + server_host='server-host', + server_port=10) + self.assertEquals('remote-host', state.remote_host) + self.assertEquals('remote-address', state.remote_address) + self.assertEquals('server-host', state.server_host) + self.assertEquals(10, state.server_port) + + state = self.STATE_CLASS() + self.assertEquals(None, state.remote_host) + self.assertEquals(None, state.remote_address) + self.assertEquals(None, state.server_host) + self.assertEquals(None, state.server_port) + + def testConstructorError(self): + """Test unexpected keyword argument.""" + self.assertRaises(TypeError, + self.STATE_CLASS, + x=10) + + def testRepr(self): + """Test string representation.""" + self.assertEquals('<%s>' % self.STATE_CLASS.__name__, + repr(self.STATE_CLASS())) + self.assertEquals("<%s remote_host='abc'>" % self.STATE_CLASS.__name__, + repr(self.STATE_CLASS(remote_host='abc'))) + self.assertEquals("<%s remote_host='abc' " + "remote_address='def'>" % self.STATE_CLASS.__name__, + repr(self.STATE_CLASS(remote_host='abc', + remote_address='def'))) + self.assertEquals("<%s remote_host='abc' " + "remote_address='def' " + "server_host='ghi'>" % self.STATE_CLASS.__name__, + repr(self.STATE_CLASS(remote_host='abc', + remote_address='def', + server_host='ghi'))) + self.assertEquals("<%s remote_host='abc' " + "remote_address='def' " + "server_host='ghi' " + 'server_port=102>' % self.STATE_CLASS.__name__, + repr(self.STATE_CLASS(remote_host='abc', + remote_address='def', + server_host='ghi', + server_port=102))) + + +class HttpRequestStateTest(RequestStateTest): + + STATE_CLASS = remote.HttpRequestState + + def testHttpMethod(self): + state = remote.HttpRequestState(http_method='GET') + self.assertEquals('GET', state.http_method) + + def testHttpMethod(self): + state = remote.HttpRequestState(service_path='/bar') + self.assertEquals('/bar', state.service_path) + + def testHeadersList(self): + state = remote.HttpRequestState( + headers=[('a', 'b'), ('c', 'd'), ('c', 'e')]) + + self.assertEquals(['a', 'c', 'c'], list(state.headers.keys())) + self.assertEquals(['b'], state.headers.get_all('a')) + self.assertEquals(['d', 'e'], state.headers.get_all('c')) + + def testHeadersDict(self): + state = remote.HttpRequestState(headers={'a': 'b', 'c': ['d', 'e']}) + + self.assertEquals(['a', 'c', 'c'], sorted(state.headers.keys())) + self.assertEquals(['b'], state.headers.get_all('a')) + self.assertEquals(['d', 'e'], state.headers.get_all('c')) + + def testRepr(self): + super(HttpRequestStateTest, self).testRepr() + + self.assertEquals("<%s remote_host='abc' " + "remote_address='def' " + "server_host='ghi' " + 'server_port=102 ' + "http_method='POST' " + "service_path='/bar' " + "headers=[('a', 'b'), ('c', 'd')]>" % + self.STATE_CLASS.__name__, + repr(self.STATE_CLASS(remote_host='abc', + remote_address='def', + server_host='ghi', + server_port=102, + http_method='POST', + service_path='/bar', + headers={'a': 'b', 'c': 'd'}, + ))) + + +class ServiceTest(test_util.TestCase): + """Test Service class.""" + + def testServiceBase_AllRemoteMethods(self): + """Test that service base class has no remote methods.""" + self.assertEquals({}, remote.Service.all_remote_methods()) + + def testAllRemoteMethods(self): + """Test all_remote_methods with properly Service subclass.""" + self.assertEquals({'remote_method': MyService.remote_method}, + MyService.all_remote_methods()) + + def testAllRemoteMethods_SubClass(self): + """Test all_remote_methods on a sub-class of a service.""" + class SubClass(MyService): + + @remote.method(Request, Response) + def sub_class_method(self, request): + pass + + self.assertEquals({'remote_method': SubClass.remote_method, + 'sub_class_method': SubClass.sub_class_method, + }, + SubClass.all_remote_methods()) + + def testOverrideMethod(self): + """Test that trying to override a remote method with remote decorator.""" + class SubClass(MyService): + + def remote_method(self, request): + response = super(SubClass, self).remote_method(request) + response.value = '(%s)' % response.value + return response + + self.assertEquals({'remote_method': SubClass.remote_method, + }, + SubClass.all_remote_methods()) + + instance = SubClass() + self.assertEquals('(Hello)', + instance.remote_method(Request(value='Hello')).value) + self.assertEquals(Request, SubClass.remote_method.remote.request_type) + self.assertEquals(Response, SubClass.remote_method.remote.response_type) + + def testOverrideMethodWithRemote(self): + """Test trying to override a remote method with remote decorator.""" + def do_override(): + class SubClass(MyService): + + @remote.method(Request, Response) + def remote_method(self, request): + pass + + self.assertRaisesWithRegexpMatch(remote.ServiceDefinitionError, + 'Do not use method decorator when ' + 'overloading remote method remote_method ' + 'on service SubClass', + do_override) + + def testOverrideMethodWithInvalidValue(self): + """Test trying to override a remote method with remote decorator.""" + def do_override(bad_value): + class SubClass(MyService): + + remote_method = bad_value + + for bad_value in [None, 1, 'string', {}]: + self.assertRaisesWithRegexpMatch(remote.ServiceDefinitionError, + 'Must override remote_method in ' + 'SubClass with a method', + do_override, bad_value) + + def testCallingRemoteMethod(self): + """Test invoking a remote method.""" + expected = Response() + expected.value = 'what was passed in' + + request = Request() + request.value = 'what was passed in' + + service = MyService() + self.assertEquals(expected, service.remote_method(request)) + + def testFactory(self): + """Test using factory to pass in state.""" + class StatefulService(remote.Service): + + def __init__(self, a, b, c=None): + self.a = a + self.b = b + self.c = c + + state = [1, 2, 3] + + factory = StatefulService.new_factory(1, state) + + module_name = ServiceTest.__module__ + pattern = ('Creates new instances of service StatefulService.\n\n' + 'Returns:\n' + ' New instance of %s.StatefulService.' % module_name) + self.assertEqual(pattern, factory.__doc__) + self.assertEquals('StatefulService_service_factory', factory.__name__) + self.assertEquals(StatefulService, factory.service_class) + + service = factory() + self.assertEquals(1, service.a) + self.assertEquals(id(state), id(service.b)) + self.assertEquals(None, service.c) + + factory = StatefulService.new_factory(2, b=3, c=4) + service = factory() + self.assertEquals(2, service.a) + self.assertEquals(3, service.b) + self.assertEquals(4, service.c) + + def testFactoryError(self): + """Test misusing a factory.""" + # Passing positional argument that is not accepted by class. + self.assertRaises(TypeError, remote.Service.new_factory(1)) + + # Passing keyword argument that is not accepted by class. + self.assertRaises(TypeError, remote.Service.new_factory(x=1)) + + class StatefulService(remote.Service): + + def __init__(self, a): + pass + + # Missing required parameter. + self.assertRaises(TypeError, StatefulService.new_factory()) + + def testDefinitionName(self): + """Test getting service definition name.""" + class TheService(remote.Service): + pass + + module_name = test_util.get_module_name(ServiceTest) + self.assertEqual(TheService.definition_name(), + '%s.TheService' % module_name) + self.assertTrue(TheService.outer_definition_name(), + module_name) + self.assertTrue(TheService.definition_package(), + module_name) + + def testDefinitionNameWithPackage(self): + """Test getting service definition name when package defined.""" + global package + package = 'my.package' + try: + class TheService(remote.Service): + pass + + self.assertEquals('my.package.TheService', TheService.definition_name()) + self.assertEquals('my.package', TheService.outer_definition_name()) + self.assertEquals('my.package', TheService.definition_package()) + finally: + del package + + def testDefinitionNameWithNoModule(self): + """Test getting service definition name when package defined.""" + module = sys.modules[__name__] + try: + del sys.modules[__name__] + class TheService(remote.Service): + pass + + self.assertEquals('TheService', TheService.definition_name()) + self.assertEquals(None, TheService.outer_definition_name()) + self.assertEquals(None, TheService.definition_package()) + finally: + sys.modules[__name__] = module + + +class StubTest(test_util.TestCase): + + def setUp(self): + self.mox = mox.Mox() + self.transport = self.mox.CreateMockAnything() + + def testDefinitionName(self): + self.assertEquals(BasicService.definition_name(), + BasicService.Stub.definition_name()) + self.assertEquals(BasicService.outer_definition_name(), + BasicService.Stub.outer_definition_name()) + self.assertEquals(BasicService.definition_package(), + BasicService.Stub.definition_package()) + + def testRemoteMethods(self): + self.assertEquals(BasicService.all_remote_methods(), + BasicService.Stub.all_remote_methods()) + + def testSync_WithRequest(self): + stub = BasicService.Stub(self.transport) + + request = SimpleRequest() + request.param1 = 'val1' + request.param2 = 'val2' + response = SimpleResponse() + + rpc = transport.Rpc(request) + rpc.set_response(response) + self.transport.send_rpc(BasicService.remote_method.remote, + request).AndReturn(rpc) + + self.mox.ReplayAll() + + self.assertEquals(SimpleResponse(), stub.remote_method(request)) + + self.mox.VerifyAll() + + def testSync_WithKwargs(self): + stub = BasicService.Stub(self.transport) + + + request = SimpleRequest() + request.param1 = 'val1' + request.param2 = 'val2' + response = SimpleResponse() + + rpc = transport.Rpc(request) + rpc.set_response(response) + self.transport.send_rpc(BasicService.remote_method.remote, + request).AndReturn(rpc) + + self.mox.ReplayAll() + + self.assertEquals(SimpleResponse(), stub.remote_method(param1='val1', + param2='val2')) + + self.mox.VerifyAll() + + def testAsync_WithRequest(self): + stub = BasicService.Stub(self.transport) + + request = SimpleRequest() + request.param1 = 'val1' + request.param2 = 'val2' + response = SimpleResponse() + + rpc = transport.Rpc(request) + + self.transport.send_rpc(BasicService.remote_method.remote, + request).AndReturn(rpc) + + self.mox.ReplayAll() + + self.assertEquals(rpc, stub.async.remote_method(request)) + + self.mox.VerifyAll() + + def testAsync_WithKwargs(self): + stub = BasicService.Stub(self.transport) + + request = SimpleRequest() + request.param1 = 'val1' + request.param2 = 'val2' + response = SimpleResponse() + + rpc = transport.Rpc(request) + + self.transport.send_rpc(BasicService.remote_method.remote, + request).AndReturn(rpc) + + self.mox.ReplayAll() + + self.assertEquals(rpc, stub.async.remote_method(param1='val1', + param2='val2')) + + self.mox.VerifyAll() + + def testAsync_WithRequestAndKwargs(self): + stub = BasicService.Stub(self.transport) + + request = SimpleRequest() + request.param1 = 'val1' + request.param2 = 'val2' + response = SimpleResponse() + + self.mox.ReplayAll() + + self.assertRaisesWithRegexpMatch( + TypeError, + r'May not provide both args and kwargs', + stub.async.remote_method, + request, + param1='val1', + param2='val2') + + self.mox.VerifyAll() + + def testAsync_WithTooManyPositionals(self): + stub = BasicService.Stub(self.transport) + + request = SimpleRequest() + request.param1 = 'val1' + request.param2 = 'val2' + response = SimpleResponse() + + self.mox.ReplayAll() + + self.assertRaisesWithRegexpMatch( + TypeError, + r'remote_method\(\) takes at most 2 positional arguments \(3 given\)', + stub.async.remote_method, + request, 'another value') + + self.mox.VerifyAll() + + +class IsErrorStatusTest(test_util.TestCase): + + def testIsError(self): + for state in (s for s in remote.RpcState if s > remote.RpcState.RUNNING): + status = remote.RpcStatus(state=state) + self.assertTrue(remote.is_error_status(status)) + + def testIsNotError(self): + for state in (s for s in remote.RpcState if s <= remote.RpcState.RUNNING): + status = remote.RpcStatus(state=state) + self.assertFalse(remote.is_error_status(status)) + + def testStateNone(self): + self.assertRaises(messages.ValidationError, + remote.is_error_status, remote.RpcStatus()) + + +class CheckRpcStatusTest(test_util.TestCase): + + def testStateNone(self): + self.assertRaises(messages.ValidationError, + remote.check_rpc_status, remote.RpcStatus()) + + def testNoError(self): + for state in (remote.RpcState.OK, remote.RpcState.RUNNING): + remote.check_rpc_status(remote.RpcStatus(state=state)) + + def testErrorState(self): + status = remote.RpcStatus(state=remote.RpcState.REQUEST_ERROR, + error_message='a request error') + self.assertRaisesWithRegexpMatch(remote.RequestError, + 'a request error', + remote.check_rpc_status, status) + + def testApplicationErrorState(self): + status = remote.RpcStatus(state=remote.RpcState.APPLICATION_ERROR, + error_message='an application error', + error_name='blam') + try: + remote.check_rpc_status(status) + self.fail('Should have raised application error.') + except remote.ApplicationError as err: + self.assertEquals('an application error', str(err)) + self.assertEquals('blam', err.error_name) + + +class ProtocolConfigTest(test_util.TestCase): + + def testConstructor(self): + config = remote.ProtocolConfig( + protojson, + 'proto1', + 'application/X-Json', + iter(['text/Json', 'text/JavaScript'])) + self.assertEquals(protojson, config.protocol) + self.assertEquals('proto1', config.name) + self.assertEquals('application/x-json', config.default_content_type) + self.assertEquals(('text/json', 'text/javascript'), + config.alternate_content_types) + self.assertEquals(('application/x-json', 'text/json', 'text/javascript'), + config.content_types) + + def testConstructorDefaults(self): + config = remote.ProtocolConfig(protojson, 'proto2') + self.assertEquals(protojson, config.protocol) + self.assertEquals('proto2', config.name) + self.assertEquals('application/json', config.default_content_type) + self.assertEquals(('application/x-javascript', + 'text/javascript', + 'text/x-javascript', + 'text/x-json', + 'text/json'), + config.alternate_content_types) + self.assertEquals(('application/json', + 'application/x-javascript', + 'text/javascript', + 'text/x-javascript', + 'text/x-json', + 'text/json'), config.content_types) + + def testEmptyAlternativeTypes(self): + config = remote.ProtocolConfig(protojson, 'proto2', + alternative_content_types=()) + self.assertEquals(protojson, config.protocol) + self.assertEquals('proto2', config.name) + self.assertEquals('application/json', config.default_content_type) + self.assertEquals((), config.alternate_content_types) + self.assertEquals(('application/json',), config.content_types) + + def testDuplicateContentTypes(self): + self.assertRaises(remote.ServiceConfigurationError, + remote.ProtocolConfig, + protojson, + 'json', + 'text/plain', + ('text/plain',)) + + self.assertRaises(remote.ServiceConfigurationError, + remote.ProtocolConfig, + protojson, + 'json', + 'text/plain', + ('text/html', 'text/html')) + + def testEncodeMessage(self): + config = remote.ProtocolConfig(protojson, 'proto2') + encoded_message = config.encode_message( + remote.RpcStatus(state=remote.RpcState.SERVER_ERROR, + error_message='bad error')) + + # Convert back to a dictionary from JSON. + dict_message = protojson.json.loads(encoded_message) + self.assertEquals({'state': 'SERVER_ERROR', 'error_message': 'bad error'}, + dict_message) + + def testDecodeMessage(self): + config = remote.ProtocolConfig(protojson, 'proto2') + self.assertEquals( + remote.RpcStatus(state=remote.RpcState.SERVER_ERROR, + error_message="bad error"), + config.decode_message( + remote.RpcStatus, + '{"state": "SERVER_ERROR", "error_message": "bad error"}')) + + +class ProtocolsTest(test_util.TestCase): + + def setUp(self): + self.protocols = remote.Protocols() + + def testEmpty(self): + self.assertEquals((), self.protocols.names) + self.assertEquals((), self.protocols.content_types) + + def testAddProtocolAllDefaults(self): + self.protocols.add_protocol(protojson, 'json') + self.assertEquals(('json',), self.protocols.names) + self.assertEquals(('application/json', + 'application/x-javascript', + 'text/javascript', + 'text/json', + 'text/x-javascript', + 'text/x-json'), + self.protocols.content_types) + + def testAddProtocolNoDefaultAlternatives(self): + class Protocol(object): + CONTENT_TYPE = 'text/plain' + + self.protocols.add_protocol(Protocol, 'text') + self.assertEquals(('text',), self.protocols.names) + self.assertEquals(('text/plain',), self.protocols.content_types) + + def testAddProtocolOverrideDefaults(self): + self.protocols.add_protocol(protojson, 'json', + default_content_type='text/blar', + alternative_content_types=('text/blam', + 'text/blim')) + self.assertEquals(('json',), self.protocols.names) + self.assertEquals(('text/blam', 'text/blar', 'text/blim'), + self.protocols.content_types) + + def testLookupByName(self): + self.protocols.add_protocol(protojson, 'json') + self.protocols.add_protocol(protojson, 'json2', + default_content_type='text/plain', + alternative_content_types=()) + + self.assertEquals('json', self.protocols.lookup_by_name('JsOn').name) + self.assertEquals('json2', self.protocols.lookup_by_name('Json2').name) + + def testLookupByContentType(self): + self.protocols.add_protocol(protojson, 'json') + self.protocols.add_protocol(protojson, 'json2', + default_content_type='text/plain', + alternative_content_types=()) + + self.assertEquals( + 'json', + self.protocols.lookup_by_content_type('AppliCation/Json').name) + + self.assertEquals( + 'json', + self.protocols.lookup_by_content_type('text/x-Json').name) + + self.assertEquals( + 'json2', + self.protocols.lookup_by_content_type('text/Plain').name) + + def testNewDefault(self): + protocols = remote.Protocols.new_default() + self.assertEquals(('protobuf', 'protojson'), protocols.names) + + protobuf_protocol = protocols.lookup_by_name('protobuf') + self.assertEquals(protobuf, protobuf_protocol.protocol) + + protojson_protocol = protocols.lookup_by_name('protojson') + self.assertEquals(protojson.ProtoJson.get_default(), + protojson_protocol.protocol) + + def testGetDefaultProtocols(self): + protocols = remote.Protocols.get_default() + self.assertEquals(('protobuf', 'protojson'), protocols.names) + + protobuf_protocol = protocols.lookup_by_name('protobuf') + self.assertEquals(protobuf, protobuf_protocol.protocol) + + protojson_protocol = protocols.lookup_by_name('protojson') + self.assertEquals(protojson.ProtoJson.get_default(), + protojson_protocol.protocol) + + self.assertTrue(protocols is remote.Protocols.get_default()) + + def testSetDefaultProtocols(self): + protocols = remote.Protocols() + remote.Protocols.set_default(protocols) + self.assertTrue(protocols is remote.Protocols.get_default()) + + def testSetDefaultWithoutProtocols(self): + self.assertRaises(TypeError, remote.Protocols.set_default, None) + self.assertRaises(TypeError, remote.Protocols.set_default, 'hi protocols') + self.assertRaises(TypeError, remote.Protocols.set_default, {}) + + +def main(): + unittest.main() + + +if __name__ == '__main__': + main() diff --git a/endpoints/internal/protorpc/static/base.html b/endpoints/internal/protorpc/static/base.html new file mode 100644 index 0000000..a62db7c --- /dev/null +++ b/endpoints/internal/protorpc/static/base.html @@ -0,0 +1,57 @@ + + + + + + {% block title%}Need title{% endblock %} + + + + + + + + + {% block top %}Need top{% endblock %} + +
+ + {% block body %}Need body{% endblock %} + + + diff --git a/endpoints/internal/protorpc/static/forms.html b/endpoints/internal/protorpc/static/forms.html new file mode 100644 index 0000000..9ba22ec --- /dev/null +++ b/endpoints/internal/protorpc/static/forms.html @@ -0,0 +1,31 @@ + + +{% extends 'base.html' %} + +{% block title %}ProtoRPC Methods for {{hostname|escape}}{% endblock %} + +{% block top %} +

ProtoRPC Methods for {{hostname|escape}}

+{% endblock %} + +{% block body %} +
+{% endblock %} + +{% block call %} +loadServices(showMethods); +{% endblock %} diff --git a/endpoints/internal/protorpc/static/forms.js b/endpoints/internal/protorpc/static/forms.js new file mode 100644 index 0000000..3c59252 --- /dev/null +++ b/endpoints/internal/protorpc/static/forms.js @@ -0,0 +1,685 @@ +// Copyright 2010 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +/** + * @fileoverview Render form appropriate for RPC method. + * @author rafek@google.com (Rafe Kaplan) + */ + + +var FORM_VISIBILITY = { + SHOW_FORM: 'Show Form', + HIDE_FORM: 'Hide Form' +}; + + +var LABEL = { + OPTIONAL: 'OPTIONAL', + REQUIRED: 'REQUIRED', + REPEATED: 'REPEATED' +}; + + +var objectId = 0; + + +/** + * Variants defined in protorpc/messages.py. + */ +var VARIANT = { + DOUBLE: 'DOUBLE', + FLOAT: 'FLOAT', + INT64: 'INT64', + UINT64: 'UINT64', + INT32: 'INT32', + BOOL: 'BOOL', + STRING: 'STRING', + MESSAGE: 'MESSAGE', + BYTES: 'BYTES', + UINT32: 'UINT32', + ENUM: 'ENUM', + SINT32: 'SINT32', + SINT64: 'SINT64' +}; + + +/** + * Data structure used to represent a form to data element. + * @param {Object} field Field descriptor that form element represents. + * @param {Object} container Element that contains field. + * @return {FormElement} New object representing a form element. Element + * starts enabled. + * @constructor + */ +function FormElement(field, container) { + this.field = field; + this.container = container; + this.enabled = true; +} + + +/** + * Display error message in error panel. + * @param {string} message Message to display in panel. + */ +function error(message) { + $('
').appendTo($('#error-messages')).text(message); +} + + +/** + * Display request errors in error panel. + * @param {object} XMLHttpRequest object. + */ +function handleRequestError(response) { + var contentType = response.getResponseHeader('content-type'); + if (contentType == 'application/json') { + var response_error = $.parseJSON(response.responseText); + var error_message = response_error.error_message; + if (error.state == 'APPLICATION_ERROR' && error.error_name) { + error_message = error_message + ' (' + error.error_name + ')'; + } + } else { + error_message = '' + response.status + ': ' + response.statusText; + } + + error(error_message); +} + + +/** + * Send JSON RPC to remote method. + * @param {string} path Path of service on originating server to send request. + * @param {string} method Name of method to invoke. + * @param {Object} request Message to send as request. + * @param {function} on_success Function to call upon successful request. + */ +function sendRequest(path, method, request, onSuccess) { + $.ajax({url: path + '.' + method, + type: 'POST', + contentType: 'application/json', + data: $.toJSON(request), + dataType: 'json', + success: onSuccess, + error: handleRequestError + }); +} + + +/** + * Create callback that enables and disables field element when associated + * checkbox is clicked. + * @param {Element} checkbox Checkbox that will be clicked. + * @param {FormElement} form Form element that will be toggled for editing. + * @param {Object} disableMessage HTML element to display in place of element. + * @return Callback that is invoked every time checkbox is clicked. + */ +function toggleInput(checkbox, form, disableMessage) { + return function() { + var checked = checkbox.checked; + if (checked) { + buildIndividualForm(form); + form.enabled = true; + disableMessage.hide(); + } else { + form.display.empty(); + form.enabled = false; + disableMessage.show(); + } + }; +} + + +/** + * Build an enum field. + * @param {FormElement} form Form to build element for. + */ +function buildEnumField(form) { + form.descriptor = enumDescriptors[form.field.type_name]; + form.input = $(''); + form.input[0].checked = Boolean(form.field.default_value); +} + + +/** + * Build text field. + * @param {FormElement} form Form to build element for. + */ +function buildTextField(form) { + form.input = $(''); + form.input. + attr('value', form.field.default_value || ''); +} + + +/** + * Build individual input element. + * @param {FormElement} form Form to build element for. + */ +function buildIndividualForm(form) { + form.required = form.label == LABEL.REQUIRED; + + if (form.field.variant == VARIANT.ENUM) { + buildEnumField(form); + } else if (form.field.variant == VARIANT.MESSAGE) { + buildMessageField(form); + } else if (form.field.variant == VARIANT.BOOL) { + buildBooleanField(form); + } else { + buildTextField(form); + } + + form.display.append(form.input); + + // TODO: Handle base64 encoding for BYTES field. + if (form.field.variant == VARIANT.BYTES) { + $("use base64 encoding").appendTo(form.display); + } +} + + +/** + * Add repeated field. This function is called when an item is added + * @param {FormElement} form Repeated form element to create item for. + */ +function addRepeatedFieldItem(form) { + var row = $('').appendTo(form.display); + subForm = new FormElement(form.field, row); + form.fields.push(subForm); + buildFieldForm(subForm, false); +} + + +/** + * Build repeated field. Contains a button that can be used for adding new + * items. + * @param {FormElement} form Form to build element for. + */ +function buildRepeatedForm(form) { + form.fields = []; + form.display = $(''). + appendTo(form.container); + var header_row = $('').appendTo(form.display); + var header = $('').appendTo(form.table); + var fieldForm = new FormElement(field, row); + fieldForm.parent = form; + buildFieldForm(fieldForm, true); + form.fields.push(fieldForm); + }); + } +} + + +/** + * HTML Escape a string + */ +function htmlEscape(value) { + if (typeof(value) == "string") { + return value + .replace(/&/g, '&') + .replace(/>/g, '>') + .replace(/'; + }); + result += indentation + ']'; + } else { + result += '{
'; + $.each(value, function(name, item) { + result += (indentation + htmlEscape(name) + ': ' + + formatJSON(item, indent + 1) + ',
'); + }); + result += indentation + '}'; + } + } else { + result += htmlEscape(value); + } + + return result; +} + + +/** + * Construct array from repeated form element. + * @param {FormElement} form Form element to build array from. + * @return {Array} Array of repeated elements read from input form. + */ +function fromRepeatedForm(form) { + var values = []; + $.each(form.fields, function(index, subForm) { + values.push(fromIndividualForm(subForm)); + }); + return values; +} + + +/** + * Construct value from individual form element. + * @param {FormElement} form Form element to get value from. + * @return {string, Float, Integer, Boolean, object} Value extracted from + * individual field. The type depends on the field variant. + */ +function fromIndividualForm(form) { + switch(form.field.variant) { + case VARIANT.MESSAGE: + return fromMessageForm(form); + break; + + case VARIANT.DOUBLE: + case VARIANT.FLOAT: + return parseFloat(form.input.val()); + + case VARIANT.BOOL: + return form.input[0].checked; + break; + + case VARIANT.ENUM: + case VARIANT.STRING: + case VARIANT.BYTES: + return form.input.val(); + + default: + break; + } + return parseInt(form.input.val(), 10); +} + + +/** + * Extract entire message from a complete form. + * @param {FormElement} form Form to extract message from. + * @return {Object} Fully populated message object ready to transmit + * as JSON message. + */ +function fromMessageForm(form) { + var message = {}; + $.each(form.fields, function(index, subForm) { + if (subForm.enabled) { + var subMessage = undefined; + if (subForm.field.label == LABEL.REPEATED) { + subMessage = fromRepeatedForm(subForm); + } else { + subMessage = fromIndividualForm(subForm); + } + + message[subForm.field.name] = subMessage; + } + }); + + return message; +} + + +/** + * Send form as an RPC. Extracts message from root form and transmits to + * originating ProtoRPC server. Response is formatted as JSON and displayed + * to user. + */ +function sendForm() { + $('#error-messages').empty(); + $('#form-response').empty(); + message = fromMessageForm(root_form); + if (message === null) { + return; + } + + sendRequest(servicePath, methodName, message, function(response) { + $('#form-response').html(formatJSON(response, 0)); + hideForm(); + }); +} + + +/** + * Reset form to original state. Deletes existing form and rebuilds a new + * one from scratch. + */ +function resetForm() { + var panel = $('#form-panel'); + var serviceType = serviceMap[servicePath]; + var service = serviceDescriptors[serviceType]; + + panel.empty(); + + function formGenerationError(message) { + error(message); + panel.html('
' + + 'There was an error generating the service form' + + '
'); + } + + // Find method. + var requestTypeName = null; + $.each(service.methods, function(index, method) { + if (method.name == methodName) { + requestTypeName = method.request_type; + } + }); + + if (!requestTypeName) { + formGenerationError('No such method definition for: ' + methodName); + return; + } + + requestType = messageDescriptors[requestTypeName]; + if (!requestType) { + formGenerationError('No such message-type: ' + requestTypeName); + return; + } + + var root = $('
').appendTo(header_row); + var add_button = $(''); + + // Set name. + if (allowRepeated) { + var nameData = $(''); + nameData.text(form.field.name + ':'); + form.container.append(nameData); + } + + // Set input. + form.repeated = form.field.label == LABEL.REPEATED; + if (allowRepeated && form.repeated) { + inputData.attr('colspan', '2'); + buildRepeatedForm(form); + } else { + if (!allowRepeated) { + inputData.attr('colspan', '2'); + } + + form.display = $('
'); + + var controlData = $('
'); + if (form.field.label != LABEL.REQUIRED && allowRepeated) { + form.enabled = false; + var checkbox_id = 'checkbox-' + objectId; + objectId++; + $('').appendTo(controlData); + var checkbox = $('').appendTo(controlData); + var disableMessage = $('
').appendTo(inputData); + checkbox.change(toggleInput(checkbox[0], form, disableMessage)); + } else { + buildIndividualForm(form); + } + + if (form.repeated) { + // TODO: Implement deletion of repeated items. Needs to delete + // from DOM and also delete from form model. + } + + form.container.append(controlData); + } + + inputData.append(form.display); + form.container.append(inputData); +} + + +/** + * Top level function for building an entire message form. Called once at form + * creation and may be called again for nested message fields. Constructs a + * a table and builds a row for each sub-field. + * @params {FormElement} form Form to build message form for. + */ +function buildMessageForm(form, messageType) { + form.fields = []; + form.descriptor = messageType; + if (messageType.fields) { + $.each(messageType.fields, function(index, field) { + var row = $('
'). + appendTo(panel); + + root_form = new FormElement(null, null); + root_form.table = root; + buildMessageForm(root_form, requestType); + $('
a"; +var e=d.getElementsByTagName("*"),j=d.getElementsByTagName("a")[0];if(!(!e||!e.length||!j)){c.support={leadingWhitespace:d.firstChild.nodeType===3,tbody:!d.getElementsByTagName("tbody").length,htmlSerialize:!!d.getElementsByTagName("link").length,style:/red/.test(j.getAttribute("style")),hrefNormalized:j.getAttribute("href")==="/a",opacity:/^0.55$/.test(j.style.opacity),cssFloat:!!j.style.cssFloat,checkOn:d.getElementsByTagName("input")[0].value==="on",optSelected:s.createElement("select").appendChild(s.createElement("option")).selected, +parentNode:d.removeChild(d.appendChild(s.createElement("div"))).parentNode===null,deleteExpando:true,checkClone:false,scriptEval:false,noCloneEvent:true,boxModel:null};b.type="text/javascript";try{b.appendChild(s.createTextNode("window."+f+"=1;"))}catch(i){}a.insertBefore(b,a.firstChild);if(A[f]){c.support.scriptEval=true;delete A[f]}try{delete b.test}catch(o){c.support.deleteExpando=false}a.removeChild(b);if(d.attachEvent&&d.fireEvent){d.attachEvent("onclick",function k(){c.support.noCloneEvent= +false;d.detachEvent("onclick",k)});d.cloneNode(true).fireEvent("onclick")}d=s.createElement("div");d.innerHTML="";a=s.createDocumentFragment();a.appendChild(d.firstChild);c.support.checkClone=a.cloneNode(true).cloneNode(true).lastChild.checked;c(function(){var k=s.createElement("div");k.style.width=k.style.paddingLeft="1px";s.body.appendChild(k);c.boxModel=c.support.boxModel=k.offsetWidth===2;s.body.removeChild(k).style.display="none"});a=function(k){var n= +s.createElement("div");k="on"+k;var r=k in n;if(!r){n.setAttribute(k,"return;");r=typeof n[k]==="function"}return r};c.support.submitBubbles=a("submit");c.support.changeBubbles=a("change");a=b=d=e=j=null}})();c.props={"for":"htmlFor","class":"className",readonly:"readOnly",maxlength:"maxLength",cellspacing:"cellSpacing",rowspan:"rowSpan",colspan:"colSpan",tabindex:"tabIndex",usemap:"useMap",frameborder:"frameBorder"};var G="jQuery"+J(),Ya=0,za={};c.extend({cache:{},expando:G,noData:{embed:true,object:true, +applet:true},data:function(a,b,d){if(!(a.nodeName&&c.noData[a.nodeName.toLowerCase()])){a=a==A?za:a;var f=a[G],e=c.cache;if(!f&&typeof b==="string"&&d===w)return null;f||(f=++Ya);if(typeof b==="object"){a[G]=f;e[f]=c.extend(true,{},b)}else if(!e[f]){a[G]=f;e[f]={}}a=e[f];if(d!==w)a[b]=d;return typeof b==="string"?a[b]:a}},removeData:function(a,b){if(!(a.nodeName&&c.noData[a.nodeName.toLowerCase()])){a=a==A?za:a;var d=a[G],f=c.cache,e=f[d];if(b){if(e){delete e[b];c.isEmptyObject(e)&&c.removeData(a)}}else{if(c.support.deleteExpando)delete a[c.expando]; +else a.removeAttribute&&a.removeAttribute(c.expando);delete f[d]}}}});c.fn.extend({data:function(a,b){if(typeof a==="undefined"&&this.length)return c.data(this[0]);else if(typeof a==="object")return this.each(function(){c.data(this,a)});var d=a.split(".");d[1]=d[1]?"."+d[1]:"";if(b===w){var f=this.triggerHandler("getData"+d[1]+"!",[d[0]]);if(f===w&&this.length)f=c.data(this[0],a);return f===w&&d[1]?this.data(d[0]):f}else return this.trigger("setData"+d[1]+"!",[d[0],b]).each(function(){c.data(this, +a,b)})},removeData:function(a){return this.each(function(){c.removeData(this,a)})}});c.extend({queue:function(a,b,d){if(a){b=(b||"fx")+"queue";var f=c.data(a,b);if(!d)return f||[];if(!f||c.isArray(d))f=c.data(a,b,c.makeArray(d));else f.push(d);return f}},dequeue:function(a,b){b=b||"fx";var d=c.queue(a,b),f=d.shift();if(f==="inprogress")f=d.shift();if(f){b==="fx"&&d.unshift("inprogress");f.call(a,function(){c.dequeue(a,b)})}}});c.fn.extend({queue:function(a,b){if(typeof a!=="string"){b=a;a="fx"}if(b=== +w)return c.queue(this[0],a);return this.each(function(){var d=c.queue(this,a,b);a==="fx"&&d[0]!=="inprogress"&&c.dequeue(this,a)})},dequeue:function(a){return this.each(function(){c.dequeue(this,a)})},delay:function(a,b){a=c.fx?c.fx.speeds[a]||a:a;b=b||"fx";return this.queue(b,function(){var d=this;setTimeout(function(){c.dequeue(d,b)},a)})},clearQueue:function(a){return this.queue(a||"fx",[])}});var Aa=/[\n\t]/g,ca=/\s+/,Za=/\r/g,$a=/href|src|style/,ab=/(button|input)/i,bb=/(button|input|object|select|textarea)/i, +cb=/^(a|area)$/i,Ba=/radio|checkbox/;c.fn.extend({attr:function(a,b){return X(this,a,b,true,c.attr)},removeAttr:function(a){return this.each(function(){c.attr(this,a,"");this.nodeType===1&&this.removeAttribute(a)})},addClass:function(a){if(c.isFunction(a))return this.each(function(n){var r=c(this);r.addClass(a.call(this,n,r.attr("class")))});if(a&&typeof a==="string")for(var b=(a||"").split(ca),d=0,f=this.length;d-1)return true;return false},val:function(a){if(a===w){var b=this[0];if(b){if(c.nodeName(b,"option"))return(b.attributes.value||{}).specified?b.value:b.text;if(c.nodeName(b,"select")){var d=b.selectedIndex,f=[],e=b.options;b=b.type==="select-one";if(d<0)return null;var j=b?d:0;for(d=b?d+1:e.length;j=0;else if(c.nodeName(this,"select")){var u=c.makeArray(r);c("option",this).each(function(){this.selected= +c.inArray(c(this).val(),u)>=0});if(!u.length)this.selectedIndex=-1}else this.value=r}})}});c.extend({attrFn:{val:true,css:true,html:true,text:true,data:true,width:true,height:true,offset:true},attr:function(a,b,d,f){if(!a||a.nodeType===3||a.nodeType===8)return w;if(f&&b in c.attrFn)return c(a)[b](d);f=a.nodeType!==1||!c.isXMLDoc(a);var e=d!==w;b=f&&c.props[b]||b;if(a.nodeType===1){var j=$a.test(b);if(b in a&&f&&!j){if(e){b==="type"&&ab.test(a.nodeName)&&a.parentNode&&c.error("type property can't be changed"); +a[b]=d}if(c.nodeName(a,"form")&&a.getAttributeNode(b))return a.getAttributeNode(b).nodeValue;if(b==="tabIndex")return(b=a.getAttributeNode("tabIndex"))&&b.specified?b.value:bb.test(a.nodeName)||cb.test(a.nodeName)&&a.href?0:w;return a[b]}if(!c.support.style&&f&&b==="style"){if(e)a.style.cssText=""+d;return a.style.cssText}e&&a.setAttribute(b,""+d);a=!c.support.hrefNormalized&&f&&j?a.getAttribute(b,2):a.getAttribute(b);return a===null?w:a}return c.style(a,b,d)}});var O=/\.(.*)$/,db=function(a){return a.replace(/[^\w\s\.\|`]/g, +function(b){return"\\"+b})};c.event={add:function(a,b,d,f){if(!(a.nodeType===3||a.nodeType===8)){if(a.setInterval&&a!==A&&!a.frameElement)a=A;var e,j;if(d.handler){e=d;d=e.handler}if(!d.guid)d.guid=c.guid++;if(j=c.data(a)){var i=j.events=j.events||{},o=j.handle;if(!o)j.handle=o=function(){return typeof c!=="undefined"&&!c.event.triggered?c.event.handle.apply(o.elem,arguments):w};o.elem=a;b=b.split(" ");for(var k,n=0,r;k=b[n++];){j=e?c.extend({},e):{handler:d,data:f};if(k.indexOf(".")>-1){r=k.split("."); +k=r.shift();j.namespace=r.slice(0).sort().join(".")}else{r=[];j.namespace=""}j.type=k;j.guid=d.guid;var u=i[k],z=c.event.special[k]||{};if(!u){u=i[k]=[];if(!z.setup||z.setup.call(a,f,r,o)===false)if(a.addEventListener)a.addEventListener(k,o,false);else a.attachEvent&&a.attachEvent("on"+k,o)}if(z.add){z.add.call(a,j);if(!j.handler.guid)j.handler.guid=d.guid}u.push(j);c.event.global[k]=true}a=null}}},global:{},remove:function(a,b,d,f){if(!(a.nodeType===3||a.nodeType===8)){var e,j=0,i,o,k,n,r,u,z=c.data(a), +C=z&&z.events;if(z&&C){if(b&&b.type){d=b.handler;b=b.type}if(!b||typeof b==="string"&&b.charAt(0)==="."){b=b||"";for(e in C)c.event.remove(a,e+b)}else{for(b=b.split(" ");e=b[j++];){n=e;i=e.indexOf(".")<0;o=[];if(!i){o=e.split(".");e=o.shift();k=new RegExp("(^|\\.)"+c.map(o.slice(0).sort(),db).join("\\.(?:.*\\.)?")+"(\\.|$)")}if(r=C[e])if(d){n=c.event.special[e]||{};for(B=f||0;B=0){a.type= +e=e.slice(0,-1);a.exclusive=true}if(!d){a.stopPropagation();c.event.global[e]&&c.each(c.cache,function(){this.events&&this.events[e]&&c.event.trigger(a,b,this.handle.elem)})}if(!d||d.nodeType===3||d.nodeType===8)return w;a.result=w;a.target=d;b=c.makeArray(b);b.unshift(a)}a.currentTarget=d;(f=c.data(d,"handle"))&&f.apply(d,b);f=d.parentNode||d.ownerDocument;try{if(!(d&&d.nodeName&&c.noData[d.nodeName.toLowerCase()]))if(d["on"+e]&&d["on"+e].apply(d,b)===false)a.result=false}catch(j){}if(!a.isPropagationStopped()&& +f)c.event.trigger(a,b,f,true);else if(!a.isDefaultPrevented()){f=a.target;var i,o=c.nodeName(f,"a")&&e==="click",k=c.event.special[e]||{};if((!k._default||k._default.call(d,a)===false)&&!o&&!(f&&f.nodeName&&c.noData[f.nodeName.toLowerCase()])){try{if(f[e]){if(i=f["on"+e])f["on"+e]=null;c.event.triggered=true;f[e]()}}catch(n){}if(i)f["on"+e]=i;c.event.triggered=false}}},handle:function(a){var b,d,f,e;a=arguments[0]=c.event.fix(a||A.event);a.currentTarget=this;b=a.type.indexOf(".")<0&&!a.exclusive; +if(!b){d=a.type.split(".");a.type=d.shift();f=new RegExp("(^|\\.)"+d.slice(0).sort().join("\\.(?:.*\\.)?")+"(\\.|$)")}e=c.data(this,"events");d=e[a.type];if(e&&d){d=d.slice(0);e=0;for(var j=d.length;e-1?c.map(a.options,function(f){return f.selected}).join("-"):"";else if(a.nodeName.toLowerCase()==="select")d=a.selectedIndex;return d},fa=function(a,b){var d=a.target,f,e;if(!(!da.test(d.nodeName)||d.readOnly)){f=c.data(d,"_change_data");e=Fa(d);if(a.type!=="focusout"||d.type!=="radio")c.data(d,"_change_data", +e);if(!(f===w||e===f))if(f!=null||e){a.type="change";return c.event.trigger(a,b,d)}}};c.event.special.change={filters:{focusout:fa,click:function(a){var b=a.target,d=b.type;if(d==="radio"||d==="checkbox"||b.nodeName.toLowerCase()==="select")return fa.call(this,a)},keydown:function(a){var b=a.target,d=b.type;if(a.keyCode===13&&b.nodeName.toLowerCase()!=="textarea"||a.keyCode===32&&(d==="checkbox"||d==="radio")||d==="select-multiple")return fa.call(this,a)},beforeactivate:function(a){a=a.target;c.data(a, +"_change_data",Fa(a))}},setup:function(){if(this.type==="file")return false;for(var a in ea)c.event.add(this,a+".specialChange",ea[a]);return da.test(this.nodeName)},teardown:function(){c.event.remove(this,".specialChange");return da.test(this.nodeName)}};ea=c.event.special.change.filters}s.addEventListener&&c.each({focus:"focusin",blur:"focusout"},function(a,b){function d(f){f=c.event.fix(f);f.type=b;return c.event.handle.call(this,f)}c.event.special[b]={setup:function(){this.addEventListener(a, +d,true)},teardown:function(){this.removeEventListener(a,d,true)}}});c.each(["bind","one"],function(a,b){c.fn[b]=function(d,f,e){if(typeof d==="object"){for(var j in d)this[b](j,f,d[j],e);return this}if(c.isFunction(f)){e=f;f=w}var i=b==="one"?c.proxy(e,function(k){c(this).unbind(k,i);return e.apply(this,arguments)}):e;if(d==="unload"&&b!=="one")this.one(d,f,e);else{j=0;for(var o=this.length;j0){y=t;break}}t=t[g]}m[q]=y}}}var f=/((?:\((?:\([^()]+\)|[^()]+)+\)|\[(?:\[[^[\]]*\]|['"][^'"]*['"]|[^[\]'"]+)+\]|\\.|[^ >+~,(\[\\]+)+|[>+~])(\s*,\s*)?((?:.|\r|\n)*)/g, +e=0,j=Object.prototype.toString,i=false,o=true;[0,0].sort(function(){o=false;return 0});var k=function(g,h,l,m){l=l||[];var q=h=h||s;if(h.nodeType!==1&&h.nodeType!==9)return[];if(!g||typeof g!=="string")return l;for(var p=[],v,t,y,S,H=true,M=x(h),I=g;(f.exec(""),v=f.exec(I))!==null;){I=v[3];p.push(v[1]);if(v[2]){S=v[3];break}}if(p.length>1&&r.exec(g))if(p.length===2&&n.relative[p[0]])t=ga(p[0]+p[1],h);else for(t=n.relative[p[0]]?[h]:k(p.shift(),h);p.length;){g=p.shift();if(n.relative[g])g+=p.shift(); +t=ga(g,t)}else{if(!m&&p.length>1&&h.nodeType===9&&!M&&n.match.ID.test(p[0])&&!n.match.ID.test(p[p.length-1])){v=k.find(p.shift(),h,M);h=v.expr?k.filter(v.expr,v.set)[0]:v.set[0]}if(h){v=m?{expr:p.pop(),set:z(m)}:k.find(p.pop(),p.length===1&&(p[0]==="~"||p[0]==="+")&&h.parentNode?h.parentNode:h,M);t=v.expr?k.filter(v.expr,v.set):v.set;if(p.length>0)y=z(t);else H=false;for(;p.length;){var D=p.pop();v=D;if(n.relative[D])v=p.pop();else D="";if(v==null)v=h;n.relative[D](y,v,M)}}else y=[]}y||(y=t);y||k.error(D|| +g);if(j.call(y)==="[object Array]")if(H)if(h&&h.nodeType===1)for(g=0;y[g]!=null;g++){if(y[g]&&(y[g]===true||y[g].nodeType===1&&E(h,y[g])))l.push(t[g])}else for(g=0;y[g]!=null;g++)y[g]&&y[g].nodeType===1&&l.push(t[g]);else l.push.apply(l,y);else z(y,l);if(S){k(S,q,l,m);k.uniqueSort(l)}return l};k.uniqueSort=function(g){if(B){i=o;g.sort(B);if(i)for(var h=1;h":function(g,h){var l=typeof h==="string";if(l&&!/\W/.test(h)){h=h.toLowerCase();for(var m=0,q=g.length;m=0))l||m.push(v);else if(l)h[p]=false;return false},ID:function(g){return g[1].replace(/\\/g,"")},TAG:function(g){return g[1].toLowerCase()}, +CHILD:function(g){if(g[1]==="nth"){var h=/(-?)(\d*)n((?:\+|-)?\d*)/.exec(g[2]==="even"&&"2n"||g[2]==="odd"&&"2n+1"||!/\D/.test(g[2])&&"0n+"+g[2]||g[2]);g[2]=h[1]+(h[2]||1)-0;g[3]=h[3]-0}g[0]=e++;return g},ATTR:function(g,h,l,m,q,p){h=g[1].replace(/\\/g,"");if(!p&&n.attrMap[h])g[1]=n.attrMap[h];if(g[2]==="~=")g[4]=" "+g[4]+" ";return g},PSEUDO:function(g,h,l,m,q){if(g[1]==="not")if((f.exec(g[3])||"").length>1||/^\w/.test(g[3]))g[3]=k(g[3],null,null,h);else{g=k.filter(g[3],h,l,true^q);l||m.push.apply(m, +g);return false}else if(n.match.POS.test(g[0])||n.match.CHILD.test(g[0]))return true;return g},POS:function(g){g.unshift(true);return g}},filters:{enabled:function(g){return g.disabled===false&&g.type!=="hidden"},disabled:function(g){return g.disabled===true},checked:function(g){return g.checked===true},selected:function(g){return g.selected===true},parent:function(g){return!!g.firstChild},empty:function(g){return!g.firstChild},has:function(g,h,l){return!!k(l[3],g).length},header:function(g){return/h\d/i.test(g.nodeName)}, +text:function(g){return"text"===g.type},radio:function(g){return"radio"===g.type},checkbox:function(g){return"checkbox"===g.type},file:function(g){return"file"===g.type},password:function(g){return"password"===g.type},submit:function(g){return"submit"===g.type},image:function(g){return"image"===g.type},reset:function(g){return"reset"===g.type},button:function(g){return"button"===g.type||g.nodeName.toLowerCase()==="button"},input:function(g){return/input|select|textarea|button/i.test(g.nodeName)}}, +setFilters:{first:function(g,h){return h===0},last:function(g,h,l,m){return h===m.length-1},even:function(g,h){return h%2===0},odd:function(g,h){return h%2===1},lt:function(g,h,l){return hl[3]-0},nth:function(g,h,l){return l[3]-0===h},eq:function(g,h,l){return l[3]-0===h}},filter:{PSEUDO:function(g,h,l,m){var q=h[1],p=n.filters[q];if(p)return p(g,l,h,m);else if(q==="contains")return(g.textContent||g.innerText||a([g])||"").indexOf(h[3])>=0;else if(q==="not"){h= +h[3];l=0;for(m=h.length;l=0}},ID:function(g,h){return g.nodeType===1&&g.getAttribute("id")===h},TAG:function(g,h){return h==="*"&&g.nodeType===1||g.nodeName.toLowerCase()===h},CLASS:function(g,h){return(" "+(g.className||g.getAttribute("class"))+" ").indexOf(h)>-1},ATTR:function(g,h){var l=h[1];g=n.attrHandle[l]?n.attrHandle[l](g):g[l]!=null?g[l]:g.getAttribute(l);l=g+"";var m=h[2];h=h[4];return g==null?m==="!=":m=== +"="?l===h:m==="*="?l.indexOf(h)>=0:m==="~="?(" "+l+" ").indexOf(h)>=0:!h?l&&g!==false:m==="!="?l!==h:m==="^="?l.indexOf(h)===0:m==="$="?l.substr(l.length-h.length)===h:m==="|="?l===h||l.substr(0,h.length+1)===h+"-":false},POS:function(g,h,l,m){var q=n.setFilters[h[2]];if(q)return q(g,l,h,m)}}},r=n.match.POS;for(var u in n.match){n.match[u]=new RegExp(n.match[u].source+/(?![^\[]*\])(?![^\(]*\))/.source);n.leftMatch[u]=new RegExp(/(^(?:.|\r|\n)*?)/.source+n.match[u].source.replace(/\\(\d+)/g,function(g, +h){return"\\"+(h-0+1)}))}var z=function(g,h){g=Array.prototype.slice.call(g,0);if(h){h.push.apply(h,g);return h}return g};try{Array.prototype.slice.call(s.documentElement.childNodes,0)}catch(C){z=function(g,h){h=h||[];if(j.call(g)==="[object Array]")Array.prototype.push.apply(h,g);else if(typeof g.length==="number")for(var l=0,m=g.length;l";var l=s.documentElement;l.insertBefore(g,l.firstChild);if(s.getElementById(h)){n.find.ID=function(m,q,p){if(typeof q.getElementById!=="undefined"&&!p)return(q=q.getElementById(m[1]))?q.id===m[1]||typeof q.getAttributeNode!=="undefined"&& +q.getAttributeNode("id").nodeValue===m[1]?[q]:w:[]};n.filter.ID=function(m,q){var p=typeof m.getAttributeNode!=="undefined"&&m.getAttributeNode("id");return m.nodeType===1&&p&&p.nodeValue===q}}l.removeChild(g);l=g=null})();(function(){var g=s.createElement("div");g.appendChild(s.createComment(""));if(g.getElementsByTagName("*").length>0)n.find.TAG=function(h,l){l=l.getElementsByTagName(h[1]);if(h[1]==="*"){h=[];for(var m=0;l[m];m++)l[m].nodeType===1&&h.push(l[m]);l=h}return l};g.innerHTML=""; +if(g.firstChild&&typeof g.firstChild.getAttribute!=="undefined"&&g.firstChild.getAttribute("href")!=="#")n.attrHandle.href=function(h){return h.getAttribute("href",2)};g=null})();s.querySelectorAll&&function(){var g=k,h=s.createElement("div");h.innerHTML="

";if(!(h.querySelectorAll&&h.querySelectorAll(".TEST").length===0)){k=function(m,q,p,v){q=q||s;if(!v&&q.nodeType===9&&!x(q))try{return z(q.querySelectorAll(m),p)}catch(t){}return g(m,q,p,v)};for(var l in g)k[l]=g[l];h=null}}(); +(function(){var g=s.createElement("div");g.innerHTML="
";if(!(!g.getElementsByClassName||g.getElementsByClassName("e").length===0)){g.lastChild.className="e";if(g.getElementsByClassName("e").length!==1){n.order.splice(1,0,"CLASS");n.find.CLASS=function(h,l,m){if(typeof l.getElementsByClassName!=="undefined"&&!m)return l.getElementsByClassName(h[1])};g=null}}})();var E=s.compareDocumentPosition?function(g,h){return!!(g.compareDocumentPosition(h)&16)}: +function(g,h){return g!==h&&(g.contains?g.contains(h):true)},x=function(g){return(g=(g?g.ownerDocument||g:0).documentElement)?g.nodeName!=="HTML":false},ga=function(g,h){var l=[],m="",q;for(h=h.nodeType?[h]:h;q=n.match.PSEUDO.exec(g);){m+=q[0];g=g.replace(n.match.PSEUDO,"")}g=n.relative[g]?g+"*":g;q=0;for(var p=h.length;q=0===d})};c.fn.extend({find:function(a){for(var b=this.pushStack("","find",a),d=0,f=0,e=this.length;f0)for(var j=d;j0},closest:function(a,b){if(c.isArray(a)){var d=[],f=this[0],e,j= +{},i;if(f&&a.length){e=0;for(var o=a.length;e-1:c(f).is(e)){d.push({selector:i,elem:f});delete j[i]}}f=f.parentNode}}return d}var k=c.expr.match.POS.test(a)?c(a,b||this.context):null;return this.map(function(n,r){for(;r&&r.ownerDocument&&r!==b;){if(k?k.index(r)>-1:c(r).is(a))return r;r=r.parentNode}return null})},index:function(a){if(!a||typeof a=== +"string")return c.inArray(this[0],a?c(a):this.parent().children());return c.inArray(a.jquery?a[0]:a,this)},add:function(a,b){a=typeof a==="string"?c(a,b||this.context):c.makeArray(a);b=c.merge(this.get(),a);return this.pushStack(qa(a[0])||qa(b[0])?b:c.unique(b))},andSelf:function(){return this.add(this.prevObject)}});c.each({parent:function(a){return(a=a.parentNode)&&a.nodeType!==11?a:null},parents:function(a){return c.dir(a,"parentNode")},parentsUntil:function(a,b,d){return c.dir(a,"parentNode", +d)},next:function(a){return c.nth(a,2,"nextSibling")},prev:function(a){return c.nth(a,2,"previousSibling")},nextAll:function(a){return c.dir(a,"nextSibling")},prevAll:function(a){return c.dir(a,"previousSibling")},nextUntil:function(a,b,d){return c.dir(a,"nextSibling",d)},prevUntil:function(a,b,d){return c.dir(a,"previousSibling",d)},siblings:function(a){return c.sibling(a.parentNode.firstChild,a)},children:function(a){return c.sibling(a.firstChild)},contents:function(a){return c.nodeName(a,"iframe")? +a.contentDocument||a.contentWindow.document:c.makeArray(a.childNodes)}},function(a,b){c.fn[a]=function(d,f){var e=c.map(this,b,d);eb.test(a)||(f=d);if(f&&typeof f==="string")e=c.filter(f,e);e=this.length>1?c.unique(e):e;if((this.length>1||gb.test(f))&&fb.test(a))e=e.reverse();return this.pushStack(e,a,R.call(arguments).join(","))}});c.extend({filter:function(a,b,d){if(d)a=":not("+a+")";return c.find.matches(a,b)},dir:function(a,b,d){var f=[];for(a=a[b];a&&a.nodeType!==9&&(d===w||a.nodeType!==1||!c(a).is(d));){a.nodeType=== +1&&f.push(a);a=a[b]}return f},nth:function(a,b,d){b=b||1;for(var f=0;a;a=a[d])if(a.nodeType===1&&++f===b)break;return a},sibling:function(a,b){for(var d=[];a;a=a.nextSibling)a.nodeType===1&&a!==b&&d.push(a);return d}});var Ja=/ jQuery\d+="(?:\d+|null)"/g,V=/^\s+/,Ka=/(<([\w:]+)[^>]*?)\/>/g,hb=/^(?:area|br|col|embed|hr|img|input|link|meta|param)$/i,La=/<([\w:]+)/,ib=/"},F={option:[1,""],legend:[1,"
","
"],thead:[1,"","
"],tr:[2,"","
"],td:[3,"","
"],col:[2,"","
"],area:[1,"",""],_default:[0,"",""]};F.optgroup=F.option;F.tbody=F.tfoot=F.colgroup=F.caption=F.thead;F.th=F.td;if(!c.support.htmlSerialize)F._default=[1,"div
","
"];c.fn.extend({text:function(a){if(c.isFunction(a))return this.each(function(b){var d= +c(this);d.text(a.call(this,b,d.text()))});if(typeof a!=="object"&&a!==w)return this.empty().append((this[0]&&this[0].ownerDocument||s).createTextNode(a));return c.text(this)},wrapAll:function(a){if(c.isFunction(a))return this.each(function(d){c(this).wrapAll(a.call(this,d))});if(this[0]){var b=c(a,this[0].ownerDocument).eq(0).clone(true);this[0].parentNode&&b.insertBefore(this[0]);b.map(function(){for(var d=this;d.firstChild&&d.firstChild.nodeType===1;)d=d.firstChild;return d}).append(this)}return this}, +wrapInner:function(a){if(c.isFunction(a))return this.each(function(b){c(this).wrapInner(a.call(this,b))});return this.each(function(){var b=c(this),d=b.contents();d.length?d.wrapAll(a):b.append(a)})},wrap:function(a){return this.each(function(){c(this).wrapAll(a)})},unwrap:function(){return this.parent().each(function(){c.nodeName(this,"body")||c(this).replaceWith(this.childNodes)}).end()},append:function(){return this.domManip(arguments,true,function(a){this.nodeType===1&&this.appendChild(a)})}, +prepend:function(){return this.domManip(arguments,true,function(a){this.nodeType===1&&this.insertBefore(a,this.firstChild)})},before:function(){if(this[0]&&this[0].parentNode)return this.domManip(arguments,false,function(b){this.parentNode.insertBefore(b,this)});else if(arguments.length){var a=c(arguments[0]);a.push.apply(a,this.toArray());return this.pushStack(a,"before",arguments)}},after:function(){if(this[0]&&this[0].parentNode)return this.domManip(arguments,false,function(b){this.parentNode.insertBefore(b, +this.nextSibling)});else if(arguments.length){var a=this.pushStack(this,"after",arguments);a.push.apply(a,c(arguments[0]).toArray());return a}},remove:function(a,b){for(var d=0,f;(f=this[d])!=null;d++)if(!a||c.filter(a,[f]).length){if(!b&&f.nodeType===1){c.cleanData(f.getElementsByTagName("*"));c.cleanData([f])}f.parentNode&&f.parentNode.removeChild(f)}return this},empty:function(){for(var a=0,b;(b=this[a])!=null;a++)for(b.nodeType===1&&c.cleanData(b.getElementsByTagName("*"));b.firstChild;)b.removeChild(b.firstChild); +return this},clone:function(a){var b=this.map(function(){if(!c.support.noCloneEvent&&!c.isXMLDoc(this)){var d=this.outerHTML,f=this.ownerDocument;if(!d){d=f.createElement("div");d.appendChild(this.cloneNode(true));d=d.innerHTML}return c.clean([d.replace(Ja,"").replace(/=([^="'>\s]+\/)>/g,'="$1">').replace(V,"")],f)[0]}else return this.cloneNode(true)});if(a===true){ra(this,b);ra(this.find("*"),b.find("*"))}return b},html:function(a){if(a===w)return this[0]&&this[0].nodeType===1?this[0].innerHTML.replace(Ja, +""):null;else if(typeof a==="string"&&!ta.test(a)&&(c.support.leadingWhitespace||!V.test(a))&&!F[(La.exec(a)||["",""])[1].toLowerCase()]){a=a.replace(Ka,Ma);try{for(var b=0,d=this.length;b0||e.cacheable||this.length>1?k.cloneNode(true):k)}o.length&&c.each(o,Qa)}return this}});c.fragments={};c.each({appendTo:"append",prependTo:"prepend",insertBefore:"before",insertAfter:"after",replaceAll:"replaceWith"},function(a,b){c.fn[a]=function(d){var f=[];d=c(d);var e=this.length===1&&this[0].parentNode;if(e&&e.nodeType===11&&e.childNodes.length===1&&d.length===1){d[b](this[0]); +return this}else{e=0;for(var j=d.length;e0?this.clone(true):this).get();c.fn[b].apply(c(d[e]),i);f=f.concat(i)}return this.pushStack(f,a,d.selector)}}});c.extend({clean:function(a,b,d,f){b=b||s;if(typeof b.createElement==="undefined")b=b.ownerDocument||b[0]&&b[0].ownerDocument||s;for(var e=[],j=0,i;(i=a[j])!=null;j++){if(typeof i==="number")i+="";if(i){if(typeof i==="string"&&!jb.test(i))i=b.createTextNode(i);else if(typeof i==="string"){i=i.replace(Ka,Ma);var o=(La.exec(i)||["", +""])[1].toLowerCase(),k=F[o]||F._default,n=k[0],r=b.createElement("div");for(r.innerHTML=k[1]+i+k[2];n--;)r=r.lastChild;if(!c.support.tbody){n=ib.test(i);o=o==="table"&&!n?r.firstChild&&r.firstChild.childNodes:k[1]===""&&!n?r.childNodes:[];for(k=o.length-1;k>=0;--k)c.nodeName(o[k],"tbody")&&!o[k].childNodes.length&&o[k].parentNode.removeChild(o[k])}!c.support.leadingWhitespace&&V.test(i)&&r.insertBefore(b.createTextNode(V.exec(i)[0]),r.firstChild);i=r.childNodes}if(i.nodeType)e.push(i);else e= +c.merge(e,i)}}if(d)for(j=0;e[j];j++)if(f&&c.nodeName(e[j],"script")&&(!e[j].type||e[j].type.toLowerCase()==="text/javascript"))f.push(e[j].parentNode?e[j].parentNode.removeChild(e[j]):e[j]);else{e[j].nodeType===1&&e.splice.apply(e,[j+1,0].concat(c.makeArray(e[j].getElementsByTagName("script"))));d.appendChild(e[j])}return e},cleanData:function(a){for(var b,d,f=c.cache,e=c.event.special,j=c.support.deleteExpando,i=0,o;(o=a[i])!=null;i++)if(d=o[c.expando]){b=f[d];if(b.events)for(var k in b.events)e[k]? +c.event.remove(o,k):Ca(o,k,b.handle);if(j)delete o[c.expando];else o.removeAttribute&&o.removeAttribute(c.expando);delete f[d]}}});var kb=/z-?index|font-?weight|opacity|zoom|line-?height/i,Na=/alpha\([^)]*\)/,Oa=/opacity=([^)]*)/,ha=/float/i,ia=/-([a-z])/ig,lb=/([A-Z])/g,mb=/^-?\d+(?:px)?$/i,nb=/^-?\d/,ob={position:"absolute",visibility:"hidden",display:"block"},pb=["Left","Right"],qb=["Top","Bottom"],rb=s.defaultView&&s.defaultView.getComputedStyle,Pa=c.support.cssFloat?"cssFloat":"styleFloat",ja= +function(a,b){return b.toUpperCase()};c.fn.css=function(a,b){return X(this,a,b,true,function(d,f,e){if(e===w)return c.curCSS(d,f);if(typeof e==="number"&&!kb.test(f))e+="px";c.style(d,f,e)})};c.extend({style:function(a,b,d){if(!a||a.nodeType===3||a.nodeType===8)return w;if((b==="width"||b==="height")&&parseFloat(d)<0)d=w;var f=a.style||a,e=d!==w;if(!c.support.opacity&&b==="opacity"){if(e){f.zoom=1;b=parseInt(d,10)+""==="NaN"?"":"alpha(opacity="+d*100+")";a=f.filter||c.curCSS(a,"filter")||"";f.filter= +Na.test(a)?a.replace(Na,b):b}return f.filter&&f.filter.indexOf("opacity=")>=0?parseFloat(Oa.exec(f.filter)[1])/100+"":""}if(ha.test(b))b=Pa;b=b.replace(ia,ja);if(e)f[b]=d;return f[b]},css:function(a,b,d,f){if(b==="width"||b==="height"){var e,j=b==="width"?pb:qb;function i(){e=b==="width"?a.offsetWidth:a.offsetHeight;f!=="border"&&c.each(j,function(){f||(e-=parseFloat(c.curCSS(a,"padding"+this,true))||0);if(f==="margin")e+=parseFloat(c.curCSS(a,"margin"+this,true))||0;else e-=parseFloat(c.curCSS(a, +"border"+this+"Width",true))||0})}a.offsetWidth!==0?i():c.swap(a,ob,i);return Math.max(0,Math.round(e))}return c.curCSS(a,b,d)},curCSS:function(a,b,d){var f,e=a.style;if(!c.support.opacity&&b==="opacity"&&a.currentStyle){f=Oa.test(a.currentStyle.filter||"")?parseFloat(RegExp.$1)/100+"":"";return f===""?"1":f}if(ha.test(b))b=Pa;if(!d&&e&&e[b])f=e[b];else if(rb){if(ha.test(b))b="float";b=b.replace(lb,"-$1").toLowerCase();e=a.ownerDocument.defaultView;if(!e)return null;if(a=e.getComputedStyle(a,null))f= +a.getPropertyValue(b);if(b==="opacity"&&f==="")f="1"}else if(a.currentStyle){d=b.replace(ia,ja);f=a.currentStyle[b]||a.currentStyle[d];if(!mb.test(f)&&nb.test(f)){b=e.left;var j=a.runtimeStyle.left;a.runtimeStyle.left=a.currentStyle.left;e.left=d==="fontSize"?"1em":f||0;f=e.pixelLeft+"px";e.left=b;a.runtimeStyle.left=j}}return f},swap:function(a,b,d){var f={};for(var e in b){f[e]=a.style[e];a.style[e]=b[e]}d.call(a);for(e in b)a.style[e]=f[e]}});if(c.expr&&c.expr.filters){c.expr.filters.hidden=function(a){var b= +a.offsetWidth,d=a.offsetHeight,f=a.nodeName.toLowerCase()==="tr";return b===0&&d===0&&!f?true:b>0&&d>0&&!f?false:c.curCSS(a,"display")==="none"};c.expr.filters.visible=function(a){return!c.expr.filters.hidden(a)}}var sb=J(),tb=//gi,ub=/select|textarea/i,vb=/color|date|datetime|email|hidden|month|number|password|range|search|tel|text|time|url|week/i,N=/=\?(&|$)/,ka=/\?/,wb=/(\?|&)_=.*?(&|$)/,xb=/^(\w+:)?\/\/([^\/?#]+)/,yb=/%20/g,zb=c.fn.load;c.fn.extend({load:function(a,b,d){if(typeof a!== +"string")return zb.call(this,a);else if(!this.length)return this;var f=a.indexOf(" ");if(f>=0){var e=a.slice(f,a.length);a=a.slice(0,f)}f="GET";if(b)if(c.isFunction(b)){d=b;b=null}else if(typeof b==="object"){b=c.param(b,c.ajaxSettings.traditional);f="POST"}var j=this;c.ajax({url:a,type:f,dataType:"html",data:b,complete:function(i,o){if(o==="success"||o==="notmodified")j.html(e?c("
").append(i.responseText.replace(tb,"")).find(e):i.responseText);d&&j.each(d,[i.responseText,o,i])}});return this}, +serialize:function(){return c.param(this.serializeArray())},serializeArray:function(){return this.map(function(){return this.elements?c.makeArray(this.elements):this}).filter(function(){return this.name&&!this.disabled&&(this.checked||ub.test(this.nodeName)||vb.test(this.type))}).map(function(a,b){a=c(this).val();return a==null?null:c.isArray(a)?c.map(a,function(d){return{name:b.name,value:d}}):{name:b.name,value:a}}).get()}});c.each("ajaxStart ajaxStop ajaxComplete ajaxError ajaxSuccess ajaxSend".split(" "), +function(a,b){c.fn[b]=function(d){return this.bind(b,d)}});c.extend({get:function(a,b,d,f){if(c.isFunction(b)){f=f||d;d=b;b=null}return c.ajax({type:"GET",url:a,data:b,success:d,dataType:f})},getScript:function(a,b){return c.get(a,null,b,"script")},getJSON:function(a,b,d){return c.get(a,b,d,"json")},post:function(a,b,d,f){if(c.isFunction(b)){f=f||d;d=b;b={}}return c.ajax({type:"POST",url:a,data:b,success:d,dataType:f})},ajaxSetup:function(a){c.extend(c.ajaxSettings,a)},ajaxSettings:{url:location.href, +global:true,type:"GET",contentType:"application/x-www-form-urlencoded",processData:true,async:true,xhr:A.XMLHttpRequest&&(A.location.protocol!=="file:"||!A.ActiveXObject)?function(){return new A.XMLHttpRequest}:function(){try{return new A.ActiveXObject("Microsoft.XMLHTTP")}catch(a){}},accepts:{xml:"application/xml, text/xml",html:"text/html",script:"text/javascript, application/javascript",json:"application/json, text/javascript",text:"text/plain",_default:"*/*"}},lastModified:{},etag:{},ajax:function(a){function b(){e.success&& +e.success.call(k,o,i,x);e.global&&f("ajaxSuccess",[x,e])}function d(){e.complete&&e.complete.call(k,x,i);e.global&&f("ajaxComplete",[x,e]);e.global&&!--c.active&&c.event.trigger("ajaxStop")}function f(q,p){(e.context?c(e.context):c.event).trigger(q,p)}var e=c.extend(true,{},c.ajaxSettings,a),j,i,o,k=a&&a.context||e,n=e.type.toUpperCase();if(e.data&&e.processData&&typeof e.data!=="string")e.data=c.param(e.data,e.traditional);if(e.dataType==="jsonp"){if(n==="GET")N.test(e.url)||(e.url+=(ka.test(e.url)? +"&":"?")+(e.jsonp||"callback")+"=?");else if(!e.data||!N.test(e.data))e.data=(e.data?e.data+"&":"")+(e.jsonp||"callback")+"=?";e.dataType="json"}if(e.dataType==="json"&&(e.data&&N.test(e.data)||N.test(e.url))){j=e.jsonpCallback||"jsonp"+sb++;if(e.data)e.data=(e.data+"").replace(N,"="+j+"$1");e.url=e.url.replace(N,"="+j+"$1");e.dataType="script";A[j]=A[j]||function(q){o=q;b();d();A[j]=w;try{delete A[j]}catch(p){}z&&z.removeChild(C)}}if(e.dataType==="script"&&e.cache===null)e.cache=false;if(e.cache=== +false&&n==="GET"){var r=J(),u=e.url.replace(wb,"$1_="+r+"$2");e.url=u+(u===e.url?(ka.test(e.url)?"&":"?")+"_="+r:"")}if(e.data&&n==="GET")e.url+=(ka.test(e.url)?"&":"?")+e.data;e.global&&!c.active++&&c.event.trigger("ajaxStart");r=(r=xb.exec(e.url))&&(r[1]&&r[1]!==location.protocol||r[2]!==location.host);if(e.dataType==="script"&&n==="GET"&&r){var z=s.getElementsByTagName("head")[0]||s.documentElement,C=s.createElement("script");C.src=e.url;if(e.scriptCharset)C.charset=e.scriptCharset;if(!j){var B= +false;C.onload=C.onreadystatechange=function(){if(!B&&(!this.readyState||this.readyState==="loaded"||this.readyState==="complete")){B=true;b();d();C.onload=C.onreadystatechange=null;z&&C.parentNode&&z.removeChild(C)}}}z.insertBefore(C,z.firstChild);return w}var E=false,x=e.xhr();if(x){e.username?x.open(n,e.url,e.async,e.username,e.password):x.open(n,e.url,e.async);try{if(e.data||a&&a.contentType)x.setRequestHeader("Content-Type",e.contentType);if(e.ifModified){c.lastModified[e.url]&&x.setRequestHeader("If-Modified-Since", +c.lastModified[e.url]);c.etag[e.url]&&x.setRequestHeader("If-None-Match",c.etag[e.url])}r||x.setRequestHeader("X-Requested-With","XMLHttpRequest");x.setRequestHeader("Accept",e.dataType&&e.accepts[e.dataType]?e.accepts[e.dataType]+", */*":e.accepts._default)}catch(ga){}if(e.beforeSend&&e.beforeSend.call(k,x,e)===false){e.global&&!--c.active&&c.event.trigger("ajaxStop");x.abort();return false}e.global&&f("ajaxSend",[x,e]);var g=x.onreadystatechange=function(q){if(!x||x.readyState===0||q==="abort"){E|| +d();E=true;if(x)x.onreadystatechange=c.noop}else if(!E&&x&&(x.readyState===4||q==="timeout")){E=true;x.onreadystatechange=c.noop;i=q==="timeout"?"timeout":!c.httpSuccess(x)?"error":e.ifModified&&c.httpNotModified(x,e.url)?"notmodified":"success";var p;if(i==="success")try{o=c.httpData(x,e.dataType,e)}catch(v){i="parsererror";p=v}if(i==="success"||i==="notmodified")j||b();else c.handleError(e,x,i,p);d();q==="timeout"&&x.abort();if(e.async)x=null}};try{var h=x.abort;x.abort=function(){x&&h.call(x); +g("abort")}}catch(l){}e.async&&e.timeout>0&&setTimeout(function(){x&&!E&&g("timeout")},e.timeout);try{x.send(n==="POST"||n==="PUT"||n==="DELETE"?e.data:null)}catch(m){c.handleError(e,x,null,m);d()}e.async||g();return x}},handleError:function(a,b,d,f){if(a.error)a.error.call(a.context||a,b,d,f);if(a.global)(a.context?c(a.context):c.event).trigger("ajaxError",[b,a,f])},active:0,httpSuccess:function(a){try{return!a.status&&location.protocol==="file:"||a.status>=200&&a.status<300||a.status===304||a.status=== +1223||a.status===0}catch(b){}return false},httpNotModified:function(a,b){var d=a.getResponseHeader("Last-Modified"),f=a.getResponseHeader("Etag");if(d)c.lastModified[b]=d;if(f)c.etag[b]=f;return a.status===304||a.status===0},httpData:function(a,b,d){var f=a.getResponseHeader("content-type")||"",e=b==="xml"||!b&&f.indexOf("xml")>=0;a=e?a.responseXML:a.responseText;e&&a.documentElement.nodeName==="parsererror"&&c.error("parsererror");if(d&&d.dataFilter)a=d.dataFilter(a,b);if(typeof a==="string")if(b=== +"json"||!b&&f.indexOf("json")>=0)a=c.parseJSON(a);else if(b==="script"||!b&&f.indexOf("javascript")>=0)c.globalEval(a);return a},param:function(a,b){function d(i,o){if(c.isArray(o))c.each(o,function(k,n){b||/\[\]$/.test(i)?f(i,n):d(i+"["+(typeof n==="object"||c.isArray(n)?k:"")+"]",n)});else!b&&o!=null&&typeof o==="object"?c.each(o,function(k,n){d(i+"["+k+"]",n)}):f(i,o)}function f(i,o){o=c.isFunction(o)?o():o;e[e.length]=encodeURIComponent(i)+"="+encodeURIComponent(o)}var e=[];if(b===w)b=c.ajaxSettings.traditional; +if(c.isArray(a)||a.jquery)c.each(a,function(){f(this.name,this.value)});else for(var j in a)d(j,a[j]);return e.join("&").replace(yb,"+")}});var la={},Ab=/toggle|show|hide/,Bb=/^([+-]=)?([\d+-.]+)(.*)$/,W,va=[["height","marginTop","marginBottom","paddingTop","paddingBottom"],["width","marginLeft","marginRight","paddingLeft","paddingRight"],["opacity"]];c.fn.extend({show:function(a,b){if(a||a===0)return this.animate(K("show",3),a,b);else{a=0;for(b=this.length;a").appendTo("body");f=e.css("display");if(f==="none")f="block";e.remove();la[d]=f}c.data(this[a],"olddisplay",f)}}a=0;for(b=this.length;a=0;f--)if(d[f].elem===this){b&&d[f](true);d.splice(f,1)}});b||this.dequeue();return this}});c.each({slideDown:K("show",1),slideUp:K("hide",1),slideToggle:K("toggle",1),fadeIn:{opacity:"show"},fadeOut:{opacity:"hide"}},function(a,b){c.fn[a]=function(d,f){return this.animate(b,d,f)}});c.extend({speed:function(a,b,d){var f=a&&typeof a==="object"?a:{complete:d||!d&&b||c.isFunction(a)&&a,duration:a,easing:d&&b||b&&!c.isFunction(b)&&b};f.duration=c.fx.off?0:typeof f.duration=== +"number"?f.duration:c.fx.speeds[f.duration]||c.fx.speeds._default;f.old=f.complete;f.complete=function(){f.queue!==false&&c(this).dequeue();c.isFunction(f.old)&&f.old.call(this)};return f},easing:{linear:function(a,b,d,f){return d+f*a},swing:function(a,b,d,f){return(-Math.cos(a*Math.PI)/2+0.5)*f+d}},timers:[],fx:function(a,b,d){this.options=b;this.elem=a;this.prop=d;if(!b.orig)b.orig={}}});c.fx.prototype={update:function(){this.options.step&&this.options.step.call(this.elem,this.now,this);(c.fx.step[this.prop]|| +c.fx.step._default)(this);if((this.prop==="height"||this.prop==="width")&&this.elem.style)this.elem.style.display="block"},cur:function(a){if(this.elem[this.prop]!=null&&(!this.elem.style||this.elem.style[this.prop]==null))return this.elem[this.prop];return(a=parseFloat(c.css(this.elem,this.prop,a)))&&a>-10000?a:parseFloat(c.curCSS(this.elem,this.prop))||0},custom:function(a,b,d){function f(j){return e.step(j)}this.startTime=J();this.start=a;this.end=b;this.unit=d||this.unit||"px";this.now=this.start; +this.pos=this.state=0;var e=this;f.elem=this.elem;if(f()&&c.timers.push(f)&&!W)W=setInterval(c.fx.tick,13)},show:function(){this.options.orig[this.prop]=c.style(this.elem,this.prop);this.options.show=true;this.custom(this.prop==="width"||this.prop==="height"?1:0,this.cur());c(this.elem).show()},hide:function(){this.options.orig[this.prop]=c.style(this.elem,this.prop);this.options.hide=true;this.custom(this.cur(),0)},step:function(a){var b=J(),d=true;if(a||b>=this.options.duration+this.startTime){this.now= +this.end;this.pos=this.state=1;this.update();this.options.curAnim[this.prop]=true;for(var f in this.options.curAnim)if(this.options.curAnim[f]!==true)d=false;if(d){if(this.options.display!=null){this.elem.style.overflow=this.options.overflow;a=c.data(this.elem,"olddisplay");this.elem.style.display=a?a:this.options.display;if(c.css(this.elem,"display")==="none")this.elem.style.display="block"}this.options.hide&&c(this.elem).hide();if(this.options.hide||this.options.show)for(var e in this.options.curAnim)c.style(this.elem, +e,this.options.orig[e]);this.options.complete.call(this.elem)}return false}else{e=b-this.startTime;this.state=e/this.options.duration;a=this.options.easing||(c.easing.swing?"swing":"linear");this.pos=c.easing[this.options.specialEasing&&this.options.specialEasing[this.prop]||a](this.state,e,0,1,this.options.duration);this.now=this.start+(this.end-this.start)*this.pos;this.update()}return true}};c.extend(c.fx,{tick:function(){for(var a=c.timers,b=0;b
"; +a.insertBefore(b,a.firstChild);d=b.firstChild;f=d.firstChild;e=d.nextSibling.firstChild.firstChild;this.doesNotAddBorder=f.offsetTop!==5;this.doesAddBorderForTableAndCells=e.offsetTop===5;f.style.position="fixed";f.style.top="20px";this.supportsFixedPosition=f.offsetTop===20||f.offsetTop===15;f.style.position=f.style.top="";d.style.overflow="hidden";d.style.position="relative";this.subtractsBorderForOverflowNotVisible=f.offsetTop===-5;this.doesNotIncludeMarginInBodyOffset=a.offsetTop!==j;a.removeChild(b); +c.offset.initialize=c.noop},bodyOffset:function(a){var b=a.offsetTop,d=a.offsetLeft;c.offset.initialize();if(c.offset.doesNotIncludeMarginInBodyOffset){b+=parseFloat(c.curCSS(a,"marginTop",true))||0;d+=parseFloat(c.curCSS(a,"marginLeft",true))||0}return{top:b,left:d}},setOffset:function(a,b,d){if(/static/.test(c.curCSS(a,"position")))a.style.position="relative";var f=c(a),e=f.offset(),j=parseInt(c.curCSS(a,"top",true),10)||0,i=parseInt(c.curCSS(a,"left",true),10)||0;if(c.isFunction(b))b=b.call(a, +d,e);d={top:b.top-e.top+j,left:b.left-e.left+i};"using"in b?b.using.call(a,d):f.css(d)}};c.fn.extend({position:function(){if(!this[0])return null;var a=this[0],b=this.offsetParent(),d=this.offset(),f=/^body|html$/i.test(b[0].nodeName)?{top:0,left:0}:b.offset();d.top-=parseFloat(c.curCSS(a,"marginTop",true))||0;d.left-=parseFloat(c.curCSS(a,"marginLeft",true))||0;f.top+=parseFloat(c.curCSS(b[0],"borderTopWidth",true))||0;f.left+=parseFloat(c.curCSS(b[0],"borderLeftWidth",true))||0;return{top:d.top- +f.top,left:d.left-f.left}},offsetParent:function(){return this.map(function(){for(var a=this.offsetParent||s.body;a&&!/^body|html$/i.test(a.nodeName)&&c.css(a,"position")==="static";)a=a.offsetParent;return a})}});c.each(["Left","Top"],function(a,b){var d="scroll"+b;c.fn[d]=function(f){var e=this[0],j;if(!e)return null;if(f!==w)return this.each(function(){if(j=wa(this))j.scrollTo(!a?f:c(j).scrollLeft(),a?f:c(j).scrollTop());else this[d]=f});else return(j=wa(e))?"pageXOffset"in j?j[a?"pageYOffset": +"pageXOffset"]:c.support.boxModel&&j.document.documentElement[d]||j.document.body[d]:e[d]}});c.each(["Height","Width"],function(a,b){var d=b.toLowerCase();c.fn["inner"+b]=function(){return this[0]?c.css(this[0],d,false,"padding"):null};c.fn["outer"+b]=function(f){return this[0]?c.css(this[0],d,false,f?"margin":"border"):null};c.fn[d]=function(f){var e=this[0];if(!e)return f==null?null:this;if(c.isFunction(f))return this.each(function(j){var i=c(this);i[d](f.call(this,j,i[d]()))});return"scrollTo"in +e&&e.document?e.document.compatMode==="CSS1Compat"&&e.document.documentElement["client"+b]||e.document.body["client"+b]:e.nodeType===9?Math.max(e.documentElement["client"+b],e.body["scroll"+b],e.documentElement["scroll"+b],e.body["offset"+b],e.documentElement["offset"+b]):f===w?c.css(e,d):this.css(d,typeof f==="string"?f:f+"px")}});A.jQuery=A.$=c})(window); diff --git a/endpoints/internal/protorpc/static/jquery.json-2.2.min.js b/endpoints/internal/protorpc/static/jquery.json-2.2.min.js new file mode 100644 index 0000000..bad4a0a --- /dev/null +++ b/endpoints/internal/protorpc/static/jquery.json-2.2.min.js @@ -0,0 +1,31 @@ + +(function($){$.toJSON=function(o) +{if(typeof(JSON)=='object'&&JSON.stringify) +return JSON.stringify(o);var type=typeof(o);if(o===null) +return"null";if(type=="undefined") +return undefined;if(type=="number"||type=="boolean") +return o+"";if(type=="string") +return $.quoteString(o);if(type=='object') +{if(typeof o.toJSON=="function") +return $.toJSON(o.toJSON());if(o.constructor===Date) +{var month=o.getUTCMonth()+1;if(month<10)month='0'+month;var day=o.getUTCDate();if(day<10)day='0'+day;var year=o.getUTCFullYear();var hours=o.getUTCHours();if(hours<10)hours='0'+hours;var minutes=o.getUTCMinutes();if(minutes<10)minutes='0'+minutes;var seconds=o.getUTCSeconds();if(seconds<10)seconds='0'+seconds;var milli=o.getUTCMilliseconds();if(milli<100)milli='0'+milli;if(milli<10)milli='0'+milli;return'"'+year+'-'+month+'-'+day+'T'+ +hours+':'+minutes+':'+seconds+'.'+milli+'Z"';} +if(o.constructor===Array) +{var ret=[];for(var i=0;i + +{% extends 'base.html' %} + +{% block title %}Form for {{service_path|escape}}.{{method_name|escape}}{% endblock %} + +{% block top %} +<< Back to method selection +

Form for {{service_path|escape}}.{{method_name|escape}}

+{% endblock %} + +{% block body %} + +
+
+
+ +
+{% endblock %} + +{% block call %} +loadServices(createForm); +{% endblock %} diff --git a/endpoints/internal/protorpc/test_util.py b/endpoints/internal/protorpc/test_util.py new file mode 100644 index 0000000..bcbccf6 --- /dev/null +++ b/endpoints/internal/protorpc/test_util.py @@ -0,0 +1,671 @@ +#!/usr/bin/env python +# +# Copyright 2010 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""Test utilities for message testing. + +Includes module interface test to ensure that public parts of module are +correctly declared in __all__. + +Includes message types that correspond to those defined in +services_test.proto. + +Includes additional test utilities to make sure encoding/decoding libraries +conform. +""" +from six.moves import range + +__author__ = 'rafek@google.com (Rafe Kaplan)' + +import cgi +import datetime +import inspect +import os +import re +import socket +import types +import unittest2 as unittest + +import six + +from . import message_types +from . import messages +from . import util + +# Unicode of the word "Russian" in cyrillic. +RUSSIAN = u'\u0440\u0443\u0441\u0441\u043a\u0438\u0439' + +# All characters binary value interspersed with nulls. +BINARY = b''.join(six.int2byte(value) + b'\0' for value in range(256)) + + +class TestCase(unittest.TestCase): + + def assertRaisesWithRegexpMatch(self, + exception, + regexp, + function, + *params, + **kwargs): + """Check that exception is raised and text matches regular expression. + + Args: + exception: Exception type that is expected. + regexp: String regular expression that is expected in error message. + function: Callable to test. + params: Parameters to forward to function. + kwargs: Keyword arguments to forward to function. + """ + try: + function(*params, **kwargs) + self.fail('Expected exception %s was not raised' % exception.__name__) + except exception as err: + match = bool(re.match(regexp, str(err))) + self.assertTrue(match, 'Expected match "%s", found "%s"' % (regexp, + err)) + + def assertHeaderSame(self, header1, header2): + """Check that two HTTP headers are the same. + + Args: + header1: Header value string 1. + header2: header value string 2. + """ + value1, params1 = cgi.parse_header(header1) + value2, params2 = cgi.parse_header(header2) + self.assertEqual(value1, value2) + self.assertEqual(params1, params2) + + def assertIterEqual(self, iter1, iter2): + """Check that two iterators or iterables are equal independent of order. + + Similar to Python 2.7 assertItemsEqual. Named differently in order to + avoid potential conflict. + + Args: + iter1: An iterator or iterable. + iter2: An iterator or iterable. + """ + list1 = list(iter1) + list2 = list(iter2) + + unmatched1 = list() + + while list1: + item1 = list1[0] + del list1[0] + for index in range(len(list2)): + if item1 == list2[index]: + del list2[index] + break + else: + unmatched1.append(item1) + + error_message = [] + for item in unmatched1: + error_message.append( + ' Item from iter1 not found in iter2: %r' % item) + for item in list2: + error_message.append( + ' Item from iter2 not found in iter1: %r' % item) + if error_message: + self.fail('Collections not equivalent:\n' + '\n'.join(error_message)) + + +class ModuleInterfaceTest(object): + """Test to ensure module interface is carefully constructed. + + A module interface is the set of public objects listed in the module __all__ + attribute. Modules that that are considered public should have this interface + carefully declared. At all times, the __all__ attribute should have objects + intended to be publically used and all other objects in the module should be + considered unused. + + Protected attributes (those beginning with '_') and other imported modules + should not be part of this set of variables. An exception is for variables + that begin and end with '__' which are implicitly part of the interface + (eg. __name__, __file__, __all__ itself, etc.). + + Modules that are imported in to the tested modules are an exception and may + be left out of the __all__ definition. The test is done by checking the value + of what would otherwise be a public name and not allowing it to be exported + if it is an instance of a module. Modules that are explicitly exported are + for the time being not permitted. + + To use this test class a module should define a new class that inherits first + from ModuleInterfaceTest and then from test_util.TestCase. No other tests + should be added to this test case, making the order of inheritance less + important, but if setUp for some reason is overidden, it is important that + ModuleInterfaceTest is first in the list so that its setUp method is + invoked. + + Multiple inheretance is required so that ModuleInterfaceTest is not itself + a test, and is not itself executed as one. + + The test class is expected to have the following class attributes defined: + + MODULE: A reference to the module that is being validated for interface + correctness. + + Example: + Module definition (hello.py): + + import sys + + __all__ = ['hello'] + + def _get_outputter(): + return sys.stdout + + def hello(): + _get_outputter().write('Hello\n') + + Test definition: + + import unittest + from protorpc import test_util + + import hello + + class ModuleInterfaceTest(test_util.ModuleInterfaceTest, + test_util.TestCase): + + MODULE = hello + + + class HelloTest(test_util.TestCase): + ... Test 'hello' module ... + + + if __name__ == '__main__': + unittest.main() + """ + + def setUp(self): + """Set up makes sure that MODULE and IMPORTED_MODULES is defined. + + This is a basic configuration test for the test itself so does not + get it's own test case. + """ + if not hasattr(self, 'MODULE'): + self.fail( + "You must define 'MODULE' on ModuleInterfaceTest sub-class %s." % + type(self).__name__) + + def testAllExist(self): + """Test that all attributes defined in __all__ exist.""" + missing_attributes = [] + for attribute in self.MODULE.__all__: + if not hasattr(self.MODULE, attribute): + missing_attributes.append(attribute) + if missing_attributes: + self.fail('%s of __all__ are not defined in module.' % + missing_attributes) + + def testAllExported(self): + """Test that all public attributes not imported are in __all__.""" + missing_attributes = [] + for attribute in dir(self.MODULE): + if not attribute.startswith('_'): + if (attribute not in self.MODULE.__all__ and + not isinstance(getattr(self.MODULE, attribute), + types.ModuleType) and + attribute != 'with_statement'): + missing_attributes.append(attribute) + if missing_attributes: + self.fail('%s are not modules and not defined in __all__.' % + missing_attributes) + + def testNoExportedProtectedVariables(self): + """Test that there are no protected variables listed in __all__.""" + protected_variables = [] + for attribute in self.MODULE.__all__: + if attribute.startswith('_'): + protected_variables.append(attribute) + if protected_variables: + self.fail('%s are protected variables and may not be exported.' % + protected_variables) + + def testNoExportedModules(self): + """Test that no modules exist in __all__.""" + exported_modules = [] + for attribute in self.MODULE.__all__: + try: + value = getattr(self.MODULE, attribute) + except AttributeError: + # This is a different error case tested for in testAllExist. + pass + else: + if isinstance(value, types.ModuleType): + exported_modules.append(attribute) + if exported_modules: + self.fail('%s are modules and may not be exported.' % exported_modules) + + +class NestedMessage(messages.Message): + """Simple message that gets nested in another message.""" + + a_value = messages.StringField(1, required=True) + + +class HasNestedMessage(messages.Message): + """Message that has another message nested in it.""" + + nested = messages.MessageField(NestedMessage, 1) + repeated_nested = messages.MessageField(NestedMessage, 2, repeated=True) + + +class HasDefault(messages.Message): + """Has a default value.""" + + a_value = messages.StringField(1, default=u'a default') + + +class OptionalMessage(messages.Message): + """Contains all message types.""" + + class SimpleEnum(messages.Enum): + """Simple enumeration type.""" + VAL1 = 1 + VAL2 = 2 + + double_value = messages.FloatField(1, variant=messages.Variant.DOUBLE) + float_value = messages.FloatField(2, variant=messages.Variant.FLOAT) + int64_value = messages.IntegerField(3, variant=messages.Variant.INT64) + uint64_value = messages.IntegerField(4, variant=messages.Variant.UINT64) + int32_value = messages.IntegerField(5, variant=messages.Variant.INT32) + bool_value = messages.BooleanField(6, variant=messages.Variant.BOOL) + string_value = messages.StringField(7, variant=messages.Variant.STRING) + bytes_value = messages.BytesField(8, variant=messages.Variant.BYTES) + enum_value = messages.EnumField(SimpleEnum, 10) + + # TODO(rafek): Add support for these variants. + # uint32_value = messages.IntegerField(9, variant=messages.Variant.UINT32) + # sint32_value = messages.IntegerField(11, variant=messages.Variant.SINT32) + # sint64_value = messages.IntegerField(12, variant=messages.Variant.SINT64) + + +class RepeatedMessage(messages.Message): + """Contains all message types as repeated fields.""" + + class SimpleEnum(messages.Enum): + """Simple enumeration type.""" + VAL1 = 1 + VAL2 = 2 + + double_value = messages.FloatField(1, + variant=messages.Variant.DOUBLE, + repeated=True) + float_value = messages.FloatField(2, + variant=messages.Variant.FLOAT, + repeated=True) + int64_value = messages.IntegerField(3, + variant=messages.Variant.INT64, + repeated=True) + uint64_value = messages.IntegerField(4, + variant=messages.Variant.UINT64, + repeated=True) + int32_value = messages.IntegerField(5, + variant=messages.Variant.INT32, + repeated=True) + bool_value = messages.BooleanField(6, + variant=messages.Variant.BOOL, + repeated=True) + string_value = messages.StringField(7, + variant=messages.Variant.STRING, + repeated=True) + bytes_value = messages.BytesField(8, + variant=messages.Variant.BYTES, + repeated=True) + #uint32_value = messages.IntegerField(9, variant=messages.Variant.UINT32) + enum_value = messages.EnumField(SimpleEnum, + 10, + repeated=True) + #sint32_value = messages.IntegerField(11, variant=messages.Variant.SINT32) + #sint64_value = messages.IntegerField(12, variant=messages.Variant.SINT64) + + +class HasOptionalNestedMessage(messages.Message): + + nested = messages.MessageField(OptionalMessage, 1) + repeated_nested = messages.MessageField(OptionalMessage, 2, repeated=True) + + +class ProtoConformanceTestBase(object): + """Protocol conformance test base class. + + Each supported protocol should implement two methods that support encoding + and decoding of Message objects in that format: + + encode_message(message) - Serialize to encoding. + encode_message(message, encoded_message) - Deserialize from encoding. + + Tests for the modules where these functions are implemented should extend + this class in order to support basic behavioral expectations. This ensures + that protocols correctly encode and decode message transparently to the + caller. + + In order to support these test, the base class should also extend the TestCase + class and implement the following class attributes which define the encoded + version of certain protocol buffers: + + encoded_partial: + + + encoded_full: + + + encoded_repeated: + + + encoded_nested: + + > + + encoded_repeated_nested: + , + + ] + > + + unexpected_tag_message: + An encoded message that has an undefined tag or number in the stream. + + encoded_default_assigned: + + + encoded_nested_empty: + + > + + encoded_invalid_enum: + + """ + + encoded_empty_message = '' + + def testEncodeInvalidMessage(self): + message = NestedMessage() + self.assertRaises(messages.ValidationError, + self.PROTOLIB.encode_message, message) + + def CompareEncoded(self, expected_encoded, actual_encoded): + """Compare two encoded protocol values. + + Can be overridden by sub-classes to special case comparison. + For example, to eliminate white space from output that is not + relevant to encoding. + + Args: + expected_encoded: Expected string encoded value. + actual_encoded: Actual string encoded value. + """ + self.assertEquals(expected_encoded, actual_encoded) + + def EncodeDecode(self, encoded, expected_message): + message = self.PROTOLIB.decode_message(type(expected_message), encoded) + self.assertEquals(expected_message, message) + self.CompareEncoded(encoded, self.PROTOLIB.encode_message(message)) + + def testEmptyMessage(self): + self.EncodeDecode(self.encoded_empty_message, OptionalMessage()) + + def testPartial(self): + """Test message with a few values set.""" + message = OptionalMessage() + message.double_value = 1.23 + message.int64_value = -100000000000 + message.int32_value = 1020 + message.string_value = u'a string' + message.enum_value = OptionalMessage.SimpleEnum.VAL2 + + self.EncodeDecode(self.encoded_partial, message) + + def testFull(self): + """Test all types.""" + message = OptionalMessage() + message.double_value = 1.23 + message.float_value = -2.5 + message.int64_value = -100000000000 + message.uint64_value = 102020202020 + message.int32_value = 1020 + message.bool_value = True + message.string_value = u'a string\u044f' + message.bytes_value = b'a bytes\xff\xfe' + message.enum_value = OptionalMessage.SimpleEnum.VAL2 + + self.EncodeDecode(self.encoded_full, message) + + def testRepeated(self): + """Test repeated fields.""" + message = RepeatedMessage() + message.double_value = [1.23, 2.3] + message.float_value = [-2.5, 0.5] + message.int64_value = [-100000000000, 20] + message.uint64_value = [102020202020, 10] + message.int32_value = [1020, 718] + message.bool_value = [True, False] + message.string_value = [u'a string\u044f', u'another string'] + message.bytes_value = [b'a bytes\xff\xfe', b'another bytes'] + message.enum_value = [RepeatedMessage.SimpleEnum.VAL2, + RepeatedMessage.SimpleEnum.VAL1] + + self.EncodeDecode(self.encoded_repeated, message) + + def testNested(self): + """Test nested messages.""" + nested_message = NestedMessage() + nested_message.a_value = u'a string' + + message = HasNestedMessage() + message.nested = nested_message + + self.EncodeDecode(self.encoded_nested, message) + + def testRepeatedNested(self): + """Test repeated nested messages.""" + nested_message1 = NestedMessage() + nested_message1.a_value = u'a string' + nested_message2 = NestedMessage() + nested_message2.a_value = u'another string' + + message = HasNestedMessage() + message.repeated_nested = [nested_message1, nested_message2] + + self.EncodeDecode(self.encoded_repeated_nested, message) + + def testStringTypes(self): + """Test that encoding str on StringField works.""" + message = OptionalMessage() + message.string_value = u'Latin' + self.EncodeDecode(self.encoded_string_types, message) + + def testEncodeUninitialized(self): + """Test that cannot encode uninitialized message.""" + required = NestedMessage() + self.assertRaisesWithRegexpMatch(messages.ValidationError, + "Message NestedMessage is missing " + "required field a_value", + self.PROTOLIB.encode_message, + required) + + def testUnexpectedField(self): + """Test decoding and encoding unexpected fields.""" + loaded_message = self.PROTOLIB.decode_message(OptionalMessage, + self.unexpected_tag_message) + # Message should be equal to an empty message, since unknown values aren't + # included in equality. + self.assertEquals(OptionalMessage(), loaded_message) + # Verify that the encoded message matches the source, including the + # unknown value. + self.assertEquals(self.unexpected_tag_message, + self.PROTOLIB.encode_message(loaded_message)) + + def testDoNotSendDefault(self): + """Test that default is not sent when nothing is assigned.""" + self.EncodeDecode(self.encoded_empty_message, HasDefault()) + + def testSendDefaultExplicitlyAssigned(self): + """Test that default is sent when explcitly assigned.""" + message = HasDefault() + + message.a_value = HasDefault.a_value.default + + self.EncodeDecode(self.encoded_default_assigned, message) + + def testEncodingNestedEmptyMessage(self): + """Test encoding a nested empty message.""" + message = HasOptionalNestedMessage() + message.nested = OptionalMessage() + + self.EncodeDecode(self.encoded_nested_empty, message) + + def testEncodingRepeatedNestedEmptyMessage(self): + """Test encoding a nested empty message.""" + message = HasOptionalNestedMessage() + message.repeated_nested = [OptionalMessage(), OptionalMessage()] + + self.EncodeDecode(self.encoded_repeated_nested_empty, message) + + def testContentType(self): + self.assertTrue(isinstance(self.PROTOLIB.CONTENT_TYPE, str)) + + def testDecodeInvalidEnumType(self): + self.assertRaisesWithRegexpMatch(messages.DecodeError, + 'Invalid enum value ', + self.PROTOLIB.decode_message, + OptionalMessage, + self.encoded_invalid_enum) + + def testDateTimeNoTimeZone(self): + """Test that DateTimeFields are encoded/decoded correctly.""" + + class MyMessage(messages.Message): + value = message_types.DateTimeField(1) + + value = datetime.datetime(2013, 1, 3, 11, 36, 30, 123000) + message = MyMessage(value=value) + decoded = self.PROTOLIB.decode_message( + MyMessage, self.PROTOLIB.encode_message(message)) + self.assertEquals(decoded.value, value) + + def testDateTimeWithTimeZone(self): + """Test DateTimeFields with time zones.""" + + class MyMessage(messages.Message): + value = message_types.DateTimeField(1) + + value = datetime.datetime(2013, 1, 3, 11, 36, 30, 123000, + util.TimeZoneOffset(8 * 60)) + message = MyMessage(value=value) + decoded = self.PROTOLIB.decode_message( + MyMessage, self.PROTOLIB.encode_message(message)) + self.assertEquals(decoded.value, value) + + +def do_with(context, function, *args, **kwargs): + """Simulate a with statement. + + Avoids need to import with from future. + + Does not support simulation of 'as'. + + Args: + context: Context object normally used with 'with'. + function: Callable to evoke. Replaces with-block. + """ + context.__enter__() + try: + function(*args, **kwargs) + except: + context.__exit__(*sys.exc_info()) + finally: + context.__exit__(None, None, None) + + +def pick_unused_port(): + """Find an unused port to use in tests. + + Derived from Damon Kohlers example: + + http://code.activestate.com/recipes/531822-pick-unused-port + """ + try: + temp = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + except socket.error: + # Try IPv6 + temp = socket.socket(socket.AF_INET6, socket.SOCK_STREAM) + + try: + temp.bind(('localhost', 0)) + port = temp.getsockname()[1] + finally: + temp.close() + return port + + +def get_module_name(module_attribute): + """Get the module name. + + Args: + module_attribute: An attribute of the module. + + Returns: + The fully qualified module name or simple module name where + 'module_attribute' is defined if the module name is "__main__". + """ + if module_attribute.__module__ == '__main__': + module_file = inspect.getfile(module_attribute) + default = os.path.basename(module_file).split('.')[0] + return default + else: + return module_attribute.__module__ diff --git a/endpoints/internal/protorpc/transport.py b/endpoints/internal/protorpc/transport.py new file mode 100644 index 0000000..5d7e564 --- /dev/null +++ b/endpoints/internal/protorpc/transport.py @@ -0,0 +1,412 @@ +#!/usr/bin/env python +# +# Copyright 2010 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""Transport library for ProtoRPC. + +Contains underlying infrastructure used for communicating RPCs over low level +transports such as HTTP. + +Includes HTTP transport built over urllib2. +""" + +import six.moves.http_client +import logging +import os +import socket +import sys +import urlparse + +from . import messages +from . import protobuf +from . import remote +from . import util +import six + +__all__ = [ + 'RpcStateError', + + 'HttpTransport', + 'LocalTransport', + 'Rpc', + 'Transport', +] + + +class RpcStateError(messages.Error): + """Raised when trying to put RPC in to an invalid state.""" + + +class Rpc(object): + """Represents a client side RPC. + + An RPC is created by the transport class and is used with a single RPC. While + an RPC is still in process, the response is set to None. When it is complete + the response will contain the response message. + """ + + def __init__(self, request): + """Constructor. + + Args: + request: Request associated with this RPC. + """ + self.__request = request + self.__response = None + self.__state = remote.RpcState.RUNNING + self.__error_message = None + self.__error_name = None + + @property + def request(self): + """Request associated with RPC.""" + return self.__request + + @property + def response(self): + """Response associated with RPC.""" + self.wait() + self.__check_status() + return self.__response + + @property + def state(self): + """State associated with RPC.""" + return self.__state + + @property + def error_message(self): + """Error, if any, associated with RPC.""" + self.wait() + return self.__error_message + + @property + def error_name(self): + """Error name, if any, associated with RPC.""" + self.wait() + return self.__error_name + + def wait(self): + """Wait for an RPC to finish.""" + if self.__state == remote.RpcState.RUNNING: + self._wait_impl() + + def _wait_impl(self): + """Implementation for wait().""" + raise NotImplementedError() + + def __check_status(self): + error_class = remote.RpcError.from_state(self.__state) + if error_class is not None: + if error_class is remote.ApplicationError: + raise error_class(self.__error_message, self.__error_name) + else: + raise error_class(self.__error_message) + + def __set_state(self, state, error_message=None, error_name=None): + if self.__state != remote.RpcState.RUNNING: + raise RpcStateError( + 'RPC must be in RUNNING state to change to %s' % state) + if state == remote.RpcState.RUNNING: + raise RpcStateError('RPC is already in RUNNING state') + self.__state = state + self.__error_message = error_message + self.__error_name = error_name + + def set_response(self, response): + # TODO: Even more specific type checking. + if not isinstance(response, messages.Message): + raise TypeError('Expected Message type, received %r' % (response)) + + self.__response = response + self.__set_state(remote.RpcState.OK) + + def set_status(self, status): + status.check_initialized() + self.__set_state(status.state, status.error_message, status.error_name) + + +class Transport(object): + """Transport base class. + + Provides basic support for implementing a ProtoRPC transport such as one + that can send and receive messages over HTTP. + + Implementations override _start_rpc. This method receives a RemoteInfo + instance and a request Message. The transport is expected to set the rpc + response or raise an exception before termination. + """ + + @util.positional(1) + def __init__(self, protocol=protobuf): + """Constructor. + + Args: + protocol: If string, will look up a protocol from the default Protocols + instance by name. Can also be an instance of remote.ProtocolConfig. + If neither, it must be an object that implements a protocol interface + by implementing encode_message, decode_message and set CONTENT_TYPE. + For example, the modules protobuf and protojson can be used directly. + """ + if isinstance(protocol, six.string_types): + protocols = remote.Protocols.get_default() + try: + protocol = protocols.lookup_by_name(protocol) + except KeyError: + protocol = protocols.lookup_by_content_type(protocol) + if isinstance(protocol, remote.ProtocolConfig): + self.__protocol = protocol.protocol + self.__protocol_config = protocol + else: + self.__protocol = protocol + self.__protocol_config = remote.ProtocolConfig( + protocol, 'default', default_content_type=protocol.CONTENT_TYPE) + + @property + def protocol(self): + """Protocol associated with this transport.""" + return self.__protocol + + @property + def protocol_config(self): + """Protocol associated with this transport.""" + return self.__protocol_config + + def send_rpc(self, remote_info, request): + """Initiate sending an RPC over the transport. + + Args: + remote_info: RemoteInfo instance describing remote method. + request: Request message to send to service. + + Returns: + An Rpc instance intialized with the request.. + """ + request.check_initialized() + + rpc = self._start_rpc(remote_info, request) + + return rpc + + def _start_rpc(self, remote_info, request): + """Start a remote procedure call. + + Args: + remote_info: RemoteInfo instance describing remote method. + request: Request message to send to service. + + Returns: + An Rpc instance initialized with the request. + """ + raise NotImplementedError() + + +class HttpTransport(Transport): + """Transport for communicating with HTTP servers.""" + + @util.positional(2) + def __init__(self, + service_url, + protocol=protobuf): + """Constructor. + + Args: + service_url: URL where the service is located. All communication via + the transport will go to this URL. + protocol: The protocol implementation. Must implement encode_message and + decode_message. Can also be an instance of remote.ProtocolConfig. + """ + super(HttpTransport, self).__init__(protocol=protocol) + self.__service_url = service_url + + def __get_rpc_status(self, response, content): + """Get RPC status from HTTP response. + + Args: + response: HTTPResponse object. + content: Content read from HTTP response. + + Returns: + RpcStatus object parsed from response, else an RpcStatus with a generic + HTTP error. + """ + # Status above 400 may have RpcStatus content. + if response.status >= 400: + content_type = response.getheader('content-type') + if content_type == self.protocol_config.default_content_type: + try: + rpc_status = self.protocol.decode_message(remote.RpcStatus, content) + except Exception as decode_err: + logging.warning( + 'An error occurred trying to parse status: %s\n%s', + str(decode_err), content) + else: + if rpc_status.is_initialized(): + return rpc_status + else: + logging.warning( + 'Body does not result in an initialized RpcStatus message:\n%s', + content) + + # If no RpcStatus message present, attempt to forward any content. If empty + # use standard error message. + if not content.strip(): + content = six.moves.http_client.responses.get(response.status, 'Unknown Error') + return remote.RpcStatus(state=remote.RpcState.SERVER_ERROR, + error_message='HTTP Error %s: %s' % ( + response.status, content or 'Unknown Error')) + + def __set_response(self, remote_info, connection, rpc): + """Set response on RPC. + + Sets response or status from HTTP request. Implements the wait method of + Rpc instance. + + Args: + remote_info: Remote info for invoked RPC. + connection: HTTPConnection that is making request. + rpc: Rpc instance. + """ + try: + response = connection.getresponse() + + content = response.read() + + if response.status == six.moves.http_client.OK: + response = self.protocol.decode_message(remote_info.response_type, + content) + rpc.set_response(response) + else: + status = self.__get_rpc_status(response, content) + rpc.set_status(status) + finally: + connection.close() + + def _start_rpc(self, remote_info, request): + """Start a remote procedure call. + + Args: + remote_info: A RemoteInfo instance for this RPC. + request: The request message for this RPC. + + Returns: + An Rpc instance initialized with a Request. + """ + method_url = '%s.%s' % (self.__service_url, remote_info.method.__name__) + encoded_request = self.protocol.encode_message(request) + + url = urlparse.urlparse(method_url) + if url.scheme == 'https': + connection_type = six.moves.http_client.HTTPSConnection + else: + connection_type = six.moves.http_client.HTTPConnection + connection = connection_type(url.hostname, url.port) + try: + self._send_http_request(connection, url.path, encoded_request) + rpc = Rpc(request) + except remote.RpcError: + # Pass through all ProtoRPC errors + connection.close() + raise + except socket.error as err: + connection.close() + raise remote.NetworkError('Socket error: %s %r' % (type(err).__name__, + err.args), + err) + except Exception as err: + connection.close() + raise remote.NetworkError('Error communicating with HTTP server', + err) + else: + wait_impl = lambda: self.__set_response(remote_info, connection, rpc) + rpc._wait_impl = wait_impl + + return rpc + + def _send_http_request(self, connection, http_path, encoded_request): + connection.request( + 'POST', + http_path, + encoded_request, + headers={'Content-type': self.protocol_config.default_content_type, + 'Content-length': len(encoded_request)}) + + +class LocalTransport(Transport): + """Local transport that sends messages directly to services. + + Useful in tests or creating code that can work with either local or remote + services. Using LocalTransport is preferrable to simply instantiating a + single instance of a service and reusing it. The entire request process + involves instantiating a new instance of a service, initializing it with + request state and then invoking the remote method for every request. + """ + + def __init__(self, service_factory): + """Constructor. + + Args: + service_factory: Service factory or class. + """ + super(LocalTransport, self).__init__() + self.__service_class = getattr(service_factory, + 'service_class', + service_factory) + self.__service_factory = service_factory + + @property + def service_class(self): + return self.__service_class + + @property + def service_factory(self): + return self.__service_factory + + def _start_rpc(self, remote_info, request): + """Start a remote procedure call. + + Args: + remote_info: RemoteInfo instance describing remote method. + request: Request message to send to service. + + Returns: + An Rpc instance initialized with the request. + """ + rpc = Rpc(request) + def wait_impl(): + instance = self.__service_factory() + try: + initalize_request_state = instance.initialize_request_state + except AttributeError: + pass + else: + host = six.text_type(os.uname()[1]) + initalize_request_state(remote.RequestState(remote_host=host, + remote_address=u'127.0.0.1', + server_host=host, + server_port=-1)) + try: + response = remote_info.method(instance, request) + assert isinstance(response, remote_info.response_type) + except remote.ApplicationError: + raise + except: + exc_type, exc_value, traceback = sys.exc_info() + message = 'Unexpected error %s: %s' % (exc_type.__name__, exc_value) + six.reraise(remote.ServerError, message, traceback) + rpc.set_response(response) + rpc._wait_impl = wait_impl + return rpc diff --git a/endpoints/internal/protorpc/transport_test.py b/endpoints/internal/protorpc/transport_test.py new file mode 100644 index 0000000..8fa39c3 --- /dev/null +++ b/endpoints/internal/protorpc/transport_test.py @@ -0,0 +1,493 @@ +#!/usr/bin/env python +# +# Copyright 2010 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import errno +import six.moves.http_client +import os +import socket +import unittest + +from protorpc import messages +from protorpc import protobuf +from protorpc import protojson +from protorpc import remote +from protorpc import test_util +from protorpc import transport +from protorpc import webapp_test_util +from protorpc.wsgi import util as wsgi_util + +import mox + +package = 'transport_test' + + +class ModuleInterfaceTest(test_util.ModuleInterfaceTest, + test_util.TestCase): + + MODULE = transport + + +class Message(messages.Message): + + value = messages.StringField(1) + + +class Service(remote.Service): + + @remote.method(Message, Message) + def method(self, request): + pass + + +# Remove when RPC is no longer subclasses. +class TestRpc(transport.Rpc): + + waited = False + + def _wait_impl(self): + self.waited = True + + +class RpcTest(test_util.TestCase): + + def setUp(self): + self.request = Message(value=u'request') + self.response = Message(value=u'response') + self.status = remote.RpcStatus(state=remote.RpcState.APPLICATION_ERROR, + error_message='an error', + error_name='blam') + + self.rpc = TestRpc(self.request) + + def testConstructor(self): + self.assertEquals(self.request, self.rpc.request) + self.assertEquals(remote.RpcState.RUNNING, self.rpc.state) + self.assertEquals(None, self.rpc.error_message) + self.assertEquals(None, self.rpc.error_name) + + def response(self): + self.assertFalse(self.rpc.waited) + self.assertEquals(None, self.rpc.response) + self.assertTrue(self.rpc.waited) + + def testSetResponse(self): + self.rpc.set_response(self.response) + + self.assertEquals(self.request, self.rpc.request) + self.assertEquals(remote.RpcState.OK, self.rpc.state) + self.assertEquals(self.response, self.rpc.response) + self.assertEquals(None, self.rpc.error_message) + self.assertEquals(None, self.rpc.error_name) + + def testSetResponseAlreadySet(self): + self.rpc.set_response(self.response) + + self.assertRaisesWithRegexpMatch( + transport.RpcStateError, + 'RPC must be in RUNNING state to change to OK', + self.rpc.set_response, + self.response) + + def testSetResponseAlreadyError(self): + self.rpc.set_status(self.status) + + self.assertRaisesWithRegexpMatch( + transport.RpcStateError, + 'RPC must be in RUNNING state to change to OK', + self.rpc.set_response, + self.response) + + def testSetStatus(self): + self.rpc.set_status(self.status) + + self.assertEquals(self.request, self.rpc.request) + self.assertEquals(remote.RpcState.APPLICATION_ERROR, self.rpc.state) + self.assertEquals('an error', self.rpc.error_message) + self.assertEquals('blam', self.rpc.error_name) + self.assertRaisesWithRegexpMatch(remote.ApplicationError, + 'an error', + getattr, self.rpc, 'response') + + def testSetStatusAlreadySet(self): + self.rpc.set_response(self.response) + + self.assertRaisesWithRegexpMatch( + transport.RpcStateError, + 'RPC must be in RUNNING state to change to OK', + self.rpc.set_response, + self.response) + + def testSetNonMessage(self): + self.assertRaisesWithRegexpMatch( + TypeError, + 'Expected Message type, received 10', + self.rpc.set_response, + 10) + + def testSetStatusAlreadyError(self): + self.rpc.set_status(self.status) + + self.assertRaisesWithRegexpMatch( + transport.RpcStateError, + 'RPC must be in RUNNING state to change to OK', + self.rpc.set_response, + self.response) + + def testSetUninitializedStatus(self): + self.assertRaises(messages.ValidationError, + self.rpc.set_status, + remote.RpcStatus()) + + +class TransportTest(test_util.TestCase): + + def setUp(self): + remote.Protocols.set_default(remote.Protocols.new_default()) + + def do_test(self, protocol, trans): + request = Message() + request.value = u'request' + + response = Message() + response.value = u'response' + + encoded_request = protocol.encode_message(request) + encoded_response = protocol.encode_message(response) + + self.assertEquals(protocol, trans.protocol) + + received_rpc = [None] + def transport_rpc(remote, rpc_request): + self.assertEquals(remote, Service.method.remote) + self.assertEquals(request, rpc_request) + rpc = TestRpc(request) + rpc.set_response(response) + return rpc + trans._start_rpc = transport_rpc + + rpc = trans.send_rpc(Service.method.remote, request) + self.assertEquals(response, rpc.response) + + def testDefaultProtocol(self): + trans = transport.Transport() + self.do_test(protobuf, trans) + self.assertEquals(protobuf, trans.protocol_config.protocol) + self.assertEquals('default', trans.protocol_config.name) + + def testAlternateProtocol(self): + trans = transport.Transport(protocol=protojson) + self.do_test(protojson, trans) + self.assertEquals(protojson, trans.protocol_config.protocol) + self.assertEquals('default', trans.protocol_config.name) + + def testProtocolConfig(self): + protocol_config = remote.ProtocolConfig( + protojson, 'protoconfig', 'image/png') + trans = transport.Transport(protocol=protocol_config) + self.do_test(protojson, trans) + self.assertTrue(trans.protocol_config is protocol_config) + + def testProtocolByName(self): + remote.Protocols.get_default().add_protocol( + protojson, 'png', 'image/png', ()) + trans = transport.Transport(protocol='png') + self.do_test(protojson, trans) + + +@remote.method(Message, Message) +def my_method(self, request): + self.fail('self.my_method should not be directly invoked.') + + +class FakeConnectionClass(object): + + def __init__(self, mox): + self.request = mox.CreateMockAnything() + self.response = mox.CreateMockAnything() + + +class HttpTransportTest(webapp_test_util.WebServerTestBase): + + def setUp(self): + # Do not need much parent construction functionality. + + self.schema = 'http' + self.server = None + + self.request = Message(value=u'The request value') + self.encoded_request = protojson.encode_message(self.request) + + self.response = Message(value=u'The response value') + self.encoded_response = protojson.encode_message(self.response) + + def testCallSucceeds(self): + self.ResetServer(wsgi_util.static_page(self.encoded_response, + content_type='application/json')) + + rpc = self.connection.send_rpc(my_method.remote, self.request) + self.assertEquals(self.response, rpc.response) + + def testHttps(self): + self.schema = 'https' + self.ResetServer(wsgi_util.static_page(self.encoded_response, + content_type='application/json')) + + # Create a fake https connection function that really just calls http. + self.used_https = False + def https_connection(*args, **kwargs): + self.used_https = True + return six.moves.http_client.HTTPConnection(*args, **kwargs) + + original_https_connection = six.moves.http_client.HTTPSConnection + six.moves.http_client.HTTPSConnection = https_connection + try: + rpc = self.connection.send_rpc(my_method.remote, self.request) + finally: + six.moves.http_client.HTTPSConnection = original_https_connection + self.assertEquals(self.response, rpc.response) + self.assertTrue(self.used_https) + + def testHttpSocketError(self): + self.ResetServer(wsgi_util.static_page(self.encoded_response, + content_type='application/json')) + + bad_transport = transport.HttpTransport('http://localhost:-1/blar') + try: + bad_transport.send_rpc(my_method.remote, self.request) + except remote.NetworkError as err: + self.assertTrue(str(err).startswith('Socket error: error (')) + self.assertEquals(errno.ECONNREFUSED, err.cause.errno) + else: + self.fail('Expected error') + + def testHttpRequestError(self): + self.ResetServer(wsgi_util.static_page(self.encoded_response, + content_type='application/json')) + + def request_error(*args, **kwargs): + raise TypeError('Generic Error') + original_request = six.moves.http_client.HTTPConnection.request + six.moves.http_client.HTTPConnection.request = request_error + try: + try: + self.connection.send_rpc(my_method.remote, self.request) + except remote.NetworkError as err: + self.assertEquals('Error communicating with HTTP server', str(err)) + self.assertEquals(TypeError, type(err.cause)) + self.assertEquals('Generic Error', str(err.cause)) + else: + self.fail('Expected error') + finally: + six.moves.http_client.HTTPConnection.request = original_request + + def testHandleGenericServiceError(self): + self.ResetServer(wsgi_util.error(six.moves.http_client.INTERNAL_SERVER_ERROR, + 'arbitrary error', + content_type='text/plain')) + + rpc = self.connection.send_rpc(my_method.remote, self.request) + try: + rpc.response + except remote.ServerError as err: + self.assertEquals('HTTP Error 500: arbitrary error', str(err).strip()) + else: + self.fail('Expected ServerError') + + def testHandleGenericServiceErrorNoMessage(self): + self.ResetServer(wsgi_util.error(six.moves.http_client.NOT_IMPLEMENTED, + ' ', + content_type='text/plain')) + + rpc = self.connection.send_rpc(my_method.remote, self.request) + try: + rpc.response + except remote.ServerError as err: + self.assertEquals('HTTP Error 501: Not Implemented', str(err).strip()) + else: + self.fail('Expected ServerError') + + def testHandleStatusContent(self): + self.ResetServer(wsgi_util.static_page('{"state": "REQUEST_ERROR",' + ' "error_message": "a request error"' + '}', + status=six.moves.http_client.BAD_REQUEST, + content_type='application/json')) + + rpc = self.connection.send_rpc(my_method.remote, self.request) + try: + rpc.response + except remote.RequestError as err: + self.assertEquals('a request error', str(err)) + else: + self.fail('Expected RequestError') + + def testHandleApplicationError(self): + self.ResetServer(wsgi_util.static_page('{"state": "APPLICATION_ERROR",' + ' "error_message": "an app error",' + ' "error_name": "MY_ERROR_NAME"}', + status=six.moves.http_client.BAD_REQUEST, + content_type='application/json')) + + rpc = self.connection.send_rpc(my_method.remote, self.request) + try: + rpc.response + except remote.ApplicationError as err: + self.assertEquals('an app error', str(err)) + self.assertEquals('MY_ERROR_NAME', err.error_name) + else: + self.fail('Expected RequestError') + + def testHandleUnparsableErrorContent(self): + self.ResetServer(wsgi_util.static_page('oops', + status=six.moves.http_client.BAD_REQUEST, + content_type='application/json')) + + rpc = self.connection.send_rpc(my_method.remote, self.request) + try: + rpc.response + except remote.ServerError as err: + self.assertEquals('HTTP Error 400: oops', str(err)) + else: + self.fail('Expected ServerError') + + def testHandleEmptyBadRpcStatus(self): + self.ResetServer(wsgi_util.static_page('{"error_message": "x"}', + status=six.moves.http_client.BAD_REQUEST, + content_type='application/json')) + + rpc = self.connection.send_rpc(my_method.remote, self.request) + try: + rpc.response + except remote.ServerError as err: + self.assertEquals('HTTP Error 400: {"error_message": "x"}', str(err)) + else: + self.fail('Expected ServerError') + + def testUseProtocolConfigContentType(self): + expected_content_type = 'image/png' + def expect_content_type(environ, start_response): + self.assertEquals(expected_content_type, environ['CONTENT_TYPE']) + app = wsgi_util.static_page('', content_type=environ['CONTENT_TYPE']) + return app(environ, start_response) + + self.ResetServer(expect_content_type) + + protocol_config = remote.ProtocolConfig(protojson, 'json', 'image/png') + self.connection = self.CreateTransport(self.service_url, protocol_config) + + rpc = self.connection.send_rpc(my_method.remote, self.request) + self.assertEquals(Message(), rpc.response) + + +class SimpleRequest(messages.Message): + + content = messages.StringField(1) + + +class SimpleResponse(messages.Message): + + content = messages.StringField(1) + factory_value = messages.StringField(2) + remote_host = messages.StringField(3) + remote_address = messages.StringField(4) + server_host = messages.StringField(5) + server_port = messages.IntegerField(6) + + +class LocalService(remote.Service): + + def __init__(self, factory_value='default'): + self.factory_value = factory_value + + @remote.method(SimpleRequest, SimpleResponse) + def call_method(self, request): + return SimpleResponse(content=request.content, + factory_value=self.factory_value, + remote_host=self.request_state.remote_host, + remote_address=self.request_state.remote_address, + server_host=self.request_state.server_host, + server_port=self.request_state.server_port) + + @remote.method() + def raise_totally_unexpected(self, request): + raise TypeError('Kablam') + + @remote.method() + def raise_unexpected(self, request): + raise remote.RequestError('Huh?') + + @remote.method() + def raise_application_error(self, request): + raise remote.ApplicationError('App error', 10) + + +class LocalTransportTest(test_util.TestCase): + + def CreateService(self, factory_value='default'): + return + + def testBasicCallWithClass(self): + stub = LocalService.Stub(transport.LocalTransport(LocalService)) + response = stub.call_method(content='Hello') + self.assertEquals(SimpleResponse(content='Hello', + factory_value='default', + remote_host=os.uname()[1], + remote_address='127.0.0.1', + server_host=os.uname()[1], + server_port=-1), + response) + + def testBasicCallWithFactory(self): + stub = LocalService.Stub( + transport.LocalTransport(LocalService.new_factory('assigned'))) + response = stub.call_method(content='Hello') + self.assertEquals(SimpleResponse(content='Hello', + factory_value='assigned', + remote_host=os.uname()[1], + remote_address='127.0.0.1', + server_host=os.uname()[1], + server_port=-1), + response) + + def testTotallyUnexpectedError(self): + stub = LocalService.Stub(transport.LocalTransport(LocalService)) + self.assertRaisesWithRegexpMatch( + remote.ServerError, + 'Unexpected error TypeError: Kablam', + stub.raise_totally_unexpected) + + def testUnexpectedError(self): + stub = LocalService.Stub(transport.LocalTransport(LocalService)) + self.assertRaisesWithRegexpMatch( + remote.ServerError, + 'Unexpected error RequestError: Huh?', + stub.raise_unexpected) + + def testApplicationError(self): + stub = LocalService.Stub(transport.LocalTransport(LocalService)) + self.assertRaisesWithRegexpMatch( + remote.ApplicationError, + 'App error', + stub.raise_application_error) + + +def main(): + unittest.main() + + +if __name__ == '__main__': + main() diff --git a/endpoints/internal/protorpc/util.py b/endpoints/internal/protorpc/util.py new file mode 100644 index 0000000..935295c --- /dev/null +++ b/endpoints/internal/protorpc/util.py @@ -0,0 +1,494 @@ +#!/usr/bin/env python +# +# Copyright 2010 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""Common utility library.""" + +from __future__ import with_statement +import six + +__author__ = ['rafek@google.com (Rafe Kaplan)', + 'guido@google.com (Guido van Rossum)', +] + +import cgi +import datetime +import functools +import inspect +import os +import re +import sys + +__all__ = ['AcceptItem', + 'AcceptError', + 'Error', + 'choose_content_type', + 'decode_datetime', + 'get_package_for_module', + 'pad_string', + 'parse_accept_header', + 'positional', + 'PROTORPC_PROJECT_URL', + 'TimeZoneOffset', + 'total_seconds', +] + + +class Error(Exception): + """Base class for protorpc exceptions.""" + + +class AcceptError(Error): + """Raised when there is an error parsing the accept header.""" + + +PROTORPC_PROJECT_URL = 'http://code.google.com/p/google-protorpc' + +_TIME_ZONE_RE_STRING = r""" + # Examples: + # +01:00 + # -05:30 + # Z12:00 + ((?PZ) | (?P[-+]) + (?P\d\d) : + (?P\d\d))$ +""" +_TIME_ZONE_RE = re.compile(_TIME_ZONE_RE_STRING, re.IGNORECASE | re.VERBOSE) + + +def pad_string(string): + """Pad a string for safe HTTP error responses. + + Prevents Internet Explorer from displaying their own error messages + when sent as the content of error responses. + + Args: + string: A string. + + Returns: + Formatted string left justified within a 512 byte field. + """ + return string.ljust(512) + + +def positional(max_positional_args): + """A decorator to declare that only the first N arguments may be positional. + + This decorator makes it easy to support Python 3 style keyword-only + parameters. For example, in Python 3 it is possible to write: + + def fn(pos1, *, kwonly1=None, kwonly1=None): + ... + + All named parameters after * must be a keyword: + + fn(10, 'kw1', 'kw2') # Raises exception. + fn(10, kwonly1='kw1') # Ok. + + Example: + To define a function like above, do: + + @positional(1) + def fn(pos1, kwonly1=None, kwonly2=None): + ... + + If no default value is provided to a keyword argument, it becomes a required + keyword argument: + + @positional(0) + def fn(required_kw): + ... + + This must be called with the keyword parameter: + + fn() # Raises exception. + fn(10) # Raises exception. + fn(required_kw=10) # Ok. + + When defining instance or class methods always remember to account for + 'self' and 'cls': + + class MyClass(object): + + @positional(2) + def my_method(self, pos1, kwonly1=None): + ... + + @classmethod + @positional(2) + def my_method(cls, pos1, kwonly1=None): + ... + + One can omit the argument to 'positional' altogether, and then no + arguments with default values may be passed positionally. This + would be equivalent to placing a '*' before the first argument + with a default value in Python 3. If there are no arguments with + default values, and no argument is given to 'positional', an error + is raised. + + @positional + def fn(arg1, arg2, required_kw1=None, required_kw2=0): + ... + + fn(1, 3, 5) # Raises exception. + fn(1, 3) # Ok. + fn(1, 3, required_kw1=5) # Ok. + + Args: + max_positional_arguments: Maximum number of positional arguments. All + parameters after the this index must be keyword only. + + Returns: + A decorator that prevents using arguments after max_positional_args from + being used as positional parameters. + + Raises: + TypeError if a keyword-only argument is provided as a positional parameter. + ValueError if no maximum number of arguments is provided and the function + has no arguments with default values. + """ + def positional_decorator(wrapped): + @functools.wraps(wrapped) + def positional_wrapper(*args, **kwargs): + if len(args) > max_positional_args: + plural_s = '' + if max_positional_args != 1: + plural_s = 's' + raise TypeError('%s() takes at most %d positional argument%s ' + '(%d given)' % (wrapped.__name__, + max_positional_args, + plural_s, len(args))) + return wrapped(*args, **kwargs) + return positional_wrapper + + if isinstance(max_positional_args, six.integer_types): + return positional_decorator + else: + args, _, _, defaults = inspect.getargspec(max_positional_args) + if defaults is None: + raise ValueError( + 'Functions with no keyword arguments must specify ' + 'max_positional_args') + return positional(len(args) - len(defaults))(max_positional_args) + + +# TODO(rafek): Support 'level' from the Accept header standard. +class AcceptItem(object): + """Encapsulate a single entry of an Accept header. + + Parses and extracts relevent values from an Accept header and implements + a sort order based on the priority of each requested type as defined + here: + + http://www.w3.org/Protocols/rfc2616/rfc2616-sec14.html + + Accept headers are normally a list of comma separated items. Each item + has the format of a normal HTTP header. For example: + + Accept: text/plain, text/html, text/*, */* + + This header means to prefer plain text over HTML, HTML over any other + kind of text and text over any other kind of supported format. + + This class does not attempt to parse the list of items from the Accept header. + The constructor expects the unparsed sub header and the index within the + Accept header that the fragment was found. + + Properties: + index: The index that this accept item was found in the Accept header. + main_type: The main type of the content type. + sub_type: The sub type of the content type. + q: The q value extracted from the header as a float. If there is no q + value, defaults to 1.0. + values: All header attributes parsed form the sub-header. + sort_key: A tuple (no_main_type, no_sub_type, q, no_values, index): + no_main_type: */* has the least priority. + no_sub_type: Items with no sub-type have less priority. + q: Items with lower q value have less priority. + no_values: Items with no values have less priority. + index: Index of item in accept header is the last priority. + """ + + __CONTENT_TYPE_REGEX = re.compile(r'^([^/]+)/([^/]+)$') + + def __init__(self, accept_header, index): + """Parse component of an Accept header. + + Args: + accept_header: Unparsed sub-expression of accept header. + index: The index that this accept item was found in the Accept header. + """ + accept_header = accept_header.lower() + content_type, values = cgi.parse_header(accept_header) + match = self.__CONTENT_TYPE_REGEX.match(content_type) + if not match: + raise AcceptError('Not valid Accept header: %s' % accept_header) + self.__index = index + self.__main_type = match.group(1) + self.__sub_type = match.group(2) + self.__q = float(values.get('q', 1)) + self.__values = values + + if self.__main_type == '*': + self.__main_type = None + + if self.__sub_type == '*': + self.__sub_type = None + + self.__sort_key = (not self.__main_type, + not self.__sub_type, + -self.__q, + not self.__values, + self.__index) + + @property + def index(self): + return self.__index + + @property + def main_type(self): + return self.__main_type + + @property + def sub_type(self): + return self.__sub_type + + @property + def q(self): + return self.__q + + @property + def values(self): + """Copy the dictionary of values parsed from the header fragment.""" + return dict(self.__values) + + @property + def sort_key(self): + return self.__sort_key + + def match(self, content_type): + """Determine if the given accept header matches content type. + + Args: + content_type: Unparsed content type string. + + Returns: + True if accept header matches content type, else False. + """ + content_type, _ = cgi.parse_header(content_type) + match = self.__CONTENT_TYPE_REGEX.match(content_type.lower()) + if not match: + return False + + main_type, sub_type = match.group(1), match.group(2) + if not(main_type and sub_type): + return False + + return ((self.__main_type is None or self.__main_type == main_type) and + (self.__sub_type is None or self.__sub_type == sub_type)) + + + def __cmp__(self, other): + """Comparison operator based on sort keys.""" + if not isinstance(other, AcceptItem): + return NotImplemented + return cmp(self.sort_key, other.sort_key) + + def __str__(self): + """Rebuilds Accept header.""" + content_type = '%s/%s' % (self.__main_type or '*', self.__sub_type or '*') + values = self.values + + if values: + value_strings = ['%s=%s' % (i, v) for i, v in values.items()] + return '%s; %s' % (content_type, '; '.join(value_strings)) + else: + return content_type + + def __repr__(self): + return 'AcceptItem(%r, %d)' % (str(self), self.__index) + + +def parse_accept_header(accept_header): + """Parse accept header. + + Args: + accept_header: Unparsed accept header. Does not include name of header. + + Returns: + List of AcceptItem instances sorted according to their priority. + """ + accept_items = [] + for index, header in enumerate(accept_header.split(',')): + accept_items.append(AcceptItem(header, index)) + return sorted(accept_items) + + +def choose_content_type(accept_header, supported_types): + """Choose most appropriate supported type based on what client accepts. + + Args: + accept_header: Unparsed accept header. Does not include name of header. + supported_types: List of content-types supported by the server. The index + of the supported types determines which supported type is prefered by + the server should the accept header match more than one at the same + priority. + + Returns: + The preferred supported type if the accept header matches any, else None. + """ + for accept_item in parse_accept_header(accept_header): + for supported_type in supported_types: + if accept_item.match(supported_type): + return supported_type + return None + + +@positional(1) +def get_package_for_module(module): + """Get package name for a module. + + Helper calculates the package name of a module. + + Args: + module: Module to get name for. If module is a string, try to find + module in sys.modules. + + Returns: + If module contains 'package' attribute, uses that as package name. + Else, if module is not the '__main__' module, the module __name__. + Else, the base name of the module file name. Else None. + """ + if isinstance(module, six.string_types): + try: + module = sys.modules[module] + except KeyError: + return None + + try: + return six.text_type(module.package) + except AttributeError: + if module.__name__ == '__main__': + try: + file_name = module.__file__ + except AttributeError: + pass + else: + base_name = os.path.basename(file_name) + split_name = os.path.splitext(base_name) + if len(split_name) == 1: + return six.text_type(base_name) + else: + return u'.'.join(split_name[:-1]) + + return six.text_type(module.__name__) + + +def total_seconds(offset): + """Backport of offset.total_seconds() from python 2.7+.""" + seconds = offset.days * 24 * 60 * 60 + offset.seconds + microseconds = seconds * 10**6 + offset.microseconds + return microseconds / (10**6 * 1.0) + + +class TimeZoneOffset(datetime.tzinfo): + """Time zone information as encoded/decoded for DateTimeFields.""" + + def __init__(self, offset): + """Initialize a time zone offset. + + Args: + offset: Integer or timedelta time zone offset, in minutes from UTC. This + can be negative. + """ + super(TimeZoneOffset, self).__init__() + if isinstance(offset, datetime.timedelta): + offset = total_seconds(offset) / 60 + self.__offset = offset + + def utcoffset(self, dt): + """Get the a timedelta with the time zone's offset from UTC. + + Returns: + The time zone offset from UTC, as a timedelta. + """ + return datetime.timedelta(minutes=self.__offset) + + def dst(self, dt): + """Get the daylight savings time offset. + + The formats that ProtoRPC uses to encode/decode time zone information don't + contain any information about daylight savings time. So this always + returns a timedelta of 0. + + Returns: + A timedelta of 0. + """ + return datetime.timedelta(0) + + +def decode_datetime(encoded_datetime): + """Decode a DateTimeField parameter from a string to a python datetime. + + Args: + encoded_datetime: A string in RFC 3339 format. + + Returns: + A datetime object with the date and time specified in encoded_datetime. + + Raises: + ValueError: If the string is not in a recognized format. + """ + # Check if the string includes a time zone offset. Break out the + # part that doesn't include time zone info. Convert to uppercase + # because all our comparisons should be case-insensitive. + time_zone_match = _TIME_ZONE_RE.search(encoded_datetime) + if time_zone_match: + time_string = encoded_datetime[:time_zone_match.start(1)].upper() + else: + time_string = encoded_datetime.upper() + + if '.' in time_string: + format_string = '%Y-%m-%dT%H:%M:%S.%f' + else: + format_string = '%Y-%m-%dT%H:%M:%S' + + decoded_datetime = datetime.datetime.strptime(time_string, format_string) + + if not time_zone_match: + return decoded_datetime + + # Time zone info was included in the parameter. Add a tzinfo + # object to the datetime. Datetimes can't be changed after they're + # created, so we'll need to create a new one. + if time_zone_match.group('z'): + offset_minutes = 0 + else: + sign = time_zone_match.group('sign') + hours, minutes = [int(value) for value in + time_zone_match.group('hours', 'minutes')] + offset_minutes = hours * 60 + minutes + if sign == '-': + offset_minutes *= -1 + + return datetime.datetime(decoded_datetime.year, + decoded_datetime.month, + decoded_datetime.day, + decoded_datetime.hour, + decoded_datetime.minute, + decoded_datetime.second, + decoded_datetime.microsecond, + TimeZoneOffset(offset_minutes)) diff --git a/endpoints/internal/protorpc/util_test.py b/endpoints/internal/protorpc/util_test.py new file mode 100644 index 0000000..df05c32 --- /dev/null +++ b/endpoints/internal/protorpc/util_test.py @@ -0,0 +1,394 @@ +#!/usr/bin/env python +# +# Copyright 2010 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""Tests for protorpc.util.""" +import six + +__author__ = 'rafek@google.com (Rafe Kaplan)' + + +import datetime +import random +import sys +import types +import unittest + +from protorpc import test_util +from protorpc import util + + +class ModuleInterfaceTest(test_util.ModuleInterfaceTest, + test_util.TestCase): + + MODULE = util + + +class PadStringTest(test_util.TestCase): + + def testPadEmptyString(self): + self.assertEquals(' ' * 512, util.pad_string('')) + + def testPadString(self): + self.assertEquals('hello' + (507 * ' '), util.pad_string('hello')) + + def testPadLongString(self): + self.assertEquals('x' * 1000, util.pad_string('x' * 1000)) + + +class UtilTest(test_util.TestCase): + + def testDecoratedFunction_LengthZero(self): + @util.positional(0) + def fn(kwonly=1): + return [kwonly] + self.assertEquals([1], fn()) + self.assertEquals([2], fn(kwonly=2)) + self.assertRaisesWithRegexpMatch(TypeError, + r'fn\(\) takes at most 0 positional ' + r'arguments \(1 given\)', + fn, 1) + + def testDecoratedFunction_LengthOne(self): + @util.positional(1) + def fn(pos, kwonly=1): + return [pos, kwonly] + self.assertEquals([1, 1], fn(1)) + self.assertEquals([2, 2], fn(2, kwonly=2)) + self.assertRaisesWithRegexpMatch(TypeError, + r'fn\(\) takes at most 1 positional ' + r'argument \(2 given\)', + fn, 2, 3) + + def testDecoratedFunction_LengthTwoWithDefault(self): + @util.positional(2) + def fn(pos1, pos2=1, kwonly=1): + return [pos1, pos2, kwonly] + self.assertEquals([1, 1, 1], fn(1)) + self.assertEquals([2, 2, 1], fn(2, 2)) + self.assertEquals([2, 3, 4], fn(2, 3, kwonly=4)) + self.assertRaisesWithRegexpMatch(TypeError, + r'fn\(\) takes at most 2 positional ' + r'arguments \(3 given\)', + fn, 2, 3, 4) + + def testDecoratedMethod(self): + class MyClass(object): + @util.positional(2) + def meth(self, pos1, kwonly=1): + return [pos1, kwonly] + self.assertEquals([1, 1], MyClass().meth(1)) + self.assertEquals([2, 2], MyClass().meth(2, kwonly=2)) + self.assertRaisesWithRegexpMatch(TypeError, + r'meth\(\) takes at most 2 positional ' + r'arguments \(3 given\)', + MyClass().meth, 2, 3) + + def testDefaultDecoration(self): + @util.positional + def fn(a, b, c=None): + return a, b, c + self.assertEquals((1, 2, 3), fn(1, 2, c=3)) + self.assertEquals((3, 4, None), fn(3, b=4)) + self.assertRaisesWithRegexpMatch(TypeError, + r'fn\(\) takes at most 2 positional ' + r'arguments \(3 given\)', + fn, 2, 3, 4) + + def testDefaultDecorationNoKwdsFails(self): + def fn(a): + return a + self.assertRaisesRegexp( + ValueError, + 'Functions with no keyword arguments must specify max_positional_args', + util.positional, fn) + + def testDecoratedFunctionDocstring(self): + @util.positional(0) + def fn(kwonly=1): + """fn docstring.""" + return [kwonly] + self.assertEquals('fn docstring.', fn.__doc__) + + +class AcceptItemTest(test_util.TestCase): + + def CheckAttributes(self, item, main_type, sub_type, q=1, values={}, index=1): + self.assertEquals(index, item.index) + self.assertEquals(main_type, item.main_type) + self.assertEquals(sub_type, item.sub_type) + self.assertEquals(q, item.q) + self.assertEquals(values, item.values) + + def testParse(self): + self.CheckAttributes(util.AcceptItem('*/*', 1), None, None) + self.CheckAttributes(util.AcceptItem('text/*', 1), 'text', None) + self.CheckAttributes(util.AcceptItem('text/plain', 1), 'text', 'plain') + self.CheckAttributes( + util.AcceptItem('text/plain; q=0.3', 1), 'text', 'plain', 0.3, + values={'q': '0.3'}) + self.CheckAttributes( + util.AcceptItem('text/plain; level=2', 1), 'text', 'plain', + values={'level': '2'}) + self.CheckAttributes( + util.AcceptItem('text/plain', 10), 'text', 'plain', index=10) + + def testCaseInsensitive(self): + self.CheckAttributes(util.AcceptItem('Text/Plain', 1), 'text', 'plain') + + def testBadValue(self): + self.assertRaises(util.AcceptError, + util.AcceptItem, 'bad value', 1) + self.assertRaises(util.AcceptError, + util.AcceptItem, 'bad value/', 1) + self.assertRaises(util.AcceptError, + util.AcceptItem, '/bad value', 1) + + def testSortKey(self): + item = util.AcceptItem('main/sub; q=0.2; level=3', 11) + self.assertEquals((False, False, -0.2, False, 11), item.sort_key) + + item = util.AcceptItem('main/*', 12) + self.assertEquals((False, True, -1, True, 12), item.sort_key) + + item = util.AcceptItem('*/*', 1) + self.assertEquals((True, True, -1, True, 1), item.sort_key) + + def testSort(self): + i1 = util.AcceptItem('text/*', 1) + i2 = util.AcceptItem('text/html', 2) + i3 = util.AcceptItem('text/html; q=0.9', 3) + i4 = util.AcceptItem('text/html; q=0.3', 4) + i5 = util.AcceptItem('text/xml', 5) + i6 = util.AcceptItem('text/html; level=1', 6) + i7 = util.AcceptItem('*/*', 7) + items = [i1, i2 ,i3 ,i4 ,i5 ,i6, i7] + random.shuffle(items) + self.assertEquals([i6, i2, i5, i3, i4, i1, i7], sorted(items)) + + def testMatchAll(self): + item = util.AcceptItem('*/*', 1) + self.assertTrue(item.match('text/html')) + self.assertTrue(item.match('text/plain; level=1')) + self.assertTrue(item.match('image/png')) + self.assertTrue(item.match('image/png; q=0.3')) + + def testMatchMainType(self): + item = util.AcceptItem('text/*', 1) + self.assertTrue(item.match('text/html')) + self.assertTrue(item.match('text/plain; level=1')) + self.assertFalse(item.match('image/png')) + self.assertFalse(item.match('image/png; q=0.3')) + + def testMatchFullType(self): + item = util.AcceptItem('text/plain', 1) + self.assertFalse(item.match('text/html')) + self.assertTrue(item.match('text/plain; level=1')) + self.assertFalse(item.match('image/png')) + self.assertFalse(item.match('image/png; q=0.3')) + + def testMatchCaseInsensitive(self): + item = util.AcceptItem('text/plain', 1) + self.assertTrue(item.match('tExt/pLain')) + + def testStr(self): + self.assertHeaderSame('*/*', str(util.AcceptItem('*/*', 1))) + self.assertHeaderSame('text/*', str(util.AcceptItem('text/*', 1))) + self.assertHeaderSame('text/plain', + str(util.AcceptItem('text/plain', 1))) + self.assertHeaderSame('text/plain; q=0.2', + str(util.AcceptItem('text/plain; q=0.2', 1))) + self.assertHeaderSame( + 'text/plain; q=0.2; level=1', + str(util.AcceptItem('text/plain; level=1; q=0.2', 1))) + + def testRepr(self): + self.assertEquals("AcceptItem('*/*', 1)", repr(util.AcceptItem('*/*', 1))) + self.assertEquals("AcceptItem('text/plain', 11)", + repr(util.AcceptItem('text/plain', 11))) + + def testValues(self): + item = util.AcceptItem('text/plain; a=1; b=2; c=3;', 1) + values = item.values + self.assertEquals(dict(a="1", b="2", c="3"), values) + values['a'] = "7" + self.assertNotEquals(values, item.values) + + +class ParseAcceptHeaderTest(test_util.TestCase): + + def testIndex(self): + accept_header = """text/*, text/html, text/html; q=0.9, + text/xml, + text/html; level=1, */*""" + accepts = util.parse_accept_header(accept_header) + self.assertEquals(6, len(accepts)) + self.assertEquals([4, 1, 3, 2, 0, 5], [a.index for a in accepts]) + + +class ChooseContentTypeTest(test_util.TestCase): + + def testIgnoreUnrequested(self): + self.assertEquals('application/json', + util.choose_content_type( + 'text/plain, application/json, */*', + ['application/X-Google-protobuf', + 'application/json' + ])) + + def testUseCorrectPreferenceIndex(self): + self.assertEquals('application/json', + util.choose_content_type( + '*/*, text/plain, application/json', + ['application/X-Google-protobuf', + 'application/json' + ])) + + def testPreferFirstInList(self): + self.assertEquals('application/X-Google-protobuf', + util.choose_content_type( + '*/*', + ['application/X-Google-protobuf', + 'application/json' + ])) + + def testCaseInsensitive(self): + self.assertEquals('application/X-Google-protobuf', + util.choose_content_type( + 'application/x-google-protobuf', + ['application/X-Google-protobuf', + 'application/json' + ])) + + +class GetPackageForModuleTest(test_util.TestCase): + + def setUp(self): + self.original_modules = dict(sys.modules) + + def tearDown(self): + sys.modules.clear() + sys.modules.update(self.original_modules) + + def CreateModule(self, name, file_name=None): + if file_name is None: + file_name = '%s.py' % name + module = types.ModuleType(name) + sys.modules[name] = module + return module + + def assertPackageEquals(self, expected, actual): + self.assertEquals(expected, actual) + if actual is not None: + self.assertTrue(isinstance(actual, six.text_type)) + + def testByString(self): + module = self.CreateModule('service_module') + module.package = 'my_package' + self.assertPackageEquals('my_package', + util.get_package_for_module('service_module')) + + def testModuleNameNotInSys(self): + self.assertPackageEquals(None, + util.get_package_for_module('service_module')) + + def testHasPackage(self): + module = self.CreateModule('service_module') + module.package = 'my_package' + self.assertPackageEquals('my_package', util.get_package_for_module(module)) + + def testHasModuleName(self): + module = self.CreateModule('service_module') + self.assertPackageEquals('service_module', + util.get_package_for_module(module)) + + def testIsMain(self): + module = self.CreateModule('__main__') + module.__file__ = '/bing/blam/bloom/blarm/my_file.py' + self.assertPackageEquals('my_file', util.get_package_for_module(module)) + + def testIsMainCompiled(self): + module = self.CreateModule('__main__') + module.__file__ = '/bing/blam/bloom/blarm/my_file.pyc' + self.assertPackageEquals('my_file', util.get_package_for_module(module)) + + def testNoExtension(self): + module = self.CreateModule('__main__') + module.__file__ = '/bing/blam/bloom/blarm/my_file' + self.assertPackageEquals('my_file', util.get_package_for_module(module)) + + def testNoPackageAtAll(self): + module = self.CreateModule('__main__') + self.assertPackageEquals('__main__', util.get_package_for_module(module)) + + +class DateTimeTests(test_util.TestCase): + + def testDecodeDateTime(self): + """Test that a RFC 3339 datetime string is decoded properly.""" + for datetime_string, datetime_vals in ( + ('2012-09-30T15:31:50.262', (2012, 9, 30, 15, 31, 50, 262000)), + ('2012-09-30T15:31:50', (2012, 9, 30, 15, 31, 50, 0))): + decoded = util.decode_datetime(datetime_string) + expected = datetime.datetime(*datetime_vals) + self.assertEquals(expected, decoded) + + def testDateTimeTimeZones(self): + """Test that a datetime string with a timezone is decoded correctly.""" + for datetime_string, datetime_vals in ( + ('2012-09-30T15:31:50.262-06:00', + (2012, 9, 30, 15, 31, 50, 262000, util.TimeZoneOffset(-360))), + ('2012-09-30T15:31:50.262+01:30', + (2012, 9, 30, 15, 31, 50, 262000, util.TimeZoneOffset(90))), + ('2012-09-30T15:31:50+00:05', + (2012, 9, 30, 15, 31, 50, 0, util.TimeZoneOffset(5))), + ('2012-09-30T15:31:50+00:00', + (2012, 9, 30, 15, 31, 50, 0, util.TimeZoneOffset(0))), + ('2012-09-30t15:31:50-00:00', + (2012, 9, 30, 15, 31, 50, 0, util.TimeZoneOffset(0))), + ('2012-09-30t15:31:50z', + (2012, 9, 30, 15, 31, 50, 0, util.TimeZoneOffset(0))), + ('2012-09-30T15:31:50-23:00', + (2012, 9, 30, 15, 31, 50, 0, util.TimeZoneOffset(-1380)))): + decoded = util.decode_datetime(datetime_string) + expected = datetime.datetime(*datetime_vals) + self.assertEquals(expected, decoded) + + def testDecodeDateTimeInvalid(self): + """Test that decoding malformed datetime strings raises execptions.""" + for datetime_string in ('invalid', + '2012-09-30T15:31:50.', + '-08:00 2012-09-30T15:31:50.262', + '2012-09-30T15:31', + '2012-09-30T15:31Z', + '2012-09-30T15:31:50ZZ', + '2012-09-30T15:31:50.262 blah blah -08:00', + '1000-99-99T25:99:99.999-99:99'): + self.assertRaises(ValueError, util.decode_datetime, datetime_string) + + def testTimeZoneOffsetDelta(self): + """Test that delta works with TimeZoneOffset.""" + time_zone = util.TimeZoneOffset(datetime.timedelta(minutes=3)) + epoch = time_zone.utcoffset(datetime.datetime.utcfromtimestamp(0)) + self.assertEqual(180, util.total_seconds(epoch)) + + +def main(): + unittest.main() + + +if __name__ == '__main__': + main() diff --git a/endpoints/internal/protorpc/webapp/__init__.py b/endpoints/internal/protorpc/webapp/__init__.py new file mode 100644 index 0000000..ce0df32 --- /dev/null +++ b/endpoints/internal/protorpc/webapp/__init__.py @@ -0,0 +1,18 @@ +#!/usr/bin/env python +# +# Copyright 2011 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +__author__ = 'rafek@google.com (Rafe Kaplan)' diff --git a/endpoints/internal/protorpc/webapp/forms.py b/endpoints/internal/protorpc/webapp/forms.py new file mode 100644 index 0000000..65d3b96 --- /dev/null +++ b/endpoints/internal/protorpc/webapp/forms.py @@ -0,0 +1,163 @@ +#!/usr/bin/env python +# +# Copyright 2010 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""Webapp forms interface to ProtoRPC services. + +This webapp application is automatically configured to work with ProtoRPCs +that have a configured protorpc.RegistryService. This webapp is +automatically added to the registry service URL at /forms +(default is /protorpc/form) when configured using the +service_handlers.service_mapping function. +""" + +import os + +from .google_imports import template +from .google_imports import webapp + + +__all__ = ['FormsHandler', + 'ResourceHandler', + + 'DEFAULT_REGISTRY_PATH', + ] + +_TEMPLATES_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), + 'static') + +_FORMS_TEMPLATE = os.path.join(_TEMPLATES_DIR, 'forms.html') +_METHODS_TEMPLATE = os.path.join(_TEMPLATES_DIR, 'methods.html') + +DEFAULT_REGISTRY_PATH = '/protorpc' + + +class ResourceHandler(webapp.RequestHandler): + """Serves static resources without needing to add static files to app.yaml.""" + + __RESOURCE_MAP = { + 'forms.js': 'text/javascript', + 'jquery-1.4.2.min.js': 'text/javascript', + 'jquery.json-2.2.min.js': 'text/javascript', + } + + def get(self, relative): + """Serve known static files. + + If static file is not known, will return 404 to client. + + Response items are cached for 300 seconds. + + Args: + relative: Name of static file relative to main FormsHandler. + """ + content_type = self.__RESOURCE_MAP.get(relative, None) + if not content_type: + self.response.set_status(404) + self.response.out.write('Resource not found.') + return + + path = os.path.join(_TEMPLATES_DIR, relative) + self.response.headers['Content-Type'] = content_type + static_file = open(path) + try: + contents = static_file.read() + finally: + static_file.close() + self.response.out.write(contents) + + +class FormsHandler(webapp.RequestHandler): + """Handler for display HTML/javascript forms of ProtoRPC method calls. + + When accessed with no query parameters, will show a web page that displays + all services and methods on the associated registry path. Links on this + page fill in the service_path and method_name query parameters back to this + same handler. + + When provided with service_path and method_name parameters will display a + dynamic form representing the request message for that method. When sent, + the form sends a JSON request to the ProtoRPC method and displays the + response in the HTML page. + + Attribute: + registry_path: Read-only registry path known by this handler. + """ + + def __init__(self, registry_path=DEFAULT_REGISTRY_PATH): + """Constructor. + + When configuring a FormsHandler to use with a webapp application do not + pass the request handler class in directly. Instead use new_factory to + ensure that the FormsHandler is created with the correct registry path + for each request. + + Args: + registry_path: Absolute path on server where the ProtoRPC RegsitryService + is located. + """ + assert registry_path + self.__registry_path = registry_path + + @property + def registry_path(self): + return self.__registry_path + + def get(self): + """Send forms and method page to user. + + By default, displays a web page listing all services and methods registered + on the server. Methods have links to display the actual method form. + + If both parameters are set, will display form for method. + + Query Parameters: + service_path: Path to service to display method of. Optional. + method_name: Name of method to display form for. Optional. + """ + params = {'forms_path': self.request.path.rstrip('/'), + 'hostname': self.request.host, + 'registry_path': self.__registry_path, + } + service_path = self.request.get('path', None) + method_name = self.request.get('method', None) + + if service_path and method_name: + form_template = _METHODS_TEMPLATE + params['service_path'] = service_path + params['method_name'] = method_name + else: + form_template = _FORMS_TEMPLATE + + self.response.out.write(template.render(form_template, params)) + + @classmethod + def new_factory(cls, registry_path=DEFAULT_REGISTRY_PATH): + """Construct a factory for use with WSGIApplication. + + This method is called automatically with the correct registry path when + services are configured via service_handlers.service_mapping. + + Args: + registry_path: Absolute path on server where the ProtoRPC RegsitryService + is located. + + Returns: + Factory function that creates a properly configured FormsHandler instance. + """ + def forms_factory(): + return cls(registry_path) + return forms_factory diff --git a/endpoints/internal/protorpc/webapp/forms_test.py b/endpoints/internal/protorpc/webapp/forms_test.py new file mode 100644 index 0000000..dcac88d --- /dev/null +++ b/endpoints/internal/protorpc/webapp/forms_test.py @@ -0,0 +1,159 @@ +#!/usr/bin/env python +# +# Copyright 2010 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""Tests for protorpc.forms.""" + +__author__ = 'rafek@google.com (Rafe Kaplan)' + + +import os +import unittest + +from protorpc import test_util +from protorpc import webapp_test_util +from protorpc.webapp import forms +from protorpc.webapp.google_imports import template + + +class ModuleInterfaceTest(test_util.ModuleInterfaceTest, + test_util.TestCase): + + MODULE = forms + + +def RenderTemplate(name, **params): + """Load content from static file. + + Args: + name: Name of static file to load from static directory. + params: Passed in to webapp template generator. + + Returns: + Contents of static file. + """ + path = os.path.join(forms._TEMPLATES_DIR, name) + return template.render(path, params) + + +class ResourceHandlerTest(webapp_test_util.RequestHandlerTestBase): + + def CreateRequestHandler(self): + return forms.ResourceHandler() + + def DoStaticContentTest(self, name, expected_type): + """Run the static content test. + + Loads expected static content from source and compares with + results in response. Checks content-type and cache header. + + Args: + name: Name of file that should be served. + expected_type: Expected content-type of served file. + """ + self.handler.get(name) + + content = RenderTemplate(name) + self.CheckResponse('200 OK', + {'content-type': expected_type, + }, + content) + + def testGet(self): + self.DoStaticContentTest('forms.js', 'text/javascript') + + def testNoSuchFile(self): + self.handler.get('unknown.txt') + + self.CheckResponse('404 Not Found', + {}, + 'Resource not found.') + + +class FormsHandlerTest(webapp_test_util.RequestHandlerTestBase): + + def CreateRequestHandler(self): + handler = forms.FormsHandler('/myreg') + self.assertEquals('/myreg', handler.registry_path) + return handler + + def testGetForm(self): + self.handler.get() + + content = RenderTemplate( + 'forms.html', + forms_path='/tmp/myhandler', + hostname=self.request.host, + registry_path='/myreg') + + self.CheckResponse('200 OK', + {}, + content) + + def testGet_MissingPath(self): + self.ResetHandler({'QUERY_STRING': 'method=my_method'}) + + self.handler.get() + + content = RenderTemplate( + 'forms.html', + forms_path='/tmp/myhandler', + hostname=self.request.host, + registry_path='/myreg') + + self.CheckResponse('200 OK', + {}, + content) + + def testGet_MissingMethod(self): + self.ResetHandler({'QUERY_STRING': 'path=/my-path'}) + + self.handler.get() + + content = RenderTemplate( + 'forms.html', + forms_path='/tmp/myhandler', + hostname=self.request.host, + registry_path='/myreg') + + self.CheckResponse('200 OK', + {}, + content) + + def testGetMethod(self): + self.ResetHandler({'QUERY_STRING': 'path=/my-path&method=my_method'}) + + self.handler.get() + + content = RenderTemplate( + 'methods.html', + forms_path='/tmp/myhandler', + hostname=self.request.host, + registry_path='/myreg', + service_path='/my-path', + method_name='my_method') + + self.CheckResponse('200 OK', + {}, + content) + + +def main(): + unittest.main() + + +if __name__ == '__main__': + main() diff --git a/endpoints/internal/protorpc/webapp/google_imports.py b/endpoints/internal/protorpc/webapp/google_imports.py new file mode 100644 index 0000000..b7de40c --- /dev/null +++ b/endpoints/internal/protorpc/webapp/google_imports.py @@ -0,0 +1,25 @@ +"""Dynamically decide from where to import other SDK modules. + +All other protorpc.webapp code should import other SDK modules from +this module. If necessary, add new imports here (in both places). +""" + +__author__ = 'yey@google.com (Ye Yuan)' + +# pylint: disable=g-import-not-at-top +# pylint: disable=unused-import + +import os +import sys + +try: + from google.appengine import ext + normal_environment = True +except ImportError: + normal_environment = False + + +if normal_environment: + from google.appengine.ext import webapp + from google.appengine.ext.webapp import util as webapp_util + from google.appengine.ext.webapp import template diff --git a/endpoints/internal/protorpc/webapp/service_handlers.py b/endpoints/internal/protorpc/webapp/service_handlers.py new file mode 100644 index 0000000..94a0855 --- /dev/null +++ b/endpoints/internal/protorpc/webapp/service_handlers.py @@ -0,0 +1,834 @@ +#!/usr/bin/env python +# +# Copyright 2010 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""Handlers for remote services. + +This module contains classes that may be used to build a service +on top of the App Engine Webapp framework. + +The services request handler can be configured to handle requests in a number +of different request formats. All different request formats must have a way +to map the request to the service handlers defined request message.Message +class. The handler can also send a response in any format that can be mapped +from the response message.Message class. + +Participants in an RPC: + + There are four classes involved with the life cycle of an RPC. + + Service factory: A user-defined service factory that is responsible for + instantiating an RPC service. The methods intended for use as RPC + methods must be decorated by the 'remote' decorator. + + RPCMapper: Responsible for determining whether or not a specific request + matches a particular RPC format and translating between the actual + request/response and the underlying message types. A single instance of + an RPCMapper sub-class is required per service configuration. Each + mapper must be usable across multiple requests. + + ServiceHandler: A webapp.RequestHandler sub-class that responds to the + webapp framework. It mediates between the RPCMapper and service + implementation class during a request. As determined by the Webapp + framework, a new ServiceHandler instance is created to handle each + user request. A handler is never used to handle more than one request. + + ServiceHandlerFactory: A class that is responsible for creating new, + properly configured ServiceHandler instance for each request. The + factory is configured by providing it with a set of RPCMapper instances. + When the Webapp framework invokes the service handler, the handler + creates a new service class instance. The service class instance is + provided with a reference to the handler. A single instance of an + RPCMapper sub-class is required to configure each service. Each mapper + instance must be usable across multiple requests. + +RPC mappers: + + RPC mappers translate between a single HTTP based RPC protocol and the + underlying service implementation. Each RPC mapper must configured + with the following information to determine if it is an appropriate + mapper for a given request: + + http_methods: Set of HTTP methods supported by handler. + + content_types: Set of supported content types. + + default_content_type: Default content type for handler responses. + + Built-in mapper implementations: + + URLEncodedRPCMapper: Matches requests that are compatible with post + forms with the 'application/x-www-form-urlencoded' content-type + (this content type is the default if none is specified. It + translates post parameters into request parameters. + + ProtobufRPCMapper: Matches requests that are compatible with post + forms with the 'application/x-google-protobuf' content-type. It + reads the contents of a binary post request. + +Public Exceptions: + Error: Base class for service handler errors. + ServiceConfigurationError: Raised when a service not correctly configured. + RequestError: Raised by RPC mappers when there is an error in its request + or request format. + ResponseError: Raised by RPC mappers when there is an error in its response. +""" +import six + +__author__ = 'rafek@google.com (Rafe Kaplan)' + +import six.moves.http_client +import logging + +from .google_imports import webapp +from .google_imports import webapp_util +from .. import messages +from .. import protobuf +from .. import protojson +from .. import protourlencode +from .. import registry +from .. import remote +from .. import util +from . import forms + +__all__ = [ + 'Error', + 'RequestError', + 'ResponseError', + 'ServiceConfigurationError', + + 'DEFAULT_REGISTRY_PATH', + + 'ProtobufRPCMapper', + 'RPCMapper', + 'ServiceHandler', + 'ServiceHandlerFactory', + 'URLEncodedRPCMapper', + 'JSONRPCMapper', + 'service_mapping', + 'run_services', +] + + +class Error(Exception): + """Base class for all errors in service handlers module.""" + + +class ServiceConfigurationError(Error): + """When service configuration is incorrect.""" + + +class RequestError(Error): + """Error occurred when building request.""" + + +class ResponseError(Error): + """Error occurred when building response.""" + + +_URLENCODED_CONTENT_TYPE = protourlencode.CONTENT_TYPE +_PROTOBUF_CONTENT_TYPE = protobuf.CONTENT_TYPE +_JSON_CONTENT_TYPE = protojson.CONTENT_TYPE + +_EXTRA_JSON_CONTENT_TYPES = ['application/x-javascript', + 'text/javascript', + 'text/x-javascript', + 'text/x-json', + 'text/json', + ] + +# The whole method pattern is an optional regex. It contains a single +# group used for mapping to the query parameter. This is passed to the +# parameters of 'get' and 'post' on the ServiceHandler. +_METHOD_PATTERN = r'(?:\.([^?]*))?' + +DEFAULT_REGISTRY_PATH = forms.DEFAULT_REGISTRY_PATH + + +class RPCMapper(object): + """Interface to mediate between request and service object. + + Request mappers are implemented to support various types of + RPC protocols. It is responsible for identifying whether a + given request matches a particular protocol, resolve the remote + method to invoke and mediate between the request and appropriate + protocol messages for the remote method. + """ + + @util.positional(4) + def __init__(self, + http_methods, + default_content_type, + protocol, + content_types=None): + """Constructor. + + Args: + http_methods: Set of HTTP methods supported by mapper. + default_content_type: Default content type supported by mapper. + protocol: The protocol implementation. Must implement encode_message and + decode_message. + content_types: Set of additionally supported content types. + """ + self.__http_methods = frozenset(http_methods) + self.__default_content_type = default_content_type + self.__protocol = protocol + + if content_types is None: + content_types = [] + self.__content_types = frozenset([self.__default_content_type] + + content_types) + + @property + def http_methods(self): + return self.__http_methods + + @property + def default_content_type(self): + return self.__default_content_type + + @property + def content_types(self): + return self.__content_types + + def build_request(self, handler, request_type): + """Build request message based on request. + + Each request mapper implementation is responsible for converting a + request to an appropriate message instance. + + Args: + handler: RequestHandler instance that is servicing request. + Must be initialized with request object and been previously determined + to matching the protocol of the RPCMapper. + request_type: Message type to build. + + Returns: + Instance of request_type populated by protocol buffer in request body. + + Raises: + RequestError if the mapper implementation is not able to correctly + convert the request to the appropriate message. + """ + try: + return self.__protocol.decode_message(request_type, handler.request.body) + except (messages.ValidationError, messages.DecodeError) as err: + raise RequestError('Unable to parse request content: %s' % err) + + def build_response(self, handler, response, pad_string=False): + """Build response based on service object response message. + + Each request mapper implementation is responsible for converting a + response message to an appropriate handler response. + + Args: + handler: RequestHandler instance that is servicing request. + Must be initialized with request object and been previously determined + to matching the protocol of the RPCMapper. + response: Response message as returned from the service object. + + Raises: + ResponseError if the mapper implementation is not able to correctly + convert the message to an appropriate response. + """ + try: + encoded_message = self.__protocol.encode_message(response) + except messages.ValidationError as err: + raise ResponseError('Unable to encode message: %s' % err) + else: + handler.response.headers['Content-Type'] = self.default_content_type + handler.response.out.write(encoded_message) + + +class ServiceHandlerFactory(object): + """Factory class used for instantiating new service handlers. + + Normally a handler class is passed directly to the webapp framework + so that it can be simply instantiated to handle a single request. + The service handler, however, must be configured with additional + information so that it knows how to instantiate a service object. + This class acts the same as a normal RequestHandler class by + overriding the __call__ method to correctly configures a ServiceHandler + instance with a new service object. + + The factory must also provide a set of RPCMapper instances which + examine a request to determine what protocol is being used and mediates + between the request and the service object. + + The mapping of a service handler must have a single group indicating the + part of the URL path that maps to the request method. This group must + exist but can be optional for the request (the group may be followed by + '?' in the regular expression matching the request). + + Usage: + + stock_factory = ServiceHandlerFactory(StockService) + ... configure stock_factory by adding RPCMapper instances ... + + application = webapp.WSGIApplication( + [stock_factory.mapping('/stocks')]) + + Default usage: + + application = webapp.WSGIApplication( + [ServiceHandlerFactory.default(StockService).mapping('/stocks')]) + """ + + def __init__(self, service_factory): + """Constructor. + + Args: + service_factory: Service factory to instantiate and provide to + service handler. + """ + self.__service_factory = service_factory + self.__request_mappers = [] + + def all_request_mappers(self): + """Get all request mappers. + + Returns: + Iterator of all request mappers used by this service factory. + """ + return iter(self.__request_mappers) + + def add_request_mapper(self, mapper): + """Add request mapper to end of request mapper list.""" + self.__request_mappers.append(mapper) + + def __call__(self): + """Construct a new service handler instance.""" + return ServiceHandler(self, self.__service_factory()) + + @property + def service_factory(self): + """Service factory associated with this factory.""" + return self.__service_factory + + @staticmethod + def __check_path(path): + """Check a path parameter. + + Make sure a provided path parameter is compatible with the + webapp URL mapping. + + Args: + path: Path to check. This is a plain path, not a regular expression. + + Raises: + ValueError if path does not start with /, path ends with /. + """ + if path.endswith('/'): + raise ValueError('Path %s must not end with /.' % path) + + def mapping(self, path): + """Convenience method to map service to application. + + Args: + path: Path to map service to. It must be a simple path + with a leading / and no trailing /. + + Returns: + Mapping from service URL to service handler factory. + """ + self.__check_path(path) + + service_url_pattern = r'(%s)%s' % (path, _METHOD_PATTERN) + + return service_url_pattern, self + + @classmethod + def default(cls, service_factory, parameter_prefix=''): + """Convenience method to map default factory configuration to application. + + Creates a standardized default service factory configuration that pre-maps + the URL encoded protocol handler to the factory. + + Args: + service_factory: Service factory to instantiate and provide to + service handler. + method_parameter: The name of the form parameter used to determine the + method to invoke used by the URLEncodedRPCMapper. If None, no + parameter is used and the mapper will only match against the form + path-name. Defaults to 'method'. + parameter_prefix: If provided, all the parameters in the form are + expected to begin with that prefix by the URLEncodedRPCMapper. + + Returns: + Mapping from service URL to service handler factory. + """ + factory = cls(service_factory) + + factory.add_request_mapper(ProtobufRPCMapper()) + factory.add_request_mapper(JSONRPCMapper()) + + return factory + + +class ServiceHandler(webapp.RequestHandler): + """Web handler for RPC service. + + Overridden methods: + get: All requests handled by 'handle' method. HTTP method stored in + attribute. Takes remote_method parameter as derived from the URL mapping. + post: All requests handled by 'handle' method. HTTP method stored in + attribute. Takes remote_method parameter as derived from the URL mapping. + redirect: Not implemented for this service handler. + + New methods: + handle: Handle request for both GET and POST. + + Attributes (in addition to attributes in RequestHandler): + service: Service instance associated with request being handled. + method: Method of request. Used by RPCMapper to determine match. + remote_method: Sub-path as provided to the 'get' and 'post' methods. + """ + + def __init__(self, factory, service): + """Constructor. + + Args: + factory: Instance of ServiceFactory used for constructing new service + instances used for handling requests. + service: Service instance used for handling RPC. + """ + self.__factory = factory + self.__service = service + + @property + def service(self): + return self.__service + + def __show_info(self, service_path, remote_method): + self.response.headers['content-type'] = 'text/plain; charset=utf-8' + response_message = [] + if remote_method: + response_message.append('%s.%s is a ProtoRPC method.\n\n' %( + service_path, remote_method)) + else: + response_message.append('%s is a ProtoRPC service.\n\n' % service_path) + definition_name_function = getattr(self.__service, 'definition_name', None) + if definition_name_function: + definition_name = definition_name_function() + else: + definition_name = '%s.%s' % (self.__service.__module__, + self.__service.__class__.__name__) + + response_message.append('Service %s\n\n' % definition_name) + response_message.append('More about ProtoRPC: ') + + response_message.append('http://code.google.com/p/google-protorpc\n') + self.response.out.write(util.pad_string(''.join(response_message))) + + def get(self, service_path, remote_method): + """Handler method for GET requests. + + Args: + service_path: Service path derived from request URL. + remote_method: Sub-path after service path has been matched. + """ + self.handle('GET', service_path, remote_method) + + def post(self, service_path, remote_method): + """Handler method for POST requests. + + Args: + service_path: Service path derived from request URL. + remote_method: Sub-path after service path has been matched. + """ + self.handle('POST', service_path, remote_method) + + def redirect(self, uri, permanent=False): + """Not supported for services.""" + raise NotImplementedError('Services do not currently support redirection.') + + def __send_error(self, + http_code, + status_state, + error_message, + mapper, + error_name=None): + status = remote.RpcStatus(state=status_state, + error_message=error_message, + error_name=error_name) + mapper.build_response(self, status) + self.response.headers['content-type'] = mapper.default_content_type + + logging.error(error_message) + response_content = self.response.out.getvalue() + padding = ' ' * max(0, 512 - len(response_content)) + self.response.out.write(padding) + + self.response.set_status(http_code, error_message) + + def __send_simple_error(self, code, message, pad=True): + """Send error to caller without embedded message.""" + self.response.headers['content-type'] = 'text/plain; charset=utf-8' + logging.error(message) + self.response.set_status(code, message) + + response_message = six.moves.http_client.responses.get(code, 'Unknown Error') + if pad: + response_message = util.pad_string(response_message) + self.response.out.write(response_message) + + def __get_content_type(self): + content_type = self.request.headers.get('content-type', None) + if not content_type: + content_type = self.request.environ.get('HTTP_CONTENT_TYPE', None) + if not content_type: + return None + + # Lop off parameters from the end (for example content-encoding) + return content_type.split(';', 1)[0].lower() + + def __headers(self, content_type): + for name in self.request.headers: + name = name.lower() + if name == 'content-type': + value = content_type + elif name == 'content-length': + value = str(len(self.request.body)) + else: + value = self.request.headers.get(name, '') + yield name, value + + def handle(self, http_method, service_path, remote_method): + """Handle a service request. + + The handle method will handle either a GET or POST response. + It is up to the individual mappers from the handler factory to determine + which request methods they can service. + + If the protocol is not recognized, the request does not provide a correct + request for that protocol or the service object does not support the + requested RPC method, will return error code 400 in the response. + + Args: + http_method: HTTP method of request. + service_path: Service path derived from request URL. + remote_method: Sub-path after service path has been matched. + """ + self.response.headers['x-content-type-options'] = 'nosniff' + if not remote_method and http_method == 'GET': + # Special case a normal get request, presumably via a browser. + self.error(405) + self.__show_info(service_path, remote_method) + return + + content_type = self.__get_content_type() + + # Provide server state to the service. If the service object does not have + # an "initialize_request_state" method, will not attempt to assign state. + try: + state_initializer = self.service.initialize_request_state + except AttributeError: + pass + else: + server_port = self.request.environ.get('SERVER_PORT', None) + if server_port: + server_port = int(server_port) + + request_state = remote.HttpRequestState( + remote_host=self.request.environ.get('REMOTE_HOST', None), + remote_address=self.request.environ.get('REMOTE_ADDR', None), + server_host=self.request.environ.get('SERVER_HOST', None), + server_port=server_port, + http_method=http_method, + service_path=service_path, + headers=list(self.__headers(content_type))) + state_initializer(request_state) + + if not content_type: + self.__send_simple_error(400, 'Invalid RPC request: missing content-type') + return + + # Search for mapper to mediate request. + for mapper in self.__factory.all_request_mappers(): + if content_type in mapper.content_types: + break + else: + if http_method == 'GET': + self.error(six.moves.http_client.UNSUPPORTED_MEDIA_TYPE) + self.__show_info(service_path, remote_method) + else: + self.__send_simple_error(six.moves.http_client.UNSUPPORTED_MEDIA_TYPE, + 'Unsupported content-type: %s' % content_type) + return + + try: + if http_method not in mapper.http_methods: + if http_method == 'GET': + self.error(six.moves.http_client.METHOD_NOT_ALLOWED) + self.__show_info(service_path, remote_method) + else: + self.__send_simple_error(six.moves.http_client.METHOD_NOT_ALLOWED, + 'Unsupported HTTP method: %s' % http_method) + return + + try: + try: + method = getattr(self.service, remote_method) + method_info = method.remote + except AttributeError as err: + self.__send_error( + 400, remote.RpcState.METHOD_NOT_FOUND_ERROR, + 'Unrecognized RPC method: %s' % remote_method, + mapper) + return + + request = mapper.build_request(self, method_info.request_type) + except (RequestError, messages.DecodeError) as err: + self.__send_error(400, + remote.RpcState.REQUEST_ERROR, + 'Error parsing ProtoRPC request (%s)' % err, + mapper) + return + + try: + response = method(request) + except remote.ApplicationError as err: + self.__send_error(400, + remote.RpcState.APPLICATION_ERROR, + unicode(err), + mapper, + err.error_name) + return + + mapper.build_response(self, response) + except Exception as err: + logging.error('An unexpected error occured when handling RPC: %s', + err, exc_info=1) + + self.__send_error(500, + remote.RpcState.SERVER_ERROR, + 'Internal Server Error', + mapper) + return + + +# TODO(rafek): Support tag-id only forms. +class URLEncodedRPCMapper(RPCMapper): + """Request mapper for application/x-www-form-urlencoded forms. + + This mapper is useful for building forms that can invoke RPC. Many services + are also configured to work using URL encoded request information because + of its perceived ease of programming and debugging. + + The mapper must be provided with at least method_parameter or + remote_method_pattern so that it is possible to determine how to determine the + requests RPC method. If both are provided, the service will respond to both + method request types, however, only one may be present in a given request. + If both types are detected, the request will not match. + """ + + def __init__(self, parameter_prefix=''): + """Constructor. + + Args: + parameter_prefix: If provided, all the parameters in the form are + expected to begin with that prefix. + """ + # Private attributes: + # __parameter_prefix: parameter prefix as provided by constructor + # parameter. + super(URLEncodedRPCMapper, self).__init__(['POST'], + _URLENCODED_CONTENT_TYPE, + self) + self.__parameter_prefix = parameter_prefix + + def encode_message(self, message): + """Encode a message using parameter prefix. + + Args: + message: Message to URL Encode. + + Returns: + URL encoded message. + """ + return protourlencode.encode_message(message, + prefix=self.__parameter_prefix) + + @property + def parameter_prefix(self): + """Prefix all form parameters are expected to begin with.""" + return self.__parameter_prefix + + def build_request(self, handler, request_type): + """Build request from URL encoded HTTP request. + + Constructs message from names of URL encoded parameters. If this service + handler has a parameter prefix, parameters must begin with it or are + ignored. + + Args: + handler: RequestHandler instance that is servicing request. + request_type: Message type to build. + + Returns: + Instance of request_type populated by protocol buffer in request + parameters. + + Raises: + RequestError if message type contains nested message field or repeated + message field. Will raise RequestError if there are any repeated + parameters. + """ + request = request_type() + builder = protourlencode.URLEncodedRequestBuilder( + request, prefix=self.__parameter_prefix) + for argument in sorted(handler.request.arguments()): + values = handler.request.get_all(argument) + try: + builder.add_parameter(argument, values) + except messages.DecodeError as err: + raise RequestError(str(err)) + return request + + +class ProtobufRPCMapper(RPCMapper): + """Request mapper for application/x-protobuf service requests. + + This mapper will parse protocol buffer from a POST body and return the request + as a protocol buffer. + """ + + def __init__(self): + super(ProtobufRPCMapper, self).__init__(['POST'], + _PROTOBUF_CONTENT_TYPE, + protobuf) + + +class JSONRPCMapper(RPCMapper): + """Request mapper for application/x-protobuf service requests. + + This mapper will parse protocol buffer from a POST body and return the request + as a protocol buffer. + """ + + def __init__(self): + super(JSONRPCMapper, self).__init__( + ['POST'], + _JSON_CONTENT_TYPE, + protojson, + content_types=_EXTRA_JSON_CONTENT_TYPES) + + +def service_mapping(services, + registry_path=DEFAULT_REGISTRY_PATH): + """Create a services mapping for use with webapp. + + Creates basic default configuration and registration for ProtoRPC services. + Each service listed in the service mapping has a standard service handler + factory created for it. + + The list of mappings can either be an explicit path to service mapping or + just services. If mappings are just services, they will automatically + be mapped to their default name. For exampel: + + package = 'my_package' + + class MyService(remote.Service): + ... + + server_mapping([('/my_path', MyService), # Maps to /my_path + MyService, # Maps to /my_package/MyService + ]) + + Specifying a service mapping: + + Normally services are mapped to URL paths by specifying a tuple + (path, service): + path: The path the service resides on. + service: The service class or service factory for creating new instances + of the service. For more information about service factories, please + see remote.Service.new_factory. + + If no tuple is provided, and therefore no path specified, a default path + is calculated by using the fully qualified service name using a URL path + separator for each of its components instead of a '.'. + + Args: + services: Can be service type, service factory or string definition name of + service being mapped or list of tuples (path, service): + path: Path on server to map service to. + service: Service type, service factory or string definition name of + service being mapped. + Can also be a dict. If so, the keys are treated as the path and values as + the service. + registry_path: Path to give to registry service. Use None to disable + registry service. + + Returns: + List of tuples defining a mapping of request handlers compatible with a + webapp application. + + Raises: + ServiceConfigurationError when duplicate paths are provided. + """ + if isinstance(services, dict): + services = six.iteritems(services) + mapping = [] + registry_map = {} + + if registry_path is not None: + registry_service = registry.RegistryService.new_factory(registry_map) + services = list(services) + [(registry_path, registry_service)] + mapping.append((registry_path + r'/form(?:/)?', + forms.FormsHandler.new_factory(registry_path))) + mapping.append((registry_path + r'/form/(.+)', forms.ResourceHandler)) + + paths = set() + for service_item in services: + infer_path = not isinstance(service_item, (list, tuple)) + if infer_path: + service = service_item + else: + service = service_item[1] + + service_class = getattr(service, 'service_class', service) + + if infer_path: + path = '/' + service_class.definition_name().replace('.', '/') + else: + path = service_item[0] + + if path in paths: + raise ServiceConfigurationError( + 'Path %r is already defined in service mapping' % path.encode('utf-8')) + else: + paths.add(path) + + # Create service mapping for webapp. + new_mapping = ServiceHandlerFactory.default(service).mapping(path) + mapping.append(new_mapping) + + # Update registry with service class. + registry_map[path] = service_class + + return mapping + + +def run_services(services, + registry_path=DEFAULT_REGISTRY_PATH): + """Handle CGI request using service mapping. + + Args: + Same as service_mapping. + """ + mappings = service_mapping(services, registry_path=registry_path) + application = webapp.WSGIApplication(mappings) + webapp_util.run_wsgi_app(application) diff --git a/endpoints/internal/protorpc/webapp/service_handlers_test.py b/endpoints/internal/protorpc/webapp/service_handlers_test.py new file mode 100644 index 0000000..baebbda --- /dev/null +++ b/endpoints/internal/protorpc/webapp/service_handlers_test.py @@ -0,0 +1,1332 @@ +#!/usr/bin/env python +# +# Copyright 2010 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""Tests for protorpc.service_handlers.""" + +__author__ = 'rafek@google.com (Rafe Kaplan)' + + +import cgi +import cStringIO +import os +import re +import sys +import unittest +import urllib + +from protorpc import messages +from protorpc import protobuf +from protorpc import protojson +from protorpc import protourlencode +from protorpc import message_types +from protorpc import registry +from protorpc import remote +from protorpc import test_util +from protorpc import util +from protorpc import webapp_test_util +from protorpc.webapp import forms +from protorpc.webapp import service_handlers +from protorpc.webapp.google_imports import webapp + +import mox + +package = 'test_package' + + +class ModuleInterfaceTest(test_util.ModuleInterfaceTest, + test_util.TestCase): + + MODULE = service_handlers + + +class Enum1(messages.Enum): + """A test enum class.""" + + VAL1 = 1 + VAL2 = 2 + VAL3 = 3 + + +class Request1(messages.Message): + """A test request message type.""" + + integer_field = messages.IntegerField(1) + string_field = messages.StringField(2) + enum_field = messages.EnumField(Enum1, 3) + + +class Response1(messages.Message): + """A test response message type.""" + + integer_field = messages.IntegerField(1) + string_field = messages.StringField(2) + enum_field = messages.EnumField(Enum1, 3) + + +class SuperMessage(messages.Message): + """A test message with a nested message field.""" + + sub_message = messages.MessageField(Request1, 1) + sub_messages = messages.MessageField(Request1, 2, repeated=True) + + +class SuperSuperMessage(messages.Message): + """A test message with two levels of nested.""" + + sub_message = messages.MessageField(SuperMessage, 1) + sub_messages = messages.MessageField(Request1, 2, repeated=True) + + +class RepeatedMessage(messages.Message): + """A test message with a repeated field.""" + + ints = messages.IntegerField(1, repeated=True) + strings = messages.StringField(2, repeated=True) + enums = messages.EnumField(Enum1, 3, repeated=True) + + +class Service(object): + """A simple service that takes a Request1 and returns Request2.""" + + @remote.method(Request1, Response1) + def method1(self, request): + response = Response1() + if hasattr(request, 'integer_field'): + response.integer_field = request.integer_field + if hasattr(request, 'string_field'): + response.string_field = request.string_field + if hasattr(request, 'enum_field'): + response.enum_field = request.enum_field + return response + + @remote.method(RepeatedMessage, RepeatedMessage) + def repeated_method(self, request): + response = RepeatedMessage() + if hasattr(request, 'ints'): + response = request.ints + return response + + def not_remote(self): + pass + + +def VerifyResponse(test, + response, + expected_status, + expected_status_message, + expected_content, + expected_content_type='application/x-www-form-urlencoded'): + def write(content): + if expected_content == '': + test.assertEquals(util.pad_string(''), content) + else: + test.assertNotEquals(-1, content.find(expected_content), + 'Expected to find:\n%s\n\nActual content: \n%s' % ( + expected_content, content)) + + def start_response(response, headers): + status, message = response.split(' ', 1) + test.assertEquals(expected_status, status) + test.assertEquals(expected_status_message, message) + for name, value in headers: + if name.lower() == 'content-type': + test.assertEquals(expected_content_type, value) + for name, value in headers: + if name.lower() == 'x-content-type-options': + test.assertEquals('nosniff', value) + elif name.lower() == 'content-type': + test.assertFalse(value.lower().startswith('text/html')) + return write + + response.wsgi_write(start_response) + + +class ServiceHandlerFactoryTest(test_util.TestCase): + """Tests for the service handler factory.""" + + def testAllRequestMappers(self): + """Test all_request_mappers method.""" + configuration = service_handlers.ServiceHandlerFactory(Service) + mapper1 = service_handlers.RPCMapper(['whatever'], 'whatever', None) + mapper2 = service_handlers.RPCMapper(['whatever'], 'whatever', None) + + configuration.add_request_mapper(mapper1) + self.assertEquals([mapper1], list(configuration.all_request_mappers())) + + configuration.add_request_mapper(mapper2) + self.assertEquals([mapper1, mapper2], + list(configuration.all_request_mappers())) + + def testServiceFactory(self): + """Test that service_factory attribute is set.""" + handler_factory = service_handlers.ServiceHandlerFactory(Service) + self.assertEquals(Service, handler_factory.service_factory) + + def testFactoryMethod(self): + """Test that factory creates correct instance of class.""" + factory = service_handlers.ServiceHandlerFactory(Service) + handler = factory() + + self.assertTrue(isinstance(handler, service_handlers.ServiceHandler)) + self.assertTrue(isinstance(handler.service, Service)) + + def testMapping(self): + """Test the mapping method.""" + factory = service_handlers.ServiceHandlerFactory(Service) + path, mapped_factory = factory.mapping('/my_service') + + self.assertEquals(r'(/my_service)' + service_handlers._METHOD_PATTERN, path) + self.assertEquals(id(factory), id(mapped_factory)) + match = re.match(path, '/my_service.my_method') + self.assertEquals('/my_service', match.group(1)) + self.assertEquals('my_method', match.group(2)) + + path, mapped_factory = factory.mapping('/my_service/nested') + self.assertEquals('(/my_service/nested)' + + service_handlers._METHOD_PATTERN, path) + match = re.match(path, '/my_service/nested.my_method') + self.assertEquals('/my_service/nested', match.group(1)) + self.assertEquals('my_method', match.group(2)) + + def testRegexMapping(self): + """Test the mapping method using a regex.""" + factory = service_handlers.ServiceHandlerFactory(Service) + path, mapped_factory = factory.mapping('.*/my_service') + + self.assertEquals(r'(.*/my_service)' + service_handlers._METHOD_PATTERN, path) + self.assertEquals(id(factory), id(mapped_factory)) + match = re.match(path, '/whatever_preceeds/my_service.my_method') + self.assertEquals('/whatever_preceeds/my_service', match.group(1)) + self.assertEquals('my_method', match.group(2)) + match = re.match(path, '/something_else/my_service.my_other_method') + self.assertEquals('/something_else/my_service', match.group(1)) + self.assertEquals('my_other_method', match.group(2)) + + def testMapping_BadPath(self): + """Test bad parameterse to the mapping method.""" + factory = service_handlers.ServiceHandlerFactory(Service) + self.assertRaises(ValueError, factory.mapping, '/my_service/') + + def testDefault(self): + """Test the default factory convenience method.""" + handler_factory = service_handlers.ServiceHandlerFactory.default( + Service, + parameter_prefix='my_prefix.') + + self.assertEquals(Service, handler_factory.service_factory) + + mappers = handler_factory.all_request_mappers() + + # Verify Protobuf encoded mapper. + protobuf_mapper = next(mappers) + self.assertTrue(isinstance(protobuf_mapper, + service_handlers.ProtobufRPCMapper)) + + # Verify JSON encoded mapper. + json_mapper = next(mappers) + self.assertTrue(isinstance(json_mapper, + service_handlers.JSONRPCMapper)) + + # Should have no more mappers. + self.assertRaises(StopIteration, mappers.next) + + +class ServiceHandlerTest(webapp_test_util.RequestHandlerTestBase): + """Test the ServiceHandler class.""" + + def setUp(self): + self.mox = mox.Mox() + self.service_factory = Service + self.remote_host = 'remote.host.com' + self.server_host = 'server.host.com' + self.ResetRequestHandler() + + self.request = Request1() + self.request.integer_field = 1 + self.request.string_field = 'a' + self.request.enum_field = Enum1.VAL1 + + def ResetRequestHandler(self): + super(ServiceHandlerTest, self).setUp() + + def CreateService(self): + return self.service_factory() + + def CreateRequestHandler(self): + self.rpc_mapper1 = self.mox.CreateMock(service_handlers.RPCMapper) + self.rpc_mapper1.http_methods = set(['POST']) + self.rpc_mapper1.content_types = set(['application/x-www-form-urlencoded']) + self.rpc_mapper1.default_content_type = 'application/x-www-form-urlencoded' + self.rpc_mapper2 = self.mox.CreateMock(service_handlers.RPCMapper) + self.rpc_mapper2.http_methods = set(['GET']) + self.rpc_mapper2.content_types = set(['application/json']) + self.rpc_mapper2.default_content_type = 'application/json' + self.factory = service_handlers.ServiceHandlerFactory( + self.CreateService) + self.factory.add_request_mapper(self.rpc_mapper1) + self.factory.add_request_mapper(self.rpc_mapper2) + return self.factory() + + def GetEnvironment(self): + """Create handler to test.""" + environ = super(ServiceHandlerTest, self).GetEnvironment() + if self.remote_host: + environ['REMOTE_HOST'] = self.remote_host + if self.server_host: + environ['SERVER_HOST'] = self.server_host + return environ + + def VerifyResponse(self, *args, **kwargs): + VerifyResponse(self, + self.response, + *args, **kwargs) + + def ExpectRpcError(self, mapper, state, error_message, error_name=None): + mapper.build_response(self.handler, + remote.RpcStatus(state=state, + error_message=error_message, + error_name=error_name)) + + def testRedirect(self): + """Test that redirection is disabled.""" + self.assertRaises(NotImplementedError, self.handler.redirect, '/') + + def testFirstMapper(self): + """Make sure service attribute works when matches first RPCMapper.""" + self.rpc_mapper1.build_request( + self.handler, Request1).AndReturn(self.request) + + def build_response(handler, response): + output = '%s %s %s' % (response.integer_field, + response.string_field, + response.enum_field) + handler.response.headers['content-type'] = ( + 'application/x-www-form-urlencoded') + handler.response.out.write(output) + self.rpc_mapper1.build_response( + self.handler, mox.IsA(Response1)).WithSideEffects(build_response) + + self.mox.ReplayAll() + + self.handler.handle('POST', '/my_service', 'method1') + + self.VerifyResponse('200', 'OK', '1 a VAL1') + + self.mox.VerifyAll() + + def testSecondMapper(self): + """Make sure service attribute works when matches first RPCMapper. + + Demonstrates the multiplicity of the RPCMapper configuration. + """ + self.rpc_mapper2.build_request( + self.handler, Request1).AndReturn(self.request) + + def build_response(handler, response): + output = '%s %s %s' % (response.integer_field, + response.string_field, + response.enum_field) + handler.response.headers['content-type'] = ( + 'application/x-www-form-urlencoded') + handler.response.out.write(output) + self.rpc_mapper2.build_response( + self.handler, mox.IsA(Response1)).WithSideEffects(build_response) + + self.mox.ReplayAll() + + self.handler.request.headers['Content-Type'] = 'application/json' + self.handler.handle('GET', '/my_service', 'method1') + + self.VerifyResponse('200', 'OK', '1 a VAL1') + + self.mox.VerifyAll() + + def testCaseInsensitiveContentType(self): + """Ensure that matching content-type is case insensitive.""" + request = Request1() + request.integer_field = 1 + request.string_field = 'a' + request.enum_field = Enum1.VAL1 + self.rpc_mapper1.build_request(self.handler, + Request1).AndReturn(self.request) + + def build_response(handler, response): + output = '%s %s %s' % (response.integer_field, + response.string_field, + response.enum_field) + handler.response.out.write(output) + handler.response.headers['content-type'] = 'text/plain' + self.rpc_mapper1.build_response( + self.handler, mox.IsA(Response1)).WithSideEffects(build_response) + + self.mox.ReplayAll() + + self.handler.request.headers['Content-Type'] = ('ApPlIcAtIoN/' + 'X-wWw-FoRm-UrLeNcOdEd') + + self.handler.handle('POST', '/my_service', 'method1') + + self.VerifyResponse('200', 'OK', '1 a VAL1', 'text/plain') + + self.mox.VerifyAll() + + def testContentTypeWithParameters(self): + """Test that content types have parameters parsed out.""" + request = Request1() + request.integer_field = 1 + request.string_field = 'a' + request.enum_field = Enum1.VAL1 + self.rpc_mapper1.build_request(self.handler, + Request1).AndReturn(self.request) + + def build_response(handler, response): + output = '%s %s %s' % (response.integer_field, + response.string_field, + response.enum_field) + handler.response.headers['content-type'] = ( + 'application/x-www-form-urlencoded') + handler.response.out.write(output) + self.rpc_mapper1.build_response( + self.handler, mox.IsA(Response1)).WithSideEffects(build_response) + + self.mox.ReplayAll() + + self.handler.request.headers['Content-Type'] = ('application/' + 'x-www-form-urlencoded' + + '; a=b; c=d') + + self.handler.handle('POST', '/my_service', 'method1') + + self.VerifyResponse('200', 'OK', '1 a VAL1') + + self.mox.VerifyAll() + + def testContentFromHeaderOnly(self): + """Test getting content-type from HTTP_CONTENT_TYPE directly. + + Some bad web server implementations might decide not to set CONTENT_TYPE for + POST requests where there is an empty body. In these cases, need to get + content-type directly from webob environ key HTTP_CONTENT_TYPE. + """ + request = Request1() + request.integer_field = 1 + request.string_field = 'a' + request.enum_field = Enum1.VAL1 + self.rpc_mapper1.build_request(self.handler, + Request1).AndReturn(self.request) + + def build_response(handler, response): + output = '%s %s %s' % (response.integer_field, + response.string_field, + response.enum_field) + handler.response.headers['Content-Type'] = ( + 'application/x-www-form-urlencoded') + handler.response.out.write(output) + self.rpc_mapper1.build_response( + self.handler, mox.IsA(Response1)).WithSideEffects(build_response) + + self.mox.ReplayAll() + + self.handler.request.headers['Content-Type'] = None + self.handler.request.environ['HTTP_CONTENT_TYPE'] = ( + 'application/x-www-form-urlencoded') + + self.handler.handle('POST', '/my_service', 'method1') + + self.VerifyResponse('200', 'OK', '1 a VAL1', + 'application/x-www-form-urlencoded') + + self.mox.VerifyAll() + + def testRequestState(self): + """Make sure request state is passed in to handler that supports it.""" + class ServiceWithState(object): + + initialize_request_state = self.mox.CreateMockAnything() + + @remote.method(Request1, Response1) + def method1(self, request): + return Response1() + + self.service_factory = ServiceWithState + + # Reset handler with new service type. + self.ResetRequestHandler() + + self.rpc_mapper1.build_request( + self.handler, Request1).AndReturn(Request1()) + + def build_response(handler, response): + handler.response.headers['Content-Type'] = ( + 'application/x-www-form-urlencoded') + handler.response.out.write('whatever') + self.rpc_mapper1.build_response( + self.handler, mox.IsA(Response1)).WithSideEffects(build_response) + + def verify_state(state): + return ( + 'remote.host.com' == state.remote_host and + '127.0.0.1' == state.remote_address and + 'server.host.com' == state.server_host and + 8080 == state.server_port and + 'POST' == state.http_method and + '/my_service' == state.service_path and + 'application/x-www-form-urlencoded' == state.headers['content-type'] and + 'dev_appserver_login="test:test@example.com:True"' == + state.headers['cookie']) + ServiceWithState.initialize_request_state(mox.Func(verify_state)) + + self.mox.ReplayAll() + + self.handler.handle('POST', '/my_service', 'method1') + + self.VerifyResponse('200', 'OK', 'whatever') + + self.mox.VerifyAll() + + def testRequestState_MissingHosts(self): + """Make sure missing state environment values are handled gracefully.""" + class ServiceWithState(object): + + initialize_request_state = self.mox.CreateMockAnything() + + @remote.method(Request1, Response1) + def method1(self, request): + return Response1() + + self.service_factory = ServiceWithState + self.remote_host = None + self.server_host = None + + # Reset handler with new service type. + self.ResetRequestHandler() + + self.rpc_mapper1.build_request( + self.handler, Request1).AndReturn(Request1()) + + def build_response(handler, response): + handler.response.headers['Content-Type'] = ( + 'application/x-www-form-urlencoded') + handler.response.out.write('whatever') + self.rpc_mapper1.build_response( + self.handler, mox.IsA(Response1)).WithSideEffects(build_response) + + def verify_state(state): + return (None is state.remote_host and + '127.0.0.1' == state.remote_address and + None is state.server_host and + 8080 == state.server_port) + ServiceWithState.initialize_request_state(mox.Func(verify_state)) + + self.mox.ReplayAll() + + self.handler.handle('POST', '/my_service', 'method1') + + self.VerifyResponse('200', 'OK', 'whatever') + + self.mox.VerifyAll() + + def testNoMatch_UnknownHTTPMethod(self): + """Test what happens when no RPCMapper matches.""" + self.mox.ReplayAll() + + self.handler.handle('UNKNOWN', '/my_service', 'does_not_matter') + + self.VerifyResponse('405', + 'Unsupported HTTP method: UNKNOWN', + 'Method Not Allowed', + 'text/plain; charset=utf-8') + + self.mox.VerifyAll() + + def testNoMatch_GetNotSupported(self): + """Test what happens when GET is not supported.""" + self.mox.ReplayAll() + + self.handler.handle('GET', '/my_service', 'method1') + + self.VerifyResponse('405', + 'Method Not Allowed', + '/my_service.method1 is a ProtoRPC method.\n\n' + 'Service %s.Service\n\n' + 'More about ProtoRPC: ' + 'http://code.google.com/p/google-protorpc' % + (__name__,), + 'text/plain; charset=utf-8') + + self.mox.VerifyAll() + + def testNoMatch_UnknownContentType(self): + """Test what happens when no RPCMapper matches.""" + self.mox.ReplayAll() + + self.handler.request.headers['Content-Type'] = 'image/png' + self.handler.handle('POST', '/my_service', 'method1') + + self.VerifyResponse('415', + 'Unsupported content-type: image/png', + 'Unsupported Media Type', + 'text/plain; charset=utf-8') + + self.mox.VerifyAll() + + def testNoMatch_NoContentType(self): + """Test what happens when no RPCMapper matches..""" + self.mox.ReplayAll() + + self.handler.request.environ.pop('HTTP_CONTENT_TYPE', None) + self.handler.request.headers.pop('Content-Type', None) + self.handler.handle('/my_service', 'POST', 'method1') + + self.VerifyResponse('400', 'Invalid RPC request: missing content-type', + 'Bad Request', + 'text/plain; charset=utf-8') + + self.mox.VerifyAll() + + def testNoSuchMethod(self): + """When service method not found.""" + self.ExpectRpcError(self.rpc_mapper1, + remote.RpcState.METHOD_NOT_FOUND_ERROR, + 'Unrecognized RPC method: no_such_method') + + self.mox.ReplayAll() + + self.handler.handle('POST', '/my_service', 'no_such_method') + + self.VerifyResponse('400', 'Unrecognized RPC method: no_such_method', '') + + self.mox.VerifyAll() + + def testNoSuchRemoteMethod(self): + """When service method exists but is not remote.""" + self.ExpectRpcError(self.rpc_mapper1, + remote.RpcState.METHOD_NOT_FOUND_ERROR, + 'Unrecognized RPC method: not_remote') + + self.mox.ReplayAll() + + self.handler.handle('POST', '/my_service', 'not_remote') + + self.VerifyResponse('400', 'Unrecognized RPC method: not_remote', '') + + self.mox.VerifyAll() + + def testRequestError(self): + """RequestError handling.""" + def build_request(handler, request): + raise service_handlers.RequestError('This is a request error') + self.rpc_mapper1.build_request( + self.handler, Request1).WithSideEffects(build_request) + + self.ExpectRpcError(self.rpc_mapper1, + remote.RpcState.REQUEST_ERROR, + 'Error parsing ProtoRPC request ' + '(This is a request error)') + + self.mox.ReplayAll() + + self.handler.handle('POST', '/my_service', 'method1') + + self.VerifyResponse('400', + 'Error parsing ProtoRPC request ' + '(This is a request error)', + '') + + + self.mox.VerifyAll() + + def testDecodeError(self): + """DecodeError handling.""" + def build_request(handler, request): + raise messages.DecodeError('This is a decode error') + self.rpc_mapper1.build_request( + self.handler, Request1).WithSideEffects(build_request) + + self.ExpectRpcError(self.rpc_mapper1, + remote.RpcState.REQUEST_ERROR, + r'Error parsing ProtoRPC request ' + r'(This is a decode error)') + + self.mox.ReplayAll() + + self.handler.handle('POST', '/my_service', 'method1') + + self.VerifyResponse('400', + 'Error parsing ProtoRPC request ' + '(This is a decode error)', + '') + + self.mox.VerifyAll() + + def testResponseException(self): + """Test what happens when build_response raises ResponseError.""" + self.rpc_mapper1.build_request( + self.handler, Request1).AndReturn(self.request) + + self.rpc_mapper1.build_response( + self.handler, mox.IsA(Response1)).AndRaise( + service_handlers.ResponseError) + + self.ExpectRpcError(self.rpc_mapper1, + remote.RpcState.SERVER_ERROR, + 'Internal Server Error') + + self.mox.ReplayAll() + + self.handler.handle('POST', '/my_service', 'method1') + + self.VerifyResponse('500', 'Internal Server Error', '') + + self.mox.VerifyAll() + + def testGet(self): + """Test that GET goes to 'handle' properly.""" + self.handler.handle = self.mox.CreateMockAnything() + self.handler.handle('GET', '/my_service', 'method1') + self.handler.handle('GET', '/my_other_service', 'method2') + + self.mox.ReplayAll() + + self.handler.get('/my_service', 'method1') + self.handler.get('/my_other_service', 'method2') + + self.mox.VerifyAll() + + def testPost(self): + """Test that POST goes to 'handle' properly.""" + self.handler.handle = self.mox.CreateMockAnything() + self.handler.handle('POST', '/my_service', 'method1') + self.handler.handle('POST', '/my_other_service', 'method2') + + self.mox.ReplayAll() + + self.handler.post('/my_service', 'method1') + self.handler.post('/my_other_service', 'method2') + + self.mox.VerifyAll() + + def testGetNoMethod(self): + self.handler.get('/my_service', '') + self.assertEquals(405, self.handler.response.status) + self.assertEquals( + util.pad_string('/my_service is a ProtoRPC service.\n\n' + 'Service %s.Service\n\n' + 'More about ProtoRPC: ' + 'http://code.google.com/p/google-protorpc\n' % + __name__), + self.handler.response.out.getvalue()) + self.assertEquals( + 'nosniff', + self.handler.response.headers['x-content-type-options']) + + def testGetNotSupported(self): + self.handler.get('/my_service', 'method1') + self.assertEquals(405, self.handler.response.status) + expected_message = ('/my_service.method1 is a ProtoRPC method.\n\n' + 'Service %s.Service\n\n' + 'More about ProtoRPC: ' + 'http://code.google.com/p/google-protorpc\n' % + __name__) + self.assertEquals(util.pad_string(expected_message), + self.handler.response.out.getvalue()) + self.assertEquals( + 'nosniff', + self.handler.response.headers['x-content-type-options']) + + def testGetUnknownContentType(self): + self.handler.request.headers['content-type'] = 'image/png' + self.handler.get('/my_service', 'method1') + self.assertEquals(415, self.handler.response.status) + self.assertEquals( + util.pad_string('/my_service.method1 is a ProtoRPC method.\n\n' + 'Service %s.Service\n\n' + 'More about ProtoRPC: ' + 'http://code.google.com/p/google-protorpc\n' % + __name__), + self.handler.response.out.getvalue()) + self.assertEquals( + 'nosniff', + self.handler.response.headers['x-content-type-options']) + + +class MissingContentLengthTests(ServiceHandlerTest): + """Test for when content-length is not set in the environment. + + This test moves CONTENT_LENGTH from the environment to the + content-length header. + """ + + def GetEnvironment(self): + environment = super(MissingContentLengthTests, self).GetEnvironment() + content_length = str(environment.pop('CONTENT_LENGTH', '0')) + environment['HTTP_CONTENT_LENGTH'] = content_length + return environment + + +class MissingContentTypeTests(ServiceHandlerTest): + """Test for when content-type is not set in the environment. + + This test moves CONTENT_TYPE from the environment to the + content-type header. + """ + + def GetEnvironment(self): + environment = super(MissingContentTypeTests, self).GetEnvironment() + content_type = str(environment.pop('CONTENT_TYPE', '')) + environment['HTTP_CONTENT_TYPE'] = content_type + return environment + + +class RPCMapperTestBase(test_util.TestCase): + + def setUp(self): + """Set up test framework.""" + self.Reinitialize() + + def Reinitialize(self, input='', + get=False, + path_method='method1', + content_type='text/plain'): + """Allows reinitialization of test with custom input values and POST. + + Args: + input: Query string or POST input. + get: Use GET method if True. Use POST if False. + """ + self.factory = service_handlers.ServiceHandlerFactory(Service) + + self.service_handler = service_handlers.ServiceHandler(self.factory, + Service()) + self.service_handler.remote_method = path_method + request_path = '/servicepath' + if path_method: + request_path += '/' + path_method + if get: + request_path += '?' + input + + if get: + environ = {'wsgi.input': cStringIO.StringIO(''), + 'CONTENT_LENGTH': '0', + 'QUERY_STRING': input, + 'REQUEST_METHOD': 'GET', + 'PATH_INFO': request_path, + } + self.service_handler.method = 'GET' + else: + environ = {'wsgi.input': cStringIO.StringIO(input), + 'CONTENT_LENGTH': str(len(input)), + 'QUERY_STRING': '', + 'REQUEST_METHOD': 'POST', + 'PATH_INFO': request_path, + } + self.service_handler.method = 'POST' + + self.request = webapp.Request(environ) + + self.response = webapp.Response() + + self.service_handler.initialize(self.request, self.response) + + self.service_handler.request.headers['Content-Type'] = content_type + + +class RPCMapperTest(RPCMapperTestBase, webapp_test_util.RequestHandlerTestBase): + """Test the RPCMapper base class.""" + + def setUp(self): + RPCMapperTestBase.setUp(self) + webapp_test_util.RequestHandlerTestBase.setUp(self) + self.mox = mox.Mox() + self.protocol = self.mox.CreateMockAnything() + + def GetEnvironment(self): + """Get environment. + + Return bogus content in body. + + Returns: + dict of CGI environment. + """ + environment = super(RPCMapperTest, self).GetEnvironment() + environment['wsgi.input'] = cStringIO.StringIO('my body') + environment['CONTENT_LENGTH'] = len('my body') + return environment + + def testContentTypes_JustDefault(self): + """Test content type attributes.""" + self.mox.ReplayAll() + + mapper = service_handlers.RPCMapper(['GET', 'POST'], + 'my-content-type', + self.protocol) + + self.assertEquals(frozenset(['GET', 'POST']), mapper.http_methods) + self.assertEquals('my-content-type', mapper.default_content_type) + self.assertEquals(frozenset(['my-content-type']), + mapper.content_types) + + self.mox.VerifyAll() + + def testContentTypes_Extended(self): + """Test content type attributes.""" + self.mox.ReplayAll() + + mapper = service_handlers.RPCMapper(['GET', 'POST'], + 'my-content-type', + self.protocol, + content_types=['a', 'b']) + + self.assertEquals(frozenset(['GET', 'POST']), mapper.http_methods) + self.assertEquals('my-content-type', mapper.default_content_type) + self.assertEquals(frozenset(['my-content-type', 'a', 'b']), + mapper.content_types) + + self.mox.VerifyAll() + + def testBuildRequest(self): + """Test building a request.""" + expected_request = Request1() + self.protocol.decode_message(Request1, + 'my body').AndReturn(expected_request) + + self.mox.ReplayAll() + + mapper = service_handlers.RPCMapper(['POST'], + 'my-content-type', + self.protocol) + + request = mapper.build_request(self.handler, Request1) + + self.assertTrue(expected_request is request) + + def testBuildRequest_ValidationError(self): + """Test building a request generating a validation error.""" + expected_request = Request1() + self.protocol.decode_message( + Request1, 'my body').AndRaise(messages.ValidationError('xyz')) + + self.mox.ReplayAll() + + mapper = service_handlers.RPCMapper(['POST'], + 'my-content-type', + self.protocol) + + self.assertRaisesWithRegexpMatch( + service_handlers.RequestError, + 'Unable to parse request content: xyz', + mapper.build_request, + self.handler, + Request1) + + def testBuildRequest_DecodeError(self): + """Test building a request generating a decode error.""" + expected_request = Request1() + self.protocol.decode_message( + Request1, 'my body').AndRaise(messages.DecodeError('xyz')) + + self.mox.ReplayAll() + + mapper = service_handlers.RPCMapper(['POST'], + 'my-content-type', + self.protocol) + + self.assertRaisesWithRegexpMatch( + service_handlers.RequestError, + 'Unable to parse request content: xyz', + mapper.build_request, + self.handler, + Request1) + + def testBuildResponse(self): + """Test building a response.""" + response = Response1() + self.protocol.encode_message(response).AndReturn('encoded') + + self.mox.ReplayAll() + + mapper = service_handlers.RPCMapper(['POST'], + 'my-content-type', + self.protocol) + + request = mapper.build_response(self.handler, response) + + self.assertEquals('my-content-type', + self.handler.response.headers['Content-Type']) + self.assertEquals('encoded', self.handler.response.out.getvalue()) + + def testBuildResponse(self): + """Test building a response.""" + response = Response1() + self.protocol.encode_message(response).AndRaise( + messages.ValidationError('xyz')) + + self.mox.ReplayAll() + + mapper = service_handlers.RPCMapper(['POST'], + 'my-content-type', + self.protocol) + + self.assertRaisesWithRegexpMatch(service_handlers.ResponseError, + 'Unable to encode message: xyz', + mapper.build_response, + self.handler, + response) + + +class ProtocolMapperTestBase(object): + """Base class for basic protocol mapper tests.""" + + def setUp(self): + """Reinitialize test specifically for protocol buffer mapper.""" + super(ProtocolMapperTestBase, self).setUp() + self.Reinitialize(path_method='my_method', + content_type='application/x-google-protobuf') + + self.request_message = Request1() + self.request_message.integer_field = 1 + self.request_message.string_field = u'something' + self.request_message.enum_field = Enum1.VAL1 + + self.response_message = Response1() + self.response_message.integer_field = 1 + self.response_message.string_field = u'something' + self.response_message.enum_field = Enum1.VAL1 + + def testBuildRequest(self): + """Test request building.""" + self.Reinitialize(self.protocol.encode_message(self.request_message), + content_type=self.content_type) + + mapper = self.mapper() + parsed_request = mapper.build_request(self.service_handler, + Request1) + self.assertEquals(self.request_message, parsed_request) + + def testBuildResponse(self): + """Test response building.""" + + mapper = self.mapper() + mapper.build_response(self.service_handler, self.response_message) + self.assertEquals(self.protocol.encode_message(self.response_message), + self.service_handler.response.out.getvalue()) + + def testWholeRequest(self): + """Test the basic flow of a request with mapper class.""" + body = self.protocol.encode_message(self.request_message) + self.Reinitialize(input=body, + content_type=self.content_type) + self.factory.add_request_mapper(self.mapper()) + self.service_handler.handle('POST', '/my_service', 'method1') + VerifyResponse(self, + self.service_handler.response, + '200', + 'OK', + self.protocol.encode_message(self.response_message), + self.content_type) + + +class URLEncodedRPCMapperTest(ProtocolMapperTestBase, RPCMapperTestBase): + """Test the URL encoded RPC mapper.""" + + content_type = 'application/x-www-form-urlencoded' + protocol = protourlencode + mapper = service_handlers.URLEncodedRPCMapper + + def testBuildRequest_Prefix(self): + """Test building request with parameter prefix.""" + self.Reinitialize(urllib.urlencode([('prefix_integer_field', '10'), + ('prefix_string_field', 'a string'), + ('prefix_enum_field', 'VAL1'), + ]), + self.content_type) + + url_encoded_mapper = service_handlers.URLEncodedRPCMapper( + parameter_prefix='prefix_') + request = url_encoded_mapper.build_request(self.service_handler, + Request1) + self.assertEquals(10, request.integer_field) + self.assertEquals('a string', request.string_field) + self.assertEquals(Enum1.VAL1, request.enum_field) + + def testBuildRequest_DecodeError(self): + """Test trying to build request that causes a decode error.""" + self.Reinitialize(urllib.urlencode((('integer_field', '10'), + ('integer_field', '20'), + )), + content_type=self.content_type) + + url_encoded_mapper = service_handlers.URLEncodedRPCMapper() + + self.assertRaises(service_handlers.RequestError, + url_encoded_mapper.build_request, + self.service_handler, + Service.method1.remote.request_type) + + def testBuildResponse_Prefix(self): + """Test building a response with parameter prefix.""" + response = Response1() + response.integer_field = 10 + response.string_field = u'a string' + response.enum_field = Enum1.VAL3 + + url_encoded_mapper = service_handlers.URLEncodedRPCMapper( + parameter_prefix='prefix_') + + url_encoded_mapper.build_response(self.service_handler, response) + self.assertEquals('application/x-www-form-urlencoded', + self.response.headers['content-type']) + self.assertEquals(cgi.parse_qs(self.response.out.getvalue(), True, True), + {'prefix_integer_field': ['10'], + 'prefix_string_field': [u'a string'], + 'prefix_enum_field': ['VAL3'], + }) + + +class ProtobufRPCMapperTest(ProtocolMapperTestBase, RPCMapperTestBase): + """Test the protobuf encoded RPC mapper.""" + + content_type = 'application/octet-stream' + protocol = protobuf + mapper = service_handlers.ProtobufRPCMapper + + +class JSONRPCMapperTest(ProtocolMapperTestBase, RPCMapperTestBase): + """Test the URL encoded RPC mapper.""" + + content_type = 'application/json' + protocol = protojson + mapper = service_handlers.JSONRPCMapper + + +class MyService(remote.Service): + + def __init__(self, value='default'): + self.value = value + + +class ServiceMappingTest(test_util.TestCase): + + def CheckFormMappings(self, mapping, registry_path='/protorpc'): + """Check to make sure that form mapping is configured as expected. + + Args: + mapping: Mapping that should contain forms handlers. + """ + pattern, factory = mapping[0] + self.assertEquals('%s/form(?:/)?' % registry_path, pattern) + handler = factory() + self.assertTrue(isinstance(handler, forms.FormsHandler)) + self.assertEquals(registry_path, handler.registry_path) + + pattern, factory = mapping[1] + self.assertEquals('%s/form/(.+)' % registry_path, pattern) + self.assertEquals(forms.ResourceHandler, factory) + + + def DoMappingTest(self, + services, + registry_path='/myreg', + expected_paths=None): + mapped_services = mapping = service_handlers.service_mapping(services, + registry_path) + if registry_path: + form_mapping = mapping[:2] + mapped_registry_path, mapped_registry_factory = mapping[-1] + mapped_services = mapping[2:-1] + self.CheckFormMappings(form_mapping, registry_path=registry_path) + + self.assertEquals(r'(%s)%s' % (registry_path, + service_handlers._METHOD_PATTERN), + mapped_registry_path) + self.assertEquals(registry.RegistryService, + mapped_registry_factory.service_factory.service_class) + + # Verify registry knows about other services. + expected_registry = {registry_path: registry.RegistryService} + for path, factory in dict(services).items(): + if isinstance(factory, type) and issubclass(factory, remote.Service): + expected_registry[path] = factory + else: + expected_registry[path] = factory.service_class + self.assertEquals(expected_registry, + mapped_registry_factory().service.registry) + + # Verify that services are mapped to URL. + self.assertEquals(len(services), len(mapped_services)) + for path, service in dict(services).items(): + mapped_path = r'(%s)%s' % (path, service_handlers._METHOD_PATTERN) + mapped_factory = dict(mapped_services)[mapped_path] + self.assertEquals(service, mapped_factory.service_factory) + + def testServiceMapping_Empty(self): + """Test an empty service mapping.""" + self.DoMappingTest({}) + + def testServiceMapping_ByClass(self): + """Test mapping a service by class.""" + self.DoMappingTest({'/my-service': MyService}) + + def testServiceMapping_ByFactory(self): + """Test mapping a service by factory.""" + self.DoMappingTest({'/my-service': MyService.new_factory('new-value')}) + + def testServiceMapping_ByList(self): + """Test mapping a service by factory.""" + self.DoMappingTest( + [('/my-service1', MyService.new_factory('service1')), + ('/my-service2', MyService.new_factory('service2')), + ]) + + def testServiceMapping_NoRegistry(self): + """Test mapping a service by class.""" + mapping = self.DoMappingTest({'/my-service': MyService}, None) + + def testDefaultMappingWithClass(self): + """Test setting path just from the class. + + Path of the mapping will be the fully qualified ProtoRPC service name with + '.' replaced with '/'. For example: + + com.nowhere.service.TheService -> /com/nowhere/service/TheService + """ + mapping = service_handlers.service_mapping([MyService]) + mapped_services = mapping[2:-1] + self.assertEquals(1, len(mapped_services)) + path, factory = mapped_services[0] + + self.assertEquals( + r'(/test_package/MyService)' + service_handlers._METHOD_PATTERN, + path) + self.assertEquals(MyService, factory.service_factory) + + def testDefaultMappingWithFactory(self): + mapping = service_handlers.service_mapping( + [MyService.new_factory('service1')]) + mapped_services = mapping[2:-1] + self.assertEquals(1, len(mapped_services)) + path, factory = mapped_services[0] + + self.assertEquals( + r'(/test_package/MyService)' + service_handlers._METHOD_PATTERN, + path) + self.assertEquals(MyService, factory.service_factory.service_class) + + def testMappingDuplicateExplicitServiceName(self): + self.assertRaisesWithRegexpMatch( + service_handlers.ServiceConfigurationError, + "Path '/my_path' is already defined in service mapping", + service_handlers.service_mapping, + [('/my_path', MyService), + ('/my_path', MyService), + ]) + + def testMappingDuplicateServiceName(self): + self.assertRaisesWithRegexpMatch( + service_handlers.ServiceConfigurationError, + "Path '/test_package/MyService' is already defined in service mapping", + service_handlers.service_mapping, + [MyService, MyService]) + + +class GetCalled(remote.Service): + + def __init__(self, test): + self.test = test + + @remote.method(Request1, Response1) + def my_method(self, request): + self.test.request = request + return Response1(string_field='a response') + + +class TestRunServices(test_util.TestCase): + + def DoRequest(self, + path, + request, + response_type, + reg_path='/protorpc'): + stdin = sys.stdin + stdout = sys.stdout + environ = os.environ + try: + sys.stdin = cStringIO.StringIO(protojson.encode_message(request)) + sys.stdout = cStringIO.StringIO() + + os.environ = webapp_test_util.GetDefaultEnvironment() + os.environ['PATH_INFO'] = path + os.environ['REQUEST_METHOD'] = 'POST' + os.environ['CONTENT_TYPE'] = 'application/json' + os.environ['wsgi.input'] = sys.stdin + os.environ['wsgi.output'] = sys.stdout + os.environ['CONTENT_LENGTH'] = len(sys.stdin.getvalue()) + + service_handlers.run_services( + [('/my_service', GetCalled.new_factory(self))], reg_path) + + header, body = sys.stdout.getvalue().split('\n\n', 1) + + return (header.split('\n')[0], + protojson.decode_message(response_type, body)) + finally: + sys.stdin = stdin + sys.stdout = stdout + os.environ = environ + + def testRequest(self): + request = Request1(string_field='request value') + + status, response = self.DoRequest('/my_service.my_method', + request, + Response1) + self.assertEquals('Status: 200 OK', status) + self.assertEquals(request, self.request) + self.assertEquals(Response1(string_field='a response'), response) + + def testRegistry(self): + request = Request1(string_field='request value') + status, response = self.DoRequest('/protorpc.services', + message_types.VoidMessage(), + registry.ServicesResponse) + + self.assertEquals('Status: 200 OK', status) + self.assertIterEqual([ + registry.ServiceMapping( + name='/protorpc', + definition='protorpc.registry.RegistryService'), + registry.ServiceMapping( + name='/my_service', + definition='test_package.GetCalled'), + ], response.services) + + def testRunServicesWithOutRegistry(self): + request = Request1(string_field='request value') + + status, response = self.DoRequest('/protorpc.services', + message_types.VoidMessage(), + registry.ServicesResponse, + reg_path=None) + self.assertEquals('Status: 404 Not Found', status) + + +def main(): + unittest.main() + + +if __name__ == '__main__': + main() diff --git a/endpoints/internal/protorpc/webapp_test_util.py b/endpoints/internal/protorpc/webapp_test_util.py new file mode 100644 index 0000000..6481dc3 --- /dev/null +++ b/endpoints/internal/protorpc/webapp_test_util.py @@ -0,0 +1,411 @@ +#!/usr/bin/env python +# +# Copyright 2010 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""Testing utilities for the webapp libraries. + + GetDefaultEnvironment: Method for easily setting up CGI environment. + RequestHandlerTestBase: Base class for setting up handler tests. +""" + +__author__ = 'rafek@google.com (Rafe Kaplan)' + +import cStringIO +import socket +import threading +import urllib2 +from wsgiref import simple_server +from wsgiref import validate + +from . import protojson +from . import remote +from . import test_util +from . import transport +from .webapp import service_handlers +from .webapp.google_imports import webapp + + +class TestService(remote.Service): + """Service used to do end to end tests with.""" + + @remote.method(test_util.OptionalMessage, + test_util.OptionalMessage) + def optional_message(self, request): + if request.string_value: + request.string_value = '+%s' % request.string_value + return request + + +def GetDefaultEnvironment(): + """Function for creating a default CGI environment.""" + return { + 'LC_NUMERIC': 'C', + 'wsgi.multiprocess': True, + 'SERVER_PROTOCOL': 'HTTP/1.0', + 'SERVER_SOFTWARE': 'Dev AppServer 0.1', + 'SCRIPT_NAME': '', + 'LOGNAME': 'nickjohnson', + 'USER': 'nickjohnson', + 'QUERY_STRING': 'foo=bar&foo=baz&foo2=123', + 'PATH': '/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin:/usr/bin/X11', + 'LANG': 'en_US', + 'LANGUAGE': 'en', + 'REMOTE_ADDR': '127.0.0.1', + 'LC_MONETARY': 'C', + 'CONTENT_TYPE': 'application/x-www-form-urlencoded', + 'wsgi.url_scheme': 'http', + 'SERVER_PORT': '8080', + 'HOME': '/home/mruser', + 'USERNAME': 'mruser', + 'CONTENT_LENGTH': '', + 'USER_IS_ADMIN': '1', + 'PYTHONPATH': '/tmp/setup', + 'LC_TIME': 'C', + 'HTTP_USER_AGENT': 'Mozilla/5.0 (X11; U; Linux i686 (x86_64); en-US; ' + 'rv:1.8.1.6) Gecko/20070725 Firefox/2.0.0.6', + 'wsgi.multithread': False, + 'wsgi.version': (1, 0), + 'USER_EMAIL': 'test@example.com', + 'USER_EMAIL': '112', + 'wsgi.input': cStringIO.StringIO(), + 'PATH_TRANSLATED': '/tmp/request.py', + 'SERVER_NAME': 'localhost', + 'GATEWAY_INTERFACE': 'CGI/1.1', + 'wsgi.run_once': True, + 'LC_COLLATE': 'C', + 'HOSTNAME': 'myhost', + 'wsgi.errors': cStringIO.StringIO(), + 'PWD': '/tmp', + 'REQUEST_METHOD': 'GET', + 'MAIL': '/dev/null', + 'MAILCHECK': '0', + 'USER_NICKNAME': 'test', + 'HTTP_COOKIE': 'dev_appserver_login="test:test@example.com:True"', + 'PATH_INFO': '/tmp/myhandler' + } + + +class RequestHandlerTestBase(test_util.TestCase): + """Base class for writing RequestHandler tests. + + To test a specific request handler override CreateRequestHandler. + To change the environment for that handler override GetEnvironment. + """ + + def setUp(self): + """Set up test for request handler.""" + self.ResetHandler() + + def GetEnvironment(self): + """Get environment. + + Override for more specific configurations. + + Returns: + dict of CGI environment. + """ + return GetDefaultEnvironment() + + def CreateRequestHandler(self): + """Create RequestHandler instances. + + Override to create more specific kinds of RequestHandler instances. + + Returns: + RequestHandler instance used in test. + """ + return webapp.RequestHandler() + + def CheckResponse(self, + expected_status, + expected_headers, + expected_content): + """Check that the web response is as expected. + + Args: + expected_status: Expected status message. + expected_headers: Dictionary of expected headers. Will ignore unexpected + headers and only check the value of those expected. + expected_content: Expected body. + """ + def check_content(content): + self.assertEquals(expected_content, content) + + def start_response(status, headers): + self.assertEquals(expected_status, status) + + found_keys = set() + for name, value in headers: + name = name.lower() + try: + expected_value = expected_headers[name] + except KeyError: + pass + else: + found_keys.add(name) + self.assertEquals(expected_value, value) + + missing_headers = set(expected_headers.keys()) - found_keys + if missing_headers: + self.fail('Expected keys %r not found' % (list(missing_headers),)) + + return check_content + + self.handler.response.wsgi_write(start_response) + + def ResetHandler(self, change_environ=None): + """Reset this tests environment with environment changes. + + Resets the entire test with a new handler which includes some changes to + the default request environment. + + Args: + change_environ: Dictionary of values that are added to default + environment. + """ + environment = self.GetEnvironment() + environment.update(change_environ or {}) + + self.request = webapp.Request(environment) + self.response = webapp.Response() + self.handler = self.CreateRequestHandler() + self.handler.initialize(self.request, self.response) + + +class SyncedWSGIServer(simple_server.WSGIServer): + pass + + +class WSGIServerIPv6(simple_server.WSGIServer): + address_family = socket.AF_INET6 + + +class ServerThread(threading.Thread): + """Thread responsible for managing wsgi server. + + This server does not just attach to the socket and listen for requests. This + is because the server classes in Python 2.5 or less have no way to shut them + down. Instead, the thread must be notified of how many requests it will + receive so that it listens for each one individually. Tests should tell how + many requests to listen for using the handle_request method. + """ + + def __init__(self, server, *args, **kwargs): + """Constructor. + + Args: + server: The WSGI server that is served by this thread. + As per threading.Thread base class. + + State: + __serving: Server is still expected to be serving. When False server + knows to shut itself down. + """ + self.server = server + # This timeout is for the socket when a connection is made. + self.server.socket.settimeout(None) + # This timeout is for when waiting for a connection. The allows + # server.handle_request() to listen for a short time, then timeout, + # allowing the server to check for shutdown. + self.server.timeout = 0.05 + self.__serving = True + + super(ServerThread, self).__init__(*args, **kwargs) + + def shutdown(self): + """Notify server that it must shutdown gracefully.""" + self.__serving = False + + def run(self): + """Handle incoming requests until shutdown.""" + while self.__serving: + self.server.handle_request() + + self.server = None + + +class TestService(remote.Service): + """Service used to do end to end tests with.""" + + def __init__(self, message='uninitialized'): + self.__message = message + + @remote.method(test_util.OptionalMessage, test_util.OptionalMessage) + def optional_message(self, request): + if request.string_value: + request.string_value = '+%s' % request.string_value + return request + + @remote.method(response_type=test_util.OptionalMessage) + def init_parameter(self, request): + return test_util.OptionalMessage(string_value=self.__message) + + @remote.method(test_util.NestedMessage, test_util.NestedMessage) + def nested_message(self, request): + request.string_value = '+%s' % request.string_value + return request + + @remote.method() + def raise_application_error(self, request): + raise remote.ApplicationError('This is an application error', 'ERROR_NAME') + + @remote.method() + def raise_unexpected_error(self, request): + raise TypeError('Unexpected error') + + @remote.method() + def raise_rpc_error(self, request): + raise remote.NetworkError('Uncaught network error') + + @remote.method(response_type=test_util.NestedMessage) + def return_bad_message(self, request): + return test_util.NestedMessage() + + +class AlternateService(remote.Service): + """Service used to requesting non-existant methods.""" + + @remote.method() + def does_not_exist(self, request): + raise NotImplementedError('Not implemented') + + +class WebServerTestBase(test_util.TestCase): + + SERVICE_PATH = '/my/service' + + def setUp(self): + self.server = None + self.schema = 'http' + self.ResetServer() + + self.bad_path_connection = self.CreateTransport(self.service_url + '_x') + self.bad_path_stub = TestService.Stub(self.bad_path_connection) + super(WebServerTestBase, self).setUp() + + def tearDown(self): + self.server.shutdown() + super(WebServerTestBase, self).tearDown() + + def ResetServer(self, application=None): + """Reset web server. + + Shuts down existing server if necessary and starts a new one. + + Args: + application: Optional WSGI function. If none provided will use + tests CreateWsgiApplication method. + """ + if self.server: + self.server.shutdown() + + self.port = test_util.pick_unused_port() + self.server, self.application = self.StartWebServer(self.port, application) + + self.connection = self.CreateTransport(self.service_url) + + def CreateTransport(self, service_url, protocol=protojson): + """Create a new transportation object.""" + return transport.HttpTransport(service_url, protocol=protocol) + + def StartWebServer(self, port, application=None): + """Start web server. + + Args: + port: Port to start application on. + application: Optional WSGI function. If none provided will use + tests CreateWsgiApplication method. + + Returns: + A tuple (server, application): + server: An instance of ServerThread. + application: Application that web server responds with. + """ + if not application: + application = self.CreateWsgiApplication() + validated_application = validate.validator(application) + + try: + server = simple_server.make_server( + 'localhost', port, validated_application) + except socket.error: + # Try IPv6 + server = simple_server.make_server( + 'localhost', port, validated_application, server_class=WSGIServerIPv6) + + server = ServerThread(server) + server.start() + return server, application + + def make_service_url(self, path): + """Make service URL using current schema and port.""" + return '%s://localhost:%d%s' % (self.schema, self.port, path) + + @property + def service_url(self): + return self.make_service_url(self.SERVICE_PATH) + + +class EndToEndTestBase(WebServerTestBase): + + # Sub-classes may override to create alternate configurations. + DEFAULT_MAPPING = service_handlers.service_mapping( + [('/my/service', TestService), + ('/my/other_service', TestService.new_factory('initialized')), + ]) + + def setUp(self): + super(EndToEndTestBase, self).setUp() + + self.stub = TestService.Stub(self.connection) + + self.other_connection = self.CreateTransport(self.other_service_url) + self.other_stub = TestService.Stub(self.other_connection) + + self.mismatched_stub = AlternateService.Stub(self.connection) + + @property + def other_service_url(self): + return 'http://localhost:%d/my/other_service' % self.port + + def CreateWsgiApplication(self): + """Create WSGI application used on the server side for testing.""" + return webapp.WSGIApplication(self.DEFAULT_MAPPING, True) + + def DoRawRequest(self, + method, + content='', + content_type='application/json', + headers=None): + headers = headers or {} + headers.update({'content-length': len(content or ''), + 'content-type': content_type, + }) + request = urllib2.Request('%s.%s' % (self.service_url, method), + content, + headers) + return urllib2.urlopen(request) + + def RawRequestError(self, + method, + content=None, + content_type='application/json', + headers=None): + try: + self.DoRawRequest(method, content, content_type, headers) + self.fail('Expected HTTP error') + except urllib2.HTTPError as err: + return err.code, err.read(), err.headers diff --git a/endpoints/internal/protorpc/wsgi/__init__.py b/endpoints/internal/protorpc/wsgi/__init__.py new file mode 100644 index 0000000..00be5b0 --- /dev/null +++ b/endpoints/internal/protorpc/wsgi/__init__.py @@ -0,0 +1,16 @@ +#!/usr/bin/env python +# +# Copyright 2011 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# diff --git a/endpoints/internal/protorpc/wsgi/service.py b/endpoints/internal/protorpc/wsgi/service.py new file mode 100644 index 0000000..954658a --- /dev/null +++ b/endpoints/internal/protorpc/wsgi/service.py @@ -0,0 +1,267 @@ +#!/usr/bin/env python +# +# Copyright 2011 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""ProtoRPC WSGI service applications. + +Use functions in this module to configure ProtoRPC services for use with +WSGI applications. For more information about WSGI, please see: + + http://wsgi.org/wsgi + http://docs.python.org/library/wsgiref.html +""" +import six + +__author__ = 'rafek@google.com (Rafe Kaplan)' + +import cgi +import six.moves.http_client +import logging +import re + +from .. import messages +from .. import registry +from .. import remote +from .. import util +from . import util as wsgi_util + +__all__ = [ + 'DEFAULT_REGISTRY_PATH', + 'service_app', +] + +_METHOD_PATTERN = r'(?:\.([^?]+))' +_REQUEST_PATH_PATTERN = r'^(%%s)%s$' % _METHOD_PATTERN + +_HTTP_BAD_REQUEST = wsgi_util.error(six.moves.http_client.BAD_REQUEST) +_HTTP_NOT_FOUND = wsgi_util.error(six.moves.http_client.NOT_FOUND) +_HTTP_UNSUPPORTED_MEDIA_TYPE = wsgi_util.error(six.moves.http_client.UNSUPPORTED_MEDIA_TYPE) + +DEFAULT_REGISTRY_PATH = '/protorpc' + + +@util.positional(2) +def service_mapping(service_factory, service_path=r'.*', protocols=None): + """WSGI application that handles a single ProtoRPC service mapping. + + Args: + service_factory: Service factory for creating instances of service request + handlers. Either callable that takes no parameters and returns a service + instance or a service class whose constructor requires no parameters. + service_path: Regular expression for matching requests against. Requests + that do not have matching paths will cause a 404 (Not Found) response. + protocols: remote.Protocols instance that configures supported protocols + on server. + """ + service_class = getattr(service_factory, 'service_class', service_factory) + remote_methods = service_class.all_remote_methods() + path_matcher = re.compile(_REQUEST_PATH_PATTERN % service_path) + + def protorpc_service_app(environ, start_response): + """Actual WSGI application function.""" + path_match = path_matcher.match(environ['PATH_INFO']) + if not path_match: + return _HTTP_NOT_FOUND(environ, start_response) + service_path = path_match.group(1) + method_name = path_match.group(2) + + content_type = environ.get('CONTENT_TYPE') + if not content_type: + content_type = environ.get('HTTP_CONTENT_TYPE') + if not content_type: + return _HTTP_BAD_REQUEST(environ, start_response) + + # TODO(rafek): Handle alternate encodings. + content_type = cgi.parse_header(content_type)[0] + + request_method = environ['REQUEST_METHOD'] + if request_method != 'POST': + content = ('%s.%s is a ProtoRPC method.\n\n' + 'Service %s\n\n' + 'More about ProtoRPC: ' + '%s\n' % + (service_path, + method_name, + service_class.definition_name().encode('utf-8'), + util.PROTORPC_PROJECT_URL)) + error_handler = wsgi_util.error( + six.moves.http_client.METHOD_NOT_ALLOWED, + six.moves.http_client.responses[six.moves.http_client.METHOD_NOT_ALLOWED], + content=content, + content_type='text/plain; charset=utf-8') + return error_handler(environ, start_response) + + local_protocols = protocols or remote.Protocols.get_default() + try: + protocol = local_protocols.lookup_by_content_type(content_type) + except KeyError: + return _HTTP_UNSUPPORTED_MEDIA_TYPE(environ,start_response) + + def send_rpc_error(status_code, state, message, error_name=None): + """Helper function to send an RpcStatus message as response. + + Will create static error handler and begin response. + + Args: + status_code: HTTP integer status code. + state: remote.RpcState enum value to send as response. + message: Helpful message to send in response. + error_name: Error name if applicable. + + Returns: + List containing encoded content response using the same content-type as + the request. + """ + status = remote.RpcStatus(state=state, + error_message=message, + error_name=error_name) + encoded_status = protocol.encode_message(status) + error_handler = wsgi_util.error( + status_code, + content_type=protocol.default_content_type, + content=encoded_status) + return error_handler(environ, start_response) + + method = remote_methods.get(method_name) + if not method: + return send_rpc_error(six.moves.http_client.BAD_REQUEST, + remote.RpcState.METHOD_NOT_FOUND_ERROR, + 'Unrecognized RPC method: %s' % method_name) + + content_length = int(environ.get('CONTENT_LENGTH') or '0') + + remote_info = method.remote + try: + request = protocol.decode_message( + remote_info.request_type, environ['wsgi.input'].read(content_length)) + except (messages.ValidationError, messages.DecodeError) as err: + return send_rpc_error(six.moves.http_client.BAD_REQUEST, + remote.RpcState.REQUEST_ERROR, + 'Error parsing ProtoRPC request ' + '(Unable to parse request content: %s)' % err) + + instance = service_factory() + + initialize_request_state = getattr( + instance, 'initialize_request_state', None) + if initialize_request_state: + # TODO(rafek): This is not currently covered by tests. + server_port = environ.get('SERVER_PORT', None) + if server_port: + server_port = int(server_port) + + headers = [] + for name, value in six.iteritems(environ): + if name.startswith('HTTP_'): + headers.append((name[len('HTTP_'):].lower().replace('_', '-'), value)) + request_state = remote.HttpRequestState( + remote_host=environ.get('REMOTE_HOST', None), + remote_address=environ.get('REMOTE_ADDR', None), + server_host=environ.get('SERVER_HOST', None), + server_port=server_port, + http_method=request_method, + service_path=service_path, + headers=headers) + + initialize_request_state(request_state) + + try: + response = method(instance, request) + encoded_response = protocol.encode_message(response) + except remote.ApplicationError as err: + return send_rpc_error(six.moves.http_client.BAD_REQUEST, + remote.RpcState.APPLICATION_ERROR, + unicode(err), + err.error_name) + except Exception as err: + logging.exception('Encountered unexpected error from ProtoRPC ' + 'method implementation: %s (%s)' % + (err.__class__.__name__, err)) + return send_rpc_error(six.moves.http_client.INTERNAL_SERVER_ERROR, + remote.RpcState.SERVER_ERROR, + 'Internal Server Error') + + response_headers = [('content-type', content_type)] + start_response('%d %s' % (six.moves.http_client.OK, six.moves.http_client.responses[six.moves.http_client.OK],), + response_headers) + return [encoded_response] + + # Return WSGI application. + return protorpc_service_app + + +@util.positional(1) +def service_mappings(services, registry_path=DEFAULT_REGISTRY_PATH): + """Create multiple service mappings with optional RegistryService. + + Use this function to create single WSGI application that maps to + multiple ProtoRPC services plus an optional RegistryService. + + Example: + services = service.service_mappings( + [(r'/time', TimeService), + (r'/weather', WeatherService) + ]) + + In this example, the services WSGI application will map to two services, + TimeService and WeatherService to the '/time' and '/weather' paths + respectively. In addition, it will also add a ProtoRPC RegistryService + configured to serve information about both services at the (default) path + '/protorpc'. + + Args: + services: If a dictionary is provided instead of a list of tuples, the + dictionary item pairs are used as the mappings instead. + Otherwise, a list of tuples (service_path, service_factory): + service_path: The path to mount service on. + service_factory: A service class or service instance factory. + registry_path: A string to change where the registry is mapped (the default + location is '/protorpc'). When None, no registry is created or mounted. + + Returns: + WSGI application that serves ProtoRPC services on their respective URLs + plus optional RegistryService. + """ + if isinstance(services, dict): + services = six.iteritems(services) + + final_mapping = [] + paths = set() + registry_map = {} if registry_path else None + + for service_path, service_factory in services: + try: + service_class = service_factory.service_class + except AttributeError: + service_class = service_factory + + if service_path not in paths: + paths.add(service_path) + else: + raise remote.ServiceConfigurationError( + 'Path %r is already defined in service mapping' % + service_path.encode('utf-8')) + + if registry_map is not None: + registry_map[service_path] = service_class + + final_mapping.append(service_mapping(service_factory, service_path)) + + if registry_map is not None: + final_mapping.append(service_mapping( + registry.RegistryService.new_factory(registry_map), registry_path)) + + return wsgi_util.first_found(final_mapping) diff --git a/endpoints/internal/protorpc/wsgi/service_test.py b/endpoints/internal/protorpc/wsgi/service_test.py new file mode 100644 index 0000000..c94d648 --- /dev/null +++ b/endpoints/internal/protorpc/wsgi/service_test.py @@ -0,0 +1,205 @@ +#!/usr/bin/env python +# +# Copyright 2011 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""WSGI application tests.""" + +__author__ = 'rafek@google.com (Rafe Kaplan)' + + +import unittest + + +from protorpc import end2end_test +from protorpc import protojson +from protorpc import remote +from protorpc import registry +from protorpc import transport +from protorpc import test_util +from protorpc import webapp_test_util +from protorpc.wsgi import service +from protorpc.wsgi import util + + +class ServiceMappingTest(end2end_test.EndToEndTest): + + def setUp(self): + self.protocols = None + remote.Protocols.set_default(remote.Protocols.new_default()) + super(ServiceMappingTest, self).setUp() + + def CreateServices(self): + + return my_service, my_other_service + + def CreateWsgiApplication(self): + """Create WSGI application used on the server side for testing.""" + my_service = service.service_mapping(webapp_test_util.TestService, + '/my/service') + my_other_service = service.service_mapping( + webapp_test_util.TestService.new_factory('initialized'), + '/my/other_service', + protocols=self.protocols) + + return util.first_found([my_service, my_other_service]) + + def testAlternateProtocols(self): + self.protocols = remote.Protocols() + self.protocols.add_protocol(protojson, 'altproto', 'image/png') + + global_protocols = remote.Protocols() + global_protocols.add_protocol(protojson, 'server-side-name', 'image/png') + remote.Protocols.set_default(global_protocols) + self.ResetServer() + + self.connection = transport.HttpTransport( + self.service_url, protocol=self.protocols.lookup_by_name('altproto')) + self.stub = webapp_test_util.TestService.Stub(self.connection) + + self.stub.optional_message(string_value='alternate-protocol') + + def testAlwaysUseDefaults(self): + new_protocols = remote.Protocols() + new_protocols.add_protocol(protojson, 'altproto', 'image/png') + + self.connection = transport.HttpTransport( + self.service_url, protocol=new_protocols.lookup_by_name('altproto')) + self.stub = webapp_test_util.TestService.Stub(self.connection) + + self.assertRaisesWithRegexpMatch( + remote.ServerError, + 'HTTP Error 415: Unsupported Media Type', + self.stub.optional_message, string_value='alternate-protocol') + + remote.Protocols.set_default(new_protocols) + + self.stub.optional_message(string_value='alternate-protocol') + + +class ProtoServiceMappingsTest(ServiceMappingTest): + + def CreateWsgiApplication(self): + """Create WSGI application used on the server side for testing.""" + return service.service_mappings( + [('/my/service', webapp_test_util.TestService), + ('/my/other_service', + webapp_test_util.TestService.new_factory('initialized')) + ]) + + def GetRegistryStub(self, path='/protorpc'): + service_url = self.make_service_url(path) + transport = self.CreateTransport(service_url) + return registry.RegistryService.Stub(transport) + + def testRegistry(self): + registry_client = self.GetRegistryStub() + response = registry_client.services() + self.assertIterEqual([ + registry.ServiceMapping( + name='/my/other_service', + definition='protorpc.webapp_test_util.TestService'), + registry.ServiceMapping( + name='/my/service', + definition='protorpc.webapp_test_util.TestService'), + ], response.services) + + def testRegistryDictionary(self): + self.ResetServer(service.service_mappings( + {'/my/service': webapp_test_util.TestService, + '/my/other_service': + webapp_test_util.TestService.new_factory('initialized'), + })) + registry_client = self.GetRegistryStub() + response = registry_client.services() + self.assertIterEqual([ + registry.ServiceMapping( + name='/my/other_service', + definition='protorpc.webapp_test_util.TestService'), + registry.ServiceMapping( + name='/my/service', + definition='protorpc.webapp_test_util.TestService'), + ], response.services) + + def testNoRegistry(self): + self.ResetServer(service.service_mappings( + [('/my/service', webapp_test_util.TestService), + ('/my/other_service', + webapp_test_util.TestService.new_factory('initialized')) + ], + registry_path=None)) + registry_client = self.GetRegistryStub() + self.assertRaisesWithRegexpMatch( + remote.ServerError, + 'HTTP Error 404: Not Found', + registry_client.services) + + def testAltRegistry(self): + self.ResetServer(service.service_mappings( + [('/my/service', webapp_test_util.TestService), + ('/my/other_service', + webapp_test_util.TestService.new_factory('initialized')) + ], + registry_path='/registry')) + registry_client = self.GetRegistryStub('/registry') + services = registry_client.services() + self.assertTrue(isinstance(services, registry.ServicesResponse)) + self.assertIterEqual( + [registry.ServiceMapping( + name='/my/other_service', + definition='protorpc.webapp_test_util.TestService'), + registry.ServiceMapping( + name='/my/service', + definition='protorpc.webapp_test_util.TestService'), + ], + services.services) + + def testDuplicateRegistryEntry(self): + self.assertRaisesWithRegexpMatch( + remote.ServiceConfigurationError, + "Path '/my/service' is already defined in service mapping", + service.service_mappings, + [('/my/service', webapp_test_util.TestService), + ('/my/service', + webapp_test_util.TestService.new_factory('initialized')) + ]) + + def testRegex(self): + self.ResetServer(service.service_mappings( + [('/my/[0-9]+', webapp_test_util.TestService.new_factory('service')), + ('/my/[a-z]+', + webapp_test_util.TestService.new_factory('other-service')), + ])) + my_service_url = 'http://localhost:%d/my/12345' % self.port + my_other_service_url = 'http://localhost:%d/my/blarblar' % self.port + + my_service = webapp_test_util.TestService.Stub( + transport.HttpTransport(my_service_url)) + my_other_service = webapp_test_util.TestService.Stub( + transport.HttpTransport(my_other_service_url)) + + response = my_service.init_parameter() + self.assertEquals('service', response.string_value) + + response = my_other_service.init_parameter() + self.assertEquals('other-service', response.string_value) + + +def main(): + unittest.main() + + +if __name__ == '__main__': + main() diff --git a/endpoints/internal/protorpc/wsgi/util.py b/endpoints/internal/protorpc/wsgi/util.py new file mode 100644 index 0000000..344a6bd --- /dev/null +++ b/endpoints/internal/protorpc/wsgi/util.py @@ -0,0 +1,180 @@ +#!/usr/bin/env python +# +# Copyright 2011 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""WSGI utilities + +Small collection of helpful utilities for working with WSGI. +""" +import six + +__author__ = 'rafek@google.com (Rafe Kaplan)' + +import six.moves.http_client +import re + +from .. import util + +__all__ = ['static_page', + 'error', + 'first_found', +] + +_STATUS_PATTERN = re.compile('^(\d{3})\s') + + +@util.positional(1) +def static_page(content='', + status='200 OK', + content_type='text/html; charset=utf-8', + headers=None): + """Create a WSGI application that serves static content. + + A static page is one that will be the same every time it receives a request. + It will always serve the same status, content and headers. + + Args: + content: Content to serve in response to HTTP request. + status: Status to serve in response to HTTP request. If string, status + is served as is without any error checking. If integer, will look up + status message. Otherwise, parameter is tuple (status, description): + status: Integer status of response. + description: Brief text description of response. + content_type: Convenient parameter for content-type header. Will appear + before any content-type header that appears in 'headers' parameter. + headers: Dictionary of headers or iterable of tuples (name, value): + name: String name of header. + value: String value of header. + + Returns: + WSGI application that serves static content. + """ + if isinstance(status, six.integer_types): + status = '%d %s' % (status, six.moves.http_client.responses.get(status, 'Unknown Error')) + elif not isinstance(status, six.string_types): + status = '%d %s' % tuple(status) + + if isinstance(headers, dict): + headers = six.iteritems(headers) + + headers = [('content-length', str(len(content))), + ('content-type', content_type), + ] + list(headers or []) + + # Ensure all headers are str. + for index, (key, value) in enumerate(headers): + if isinstance(value, six.text_type): + value = value.encode('utf-8') + headers[index] = key, value + + if not isinstance(key, str): + raise TypeError('Header key must be str, found: %r' % (key,)) + + if not isinstance(value, str): + raise TypeError( + 'Header %r must be type str or unicode, found: %r' % (key, value)) + + def static_page_application(environ, start_response): + start_response(status, headers) + return [content] + + return static_page_application + + +@util.positional(2) +def error(status_code, status_message=None, + content_type='text/plain; charset=utf-8', + headers=None, content=None): + """Create WSGI application that statically serves an error page. + + Creates a static error page specifically for non-200 HTTP responses. + + Browsers such as Internet Explorer will display their own error pages for + error content responses smaller than 512 bytes. For this reason all responses + are right-padded up to 512 bytes. + + Error pages that are not provided will content will contain the standard HTTP + status message as their content. + + Args: + status_code: Integer status code of error. + status_message: Status message. + + Returns: + Static WSGI application that sends static error response. + """ + if status_message is None: + status_message = six.moves.http_client.responses.get(status_code, 'Unknown Error') + + if content is None: + content = status_message + + content = util.pad_string(content) + + return static_page(content, + status=(status_code, status_message), + content_type=content_type, + headers=headers) + + +def first_found(apps): + """Serve the first application that does not response with 404 Not Found. + + If no application serves content, will respond with generic 404 Not Found. + + Args: + apps: List of WSGI applications to search through. Will serve the content + of the first of these that does not return a 404 Not Found. Applications + in this list must not modify the environment or any objects in it if they + do not match. Applications that do not obey this restriction can create + unpredictable results. + + Returns: + Compound application that serves the contents of the first application that + does not response with 404 Not Found. + """ + apps = tuple(apps) + not_found = error(six.moves.http_client.NOT_FOUND) + + def first_found_app(environ, start_response): + """Compound application returned from the first_found function.""" + final_result = {} # Used in absence of Python local scoping. + + def first_found_start_response(status, response_headers): + """Replacement for start_response as passed in to first_found_app. + + Called by each application in apps instead of the real start response. + Checks the response status, and if anything other than 404, sets 'status' + and 'response_headers' in final_result. + """ + status_match = _STATUS_PATTERN.match(status) + assert status_match, ('Status must be a string beginning ' + 'with 3 digit number. Found: %s' % status) + status_code = status_match.group(0) + if int(status_code) == six.moves.http_client.NOT_FOUND: + return + + final_result['status'] = status + final_result['response_headers'] = response_headers + + for app in apps: + response = app(environ, first_found_start_response) + if final_result: + start_response(final_result['status'], final_result['response_headers']) + return response + + return not_found(environ, start_response) + return first_found_app diff --git a/endpoints/internal/protorpc/wsgi/util_test.py b/endpoints/internal/protorpc/wsgi/util_test.py new file mode 100644 index 0000000..60a79af --- /dev/null +++ b/endpoints/internal/protorpc/wsgi/util_test.py @@ -0,0 +1,295 @@ +#!/usr/bin/env python +# +# Copyright 2011 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""WSGI utility library tests.""" +import six +from six.moves import filter + +__author__ = 'rafe@google.com (Rafe Kaplan)' + + +import six.moves.http_client +import unittest + +from protorpc import test_util +from protorpc import util +from protorpc import webapp_test_util +from protorpc.wsgi import util as wsgi_util + +APP1 = wsgi_util.static_page('App1') +APP2 = wsgi_util.static_page('App2') +NOT_FOUND = wsgi_util.error(six.moves.http_client.NOT_FOUND) + + +class WsgiTestBase(webapp_test_util.WebServerTestBase): + + server_thread = None + + def CreateWsgiApplication(self): + return None + + def DoHttpRequest(self, + path='/', + content=None, + content_type='text/plain; charset=utf-8', + headers=None): + connection = six.moves.http_client.HTTPConnection('localhost', self.port) + if content is None: + method = 'GET' + else: + method = 'POST' + headers = {'content=type': content_type} + headers.update(headers) + connection.request(method, path, content, headers) + response = connection.getresponse() + + not_date_or_server = lambda header: header[0] not in ('date', 'server') + headers = list(filter(not_date_or_server, response.getheaders())) + + return response.status, response.reason, response.read(), dict(headers) + + +class StaticPageBase(WsgiTestBase): + + def testDefault(self): + default_page = wsgi_util.static_page() + self.ResetServer(default_page) + status, reason, content, headers = self.DoHttpRequest() + self.assertEquals(200, status) + self.assertEquals('OK', reason) + self.assertEquals('', content) + self.assertEquals({'content-length': '0', + 'content-type': 'text/html; charset=utf-8', + }, + headers) + + def testHasContent(self): + default_page = wsgi_util.static_page('my content') + self.ResetServer(default_page) + status, reason, content, headers = self.DoHttpRequest() + self.assertEquals(200, status) + self.assertEquals('OK', reason) + self.assertEquals('my content', content) + self.assertEquals({'content-length': str(len('my content')), + 'content-type': 'text/html; charset=utf-8', + }, + headers) + + def testHasContentType(self): + default_page = wsgi_util.static_page(content_type='text/plain') + self.ResetServer(default_page) + status, reason, content, headers = self.DoHttpRequest() + self.assertEquals(200, status) + self.assertEquals('OK', reason) + self.assertEquals('', content) + self.assertEquals({'content-length': '0', + 'content-type': 'text/plain', + }, + headers) + + def testHasStatus(self): + default_page = wsgi_util.static_page(status='400 Not Good Request') + self.ResetServer(default_page) + status, reason, content, headers = self.DoHttpRequest() + self.assertEquals(400, status) + self.assertEquals('Not Good Request', reason) + self.assertEquals('', content) + self.assertEquals({'content-length': '0', + 'content-type': 'text/html; charset=utf-8', + }, + headers) + + def testHasStatusInt(self): + default_page = wsgi_util.static_page(status=401) + self.ResetServer(default_page) + status, reason, content, headers = self.DoHttpRequest() + self.assertEquals(401, status) + self.assertEquals('Unauthorized', reason) + self.assertEquals('', content) + self.assertEquals({'content-length': '0', + 'content-type': 'text/html; charset=utf-8', + }, + headers) + + def testHasStatusUnknown(self): + default_page = wsgi_util.static_page(status=909) + self.ResetServer(default_page) + status, reason, content, headers = self.DoHttpRequest() + self.assertEquals(909, status) + self.assertEquals('Unknown Error', reason) + self.assertEquals('', content) + self.assertEquals({'content-length': '0', + 'content-type': 'text/html; charset=utf-8', + }, + headers) + + def testHasStatusTuple(self): + default_page = wsgi_util.static_page(status=(500, 'Bad Thing')) + self.ResetServer(default_page) + status, reason, content, headers = self.DoHttpRequest() + self.assertEquals(500, status) + self.assertEquals('Bad Thing', reason) + self.assertEquals('', content) + self.assertEquals({'content-length': '0', + 'content-type': 'text/html; charset=utf-8', + }, + headers) + + def testHasHeaders(self): + default_page = wsgi_util.static_page(headers=[('x', 'foo'), + ('a', 'bar'), + ('z', 'bin')]) + self.ResetServer(default_page) + status, reason, content, headers = self.DoHttpRequest() + self.assertEquals(200, status) + self.assertEquals('OK', reason) + self.assertEquals('', content) + self.assertEquals({'content-length': '0', + 'content-type': 'text/html; charset=utf-8', + 'x': 'foo', + 'a': 'bar', + 'z': 'bin', + }, + headers) + + def testHeadersUnicodeSafe(self): + default_page = wsgi_util.static_page(headers=[('x', u'foo')]) + self.ResetServer(default_page) + status, reason, content, headers = self.DoHttpRequest() + self.assertEquals(200, status) + self.assertEquals('OK', reason) + self.assertEquals('', content) + self.assertEquals({'content-length': '0', + 'content-type': 'text/html; charset=utf-8', + 'x': 'foo', + }, + headers) + self.assertTrue(isinstance(headers['x'], str)) + + def testHasHeadersDict(self): + default_page = wsgi_util.static_page(headers={'x': 'foo', + 'a': 'bar', + 'z': 'bin'}) + self.ResetServer(default_page) + status, reason, content, headers = self.DoHttpRequest() + self.assertEquals(200, status) + self.assertEquals('OK', reason) + self.assertEquals('', content) + self.assertEquals({'content-length': '0', + 'content-type': 'text/html; charset=utf-8', + 'x': 'foo', + 'a': 'bar', + 'z': 'bin', + }, + headers) + + +class FirstFoundTest(WsgiTestBase): + + def testEmptyConfiguration(self): + self.ResetServer(wsgi_util.first_found([])) + status, status_text, content, headers = self.DoHttpRequest('/') + self.assertEquals(six.moves.http_client.NOT_FOUND, status) + self.assertEquals(six.moves.http_client.responses[six.moves.http_client.NOT_FOUND], status_text) + self.assertEquals(util.pad_string(six.moves.http_client.responses[six.moves.http_client.NOT_FOUND]), + content) + self.assertEquals({'content-length': '512', + 'content-type': 'text/plain; charset=utf-8', + }, + headers) + + def testOneApp(self): + self.ResetServer(wsgi_util.first_found([APP1])) + + status, status_text, content, headers = self.DoHttpRequest('/') + self.assertEquals(six.moves.http_client.OK, status) + self.assertEquals(six.moves.http_client.responses[six.moves.http_client.OK], status_text) + self.assertEquals('App1', content) + self.assertEquals({'content-length': '4', + 'content-type': 'text/html; charset=utf-8', + }, + headers) + + def testIterator(self): + self.ResetServer(wsgi_util.first_found(iter([APP1]))) + + status, status_text, content, headers = self.DoHttpRequest('/') + self.assertEquals(six.moves.http_client.OK, status) + self.assertEquals(six.moves.http_client.responses[six.moves.http_client.OK], status_text) + self.assertEquals('App1', content) + self.assertEquals({'content-length': '4', + 'content-type': 'text/html; charset=utf-8', + }, + headers) + + # Do request again to make sure iterator was properly copied. + status, status_text, content, headers = self.DoHttpRequest('/') + self.assertEquals(six.moves.http_client.OK, status) + self.assertEquals(six.moves.http_client.responses[six.moves.http_client.OK], status_text) + self.assertEquals('App1', content) + self.assertEquals({'content-length': '4', + 'content-type': 'text/html; charset=utf-8', + }, + headers) + + def testTwoApps(self): + self.ResetServer(wsgi_util.first_found([APP1, APP2])) + + status, status_text, content, headers = self.DoHttpRequest('/') + self.assertEquals(six.moves.http_client.OK, status) + self.assertEquals(six.moves.http_client.responses[six.moves.http_client.OK], status_text) + self.assertEquals('App1', content) + self.assertEquals({'content-length': '4', + 'content-type': 'text/html; charset=utf-8', + }, + headers) + + def testFirstNotFound(self): + self.ResetServer(wsgi_util.first_found([NOT_FOUND, APP2])) + + status, status_text, content, headers = self.DoHttpRequest('/') + self.assertEquals(six.moves.http_client.OK, status) + self.assertEquals(six.moves.http_client.responses[six.moves.http_client.OK], status_text) + self.assertEquals('App2', content) + self.assertEquals({'content-length': '4', + 'content-type': 'text/html; charset=utf-8', + }, + headers) + + def testOnlyNotFound(self): + def current_error(environ, start_response): + """The variable current_status is defined in loop after ResetServer.""" + headers = [('content-type', 'text/plain')] + status_line = '%03d Whatever' % current_status + start_response(status_line, headers) + return [] + + self.ResetServer(wsgi_util.first_found([current_error, APP2])) + + statuses_to_check = sorted(httplib.responses.keys()) + # 100, 204 and 304 have slightly different expectations, so they are left + # out of this test in order to keep the code simple. + for dont_check in (100, 200, 204, 304, 404): + statuses_to_check.remove(dont_check) + for current_status in statuses_to_check: + status, status_text, content, headers = self.DoHttpRequest('/') + self.assertEquals(current_status, status) + self.assertEquals('Whatever', status_text) + + +if __name__ == '__main__': + unittest.main() From 3bb0396cab2c2998482ae6c390d1afb001a8d881 Mon Sep 17 00:00:00 2001 From: Rose Davidson Date: Thu, 26 Jul 2018 16:56:19 -0700 Subject: [PATCH 2/6] Remove included protorpc tests. --- .../internal/protorpc/definition_test.py | 657 ----- .../internal/protorpc/descriptor_test.py | 649 ----- endpoints/internal/protorpc/end2end_test.py | 148 -- .../internal/protorpc/generate_proto_test.py | 197 -- .../internal/protorpc/generate_python_test.py | 362 --- endpoints/internal/protorpc/generate_test.py | 152 -- .../internal/protorpc/message_types_test.py | 88 - endpoints/internal/protorpc/messages_test.py | 2109 ----------------- endpoints/internal/protorpc/protobuf_test.py | 299 --- endpoints/internal/protorpc/protojson_test.py | 565 ----- .../internal/protorpc/protorpc_test_pb2.py | 405 ---- .../internal/protorpc/protourlencode_test.py | 369 --- endpoints/internal/protorpc/registry_test.py | 124 - endpoints/internal/protorpc/remote_test.py | 933 -------- endpoints/internal/protorpc/test_util.py | 671 ------ endpoints/internal/protorpc/transport_test.py | 493 ---- endpoints/internal/protorpc/util_test.py | 394 --- .../internal/protorpc/webapp/forms_test.py | 159 -- .../protorpc/webapp/service_handlers_test.py | 1332 ----------- .../internal/protorpc/webapp_test_util.py | 411 ---- .../internal/protorpc/wsgi/service_test.py | 205 -- endpoints/internal/protorpc/wsgi/util_test.py | 295 --- 22 files changed, 11017 deletions(-) delete mode 100644 endpoints/internal/protorpc/definition_test.py delete mode 100644 endpoints/internal/protorpc/descriptor_test.py delete mode 100644 endpoints/internal/protorpc/end2end_test.py delete mode 100644 endpoints/internal/protorpc/generate_proto_test.py delete mode 100644 endpoints/internal/protorpc/generate_python_test.py delete mode 100644 endpoints/internal/protorpc/generate_test.py delete mode 100644 endpoints/internal/protorpc/message_types_test.py delete mode 100644 endpoints/internal/protorpc/messages_test.py delete mode 100644 endpoints/internal/protorpc/protobuf_test.py delete mode 100644 endpoints/internal/protorpc/protojson_test.py delete mode 100644 endpoints/internal/protorpc/protorpc_test_pb2.py delete mode 100644 endpoints/internal/protorpc/protourlencode_test.py delete mode 100644 endpoints/internal/protorpc/registry_test.py delete mode 100644 endpoints/internal/protorpc/remote_test.py delete mode 100644 endpoints/internal/protorpc/test_util.py delete mode 100644 endpoints/internal/protorpc/transport_test.py delete mode 100644 endpoints/internal/protorpc/util_test.py delete mode 100644 endpoints/internal/protorpc/webapp/forms_test.py delete mode 100644 endpoints/internal/protorpc/webapp/service_handlers_test.py delete mode 100644 endpoints/internal/protorpc/webapp_test_util.py delete mode 100644 endpoints/internal/protorpc/wsgi/service_test.py delete mode 100644 endpoints/internal/protorpc/wsgi/util_test.py diff --git a/endpoints/internal/protorpc/definition_test.py b/endpoints/internal/protorpc/definition_test.py deleted file mode 100644 index 992220e..0000000 --- a/endpoints/internal/protorpc/definition_test.py +++ /dev/null @@ -1,657 +0,0 @@ -#!/usr/bin/env python -# -# Copyright 2010 Google Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -"""Tests for protorpc.stub.""" - -__author__ = 'rafek@google.com (Rafe Kaplan)' - - -import StringIO -import sys -import types -import unittest - -from protorpc import definition -from protorpc import descriptor -from protorpc import message_types -from protorpc import messages -from protorpc import protobuf -from protorpc import remote -from protorpc import test_util - -import mox - - -class ModuleInterfaceTest(test_util.ModuleInterfaceTest, - test_util.TestCase): - - MODULE = definition - - -class DefineEnumTest(test_util.TestCase): - """Test for define_enum.""" - - def testDefineEnum_Empty(self): - """Test defining an empty enum.""" - enum_descriptor = descriptor.EnumDescriptor() - enum_descriptor.name = 'Empty' - - enum_class = definition.define_enum(enum_descriptor, 'whatever') - - self.assertEquals('Empty', enum_class.__name__) - self.assertEquals('whatever', enum_class.__module__) - - self.assertEquals(enum_descriptor, descriptor.describe_enum(enum_class)) - - def testDefineEnum(self): - """Test defining an enum.""" - red = descriptor.EnumValueDescriptor() - green = descriptor.EnumValueDescriptor() - blue = descriptor.EnumValueDescriptor() - - red.name = 'RED' - red.number = 1 - green.name = 'GREEN' - green.number = 2 - blue.name = 'BLUE' - blue.number = 3 - - enum_descriptor = descriptor.EnumDescriptor() - enum_descriptor.name = 'Colors' - enum_descriptor.values = [red, green, blue] - - enum_class = definition.define_enum(enum_descriptor, 'whatever') - - self.assertEquals('Colors', enum_class.__name__) - self.assertEquals('whatever', enum_class.__module__) - - self.assertEquals(enum_descriptor, descriptor.describe_enum(enum_class)) - - -class DefineFieldTest(test_util.TestCase): - """Test for define_field.""" - - def testDefineField_Optional(self): - """Test defining an optional field instance from a method descriptor.""" - field_descriptor = descriptor.FieldDescriptor() - - field_descriptor.name = 'a_field' - field_descriptor.number = 1 - field_descriptor.variant = descriptor.FieldDescriptor.Variant.INT32 - field_descriptor.label = descriptor.FieldDescriptor.Label.OPTIONAL - - field = definition.define_field(field_descriptor) - - # Name will not be set from the original descriptor. - self.assertFalse(hasattr(field, 'name')) - - self.assertTrue(isinstance(field, messages.IntegerField)) - self.assertEquals(1, field.number) - self.assertEquals(descriptor.FieldDescriptor.Variant.INT32, field.variant) - self.assertFalse(field.required) - self.assertFalse(field.repeated) - - def testDefineField_Required(self): - """Test defining a required field instance from a method descriptor.""" - field_descriptor = descriptor.FieldDescriptor() - - field_descriptor.name = 'a_field' - field_descriptor.number = 1 - field_descriptor.variant = descriptor.FieldDescriptor.Variant.STRING - field_descriptor.label = descriptor.FieldDescriptor.Label.REQUIRED - - field = definition.define_field(field_descriptor) - - # Name will not be set from the original descriptor. - self.assertFalse(hasattr(field, 'name')) - - self.assertTrue(isinstance(field, messages.StringField)) - self.assertEquals(1, field.number) - self.assertEquals(descriptor.FieldDescriptor.Variant.STRING, field.variant) - self.assertTrue(field.required) - self.assertFalse(field.repeated) - - def testDefineField_Repeated(self): - """Test defining a repeated field instance from a method descriptor.""" - field_descriptor = descriptor.FieldDescriptor() - - field_descriptor.name = 'a_field' - field_descriptor.number = 1 - field_descriptor.variant = descriptor.FieldDescriptor.Variant.DOUBLE - field_descriptor.label = descriptor.FieldDescriptor.Label.REPEATED - - field = definition.define_field(field_descriptor) - - # Name will not be set from the original descriptor. - self.assertFalse(hasattr(field, 'name')) - - self.assertTrue(isinstance(field, messages.FloatField)) - self.assertEquals(1, field.number) - self.assertEquals(descriptor.FieldDescriptor.Variant.DOUBLE, field.variant) - self.assertFalse(field.required) - self.assertTrue(field.repeated) - - def testDefineField_Message(self): - """Test defining a message field.""" - field_descriptor = descriptor.FieldDescriptor() - - field_descriptor.name = 'a_field' - field_descriptor.number = 1 - field_descriptor.variant = descriptor.FieldDescriptor.Variant.MESSAGE - field_descriptor.type_name = 'something.yet.to.be.Defined' - field_descriptor.label = descriptor.FieldDescriptor.Label.REPEATED - - field = definition.define_field(field_descriptor) - - # Name will not be set from the original descriptor. - self.assertFalse(hasattr(field, 'name')) - - self.assertTrue(isinstance(field, messages.MessageField)) - self.assertEquals(1, field.number) - self.assertEquals(descriptor.FieldDescriptor.Variant.MESSAGE, field.variant) - self.assertFalse(field.required) - self.assertTrue(field.repeated) - self.assertRaisesWithRegexpMatch(messages.DefinitionNotFoundError, - 'Could not find definition for ' - 'something.yet.to.be.Defined', - getattr, field, 'type') - - def testDefineField_DateTime(self): - """Test defining a date time field.""" - field_descriptor = descriptor.FieldDescriptor() - - field_descriptor.name = 'a_timestamp' - field_descriptor.number = 1 - field_descriptor.variant = descriptor.FieldDescriptor.Variant.MESSAGE - field_descriptor.type_name = 'protorpc.message_types.DateTimeMessage' - field_descriptor.label = descriptor.FieldDescriptor.Label.REPEATED - - field = definition.define_field(field_descriptor) - - # Name will not be set from the original descriptor. - self.assertFalse(hasattr(field, 'name')) - - self.assertTrue(isinstance(field, message_types.DateTimeField)) - self.assertEquals(1, field.number) - self.assertEquals(descriptor.FieldDescriptor.Variant.MESSAGE, field.variant) - self.assertFalse(field.required) - self.assertTrue(field.repeated) - - def testDefineField_Enum(self): - """Test defining an enum field.""" - field_descriptor = descriptor.FieldDescriptor() - - field_descriptor.name = 'a_field' - field_descriptor.number = 1 - field_descriptor.variant = descriptor.FieldDescriptor.Variant.ENUM - field_descriptor.type_name = 'something.yet.to.be.Defined' - field_descriptor.label = descriptor.FieldDescriptor.Label.REPEATED - - field = definition.define_field(field_descriptor) - - # Name will not be set from the original descriptor. - self.assertFalse(hasattr(field, 'name')) - - self.assertTrue(isinstance(field, messages.EnumField)) - self.assertEquals(1, field.number) - self.assertEquals(descriptor.FieldDescriptor.Variant.ENUM, field.variant) - self.assertFalse(field.required) - self.assertTrue(field.repeated) - self.assertRaisesWithRegexpMatch(messages.DefinitionNotFoundError, - 'Could not find definition for ' - 'something.yet.to.be.Defined', - getattr, field, 'type') - - def testDefineField_Default_Bool(self): - """Test defining a default value for a bool.""" - field_descriptor = descriptor.FieldDescriptor() - - field_descriptor.name = 'a_field' - field_descriptor.number = 1 - field_descriptor.variant = descriptor.FieldDescriptor.Variant.BOOL - field_descriptor.default_value = u'true' - - field = definition.define_field(field_descriptor) - - # Name will not be set from the original descriptor. - self.assertFalse(hasattr(field, 'name')) - - self.assertTrue(isinstance(field, messages.BooleanField)) - self.assertEquals(1, field.number) - self.assertEquals(descriptor.FieldDescriptor.Variant.BOOL, field.variant) - self.assertFalse(field.required) - self.assertFalse(field.repeated) - self.assertEqual(field.default, True) - - field_descriptor.default_value = u'false' - - field = definition.define_field(field_descriptor) - - self.assertEqual(field.default, False) - - def testDefineField_Default_Float(self): - """Test defining a default value for a float.""" - field_descriptor = descriptor.FieldDescriptor() - - field_descriptor.name = 'a_field' - field_descriptor.number = 1 - field_descriptor.variant = descriptor.FieldDescriptor.Variant.FLOAT - field_descriptor.default_value = u'34.567' - - field = definition.define_field(field_descriptor) - - # Name will not be set from the original descriptor. - self.assertFalse(hasattr(field, 'name')) - - self.assertTrue(isinstance(field, messages.FloatField)) - self.assertEquals(1, field.number) - self.assertEquals(descriptor.FieldDescriptor.Variant.FLOAT, field.variant) - self.assertFalse(field.required) - self.assertFalse(field.repeated) - self.assertEqual(field.default, 34.567) - - def testDefineField_Default_Int(self): - """Test defining a default value for an int.""" - field_descriptor = descriptor.FieldDescriptor() - - field_descriptor.name = 'a_field' - field_descriptor.number = 1 - field_descriptor.variant = descriptor.FieldDescriptor.Variant.INT64 - field_descriptor.default_value = u'34' - - field = definition.define_field(field_descriptor) - - # Name will not be set from the original descriptor. - self.assertFalse(hasattr(field, 'name')) - - self.assertTrue(isinstance(field, messages.IntegerField)) - self.assertEquals(1, field.number) - self.assertEquals(descriptor.FieldDescriptor.Variant.INT64, field.variant) - self.assertFalse(field.required) - self.assertFalse(field.repeated) - self.assertEqual(field.default, 34) - - def testDefineField_Default_Str(self): - """Test defining a default value for a str.""" - field_descriptor = descriptor.FieldDescriptor() - - field_descriptor.name = 'a_field' - field_descriptor.number = 1 - field_descriptor.variant = descriptor.FieldDescriptor.Variant.STRING - field_descriptor.default_value = u'Test' - - field = definition.define_field(field_descriptor) - - # Name will not be set from the original descriptor. - self.assertFalse(hasattr(field, 'name')) - - self.assertTrue(isinstance(field, messages.StringField)) - self.assertEquals(1, field.number) - self.assertEquals(descriptor.FieldDescriptor.Variant.STRING, field.variant) - self.assertFalse(field.required) - self.assertFalse(field.repeated) - self.assertEqual(field.default, u'Test') - - def testDefineField_Default_Invalid(self): - """Test defining a default value that is not valid.""" - field_descriptor = descriptor.FieldDescriptor() - - field_descriptor.name = 'a_field' - field_descriptor.number = 1 - field_descriptor.variant = descriptor.FieldDescriptor.Variant.INT64 - field_descriptor.default_value = u'Test' - - # Verify that the string is passed to the Constructor. - mock = mox.Mox() - mock.StubOutWithMock(messages.IntegerField, '__init__') - messages.IntegerField.__init__( - default=u'Test', - number=1, - variant=messages.Variant.INT64 - ).AndRaise(messages.InvalidDefaultError) - - mock.ReplayAll() - self.assertRaises(messages.InvalidDefaultError, - definition.define_field, field_descriptor) - mock.VerifyAll() - - mock.ResetAll() - mock.UnsetStubs() - - -class DefineMessageTest(test_util.TestCase): - """Test for define_message.""" - - def testDefineMessageEmpty(self): - """Test definition a message with no fields or enums.""" - - class AMessage(messages.Message): - pass - - message_descriptor = descriptor.describe_message(AMessage) - - message_class = definition.define_message(message_descriptor, '__main__') - - self.assertEquals('AMessage', message_class.__name__) - self.assertEquals('__main__', message_class.__module__) - - self.assertEquals(message_descriptor, - descriptor.describe_message(message_class)) - - def testDefineMessageEnumOnly(self): - """Test definition a message with only enums.""" - - class AMessage(messages.Message): - class NestedEnum(messages.Enum): - pass - - message_descriptor = descriptor.describe_message(AMessage) - - message_class = definition.define_message(message_descriptor, '__main__') - - self.assertEquals('AMessage', message_class.__name__) - self.assertEquals('__main__', message_class.__module__) - - self.assertEquals(message_descriptor, - descriptor.describe_message(message_class)) - - def testDefineMessageFieldsOnly(self): - """Test definition a message with only fields.""" - - class AMessage(messages.Message): - - field1 = messages.IntegerField(1) - field2 = messages.StringField(2) - - message_descriptor = descriptor.describe_message(AMessage) - - message_class = definition.define_message(message_descriptor, '__main__') - - self.assertEquals('AMessage', message_class.__name__) - self.assertEquals('__main__', message_class.__module__) - - self.assertEquals(message_descriptor, - descriptor.describe_message(message_class)) - - def testDefineMessage(self): - """Test defining Message class from descriptor.""" - - class AMessage(messages.Message): - class NestedEnum(messages.Enum): - pass - - field1 = messages.IntegerField(1) - field2 = messages.StringField(2) - - message_descriptor = descriptor.describe_message(AMessage) - - message_class = definition.define_message(message_descriptor, '__main__') - - self.assertEquals('AMessage', message_class.__name__) - self.assertEquals('__main__', message_class.__module__) - - self.assertEquals(message_descriptor, - descriptor.describe_message(message_class)) - - -class DefineServiceTest(test_util.TestCase): - """Test service proxy definition.""" - - def setUp(self): - """Set up mock and request classes.""" - self.module = types.ModuleType('stocks') - - class GetQuoteRequest(messages.Message): - __module__ = 'stocks' - - symbols = messages.StringField(1, repeated=True) - - class GetQuoteResponse(messages.Message): - __module__ = 'stocks' - - prices = messages.IntegerField(1, repeated=True) - - self.module.GetQuoteRequest = GetQuoteRequest - self.module.GetQuoteResponse = GetQuoteResponse - - def testDefineService(self): - """Test service definition from descriptor.""" - method_descriptor = descriptor.MethodDescriptor() - method_descriptor.name = 'get_quote' - method_descriptor.request_type = 'GetQuoteRequest' - method_descriptor.response_type = 'GetQuoteResponse' - - service_descriptor = descriptor.ServiceDescriptor() - service_descriptor.name = 'Stocks' - service_descriptor.methods = [method_descriptor] - - StockService = definition.define_service(service_descriptor, self.module) - - self.assertTrue(issubclass(StockService, remote.Service)) - self.assertTrue(issubclass(StockService.Stub, remote.StubBase)) - - request = self.module.GetQuoteRequest() - service = StockService() - self.assertRaises(NotImplementedError, - service.get_quote, request) - - self.assertEquals(self.module.GetQuoteRequest, - service.get_quote.remote.request_type) - self.assertEquals(self.module.GetQuoteResponse, - service.get_quote.remote.response_type) - - -class ModuleTest(test_util.TestCase): - """Test for module creation and importation functions.""" - - def MakeFileDescriptor(self, package): - """Helper method to construct FileDescriptors. - - Creates FileDescriptor with a MessageDescriptor and an EnumDescriptor. - - Args: - package: Package name to give new file descriptors. - - Returns: - New FileDescriptor instance. - """ - enum_descriptor = descriptor.EnumDescriptor() - enum_descriptor.name = u'MyEnum' - - message_descriptor = descriptor.MessageDescriptor() - message_descriptor.name = u'MyMessage' - - service_descriptor = descriptor.ServiceDescriptor() - service_descriptor.name = u'MyService' - - file_descriptor = descriptor.FileDescriptor() - file_descriptor.package = package - file_descriptor.enum_types = [enum_descriptor] - file_descriptor.message_types = [message_descriptor] - file_descriptor.service_types = [service_descriptor] - - return file_descriptor - - def testDefineModule(self): - """Test define_module function.""" - file_descriptor = self.MakeFileDescriptor('my.package') - - module = definition.define_file(file_descriptor) - - self.assertEquals('my.package', module.__name__) - self.assertEquals('my.package', module.MyEnum.__module__) - self.assertEquals('my.package', module.MyMessage.__module__) - self.assertEquals('my.package', module.MyService.__module__) - - self.assertEquals(file_descriptor, descriptor.describe_file(module)) - - def testDefineModule_ReuseModule(self): - """Test updating module with additional definitions.""" - file_descriptor = self.MakeFileDescriptor('my.package') - - module = types.ModuleType('override') - self.assertEquals(module, definition.define_file(file_descriptor, module)) - - self.assertEquals('override', module.MyEnum.__module__) - self.assertEquals('override', module.MyMessage.__module__) - self.assertEquals('override', module.MyService.__module__) - - # One thing is different between original descriptor and new. - file_descriptor.package = 'override' - self.assertEquals(file_descriptor, descriptor.describe_file(module)) - - def testImportFile(self): - """Test importing FileDescriptor in to module space.""" - modules = {} - file_descriptor = self.MakeFileDescriptor('standalone') - definition.import_file(file_descriptor, modules=modules) - self.assertEquals(file_descriptor, - descriptor.describe_file(modules['standalone'])) - - def testImportFile_InToExisting(self): - """Test importing FileDescriptor in to existing module.""" - module = types.ModuleType('standalone') - modules = {'standalone': module} - file_descriptor = self.MakeFileDescriptor('standalone') - definition.import_file(file_descriptor, modules=modules) - self.assertEquals(module, modules['standalone']) - self.assertEquals(file_descriptor, - descriptor.describe_file(modules['standalone'])) - - def testImportFile_InToGlobalModules(self): - """Test importing FileDescriptor in to global modules.""" - original_modules = sys.modules - try: - sys.modules = dict(sys.modules) - if 'standalone' in sys.modules: - del sys.modules['standalone'] - file_descriptor = self.MakeFileDescriptor('standalone') - definition.import_file(file_descriptor) - self.assertEquals(file_descriptor, - descriptor.describe_file(sys.modules['standalone'])) - finally: - sys.modules = original_modules - - def testImportFile_Nested(self): - """Test importing FileDescriptor in to existing nested module.""" - modules = {} - file_descriptor = self.MakeFileDescriptor('root.nested') - definition.import_file(file_descriptor, modules=modules) - self.assertEquals(modules['root'].nested, modules['root.nested']) - self.assertEquals(file_descriptor, - descriptor.describe_file(modules['root.nested'])) - - def testImportFile_NoPackage(self): - """Test importing FileDescriptor with no package.""" - file_descriptor = self.MakeFileDescriptor('does not matter') - file_descriptor.reset('package') - self.assertRaisesWithRegexpMatch(ValueError, - 'File descriptor must have package name', - definition.import_file, - file_descriptor) - - def testImportFileSet(self): - """Test importing a whole file set.""" - file_set = descriptor.FileSet() - file_set.files = [self.MakeFileDescriptor(u'standalone'), - self.MakeFileDescriptor(u'root.nested'), - self.MakeFileDescriptor(u'root.nested.nested'), - ] - - root = types.ModuleType('root') - nested = types.ModuleType('root.nested') - root.nested = nested - modules = { - 'root': root, - 'root.nested': nested, - } - - definition.import_file_set(file_set, modules=modules) - - self.assertEquals(root, modules['root']) - self.assertEquals(nested, modules['root.nested']) - self.assertEquals(nested.nested, modules['root.nested.nested']) - - self.assertEquals(file_set, - descriptor.describe_file_set( - [modules['standalone'], - modules['root.nested'], - modules['root.nested.nested'], - ])) - - def testImportFileSetFromFile(self): - """Test importing a whole file set from a file.""" - file_set = descriptor.FileSet() - file_set.files = [self.MakeFileDescriptor(u'standalone'), - self.MakeFileDescriptor(u'root.nested'), - self.MakeFileDescriptor(u'root.nested.nested'), - ] - - stream = StringIO.StringIO(protobuf.encode_message(file_set)) - - self.mox = mox.Mox() - opener = self.mox.CreateMockAnything() - opener('my-file.dat', 'rb').AndReturn(stream) - - self.mox.ReplayAll() - - modules = {} - definition.import_file_set('my-file.dat', modules=modules, _open=opener) - - self.assertEquals(file_set, - descriptor.describe_file_set( - [modules['standalone'], - modules['root.nested'], - modules['root.nested.nested'], - ])) - - def testImportBuiltInProtorpcClasses(self): - """Test that built in Protorpc classes are skipped.""" - file_set = descriptor.FileSet() - file_set.files = [self.MakeFileDescriptor(u'standalone'), - self.MakeFileDescriptor(u'root.nested'), - self.MakeFileDescriptor(u'root.nested.nested'), - descriptor.describe_file(descriptor), - ] - - root = types.ModuleType('root') - nested = types.ModuleType('root.nested') - root.nested = nested - modules = { - 'root': root, - 'root.nested': nested, - 'protorpc.descriptor': descriptor, - } - - definition.import_file_set(file_set, modules=modules) - - self.assertEquals(root, modules['root']) - self.assertEquals(nested, modules['root.nested']) - self.assertEquals(nested.nested, modules['root.nested.nested']) - self.assertEquals(descriptor, modules['protorpc.descriptor']) - - self.assertEquals(file_set, - descriptor.describe_file_set( - [modules['standalone'], - modules['root.nested'], - modules['root.nested.nested'], - modules['protorpc.descriptor'], - ])) - - -if __name__ == '__main__': - unittest.main() diff --git a/endpoints/internal/protorpc/descriptor_test.py b/endpoints/internal/protorpc/descriptor_test.py deleted file mode 100644 index 5047e8e..0000000 --- a/endpoints/internal/protorpc/descriptor_test.py +++ /dev/null @@ -1,649 +0,0 @@ -#!/usr/bin/env python -# -# Copyright 2010 Google Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -"""Tests for protorpc.descriptor.""" - -__author__ = 'rafek@google.com (Rafe Kaplan)' - - -import types -import unittest - -from protorpc import descriptor -from protorpc import message_types -from protorpc import messages -from protorpc import registry -from protorpc import remote -from protorpc import test_util - - -RUSSIA = u'\u0420\u043e\u0441\u0441\u0438\u044f' - - -class ModuleInterfaceTest(test_util.ModuleInterfaceTest, - test_util.TestCase): - - MODULE = descriptor - - -class DescribeEnumValueTest(test_util.TestCase): - - def testDescribe(self): - class MyEnum(messages.Enum): - MY_NAME = 10 - - expected = descriptor.EnumValueDescriptor() - expected.name = 'MY_NAME' - expected.number = 10 - - described = descriptor.describe_enum_value(MyEnum.MY_NAME) - described.check_initialized() - self.assertEquals(expected, described) - - -class DescribeEnumTest(test_util.TestCase): - - def testEmptyEnum(self): - class EmptyEnum(messages.Enum): - pass - - expected = descriptor.EnumDescriptor() - expected.name = 'EmptyEnum' - - described = descriptor.describe_enum(EmptyEnum) - described.check_initialized() - self.assertEquals(expected, described) - - def testNestedEnum(self): - class MyScope(messages.Message): - class NestedEnum(messages.Enum): - pass - - expected = descriptor.EnumDescriptor() - expected.name = 'NestedEnum' - - described = descriptor.describe_enum(MyScope.NestedEnum) - described.check_initialized() - self.assertEquals(expected, described) - - def testEnumWithItems(self): - class EnumWithItems(messages.Enum): - A = 3 - B = 1 - C = 2 - - expected = descriptor.EnumDescriptor() - expected.name = 'EnumWithItems' - - a = descriptor.EnumValueDescriptor() - a.name = 'A' - a.number = 3 - - b = descriptor.EnumValueDescriptor() - b.name = 'B' - b.number = 1 - - c = descriptor.EnumValueDescriptor() - c.name = 'C' - c.number = 2 - - expected.values = [b, c, a] - - described = descriptor.describe_enum(EnumWithItems) - described.check_initialized() - self.assertEquals(expected, described) - - -class DescribeFieldTest(test_util.TestCase): - - def testLabel(self): - for repeated, required, expected_label in ( - (True, False, descriptor.FieldDescriptor.Label.REPEATED), - (False, True, descriptor.FieldDescriptor.Label.REQUIRED), - (False, False, descriptor.FieldDescriptor.Label.OPTIONAL)): - field = messages.IntegerField(10, required=required, repeated=repeated) - field.name = 'a_field' - - expected = descriptor.FieldDescriptor() - expected.name = 'a_field' - expected.number = 10 - expected.label = expected_label - expected.variant = descriptor.FieldDescriptor.Variant.INT64 - - described = descriptor.describe_field(field) - described.check_initialized() - self.assertEquals(expected, described) - - def testDefault(self): - for field_class, default, expected_default in ( - (messages.IntegerField, 200, '200'), - (messages.FloatField, 1.5, '1.5'), - (messages.FloatField, 1e6, '1000000.0'), - (messages.BooleanField, True, 'true'), - (messages.BooleanField, False, 'false'), - (messages.BytesField, 'ab\xF1', 'ab\\xf1'), - (messages.StringField, RUSSIA, RUSSIA), - ): - field = field_class(10, default=default) - field.name = u'a_field' - - expected = descriptor.FieldDescriptor() - expected.name = u'a_field' - expected.number = 10 - expected.label = descriptor.FieldDescriptor.Label.OPTIONAL - expected.variant = field_class.DEFAULT_VARIANT - expected.default_value = expected_default - - described = descriptor.describe_field(field) - described.check_initialized() - self.assertEquals(expected, described) - - def testDefault_EnumField(self): - class MyEnum(messages.Enum): - - VAL = 1 - - module_name = test_util.get_module_name(MyEnum) - field = messages.EnumField(MyEnum, 10, default=MyEnum.VAL) - field.name = 'a_field' - - expected = descriptor.FieldDescriptor() - expected.name = 'a_field' - expected.number = 10 - expected.label = descriptor.FieldDescriptor.Label.OPTIONAL - expected.variant = messages.EnumField.DEFAULT_VARIANT - expected.type_name = '%s.MyEnum' % module_name - expected.default_value = '1' - - described = descriptor.describe_field(field) - self.assertEquals(expected, described) - - def testMessageField(self): - field = messages.MessageField(descriptor.FieldDescriptor, 10) - field.name = 'a_field' - - expected = descriptor.FieldDescriptor() - expected.name = 'a_field' - expected.number = 10 - expected.label = descriptor.FieldDescriptor.Label.OPTIONAL - expected.variant = messages.MessageField.DEFAULT_VARIANT - expected.type_name = ('protorpc.descriptor.FieldDescriptor') - - described = descriptor.describe_field(field) - described.check_initialized() - self.assertEquals(expected, described) - - def testDateTimeField(self): - field = message_types.DateTimeField(20) - field.name = 'a_timestamp' - - expected = descriptor.FieldDescriptor() - expected.name = 'a_timestamp' - expected.number = 20 - expected.label = descriptor.FieldDescriptor.Label.OPTIONAL - expected.variant = messages.MessageField.DEFAULT_VARIANT - expected.type_name = ('protorpc.message_types.DateTimeMessage') - - described = descriptor.describe_field(field) - described.check_initialized() - self.assertEquals(expected, described) - - -class DescribeMessageTest(test_util.TestCase): - - def testEmptyDefinition(self): - class MyMessage(messages.Message): - pass - - expected = descriptor.MessageDescriptor() - expected.name = 'MyMessage' - - described = descriptor.describe_message(MyMessage) - described.check_initialized() - self.assertEquals(expected, described) - - def testDefinitionWithFields(self): - class MessageWithFields(messages.Message): - field1 = messages.IntegerField(10) - field2 = messages.StringField(30) - field3 = messages.IntegerField(20) - - expected = descriptor.MessageDescriptor() - expected.name = 'MessageWithFields' - - expected.fields = [ - descriptor.describe_field(MessageWithFields.field_by_name('field1')), - descriptor.describe_field(MessageWithFields.field_by_name('field3')), - descriptor.describe_field(MessageWithFields.field_by_name('field2')), - ] - - described = descriptor.describe_message(MessageWithFields) - described.check_initialized() - self.assertEquals(expected, described) - - def testNestedEnum(self): - class MessageWithEnum(messages.Message): - class Mood(messages.Enum): - GOOD = 1 - BAD = 2 - UGLY = 3 - - class Music(messages.Enum): - CLASSIC = 1 - JAZZ = 2 - BLUES = 3 - - expected = descriptor.MessageDescriptor() - expected.name = 'MessageWithEnum' - - expected.enum_types = [descriptor.describe_enum(MessageWithEnum.Mood), - descriptor.describe_enum(MessageWithEnum.Music)] - - described = descriptor.describe_message(MessageWithEnum) - described.check_initialized() - self.assertEquals(expected, described) - - def testNestedMessage(self): - class MessageWithMessage(messages.Message): - class Nesty(messages.Message): - pass - - expected = descriptor.MessageDescriptor() - expected.name = 'MessageWithMessage' - - expected.message_types = [ - descriptor.describe_message(MessageWithMessage.Nesty)] - - described = descriptor.describe_message(MessageWithMessage) - described.check_initialized() - self.assertEquals(expected, described) - - -class DescribeMethodTest(test_util.TestCase): - """Test describing remote methods.""" - - def testDescribe(self): - class Request(messages.Message): - pass - - class Response(messages.Message): - pass - - @remote.method(Request, Response) - def remote_method(request): - pass - - module_name = test_util.get_module_name(DescribeMethodTest) - expected = descriptor.MethodDescriptor() - expected.name = 'remote_method' - expected.request_type = '%s.Request' % module_name - expected.response_type = '%s.Response' % module_name - - described = descriptor.describe_method(remote_method) - described.check_initialized() - self.assertEquals(expected, described) - - -class DescribeServiceTest(test_util.TestCase): - """Test describing service classes.""" - - def testDescribe(self): - class Request1(messages.Message): - pass - - class Response1(messages.Message): - pass - - class Request2(messages.Message): - pass - - class Response2(messages.Message): - pass - - class MyService(remote.Service): - - @remote.method(Request1, Response1) - def method1(self, request): - pass - - @remote.method(Request2, Response2) - def method2(self, request): - pass - - expected = descriptor.ServiceDescriptor() - expected.name = 'MyService' - expected.methods = [] - - expected.methods.append(descriptor.describe_method(MyService.method1)) - expected.methods.append(descriptor.describe_method(MyService.method2)) - - described = descriptor.describe_service(MyService) - described.check_initialized() - self.assertEquals(expected, described) - - -class DescribeFileTest(test_util.TestCase): - """Test describing modules.""" - - def LoadModule(self, module_name, source): - result = {'__name__': module_name, - 'messages': messages, - 'remote': remote, - } - exec(source, result) - - module = types.ModuleType(module_name) - for name, value in result.items(): - setattr(module, name, value) - - return module - - def testEmptyModule(self): - """Test describing an empty file.""" - module = types.ModuleType('my.package.name') - - expected = descriptor.FileDescriptor() - expected.package = 'my.package.name' - - described = descriptor.describe_file(module) - described.check_initialized() - self.assertEquals(expected, described) - - def testNoPackageName(self): - """Test describing a module with no module name.""" - module = types.ModuleType('') - - expected = descriptor.FileDescriptor() - - described = descriptor.describe_file(module) - described.check_initialized() - self.assertEquals(expected, described) - - def testPackageName(self): - """Test using the 'package' module attribute.""" - module = types.ModuleType('my.module.name') - module.package = 'my.package.name' - - expected = descriptor.FileDescriptor() - expected.package = 'my.package.name' - - described = descriptor.describe_file(module) - described.check_initialized() - self.assertEquals(expected, described) - - def testMain(self): - """Test using the 'package' module attribute.""" - module = types.ModuleType('__main__') - module.__file__ = '/blim/blam/bloom/my_package.py' - - expected = descriptor.FileDescriptor() - expected.package = 'my_package' - - described = descriptor.describe_file(module) - described.check_initialized() - self.assertEquals(expected, described) - - def testMessages(self): - """Test that messages are described.""" - module = self.LoadModule('my.package', - 'class Message1(messages.Message): pass\n' - 'class Message2(messages.Message): pass\n') - - message1 = descriptor.MessageDescriptor() - message1.name = 'Message1' - - message2 = descriptor.MessageDescriptor() - message2.name = 'Message2' - - expected = descriptor.FileDescriptor() - expected.package = 'my.package' - expected.message_types = [message1, message2] - - described = descriptor.describe_file(module) - described.check_initialized() - self.assertEquals(expected, described) - - def testEnums(self): - """Test that enums are described.""" - module = self.LoadModule('my.package', - 'class Enum1(messages.Enum): pass\n' - 'class Enum2(messages.Enum): pass\n') - - enum1 = descriptor.EnumDescriptor() - enum1.name = 'Enum1' - - enum2 = descriptor.EnumDescriptor() - enum2.name = 'Enum2' - - expected = descriptor.FileDescriptor() - expected.package = 'my.package' - expected.enum_types = [enum1, enum2] - - described = descriptor.describe_file(module) - described.check_initialized() - self.assertEquals(expected, described) - - def testServices(self): - """Test that services are described.""" - module = self.LoadModule('my.package', - 'class Service1(remote.Service): pass\n' - 'class Service2(remote.Service): pass\n') - - service1 = descriptor.ServiceDescriptor() - service1.name = 'Service1' - - service2 = descriptor.ServiceDescriptor() - service2.name = 'Service2' - - expected = descriptor.FileDescriptor() - expected.package = 'my.package' - expected.service_types = [service1, service2] - - described = descriptor.describe_file(module) - described.check_initialized() - self.assertEquals(expected, described) - - -class DescribeFileSetTest(test_util.TestCase): - """Test describing multiple modules.""" - - def testNoModules(self): - """Test what happens when no modules provided.""" - described = descriptor.describe_file_set([]) - described.check_initialized() - # The described FileSet.files will be None. - self.assertEquals(descriptor.FileSet(), described) - - def testWithModules(self): - """Test what happens when no modules provided.""" - modules = [types.ModuleType('package1'), types.ModuleType('package1')] - - file1 = descriptor.FileDescriptor() - file1.package = 'package1' - file2 = descriptor.FileDescriptor() - file2.package = 'package2' - - expected = descriptor.FileSet() - expected.files = [file1, file1] - - described = descriptor.describe_file_set(modules) - described.check_initialized() - self.assertEquals(expected, described) - - -class DescribeTest(test_util.TestCase): - - def testModule(self): - self.assertEquals(descriptor.describe_file(test_util), - descriptor.describe(test_util)) - - def testMethod(self): - class Param(messages.Message): - pass - - class Service(remote.Service): - - @remote.method(Param, Param) - def fn(self): - return Param() - - self.assertEquals(descriptor.describe_method(Service.fn), - descriptor.describe(Service.fn)) - - def testField(self): - self.assertEquals( - descriptor.describe_field(test_util.NestedMessage.a_value), - descriptor.describe(test_util.NestedMessage.a_value)) - - def testEnumValue(self): - self.assertEquals( - descriptor.describe_enum_value( - test_util.OptionalMessage.SimpleEnum.VAL1), - descriptor.describe(test_util.OptionalMessage.SimpleEnum.VAL1)) - - def testMessage(self): - self.assertEquals(descriptor.describe_message(test_util.NestedMessage), - descriptor.describe(test_util.NestedMessage)) - - def testEnum(self): - self.assertEquals( - descriptor.describe_enum(test_util.OptionalMessage.SimpleEnum), - descriptor.describe(test_util.OptionalMessage.SimpleEnum)) - - def testService(self): - class Service(remote.Service): - pass - - self.assertEquals(descriptor.describe_service(Service), - descriptor.describe(Service)) - - def testService(self): - class Service(remote.Service): - pass - - self.assertEquals(descriptor.describe_service(Service), - descriptor.describe(Service)) - - def testUndescribable(self): - class NonService(object): - - def fn(self): - pass - - for value in (NonService, - NonService.fn, - 1, - 'string', - 1.2, - None): - self.assertEquals(None, descriptor.describe(value)) - - -class ModuleFinderTest(test_util.TestCase): - - def testFindModule(self): - self.assertEquals(descriptor.describe_file(registry), - descriptor.import_descriptor_loader('protorpc.registry')) - - def testFindMessage(self): - self.assertEquals( - descriptor.describe_message(descriptor.FileSet), - descriptor.import_descriptor_loader('protorpc.descriptor.FileSet')) - - def testFindField(self): - self.assertEquals( - descriptor.describe_field(descriptor.FileSet.files), - descriptor.import_descriptor_loader('protorpc.descriptor.FileSet.files')) - - def testFindEnumValue(self): - self.assertEquals( - descriptor.describe_enum_value(test_util.OptionalMessage.SimpleEnum.VAL1), - descriptor.import_descriptor_loader( - 'protorpc.test_util.OptionalMessage.SimpleEnum.VAL1')) - - def testFindMethod(self): - self.assertEquals( - descriptor.describe_method(registry.RegistryService.services), - descriptor.import_descriptor_loader( - 'protorpc.registry.RegistryService.services')) - - def testFindService(self): - self.assertEquals( - descriptor.describe_service(registry.RegistryService), - descriptor.import_descriptor_loader('protorpc.registry.RegistryService')) - - def testFindWithAbsoluteName(self): - self.assertEquals( - descriptor.describe_service(registry.RegistryService), - descriptor.import_descriptor_loader('.protorpc.registry.RegistryService')) - - def testFindWrongThings(self): - for name in ('a', 'protorpc.registry.RegistryService.__init__', '', ): - self.assertRaisesWithRegexpMatch( - messages.DefinitionNotFoundError, - 'Could not find definition for %s' % name, - descriptor.import_descriptor_loader, name) - - -class DescriptorLibraryTest(test_util.TestCase): - - def setUp(self): - self.packageless = descriptor.MessageDescriptor() - self.packageless.name = 'Packageless' - self.library = descriptor.DescriptorLibrary( - descriptors={ - 'not.real.Packageless': self.packageless, - 'Packageless': self.packageless, - }) - - def testLookupPackage(self): - self.assertEquals('csv', self.library.lookup_package('csv')) - self.assertEquals('protorpc', self.library.lookup_package('protorpc')) - self.assertEquals('protorpc.registry', - self.library.lookup_package('protorpc.registry')) - self.assertEquals('protorpc.registry', - self.library.lookup_package('.protorpc.registry')) - self.assertEquals( - 'protorpc.registry', - self.library.lookup_package('protorpc.registry.RegistryService')) - self.assertEquals( - 'protorpc.registry', - self.library.lookup_package( - 'protorpc.registry.RegistryService.services')) - - def testLookupNonPackages(self): - for name in ('', 'a', 'protorpc.descriptor.DescriptorLibrary'): - self.assertRaisesWithRegexpMatch( - messages.DefinitionNotFoundError, - 'Could not find definition for %s' % name, - self.library.lookup_package, name) - - def testNoPackage(self): - self.assertRaisesWithRegexpMatch( - messages.DefinitionNotFoundError, - 'Could not find definition for not.real', - self.library.lookup_package, 'not.real.Packageless') - - self.assertEquals(None, self.library.lookup_package('Packageless')) - - -def main(): - unittest.main() - - -if __name__ == '__main__': - main() diff --git a/endpoints/internal/protorpc/end2end_test.py b/endpoints/internal/protorpc/end2end_test.py deleted file mode 100644 index c3e0141..0000000 --- a/endpoints/internal/protorpc/end2end_test.py +++ /dev/null @@ -1,148 +0,0 @@ -#!/usr/bin/env python -# -# Copyright 2011 Google Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -"""End to end tests for ProtoRPC.""" - -__author__ = 'rafek@google.com (Rafe Kaplan)' - - -import unittest - -from protorpc import protojson -from protorpc import remote -from protorpc import test_util -from protorpc import util -from protorpc import webapp_test_util - -package = 'test_package' - - -class EndToEndTest(webapp_test_util.EndToEndTestBase): - - def testSimpleRequest(self): - self.assertEquals(test_util.OptionalMessage(string_value='+blar'), - self.stub.optional_message(string_value='blar')) - - def testSimpleRequestComplexContentType(self): - response = self.DoRawRequest( - 'optional_message', - content='{"string_value": "blar"}', - content_type='application/json; charset=utf-8') - headers = response.headers - self.assertEquals(200, response.code) - self.assertEquals('{"string_value": "+blar"}', response.read()) - self.assertEquals('application/json', headers['content-type']) - - def testInitParameter(self): - self.assertEquals(test_util.OptionalMessage(string_value='uninitialized'), - self.stub.init_parameter()) - self.assertEquals(test_util.OptionalMessage(string_value='initialized'), - self.other_stub.init_parameter()) - - def testMissingContentType(self): - code, content, headers = self.RawRequestError( - 'optional_message', - content='{"string_value": "blar"}', - content_type='') - self.assertEquals(400, code) - self.assertEquals(util.pad_string('Bad Request'), content) - self.assertEquals('text/plain; charset=utf-8', headers['content-type']) - - def testWrongPath(self): - self.assertRaisesWithRegexpMatch(remote.ServerError, - 'HTTP Error 404: Not Found', - self.bad_path_stub.optional_message) - - def testUnsupportedContentType(self): - code, content, headers = self.RawRequestError( - 'optional_message', - content='{"string_value": "blar"}', - content_type='image/png') - self.assertEquals(415, code) - self.assertEquals(util.pad_string('Unsupported Media Type'), content) - self.assertEquals(headers['content-type'], 'text/plain; charset=utf-8') - - def testUnsupportedHttpMethod(self): - code, content, headers = self.RawRequestError('optional_message') - self.assertEquals(405, code) - self.assertEquals( - util.pad_string('/my/service.optional_message is a ProtoRPC method.\n\n' - 'Service protorpc.webapp_test_util.TestService\n\n' - 'More about ProtoRPC: ' - 'http://code.google.com/p/google-protorpc\n'), - content) - self.assertEquals(headers['content-type'], 'text/plain; charset=utf-8') - - def testMethodNotFound(self): - self.assertRaisesWithRegexpMatch(remote.MethodNotFoundError, - 'Unrecognized RPC method: does_not_exist', - self.mismatched_stub.does_not_exist) - - def testBadMessageError(self): - code, content, headers = self.RawRequestError('nested_message', - content='{}') - self.assertEquals(400, code) - - expected_content = protojson.encode_message(remote.RpcStatus( - state=remote.RpcState.REQUEST_ERROR, - error_message=('Error parsing ProtoRPC request ' - '(Unable to parse request content: ' - 'Message NestedMessage is missing ' - 'required field a_value)'))) - self.assertEquals(util.pad_string(expected_content), content) - self.assertEquals(headers['content-type'], 'application/json') - - def testApplicationError(self): - try: - self.stub.raise_application_error() - except remote.ApplicationError as err: - self.assertEquals('This is an application error', unicode(err)) - self.assertEquals('ERROR_NAME', err.error_name) - else: - self.fail('Expected application error') - - def testRpcError(self): - try: - self.stub.raise_rpc_error() - except remote.ServerError as err: - self.assertEquals('Internal Server Error', unicode(err)) - else: - self.fail('Expected server error') - - def testUnexpectedError(self): - try: - self.stub.raise_unexpected_error() - except remote.ServerError as err: - self.assertEquals('Internal Server Error', unicode(err)) - else: - self.fail('Expected server error') - - def testBadResponse(self): - try: - self.stub.return_bad_message() - except remote.ServerError as err: - self.assertEquals('Internal Server Error', unicode(err)) - else: - self.fail('Expected server error') - - -def main(): - unittest.main() - - -if __name__ == '__main__': - main() diff --git a/endpoints/internal/protorpc/generate_proto_test.py b/endpoints/internal/protorpc/generate_proto_test.py deleted file mode 100644 index 43469b5..0000000 --- a/endpoints/internal/protorpc/generate_proto_test.py +++ /dev/null @@ -1,197 +0,0 @@ -#!/usr/bin/env python -# -# Copyright 2010 Google Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -"""Tests for protorpc.generate_proto_test.""" - - -import os -import shutil -import cStringIO -import sys -import tempfile -import unittest - -from protorpc import descriptor -from protorpc import generate_proto -from protorpc import test_util -from protorpc import util - - -class ModuleInterfaceTest(test_util.ModuleInterfaceTest, - test_util.TestCase): - - MODULE = generate_proto - - -class FormatProtoFileTest(test_util.TestCase): - - def setUp(self): - self.file_descriptor = descriptor.FileDescriptor() - self.output = cStringIO.StringIO() - - @property - def result(self): - return self.output.getvalue() - - def MakeMessage(self, name='MyMessage', fields=[]): - message = descriptor.MessageDescriptor() - message.name = name - message.fields = fields - - messages_list = getattr(self.file_descriptor, 'fields', []) - messages_list.append(message) - self.file_descriptor.message_types = messages_list - - def testBlankPackage(self): - self.file_descriptor.package = None - generate_proto.format_proto_file(self.file_descriptor, self.output) - self.assertEquals('', self.result) - - def testEmptyPackage(self): - self.file_descriptor.package = 'my_package' - generate_proto.format_proto_file(self.file_descriptor, self.output) - self.assertEquals('package my_package;\n', self.result) - - def testSingleField(self): - field = descriptor.FieldDescriptor() - field.name = 'integer_field' - field.number = 1 - field.label = descriptor.FieldDescriptor.Label.OPTIONAL - field.variant = descriptor.FieldDescriptor.Variant.INT64 - - self.MakeMessage(fields=[field]) - - generate_proto.format_proto_file(self.file_descriptor, self.output) - self.assertEquals('\n\n' - 'message MyMessage {\n' - ' optional int64 integer_field = 1;\n' - '}\n', - self.result) - - def testSingleFieldWithDefault(self): - field = descriptor.FieldDescriptor() - field.name = 'integer_field' - field.number = 1 - field.label = descriptor.FieldDescriptor.Label.OPTIONAL - field.variant = descriptor.FieldDescriptor.Variant.INT64 - field.default_value = '10' - - self.MakeMessage(fields=[field]) - - generate_proto.format_proto_file(self.file_descriptor, self.output) - self.assertEquals('\n\n' - 'message MyMessage {\n' - ' optional int64 integer_field = 1 [default=10];\n' - '}\n', - self.result) - - def testRepeatedFieldWithDefault(self): - field = descriptor.FieldDescriptor() - field.name = 'integer_field' - field.number = 1 - field.label = descriptor.FieldDescriptor.Label.REPEATED - field.variant = descriptor.FieldDescriptor.Variant.INT64 - field.default_value = '[10, 20]' - - self.MakeMessage(fields=[field]) - - generate_proto.format_proto_file(self.file_descriptor, self.output) - self.assertEquals('\n\n' - 'message MyMessage {\n' - ' repeated int64 integer_field = 1;\n' - '}\n', - self.result) - - def testSingleFieldWithDefaultString(self): - field = descriptor.FieldDescriptor() - field.name = 'string_field' - field.number = 1 - field.label = descriptor.FieldDescriptor.Label.OPTIONAL - field.variant = descriptor.FieldDescriptor.Variant.STRING - field.default_value = 'hello' - - self.MakeMessage(fields=[field]) - - generate_proto.format_proto_file(self.file_descriptor, self.output) - self.assertEquals('\n\n' - 'message MyMessage {\n' - " optional string string_field = 1 [default='hello'];\n" - '}\n', - self.result) - - def testSingleFieldWithDefaultEmptyString(self): - field = descriptor.FieldDescriptor() - field.name = 'string_field' - field.number = 1 - field.label = descriptor.FieldDescriptor.Label.OPTIONAL - field.variant = descriptor.FieldDescriptor.Variant.STRING - field.default_value = '' - - self.MakeMessage(fields=[field]) - - generate_proto.format_proto_file(self.file_descriptor, self.output) - self.assertEquals('\n\n' - 'message MyMessage {\n' - " optional string string_field = 1 [default=''];\n" - '}\n', - self.result) - - def testSingleFieldWithDefaultMessage(self): - field = descriptor.FieldDescriptor() - field.name = 'message_field' - field.number = 1 - field.label = descriptor.FieldDescriptor.Label.OPTIONAL - field.variant = descriptor.FieldDescriptor.Variant.MESSAGE - field.type_name = 'MyNestedMessage' - field.default_value = 'not valid' - - self.MakeMessage(fields=[field]) - - generate_proto.format_proto_file(self.file_descriptor, self.output) - self.assertEquals('\n\n' - 'message MyMessage {\n' - " optional MyNestedMessage message_field = 1;\n" - '}\n', - self.result) - - def testSingleFieldWithDefaultEnum(self): - field = descriptor.FieldDescriptor() - field.name = 'enum_field' - field.number = 1 - field.label = descriptor.FieldDescriptor.Label.OPTIONAL - field.variant = descriptor.FieldDescriptor.Variant.ENUM - field.type_name = 'my_package.MyEnum' - field.default_value = '17' - - self.MakeMessage(fields=[field]) - - generate_proto.format_proto_file(self.file_descriptor, self.output) - self.assertEquals('\n\n' - 'message MyMessage {\n' - " optional my_package.MyEnum enum_field = 1 " - "[default=17];\n" - '}\n', - self.result) - - -def main(): - unittest.main() - - -if __name__ == '__main__': - main() - diff --git a/endpoints/internal/protorpc/generate_python_test.py b/endpoints/internal/protorpc/generate_python_test.py deleted file mode 100644 index 21a05cc..0000000 --- a/endpoints/internal/protorpc/generate_python_test.py +++ /dev/null @@ -1,362 +0,0 @@ -#!/usr/bin/env python -# -# Copyright 2010 Google Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -"""Tests for protorpc.generate_python_test.""" - -__author__ = 'rafek@google.com (Rafe Kaplan)' - - -import os -import shutil -import sys -import tempfile -import unittest - -from protorpc import descriptor -from protorpc import generate_python -from protorpc import test_util -from protorpc import util - - -class ModuleInterfaceTest(test_util.ModuleInterfaceTest, - test_util.TestCase): - - MODULE = generate_python - - -class FormatPythonFileTest(test_util.TestCase): - - def setUp(self): - self.original_path = list(sys.path) - self.original_modules = dict(sys.modules) - sys.path = list(sys.path) - self.file_descriptor = descriptor.FileDescriptor() - - # Create temporary directory and add to Python path so that generated - # Python code can be easily parsed, imported and executed. - self.temp_dir = tempfile.mkdtemp() - sys.path.append(self.temp_dir) - - def tearDown(self): - # Reset path. - sys.path[:] = [] - sys.path.extend(self.original_path) - - # Reset modules. - sys.modules.clear() - sys.modules.update(self.original_modules) - - # Remove temporary directory. - try: - shutil.rmtree(self.temp_dir) - except IOError: - pass - - def DoPythonTest(self, file_descriptor): - """Execute python test based on a FileDescriptor object. - - The full test of the Python code generation is to generate a Python source - code file, import the module and regenerate the FileDescriptor from it. - If the generated FileDescriptor is the same as the original, it means that - the generated source code correctly implements the actual FileDescriptor. - """ - file_name = os.path.join(self.temp_dir, - '%s.py' % (file_descriptor.package or 'blank',)) - source_file = open(file_name, 'wt') - try: - generate_python.format_python_file(file_descriptor, source_file) - finally: - source_file.close() - - module_to_import = file_descriptor.package or 'blank' - module = __import__(module_to_import) - - if not file_descriptor.package: - self.assertFalse(hasattr(module, 'package')) - module.package = '' # Create package name so that comparison will work. - - reloaded_descriptor = descriptor.describe_file(module) - - # Need to sort both message_types fields because document order is never - # Ensured. - # TODO(rafek): Ensure document order. - if reloaded_descriptor.message_types: - reloaded_descriptor.message_types = sorted( - reloaded_descriptor.message_types, key=lambda v: v.name) - - if file_descriptor.message_types: - file_descriptor.message_types = sorted( - file_descriptor.message_types, key=lambda v: v.name) - - self.assertEquals(file_descriptor, reloaded_descriptor) - - @util.positional(2) - def DoMessageTest(self, - field_descriptors, - message_types=None, - enum_types=None): - """Execute message generation test based on FieldDescriptor objects. - - Args: - field_descriptor: List of FieldDescriptor object to generate and test. - message_types: List of other MessageDescriptor objects that the new - Message class depends on. - enum_types: List of EnumDescriptor objects that the new Message class - depends on. - """ - file_descriptor = descriptor.FileDescriptor() - file_descriptor.package = 'my_package' - - message_descriptor = descriptor.MessageDescriptor() - message_descriptor.name = 'MyMessage' - - message_descriptor.fields = list(field_descriptors) - - file_descriptor.message_types = message_types or [] - file_descriptor.message_types.append(message_descriptor) - - if enum_types: - file_descriptor.enum_types = list(enum_types) - - self.DoPythonTest(file_descriptor) - - def testBlankPackage(self): - self.DoPythonTest(descriptor.FileDescriptor()) - - def testEmptyPackage(self): - file_descriptor = descriptor.FileDescriptor() - file_descriptor.package = 'mypackage' - self.DoPythonTest(file_descriptor) - - def testSingleField(self): - field = descriptor.FieldDescriptor() - field.name = 'integer_field' - field.number = 1 - field.label = descriptor.FieldDescriptor.Label.OPTIONAL - field.variant = descriptor.FieldDescriptor.Variant.INT64 - - self.DoMessageTest([field]) - - def testMessageField_InternalReference(self): - other_message = descriptor.MessageDescriptor() - other_message.name = 'OtherMessage' - - field = descriptor.FieldDescriptor() - field.name = 'message_field' - field.number = 1 - field.label = descriptor.FieldDescriptor.Label.OPTIONAL - field.variant = descriptor.FieldDescriptor.Variant.MESSAGE - field.type_name = 'my_package.OtherMessage' - - self.DoMessageTest([field], message_types=[other_message]) - - def testMessageField_ExternalReference(self): - field = descriptor.FieldDescriptor() - field.name = 'message_field' - field.number = 1 - field.label = descriptor.FieldDescriptor.Label.OPTIONAL - field.variant = descriptor.FieldDescriptor.Variant.MESSAGE - field.type_name = 'protorpc.registry.GetFileSetResponse' - - self.DoMessageTest([field]) - - def testEnumField_InternalReference(self): - enum = descriptor.EnumDescriptor() - enum.name = 'Color' - - field = descriptor.FieldDescriptor() - field.name = 'color' - field.number = 1 - field.label = descriptor.FieldDescriptor.Label.OPTIONAL - field.variant = descriptor.FieldDescriptor.Variant.ENUM - field.type_name = 'my_package.Color' - - self.DoMessageTest([field], enum_types=[enum]) - - def testEnumField_ExternalReference(self): - field = descriptor.FieldDescriptor() - field.name = 'color' - field.number = 1 - field.label = descriptor.FieldDescriptor.Label.OPTIONAL - field.variant = descriptor.FieldDescriptor.Variant.ENUM - field.type_name = 'protorpc.descriptor.FieldDescriptor.Label' - - self.DoMessageTest([field]) - - def testDateTimeField(self): - field = descriptor.FieldDescriptor() - field.name = 'timestamp' - field.number = 1 - field.label = descriptor.FieldDescriptor.Label.OPTIONAL - field.variant = descriptor.FieldDescriptor.Variant.MESSAGE - field.type_name = 'protorpc.message_types.DateTimeMessage' - - self.DoMessageTest([field]) - - def testNonDefaultVariant(self): - field = descriptor.FieldDescriptor() - field.name = 'integer_field' - field.number = 1 - field.label = descriptor.FieldDescriptor.Label.OPTIONAL - field.variant = descriptor.FieldDescriptor.Variant.UINT64 - - self.DoMessageTest([field]) - - def testRequiredField(self): - field = descriptor.FieldDescriptor() - field.name = 'integer_field' - field.number = 1 - field.label = descriptor.FieldDescriptor.Label.REQUIRED - field.variant = descriptor.FieldDescriptor.Variant.INT64 - - self.DoMessageTest([field]) - - def testRepeatedField(self): - field = descriptor.FieldDescriptor() - field.name = 'integer_field' - field.number = 1 - field.label = descriptor.FieldDescriptor.Label.REPEATED - field.variant = descriptor.FieldDescriptor.Variant.INT64 - - self.DoMessageTest([field]) - - def testIntegerDefaultValue(self): - field = descriptor.FieldDescriptor() - field.name = 'integer_field' - field.number = 1 - field.label = descriptor.FieldDescriptor.Label.OPTIONAL - field.variant = descriptor.FieldDescriptor.Variant.INT64 - field.default_value = '10' - - self.DoMessageTest([field]) - - def testFloatDefaultValue(self): - field = descriptor.FieldDescriptor() - field.name = 'float_field' - field.number = 1 - field.label = descriptor.FieldDescriptor.Label.OPTIONAL - field.variant = descriptor.FieldDescriptor.Variant.DOUBLE - field.default_value = '10.1' - - self.DoMessageTest([field]) - - def testStringDefaultValue(self): - field = descriptor.FieldDescriptor() - field.name = 'string_field' - field.number = 1 - field.label = descriptor.FieldDescriptor.Label.OPTIONAL - field.variant = descriptor.FieldDescriptor.Variant.STRING - field.default_value = u'a nice lovely string\'s "string"' - - self.DoMessageTest([field]) - - def testEnumDefaultValue(self): - field = descriptor.FieldDescriptor() - field.name = 'label' - field.number = 1 - field.label = descriptor.FieldDescriptor.Label.OPTIONAL - field.variant = descriptor.FieldDescriptor.Variant.ENUM - field.type_name = 'protorpc.descriptor.FieldDescriptor.Label' - field.default_value = '2' - - self.DoMessageTest([field]) - - def testMultiFields(self): - field1 = descriptor.FieldDescriptor() - field1.name = 'integer_field' - field1.number = 1 - field1.label = descriptor.FieldDescriptor.Label.OPTIONAL - field1.variant = descriptor.FieldDescriptor.Variant.INT64 - - field2 = descriptor.FieldDescriptor() - field2.name = 'string_field' - field2.number = 2 - field2.label = descriptor.FieldDescriptor.Label.OPTIONAL - field2.variant = descriptor.FieldDescriptor.Variant.STRING - - field3 = descriptor.FieldDescriptor() - field3.name = 'unsigned_integer_field' - field3.number = 3 - field3.label = descriptor.FieldDescriptor.Label.OPTIONAL - field3.variant = descriptor.FieldDescriptor.Variant.UINT64 - - self.DoMessageTest([field1, field2, field3]) - - def testNestedMessage(self): - message = descriptor.MessageDescriptor() - message.name = 'OuterMessage' - - inner_message = descriptor.MessageDescriptor() - inner_message.name = 'InnerMessage' - - inner_inner_message = descriptor.MessageDescriptor() - inner_inner_message.name = 'InnerInnerMessage' - - inner_message.message_types = [inner_inner_message] - - message.message_types = [inner_message] - - file_descriptor = descriptor.FileDescriptor() - file_descriptor.message_types = [message] - - self.DoPythonTest(file_descriptor) - - def testNestedEnum(self): - message = descriptor.MessageDescriptor() - message.name = 'OuterMessage' - - inner_enum = descriptor.EnumDescriptor() - inner_enum.name = 'InnerEnum' - - message.enum_types = [inner_enum] - - file_descriptor = descriptor.FileDescriptor() - file_descriptor.message_types = [message] - - self.DoPythonTest(file_descriptor) - - def testService(self): - service = descriptor.ServiceDescriptor() - service.name = 'TheService' - - method1 = descriptor.MethodDescriptor() - method1.name = 'method1' - method1.request_type = 'protorpc.descriptor.FileDescriptor' - method1.response_type = 'protorpc.descriptor.MethodDescriptor' - - service.methods = [method1] - - file_descriptor = descriptor.FileDescriptor() - file_descriptor.service_types = [service] - - self.DoPythonTest(file_descriptor) - - # Test to make sure that implementation methods raise an exception. - import blank - service_instance = blank.TheService() - self.assertRaisesWithRegexpMatch(NotImplementedError, - 'Method method1 is not implemented', - service_instance.method1, - descriptor.FileDescriptor()) - - -def main(): - unittest.main() - - -if __name__ == '__main__': - main() diff --git a/endpoints/internal/protorpc/generate_test.py b/endpoints/internal/protorpc/generate_test.py deleted file mode 100644 index 7b9893a..0000000 --- a/endpoints/internal/protorpc/generate_test.py +++ /dev/null @@ -1,152 +0,0 @@ -#!/usr/bin/env python -# -# Copyright 2010 Google Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -"""Tests for protorpc.generate.""" - -__author__ = 'rafek@google.com (Rafe Kaplan)' - -import operator - -import cStringIO -import sys -import unittest - -from protorpc import generate -from protorpc import test_util - - -class ModuleInterfaceTest(test_util.ModuleInterfaceTest, - test_util.TestCase): - - MODULE = generate - - -class IndentWriterTest(test_util.TestCase): - - def setUp(self): - self.out = cStringIO.StringIO() - self.indent_writer = generate.IndentWriter(self.out) - - def testWriteLine(self): - self.indent_writer.write_line('This is a line') - self.indent_writer.write_line('This is another line') - - self.assertEquals('This is a line\n' - 'This is another line\n', - self.out.getvalue()) - - def testLeftShift(self): - self.run_count = 0 - def mock_write_line(line): - self.run_count += 1 - self.assertEquals('same as calling write_line', line) - - self.indent_writer.write_line = mock_write_line - self.indent_writer << 'same as calling write_line' - self.assertEquals(1, self.run_count) - - def testIndentation(self): - self.indent_writer << 'indent 0' - self.indent_writer.begin_indent() - self.indent_writer << 'indent 1' - self.indent_writer.begin_indent() - self.indent_writer << 'indent 2' - self.indent_writer.end_indent() - self.indent_writer << 'end 2' - self.indent_writer.end_indent() - self.indent_writer << 'end 1' - self.assertRaises(generate.IndentationError, - self.indent_writer.end_indent) - - self.assertEquals('indent 0\n' - ' indent 1\n' - ' indent 2\n' - ' end 2\n' - 'end 1\n', - self.out.getvalue()) - - def testBlankLine(self): - self.indent_writer << '' - self.indent_writer.begin_indent() - self.indent_writer << '' - self.assertEquals('\n\n', self.out.getvalue()) - - def testNoneInvalid(self): - self.assertRaises( - TypeError, operator.lshift, self.indent_writer, None) - - def testAltIndentation(self): - self.indent_writer = generate.IndentWriter(self.out, indent_space=3) - self.indent_writer << 'indent 0' - self.assertEquals(0, self.indent_writer.indent_level) - self.indent_writer.begin_indent() - self.indent_writer << 'indent 1' - self.assertEquals(1, self.indent_writer.indent_level) - self.indent_writer.begin_indent() - self.indent_writer << 'indent 2' - self.assertEquals(2, self.indent_writer.indent_level) - self.indent_writer.end_indent() - self.indent_writer << 'end 2' - self.assertEquals(1, self.indent_writer.indent_level) - self.indent_writer.end_indent() - self.indent_writer << 'end 1' - self.assertEquals(0, self.indent_writer.indent_level) - self.assertRaises(generate.IndentationError, - self.indent_writer.end_indent) - self.assertEquals(0, self.indent_writer.indent_level) - - self.assertEquals('indent 0\n' - ' indent 1\n' - ' indent 2\n' - ' end 2\n' - 'end 1\n', - self.out.getvalue()) - - def testIndent(self): - self.indent_writer << 'indent 0' - self.assertEquals(0, self.indent_writer.indent_level) - - def indent1(): - self.indent_writer << 'indent 1' - self.assertEquals(1, self.indent_writer.indent_level) - - def indent2(): - self.indent_writer << 'indent 2' - self.assertEquals(2, self.indent_writer.indent_level) - test_util.do_with(self.indent_writer.indent(), indent2) - - self.assertEquals(1, self.indent_writer.indent_level) - self.indent_writer << 'end 2' - test_util.do_with(self.indent_writer.indent(), indent1) - - self.assertEquals(0, self.indent_writer.indent_level) - self.indent_writer << 'end 1' - - self.assertEquals('indent 0\n' - ' indent 1\n' - ' indent 2\n' - ' end 2\n' - 'end 1\n', - self.out.getvalue()) - - -def main(): - unittest.main() - - -if __name__ == '__main__': - main() diff --git a/endpoints/internal/protorpc/message_types_test.py b/endpoints/internal/protorpc/message_types_test.py deleted file mode 100644 index b061cdf..0000000 --- a/endpoints/internal/protorpc/message_types_test.py +++ /dev/null @@ -1,88 +0,0 @@ -#!/usr/bin/env python -# -# Copyright 2013 Google Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -"""Tests for protorpc.message_types.""" - -__author__ = 'rafek@google.com (Rafe Kaplan)' - - -import datetime - -import unittest - -from protorpc import message_types -from protorpc import messages -from protorpc import test_util -from protorpc import util - - -class ModuleInterfaceTest(test_util.ModuleInterfaceTest, - test_util.TestCase): - - MODULE = message_types - - -class DateTimeFieldTest(test_util.TestCase): - - def testValueToMessage(self): - field = message_types.DateTimeField(1) - message = field.value_to_message(datetime.datetime(2033, 2, 4, 11, 22, 10)) - self.assertEqual(message_types.DateTimeMessage(milliseconds=1991128930000), - message) - - def testValueToMessageBadValue(self): - field = message_types.DateTimeField(1) - self.assertRaisesWithRegexpMatch( - messages.EncodeError, - 'Expected type datetime, got int: 20', - field.value_to_message, 20) - - def testValueToMessageWithTimeZone(self): - time_zone = util.TimeZoneOffset(60 * 10) - field = message_types.DateTimeField(1) - message = field.value_to_message( - datetime.datetime(2033, 2, 4, 11, 22, 10, tzinfo=time_zone)) - self.assertEqual(message_types.DateTimeMessage(milliseconds=1991128930000, - time_zone_offset=600), - message) - - def testValueFromMessage(self): - message = message_types.DateTimeMessage(milliseconds=1991128000000) - field = message_types.DateTimeField(1) - timestamp = field.value_from_message(message) - self.assertEqual(datetime.datetime(2033, 2, 4, 11, 6, 40), - timestamp) - - def testValueFromMessageBadValue(self): - field = message_types.DateTimeField(1) - self.assertRaisesWithRegexpMatch( - messages.DecodeError, - 'Expected type DateTimeMessage, got VoidMessage: ', - field.value_from_message, message_types.VoidMessage()) - - def testValueFromMessageWithTimeZone(self): - message = message_types.DateTimeMessage(milliseconds=1991128000000, - time_zone_offset=300) - field = message_types.DateTimeField(1) - timestamp = field.value_from_message(message) - time_zone = util.TimeZoneOffset(60 * 5) - self.assertEqual(datetime.datetime(2033, 2, 4, 11, 6, 40, tzinfo=time_zone), - timestamp) - - -if __name__ == '__main__': - unittest.main() diff --git a/endpoints/internal/protorpc/messages_test.py b/endpoints/internal/protorpc/messages_test.py deleted file mode 100644 index 9460b31..0000000 --- a/endpoints/internal/protorpc/messages_test.py +++ /dev/null @@ -1,2109 +0,0 @@ -#!/usr/bin/env python -# -# Copyright 2010 Google Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -"""Tests for protorpc.messages.""" -import six - -__author__ = 'rafek@google.com (Rafe Kaplan)' - - -import pickle -import re -import sys -import types -import unittest - -from protorpc import descriptor -from protorpc import message_types -from protorpc import messages -from protorpc import test_util - - -class ModuleInterfaceTest(test_util.ModuleInterfaceTest, - test_util.TestCase): - - MODULE = messages - - -class ValidationErrorTest(test_util.TestCase): - - def testStr_NoFieldName(self): - """Test string version of ValidationError when no name provided.""" - self.assertEquals('Validation error', - str(messages.ValidationError('Validation error'))) - - def testStr_FieldName(self): - """Test string version of ValidationError when no name provided.""" - validation_error = messages.ValidationError('Validation error') - validation_error.field_name = 'a_field' - self.assertEquals('Validation error', str(validation_error)) - - -class EnumTest(test_util.TestCase): - - def setUp(self): - """Set up tests.""" - # Redefine Color class in case so that changes to it (an error) in one test - # does not affect other tests. - global Color - class Color(messages.Enum): - RED = 20 - ORANGE = 2 - YELLOW = 40 - GREEN = 4 - BLUE = 50 - INDIGO = 5 - VIOLET = 80 - - def testNames(self): - """Test that names iterates over enum names.""" - self.assertEquals( - set(['BLUE', 'GREEN', 'INDIGO', 'ORANGE', 'RED', 'VIOLET', 'YELLOW']), - set(Color.names())) - - def testNumbers(self): - """Tests that numbers iterates of enum numbers.""" - self.assertEquals(set([2, 4, 5, 20, 40, 50, 80]), set(Color.numbers())) - - def testIterate(self): - """Test that __iter__ iterates over all enum values.""" - self.assertEquals(set(Color), - set([Color.RED, - Color.ORANGE, - Color.YELLOW, - Color.GREEN, - Color.BLUE, - Color.INDIGO, - Color.VIOLET])) - - def testNaturalOrder(self): - """Test that natural order enumeration is in numeric order.""" - self.assertEquals([Color.ORANGE, - Color.GREEN, - Color.INDIGO, - Color.RED, - Color.YELLOW, - Color.BLUE, - Color.VIOLET], - sorted(Color)) - - def testByName(self): - """Test look-up by name.""" - self.assertEquals(Color.RED, Color.lookup_by_name('RED')) - self.assertRaises(KeyError, Color.lookup_by_name, 20) - self.assertRaises(KeyError, Color.lookup_by_name, Color.RED) - - def testByNumber(self): - """Test look-up by number.""" - self.assertRaises(KeyError, Color.lookup_by_number, 'RED') - self.assertEquals(Color.RED, Color.lookup_by_number(20)) - self.assertRaises(KeyError, Color.lookup_by_number, Color.RED) - - def testConstructor(self): - """Test that constructor look-up by name or number.""" - self.assertEquals(Color.RED, Color('RED')) - self.assertEquals(Color.RED, Color(u'RED')) - self.assertEquals(Color.RED, Color(20)) - if six.PY2: - self.assertEquals(Color.RED, Color(long(20))) - self.assertEquals(Color.RED, Color(Color.RED)) - self.assertRaises(TypeError, Color, 'Not exists') - self.assertRaises(TypeError, Color, 'Red') - self.assertRaises(TypeError, Color, 100) - self.assertRaises(TypeError, Color, 10.0) - - def testLen(self): - """Test that len function works to count enums.""" - self.assertEquals(7, len(Color)) - - def testNoSubclasses(self): - """Test that it is not possible to sub-class enum classes.""" - def declare_subclass(): - class MoreColor(Color): - pass - self.assertRaises(messages.EnumDefinitionError, - declare_subclass) - - def testClassNotMutable(self): - """Test that enum classes themselves are not mutable.""" - self.assertRaises(AttributeError, - setattr, - Color, - 'something_new', - 10) - - def testInstancesMutable(self): - """Test that enum instances are not mutable.""" - self.assertRaises(TypeError, - setattr, - Color.RED, - 'something_new', - 10) - - def testDefEnum(self): - """Test def_enum works by building enum class from dict.""" - WeekDay = messages.Enum.def_enum({'Monday': 1, - 'Tuesday': 2, - 'Wednesday': 3, - 'Thursday': 4, - 'Friday': 6, - 'Saturday': 7, - 'Sunday': 8}, - 'WeekDay') - self.assertEquals('Wednesday', WeekDay(3).name) - self.assertEquals(6, WeekDay('Friday').number) - self.assertEquals(WeekDay.Sunday, WeekDay('Sunday')) - - def testNonInt(self): - """Test that non-integer values rejection by enum def.""" - self.assertRaises(messages.EnumDefinitionError, - messages.Enum.def_enum, - {'Bad': '1'}, - 'BadEnum') - - def testNegativeInt(self): - """Test that negative numbers rejection by enum def.""" - self.assertRaises(messages.EnumDefinitionError, - messages.Enum.def_enum, - {'Bad': -1}, - 'BadEnum') - - def testLowerBound(self): - """Test that zero is accepted by enum def.""" - class NotImportant(messages.Enum): - """Testing for value zero""" - VALUE = 0 - - self.assertEquals(0, int(NotImportant.VALUE)) - - def testTooLargeInt(self): - """Test that numbers too large are rejected.""" - self.assertRaises(messages.EnumDefinitionError, - messages.Enum.def_enum, - {'Bad': (2 ** 29)}, - 'BadEnum') - - def testRepeatedInt(self): - """Test duplicated numbers are forbidden.""" - self.assertRaises(messages.EnumDefinitionError, - messages.Enum.def_enum, - {'Ok': 1, 'Repeated': 1}, - 'BadEnum') - - def testStr(self): - """Test converting to string.""" - self.assertEquals('RED', str(Color.RED)) - self.assertEquals('ORANGE', str(Color.ORANGE)) - - def testInt(self): - """Test converting to int.""" - self.assertEquals(20, int(Color.RED)) - self.assertEquals(2, int(Color.ORANGE)) - - def testRepr(self): - """Test enum representation.""" - self.assertEquals('Color(RED, 20)', repr(Color.RED)) - self.assertEquals('Color(YELLOW, 40)', repr(Color.YELLOW)) - - def testDocstring(self): - """Test that docstring is supported ok.""" - class NotImportant(messages.Enum): - """I have a docstring.""" - - VALUE1 = 1 - - self.assertEquals('I have a docstring.', NotImportant.__doc__) - - def testDeleteEnumValue(self): - """Test that enum values cannot be deleted.""" - self.assertRaises(TypeError, delattr, Color, 'RED') - - def testEnumName(self): - """Test enum name.""" - module_name = test_util.get_module_name(EnumTest) - self.assertEquals('%s.Color' % module_name, Color.definition_name()) - self.assertEquals(module_name, Color.outer_definition_name()) - self.assertEquals(module_name, Color.definition_package()) - - def testDefinitionName_OverrideModule(self): - """Test enum module is overriden by module package name.""" - global package - try: - package = 'my.package' - self.assertEquals('my.package.Color', Color.definition_name()) - self.assertEquals('my.package', Color.outer_definition_name()) - self.assertEquals('my.package', Color.definition_package()) - finally: - del package - - def testDefinitionName_NoModule(self): - """Test what happens when there is no module for enum.""" - class Enum1(messages.Enum): - pass - - original_modules = sys.modules - sys.modules = dict(sys.modules) - try: - del sys.modules[__name__] - self.assertEquals('Enum1', Enum1.definition_name()) - self.assertEquals(None, Enum1.outer_definition_name()) - self.assertEquals(None, Enum1.definition_package()) - self.assertEquals(six.text_type, type(Enum1.definition_name())) - finally: - sys.modules = original_modules - - def testDefinitionName_Nested(self): - """Test nested Enum names.""" - class MyMessage(messages.Message): - - class NestedEnum(messages.Enum): - - pass - - class NestedMessage(messages.Message): - - class NestedEnum(messages.Enum): - - pass - - module_name = test_util.get_module_name(EnumTest) - self.assertEquals('%s.MyMessage.NestedEnum' % module_name, - MyMessage.NestedEnum.definition_name()) - self.assertEquals('%s.MyMessage' % module_name, - MyMessage.NestedEnum.outer_definition_name()) - self.assertEquals(module_name, - MyMessage.NestedEnum.definition_package()) - - self.assertEquals('%s.MyMessage.NestedMessage.NestedEnum' % module_name, - MyMessage.NestedMessage.NestedEnum.definition_name()) - self.assertEquals( - '%s.MyMessage.NestedMessage' % module_name, - MyMessage.NestedMessage.NestedEnum.outer_definition_name()) - self.assertEquals(module_name, - MyMessage.NestedMessage.NestedEnum.definition_package()) - - def testMessageDefinition(self): - """Test that enumeration knows its enclosing message definition.""" - class OuterEnum(messages.Enum): - pass - - self.assertEquals(None, OuterEnum.message_definition()) - - class OuterMessage(messages.Message): - - class InnerEnum(messages.Enum): - pass - - self.assertEquals(OuterMessage, OuterMessage.InnerEnum.message_definition()) - - def testComparison(self): - """Test comparing various enums to different types.""" - class Enum1(messages.Enum): - VAL1 = 1 - VAL2 = 2 - - class Enum2(messages.Enum): - VAL1 = 1 - - self.assertEquals(Enum1.VAL1, Enum1.VAL1) - self.assertNotEquals(Enum1.VAL1, Enum1.VAL2) - self.assertNotEquals(Enum1.VAL1, Enum2.VAL1) - self.assertNotEquals(Enum1.VAL1, 'VAL1') - self.assertNotEquals(Enum1.VAL1, 1) - self.assertNotEquals(Enum1.VAL1, 2) - self.assertNotEquals(Enum1.VAL1, None) - self.assertNotEquals(Enum1.VAL1, Enum2.VAL1) - - self.assertTrue(Enum1.VAL1 < Enum1.VAL2) - self.assertTrue(Enum1.VAL2 > Enum1.VAL1) - - self.assertNotEquals(1, Enum2.VAL1) - - def testPickle(self): - """Testing pickling and unpickling of Enum instances.""" - colors = list(Color) - unpickled = pickle.loads(pickle.dumps(colors)) - self.assertEquals(colors, unpickled) - # Unpickling shouldn't create new enum instances. - for i, color in enumerate(colors): - self.assertTrue(color is unpickled[i]) - - -class FieldListTest(test_util.TestCase): - - def setUp(self): - self.integer_field = messages.IntegerField(1, repeated=True) - - def testConstructor(self): - self.assertEquals([1, 2, 3], - messages.FieldList(self.integer_field, [1, 2, 3])) - self.assertEquals([1, 2, 3], - messages.FieldList(self.integer_field, (1, 2, 3))) - self.assertEquals([], messages.FieldList(self.integer_field, [])) - - def testNone(self): - self.assertRaises(TypeError, messages.FieldList, self.integer_field, None) - - def testDoNotAutoConvertString(self): - string_field = messages.StringField(1, repeated=True) - self.assertRaises(messages.ValidationError, - messages.FieldList, string_field, 'abc') - - def testConstructorCopies(self): - a_list = [1, 3, 6] - field_list = messages.FieldList(self.integer_field, a_list) - self.assertFalse(a_list is field_list) - self.assertFalse(field_list is - messages.FieldList(self.integer_field, field_list)) - - def testNonRepeatedField(self): - self.assertRaisesWithRegexpMatch( - messages.FieldDefinitionError, - 'FieldList may only accept repeated fields', - messages.FieldList, - messages.IntegerField(1), - []) - - def testConstructor_InvalidValues(self): - self.assertRaisesWithRegexpMatch( - messages.ValidationError, - re.escape("Expected type %r " - "for IntegerField, found 1 (type %r)" - % (six.integer_types, str)), - messages.FieldList, self.integer_field, ["1", "2", "3"]) - - def testConstructor_Scalars(self): - self.assertRaisesWithRegexpMatch( - messages.ValidationError, - "IntegerField is repeated. Found: 3", - messages.FieldList, self.integer_field, 3) - - self.assertRaisesWithRegexpMatch( - messages.ValidationError, - "IntegerField is repeated. Found: <(list[_]?|sequence)iterator object", - messages.FieldList, self.integer_field, iter([1, 2, 3])) - - def testSetSlice(self): - field_list = messages.FieldList(self.integer_field, [1, 2, 3, 4, 5]) - field_list[1:3] = [10, 20] - self.assertEquals([1, 10, 20, 4, 5], field_list) - - def testSetSlice_InvalidValues(self): - field_list = messages.FieldList(self.integer_field, [1, 2, 3, 4, 5]) - - def setslice(): - field_list[1:3] = ['10', '20'] - - msg_re = re.escape("Expected type %r " - "for IntegerField, found 10 (type %r)" - % (six.integer_types, str)) - self.assertRaisesWithRegexpMatch( - messages.ValidationError, - msg_re, - setslice) - - def testSetItem(self): - field_list = messages.FieldList(self.integer_field, [2]) - field_list[0] = 10 - self.assertEquals([10], field_list) - - def testSetItem_InvalidValues(self): - field_list = messages.FieldList(self.integer_field, [2]) - - def setitem(): - field_list[0] = '10' - self.assertRaisesWithRegexpMatch( - messages.ValidationError, - re.escape("Expected type %r " - "for IntegerField, found 10 (type %r)" - % (six.integer_types, str)), - setitem) - - def testAppend(self): - field_list = messages.FieldList(self.integer_field, [2]) - field_list.append(10) - self.assertEquals([2, 10], field_list) - - def testAppend_InvalidValues(self): - field_list = messages.FieldList(self.integer_field, [2]) - field_list.name = 'a_field' - - def append(): - field_list.append('10') - self.assertRaisesWithRegexpMatch( - messages.ValidationError, - re.escape("Expected type %r " - "for IntegerField, found 10 (type %r)" - % (six.integer_types, str)), - append) - - def testExtend(self): - field_list = messages.FieldList(self.integer_field, [2]) - field_list.extend([10]) - self.assertEquals([2, 10], field_list) - - def testExtend_InvalidValues(self): - field_list = messages.FieldList(self.integer_field, [2]) - - def extend(): - field_list.extend(['10']) - self.assertRaisesWithRegexpMatch( - messages.ValidationError, - re.escape("Expected type %r " - "for IntegerField, found 10 (type %r)" - % (six.integer_types, str)), - extend) - - def testInsert(self): - field_list = messages.FieldList(self.integer_field, [2, 3]) - field_list.insert(1, 10) - self.assertEquals([2, 10, 3], field_list) - - def testInsert_InvalidValues(self): - field_list = messages.FieldList(self.integer_field, [2, 3]) - - def insert(): - field_list.insert(1, '10') - self.assertRaisesWithRegexpMatch( - messages.ValidationError, - re.escape("Expected type %r " - "for IntegerField, found 10 (type %r)" - % (six.integer_types, str)), - insert) - - def testPickle(self): - """Testing pickling and unpickling of disconnected FieldList instances.""" - field_list = messages.FieldList(self.integer_field, [1, 2, 3, 4, 5]) - unpickled = pickle.loads(pickle.dumps(field_list)) - self.assertEquals(field_list, unpickled) - self.assertIsInstance(unpickled.field, messages.IntegerField) - self.assertEquals(1, unpickled.field.number) - self.assertTrue(unpickled.field.repeated) - - -class FieldTest(test_util.TestCase): - - def ActionOnAllFieldClasses(self, action): - """Test all field classes except Message and Enum. - - Message and Enum require separate tests. - - Args: - action: Callable that takes the field class as a parameter. - """ - for field_class in (messages.IntegerField, - messages.FloatField, - messages.BooleanField, - messages.BytesField, - messages.StringField, - ): - action(field_class) - - def testNumberAttribute(self): - """Test setting the number attribute.""" - def action(field_class): - # Check range. - self.assertRaises(messages.InvalidNumberError, - field_class, - 0) - self.assertRaises(messages.InvalidNumberError, - field_class, - -1) - self.assertRaises(messages.InvalidNumberError, - field_class, - messages.MAX_FIELD_NUMBER + 1) - - # Check reserved. - self.assertRaises(messages.InvalidNumberError, - field_class, - messages.FIRST_RESERVED_FIELD_NUMBER) - self.assertRaises(messages.InvalidNumberError, - field_class, - messages.LAST_RESERVED_FIELD_NUMBER) - self.assertRaises(messages.InvalidNumberError, - field_class, - '1') - - # This one should work. - field_class(number=1) - self.ActionOnAllFieldClasses(action) - - def testRequiredAndRepeated(self): - """Test setting the required and repeated fields.""" - def action(field_class): - field_class(1, required=True) - field_class(1, repeated=True) - self.assertRaises(messages.FieldDefinitionError, - field_class, - 1, - required=True, - repeated=True) - self.ActionOnAllFieldClasses(action) - - def testInvalidVariant(self): - """Test field with invalid variants.""" - def action(field_class): - if field_class is not message_types.DateTimeField: - self.assertRaises(messages.InvalidVariantError, - field_class, - 1, - variant=messages.Variant.ENUM) - self.ActionOnAllFieldClasses(action) - - def testDefaultVariant(self): - """Test that default variant is used when not set.""" - def action(field_class): - field = field_class(1) - self.assertEquals(field_class.DEFAULT_VARIANT, field.variant) - - self.ActionOnAllFieldClasses(action) - - def testAlternateVariant(self): - """Test that default variant is used when not set.""" - field = messages.IntegerField(1, variant=messages.Variant.UINT32) - self.assertEquals(messages.Variant.UINT32, field.variant) - - def testDefaultFields_Single(self): - """Test default field is correct type (single).""" - defaults = {messages.IntegerField: 10, - messages.FloatField: 1.5, - messages.BooleanField: False, - messages.BytesField: b'abc', - messages.StringField: u'abc', - } - - def action(field_class): - field_class(1, default=defaults[field_class]) - self.ActionOnAllFieldClasses(action) - - # Run defaults test again checking for str/unicode compatiblity. - defaults[messages.StringField] = 'abc' - self.ActionOnAllFieldClasses(action) - - def testStringField_BadUnicodeInDefault(self): - """Test binary values in string field.""" - self.assertRaisesWithRegexpMatch( - messages.InvalidDefaultError, - r"Invalid default value for StringField:.*: " - r"Field encountered non-ASCII string .*: " - r"'ascii' codec can't decode byte 0x89 in position 0: " - r"ordinal not in range", - messages.StringField, 1, default=b'\x89') - - def testDefaultFields_InvalidSingle(self): - """Test default field is correct type (invalid single).""" - def action(field_class): - self.assertRaises(messages.InvalidDefaultError, - field_class, - 1, - default=object()) - self.ActionOnAllFieldClasses(action) - - def testDefaultFields_InvalidRepeated(self): - """Test default field does not accept defaults.""" - self.assertRaisesWithRegexpMatch( - messages.FieldDefinitionError, - 'Repeated fields may not have defaults', - messages.StringField, 1, repeated=True, default=[1, 2, 3]) - - def testDefaultFields_None(self): - """Test none is always acceptable.""" - def action(field_class): - field_class(1, default=None) - field_class(1, required=True, default=None) - field_class(1, repeated=True, default=None) - self.ActionOnAllFieldClasses(action) - - def testDefaultFields_Enum(self): - """Test the default for enum fields.""" - class Symbol(messages.Enum): - - ALPHA = 1 - BETA = 2 - GAMMA = 3 - - field = messages.EnumField(Symbol, 1, default=Symbol.ALPHA) - - self.assertEquals(Symbol.ALPHA, field.default) - - def testDefaultFields_EnumStringDelayedResolution(self): - """Test that enum fields resolve default strings.""" - field = messages.EnumField('protorpc.descriptor.FieldDescriptor.Label', - 1, - default='OPTIONAL') - - self.assertEquals(descriptor.FieldDescriptor.Label.OPTIONAL, field.default) - - def testDefaultFields_EnumIntDelayedResolution(self): - """Test that enum fields resolve default integers.""" - field = messages.EnumField('protorpc.descriptor.FieldDescriptor.Label', - 1, - default=2) - - self.assertEquals(descriptor.FieldDescriptor.Label.REQUIRED, field.default) - - def testDefaultFields_EnumOkIfTypeKnown(self): - """Test that enum fields accept valid default values when type is known.""" - field = messages.EnumField(descriptor.FieldDescriptor.Label, - 1, - default='REPEATED') - - self.assertEquals(descriptor.FieldDescriptor.Label.REPEATED, field.default) - - def testDefaultFields_EnumForceCheckIfTypeKnown(self): - """Test that enum fields validate default values if type is known.""" - self.assertRaisesWithRegexpMatch(TypeError, - 'No such value for NOT_A_LABEL in ' - 'Enum Label', - messages.EnumField, - descriptor.FieldDescriptor.Label, - 1, - default='NOT_A_LABEL') - - def testDefaultFields_EnumInvalidDelayedResolution(self): - """Test that enum fields raise errors upon delayed resolution error.""" - field = messages.EnumField('protorpc.descriptor.FieldDescriptor.Label', - 1, - default=200) - - self.assertRaisesWithRegexpMatch(TypeError, - 'No such value for 200 in Enum Label', - getattr, - field, - 'default') - - def testValidate_Valid(self): - """Test validation of valid values.""" - values = {messages.IntegerField: 10, - messages.FloatField: 1.5, - messages.BooleanField: False, - messages.BytesField: b'abc', - messages.StringField: u'abc', - } - def action(field_class): - # Optional. - field = field_class(1) - field.validate(values[field_class]) - - # Required. - field = field_class(1, required=True) - field.validate(values[field_class]) - - # Repeated. - field = field_class(1, repeated=True) - field.validate([]) - field.validate(()) - field.validate([values[field_class]]) - field.validate((values[field_class],)) - - # Right value, but not repeated. - self.assertRaises(messages.ValidationError, - field.validate, - values[field_class]) - self.assertRaises(messages.ValidationError, - field.validate, - values[field_class]) - - self.ActionOnAllFieldClasses(action) - - def testValidate_Invalid(self): - """Test validation of valid values.""" - values = {messages.IntegerField: "10", - messages.FloatField: "blah", - messages.BooleanField: 0, - messages.BytesField: 10.20, - messages.StringField: 42, - } - def action(field_class): - # Optional. - field = field_class(1) - self.assertRaises(messages.ValidationError, - field.validate, - values[field_class]) - - # Required. - field = field_class(1, required=True) - self.assertRaises(messages.ValidationError, - field.validate, - values[field_class]) - - # Repeated. - field = field_class(1, repeated=True) - self.assertRaises(messages.ValidationError, - field.validate, - [values[field_class]]) - self.assertRaises(messages.ValidationError, - field.validate, - (values[field_class],)) - self.ActionOnAllFieldClasses(action) - - def testValidate_None(self): - """Test that None is valid for non-required fields.""" - def action(field_class): - # Optional. - field = field_class(1) - field.validate(None) - - # Required. - field = field_class(1, required=True) - self.assertRaisesWithRegexpMatch(messages.ValidationError, - 'Required field is missing', - field.validate, - None) - - # Repeated. - field = field_class(1, repeated=True) - field.validate(None) - self.assertRaisesWithRegexpMatch(messages.ValidationError, - 'Repeated values for %s may ' - 'not be None' % field_class.__name__, - field.validate, - [None]) - self.assertRaises(messages.ValidationError, - field.validate, - (None,)) - self.ActionOnAllFieldClasses(action) - - def testValidateElement(self): - """Test validation of valid values.""" - values = {messages.IntegerField: (10, -1, 0), - messages.FloatField: (1.5, -1.5, 3), # for json it is all a number - messages.BooleanField: (True, False), - messages.BytesField: (b'abc',), - messages.StringField: (u'abc',), - } - def action(field_class): - # Optional. - field = field_class(1) - for value in values[field_class]: - field.validate_element(value) - - # Required. - field = field_class(1, required=True) - for value in values[field_class]: - field.validate_element(value) - - # Repeated. - field = field_class(1, repeated=True) - self.assertRaises(messages.ValidationError, - field.validate_element, - []) - self.assertRaises(messages.ValidationError, - field.validate_element, - ()) - for value in values[field_class]: - field.validate_element(value) - - # Right value, but repeated. - self.assertRaises(messages.ValidationError, - field.validate_element, - list(values[field_class])) # testing list - self.assertRaises(messages.ValidationError, - field.validate_element, - values[field_class]) # testing tuple - self.ActionOnAllFieldClasses(action) - - def testValidateCastingElement(self): - field = messages.FloatField(1) - self.assertEquals(type(field.validate_element(12)), float) - self.assertEquals(type(field.validate_element(12.0)), float) - field = messages.IntegerField(1) - self.assertEquals(type(field.validate_element(12)), int) - self.assertRaises(messages.ValidationError, - field.validate_element, - 12.0) # should fail from float to int - - def testReadOnly(self): - """Test that objects are all read-only.""" - def action(field_class): - field = field_class(10) - self.assertRaises(AttributeError, - setattr, - field, - 'number', - 20) - self.assertRaises(AttributeError, - setattr, - field, - 'anything_else', - 'whatever') - self.ActionOnAllFieldClasses(action) - - def testMessageField(self): - """Test the construction of message fields.""" - self.assertRaises(messages.FieldDefinitionError, - messages.MessageField, - str, - 10) - - self.assertRaises(messages.FieldDefinitionError, - messages.MessageField, - messages.Message, - 10) - - class MyMessage(messages.Message): - pass - - field = messages.MessageField(MyMessage, 10) - self.assertEquals(MyMessage, field.type) - - def testMessageField_ForwardReference(self): - """Test the construction of forward reference message fields.""" - global MyMessage - global ForwardMessage - try: - class MyMessage(messages.Message): - - self_reference = messages.MessageField('MyMessage', 1) - forward = messages.MessageField('ForwardMessage', 2) - nested = messages.MessageField('ForwardMessage.NestedMessage', 3) - inner = messages.MessageField('Inner', 4) - - class Inner(messages.Message): - - sibling = messages.MessageField('Sibling', 1) - - class Sibling(messages.Message): - - pass - - class ForwardMessage(messages.Message): - - class NestedMessage(messages.Message): - - pass - - self.assertEquals(MyMessage, - MyMessage.field_by_name('self_reference').type) - - self.assertEquals(ForwardMessage, - MyMessage.field_by_name('forward').type) - - self.assertEquals(ForwardMessage.NestedMessage, - MyMessage.field_by_name('nested').type) - - self.assertEquals(MyMessage.Inner, - MyMessage.field_by_name('inner').type) - - self.assertEquals(MyMessage.Sibling, - MyMessage.Inner.field_by_name('sibling').type) - finally: - try: - del MyMessage - del ForwardMessage - except: - pass - - def testMessageField_WrongType(self): - """Test that forward referencing the wrong type raises an error.""" - global AnEnum - try: - class AnEnum(messages.Enum): - pass - - class AnotherMessage(messages.Message): - - a_field = messages.MessageField('AnEnum', 1) - - self.assertRaises(messages.FieldDefinitionError, - getattr, - AnotherMessage.field_by_name('a_field'), - 'type') - finally: - del AnEnum - - def testMessageFieldValidate(self): - """Test validation on message field.""" - class MyMessage(messages.Message): - pass - - class AnotherMessage(messages.Message): - pass - - field = messages.MessageField(MyMessage, 10) - field.validate(MyMessage()) - - self.assertRaises(messages.ValidationError, - field.validate, - AnotherMessage()) - - def testMessageFieldMessageType(self): - """Test message_type property.""" - class MyMessage(messages.Message): - pass - - class HasMessage(messages.Message): - field = messages.MessageField(MyMessage, 1) - - self.assertEqual(HasMessage.field.type, HasMessage.field.message_type) - - def testMessageFieldValueFromMessage(self): - class MyMessage(messages.Message): - pass - - class HasMessage(messages.Message): - field = messages.MessageField(MyMessage, 1) - - instance = MyMessage() - - self.assertTrue(instance is HasMessage.field.value_from_message(instance)) - - def testMessageFieldValueFromMessageWrongType(self): - class MyMessage(messages.Message): - pass - - class HasMessage(messages.Message): - field = messages.MessageField(MyMessage, 1) - - self.assertRaisesWithRegexpMatch( - messages.DecodeError, - 'Expected type MyMessage, got int: 10', - HasMessage.field.value_from_message, 10) - - def testMessageFieldValueToMessage(self): - class MyMessage(messages.Message): - pass - - class HasMessage(messages.Message): - field = messages.MessageField(MyMessage, 1) - - instance = MyMessage() - - self.assertTrue(instance is HasMessage.field.value_to_message(instance)) - - def testMessageFieldValueToMessageWrongType(self): - class MyMessage(messages.Message): - pass - - class MyOtherMessage(messages.Message): - pass - - class HasMessage(messages.Message): - field = messages.MessageField(MyMessage, 1) - - instance = MyOtherMessage() - - self.assertRaisesWithRegexpMatch( - messages.EncodeError, - 'Expected type MyMessage, got MyOtherMessage: ', - HasMessage.field.value_to_message, instance) - - def testIntegerField_AllowLong(self): - """Test that the integer field allows for longs.""" - if six.PY2: - messages.IntegerField(10, default=long(10)) - - def testMessageFieldValidate_Initialized(self): - """Test validation on message field.""" - class MyMessage(messages.Message): - field1 = messages.IntegerField(1, required=True) - - field = messages.MessageField(MyMessage, 10) - - # Will validate messages where is_initialized() is False. - message = MyMessage() - field.validate(message) - message.field1 = 20 - field.validate(message) - - def testEnumField(self): - """Test the construction of enum fields.""" - self.assertRaises(messages.FieldDefinitionError, - messages.EnumField, - str, - 10) - - self.assertRaises(messages.FieldDefinitionError, - messages.EnumField, - messages.Enum, - 10) - - class Color(messages.Enum): - RED = 1 - GREEN = 2 - BLUE = 3 - - field = messages.EnumField(Color, 10) - self.assertEquals(Color, field.type) - - class Another(messages.Enum): - VALUE = 1 - - self.assertRaises(messages.InvalidDefaultError, - messages.EnumField, - Color, - 10, - default=Another.VALUE) - - def testEnumField_ForwardReference(self): - """Test the construction of forward reference enum fields.""" - global MyMessage - global ForwardEnum - global ForwardMessage - try: - class MyMessage(messages.Message): - - forward = messages.EnumField('ForwardEnum', 1) - nested = messages.EnumField('ForwardMessage.NestedEnum', 2) - inner = messages.EnumField('Inner', 3) - - class Inner(messages.Enum): - pass - - class ForwardEnum(messages.Enum): - pass - - class ForwardMessage(messages.Message): - - class NestedEnum(messages.Enum): - pass - - self.assertEquals(ForwardEnum, - MyMessage.field_by_name('forward').type) - - self.assertEquals(ForwardMessage.NestedEnum, - MyMessage.field_by_name('nested').type) - - self.assertEquals(MyMessage.Inner, - MyMessage.field_by_name('inner').type) - finally: - try: - del MyMessage - del ForwardEnum - del ForwardMessage - except: - pass - - def testEnumField_WrongType(self): - """Test that forward referencing the wrong type raises an error.""" - global AMessage - try: - class AMessage(messages.Message): - pass - - class AnotherMessage(messages.Message): - - a_field = messages.EnumField('AMessage', 1) - - self.assertRaises(messages.FieldDefinitionError, - getattr, - AnotherMessage.field_by_name('a_field'), - 'type') - finally: - del AMessage - - def testMessageDefinition(self): - """Test that message definition is set on fields.""" - class MyMessage(messages.Message): - - my_field = messages.StringField(1) - - self.assertEquals(MyMessage, - MyMessage.field_by_name('my_field').message_definition()) - - def testNoneAssignment(self): - """Test that assigning None does not change comparison.""" - class MyMessage(messages.Message): - - my_field = messages.StringField(1) - - m1 = MyMessage() - m2 = MyMessage() - m2.my_field = None - self.assertEquals(m1, m2) - - def testNonAsciiStr(self): - """Test validation fails for non-ascii StringField values.""" - class Thing(messages.Message): - string_field = messages.StringField(2) - - thing = Thing() - self.assertRaisesWithRegexpMatch( - messages.ValidationError, - 'Field string_field encountered non-ASCII string', - setattr, thing, 'string_field', test_util.BINARY) - - -class MessageTest(test_util.TestCase): - """Tests for message class.""" - - def CreateMessageClass(self): - """Creates a simple message class with 3 fields. - - Fields are defined in alphabetical order but with conflicting numeric - order. - """ - class ComplexMessage(messages.Message): - a3 = messages.IntegerField(3) - b1 = messages.StringField(1) - c2 = messages.StringField(2) - - return ComplexMessage - - def testSameNumbers(self): - """Test that cannot assign two fields with same numbers.""" - - def action(): - class BadMessage(messages.Message): - f1 = messages.IntegerField(1) - f2 = messages.IntegerField(1) - self.assertRaises(messages.DuplicateNumberError, - action) - - def testStrictAssignment(self): - """Tests that cannot assign to unknown or non-reserved attributes.""" - class SimpleMessage(messages.Message): - field = messages.IntegerField(1) - - simple_message = SimpleMessage() - self.assertRaises(AttributeError, - setattr, - simple_message, - 'does_not_exist', - 10) - - def testListAssignmentDoesNotCopy(self): - class SimpleMessage(messages.Message): - repeated = messages.IntegerField(1, repeated=True) - - message = SimpleMessage() - original = message.repeated - message.repeated = [] - self.assertFalse(original is message.repeated) - - def testValidate_Optional(self): - """Tests validation of optional fields.""" - class SimpleMessage(messages.Message): - non_required = messages.IntegerField(1) - - simple_message = SimpleMessage() - simple_message.check_initialized() - simple_message.non_required = 10 - simple_message.check_initialized() - - def testValidate_Required(self): - """Tests validation of required fields.""" - class SimpleMessage(messages.Message): - required = messages.IntegerField(1, required=True) - - simple_message = SimpleMessage() - self.assertRaises(messages.ValidationError, - simple_message.check_initialized) - simple_message.required = 10 - simple_message.check_initialized() - - def testValidate_Repeated(self): - """Tests validation of repeated fields.""" - class SimpleMessage(messages.Message): - repeated = messages.IntegerField(1, repeated=True) - - simple_message = SimpleMessage() - - # Check valid values. - for valid_value in [], [10], [10, 20], (), (10,), (10, 20): - simple_message.repeated = valid_value - simple_message.check_initialized() - - # Check cleared. - simple_message.repeated = [] - simple_message.check_initialized() - - # Check invalid values. - for invalid_value in 10, ['10', '20'], [None], (None,): - self.assertRaises(messages.ValidationError, - setattr, simple_message, 'repeated', invalid_value) - - def testIsInitialized(self): - """Tests is_initialized.""" - class SimpleMessage(messages.Message): - required = messages.IntegerField(1, required=True) - - simple_message = SimpleMessage() - self.assertFalse(simple_message.is_initialized()) - - simple_message.required = 10 - - self.assertTrue(simple_message.is_initialized()) - - def testIsInitializedNestedField(self): - """Tests is_initialized for nested fields.""" - class SimpleMessage(messages.Message): - required = messages.IntegerField(1, required=True) - - class NestedMessage(messages.Message): - simple = messages.MessageField(SimpleMessage, 1) - - simple_message = SimpleMessage() - self.assertFalse(simple_message.is_initialized()) - nested_message = NestedMessage(simple=simple_message) - self.assertFalse(nested_message.is_initialized()) - - simple_message.required = 10 - - self.assertTrue(simple_message.is_initialized()) - self.assertTrue(nested_message.is_initialized()) - - def testInitializeNestedFieldFromDict(self): - """Tests initializing nested fields from dict.""" - class SimpleMessage(messages.Message): - required = messages.IntegerField(1, required=True) - - class NestedMessage(messages.Message): - simple = messages.MessageField(SimpleMessage, 1) - - class RepeatedMessage(messages.Message): - simple = messages.MessageField(SimpleMessage, 1, repeated=True) - - nested_message1 = NestedMessage(simple={'required': 10}) - self.assertTrue(nested_message1.is_initialized()) - self.assertTrue(nested_message1.simple.is_initialized()) - - nested_message2 = NestedMessage() - nested_message2.simple = {'required': 10} - self.assertTrue(nested_message2.is_initialized()) - self.assertTrue(nested_message2.simple.is_initialized()) - - repeated_values = [{}, {'required': 10}, SimpleMessage(required=20)] - - repeated_message1 = RepeatedMessage(simple=repeated_values) - self.assertEquals(3, len(repeated_message1.simple)) - self.assertFalse(repeated_message1.is_initialized()) - - repeated_message1.simple[0].required = 0 - self.assertTrue(repeated_message1.is_initialized()) - - repeated_message2 = RepeatedMessage() - repeated_message2.simple = repeated_values - self.assertEquals(3, len(repeated_message2.simple)) - self.assertFalse(repeated_message2.is_initialized()) - - repeated_message2.simple[0].required = 0 - self.assertTrue(repeated_message2.is_initialized()) - - def testNestedMethodsNotAllowed(self): - """Test that method definitions on Message classes are not allowed.""" - def action(): - class WithMethods(messages.Message): - def not_allowed(self): - pass - - self.assertRaises(messages.MessageDefinitionError, - action) - - def testNestedAttributesNotAllowed(self): - """Test that attribute assignment on Message classes are not allowed.""" - def int_attribute(): - class WithMethods(messages.Message): - not_allowed = 1 - - def string_attribute(): - class WithMethods(messages.Message): - not_allowed = 'not allowed' - - def enum_attribute(): - class WithMethods(messages.Message): - not_allowed = Color.RED - - for action in (int_attribute, string_attribute, enum_attribute): - self.assertRaises(messages.MessageDefinitionError, - action) - - def testNameIsSetOnFields(self): - """Make sure name is set on fields after Message class init.""" - class HasNamedFields(messages.Message): - field = messages.StringField(1) - - self.assertEquals('field', HasNamedFields.field_by_number(1).name) - - def testSubclassingMessageDisallowed(self): - """Not permitted to create sub-classes of message classes.""" - class SuperClass(messages.Message): - pass - - def action(): - class SubClass(SuperClass): - pass - - self.assertRaises(messages.MessageDefinitionError, - action) - - def testAllFields(self): - """Test all_fields method.""" - ComplexMessage = self.CreateMessageClass() - fields = list(ComplexMessage.all_fields()) - - # Order does not matter, so sort now. - fields = sorted(fields, key=lambda f: f.name) - - self.assertEquals(3, len(fields)) - self.assertEquals('a3', fields[0].name) - self.assertEquals('b1', fields[1].name) - self.assertEquals('c2', fields[2].name) - - def testFieldByName(self): - """Test getting field by name.""" - ComplexMessage = self.CreateMessageClass() - - self.assertEquals(3, ComplexMessage.field_by_name('a3').number) - self.assertEquals(1, ComplexMessage.field_by_name('b1').number) - self.assertEquals(2, ComplexMessage.field_by_name('c2').number) - - self.assertRaises(KeyError, - ComplexMessage.field_by_name, - 'unknown') - - def testFieldByNumber(self): - """Test getting field by number.""" - ComplexMessage = self.CreateMessageClass() - - self.assertEquals('a3', ComplexMessage.field_by_number(3).name) - self.assertEquals('b1', ComplexMessage.field_by_number(1).name) - self.assertEquals('c2', ComplexMessage.field_by_number(2).name) - - self.assertRaises(KeyError, - ComplexMessage.field_by_number, - 4) - - def testGetAssignedValue(self): - """Test getting the assigned value of a field.""" - class SomeMessage(messages.Message): - a_value = messages.StringField(1, default=u'a default') - - message = SomeMessage() - self.assertEquals(None, message.get_assigned_value('a_value')) - - message.a_value = u'a string' - self.assertEquals(u'a string', message.get_assigned_value('a_value')) - - message.a_value = u'a default' - self.assertEquals(u'a default', message.get_assigned_value('a_value')) - - self.assertRaisesWithRegexpMatch( - AttributeError, - 'Message SomeMessage has no field no_such_field', - message.get_assigned_value, - 'no_such_field') - - def testReset(self): - """Test resetting a field value.""" - class SomeMessage(messages.Message): - a_value = messages.StringField(1, default=u'a default') - repeated = messages.IntegerField(2, repeated=True) - - message = SomeMessage() - - self.assertRaises(AttributeError, message.reset, 'unknown') - - self.assertEquals(u'a default', message.a_value) - message.reset('a_value') - self.assertEquals(u'a default', message.a_value) - - message.a_value = u'a new value' - self.assertEquals(u'a new value', message.a_value) - message.reset('a_value') - self.assertEquals(u'a default', message.a_value) - - message.repeated = [1, 2, 3] - self.assertEquals([1, 2, 3], message.repeated) - saved = message.repeated - message.reset('repeated') - self.assertEquals([], message.repeated) - self.assertIsInstance(message.repeated, messages.FieldList) - self.assertEquals([1, 2, 3], saved) - - def testAllowNestedEnums(self): - """Test allowing nested enums in a message definition.""" - class Trade(messages.Message): - class Duration(messages.Enum): - GTC = 1 - DAY = 2 - - class Currency(messages.Enum): - USD = 1 - GBP = 2 - INR = 3 - - # Sorted by name order seems to be the only feasible option. - self.assertEquals(['Currency', 'Duration'], Trade.__enums__) - - # Message definition will now be set on Enumerated objects. - self.assertEquals(Trade, Trade.Duration.message_definition()) - - def testAllowNestedMessages(self): - """Test allowing nested messages in a message definition.""" - class Trade(messages.Message): - class Lot(messages.Message): - pass - - class Agent(messages.Message): - pass - - # Sorted by name order seems to be the only feasible option. - self.assertEquals(['Agent', 'Lot'], Trade.__messages__) - self.assertEquals(Trade, Trade.Agent.message_definition()) - self.assertEquals(Trade, Trade.Lot.message_definition()) - - # But not Message itself. - def action(): - class Trade(messages.Message): - NiceTry = messages.Message - self.assertRaises(messages.MessageDefinitionError, action) - - def testDisallowClassAssignments(self): - """Test setting class attributes may not happen.""" - class MyMessage(messages.Message): - pass - - self.assertRaises(AttributeError, - setattr, - MyMessage, - 'x', - 'do not assign') - - def testEquality(self): - """Test message class equality.""" - # Comparison against enums must work. - class MyEnum(messages.Enum): - val1 = 1 - val2 = 2 - - # Comparisons against nested messages must work. - class AnotherMessage(messages.Message): - string = messages.StringField(1) - - class MyMessage(messages.Message): - field1 = messages.IntegerField(1) - field2 = messages.EnumField(MyEnum, 2) - field3 = messages.MessageField(AnotherMessage, 3) - - message1 = MyMessage() - - self.assertNotEquals('hi', message1) - self.assertNotEquals(AnotherMessage(), message1) - self.assertEquals(message1, message1) - - message2 = MyMessage() - - self.assertEquals(message1, message2) - - message1.field1 = 10 - self.assertNotEquals(message1, message2) - - message2.field1 = 20 - self.assertNotEquals(message1, message2) - - message2.field1 = 10 - self.assertEquals(message1, message2) - - message1.field2 = MyEnum.val1 - self.assertNotEquals(message1, message2) - - message2.field2 = MyEnum.val2 - self.assertNotEquals(message1, message2) - - message2.field2 = MyEnum.val1 - self.assertEquals(message1, message2) - - message1.field3 = AnotherMessage() - message1.field3.string = u'value1' - self.assertNotEquals(message1, message2) - - message2.field3 = AnotherMessage() - message2.field3.string = u'value2' - self.assertNotEquals(message1, message2) - - message2.field3.string = u'value1' - self.assertEquals(message1, message2) - - def testEqualityWithUnknowns(self): - """Test message class equality with unknown fields.""" - - class MyMessage(messages.Message): - field1 = messages.IntegerField(1) - - message1 = MyMessage() - message2 = MyMessage() - self.assertEquals(message1, message2) - message1.set_unrecognized_field('unknown1', 'value1', - messages.Variant.STRING) - self.assertEquals(message1, message2) - - message1.set_unrecognized_field('unknown2', ['asdf', 3], - messages.Variant.STRING) - message1.set_unrecognized_field('unknown3', 4.7, - messages.Variant.DOUBLE) - self.assertEquals(message1, message2) - - def testUnrecognizedFieldInvalidVariant(self): - class MyMessage(messages.Message): - field1 = messages.IntegerField(1) - - message1 = MyMessage() - self.assertRaises(TypeError, message1.set_unrecognized_field, 'unknown4', - {'unhandled': 'type'}, None) - self.assertRaises(TypeError, message1.set_unrecognized_field, 'unknown4', - {'unhandled': 'type'}, 123) - - def testRepr(self): - """Test represtation of Message object.""" - class MyMessage(messages.Message): - integer_value = messages.IntegerField(1) - string_value = messages.StringField(2) - unassigned = messages.StringField(3) - unassigned_with_default = messages.StringField(4, default=u'a default') - - my_message = MyMessage() - my_message.integer_value = 42 - my_message.string_value = u'A string' - - pat = re.compile(r"") - self.assertTrue(pat.match(repr(my_message)) is not None) - - def testValidation(self): - """Test validation of message values.""" - # Test optional. - class SubMessage(messages.Message): - pass - - class Message(messages.Message): - val = messages.MessageField(SubMessage, 1) - - message = Message() - - message_field = messages.MessageField(Message, 1) - message_field.validate(message) - message.val = SubMessage() - message_field.validate(message) - self.assertRaises(messages.ValidationError, - setattr, message, 'val', [SubMessage()]) - - # Test required. - class Message(messages.Message): - val = messages.MessageField(SubMessage, 1, required=True) - - message = Message() - - message_field = messages.MessageField(Message, 1) - message_field.validate(message) - message.val = SubMessage() - message_field.validate(message) - self.assertRaises(messages.ValidationError, - setattr, message, 'val', [SubMessage()]) - - # Test repeated. - class Message(messages.Message): - val = messages.MessageField(SubMessage, 1, repeated=True) - - message = Message() - - message_field = messages.MessageField(Message, 1) - message_field.validate(message) - self.assertRaisesWithRegexpMatch( - messages.ValidationError, - "Field val is repeated. Found: ", - setattr, message, 'val', SubMessage()) - message.val = [SubMessage()] - message_field.validate(message) - - def testDefinitionName(self): - """Test message name.""" - class MyMessage(messages.Message): - pass - - module_name = test_util.get_module_name(FieldTest) - self.assertEquals('%s.MyMessage' % module_name, - MyMessage.definition_name()) - self.assertEquals(module_name, MyMessage.outer_definition_name()) - self.assertEquals(module_name, MyMessage.definition_package()) - - self.assertEquals(six.text_type, type(MyMessage.definition_name())) - self.assertEquals(six.text_type, type(MyMessage.outer_definition_name())) - self.assertEquals(six.text_type, type(MyMessage.definition_package())) - - def testDefinitionName_OverrideModule(self): - """Test message module is overriden by module package name.""" - class MyMessage(messages.Message): - pass - - global package - package = 'my.package' - - try: - self.assertEquals('my.package.MyMessage', MyMessage.definition_name()) - self.assertEquals('my.package', MyMessage.outer_definition_name()) - self.assertEquals('my.package', MyMessage.definition_package()) - - self.assertEquals(six.text_type, type(MyMessage.definition_name())) - self.assertEquals(six.text_type, type(MyMessage.outer_definition_name())) - self.assertEquals(six.text_type, type(MyMessage.definition_package())) - finally: - del package - - def testDefinitionName_NoModule(self): - """Test what happens when there is no module for message.""" - class MyMessage(messages.Message): - pass - - original_modules = sys.modules - sys.modules = dict(sys.modules) - try: - del sys.modules[__name__] - self.assertEquals('MyMessage', MyMessage.definition_name()) - self.assertEquals(None, MyMessage.outer_definition_name()) - self.assertEquals(None, MyMessage.definition_package()) - - self.assertEquals(six.text_type, type(MyMessage.definition_name())) - finally: - sys.modules = original_modules - - def testDefinitionName_Nested(self): - """Test nested message names.""" - class MyMessage(messages.Message): - - class NestedMessage(messages.Message): - - class NestedMessage(messages.Message): - - pass - - module_name = test_util.get_module_name(MessageTest) - self.assertEquals('%s.MyMessage.NestedMessage' % module_name, - MyMessage.NestedMessage.definition_name()) - self.assertEquals('%s.MyMessage' % module_name, - MyMessage.NestedMessage.outer_definition_name()) - self.assertEquals(module_name, - MyMessage.NestedMessage.definition_package()) - - self.assertEquals('%s.MyMessage.NestedMessage.NestedMessage' % module_name, - MyMessage.NestedMessage.NestedMessage.definition_name()) - self.assertEquals( - '%s.MyMessage.NestedMessage' % module_name, - MyMessage.NestedMessage.NestedMessage.outer_definition_name()) - self.assertEquals( - module_name, - MyMessage.NestedMessage.NestedMessage.definition_package()) - - - def testMessageDefinition(self): - """Test that enumeration knows its enclosing message definition.""" - class OuterMessage(messages.Message): - - class InnerMessage(messages.Message): - pass - - self.assertEquals(None, OuterMessage.message_definition()) - self.assertEquals(OuterMessage, - OuterMessage.InnerMessage.message_definition()) - - def testConstructorKwargs(self): - """Test kwargs via constructor.""" - class SomeMessage(messages.Message): - name = messages.StringField(1) - number = messages.IntegerField(2) - - expected = SomeMessage() - expected.name = 'my name' - expected.number = 200 - self.assertEquals(expected, SomeMessage(name='my name', number=200)) - - def testConstructorNotAField(self): - """Test kwargs via constructor with wrong names.""" - class SomeMessage(messages.Message): - pass - - self.assertRaisesWithRegexpMatch( - AttributeError, - 'May not assign arbitrary value does_not_exist to message SomeMessage', - SomeMessage, - does_not_exist=10) - - def testGetUnsetRepeatedValue(self): - class SomeMessage(messages.Message): - repeated = messages.IntegerField(1, repeated=True) - - instance = SomeMessage() - self.assertEquals([], instance.repeated) - self.assertTrue(isinstance(instance.repeated, messages.FieldList)) - - def testCompareAutoInitializedRepeatedFields(self): - class SomeMessage(messages.Message): - repeated = messages.IntegerField(1, repeated=True) - - message1 = SomeMessage(repeated=[]) - message2 = SomeMessage() - self.assertEquals(message1, message2) - - def testUnknownValues(self): - """Test message class equality with unknown fields.""" - class MyMessage(messages.Message): - field1 = messages.IntegerField(1) - - message = MyMessage() - self.assertEquals([], message.all_unrecognized_fields()) - self.assertEquals((None, None), - message.get_unrecognized_field_info('doesntexist')) - self.assertEquals((None, None), - message.get_unrecognized_field_info( - 'doesntexist', None, None)) - self.assertEquals(('defaultvalue', 'defaultwire'), - message.get_unrecognized_field_info( - 'doesntexist', 'defaultvalue', 'defaultwire')) - self.assertEquals((3, None), - message.get_unrecognized_field_info( - 'doesntexist', value_default=3)) - - message.set_unrecognized_field('exists', 9.5, messages.Variant.DOUBLE) - self.assertEquals(1, len(message.all_unrecognized_fields())) - self.assertTrue('exists' in message.all_unrecognized_fields()) - self.assertEquals((9.5, messages.Variant.DOUBLE), - message.get_unrecognized_field_info('exists')) - self.assertEquals((9.5, messages.Variant.DOUBLE), - message.get_unrecognized_field_info('exists', 'type', - 1234)) - self.assertEquals((1234, None), - message.get_unrecognized_field_info('doesntexist', 1234)) - - message.set_unrecognized_field('another', 'value', messages.Variant.STRING) - self.assertEquals(2, len(message.all_unrecognized_fields())) - self.assertTrue('exists' in message.all_unrecognized_fields()) - self.assertTrue('another' in message.all_unrecognized_fields()) - self.assertEquals((9.5, messages.Variant.DOUBLE), - message.get_unrecognized_field_info('exists')) - self.assertEquals(('value', messages.Variant.STRING), - message.get_unrecognized_field_info('another')) - - message.set_unrecognized_field('typetest1', ['list', 0, ('test',)], - messages.Variant.STRING) - self.assertEquals((['list', 0, ('test',)], messages.Variant.STRING), - message.get_unrecognized_field_info('typetest1')) - message.set_unrecognized_field('typetest2', '', messages.Variant.STRING) - self.assertEquals(('', messages.Variant.STRING), - message.get_unrecognized_field_info('typetest2')) - - def testPickle(self): - """Testing pickling and unpickling of Message instances.""" - global MyEnum - global AnotherMessage - global MyMessage - - class MyEnum(messages.Enum): - val1 = 1 - val2 = 2 - - class AnotherMessage(messages.Message): - string = messages.StringField(1, repeated=True) - - class MyMessage(messages.Message): - field1 = messages.IntegerField(1) - field2 = messages.EnumField(MyEnum, 2) - field3 = messages.MessageField(AnotherMessage, 3) - - message = MyMessage(field1=1, field2=MyEnum.val2, - field3=AnotherMessage(string=['a', 'b', 'c'])) - message.set_unrecognized_field('exists', 'value', messages.Variant.STRING) - message.set_unrecognized_field('repeated', ['list', 0, ('test',)], - messages.Variant.STRING) - unpickled = pickle.loads(pickle.dumps(message)) - self.assertEquals(message, unpickled) - self.assertTrue(AnotherMessage.string is unpickled.field3.string.field) - self.assertTrue('exists' in message.all_unrecognized_fields()) - self.assertEquals(('value', messages.Variant.STRING), - message.get_unrecognized_field_info('exists')) - self.assertEquals((['list', 0, ('test',)], messages.Variant.STRING), - message.get_unrecognized_field_info('repeated')) - - -class FindDefinitionTest(test_util.TestCase): - """Test finding definitions relative to various definitions and modules.""" - - def setUp(self): - """Set up module-space. Starts off empty.""" - self.modules = {} - - def DefineModule(self, name): - """Define a module and its parents in module space. - - Modules that are already defined in self.modules are not re-created. - - Args: - name: Fully qualified name of modules to create. - - Returns: - Deepest nested module. For example: - - DefineModule('a.b.c') # Returns c. - """ - name_path = name.split('.') - full_path = [] - for node in name_path: - full_path.append(node) - full_name = '.'.join(full_path) - self.modules.setdefault(full_name, types.ModuleType(full_name)) - return self.modules[name] - - def DefineMessage(self, module, name, children={}, add_to_module=True): - """Define a new Message class in the context of a module. - - Used for easily describing complex Message hierarchy. Message is defined - including all child definitions. - - Args: - module: Fully qualified name of module to place Message class in. - name: Name of Message to define within module. - children: Define any level of nesting of children definitions. To define - a message, map the name to another dictionary. The dictionary can - itself contain additional definitions, and so on. To map to an Enum, - define the Enum class separately and map it by name. - add_to_module: If True, new Message class is added to module. If False, - new Message is not added. - """ - # Make sure module exists. - module_instance = self.DefineModule(module) - - # Recursively define all child messages. - for attribute, value in children.items(): - if isinstance(value, dict): - children[attribute] = self.DefineMessage( - module, attribute, value, False) - - # Override default __module__ variable. - children['__module__'] = module - - # Instantiate and possibly add to module. - message_class = type(name, (messages.Message,), dict(children)) - if add_to_module: - setattr(module_instance, name, message_class) - return message_class - - def Importer(self, module, globals='', locals='', fromlist=None): - """Importer function. - - Acts like __import__. Only loads modules from self.modules. Does not - try to load real modules defined elsewhere. Does not try to handle relative - imports. - - Args: - module: Fully qualified name of module to load from self.modules. - """ - if fromlist is None: - module = module.split('.')[0] - try: - return self.modules[module] - except KeyError: - raise ImportError() - - def testNoSuchModule(self): - """Test searching for definitions that do no exist.""" - self.assertRaises(messages.DefinitionNotFoundError, - messages.find_definition, - 'does.not.exist', - importer=self.Importer) - - def testRefersToModule(self): - """Test that referring to a module does not return that module.""" - self.DefineModule('i.am.a.module') - self.assertRaises(messages.DefinitionNotFoundError, - messages.find_definition, - 'i.am.a.module', - importer=self.Importer) - - def testNoDefinition(self): - """Test not finding a definition in an existing module.""" - self.DefineModule('i.am.a.module') - self.assertRaises(messages.DefinitionNotFoundError, - messages.find_definition, - 'i.am.a.module.MyMessage', - importer=self.Importer) - - def testNotADefinition(self): - """Test trying to fetch something that is not a definition.""" - module = self.DefineModule('i.am.a.module') - setattr(module, 'A', 'a string') - self.assertRaises(messages.DefinitionNotFoundError, - messages.find_definition, - 'i.am.a.module.A', - importer=self.Importer) - - def testGlobalFind(self): - """Test finding definitions from fully qualified module names.""" - A = self.DefineMessage('a.b.c', 'A', {}) - self.assertEquals(A, messages.find_definition('a.b.c.A', - importer=self.Importer)) - B = self.DefineMessage('a.b.c', 'B', {'C':{}}) - self.assertEquals(B.C, messages.find_definition('a.b.c.B.C', - importer=self.Importer)) - - def testRelativeToModule(self): - """Test finding definitions relative to modules.""" - # Define modules. - a = self.DefineModule('a') - b = self.DefineModule('a.b') - c = self.DefineModule('a.b.c') - - # Define messages. - A = self.DefineMessage('a', 'A') - B = self.DefineMessage('a.b', 'B') - C = self.DefineMessage('a.b.c', 'C') - D = self.DefineMessage('a.b.d', 'D') - - # Find A, B, C and D relative to a. - self.assertEquals(A, messages.find_definition( - 'A', a, importer=self.Importer)) - self.assertEquals(B, messages.find_definition( - 'b.B', a, importer=self.Importer)) - self.assertEquals(C, messages.find_definition( - 'b.c.C', a, importer=self.Importer)) - self.assertEquals(D, messages.find_definition( - 'b.d.D', a, importer=self.Importer)) - - # Find A, B, C and D relative to b. - self.assertEquals(A, messages.find_definition( - 'A', b, importer=self.Importer)) - self.assertEquals(B, messages.find_definition( - 'B', b, importer=self.Importer)) - self.assertEquals(C, messages.find_definition( - 'c.C', b, importer=self.Importer)) - self.assertEquals(D, messages.find_definition( - 'd.D', b, importer=self.Importer)) - - # Find A, B, C and D relative to c. Module d is the same case as c. - self.assertEquals(A, messages.find_definition( - 'A', c, importer=self.Importer)) - self.assertEquals(B, messages.find_definition( - 'B', c, importer=self.Importer)) - self.assertEquals(C, messages.find_definition( - 'C', c, importer=self.Importer)) - self.assertEquals(D, messages.find_definition( - 'd.D', c, importer=self.Importer)) - - def testRelativeToMessages(self): - """Test finding definitions relative to Message definitions.""" - A = self.DefineMessage('a.b', 'A', {'B': {'C': {}, 'D': {}}}) - B = A.B - C = A.B.C - D = A.B.D - - # Find relative to A. - self.assertEquals(A, messages.find_definition( - 'A', A, importer=self.Importer)) - self.assertEquals(B, messages.find_definition( - 'B', A, importer=self.Importer)) - self.assertEquals(C, messages.find_definition( - 'B.C', A, importer=self.Importer)) - self.assertEquals(D, messages.find_definition( - 'B.D', A, importer=self.Importer)) - - # Find relative to B. - self.assertEquals(A, messages.find_definition( - 'A', B, importer=self.Importer)) - self.assertEquals(B, messages.find_definition( - 'B', B, importer=self.Importer)) - self.assertEquals(C, messages.find_definition( - 'C', B, importer=self.Importer)) - self.assertEquals(D, messages.find_definition( - 'D', B, importer=self.Importer)) - - # Find relative to C. - self.assertEquals(A, messages.find_definition( - 'A', C, importer=self.Importer)) - self.assertEquals(B, messages.find_definition( - 'B', C, importer=self.Importer)) - self.assertEquals(C, messages.find_definition( - 'C', C, importer=self.Importer)) - self.assertEquals(D, messages.find_definition( - 'D', C, importer=self.Importer)) - - # Find relative to C searching from c. - self.assertEquals(A, messages.find_definition( - 'b.A', C, importer=self.Importer)) - self.assertEquals(B, messages.find_definition( - 'b.A.B', C, importer=self.Importer)) - self.assertEquals(C, messages.find_definition( - 'b.A.B.C', C, importer=self.Importer)) - self.assertEquals(D, messages.find_definition( - 'b.A.B.D', C, importer=self.Importer)) - - def testAbsoluteReference(self): - """Test finding absolute definition names.""" - # Define modules. - a = self.DefineModule('a') - b = self.DefineModule('a.a') - - # Define messages. - aA = self.DefineMessage('a', 'A') - aaA = self.DefineMessage('a.a', 'A') - - # Always find a.A. - self.assertEquals(aA, messages.find_definition('.a.A', None, - importer=self.Importer)) - self.assertEquals(aA, messages.find_definition('.a.A', a, - importer=self.Importer)) - self.assertEquals(aA, messages.find_definition('.a.A', aA, - importer=self.Importer)) - self.assertEquals(aA, messages.find_definition('.a.A', aaA, - importer=self.Importer)) - - def testFindEnum(self): - """Test that Enums are found.""" - class Color(messages.Enum): - pass - A = self.DefineMessage('a', 'A', {'Color': Color}) - - self.assertEquals( - Color, - messages.find_definition('Color', A, importer=self.Importer)) - - def testFalseScope(self): - """Test that Message definitions nested in strange objects are hidden.""" - global X - class X(object): - class A(messages.Message): - pass - - self.assertRaises(TypeError, messages.find_definition, 'A', X) - self.assertRaises(messages.DefinitionNotFoundError, - messages.find_definition, - 'X.A', sys.modules[__name__]) - - def testSearchAttributeFirst(self): - """Make sure not faked out by module, but continues searching.""" - A = self.DefineMessage('a', 'A') - module_A = self.DefineModule('a.A') - - self.assertEquals(A, messages.find_definition( - 'a.A', None, importer=self.Importer)) - - -class FindDefinitionUnicodeTests(test_util.TestCase): - - # TODO(craigcitro): Fix this test and re-enable it. - def notatestUnicodeString(self): - """Test using unicode names.""" - from protorpc import registry - self.assertEquals('ServiceMapping', - messages.find_definition( - u'protorpc.registry.ServiceMapping', - None).__name__) - - -def main(): - unittest.main() - - -if __name__ == '__main__': - main() diff --git a/endpoints/internal/protorpc/protobuf_test.py b/endpoints/internal/protorpc/protobuf_test.py deleted file mode 100644 index 9a65824..0000000 --- a/endpoints/internal/protorpc/protobuf_test.py +++ /dev/null @@ -1,299 +0,0 @@ -#!/usr/bin/env python -# -# Copyright 2010 Google Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -"""Tests for protorpc.protobuf.""" - -__author__ = 'rafek@google.com (Rafe Kaplan)' - - -import datetime -import unittest - -from protorpc import message_types -from protorpc import messages -from protorpc import protobuf -from protorpc import protorpc_test_pb2 -from protorpc import test_util -from protorpc import util - -# TODO: Add DateTimeFields to protorpc_test.proto when definition.py -# supports date time fields. -class HasDateTimeMessage(messages.Message): - value = message_types.DateTimeField(1) - -class NestedDateTimeMessage(messages.Message): - value = messages.MessageField(message_types.DateTimeMessage, 1) - - -class ModuleInterfaceTest(test_util.ModuleInterfaceTest, - test_util.TestCase): - - MODULE = protobuf - - -class EncodeMessageTest(test_util.TestCase, - test_util.ProtoConformanceTestBase): - """Test message to protocol buffer encoding.""" - - PROTOLIB = protobuf - - def assertErrorIs(self, exception, message, function, *params, **kwargs): - try: - function(*params, **kwargs) - self.fail('Expected to raise exception %s but did not.' % exception) - except exception as err: - self.assertEquals(message, str(err)) - - @property - def encoded_partial(self): - proto = protorpc_test_pb2.OptionalMessage() - proto.double_value = 1.23 - proto.int64_value = -100000000000 - proto.int32_value = 1020 - proto.string_value = u'a string' - proto.enum_value = protorpc_test_pb2.OptionalMessage.VAL2 - - return proto.SerializeToString() - - @property - def encoded_full(self): - proto = protorpc_test_pb2.OptionalMessage() - proto.double_value = 1.23 - proto.float_value = -2.5 - proto.int64_value = -100000000000 - proto.uint64_value = 102020202020 - proto.int32_value = 1020 - proto.bool_value = True - proto.string_value = u'a string\u044f' - proto.bytes_value = b'a bytes\xff\xfe' - proto.enum_value = protorpc_test_pb2.OptionalMessage.VAL2 - - return proto.SerializeToString() - - @property - def encoded_repeated(self): - proto = protorpc_test_pb2.RepeatedMessage() - proto.double_value.append(1.23) - proto.double_value.append(2.3) - proto.float_value.append(-2.5) - proto.float_value.append(0.5) - proto.int64_value.append(-100000000000) - proto.int64_value.append(20) - proto.uint64_value.append(102020202020) - proto.uint64_value.append(10) - proto.int32_value.append(1020) - proto.int32_value.append(718) - proto.bool_value.append(True) - proto.bool_value.append(False) - proto.string_value.append(u'a string\u044f') - proto.string_value.append(u'another string') - proto.bytes_value.append(b'a bytes\xff\xfe') - proto.bytes_value.append(b'another bytes') - proto.enum_value.append(protorpc_test_pb2.RepeatedMessage.VAL2) - proto.enum_value.append(protorpc_test_pb2.RepeatedMessage.VAL1) - - return proto.SerializeToString() - - @property - def encoded_nested(self): - proto = protorpc_test_pb2.HasNestedMessage() - proto.nested.a_value = 'a string' - - return proto.SerializeToString() - - @property - def encoded_repeated_nested(self): - proto = protorpc_test_pb2.HasNestedMessage() - proto.repeated_nested.add().a_value = 'a string' - proto.repeated_nested.add().a_value = 'another string' - - return proto.SerializeToString() - - unexpected_tag_message = ( - chr((15 << protobuf._WIRE_TYPE_BITS) | protobuf._Encoder.NUMERIC) + - chr(5)) - - @property - def encoded_default_assigned(self): - proto = protorpc_test_pb2.HasDefault() - proto.a_value = test_util.HasDefault.a_value.default - return proto.SerializeToString() - - @property - def encoded_nested_empty(self): - proto = protorpc_test_pb2.HasOptionalNestedMessage() - proto.nested.Clear() - return proto.SerializeToString() - - @property - def encoded_repeated_nested_empty(self): - proto = protorpc_test_pb2.HasOptionalNestedMessage() - proto.repeated_nested.add() - proto.repeated_nested.add() - return proto.SerializeToString() - - @property - def encoded_extend_message(self): - proto = protorpc_test_pb2.RepeatedMessage() - proto.add_int64_value(400) - proto.add_int64_value(50) - proto.add_int64_value(6000) - return proto.SerializeToString() - - @property - def encoded_string_types(self): - proto = protorpc_test_pb2.OptionalMessage() - proto.string_value = u'Latin' - return proto.SerializeToString() - - @property - def encoded_invalid_enum(self): - encoder = protobuf._Encoder() - field_num = test_util.OptionalMessage.enum_value.number - tag = (field_num << protobuf._WIRE_TYPE_BITS) | encoder.NUMERIC - encoder.putVarInt32(tag) - encoder.putVarInt32(1000) - return encoder.buffer().tostring() - - def testDecodeWrongWireFormat(self): - """Test what happens when wrong wire format found in protobuf.""" - class ExpectedProto(messages.Message): - value = messages.StringField(1) - - class WrongVariant(messages.Message): - value = messages.IntegerField(1) - - original = WrongVariant() - original.value = 10 - self.assertErrorIs(messages.DecodeError, - 'Expected wire type STRING but found NUMERIC', - protobuf.decode_message, - ExpectedProto, - protobuf.encode_message(original)) - - def testDecodeBadWireType(self): - """Test what happens when non-existant wire type found in protobuf.""" - # Message has tag 1, type 3 which does not exist. - bad_wire_type_message = chr((1 << protobuf._WIRE_TYPE_BITS) | 3) - - self.assertErrorIs(messages.DecodeError, - 'No such wire type 3', - protobuf.decode_message, - test_util.OptionalMessage, - bad_wire_type_message) - - def testUnexpectedTagBelowOne(self): - """Test that completely invalid tags generate an error.""" - # Message has tag 0, type NUMERIC. - invalid_tag_message = chr(protobuf._Encoder.NUMERIC) - - self.assertErrorIs(messages.DecodeError, - 'Invalid tag value 0', - protobuf.decode_message, - test_util.OptionalMessage, - invalid_tag_message) - - def testProtocolBufferDecodeError(self): - """Test what happens when there a ProtocolBufferDecodeError. - - This is what happens when the underlying ProtocolBuffer library raises - it's own decode error. - """ - # Message has tag 1, type DOUBLE, missing value. - truncated_message = ( - chr((1 << protobuf._WIRE_TYPE_BITS) | protobuf._Encoder.DOUBLE)) - - self.assertErrorIs(messages.DecodeError, - 'Decoding error: truncated', - protobuf.decode_message, - test_util.OptionalMessage, - truncated_message) - - def testProtobufUnrecognizedField(self): - """Test that unrecognized fields are serialized and can be accessed.""" - decoded = protobuf.decode_message(test_util.OptionalMessage, - self.unexpected_tag_message) - self.assertEquals(1, len(decoded.all_unrecognized_fields())) - self.assertEquals(15, decoded.all_unrecognized_fields()[0]) - self.assertEquals((5, messages.Variant.INT64), - decoded.get_unrecognized_field_info(15)) - - def testUnrecognizedFieldWrongFormat(self): - """Test that unrecognized fields in the wrong format are skipped.""" - - class SimpleMessage(messages.Message): - value = messages.IntegerField(1) - - message = SimpleMessage(value=3) - message.set_unrecognized_field('from_json', 'test', messages.Variant.STRING) - - encoded = protobuf.encode_message(message) - expected = ( - chr((1 << protobuf._WIRE_TYPE_BITS) | protobuf._Encoder.NUMERIC) + - chr(3)) - self.assertEquals(encoded, expected) - - def testProtobufDecodeDateTimeMessage(self): - """Test what happens when decoding a DateTimeMessage.""" - - nested = NestedDateTimeMessage() - nested.value = message_types.DateTimeMessage(milliseconds=2500) - value = protobuf.decode_message(HasDateTimeMessage, - protobuf.encode_message(nested)).value - self.assertEqual(datetime.datetime(1970, 1, 1, 0, 0, 2, 500000), value) - - def testProtobufDecodeDateTimeMessageWithTimeZone(self): - """Test what happens when decoding a DateTimeMessage with a time zone.""" - nested = NestedDateTimeMessage() - nested.value = message_types.DateTimeMessage(milliseconds=12345678, - time_zone_offset=60) - value = protobuf.decode_message(HasDateTimeMessage, - protobuf.encode_message(nested)).value - self.assertEqual(datetime.datetime(1970, 1, 1, 3, 25, 45, 678000, - tzinfo=util.TimeZoneOffset(60)), - value) - - def testProtobufEncodeDateTimeMessage(self): - """Test what happens when encoding a DateTimeField.""" - mine = HasDateTimeMessage(value=datetime.datetime(1970, 1, 1)) - nested = NestedDateTimeMessage() - nested.value = message_types.DateTimeMessage(milliseconds=0) - - my_encoded = protobuf.encode_message(mine) - encoded = protobuf.encode_message(nested) - self.assertEquals(my_encoded, encoded) - - def testProtobufEncodeDateTimeMessageWithTimeZone(self): - """Test what happens when encoding a DateTimeField with a time zone.""" - for tz_offset in (30, -30, 8 * 60, 0): - mine = HasDateTimeMessage(value=datetime.datetime( - 1970, 1, 1, tzinfo=util.TimeZoneOffset(tz_offset))) - nested = NestedDateTimeMessage() - nested.value = message_types.DateTimeMessage( - milliseconds=0, time_zone_offset=tz_offset) - - my_encoded = protobuf.encode_message(mine) - encoded = protobuf.encode_message(nested) - self.assertEquals(my_encoded, encoded) - - -def main(): - unittest.main() - - -if __name__ == '__main__': - main() diff --git a/endpoints/internal/protorpc/protojson_test.py b/endpoints/internal/protorpc/protojson_test.py deleted file mode 100644 index b71f93f..0000000 --- a/endpoints/internal/protorpc/protojson_test.py +++ /dev/null @@ -1,565 +0,0 @@ -#!/usr/bin/env python -# -# Copyright 2010 Google Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -"""Tests for protorpc.protojson.""" - -__author__ = 'rafek@google.com (Rafe Kaplan)' - - -import datetime -import imp -import sys -import unittest - -from protorpc import message_types -from protorpc import messages -from protorpc import protojson -from protorpc import test_util - -try: - import json -except ImportError: - import simplejson as json - - -class CustomField(messages.MessageField): - """Custom MessageField class.""" - - type = int - message_type = message_types.VoidMessage - - def __init__(self, number, **kwargs): - super(CustomField, self).__init__(self.message_type, number, **kwargs) - - def value_to_message(self, value): - return self.message_type() - - -class MyMessage(messages.Message): - """Test message containing various types.""" - - class Color(messages.Enum): - - RED = 1 - GREEN = 2 - BLUE = 3 - - class Nested(messages.Message): - - nested_value = messages.StringField(1) - - a_string = messages.StringField(2) - an_integer = messages.IntegerField(3) - a_float = messages.FloatField(4) - a_boolean = messages.BooleanField(5) - an_enum = messages.EnumField(Color, 6) - a_nested = messages.MessageField(Nested, 7) - a_repeated = messages.IntegerField(8, repeated=True) - a_repeated_float = messages.FloatField(9, repeated=True) - a_datetime = message_types.DateTimeField(10) - a_repeated_datetime = message_types.DateTimeField(11, repeated=True) - a_custom = CustomField(12) - a_repeated_custom = CustomField(13, repeated=True) - - -class ModuleInterfaceTest(test_util.ModuleInterfaceTest, - test_util.TestCase): - - MODULE = protojson - - -# TODO(rafek): Convert this test to the compliance test in test_util. -class ProtojsonTest(test_util.TestCase, - test_util.ProtoConformanceTestBase): - """Test JSON encoding and decoding.""" - - PROTOLIB = protojson - - def CompareEncoded(self, expected_encoded, actual_encoded): - """JSON encoding will be laundered to remove string differences.""" - self.assertEquals(json.loads(expected_encoded), - json.loads(actual_encoded)) - - encoded_empty_message = '{}' - - encoded_partial = """{ - "double_value": 1.23, - "int64_value": -100000000000, - "int32_value": 1020, - "string_value": "a string", - "enum_value": "VAL2" - } - """ - - encoded_full = """{ - "double_value": 1.23, - "float_value": -2.5, - "int64_value": -100000000000, - "uint64_value": 102020202020, - "int32_value": 1020, - "bool_value": true, - "string_value": "a string\u044f", - "bytes_value": "YSBieXRlc//+", - "enum_value": "VAL2" - } - """ - - encoded_repeated = """{ - "double_value": [1.23, 2.3], - "float_value": [-2.5, 0.5], - "int64_value": [-100000000000, 20], - "uint64_value": [102020202020, 10], - "int32_value": [1020, 718], - "bool_value": [true, false], - "string_value": ["a string\u044f", "another string"], - "bytes_value": ["YSBieXRlc//+", "YW5vdGhlciBieXRlcw=="], - "enum_value": ["VAL2", "VAL1"] - } - """ - - encoded_nested = """{ - "nested": { - "a_value": "a string" - } - } - """ - - encoded_repeated_nested = """{ - "repeated_nested": [{"a_value": "a string"}, - {"a_value": "another string"}] - } - """ - - unexpected_tag_message = '{"unknown": "value"}' - - encoded_default_assigned = '{"a_value": "a default"}' - - encoded_nested_empty = '{"nested": {}}' - - encoded_repeated_nested_empty = '{"repeated_nested": [{}, {}]}' - - encoded_extend_message = '{"int64_value": [400, 50, 6000]}' - - encoded_string_types = '{"string_value": "Latin"}' - - encoded_invalid_enum = '{"enum_value": "undefined"}' - - def testConvertIntegerToFloat(self): - """Test that integers passed in to float fields are converted. - - This is necessary because JSON outputs integers for numbers with 0 decimals. - """ - message = protojson.decode_message(MyMessage, '{"a_float": 10}') - - self.assertTrue(isinstance(message.a_float, float)) - self.assertEquals(10.0, message.a_float) - - def testConvertStringToNumbers(self): - """Test that strings passed to integer fields are converted.""" - message = protojson.decode_message(MyMessage, - """{"an_integer": "10", - "a_float": "3.5", - "a_repeated": ["1", "2"], - "a_repeated_float": ["1.5", "2", 10] - }""") - - self.assertEquals(MyMessage(an_integer=10, - a_float=3.5, - a_repeated=[1, 2], - a_repeated_float=[1.5, 2.0, 10.0]), - message) - - def testWrongTypeAssignment(self): - """Test when wrong type is assigned to a field.""" - self.assertRaises(messages.ValidationError, - protojson.decode_message, - MyMessage, '{"a_string": 10}') - self.assertRaises(messages.ValidationError, - protojson.decode_message, - MyMessage, '{"an_integer": 10.2}') - self.assertRaises(messages.ValidationError, - protojson.decode_message, - MyMessage, '{"an_integer": "10.2"}') - - def testNumericEnumeration(self): - """Test that numbers work for enum values.""" - message = protojson.decode_message(MyMessage, '{"an_enum": 2}') - - expected_message = MyMessage() - expected_message.an_enum = MyMessage.Color.GREEN - - self.assertEquals(expected_message, message) - - def testNumericEnumerationNegativeTest(self): - """Test with an invalid number for the enum value.""" - self.assertRaisesRegexp( - messages.DecodeError, - 'Invalid enum value "89"', - protojson.decode_message, - MyMessage, - '{"an_enum": 89}') - - def testAlphaEnumeration(self): - """Test that alpha enum values work.""" - message = protojson.decode_message(MyMessage, '{"an_enum": "RED"}') - - expected_message = MyMessage() - expected_message.an_enum = MyMessage.Color.RED - - self.assertEquals(expected_message, message) - - def testAlphaEnumerationNegativeTest(self): - """The alpha enum value is invalid.""" - self.assertRaisesRegexp( - messages.DecodeError, - 'Invalid enum value "IAMINVALID"', - protojson.decode_message, - MyMessage, - '{"an_enum": "IAMINVALID"}') - - def testEnumerationNegativeTestWithEmptyString(self): - """The enum value is an empty string.""" - self.assertRaisesRegexp( - messages.DecodeError, - 'Invalid enum value ""', - protojson.decode_message, - MyMessage, - '{"an_enum": ""}') - - def testNullValues(self): - """Test that null values overwrite existing values.""" - self.assertEquals(MyMessage(), - protojson.decode_message(MyMessage, - ('{"an_integer": null,' - ' "a_nested": null,' - ' "an_enum": null' - '}'))) - - def testEmptyList(self): - """Test that empty lists are ignored.""" - self.assertEquals(MyMessage(), - protojson.decode_message(MyMessage, - '{"a_repeated": []}')) - - def testNotJSON(self): - """Test error when string is not valid JSON.""" - self.assertRaises(ValueError, - protojson.decode_message, MyMessage, '{this is not json}') - - def testDoNotEncodeStrangeObjects(self): - """Test trying to encode a strange object. - - The main purpose of this test is to complete coverage. It ensures that - the default behavior of the JSON encoder is preserved when someone tries to - serialized an unexpected type. - """ - class BogusObject(object): - - def check_initialized(self): - pass - - self.assertRaises(TypeError, - protojson.encode_message, - BogusObject()) - - def testMergeEmptyString(self): - """Test merging the empty or space only string.""" - message = protojson.decode_message(test_util.OptionalMessage, '') - self.assertEquals(test_util.OptionalMessage(), message) - - message = protojson.decode_message(test_util.OptionalMessage, ' ') - self.assertEquals(test_util.OptionalMessage(), message) - - def testMeregeInvalidEmptyMessage(self): - self.assertRaisesWithRegexpMatch(messages.ValidationError, - 'Message NestedMessage is missing ' - 'required field a_value', - self.PROTOLIB.decode_message, - test_util.NestedMessage, - '') - - def testProtojsonUnrecognizedFieldName(self): - """Test that unrecognized fields are saved and can be accessed.""" - decoded = protojson.decode_message(MyMessage, - ('{"an_integer": 1, "unknown_val": 2}')) - self.assertEquals(decoded.an_integer, 1) - self.assertEquals(1, len(decoded.all_unrecognized_fields())) - self.assertEquals('unknown_val', decoded.all_unrecognized_fields()[0]) - self.assertEquals((2, messages.Variant.INT64), - decoded.get_unrecognized_field_info('unknown_val')) - - def testProtojsonUnrecognizedFieldNumber(self): - """Test that unrecognized fields are saved and can be accessed.""" - decoded = protojson.decode_message( - MyMessage, - '{"an_integer": 1, "1001": "unknown", "-123": "negative", ' - '"456_mixed": 2}') - self.assertEquals(decoded.an_integer, 1) - self.assertEquals(3, len(decoded.all_unrecognized_fields())) - self.assertTrue(1001 in decoded.all_unrecognized_fields()) - self.assertEquals(('unknown', messages.Variant.STRING), - decoded.get_unrecognized_field_info(1001)) - self.assertTrue('-123' in decoded.all_unrecognized_fields()) - self.assertEquals(('negative', messages.Variant.STRING), - decoded.get_unrecognized_field_info('-123')) - self.assertTrue('456_mixed' in decoded.all_unrecognized_fields()) - self.assertEquals((2, messages.Variant.INT64), - decoded.get_unrecognized_field_info('456_mixed')) - - def testProtojsonUnrecognizedNull(self): - """Test that unrecognized fields that are None are skipped.""" - decoded = protojson.decode_message( - MyMessage, - '{"an_integer": 1, "unrecognized_null": null}') - self.assertEquals(decoded.an_integer, 1) - self.assertEquals(decoded.all_unrecognized_fields(), []) - - def testUnrecognizedFieldVariants(self): - """Test that unrecognized fields are mapped to the right variants.""" - for encoded, expected_variant in ( - ('{"an_integer": 1, "unknown_val": 2}', messages.Variant.INT64), - ('{"an_integer": 1, "unknown_val": 2.0}', messages.Variant.DOUBLE), - ('{"an_integer": 1, "unknown_val": "string value"}', - messages.Variant.STRING), - ('{"an_integer": 1, "unknown_val": [1, 2, 3]}', messages.Variant.INT64), - ('{"an_integer": 1, "unknown_val": [1, 2.0, 3]}', - messages.Variant.DOUBLE), - ('{"an_integer": 1, "unknown_val": [1, "foo", 3]}', - messages.Variant.STRING), - ('{"an_integer": 1, "unknown_val": true}', messages.Variant.BOOL)): - decoded = protojson.decode_message(MyMessage, encoded) - self.assertEquals(decoded.an_integer, 1) - self.assertEquals(1, len(decoded.all_unrecognized_fields())) - self.assertEquals('unknown_val', decoded.all_unrecognized_fields()[0]) - _, decoded_variant = decoded.get_unrecognized_field_info('unknown_val') - self.assertEquals(expected_variant, decoded_variant) - - def testDecodeDateTime(self): - for datetime_string, datetime_vals in ( - ('2012-09-30T15:31:50.262', (2012, 9, 30, 15, 31, 50, 262000)), - ('2012-09-30T15:31:50', (2012, 9, 30, 15, 31, 50, 0))): - message = protojson.decode_message( - MyMessage, '{"a_datetime": "%s"}' % datetime_string) - expected_message = MyMessage( - a_datetime=datetime.datetime(*datetime_vals)) - - self.assertEquals(expected_message, message) - - def testDecodeInvalidDateTime(self): - self.assertRaises(messages.DecodeError, protojson.decode_message, - MyMessage, '{"a_datetime": "invalid"}') - - def testEncodeDateTime(self): - for datetime_string, datetime_vals in ( - ('2012-09-30T15:31:50.262000', (2012, 9, 30, 15, 31, 50, 262000)), - ('2012-09-30T15:31:50.262123', (2012, 9, 30, 15, 31, 50, 262123)), - ('2012-09-30T15:31:50', (2012, 9, 30, 15, 31, 50, 0))): - decoded_message = protojson.encode_message( - MyMessage(a_datetime=datetime.datetime(*datetime_vals))) - expected_decoding = '{"a_datetime": "%s"}' % datetime_string - self.CompareEncoded(expected_decoding, decoded_message) - - def testDecodeRepeatedDateTime(self): - message = protojson.decode_message( - MyMessage, - '{"a_repeated_datetime": ["2012-09-30T15:31:50.262", ' - '"2010-01-21T09:52:00", "2000-01-01T01:00:59.999999"]}') - expected_message = MyMessage( - a_repeated_datetime=[ - datetime.datetime(2012, 9, 30, 15, 31, 50, 262000), - datetime.datetime(2010, 1, 21, 9, 52), - datetime.datetime(2000, 1, 1, 1, 0, 59, 999999)]) - - self.assertEquals(expected_message, message) - - def testDecodeCustom(self): - message = protojson.decode_message(MyMessage, '{"a_custom": 1}') - self.assertEquals(MyMessage(a_custom=1), message) - - def testDecodeInvalidCustom(self): - self.assertRaises(messages.ValidationError, protojson.decode_message, - MyMessage, '{"a_custom": "invalid"}') - - def testEncodeCustom(self): - decoded_message = protojson.encode_message(MyMessage(a_custom=1)) - self.CompareEncoded('{"a_custom": 1}', decoded_message) - - def testDecodeRepeatedCustom(self): - message = protojson.decode_message( - MyMessage, '{"a_repeated_custom": [1, 2, 3]}') - self.assertEquals(MyMessage(a_repeated_custom=[1, 2, 3]), message) - - def testDecodeBadBase64BytesField(self): - """Test decoding improperly encoded base64 bytes value.""" - self.assertRaisesWithRegexpMatch( - messages.DecodeError, - 'Base64 decoding error: Incorrect padding', - protojson.decode_message, - test_util.OptionalMessage, - '{"bytes_value": "abcdefghijklmnopq"}') - - -class CustomProtoJson(protojson.ProtoJson): - - def encode_field(self, field, value): - return '{encoded}' + value - - def decode_field(self, field, value): - return '{decoded}' + value - - -class CustomProtoJsonTest(test_util.TestCase): - """Tests for serialization overriding functionality.""" - - def setUp(self): - self.protojson = CustomProtoJson() - - def testEncode(self): - self.assertEqual(u'{"a_string": "{encoded}xyz"}', - self.protojson.encode_message(MyMessage(a_string=u'xyz'))) - - def testDecode(self): - self.assertEqual( - MyMessage(a_string=u'{decoded}xyz'), - self.protojson.decode_message(MyMessage, u'{"a_string": "xyz"}')) - - def testDecodeEmptyMessage(self): - self.assertEqual( - MyMessage(a_string=u'{decoded}'), - self.protojson.decode_message(MyMessage, u'{"a_string": ""}')) - - def testDefault(self): - self.assertTrue(protojson.ProtoJson.get_default(), - protojson.ProtoJson.get_default()) - - instance = CustomProtoJson() - protojson.ProtoJson.set_default(instance) - self.assertTrue(instance is protojson.ProtoJson.get_default()) - - -class InvalidJsonModule(object): - pass - - -class ValidJsonModule(object): - class JSONEncoder(object): - pass - - -class TestJsonDependencyLoading(test_util.TestCase): - """Test loading various implementations of json.""" - - def get_import(self): - """Get __import__ method. - - Returns: - The current __import__ method. - """ - if isinstance(__builtins__, dict): - return __builtins__['__import__'] - else: - return __builtins__.__import__ - - def set_import(self, new_import): - """Set __import__ method. - - Args: - new_import: Function to replace __import__. - """ - if isinstance(__builtins__, dict): - __builtins__['__import__'] = new_import - else: - __builtins__.__import__ = new_import - - def setUp(self): - """Save original import function.""" - self.simplejson = sys.modules.pop('simplejson', None) - self.json = sys.modules.pop('json', None) - self.original_import = self.get_import() - def block_all_jsons(name, *args, **kwargs): - if 'json' in name: - if name in sys.modules: - module = sys.modules[name] - module.name = name - return module - raise ImportError('Unable to find %s' % name) - else: - return self.original_import(name, *args, **kwargs) - self.set_import(block_all_jsons) - - def tearDown(self): - """Restore original import functions and any loaded modules.""" - - def reset_module(name, module): - if module: - sys.modules[name] = module - else: - sys.modules.pop(name, None) - reset_module('simplejson', self.simplejson) - reset_module('json', self.json) - imp.reload(protojson) - - def testLoadProtojsonWithValidJsonModule(self): - """Test loading protojson module with a valid json dependency.""" - sys.modules['json'] = ValidJsonModule - - # This will cause protojson to reload with the default json module - # instead of simplejson. - imp.reload(protojson) - self.assertEquals('json', protojson.json.name) - - def testLoadProtojsonWithSimplejsonModule(self): - """Test loading protojson module with simplejson dependency.""" - sys.modules['simplejson'] = ValidJsonModule - - # This will cause protojson to reload with the default json module - # instead of simplejson. - imp.reload(protojson) - self.assertEquals('simplejson', protojson.json.name) - - def testLoadProtojsonWithInvalidJsonModule(self): - """Loading protojson module with an invalid json defaults to simplejson.""" - sys.modules['json'] = InvalidJsonModule - sys.modules['simplejson'] = ValidJsonModule - - # Ignore bad module and default back to simplejson. - imp.reload(protojson) - self.assertEquals('simplejson', protojson.json.name) - - def testLoadProtojsonWithInvalidJsonModuleAndNoSimplejson(self): - """Loading protojson module with invalid json and no simplejson.""" - sys.modules['json'] = InvalidJsonModule - - # Bad module without simplejson back raises errors. - self.assertRaisesWithRegexpMatch( - ImportError, - 'json library "json" is not compatible with ProtoRPC', - imp.reload, - protojson) - - def testLoadProtojsonWithNoJsonModules(self): - """Loading protojson module with invalid json and no simplejson.""" - # No json modules raise the first exception. - self.assertRaisesWithRegexpMatch( - ImportError, - 'Unable to find json', - imp.reload, - protojson) - - -if __name__ == '__main__': - unittest.main() diff --git a/endpoints/internal/protorpc/protorpc_test_pb2.py b/endpoints/internal/protorpc/protorpc_test_pb2.py deleted file mode 100644 index 1dc3852..0000000 --- a/endpoints/internal/protorpc/protorpc_test_pb2.py +++ /dev/null @@ -1,405 +0,0 @@ -# Generated by the protocol buffer compiler. DO NOT EDIT (except the imports)! - -# Replace auto generated imports with .non_sdk_imports manually! -# Do the replacement and copy this comment everytime! -from .non_sdk_imports import descriptor -from .non_sdk_imports import message -from .non_sdk_imports import reflection -from .non_sdk_imports import descriptor_pb2 -import six - -# @@protoc_insertion_point(imports) - - - -DESCRIPTOR = descriptor.FileDescriptor( - name='protorpc_test.proto', - package='protorpc', - serialized_pb='\n\x13protorpc_test.proto\x12\x08protorpc\" \n\rNestedMessage\x12\x0f\n\x07\x61_value\x18\x01 \x02(\t\"m\n\x10HasNestedMessage\x12\'\n\x06nested\x18\x01 \x01(\x0b\x32\x17.protorpc.NestedMessage\x12\x30\n\x0frepeated_nested\x18\x02 \x03(\x0b\x32\x17.protorpc.NestedMessage\"(\n\nHasDefault\x12\x1a\n\x07\x61_value\x18\x01 \x01(\t:\ta default\"\x97\x02\n\x0fOptionalMessage\x12\x14\n\x0c\x64ouble_value\x18\x01 \x01(\x01\x12\x13\n\x0b\x66loat_value\x18\x02 \x01(\x02\x12\x13\n\x0bint64_value\x18\x03 \x01(\x03\x12\x14\n\x0cuint64_value\x18\x04 \x01(\x04\x12\x13\n\x0bint32_value\x18\x05 \x01(\x05\x12\x12\n\nbool_value\x18\x06 \x01(\x08\x12\x14\n\x0cstring_value\x18\x07 \x01(\t\x12\x13\n\x0b\x62ytes_value\x18\x08 \x01(\x0c\x12\x38\n\nenum_value\x18\n \x01(\x0e\x32$.protorpc.OptionalMessage.SimpleEnum\" \n\nSimpleEnum\x12\x08\n\x04VAL1\x10\x01\x12\x08\n\x04VAL2\x10\x02\"\x97\x02\n\x0fRepeatedMessage\x12\x14\n\x0c\x64ouble_value\x18\x01 \x03(\x01\x12\x13\n\x0b\x66loat_value\x18\x02 \x03(\x02\x12\x13\n\x0bint64_value\x18\x03 \x03(\x03\x12\x14\n\x0cuint64_value\x18\x04 \x03(\x04\x12\x13\n\x0bint32_value\x18\x05 \x03(\x05\x12\x12\n\nbool_value\x18\x06 \x03(\x08\x12\x14\n\x0cstring_value\x18\x07 \x03(\t\x12\x13\n\x0b\x62ytes_value\x18\x08 \x03(\x0c\x12\x38\n\nenum_value\x18\n \x03(\x0e\x32$.protorpc.RepeatedMessage.SimpleEnum\" \n\nSimpleEnum\x12\x08\n\x04VAL1\x10\x01\x12\x08\n\x04VAL2\x10\x02\"y\n\x18HasOptionalNestedMessage\x12)\n\x06nested\x18\x01 \x01(\x0b\x32\x19.protorpc.OptionalMessage\x12\x32\n\x0frepeated_nested\x18\x02 \x03(\x0b\x32\x19.protorpc.OptionalMessage') - - - -_OPTIONALMESSAGE_SIMPLEENUM = descriptor.EnumDescriptor( - name='SimpleEnum', - full_name='protorpc.OptionalMessage.SimpleEnum', - filename=None, - file=DESCRIPTOR, - values=[ - descriptor.EnumValueDescriptor( - name='VAL1', index=0, number=1, - options=None, - type=None), - descriptor.EnumValueDescriptor( - name='VAL2', index=1, number=2, - options=None, - type=None), - ], - containing_type=None, - options=None, - serialized_start=468, - serialized_end=500, -) - -_REPEATEDMESSAGE_SIMPLEENUM = descriptor.EnumDescriptor( - name='SimpleEnum', - full_name='protorpc.RepeatedMessage.SimpleEnum', - filename=None, - file=DESCRIPTOR, - values=[ - descriptor.EnumValueDescriptor( - name='VAL1', index=0, number=1, - options=None, - type=None), - descriptor.EnumValueDescriptor( - name='VAL2', index=1, number=2, - options=None, - type=None), - ], - containing_type=None, - options=None, - serialized_start=468, - serialized_end=500, -) - - -_NESTEDMESSAGE = descriptor.Descriptor( - name='NestedMessage', - full_name='protorpc.NestedMessage', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - descriptor.FieldDescriptor( - name='a_value', full_name='protorpc.NestedMessage.a_value', index=0, - number=1, type=9, cpp_type=9, label=2, - has_default_value=False, default_value=six.text_type("", "utf-8"), - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - options=None), - ], - extensions=[ - ], - nested_types=[], - enum_types=[ - ], - options=None, - is_extendable=False, - extension_ranges=[], - serialized_start=33, - serialized_end=65, -) - - -_HASNESTEDMESSAGE = descriptor.Descriptor( - name='HasNestedMessage', - full_name='protorpc.HasNestedMessage', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - descriptor.FieldDescriptor( - name='nested', full_name='protorpc.HasNestedMessage.nested', index=0, - number=1, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - options=None), - descriptor.FieldDescriptor( - name='repeated_nested', full_name='protorpc.HasNestedMessage.repeated_nested', index=1, - number=2, type=11, cpp_type=10, label=3, - has_default_value=False, default_value=[], - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - options=None), - ], - extensions=[ - ], - nested_types=[], - enum_types=[ - ], - options=None, - is_extendable=False, - extension_ranges=[], - serialized_start=67, - serialized_end=176, -) - - -_HASDEFAULT = descriptor.Descriptor( - name='HasDefault', - full_name='protorpc.HasDefault', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - descriptor.FieldDescriptor( - name='a_value', full_name='protorpc.HasDefault.a_value', index=0, - number=1, type=9, cpp_type=9, label=1, - has_default_value=True, default_value=six.text_type("a default", "utf-8"), - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - options=None), - ], - extensions=[ - ], - nested_types=[], - enum_types=[ - ], - options=None, - is_extendable=False, - extension_ranges=[], - serialized_start=178, - serialized_end=218, -) - - -_OPTIONALMESSAGE = descriptor.Descriptor( - name='OptionalMessage', - full_name='protorpc.OptionalMessage', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - descriptor.FieldDescriptor( - name='double_value', full_name='protorpc.OptionalMessage.double_value', index=0, - number=1, type=1, cpp_type=5, label=1, - has_default_value=False, default_value=0, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - options=None), - descriptor.FieldDescriptor( - name='float_value', full_name='protorpc.OptionalMessage.float_value', index=1, - number=2, type=2, cpp_type=6, label=1, - has_default_value=False, default_value=0, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - options=None), - descriptor.FieldDescriptor( - name='int64_value', full_name='protorpc.OptionalMessage.int64_value', index=2, - number=3, type=3, cpp_type=2, label=1, - has_default_value=False, default_value=0, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - options=None), - descriptor.FieldDescriptor( - name='uint64_value', full_name='protorpc.OptionalMessage.uint64_value', index=3, - number=4, type=4, cpp_type=4, label=1, - has_default_value=False, default_value=0, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - options=None), - descriptor.FieldDescriptor( - name='int32_value', full_name='protorpc.OptionalMessage.int32_value', index=4, - number=5, type=5, cpp_type=1, label=1, - has_default_value=False, default_value=0, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - options=None), - descriptor.FieldDescriptor( - name='bool_value', full_name='protorpc.OptionalMessage.bool_value', index=5, - number=6, type=8, cpp_type=7, label=1, - has_default_value=False, default_value=False, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - options=None), - descriptor.FieldDescriptor( - name='string_value', full_name='protorpc.OptionalMessage.string_value', index=6, - number=7, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=six.text_type("", "utf-8"), - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - options=None), - descriptor.FieldDescriptor( - name='bytes_value', full_name='protorpc.OptionalMessage.bytes_value', index=7, - number=8, type=12, cpp_type=9, label=1, - has_default_value=False, default_value="", - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - options=None), - descriptor.FieldDescriptor( - name='enum_value', full_name='protorpc.OptionalMessage.enum_value', index=8, - number=10, type=14, cpp_type=8, label=1, - has_default_value=False, default_value=1, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - options=None), - ], - extensions=[ - ], - nested_types=[], - enum_types=[ - _OPTIONALMESSAGE_SIMPLEENUM, - ], - options=None, - is_extendable=False, - extension_ranges=[], - serialized_start=221, - serialized_end=500, -) - - -_REPEATEDMESSAGE = descriptor.Descriptor( - name='RepeatedMessage', - full_name='protorpc.RepeatedMessage', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - descriptor.FieldDescriptor( - name='double_value', full_name='protorpc.RepeatedMessage.double_value', index=0, - number=1, type=1, cpp_type=5, label=3, - has_default_value=False, default_value=[], - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - options=None), - descriptor.FieldDescriptor( - name='float_value', full_name='protorpc.RepeatedMessage.float_value', index=1, - number=2, type=2, cpp_type=6, label=3, - has_default_value=False, default_value=[], - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - options=None), - descriptor.FieldDescriptor( - name='int64_value', full_name='protorpc.RepeatedMessage.int64_value', index=2, - number=3, type=3, cpp_type=2, label=3, - has_default_value=False, default_value=[], - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - options=None), - descriptor.FieldDescriptor( - name='uint64_value', full_name='protorpc.RepeatedMessage.uint64_value', index=3, - number=4, type=4, cpp_type=4, label=3, - has_default_value=False, default_value=[], - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - options=None), - descriptor.FieldDescriptor( - name='int32_value', full_name='protorpc.RepeatedMessage.int32_value', index=4, - number=5, type=5, cpp_type=1, label=3, - has_default_value=False, default_value=[], - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - options=None), - descriptor.FieldDescriptor( - name='bool_value', full_name='protorpc.RepeatedMessage.bool_value', index=5, - number=6, type=8, cpp_type=7, label=3, - has_default_value=False, default_value=[], - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - options=None), - descriptor.FieldDescriptor( - name='string_value', full_name='protorpc.RepeatedMessage.string_value', index=6, - number=7, type=9, cpp_type=9, label=3, - has_default_value=False, default_value=[], - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - options=None), - descriptor.FieldDescriptor( - name='bytes_value', full_name='protorpc.RepeatedMessage.bytes_value', index=7, - number=8, type=12, cpp_type=9, label=3, - has_default_value=False, default_value=[], - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - options=None), - descriptor.FieldDescriptor( - name='enum_value', full_name='protorpc.RepeatedMessage.enum_value', index=8, - number=10, type=14, cpp_type=8, label=3, - has_default_value=False, default_value=[], - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - options=None), - ], - extensions=[ - ], - nested_types=[], - enum_types=[ - _REPEATEDMESSAGE_SIMPLEENUM, - ], - options=None, - is_extendable=False, - extension_ranges=[], - serialized_start=503, - serialized_end=782, -) - - -_HASOPTIONALNESTEDMESSAGE = descriptor.Descriptor( - name='HasOptionalNestedMessage', - full_name='protorpc.HasOptionalNestedMessage', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - descriptor.FieldDescriptor( - name='nested', full_name='protorpc.HasOptionalNestedMessage.nested', index=0, - number=1, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - options=None), - descriptor.FieldDescriptor( - name='repeated_nested', full_name='protorpc.HasOptionalNestedMessage.repeated_nested', index=1, - number=2, type=11, cpp_type=10, label=3, - has_default_value=False, default_value=[], - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - options=None), - ], - extensions=[ - ], - nested_types=[], - enum_types=[ - ], - options=None, - is_extendable=False, - extension_ranges=[], - serialized_start=784, - serialized_end=905, -) - -_HASNESTEDMESSAGE.fields_by_name['nested'].message_type = _NESTEDMESSAGE -_HASNESTEDMESSAGE.fields_by_name['repeated_nested'].message_type = _NESTEDMESSAGE -_OPTIONALMESSAGE.fields_by_name['enum_value'].enum_type = _OPTIONALMESSAGE_SIMPLEENUM -_OPTIONALMESSAGE_SIMPLEENUM.containing_type = _OPTIONALMESSAGE; -_REPEATEDMESSAGE.fields_by_name['enum_value'].enum_type = _REPEATEDMESSAGE_SIMPLEENUM -_REPEATEDMESSAGE_SIMPLEENUM.containing_type = _REPEATEDMESSAGE; -_HASOPTIONALNESTEDMESSAGE.fields_by_name['nested'].message_type = _OPTIONALMESSAGE -_HASOPTIONALNESTEDMESSAGE.fields_by_name['repeated_nested'].message_type = _OPTIONALMESSAGE -DESCRIPTOR.message_types_by_name['NestedMessage'] = _NESTEDMESSAGE -DESCRIPTOR.message_types_by_name['HasNestedMessage'] = _HASNESTEDMESSAGE -DESCRIPTOR.message_types_by_name['HasDefault'] = _HASDEFAULT -DESCRIPTOR.message_types_by_name['OptionalMessage'] = _OPTIONALMESSAGE -DESCRIPTOR.message_types_by_name['RepeatedMessage'] = _REPEATEDMESSAGE -DESCRIPTOR.message_types_by_name['HasOptionalNestedMessage'] = _HASOPTIONALNESTEDMESSAGE - -class NestedMessage(six.with_metaclass(reflection.GeneratedProtocolMessageType, message.Message)): - DESCRIPTOR = _NESTEDMESSAGE - - # @@protoc_insertion_point(class_scope:protorpc.NestedMessage) - -class HasNestedMessage(six.with_metaclass(reflection.GeneratedProtocolMessageType, message.Message)): - DESCRIPTOR = _HASNESTEDMESSAGE - - # @@protoc_insertion_point(class_scope:protorpc.HasNestedMessage) - -class HasDefault(six.with_metaclass(reflection.GeneratedProtocolMessageType, message.Message)): - DESCRIPTOR = _HASDEFAULT - - # @@protoc_insertion_point(class_scope:protorpc.HasDefault) - -class OptionalMessage(six.with_metaclass(reflection.GeneratedProtocolMessageType, message.Message)): - DESCRIPTOR = _OPTIONALMESSAGE - - # @@protoc_insertion_point(class_scope:protorpc.OptionalMessage) - -class RepeatedMessage(six.with_metaclass(reflection.GeneratedProtocolMessageType, message.Message)): - DESCRIPTOR = _REPEATEDMESSAGE - - # @@protoc_insertion_point(class_scope:protorpc.RepeatedMessage) - -class HasOptionalNestedMessage(six.with_metaclass(reflection.GeneratedProtocolMessageType, message.Message)): - DESCRIPTOR = _HASOPTIONALNESTEDMESSAGE - - # @@protoc_insertion_point(class_scope:protorpc.HasOptionalNestedMessage) - -# @@protoc_insertion_point(module_scope) diff --git a/endpoints/internal/protorpc/protourlencode_test.py b/endpoints/internal/protorpc/protourlencode_test.py deleted file mode 100644 index 0121896..0000000 --- a/endpoints/internal/protorpc/protourlencode_test.py +++ /dev/null @@ -1,369 +0,0 @@ -#!/usr/bin/env python -# -# Copyright 2010 Google Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -"""Tests for protorpc.protourlencode.""" - -__author__ = 'rafek@google.com (Rafe Kaplan)' - - -import cgi -import logging -import unittest -import urllib - -from protorpc import message_types -from protorpc import messages -from protorpc import protourlencode -from protorpc import test_util - - -class ModuleInterfaceTest(test_util.ModuleInterfaceTest, - test_util.TestCase): - - MODULE = protourlencode - - - -class SuperMessage(messages.Message): - """A test message with a nested message field.""" - - sub_message = messages.MessageField(test_util.OptionalMessage, 1) - sub_messages = messages.MessageField(test_util.OptionalMessage, - 2, - repeated=True) - - -class SuperSuperMessage(messages.Message): - """A test message with two levels of nested.""" - - sub_message = messages.MessageField(SuperMessage, 1) - sub_messages = messages.MessageField(SuperMessage, 2, repeated=True) - - -class URLEncodedRequestBuilderTest(test_util.TestCase): - """Test the URL Encoded request builder.""" - - def testMakePath(self): - builder = protourlencode.URLEncodedRequestBuilder(SuperSuperMessage(), - prefix='pre.') - - self.assertEquals(None, builder.make_path('')) - self.assertEquals(None, builder.make_path('no_such_field')) - self.assertEquals(None, builder.make_path('pre.no_such_field')) - - # Missing prefix. - self.assertEquals(None, builder.make_path('sub_message')) - - # Valid parameters. - self.assertEquals((('sub_message', None),), - builder.make_path('pre.sub_message')) - self.assertEquals((('sub_message', None), ('sub_messages', 1)), - builder.make_path('pre.sub_message.sub_messages-1')) - self.assertEquals( - (('sub_message', None), - ('sub_messages', 1), - ('int64_value', None)), - builder.make_path('pre.sub_message.sub_messages-1.int64_value')) - - # Missing index. - self.assertEquals( - None, - builder.make_path('pre.sub_message.sub_messages.integer_field')) - - # Has unexpected index. - self.assertEquals( - None, - builder.make_path('pre.sub_message.sub_message-1.integer_field')) - - def testAddParameter_SimpleAttributes(self): - message = test_util.OptionalMessage() - builder = protourlencode.URLEncodedRequestBuilder(message, prefix='pre.') - - self.assertTrue(builder.add_parameter('pre.int64_value', ['10'])) - self.assertTrue(builder.add_parameter('pre.string_value', ['a string'])) - self.assertTrue(builder.add_parameter('pre.enum_value', ['VAL1'])) - self.assertEquals(10, message.int64_value) - self.assertEquals('a string', message.string_value) - self.assertEquals(test_util.OptionalMessage.SimpleEnum.VAL1, - message.enum_value) - - def testAddParameter_InvalidAttributes(self): - message = SuperSuperMessage() - builder = protourlencode.URLEncodedRequestBuilder(message, prefix='pre.') - - def assert_empty(): - self.assertEquals(None, getattr(message, 'sub_message')) - self.assertEquals([], getattr(message, 'sub_messages')) - - self.assertFalse(builder.add_parameter('pre.nothing', ['x'])) - assert_empty() - - self.assertFalse(builder.add_parameter('pre.sub_messages', ['x'])) - self.assertFalse(builder.add_parameter('pre.sub_messages-1.nothing', ['x'])) - assert_empty() - - def testAddParameter_NestedAttributes(self): - message = SuperSuperMessage() - builder = protourlencode.URLEncodedRequestBuilder(message, prefix='pre.') - - # Set an empty message fields. - self.assertTrue(builder.add_parameter('pre.sub_message', [''])) - self.assertTrue(isinstance(message.sub_message, SuperMessage)) - - # Add a basic attribute. - self.assertTrue(builder.add_parameter( - 'pre.sub_message.sub_message.int64_value', ['10'])) - self.assertTrue(builder.add_parameter( - 'pre.sub_message.sub_message.string_value', ['hello'])) - - self.assertTrue(10, message.sub_message.sub_message.int64_value) - self.assertTrue('hello', message.sub_message.sub_message.string_value) - - - def testAddParameter_NestedMessages(self): - message = SuperSuperMessage() - builder = protourlencode.URLEncodedRequestBuilder(message, prefix='pre.') - - # Add a repeated empty message. - self.assertTrue(builder.add_parameter( - 'pre.sub_message.sub_messages-0', [''])) - sub_message = message.sub_message.sub_messages[0] - self.assertTrue(1, len(message.sub_message.sub_messages)) - self.assertTrue(isinstance(sub_message, - test_util.OptionalMessage)) - self.assertEquals(None, getattr(sub_message, 'int64_value')) - self.assertEquals(None, getattr(sub_message, 'string_value')) - self.assertEquals(None, getattr(sub_message, 'enum_value')) - - # Add a repeated message with value. - self.assertTrue(builder.add_parameter( - 'pre.sub_message.sub_messages-1.int64_value', ['10'])) - self.assertTrue(2, len(message.sub_message.sub_messages)) - self.assertTrue(10, message.sub_message.sub_messages[1].int64_value) - - # Add another value to the same nested message. - self.assertTrue(builder.add_parameter( - 'pre.sub_message.sub_messages-1.string_value', ['a string'])) - self.assertTrue(2, len(message.sub_message.sub_messages)) - self.assertEquals(10, message.sub_message.sub_messages[1].int64_value) - self.assertEquals('a string', - message.sub_message.sub_messages[1].string_value) - - def testAddParameter_RepeatedValues(self): - message = test_util.RepeatedMessage() - builder = protourlencode.URLEncodedRequestBuilder(message, prefix='pre.') - - self.assertTrue(builder.add_parameter('pre.int64_value-0', ['20'])) - self.assertTrue(builder.add_parameter('pre.int64_value-1', ['30'])) - self.assertEquals([20, 30], message.int64_value) - - self.assertTrue(builder.add_parameter('pre.string_value-0', ['hi'])) - self.assertTrue(builder.add_parameter('pre.string_value-1', ['lo'])) - self.assertTrue(builder.add_parameter('pre.string_value-1', ['dups overwrite'])) - self.assertEquals(['hi', 'dups overwrite'], message.string_value) - - def testAddParameter_InvalidValuesMayRepeat(self): - message = test_util.OptionalMessage() - builder = protourlencode.URLEncodedRequestBuilder(message, prefix='pre.') - - self.assertFalse(builder.add_parameter('nothing', [1, 2, 3])) - - def testAddParameter_RepeatedParameters(self): - message = test_util.OptionalMessage() - builder = protourlencode.URLEncodedRequestBuilder(message, prefix='pre.') - - self.assertRaises(messages.DecodeError, - builder.add_parameter, - 'pre.int64_value', - [1, 2, 3]) - self.assertRaises(messages.DecodeError, - builder.add_parameter, - 'pre.int64_value', - []) - - def testAddParameter_UnexpectedNestedValue(self): - """Test getting a nested value on a non-message sub-field.""" - message = test_util.HasNestedMessage() - builder = protourlencode.URLEncodedRequestBuilder(message, 'pre.') - - self.assertFalse(builder.add_parameter('pre.nested.a_value.whatever', - ['1'])) - - def testInvalidFieldFormat(self): - message = test_util.OptionalMessage() - builder = protourlencode.URLEncodedRequestBuilder(message, prefix='pre.') - - self.assertFalse(builder.add_parameter('pre.illegal%20', ['1'])) - - def testAddParameter_UnexpectedNestedValue(self): - """Test getting a nested value on a non-message sub-field - - There is an odd corner case where if trying to insert a repeated value - on an nested repeated message that would normally succeed in being created - should fail. This case can only be tested when the first message of the - nested messages already exists. - - Another case is trying to access an indexed value nested within a - non-message field. - """ - class HasRepeated(messages.Message): - - values = messages.IntegerField(1, repeated=True) - - class HasNestedRepeated(messages.Message): - - nested = messages.MessageField(HasRepeated, 1, repeated=True) - - - message = HasNestedRepeated() - builder = protourlencode.URLEncodedRequestBuilder(message, prefix='pre.') - - self.assertTrue(builder.add_parameter('pre.nested-0.values-0', ['1'])) - # Try to create an indexed value on a non-message field. - self.assertFalse(builder.add_parameter('pre.nested-0.values-0.unknown-0', - ['1'])) - # Try to create an out of range indexed field on an otherwise valid - # repeated message field. - self.assertFalse(builder.add_parameter('pre.nested-1.values-1', ['1'])) - - -class ProtourlencodeConformanceTest(test_util.TestCase, - test_util.ProtoConformanceTestBase): - - PROTOLIB = protourlencode - - encoded_partial = urllib.urlencode([('double_value', 1.23), - ('int64_value', -100000000000), - ('int32_value', 1020), - ('string_value', u'a string'), - ('enum_value', 'VAL2'), - ]) - - encoded_full = urllib.urlencode([('double_value', 1.23), - ('float_value', -2.5), - ('int64_value', -100000000000), - ('uint64_value', 102020202020), - ('int32_value', 1020), - ('bool_value', 'true'), - ('string_value', - u'a string\u044f'.encode('utf-8')), - ('bytes_value', b'a bytes\xff\xfe'), - ('enum_value', 'VAL2'), - ]) - - encoded_repeated = urllib.urlencode([('double_value-0', 1.23), - ('double_value-1', 2.3), - ('float_value-0', -2.5), - ('float_value-1', 0.5), - ('int64_value-0', -100000000000), - ('int64_value-1', 20), - ('uint64_value-0', 102020202020), - ('uint64_value-1', 10), - ('int32_value-0', 1020), - ('int32_value-1', 718), - ('bool_value-0', 'true'), - ('bool_value-1', 'false'), - ('string_value-0', - u'a string\u044f'.encode('utf-8')), - ('string_value-1', - u'another string'.encode('utf-8')), - ('bytes_value-0', b'a bytes\xff\xfe'), - ('bytes_value-1', b'another bytes'), - ('enum_value-0', 'VAL2'), - ('enum_value-1', 'VAL1'), - ]) - - encoded_nested = urllib.urlencode([('nested.a_value', 'a string'), - ]) - - encoded_repeated_nested = urllib.urlencode( - [('repeated_nested-0.a_value', 'a string'), - ('repeated_nested-1.a_value', 'another string'), - ]) - - unexpected_tag_message = 'unexpected=whatever' - - encoded_default_assigned = urllib.urlencode([('a_value', 'a default'), - ]) - - encoded_nested_empty = urllib.urlencode([('nested', '')]) - - encoded_repeated_nested_empty = urllib.urlencode([('repeated_nested-0', ''), - ('repeated_nested-1', '')]) - - encoded_extend_message = urllib.urlencode([('int64_value-0', 400), - ('int64_value-1', 50), - ('int64_value-2', 6000)]) - - encoded_string_types = urllib.urlencode( - [('string_value', 'Latin')]) - - encoded_invalid_enum = urllib.urlencode([('enum_value', 'undefined')]) - - def testParameterPrefix(self): - """Test using the 'prefix' parameter to encode_message.""" - class MyMessage(messages.Message): - number = messages.IntegerField(1) - names = messages.StringField(2, repeated=True) - - message = MyMessage() - message.number = 10 - message.names = [u'Fred', u'Lisa'] - - encoded_message = protourlencode.encode_message(message, prefix='prefix-') - self.assertEquals({'prefix-number': ['10'], - 'prefix-names-0': ['Fred'], - 'prefix-names-1': ['Lisa'], - }, - cgi.parse_qs(encoded_message)) - - self.assertEquals(message, protourlencode.decode_message(MyMessage, - encoded_message, - prefix='prefix-')) - - def testProtourlencodeUnrecognizedField(self): - """Test that unrecognized fields are saved and can be accessed.""" - - class MyMessage(messages.Message): - number = messages.IntegerField(1) - - decoded = protourlencode.decode_message(MyMessage, - self.unexpected_tag_message) - self.assertEquals(1, len(decoded.all_unrecognized_fields())) - self.assertEquals('unexpected', decoded.all_unrecognized_fields()[0]) - # Unknown values set to a list of however many values had that name. - self.assertEquals((['whatever'], messages.Variant.STRING), - decoded.get_unrecognized_field_info('unexpected')) - - repeated_unknown = urllib.urlencode([('repeated', 400), - ('repeated', 'test'), - ('repeated', '123.456')]) - decoded2 = protourlencode.decode_message(MyMessage, repeated_unknown) - self.assertEquals((['400', 'test', '123.456'], messages.Variant.STRING), - decoded2.get_unrecognized_field_info('repeated')) - - def testDecodeInvalidDateTime(self): - - class MyMessage(messages.Message): - a_datetime = message_types.DateTimeField(1) - - self.assertRaises(messages.DecodeError, protourlencode.decode_message, - MyMessage, 'a_datetime=invalid') - - -if __name__ == '__main__': - unittest.main() diff --git a/endpoints/internal/protorpc/registry_test.py b/endpoints/internal/protorpc/registry_test.py deleted file mode 100644 index ec30a3f..0000000 --- a/endpoints/internal/protorpc/registry_test.py +++ /dev/null @@ -1,124 +0,0 @@ -#!/usr/bin/env python -# -# Copyright 2010 Google Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -"""Tests for protorpc.message.""" - -__author__ = 'rafek@google.com (Rafe Kaplan)' - - -import sys -import unittest - -from protorpc import descriptor -from protorpc import message_types -from protorpc import messages -from protorpc import registry -from protorpc import remote -from protorpc import test_util - - -class ModuleInterfaceTest(test_util.ModuleInterfaceTest, - test_util.TestCase): - - MODULE = registry - - -class MyService1(remote.Service): - """Test service that refers to messages in another module.""" - - @remote.method(test_util.NestedMessage, test_util.NestedMessage) - def a_method(self, request): - pass - - -class MyService2(remote.Service): - """Test service that does not refer to messages in another module.""" - - -class RegistryServiceTest(test_util.TestCase): - - def setUp(self): - self.registry = { - 'my-service1': MyService1, - 'my-service2': MyService2, - } - - self.modules = { - __name__: sys.modules[__name__], - test_util.__name__: test_util, - } - - self.registry_service = registry.RegistryService(self.registry, - modules=self.modules) - - def CheckServiceMappings(self, mappings): - module_name = test_util.get_module_name(RegistryServiceTest) - service1_mapping = registry.ServiceMapping() - service1_mapping.name = 'my-service1' - service1_mapping.definition = '%s.MyService1' % module_name - - service2_mapping = registry.ServiceMapping() - service2_mapping.name = 'my-service2' - service2_mapping.definition = '%s.MyService2' % module_name - - self.assertIterEqual(mappings, [service1_mapping, service2_mapping]) - - def testServices(self): - response = self.registry_service.services(message_types.VoidMessage()) - - self.CheckServiceMappings(response.services) - - def testGetFileSet_All(self): - request = registry.GetFileSetRequest() - request.names = ['my-service1', 'my-service2'] - response = self.registry_service.get_file_set(request) - - expected_file_set = descriptor.describe_file_set(list(self.modules.values())) - self.assertIterEqual(expected_file_set.files, response.file_set.files) - - def testGetFileSet_None(self): - request = registry.GetFileSetRequest() - response = self.registry_service.get_file_set(request) - - self.assertEquals(descriptor.FileSet(), - response.file_set) - - def testGetFileSet_ReferenceOtherModules(self): - request = registry.GetFileSetRequest() - request.names = ['my-service1'] - response = self.registry_service.get_file_set(request) - - # Will suck in and describe the test_util module. - expected_file_set = descriptor.describe_file_set(list(self.modules.values())) - self.assertIterEqual(expected_file_set.files, response.file_set.files) - - def testGetFileSet_DoNotReferenceOtherModules(self): - request = registry.GetFileSetRequest() - request.names = ['my-service2'] - response = self.registry_service.get_file_set(request) - - # Service does not reference test_util, so will only describe this module. - expected_file_set = descriptor.describe_file_set([self.modules[__name__]]) - self.assertIterEqual(expected_file_set.files, response.file_set.files) - - -def main(): - unittest.main() - - -if __name__ == '__main__': - main() diff --git a/endpoints/internal/protorpc/remote_test.py b/endpoints/internal/protorpc/remote_test.py deleted file mode 100644 index 155dcb8..0000000 --- a/endpoints/internal/protorpc/remote_test.py +++ /dev/null @@ -1,933 +0,0 @@ -#!/usr/bin/env python -# -# Copyright 2010 Google Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -"""Tests for protorpc.remote.""" - -__author__ = 'rafek@google.com (Rafe Kaplan)' - - -import sys -import types -import unittest -from wsgiref import headers - -from protorpc import descriptor -from protorpc import message_types -from protorpc import messages -from protorpc import protobuf -from protorpc import protojson -from protorpc import remote -from protorpc import test_util -from protorpc import transport - -import mox - - -class ModuleInterfaceTest(test_util.ModuleInterfaceTest, - test_util.TestCase): - - MODULE = remote - - -class Request(messages.Message): - """Test request message.""" - - value = messages.StringField(1) - - -class Response(messages.Message): - """Test response message.""" - - value = messages.StringField(1) - - -class MyService(remote.Service): - - @remote.method(Request, Response) - def remote_method(self, request): - response = Response() - response.value = request.value - return response - - -class SimpleRequest(messages.Message): - """Simple request message type used for tests.""" - - param1 = messages.StringField(1) - param2 = messages.StringField(2) - - -class SimpleResponse(messages.Message): - """Simple response message type used for tests.""" - - -class BasicService(remote.Service): - """A basic service with decorated remote method.""" - - def __init__(self): - self.request_ids = [] - - @remote.method(SimpleRequest, SimpleResponse) - def remote_method(self, request): - """BasicService remote_method docstring.""" - self.request_ids.append(id(request)) - return SimpleResponse() - - -class RpcErrorTest(test_util.TestCase): - - def testFromStatus(self): - for state in remote.RpcState: - exception = remote.RpcError.from_state - self.assertEquals(remote.ServerError, - remote.RpcError.from_state('SERVER_ERROR')) - - -class ApplicationErrorTest(test_util.TestCase): - - def testErrorCode(self): - self.assertEquals('blam', - remote.ApplicationError('an error', 'blam').error_name) - - def testStr(self): - self.assertEquals('an error', str(remote.ApplicationError('an error', 1))) - - def testRepr(self): - self.assertEquals("ApplicationError('an error', 1)", - repr(remote.ApplicationError('an error', 1))) - - self.assertEquals("ApplicationError('an error')", - repr(remote.ApplicationError('an error'))) - - -class MethodTest(test_util.TestCase): - """Test remote method decorator.""" - - def testMethod(self): - """Test use of remote decorator.""" - self.assertEquals(SimpleRequest, - BasicService.remote_method.remote.request_type) - self.assertEquals(SimpleResponse, - BasicService.remote_method.remote.response_type) - self.assertTrue(isinstance(BasicService.remote_method.remote.method, - types.FunctionType)) - - def testMethodMessageResolution(self): - """Test use of remote decorator to resolve message types by name.""" - class OtherService(remote.Service): - - @remote.method('SimpleRequest', 'SimpleResponse') - def remote_method(self, request): - pass - - self.assertEquals(SimpleRequest, - OtherService.remote_method.remote.request_type) - self.assertEquals(SimpleResponse, - OtherService.remote_method.remote.response_type) - - def testMethodMessageResolution_NotFound(self): - """Test failure to find message types.""" - class OtherService(remote.Service): - - @remote.method('NoSuchRequest', 'NoSuchResponse') - def remote_method(self, request): - pass - - self.assertRaisesWithRegexpMatch( - messages.DefinitionNotFoundError, - 'Could not find definition for NoSuchRequest', - getattr, - OtherService.remote_method.remote, - 'request_type') - - self.assertRaisesWithRegexpMatch( - messages.DefinitionNotFoundError, - 'Could not find definition for NoSuchResponse', - getattr, - OtherService.remote_method.remote, - 'response_type') - - def testInvocation(self): - """Test that invocation passes request through properly.""" - service = BasicService() - request = SimpleRequest() - self.assertEquals(SimpleResponse(), service.remote_method(request)) - self.assertEquals([id(request)], service.request_ids) - - def testInvocation_WrongRequestType(self): - """Wrong request type passed to remote method.""" - service = BasicService() - - self.assertRaises(remote.RequestError, - service.remote_method, - 'wrong') - - self.assertRaises(remote.RequestError, - service.remote_method, - None) - - self.assertRaises(remote.RequestError, - service.remote_method, - SimpleResponse()) - - def testInvocation_WrongResponseType(self): - """Wrong response type returned from remote method.""" - - class AnotherService(object): - - @remote.method(SimpleRequest, SimpleResponse) - def remote_method(self, unused_request): - return self.return_this - - service = AnotherService() - - service.return_this = 'wrong' - self.assertRaises(remote.ServerError, - service.remote_method, - SimpleRequest()) - service.return_this = None - self.assertRaises(remote.ServerError, - service.remote_method, - SimpleRequest()) - service.return_this = SimpleRequest() - self.assertRaises(remote.ServerError, - service.remote_method, - SimpleRequest()) - - def testBadRequestType(self): - """Test bad request types used in remote definition.""" - - for request_type in (None, 1020, messages.Message, str): - - def declare(): - class BadService(object): - - @remote.method(request_type, SimpleResponse) - def remote_method(self, request): - pass - - self.assertRaises(TypeError, declare) - - def testBadResponseType(self): - """Test bad response types used in remote definition.""" - - for response_type in (None, 1020, messages.Message, str): - - def declare(): - class BadService(object): - - @remote.method(SimpleRequest, response_type) - def remote_method(self, request): - pass - - self.assertRaises(TypeError, declare) - - def testDocString(self): - """Test that the docstring comes from the original method.""" - service = BasicService() - self.assertEquals('BasicService remote_method docstring.', - service.remote_method.__doc__) - - -class GetRemoteMethodTest(test_util.TestCase): - """Test for is_remote_method.""" - - def testGetRemoteMethod(self): - """Test valid remote method detection.""" - - class Service(object): - - @remote.method(Request, Response) - def remote_method(self, request): - pass - - self.assertEquals(Service.remote_method.remote, - remote.get_remote_method_info(Service.remote_method)) - self.assertTrue(Service.remote_method.remote, - remote.get_remote_method_info(Service().remote_method)) - - def testGetNotRemoteMethod(self): - """Test positive result on a remote method.""" - - class NotService(object): - - def not_remote_method(self, request): - pass - - def fn(self): - pass - - class NotReallyRemote(object): - """Test negative result on many bad values for remote methods.""" - - def not_really(self, request): - pass - - not_really.remote = 'something else' - - for not_remote in [NotService.not_remote_method, - NotService().not_remote_method, - NotReallyRemote.not_really, - NotReallyRemote().not_really, - None, - 1, - 'a string', - fn]: - self.assertEquals(None, remote.get_remote_method_info(not_remote)) - - -class RequestStateTest(test_util.TestCase): - """Test request state.""" - - STATE_CLASS = remote.RequestState - - def testConstructor(self): - """Test constructor.""" - state = self.STATE_CLASS(remote_host='remote-host', - remote_address='remote-address', - server_host='server-host', - server_port=10) - self.assertEquals('remote-host', state.remote_host) - self.assertEquals('remote-address', state.remote_address) - self.assertEquals('server-host', state.server_host) - self.assertEquals(10, state.server_port) - - state = self.STATE_CLASS() - self.assertEquals(None, state.remote_host) - self.assertEquals(None, state.remote_address) - self.assertEquals(None, state.server_host) - self.assertEquals(None, state.server_port) - - def testConstructorError(self): - """Test unexpected keyword argument.""" - self.assertRaises(TypeError, - self.STATE_CLASS, - x=10) - - def testRepr(self): - """Test string representation.""" - self.assertEquals('<%s>' % self.STATE_CLASS.__name__, - repr(self.STATE_CLASS())) - self.assertEquals("<%s remote_host='abc'>" % self.STATE_CLASS.__name__, - repr(self.STATE_CLASS(remote_host='abc'))) - self.assertEquals("<%s remote_host='abc' " - "remote_address='def'>" % self.STATE_CLASS.__name__, - repr(self.STATE_CLASS(remote_host='abc', - remote_address='def'))) - self.assertEquals("<%s remote_host='abc' " - "remote_address='def' " - "server_host='ghi'>" % self.STATE_CLASS.__name__, - repr(self.STATE_CLASS(remote_host='abc', - remote_address='def', - server_host='ghi'))) - self.assertEquals("<%s remote_host='abc' " - "remote_address='def' " - "server_host='ghi' " - 'server_port=102>' % self.STATE_CLASS.__name__, - repr(self.STATE_CLASS(remote_host='abc', - remote_address='def', - server_host='ghi', - server_port=102))) - - -class HttpRequestStateTest(RequestStateTest): - - STATE_CLASS = remote.HttpRequestState - - def testHttpMethod(self): - state = remote.HttpRequestState(http_method='GET') - self.assertEquals('GET', state.http_method) - - def testHttpMethod(self): - state = remote.HttpRequestState(service_path='/bar') - self.assertEquals('/bar', state.service_path) - - def testHeadersList(self): - state = remote.HttpRequestState( - headers=[('a', 'b'), ('c', 'd'), ('c', 'e')]) - - self.assertEquals(['a', 'c', 'c'], list(state.headers.keys())) - self.assertEquals(['b'], state.headers.get_all('a')) - self.assertEquals(['d', 'e'], state.headers.get_all('c')) - - def testHeadersDict(self): - state = remote.HttpRequestState(headers={'a': 'b', 'c': ['d', 'e']}) - - self.assertEquals(['a', 'c', 'c'], sorted(state.headers.keys())) - self.assertEquals(['b'], state.headers.get_all('a')) - self.assertEquals(['d', 'e'], state.headers.get_all('c')) - - def testRepr(self): - super(HttpRequestStateTest, self).testRepr() - - self.assertEquals("<%s remote_host='abc' " - "remote_address='def' " - "server_host='ghi' " - 'server_port=102 ' - "http_method='POST' " - "service_path='/bar' " - "headers=[('a', 'b'), ('c', 'd')]>" % - self.STATE_CLASS.__name__, - repr(self.STATE_CLASS(remote_host='abc', - remote_address='def', - server_host='ghi', - server_port=102, - http_method='POST', - service_path='/bar', - headers={'a': 'b', 'c': 'd'}, - ))) - - -class ServiceTest(test_util.TestCase): - """Test Service class.""" - - def testServiceBase_AllRemoteMethods(self): - """Test that service base class has no remote methods.""" - self.assertEquals({}, remote.Service.all_remote_methods()) - - def testAllRemoteMethods(self): - """Test all_remote_methods with properly Service subclass.""" - self.assertEquals({'remote_method': MyService.remote_method}, - MyService.all_remote_methods()) - - def testAllRemoteMethods_SubClass(self): - """Test all_remote_methods on a sub-class of a service.""" - class SubClass(MyService): - - @remote.method(Request, Response) - def sub_class_method(self, request): - pass - - self.assertEquals({'remote_method': SubClass.remote_method, - 'sub_class_method': SubClass.sub_class_method, - }, - SubClass.all_remote_methods()) - - def testOverrideMethod(self): - """Test that trying to override a remote method with remote decorator.""" - class SubClass(MyService): - - def remote_method(self, request): - response = super(SubClass, self).remote_method(request) - response.value = '(%s)' % response.value - return response - - self.assertEquals({'remote_method': SubClass.remote_method, - }, - SubClass.all_remote_methods()) - - instance = SubClass() - self.assertEquals('(Hello)', - instance.remote_method(Request(value='Hello')).value) - self.assertEquals(Request, SubClass.remote_method.remote.request_type) - self.assertEquals(Response, SubClass.remote_method.remote.response_type) - - def testOverrideMethodWithRemote(self): - """Test trying to override a remote method with remote decorator.""" - def do_override(): - class SubClass(MyService): - - @remote.method(Request, Response) - def remote_method(self, request): - pass - - self.assertRaisesWithRegexpMatch(remote.ServiceDefinitionError, - 'Do not use method decorator when ' - 'overloading remote method remote_method ' - 'on service SubClass', - do_override) - - def testOverrideMethodWithInvalidValue(self): - """Test trying to override a remote method with remote decorator.""" - def do_override(bad_value): - class SubClass(MyService): - - remote_method = bad_value - - for bad_value in [None, 1, 'string', {}]: - self.assertRaisesWithRegexpMatch(remote.ServiceDefinitionError, - 'Must override remote_method in ' - 'SubClass with a method', - do_override, bad_value) - - def testCallingRemoteMethod(self): - """Test invoking a remote method.""" - expected = Response() - expected.value = 'what was passed in' - - request = Request() - request.value = 'what was passed in' - - service = MyService() - self.assertEquals(expected, service.remote_method(request)) - - def testFactory(self): - """Test using factory to pass in state.""" - class StatefulService(remote.Service): - - def __init__(self, a, b, c=None): - self.a = a - self.b = b - self.c = c - - state = [1, 2, 3] - - factory = StatefulService.new_factory(1, state) - - module_name = ServiceTest.__module__ - pattern = ('Creates new instances of service StatefulService.\n\n' - 'Returns:\n' - ' New instance of %s.StatefulService.' % module_name) - self.assertEqual(pattern, factory.__doc__) - self.assertEquals('StatefulService_service_factory', factory.__name__) - self.assertEquals(StatefulService, factory.service_class) - - service = factory() - self.assertEquals(1, service.a) - self.assertEquals(id(state), id(service.b)) - self.assertEquals(None, service.c) - - factory = StatefulService.new_factory(2, b=3, c=4) - service = factory() - self.assertEquals(2, service.a) - self.assertEquals(3, service.b) - self.assertEquals(4, service.c) - - def testFactoryError(self): - """Test misusing a factory.""" - # Passing positional argument that is not accepted by class. - self.assertRaises(TypeError, remote.Service.new_factory(1)) - - # Passing keyword argument that is not accepted by class. - self.assertRaises(TypeError, remote.Service.new_factory(x=1)) - - class StatefulService(remote.Service): - - def __init__(self, a): - pass - - # Missing required parameter. - self.assertRaises(TypeError, StatefulService.new_factory()) - - def testDefinitionName(self): - """Test getting service definition name.""" - class TheService(remote.Service): - pass - - module_name = test_util.get_module_name(ServiceTest) - self.assertEqual(TheService.definition_name(), - '%s.TheService' % module_name) - self.assertTrue(TheService.outer_definition_name(), - module_name) - self.assertTrue(TheService.definition_package(), - module_name) - - def testDefinitionNameWithPackage(self): - """Test getting service definition name when package defined.""" - global package - package = 'my.package' - try: - class TheService(remote.Service): - pass - - self.assertEquals('my.package.TheService', TheService.definition_name()) - self.assertEquals('my.package', TheService.outer_definition_name()) - self.assertEquals('my.package', TheService.definition_package()) - finally: - del package - - def testDefinitionNameWithNoModule(self): - """Test getting service definition name when package defined.""" - module = sys.modules[__name__] - try: - del sys.modules[__name__] - class TheService(remote.Service): - pass - - self.assertEquals('TheService', TheService.definition_name()) - self.assertEquals(None, TheService.outer_definition_name()) - self.assertEquals(None, TheService.definition_package()) - finally: - sys.modules[__name__] = module - - -class StubTest(test_util.TestCase): - - def setUp(self): - self.mox = mox.Mox() - self.transport = self.mox.CreateMockAnything() - - def testDefinitionName(self): - self.assertEquals(BasicService.definition_name(), - BasicService.Stub.definition_name()) - self.assertEquals(BasicService.outer_definition_name(), - BasicService.Stub.outer_definition_name()) - self.assertEquals(BasicService.definition_package(), - BasicService.Stub.definition_package()) - - def testRemoteMethods(self): - self.assertEquals(BasicService.all_remote_methods(), - BasicService.Stub.all_remote_methods()) - - def testSync_WithRequest(self): - stub = BasicService.Stub(self.transport) - - request = SimpleRequest() - request.param1 = 'val1' - request.param2 = 'val2' - response = SimpleResponse() - - rpc = transport.Rpc(request) - rpc.set_response(response) - self.transport.send_rpc(BasicService.remote_method.remote, - request).AndReturn(rpc) - - self.mox.ReplayAll() - - self.assertEquals(SimpleResponse(), stub.remote_method(request)) - - self.mox.VerifyAll() - - def testSync_WithKwargs(self): - stub = BasicService.Stub(self.transport) - - - request = SimpleRequest() - request.param1 = 'val1' - request.param2 = 'val2' - response = SimpleResponse() - - rpc = transport.Rpc(request) - rpc.set_response(response) - self.transport.send_rpc(BasicService.remote_method.remote, - request).AndReturn(rpc) - - self.mox.ReplayAll() - - self.assertEquals(SimpleResponse(), stub.remote_method(param1='val1', - param2='val2')) - - self.mox.VerifyAll() - - def testAsync_WithRequest(self): - stub = BasicService.Stub(self.transport) - - request = SimpleRequest() - request.param1 = 'val1' - request.param2 = 'val2' - response = SimpleResponse() - - rpc = transport.Rpc(request) - - self.transport.send_rpc(BasicService.remote_method.remote, - request).AndReturn(rpc) - - self.mox.ReplayAll() - - self.assertEquals(rpc, stub.async.remote_method(request)) - - self.mox.VerifyAll() - - def testAsync_WithKwargs(self): - stub = BasicService.Stub(self.transport) - - request = SimpleRequest() - request.param1 = 'val1' - request.param2 = 'val2' - response = SimpleResponse() - - rpc = transport.Rpc(request) - - self.transport.send_rpc(BasicService.remote_method.remote, - request).AndReturn(rpc) - - self.mox.ReplayAll() - - self.assertEquals(rpc, stub.async.remote_method(param1='val1', - param2='val2')) - - self.mox.VerifyAll() - - def testAsync_WithRequestAndKwargs(self): - stub = BasicService.Stub(self.transport) - - request = SimpleRequest() - request.param1 = 'val1' - request.param2 = 'val2' - response = SimpleResponse() - - self.mox.ReplayAll() - - self.assertRaisesWithRegexpMatch( - TypeError, - r'May not provide both args and kwargs', - stub.async.remote_method, - request, - param1='val1', - param2='val2') - - self.mox.VerifyAll() - - def testAsync_WithTooManyPositionals(self): - stub = BasicService.Stub(self.transport) - - request = SimpleRequest() - request.param1 = 'val1' - request.param2 = 'val2' - response = SimpleResponse() - - self.mox.ReplayAll() - - self.assertRaisesWithRegexpMatch( - TypeError, - r'remote_method\(\) takes at most 2 positional arguments \(3 given\)', - stub.async.remote_method, - request, 'another value') - - self.mox.VerifyAll() - - -class IsErrorStatusTest(test_util.TestCase): - - def testIsError(self): - for state in (s for s in remote.RpcState if s > remote.RpcState.RUNNING): - status = remote.RpcStatus(state=state) - self.assertTrue(remote.is_error_status(status)) - - def testIsNotError(self): - for state in (s for s in remote.RpcState if s <= remote.RpcState.RUNNING): - status = remote.RpcStatus(state=state) - self.assertFalse(remote.is_error_status(status)) - - def testStateNone(self): - self.assertRaises(messages.ValidationError, - remote.is_error_status, remote.RpcStatus()) - - -class CheckRpcStatusTest(test_util.TestCase): - - def testStateNone(self): - self.assertRaises(messages.ValidationError, - remote.check_rpc_status, remote.RpcStatus()) - - def testNoError(self): - for state in (remote.RpcState.OK, remote.RpcState.RUNNING): - remote.check_rpc_status(remote.RpcStatus(state=state)) - - def testErrorState(self): - status = remote.RpcStatus(state=remote.RpcState.REQUEST_ERROR, - error_message='a request error') - self.assertRaisesWithRegexpMatch(remote.RequestError, - 'a request error', - remote.check_rpc_status, status) - - def testApplicationErrorState(self): - status = remote.RpcStatus(state=remote.RpcState.APPLICATION_ERROR, - error_message='an application error', - error_name='blam') - try: - remote.check_rpc_status(status) - self.fail('Should have raised application error.') - except remote.ApplicationError as err: - self.assertEquals('an application error', str(err)) - self.assertEquals('blam', err.error_name) - - -class ProtocolConfigTest(test_util.TestCase): - - def testConstructor(self): - config = remote.ProtocolConfig( - protojson, - 'proto1', - 'application/X-Json', - iter(['text/Json', 'text/JavaScript'])) - self.assertEquals(protojson, config.protocol) - self.assertEquals('proto1', config.name) - self.assertEquals('application/x-json', config.default_content_type) - self.assertEquals(('text/json', 'text/javascript'), - config.alternate_content_types) - self.assertEquals(('application/x-json', 'text/json', 'text/javascript'), - config.content_types) - - def testConstructorDefaults(self): - config = remote.ProtocolConfig(protojson, 'proto2') - self.assertEquals(protojson, config.protocol) - self.assertEquals('proto2', config.name) - self.assertEquals('application/json', config.default_content_type) - self.assertEquals(('application/x-javascript', - 'text/javascript', - 'text/x-javascript', - 'text/x-json', - 'text/json'), - config.alternate_content_types) - self.assertEquals(('application/json', - 'application/x-javascript', - 'text/javascript', - 'text/x-javascript', - 'text/x-json', - 'text/json'), config.content_types) - - def testEmptyAlternativeTypes(self): - config = remote.ProtocolConfig(protojson, 'proto2', - alternative_content_types=()) - self.assertEquals(protojson, config.protocol) - self.assertEquals('proto2', config.name) - self.assertEquals('application/json', config.default_content_type) - self.assertEquals((), config.alternate_content_types) - self.assertEquals(('application/json',), config.content_types) - - def testDuplicateContentTypes(self): - self.assertRaises(remote.ServiceConfigurationError, - remote.ProtocolConfig, - protojson, - 'json', - 'text/plain', - ('text/plain',)) - - self.assertRaises(remote.ServiceConfigurationError, - remote.ProtocolConfig, - protojson, - 'json', - 'text/plain', - ('text/html', 'text/html')) - - def testEncodeMessage(self): - config = remote.ProtocolConfig(protojson, 'proto2') - encoded_message = config.encode_message( - remote.RpcStatus(state=remote.RpcState.SERVER_ERROR, - error_message='bad error')) - - # Convert back to a dictionary from JSON. - dict_message = protojson.json.loads(encoded_message) - self.assertEquals({'state': 'SERVER_ERROR', 'error_message': 'bad error'}, - dict_message) - - def testDecodeMessage(self): - config = remote.ProtocolConfig(protojson, 'proto2') - self.assertEquals( - remote.RpcStatus(state=remote.RpcState.SERVER_ERROR, - error_message="bad error"), - config.decode_message( - remote.RpcStatus, - '{"state": "SERVER_ERROR", "error_message": "bad error"}')) - - -class ProtocolsTest(test_util.TestCase): - - def setUp(self): - self.protocols = remote.Protocols() - - def testEmpty(self): - self.assertEquals((), self.protocols.names) - self.assertEquals((), self.protocols.content_types) - - def testAddProtocolAllDefaults(self): - self.protocols.add_protocol(protojson, 'json') - self.assertEquals(('json',), self.protocols.names) - self.assertEquals(('application/json', - 'application/x-javascript', - 'text/javascript', - 'text/json', - 'text/x-javascript', - 'text/x-json'), - self.protocols.content_types) - - def testAddProtocolNoDefaultAlternatives(self): - class Protocol(object): - CONTENT_TYPE = 'text/plain' - - self.protocols.add_protocol(Protocol, 'text') - self.assertEquals(('text',), self.protocols.names) - self.assertEquals(('text/plain',), self.protocols.content_types) - - def testAddProtocolOverrideDefaults(self): - self.protocols.add_protocol(protojson, 'json', - default_content_type='text/blar', - alternative_content_types=('text/blam', - 'text/blim')) - self.assertEquals(('json',), self.protocols.names) - self.assertEquals(('text/blam', 'text/blar', 'text/blim'), - self.protocols.content_types) - - def testLookupByName(self): - self.protocols.add_protocol(protojson, 'json') - self.protocols.add_protocol(protojson, 'json2', - default_content_type='text/plain', - alternative_content_types=()) - - self.assertEquals('json', self.protocols.lookup_by_name('JsOn').name) - self.assertEquals('json2', self.protocols.lookup_by_name('Json2').name) - - def testLookupByContentType(self): - self.protocols.add_protocol(protojson, 'json') - self.protocols.add_protocol(protojson, 'json2', - default_content_type='text/plain', - alternative_content_types=()) - - self.assertEquals( - 'json', - self.protocols.lookup_by_content_type('AppliCation/Json').name) - - self.assertEquals( - 'json', - self.protocols.lookup_by_content_type('text/x-Json').name) - - self.assertEquals( - 'json2', - self.protocols.lookup_by_content_type('text/Plain').name) - - def testNewDefault(self): - protocols = remote.Protocols.new_default() - self.assertEquals(('protobuf', 'protojson'), protocols.names) - - protobuf_protocol = protocols.lookup_by_name('protobuf') - self.assertEquals(protobuf, protobuf_protocol.protocol) - - protojson_protocol = protocols.lookup_by_name('protojson') - self.assertEquals(protojson.ProtoJson.get_default(), - protojson_protocol.protocol) - - def testGetDefaultProtocols(self): - protocols = remote.Protocols.get_default() - self.assertEquals(('protobuf', 'protojson'), protocols.names) - - protobuf_protocol = protocols.lookup_by_name('protobuf') - self.assertEquals(protobuf, protobuf_protocol.protocol) - - protojson_protocol = protocols.lookup_by_name('protojson') - self.assertEquals(protojson.ProtoJson.get_default(), - protojson_protocol.protocol) - - self.assertTrue(protocols is remote.Protocols.get_default()) - - def testSetDefaultProtocols(self): - protocols = remote.Protocols() - remote.Protocols.set_default(protocols) - self.assertTrue(protocols is remote.Protocols.get_default()) - - def testSetDefaultWithoutProtocols(self): - self.assertRaises(TypeError, remote.Protocols.set_default, None) - self.assertRaises(TypeError, remote.Protocols.set_default, 'hi protocols') - self.assertRaises(TypeError, remote.Protocols.set_default, {}) - - -def main(): - unittest.main() - - -if __name__ == '__main__': - main() diff --git a/endpoints/internal/protorpc/test_util.py b/endpoints/internal/protorpc/test_util.py deleted file mode 100644 index bcbccf6..0000000 --- a/endpoints/internal/protorpc/test_util.py +++ /dev/null @@ -1,671 +0,0 @@ -#!/usr/bin/env python -# -# Copyright 2010 Google Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -"""Test utilities for message testing. - -Includes module interface test to ensure that public parts of module are -correctly declared in __all__. - -Includes message types that correspond to those defined in -services_test.proto. - -Includes additional test utilities to make sure encoding/decoding libraries -conform. -""" -from six.moves import range - -__author__ = 'rafek@google.com (Rafe Kaplan)' - -import cgi -import datetime -import inspect -import os -import re -import socket -import types -import unittest2 as unittest - -import six - -from . import message_types -from . import messages -from . import util - -# Unicode of the word "Russian" in cyrillic. -RUSSIAN = u'\u0440\u0443\u0441\u0441\u043a\u0438\u0439' - -# All characters binary value interspersed with nulls. -BINARY = b''.join(six.int2byte(value) + b'\0' for value in range(256)) - - -class TestCase(unittest.TestCase): - - def assertRaisesWithRegexpMatch(self, - exception, - regexp, - function, - *params, - **kwargs): - """Check that exception is raised and text matches regular expression. - - Args: - exception: Exception type that is expected. - regexp: String regular expression that is expected in error message. - function: Callable to test. - params: Parameters to forward to function. - kwargs: Keyword arguments to forward to function. - """ - try: - function(*params, **kwargs) - self.fail('Expected exception %s was not raised' % exception.__name__) - except exception as err: - match = bool(re.match(regexp, str(err))) - self.assertTrue(match, 'Expected match "%s", found "%s"' % (regexp, - err)) - - def assertHeaderSame(self, header1, header2): - """Check that two HTTP headers are the same. - - Args: - header1: Header value string 1. - header2: header value string 2. - """ - value1, params1 = cgi.parse_header(header1) - value2, params2 = cgi.parse_header(header2) - self.assertEqual(value1, value2) - self.assertEqual(params1, params2) - - def assertIterEqual(self, iter1, iter2): - """Check that two iterators or iterables are equal independent of order. - - Similar to Python 2.7 assertItemsEqual. Named differently in order to - avoid potential conflict. - - Args: - iter1: An iterator or iterable. - iter2: An iterator or iterable. - """ - list1 = list(iter1) - list2 = list(iter2) - - unmatched1 = list() - - while list1: - item1 = list1[0] - del list1[0] - for index in range(len(list2)): - if item1 == list2[index]: - del list2[index] - break - else: - unmatched1.append(item1) - - error_message = [] - for item in unmatched1: - error_message.append( - ' Item from iter1 not found in iter2: %r' % item) - for item in list2: - error_message.append( - ' Item from iter2 not found in iter1: %r' % item) - if error_message: - self.fail('Collections not equivalent:\n' + '\n'.join(error_message)) - - -class ModuleInterfaceTest(object): - """Test to ensure module interface is carefully constructed. - - A module interface is the set of public objects listed in the module __all__ - attribute. Modules that that are considered public should have this interface - carefully declared. At all times, the __all__ attribute should have objects - intended to be publically used and all other objects in the module should be - considered unused. - - Protected attributes (those beginning with '_') and other imported modules - should not be part of this set of variables. An exception is for variables - that begin and end with '__' which are implicitly part of the interface - (eg. __name__, __file__, __all__ itself, etc.). - - Modules that are imported in to the tested modules are an exception and may - be left out of the __all__ definition. The test is done by checking the value - of what would otherwise be a public name and not allowing it to be exported - if it is an instance of a module. Modules that are explicitly exported are - for the time being not permitted. - - To use this test class a module should define a new class that inherits first - from ModuleInterfaceTest and then from test_util.TestCase. No other tests - should be added to this test case, making the order of inheritance less - important, but if setUp for some reason is overidden, it is important that - ModuleInterfaceTest is first in the list so that its setUp method is - invoked. - - Multiple inheretance is required so that ModuleInterfaceTest is not itself - a test, and is not itself executed as one. - - The test class is expected to have the following class attributes defined: - - MODULE: A reference to the module that is being validated for interface - correctness. - - Example: - Module definition (hello.py): - - import sys - - __all__ = ['hello'] - - def _get_outputter(): - return sys.stdout - - def hello(): - _get_outputter().write('Hello\n') - - Test definition: - - import unittest - from protorpc import test_util - - import hello - - class ModuleInterfaceTest(test_util.ModuleInterfaceTest, - test_util.TestCase): - - MODULE = hello - - - class HelloTest(test_util.TestCase): - ... Test 'hello' module ... - - - if __name__ == '__main__': - unittest.main() - """ - - def setUp(self): - """Set up makes sure that MODULE and IMPORTED_MODULES is defined. - - This is a basic configuration test for the test itself so does not - get it's own test case. - """ - if not hasattr(self, 'MODULE'): - self.fail( - "You must define 'MODULE' on ModuleInterfaceTest sub-class %s." % - type(self).__name__) - - def testAllExist(self): - """Test that all attributes defined in __all__ exist.""" - missing_attributes = [] - for attribute in self.MODULE.__all__: - if not hasattr(self.MODULE, attribute): - missing_attributes.append(attribute) - if missing_attributes: - self.fail('%s of __all__ are not defined in module.' % - missing_attributes) - - def testAllExported(self): - """Test that all public attributes not imported are in __all__.""" - missing_attributes = [] - for attribute in dir(self.MODULE): - if not attribute.startswith('_'): - if (attribute not in self.MODULE.__all__ and - not isinstance(getattr(self.MODULE, attribute), - types.ModuleType) and - attribute != 'with_statement'): - missing_attributes.append(attribute) - if missing_attributes: - self.fail('%s are not modules and not defined in __all__.' % - missing_attributes) - - def testNoExportedProtectedVariables(self): - """Test that there are no protected variables listed in __all__.""" - protected_variables = [] - for attribute in self.MODULE.__all__: - if attribute.startswith('_'): - protected_variables.append(attribute) - if protected_variables: - self.fail('%s are protected variables and may not be exported.' % - protected_variables) - - def testNoExportedModules(self): - """Test that no modules exist in __all__.""" - exported_modules = [] - for attribute in self.MODULE.__all__: - try: - value = getattr(self.MODULE, attribute) - except AttributeError: - # This is a different error case tested for in testAllExist. - pass - else: - if isinstance(value, types.ModuleType): - exported_modules.append(attribute) - if exported_modules: - self.fail('%s are modules and may not be exported.' % exported_modules) - - -class NestedMessage(messages.Message): - """Simple message that gets nested in another message.""" - - a_value = messages.StringField(1, required=True) - - -class HasNestedMessage(messages.Message): - """Message that has another message nested in it.""" - - nested = messages.MessageField(NestedMessage, 1) - repeated_nested = messages.MessageField(NestedMessage, 2, repeated=True) - - -class HasDefault(messages.Message): - """Has a default value.""" - - a_value = messages.StringField(1, default=u'a default') - - -class OptionalMessage(messages.Message): - """Contains all message types.""" - - class SimpleEnum(messages.Enum): - """Simple enumeration type.""" - VAL1 = 1 - VAL2 = 2 - - double_value = messages.FloatField(1, variant=messages.Variant.DOUBLE) - float_value = messages.FloatField(2, variant=messages.Variant.FLOAT) - int64_value = messages.IntegerField(3, variant=messages.Variant.INT64) - uint64_value = messages.IntegerField(4, variant=messages.Variant.UINT64) - int32_value = messages.IntegerField(5, variant=messages.Variant.INT32) - bool_value = messages.BooleanField(6, variant=messages.Variant.BOOL) - string_value = messages.StringField(7, variant=messages.Variant.STRING) - bytes_value = messages.BytesField(8, variant=messages.Variant.BYTES) - enum_value = messages.EnumField(SimpleEnum, 10) - - # TODO(rafek): Add support for these variants. - # uint32_value = messages.IntegerField(9, variant=messages.Variant.UINT32) - # sint32_value = messages.IntegerField(11, variant=messages.Variant.SINT32) - # sint64_value = messages.IntegerField(12, variant=messages.Variant.SINT64) - - -class RepeatedMessage(messages.Message): - """Contains all message types as repeated fields.""" - - class SimpleEnum(messages.Enum): - """Simple enumeration type.""" - VAL1 = 1 - VAL2 = 2 - - double_value = messages.FloatField(1, - variant=messages.Variant.DOUBLE, - repeated=True) - float_value = messages.FloatField(2, - variant=messages.Variant.FLOAT, - repeated=True) - int64_value = messages.IntegerField(3, - variant=messages.Variant.INT64, - repeated=True) - uint64_value = messages.IntegerField(4, - variant=messages.Variant.UINT64, - repeated=True) - int32_value = messages.IntegerField(5, - variant=messages.Variant.INT32, - repeated=True) - bool_value = messages.BooleanField(6, - variant=messages.Variant.BOOL, - repeated=True) - string_value = messages.StringField(7, - variant=messages.Variant.STRING, - repeated=True) - bytes_value = messages.BytesField(8, - variant=messages.Variant.BYTES, - repeated=True) - #uint32_value = messages.IntegerField(9, variant=messages.Variant.UINT32) - enum_value = messages.EnumField(SimpleEnum, - 10, - repeated=True) - #sint32_value = messages.IntegerField(11, variant=messages.Variant.SINT32) - #sint64_value = messages.IntegerField(12, variant=messages.Variant.SINT64) - - -class HasOptionalNestedMessage(messages.Message): - - nested = messages.MessageField(OptionalMessage, 1) - repeated_nested = messages.MessageField(OptionalMessage, 2, repeated=True) - - -class ProtoConformanceTestBase(object): - """Protocol conformance test base class. - - Each supported protocol should implement two methods that support encoding - and decoding of Message objects in that format: - - encode_message(message) - Serialize to encoding. - encode_message(message, encoded_message) - Deserialize from encoding. - - Tests for the modules where these functions are implemented should extend - this class in order to support basic behavioral expectations. This ensures - that protocols correctly encode and decode message transparently to the - caller. - - In order to support these test, the base class should also extend the TestCase - class and implement the following class attributes which define the encoded - version of certain protocol buffers: - - encoded_partial: - - - encoded_full: - - - encoded_repeated: - - - encoded_nested: - - > - - encoded_repeated_nested: - , - - ] - > - - unexpected_tag_message: - An encoded message that has an undefined tag or number in the stream. - - encoded_default_assigned: - - - encoded_nested_empty: - - > - - encoded_invalid_enum: - - """ - - encoded_empty_message = '' - - def testEncodeInvalidMessage(self): - message = NestedMessage() - self.assertRaises(messages.ValidationError, - self.PROTOLIB.encode_message, message) - - def CompareEncoded(self, expected_encoded, actual_encoded): - """Compare two encoded protocol values. - - Can be overridden by sub-classes to special case comparison. - For example, to eliminate white space from output that is not - relevant to encoding. - - Args: - expected_encoded: Expected string encoded value. - actual_encoded: Actual string encoded value. - """ - self.assertEquals(expected_encoded, actual_encoded) - - def EncodeDecode(self, encoded, expected_message): - message = self.PROTOLIB.decode_message(type(expected_message), encoded) - self.assertEquals(expected_message, message) - self.CompareEncoded(encoded, self.PROTOLIB.encode_message(message)) - - def testEmptyMessage(self): - self.EncodeDecode(self.encoded_empty_message, OptionalMessage()) - - def testPartial(self): - """Test message with a few values set.""" - message = OptionalMessage() - message.double_value = 1.23 - message.int64_value = -100000000000 - message.int32_value = 1020 - message.string_value = u'a string' - message.enum_value = OptionalMessage.SimpleEnum.VAL2 - - self.EncodeDecode(self.encoded_partial, message) - - def testFull(self): - """Test all types.""" - message = OptionalMessage() - message.double_value = 1.23 - message.float_value = -2.5 - message.int64_value = -100000000000 - message.uint64_value = 102020202020 - message.int32_value = 1020 - message.bool_value = True - message.string_value = u'a string\u044f' - message.bytes_value = b'a bytes\xff\xfe' - message.enum_value = OptionalMessage.SimpleEnum.VAL2 - - self.EncodeDecode(self.encoded_full, message) - - def testRepeated(self): - """Test repeated fields.""" - message = RepeatedMessage() - message.double_value = [1.23, 2.3] - message.float_value = [-2.5, 0.5] - message.int64_value = [-100000000000, 20] - message.uint64_value = [102020202020, 10] - message.int32_value = [1020, 718] - message.bool_value = [True, False] - message.string_value = [u'a string\u044f', u'another string'] - message.bytes_value = [b'a bytes\xff\xfe', b'another bytes'] - message.enum_value = [RepeatedMessage.SimpleEnum.VAL2, - RepeatedMessage.SimpleEnum.VAL1] - - self.EncodeDecode(self.encoded_repeated, message) - - def testNested(self): - """Test nested messages.""" - nested_message = NestedMessage() - nested_message.a_value = u'a string' - - message = HasNestedMessage() - message.nested = nested_message - - self.EncodeDecode(self.encoded_nested, message) - - def testRepeatedNested(self): - """Test repeated nested messages.""" - nested_message1 = NestedMessage() - nested_message1.a_value = u'a string' - nested_message2 = NestedMessage() - nested_message2.a_value = u'another string' - - message = HasNestedMessage() - message.repeated_nested = [nested_message1, nested_message2] - - self.EncodeDecode(self.encoded_repeated_nested, message) - - def testStringTypes(self): - """Test that encoding str on StringField works.""" - message = OptionalMessage() - message.string_value = u'Latin' - self.EncodeDecode(self.encoded_string_types, message) - - def testEncodeUninitialized(self): - """Test that cannot encode uninitialized message.""" - required = NestedMessage() - self.assertRaisesWithRegexpMatch(messages.ValidationError, - "Message NestedMessage is missing " - "required field a_value", - self.PROTOLIB.encode_message, - required) - - def testUnexpectedField(self): - """Test decoding and encoding unexpected fields.""" - loaded_message = self.PROTOLIB.decode_message(OptionalMessage, - self.unexpected_tag_message) - # Message should be equal to an empty message, since unknown values aren't - # included in equality. - self.assertEquals(OptionalMessage(), loaded_message) - # Verify that the encoded message matches the source, including the - # unknown value. - self.assertEquals(self.unexpected_tag_message, - self.PROTOLIB.encode_message(loaded_message)) - - def testDoNotSendDefault(self): - """Test that default is not sent when nothing is assigned.""" - self.EncodeDecode(self.encoded_empty_message, HasDefault()) - - def testSendDefaultExplicitlyAssigned(self): - """Test that default is sent when explcitly assigned.""" - message = HasDefault() - - message.a_value = HasDefault.a_value.default - - self.EncodeDecode(self.encoded_default_assigned, message) - - def testEncodingNestedEmptyMessage(self): - """Test encoding a nested empty message.""" - message = HasOptionalNestedMessage() - message.nested = OptionalMessage() - - self.EncodeDecode(self.encoded_nested_empty, message) - - def testEncodingRepeatedNestedEmptyMessage(self): - """Test encoding a nested empty message.""" - message = HasOptionalNestedMessage() - message.repeated_nested = [OptionalMessage(), OptionalMessage()] - - self.EncodeDecode(self.encoded_repeated_nested_empty, message) - - def testContentType(self): - self.assertTrue(isinstance(self.PROTOLIB.CONTENT_TYPE, str)) - - def testDecodeInvalidEnumType(self): - self.assertRaisesWithRegexpMatch(messages.DecodeError, - 'Invalid enum value ', - self.PROTOLIB.decode_message, - OptionalMessage, - self.encoded_invalid_enum) - - def testDateTimeNoTimeZone(self): - """Test that DateTimeFields are encoded/decoded correctly.""" - - class MyMessage(messages.Message): - value = message_types.DateTimeField(1) - - value = datetime.datetime(2013, 1, 3, 11, 36, 30, 123000) - message = MyMessage(value=value) - decoded = self.PROTOLIB.decode_message( - MyMessage, self.PROTOLIB.encode_message(message)) - self.assertEquals(decoded.value, value) - - def testDateTimeWithTimeZone(self): - """Test DateTimeFields with time zones.""" - - class MyMessage(messages.Message): - value = message_types.DateTimeField(1) - - value = datetime.datetime(2013, 1, 3, 11, 36, 30, 123000, - util.TimeZoneOffset(8 * 60)) - message = MyMessage(value=value) - decoded = self.PROTOLIB.decode_message( - MyMessage, self.PROTOLIB.encode_message(message)) - self.assertEquals(decoded.value, value) - - -def do_with(context, function, *args, **kwargs): - """Simulate a with statement. - - Avoids need to import with from future. - - Does not support simulation of 'as'. - - Args: - context: Context object normally used with 'with'. - function: Callable to evoke. Replaces with-block. - """ - context.__enter__() - try: - function(*args, **kwargs) - except: - context.__exit__(*sys.exc_info()) - finally: - context.__exit__(None, None, None) - - -def pick_unused_port(): - """Find an unused port to use in tests. - - Derived from Damon Kohlers example: - - http://code.activestate.com/recipes/531822-pick-unused-port - """ - try: - temp = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - except socket.error: - # Try IPv6 - temp = socket.socket(socket.AF_INET6, socket.SOCK_STREAM) - - try: - temp.bind(('localhost', 0)) - port = temp.getsockname()[1] - finally: - temp.close() - return port - - -def get_module_name(module_attribute): - """Get the module name. - - Args: - module_attribute: An attribute of the module. - - Returns: - The fully qualified module name or simple module name where - 'module_attribute' is defined if the module name is "__main__". - """ - if module_attribute.__module__ == '__main__': - module_file = inspect.getfile(module_attribute) - default = os.path.basename(module_file).split('.')[0] - return default - else: - return module_attribute.__module__ diff --git a/endpoints/internal/protorpc/transport_test.py b/endpoints/internal/protorpc/transport_test.py deleted file mode 100644 index 8fa39c3..0000000 --- a/endpoints/internal/protorpc/transport_test.py +++ /dev/null @@ -1,493 +0,0 @@ -#!/usr/bin/env python -# -# Copyright 2010 Google Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -import errno -import six.moves.http_client -import os -import socket -import unittest - -from protorpc import messages -from protorpc import protobuf -from protorpc import protojson -from protorpc import remote -from protorpc import test_util -from protorpc import transport -from protorpc import webapp_test_util -from protorpc.wsgi import util as wsgi_util - -import mox - -package = 'transport_test' - - -class ModuleInterfaceTest(test_util.ModuleInterfaceTest, - test_util.TestCase): - - MODULE = transport - - -class Message(messages.Message): - - value = messages.StringField(1) - - -class Service(remote.Service): - - @remote.method(Message, Message) - def method(self, request): - pass - - -# Remove when RPC is no longer subclasses. -class TestRpc(transport.Rpc): - - waited = False - - def _wait_impl(self): - self.waited = True - - -class RpcTest(test_util.TestCase): - - def setUp(self): - self.request = Message(value=u'request') - self.response = Message(value=u'response') - self.status = remote.RpcStatus(state=remote.RpcState.APPLICATION_ERROR, - error_message='an error', - error_name='blam') - - self.rpc = TestRpc(self.request) - - def testConstructor(self): - self.assertEquals(self.request, self.rpc.request) - self.assertEquals(remote.RpcState.RUNNING, self.rpc.state) - self.assertEquals(None, self.rpc.error_message) - self.assertEquals(None, self.rpc.error_name) - - def response(self): - self.assertFalse(self.rpc.waited) - self.assertEquals(None, self.rpc.response) - self.assertTrue(self.rpc.waited) - - def testSetResponse(self): - self.rpc.set_response(self.response) - - self.assertEquals(self.request, self.rpc.request) - self.assertEquals(remote.RpcState.OK, self.rpc.state) - self.assertEquals(self.response, self.rpc.response) - self.assertEquals(None, self.rpc.error_message) - self.assertEquals(None, self.rpc.error_name) - - def testSetResponseAlreadySet(self): - self.rpc.set_response(self.response) - - self.assertRaisesWithRegexpMatch( - transport.RpcStateError, - 'RPC must be in RUNNING state to change to OK', - self.rpc.set_response, - self.response) - - def testSetResponseAlreadyError(self): - self.rpc.set_status(self.status) - - self.assertRaisesWithRegexpMatch( - transport.RpcStateError, - 'RPC must be in RUNNING state to change to OK', - self.rpc.set_response, - self.response) - - def testSetStatus(self): - self.rpc.set_status(self.status) - - self.assertEquals(self.request, self.rpc.request) - self.assertEquals(remote.RpcState.APPLICATION_ERROR, self.rpc.state) - self.assertEquals('an error', self.rpc.error_message) - self.assertEquals('blam', self.rpc.error_name) - self.assertRaisesWithRegexpMatch(remote.ApplicationError, - 'an error', - getattr, self.rpc, 'response') - - def testSetStatusAlreadySet(self): - self.rpc.set_response(self.response) - - self.assertRaisesWithRegexpMatch( - transport.RpcStateError, - 'RPC must be in RUNNING state to change to OK', - self.rpc.set_response, - self.response) - - def testSetNonMessage(self): - self.assertRaisesWithRegexpMatch( - TypeError, - 'Expected Message type, received 10', - self.rpc.set_response, - 10) - - def testSetStatusAlreadyError(self): - self.rpc.set_status(self.status) - - self.assertRaisesWithRegexpMatch( - transport.RpcStateError, - 'RPC must be in RUNNING state to change to OK', - self.rpc.set_response, - self.response) - - def testSetUninitializedStatus(self): - self.assertRaises(messages.ValidationError, - self.rpc.set_status, - remote.RpcStatus()) - - -class TransportTest(test_util.TestCase): - - def setUp(self): - remote.Protocols.set_default(remote.Protocols.new_default()) - - def do_test(self, protocol, trans): - request = Message() - request.value = u'request' - - response = Message() - response.value = u'response' - - encoded_request = protocol.encode_message(request) - encoded_response = protocol.encode_message(response) - - self.assertEquals(protocol, trans.protocol) - - received_rpc = [None] - def transport_rpc(remote, rpc_request): - self.assertEquals(remote, Service.method.remote) - self.assertEquals(request, rpc_request) - rpc = TestRpc(request) - rpc.set_response(response) - return rpc - trans._start_rpc = transport_rpc - - rpc = trans.send_rpc(Service.method.remote, request) - self.assertEquals(response, rpc.response) - - def testDefaultProtocol(self): - trans = transport.Transport() - self.do_test(protobuf, trans) - self.assertEquals(protobuf, trans.protocol_config.protocol) - self.assertEquals('default', trans.protocol_config.name) - - def testAlternateProtocol(self): - trans = transport.Transport(protocol=protojson) - self.do_test(protojson, trans) - self.assertEquals(protojson, trans.protocol_config.protocol) - self.assertEquals('default', trans.protocol_config.name) - - def testProtocolConfig(self): - protocol_config = remote.ProtocolConfig( - protojson, 'protoconfig', 'image/png') - trans = transport.Transport(protocol=protocol_config) - self.do_test(protojson, trans) - self.assertTrue(trans.protocol_config is protocol_config) - - def testProtocolByName(self): - remote.Protocols.get_default().add_protocol( - protojson, 'png', 'image/png', ()) - trans = transport.Transport(protocol='png') - self.do_test(protojson, trans) - - -@remote.method(Message, Message) -def my_method(self, request): - self.fail('self.my_method should not be directly invoked.') - - -class FakeConnectionClass(object): - - def __init__(self, mox): - self.request = mox.CreateMockAnything() - self.response = mox.CreateMockAnything() - - -class HttpTransportTest(webapp_test_util.WebServerTestBase): - - def setUp(self): - # Do not need much parent construction functionality. - - self.schema = 'http' - self.server = None - - self.request = Message(value=u'The request value') - self.encoded_request = protojson.encode_message(self.request) - - self.response = Message(value=u'The response value') - self.encoded_response = protojson.encode_message(self.response) - - def testCallSucceeds(self): - self.ResetServer(wsgi_util.static_page(self.encoded_response, - content_type='application/json')) - - rpc = self.connection.send_rpc(my_method.remote, self.request) - self.assertEquals(self.response, rpc.response) - - def testHttps(self): - self.schema = 'https' - self.ResetServer(wsgi_util.static_page(self.encoded_response, - content_type='application/json')) - - # Create a fake https connection function that really just calls http. - self.used_https = False - def https_connection(*args, **kwargs): - self.used_https = True - return six.moves.http_client.HTTPConnection(*args, **kwargs) - - original_https_connection = six.moves.http_client.HTTPSConnection - six.moves.http_client.HTTPSConnection = https_connection - try: - rpc = self.connection.send_rpc(my_method.remote, self.request) - finally: - six.moves.http_client.HTTPSConnection = original_https_connection - self.assertEquals(self.response, rpc.response) - self.assertTrue(self.used_https) - - def testHttpSocketError(self): - self.ResetServer(wsgi_util.static_page(self.encoded_response, - content_type='application/json')) - - bad_transport = transport.HttpTransport('http://localhost:-1/blar') - try: - bad_transport.send_rpc(my_method.remote, self.request) - except remote.NetworkError as err: - self.assertTrue(str(err).startswith('Socket error: error (')) - self.assertEquals(errno.ECONNREFUSED, err.cause.errno) - else: - self.fail('Expected error') - - def testHttpRequestError(self): - self.ResetServer(wsgi_util.static_page(self.encoded_response, - content_type='application/json')) - - def request_error(*args, **kwargs): - raise TypeError('Generic Error') - original_request = six.moves.http_client.HTTPConnection.request - six.moves.http_client.HTTPConnection.request = request_error - try: - try: - self.connection.send_rpc(my_method.remote, self.request) - except remote.NetworkError as err: - self.assertEquals('Error communicating with HTTP server', str(err)) - self.assertEquals(TypeError, type(err.cause)) - self.assertEquals('Generic Error', str(err.cause)) - else: - self.fail('Expected error') - finally: - six.moves.http_client.HTTPConnection.request = original_request - - def testHandleGenericServiceError(self): - self.ResetServer(wsgi_util.error(six.moves.http_client.INTERNAL_SERVER_ERROR, - 'arbitrary error', - content_type='text/plain')) - - rpc = self.connection.send_rpc(my_method.remote, self.request) - try: - rpc.response - except remote.ServerError as err: - self.assertEquals('HTTP Error 500: arbitrary error', str(err).strip()) - else: - self.fail('Expected ServerError') - - def testHandleGenericServiceErrorNoMessage(self): - self.ResetServer(wsgi_util.error(six.moves.http_client.NOT_IMPLEMENTED, - ' ', - content_type='text/plain')) - - rpc = self.connection.send_rpc(my_method.remote, self.request) - try: - rpc.response - except remote.ServerError as err: - self.assertEquals('HTTP Error 501: Not Implemented', str(err).strip()) - else: - self.fail('Expected ServerError') - - def testHandleStatusContent(self): - self.ResetServer(wsgi_util.static_page('{"state": "REQUEST_ERROR",' - ' "error_message": "a request error"' - '}', - status=six.moves.http_client.BAD_REQUEST, - content_type='application/json')) - - rpc = self.connection.send_rpc(my_method.remote, self.request) - try: - rpc.response - except remote.RequestError as err: - self.assertEquals('a request error', str(err)) - else: - self.fail('Expected RequestError') - - def testHandleApplicationError(self): - self.ResetServer(wsgi_util.static_page('{"state": "APPLICATION_ERROR",' - ' "error_message": "an app error",' - ' "error_name": "MY_ERROR_NAME"}', - status=six.moves.http_client.BAD_REQUEST, - content_type='application/json')) - - rpc = self.connection.send_rpc(my_method.remote, self.request) - try: - rpc.response - except remote.ApplicationError as err: - self.assertEquals('an app error', str(err)) - self.assertEquals('MY_ERROR_NAME', err.error_name) - else: - self.fail('Expected RequestError') - - def testHandleUnparsableErrorContent(self): - self.ResetServer(wsgi_util.static_page('oops', - status=six.moves.http_client.BAD_REQUEST, - content_type='application/json')) - - rpc = self.connection.send_rpc(my_method.remote, self.request) - try: - rpc.response - except remote.ServerError as err: - self.assertEquals('HTTP Error 400: oops', str(err)) - else: - self.fail('Expected ServerError') - - def testHandleEmptyBadRpcStatus(self): - self.ResetServer(wsgi_util.static_page('{"error_message": "x"}', - status=six.moves.http_client.BAD_REQUEST, - content_type='application/json')) - - rpc = self.connection.send_rpc(my_method.remote, self.request) - try: - rpc.response - except remote.ServerError as err: - self.assertEquals('HTTP Error 400: {"error_message": "x"}', str(err)) - else: - self.fail('Expected ServerError') - - def testUseProtocolConfigContentType(self): - expected_content_type = 'image/png' - def expect_content_type(environ, start_response): - self.assertEquals(expected_content_type, environ['CONTENT_TYPE']) - app = wsgi_util.static_page('', content_type=environ['CONTENT_TYPE']) - return app(environ, start_response) - - self.ResetServer(expect_content_type) - - protocol_config = remote.ProtocolConfig(protojson, 'json', 'image/png') - self.connection = self.CreateTransport(self.service_url, protocol_config) - - rpc = self.connection.send_rpc(my_method.remote, self.request) - self.assertEquals(Message(), rpc.response) - - -class SimpleRequest(messages.Message): - - content = messages.StringField(1) - - -class SimpleResponse(messages.Message): - - content = messages.StringField(1) - factory_value = messages.StringField(2) - remote_host = messages.StringField(3) - remote_address = messages.StringField(4) - server_host = messages.StringField(5) - server_port = messages.IntegerField(6) - - -class LocalService(remote.Service): - - def __init__(self, factory_value='default'): - self.factory_value = factory_value - - @remote.method(SimpleRequest, SimpleResponse) - def call_method(self, request): - return SimpleResponse(content=request.content, - factory_value=self.factory_value, - remote_host=self.request_state.remote_host, - remote_address=self.request_state.remote_address, - server_host=self.request_state.server_host, - server_port=self.request_state.server_port) - - @remote.method() - def raise_totally_unexpected(self, request): - raise TypeError('Kablam') - - @remote.method() - def raise_unexpected(self, request): - raise remote.RequestError('Huh?') - - @remote.method() - def raise_application_error(self, request): - raise remote.ApplicationError('App error', 10) - - -class LocalTransportTest(test_util.TestCase): - - def CreateService(self, factory_value='default'): - return - - def testBasicCallWithClass(self): - stub = LocalService.Stub(transport.LocalTransport(LocalService)) - response = stub.call_method(content='Hello') - self.assertEquals(SimpleResponse(content='Hello', - factory_value='default', - remote_host=os.uname()[1], - remote_address='127.0.0.1', - server_host=os.uname()[1], - server_port=-1), - response) - - def testBasicCallWithFactory(self): - stub = LocalService.Stub( - transport.LocalTransport(LocalService.new_factory('assigned'))) - response = stub.call_method(content='Hello') - self.assertEquals(SimpleResponse(content='Hello', - factory_value='assigned', - remote_host=os.uname()[1], - remote_address='127.0.0.1', - server_host=os.uname()[1], - server_port=-1), - response) - - def testTotallyUnexpectedError(self): - stub = LocalService.Stub(transport.LocalTransport(LocalService)) - self.assertRaisesWithRegexpMatch( - remote.ServerError, - 'Unexpected error TypeError: Kablam', - stub.raise_totally_unexpected) - - def testUnexpectedError(self): - stub = LocalService.Stub(transport.LocalTransport(LocalService)) - self.assertRaisesWithRegexpMatch( - remote.ServerError, - 'Unexpected error RequestError: Huh?', - stub.raise_unexpected) - - def testApplicationError(self): - stub = LocalService.Stub(transport.LocalTransport(LocalService)) - self.assertRaisesWithRegexpMatch( - remote.ApplicationError, - 'App error', - stub.raise_application_error) - - -def main(): - unittest.main() - - -if __name__ == '__main__': - main() diff --git a/endpoints/internal/protorpc/util_test.py b/endpoints/internal/protorpc/util_test.py deleted file mode 100644 index df05c32..0000000 --- a/endpoints/internal/protorpc/util_test.py +++ /dev/null @@ -1,394 +0,0 @@ -#!/usr/bin/env python -# -# Copyright 2010 Google Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -"""Tests for protorpc.util.""" -import six - -__author__ = 'rafek@google.com (Rafe Kaplan)' - - -import datetime -import random -import sys -import types -import unittest - -from protorpc import test_util -from protorpc import util - - -class ModuleInterfaceTest(test_util.ModuleInterfaceTest, - test_util.TestCase): - - MODULE = util - - -class PadStringTest(test_util.TestCase): - - def testPadEmptyString(self): - self.assertEquals(' ' * 512, util.pad_string('')) - - def testPadString(self): - self.assertEquals('hello' + (507 * ' '), util.pad_string('hello')) - - def testPadLongString(self): - self.assertEquals('x' * 1000, util.pad_string('x' * 1000)) - - -class UtilTest(test_util.TestCase): - - def testDecoratedFunction_LengthZero(self): - @util.positional(0) - def fn(kwonly=1): - return [kwonly] - self.assertEquals([1], fn()) - self.assertEquals([2], fn(kwonly=2)) - self.assertRaisesWithRegexpMatch(TypeError, - r'fn\(\) takes at most 0 positional ' - r'arguments \(1 given\)', - fn, 1) - - def testDecoratedFunction_LengthOne(self): - @util.positional(1) - def fn(pos, kwonly=1): - return [pos, kwonly] - self.assertEquals([1, 1], fn(1)) - self.assertEquals([2, 2], fn(2, kwonly=2)) - self.assertRaisesWithRegexpMatch(TypeError, - r'fn\(\) takes at most 1 positional ' - r'argument \(2 given\)', - fn, 2, 3) - - def testDecoratedFunction_LengthTwoWithDefault(self): - @util.positional(2) - def fn(pos1, pos2=1, kwonly=1): - return [pos1, pos2, kwonly] - self.assertEquals([1, 1, 1], fn(1)) - self.assertEquals([2, 2, 1], fn(2, 2)) - self.assertEquals([2, 3, 4], fn(2, 3, kwonly=4)) - self.assertRaisesWithRegexpMatch(TypeError, - r'fn\(\) takes at most 2 positional ' - r'arguments \(3 given\)', - fn, 2, 3, 4) - - def testDecoratedMethod(self): - class MyClass(object): - @util.positional(2) - def meth(self, pos1, kwonly=1): - return [pos1, kwonly] - self.assertEquals([1, 1], MyClass().meth(1)) - self.assertEquals([2, 2], MyClass().meth(2, kwonly=2)) - self.assertRaisesWithRegexpMatch(TypeError, - r'meth\(\) takes at most 2 positional ' - r'arguments \(3 given\)', - MyClass().meth, 2, 3) - - def testDefaultDecoration(self): - @util.positional - def fn(a, b, c=None): - return a, b, c - self.assertEquals((1, 2, 3), fn(1, 2, c=3)) - self.assertEquals((3, 4, None), fn(3, b=4)) - self.assertRaisesWithRegexpMatch(TypeError, - r'fn\(\) takes at most 2 positional ' - r'arguments \(3 given\)', - fn, 2, 3, 4) - - def testDefaultDecorationNoKwdsFails(self): - def fn(a): - return a - self.assertRaisesRegexp( - ValueError, - 'Functions with no keyword arguments must specify max_positional_args', - util.positional, fn) - - def testDecoratedFunctionDocstring(self): - @util.positional(0) - def fn(kwonly=1): - """fn docstring.""" - return [kwonly] - self.assertEquals('fn docstring.', fn.__doc__) - - -class AcceptItemTest(test_util.TestCase): - - def CheckAttributes(self, item, main_type, sub_type, q=1, values={}, index=1): - self.assertEquals(index, item.index) - self.assertEquals(main_type, item.main_type) - self.assertEquals(sub_type, item.sub_type) - self.assertEquals(q, item.q) - self.assertEquals(values, item.values) - - def testParse(self): - self.CheckAttributes(util.AcceptItem('*/*', 1), None, None) - self.CheckAttributes(util.AcceptItem('text/*', 1), 'text', None) - self.CheckAttributes(util.AcceptItem('text/plain', 1), 'text', 'plain') - self.CheckAttributes( - util.AcceptItem('text/plain; q=0.3', 1), 'text', 'plain', 0.3, - values={'q': '0.3'}) - self.CheckAttributes( - util.AcceptItem('text/plain; level=2', 1), 'text', 'plain', - values={'level': '2'}) - self.CheckAttributes( - util.AcceptItem('text/plain', 10), 'text', 'plain', index=10) - - def testCaseInsensitive(self): - self.CheckAttributes(util.AcceptItem('Text/Plain', 1), 'text', 'plain') - - def testBadValue(self): - self.assertRaises(util.AcceptError, - util.AcceptItem, 'bad value', 1) - self.assertRaises(util.AcceptError, - util.AcceptItem, 'bad value/', 1) - self.assertRaises(util.AcceptError, - util.AcceptItem, '/bad value', 1) - - def testSortKey(self): - item = util.AcceptItem('main/sub; q=0.2; level=3', 11) - self.assertEquals((False, False, -0.2, False, 11), item.sort_key) - - item = util.AcceptItem('main/*', 12) - self.assertEquals((False, True, -1, True, 12), item.sort_key) - - item = util.AcceptItem('*/*', 1) - self.assertEquals((True, True, -1, True, 1), item.sort_key) - - def testSort(self): - i1 = util.AcceptItem('text/*', 1) - i2 = util.AcceptItem('text/html', 2) - i3 = util.AcceptItem('text/html; q=0.9', 3) - i4 = util.AcceptItem('text/html; q=0.3', 4) - i5 = util.AcceptItem('text/xml', 5) - i6 = util.AcceptItem('text/html; level=1', 6) - i7 = util.AcceptItem('*/*', 7) - items = [i1, i2 ,i3 ,i4 ,i5 ,i6, i7] - random.shuffle(items) - self.assertEquals([i6, i2, i5, i3, i4, i1, i7], sorted(items)) - - def testMatchAll(self): - item = util.AcceptItem('*/*', 1) - self.assertTrue(item.match('text/html')) - self.assertTrue(item.match('text/plain; level=1')) - self.assertTrue(item.match('image/png')) - self.assertTrue(item.match('image/png; q=0.3')) - - def testMatchMainType(self): - item = util.AcceptItem('text/*', 1) - self.assertTrue(item.match('text/html')) - self.assertTrue(item.match('text/plain; level=1')) - self.assertFalse(item.match('image/png')) - self.assertFalse(item.match('image/png; q=0.3')) - - def testMatchFullType(self): - item = util.AcceptItem('text/plain', 1) - self.assertFalse(item.match('text/html')) - self.assertTrue(item.match('text/plain; level=1')) - self.assertFalse(item.match('image/png')) - self.assertFalse(item.match('image/png; q=0.3')) - - def testMatchCaseInsensitive(self): - item = util.AcceptItem('text/plain', 1) - self.assertTrue(item.match('tExt/pLain')) - - def testStr(self): - self.assertHeaderSame('*/*', str(util.AcceptItem('*/*', 1))) - self.assertHeaderSame('text/*', str(util.AcceptItem('text/*', 1))) - self.assertHeaderSame('text/plain', - str(util.AcceptItem('text/plain', 1))) - self.assertHeaderSame('text/plain; q=0.2', - str(util.AcceptItem('text/plain; q=0.2', 1))) - self.assertHeaderSame( - 'text/plain; q=0.2; level=1', - str(util.AcceptItem('text/plain; level=1; q=0.2', 1))) - - def testRepr(self): - self.assertEquals("AcceptItem('*/*', 1)", repr(util.AcceptItem('*/*', 1))) - self.assertEquals("AcceptItem('text/plain', 11)", - repr(util.AcceptItem('text/plain', 11))) - - def testValues(self): - item = util.AcceptItem('text/plain; a=1; b=2; c=3;', 1) - values = item.values - self.assertEquals(dict(a="1", b="2", c="3"), values) - values['a'] = "7" - self.assertNotEquals(values, item.values) - - -class ParseAcceptHeaderTest(test_util.TestCase): - - def testIndex(self): - accept_header = """text/*, text/html, text/html; q=0.9, - text/xml, - text/html; level=1, */*""" - accepts = util.parse_accept_header(accept_header) - self.assertEquals(6, len(accepts)) - self.assertEquals([4, 1, 3, 2, 0, 5], [a.index for a in accepts]) - - -class ChooseContentTypeTest(test_util.TestCase): - - def testIgnoreUnrequested(self): - self.assertEquals('application/json', - util.choose_content_type( - 'text/plain, application/json, */*', - ['application/X-Google-protobuf', - 'application/json' - ])) - - def testUseCorrectPreferenceIndex(self): - self.assertEquals('application/json', - util.choose_content_type( - '*/*, text/plain, application/json', - ['application/X-Google-protobuf', - 'application/json' - ])) - - def testPreferFirstInList(self): - self.assertEquals('application/X-Google-protobuf', - util.choose_content_type( - '*/*', - ['application/X-Google-protobuf', - 'application/json' - ])) - - def testCaseInsensitive(self): - self.assertEquals('application/X-Google-protobuf', - util.choose_content_type( - 'application/x-google-protobuf', - ['application/X-Google-protobuf', - 'application/json' - ])) - - -class GetPackageForModuleTest(test_util.TestCase): - - def setUp(self): - self.original_modules = dict(sys.modules) - - def tearDown(self): - sys.modules.clear() - sys.modules.update(self.original_modules) - - def CreateModule(self, name, file_name=None): - if file_name is None: - file_name = '%s.py' % name - module = types.ModuleType(name) - sys.modules[name] = module - return module - - def assertPackageEquals(self, expected, actual): - self.assertEquals(expected, actual) - if actual is not None: - self.assertTrue(isinstance(actual, six.text_type)) - - def testByString(self): - module = self.CreateModule('service_module') - module.package = 'my_package' - self.assertPackageEquals('my_package', - util.get_package_for_module('service_module')) - - def testModuleNameNotInSys(self): - self.assertPackageEquals(None, - util.get_package_for_module('service_module')) - - def testHasPackage(self): - module = self.CreateModule('service_module') - module.package = 'my_package' - self.assertPackageEquals('my_package', util.get_package_for_module(module)) - - def testHasModuleName(self): - module = self.CreateModule('service_module') - self.assertPackageEquals('service_module', - util.get_package_for_module(module)) - - def testIsMain(self): - module = self.CreateModule('__main__') - module.__file__ = '/bing/blam/bloom/blarm/my_file.py' - self.assertPackageEquals('my_file', util.get_package_for_module(module)) - - def testIsMainCompiled(self): - module = self.CreateModule('__main__') - module.__file__ = '/bing/blam/bloom/blarm/my_file.pyc' - self.assertPackageEquals('my_file', util.get_package_for_module(module)) - - def testNoExtension(self): - module = self.CreateModule('__main__') - module.__file__ = '/bing/blam/bloom/blarm/my_file' - self.assertPackageEquals('my_file', util.get_package_for_module(module)) - - def testNoPackageAtAll(self): - module = self.CreateModule('__main__') - self.assertPackageEquals('__main__', util.get_package_for_module(module)) - - -class DateTimeTests(test_util.TestCase): - - def testDecodeDateTime(self): - """Test that a RFC 3339 datetime string is decoded properly.""" - for datetime_string, datetime_vals in ( - ('2012-09-30T15:31:50.262', (2012, 9, 30, 15, 31, 50, 262000)), - ('2012-09-30T15:31:50', (2012, 9, 30, 15, 31, 50, 0))): - decoded = util.decode_datetime(datetime_string) - expected = datetime.datetime(*datetime_vals) - self.assertEquals(expected, decoded) - - def testDateTimeTimeZones(self): - """Test that a datetime string with a timezone is decoded correctly.""" - for datetime_string, datetime_vals in ( - ('2012-09-30T15:31:50.262-06:00', - (2012, 9, 30, 15, 31, 50, 262000, util.TimeZoneOffset(-360))), - ('2012-09-30T15:31:50.262+01:30', - (2012, 9, 30, 15, 31, 50, 262000, util.TimeZoneOffset(90))), - ('2012-09-30T15:31:50+00:05', - (2012, 9, 30, 15, 31, 50, 0, util.TimeZoneOffset(5))), - ('2012-09-30T15:31:50+00:00', - (2012, 9, 30, 15, 31, 50, 0, util.TimeZoneOffset(0))), - ('2012-09-30t15:31:50-00:00', - (2012, 9, 30, 15, 31, 50, 0, util.TimeZoneOffset(0))), - ('2012-09-30t15:31:50z', - (2012, 9, 30, 15, 31, 50, 0, util.TimeZoneOffset(0))), - ('2012-09-30T15:31:50-23:00', - (2012, 9, 30, 15, 31, 50, 0, util.TimeZoneOffset(-1380)))): - decoded = util.decode_datetime(datetime_string) - expected = datetime.datetime(*datetime_vals) - self.assertEquals(expected, decoded) - - def testDecodeDateTimeInvalid(self): - """Test that decoding malformed datetime strings raises execptions.""" - for datetime_string in ('invalid', - '2012-09-30T15:31:50.', - '-08:00 2012-09-30T15:31:50.262', - '2012-09-30T15:31', - '2012-09-30T15:31Z', - '2012-09-30T15:31:50ZZ', - '2012-09-30T15:31:50.262 blah blah -08:00', - '1000-99-99T25:99:99.999-99:99'): - self.assertRaises(ValueError, util.decode_datetime, datetime_string) - - def testTimeZoneOffsetDelta(self): - """Test that delta works with TimeZoneOffset.""" - time_zone = util.TimeZoneOffset(datetime.timedelta(minutes=3)) - epoch = time_zone.utcoffset(datetime.datetime.utcfromtimestamp(0)) - self.assertEqual(180, util.total_seconds(epoch)) - - -def main(): - unittest.main() - - -if __name__ == '__main__': - main() diff --git a/endpoints/internal/protorpc/webapp/forms_test.py b/endpoints/internal/protorpc/webapp/forms_test.py deleted file mode 100644 index dcac88d..0000000 --- a/endpoints/internal/protorpc/webapp/forms_test.py +++ /dev/null @@ -1,159 +0,0 @@ -#!/usr/bin/env python -# -# Copyright 2010 Google Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -"""Tests for protorpc.forms.""" - -__author__ = 'rafek@google.com (Rafe Kaplan)' - - -import os -import unittest - -from protorpc import test_util -from protorpc import webapp_test_util -from protorpc.webapp import forms -from protorpc.webapp.google_imports import template - - -class ModuleInterfaceTest(test_util.ModuleInterfaceTest, - test_util.TestCase): - - MODULE = forms - - -def RenderTemplate(name, **params): - """Load content from static file. - - Args: - name: Name of static file to load from static directory. - params: Passed in to webapp template generator. - - Returns: - Contents of static file. - """ - path = os.path.join(forms._TEMPLATES_DIR, name) - return template.render(path, params) - - -class ResourceHandlerTest(webapp_test_util.RequestHandlerTestBase): - - def CreateRequestHandler(self): - return forms.ResourceHandler() - - def DoStaticContentTest(self, name, expected_type): - """Run the static content test. - - Loads expected static content from source and compares with - results in response. Checks content-type and cache header. - - Args: - name: Name of file that should be served. - expected_type: Expected content-type of served file. - """ - self.handler.get(name) - - content = RenderTemplate(name) - self.CheckResponse('200 OK', - {'content-type': expected_type, - }, - content) - - def testGet(self): - self.DoStaticContentTest('forms.js', 'text/javascript') - - def testNoSuchFile(self): - self.handler.get('unknown.txt') - - self.CheckResponse('404 Not Found', - {}, - 'Resource not found.') - - -class FormsHandlerTest(webapp_test_util.RequestHandlerTestBase): - - def CreateRequestHandler(self): - handler = forms.FormsHandler('/myreg') - self.assertEquals('/myreg', handler.registry_path) - return handler - - def testGetForm(self): - self.handler.get() - - content = RenderTemplate( - 'forms.html', - forms_path='/tmp/myhandler', - hostname=self.request.host, - registry_path='/myreg') - - self.CheckResponse('200 OK', - {}, - content) - - def testGet_MissingPath(self): - self.ResetHandler({'QUERY_STRING': 'method=my_method'}) - - self.handler.get() - - content = RenderTemplate( - 'forms.html', - forms_path='/tmp/myhandler', - hostname=self.request.host, - registry_path='/myreg') - - self.CheckResponse('200 OK', - {}, - content) - - def testGet_MissingMethod(self): - self.ResetHandler({'QUERY_STRING': 'path=/my-path'}) - - self.handler.get() - - content = RenderTemplate( - 'forms.html', - forms_path='/tmp/myhandler', - hostname=self.request.host, - registry_path='/myreg') - - self.CheckResponse('200 OK', - {}, - content) - - def testGetMethod(self): - self.ResetHandler({'QUERY_STRING': 'path=/my-path&method=my_method'}) - - self.handler.get() - - content = RenderTemplate( - 'methods.html', - forms_path='/tmp/myhandler', - hostname=self.request.host, - registry_path='/myreg', - service_path='/my-path', - method_name='my_method') - - self.CheckResponse('200 OK', - {}, - content) - - -def main(): - unittest.main() - - -if __name__ == '__main__': - main() diff --git a/endpoints/internal/protorpc/webapp/service_handlers_test.py b/endpoints/internal/protorpc/webapp/service_handlers_test.py deleted file mode 100644 index baebbda..0000000 --- a/endpoints/internal/protorpc/webapp/service_handlers_test.py +++ /dev/null @@ -1,1332 +0,0 @@ -#!/usr/bin/env python -# -# Copyright 2010 Google Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -"""Tests for protorpc.service_handlers.""" - -__author__ = 'rafek@google.com (Rafe Kaplan)' - - -import cgi -import cStringIO -import os -import re -import sys -import unittest -import urllib - -from protorpc import messages -from protorpc import protobuf -from protorpc import protojson -from protorpc import protourlencode -from protorpc import message_types -from protorpc import registry -from protorpc import remote -from protorpc import test_util -from protorpc import util -from protorpc import webapp_test_util -from protorpc.webapp import forms -from protorpc.webapp import service_handlers -from protorpc.webapp.google_imports import webapp - -import mox - -package = 'test_package' - - -class ModuleInterfaceTest(test_util.ModuleInterfaceTest, - test_util.TestCase): - - MODULE = service_handlers - - -class Enum1(messages.Enum): - """A test enum class.""" - - VAL1 = 1 - VAL2 = 2 - VAL3 = 3 - - -class Request1(messages.Message): - """A test request message type.""" - - integer_field = messages.IntegerField(1) - string_field = messages.StringField(2) - enum_field = messages.EnumField(Enum1, 3) - - -class Response1(messages.Message): - """A test response message type.""" - - integer_field = messages.IntegerField(1) - string_field = messages.StringField(2) - enum_field = messages.EnumField(Enum1, 3) - - -class SuperMessage(messages.Message): - """A test message with a nested message field.""" - - sub_message = messages.MessageField(Request1, 1) - sub_messages = messages.MessageField(Request1, 2, repeated=True) - - -class SuperSuperMessage(messages.Message): - """A test message with two levels of nested.""" - - sub_message = messages.MessageField(SuperMessage, 1) - sub_messages = messages.MessageField(Request1, 2, repeated=True) - - -class RepeatedMessage(messages.Message): - """A test message with a repeated field.""" - - ints = messages.IntegerField(1, repeated=True) - strings = messages.StringField(2, repeated=True) - enums = messages.EnumField(Enum1, 3, repeated=True) - - -class Service(object): - """A simple service that takes a Request1 and returns Request2.""" - - @remote.method(Request1, Response1) - def method1(self, request): - response = Response1() - if hasattr(request, 'integer_field'): - response.integer_field = request.integer_field - if hasattr(request, 'string_field'): - response.string_field = request.string_field - if hasattr(request, 'enum_field'): - response.enum_field = request.enum_field - return response - - @remote.method(RepeatedMessage, RepeatedMessage) - def repeated_method(self, request): - response = RepeatedMessage() - if hasattr(request, 'ints'): - response = request.ints - return response - - def not_remote(self): - pass - - -def VerifyResponse(test, - response, - expected_status, - expected_status_message, - expected_content, - expected_content_type='application/x-www-form-urlencoded'): - def write(content): - if expected_content == '': - test.assertEquals(util.pad_string(''), content) - else: - test.assertNotEquals(-1, content.find(expected_content), - 'Expected to find:\n%s\n\nActual content: \n%s' % ( - expected_content, content)) - - def start_response(response, headers): - status, message = response.split(' ', 1) - test.assertEquals(expected_status, status) - test.assertEquals(expected_status_message, message) - for name, value in headers: - if name.lower() == 'content-type': - test.assertEquals(expected_content_type, value) - for name, value in headers: - if name.lower() == 'x-content-type-options': - test.assertEquals('nosniff', value) - elif name.lower() == 'content-type': - test.assertFalse(value.lower().startswith('text/html')) - return write - - response.wsgi_write(start_response) - - -class ServiceHandlerFactoryTest(test_util.TestCase): - """Tests for the service handler factory.""" - - def testAllRequestMappers(self): - """Test all_request_mappers method.""" - configuration = service_handlers.ServiceHandlerFactory(Service) - mapper1 = service_handlers.RPCMapper(['whatever'], 'whatever', None) - mapper2 = service_handlers.RPCMapper(['whatever'], 'whatever', None) - - configuration.add_request_mapper(mapper1) - self.assertEquals([mapper1], list(configuration.all_request_mappers())) - - configuration.add_request_mapper(mapper2) - self.assertEquals([mapper1, mapper2], - list(configuration.all_request_mappers())) - - def testServiceFactory(self): - """Test that service_factory attribute is set.""" - handler_factory = service_handlers.ServiceHandlerFactory(Service) - self.assertEquals(Service, handler_factory.service_factory) - - def testFactoryMethod(self): - """Test that factory creates correct instance of class.""" - factory = service_handlers.ServiceHandlerFactory(Service) - handler = factory() - - self.assertTrue(isinstance(handler, service_handlers.ServiceHandler)) - self.assertTrue(isinstance(handler.service, Service)) - - def testMapping(self): - """Test the mapping method.""" - factory = service_handlers.ServiceHandlerFactory(Service) - path, mapped_factory = factory.mapping('/my_service') - - self.assertEquals(r'(/my_service)' + service_handlers._METHOD_PATTERN, path) - self.assertEquals(id(factory), id(mapped_factory)) - match = re.match(path, '/my_service.my_method') - self.assertEquals('/my_service', match.group(1)) - self.assertEquals('my_method', match.group(2)) - - path, mapped_factory = factory.mapping('/my_service/nested') - self.assertEquals('(/my_service/nested)' + - service_handlers._METHOD_PATTERN, path) - match = re.match(path, '/my_service/nested.my_method') - self.assertEquals('/my_service/nested', match.group(1)) - self.assertEquals('my_method', match.group(2)) - - def testRegexMapping(self): - """Test the mapping method using a regex.""" - factory = service_handlers.ServiceHandlerFactory(Service) - path, mapped_factory = factory.mapping('.*/my_service') - - self.assertEquals(r'(.*/my_service)' + service_handlers._METHOD_PATTERN, path) - self.assertEquals(id(factory), id(mapped_factory)) - match = re.match(path, '/whatever_preceeds/my_service.my_method') - self.assertEquals('/whatever_preceeds/my_service', match.group(1)) - self.assertEquals('my_method', match.group(2)) - match = re.match(path, '/something_else/my_service.my_other_method') - self.assertEquals('/something_else/my_service', match.group(1)) - self.assertEquals('my_other_method', match.group(2)) - - def testMapping_BadPath(self): - """Test bad parameterse to the mapping method.""" - factory = service_handlers.ServiceHandlerFactory(Service) - self.assertRaises(ValueError, factory.mapping, '/my_service/') - - def testDefault(self): - """Test the default factory convenience method.""" - handler_factory = service_handlers.ServiceHandlerFactory.default( - Service, - parameter_prefix='my_prefix.') - - self.assertEquals(Service, handler_factory.service_factory) - - mappers = handler_factory.all_request_mappers() - - # Verify Protobuf encoded mapper. - protobuf_mapper = next(mappers) - self.assertTrue(isinstance(protobuf_mapper, - service_handlers.ProtobufRPCMapper)) - - # Verify JSON encoded mapper. - json_mapper = next(mappers) - self.assertTrue(isinstance(json_mapper, - service_handlers.JSONRPCMapper)) - - # Should have no more mappers. - self.assertRaises(StopIteration, mappers.next) - - -class ServiceHandlerTest(webapp_test_util.RequestHandlerTestBase): - """Test the ServiceHandler class.""" - - def setUp(self): - self.mox = mox.Mox() - self.service_factory = Service - self.remote_host = 'remote.host.com' - self.server_host = 'server.host.com' - self.ResetRequestHandler() - - self.request = Request1() - self.request.integer_field = 1 - self.request.string_field = 'a' - self.request.enum_field = Enum1.VAL1 - - def ResetRequestHandler(self): - super(ServiceHandlerTest, self).setUp() - - def CreateService(self): - return self.service_factory() - - def CreateRequestHandler(self): - self.rpc_mapper1 = self.mox.CreateMock(service_handlers.RPCMapper) - self.rpc_mapper1.http_methods = set(['POST']) - self.rpc_mapper1.content_types = set(['application/x-www-form-urlencoded']) - self.rpc_mapper1.default_content_type = 'application/x-www-form-urlencoded' - self.rpc_mapper2 = self.mox.CreateMock(service_handlers.RPCMapper) - self.rpc_mapper2.http_methods = set(['GET']) - self.rpc_mapper2.content_types = set(['application/json']) - self.rpc_mapper2.default_content_type = 'application/json' - self.factory = service_handlers.ServiceHandlerFactory( - self.CreateService) - self.factory.add_request_mapper(self.rpc_mapper1) - self.factory.add_request_mapper(self.rpc_mapper2) - return self.factory() - - def GetEnvironment(self): - """Create handler to test.""" - environ = super(ServiceHandlerTest, self).GetEnvironment() - if self.remote_host: - environ['REMOTE_HOST'] = self.remote_host - if self.server_host: - environ['SERVER_HOST'] = self.server_host - return environ - - def VerifyResponse(self, *args, **kwargs): - VerifyResponse(self, - self.response, - *args, **kwargs) - - def ExpectRpcError(self, mapper, state, error_message, error_name=None): - mapper.build_response(self.handler, - remote.RpcStatus(state=state, - error_message=error_message, - error_name=error_name)) - - def testRedirect(self): - """Test that redirection is disabled.""" - self.assertRaises(NotImplementedError, self.handler.redirect, '/') - - def testFirstMapper(self): - """Make sure service attribute works when matches first RPCMapper.""" - self.rpc_mapper1.build_request( - self.handler, Request1).AndReturn(self.request) - - def build_response(handler, response): - output = '%s %s %s' % (response.integer_field, - response.string_field, - response.enum_field) - handler.response.headers['content-type'] = ( - 'application/x-www-form-urlencoded') - handler.response.out.write(output) - self.rpc_mapper1.build_response( - self.handler, mox.IsA(Response1)).WithSideEffects(build_response) - - self.mox.ReplayAll() - - self.handler.handle('POST', '/my_service', 'method1') - - self.VerifyResponse('200', 'OK', '1 a VAL1') - - self.mox.VerifyAll() - - def testSecondMapper(self): - """Make sure service attribute works when matches first RPCMapper. - - Demonstrates the multiplicity of the RPCMapper configuration. - """ - self.rpc_mapper2.build_request( - self.handler, Request1).AndReturn(self.request) - - def build_response(handler, response): - output = '%s %s %s' % (response.integer_field, - response.string_field, - response.enum_field) - handler.response.headers['content-type'] = ( - 'application/x-www-form-urlencoded') - handler.response.out.write(output) - self.rpc_mapper2.build_response( - self.handler, mox.IsA(Response1)).WithSideEffects(build_response) - - self.mox.ReplayAll() - - self.handler.request.headers['Content-Type'] = 'application/json' - self.handler.handle('GET', '/my_service', 'method1') - - self.VerifyResponse('200', 'OK', '1 a VAL1') - - self.mox.VerifyAll() - - def testCaseInsensitiveContentType(self): - """Ensure that matching content-type is case insensitive.""" - request = Request1() - request.integer_field = 1 - request.string_field = 'a' - request.enum_field = Enum1.VAL1 - self.rpc_mapper1.build_request(self.handler, - Request1).AndReturn(self.request) - - def build_response(handler, response): - output = '%s %s %s' % (response.integer_field, - response.string_field, - response.enum_field) - handler.response.out.write(output) - handler.response.headers['content-type'] = 'text/plain' - self.rpc_mapper1.build_response( - self.handler, mox.IsA(Response1)).WithSideEffects(build_response) - - self.mox.ReplayAll() - - self.handler.request.headers['Content-Type'] = ('ApPlIcAtIoN/' - 'X-wWw-FoRm-UrLeNcOdEd') - - self.handler.handle('POST', '/my_service', 'method1') - - self.VerifyResponse('200', 'OK', '1 a VAL1', 'text/plain') - - self.mox.VerifyAll() - - def testContentTypeWithParameters(self): - """Test that content types have parameters parsed out.""" - request = Request1() - request.integer_field = 1 - request.string_field = 'a' - request.enum_field = Enum1.VAL1 - self.rpc_mapper1.build_request(self.handler, - Request1).AndReturn(self.request) - - def build_response(handler, response): - output = '%s %s %s' % (response.integer_field, - response.string_field, - response.enum_field) - handler.response.headers['content-type'] = ( - 'application/x-www-form-urlencoded') - handler.response.out.write(output) - self.rpc_mapper1.build_response( - self.handler, mox.IsA(Response1)).WithSideEffects(build_response) - - self.mox.ReplayAll() - - self.handler.request.headers['Content-Type'] = ('application/' - 'x-www-form-urlencoded' + - '; a=b; c=d') - - self.handler.handle('POST', '/my_service', 'method1') - - self.VerifyResponse('200', 'OK', '1 a VAL1') - - self.mox.VerifyAll() - - def testContentFromHeaderOnly(self): - """Test getting content-type from HTTP_CONTENT_TYPE directly. - - Some bad web server implementations might decide not to set CONTENT_TYPE for - POST requests where there is an empty body. In these cases, need to get - content-type directly from webob environ key HTTP_CONTENT_TYPE. - """ - request = Request1() - request.integer_field = 1 - request.string_field = 'a' - request.enum_field = Enum1.VAL1 - self.rpc_mapper1.build_request(self.handler, - Request1).AndReturn(self.request) - - def build_response(handler, response): - output = '%s %s %s' % (response.integer_field, - response.string_field, - response.enum_field) - handler.response.headers['Content-Type'] = ( - 'application/x-www-form-urlencoded') - handler.response.out.write(output) - self.rpc_mapper1.build_response( - self.handler, mox.IsA(Response1)).WithSideEffects(build_response) - - self.mox.ReplayAll() - - self.handler.request.headers['Content-Type'] = None - self.handler.request.environ['HTTP_CONTENT_TYPE'] = ( - 'application/x-www-form-urlencoded') - - self.handler.handle('POST', '/my_service', 'method1') - - self.VerifyResponse('200', 'OK', '1 a VAL1', - 'application/x-www-form-urlencoded') - - self.mox.VerifyAll() - - def testRequestState(self): - """Make sure request state is passed in to handler that supports it.""" - class ServiceWithState(object): - - initialize_request_state = self.mox.CreateMockAnything() - - @remote.method(Request1, Response1) - def method1(self, request): - return Response1() - - self.service_factory = ServiceWithState - - # Reset handler with new service type. - self.ResetRequestHandler() - - self.rpc_mapper1.build_request( - self.handler, Request1).AndReturn(Request1()) - - def build_response(handler, response): - handler.response.headers['Content-Type'] = ( - 'application/x-www-form-urlencoded') - handler.response.out.write('whatever') - self.rpc_mapper1.build_response( - self.handler, mox.IsA(Response1)).WithSideEffects(build_response) - - def verify_state(state): - return ( - 'remote.host.com' == state.remote_host and - '127.0.0.1' == state.remote_address and - 'server.host.com' == state.server_host and - 8080 == state.server_port and - 'POST' == state.http_method and - '/my_service' == state.service_path and - 'application/x-www-form-urlencoded' == state.headers['content-type'] and - 'dev_appserver_login="test:test@example.com:True"' == - state.headers['cookie']) - ServiceWithState.initialize_request_state(mox.Func(verify_state)) - - self.mox.ReplayAll() - - self.handler.handle('POST', '/my_service', 'method1') - - self.VerifyResponse('200', 'OK', 'whatever') - - self.mox.VerifyAll() - - def testRequestState_MissingHosts(self): - """Make sure missing state environment values are handled gracefully.""" - class ServiceWithState(object): - - initialize_request_state = self.mox.CreateMockAnything() - - @remote.method(Request1, Response1) - def method1(self, request): - return Response1() - - self.service_factory = ServiceWithState - self.remote_host = None - self.server_host = None - - # Reset handler with new service type. - self.ResetRequestHandler() - - self.rpc_mapper1.build_request( - self.handler, Request1).AndReturn(Request1()) - - def build_response(handler, response): - handler.response.headers['Content-Type'] = ( - 'application/x-www-form-urlencoded') - handler.response.out.write('whatever') - self.rpc_mapper1.build_response( - self.handler, mox.IsA(Response1)).WithSideEffects(build_response) - - def verify_state(state): - return (None is state.remote_host and - '127.0.0.1' == state.remote_address and - None is state.server_host and - 8080 == state.server_port) - ServiceWithState.initialize_request_state(mox.Func(verify_state)) - - self.mox.ReplayAll() - - self.handler.handle('POST', '/my_service', 'method1') - - self.VerifyResponse('200', 'OK', 'whatever') - - self.mox.VerifyAll() - - def testNoMatch_UnknownHTTPMethod(self): - """Test what happens when no RPCMapper matches.""" - self.mox.ReplayAll() - - self.handler.handle('UNKNOWN', '/my_service', 'does_not_matter') - - self.VerifyResponse('405', - 'Unsupported HTTP method: UNKNOWN', - 'Method Not Allowed', - 'text/plain; charset=utf-8') - - self.mox.VerifyAll() - - def testNoMatch_GetNotSupported(self): - """Test what happens when GET is not supported.""" - self.mox.ReplayAll() - - self.handler.handle('GET', '/my_service', 'method1') - - self.VerifyResponse('405', - 'Method Not Allowed', - '/my_service.method1 is a ProtoRPC method.\n\n' - 'Service %s.Service\n\n' - 'More about ProtoRPC: ' - 'http://code.google.com/p/google-protorpc' % - (__name__,), - 'text/plain; charset=utf-8') - - self.mox.VerifyAll() - - def testNoMatch_UnknownContentType(self): - """Test what happens when no RPCMapper matches.""" - self.mox.ReplayAll() - - self.handler.request.headers['Content-Type'] = 'image/png' - self.handler.handle('POST', '/my_service', 'method1') - - self.VerifyResponse('415', - 'Unsupported content-type: image/png', - 'Unsupported Media Type', - 'text/plain; charset=utf-8') - - self.mox.VerifyAll() - - def testNoMatch_NoContentType(self): - """Test what happens when no RPCMapper matches..""" - self.mox.ReplayAll() - - self.handler.request.environ.pop('HTTP_CONTENT_TYPE', None) - self.handler.request.headers.pop('Content-Type', None) - self.handler.handle('/my_service', 'POST', 'method1') - - self.VerifyResponse('400', 'Invalid RPC request: missing content-type', - 'Bad Request', - 'text/plain; charset=utf-8') - - self.mox.VerifyAll() - - def testNoSuchMethod(self): - """When service method not found.""" - self.ExpectRpcError(self.rpc_mapper1, - remote.RpcState.METHOD_NOT_FOUND_ERROR, - 'Unrecognized RPC method: no_such_method') - - self.mox.ReplayAll() - - self.handler.handle('POST', '/my_service', 'no_such_method') - - self.VerifyResponse('400', 'Unrecognized RPC method: no_such_method', '') - - self.mox.VerifyAll() - - def testNoSuchRemoteMethod(self): - """When service method exists but is not remote.""" - self.ExpectRpcError(self.rpc_mapper1, - remote.RpcState.METHOD_NOT_FOUND_ERROR, - 'Unrecognized RPC method: not_remote') - - self.mox.ReplayAll() - - self.handler.handle('POST', '/my_service', 'not_remote') - - self.VerifyResponse('400', 'Unrecognized RPC method: not_remote', '') - - self.mox.VerifyAll() - - def testRequestError(self): - """RequestError handling.""" - def build_request(handler, request): - raise service_handlers.RequestError('This is a request error') - self.rpc_mapper1.build_request( - self.handler, Request1).WithSideEffects(build_request) - - self.ExpectRpcError(self.rpc_mapper1, - remote.RpcState.REQUEST_ERROR, - 'Error parsing ProtoRPC request ' - '(This is a request error)') - - self.mox.ReplayAll() - - self.handler.handle('POST', '/my_service', 'method1') - - self.VerifyResponse('400', - 'Error parsing ProtoRPC request ' - '(This is a request error)', - '') - - - self.mox.VerifyAll() - - def testDecodeError(self): - """DecodeError handling.""" - def build_request(handler, request): - raise messages.DecodeError('This is a decode error') - self.rpc_mapper1.build_request( - self.handler, Request1).WithSideEffects(build_request) - - self.ExpectRpcError(self.rpc_mapper1, - remote.RpcState.REQUEST_ERROR, - r'Error parsing ProtoRPC request ' - r'(This is a decode error)') - - self.mox.ReplayAll() - - self.handler.handle('POST', '/my_service', 'method1') - - self.VerifyResponse('400', - 'Error parsing ProtoRPC request ' - '(This is a decode error)', - '') - - self.mox.VerifyAll() - - def testResponseException(self): - """Test what happens when build_response raises ResponseError.""" - self.rpc_mapper1.build_request( - self.handler, Request1).AndReturn(self.request) - - self.rpc_mapper1.build_response( - self.handler, mox.IsA(Response1)).AndRaise( - service_handlers.ResponseError) - - self.ExpectRpcError(self.rpc_mapper1, - remote.RpcState.SERVER_ERROR, - 'Internal Server Error') - - self.mox.ReplayAll() - - self.handler.handle('POST', '/my_service', 'method1') - - self.VerifyResponse('500', 'Internal Server Error', '') - - self.mox.VerifyAll() - - def testGet(self): - """Test that GET goes to 'handle' properly.""" - self.handler.handle = self.mox.CreateMockAnything() - self.handler.handle('GET', '/my_service', 'method1') - self.handler.handle('GET', '/my_other_service', 'method2') - - self.mox.ReplayAll() - - self.handler.get('/my_service', 'method1') - self.handler.get('/my_other_service', 'method2') - - self.mox.VerifyAll() - - def testPost(self): - """Test that POST goes to 'handle' properly.""" - self.handler.handle = self.mox.CreateMockAnything() - self.handler.handle('POST', '/my_service', 'method1') - self.handler.handle('POST', '/my_other_service', 'method2') - - self.mox.ReplayAll() - - self.handler.post('/my_service', 'method1') - self.handler.post('/my_other_service', 'method2') - - self.mox.VerifyAll() - - def testGetNoMethod(self): - self.handler.get('/my_service', '') - self.assertEquals(405, self.handler.response.status) - self.assertEquals( - util.pad_string('/my_service is a ProtoRPC service.\n\n' - 'Service %s.Service\n\n' - 'More about ProtoRPC: ' - 'http://code.google.com/p/google-protorpc\n' % - __name__), - self.handler.response.out.getvalue()) - self.assertEquals( - 'nosniff', - self.handler.response.headers['x-content-type-options']) - - def testGetNotSupported(self): - self.handler.get('/my_service', 'method1') - self.assertEquals(405, self.handler.response.status) - expected_message = ('/my_service.method1 is a ProtoRPC method.\n\n' - 'Service %s.Service\n\n' - 'More about ProtoRPC: ' - 'http://code.google.com/p/google-protorpc\n' % - __name__) - self.assertEquals(util.pad_string(expected_message), - self.handler.response.out.getvalue()) - self.assertEquals( - 'nosniff', - self.handler.response.headers['x-content-type-options']) - - def testGetUnknownContentType(self): - self.handler.request.headers['content-type'] = 'image/png' - self.handler.get('/my_service', 'method1') - self.assertEquals(415, self.handler.response.status) - self.assertEquals( - util.pad_string('/my_service.method1 is a ProtoRPC method.\n\n' - 'Service %s.Service\n\n' - 'More about ProtoRPC: ' - 'http://code.google.com/p/google-protorpc\n' % - __name__), - self.handler.response.out.getvalue()) - self.assertEquals( - 'nosniff', - self.handler.response.headers['x-content-type-options']) - - -class MissingContentLengthTests(ServiceHandlerTest): - """Test for when content-length is not set in the environment. - - This test moves CONTENT_LENGTH from the environment to the - content-length header. - """ - - def GetEnvironment(self): - environment = super(MissingContentLengthTests, self).GetEnvironment() - content_length = str(environment.pop('CONTENT_LENGTH', '0')) - environment['HTTP_CONTENT_LENGTH'] = content_length - return environment - - -class MissingContentTypeTests(ServiceHandlerTest): - """Test for when content-type is not set in the environment. - - This test moves CONTENT_TYPE from the environment to the - content-type header. - """ - - def GetEnvironment(self): - environment = super(MissingContentTypeTests, self).GetEnvironment() - content_type = str(environment.pop('CONTENT_TYPE', '')) - environment['HTTP_CONTENT_TYPE'] = content_type - return environment - - -class RPCMapperTestBase(test_util.TestCase): - - def setUp(self): - """Set up test framework.""" - self.Reinitialize() - - def Reinitialize(self, input='', - get=False, - path_method='method1', - content_type='text/plain'): - """Allows reinitialization of test with custom input values and POST. - - Args: - input: Query string or POST input. - get: Use GET method if True. Use POST if False. - """ - self.factory = service_handlers.ServiceHandlerFactory(Service) - - self.service_handler = service_handlers.ServiceHandler(self.factory, - Service()) - self.service_handler.remote_method = path_method - request_path = '/servicepath' - if path_method: - request_path += '/' + path_method - if get: - request_path += '?' + input - - if get: - environ = {'wsgi.input': cStringIO.StringIO(''), - 'CONTENT_LENGTH': '0', - 'QUERY_STRING': input, - 'REQUEST_METHOD': 'GET', - 'PATH_INFO': request_path, - } - self.service_handler.method = 'GET' - else: - environ = {'wsgi.input': cStringIO.StringIO(input), - 'CONTENT_LENGTH': str(len(input)), - 'QUERY_STRING': '', - 'REQUEST_METHOD': 'POST', - 'PATH_INFO': request_path, - } - self.service_handler.method = 'POST' - - self.request = webapp.Request(environ) - - self.response = webapp.Response() - - self.service_handler.initialize(self.request, self.response) - - self.service_handler.request.headers['Content-Type'] = content_type - - -class RPCMapperTest(RPCMapperTestBase, webapp_test_util.RequestHandlerTestBase): - """Test the RPCMapper base class.""" - - def setUp(self): - RPCMapperTestBase.setUp(self) - webapp_test_util.RequestHandlerTestBase.setUp(self) - self.mox = mox.Mox() - self.protocol = self.mox.CreateMockAnything() - - def GetEnvironment(self): - """Get environment. - - Return bogus content in body. - - Returns: - dict of CGI environment. - """ - environment = super(RPCMapperTest, self).GetEnvironment() - environment['wsgi.input'] = cStringIO.StringIO('my body') - environment['CONTENT_LENGTH'] = len('my body') - return environment - - def testContentTypes_JustDefault(self): - """Test content type attributes.""" - self.mox.ReplayAll() - - mapper = service_handlers.RPCMapper(['GET', 'POST'], - 'my-content-type', - self.protocol) - - self.assertEquals(frozenset(['GET', 'POST']), mapper.http_methods) - self.assertEquals('my-content-type', mapper.default_content_type) - self.assertEquals(frozenset(['my-content-type']), - mapper.content_types) - - self.mox.VerifyAll() - - def testContentTypes_Extended(self): - """Test content type attributes.""" - self.mox.ReplayAll() - - mapper = service_handlers.RPCMapper(['GET', 'POST'], - 'my-content-type', - self.protocol, - content_types=['a', 'b']) - - self.assertEquals(frozenset(['GET', 'POST']), mapper.http_methods) - self.assertEquals('my-content-type', mapper.default_content_type) - self.assertEquals(frozenset(['my-content-type', 'a', 'b']), - mapper.content_types) - - self.mox.VerifyAll() - - def testBuildRequest(self): - """Test building a request.""" - expected_request = Request1() - self.protocol.decode_message(Request1, - 'my body').AndReturn(expected_request) - - self.mox.ReplayAll() - - mapper = service_handlers.RPCMapper(['POST'], - 'my-content-type', - self.protocol) - - request = mapper.build_request(self.handler, Request1) - - self.assertTrue(expected_request is request) - - def testBuildRequest_ValidationError(self): - """Test building a request generating a validation error.""" - expected_request = Request1() - self.protocol.decode_message( - Request1, 'my body').AndRaise(messages.ValidationError('xyz')) - - self.mox.ReplayAll() - - mapper = service_handlers.RPCMapper(['POST'], - 'my-content-type', - self.protocol) - - self.assertRaisesWithRegexpMatch( - service_handlers.RequestError, - 'Unable to parse request content: xyz', - mapper.build_request, - self.handler, - Request1) - - def testBuildRequest_DecodeError(self): - """Test building a request generating a decode error.""" - expected_request = Request1() - self.protocol.decode_message( - Request1, 'my body').AndRaise(messages.DecodeError('xyz')) - - self.mox.ReplayAll() - - mapper = service_handlers.RPCMapper(['POST'], - 'my-content-type', - self.protocol) - - self.assertRaisesWithRegexpMatch( - service_handlers.RequestError, - 'Unable to parse request content: xyz', - mapper.build_request, - self.handler, - Request1) - - def testBuildResponse(self): - """Test building a response.""" - response = Response1() - self.protocol.encode_message(response).AndReturn('encoded') - - self.mox.ReplayAll() - - mapper = service_handlers.RPCMapper(['POST'], - 'my-content-type', - self.protocol) - - request = mapper.build_response(self.handler, response) - - self.assertEquals('my-content-type', - self.handler.response.headers['Content-Type']) - self.assertEquals('encoded', self.handler.response.out.getvalue()) - - def testBuildResponse(self): - """Test building a response.""" - response = Response1() - self.protocol.encode_message(response).AndRaise( - messages.ValidationError('xyz')) - - self.mox.ReplayAll() - - mapper = service_handlers.RPCMapper(['POST'], - 'my-content-type', - self.protocol) - - self.assertRaisesWithRegexpMatch(service_handlers.ResponseError, - 'Unable to encode message: xyz', - mapper.build_response, - self.handler, - response) - - -class ProtocolMapperTestBase(object): - """Base class for basic protocol mapper tests.""" - - def setUp(self): - """Reinitialize test specifically for protocol buffer mapper.""" - super(ProtocolMapperTestBase, self).setUp() - self.Reinitialize(path_method='my_method', - content_type='application/x-google-protobuf') - - self.request_message = Request1() - self.request_message.integer_field = 1 - self.request_message.string_field = u'something' - self.request_message.enum_field = Enum1.VAL1 - - self.response_message = Response1() - self.response_message.integer_field = 1 - self.response_message.string_field = u'something' - self.response_message.enum_field = Enum1.VAL1 - - def testBuildRequest(self): - """Test request building.""" - self.Reinitialize(self.protocol.encode_message(self.request_message), - content_type=self.content_type) - - mapper = self.mapper() - parsed_request = mapper.build_request(self.service_handler, - Request1) - self.assertEquals(self.request_message, parsed_request) - - def testBuildResponse(self): - """Test response building.""" - - mapper = self.mapper() - mapper.build_response(self.service_handler, self.response_message) - self.assertEquals(self.protocol.encode_message(self.response_message), - self.service_handler.response.out.getvalue()) - - def testWholeRequest(self): - """Test the basic flow of a request with mapper class.""" - body = self.protocol.encode_message(self.request_message) - self.Reinitialize(input=body, - content_type=self.content_type) - self.factory.add_request_mapper(self.mapper()) - self.service_handler.handle('POST', '/my_service', 'method1') - VerifyResponse(self, - self.service_handler.response, - '200', - 'OK', - self.protocol.encode_message(self.response_message), - self.content_type) - - -class URLEncodedRPCMapperTest(ProtocolMapperTestBase, RPCMapperTestBase): - """Test the URL encoded RPC mapper.""" - - content_type = 'application/x-www-form-urlencoded' - protocol = protourlencode - mapper = service_handlers.URLEncodedRPCMapper - - def testBuildRequest_Prefix(self): - """Test building request with parameter prefix.""" - self.Reinitialize(urllib.urlencode([('prefix_integer_field', '10'), - ('prefix_string_field', 'a string'), - ('prefix_enum_field', 'VAL1'), - ]), - self.content_type) - - url_encoded_mapper = service_handlers.URLEncodedRPCMapper( - parameter_prefix='prefix_') - request = url_encoded_mapper.build_request(self.service_handler, - Request1) - self.assertEquals(10, request.integer_field) - self.assertEquals('a string', request.string_field) - self.assertEquals(Enum1.VAL1, request.enum_field) - - def testBuildRequest_DecodeError(self): - """Test trying to build request that causes a decode error.""" - self.Reinitialize(urllib.urlencode((('integer_field', '10'), - ('integer_field', '20'), - )), - content_type=self.content_type) - - url_encoded_mapper = service_handlers.URLEncodedRPCMapper() - - self.assertRaises(service_handlers.RequestError, - url_encoded_mapper.build_request, - self.service_handler, - Service.method1.remote.request_type) - - def testBuildResponse_Prefix(self): - """Test building a response with parameter prefix.""" - response = Response1() - response.integer_field = 10 - response.string_field = u'a string' - response.enum_field = Enum1.VAL3 - - url_encoded_mapper = service_handlers.URLEncodedRPCMapper( - parameter_prefix='prefix_') - - url_encoded_mapper.build_response(self.service_handler, response) - self.assertEquals('application/x-www-form-urlencoded', - self.response.headers['content-type']) - self.assertEquals(cgi.parse_qs(self.response.out.getvalue(), True, True), - {'prefix_integer_field': ['10'], - 'prefix_string_field': [u'a string'], - 'prefix_enum_field': ['VAL3'], - }) - - -class ProtobufRPCMapperTest(ProtocolMapperTestBase, RPCMapperTestBase): - """Test the protobuf encoded RPC mapper.""" - - content_type = 'application/octet-stream' - protocol = protobuf - mapper = service_handlers.ProtobufRPCMapper - - -class JSONRPCMapperTest(ProtocolMapperTestBase, RPCMapperTestBase): - """Test the URL encoded RPC mapper.""" - - content_type = 'application/json' - protocol = protojson - mapper = service_handlers.JSONRPCMapper - - -class MyService(remote.Service): - - def __init__(self, value='default'): - self.value = value - - -class ServiceMappingTest(test_util.TestCase): - - def CheckFormMappings(self, mapping, registry_path='/protorpc'): - """Check to make sure that form mapping is configured as expected. - - Args: - mapping: Mapping that should contain forms handlers. - """ - pattern, factory = mapping[0] - self.assertEquals('%s/form(?:/)?' % registry_path, pattern) - handler = factory() - self.assertTrue(isinstance(handler, forms.FormsHandler)) - self.assertEquals(registry_path, handler.registry_path) - - pattern, factory = mapping[1] - self.assertEquals('%s/form/(.+)' % registry_path, pattern) - self.assertEquals(forms.ResourceHandler, factory) - - - def DoMappingTest(self, - services, - registry_path='/myreg', - expected_paths=None): - mapped_services = mapping = service_handlers.service_mapping(services, - registry_path) - if registry_path: - form_mapping = mapping[:2] - mapped_registry_path, mapped_registry_factory = mapping[-1] - mapped_services = mapping[2:-1] - self.CheckFormMappings(form_mapping, registry_path=registry_path) - - self.assertEquals(r'(%s)%s' % (registry_path, - service_handlers._METHOD_PATTERN), - mapped_registry_path) - self.assertEquals(registry.RegistryService, - mapped_registry_factory.service_factory.service_class) - - # Verify registry knows about other services. - expected_registry = {registry_path: registry.RegistryService} - for path, factory in dict(services).items(): - if isinstance(factory, type) and issubclass(factory, remote.Service): - expected_registry[path] = factory - else: - expected_registry[path] = factory.service_class - self.assertEquals(expected_registry, - mapped_registry_factory().service.registry) - - # Verify that services are mapped to URL. - self.assertEquals(len(services), len(mapped_services)) - for path, service in dict(services).items(): - mapped_path = r'(%s)%s' % (path, service_handlers._METHOD_PATTERN) - mapped_factory = dict(mapped_services)[mapped_path] - self.assertEquals(service, mapped_factory.service_factory) - - def testServiceMapping_Empty(self): - """Test an empty service mapping.""" - self.DoMappingTest({}) - - def testServiceMapping_ByClass(self): - """Test mapping a service by class.""" - self.DoMappingTest({'/my-service': MyService}) - - def testServiceMapping_ByFactory(self): - """Test mapping a service by factory.""" - self.DoMappingTest({'/my-service': MyService.new_factory('new-value')}) - - def testServiceMapping_ByList(self): - """Test mapping a service by factory.""" - self.DoMappingTest( - [('/my-service1', MyService.new_factory('service1')), - ('/my-service2', MyService.new_factory('service2')), - ]) - - def testServiceMapping_NoRegistry(self): - """Test mapping a service by class.""" - mapping = self.DoMappingTest({'/my-service': MyService}, None) - - def testDefaultMappingWithClass(self): - """Test setting path just from the class. - - Path of the mapping will be the fully qualified ProtoRPC service name with - '.' replaced with '/'. For example: - - com.nowhere.service.TheService -> /com/nowhere/service/TheService - """ - mapping = service_handlers.service_mapping([MyService]) - mapped_services = mapping[2:-1] - self.assertEquals(1, len(mapped_services)) - path, factory = mapped_services[0] - - self.assertEquals( - r'(/test_package/MyService)' + service_handlers._METHOD_PATTERN, - path) - self.assertEquals(MyService, factory.service_factory) - - def testDefaultMappingWithFactory(self): - mapping = service_handlers.service_mapping( - [MyService.new_factory('service1')]) - mapped_services = mapping[2:-1] - self.assertEquals(1, len(mapped_services)) - path, factory = mapped_services[0] - - self.assertEquals( - r'(/test_package/MyService)' + service_handlers._METHOD_PATTERN, - path) - self.assertEquals(MyService, factory.service_factory.service_class) - - def testMappingDuplicateExplicitServiceName(self): - self.assertRaisesWithRegexpMatch( - service_handlers.ServiceConfigurationError, - "Path '/my_path' is already defined in service mapping", - service_handlers.service_mapping, - [('/my_path', MyService), - ('/my_path', MyService), - ]) - - def testMappingDuplicateServiceName(self): - self.assertRaisesWithRegexpMatch( - service_handlers.ServiceConfigurationError, - "Path '/test_package/MyService' is already defined in service mapping", - service_handlers.service_mapping, - [MyService, MyService]) - - -class GetCalled(remote.Service): - - def __init__(self, test): - self.test = test - - @remote.method(Request1, Response1) - def my_method(self, request): - self.test.request = request - return Response1(string_field='a response') - - -class TestRunServices(test_util.TestCase): - - def DoRequest(self, - path, - request, - response_type, - reg_path='/protorpc'): - stdin = sys.stdin - stdout = sys.stdout - environ = os.environ - try: - sys.stdin = cStringIO.StringIO(protojson.encode_message(request)) - sys.stdout = cStringIO.StringIO() - - os.environ = webapp_test_util.GetDefaultEnvironment() - os.environ['PATH_INFO'] = path - os.environ['REQUEST_METHOD'] = 'POST' - os.environ['CONTENT_TYPE'] = 'application/json' - os.environ['wsgi.input'] = sys.stdin - os.environ['wsgi.output'] = sys.stdout - os.environ['CONTENT_LENGTH'] = len(sys.stdin.getvalue()) - - service_handlers.run_services( - [('/my_service', GetCalled.new_factory(self))], reg_path) - - header, body = sys.stdout.getvalue().split('\n\n', 1) - - return (header.split('\n')[0], - protojson.decode_message(response_type, body)) - finally: - sys.stdin = stdin - sys.stdout = stdout - os.environ = environ - - def testRequest(self): - request = Request1(string_field='request value') - - status, response = self.DoRequest('/my_service.my_method', - request, - Response1) - self.assertEquals('Status: 200 OK', status) - self.assertEquals(request, self.request) - self.assertEquals(Response1(string_field='a response'), response) - - def testRegistry(self): - request = Request1(string_field='request value') - status, response = self.DoRequest('/protorpc.services', - message_types.VoidMessage(), - registry.ServicesResponse) - - self.assertEquals('Status: 200 OK', status) - self.assertIterEqual([ - registry.ServiceMapping( - name='/protorpc', - definition='protorpc.registry.RegistryService'), - registry.ServiceMapping( - name='/my_service', - definition='test_package.GetCalled'), - ], response.services) - - def testRunServicesWithOutRegistry(self): - request = Request1(string_field='request value') - - status, response = self.DoRequest('/protorpc.services', - message_types.VoidMessage(), - registry.ServicesResponse, - reg_path=None) - self.assertEquals('Status: 404 Not Found', status) - - -def main(): - unittest.main() - - -if __name__ == '__main__': - main() diff --git a/endpoints/internal/protorpc/webapp_test_util.py b/endpoints/internal/protorpc/webapp_test_util.py deleted file mode 100644 index 6481dc3..0000000 --- a/endpoints/internal/protorpc/webapp_test_util.py +++ /dev/null @@ -1,411 +0,0 @@ -#!/usr/bin/env python -# -# Copyright 2010 Google Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -"""Testing utilities for the webapp libraries. - - GetDefaultEnvironment: Method for easily setting up CGI environment. - RequestHandlerTestBase: Base class for setting up handler tests. -""" - -__author__ = 'rafek@google.com (Rafe Kaplan)' - -import cStringIO -import socket -import threading -import urllib2 -from wsgiref import simple_server -from wsgiref import validate - -from . import protojson -from . import remote -from . import test_util -from . import transport -from .webapp import service_handlers -from .webapp.google_imports import webapp - - -class TestService(remote.Service): - """Service used to do end to end tests with.""" - - @remote.method(test_util.OptionalMessage, - test_util.OptionalMessage) - def optional_message(self, request): - if request.string_value: - request.string_value = '+%s' % request.string_value - return request - - -def GetDefaultEnvironment(): - """Function for creating a default CGI environment.""" - return { - 'LC_NUMERIC': 'C', - 'wsgi.multiprocess': True, - 'SERVER_PROTOCOL': 'HTTP/1.0', - 'SERVER_SOFTWARE': 'Dev AppServer 0.1', - 'SCRIPT_NAME': '', - 'LOGNAME': 'nickjohnson', - 'USER': 'nickjohnson', - 'QUERY_STRING': 'foo=bar&foo=baz&foo2=123', - 'PATH': '/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin:/usr/bin/X11', - 'LANG': 'en_US', - 'LANGUAGE': 'en', - 'REMOTE_ADDR': '127.0.0.1', - 'LC_MONETARY': 'C', - 'CONTENT_TYPE': 'application/x-www-form-urlencoded', - 'wsgi.url_scheme': 'http', - 'SERVER_PORT': '8080', - 'HOME': '/home/mruser', - 'USERNAME': 'mruser', - 'CONTENT_LENGTH': '', - 'USER_IS_ADMIN': '1', - 'PYTHONPATH': '/tmp/setup', - 'LC_TIME': 'C', - 'HTTP_USER_AGENT': 'Mozilla/5.0 (X11; U; Linux i686 (x86_64); en-US; ' - 'rv:1.8.1.6) Gecko/20070725 Firefox/2.0.0.6', - 'wsgi.multithread': False, - 'wsgi.version': (1, 0), - 'USER_EMAIL': 'test@example.com', - 'USER_EMAIL': '112', - 'wsgi.input': cStringIO.StringIO(), - 'PATH_TRANSLATED': '/tmp/request.py', - 'SERVER_NAME': 'localhost', - 'GATEWAY_INTERFACE': 'CGI/1.1', - 'wsgi.run_once': True, - 'LC_COLLATE': 'C', - 'HOSTNAME': 'myhost', - 'wsgi.errors': cStringIO.StringIO(), - 'PWD': '/tmp', - 'REQUEST_METHOD': 'GET', - 'MAIL': '/dev/null', - 'MAILCHECK': '0', - 'USER_NICKNAME': 'test', - 'HTTP_COOKIE': 'dev_appserver_login="test:test@example.com:True"', - 'PATH_INFO': '/tmp/myhandler' - } - - -class RequestHandlerTestBase(test_util.TestCase): - """Base class for writing RequestHandler tests. - - To test a specific request handler override CreateRequestHandler. - To change the environment for that handler override GetEnvironment. - """ - - def setUp(self): - """Set up test for request handler.""" - self.ResetHandler() - - def GetEnvironment(self): - """Get environment. - - Override for more specific configurations. - - Returns: - dict of CGI environment. - """ - return GetDefaultEnvironment() - - def CreateRequestHandler(self): - """Create RequestHandler instances. - - Override to create more specific kinds of RequestHandler instances. - - Returns: - RequestHandler instance used in test. - """ - return webapp.RequestHandler() - - def CheckResponse(self, - expected_status, - expected_headers, - expected_content): - """Check that the web response is as expected. - - Args: - expected_status: Expected status message. - expected_headers: Dictionary of expected headers. Will ignore unexpected - headers and only check the value of those expected. - expected_content: Expected body. - """ - def check_content(content): - self.assertEquals(expected_content, content) - - def start_response(status, headers): - self.assertEquals(expected_status, status) - - found_keys = set() - for name, value in headers: - name = name.lower() - try: - expected_value = expected_headers[name] - except KeyError: - pass - else: - found_keys.add(name) - self.assertEquals(expected_value, value) - - missing_headers = set(expected_headers.keys()) - found_keys - if missing_headers: - self.fail('Expected keys %r not found' % (list(missing_headers),)) - - return check_content - - self.handler.response.wsgi_write(start_response) - - def ResetHandler(self, change_environ=None): - """Reset this tests environment with environment changes. - - Resets the entire test with a new handler which includes some changes to - the default request environment. - - Args: - change_environ: Dictionary of values that are added to default - environment. - """ - environment = self.GetEnvironment() - environment.update(change_environ or {}) - - self.request = webapp.Request(environment) - self.response = webapp.Response() - self.handler = self.CreateRequestHandler() - self.handler.initialize(self.request, self.response) - - -class SyncedWSGIServer(simple_server.WSGIServer): - pass - - -class WSGIServerIPv6(simple_server.WSGIServer): - address_family = socket.AF_INET6 - - -class ServerThread(threading.Thread): - """Thread responsible for managing wsgi server. - - This server does not just attach to the socket and listen for requests. This - is because the server classes in Python 2.5 or less have no way to shut them - down. Instead, the thread must be notified of how many requests it will - receive so that it listens for each one individually. Tests should tell how - many requests to listen for using the handle_request method. - """ - - def __init__(self, server, *args, **kwargs): - """Constructor. - - Args: - server: The WSGI server that is served by this thread. - As per threading.Thread base class. - - State: - __serving: Server is still expected to be serving. When False server - knows to shut itself down. - """ - self.server = server - # This timeout is for the socket when a connection is made. - self.server.socket.settimeout(None) - # This timeout is for when waiting for a connection. The allows - # server.handle_request() to listen for a short time, then timeout, - # allowing the server to check for shutdown. - self.server.timeout = 0.05 - self.__serving = True - - super(ServerThread, self).__init__(*args, **kwargs) - - def shutdown(self): - """Notify server that it must shutdown gracefully.""" - self.__serving = False - - def run(self): - """Handle incoming requests until shutdown.""" - while self.__serving: - self.server.handle_request() - - self.server = None - - -class TestService(remote.Service): - """Service used to do end to end tests with.""" - - def __init__(self, message='uninitialized'): - self.__message = message - - @remote.method(test_util.OptionalMessage, test_util.OptionalMessage) - def optional_message(self, request): - if request.string_value: - request.string_value = '+%s' % request.string_value - return request - - @remote.method(response_type=test_util.OptionalMessage) - def init_parameter(self, request): - return test_util.OptionalMessage(string_value=self.__message) - - @remote.method(test_util.NestedMessage, test_util.NestedMessage) - def nested_message(self, request): - request.string_value = '+%s' % request.string_value - return request - - @remote.method() - def raise_application_error(self, request): - raise remote.ApplicationError('This is an application error', 'ERROR_NAME') - - @remote.method() - def raise_unexpected_error(self, request): - raise TypeError('Unexpected error') - - @remote.method() - def raise_rpc_error(self, request): - raise remote.NetworkError('Uncaught network error') - - @remote.method(response_type=test_util.NestedMessage) - def return_bad_message(self, request): - return test_util.NestedMessage() - - -class AlternateService(remote.Service): - """Service used to requesting non-existant methods.""" - - @remote.method() - def does_not_exist(self, request): - raise NotImplementedError('Not implemented') - - -class WebServerTestBase(test_util.TestCase): - - SERVICE_PATH = '/my/service' - - def setUp(self): - self.server = None - self.schema = 'http' - self.ResetServer() - - self.bad_path_connection = self.CreateTransport(self.service_url + '_x') - self.bad_path_stub = TestService.Stub(self.bad_path_connection) - super(WebServerTestBase, self).setUp() - - def tearDown(self): - self.server.shutdown() - super(WebServerTestBase, self).tearDown() - - def ResetServer(self, application=None): - """Reset web server. - - Shuts down existing server if necessary and starts a new one. - - Args: - application: Optional WSGI function. If none provided will use - tests CreateWsgiApplication method. - """ - if self.server: - self.server.shutdown() - - self.port = test_util.pick_unused_port() - self.server, self.application = self.StartWebServer(self.port, application) - - self.connection = self.CreateTransport(self.service_url) - - def CreateTransport(self, service_url, protocol=protojson): - """Create a new transportation object.""" - return transport.HttpTransport(service_url, protocol=protocol) - - def StartWebServer(self, port, application=None): - """Start web server. - - Args: - port: Port to start application on. - application: Optional WSGI function. If none provided will use - tests CreateWsgiApplication method. - - Returns: - A tuple (server, application): - server: An instance of ServerThread. - application: Application that web server responds with. - """ - if not application: - application = self.CreateWsgiApplication() - validated_application = validate.validator(application) - - try: - server = simple_server.make_server( - 'localhost', port, validated_application) - except socket.error: - # Try IPv6 - server = simple_server.make_server( - 'localhost', port, validated_application, server_class=WSGIServerIPv6) - - server = ServerThread(server) - server.start() - return server, application - - def make_service_url(self, path): - """Make service URL using current schema and port.""" - return '%s://localhost:%d%s' % (self.schema, self.port, path) - - @property - def service_url(self): - return self.make_service_url(self.SERVICE_PATH) - - -class EndToEndTestBase(WebServerTestBase): - - # Sub-classes may override to create alternate configurations. - DEFAULT_MAPPING = service_handlers.service_mapping( - [('/my/service', TestService), - ('/my/other_service', TestService.new_factory('initialized')), - ]) - - def setUp(self): - super(EndToEndTestBase, self).setUp() - - self.stub = TestService.Stub(self.connection) - - self.other_connection = self.CreateTransport(self.other_service_url) - self.other_stub = TestService.Stub(self.other_connection) - - self.mismatched_stub = AlternateService.Stub(self.connection) - - @property - def other_service_url(self): - return 'http://localhost:%d/my/other_service' % self.port - - def CreateWsgiApplication(self): - """Create WSGI application used on the server side for testing.""" - return webapp.WSGIApplication(self.DEFAULT_MAPPING, True) - - def DoRawRequest(self, - method, - content='', - content_type='application/json', - headers=None): - headers = headers or {} - headers.update({'content-length': len(content or ''), - 'content-type': content_type, - }) - request = urllib2.Request('%s.%s' % (self.service_url, method), - content, - headers) - return urllib2.urlopen(request) - - def RawRequestError(self, - method, - content=None, - content_type='application/json', - headers=None): - try: - self.DoRawRequest(method, content, content_type, headers) - self.fail('Expected HTTP error') - except urllib2.HTTPError as err: - return err.code, err.read(), err.headers diff --git a/endpoints/internal/protorpc/wsgi/service_test.py b/endpoints/internal/protorpc/wsgi/service_test.py deleted file mode 100644 index c94d648..0000000 --- a/endpoints/internal/protorpc/wsgi/service_test.py +++ /dev/null @@ -1,205 +0,0 @@ -#!/usr/bin/env python -# -# Copyright 2011 Google Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -"""WSGI application tests.""" - -__author__ = 'rafek@google.com (Rafe Kaplan)' - - -import unittest - - -from protorpc import end2end_test -from protorpc import protojson -from protorpc import remote -from protorpc import registry -from protorpc import transport -from protorpc import test_util -from protorpc import webapp_test_util -from protorpc.wsgi import service -from protorpc.wsgi import util - - -class ServiceMappingTest(end2end_test.EndToEndTest): - - def setUp(self): - self.protocols = None - remote.Protocols.set_default(remote.Protocols.new_default()) - super(ServiceMappingTest, self).setUp() - - def CreateServices(self): - - return my_service, my_other_service - - def CreateWsgiApplication(self): - """Create WSGI application used on the server side for testing.""" - my_service = service.service_mapping(webapp_test_util.TestService, - '/my/service') - my_other_service = service.service_mapping( - webapp_test_util.TestService.new_factory('initialized'), - '/my/other_service', - protocols=self.protocols) - - return util.first_found([my_service, my_other_service]) - - def testAlternateProtocols(self): - self.protocols = remote.Protocols() - self.protocols.add_protocol(protojson, 'altproto', 'image/png') - - global_protocols = remote.Protocols() - global_protocols.add_protocol(protojson, 'server-side-name', 'image/png') - remote.Protocols.set_default(global_protocols) - self.ResetServer() - - self.connection = transport.HttpTransport( - self.service_url, protocol=self.protocols.lookup_by_name('altproto')) - self.stub = webapp_test_util.TestService.Stub(self.connection) - - self.stub.optional_message(string_value='alternate-protocol') - - def testAlwaysUseDefaults(self): - new_protocols = remote.Protocols() - new_protocols.add_protocol(protojson, 'altproto', 'image/png') - - self.connection = transport.HttpTransport( - self.service_url, protocol=new_protocols.lookup_by_name('altproto')) - self.stub = webapp_test_util.TestService.Stub(self.connection) - - self.assertRaisesWithRegexpMatch( - remote.ServerError, - 'HTTP Error 415: Unsupported Media Type', - self.stub.optional_message, string_value='alternate-protocol') - - remote.Protocols.set_default(new_protocols) - - self.stub.optional_message(string_value='alternate-protocol') - - -class ProtoServiceMappingsTest(ServiceMappingTest): - - def CreateWsgiApplication(self): - """Create WSGI application used on the server side for testing.""" - return service.service_mappings( - [('/my/service', webapp_test_util.TestService), - ('/my/other_service', - webapp_test_util.TestService.new_factory('initialized')) - ]) - - def GetRegistryStub(self, path='/protorpc'): - service_url = self.make_service_url(path) - transport = self.CreateTransport(service_url) - return registry.RegistryService.Stub(transport) - - def testRegistry(self): - registry_client = self.GetRegistryStub() - response = registry_client.services() - self.assertIterEqual([ - registry.ServiceMapping( - name='/my/other_service', - definition='protorpc.webapp_test_util.TestService'), - registry.ServiceMapping( - name='/my/service', - definition='protorpc.webapp_test_util.TestService'), - ], response.services) - - def testRegistryDictionary(self): - self.ResetServer(service.service_mappings( - {'/my/service': webapp_test_util.TestService, - '/my/other_service': - webapp_test_util.TestService.new_factory('initialized'), - })) - registry_client = self.GetRegistryStub() - response = registry_client.services() - self.assertIterEqual([ - registry.ServiceMapping( - name='/my/other_service', - definition='protorpc.webapp_test_util.TestService'), - registry.ServiceMapping( - name='/my/service', - definition='protorpc.webapp_test_util.TestService'), - ], response.services) - - def testNoRegistry(self): - self.ResetServer(service.service_mappings( - [('/my/service', webapp_test_util.TestService), - ('/my/other_service', - webapp_test_util.TestService.new_factory('initialized')) - ], - registry_path=None)) - registry_client = self.GetRegistryStub() - self.assertRaisesWithRegexpMatch( - remote.ServerError, - 'HTTP Error 404: Not Found', - registry_client.services) - - def testAltRegistry(self): - self.ResetServer(service.service_mappings( - [('/my/service', webapp_test_util.TestService), - ('/my/other_service', - webapp_test_util.TestService.new_factory('initialized')) - ], - registry_path='/registry')) - registry_client = self.GetRegistryStub('/registry') - services = registry_client.services() - self.assertTrue(isinstance(services, registry.ServicesResponse)) - self.assertIterEqual( - [registry.ServiceMapping( - name='/my/other_service', - definition='protorpc.webapp_test_util.TestService'), - registry.ServiceMapping( - name='/my/service', - definition='protorpc.webapp_test_util.TestService'), - ], - services.services) - - def testDuplicateRegistryEntry(self): - self.assertRaisesWithRegexpMatch( - remote.ServiceConfigurationError, - "Path '/my/service' is already defined in service mapping", - service.service_mappings, - [('/my/service', webapp_test_util.TestService), - ('/my/service', - webapp_test_util.TestService.new_factory('initialized')) - ]) - - def testRegex(self): - self.ResetServer(service.service_mappings( - [('/my/[0-9]+', webapp_test_util.TestService.new_factory('service')), - ('/my/[a-z]+', - webapp_test_util.TestService.new_factory('other-service')), - ])) - my_service_url = 'http://localhost:%d/my/12345' % self.port - my_other_service_url = 'http://localhost:%d/my/blarblar' % self.port - - my_service = webapp_test_util.TestService.Stub( - transport.HttpTransport(my_service_url)) - my_other_service = webapp_test_util.TestService.Stub( - transport.HttpTransport(my_other_service_url)) - - response = my_service.init_parameter() - self.assertEquals('service', response.string_value) - - response = my_other_service.init_parameter() - self.assertEquals('other-service', response.string_value) - - -def main(): - unittest.main() - - -if __name__ == '__main__': - main() diff --git a/endpoints/internal/protorpc/wsgi/util_test.py b/endpoints/internal/protorpc/wsgi/util_test.py deleted file mode 100644 index 60a79af..0000000 --- a/endpoints/internal/protorpc/wsgi/util_test.py +++ /dev/null @@ -1,295 +0,0 @@ -#!/usr/bin/env python -# -# Copyright 2011 Google Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -"""WSGI utility library tests.""" -import six -from six.moves import filter - -__author__ = 'rafe@google.com (Rafe Kaplan)' - - -import six.moves.http_client -import unittest - -from protorpc import test_util -from protorpc import util -from protorpc import webapp_test_util -from protorpc.wsgi import util as wsgi_util - -APP1 = wsgi_util.static_page('App1') -APP2 = wsgi_util.static_page('App2') -NOT_FOUND = wsgi_util.error(six.moves.http_client.NOT_FOUND) - - -class WsgiTestBase(webapp_test_util.WebServerTestBase): - - server_thread = None - - def CreateWsgiApplication(self): - return None - - def DoHttpRequest(self, - path='/', - content=None, - content_type='text/plain; charset=utf-8', - headers=None): - connection = six.moves.http_client.HTTPConnection('localhost', self.port) - if content is None: - method = 'GET' - else: - method = 'POST' - headers = {'content=type': content_type} - headers.update(headers) - connection.request(method, path, content, headers) - response = connection.getresponse() - - not_date_or_server = lambda header: header[0] not in ('date', 'server') - headers = list(filter(not_date_or_server, response.getheaders())) - - return response.status, response.reason, response.read(), dict(headers) - - -class StaticPageBase(WsgiTestBase): - - def testDefault(self): - default_page = wsgi_util.static_page() - self.ResetServer(default_page) - status, reason, content, headers = self.DoHttpRequest() - self.assertEquals(200, status) - self.assertEquals('OK', reason) - self.assertEquals('', content) - self.assertEquals({'content-length': '0', - 'content-type': 'text/html; charset=utf-8', - }, - headers) - - def testHasContent(self): - default_page = wsgi_util.static_page('my content') - self.ResetServer(default_page) - status, reason, content, headers = self.DoHttpRequest() - self.assertEquals(200, status) - self.assertEquals('OK', reason) - self.assertEquals('my content', content) - self.assertEquals({'content-length': str(len('my content')), - 'content-type': 'text/html; charset=utf-8', - }, - headers) - - def testHasContentType(self): - default_page = wsgi_util.static_page(content_type='text/plain') - self.ResetServer(default_page) - status, reason, content, headers = self.DoHttpRequest() - self.assertEquals(200, status) - self.assertEquals('OK', reason) - self.assertEquals('', content) - self.assertEquals({'content-length': '0', - 'content-type': 'text/plain', - }, - headers) - - def testHasStatus(self): - default_page = wsgi_util.static_page(status='400 Not Good Request') - self.ResetServer(default_page) - status, reason, content, headers = self.DoHttpRequest() - self.assertEquals(400, status) - self.assertEquals('Not Good Request', reason) - self.assertEquals('', content) - self.assertEquals({'content-length': '0', - 'content-type': 'text/html; charset=utf-8', - }, - headers) - - def testHasStatusInt(self): - default_page = wsgi_util.static_page(status=401) - self.ResetServer(default_page) - status, reason, content, headers = self.DoHttpRequest() - self.assertEquals(401, status) - self.assertEquals('Unauthorized', reason) - self.assertEquals('', content) - self.assertEquals({'content-length': '0', - 'content-type': 'text/html; charset=utf-8', - }, - headers) - - def testHasStatusUnknown(self): - default_page = wsgi_util.static_page(status=909) - self.ResetServer(default_page) - status, reason, content, headers = self.DoHttpRequest() - self.assertEquals(909, status) - self.assertEquals('Unknown Error', reason) - self.assertEquals('', content) - self.assertEquals({'content-length': '0', - 'content-type': 'text/html; charset=utf-8', - }, - headers) - - def testHasStatusTuple(self): - default_page = wsgi_util.static_page(status=(500, 'Bad Thing')) - self.ResetServer(default_page) - status, reason, content, headers = self.DoHttpRequest() - self.assertEquals(500, status) - self.assertEquals('Bad Thing', reason) - self.assertEquals('', content) - self.assertEquals({'content-length': '0', - 'content-type': 'text/html; charset=utf-8', - }, - headers) - - def testHasHeaders(self): - default_page = wsgi_util.static_page(headers=[('x', 'foo'), - ('a', 'bar'), - ('z', 'bin')]) - self.ResetServer(default_page) - status, reason, content, headers = self.DoHttpRequest() - self.assertEquals(200, status) - self.assertEquals('OK', reason) - self.assertEquals('', content) - self.assertEquals({'content-length': '0', - 'content-type': 'text/html; charset=utf-8', - 'x': 'foo', - 'a': 'bar', - 'z': 'bin', - }, - headers) - - def testHeadersUnicodeSafe(self): - default_page = wsgi_util.static_page(headers=[('x', u'foo')]) - self.ResetServer(default_page) - status, reason, content, headers = self.DoHttpRequest() - self.assertEquals(200, status) - self.assertEquals('OK', reason) - self.assertEquals('', content) - self.assertEquals({'content-length': '0', - 'content-type': 'text/html; charset=utf-8', - 'x': 'foo', - }, - headers) - self.assertTrue(isinstance(headers['x'], str)) - - def testHasHeadersDict(self): - default_page = wsgi_util.static_page(headers={'x': 'foo', - 'a': 'bar', - 'z': 'bin'}) - self.ResetServer(default_page) - status, reason, content, headers = self.DoHttpRequest() - self.assertEquals(200, status) - self.assertEquals('OK', reason) - self.assertEquals('', content) - self.assertEquals({'content-length': '0', - 'content-type': 'text/html; charset=utf-8', - 'x': 'foo', - 'a': 'bar', - 'z': 'bin', - }, - headers) - - -class FirstFoundTest(WsgiTestBase): - - def testEmptyConfiguration(self): - self.ResetServer(wsgi_util.first_found([])) - status, status_text, content, headers = self.DoHttpRequest('/') - self.assertEquals(six.moves.http_client.NOT_FOUND, status) - self.assertEquals(six.moves.http_client.responses[six.moves.http_client.NOT_FOUND], status_text) - self.assertEquals(util.pad_string(six.moves.http_client.responses[six.moves.http_client.NOT_FOUND]), - content) - self.assertEquals({'content-length': '512', - 'content-type': 'text/plain; charset=utf-8', - }, - headers) - - def testOneApp(self): - self.ResetServer(wsgi_util.first_found([APP1])) - - status, status_text, content, headers = self.DoHttpRequest('/') - self.assertEquals(six.moves.http_client.OK, status) - self.assertEquals(six.moves.http_client.responses[six.moves.http_client.OK], status_text) - self.assertEquals('App1', content) - self.assertEquals({'content-length': '4', - 'content-type': 'text/html; charset=utf-8', - }, - headers) - - def testIterator(self): - self.ResetServer(wsgi_util.first_found(iter([APP1]))) - - status, status_text, content, headers = self.DoHttpRequest('/') - self.assertEquals(six.moves.http_client.OK, status) - self.assertEquals(six.moves.http_client.responses[six.moves.http_client.OK], status_text) - self.assertEquals('App1', content) - self.assertEquals({'content-length': '4', - 'content-type': 'text/html; charset=utf-8', - }, - headers) - - # Do request again to make sure iterator was properly copied. - status, status_text, content, headers = self.DoHttpRequest('/') - self.assertEquals(six.moves.http_client.OK, status) - self.assertEquals(six.moves.http_client.responses[six.moves.http_client.OK], status_text) - self.assertEquals('App1', content) - self.assertEquals({'content-length': '4', - 'content-type': 'text/html; charset=utf-8', - }, - headers) - - def testTwoApps(self): - self.ResetServer(wsgi_util.first_found([APP1, APP2])) - - status, status_text, content, headers = self.DoHttpRequest('/') - self.assertEquals(six.moves.http_client.OK, status) - self.assertEquals(six.moves.http_client.responses[six.moves.http_client.OK], status_text) - self.assertEquals('App1', content) - self.assertEquals({'content-length': '4', - 'content-type': 'text/html; charset=utf-8', - }, - headers) - - def testFirstNotFound(self): - self.ResetServer(wsgi_util.first_found([NOT_FOUND, APP2])) - - status, status_text, content, headers = self.DoHttpRequest('/') - self.assertEquals(six.moves.http_client.OK, status) - self.assertEquals(six.moves.http_client.responses[six.moves.http_client.OK], status_text) - self.assertEquals('App2', content) - self.assertEquals({'content-length': '4', - 'content-type': 'text/html; charset=utf-8', - }, - headers) - - def testOnlyNotFound(self): - def current_error(environ, start_response): - """The variable current_status is defined in loop after ResetServer.""" - headers = [('content-type', 'text/plain')] - status_line = '%03d Whatever' % current_status - start_response(status_line, headers) - return [] - - self.ResetServer(wsgi_util.first_found([current_error, APP2])) - - statuses_to_check = sorted(httplib.responses.keys()) - # 100, 204 and 304 have slightly different expectations, so they are left - # out of this test in order to keep the code simple. - for dont_check in (100, 200, 204, 304, 404): - statuses_to_check.remove(dont_check) - for current_status in statuses_to_check: - status, status_text, content, headers = self.DoHttpRequest('/') - self.assertEquals(current_status, status) - self.assertEquals('Whatever', status_text) - - -if __name__ == '__main__': - unittest.main() From 4d4d532f373c936f720f20b9fb163a819d9708f6 Mon Sep 17 00:00:00 2001 From: Rose Davidson Date: Fri, 27 Jul 2018 16:54:35 -0700 Subject: [PATCH 3/6] Remove some unused parts of protorpc. No endpoints service should be using these. --- .../protorpc/experimental/__init__.py | 20 - .../protorpc/experimental/parser/protobuf.g | 159 ---- .../experimental/parser/protobuf_lexer.g | 153 ---- .../protorpc/experimental/parser/pyprotobuf.g | 45 - .../protorpc/experimental/parser/test.proto | 27 - endpoints/internal/protorpc/generate.py | 128 --- endpoints/internal/protorpc/generate_proto.py | 127 --- .../internal/protorpc/generate_python.py | 218 ----- endpoints/internal/protorpc/static/base.html | 57 -- endpoints/internal/protorpc/static/forms.html | 31 - endpoints/internal/protorpc/static/forms.js | 685 -------------- .../protorpc/static/jquery-1.4.2.min.js | 154 ---- .../protorpc/static/jquery.json-2.2.min.js | 31 - .../internal/protorpc/static/methods.html | 37 - .../internal/protorpc/webapp/__init__.py | 18 - endpoints/internal/protorpc/webapp/forms.py | 163 ---- .../protorpc/webapp/google_imports.py | 25 - .../protorpc/webapp/service_handlers.py | 834 ------------------ 18 files changed, 2912 deletions(-) delete mode 100644 endpoints/internal/protorpc/experimental/__init__.py delete mode 100644 endpoints/internal/protorpc/experimental/parser/protobuf.g delete mode 100644 endpoints/internal/protorpc/experimental/parser/protobuf_lexer.g delete mode 100644 endpoints/internal/protorpc/experimental/parser/pyprotobuf.g delete mode 100644 endpoints/internal/protorpc/experimental/parser/test.proto delete mode 100644 endpoints/internal/protorpc/generate.py delete mode 100644 endpoints/internal/protorpc/generate_proto.py delete mode 100644 endpoints/internal/protorpc/generate_python.py delete mode 100644 endpoints/internal/protorpc/static/base.html delete mode 100644 endpoints/internal/protorpc/static/forms.html delete mode 100644 endpoints/internal/protorpc/static/forms.js delete mode 100644 endpoints/internal/protorpc/static/jquery-1.4.2.min.js delete mode 100644 endpoints/internal/protorpc/static/jquery.json-2.2.min.js delete mode 100644 endpoints/internal/protorpc/static/methods.html delete mode 100644 endpoints/internal/protorpc/webapp/__init__.py delete mode 100644 endpoints/internal/protorpc/webapp/forms.py delete mode 100644 endpoints/internal/protorpc/webapp/google_imports.py delete mode 100644 endpoints/internal/protorpc/webapp/service_handlers.py diff --git a/endpoints/internal/protorpc/experimental/__init__.py b/endpoints/internal/protorpc/experimental/__init__.py deleted file mode 100644 index 419fff2..0000000 --- a/endpoints/internal/protorpc/experimental/__init__.py +++ /dev/null @@ -1,20 +0,0 @@ -#!/usr/bin/env python -# -# Copyright 2011 Google Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -"""Main module for ProtoRPC package.""" - -__author__ = 'rafek@google.com (Rafe Kaplan)' diff --git a/endpoints/internal/protorpc/experimental/parser/protobuf.g b/endpoints/internal/protorpc/experimental/parser/protobuf.g deleted file mode 100644 index 8115be5..0000000 --- a/endpoints/internal/protorpc/experimental/parser/protobuf.g +++ /dev/null @@ -1,159 +0,0 @@ -/* !/usr/bin/env python - * - * Copyright 2011 Google Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -parser grammar protobuf; - -scalar_value - : STRING - | FLOAT - | INT - | BOOL - ; - -id - : ID - | PACKAGE - | SERVICE - | MESSAGE - | ENUM - | DATA_TYPE - | EXTENSIONS - ; - -user_option_id - : '(' name_root='.'? qualified_name ')' - -> ^(USER_OPTION_ID $name_root? qualified_name) - ; - -option_id - : (id | user_option_id) ('.'! (id | user_option_id))* - ; - -option - : option_id '=' (scalar_value | id) - -> ^(OPTION ^(OPTION_ID option_id) scalar_value? id?) - ; - -decl_options - : '[' option (',' option)* ']' - -> ^(OPTIONS option*) - ; - -qualified_name - : id ('.'! id)* - ; - -field_decl - : qualified_name id '=' INT decl_options? ';' - -> ^(FIELD_TYPE qualified_name) id INT decl_options? - | GROUP id '=' INT '{' message_def '}' - -> ^(FIELD_TYPE GROUP) id INT ^(GROUP_MESSAGE message_def) - ; - -field - : LABEL field_decl - -> ^(FIELD LABEL field_decl) - ; - -enum_decl - : id '=' INT decl_options? ';' - -> ^(ENUM_DECL id INT decl_options?) - ; - -enum_def - : ENUM id '{' (def_option | enum_decl | ';')* '}' - -> ^(ENUM id - ^(OPTIONS def_option*) - ^(ENUM_DECLS enum_decl*)) - ; - -extensions - : EXTENSIONS start=INT (TO (end=INT | end=MAX))? ';' -> ^(EXTENSION_RANGE $start $end) - ; - -message_def - : ( field - | enum_def - | message - | extension - | extensions - | def_option - | ';' - )* -> - ^(FIELDS field*) - ^(MESSAGES message*) - ^(ENUMS enum_def*) - ^(EXTENSIONS extensions*) - ^(OPTIONS def_option*) - ; - -message - : MESSAGE^ id '{'! message_def '}'! - ; - -method_options - : '{'! (def_option | ';'!)+ '}'! - ; - -method_def - : RPC id '(' qualified_name ')' - RETURNS '(' qualified_name ')' (method_options | ';') - ; - -service_defs - : (def_option | method_def | ';')+ - ; - -service - : SERVICE id '{' service_defs? '}' - ; - -extension - : EXTEND qualified_name '{' message_def '}' - ; - -import_line - : IMPORT! STRING ';'! - ; - -package_decl - : PACKAGE^ qualified_name ';'! - ; - -def_option - : OPTION option ';' -> option - ; - -proto_file - : ( package_decl - | import_line - | message - | enum_def - | service - | extension - | def_option - | ';' - )* - -> ^(PROTO_FILE package_decl* - ^(IMPORTS import_line*) - ^(MESSAGES message*) - ^(ENUMS enum_def*) - ^(SERVICES service*) - ^(EXTENSIONS extension*) - ^(OPTIONS def_option*) - ) - ; diff --git a/endpoints/internal/protorpc/experimental/parser/protobuf_lexer.g b/endpoints/internal/protorpc/experimental/parser/protobuf_lexer.g deleted file mode 100644 index be789b5..0000000 --- a/endpoints/internal/protorpc/experimental/parser/protobuf_lexer.g +++ /dev/null @@ -1,153 +0,0 @@ -/* !/usr/bin/env python - * - * Copyright 2011 Google Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -lexer grammar protobuf_lexer; - -tokens { - // Imaginary tree nodes. - ENUMS; - ENUM_DECL; - ENUM_DECLS; - EXTENSION_RANGE; - FIELD; - FIELDS; - FIELD_TYPE; - GROUP_MESSAGE; - IMPORTS; - MESSAGES; - NAME_ROOT; - OPTIONS; - OPTION_ID; - PROTO_FILE; - SERVICES; - USER_OPTION_ID; -} - -// Basic keyword tokens. -ENUM : 'enum'; -MESSAGE : 'message'; -IMPORT : 'import'; -OPTION : 'option'; -PACKAGE : 'package'; -RPC : 'rpc'; -SERVICE : 'service'; -RETURNS : 'returns'; -EXTEND : 'extend'; -EXTENSIONS : 'extensions'; -TO : 'to'; -GROUP : 'group'; -MAX : 'max'; - -COMMENT - : '//' ~('\n'|'\r')* '\r'? '\n' {$channel=HIDDEN;} - | '/*' ( options {greedy=false;} : . )* '*/' {$channel=HIDDEN;} - ; - -WS - : ( ' ' - | '\t' - | '\r' - | '\n' - ) {$channel=HIDDEN;} - ; - -DATA_TYPE - : 'double' - | 'float' - | 'int32' - | 'int64' - | 'uint32' - | 'uint64' - | 'sint32' - | 'sint64' - | 'fixed32' - | 'fixed64' - | 'sfixed32' - | 'sfixed64' - | 'bool' - | 'string' - | 'bytes' - ; - -LABEL - : 'required' - | 'optional' - | 'repeated' - ; - -BOOL - : 'true' - | 'false' - ; - -ID - : ('a'..'z'|'A'..'Z'|'_') ('a'..'z'|'A'..'Z'|'0'..'9'|'_')* - ; - -INT - : '-'? ('0'..'9'+ | '0x' ('a'..'f'|'A'..'F'|'0'..'9')+ | 'inf') - | 'nan' - ; - -FLOAT - : '-'? ('0'..'9')+ '.' ('0'..'9')* EXPONENT? - | '-'? '.' ('0'..'9')+ EXPONENT? - | '-'? ('0'..'9')+ EXPONENT - ; - -STRING - : '"' ( STRING_INNARDS )* '"'; - -fragment -STRING_INNARDS - : ESC_SEQ - | ~('\\'|'"') - ; - -fragment -EXPONENT - : ('e'|'E') ('+'|'-')? ('0'..'9')+ - ; - -fragment -HEX_DIGIT - : ('0'..'9'|'a'..'f'|'A'..'F') - ; - -fragment -ESC_SEQ - : '\\' ('a'|'b'|'t'|'n'|'f'|'r'|'v'|'\"'|'\''|'\\') - | UNICODE_ESC - | OCTAL_ESC - | HEX_ESC - ; - -fragment -OCTAL_ESC - : '\\' ('0'..'3') ('0'..'7') ('0'..'7') - | '\\' ('0'..'7') ('0'..'7') - | '\\' ('0'..'7') - ; - -fragment -HEX_ESC - : '\\x' HEX_DIGIT HEX_DIGIT - ; - -fragment -UNICODE_ESC - : '\\' 'u' HEX_DIGIT HEX_DIGIT HEX_DIGIT HEX_DIGIT - ; diff --git a/endpoints/internal/protorpc/experimental/parser/pyprotobuf.g b/endpoints/internal/protorpc/experimental/parser/pyprotobuf.g deleted file mode 100644 index 534e1f8..0000000 --- a/endpoints/internal/protorpc/experimental/parser/pyprotobuf.g +++ /dev/null @@ -1,45 +0,0 @@ -/* !/usr/bin/env python - * - * Copyright 2011 Google Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -grammar pyprotobuf; - -options { -// language=Python; - output = AST; - ASTLabelType = CommonTree; -} - -import protobuf_lexer, protobuf; - -// For reasons I do not understand the HIDDEN elements from the imported -// with their channel intact. - -COMMENT - : '//' ~('\n'|'\r')* '\r'? '\n' {$channel=HIDDEN;} - | '/*' ( options {greedy=false;} : . )* '*/' {$channel=HIDDEN;} - ; - -WS : ( ' ' - | '\t' - | '\r' - | '\n' - ) {$channel=HIDDEN;} - ; - -py_proto_file - : proto_file EOF^ - ; diff --git a/endpoints/internal/protorpc/experimental/parser/test.proto b/endpoints/internal/protorpc/experimental/parser/test.proto deleted file mode 100644 index 438e1e6..0000000 --- a/endpoints/internal/protorpc/experimental/parser/test.proto +++ /dev/null @@ -1,27 +0,0 @@ -/* !/usr/bin/env python - * - * Copyright 2011 Google Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package a.b.c; - -import "abc.def"; -import "from/here"; - -message MyMessage { - required int64 thing = 1 [a="b"]; - optional group whatever = 2 { - repeated int64 thing = 1; - } -} diff --git a/endpoints/internal/protorpc/generate.py b/endpoints/internal/protorpc/generate.py deleted file mode 100644 index 9a2630b..0000000 --- a/endpoints/internal/protorpc/generate.py +++ /dev/null @@ -1,128 +0,0 @@ -#!/usr/bin/env python -# -# Copyright 2010 Google Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -__author__ = 'rafek@google.com (Rafe Kaplan)' - -import contextlib - -from . import messages -from . import util - -__all__ = ['IndentationError', - 'IndentWriter', - ] - - -class IndentationError(messages.Error): - """Raised when end_indent is called too many times.""" - - -class IndentWriter(object): - """Utility class to make it easy to write formatted indented text. - - IndentWriter delegates to a file-like object and is able to keep track of the - level of indentation. Each call to write_line will write a line terminated - by a new line proceeded by a number of spaces indicated by the current level - of indentation. - - IndexWriter overloads the << operator to make line writing operations clearer. - - The indent method returns a context manager that can be used by the Python - with statement that makes generating python code easier to use. For example: - - index_writer << 'def factorial(n):' - with index_writer.indent(): - index_writer << 'if n <= 1:' - with index_writer.indent(): - index_writer << 'return 1' - index_writer << 'else:' - with index_writer.indent(): - index_writer << 'return factorial(n - 1)' - - This would generate: - - def factorial(n): - if n <= 1: - return 1 - else: - return factorial(n - 1) - """ - - @util.positional(2) - def __init__(self, output, indent_space=2): - """Constructor. - - Args: - output: File-like object to wrap. - indent_space: Number of spaces each level of indentation will be. - """ - # Private attributes: - # - # __output: The wrapped file-like object. - # __indent_space: String to append for each level of indentation. - # __indentation: The current full indentation string. - self.__output = output - self.__indent_space = indent_space * ' ' - self.__indentation = 0 - - @property - def indent_level(self): - """Current level of indentation for IndentWriter.""" - return self.__indentation - - def write_line(self, line): - """Write line to wrapped file-like object using correct indentation. - - The line is written with the current level of indentation printed before it - and terminated by a new line. - - Args: - line: Line to write to wrapped file-like object. - """ - if line != '': - self.__output.write(self.__indentation * self.__indent_space) - self.__output.write(line) - self.__output.write('\n') - - def begin_indent(self): - """Begin a level of indentation.""" - self.__indentation += 1 - - def end_indent(self): - """Undo the most recent level of indentation. - - Raises: - IndentationError when called with no indentation levels. - """ - if not self.__indentation: - raise IndentationError('Unable to un-indent further') - self.__indentation -= 1 - - @contextlib.contextmanager - def indent(self): - """Create indentation level compatible with the Python 'with' keyword.""" - self.begin_indent() - yield - self.end_indent() - - def __lshift__(self, line): - """Syntactic sugar for write_line method. - - Args: - line: Line to write to wrapped file-like object. - """ - self.write_line(line) diff --git a/endpoints/internal/protorpc/generate_proto.py b/endpoints/internal/protorpc/generate_proto.py deleted file mode 100644 index 8e4b19e..0000000 --- a/endpoints/internal/protorpc/generate_proto.py +++ /dev/null @@ -1,127 +0,0 @@ -#!/usr/bin/env python -# -# Copyright 2010 Google Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -from __future__ import with_statement - -__author__ = 'rafek@google.com (Rafe Kaplan)' - -import logging - -from . import descriptor -from . import generate -from . import messages -from . import util - - -__all__ = ['format_proto_file'] - - -@util.positional(2) -def format_proto_file(file_descriptor, output, indent_space=2): - out = generate.IndentWriter(output, indent_space=indent_space) - - if file_descriptor.package: - out << 'package %s;' % file_descriptor.package - - def write_enums(enum_descriptors): - """Write nested and non-nested Enum types. - - Args: - enum_descriptors: List of EnumDescriptor objects from which to generate - enums. - """ - # Write enums. - for enum in enum_descriptors or []: - out << '' - out << '' - out << 'enum %s {' % enum.name - out << '' - - with out.indent(): - if enum.values: - for enum_value in enum.values: - out << '%s = %s;' % (enum_value.name, enum_value.number) - - out << '}' - - write_enums(file_descriptor.enum_types) - - def write_fields(field_descriptors): - """Write fields for Message types. - - Args: - field_descriptors: List of FieldDescriptor objects from which to generate - fields. - """ - for field in field_descriptors or []: - default_format = '' - if field.default_value is not None: - if field.label == descriptor.FieldDescriptor.Label.REPEATED: - logging.warning('Default value for repeated field %s is not being ' - 'written to proto file' % field.name) - else: - # Convert default value to string. - if field.variant == messages.Variant.MESSAGE: - logging.warning( - 'Message field %s should not have default values' % field.name) - default = None - elif field.variant == messages.Variant.STRING: - default = repr(field.default_value.encode('utf-8')) - elif field.variant == messages.Variant.BYTES: - default = repr(field.default_value) - else: - default = str(field.default_value) - - if default is not None: - default_format = ' [default=%s]' % default - - if field.variant in (messages.Variant.MESSAGE, messages.Variant.ENUM): - field_type = field.type_name - else: - field_type = str(field.variant).lower() - - out << '%s %s %s = %s%s;' % (str(field.label).lower(), - field_type, - field.name, - field.number, - default_format) - - def write_messages(message_descriptors): - """Write nested and non-nested Message types. - - Args: - message_descriptors: List of MessageDescriptor objects from which to - generate messages. - """ - for message in message_descriptors or []: - out << '' - out << '' - out << 'message %s {' % message.name - - with out.indent(): - if message.enum_types: - write_enums(message.enum_types) - - if message.message_types: - write_messages(message.message_types) - - if message.fields: - write_fields(message.fields) - - out << '}' - - write_messages(file_descriptor.message_types) diff --git a/endpoints/internal/protorpc/generate_python.py b/endpoints/internal/protorpc/generate_python.py deleted file mode 100644 index 5234e05..0000000 --- a/endpoints/internal/protorpc/generate_python.py +++ /dev/null @@ -1,218 +0,0 @@ -#!/usr/bin/env python -# -# Copyright 2010 Google Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -from __future__ import with_statement - -__author__ = 'rafek@google.com (Rafe Kaplan)' - -from . import descriptor -from . import generate -from . import message_types -from . import messages -from . import util - - -__all__ = ['format_python_file'] - -_MESSAGE_FIELD_MAP = { - message_types.DateTimeMessage.definition_name(): message_types.DateTimeField, -} - - -def _write_enums(enum_descriptors, out): - """Write nested and non-nested Enum types. - - Args: - enum_descriptors: List of EnumDescriptor objects from which to generate - enums. - out: Indent writer used for generating text. - """ - # Write enums. - for enum in enum_descriptors or []: - out << '' - out << '' - out << 'class %s(messages.Enum):' % enum.name - out << '' - - with out.indent(): - if not enum.values: - out << 'pass' - else: - for enum_value in enum.values: - out << '%s = %s' % (enum_value.name, enum_value.number) - - -def _write_fields(field_descriptors, out): - """Write fields for Message types. - - Args: - field_descriptors: List of FieldDescriptor objects from which to generate - fields. - out: Indent writer used for generating text. - """ - out << '' - for field in field_descriptors or []: - type_format = '' - label_format = '' - - message_field = _MESSAGE_FIELD_MAP.get(field.type_name) - if message_field: - module = 'message_types' - field_type = message_field - else: - module = 'messages' - field_type = messages.Field.lookup_field_type_by_variant(field.variant) - - if field_type in (messages.EnumField, messages.MessageField): - type_format = '\'%s\', ' % field.type_name - - if field.label == descriptor.FieldDescriptor.Label.REQUIRED: - label_format = ', required=True' - - elif field.label == descriptor.FieldDescriptor.Label.REPEATED: - label_format = ', repeated=True' - - if field_type.DEFAULT_VARIANT != field.variant: - variant_format = ', variant=messages.Variant.%s' % field.variant - else: - variant_format = '' - - if field.default_value: - if field_type in [messages.BytesField, - messages.StringField, - ]: - default_value = repr(field.default_value) - elif field_type is messages.EnumField: - try: - default_value = str(int(field.default_value)) - except ValueError: - default_value = repr(field.default_value) - else: - default_value = field.default_value - - default_format = ', default=%s' % (default_value,) - else: - default_format = '' - - out << '%s = %s.%s(%s%s%s%s%s)' % (field.name, - module, - field_type.__name__, - type_format, - field.number, - label_format, - variant_format, - default_format) - - -def _write_messages(message_descriptors, out): - """Write nested and non-nested Message types. - - Args: - message_descriptors: List of MessageDescriptor objects from which to - generate messages. - out: Indent writer used for generating text. - """ - for message in message_descriptors or []: - out << '' - out << '' - out << 'class %s(messages.Message):' % message.name - - with out.indent(): - if not (message.enum_types or message.message_types or message.fields): - out << '' - out << 'pass' - else: - _write_enums(message.enum_types, out) - _write_messages(message.message_types, out) - _write_fields(message.fields, out) - - -def _write_methods(method_descriptors, out): - """Write methods of Service types. - - All service method implementations raise NotImplementedError. - - Args: - method_descriptors: List of MethodDescriptor objects from which to - generate methods. - out: Indent writer used for generating text. - """ - for method in method_descriptors: - out << '' - out << "@remote.method('%s', '%s')" % (method.request_type, - method.response_type) - out << 'def %s(self, request):' % (method.name,) - with out.indent(): - out << ('raise NotImplementedError' - "('Method %s is not implemented')" % (method.name)) - - -def _write_services(service_descriptors, out): - """Write Service types. - - Args: - service_descriptors: List of ServiceDescriptor instances from which to - generate services. - out: Indent writer used for generating text. - """ - for service in service_descriptors or []: - out << '' - out << '' - out << 'class %s(remote.Service):' % service.name - - with out.indent(): - if service.methods: - _write_methods(service.methods, out) - else: - out << '' - out << 'pass' - - -@util.positional(2) -def format_python_file(file_descriptor, output, indent_space=2): - """Format FileDescriptor object as a single Python module. - - Services generated by this function will raise NotImplementedError. - - All Python classes generated by this function use delayed binding for all - message fields, enum fields and method parameter types. For example a - service method might be generated like so: - - class MyService(remote.Service): - - @remote.method('my_package.MyRequestType', 'my_package.MyResponseType') - def my_method(self, request): - raise NotImplementedError('Method my_method is not implemented') - - Args: - file_descriptor: FileDescriptor instance to format as python module. - output: File-like object to write module source code to. - indent_space: Number of spaces for each level of Python indentation. - """ - out = generate.IndentWriter(output, indent_space=indent_space) - - out << 'from protorpc import message_types' - out << 'from protorpc import messages' - if file_descriptor.service_types: - out << 'from protorpc import remote' - - if file_descriptor.package: - out << "package = '%s'" % file_descriptor.package - - _write_enums(file_descriptor.enum_types, out) - _write_messages(file_descriptor.message_types, out) - _write_services(file_descriptor.service_types, out) diff --git a/endpoints/internal/protorpc/static/base.html b/endpoints/internal/protorpc/static/base.html deleted file mode 100644 index a62db7c..0000000 --- a/endpoints/internal/protorpc/static/base.html +++ /dev/null @@ -1,57 +0,0 @@ - - - - - - {% block title%}Need title{% endblock %} - - - - - - - - - {% block top %}Need top{% endblock %} - -
- - {% block body %}Need body{% endblock %} - - - diff --git a/endpoints/internal/protorpc/static/forms.html b/endpoints/internal/protorpc/static/forms.html deleted file mode 100644 index 9ba22ec..0000000 --- a/endpoints/internal/protorpc/static/forms.html +++ /dev/null @@ -1,31 +0,0 @@ - - -{% extends 'base.html' %} - -{% block title %}ProtoRPC Methods for {{hostname|escape}}{% endblock %} - -{% block top %} -

ProtoRPC Methods for {{hostname|escape}}

-{% endblock %} - -{% block body %} -
-{% endblock %} - -{% block call %} -loadServices(showMethods); -{% endblock %} diff --git a/endpoints/internal/protorpc/static/forms.js b/endpoints/internal/protorpc/static/forms.js deleted file mode 100644 index 3c59252..0000000 --- a/endpoints/internal/protorpc/static/forms.js +++ /dev/null @@ -1,685 +0,0 @@ -// Copyright 2010 Google Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// - -/** - * @fileoverview Render form appropriate for RPC method. - * @author rafek@google.com (Rafe Kaplan) - */ - - -var FORM_VISIBILITY = { - SHOW_FORM: 'Show Form', - HIDE_FORM: 'Hide Form' -}; - - -var LABEL = { - OPTIONAL: 'OPTIONAL', - REQUIRED: 'REQUIRED', - REPEATED: 'REPEATED' -}; - - -var objectId = 0; - - -/** - * Variants defined in protorpc/messages.py. - */ -var VARIANT = { - DOUBLE: 'DOUBLE', - FLOAT: 'FLOAT', - INT64: 'INT64', - UINT64: 'UINT64', - INT32: 'INT32', - BOOL: 'BOOL', - STRING: 'STRING', - MESSAGE: 'MESSAGE', - BYTES: 'BYTES', - UINT32: 'UINT32', - ENUM: 'ENUM', - SINT32: 'SINT32', - SINT64: 'SINT64' -}; - - -/** - * Data structure used to represent a form to data element. - * @param {Object} field Field descriptor that form element represents. - * @param {Object} container Element that contains field. - * @return {FormElement} New object representing a form element. Element - * starts enabled. - * @constructor - */ -function FormElement(field, container) { - this.field = field; - this.container = container; - this.enabled = true; -} - - -/** - * Display error message in error panel. - * @param {string} message Message to display in panel. - */ -function error(message) { - $('
').appendTo($('#error-messages')).text(message); -} - - -/** - * Display request errors in error panel. - * @param {object} XMLHttpRequest object. - */ -function handleRequestError(response) { - var contentType = response.getResponseHeader('content-type'); - if (contentType == 'application/json') { - var response_error = $.parseJSON(response.responseText); - var error_message = response_error.error_message; - if (error.state == 'APPLICATION_ERROR' && error.error_name) { - error_message = error_message + ' (' + error.error_name + ')'; - } - } else { - error_message = '' + response.status + ': ' + response.statusText; - } - - error(error_message); -} - - -/** - * Send JSON RPC to remote method. - * @param {string} path Path of service on originating server to send request. - * @param {string} method Name of method to invoke. - * @param {Object} request Message to send as request. - * @param {function} on_success Function to call upon successful request. - */ -function sendRequest(path, method, request, onSuccess) { - $.ajax({url: path + '.' + method, - type: 'POST', - contentType: 'application/json', - data: $.toJSON(request), - dataType: 'json', - success: onSuccess, - error: handleRequestError - }); -} - - -/** - * Create callback that enables and disables field element when associated - * checkbox is clicked. - * @param {Element} checkbox Checkbox that will be clicked. - * @param {FormElement} form Form element that will be toggled for editing. - * @param {Object} disableMessage HTML element to display in place of element. - * @return Callback that is invoked every time checkbox is clicked. - */ -function toggleInput(checkbox, form, disableMessage) { - return function() { - var checked = checkbox.checked; - if (checked) { - buildIndividualForm(form); - form.enabled = true; - disableMessage.hide(); - } else { - form.display.empty(); - form.enabled = false; - disableMessage.show(); - } - }; -} - - -/** - * Build an enum field. - * @param {FormElement} form Form to build element for. - */ -function buildEnumField(form) { - form.descriptor = enumDescriptors[form.field.type_name]; - form.input = $(''); - form.input[0].checked = Boolean(form.field.default_value); -} - - -/** - * Build text field. - * @param {FormElement} form Form to build element for. - */ -function buildTextField(form) { - form.input = $(''); - form.input. - attr('value', form.field.default_value || ''); -} - - -/** - * Build individual input element. - * @param {FormElement} form Form to build element for. - */ -function buildIndividualForm(form) { - form.required = form.label == LABEL.REQUIRED; - - if (form.field.variant == VARIANT.ENUM) { - buildEnumField(form); - } else if (form.field.variant == VARIANT.MESSAGE) { - buildMessageField(form); - } else if (form.field.variant == VARIANT.BOOL) { - buildBooleanField(form); - } else { - buildTextField(form); - } - - form.display.append(form.input); - - // TODO: Handle base64 encoding for BYTES field. - if (form.field.variant == VARIANT.BYTES) { - $("use base64 encoding").appendTo(form.display); - } -} - - -/** - * Add repeated field. This function is called when an item is added - * @param {FormElement} form Repeated form element to create item for. - */ -function addRepeatedFieldItem(form) { - var row = $('').appendTo(form.display); - subForm = new FormElement(form.field, row); - form.fields.push(subForm); - buildFieldForm(subForm, false); -} - - -/** - * Build repeated field. Contains a button that can be used for adding new - * items. - * @param {FormElement} form Form to build element for. - */ -function buildRepeatedForm(form) { - form.fields = []; - form.display = $(''). - appendTo(form.container); - var header_row = $('').appendTo(form.display); - var header = $('').appendTo(form.table); - var fieldForm = new FormElement(field, row); - fieldForm.parent = form; - buildFieldForm(fieldForm, true); - form.fields.push(fieldForm); - }); - } -} - - -/** - * HTML Escape a string - */ -function htmlEscape(value) { - if (typeof(value) == "string") { - return value - .replace(/&/g, '&') - .replace(/>/g, '>') - .replace(/'; - }); - result += indentation + ']'; - } else { - result += '{
'; - $.each(value, function(name, item) { - result += (indentation + htmlEscape(name) + ': ' + - formatJSON(item, indent + 1) + ',
'); - }); - result += indentation + '}'; - } - } else { - result += htmlEscape(value); - } - - return result; -} - - -/** - * Construct array from repeated form element. - * @param {FormElement} form Form element to build array from. - * @return {Array} Array of repeated elements read from input form. - */ -function fromRepeatedForm(form) { - var values = []; - $.each(form.fields, function(index, subForm) { - values.push(fromIndividualForm(subForm)); - }); - return values; -} - - -/** - * Construct value from individual form element. - * @param {FormElement} form Form element to get value from. - * @return {string, Float, Integer, Boolean, object} Value extracted from - * individual field. The type depends on the field variant. - */ -function fromIndividualForm(form) { - switch(form.field.variant) { - case VARIANT.MESSAGE: - return fromMessageForm(form); - break; - - case VARIANT.DOUBLE: - case VARIANT.FLOAT: - return parseFloat(form.input.val()); - - case VARIANT.BOOL: - return form.input[0].checked; - break; - - case VARIANT.ENUM: - case VARIANT.STRING: - case VARIANT.BYTES: - return form.input.val(); - - default: - break; - } - return parseInt(form.input.val(), 10); -} - - -/** - * Extract entire message from a complete form. - * @param {FormElement} form Form to extract message from. - * @return {Object} Fully populated message object ready to transmit - * as JSON message. - */ -function fromMessageForm(form) { - var message = {}; - $.each(form.fields, function(index, subForm) { - if (subForm.enabled) { - var subMessage = undefined; - if (subForm.field.label == LABEL.REPEATED) { - subMessage = fromRepeatedForm(subForm); - } else { - subMessage = fromIndividualForm(subForm); - } - - message[subForm.field.name] = subMessage; - } - }); - - return message; -} - - -/** - * Send form as an RPC. Extracts message from root form and transmits to - * originating ProtoRPC server. Response is formatted as JSON and displayed - * to user. - */ -function sendForm() { - $('#error-messages').empty(); - $('#form-response').empty(); - message = fromMessageForm(root_form); - if (message === null) { - return; - } - - sendRequest(servicePath, methodName, message, function(response) { - $('#form-response').html(formatJSON(response, 0)); - hideForm(); - }); -} - - -/** - * Reset form to original state. Deletes existing form and rebuilds a new - * one from scratch. - */ -function resetForm() { - var panel = $('#form-panel'); - var serviceType = serviceMap[servicePath]; - var service = serviceDescriptors[serviceType]; - - panel.empty(); - - function formGenerationError(message) { - error(message); - panel.html('
' + - 'There was an error generating the service form' + - '
'); - } - - // Find method. - var requestTypeName = null; - $.each(service.methods, function(index, method) { - if (method.name == methodName) { - requestTypeName = method.request_type; - } - }); - - if (!requestTypeName) { - formGenerationError('No such method definition for: ' + methodName); - return; - } - - requestType = messageDescriptors[requestTypeName]; - if (!requestType) { - formGenerationError('No such message-type: ' + requestTypeName); - return; - } - - var root = $('
').appendTo(header_row); - var add_button = $(''); - - // Set name. - if (allowRepeated) { - var nameData = $(''); - nameData.text(form.field.name + ':'); - form.container.append(nameData); - } - - // Set input. - form.repeated = form.field.label == LABEL.REPEATED; - if (allowRepeated && form.repeated) { - inputData.attr('colspan', '2'); - buildRepeatedForm(form); - } else { - if (!allowRepeated) { - inputData.attr('colspan', '2'); - } - - form.display = $('
'); - - var controlData = $('
'); - if (form.field.label != LABEL.REQUIRED && allowRepeated) { - form.enabled = false; - var checkbox_id = 'checkbox-' + objectId; - objectId++; - $('').appendTo(controlData); - var checkbox = $('').appendTo(controlData); - var disableMessage = $('
').appendTo(inputData); - checkbox.change(toggleInput(checkbox[0], form, disableMessage)); - } else { - buildIndividualForm(form); - } - - if (form.repeated) { - // TODO: Implement deletion of repeated items. Needs to delete - // from DOM and also delete from form model. - } - - form.container.append(controlData); - } - - inputData.append(form.display); - form.container.append(inputData); -} - - -/** - * Top level function for building an entire message form. Called once at form - * creation and may be called again for nested message fields. Constructs a - * a table and builds a row for each sub-field. - * @params {FormElement} form Form to build message form for. - */ -function buildMessageForm(form, messageType) { - form.fields = []; - form.descriptor = messageType; - if (messageType.fields) { - $.each(messageType.fields, function(index, field) { - var row = $('
'). - appendTo(panel); - - root_form = new FormElement(null, null); - root_form.table = root; - buildMessageForm(root_form, requestType); - $('
a"; -var e=d.getElementsByTagName("*"),j=d.getElementsByTagName("a")[0];if(!(!e||!e.length||!j)){c.support={leadingWhitespace:d.firstChild.nodeType===3,tbody:!d.getElementsByTagName("tbody").length,htmlSerialize:!!d.getElementsByTagName("link").length,style:/red/.test(j.getAttribute("style")),hrefNormalized:j.getAttribute("href")==="/a",opacity:/^0.55$/.test(j.style.opacity),cssFloat:!!j.style.cssFloat,checkOn:d.getElementsByTagName("input")[0].value==="on",optSelected:s.createElement("select").appendChild(s.createElement("option")).selected, -parentNode:d.removeChild(d.appendChild(s.createElement("div"))).parentNode===null,deleteExpando:true,checkClone:false,scriptEval:false,noCloneEvent:true,boxModel:null};b.type="text/javascript";try{b.appendChild(s.createTextNode("window."+f+"=1;"))}catch(i){}a.insertBefore(b,a.firstChild);if(A[f]){c.support.scriptEval=true;delete A[f]}try{delete b.test}catch(o){c.support.deleteExpando=false}a.removeChild(b);if(d.attachEvent&&d.fireEvent){d.attachEvent("onclick",function k(){c.support.noCloneEvent= -false;d.detachEvent("onclick",k)});d.cloneNode(true).fireEvent("onclick")}d=s.createElement("div");d.innerHTML="";a=s.createDocumentFragment();a.appendChild(d.firstChild);c.support.checkClone=a.cloneNode(true).cloneNode(true).lastChild.checked;c(function(){var k=s.createElement("div");k.style.width=k.style.paddingLeft="1px";s.body.appendChild(k);c.boxModel=c.support.boxModel=k.offsetWidth===2;s.body.removeChild(k).style.display="none"});a=function(k){var n= -s.createElement("div");k="on"+k;var r=k in n;if(!r){n.setAttribute(k,"return;");r=typeof n[k]==="function"}return r};c.support.submitBubbles=a("submit");c.support.changeBubbles=a("change");a=b=d=e=j=null}})();c.props={"for":"htmlFor","class":"className",readonly:"readOnly",maxlength:"maxLength",cellspacing:"cellSpacing",rowspan:"rowSpan",colspan:"colSpan",tabindex:"tabIndex",usemap:"useMap",frameborder:"frameBorder"};var G="jQuery"+J(),Ya=0,za={};c.extend({cache:{},expando:G,noData:{embed:true,object:true, -applet:true},data:function(a,b,d){if(!(a.nodeName&&c.noData[a.nodeName.toLowerCase()])){a=a==A?za:a;var f=a[G],e=c.cache;if(!f&&typeof b==="string"&&d===w)return null;f||(f=++Ya);if(typeof b==="object"){a[G]=f;e[f]=c.extend(true,{},b)}else if(!e[f]){a[G]=f;e[f]={}}a=e[f];if(d!==w)a[b]=d;return typeof b==="string"?a[b]:a}},removeData:function(a,b){if(!(a.nodeName&&c.noData[a.nodeName.toLowerCase()])){a=a==A?za:a;var d=a[G],f=c.cache,e=f[d];if(b){if(e){delete e[b];c.isEmptyObject(e)&&c.removeData(a)}}else{if(c.support.deleteExpando)delete a[c.expando]; -else a.removeAttribute&&a.removeAttribute(c.expando);delete f[d]}}}});c.fn.extend({data:function(a,b){if(typeof a==="undefined"&&this.length)return c.data(this[0]);else if(typeof a==="object")return this.each(function(){c.data(this,a)});var d=a.split(".");d[1]=d[1]?"."+d[1]:"";if(b===w){var f=this.triggerHandler("getData"+d[1]+"!",[d[0]]);if(f===w&&this.length)f=c.data(this[0],a);return f===w&&d[1]?this.data(d[0]):f}else return this.trigger("setData"+d[1]+"!",[d[0],b]).each(function(){c.data(this, -a,b)})},removeData:function(a){return this.each(function(){c.removeData(this,a)})}});c.extend({queue:function(a,b,d){if(a){b=(b||"fx")+"queue";var f=c.data(a,b);if(!d)return f||[];if(!f||c.isArray(d))f=c.data(a,b,c.makeArray(d));else f.push(d);return f}},dequeue:function(a,b){b=b||"fx";var d=c.queue(a,b),f=d.shift();if(f==="inprogress")f=d.shift();if(f){b==="fx"&&d.unshift("inprogress");f.call(a,function(){c.dequeue(a,b)})}}});c.fn.extend({queue:function(a,b){if(typeof a!=="string"){b=a;a="fx"}if(b=== -w)return c.queue(this[0],a);return this.each(function(){var d=c.queue(this,a,b);a==="fx"&&d[0]!=="inprogress"&&c.dequeue(this,a)})},dequeue:function(a){return this.each(function(){c.dequeue(this,a)})},delay:function(a,b){a=c.fx?c.fx.speeds[a]||a:a;b=b||"fx";return this.queue(b,function(){var d=this;setTimeout(function(){c.dequeue(d,b)},a)})},clearQueue:function(a){return this.queue(a||"fx",[])}});var Aa=/[\n\t]/g,ca=/\s+/,Za=/\r/g,$a=/href|src|style/,ab=/(button|input)/i,bb=/(button|input|object|select|textarea)/i, -cb=/^(a|area)$/i,Ba=/radio|checkbox/;c.fn.extend({attr:function(a,b){return X(this,a,b,true,c.attr)},removeAttr:function(a){return this.each(function(){c.attr(this,a,"");this.nodeType===1&&this.removeAttribute(a)})},addClass:function(a){if(c.isFunction(a))return this.each(function(n){var r=c(this);r.addClass(a.call(this,n,r.attr("class")))});if(a&&typeof a==="string")for(var b=(a||"").split(ca),d=0,f=this.length;d-1)return true;return false},val:function(a){if(a===w){var b=this[0];if(b){if(c.nodeName(b,"option"))return(b.attributes.value||{}).specified?b.value:b.text;if(c.nodeName(b,"select")){var d=b.selectedIndex,f=[],e=b.options;b=b.type==="select-one";if(d<0)return null;var j=b?d:0;for(d=b?d+1:e.length;j=0;else if(c.nodeName(this,"select")){var u=c.makeArray(r);c("option",this).each(function(){this.selected= -c.inArray(c(this).val(),u)>=0});if(!u.length)this.selectedIndex=-1}else this.value=r}})}});c.extend({attrFn:{val:true,css:true,html:true,text:true,data:true,width:true,height:true,offset:true},attr:function(a,b,d,f){if(!a||a.nodeType===3||a.nodeType===8)return w;if(f&&b in c.attrFn)return c(a)[b](d);f=a.nodeType!==1||!c.isXMLDoc(a);var e=d!==w;b=f&&c.props[b]||b;if(a.nodeType===1){var j=$a.test(b);if(b in a&&f&&!j){if(e){b==="type"&&ab.test(a.nodeName)&&a.parentNode&&c.error("type property can't be changed"); -a[b]=d}if(c.nodeName(a,"form")&&a.getAttributeNode(b))return a.getAttributeNode(b).nodeValue;if(b==="tabIndex")return(b=a.getAttributeNode("tabIndex"))&&b.specified?b.value:bb.test(a.nodeName)||cb.test(a.nodeName)&&a.href?0:w;return a[b]}if(!c.support.style&&f&&b==="style"){if(e)a.style.cssText=""+d;return a.style.cssText}e&&a.setAttribute(b,""+d);a=!c.support.hrefNormalized&&f&&j?a.getAttribute(b,2):a.getAttribute(b);return a===null?w:a}return c.style(a,b,d)}});var O=/\.(.*)$/,db=function(a){return a.replace(/[^\w\s\.\|`]/g, -function(b){return"\\"+b})};c.event={add:function(a,b,d,f){if(!(a.nodeType===3||a.nodeType===8)){if(a.setInterval&&a!==A&&!a.frameElement)a=A;var e,j;if(d.handler){e=d;d=e.handler}if(!d.guid)d.guid=c.guid++;if(j=c.data(a)){var i=j.events=j.events||{},o=j.handle;if(!o)j.handle=o=function(){return typeof c!=="undefined"&&!c.event.triggered?c.event.handle.apply(o.elem,arguments):w};o.elem=a;b=b.split(" ");for(var k,n=0,r;k=b[n++];){j=e?c.extend({},e):{handler:d,data:f};if(k.indexOf(".")>-1){r=k.split("."); -k=r.shift();j.namespace=r.slice(0).sort().join(".")}else{r=[];j.namespace=""}j.type=k;j.guid=d.guid;var u=i[k],z=c.event.special[k]||{};if(!u){u=i[k]=[];if(!z.setup||z.setup.call(a,f,r,o)===false)if(a.addEventListener)a.addEventListener(k,o,false);else a.attachEvent&&a.attachEvent("on"+k,o)}if(z.add){z.add.call(a,j);if(!j.handler.guid)j.handler.guid=d.guid}u.push(j);c.event.global[k]=true}a=null}}},global:{},remove:function(a,b,d,f){if(!(a.nodeType===3||a.nodeType===8)){var e,j=0,i,o,k,n,r,u,z=c.data(a), -C=z&&z.events;if(z&&C){if(b&&b.type){d=b.handler;b=b.type}if(!b||typeof b==="string"&&b.charAt(0)==="."){b=b||"";for(e in C)c.event.remove(a,e+b)}else{for(b=b.split(" ");e=b[j++];){n=e;i=e.indexOf(".")<0;o=[];if(!i){o=e.split(".");e=o.shift();k=new RegExp("(^|\\.)"+c.map(o.slice(0).sort(),db).join("\\.(?:.*\\.)?")+"(\\.|$)")}if(r=C[e])if(d){n=c.event.special[e]||{};for(B=f||0;B=0){a.type= -e=e.slice(0,-1);a.exclusive=true}if(!d){a.stopPropagation();c.event.global[e]&&c.each(c.cache,function(){this.events&&this.events[e]&&c.event.trigger(a,b,this.handle.elem)})}if(!d||d.nodeType===3||d.nodeType===8)return w;a.result=w;a.target=d;b=c.makeArray(b);b.unshift(a)}a.currentTarget=d;(f=c.data(d,"handle"))&&f.apply(d,b);f=d.parentNode||d.ownerDocument;try{if(!(d&&d.nodeName&&c.noData[d.nodeName.toLowerCase()]))if(d["on"+e]&&d["on"+e].apply(d,b)===false)a.result=false}catch(j){}if(!a.isPropagationStopped()&& -f)c.event.trigger(a,b,f,true);else if(!a.isDefaultPrevented()){f=a.target;var i,o=c.nodeName(f,"a")&&e==="click",k=c.event.special[e]||{};if((!k._default||k._default.call(d,a)===false)&&!o&&!(f&&f.nodeName&&c.noData[f.nodeName.toLowerCase()])){try{if(f[e]){if(i=f["on"+e])f["on"+e]=null;c.event.triggered=true;f[e]()}}catch(n){}if(i)f["on"+e]=i;c.event.triggered=false}}},handle:function(a){var b,d,f,e;a=arguments[0]=c.event.fix(a||A.event);a.currentTarget=this;b=a.type.indexOf(".")<0&&!a.exclusive; -if(!b){d=a.type.split(".");a.type=d.shift();f=new RegExp("(^|\\.)"+d.slice(0).sort().join("\\.(?:.*\\.)?")+"(\\.|$)")}e=c.data(this,"events");d=e[a.type];if(e&&d){d=d.slice(0);e=0;for(var j=d.length;e-1?c.map(a.options,function(f){return f.selected}).join("-"):"";else if(a.nodeName.toLowerCase()==="select")d=a.selectedIndex;return d},fa=function(a,b){var d=a.target,f,e;if(!(!da.test(d.nodeName)||d.readOnly)){f=c.data(d,"_change_data");e=Fa(d);if(a.type!=="focusout"||d.type!=="radio")c.data(d,"_change_data", -e);if(!(f===w||e===f))if(f!=null||e){a.type="change";return c.event.trigger(a,b,d)}}};c.event.special.change={filters:{focusout:fa,click:function(a){var b=a.target,d=b.type;if(d==="radio"||d==="checkbox"||b.nodeName.toLowerCase()==="select")return fa.call(this,a)},keydown:function(a){var b=a.target,d=b.type;if(a.keyCode===13&&b.nodeName.toLowerCase()!=="textarea"||a.keyCode===32&&(d==="checkbox"||d==="radio")||d==="select-multiple")return fa.call(this,a)},beforeactivate:function(a){a=a.target;c.data(a, -"_change_data",Fa(a))}},setup:function(){if(this.type==="file")return false;for(var a in ea)c.event.add(this,a+".specialChange",ea[a]);return da.test(this.nodeName)},teardown:function(){c.event.remove(this,".specialChange");return da.test(this.nodeName)}};ea=c.event.special.change.filters}s.addEventListener&&c.each({focus:"focusin",blur:"focusout"},function(a,b){function d(f){f=c.event.fix(f);f.type=b;return c.event.handle.call(this,f)}c.event.special[b]={setup:function(){this.addEventListener(a, -d,true)},teardown:function(){this.removeEventListener(a,d,true)}}});c.each(["bind","one"],function(a,b){c.fn[b]=function(d,f,e){if(typeof d==="object"){for(var j in d)this[b](j,f,d[j],e);return this}if(c.isFunction(f)){e=f;f=w}var i=b==="one"?c.proxy(e,function(k){c(this).unbind(k,i);return e.apply(this,arguments)}):e;if(d==="unload"&&b!=="one")this.one(d,f,e);else{j=0;for(var o=this.length;j0){y=t;break}}t=t[g]}m[q]=y}}}var f=/((?:\((?:\([^()]+\)|[^()]+)+\)|\[(?:\[[^[\]]*\]|['"][^'"]*['"]|[^[\]'"]+)+\]|\\.|[^ >+~,(\[\\]+)+|[>+~])(\s*,\s*)?((?:.|\r|\n)*)/g, -e=0,j=Object.prototype.toString,i=false,o=true;[0,0].sort(function(){o=false;return 0});var k=function(g,h,l,m){l=l||[];var q=h=h||s;if(h.nodeType!==1&&h.nodeType!==9)return[];if(!g||typeof g!=="string")return l;for(var p=[],v,t,y,S,H=true,M=x(h),I=g;(f.exec(""),v=f.exec(I))!==null;){I=v[3];p.push(v[1]);if(v[2]){S=v[3];break}}if(p.length>1&&r.exec(g))if(p.length===2&&n.relative[p[0]])t=ga(p[0]+p[1],h);else for(t=n.relative[p[0]]?[h]:k(p.shift(),h);p.length;){g=p.shift();if(n.relative[g])g+=p.shift(); -t=ga(g,t)}else{if(!m&&p.length>1&&h.nodeType===9&&!M&&n.match.ID.test(p[0])&&!n.match.ID.test(p[p.length-1])){v=k.find(p.shift(),h,M);h=v.expr?k.filter(v.expr,v.set)[0]:v.set[0]}if(h){v=m?{expr:p.pop(),set:z(m)}:k.find(p.pop(),p.length===1&&(p[0]==="~"||p[0]==="+")&&h.parentNode?h.parentNode:h,M);t=v.expr?k.filter(v.expr,v.set):v.set;if(p.length>0)y=z(t);else H=false;for(;p.length;){var D=p.pop();v=D;if(n.relative[D])v=p.pop();else D="";if(v==null)v=h;n.relative[D](y,v,M)}}else y=[]}y||(y=t);y||k.error(D|| -g);if(j.call(y)==="[object Array]")if(H)if(h&&h.nodeType===1)for(g=0;y[g]!=null;g++){if(y[g]&&(y[g]===true||y[g].nodeType===1&&E(h,y[g])))l.push(t[g])}else for(g=0;y[g]!=null;g++)y[g]&&y[g].nodeType===1&&l.push(t[g]);else l.push.apply(l,y);else z(y,l);if(S){k(S,q,l,m);k.uniqueSort(l)}return l};k.uniqueSort=function(g){if(B){i=o;g.sort(B);if(i)for(var h=1;h":function(g,h){var l=typeof h==="string";if(l&&!/\W/.test(h)){h=h.toLowerCase();for(var m=0,q=g.length;m=0))l||m.push(v);else if(l)h[p]=false;return false},ID:function(g){return g[1].replace(/\\/g,"")},TAG:function(g){return g[1].toLowerCase()}, -CHILD:function(g){if(g[1]==="nth"){var h=/(-?)(\d*)n((?:\+|-)?\d*)/.exec(g[2]==="even"&&"2n"||g[2]==="odd"&&"2n+1"||!/\D/.test(g[2])&&"0n+"+g[2]||g[2]);g[2]=h[1]+(h[2]||1)-0;g[3]=h[3]-0}g[0]=e++;return g},ATTR:function(g,h,l,m,q,p){h=g[1].replace(/\\/g,"");if(!p&&n.attrMap[h])g[1]=n.attrMap[h];if(g[2]==="~=")g[4]=" "+g[4]+" ";return g},PSEUDO:function(g,h,l,m,q){if(g[1]==="not")if((f.exec(g[3])||"").length>1||/^\w/.test(g[3]))g[3]=k(g[3],null,null,h);else{g=k.filter(g[3],h,l,true^q);l||m.push.apply(m, -g);return false}else if(n.match.POS.test(g[0])||n.match.CHILD.test(g[0]))return true;return g},POS:function(g){g.unshift(true);return g}},filters:{enabled:function(g){return g.disabled===false&&g.type!=="hidden"},disabled:function(g){return g.disabled===true},checked:function(g){return g.checked===true},selected:function(g){return g.selected===true},parent:function(g){return!!g.firstChild},empty:function(g){return!g.firstChild},has:function(g,h,l){return!!k(l[3],g).length},header:function(g){return/h\d/i.test(g.nodeName)}, -text:function(g){return"text"===g.type},radio:function(g){return"radio"===g.type},checkbox:function(g){return"checkbox"===g.type},file:function(g){return"file"===g.type},password:function(g){return"password"===g.type},submit:function(g){return"submit"===g.type},image:function(g){return"image"===g.type},reset:function(g){return"reset"===g.type},button:function(g){return"button"===g.type||g.nodeName.toLowerCase()==="button"},input:function(g){return/input|select|textarea|button/i.test(g.nodeName)}}, -setFilters:{first:function(g,h){return h===0},last:function(g,h,l,m){return h===m.length-1},even:function(g,h){return h%2===0},odd:function(g,h){return h%2===1},lt:function(g,h,l){return hl[3]-0},nth:function(g,h,l){return l[3]-0===h},eq:function(g,h,l){return l[3]-0===h}},filter:{PSEUDO:function(g,h,l,m){var q=h[1],p=n.filters[q];if(p)return p(g,l,h,m);else if(q==="contains")return(g.textContent||g.innerText||a([g])||"").indexOf(h[3])>=0;else if(q==="not"){h= -h[3];l=0;for(m=h.length;l=0}},ID:function(g,h){return g.nodeType===1&&g.getAttribute("id")===h},TAG:function(g,h){return h==="*"&&g.nodeType===1||g.nodeName.toLowerCase()===h},CLASS:function(g,h){return(" "+(g.className||g.getAttribute("class"))+" ").indexOf(h)>-1},ATTR:function(g,h){var l=h[1];g=n.attrHandle[l]?n.attrHandle[l](g):g[l]!=null?g[l]:g.getAttribute(l);l=g+"";var m=h[2];h=h[4];return g==null?m==="!=":m=== -"="?l===h:m==="*="?l.indexOf(h)>=0:m==="~="?(" "+l+" ").indexOf(h)>=0:!h?l&&g!==false:m==="!="?l!==h:m==="^="?l.indexOf(h)===0:m==="$="?l.substr(l.length-h.length)===h:m==="|="?l===h||l.substr(0,h.length+1)===h+"-":false},POS:function(g,h,l,m){var q=n.setFilters[h[2]];if(q)return q(g,l,h,m)}}},r=n.match.POS;for(var u in n.match){n.match[u]=new RegExp(n.match[u].source+/(?![^\[]*\])(?![^\(]*\))/.source);n.leftMatch[u]=new RegExp(/(^(?:.|\r|\n)*?)/.source+n.match[u].source.replace(/\\(\d+)/g,function(g, -h){return"\\"+(h-0+1)}))}var z=function(g,h){g=Array.prototype.slice.call(g,0);if(h){h.push.apply(h,g);return h}return g};try{Array.prototype.slice.call(s.documentElement.childNodes,0)}catch(C){z=function(g,h){h=h||[];if(j.call(g)==="[object Array]")Array.prototype.push.apply(h,g);else if(typeof g.length==="number")for(var l=0,m=g.length;l";var l=s.documentElement;l.insertBefore(g,l.firstChild);if(s.getElementById(h)){n.find.ID=function(m,q,p){if(typeof q.getElementById!=="undefined"&&!p)return(q=q.getElementById(m[1]))?q.id===m[1]||typeof q.getAttributeNode!=="undefined"&& -q.getAttributeNode("id").nodeValue===m[1]?[q]:w:[]};n.filter.ID=function(m,q){var p=typeof m.getAttributeNode!=="undefined"&&m.getAttributeNode("id");return m.nodeType===1&&p&&p.nodeValue===q}}l.removeChild(g);l=g=null})();(function(){var g=s.createElement("div");g.appendChild(s.createComment(""));if(g.getElementsByTagName("*").length>0)n.find.TAG=function(h,l){l=l.getElementsByTagName(h[1]);if(h[1]==="*"){h=[];for(var m=0;l[m];m++)l[m].nodeType===1&&h.push(l[m]);l=h}return l};g.innerHTML=""; -if(g.firstChild&&typeof g.firstChild.getAttribute!=="undefined"&&g.firstChild.getAttribute("href")!=="#")n.attrHandle.href=function(h){return h.getAttribute("href",2)};g=null})();s.querySelectorAll&&function(){var g=k,h=s.createElement("div");h.innerHTML="

";if(!(h.querySelectorAll&&h.querySelectorAll(".TEST").length===0)){k=function(m,q,p,v){q=q||s;if(!v&&q.nodeType===9&&!x(q))try{return z(q.querySelectorAll(m),p)}catch(t){}return g(m,q,p,v)};for(var l in g)k[l]=g[l];h=null}}(); -(function(){var g=s.createElement("div");g.innerHTML="
";if(!(!g.getElementsByClassName||g.getElementsByClassName("e").length===0)){g.lastChild.className="e";if(g.getElementsByClassName("e").length!==1){n.order.splice(1,0,"CLASS");n.find.CLASS=function(h,l,m){if(typeof l.getElementsByClassName!=="undefined"&&!m)return l.getElementsByClassName(h[1])};g=null}}})();var E=s.compareDocumentPosition?function(g,h){return!!(g.compareDocumentPosition(h)&16)}: -function(g,h){return g!==h&&(g.contains?g.contains(h):true)},x=function(g){return(g=(g?g.ownerDocument||g:0).documentElement)?g.nodeName!=="HTML":false},ga=function(g,h){var l=[],m="",q;for(h=h.nodeType?[h]:h;q=n.match.PSEUDO.exec(g);){m+=q[0];g=g.replace(n.match.PSEUDO,"")}g=n.relative[g]?g+"*":g;q=0;for(var p=h.length;q=0===d})};c.fn.extend({find:function(a){for(var b=this.pushStack("","find",a),d=0,f=0,e=this.length;f0)for(var j=d;j0},closest:function(a,b){if(c.isArray(a)){var d=[],f=this[0],e,j= -{},i;if(f&&a.length){e=0;for(var o=a.length;e-1:c(f).is(e)){d.push({selector:i,elem:f});delete j[i]}}f=f.parentNode}}return d}var k=c.expr.match.POS.test(a)?c(a,b||this.context):null;return this.map(function(n,r){for(;r&&r.ownerDocument&&r!==b;){if(k?k.index(r)>-1:c(r).is(a))return r;r=r.parentNode}return null})},index:function(a){if(!a||typeof a=== -"string")return c.inArray(this[0],a?c(a):this.parent().children());return c.inArray(a.jquery?a[0]:a,this)},add:function(a,b){a=typeof a==="string"?c(a,b||this.context):c.makeArray(a);b=c.merge(this.get(),a);return this.pushStack(qa(a[0])||qa(b[0])?b:c.unique(b))},andSelf:function(){return this.add(this.prevObject)}});c.each({parent:function(a){return(a=a.parentNode)&&a.nodeType!==11?a:null},parents:function(a){return c.dir(a,"parentNode")},parentsUntil:function(a,b,d){return c.dir(a,"parentNode", -d)},next:function(a){return c.nth(a,2,"nextSibling")},prev:function(a){return c.nth(a,2,"previousSibling")},nextAll:function(a){return c.dir(a,"nextSibling")},prevAll:function(a){return c.dir(a,"previousSibling")},nextUntil:function(a,b,d){return c.dir(a,"nextSibling",d)},prevUntil:function(a,b,d){return c.dir(a,"previousSibling",d)},siblings:function(a){return c.sibling(a.parentNode.firstChild,a)},children:function(a){return c.sibling(a.firstChild)},contents:function(a){return c.nodeName(a,"iframe")? -a.contentDocument||a.contentWindow.document:c.makeArray(a.childNodes)}},function(a,b){c.fn[a]=function(d,f){var e=c.map(this,b,d);eb.test(a)||(f=d);if(f&&typeof f==="string")e=c.filter(f,e);e=this.length>1?c.unique(e):e;if((this.length>1||gb.test(f))&&fb.test(a))e=e.reverse();return this.pushStack(e,a,R.call(arguments).join(","))}});c.extend({filter:function(a,b,d){if(d)a=":not("+a+")";return c.find.matches(a,b)},dir:function(a,b,d){var f=[];for(a=a[b];a&&a.nodeType!==9&&(d===w||a.nodeType!==1||!c(a).is(d));){a.nodeType=== -1&&f.push(a);a=a[b]}return f},nth:function(a,b,d){b=b||1;for(var f=0;a;a=a[d])if(a.nodeType===1&&++f===b)break;return a},sibling:function(a,b){for(var d=[];a;a=a.nextSibling)a.nodeType===1&&a!==b&&d.push(a);return d}});var Ja=/ jQuery\d+="(?:\d+|null)"/g,V=/^\s+/,Ka=/(<([\w:]+)[^>]*?)\/>/g,hb=/^(?:area|br|col|embed|hr|img|input|link|meta|param)$/i,La=/<([\w:]+)/,ib=/"},F={option:[1,""],legend:[1,"
","
"],thead:[1,"","
"],tr:[2,"","
"],td:[3,"","
"],col:[2,"","
"],area:[1,"",""],_default:[0,"",""]};F.optgroup=F.option;F.tbody=F.tfoot=F.colgroup=F.caption=F.thead;F.th=F.td;if(!c.support.htmlSerialize)F._default=[1,"div
","
"];c.fn.extend({text:function(a){if(c.isFunction(a))return this.each(function(b){var d= -c(this);d.text(a.call(this,b,d.text()))});if(typeof a!=="object"&&a!==w)return this.empty().append((this[0]&&this[0].ownerDocument||s).createTextNode(a));return c.text(this)},wrapAll:function(a){if(c.isFunction(a))return this.each(function(d){c(this).wrapAll(a.call(this,d))});if(this[0]){var b=c(a,this[0].ownerDocument).eq(0).clone(true);this[0].parentNode&&b.insertBefore(this[0]);b.map(function(){for(var d=this;d.firstChild&&d.firstChild.nodeType===1;)d=d.firstChild;return d}).append(this)}return this}, -wrapInner:function(a){if(c.isFunction(a))return this.each(function(b){c(this).wrapInner(a.call(this,b))});return this.each(function(){var b=c(this),d=b.contents();d.length?d.wrapAll(a):b.append(a)})},wrap:function(a){return this.each(function(){c(this).wrapAll(a)})},unwrap:function(){return this.parent().each(function(){c.nodeName(this,"body")||c(this).replaceWith(this.childNodes)}).end()},append:function(){return this.domManip(arguments,true,function(a){this.nodeType===1&&this.appendChild(a)})}, -prepend:function(){return this.domManip(arguments,true,function(a){this.nodeType===1&&this.insertBefore(a,this.firstChild)})},before:function(){if(this[0]&&this[0].parentNode)return this.domManip(arguments,false,function(b){this.parentNode.insertBefore(b,this)});else if(arguments.length){var a=c(arguments[0]);a.push.apply(a,this.toArray());return this.pushStack(a,"before",arguments)}},after:function(){if(this[0]&&this[0].parentNode)return this.domManip(arguments,false,function(b){this.parentNode.insertBefore(b, -this.nextSibling)});else if(arguments.length){var a=this.pushStack(this,"after",arguments);a.push.apply(a,c(arguments[0]).toArray());return a}},remove:function(a,b){for(var d=0,f;(f=this[d])!=null;d++)if(!a||c.filter(a,[f]).length){if(!b&&f.nodeType===1){c.cleanData(f.getElementsByTagName("*"));c.cleanData([f])}f.parentNode&&f.parentNode.removeChild(f)}return this},empty:function(){for(var a=0,b;(b=this[a])!=null;a++)for(b.nodeType===1&&c.cleanData(b.getElementsByTagName("*"));b.firstChild;)b.removeChild(b.firstChild); -return this},clone:function(a){var b=this.map(function(){if(!c.support.noCloneEvent&&!c.isXMLDoc(this)){var d=this.outerHTML,f=this.ownerDocument;if(!d){d=f.createElement("div");d.appendChild(this.cloneNode(true));d=d.innerHTML}return c.clean([d.replace(Ja,"").replace(/=([^="'>\s]+\/)>/g,'="$1">').replace(V,"")],f)[0]}else return this.cloneNode(true)});if(a===true){ra(this,b);ra(this.find("*"),b.find("*"))}return b},html:function(a){if(a===w)return this[0]&&this[0].nodeType===1?this[0].innerHTML.replace(Ja, -""):null;else if(typeof a==="string"&&!ta.test(a)&&(c.support.leadingWhitespace||!V.test(a))&&!F[(La.exec(a)||["",""])[1].toLowerCase()]){a=a.replace(Ka,Ma);try{for(var b=0,d=this.length;b0||e.cacheable||this.length>1?k.cloneNode(true):k)}o.length&&c.each(o,Qa)}return this}});c.fragments={};c.each({appendTo:"append",prependTo:"prepend",insertBefore:"before",insertAfter:"after",replaceAll:"replaceWith"},function(a,b){c.fn[a]=function(d){var f=[];d=c(d);var e=this.length===1&&this[0].parentNode;if(e&&e.nodeType===11&&e.childNodes.length===1&&d.length===1){d[b](this[0]); -return this}else{e=0;for(var j=d.length;e0?this.clone(true):this).get();c.fn[b].apply(c(d[e]),i);f=f.concat(i)}return this.pushStack(f,a,d.selector)}}});c.extend({clean:function(a,b,d,f){b=b||s;if(typeof b.createElement==="undefined")b=b.ownerDocument||b[0]&&b[0].ownerDocument||s;for(var e=[],j=0,i;(i=a[j])!=null;j++){if(typeof i==="number")i+="";if(i){if(typeof i==="string"&&!jb.test(i))i=b.createTextNode(i);else if(typeof i==="string"){i=i.replace(Ka,Ma);var o=(La.exec(i)||["", -""])[1].toLowerCase(),k=F[o]||F._default,n=k[0],r=b.createElement("div");for(r.innerHTML=k[1]+i+k[2];n--;)r=r.lastChild;if(!c.support.tbody){n=ib.test(i);o=o==="table"&&!n?r.firstChild&&r.firstChild.childNodes:k[1]===""&&!n?r.childNodes:[];for(k=o.length-1;k>=0;--k)c.nodeName(o[k],"tbody")&&!o[k].childNodes.length&&o[k].parentNode.removeChild(o[k])}!c.support.leadingWhitespace&&V.test(i)&&r.insertBefore(b.createTextNode(V.exec(i)[0]),r.firstChild);i=r.childNodes}if(i.nodeType)e.push(i);else e= -c.merge(e,i)}}if(d)for(j=0;e[j];j++)if(f&&c.nodeName(e[j],"script")&&(!e[j].type||e[j].type.toLowerCase()==="text/javascript"))f.push(e[j].parentNode?e[j].parentNode.removeChild(e[j]):e[j]);else{e[j].nodeType===1&&e.splice.apply(e,[j+1,0].concat(c.makeArray(e[j].getElementsByTagName("script"))));d.appendChild(e[j])}return e},cleanData:function(a){for(var b,d,f=c.cache,e=c.event.special,j=c.support.deleteExpando,i=0,o;(o=a[i])!=null;i++)if(d=o[c.expando]){b=f[d];if(b.events)for(var k in b.events)e[k]? -c.event.remove(o,k):Ca(o,k,b.handle);if(j)delete o[c.expando];else o.removeAttribute&&o.removeAttribute(c.expando);delete f[d]}}});var kb=/z-?index|font-?weight|opacity|zoom|line-?height/i,Na=/alpha\([^)]*\)/,Oa=/opacity=([^)]*)/,ha=/float/i,ia=/-([a-z])/ig,lb=/([A-Z])/g,mb=/^-?\d+(?:px)?$/i,nb=/^-?\d/,ob={position:"absolute",visibility:"hidden",display:"block"},pb=["Left","Right"],qb=["Top","Bottom"],rb=s.defaultView&&s.defaultView.getComputedStyle,Pa=c.support.cssFloat?"cssFloat":"styleFloat",ja= -function(a,b){return b.toUpperCase()};c.fn.css=function(a,b){return X(this,a,b,true,function(d,f,e){if(e===w)return c.curCSS(d,f);if(typeof e==="number"&&!kb.test(f))e+="px";c.style(d,f,e)})};c.extend({style:function(a,b,d){if(!a||a.nodeType===3||a.nodeType===8)return w;if((b==="width"||b==="height")&&parseFloat(d)<0)d=w;var f=a.style||a,e=d!==w;if(!c.support.opacity&&b==="opacity"){if(e){f.zoom=1;b=parseInt(d,10)+""==="NaN"?"":"alpha(opacity="+d*100+")";a=f.filter||c.curCSS(a,"filter")||"";f.filter= -Na.test(a)?a.replace(Na,b):b}return f.filter&&f.filter.indexOf("opacity=")>=0?parseFloat(Oa.exec(f.filter)[1])/100+"":""}if(ha.test(b))b=Pa;b=b.replace(ia,ja);if(e)f[b]=d;return f[b]},css:function(a,b,d,f){if(b==="width"||b==="height"){var e,j=b==="width"?pb:qb;function i(){e=b==="width"?a.offsetWidth:a.offsetHeight;f!=="border"&&c.each(j,function(){f||(e-=parseFloat(c.curCSS(a,"padding"+this,true))||0);if(f==="margin")e+=parseFloat(c.curCSS(a,"margin"+this,true))||0;else e-=parseFloat(c.curCSS(a, -"border"+this+"Width",true))||0})}a.offsetWidth!==0?i():c.swap(a,ob,i);return Math.max(0,Math.round(e))}return c.curCSS(a,b,d)},curCSS:function(a,b,d){var f,e=a.style;if(!c.support.opacity&&b==="opacity"&&a.currentStyle){f=Oa.test(a.currentStyle.filter||"")?parseFloat(RegExp.$1)/100+"":"";return f===""?"1":f}if(ha.test(b))b=Pa;if(!d&&e&&e[b])f=e[b];else if(rb){if(ha.test(b))b="float";b=b.replace(lb,"-$1").toLowerCase();e=a.ownerDocument.defaultView;if(!e)return null;if(a=e.getComputedStyle(a,null))f= -a.getPropertyValue(b);if(b==="opacity"&&f==="")f="1"}else if(a.currentStyle){d=b.replace(ia,ja);f=a.currentStyle[b]||a.currentStyle[d];if(!mb.test(f)&&nb.test(f)){b=e.left;var j=a.runtimeStyle.left;a.runtimeStyle.left=a.currentStyle.left;e.left=d==="fontSize"?"1em":f||0;f=e.pixelLeft+"px";e.left=b;a.runtimeStyle.left=j}}return f},swap:function(a,b,d){var f={};for(var e in b){f[e]=a.style[e];a.style[e]=b[e]}d.call(a);for(e in b)a.style[e]=f[e]}});if(c.expr&&c.expr.filters){c.expr.filters.hidden=function(a){var b= -a.offsetWidth,d=a.offsetHeight,f=a.nodeName.toLowerCase()==="tr";return b===0&&d===0&&!f?true:b>0&&d>0&&!f?false:c.curCSS(a,"display")==="none"};c.expr.filters.visible=function(a){return!c.expr.filters.hidden(a)}}var sb=J(),tb=//gi,ub=/select|textarea/i,vb=/color|date|datetime|email|hidden|month|number|password|range|search|tel|text|time|url|week/i,N=/=\?(&|$)/,ka=/\?/,wb=/(\?|&)_=.*?(&|$)/,xb=/^(\w+:)?\/\/([^\/?#]+)/,yb=/%20/g,zb=c.fn.load;c.fn.extend({load:function(a,b,d){if(typeof a!== -"string")return zb.call(this,a);else if(!this.length)return this;var f=a.indexOf(" ");if(f>=0){var e=a.slice(f,a.length);a=a.slice(0,f)}f="GET";if(b)if(c.isFunction(b)){d=b;b=null}else if(typeof b==="object"){b=c.param(b,c.ajaxSettings.traditional);f="POST"}var j=this;c.ajax({url:a,type:f,dataType:"html",data:b,complete:function(i,o){if(o==="success"||o==="notmodified")j.html(e?c("
").append(i.responseText.replace(tb,"")).find(e):i.responseText);d&&j.each(d,[i.responseText,o,i])}});return this}, -serialize:function(){return c.param(this.serializeArray())},serializeArray:function(){return this.map(function(){return this.elements?c.makeArray(this.elements):this}).filter(function(){return this.name&&!this.disabled&&(this.checked||ub.test(this.nodeName)||vb.test(this.type))}).map(function(a,b){a=c(this).val();return a==null?null:c.isArray(a)?c.map(a,function(d){return{name:b.name,value:d}}):{name:b.name,value:a}}).get()}});c.each("ajaxStart ajaxStop ajaxComplete ajaxError ajaxSuccess ajaxSend".split(" "), -function(a,b){c.fn[b]=function(d){return this.bind(b,d)}});c.extend({get:function(a,b,d,f){if(c.isFunction(b)){f=f||d;d=b;b=null}return c.ajax({type:"GET",url:a,data:b,success:d,dataType:f})},getScript:function(a,b){return c.get(a,null,b,"script")},getJSON:function(a,b,d){return c.get(a,b,d,"json")},post:function(a,b,d,f){if(c.isFunction(b)){f=f||d;d=b;b={}}return c.ajax({type:"POST",url:a,data:b,success:d,dataType:f})},ajaxSetup:function(a){c.extend(c.ajaxSettings,a)},ajaxSettings:{url:location.href, -global:true,type:"GET",contentType:"application/x-www-form-urlencoded",processData:true,async:true,xhr:A.XMLHttpRequest&&(A.location.protocol!=="file:"||!A.ActiveXObject)?function(){return new A.XMLHttpRequest}:function(){try{return new A.ActiveXObject("Microsoft.XMLHTTP")}catch(a){}},accepts:{xml:"application/xml, text/xml",html:"text/html",script:"text/javascript, application/javascript",json:"application/json, text/javascript",text:"text/plain",_default:"*/*"}},lastModified:{},etag:{},ajax:function(a){function b(){e.success&& -e.success.call(k,o,i,x);e.global&&f("ajaxSuccess",[x,e])}function d(){e.complete&&e.complete.call(k,x,i);e.global&&f("ajaxComplete",[x,e]);e.global&&!--c.active&&c.event.trigger("ajaxStop")}function f(q,p){(e.context?c(e.context):c.event).trigger(q,p)}var e=c.extend(true,{},c.ajaxSettings,a),j,i,o,k=a&&a.context||e,n=e.type.toUpperCase();if(e.data&&e.processData&&typeof e.data!=="string")e.data=c.param(e.data,e.traditional);if(e.dataType==="jsonp"){if(n==="GET")N.test(e.url)||(e.url+=(ka.test(e.url)? -"&":"?")+(e.jsonp||"callback")+"=?");else if(!e.data||!N.test(e.data))e.data=(e.data?e.data+"&":"")+(e.jsonp||"callback")+"=?";e.dataType="json"}if(e.dataType==="json"&&(e.data&&N.test(e.data)||N.test(e.url))){j=e.jsonpCallback||"jsonp"+sb++;if(e.data)e.data=(e.data+"").replace(N,"="+j+"$1");e.url=e.url.replace(N,"="+j+"$1");e.dataType="script";A[j]=A[j]||function(q){o=q;b();d();A[j]=w;try{delete A[j]}catch(p){}z&&z.removeChild(C)}}if(e.dataType==="script"&&e.cache===null)e.cache=false;if(e.cache=== -false&&n==="GET"){var r=J(),u=e.url.replace(wb,"$1_="+r+"$2");e.url=u+(u===e.url?(ka.test(e.url)?"&":"?")+"_="+r:"")}if(e.data&&n==="GET")e.url+=(ka.test(e.url)?"&":"?")+e.data;e.global&&!c.active++&&c.event.trigger("ajaxStart");r=(r=xb.exec(e.url))&&(r[1]&&r[1]!==location.protocol||r[2]!==location.host);if(e.dataType==="script"&&n==="GET"&&r){var z=s.getElementsByTagName("head")[0]||s.documentElement,C=s.createElement("script");C.src=e.url;if(e.scriptCharset)C.charset=e.scriptCharset;if(!j){var B= -false;C.onload=C.onreadystatechange=function(){if(!B&&(!this.readyState||this.readyState==="loaded"||this.readyState==="complete")){B=true;b();d();C.onload=C.onreadystatechange=null;z&&C.parentNode&&z.removeChild(C)}}}z.insertBefore(C,z.firstChild);return w}var E=false,x=e.xhr();if(x){e.username?x.open(n,e.url,e.async,e.username,e.password):x.open(n,e.url,e.async);try{if(e.data||a&&a.contentType)x.setRequestHeader("Content-Type",e.contentType);if(e.ifModified){c.lastModified[e.url]&&x.setRequestHeader("If-Modified-Since", -c.lastModified[e.url]);c.etag[e.url]&&x.setRequestHeader("If-None-Match",c.etag[e.url])}r||x.setRequestHeader("X-Requested-With","XMLHttpRequest");x.setRequestHeader("Accept",e.dataType&&e.accepts[e.dataType]?e.accepts[e.dataType]+", */*":e.accepts._default)}catch(ga){}if(e.beforeSend&&e.beforeSend.call(k,x,e)===false){e.global&&!--c.active&&c.event.trigger("ajaxStop");x.abort();return false}e.global&&f("ajaxSend",[x,e]);var g=x.onreadystatechange=function(q){if(!x||x.readyState===0||q==="abort"){E|| -d();E=true;if(x)x.onreadystatechange=c.noop}else if(!E&&x&&(x.readyState===4||q==="timeout")){E=true;x.onreadystatechange=c.noop;i=q==="timeout"?"timeout":!c.httpSuccess(x)?"error":e.ifModified&&c.httpNotModified(x,e.url)?"notmodified":"success";var p;if(i==="success")try{o=c.httpData(x,e.dataType,e)}catch(v){i="parsererror";p=v}if(i==="success"||i==="notmodified")j||b();else c.handleError(e,x,i,p);d();q==="timeout"&&x.abort();if(e.async)x=null}};try{var h=x.abort;x.abort=function(){x&&h.call(x); -g("abort")}}catch(l){}e.async&&e.timeout>0&&setTimeout(function(){x&&!E&&g("timeout")},e.timeout);try{x.send(n==="POST"||n==="PUT"||n==="DELETE"?e.data:null)}catch(m){c.handleError(e,x,null,m);d()}e.async||g();return x}},handleError:function(a,b,d,f){if(a.error)a.error.call(a.context||a,b,d,f);if(a.global)(a.context?c(a.context):c.event).trigger("ajaxError",[b,a,f])},active:0,httpSuccess:function(a){try{return!a.status&&location.protocol==="file:"||a.status>=200&&a.status<300||a.status===304||a.status=== -1223||a.status===0}catch(b){}return false},httpNotModified:function(a,b){var d=a.getResponseHeader("Last-Modified"),f=a.getResponseHeader("Etag");if(d)c.lastModified[b]=d;if(f)c.etag[b]=f;return a.status===304||a.status===0},httpData:function(a,b,d){var f=a.getResponseHeader("content-type")||"",e=b==="xml"||!b&&f.indexOf("xml")>=0;a=e?a.responseXML:a.responseText;e&&a.documentElement.nodeName==="parsererror"&&c.error("parsererror");if(d&&d.dataFilter)a=d.dataFilter(a,b);if(typeof a==="string")if(b=== -"json"||!b&&f.indexOf("json")>=0)a=c.parseJSON(a);else if(b==="script"||!b&&f.indexOf("javascript")>=0)c.globalEval(a);return a},param:function(a,b){function d(i,o){if(c.isArray(o))c.each(o,function(k,n){b||/\[\]$/.test(i)?f(i,n):d(i+"["+(typeof n==="object"||c.isArray(n)?k:"")+"]",n)});else!b&&o!=null&&typeof o==="object"?c.each(o,function(k,n){d(i+"["+k+"]",n)}):f(i,o)}function f(i,o){o=c.isFunction(o)?o():o;e[e.length]=encodeURIComponent(i)+"="+encodeURIComponent(o)}var e=[];if(b===w)b=c.ajaxSettings.traditional; -if(c.isArray(a)||a.jquery)c.each(a,function(){f(this.name,this.value)});else for(var j in a)d(j,a[j]);return e.join("&").replace(yb,"+")}});var la={},Ab=/toggle|show|hide/,Bb=/^([+-]=)?([\d+-.]+)(.*)$/,W,va=[["height","marginTop","marginBottom","paddingTop","paddingBottom"],["width","marginLeft","marginRight","paddingLeft","paddingRight"],["opacity"]];c.fn.extend({show:function(a,b){if(a||a===0)return this.animate(K("show",3),a,b);else{a=0;for(b=this.length;a").appendTo("body");f=e.css("display");if(f==="none")f="block";e.remove();la[d]=f}c.data(this[a],"olddisplay",f)}}a=0;for(b=this.length;a=0;f--)if(d[f].elem===this){b&&d[f](true);d.splice(f,1)}});b||this.dequeue();return this}});c.each({slideDown:K("show",1),slideUp:K("hide",1),slideToggle:K("toggle",1),fadeIn:{opacity:"show"},fadeOut:{opacity:"hide"}},function(a,b){c.fn[a]=function(d,f){return this.animate(b,d,f)}});c.extend({speed:function(a,b,d){var f=a&&typeof a==="object"?a:{complete:d||!d&&b||c.isFunction(a)&&a,duration:a,easing:d&&b||b&&!c.isFunction(b)&&b};f.duration=c.fx.off?0:typeof f.duration=== -"number"?f.duration:c.fx.speeds[f.duration]||c.fx.speeds._default;f.old=f.complete;f.complete=function(){f.queue!==false&&c(this).dequeue();c.isFunction(f.old)&&f.old.call(this)};return f},easing:{linear:function(a,b,d,f){return d+f*a},swing:function(a,b,d,f){return(-Math.cos(a*Math.PI)/2+0.5)*f+d}},timers:[],fx:function(a,b,d){this.options=b;this.elem=a;this.prop=d;if(!b.orig)b.orig={}}});c.fx.prototype={update:function(){this.options.step&&this.options.step.call(this.elem,this.now,this);(c.fx.step[this.prop]|| -c.fx.step._default)(this);if((this.prop==="height"||this.prop==="width")&&this.elem.style)this.elem.style.display="block"},cur:function(a){if(this.elem[this.prop]!=null&&(!this.elem.style||this.elem.style[this.prop]==null))return this.elem[this.prop];return(a=parseFloat(c.css(this.elem,this.prop,a)))&&a>-10000?a:parseFloat(c.curCSS(this.elem,this.prop))||0},custom:function(a,b,d){function f(j){return e.step(j)}this.startTime=J();this.start=a;this.end=b;this.unit=d||this.unit||"px";this.now=this.start; -this.pos=this.state=0;var e=this;f.elem=this.elem;if(f()&&c.timers.push(f)&&!W)W=setInterval(c.fx.tick,13)},show:function(){this.options.orig[this.prop]=c.style(this.elem,this.prop);this.options.show=true;this.custom(this.prop==="width"||this.prop==="height"?1:0,this.cur());c(this.elem).show()},hide:function(){this.options.orig[this.prop]=c.style(this.elem,this.prop);this.options.hide=true;this.custom(this.cur(),0)},step:function(a){var b=J(),d=true;if(a||b>=this.options.duration+this.startTime){this.now= -this.end;this.pos=this.state=1;this.update();this.options.curAnim[this.prop]=true;for(var f in this.options.curAnim)if(this.options.curAnim[f]!==true)d=false;if(d){if(this.options.display!=null){this.elem.style.overflow=this.options.overflow;a=c.data(this.elem,"olddisplay");this.elem.style.display=a?a:this.options.display;if(c.css(this.elem,"display")==="none")this.elem.style.display="block"}this.options.hide&&c(this.elem).hide();if(this.options.hide||this.options.show)for(var e in this.options.curAnim)c.style(this.elem, -e,this.options.orig[e]);this.options.complete.call(this.elem)}return false}else{e=b-this.startTime;this.state=e/this.options.duration;a=this.options.easing||(c.easing.swing?"swing":"linear");this.pos=c.easing[this.options.specialEasing&&this.options.specialEasing[this.prop]||a](this.state,e,0,1,this.options.duration);this.now=this.start+(this.end-this.start)*this.pos;this.update()}return true}};c.extend(c.fx,{tick:function(){for(var a=c.timers,b=0;b
"; -a.insertBefore(b,a.firstChild);d=b.firstChild;f=d.firstChild;e=d.nextSibling.firstChild.firstChild;this.doesNotAddBorder=f.offsetTop!==5;this.doesAddBorderForTableAndCells=e.offsetTop===5;f.style.position="fixed";f.style.top="20px";this.supportsFixedPosition=f.offsetTop===20||f.offsetTop===15;f.style.position=f.style.top="";d.style.overflow="hidden";d.style.position="relative";this.subtractsBorderForOverflowNotVisible=f.offsetTop===-5;this.doesNotIncludeMarginInBodyOffset=a.offsetTop!==j;a.removeChild(b); -c.offset.initialize=c.noop},bodyOffset:function(a){var b=a.offsetTop,d=a.offsetLeft;c.offset.initialize();if(c.offset.doesNotIncludeMarginInBodyOffset){b+=parseFloat(c.curCSS(a,"marginTop",true))||0;d+=parseFloat(c.curCSS(a,"marginLeft",true))||0}return{top:b,left:d}},setOffset:function(a,b,d){if(/static/.test(c.curCSS(a,"position")))a.style.position="relative";var f=c(a),e=f.offset(),j=parseInt(c.curCSS(a,"top",true),10)||0,i=parseInt(c.curCSS(a,"left",true),10)||0;if(c.isFunction(b))b=b.call(a, -d,e);d={top:b.top-e.top+j,left:b.left-e.left+i};"using"in b?b.using.call(a,d):f.css(d)}};c.fn.extend({position:function(){if(!this[0])return null;var a=this[0],b=this.offsetParent(),d=this.offset(),f=/^body|html$/i.test(b[0].nodeName)?{top:0,left:0}:b.offset();d.top-=parseFloat(c.curCSS(a,"marginTop",true))||0;d.left-=parseFloat(c.curCSS(a,"marginLeft",true))||0;f.top+=parseFloat(c.curCSS(b[0],"borderTopWidth",true))||0;f.left+=parseFloat(c.curCSS(b[0],"borderLeftWidth",true))||0;return{top:d.top- -f.top,left:d.left-f.left}},offsetParent:function(){return this.map(function(){for(var a=this.offsetParent||s.body;a&&!/^body|html$/i.test(a.nodeName)&&c.css(a,"position")==="static";)a=a.offsetParent;return a})}});c.each(["Left","Top"],function(a,b){var d="scroll"+b;c.fn[d]=function(f){var e=this[0],j;if(!e)return null;if(f!==w)return this.each(function(){if(j=wa(this))j.scrollTo(!a?f:c(j).scrollLeft(),a?f:c(j).scrollTop());else this[d]=f});else return(j=wa(e))?"pageXOffset"in j?j[a?"pageYOffset": -"pageXOffset"]:c.support.boxModel&&j.document.documentElement[d]||j.document.body[d]:e[d]}});c.each(["Height","Width"],function(a,b){var d=b.toLowerCase();c.fn["inner"+b]=function(){return this[0]?c.css(this[0],d,false,"padding"):null};c.fn["outer"+b]=function(f){return this[0]?c.css(this[0],d,false,f?"margin":"border"):null};c.fn[d]=function(f){var e=this[0];if(!e)return f==null?null:this;if(c.isFunction(f))return this.each(function(j){var i=c(this);i[d](f.call(this,j,i[d]()))});return"scrollTo"in -e&&e.document?e.document.compatMode==="CSS1Compat"&&e.document.documentElement["client"+b]||e.document.body["client"+b]:e.nodeType===9?Math.max(e.documentElement["client"+b],e.body["scroll"+b],e.documentElement["scroll"+b],e.body["offset"+b],e.documentElement["offset"+b]):f===w?c.css(e,d):this.css(d,typeof f==="string"?f:f+"px")}});A.jQuery=A.$=c})(window); diff --git a/endpoints/internal/protorpc/static/jquery.json-2.2.min.js b/endpoints/internal/protorpc/static/jquery.json-2.2.min.js deleted file mode 100644 index bad4a0a..0000000 --- a/endpoints/internal/protorpc/static/jquery.json-2.2.min.js +++ /dev/null @@ -1,31 +0,0 @@ - -(function($){$.toJSON=function(o) -{if(typeof(JSON)=='object'&&JSON.stringify) -return JSON.stringify(o);var type=typeof(o);if(o===null) -return"null";if(type=="undefined") -return undefined;if(type=="number"||type=="boolean") -return o+"";if(type=="string") -return $.quoteString(o);if(type=='object') -{if(typeof o.toJSON=="function") -return $.toJSON(o.toJSON());if(o.constructor===Date) -{var month=o.getUTCMonth()+1;if(month<10)month='0'+month;var day=o.getUTCDate();if(day<10)day='0'+day;var year=o.getUTCFullYear();var hours=o.getUTCHours();if(hours<10)hours='0'+hours;var minutes=o.getUTCMinutes();if(minutes<10)minutes='0'+minutes;var seconds=o.getUTCSeconds();if(seconds<10)seconds='0'+seconds;var milli=o.getUTCMilliseconds();if(milli<100)milli='0'+milli;if(milli<10)milli='0'+milli;return'"'+year+'-'+month+'-'+day+'T'+ -hours+':'+minutes+':'+seconds+'.'+milli+'Z"';} -if(o.constructor===Array) -{var ret=[];for(var i=0;i - -{% extends 'base.html' %} - -{% block title %}Form for {{service_path|escape}}.{{method_name|escape}}{% endblock %} - -{% block top %} -<< Back to method selection -

Form for {{service_path|escape}}.{{method_name|escape}}

-{% endblock %} - -{% block body %} - -
-
-
- -
-{% endblock %} - -{% block call %} -loadServices(createForm); -{% endblock %} diff --git a/endpoints/internal/protorpc/webapp/__init__.py b/endpoints/internal/protorpc/webapp/__init__.py deleted file mode 100644 index ce0df32..0000000 --- a/endpoints/internal/protorpc/webapp/__init__.py +++ /dev/null @@ -1,18 +0,0 @@ -#!/usr/bin/env python -# -# Copyright 2011 Google Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -__author__ = 'rafek@google.com (Rafe Kaplan)' diff --git a/endpoints/internal/protorpc/webapp/forms.py b/endpoints/internal/protorpc/webapp/forms.py deleted file mode 100644 index 65d3b96..0000000 --- a/endpoints/internal/protorpc/webapp/forms.py +++ /dev/null @@ -1,163 +0,0 @@ -#!/usr/bin/env python -# -# Copyright 2010 Google Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -"""Webapp forms interface to ProtoRPC services. - -This webapp application is automatically configured to work with ProtoRPCs -that have a configured protorpc.RegistryService. This webapp is -automatically added to the registry service URL at /forms -(default is /protorpc/form) when configured using the -service_handlers.service_mapping function. -""" - -import os - -from .google_imports import template -from .google_imports import webapp - - -__all__ = ['FormsHandler', - 'ResourceHandler', - - 'DEFAULT_REGISTRY_PATH', - ] - -_TEMPLATES_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), - 'static') - -_FORMS_TEMPLATE = os.path.join(_TEMPLATES_DIR, 'forms.html') -_METHODS_TEMPLATE = os.path.join(_TEMPLATES_DIR, 'methods.html') - -DEFAULT_REGISTRY_PATH = '/protorpc' - - -class ResourceHandler(webapp.RequestHandler): - """Serves static resources without needing to add static files to app.yaml.""" - - __RESOURCE_MAP = { - 'forms.js': 'text/javascript', - 'jquery-1.4.2.min.js': 'text/javascript', - 'jquery.json-2.2.min.js': 'text/javascript', - } - - def get(self, relative): - """Serve known static files. - - If static file is not known, will return 404 to client. - - Response items are cached for 300 seconds. - - Args: - relative: Name of static file relative to main FormsHandler. - """ - content_type = self.__RESOURCE_MAP.get(relative, None) - if not content_type: - self.response.set_status(404) - self.response.out.write('Resource not found.') - return - - path = os.path.join(_TEMPLATES_DIR, relative) - self.response.headers['Content-Type'] = content_type - static_file = open(path) - try: - contents = static_file.read() - finally: - static_file.close() - self.response.out.write(contents) - - -class FormsHandler(webapp.RequestHandler): - """Handler for display HTML/javascript forms of ProtoRPC method calls. - - When accessed with no query parameters, will show a web page that displays - all services and methods on the associated registry path. Links on this - page fill in the service_path and method_name query parameters back to this - same handler. - - When provided with service_path and method_name parameters will display a - dynamic form representing the request message for that method. When sent, - the form sends a JSON request to the ProtoRPC method and displays the - response in the HTML page. - - Attribute: - registry_path: Read-only registry path known by this handler. - """ - - def __init__(self, registry_path=DEFAULT_REGISTRY_PATH): - """Constructor. - - When configuring a FormsHandler to use with a webapp application do not - pass the request handler class in directly. Instead use new_factory to - ensure that the FormsHandler is created with the correct registry path - for each request. - - Args: - registry_path: Absolute path on server where the ProtoRPC RegsitryService - is located. - """ - assert registry_path - self.__registry_path = registry_path - - @property - def registry_path(self): - return self.__registry_path - - def get(self): - """Send forms and method page to user. - - By default, displays a web page listing all services and methods registered - on the server. Methods have links to display the actual method form. - - If both parameters are set, will display form for method. - - Query Parameters: - service_path: Path to service to display method of. Optional. - method_name: Name of method to display form for. Optional. - """ - params = {'forms_path': self.request.path.rstrip('/'), - 'hostname': self.request.host, - 'registry_path': self.__registry_path, - } - service_path = self.request.get('path', None) - method_name = self.request.get('method', None) - - if service_path and method_name: - form_template = _METHODS_TEMPLATE - params['service_path'] = service_path - params['method_name'] = method_name - else: - form_template = _FORMS_TEMPLATE - - self.response.out.write(template.render(form_template, params)) - - @classmethod - def new_factory(cls, registry_path=DEFAULT_REGISTRY_PATH): - """Construct a factory for use with WSGIApplication. - - This method is called automatically with the correct registry path when - services are configured via service_handlers.service_mapping. - - Args: - registry_path: Absolute path on server where the ProtoRPC RegsitryService - is located. - - Returns: - Factory function that creates a properly configured FormsHandler instance. - """ - def forms_factory(): - return cls(registry_path) - return forms_factory diff --git a/endpoints/internal/protorpc/webapp/google_imports.py b/endpoints/internal/protorpc/webapp/google_imports.py deleted file mode 100644 index b7de40c..0000000 --- a/endpoints/internal/protorpc/webapp/google_imports.py +++ /dev/null @@ -1,25 +0,0 @@ -"""Dynamically decide from where to import other SDK modules. - -All other protorpc.webapp code should import other SDK modules from -this module. If necessary, add new imports here (in both places). -""" - -__author__ = 'yey@google.com (Ye Yuan)' - -# pylint: disable=g-import-not-at-top -# pylint: disable=unused-import - -import os -import sys - -try: - from google.appengine import ext - normal_environment = True -except ImportError: - normal_environment = False - - -if normal_environment: - from google.appengine.ext import webapp - from google.appengine.ext.webapp import util as webapp_util - from google.appengine.ext.webapp import template diff --git a/endpoints/internal/protorpc/webapp/service_handlers.py b/endpoints/internal/protorpc/webapp/service_handlers.py deleted file mode 100644 index 94a0855..0000000 --- a/endpoints/internal/protorpc/webapp/service_handlers.py +++ /dev/null @@ -1,834 +0,0 @@ -#!/usr/bin/env python -# -# Copyright 2010 Google Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -"""Handlers for remote services. - -This module contains classes that may be used to build a service -on top of the App Engine Webapp framework. - -The services request handler can be configured to handle requests in a number -of different request formats. All different request formats must have a way -to map the request to the service handlers defined request message.Message -class. The handler can also send a response in any format that can be mapped -from the response message.Message class. - -Participants in an RPC: - - There are four classes involved with the life cycle of an RPC. - - Service factory: A user-defined service factory that is responsible for - instantiating an RPC service. The methods intended for use as RPC - methods must be decorated by the 'remote' decorator. - - RPCMapper: Responsible for determining whether or not a specific request - matches a particular RPC format and translating between the actual - request/response and the underlying message types. A single instance of - an RPCMapper sub-class is required per service configuration. Each - mapper must be usable across multiple requests. - - ServiceHandler: A webapp.RequestHandler sub-class that responds to the - webapp framework. It mediates between the RPCMapper and service - implementation class during a request. As determined by the Webapp - framework, a new ServiceHandler instance is created to handle each - user request. A handler is never used to handle more than one request. - - ServiceHandlerFactory: A class that is responsible for creating new, - properly configured ServiceHandler instance for each request. The - factory is configured by providing it with a set of RPCMapper instances. - When the Webapp framework invokes the service handler, the handler - creates a new service class instance. The service class instance is - provided with a reference to the handler. A single instance of an - RPCMapper sub-class is required to configure each service. Each mapper - instance must be usable across multiple requests. - -RPC mappers: - - RPC mappers translate between a single HTTP based RPC protocol and the - underlying service implementation. Each RPC mapper must configured - with the following information to determine if it is an appropriate - mapper for a given request: - - http_methods: Set of HTTP methods supported by handler. - - content_types: Set of supported content types. - - default_content_type: Default content type for handler responses. - - Built-in mapper implementations: - - URLEncodedRPCMapper: Matches requests that are compatible with post - forms with the 'application/x-www-form-urlencoded' content-type - (this content type is the default if none is specified. It - translates post parameters into request parameters. - - ProtobufRPCMapper: Matches requests that are compatible with post - forms with the 'application/x-google-protobuf' content-type. It - reads the contents of a binary post request. - -Public Exceptions: - Error: Base class for service handler errors. - ServiceConfigurationError: Raised when a service not correctly configured. - RequestError: Raised by RPC mappers when there is an error in its request - or request format. - ResponseError: Raised by RPC mappers when there is an error in its response. -""" -import six - -__author__ = 'rafek@google.com (Rafe Kaplan)' - -import six.moves.http_client -import logging - -from .google_imports import webapp -from .google_imports import webapp_util -from .. import messages -from .. import protobuf -from .. import protojson -from .. import protourlencode -from .. import registry -from .. import remote -from .. import util -from . import forms - -__all__ = [ - 'Error', - 'RequestError', - 'ResponseError', - 'ServiceConfigurationError', - - 'DEFAULT_REGISTRY_PATH', - - 'ProtobufRPCMapper', - 'RPCMapper', - 'ServiceHandler', - 'ServiceHandlerFactory', - 'URLEncodedRPCMapper', - 'JSONRPCMapper', - 'service_mapping', - 'run_services', -] - - -class Error(Exception): - """Base class for all errors in service handlers module.""" - - -class ServiceConfigurationError(Error): - """When service configuration is incorrect.""" - - -class RequestError(Error): - """Error occurred when building request.""" - - -class ResponseError(Error): - """Error occurred when building response.""" - - -_URLENCODED_CONTENT_TYPE = protourlencode.CONTENT_TYPE -_PROTOBUF_CONTENT_TYPE = protobuf.CONTENT_TYPE -_JSON_CONTENT_TYPE = protojson.CONTENT_TYPE - -_EXTRA_JSON_CONTENT_TYPES = ['application/x-javascript', - 'text/javascript', - 'text/x-javascript', - 'text/x-json', - 'text/json', - ] - -# The whole method pattern is an optional regex. It contains a single -# group used for mapping to the query parameter. This is passed to the -# parameters of 'get' and 'post' on the ServiceHandler. -_METHOD_PATTERN = r'(?:\.([^?]*))?' - -DEFAULT_REGISTRY_PATH = forms.DEFAULT_REGISTRY_PATH - - -class RPCMapper(object): - """Interface to mediate between request and service object. - - Request mappers are implemented to support various types of - RPC protocols. It is responsible for identifying whether a - given request matches a particular protocol, resolve the remote - method to invoke and mediate between the request and appropriate - protocol messages for the remote method. - """ - - @util.positional(4) - def __init__(self, - http_methods, - default_content_type, - protocol, - content_types=None): - """Constructor. - - Args: - http_methods: Set of HTTP methods supported by mapper. - default_content_type: Default content type supported by mapper. - protocol: The protocol implementation. Must implement encode_message and - decode_message. - content_types: Set of additionally supported content types. - """ - self.__http_methods = frozenset(http_methods) - self.__default_content_type = default_content_type - self.__protocol = protocol - - if content_types is None: - content_types = [] - self.__content_types = frozenset([self.__default_content_type] + - content_types) - - @property - def http_methods(self): - return self.__http_methods - - @property - def default_content_type(self): - return self.__default_content_type - - @property - def content_types(self): - return self.__content_types - - def build_request(self, handler, request_type): - """Build request message based on request. - - Each request mapper implementation is responsible for converting a - request to an appropriate message instance. - - Args: - handler: RequestHandler instance that is servicing request. - Must be initialized with request object and been previously determined - to matching the protocol of the RPCMapper. - request_type: Message type to build. - - Returns: - Instance of request_type populated by protocol buffer in request body. - - Raises: - RequestError if the mapper implementation is not able to correctly - convert the request to the appropriate message. - """ - try: - return self.__protocol.decode_message(request_type, handler.request.body) - except (messages.ValidationError, messages.DecodeError) as err: - raise RequestError('Unable to parse request content: %s' % err) - - def build_response(self, handler, response, pad_string=False): - """Build response based on service object response message. - - Each request mapper implementation is responsible for converting a - response message to an appropriate handler response. - - Args: - handler: RequestHandler instance that is servicing request. - Must be initialized with request object and been previously determined - to matching the protocol of the RPCMapper. - response: Response message as returned from the service object. - - Raises: - ResponseError if the mapper implementation is not able to correctly - convert the message to an appropriate response. - """ - try: - encoded_message = self.__protocol.encode_message(response) - except messages.ValidationError as err: - raise ResponseError('Unable to encode message: %s' % err) - else: - handler.response.headers['Content-Type'] = self.default_content_type - handler.response.out.write(encoded_message) - - -class ServiceHandlerFactory(object): - """Factory class used for instantiating new service handlers. - - Normally a handler class is passed directly to the webapp framework - so that it can be simply instantiated to handle a single request. - The service handler, however, must be configured with additional - information so that it knows how to instantiate a service object. - This class acts the same as a normal RequestHandler class by - overriding the __call__ method to correctly configures a ServiceHandler - instance with a new service object. - - The factory must also provide a set of RPCMapper instances which - examine a request to determine what protocol is being used and mediates - between the request and the service object. - - The mapping of a service handler must have a single group indicating the - part of the URL path that maps to the request method. This group must - exist but can be optional for the request (the group may be followed by - '?' in the regular expression matching the request). - - Usage: - - stock_factory = ServiceHandlerFactory(StockService) - ... configure stock_factory by adding RPCMapper instances ... - - application = webapp.WSGIApplication( - [stock_factory.mapping('/stocks')]) - - Default usage: - - application = webapp.WSGIApplication( - [ServiceHandlerFactory.default(StockService).mapping('/stocks')]) - """ - - def __init__(self, service_factory): - """Constructor. - - Args: - service_factory: Service factory to instantiate and provide to - service handler. - """ - self.__service_factory = service_factory - self.__request_mappers = [] - - def all_request_mappers(self): - """Get all request mappers. - - Returns: - Iterator of all request mappers used by this service factory. - """ - return iter(self.__request_mappers) - - def add_request_mapper(self, mapper): - """Add request mapper to end of request mapper list.""" - self.__request_mappers.append(mapper) - - def __call__(self): - """Construct a new service handler instance.""" - return ServiceHandler(self, self.__service_factory()) - - @property - def service_factory(self): - """Service factory associated with this factory.""" - return self.__service_factory - - @staticmethod - def __check_path(path): - """Check a path parameter. - - Make sure a provided path parameter is compatible with the - webapp URL mapping. - - Args: - path: Path to check. This is a plain path, not a regular expression. - - Raises: - ValueError if path does not start with /, path ends with /. - """ - if path.endswith('/'): - raise ValueError('Path %s must not end with /.' % path) - - def mapping(self, path): - """Convenience method to map service to application. - - Args: - path: Path to map service to. It must be a simple path - with a leading / and no trailing /. - - Returns: - Mapping from service URL to service handler factory. - """ - self.__check_path(path) - - service_url_pattern = r'(%s)%s' % (path, _METHOD_PATTERN) - - return service_url_pattern, self - - @classmethod - def default(cls, service_factory, parameter_prefix=''): - """Convenience method to map default factory configuration to application. - - Creates a standardized default service factory configuration that pre-maps - the URL encoded protocol handler to the factory. - - Args: - service_factory: Service factory to instantiate and provide to - service handler. - method_parameter: The name of the form parameter used to determine the - method to invoke used by the URLEncodedRPCMapper. If None, no - parameter is used and the mapper will only match against the form - path-name. Defaults to 'method'. - parameter_prefix: If provided, all the parameters in the form are - expected to begin with that prefix by the URLEncodedRPCMapper. - - Returns: - Mapping from service URL to service handler factory. - """ - factory = cls(service_factory) - - factory.add_request_mapper(ProtobufRPCMapper()) - factory.add_request_mapper(JSONRPCMapper()) - - return factory - - -class ServiceHandler(webapp.RequestHandler): - """Web handler for RPC service. - - Overridden methods: - get: All requests handled by 'handle' method. HTTP method stored in - attribute. Takes remote_method parameter as derived from the URL mapping. - post: All requests handled by 'handle' method. HTTP method stored in - attribute. Takes remote_method parameter as derived from the URL mapping. - redirect: Not implemented for this service handler. - - New methods: - handle: Handle request for both GET and POST. - - Attributes (in addition to attributes in RequestHandler): - service: Service instance associated with request being handled. - method: Method of request. Used by RPCMapper to determine match. - remote_method: Sub-path as provided to the 'get' and 'post' methods. - """ - - def __init__(self, factory, service): - """Constructor. - - Args: - factory: Instance of ServiceFactory used for constructing new service - instances used for handling requests. - service: Service instance used for handling RPC. - """ - self.__factory = factory - self.__service = service - - @property - def service(self): - return self.__service - - def __show_info(self, service_path, remote_method): - self.response.headers['content-type'] = 'text/plain; charset=utf-8' - response_message = [] - if remote_method: - response_message.append('%s.%s is a ProtoRPC method.\n\n' %( - service_path, remote_method)) - else: - response_message.append('%s is a ProtoRPC service.\n\n' % service_path) - definition_name_function = getattr(self.__service, 'definition_name', None) - if definition_name_function: - definition_name = definition_name_function() - else: - definition_name = '%s.%s' % (self.__service.__module__, - self.__service.__class__.__name__) - - response_message.append('Service %s\n\n' % definition_name) - response_message.append('More about ProtoRPC: ') - - response_message.append('http://code.google.com/p/google-protorpc\n') - self.response.out.write(util.pad_string(''.join(response_message))) - - def get(self, service_path, remote_method): - """Handler method for GET requests. - - Args: - service_path: Service path derived from request URL. - remote_method: Sub-path after service path has been matched. - """ - self.handle('GET', service_path, remote_method) - - def post(self, service_path, remote_method): - """Handler method for POST requests. - - Args: - service_path: Service path derived from request URL. - remote_method: Sub-path after service path has been matched. - """ - self.handle('POST', service_path, remote_method) - - def redirect(self, uri, permanent=False): - """Not supported for services.""" - raise NotImplementedError('Services do not currently support redirection.') - - def __send_error(self, - http_code, - status_state, - error_message, - mapper, - error_name=None): - status = remote.RpcStatus(state=status_state, - error_message=error_message, - error_name=error_name) - mapper.build_response(self, status) - self.response.headers['content-type'] = mapper.default_content_type - - logging.error(error_message) - response_content = self.response.out.getvalue() - padding = ' ' * max(0, 512 - len(response_content)) - self.response.out.write(padding) - - self.response.set_status(http_code, error_message) - - def __send_simple_error(self, code, message, pad=True): - """Send error to caller without embedded message.""" - self.response.headers['content-type'] = 'text/plain; charset=utf-8' - logging.error(message) - self.response.set_status(code, message) - - response_message = six.moves.http_client.responses.get(code, 'Unknown Error') - if pad: - response_message = util.pad_string(response_message) - self.response.out.write(response_message) - - def __get_content_type(self): - content_type = self.request.headers.get('content-type', None) - if not content_type: - content_type = self.request.environ.get('HTTP_CONTENT_TYPE', None) - if not content_type: - return None - - # Lop off parameters from the end (for example content-encoding) - return content_type.split(';', 1)[0].lower() - - def __headers(self, content_type): - for name in self.request.headers: - name = name.lower() - if name == 'content-type': - value = content_type - elif name == 'content-length': - value = str(len(self.request.body)) - else: - value = self.request.headers.get(name, '') - yield name, value - - def handle(self, http_method, service_path, remote_method): - """Handle a service request. - - The handle method will handle either a GET or POST response. - It is up to the individual mappers from the handler factory to determine - which request methods they can service. - - If the protocol is not recognized, the request does not provide a correct - request for that protocol or the service object does not support the - requested RPC method, will return error code 400 in the response. - - Args: - http_method: HTTP method of request. - service_path: Service path derived from request URL. - remote_method: Sub-path after service path has been matched. - """ - self.response.headers['x-content-type-options'] = 'nosniff' - if not remote_method and http_method == 'GET': - # Special case a normal get request, presumably via a browser. - self.error(405) - self.__show_info(service_path, remote_method) - return - - content_type = self.__get_content_type() - - # Provide server state to the service. If the service object does not have - # an "initialize_request_state" method, will not attempt to assign state. - try: - state_initializer = self.service.initialize_request_state - except AttributeError: - pass - else: - server_port = self.request.environ.get('SERVER_PORT', None) - if server_port: - server_port = int(server_port) - - request_state = remote.HttpRequestState( - remote_host=self.request.environ.get('REMOTE_HOST', None), - remote_address=self.request.environ.get('REMOTE_ADDR', None), - server_host=self.request.environ.get('SERVER_HOST', None), - server_port=server_port, - http_method=http_method, - service_path=service_path, - headers=list(self.__headers(content_type))) - state_initializer(request_state) - - if not content_type: - self.__send_simple_error(400, 'Invalid RPC request: missing content-type') - return - - # Search for mapper to mediate request. - for mapper in self.__factory.all_request_mappers(): - if content_type in mapper.content_types: - break - else: - if http_method == 'GET': - self.error(six.moves.http_client.UNSUPPORTED_MEDIA_TYPE) - self.__show_info(service_path, remote_method) - else: - self.__send_simple_error(six.moves.http_client.UNSUPPORTED_MEDIA_TYPE, - 'Unsupported content-type: %s' % content_type) - return - - try: - if http_method not in mapper.http_methods: - if http_method == 'GET': - self.error(six.moves.http_client.METHOD_NOT_ALLOWED) - self.__show_info(service_path, remote_method) - else: - self.__send_simple_error(six.moves.http_client.METHOD_NOT_ALLOWED, - 'Unsupported HTTP method: %s' % http_method) - return - - try: - try: - method = getattr(self.service, remote_method) - method_info = method.remote - except AttributeError as err: - self.__send_error( - 400, remote.RpcState.METHOD_NOT_FOUND_ERROR, - 'Unrecognized RPC method: %s' % remote_method, - mapper) - return - - request = mapper.build_request(self, method_info.request_type) - except (RequestError, messages.DecodeError) as err: - self.__send_error(400, - remote.RpcState.REQUEST_ERROR, - 'Error parsing ProtoRPC request (%s)' % err, - mapper) - return - - try: - response = method(request) - except remote.ApplicationError as err: - self.__send_error(400, - remote.RpcState.APPLICATION_ERROR, - unicode(err), - mapper, - err.error_name) - return - - mapper.build_response(self, response) - except Exception as err: - logging.error('An unexpected error occured when handling RPC: %s', - err, exc_info=1) - - self.__send_error(500, - remote.RpcState.SERVER_ERROR, - 'Internal Server Error', - mapper) - return - - -# TODO(rafek): Support tag-id only forms. -class URLEncodedRPCMapper(RPCMapper): - """Request mapper for application/x-www-form-urlencoded forms. - - This mapper is useful for building forms that can invoke RPC. Many services - are also configured to work using URL encoded request information because - of its perceived ease of programming and debugging. - - The mapper must be provided with at least method_parameter or - remote_method_pattern so that it is possible to determine how to determine the - requests RPC method. If both are provided, the service will respond to both - method request types, however, only one may be present in a given request. - If both types are detected, the request will not match. - """ - - def __init__(self, parameter_prefix=''): - """Constructor. - - Args: - parameter_prefix: If provided, all the parameters in the form are - expected to begin with that prefix. - """ - # Private attributes: - # __parameter_prefix: parameter prefix as provided by constructor - # parameter. - super(URLEncodedRPCMapper, self).__init__(['POST'], - _URLENCODED_CONTENT_TYPE, - self) - self.__parameter_prefix = parameter_prefix - - def encode_message(self, message): - """Encode a message using parameter prefix. - - Args: - message: Message to URL Encode. - - Returns: - URL encoded message. - """ - return protourlencode.encode_message(message, - prefix=self.__parameter_prefix) - - @property - def parameter_prefix(self): - """Prefix all form parameters are expected to begin with.""" - return self.__parameter_prefix - - def build_request(self, handler, request_type): - """Build request from URL encoded HTTP request. - - Constructs message from names of URL encoded parameters. If this service - handler has a parameter prefix, parameters must begin with it or are - ignored. - - Args: - handler: RequestHandler instance that is servicing request. - request_type: Message type to build. - - Returns: - Instance of request_type populated by protocol buffer in request - parameters. - - Raises: - RequestError if message type contains nested message field or repeated - message field. Will raise RequestError if there are any repeated - parameters. - """ - request = request_type() - builder = protourlencode.URLEncodedRequestBuilder( - request, prefix=self.__parameter_prefix) - for argument in sorted(handler.request.arguments()): - values = handler.request.get_all(argument) - try: - builder.add_parameter(argument, values) - except messages.DecodeError as err: - raise RequestError(str(err)) - return request - - -class ProtobufRPCMapper(RPCMapper): - """Request mapper for application/x-protobuf service requests. - - This mapper will parse protocol buffer from a POST body and return the request - as a protocol buffer. - """ - - def __init__(self): - super(ProtobufRPCMapper, self).__init__(['POST'], - _PROTOBUF_CONTENT_TYPE, - protobuf) - - -class JSONRPCMapper(RPCMapper): - """Request mapper for application/x-protobuf service requests. - - This mapper will parse protocol buffer from a POST body and return the request - as a protocol buffer. - """ - - def __init__(self): - super(JSONRPCMapper, self).__init__( - ['POST'], - _JSON_CONTENT_TYPE, - protojson, - content_types=_EXTRA_JSON_CONTENT_TYPES) - - -def service_mapping(services, - registry_path=DEFAULT_REGISTRY_PATH): - """Create a services mapping for use with webapp. - - Creates basic default configuration and registration for ProtoRPC services. - Each service listed in the service mapping has a standard service handler - factory created for it. - - The list of mappings can either be an explicit path to service mapping or - just services. If mappings are just services, they will automatically - be mapped to their default name. For exampel: - - package = 'my_package' - - class MyService(remote.Service): - ... - - server_mapping([('/my_path', MyService), # Maps to /my_path - MyService, # Maps to /my_package/MyService - ]) - - Specifying a service mapping: - - Normally services are mapped to URL paths by specifying a tuple - (path, service): - path: The path the service resides on. - service: The service class or service factory for creating new instances - of the service. For more information about service factories, please - see remote.Service.new_factory. - - If no tuple is provided, and therefore no path specified, a default path - is calculated by using the fully qualified service name using a URL path - separator for each of its components instead of a '.'. - - Args: - services: Can be service type, service factory or string definition name of - service being mapped or list of tuples (path, service): - path: Path on server to map service to. - service: Service type, service factory or string definition name of - service being mapped. - Can also be a dict. If so, the keys are treated as the path and values as - the service. - registry_path: Path to give to registry service. Use None to disable - registry service. - - Returns: - List of tuples defining a mapping of request handlers compatible with a - webapp application. - - Raises: - ServiceConfigurationError when duplicate paths are provided. - """ - if isinstance(services, dict): - services = six.iteritems(services) - mapping = [] - registry_map = {} - - if registry_path is not None: - registry_service = registry.RegistryService.new_factory(registry_map) - services = list(services) + [(registry_path, registry_service)] - mapping.append((registry_path + r'/form(?:/)?', - forms.FormsHandler.new_factory(registry_path))) - mapping.append((registry_path + r'/form/(.+)', forms.ResourceHandler)) - - paths = set() - for service_item in services: - infer_path = not isinstance(service_item, (list, tuple)) - if infer_path: - service = service_item - else: - service = service_item[1] - - service_class = getattr(service, 'service_class', service) - - if infer_path: - path = '/' + service_class.definition_name().replace('.', '/') - else: - path = service_item[0] - - if path in paths: - raise ServiceConfigurationError( - 'Path %r is already defined in service mapping' % path.encode('utf-8')) - else: - paths.add(path) - - # Create service mapping for webapp. - new_mapping = ServiceHandlerFactory.default(service).mapping(path) - mapping.append(new_mapping) - - # Update registry with service class. - registry_map[path] = service_class - - return mapping - - -def run_services(services, - registry_path=DEFAULT_REGISTRY_PATH): - """Handle CGI request using service mapping. - - Args: - Same as service_mapping. - """ - mappings = service_mapping(services, registry_path=registry_path) - application = webapp.WSGIApplication(mappings) - webapp_util.run_wsgi_app(application) From ab0a311ccb63534d69a9ec86ff32bd6533e2e359 Mon Sep 17 00:00:00 2001 From: Rose Davidson Date: Thu, 26 Jul 2018 17:11:13 -0700 Subject: [PATCH 4/6] Ensure protorpc message types keep existing names. --- endpoints/internal/protorpc/message_types.py | 3 +++ endpoints/internal/protorpc/messages.py | 2 ++ 2 files changed, 5 insertions(+) diff --git a/endpoints/internal/protorpc/message_types.py b/endpoints/internal/protorpc/message_types.py index f707d52..a1474b7 100644 --- a/endpoints/internal/protorpc/message_types.py +++ b/endpoints/internal/protorpc/message_types.py @@ -34,6 +34,9 @@ 'VoidMessage', ] +package = 'protorpc.message_types' + + class VoidMessage(messages.Message): """Empty message.""" diff --git a/endpoints/internal/protorpc/messages.py b/endpoints/internal/protorpc/messages.py index 024039a..f86ed48 100644 --- a/endpoints/internal/protorpc/messages.py +++ b/endpoints/internal/protorpc/messages.py @@ -83,6 +83,8 @@ 'DefinitionNotFoundError', ] +package = 'protorpc.messages' + # TODO(rafek): Add extended module test to ensure all exceptions # in services extends Error. From 6044e4975dcd101f31ad1c381271726d448ca4fb Mon Sep 17 00:00:00 2001 From: Rose Davidson Date: Fri, 27 Jul 2018 15:43:55 -0700 Subject: [PATCH 5/6] Use included protorpc library instead of an external version. --- endpoints/__init__.py | 6 +++--- endpoints/_endpointscfg_impl.py | 2 +- endpoints/api_config.py | 2 +- endpoints/apiserving.py | 6 +++++- endpoints/protojson.py | 2 +- endpoints/test/apiserving_test.py | 19 +++++++++++++++++++ test-requirements.txt | 2 +- 7 files changed, 31 insertions(+), 8 deletions(-) diff --git a/endpoints/__init__.py b/endpoints/__init__.py index d811e9f..06e5852 100644 --- a/endpoints/__init__.py +++ b/endpoints/__init__.py @@ -20,9 +20,9 @@ # pylint: disable=wildcard-import from __future__ import absolute_import -from protorpc import message_types -from protorpc import messages -from protorpc import remote +from .internal.protorpc import message_types +from .internal.protorpc import messages +from .internal.protorpc import remote from .api_config import api, method from .api_config import AUTH_LEVEL, EMAIL_SCOPE diff --git a/endpoints/_endpointscfg_impl.py b/endpoints/_endpointscfg_impl.py index e5d97ba..4e8c72d 100644 --- a/endpoints/_endpointscfg_impl.py +++ b/endpoints/_endpointscfg_impl.py @@ -188,7 +188,7 @@ def GenApiConfig(service_class_names, config_string_generator=None, resolved_services.extend(service.get_api_classes()) elif (not isinstance(service, type) or not issubclass(service, remote.Service)): - raise TypeError('%s is not a ProtoRPC service' % service_class_name) + raise TypeError('%s is not a subclass of endpoints.remote.Service' % service_class_name) else: resolved_services.append(service) diff --git a/endpoints/api_config.py b/endpoints/api_config.py index 93b9113..e27ec0f 100644 --- a/endpoints/api_config.py +++ b/endpoints/api_config.py @@ -43,7 +43,7 @@ def entries_get(self, request): import attr import semver -from protorpc import util +from .internal.protorpc import util from . import api_exceptions from . import constants diff --git a/endpoints/apiserving.py b/endpoints/apiserving.py index fd2776f..737ad58 100644 --- a/endpoints/apiserving.py +++ b/endpoints/apiserving.py @@ -70,7 +70,7 @@ def list(self, request): from endpoints_management.control import client as control_client from endpoints_management.control import wsgi as control_wsgi -from protorpc.wsgi import service as wsgi_service +from .internal.protorpc.wsgi import service as wsgi_service from . import api_config from . import api_exceptions @@ -564,6 +564,10 @@ def api_server(api_services, **kwargs): if 'protocols' in kwargs: raise TypeError("__init__() got an unexpected keyword argument 'protocols'") + for service in api_services: + if not issubclass(service, remote.Service): + raise TypeError('%s is not a subclass of endpoints.remote.Service' % service) + # Construct the api serving app apis_app = _ApiServer(api_services, **kwargs) dispatcher = endpoints_dispatcher.EndpointsDispatcherMiddleware(apis_app) diff --git a/endpoints/protojson.py b/endpoints/protojson.py index 83658db..55a9924 100644 --- a/endpoints/protojson.py +++ b/endpoints/protojson.py @@ -17,7 +17,7 @@ import base64 -from protorpc import protojson +from .internal.protorpc import protojson from . import messages diff --git a/endpoints/test/apiserving_test.py b/endpoints/test/apiserving_test.py index 56cf226..a9eb137 100644 --- a/endpoints/test/apiserving_test.py +++ b/endpoints/test/apiserving_test.py @@ -28,6 +28,7 @@ import urllib2 import mock +import pytest import test_util import webtest from endpoints import api_config @@ -38,6 +39,8 @@ from endpoints import remote from endpoints import resource_container +from protorpc import remote as nonbundled_remote + package = 'endpoints.test' @@ -362,5 +365,21 @@ def testGetApiConfigs(self): self.assertEqual(TEST_SERVICE_CUSTOM_URL_API_CONFIG, configs) +@api_config.api(name='testapi', version='v3', description='A wonderful API.') +class TestNonbundledService(nonbundled_remote.Service): + + @api_config.method(test_request, + message_types.VoidMessage, + http_method='DELETE', path='items/{id}') + # Silence lint warning about method naming conventions + # pylint: disable=g-bad-name + def delete(self, unused_request): + return message_types.VoidMessage() + + +def test_nonbundled_service_error(): + with pytest.raises(TypeError): + apiserving.api_server([TestNonbundledService]) + if __name__ == '__main__': unittest.main() diff --git a/test-requirements.txt b/test-requirements.txt index 3d07a84..6547b5d 100644 --- a/test-requirements.txt +++ b/test-requirements.txt @@ -4,6 +4,6 @@ pytest>=2.8.3 pytest-cov>=1.8.1 pytest-timeout>=1.0.0 webtest>=2.0.23,<3.0 -git+git://github.com/inklesspen/protorpc.git@endpoints-dependency#egg=protorpc-0.12.0a0 +protorpc>=0.12.0 protobuf>=3.0.0b3 PyYAML==3.12 From f62237a0d8268586d248ee711b386e9816327932 Mon Sep 17 00:00:00 2001 From: Rose Davidson Date: Fri, 27 Jul 2018 17:06:09 -0700 Subject: [PATCH 6/6] Rename internal to bundled --- endpoints/__init__.py | 6 +++--- endpoints/api_config.py | 2 +- endpoints/apiserving.py | 2 +- endpoints/{internal => bundled}/__init__.py | 0 endpoints/{internal => bundled}/protorpc/__init__.py | 0 endpoints/{internal => bundled}/protorpc/definition.py | 0 endpoints/{internal => bundled}/protorpc/descriptor.py | 0 endpoints/{internal => bundled}/protorpc/google_imports.py | 0 endpoints/{internal => bundled}/protorpc/message_types.py | 0 endpoints/{internal => bundled}/protorpc/messages.py | 0 endpoints/{internal => bundled}/protorpc/non_sdk_imports.py | 0 endpoints/{internal => bundled}/protorpc/protobuf.py | 0 endpoints/{internal => bundled}/protorpc/protojson.py | 0 .../{internal => bundled}/protorpc/protorpc_test.proto | 0 endpoints/{internal => bundled}/protorpc/protourlencode.py | 0 endpoints/{internal => bundled}/protorpc/registry.py | 0 endpoints/{internal => bundled}/protorpc/remote.py | 0 endpoints/{internal => bundled}/protorpc/transport.py | 0 endpoints/{internal => bundled}/protorpc/util.py | 0 endpoints/{internal => bundled}/protorpc/wsgi/__init__.py | 0 endpoints/{internal => bundled}/protorpc/wsgi/service.py | 0 endpoints/{internal => bundled}/protorpc/wsgi/util.py | 0 endpoints/protojson.py | 2 +- 23 files changed, 6 insertions(+), 6 deletions(-) rename endpoints/{internal => bundled}/__init__.py (100%) rename endpoints/{internal => bundled}/protorpc/__init__.py (100%) rename endpoints/{internal => bundled}/protorpc/definition.py (100%) rename endpoints/{internal => bundled}/protorpc/descriptor.py (100%) rename endpoints/{internal => bundled}/protorpc/google_imports.py (100%) rename endpoints/{internal => bundled}/protorpc/message_types.py (100%) rename endpoints/{internal => bundled}/protorpc/messages.py (100%) rename endpoints/{internal => bundled}/protorpc/non_sdk_imports.py (100%) rename endpoints/{internal => bundled}/protorpc/protobuf.py (100%) rename endpoints/{internal => bundled}/protorpc/protojson.py (100%) rename endpoints/{internal => bundled}/protorpc/protorpc_test.proto (100%) rename endpoints/{internal => bundled}/protorpc/protourlencode.py (100%) rename endpoints/{internal => bundled}/protorpc/registry.py (100%) rename endpoints/{internal => bundled}/protorpc/remote.py (100%) rename endpoints/{internal => bundled}/protorpc/transport.py (100%) rename endpoints/{internal => bundled}/protorpc/util.py (100%) rename endpoints/{internal => bundled}/protorpc/wsgi/__init__.py (100%) rename endpoints/{internal => bundled}/protorpc/wsgi/service.py (100%) rename endpoints/{internal => bundled}/protorpc/wsgi/util.py (100%) diff --git a/endpoints/__init__.py b/endpoints/__init__.py index 06e5852..bc54a84 100644 --- a/endpoints/__init__.py +++ b/endpoints/__init__.py @@ -20,9 +20,9 @@ # pylint: disable=wildcard-import from __future__ import absolute_import -from .internal.protorpc import message_types -from .internal.protorpc import messages -from .internal.protorpc import remote +from .bundled.protorpc import message_types +from .bundled.protorpc import messages +from .bundled.protorpc import remote from .api_config import api, method from .api_config import AUTH_LEVEL, EMAIL_SCOPE diff --git a/endpoints/api_config.py b/endpoints/api_config.py index e27ec0f..d753805 100644 --- a/endpoints/api_config.py +++ b/endpoints/api_config.py @@ -43,7 +43,7 @@ def entries_get(self, request): import attr import semver -from .internal.protorpc import util +from .bundled.protorpc import util from . import api_exceptions from . import constants diff --git a/endpoints/apiserving.py b/endpoints/apiserving.py index 737ad58..5ec941f 100644 --- a/endpoints/apiserving.py +++ b/endpoints/apiserving.py @@ -70,7 +70,7 @@ def list(self, request): from endpoints_management.control import client as control_client from endpoints_management.control import wsgi as control_wsgi -from .internal.protorpc.wsgi import service as wsgi_service +from .bundled.protorpc.wsgi import service as wsgi_service from . import api_config from . import api_exceptions diff --git a/endpoints/internal/__init__.py b/endpoints/bundled/__init__.py similarity index 100% rename from endpoints/internal/__init__.py rename to endpoints/bundled/__init__.py diff --git a/endpoints/internal/protorpc/__init__.py b/endpoints/bundled/protorpc/__init__.py similarity index 100% rename from endpoints/internal/protorpc/__init__.py rename to endpoints/bundled/protorpc/__init__.py diff --git a/endpoints/internal/protorpc/definition.py b/endpoints/bundled/protorpc/definition.py similarity index 100% rename from endpoints/internal/protorpc/definition.py rename to endpoints/bundled/protorpc/definition.py diff --git a/endpoints/internal/protorpc/descriptor.py b/endpoints/bundled/protorpc/descriptor.py similarity index 100% rename from endpoints/internal/protorpc/descriptor.py rename to endpoints/bundled/protorpc/descriptor.py diff --git a/endpoints/internal/protorpc/google_imports.py b/endpoints/bundled/protorpc/google_imports.py similarity index 100% rename from endpoints/internal/protorpc/google_imports.py rename to endpoints/bundled/protorpc/google_imports.py diff --git a/endpoints/internal/protorpc/message_types.py b/endpoints/bundled/protorpc/message_types.py similarity index 100% rename from endpoints/internal/protorpc/message_types.py rename to endpoints/bundled/protorpc/message_types.py diff --git a/endpoints/internal/protorpc/messages.py b/endpoints/bundled/protorpc/messages.py similarity index 100% rename from endpoints/internal/protorpc/messages.py rename to endpoints/bundled/protorpc/messages.py diff --git a/endpoints/internal/protorpc/non_sdk_imports.py b/endpoints/bundled/protorpc/non_sdk_imports.py similarity index 100% rename from endpoints/internal/protorpc/non_sdk_imports.py rename to endpoints/bundled/protorpc/non_sdk_imports.py diff --git a/endpoints/internal/protorpc/protobuf.py b/endpoints/bundled/protorpc/protobuf.py similarity index 100% rename from endpoints/internal/protorpc/protobuf.py rename to endpoints/bundled/protorpc/protobuf.py diff --git a/endpoints/internal/protorpc/protojson.py b/endpoints/bundled/protorpc/protojson.py similarity index 100% rename from endpoints/internal/protorpc/protojson.py rename to endpoints/bundled/protorpc/protojson.py diff --git a/endpoints/internal/protorpc/protorpc_test.proto b/endpoints/bundled/protorpc/protorpc_test.proto similarity index 100% rename from endpoints/internal/protorpc/protorpc_test.proto rename to endpoints/bundled/protorpc/protorpc_test.proto diff --git a/endpoints/internal/protorpc/protourlencode.py b/endpoints/bundled/protorpc/protourlencode.py similarity index 100% rename from endpoints/internal/protorpc/protourlencode.py rename to endpoints/bundled/protorpc/protourlencode.py diff --git a/endpoints/internal/protorpc/registry.py b/endpoints/bundled/protorpc/registry.py similarity index 100% rename from endpoints/internal/protorpc/registry.py rename to endpoints/bundled/protorpc/registry.py diff --git a/endpoints/internal/protorpc/remote.py b/endpoints/bundled/protorpc/remote.py similarity index 100% rename from endpoints/internal/protorpc/remote.py rename to endpoints/bundled/protorpc/remote.py diff --git a/endpoints/internal/protorpc/transport.py b/endpoints/bundled/protorpc/transport.py similarity index 100% rename from endpoints/internal/protorpc/transport.py rename to endpoints/bundled/protorpc/transport.py diff --git a/endpoints/internal/protorpc/util.py b/endpoints/bundled/protorpc/util.py similarity index 100% rename from endpoints/internal/protorpc/util.py rename to endpoints/bundled/protorpc/util.py diff --git a/endpoints/internal/protorpc/wsgi/__init__.py b/endpoints/bundled/protorpc/wsgi/__init__.py similarity index 100% rename from endpoints/internal/protorpc/wsgi/__init__.py rename to endpoints/bundled/protorpc/wsgi/__init__.py diff --git a/endpoints/internal/protorpc/wsgi/service.py b/endpoints/bundled/protorpc/wsgi/service.py similarity index 100% rename from endpoints/internal/protorpc/wsgi/service.py rename to endpoints/bundled/protorpc/wsgi/service.py diff --git a/endpoints/internal/protorpc/wsgi/util.py b/endpoints/bundled/protorpc/wsgi/util.py similarity index 100% rename from endpoints/internal/protorpc/wsgi/util.py rename to endpoints/bundled/protorpc/wsgi/util.py diff --git a/endpoints/protojson.py b/endpoints/protojson.py index 55a9924..9d36e03 100644 --- a/endpoints/protojson.py +++ b/endpoints/protojson.py @@ -17,7 +17,7 @@ import base64 -from .internal.protorpc import protojson +from .bundled.protorpc import protojson from . import messages