Skip to content

Commit

Permalink
ModuleRouter: support paths in BASE
Browse files Browse the repository at this point in the history
If Satosa is installed under a path which is not the root of the
webserver (ie. "https://example.com/satosa"), then endpoint routing must
take the base path into consideration.

Some modules registered some of their endpoints with the base path
included, but other times the base path was omitted, thus it made the
routing fail. Now all endpoint registrations include the base path in
their endpoint map.

Additionally, DEBUG logging was configured for the tests so that the
debug logs are accessible during testing.
  • Loading branch information
bajnokk committed Jul 20, 2022
1 parent 6a4a83b commit 862c400
Show file tree
Hide file tree
Showing 14 changed files with 127 additions and 45 deletions.
2 changes: 1 addition & 1 deletion src/satosa/backends/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def __init__(self, auth_callback_func, internal_attributes, base_url, name):
self.auth_callback_func = auth_callback_func
self.internal_attributes = internal_attributes
self.converter = AttributeMapper(internal_attributes)
self.base_url = base_url
self.base_url = base_url.rstrip("/") if base_url else ""
self.name = name

def start_auth(self, context, internal_request):
Expand Down
9 changes: 7 additions & 2 deletions src/satosa/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import uuid

from saml2.s_utils import UnknownSystemEntity
from urllib.parse import urlparse

from satosa import util
from .context import Context
Expand Down Expand Up @@ -39,6 +40,8 @@ def __init__(self, config):
"""
self.config = config

base_path = urlparse(self.config["BASE"]).path.lstrip("/")

logger.info("Loading backend modules...")
backends = load_backends(self.config, self._auth_resp_callback_func,
self.config["INTERNAL_ATTRIBUTES"])
Expand All @@ -64,8 +67,10 @@ def __init__(self, config):
self.config["BASE"]))
self._link_micro_services(self.response_micro_services, self._auth_resp_finish)

self.module_router = ModuleRouter(frontends, backends,
self.request_micro_services + self.response_micro_services)
self.module_router = ModuleRouter(frontends,
backends,
self.request_micro_services + self.response_micro_services,
base_path)

def _link_micro_services(self, micro_services, finisher):
if not micro_services:
Expand Down
4 changes: 0 additions & 4 deletions src/satosa/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,10 +76,6 @@ def path(self, p):
raise ValueError("path can't start with '/'")
self._path = p

def target_entity_id_from_path(self):
target_entity_id = self.path.split("/")[1]
return target_entity_id

def decorate(self, key, value):
"""
Add information to the context
Expand Down
7 changes: 6 additions & 1 deletion src/satosa/frontends/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
"""
from ..attribute_mapping import AttributeMapper

import os.path
from urllib.parse import urlparse


class FrontendModule(object):
"""
Expand All @@ -23,8 +26,10 @@ def __init__(self, auth_req_callback_func, internal_attributes, base_url, name):
self.auth_req_callback_func = auth_req_callback_func
self.internal_attributes = internal_attributes
self.converter = AttributeMapper(internal_attributes)
self.base_url = base_url
self.base_url = base_url.rstrip("/") if base_url else ""
self.name = name
self.endpoint_baseurl = os.path.join(self.base_url, self.name)
self.endpoint_basepath = urlparse(self.endpoint_baseurl).path.lstrip("/")

def handle_authn_response(self, context, internal_resp):
"""
Expand Down
15 changes: 7 additions & 8 deletions src/satosa/frontends/openid_connect.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,6 @@ def __init__(self, auth_req_callback_func, internal_attributes, conf, base_url,
else:
cdb = {}

self.endpoint_baseurl = "{}/{}".format(self.base_url, self.name)
self.provider = _create_provider(
provider_config,
self.endpoint_baseurl,
Expand Down Expand Up @@ -165,6 +164,9 @@ def register_endpoints(self, backend_names):
:rtype: list[(str, ((satosa.context.Context, Any) -> satosa.response.Response, Any))]
:raise ValueError: if more than one backend is configured
"""
provider_config = ("^.well-known/openid-configuration$", self.provider_config)
jwks_uri = ("^{}/jwks$".format(self.endpoint_basepath), self.jwks)

