From 51c2be2862f613cd4b36b5e0151ccb0dbd5ce97d Mon Sep 17 00:00:00 2001 From: Jeremiah Lowin <153965+jlowin@users.noreply.github.com> Date: Wed, 24 Apr 2024 08:30:56 -0400 Subject: [PATCH 1/3] Update context.py --- src/marvin/utilities/context.py | 24 ++++++++++++++++++++---- 1 file changed, 20 insertions(+), 4 deletions(-) diff --git a/src/marvin/utilities/context.py b/src/marvin/utilities/context.py index 1a0d8bf88..c6f060231 100644 --- a/src/marvin/utilities/context.py +++ b/src/marvin/utilities/context.py @@ -35,19 +35,35 @@ def __init__(self): def get(self, key: str, default: Any = None) -> Any: return self._context_storage.get().get(key, default) + def __getitem__(self, key: str) -> Any: + notfound = object() + result = self.get(key, default=notfound) + if result == notfound: + raise KeyError(key) + return result + def set(self, **kwargs: Any) -> None: ctx = self._context_storage.get() updated_ctx = {**ctx, **kwargs} - self._context_storage.set(updated_ctx) + token = self._context_storage.set(updated_ctx) + return token @contextmanager def __call__(self, **kwargs: Any) -> Generator[None, None, Any]: - current_context = self._context_storage.get().copy() - self.set(**kwargs) + current_context_copy = self._context_storage.get().copy() + token = self.set(**kwargs) try: yield finally: - self._context_storage.set(current_context) + try: + self._context_storage.reset(token) + except Exception: + # the only way we can reach this line is if the setup and + # teardown of this context are run in different frames or + # threads (which happens with pytest fixtures!), in which case + # the token is considered invalid. This catch serves as a + # "manual" reset of the context values + self._context_storage.set(current_context_copy) ctx = ScopedContext() From 142082b0220b30f181a09ad86c0e2ab13807ccc5 Mon Sep 17 00:00:00 2001 From: Jeremiah Lowin <153965+jlowin@users.noreply.github.com> Date: Wed, 24 Apr 2024 08:35:51 -0400 Subject: [PATCH 2/3] Improve error handling --- src/marvin/utilities/context.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/src/marvin/utilities/context.py b/src/marvin/utilities/context.py index c6f060231..d72e33b23 100644 --- a/src/marvin/utilities/context.py +++ b/src/marvin/utilities/context.py @@ -57,13 +57,16 @@ def __call__(self, **kwargs: Any) -> Generator[None, None, Any]: finally: try: self._context_storage.reset(token) - except Exception: - # the only way we can reach this line is if the setup and - # teardown of this context are run in different frames or - # threads (which happens with pytest fixtures!), in which case - # the token is considered invalid. This catch serves as a - # "manual" reset of the context values - self._context_storage.set(current_context_copy) + except ValueError as exc: + if "was created in a different context" in str(exc).lower(): + # the only way we can reach this line is if the setup and + # teardown of this context are run in different frames or + # threads (which happens with pytest fixtures!), in which case + # the token is considered invalid. This catch serves as a + # "manual" reset of the context values + self._context_storage.set(current_context_copy) + else: + raise ctx = ScopedContext() From 338e770de94236b957b0871c9787d15861da6430 Mon Sep 17 00:00:00 2001 From: Jeremiah Lowin <153965+jlowin@users.noreply.github.com> Date: Wed, 24 Apr 2024 08:41:35 -0400 Subject: [PATCH 3/3] add initial value --- src/marvin/utilities/context.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/marvin/utilities/context.py b/src/marvin/utilities/context.py index d72e33b23..7495d0596 100644 --- a/src/marvin/utilities/context.py +++ b/src/marvin/utilities/context.py @@ -26,10 +26,10 @@ class ScopedContext: ``` """ - def __init__(self): - """Initializes the ScopedContext with a default empty dictionary.""" + def __init__(self, initial_value: dict = None): + """Initializes the ScopedContext with an initial valuedictionary.""" self._context_storage = contextvars.ContextVar( - "scoped_context_storage", default={} + "scoped_context_storage", default=initial_value or {} ) def get(self, key: str, default: Any = None) -> Any: