diff --git a/example/plugins/frontends/openid_connect_frontend.yaml.example b/example/plugins/frontends/openid_connect_frontend.yaml.example index d7a5584d8..7a9e78417 100644 --- a/example/plugins/frontends/openid_connect_frontend.yaml.example +++ b/example/plugins/frontends/openid_connect_frontend.yaml.example @@ -33,6 +33,7 @@ config: sub_hash_salt: randomSALTvalue provider: + issuer: client_registration_supported: Yes response_types_supported: ["code", "id_token token"] subject_types_supported: ["pairwise"] diff --git a/src/satosa/backends/base.py b/src/satosa/backends/base.py index 8d0432da8..d18dfc4d6 100644 --- a/src/satosa/backends/base.py +++ b/src/satosa/backends/base.py @@ -3,6 +3,9 @@ """ from ..attribute_mapping import AttributeMapper +from ..util import join_paths + +from urllib.parse import urlparse class BackendModule(object): @@ -29,8 +32,11 @@ def __init__(self, auth_callback_func, internal_attributes, base_url, name): self.auth_callback_func = auth_callback_func self.internal_attributes = internal_attributes self.converter = AttributeMapper(internal_attributes) - self.base_url = base_url + self.base_url = base_url.rstrip("/") if base_url else "" + self.base_path = urlparse(self.base_url).path.lstrip("/") self.name = name + self.endpoint_baseurl = join_paths(self.base_url, self.name) + self.endpoint_basepath = urlparse(self.endpoint_baseurl).path.lstrip("/") def start_auth(self, context, internal_request): """ diff --git a/src/satosa/base.py b/src/satosa/base.py index 404104920..67d4bc995 100644 --- a/src/satosa/base.py +++ b/src/satosa/base.py @@ -6,6 +6,7 @@ import uuid from saml2.s_utils import UnknownSystemEntity +from urllib.parse import urlparse from satosa import util from .context import Context @@ -38,6 +39,8 @@ def __init__(self, config): """ self.config = config + base_path = urlparse(self.config["BASE"]).path.lstrip("/") + logger.info("Loading backend modules...") backends = load_backends(self.config, self._auth_resp_callback_func, self.config["INTERNAL_ATTRIBUTES"]) @@ -63,8 +66,10 @@ def __init__(self, config): self.config["BASE"])) self._link_micro_services(self.response_micro_services, self._auth_resp_finish) - self.module_router = ModuleRouter(frontends, backends, - self.request_micro_services + self.response_micro_services) + self.module_router = ModuleRouter(frontends, + backends, + self.request_micro_services + self.response_micro_services, + base_path) def _link_micro_services(self, micro_services, finisher): if not micro_services: diff --git a/src/satosa/context.py b/src/satosa/context.py index 1cf140586..fd51ef581 100644 --- a/src/satosa/context.py +++ b/src/satosa/context.py @@ -76,10 +76,6 @@ def path(self, p): raise ValueError("path can't start with '/'") self._path = p - def target_entity_id_from_path(self): - target_entity_id = self.path.split("/")[1] - return target_entity_id - def decorate(self, key, value): """ Add information to the context diff --git a/src/satosa/frontends/base.py b/src/satosa/frontends/base.py index 52840a85c..49224c48f 100644 --- a/src/satosa/frontends/base.py +++ b/src/satosa/frontends/base.py @@ -2,6 +2,9 @@ Holds a base class for frontend modules used in the SATOSA proxy. """ from ..attribute_mapping import AttributeMapper +from ..util import join_paths + +from urllib.parse import urlparse class FrontendModule(object): @@ -14,17 +17,23 @@ def __init__(self, auth_req_callback_func, internal_attributes, base_url, name): :type auth_req_callback_func: (satosa.context.Context, satosa.internal.InternalData) -> satosa.response.Response :type internal_attributes: dict[str, dict[str, str | list[str]]] + :type base_url: str :type name: str :param auth_req_callback_func: Callback should be called by the module after the authorization response has been processed. + :param internal_attributes: attribute mapping + :param base_url: base url of the proxy :param name: name of the plugin """ self.auth_req_callback_func = auth_req_callback_func self.internal_attributes = internal_attributes self.converter = AttributeMapper(internal_attributes) - self.base_url = base_url + self.base_url = base_url or "" + self.base_path = urlparse(self.base_url).path.lstrip("/") self.name = name + self.endpoint_baseurl = join_paths(self.base_url, self.name) + self.endpoint_basepath = urlparse(self.endpoint_baseurl).path.lstrip("/") def handle_authn_response(self, context, internal_resp): """ diff --git a/src/satosa/frontends/openid_connect.py b/src/satosa/frontends/openid_connect.py index 88041b373..3097c594b 100644 --- a/src/satosa/frontends/openid_connect.py +++ b/src/satosa/frontends/openid_connect.py @@ -37,7 +37,7 @@ from ..response import BadRequest, Created from ..response import SeeOther, Response from ..response import Unauthorized -from ..util import rndstr +from ..util import join_paths, rndstr import satosa.logging_util as lu from satosa.internal import InternalData @@ -62,7 +62,8 @@ def __init__(self, auth_req_callback_func, internal_attributes, conf, base_url, self.config = conf provider_config = self.config["provider"] - provider_config["issuer"] = base_url + if not provider_config.get("issuer"): + provider_config["issuer"] = base_url self.signing_key = RSAKey( key=rsa_load(self.config["signing_key_path"]), @@ -97,7 +98,6 @@ def __init__(self, auth_req_callback_func, internal_attributes, conf, base_url, else: cdb = {} - self.endpoint_baseurl = "{}/{}".format(self.base_url, self.name) self.provider = _create_provider( provider_config, self.endpoint_baseurl, @@ -173,6 +173,19 @@ def register_endpoints(self, backend_names): :rtype: list[(str, ((satosa.context.Context, Any) -> satosa.response.Response, Any))] :raise ValueError: if more than one backend is configured """ + # See https://openid.net/specs/openid-connect-discovery-1_0.html#ProviderConfig + # + # We skip the scheme + host + port of the issuer URL, because we can only map the + # path for the provider config endpoint. We are safe to use urlparse().path here, + # because for issuer OIDC allows only https URLs without query and fragment parts. + issuer = self.provider.configuration_information["issuer"] + autoconf_path = ".well-known/openid-configuration" + provider_config = ( + "^{}$".format(join_paths(urlparse(issuer).path.lstrip("/"), autoconf_path)), + self.provider_config, + ) + jwks_uri = ("^{}/jwks$".format(self.endpoint_basepath), self.jwks) + backend_name = None if len(backend_names) != 1: # only supports one backend since there currently is no way to publish multiple authorization endpoints @@ -189,40 +202,49 @@ def register_endpoints(self, backend_names): else: backend_name = backend_names[0] - provider_config = ("^.well-known/openid-configuration$", self.provider_config) - jwks_uri = ("^{}/jwks$".format(self.name), self.jwks) - if backend_name: # if there is only one backend, include its name in the path so the default routing can work - auth_endpoint = "{}/{}/{}/{}".format(self.base_url, backend_name, self.name, AuthorizationEndpoint.url) + auth_endpoint = join_paths( + self.base_url, + backend_name, + self.name, + AuthorizationEndpoint.url, + ) self.provider.configuration_information["authorization_endpoint"] = auth_endpoint auth_path = urlparse(auth_endpoint).path.lstrip("/") else: - auth_path = "{}/{}".format(self.name, AuthorizationEndpoint.url) + auth_path = join_paths(self.endpoint_basepath, AuthorizationRequest.url) authentication = ("^{}$".format(auth_path), self.handle_authn_request) url_map = [provider_config, jwks_uri, authentication] if any("code" in v for v in self.provider.configuration_information["response_types_supported"]): - self.provider.configuration_information["token_endpoint"] = "{}/{}".format( - self.endpoint_baseurl, TokenEndpoint.url + self.provider.configuration_information["token_endpoint"] = join_paths( + self.endpoint_baseurl, + TokenEndpoint.url, ) token_endpoint = ( - "^{}/{}".format(self.name, TokenEndpoint.url), self.token_endpoint + "^{}".format(join_paths(self.endpoint_basepath, TokenEndpoint.url)), + self.token_endpoint, ) url_map.append(token_endpoint) self.provider.configuration_information["userinfo_endpoint"] = ( - "{}/{}".format(self.endpoint_baseurl, UserinfoEndpoint.url) + join_paths(self.endpoint_baseurl, UserinfoEndpoint.url) ) userinfo_endpoint = ( - "^{}/{}".format(self.name, UserinfoEndpoint.url), self.userinfo_endpoint + "^{}".format( + join_paths(self.endpoint_basepath, UserinfoEndpoint.url) + ), + self.userinfo_endpoint, ) url_map.append(userinfo_endpoint) if "registration_endpoint" in self.provider.configuration_information: client_registration = ( - "^{}/{}".format(self.name, RegistrationEndpoint.url), + "^{}".format( + join_paths(self.endpoint_basepath, RegistrationEndpoint.url) + ), self.client_registration, ) url_map.append(client_registration) diff --git a/src/satosa/frontends/ping.py b/src/satosa/frontends/ping.py index 27fec279c..ff609cfd2 100644 --- a/src/satosa/frontends/ping.py +++ b/src/satosa/frontends/ping.py @@ -3,6 +3,7 @@ import satosa.logging_util as lu from satosa.frontends.base import FrontendModule from satosa.response import Response +from satosa.util import join_paths logger = logging.getLogger(__name__) @@ -43,7 +44,7 @@ def register_endpoints(self, backend_names): :rtype: list[(str, ((satosa.context.Context, Any) -> satosa.response.Response, Any))] :raise ValueError: if more than one backend is configured """ - url_map = [("^{}".format(self.name), self.ping_endpoint)] + url_map = [("^{}".format(join_paths(self.endpoint_basepath, self.name)), self.ping_endpoint)] return url_map diff --git a/src/satosa/frontends/saml2.py b/src/satosa/frontends/saml2.py index 655e6da68..2b8a1149a 100644 --- a/src/satosa/frontends/saml2.py +++ b/src/satosa/frontends/saml2.py @@ -33,6 +33,7 @@ from ..response import Response from ..response import ServiceError from ..saml_util import make_saml_response +from ..util import join_paths from satosa.exception import SATOSAError import satosa.util as util @@ -117,7 +118,7 @@ def register_endpoints(self, backend_names): if self.enable_metadata_reload(): url_map.append( - ("^%s/%s$" % (self.name, "reload-metadata"), self._reload_metadata)) + ("^%s/%s$" % (self.endpoint_basepath, "reload-metadata"), self._reload_metadata)) self.idp_config = self._build_idp_config_endpoints( self.config[self.KEY_IDP_CONFIG], backend_names) @@ -511,15 +512,20 @@ def _register_endpoints(self, providers): """ url_map = [] + backend_providers = "(" + "|".join(providers) + ")" for endp_category in self.endpoints: for binding, endp in self.endpoints[endp_category].items(): - valid_providers = "" - for provider in providers: - valid_providers = "{}|^{}".format(valid_providers, provider) - valid_providers = valid_providers.lstrip("|") - parsed_endp = urlparse(endp) - url_map.append(("(%s)/%s$" % (valid_providers, parsed_endp.path), - functools.partial(self.handle_authn_request, binding_in=binding))) + endp_path = urlparse(endp).path + url_map.append( + ( + "^{}$".format( + join_paths(self.base_path, backend_providers, endp_path) + ), + functools.partial( + self.handle_authn_request, binding_in=binding + ), + ) + ) if self.expose_entityid_endpoint(): logger.debug("Exposing frontend entity endpoint = {}".format(self.idp.config.entityid)) @@ -675,11 +681,18 @@ def _load_idp_dynamic_endpoints(self, context): :param context: :return: An idp server """ - target_entity_id = context.target_entity_id_from_path() + target_entity_id = self._target_entity_id_from_path(context.path) idp_conf_file = self._load_endpoints_to_config(context.target_backend, target_entity_id) idp_config = IdPConfig().load(idp_conf_file) return Server(config=idp_config) + def _target_entity_id_from_path(self, request_path): + path = request_path.lstrip("/") + base_path = urlparse(self.base_url).path.lstrip("/") + if base_path and path.startswith(base_path): + path = path[len(base_path):].lstrip("/") + return path.split("/")[1] + def _load_idp_dynamic_entity_id(self, state): """ Loads an idp server with the entity id saved in state @@ -705,7 +718,7 @@ def handle_authn_request(self, context, binding_in): :type binding_in: str :rtype: satosa.response.Response """ - target_entity_id = context.target_entity_id_from_path() + target_entity_id = self._target_entity_id_from_path(context.path) target_entity_id = urlsafe_b64decode(target_entity_id).decode() context.decorate(Context.KEY_TARGET_ENTITYID, target_entity_id) @@ -723,7 +736,7 @@ def _create_state_data(self, context, resp_args, relay_state): :rtype: dict[str, dict[str, str] | str] """ state = super()._create_state_data(context, resp_args, relay_state) - state["target_entity_id"] = context.target_entity_id_from_path() + state["target_entity_id"] = self._target_entity_id_from_path(context.path) return state def handle_backend_error(self, exception): @@ -758,14 +771,20 @@ def _register_endpoints(self, providers): """ url_map = [] + backend_providers = "(" + "|".join(providers) + ")" for endp_category in self.endpoints: for binding, endp in self.endpoints[endp_category].items(): - valid_providers = "|^".join(providers) - parsed_endp = urlparse(endp) + endp_path = urlparse(endp).path url_map.append( ( - r"(^{})/\S+/{}".format(valid_providers, parsed_endp.path), - functools.partial(self.handle_authn_request, binding_in=binding) + "^{}$".format( + join_paths( + self.base_path, backend_providers, "\S+", endp_path + ) + ), + functools.partial( + self.handle_authn_request, binding_in=binding + ), ) ) diff --git a/src/satosa/micro_services/account_linking.py b/src/satosa/micro_services/account_linking.py index 7305c3d79..bebd77b7f 100644 --- a/src/satosa/micro_services/account_linking.py +++ b/src/satosa/micro_services/account_linking.py @@ -12,6 +12,7 @@ from ..exception import SATOSAAuthenticationError from ..micro_services.base import ResponseMicroService from ..response import Redirect +from ..util import join_paths import satosa.logging_util as lu logger = logging.getLogger(__name__) @@ -161,4 +162,11 @@ def register_endpoints(self): :return: A list of endpoints bound to a function """ - return [("^account_linking%s$" % self.endpoint, self._handle_al_response)] + return [ + ( + "^{}$".format( + join_paths(self.base_path, "account_linking", self.endpoint) + ), + self._handle_al_response, + ) + ] diff --git a/src/satosa/micro_services/base.py b/src/satosa/micro_services/base.py index 084cbea76..b31baf9b4 100644 --- a/src/satosa/micro_services/base.py +++ b/src/satosa/micro_services/base.py @@ -2,6 +2,9 @@ Micro service for SATOSA """ import logging +from urllib.parse import urlparse + +from ..util import join_paths logger = logging.getLogger(__name__) @@ -14,6 +17,9 @@ class MicroService(object): def __init__(self, name, base_url, **kwargs): self.name = name self.base_url = base_url + self.base_path = urlparse(base_url).path.lstrip("/") + self.endpoint_baseurl = join_paths(self.base_url, self.name) + self.endpoint_basepath = urlparse(self.endpoint_baseurl).path.lstrip("/") self.next = None def process(self, context, data): diff --git a/src/satosa/micro_services/consent.py b/src/satosa/micro_services/consent.py index a469e2189..fbe8e75dc 100644 --- a/src/satosa/micro_services/consent.py +++ b/src/satosa/micro_services/consent.py @@ -16,6 +16,7 @@ from satosa.internal import InternalData from satosa.micro_services.base import ResponseMicroService from satosa.response import Redirect +from satosa.util import join_paths logger = logging.getLogger(__name__) @@ -238,4 +239,11 @@ def register_endpoints(self): :return: A list of endpoints bound to a function """ - return [("^consent%s$" % self.endpoint, self._handle_consent_response)] + return [ + ( + "^{}$".format( + join_paths(self.base_path, "consent", self.endpoint) + ), + self._handle_consent_response, + ) + ] diff --git a/src/satosa/routing.py b/src/satosa/routing.py index 317b047f9..d739273ac 100644 --- a/src/satosa/routing.py +++ b/src/satosa/routing.py @@ -38,11 +38,12 @@ class UnknownEndpoint(ValueError): and handles the internal routing between frontends and backends. """ - def __init__(self, frontends, backends, micro_services): + def __init__(self, frontends, backends, micro_services, base_path=None): """ :type frontends: dict[str, satosa.frontends.base.FrontendModule] :type backends: dict[str, satosa.backends.base.BackendModule] :type micro_services: Sequence[satosa.micro_services.base.MicroService] + :type base_path: str :param frontends: All available frontends used by the proxy. Key as frontend name, value as module @@ -50,6 +51,7 @@ def __init__(self, frontends, backends, micro_services): module :param micro_services: All available micro services used by the proxy. Key as micro service name, value as module + :param base_path: Base path for endpoint mapping """ if not frontends or not backends: @@ -68,6 +70,8 @@ def __init__(self, frontends, backends, micro_services): else: self.micro_services = {} + self.base_path = base_path if base_path else "" + logger.debug("Loaded backends with endpoints: {}".format(backends)) logger.debug("Loaded frontends with endpoints: {}".format(frontends)) logger.debug("Loaded micro services with endpoints: {}".format(micro_services)) @@ -134,6 +138,19 @@ def _find_registered_endpoint(self, context, modules): raise ModuleRouter.UnknownEndpoint(context.path) + def _find_backend(self, request_path): + """ + Tries to guess the backend in use from the request. + Returns the backend name or None if the backend was not specified. + """ + request_path = request_path.lstrip("/") + if self.base_path and request_path.startswith(self.base_path): + request_path = request_path[len(self.base_path):].lstrip("/") + backend_guess = request_path.split("/")[0] + if backend_guess in self.backends: + return backend_guess + return None + def endpoint_routing(self, context): """ Finds and returns the endpoint function bound to the path @@ -155,13 +172,12 @@ def endpoint_routing(self, context): msg = "Routing path: {path}".format(path=context.path) logline = lu.LOG_FMT.format(id=lu.get_session_id(context.state), message=msg) logger.debug(logline) - path_split = context.path.split("/") - backend = path_split[0] - if backend in self.backends: + backend = self._find_backend(context.path) + if backend is not None: context.target_backend = backend else: - msg = "Unknown backend {}".format(backend) + msg = "No backend was specified in request or no such backend {}".format(backend) logline = lu.LOG_FMT.format( id=lu.get_session_id(context.state), message=msg ) @@ -170,6 +186,8 @@ def endpoint_routing(self, context): try: name, frontend_endpoint = self._find_registered_endpoint(context, self.frontends) except ModuleRouter.UnknownEndpoint: + for frontend in self.frontends.values(): + logger.debug(f"Unable to find {context.path} in {frontend['endpoints']}") pass else: context.target_frontend = name @@ -183,7 +201,7 @@ def endpoint_routing(self, context): context.target_micro_service = name return micro_service_endpoint - if backend in self.backends: + if backend is not None: backend_endpoint = self._find_registered_backend_endpoint(context) if backend_endpoint: return backend_endpoint diff --git a/src/satosa/util.py b/src/satosa/util.py index 9b5d63fc1..75ec1f6fc 100644 --- a/src/satosa/util.py +++ b/src/satosa/util.py @@ -5,6 +5,7 @@ import logging import random import string +import typing logger = logging.getLogger(__name__) @@ -89,3 +90,25 @@ def rndstr(size=16, alphabet=""): if not alphabet: alphabet = string.ascii_letters[0:52] + string.digits return type(alphabet)().join(rng.choice(alphabet) for _ in range(size)) + + +def join_paths(*paths, sep: typing.Optional[str] = None) -> str: + """ + Joins strings with a separator like they were path components. The + separator is stripped off from all path components, except for the + beginning of the first component. Empty (or falsy) components are skipped. + Note that the components are not sanitized in any other way. + + Raises TypeError if any of the components are not strings (or empty). + """ + sep = sep or "/" + leading = "" + if paths and paths[0] and paths[0][0] == sep: + leading = sep + + try: + return leading + sep.join( + path.strip(sep) for path in filter(lambda p: p and p.strip(sep), paths) + ) + except (AttributeError, TypeError) as err: + raise TypeError("Arguments must be strings") from err diff --git a/tests/conftest.py b/tests/conftest.py index f0602a028..52c4a9fc1 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -11,7 +11,7 @@ from .util import create_metadata_from_config_dict from .util import generate_cert, write_cert -BASE_URL = "https://test-proxy.com" +BASE_URL = "https://test-proxy.com/satosa" @pytest.fixture(scope="session") @@ -38,7 +38,7 @@ def cert_and_key(tmpdir): @pytest.fixture def sp_conf(cert_and_key): - sp_base = "http://example.com" + sp_base = "http://example.com/test" spconfig = { "entityid": "{}/unittest_sp.xml".format(sp_base), "service": { @@ -64,7 +64,7 @@ def sp_conf(cert_and_key): @pytest.fixture def idp_conf(cert_and_key): - idp_base = "http://idp.example.com" + idp_base = "http://idp.example.com/test" idpconfig = { "entityid": "{}/{}/proxy.xml".format(idp_base, "Saml2IDP"), @@ -136,7 +136,23 @@ def satosa_config_dict(backend_plugin_config, frontend_plugin_config, request_mi "BACKEND_MODULES": [backend_plugin_config], "FRONTEND_MODULES": [frontend_plugin_config], "MICRO_SERVICES": [request_microservice_config, response_microservice_config], - "LOGGING": {"version": 1} + "LOGGING": { + "version": 1, + "handlers": { + "stdout": { + "level": "DEBUG", + "class": "logging.StreamHandler", + "stream": "ext://sys.stdout", + "formatter": "simple", + } + }, + "loggers": {"satosa": {"level": "DEBUG"}}, + "formatters": { + "simple": { + "format": "[%(asctime)s] [%(levelname)s] [%(name)s.%(funcName)s] %(message)s" + } + }, + } } return config diff --git a/tests/flows/test_account_linking.py b/tests/flows/test_account_linking.py index 94f53a431..ccb189593 100644 --- a/tests/flows/test_account_linking.py +++ b/tests/flows/test_account_linking.py @@ -1,9 +1,11 @@ import responses +from urllib.parse import urlparse from werkzeug.test import Client from werkzeug.wrappers import Response from satosa.proxy_server import make_app from satosa.satosa_config import SATOSAConfig +from satosa.util import join_paths class TestAccountLinking: @@ -13,6 +15,7 @@ def test_full_flow(self, satosa_config_dict, account_linking_module_config): account_linking_module_config["config"]["api_url"] = api_url account_linking_module_config["config"]["redirect_url"] = redirect_url satosa_config_dict["MICRO_SERVICES"].insert(0, account_linking_module_config) + base_path = urlparse(satosa_config_dict["BASE"]).path.lstrip("\n/") # application test_client = Client(make_app(SATOSAConfig(satosa_config_dict)), Response) @@ -36,5 +39,5 @@ def test_full_flow(self, satosa_config_dict, account_linking_module_config): rsps.add(responses.GET, "{}/get_id".format(api_url), "test_userid", status=200) # incoming account linking response - http_resp = test_client.get("/account_linking/handle_account_linking") + http_resp = test_client.get(join_paths("/", base_path, "account_linking/handle_account_linking")) assert http_resp.status_code == 200 diff --git a/tests/flows/test_consent.py b/tests/flows/test_consent.py index 76dff496b..4f616e180 100644 --- a/tests/flows/test_consent.py +++ b/tests/flows/test_consent.py @@ -2,11 +2,13 @@ import re import responses +from urllib.parse import urlparse from werkzeug.test import Client from werkzeug.wrappers import Response from satosa.proxy_server import make_app from satosa.satosa_config import SATOSAConfig +from satosa.util import join_paths class TestConsent: @@ -16,6 +18,7 @@ def test_full_flow(self, satosa_config_dict, consent_module_config): consent_module_config["config"]["api_url"] = api_url consent_module_config["config"]["redirect_url"] = redirect_url satosa_config_dict["MICRO_SERVICES"].append(consent_module_config) + base_path = urlparse(satosa_config_dict["BASE"]).path.lstrip("\n/") # application test_client = Client(make_app(SATOSAConfig(satosa_config_dict)), Response) @@ -43,5 +46,5 @@ def test_full_flow(self, satosa_config_dict, consent_module_config): rsps.add(responses.GET, verify_url_re, json.dumps({"foo": "bar"}), status=200) # incoming consent response - http_resp = test_client.get("/consent/handle_consent") + http_resp = test_client.get(join_paths("/", base_path, "consent/handle_consent")) assert http_resp.status_code == 200 diff --git a/tests/flows/test_oidc-saml.py b/tests/flows/test_oidc-saml.py index 2a299bfef..81acef2e4 100644 --- a/tests/flows/test_oidc-saml.py +++ b/tests/flows/test_oidc-saml.py @@ -17,6 +17,7 @@ from satosa.metadata_creation.saml_metadata import create_entity_descriptors from satosa.proxy_server import make_app from satosa.satosa_config import SATOSAConfig +from satosa.util import join_paths from tests.users import USERS from tests.users import OIDC_USERS from tests.util import FakeIdP @@ -27,6 +28,7 @@ CLIENT_REDIRECT_URI = "https://client.example.com/cb" REDIRECT_URI = "https://client.example.com/cb" DB_URI = "mongodb://localhost/satosa" +EXTRA_ISSUER = "/other/op" @pytest.fixture(scope="session") def client_db_path(tmpdir_factory): @@ -52,7 +54,6 @@ def oidc_frontend_config(signing_key_path): "module": "satosa.frontends.openid_connect.OpenIDConnectFrontend", "name": "OIDCFrontend", "config": { - "issuer": "https://proxy-op.example.com", "signing_key_path": signing_key_path, "provider": {"response_types_supported": ["id_token"]}, "client_db_uri": DB_URI, # use mongodb for integration testing @@ -69,7 +70,6 @@ def oidc_stateless_frontend_config(signing_key_path, client_db_path): "module": "satosa.frontends.openid_connect.OpenIDConnectFrontend", "name": "OIDCFrontend", "config": { - "issuer": "https://proxy-op.example.com", "signing_key_path": signing_key_path, "client_db_path": client_db_path, "db_uri": "stateless://user:abc123@localhost", @@ -94,11 +94,18 @@ def _client_setup(self): "response_types": ["id_token"] } + def _discover_provider(self, client, provider): + discovery_path = ( + join_paths(urlparse(provider).path, ".well-known/openid-configuration") + ) + return json.loads(client.get(discovery_path).data.decode("utf-8")) + def test_full_flow(self, satosa_config_dict, oidc_frontend_config, saml_backend_config, idp_conf): self._client_setup() subject_id = "testuser1" # proxy config + oidc_frontend_config["config"]["provider"]["issuer"] = EXTRA_ISSUER satosa_config_dict["FRONTEND_MODULES"] = [oidc_frontend_config] satosa_config_dict["BACKEND_MODULES"] = [saml_backend_config] satosa_config_dict["INTERNAL_ATTRIBUTES"]["attributes"] = {attr_name: {"openid": [attr_name], @@ -110,7 +117,8 @@ def test_full_flow(self, satosa_config_dict, oidc_frontend_config, saml_backend_ test_client = Client(make_app(SATOSAConfig(satosa_config_dict)), Response) # get frontend OP config info - provider_config = json.loads(test_client.get("/.well-known/openid-configuration").data.decode("utf-8")) + issuer = EXTRA_ISSUER.replace("", satosa_config_dict["BASE"]) + provider_config = self._discover_provider(test_client, issuer) # create auth req claims_request = ClaimsRequest(id_token=Claims(**{k: None for k in USERS[subject_id]})) @@ -168,7 +176,8 @@ def test_full_stateless_id_token_flow(self, satosa_config_dict, oidc_stateless_f test_client = Client(make_app(SATOSAConfig(satosa_config_dict)), Response) # get frontend OP config info - provider_config = json.loads(test_client.get("/.well-known/openid-configuration").data.decode("utf-8")) + issuer = satosa_config_dict["BASE"] + provider_config = self._discover_provider(test_client, issuer) # create auth req claims_request = ClaimsRequest(id_token=Claims(**{k: None for k in USERS[subject_id]})) @@ -226,7 +235,8 @@ def test_full_stateless_code_flow(self, satosa_config_dict, oidc_stateless_front test_client = Client(make_app(SATOSAConfig(satosa_config_dict)), Response) # get frontend OP config info - provider_config = json.loads(test_client.get("/.well-known/openid-configuration").data.decode("utf-8")) + issuer = satosa_config_dict["BASE"] + provider_config = self._discover_provider(test_client, issuer) # create auth req claims_request = ClaimsRequest(id_token=Claims(**{k: None for k in USERS[subject_id]})) diff --git a/tests/satosa/frontends/test_openid_connect.py b/tests/satosa/frontends/test_openid_connect.py index f769b2c66..5db023dcd 100644 --- a/tests/satosa/frontends/test_openid_connect.py +++ b/tests/satosa/frontends/test_openid_connect.py @@ -28,7 +28,7 @@ INTERNAL_ATTRIBUTES = { "attributes": {"mail": {"saml": ["email"], "openid": ["email"]}} } -BASE_URL = "https://op.example.com" +BASE_URL = "https://op.example.com/satosa" CLIENT_ID = "client1" CLIENT_SECRET = "client_secret" EXTRA_CLAIMS = { @@ -44,6 +44,7 @@ EXTRA_SCOPES = { "eduperson": ["eduperson_scoped_affiliation", "eduperson_principal_name"] } +ISSUER = "https://other-op.example.com/satosa/other-op" class TestOpenIDConnectFrontend(object): @pytest.fixture @@ -86,10 +87,10 @@ def frontend_config_with_extra_id_token_claims(self, signing_key_path): return config - def create_frontend(self, frontend_config): + def create_frontend(self, frontend_config, issuer=BASE_URL): # will use in-memory storage instance = OpenIDConnectFrontend(lambda ctx, req: None, INTERNAL_ATTRIBUTES, - frontend_config, BASE_URL, "oidc_frontend") + frontend_config, issuer, "oidc_frontend") instance.register_endpoints(["foo_backend"]) return instance @@ -369,10 +370,40 @@ def test_jwks(self, context, frontend): jwks = json.loads(http_response.message) assert jwks == {"keys": [frontend.signing_key.serialize()]} - def test_register_endpoints_token_and_userinfo_endpoint_is_published_if_necessary(self, frontend): + @pytest.mark.parametrize("issuer", [ + "https://example.com", + "https://example.com/some/op", + "https://example.com/" + ]) + def test_register_endpoints_handles_path_in_issuer(self, frontend_config, issuer): + frontend = self.create_frontend(frontend_config, issuer) + issuer_path = urlparse(issuer).path[1:] + if issuer_path: + issuer_path += "/" + urls = frontend.register_endpoints(["test"]) + assert ( + "^{}{}$".format(issuer_path, ".well-known/openid-configuration"), + frontend.provider_config, + ) in urls + + assert ( + "^{}/{}".format(frontend.endpoint_basepath, TokenEndpoint.url), + frontend.token_endpoint, + ) in urls + assert ( + "^{}/{}".format(frontend.endpoint_basepath, UserinfoEndpoint.url), + frontend.userinfo_endpoint, + ) in urls + + def test_discovery_endpoint_honours_issuer_override(self, frontend_config): + frontend_config["provider"]["issuer"] = ISSUER + frontend = self.create_frontend(frontend_config) + discovery_path = urlparse(ISSUER).path[1:] urls = frontend.register_endpoints(["test"]) - assert ("^{}/{}".format(frontend.name, TokenEndpoint.url), frontend.token_endpoint) in urls - assert ("^{}/{}".format(frontend.name, UserinfoEndpoint.url), frontend.userinfo_endpoint) in urls + assert ( + "^{}/{}$".format(discovery_path, ".well-known/openid-configuration"), + frontend.provider_config, + ) in urls def test_register_endpoints_token_and_userinfo_endpoint_is_not_published_if_only_implicit_flow( self, frontend_config, context): @@ -380,8 +411,14 @@ def test_register_endpoints_token_and_userinfo_endpoint_is_not_published_if_only frontend = self.create_frontend(frontend_config) urls = frontend.register_endpoints(["test"]) - assert ("^{}/{}".format("test", TokenEndpoint.url), frontend.token_endpoint) not in urls - assert ("^{}/{}".format("test", UserinfoEndpoint.url), frontend.userinfo_endpoint) not in urls + assert ( + "^{}/{}".format(frontend.endpoint_basepath, TokenEndpoint.url), + frontend.token_endpoint, + ) not in urls + assert ( + "^{}/{}".format(frontend.endpoint_basepath, UserinfoEndpoint.url), + frontend.userinfo_endpoint, + ) not in urls http_response = frontend.provider_config(context) provider_config = ProviderConfigurationResponse().deserialize(http_response.message, "json") @@ -397,8 +434,13 @@ def test_register_endpoints_dynamic_client_registration_is_configurable( frontend = self.create_frontend(frontend_config) urls = frontend.register_endpoints(["test"]) - assert (("^{}/{}".format(frontend.name, RegistrationEndpoint.url), - frontend.client_registration) in urls) == client_registration_enabled + assert ( + ( + "^{}/{}".format(frontend.endpoint_basepath, RegistrationEndpoint.url), + frontend.client_registration, + ) + in urls + ) == client_registration_enabled provider_info = ProviderConfigurationResponse().deserialize(frontend.provider_config(None).message, "json") assert ("registration_endpoint" in provider_info) == client_registration_enabled diff --git a/tests/satosa/test_util.py b/tests/satosa/test_util.py new file mode 100644 index 000000000..e74c9f842 --- /dev/null +++ b/tests/satosa/test_util.py @@ -0,0 +1,47 @@ +import pytest +from satosa.util import join_paths + + +@pytest.mark.parametrize( + "args, expected", + [ + (["/foo", "baz", "bar"], "/foo/baz/bar"), + (["foo", "baz", "bar"], "foo/baz/bar"), + (["https://foo.baz", "bar"], "https://foo.baz/bar"), + (["https://foo.baz/", "bar"], "https://foo.baz/bar"), + (["foo", "/bar"], "foo/bar"), + (["/foo", "baz", "/bar"], "/foo/baz/bar"), + (["", "foo", "bar"], "foo/bar"), + (["", "/foo", "bar"], "foo/bar"), + (["", "/foo/", "bar"], "foo/bar"), + (["", "", "", "/foo", "bar"], "foo/bar"), + (["", "", "/foo/", "", "bar"], "foo/bar"), + (["", "", "/foo/", "", "", "bar/"], "foo/bar"), + (["/foo", ""], "/foo"), + (["/foo", "", "", ""], "/foo"), + (["/foo//", "bar"], "/foo/bar"), + (["foo"], "foo"), + ([""], ""), + (["", ""], ""), + (["'not ", "sanitized'\0/; rm -rf *"], "'not /sanitized'\0/; rm -rf *"), + (["foo/", "/bar"], "foo/bar"), + (["foo", "", "/bar"], "foo/bar"), + ([b"foo", "bar"], TypeError), + (["foo", b"bar"], TypeError), + ([None, "foo"], "foo"), + (["foo", [], "bar"], "foo/bar"), + (["foo", ["baz"], "bar"], TypeError), + (["/", "foo", "bar"], "/foo/bar"), + (["///foo", "bar"], "/foo/bar"), + ], +) +def test_join_paths(args, expected): + if isinstance(expected, str): + assert join_paths(*args) == expected + else: + with pytest.raises(expected): + _ = join_paths(*args) + + +def test_join_paths_with_separator(): + assert join_paths("this", "is", "not", "a", "path", sep="|") == "this|is|not|a|path"