Skip to content

Commit

Permalink
WIP on CORS support (#10);
Browse files Browse the repository at this point in the history
  • Loading branch information
Alexey Popravka committed Apr 23, 2014
1 parent 30ad6c8 commit 24fa72c
Showing 1 changed file with 33 additions and 7 deletions.
40 changes: 33 additions & 7 deletions aiorest/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,8 @@ def _parse_version(ver):
version_info = _parse_version(__version__)


Entry = collections.namedtuple('Entry', 'regex method handler use_request')
Entry = collections.namedtuple('Entry', 'regex method handler use_request'
' check_cors cors_options')


class Response:
Expand Down Expand Up @@ -252,9 +253,8 @@ def handle_request(self, message, payload):
resp_impl.send_headers()
resp_impl.write(bbody)
resp_impl.write_eof()
## if resp_impl.keep_alive():
## print("KEEP ALIVE")
## self.keep_alive(True)
if resp_impl.keep_alive():
self.keep_alive(True)

#self.log.debug("Fihish handle request %r at %d -> %s",
# message, time.time(), body)
Expand All @@ -272,7 +272,16 @@ class RESTServer:

METHODS = {'POST', 'GET', 'PUT', 'DELETE', 'PATCH', 'HEAD'}

def __init__(self, *, hostname, session_factory=None, loop=None, **kwargs):
CORS_OPTIONS = {
'allow-origin': lambda req: req.headers.get('ORIGIN', '*'),
'allow-credentials': False,
'allow-headers': None,
'expose-headers': None,
'max-age': 86400,
}

def __init__(self, *, hostname, session_factory=None,
enable_cors=False, loop=None, **kwargs):
assert session_factory is None or callable(session_factory), \
"session_factory must be None or callable (coroutine) function"
if loop is None:
Expand All @@ -281,6 +290,7 @@ def __init__(self, *, hostname, session_factory=None, loop=None, **kwargs):
super().__init__()
self.hostname = hostname
self.session_factory = session_factory
self._enable_cors = enable_cors
self._kwargs = kwargs
self._urls = []

Expand All @@ -290,9 +300,17 @@ def make_handler(self):
loop=self._loop,
**self._kwargs)

def add_url(self, method, path, handler, use_request=False):
@property
def cors_enabled(self):
return self._enable_cors

def add_url(self, method, path, handler, use_request=False,
check_cors=True, cors_options={}):
"""XXX"""
assert callable(handler), handler
assert not set(cors_options) - set(self.CORS_OPTIONS), \
'Got bad CORS options: {}'.format(
set(cors_options) - set(self.CORS_OPTIONS))
if isinstance(handler, MethodType):
holder = handler.__func__
else:
Expand Down Expand Up @@ -334,19 +352,27 @@ def add_url(self, method, path, handler, use_request=False):
compiled = re.compile('^' + pattern + '$')
except re.error:
raise ValueError("Invalid path '{}'".format(path))
self._urls.append(Entry(compiled, method, handler, use_request))
self._urls.append(Entry(compiled, method, handler, use_request,
check_cors, cors_options))

@asyncio.coroutine
def dispatch(self, request):
path = request.path
method = request.method
allowed_methods = set()
check_cors = False
if method == 'OPTIONS' and self.cors_enabled:
check_cors = True
method = request.headers['ACCESS-CONTROL-REQUEST-METHOD']
for entry in self._urls:
match = entry.regex.match(path)
if match is None:
continue
if entry.method != method:
allowed_methods.add(entry.method)
elif check_cors and entry.check_cors:
# yield from self._handle_cors_check(request, entry.cors_options)
return
else:
break
else:
Expand Down

0 comments on commit 24fa72c

Please sign in to comment.