Skip to content

Commit

Permalink
[resotocore][feat] Allow authorization message as first ws message (#…
Browse files Browse the repository at this point in the history
…1308)

* [resotocore][feat] Allow authorization message as first ws message

* define handler as variable and avoid lookup

* oops

* add comment

* define separate groups to make the intent more clear
  • Loading branch information
aquamatthias authored Nov 24, 2022
1 parent 589b0d2 commit e044fa8
Show file tree
Hide file tree
Showing 5 changed files with 108 additions and 48 deletions.
9 changes: 8 additions & 1 deletion resotocore/resotocore/static/api-doc.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1441,12 +1441,19 @@ paths:
/events:
get:
summary: "[WebSocket] Register as event listener and receive all events."
description:
description: |
## WebSocket Endpoint
The client needs to send all the required headers for a ws connection
and has to handle the websocket protocol.<br/>
**Note this can not be tested from within swagger!**
## Authorization
In case Resoto has a PSK infrastructure in place, the client needs to send a JWT token via the `Authorization` header
or via the `resoto_authorization` cookie.
It is also possible to omit header or cookie and send an Authorization message as first message on the websocket.
Example
{ "kind": "authorization", "jwt": "Bearer <jwt>" }
parameters:
- name: show
in: query
Expand Down
5 changes: 5 additions & 0 deletions resotocore/resotocore/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,11 @@
Periodic = periodic.Periodic


# noinspection PyUnusedLocal
async def async_noop(*args: Any, **kwargs: Any) -> None:
pass


def identity(o: AnyT) -> AnyT:
return o

Expand Down
56 changes: 32 additions & 24 deletions resotocore/resotocore/web/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@
from resotocore.task.subscribers import SubscriptionHandler
from resotocore.task.task_handler import TaskHandlerService
from resotocore.types import Json, JsonElement
from resotocore.util import uuid_str, force_gen, rnd_str, if_set, duration, utc_str, parse_utc
from resotocore.util import uuid_str, force_gen, rnd_str, if_set, duration, utc_str, parse_utc, async_noop
from resotocore.web.certificate_handler import CertificateHandler
from resotocore.web.content_renderer import result_binary_gen, single_result
from resotocore.web.directives import (
Expand All @@ -92,7 +92,7 @@
WorkerTaskResult,
WorkerTaskInProgress,
)
from resotolib.asynchronous.web.auth import auth_handler
from resotolib.asynchronous.web.auth import auth_handler, set_valid_jwt, raw_jwt_from_auth_message
from resotolib.asynchronous.web.ws_handler import accept_websocket, clean_ws_handler
from resotolib.jwt import encode_jwt

Expand All @@ -106,7 +106,10 @@ def section_of(request: Request) -> Optional[str]:
return section


# No Authorization required for following paths
AlwaysAllowed = {"/", "/metrics", "/api-doc.*", "/system/.*", "/ui.*", "/ca/cert", "/notebook.*"}
# Authorization is not required, but implemented as part of the request handler
DeferredCheck = {"/events"}