backend_name = None
if len(backend_names) != 1:
# only supports one backend since there currently is no way to publish multiple authorization endpoints
Expand All @@ -181,16 +183,13 @@ def register_endpoints(self, backend_names):
else:
backend_name = backend_names[0]

provider_config = ("^.well-known/openid-configuration$", self.provider_config)
jwks_uri = ("^{}/jwks$".format(self.name), self.jwks)

if backend_name:
# if there is only one backend, include its name in the path so the default routing can work
auth_endpoint = "{}/{}/{}/{}".format(self.base_url, backend_name, self.name, AuthorizationEndpoint.url)
self.provider.configuration_information["authorization_endpoint"] = auth_endpoint
auth_path = urlparse(auth_endpoint).path.lstrip("/")
else:
auth_path = "{}/{}".format(self.name, AuthorizationEndpoint.url)
auth_path = "{}/{}".format(self.endpoint_basepath, AuthorizationEndpoint.url)

authentication = ("^{}$".format(auth_path), self.handle_authn_request)
url_map = [provider_config, jwks_uri, authentication]
Expand All @@ -200,21 +199,21 @@ def register_endpoints(self, backend_names):
self.endpoint_baseurl, TokenEndpoint.url
)
token_endpoint = (
"^{}/{}".format(self.name, TokenEndpoint.url), self.token_endpoint
"^{}/{}".format(self.endpoint_basepath, TokenEndpoint.url), self.token_endpoint
)
url_map.append(token_endpoint)

self.provider.configuration_information["userinfo_endpoint"] = (
"{}/{}".format(self.endpoint_baseurl, UserinfoEndpoint.url)
)
userinfo_endpoint = (
"^{}/{}".format(self.name, UserinfoEndpoint.url), self.userinfo_endpoint
"^{}/{}".format(self.endpoint_basepath, UserinfoEndpoint.url), self.userinfo_endpoint
)
url_map.append(userinfo_endpoint)

if "registration_endpoint" in self.provider.configuration_information:
client_registration = (
"^{}/{}".format(self.name, RegistrationEndpoint.url),
"^{}/{}".format(self.endpoint_basepath, RegistrationEndpoint.url),
self.client_registration,
)
url_map.append(client_registration)
Expand Down
3 changes: 2 additions & 1 deletion src/satosa/frontends/ping.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
import os.path

import satosa.logging_util as lu
import satosa.micro_services.base
Expand Down Expand Up @@ -44,7 +45,7 @@ def register_endpoints(self, backend_names):
:rtype: list[(str, ((satosa.context.Context, Any) -> satosa.response.Response, Any))]
:raise ValueError: if more than one backend is configured
"""
url_map = [("^{}".format(self.name), self.ping_endpoint)]
url_map = [("^{}".format(os.path.join(self.endpoint_basepath, self.name)), self.ping_endpoint)]

return url_map

Expand Down
42 changes: 28 additions & 14 deletions src/satosa/frontends/saml2.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def register_endpoints(self, backend_names):

if self.enable_metadata_reload():
url_map.append(
("^%s/%s$" % (self.name, "reload-metadata"), self._reload_metadata))
("^%s/%s$" % (self.endpoint_basepath, "reload-metadata"), self._reload_metadata))

self.idp_config = self._build_idp_config_endpoints(
self.config[self.KEY_IDP_CONFIG], backend_names)
Expand Down Expand Up @@ -512,15 +512,19 @@ def _register_endpoints(self, providers):
"""
url_map = []

