diff --git a/src/satosa/util.py b/src/satosa/util.py index b171f1f3e..ea1728853 100644 --- a/src/satosa/util.py +++ b/src/satosa/util.py @@ -5,6 +5,7 @@ import logging import random import string +import typing logger = logging.getLogger(__name__) @@ -91,27 +92,23 @@ def rndstr(size=16, alphabet=""): return type(alphabet)().join(rng.choice(alphabet) for _ in range(size)) -def join_paths(base, *paths): +def join_paths(*paths, sep: typing.Optional[str] = None) -> str: """ - Joins strings with a "/" separator, like they were path components, but - tries to avoid adding an unnecessary separator. Note that the contents of - the strings are not sanitized in any way. If any of the components begins or - ends with a "/", the separator is not inserted, and any number of empty - strings at the beginning would not add a leading slash. Any number of empty - strings at the end only add a single trailing slash. - - Raises TypeError if any of the components are not strings. + Joins strings with a separator like they were path components. The + separator is stripped off from all path components, except for the + beginning of the first component. Empty (or falsy) components are skipped. + Note that the components are not sanitized in any other way. + + Raises TypeError if any of the components are not strings (or empty). """ - sep = "/" + sep = sep or "/" + leading = "" + if paths and paths[0] and paths[0][0] == sep: + leading = sep - path = base try: - for p in paths: - if not path or path.endswith(sep) or p.startswith(sep): - path += p - else: - path += sep + p + return leading + sep.join( + [path.strip(sep) for path in filter(lambda p: p and p.strip(sep), paths)] + ) except (AttributeError, TypeError) as err: raise TypeError("Arguments must be strings") from err - - return path diff --git a/tests/satosa/test_util.py b/tests/satosa/test_util.py index 893534b89..e74c9f842 100644 --- a/tests/satosa/test_util.py +++ b/tests/satosa/test_util.py @@ -12,23 +12,27 @@ (["foo", "/bar"], "foo/bar"), (["/foo", "baz", "/bar"], "/foo/baz/bar"), (["", "foo", "bar"], "foo/bar"), - (["", "/foo", "bar"], "/foo/bar"), - (["", "/foo/", "bar"], "/foo/bar"), - (["", "", "", "/foo", "bar"], "/foo/bar"), - (["", "", "/foo/", "", "bar"], "/foo/bar"), - (["", "", "/foo/", "", "", "bar/"], "/foo/bar/"), - (["/foo", ""], "/foo/"), - (["/foo", "", "", ""], "/foo/"), - (["/foo//", "bar"], "/foo//bar"), + (["", "/foo", "bar"], "foo/bar"), + (["", "/foo/", "bar"], "foo/bar"), + (["", "", "", "/foo", "bar"], "foo/bar"), + (["", "", "/foo/", "", "bar"], "foo/bar"), + (["", "", "/foo/", "", "", "bar/"], "foo/bar"), + (["/foo", ""], "/foo"), + (["/foo", "", "", ""], "/foo"), + (["/foo//", "bar"], "/foo/bar"), (["foo"], "foo"), ([""], ""), (["", ""], ""), (["'not ", "sanitized'\0/; rm -rf *"], "'not /sanitized'\0/; rm -rf *"), - (["foo/", "/bar"], "foo//bar"), - (["foo", "", "/bar"], "foo//bar"), + (["foo/", "/bar"], "foo/bar"), + (["foo", "", "/bar"], "foo/bar"), ([b"foo", "bar"], TypeError), (["foo", b"bar"], TypeError), - ([None, "foo"], TypeError), + ([None, "foo"], "foo"), + (["foo", [], "bar"], "foo/bar"), + (["foo", ["baz"], "bar"], TypeError), + (["/", "foo", "bar"], "/foo/bar"), + (["///foo", "bar"], "/foo/bar"), ], ) def test_join_paths(args, expected): @@ -37,3 +41,7 @@ def test_join_paths(args, expected): else: with pytest.raises(expected): _ = join_paths(*args) + + +def test_join_paths_with_separator(): + assert join_paths("this", "is", "not", "a", "path", sep="|") == "this|is|not|a|path"