From 98b786a5dd6ff711a23749bc6d7008a3d6737d3e Mon Sep 17 00:00:00 2001 From: pwwang Date: Wed, 21 Feb 2024 22:26:22 -0600 Subject: [PATCH] feat: allow passing arguments to `utils.is_loading_pipeline()` --- pipen/utils.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/pipen/utils.py b/pipen/utils.py index 5b589506..f2746273 100644 --- a/pipen/utils.py +++ b/pipen/utils.py @@ -734,7 +734,7 @@ async def load_pipeline( return pipeline -def is_loading_pipeline(*flags: str) -> bool: +def is_loading_pipeline(*flags: str, args: Sequence[str] | None = None) -> bool: """Check if we are loading the pipeline. Works only when `argv0` is "@pipen" while loading the pipeline. @@ -745,13 +745,20 @@ def is_loading_pipeline(*flags: str) -> bool: Args: *flags: Additional flags to check in sys.argv (e.g. "-h", "--help") to determine if we are loading the pipeline + args: The arguments to check. sys.argv is used by default. + Note that the first argument should be included in the check. + You could typically pass `[sys.argv[0], *your_args]` to this if you want + to check if `sys.argv[0]` is "@pipen" or `your_args` contains some flags. Returns: True if we are loading the pipeline (argv[0] == "@pipen"), otherwise False """ - if sys.argv[0] == LOADING_ARGV0: + if args is None: + args = sys.argv + + if len(args) > 0 and args[0] == LOADING_ARGV0: return True if flags: - return any(flag in sys.argv for flag in flags) + return any(flag in args for flag in flags)