backend_providers = "|".join(providers)
base_path = urlparse(self.base_url).path.lstrip("/")
if base_path:
base_path = base_path + "/"
for endp_category in self.endpoints:
for binding, endp in self.endpoints[endp_category].items():
valid_providers = ""
for provider in providers:
valid_providers = "{}|^{}".format(valid_providers, provider)
valid_providers = valid_providers.lstrip("|")
parsed_endp = urlparse(endp)
url_map.append(("(%s)/%s$" % (valid_providers, parsed_endp.path),
functools.partial(self.handle_authn_request, binding_in=binding)))
endp_path = urlparse(endp).path
url_map.append(
(
"^{}({})/{}$".format(base_path, backend_providers, endp_path),
functools.partial(self.handle_authn_request, binding_in=binding)
)
)

if self.expose_entityid_endpoint():
logger.debug("Exposing frontend entity endpoint = {}".format(self.idp.config.entityid))
Expand Down Expand Up @@ -676,11 +680,18 @@ def _load_idp_dynamic_endpoints(self, context):
:param context:
:return: An idp server
"""
target_entity_id = context.target_entity_id_from_path()
target_entity_id = self._target_entity_id_from_path(context.path)
idp_conf_file = self._load_endpoints_to_config(context.target_backend, target_entity_id)
idp_config = IdPConfig().load(idp_conf_file)
return Server(config=idp_config)

def _target_entity_id_from_path(self, request_path):
path = request_path.lstrip("/")
base_path = urlparse(self.base_url).path.lstrip("/")
if base_path and path.startswith(base_path):
path = path[len(base_path):].lstrip("/")
return path.split("/")[1]

def _load_idp_dynamic_entity_id(self, state):
"""
Loads an idp server with the entity id saved in state
Expand All @@ -706,7 +717,7 @@ def handle_authn_request(self, context, binding_in):
:type binding_in: str
:rtype: satosa.response.Response
"""
target_entity_id = context.target_entity_id_from_path()
target_entity_id = self._target_entity_id_from_path(context.path)
target_entity_id = urlsafe_b64decode(target_entity_id).decode()
context.decorate(Context.KEY_TARGET_ENTITYID, target_entity_id)

Expand All @@ -724,7 +735,7 @@ def _create_state_data(self, context, resp_args, relay_state):
:rtype: dict[str, dict[str, str] | str]
"""
state = super()._create_state_data(context, resp_args, relay_state)
state["target_entity_id"] = context.target_entity_id_from_path()
state["target_entity_id"] = self._target_entity_id_from_path(context.path)
return state

def handle_backend_error(self, exception):
Expand Down Expand Up @@ -759,13 +770,16 @@ def _register_endpoints(self, providers):
"""
url_map = []

backend_providers = "|".join(providers)
base_path = urlparse(self.base_url).path.lstrip("/")
if base_path:
base_path = base_path + "/"
for endp_category in self.endpoints:
for binding, endp in self.endpoints[endp_category].items():
valid_providers = "|^".join(providers)
parsed_endp = urlparse(endp)
endp_path = urlparse(endp).path
url_map.append(
(
r"(^{})/\S+/{}".format(valid_providers, parsed_endp.path),
"^{}({})/\S+/{}$".format(base_path, backend_providers, endp_path),
functools.partial(self.handle_authn_request, binding_in=binding)
)
)
Expand Down
12 changes: 11 additions & 1 deletion src/satosa/micro_services/account_linking.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
"""
import json
import logging
import os.path

import requests
from jwkest.jwk import rsa_load, RSAKey
Expand Down Expand Up @@ -161,4 +162,13 @@ def register_endpoints(self):
:return: A list of endpoints bound to a function
"""
return [("^account_linking%s$" % self.endpoint, self._handle_al_response)]
return [
(
"^{}$".format(
os.path.join(
self.base_path, "account_linking", self.endpoint.lstrip("/")
)
),
self._handle_al_response,
)
]
2 changes: 2 additions & 0 deletions src/satosa/micro_services/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
Micro service for SATOSA
"""
import logging
from urllib.parse import urlparse

logger = logging.getLogger(__name__)

Expand All @@ -14,6 +15,7 @@ class MicroService(object):
def __init__(self, name, base_url, **kwargs):
self.name = name
self.base_url = base_url
self.base_path = urlparse(base_url).path.lstrip("/")
self.next = None