class Api:
Expand Down Expand Up @@ -141,7 +144,7 @@ def __init__(
# note on order: the middleware is passed in the order provided.
middlewares=[
metrics_handler,
auth_handler(config.args.psk, AlwaysAllowed),
auth_handler(config.args.psk, AlwaysAllowed | DeferredCheck),
cors_handler,
error_handler(config, event_sender),
default_middleware(self),
Expand Down Expand Up @@ -222,7 +225,6 @@ def __add_routes(self, prefix: str) -> None:
web.post(prefix + "/analytics", self.send_analytics_events),
# Worker operations
web.get(prefix + "/work/queue", self.handle_work_tasks),
web.get(prefix + "/work/create", self.create_work),
web.get(prefix + "/work/list", self.list_work),
# Serve static filed
web.get(prefix, self.forward("/ui/index.html")),
Expand Down Expand Up @@ -456,6 +458,15 @@ async def listen_to_events(
event_types: List[str],
initial_messages: Optional[Sequence[Message]] = None,
) -> WebSocketResponse:
handler: Callable[[str], Awaitable[None]] = async_noop

async def authorize_request(msg: str) -> None:
nonlocal handler
if (r := raw_jwt_from_auth_message(msg)) and set_valid_jwt(request, r, self.config.args.psk) is not None:
handler = handle_message
else:
raise ValueError("No Authorization header provided and no valid auth message sent")

async def handle_message(msg: str) -> None:
js = json.loads(msg)
if "data" in js:
Expand All @@ -475,32 +486,38 @@ async def handle_message(msg: str) -> None:
else:
await self.message_bus.emit(message)

handler = authorize_request if request.get("authorized", False) is False else handle_message
return await accept_websocket(
request,
handle_incoming=handle_message,
handle_incoming=lambda x: handler(x), # pylint: disable=unnecessary-lambda # it is required!
outgoing_context=partial(self.message_bus.subscribe, listener_id, event_types),
websocket_handler=self.websocket_handler,
initial_messages=initial_messages,
)

async def handle_work_tasks(self, request: Request) -> WebSocketResponse:
worker_id = WorkerId(uuid_str())
initialized = False
worker_descriptions: Future[List[WorkerTaskDescription]] = asyncio.get_event_loop().create_future()
handler: Callable[[str], Awaitable[None]] = async_noop

async def authorize_request(msg: str) -> None:
nonlocal handler
if (r := raw_jwt_from_auth_message(msg)) and set_valid_jwt(request, r, self.config.args.psk) is not None:
handler = handle_connect
else:
raise ValueError("No Authorization header provided and no valid auth message sent")

async def handle_connect(msg: str) -> None:
nonlocal initialized
nonlocal handler
cmds = from_js(json.loads(msg), List[WorkerCustomCommand])
print("connected: ", cmds)

description = [WorkerTaskDescription(cmd.name, cmd.filter) for cmd in cmds]
# set the future and allow attaching the worker to the task queue
worker_descriptions.set_result(description)
# register the descriptions as custom command on the CLI
for cmd in cmds:
self.cli.register_worker_custom_command(cmd)
# mark the worker as initialized
initialized = True
# the connect process is done, define the final handler
handler = handle_message

async def handle_message(msg: str) -> None:
tr = from_js(json.loads(msg), WorkerTaskResult)
Expand All @@ -523,26 +540,17 @@ async def connect_to_task_queue() -> AsyncIterator[Queue[WorkerTask]]:
async with self.worker_task_queue.attach(worker_id, descriptions) as queue:
yield queue

handler = authorize_request if request.get("authorized", False) is False else handle_connect
# noinspection PyTypeChecker
return await accept_websocket(
request,
handle_incoming=lambda msg: handle_connect(msg) if not initialized else handle_message(msg),
handle_incoming=lambda x: handler(x), # pylint: disable=unnecessary-lambda # it is required!
outgoing_context=connect_to_task_queue,
websocket_handler=self.websocket_handler,
outgoing_fn=task_json,
)

async def create_work(self, request: Request) -> StreamResponse:
attrs = {k: v for k, v in request.query.items() if k != "task"}
future = asyncio.get_event_loop().create_future()
task = WorkerTask(
TaskId(uuid_str()), "test", attrs, {"some": "data", "foo": "bla"}, future, timedelta(seconds=3)
)
await self.worker_task_queue.add_task(task)
await future
return web.HTTPOk()

async def list_work(self, request: Request) -> StreamResponse:
async def list_work(self, _: Request) -> StreamResponse:
def wt_to_js(ip: WorkerTaskInProgress) -> Json:
return {
"task": ip.task.to_json(),
Expand Down Expand Up @@ -870,7 +878,7 @@ async def execute(self, request: Request) -> StreamResponse:
temp = tempfile.mkdtemp()
temp_dir = temp
files = {}
# for now we assume that all multi-parts are file uploads
# for now, we assume that all multi-parts are file uploads
async for part in MultipartReader(request.headers, request.content):
name = part.name
if not name:
Expand Down
52 changes: 39 additions & 13 deletions resotolib/resotolib/asynchronous/web/auth.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
import logging
import re
from contextvars import ContextVar
Expand Down Expand Up @@ -25,11 +26,39 @@ async def jwt_from_context() -> JWT:
return __JWT_Context.get()


def raw_jwt_from_auth_message(msg: str) -> Optional[str]:
"""
Expected message: json object with type kind="authorization" and a jwt field
{ "kind": "authorization", "jwt": "Bearer <jwt>" }
"""
try:
js = json.loads(msg)
assert js.get("kind") == "authorization"
return js.get("jwt")
except Exception:
return None


@middleware
async def no_check(request: Request, handler: RequestHandler) -> StreamResponse:
# all requests are authorized automatically
request["authorized"] = True
return await handler(request)


def set_valid_jwt(request: Request, jwt_raw: str, psk: str) -> Optional[JWT]:
try:
# note: the expiration is already checked by this function
jwt = ck_jwt.decode_jwt_from_header_value(jwt_raw, psk)
except PyJWTError:
return None
if jwt:
request["jwt"] = jwt
request["authorized"] = True
__JWT_Context.set(jwt)
return jwt


def check_jwt(psk: str, always_allowed_paths: Set[str]) -> Middleware:
def always_allowed(request: Request) -> bool:
for path in always_allowed_paths:
Expand All @@ -40,9 +69,9 @@ def always_allowed(request: Request) -> bool:
@middleware
async def valid_jwt_handler(request: Request, handler: RequestHandler) -> StreamResponse:
auth_header = request.headers.get("Authorization") or request.cookies.get("resoto_authorization")
if always_allowed(request):
return await handler(request)
elif auth_header:
authorized = False
if auth_header:
# make sure origin and host match, so the request is valid
origin: Optional[str] = urlparse(request.headers.get("Origin")).hostname
host: Optional[str] = request.headers.get("Host")
if host is not None and origin is not None:
Expand All @@ -51,16 +80,13 @@ async def valid_jwt_handler(request: Request, handler: RequestHandler) -> Stream
if origin.lower() != host.lower():
log.warning(f"Origin {origin} is not allowed in request from {request.remote} to {request.path}")
raise web.HTTPForbidden()
try:
# note: the expiration is already checked by this function
jwt = ck_jwt.decode_jwt_from_header_value(auth_header, psk)
except PyJWTError as ex:
raise web.HTTPUnauthorized() from ex
if jwt:
__JWT_Context.set(jwt)
return await handler(request)
# if we come here, something is wrong: reject
raise web.HTTPUnauthorized()

# try to authorize the request, even if it is one of the always allowed paths
authorized = set_valid_jwt(request, auth_header, psk) is not None
if authorized or always_allowed(request):
return await handler(request)
else:
raise web.HTTPUnauthorized()

return valid_jwt_handler

Expand Down
34 changes: 24 additions & 10 deletions resotolib/resotolib/asynchronous/web/ws_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,16 @@ async def accept_websocket(
await ws.prepare(request)
wsid = str(uuid1())

# in case we wait for an initial authorization message, only wait for a limited amount of tine
async def wait_for_authorization() -> None:
counter = 10
while request.get("authorized", False) is not True and counter >= 0:
await asyncio.sleep(1)
counter -= 1
if counter <= 0:
log.info(f"Wait for authorization: message listener {wsid}: Timeout. Hang up.")
await clean_ws_handler(wsid, websocket_handler)

async def receive() -> None:
try:
async for msg in ws:
Expand All @@ -73,6 +83,14 @@ async def receive() -> None:

async def send(ctx: Callable[[], AsyncContextManager[Queue[T]]]) -> None:
try:
# wait for the request to become authorized, before we will send any message
while request.get("authorized", False) is not True:
await asyncio.sleep(1)
# send all initial messages
if initial_messages:
for msg in initial_messages:
await ws.send_str(outgoing_fn(msg) + "\n")
# attach to the queue and wait for messages
async with ctx() as events:
while True:
event = await events.get()
Expand All @@ -83,17 +101,13 @@ async def send(ctx: Callable[[], AsyncContextManager[Queue[T]]]) -> None:
finally:
await clean_ws_handler(wsid, websocket_handler)

receive_task = asyncio.create_task(receive())
to_wait = (
asyncio.gather(receive_task, asyncio.create_task(send(outgoing_context)))
if outgoing_context is not None
else receive_task
)

if initial_messages:
for msg in initial_messages:
await ws.send_str(outgoing_fn(msg) + "\n")
tasks = [asyncio.create_task(receive())]
if outgoing_context is not None:
tasks.append(asyncio.create_task(send(outgoing_context)))
if request.get("authorized", False) is not True:
tasks.append(asyncio.create_task(wait_for_authorization()))

to_wait = asyncio.gather(*tasks)
websocket_handler[wsid] = (to_wait, ws)
await to_wait
return ws

0 comments on commit e044fa8

Please sign in to comment.