diff --git a/src/satosa/frontends/base.py b/src/satosa/frontends/base.py index 08c38b79c..49224c48f 100644 --- a/src/satosa/frontends/base.py +++ b/src/satosa/frontends/base.py @@ -30,6 +30,7 @@ def __init__(self, auth_req_callback_func, internal_attributes, base_url, name): self.internal_attributes = internal_attributes self.converter = AttributeMapper(internal_attributes) self.base_url = base_url or "" + self.base_path = urlparse(self.base_url).path.lstrip("/") self.name = name self.endpoint_baseurl = join_paths(self.base_url, self.name) self.endpoint_basepath = urlparse(self.endpoint_baseurl).path.lstrip("/") diff --git a/src/satosa/frontends/saml2.py b/src/satosa/frontends/saml2.py index 41be65799..aeaa31e55 100644 --- a/src/satosa/frontends/saml2.py +++ b/src/satosa/frontends/saml2.py @@ -512,15 +512,12 @@ def _register_endpoints(self, providers): url_map = [] backend_providers = "|".join(providers) - base_path = urlparse(self.base_url).path.lstrip("/") - if base_path: - base_path = base_path + "/" for endp_category in self.endpoints: for binding, endp in self.endpoints[endp_category].items(): endp_path = urlparse(endp).path url_map.append( ( - "^{}({})/{}$".format(base_path, backend_providers, endp_path), + "^{}/({})/{}$".format(self.base_path, backend_providers, endp_path), functools.partial(self.handle_authn_request, binding_in=binding) ) ) @@ -770,15 +767,12 @@ def _register_endpoints(self, providers): url_map = [] backend_providers = "|".join(providers) - base_path = urlparse(self.base_url).path.lstrip("/") - if base_path: - base_path = base_path + "/" for endp_category in self.endpoints: for binding, endp in self.endpoints[endp_category].items(): endp_path = urlparse(endp).path url_map.append( ( - "^{}({})/\S+/{}$".format(base_path, backend_providers, endp_path), + "^{}/({})/\S+/{}$".format(self.base_path, backend_providers, endp_path), functools.partial(self.handle_authn_request, binding_in=binding) ) ) diff --git a/src/satosa/micro_services/account_linking.py b/src/satosa/micro_services/account_linking.py index 4cfd72c99..304980ff8 100644 --- a/src/satosa/micro_services/account_linking.py +++ b/src/satosa/micro_services/account_linking.py @@ -165,9 +165,7 @@ def register_endpoints(self): return [ ( "^{}$".format( - join_paths( - self.base_path, "account_linking", self.endpoint - ) + join_paths(self.endpoint_basepath, self.endpoint) ), self._handle_al_response, ) diff --git a/src/satosa/micro_services/base.py b/src/satosa/micro_services/base.py index 97271b013..72110daf3 100644 --- a/src/satosa/micro_services/base.py +++ b/src/satosa/micro_services/base.py @@ -16,6 +16,8 @@ def __init__(self, name, base_url, **kwargs): self.name = name self.base_url = base_url self.base_path = urlparse(base_url).path.lstrip("/") + self.endpoint_baseurl = join_paths(self.base_url, self.name) + self.endpoint_basepath = urlparse(self.endpoint_baseurl).path.lstrip("/") self.next = None def process(self, context, data): diff --git a/src/satosa/micro_services/consent.py b/src/satosa/micro_services/consent.py index c273116d5..cf70d4d5f 100644 --- a/src/satosa/micro_services/consent.py +++ b/src/satosa/micro_services/consent.py @@ -242,9 +242,7 @@ def register_endpoints(self): return [ ( "^{}$".format( - join_paths( - self.base_path, "consent", self.endpoint - ) + join_paths(self.endpoint_basepath, self.endpoint) ), self._handle_consent_response, ) diff --git a/src/satosa/routing.py b/src/satosa/routing.py index c9fa8ab8e..d739273ac 100644 --- a/src/satosa/routing.py +++ b/src/satosa/routing.py @@ -38,7 +38,7 @@ class UnknownEndpoint(ValueError): and handles the internal routing between frontends and backends. """ - def __init__(self, frontends, backends, micro_services, base_path=""): + def __init__(self, frontends, backends, micro_services, base_path=None): """ :type frontends: dict[str, satosa.frontends.base.FrontendModule] :type backends: dict[str, satosa.backends.base.BackendModule] @@ -70,7 +70,7 @@ def __init__(self, frontends, backends, micro_services, base_path=""): else: self.micro_services = {} - self.base_path = base_path + self.base_path = base_path if base_path else "" logger.debug("Loaded backends with endpoints: {}".format(backends)) logger.debug("Loaded frontends with endpoints: {}".format(frontends))