def process(self, context, data):
Expand Down
12 changes: 11 additions & 1 deletion src/satosa/micro_services/consent.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import hashlib
import json
import logging
import os.path
from base64 import urlsafe_b64encode

import requests
Expand Down Expand Up @@ -238,4 +239,13 @@ def register_endpoints(self):
:return: A list of endpoints bound to a function
"""
return [("^consent%s$" % self.endpoint, self._handle_consent_response)]
return [
(
"^{}$".format(
os.path.join(
self.base_path, "consent", self.endpoint.lstrip("/")
)
),
self._handle_consent_response,
)
]
30 changes: 24 additions & 6 deletions src/satosa/routing.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,18 +38,20 @@ class UnknownEndpoint(ValueError):
and handles the internal routing between frontends and backends.
"""

def __init__(self, frontends, backends, micro_services):
def __init__(self, frontends, backends, micro_services, base_path=""):
"""
:type frontends: dict[str, satosa.frontends.base.FrontendModule]
:type backends: dict[str, satosa.backends.base.BackendModule]
:type micro_services: Sequence[satosa.micro_services.base.MicroService]
:type base_path: str
:param frontends: All available frontends used by the proxy. Key as frontend name, value as
module
:param backends: All available backends used by the proxy. Key as backend name, value as
module
:param micro_services: All available micro services used by the proxy. Key as micro service name, value as
module
:param base_path: Base path for endpoint mapping
"""

if not frontends or not backends:
Expand All @@ -68,6 +70,8 @@ def __init__(self, frontends, backends, micro_services):
else:
self.micro_services = {}

self.base_path = base_path

logger.debug("Loaded backends with endpoints: {}".format(backends))
logger.debug("Loaded frontends with endpoints: {}".format(frontends))
logger.debug("Loaded micro services with endpoints: {}".format(micro_services))
Expand Down Expand Up @@ -134,6 +138,19 @@ def _find_registered_endpoint(self, context, modules):

raise ModuleRouter.UnknownEndpoint(context.path)

def _find_backend(self, request_path):
"""
Tries to guess the backend in use from the request.
Returns the backend name or None if the backend was not specified.
"""
request_path = request_path.lstrip("/")
if self.base_path and request_path.startswith(self.base_path):
request_path = request_path[len(self.base_path):].lstrip("/")
backend_guess = request_path.split("/")[0]
if backend_guess in self.backends:
return backend_guess
return None

def endpoint_routing(self, context):
"""
Finds and returns the endpoint function bound to the path
Expand All @@ -155,13 +172,12 @@ def endpoint_routing(self, context):
msg = "Routing path: {path}".format(path=context.path)
logline = lu.LOG_FMT.format(id=lu.get_session_id(context.state), message=msg)
logger.debug(logline)
path_split = context.path.split("/")
backend = path_split[0]

if backend in self.backends:
backend = self._find_backend(context.path)
if backend is not None:
context.target_backend = backend
else:
msg = "Unknown backend {}".format(backend)
msg = "No backend was specified in request or no such backend {}".format(backend)
logline = lu.LOG_FMT.format(
id=lu.get_session_id(context.state), message=msg
)
Expand All @@ -170,6 +186,8 @@ def endpoint_routing(self, context):
try:
name, frontend_endpoint = self._find_registered_endpoint(context, self.frontends)
except ModuleRouter.UnknownEndpoint:
for frontend in self.frontends.values():
logger.debug(f"Unable to find {context.path} in {frontend['endpoints']}")
pass
else:
context.target_frontend = name
Expand All @@ -183,7 +201,7 @@ def endpoint_routing(self, context):
context.target_micro_service = name
return micro_service_endpoint

if backend in self.backends:
if backend is not None:
backend_endpoint = self._find_registered_backend_endpoint(context)
if backend_endpoint:
return backend_endpoint
Expand Down
Loading

0 comments on commit 862c400

Please sign in to comment.