diff --git a/src/marvin/utilities/context.py b/src/marvin/utilities/context.py index 1a0d8bf88..7495d0596 100644 --- a/src/marvin/utilities/context.py +++ b/src/marvin/utilities/context.py @@ -26,28 +26,47 @@ class ScopedContext: ``` """ - def __init__(self): - """Initializes the ScopedContext with a default empty dictionary.""" + def __init__(self, initial_value: dict = None): + """Initializes the ScopedContext with an initial valuedictionary.""" self._context_storage = contextvars.ContextVar( - "scoped_context_storage", default={} + "scoped_context_storage", default=initial_value or {} ) def get(self, key: str, default: Any = None) -> Any: return self._context_storage.get().get(key, default) + def __getitem__(self, key: str) -> Any: + notfound = object() + result = self.get(key, default=notfound) + if result == notfound: + raise KeyError(key) + return result + def set(self, **kwargs: Any) -> None: ctx = self._context_storage.get() updated_ctx = {**ctx, **kwargs} - self._context_storage.set(updated_ctx) + token = self._context_storage.set(updated_ctx) + return token @contextmanager def __call__(self, **kwargs: Any) -> Generator[None, None, Any]: - current_context = self._context_storage.get().copy() - self.set(**kwargs) + current_context_copy = self._context_storage.get().copy() + token = self.set(**kwargs) try: yield finally: - self._context_storage.set(current_context) + try: + self._context_storage.reset(token) + except ValueError as exc: + if "was created in a different context" in str(exc).lower(): + # the only way we can reach this line is if the setup and + # teardown of this context are run in different frames or + # threads (which happens with pytest fixtures!), in which case + # the token is considered invalid. This catch serves as a + # "manual" reset of the context values + self._context_storage.set(current_context_copy) + else: + raise ctx = ScopedContext()