diff --git a/ops/_main.py b/ops/_main.py index 07adc041c..d82fdc597 100644 --- a/ops/_main.py +++ b/ops/_main.py @@ -112,28 +112,6 @@ def _setup_event_links(charm_dir: Path, charm: 'ops.charm.CharmBase', juju_conte _create_event_link(charm, bound_event, link_to) -def _emit_charm_event(charm: 'ops.charm.CharmBase', event_name: str, juju_context: _JujuContext): - """Emits a charm event based on a Juju event name. - - Args: - charm: A charm instance to emit an event from. - event_name: A Juju event name to emit on a charm. - juju_context: An instance of the _JujuContext class. - """ - event_to_emit = None - try: - event_to_emit = getattr(charm.on, event_name) - except AttributeError: - logger.debug('Event %s not defined for %s.', event_name, charm) - - # If the event is not supported by the charm implementation, do - # not error out or try to emit it. This is to support rollbacks. - if event_to_emit is not None: - args, kwargs = _get_event_args(charm, event_to_emit, juju_context) - logger.debug('Emitting Juju event %s.', event_name) - event_to_emit.emit(*args, **kwargs) - - def _get_event_args( charm: 'ops.charm.CharmBase', bound_event: 'ops.framework.BoundEvent', @@ -401,8 +379,11 @@ def __init__( model_backend: Optional[ops.model._ModelBackend] = None, use_juju_for_storage: Optional[bool] = None, charm_state_path: str = CHARM_STATE_FILE, + juju_context: Optional[_JujuContext] = None, ): - self._juju_context = _JujuContext.from_dict(os.environ) + if juju_context is None: + juju_context = _JujuContext.from_dict(os.environ) + self._juju_context = juju_context self._charm_state_path = charm_state_path self._charm_class = charm_class if model_backend is None: @@ -413,7 +394,7 @@ def __init__( self._setup_root_logging() self._charm_root = self._juju_context.charm_dir - self._charm_meta = CharmMeta.from_charm_root(self._charm_root) + self._charm_meta = self._load_charm_meta() self._use_juju_for_storage = use_juju_for_storage # Set up dispatcher, framework and charm objects. @@ -423,6 +404,9 @@ def __init__( self.framework = self._make_framework(self.dispatcher) self.charm = self._make_charm(self.framework, self.dispatcher) + def _load_charm_meta(self): + return CharmMeta.from_charm_root(self._charm_root) + def _make_charm(self, framework: 'ops.framework.Framework', dispatcher: _Dispatcher): charm = self._charm_class(framework) dispatcher.ensure_event_links(charm) @@ -482,7 +466,7 @@ def _make_framework(self, dispatcher: _Dispatcher): # If we are in a RelationBroken event, we want to know which relation is # broken within the model, not only in the event's `.relation` attribute. - if self._juju_context.dispatch_path.endswith('-relation-broken'): + if self._juju_context.dispatch_path.endswith(('-relation-broken', '_relation_broken')): broken_relation_id = self._juju_context.relation_id else: broken_relation_id = None @@ -515,19 +499,50 @@ def _emit(self): self.framework.reemit() # Emit the Juju event. - _emit_charm_event(self.charm, self.dispatcher.event_name, self._juju_context) + self._emit_charm_event(self.dispatcher.event_name) # Emit collect-status events. ops.charm._evaluate_status(self.charm) + def _get_event_to_emit(self, event_name: str) -> Optional[ops.framework.BoundEvent]: + try: + return getattr(self.charm.on, event_name) + except AttributeError: + logger.debug('Event %s not defined for %s.', event_name, self.charm) + return None + + def _emit_charm_event(self, event_name: str): + """Emits a charm event based on a Juju event name. + + Args: + charm: A charm instance to emit an event from. + event_name: A Juju event name to emit on a charm. + juju_context: An instance of the _JujuContext class. + """ + event_to_emit = self._get_event_to_emit(event_name) + + # If the event is not supported by the charm implementation, do + # not error out or try to emit it. This is to support rollbacks. + if event_to_emit is None: + return + + args, kwargs = _get_event_args(self.charm, event_to_emit, self._juju_context) + logger.debug('Emitting Juju event %s.', event_name) + event_to_emit.emit(*args, **kwargs) + def _commit(self): """Commit the framework and gracefully teardown.""" self.framework.commit() + def _close(self): + """Perform any necessary cleanup before the framework is closed.""" + # Provided for child classes - nothing needs to be done in the base. + def run(self): """Emit and then commit the framework.""" try: self._emit() self._commit() + self._close() finally: self.framework.close() diff --git a/ops/testing.py b/ops/testing.py index fc916ef55..73da1b8cf 100644 --- a/ops/testing.py +++ b/ops/testing.py @@ -177,8 +177,8 @@ # monkeypatch it in, so that the ops.testing.ActionFailed object is the # one that we expect, even if people are mixing Harness and Scenario. # https://github.com/canonical/ops-scenario/issues/201 + import scenario._runtime as _runtime import scenario.context as _context - import scenario.runtime as _runtime _context.ActionFailed = ActionFailed # type: ignore[reportPrivateImportUsage] _runtime.ActionFailed = ActionFailed # type: ignore[reportPrivateImportUsage] diff --git a/test/test_main.py b/test/test_main.py index 2ce616268..65172b00e 100644 --- a/test/test_main.py +++ b/test/test_main.py @@ -97,7 +97,7 @@ def __init__( @patch('ops._main.setup_root_logging', new=lambda *a, **kw: None) # type: ignore -@patch('ops._main._emit_charm_event', new=lambda *a, **kw: None) # type: ignore +@patch('ops._main._Manager._emit_charm_event', new=lambda *a, **kw: None) # type: ignore @patch('ops.charm._evaluate_status', new=lambda *a, **kw: None) # type: ignore class TestCharmInit: @patch('sys.stderr', new_callable=io.StringIO) @@ -235,11 +235,11 @@ def __init__(self, framework: ops.Framework): dispatch.chmod(0o755) with patch.dict(os.environ, fake_environ): - with patch('ops._main._emit_charm_event') as mock_charm_event: + with patch('ops._main._Manager._emit_charm_event') as mock_charm_event: ops.main(MyCharm) assert mock_charm_event.call_count == 1 - return mock_charm_event.call_args[0][1] + return mock_charm_event.call_args[0][0] def test_most_legacy(self): """Without dispatch, sys.argv[0] is used.""" diff --git a/test/test_main_invocation.py b/test/test_main_invocation.py index 4751b7fd1..4105b3c17 100644 --- a/test/test_main_invocation.py +++ b/test/test_main_invocation.py @@ -24,7 +24,7 @@ @pytest.fixture def charm_env(monkeypatch: pytest.MonkeyPatch, tmp_path: Path): monkeypatch.setattr('sys.argv', ('hooks/install',)) - monkeypatch.setattr('ops._main._emit_charm_event', Mock()) + monkeypatch.setattr('ops._main._Manager._emit_charm_event', Mock()) monkeypatch.setattr('ops._main._Manager._setup_root_logging', Mock()) monkeypatch.setattr('ops.charm._evaluate_status', Mock()) monkeypatch.setenv('JUJU_CHARM_DIR', str(tmp_path)) diff --git a/testing/src/scenario/_consistency_checker.py b/testing/src/scenario/_consistency_checker.py index 1d21a60ac..33bcc4936 100644 --- a/testing/src/scenario/_consistency_checker.py +++ b/testing/src/scenario/_consistency_checker.py @@ -39,7 +39,7 @@ ) from .errors import InconsistentScenarioError -from .runtime import logger as scenario_logger +from ._runtime import logger as scenario_logger from .state import ( CharmType, PeerRelation, @@ -179,6 +179,11 @@ def check_event_consistency( # skip everything here. Perhaps in the future, custom events could # optionally include some sort of state metadata that made testing # consistency possible? + warnings.append( + "this is a custom event; if its name makes it look like a builtin one " + "(for example, a relation event, or a workload event), you might get some false-negative " + "consistency checks.", + ) return Results(errors, warnings) if event._is_relation_event: diff --git a/testing/src/scenario/_ops_main_mock.py b/testing/src/scenario/_ops_main_mock.py new file mode 100644 index 000000000..1f29c0a39 --- /dev/null +++ b/testing/src/scenario/_ops_main_mock.py @@ -0,0 +1,197 @@ +#!/usr/bin/env python3 +# Copyright 2023 Canonical Ltd. +# See LICENSE file for licensing details. + +import dataclasses +import marshal +import re +import sys +from typing import TYPE_CHECKING, Any, Dict, FrozenSet, List, Sequence, Set + +import ops +import ops.jujucontext +import ops.storage + +from ops.framework import _event_regex +from ops._main import _Dispatcher, _Manager +from ops._main import logger as ops_logger + +from .errors import BadOwnerPath, NoObserverError +from .logger import logger as scenario_logger +from .mocking import _MockModelBackend +from .state import CharmType, StoredState, DeferredEvent + +if TYPE_CHECKING: # pragma: no cover + from .context import Context + from .state import State, _CharmSpec, _Event + +EVENT_REGEX = re.compile(_event_regex) +STORED_STATE_REGEX = re.compile( + r"((?P.*)\/)?(?P<_data_type_name>\D+)\[(?P.*)\]", +) + +logger = scenario_logger.getChild("ops_main_mock") + +# pyright: reportPrivateUsage=false + + +class UnitStateDB: + """Wraps the unit-state database with convenience methods for adjusting the state.""" + + def __init__(self, underlying_store: ops.storage.SQLiteStorage): + self._db = underlying_store + + def get_stored_states(self) -> FrozenSet["StoredState"]: + """Load any StoredState data structures from the db.""" + db = self._db + stored_states: Set[StoredState] = set() + for handle_path in db.list_snapshots(): + if not EVENT_REGEX.match(handle_path) and ( + match := STORED_STATE_REGEX.match(handle_path) + ): + stored_state_snapshot = db.load_snapshot(handle_path) + kwargs = match.groupdict() + sst = StoredState(content=stored_state_snapshot, **kwargs) + stored_states.add(sst) + + return frozenset(stored_states) + + def get_deferred_events(self) -> List["DeferredEvent"]: + """Load any DeferredEvent data structures from the db.""" + db = self._db + deferred: List[DeferredEvent] = [] + for handle_path in db.list_snapshots(): + if EVENT_REGEX.match(handle_path): + notices = db.notices(handle_path) + for handle, owner, observer in notices: + try: + snapshot_data = db.load_snapshot(handle) + except ops.storage.NoSnapshotError: + snapshot_data: Dict[str, Any] = {} + + event = DeferredEvent( + handle_path=handle, + owner=owner, + observer=observer, + snapshot_data=snapshot_data, + ) + deferred.append(event) + + return deferred + + def apply_state(self, state: "State"): + """Add DeferredEvent and StoredState from this State instance to the storage.""" + db = self._db + for event in state.deferred: + db.save_notice(event.handle_path, event.owner, event.observer) + try: + marshal.dumps(event.snapshot_data) + except ValueError as e: + raise ValueError( + f"unable to save the data for {event}, it must contain only simple types.", + ) from e + db.save_snapshot(event.handle_path, event.snapshot_data) + + for stored_state in state.stored_states: + db.save_snapshot(stored_state._handle_path, stored_state.content) + + +class Ops(_Manager): + """Class to manage stepping through ops setup, event emission and framework commit.""" + + def __init__( + self, + state: "State", + event: "_Event", + context: "Context[CharmType]", + charm_spec: "_CharmSpec[CharmType]", + juju_context: ops.jujucontext._JujuContext, + ): + self.state = state + self.event = event + self.context = context + self.charm_spec = charm_spec + self.store = None + + model_backend = _MockModelBackend( + state=state, + event=event, + context=context, + charm_spec=charm_spec, + juju_context=juju_context, + ) + + super().__init__( + self.charm_spec.charm_type, model_backend, juju_context=juju_context + ) + + def _load_charm_meta(self): + metadata = (self._charm_root / "metadata.yaml").read_text() + actions_meta = self._charm_root / "actions.yaml" + if actions_meta.exists(): + actions_metadata = actions_meta.read_text() + else: + actions_metadata = None + + return ops.CharmMeta.from_yaml(metadata, actions_metadata) + + def _setup_root_logging(self): + # Ops sets sys.excepthook to go to Juju's debug-log, but that's not + # useful in a testing context, so we reset it here. + super()._setup_root_logging() + sys.excepthook = sys.__excepthook__ + + def _make_storage(self, _: _Dispatcher): + # TODO: add use_juju_for_storage support + # TODO: Pass a charm_state_path that is ':memory:' when appropriate. + charm_state_path = self._charm_root / self._charm_state_path + storage = ops.storage.SQLiteStorage(charm_state_path) + logger.info("Copying input state to storage.") + self.store = UnitStateDB(storage) + self.store.apply_state(self.state) + return storage + + def _get_event_to_emit(self, event_name: str): + owner = ( + self._get_owner(self.charm, self.event.owner_path) + if self.event + else self.charm.on + ) + + try: + event_to_emit = getattr(owner, event_name) + except AttributeError: + ops_logger.debug("Event %s not defined for %s.", event_name, self.charm) + raise NoObserverError( + f"Cannot fire {event_name!r} on {owner}: " + f"invalid event (not on charm.on).", + ) + return event_to_emit + + @staticmethod + def _get_owner(root: Any, path: Sequence[str]) -> ops.ObjectEvents: + """Walk path on root to an ObjectEvents instance.""" + obj = root + for step in path: + try: + obj = getattr(obj, step) + except AttributeError: + raise BadOwnerPath( + f"event_owner_path {path!r} invalid: {step!r} leads to nowhere.", + ) + if not isinstance(obj, ops.ObjectEvents): + raise BadOwnerPath( + f"event_owner_path {path!r} invalid: does not lead to " + f"an ObjectEvents instance.", + ) + return obj + + def _close(self): + """Now that we're done processing this event, read the charm state and expose it.""" + logger.info("Copying storage to output state.") + assert self.store is not None + deferred = self.store.get_deferred_events() + stored_state = self.store.get_stored_states() + self.state = dataclasses.replace( + self.state, deferred=deferred, stored_states=stored_state + ) diff --git a/testing/src/scenario/runtime.py b/testing/src/scenario/_runtime.py similarity index 74% rename from testing/src/scenario/runtime.py rename to testing/src/scenario/_runtime.py index 3ad2fd0a2..2dbd683b2 100644 --- a/testing/src/scenario/runtime.py +++ b/testing/src/scenario/_runtime.py @@ -6,20 +6,15 @@ import copy import dataclasses -import marshal -import re import tempfile import typing from contextlib import contextmanager from pathlib import Path from typing import ( TYPE_CHECKING, - Any, Dict, - FrozenSet, List, Optional, - Set, Type, TypeVar, Union, @@ -37,17 +32,13 @@ PreCommitEvent, ) from ops.jujucontext import _JujuContext -from ops.storage import NoSnapshotError, SQLiteStorage -from ops.framework import _event_regex from ops._private.harness import ActionFailed from .errors import NoObserverError, UncaughtCharmError from .logger import logger as scenario_logger from .state import ( - DeferredEvent, PeerRelation, Relation, - StoredState, SubordinateRelation, ) @@ -56,114 +47,10 @@ from .state import CharmType, State, _CharmSpec, _Event logger = scenario_logger.getChild("runtime") -STORED_STATE_REGEX = re.compile( - r"((?P.*)\/)?(?P<_data_type_name>\D+)\[(?P.*)\]", -) -EVENT_REGEX = re.compile(_event_regex) RUNTIME_MODULE = Path(__file__).parent -class UnitStateDB: - """Represents the unit-state.db.""" - - def __init__(self, db_path: Union[Path, str]): - self._db_path = db_path - self._state_file = Path(self._db_path) - - def _open_db(self) -> SQLiteStorage: - """Open the db.""" - return SQLiteStorage(self._state_file) - - def get_stored_states(self) -> FrozenSet["StoredState"]: - """Load any StoredState data structures from the db.""" - - db = self._open_db() - - stored_states: Set[StoredState] = set() - for handle_path in db.list_snapshots(): - if not EVENT_REGEX.match(handle_path) and ( - match := STORED_STATE_REGEX.match(handle_path) - ): - stored_state_snapshot = db.load_snapshot(handle_path) - kwargs = match.groupdict() - sst = StoredState(content=stored_state_snapshot, **kwargs) - stored_states.add(sst) - - db.close() - return frozenset(stored_states) - - def get_deferred_events(self) -> List["DeferredEvent"]: - """Load any DeferredEvent data structures from the db.""" - - db = self._open_db() - - deferred: List[DeferredEvent] = [] - for handle_path in db.list_snapshots(): - if EVENT_REGEX.match(handle_path): - notices = db.notices(handle_path) - for handle, owner, observer in notices: - try: - snapshot_data = db.load_snapshot(handle) - except NoSnapshotError: - snapshot_data: Dict[str, Any] = {} - - event = DeferredEvent( - handle_path=handle, - owner=owner, - observer=observer, - snapshot_data=snapshot_data, - ) - deferred.append(event) - - db.close() - return deferred - - def apply_state(self, state: "State"): - """Add DeferredEvent and StoredState from this State instance to the storage.""" - db = self._open_db() - for event in state.deferred: - db.save_notice(event.handle_path, event.owner, event.observer) - try: - marshal.dumps(event.snapshot_data) - except ValueError as e: - raise ValueError( - f"unable to save the data for {event}, it must contain only simple types.", - ) from e - db.save_snapshot(event.handle_path, event.snapshot_data) - - for stored_state in state.stored_states: - db.save_snapshot(stored_state._handle_path, stored_state.content) - - db.close() - - -class _OpsMainContext: # type: ignore - """Context manager representing ops.main execution context. - - When entered, ops.main sets up everything up until the charm. - When .emit() is called, ops.main proceeds with emitting the event. - When exited, if .emit has not been called manually, it is called automatically. - """ - - def __init__(self): - self._has_emitted = False - - def __enter__(self): - pass - - def emit(self): - """Emit the event. - - Within the test framework, this only requires recording that it was emitted. - """ - self._has_emitted = True - - def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any): # noqa: U100 - if not self._has_emitted: - self.emit() - - class Runtime: """Charm runtime wrapper. @@ -305,8 +192,6 @@ def _wrap(charm_type: Type["CharmType"]) -> Type["CharmType"]: class WrappedEvents(charm_type.on.__class__): """The charm's event sources, but wrapped.""" - pass - WrappedEvents.__name__ = charm_type.on.__class__.__name__ class WrappedCharm(charm_type): @@ -388,28 +273,11 @@ def _virtual_charm_root(self): # charm_virtual_root is a tempdir typing.cast(tempfile.TemporaryDirectory, charm_virtual_root).cleanup() # type: ignore - @staticmethod - def _get_state_db(temporary_charm_root: Path): - charm_state_path = temporary_charm_root / ".unit-state.db" - return UnitStateDB(charm_state_path) - - def _initialize_storage(self, state: "State", temporary_charm_root: Path): - """Before we start processing this event, store the relevant parts of State.""" - store = self._get_state_db(temporary_charm_root) - store.apply_state(state) - - def _close_storage(self, state: "State", temporary_charm_root: Path): - """Now that we're done processing this event, read the charm state and expose it.""" - store = self._get_state_db(temporary_charm_root) - deferred = store.get_deferred_events() - stored_state = store.get_stored_states() - return dataclasses.replace(state, deferred=deferred, stored_states=stored_state) - @contextmanager def _exec_ctx(self, ctx: "Context"): """python 3.8 compatibility shim""" with self._virtual_charm_root() as temporary_charm_root: - with _capture_events( + with capture_events( include_deferred=ctx.capture_deferred_events, include_framework=ctx.capture_framework_events, ) as captured: @@ -442,9 +310,6 @@ def exec( logger.info(" - generating virtual charm root") with self._exec_ctx(context) as (temporary_charm_root, captured): - logger.info(" - initializing storage") - self._initialize_storage(state, temporary_charm_root) - logger.info(" - preparing env") env = self._get_event_env( state=state, @@ -453,8 +318,8 @@ def exec( ) juju_context = _JujuContext.from_dict(env) - logger.info(" - Entering ops.main (mocked).") - from .ops_main_mock import Ops # noqa: F811 + logger.info(" - entering ops.main (mocked)") + from ._ops_main_mock import Ops # noqa: F811 try: ops = Ops( @@ -467,13 +332,9 @@ def exec( ), juju_context=juju_context, ) - ops.setup() yield ops - # if the caller did not manually emit or commit: do that. - ops.finalize() - except (NoObserverError, ActionFailed): raise # propagate along except Exception as e: @@ -482,21 +343,18 @@ def exec( ) from e finally: - logger.info(" - Exited ops.main.") - - logger.info(" - closing storage") - output_state = self._close_storage(output_state, temporary_charm_root) + logger.info(" - exited ops.main") context.emitted_events.extend(captured) logger.info("event dispatched. done.") - context._set_output_state(output_state) + context._set_output_state(ops.state) _T = TypeVar("_T", bound=EventBase) @contextmanager -def _capture_events( +def capture_events( *types: Type[EventBase], include_framework: bool = False, include_deferred: bool = True, diff --git a/testing/src/scenario/context.py b/testing/src/scenario/context.py index 8087480f2..411beed9b 100644 --- a/testing/src/scenario/context.py +++ b/testing/src/scenario/context.py @@ -31,7 +31,6 @@ MetadataNotFoundError, ) from .logger import logger as scenario_logger -from .runtime import Runtime from .state import ( CharmType, CheckInfo, @@ -43,12 +42,14 @@ _CharmSpec, _Event, ) +from ._runtime import Runtime if TYPE_CHECKING: # pragma: no cover from ops._private.harness import ExecArgs - from .ops_main_mock import Ops + from ._ops_main_mock import Ops from .state import ( AnyJson, + CharmType, JujuLogLine, RelationBase, State, @@ -83,7 +84,6 @@ def __init__( self._state_in = state_in self._emitted: bool = False - self._wrapped_ctx = None self.ops: Ops[CharmType] | None = None @@ -115,10 +115,14 @@ def run(self) -> State: """ if self._emitted: raise AlreadyEmittedError("Can only run once.") + if not self.ops: + raise RuntimeError( + "you should __enter__ this context manager before running it", + ) self._emitted = True + self.ops.run() # wrap up Runtime.exec() so that we can gather the output state - assert self._wrapped_ctx is not None self._wrapped_ctx.__exit__(None, None, None) assert self._ctx._output_state is not None @@ -127,7 +131,8 @@ def run(self) -> State: def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any): # noqa: U100 if not self._emitted: logger.debug( - "user didn't emit the event within the context manager scope. Doing so implicitly upon exit...", + "user didn't emit the event within the context manager scope. " + "Doing so implicitly upon exit...", ) self.run() @@ -662,8 +667,8 @@ def run(self, event: _Event, state: State) -> State: if self.action_results is not None: self.action_results.clear() self._action_failure_message = None - with self._run(event=event, state=state) as manager: - manager.emit() + with self._run(event=event, state=state) as ops: + ops.run() # We know that the output state will have been set by this point, # so let the type checkers know that too. assert self._output_state is not None diff --git a/testing/src/scenario/ops_main_mock.py b/testing/src/scenario/ops_main_mock.py deleted file mode 100644 index 5e4846eba..000000000 --- a/testing/src/scenario/ops_main_mock.py +++ /dev/null @@ -1,270 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2023 Canonical Ltd. -# See LICENSE file for licensing details. - -import inspect -import os -import pathlib -import sys -from typing import TYPE_CHECKING, Any, Generic, Optional, Sequence, Type, cast - -import ops.charm -import ops.framework -import ops.jujucontext -import ops.model -import ops.storage -from ops import CharmBase - -# use logger from ops._main so that juju_log will be triggered -from ops._main import CHARM_STATE_FILE, _Dispatcher, _get_event_args -from ops._main import logger as ops_logger -from ops.charm import CharmMeta -from ops.log import setup_root_logging - -from .errors import BadOwnerPath, NoObserverError -from .state import CharmType - -if TYPE_CHECKING: # pragma: no cover - from .context import Context - from .state import State, _CharmSpec, _Event - -# pyright: reportPrivateUsage=false - - -def _get_owner(root: Any, path: Sequence[str]) -> ops.ObjectEvents: - """Walk path on root to an ObjectEvents instance.""" - obj = root - for step in path: - try: - obj = getattr(obj, step) - except AttributeError: - raise BadOwnerPath( - f"event_owner_path {path!r} invalid: {step!r} leads to nowhere.", - ) - if not isinstance(obj, ops.ObjectEvents): - raise BadOwnerPath( - f"event_owner_path {path!r} invalid: does not lead to " - f"an ObjectEvents instance.", - ) - return obj - - -def _emit_charm_event( - charm: "CharmBase", - event_name: str, - juju_context: ops.jujucontext._JujuContext, - event: Optional["_Event"] = None, -): - """Emits a charm event based on a Juju event name. - - Args: - charm: A charm instance to emit an event from. - event_name: A Juju event name to emit on a charm. - event: Event to emit. - juju_context: Juju context to use for the event. - """ - owner = _get_owner(charm, event.owner_path) if event else charm.on - - try: - event_to_emit = getattr(owner, event_name) - except AttributeError: - ops_logger.debug("Event %s not defined for %s.", event_name, charm) - raise NoObserverError( - f"Cannot fire {event_name!r} on {owner}: " - f"invalid event (not on charm.on).", - ) - - args, kwargs = _get_event_args(charm, event_to_emit, juju_context) - ops_logger.debug("Emitting Juju event %s.", event_name) - event_to_emit.emit(*args, **kwargs) - - -def setup_framework( - charm_dir: pathlib.Path, - state: "State", - event: "_Event", - context: "Context[CharmType]", - charm_spec: "_CharmSpec[CharmType]", - juju_context: Optional[ops.jujucontext._JujuContext] = None, -): - from .mocking import _MockModelBackend - - if juju_context is None: - juju_context = ops.jujucontext._JujuContext.from_dict(os.environ) - model_backend = _MockModelBackend( - state=state, - event=event, - context=context, - charm_spec=charm_spec, - juju_context=juju_context, - ) - setup_root_logging(model_backend, debug=juju_context.debug) - # ops sets sys.excepthook to go to Juju's debug-log, but that's not useful - # in a testing context, so reset it. - sys.excepthook = sys.__excepthook__ - ops_logger.debug( - "Operator Framework %s up and running.", - ops.__version__, - ) - - metadata = (charm_dir / "metadata.yaml").read_text() - actions_meta = charm_dir / "actions.yaml" - if actions_meta.exists(): - actions_metadata = actions_meta.read_text() - else: - actions_metadata = None - - meta = CharmMeta.from_yaml(metadata, actions_metadata) - - # ops >= 2.10 - if inspect.signature(ops.model.Model).parameters.get("broken_relation_id"): - # If we are in a RelationBroken event, we want to know which relation is - # broken within the model, not only in the event's `.relation` attribute. - broken_relation_id = ( - event.relation.id # type: ignore - if event.name.endswith("_relation_broken") - else None - ) - - model = ops.model.Model( - meta, - model_backend, - broken_relation_id=broken_relation_id, - ) - else: - ops_logger.warning( - "It looks like this charm is using an older `ops` version. " - "You may experience weirdness. Please update ops.", - ) - model = ops.model.Model(meta, model_backend) - - charm_state_path = charm_dir / CHARM_STATE_FILE - - # TODO: add use_juju_for_storage support - store = ops.storage.SQLiteStorage(charm_state_path) - framework = ops.Framework(store, charm_dir, meta, model) - framework.set_breakpointhook() - return framework - - -def setup_charm( - charm_class: Type[ops.CharmBase], framework: ops.Framework, dispatcher: _Dispatcher -): - sig = inspect.signature(charm_class) - sig.bind(framework) # signature check - - charm = charm_class(framework) - dispatcher.ensure_event_links(charm) - return charm - - -def setup( - state: "State", - event: "_Event", - context: "Context[CharmType]", - charm_spec: "_CharmSpec[CharmType]", - juju_context: Optional[ops.jujucontext._JujuContext] = None, -): - """Setup dispatcher, framework and charm objects.""" - charm_class = charm_spec.charm_type - if juju_context is None: - juju_context = ops.jujucontext._JujuContext.from_dict(os.environ) - charm_dir = juju_context.charm_dir - - dispatcher = _Dispatcher(charm_dir, juju_context) - dispatcher.run_any_legacy_hook() - - framework = setup_framework( - charm_dir, state, event, context, charm_spec, juju_context - ) - charm = setup_charm(charm_class, framework, dispatcher) - return dispatcher, framework, charm - - -class Ops(Generic[CharmType]): - """Class to manage stepping through ops setup, event emission and framework commit.""" - - def __init__( - self, - state: "State", - event: "_Event", - context: "Context[CharmType]", - charm_spec: "_CharmSpec[CharmType]", - juju_context: Optional[ops.jujucontext._JujuContext] = None, - ): - self.state = state - self.event = event - self.context = context - self.charm_spec = charm_spec - if juju_context is None: - juju_context = ops.jujucontext._JujuContext.from_dict(os.environ) - self.juju_context = juju_context - - # set by setup() - self.dispatcher: Optional[_Dispatcher] = None - self.framework: Optional[ops.Framework] = None - self.charm: Optional["CharmType"] = None - - self._has_setup = False - self._has_emitted = False - self._has_committed = False - - def setup(self): - """Setup framework, charm and dispatcher.""" - self._has_setup = True - self.dispatcher, self.framework, self.charm = setup( - self.state, - self.event, - self.context, - self.charm_spec, - self.juju_context, - ) - - def emit(self): - """Emit the event on the charm.""" - if not self._has_setup: - raise RuntimeError("should .setup() before you .emit()") - self._has_emitted = True - - dispatcher = cast(_Dispatcher, self.dispatcher) - charm = cast(CharmBase, self.charm) - framework = cast(ops.Framework, self.framework) - - try: - if not dispatcher.is_restricted_context(): - framework.reemit() - - _emit_charm_event( - charm, dispatcher.event_name, self.juju_context, self.event - ) - - except Exception: - framework.close() - raise - - def commit(self): - """Commit the framework and teardown.""" - if not self._has_emitted: - raise RuntimeError("should .emit() before you .commit()") - - framework = cast(ops.Framework, self.framework) - charm = cast(CharmBase, self.charm) - - # emit collect-status events - ops.charm._evaluate_status(charm) - - self._has_committed = True - - try: - framework.commit() - finally: - framework.close() - - def finalize(self): - """Step through all non-manually-called procedures and run them.""" - if not self._has_setup: - self.setup() - if not self._has_emitted: - self.emit() - if not self._has_committed: - self.commit() diff --git a/testing/tests/test_context_on.py b/testing/tests/test_context_on.py index 32759fd49..402de45ce 100644 --- a/testing/tests/test_context_on.py +++ b/testing/tests/test_context_on.py @@ -1,4 +1,5 @@ import copy +import typing import ops import pytest @@ -35,13 +36,13 @@ class ContextCharm(ops.CharmBase): - def __init__(self, framework): + def __init__(self, framework: ops.Framework): super().__init__(framework) - self.observed = [] + self.observed: typing.List[ops.EventBase] = [] for event in self.on.events().values(): framework.observe(event, self._on_event) - def _on_event(self, event): + def _on_event(self, event: ops.EventBase): self.observed.append(event) @@ -60,7 +61,7 @@ def _on_event(self, event): ("leader_elected", ops.LeaderElectedEvent), ], ) -def test_simple_events(event_name, event_kind): +def test_simple_events(event_name: str, event_kind: typing.Type[ops.EventBase]): ctx = scenario.Context(ContextCharm, meta=META, actions=ACTIONS) # These look like: # ctx.run(ctx.on.install(), state) diff --git a/testing/tests/test_e2e/test_stored_state.py b/testing/tests/test_e2e/test_stored_state.py index b4cb7c7a9..1f26e0aaa 100644 --- a/testing/tests/test_e2e/test_stored_state.py +++ b/testing/tests/test_e2e/test_stored_state.py @@ -1,6 +1,6 @@ import pytest -from ops.charm import CharmBase -from ops.framework import Framework + +import ops from ops.framework import StoredState as ops_storedstate from scenario.state import State, StoredState @@ -9,21 +9,21 @@ @pytest.fixture(scope="function") def mycharm(): - class MyCharm(CharmBase): + class MyCharm(ops.CharmBase): META = {"name": "mycharm"} _read = {} _stored = ops_storedstate() _stored2 = ops_storedstate() - def __init__(self, framework: Framework): + def __init__(self, framework: ops.Framework): super().__init__(framework) self._stored.set_default(foo="bar", baz={12: 142}) self._stored2.set_default(foo="bar", baz={12: 142}) for evt in self.on.events().values(): self.framework.observe(evt, self._on_event) - def _on_event(self, event): + def _on_event(self, _: ops.EventBase): self._read["foo"] = self._stored.foo self._read["baz"] = self._stored.baz diff --git a/testing/tests/test_emitted_events_util.py b/testing/tests/test_emitted_events_util.py index f22a69586..0714562f5 100644 --- a/testing/tests/test_emitted_events_util.py +++ b/testing/tests/test_emitted_events_util.py @@ -2,8 +2,8 @@ from ops.framework import CommitEvent, EventBase, EventSource, PreCommitEvent from scenario import State -from scenario.runtime import _capture_events from scenario.state import _Event +from scenario._runtime import capture_events from .helpers import trigger @@ -32,7 +32,7 @@ def _on_foo(self, e): def test_capture_custom_evt_nonspecific_capture_include_fw_evts(): - with _capture_events(include_framework=True) as emitted: + with capture_events(include_framework=True) as emitted: trigger(State(), "start", MyCharm, meta=MyCharm.META) assert len(emitted) == 5 @@ -44,7 +44,7 @@ def test_capture_custom_evt_nonspecific_capture_include_fw_evts(): def test_capture_juju_evt(): - with _capture_events() as emitted: + with capture_events() as emitted: trigger(State(), "start", MyCharm, meta=MyCharm.META) assert len(emitted) == 2 @@ -54,7 +54,7 @@ def test_capture_juju_evt(): def test_capture_deferred_evt(): # todo: this test should pass with ops < 2.1 as well - with _capture_events() as emitted: + with capture_events() as emitted: trigger( State(deferred=[_Event("foo").deferred(handler=MyCharm._on_foo)]), "start", @@ -70,7 +70,7 @@ def test_capture_deferred_evt(): def test_capture_no_deferred_evt(): # todo: this test should pass with ops < 2.1 as well - with _capture_events(include_deferred=False) as emitted: + with capture_events(include_deferred=False) as emitted: trigger( State(deferred=[_Event("foo").deferred(handler=MyCharm._on_foo)]), "start", diff --git a/testing/tests/test_runtime.py b/testing/tests/test_runtime.py index b303fadf8..79e465636 100644 --- a/testing/tests/test_runtime.py +++ b/testing/tests/test_runtime.py @@ -2,28 +2,28 @@ from tempfile import TemporaryDirectory import pytest -from ops.charm import CharmBase, CharmEvents -from ops.framework import EventBase + +import ops from scenario import Context -from scenario.runtime import Runtime, UncaughtCharmError from scenario.state import Relation, State, _CharmSpec, _Event +from scenario._runtime import Runtime, UncaughtCharmError def charm_type(): - class _CharmEvents(CharmEvents): + class _CharmEvents(ops.CharmEvents): pass - class MyCharm(CharmBase): - on = _CharmEvents() + class MyCharm(ops.CharmBase): + on = _CharmEvents() # type: ignore _event = None - def __init__(self, framework): + def __init__(self, framework: ops.Framework): super().__init__(framework) for evt in self.on.events().values(): self.framework.observe(evt, self._catchall) - def _catchall(self, e): + def _catchall(self, e: ops.EventBase): if self._event: return MyCharm._event = e @@ -40,7 +40,7 @@ def test_event_emission(): my_charm_type = charm_type() - class MyEvt(EventBase): + class MyEvt(ops.EventBase): pass my_charm_type.on.define_event("bar", MyEvt) @@ -56,8 +56,8 @@ class MyEvt(EventBase): state=State(), event=_Event("bar"), context=Context(my_charm_type, meta=meta), - ): - pass + ) as manager: + manager.run() assert my_charm_type._event assert isinstance(my_charm_type._event, MyEvt) @@ -109,7 +109,7 @@ def test_env_clean_on_charm_error(): event=_Event("box_relation_changed", relation=rel), context=Context(my_charm_type, meta=meta), ) as manager: - assert manager.juju_context.remote_app_name == remote_name + assert manager._juju_context.remote_app_name == remote_name assert "JUJU_REMOTE_APP" not in os.environ _ = 1 / 0 # raise some error # Ensure that some other error didn't occur (like AssertionError!).