diff --git a/src/xworkflows/base.py b/src/xworkflows/base.py index bac3588..4e0d10e 100644 --- a/src/xworkflows/base.py +++ b/src/xworkflows/base.py @@ -50,6 +50,9 @@ def __str__(self): def __repr__(self): return '<%s: %r>' % (self.__class__.__name__, self.name) + def __eq__(self, other): + return self.name == other.name + class StateList(object): """A list of states.""" @@ -358,15 +361,17 @@ def __init__(self, instance, field_name, transition, workflow, @property def current_state(self): - return getattr(self.instance, self.field_name) + current_state = getattr(self.instance, self.field_name) + if isinstance(current_state, StateWrapper): + current_state = current_state.state + return current_state def _pre_transition_checks(self): """Run the pre-transition checks.""" - current_state = getattr(self.instance, self.field_name) - if current_state not in self.transition.source: + if self.current_state not in self.transition.source: raise InvalidTransitionError( "Transition '%s' isn't available from state '%s'." % - (self.transition.name, current_state.name)) + (self.transition.name, self.current_state.name)) for check in self._filter_hooks(HOOK_CHECK): if not check(self.instance):