diff --git a/openhtf/core/test_descriptor.py b/openhtf/core/test_descriptor.py index 649201253..266c26047 100644 --- a/openhtf/core/test_descriptor.py +++ b/openhtf/core/test_descriptor.py @@ -46,6 +46,7 @@ from openhtf.core import test_record as htf_test_record from openhtf.core import test_state from openhtf.core.dut_id import DutIdentifier +from openhtf.plugs import PlugManager from openhtf.util import configuration from openhtf.util import console_output @@ -56,14 +57,14 @@ _LOG = logging.getLogger(__name__) CONF.declare( - 'capture_source', - description=textwrap.dedent( - """Whether to capture the source of phases and the test module. This - defaults to False since this potentially reads many files and makes large - string copies. + 'capture_source', + description=textwrap.dedent( + """Whether to capture the source of phases and the test module. This +defaults to False since this potentially reads many files and makes large +string copies. - Set to 'true' if you want to capture your test's source."""), - default_value=False) +Set to 'true' if you want to capture your test's source."""), + default_value=False) class MeasurementNotFoundError(Exception): @@ -103,19 +104,19 @@ def create_arg_parser(add_help: bool = False) -> argparse.ArgumentParser: """ parser = argparse.ArgumentParser( - 'OpenHTF-based testing', - parents=[ - CONF.ARG_PARSER, - console_output.ARG_PARSER, - logs.ARG_PARSER, - phase_executor.ARG_PARSER, - ], - add_help=add_help) + 'OpenHTF-based testing', + parents=[ + CONF.ARG_PARSER, + console_output.ARG_PARSER, + logs.ARG_PARSER, + phase_executor.ARG_PARSER, + ], + add_help=add_help) parser.add_argument( - '--config-help', - action='store_true', - help='Instead of executing the test, simply print all available config ' - 'keys and their description strings.') + '--config-help', + action='store_true', + help='Instead of executing the test, simply print all available config ' + 'keys and their description strings.') return parser @@ -141,17 +142,19 @@ def PhaseTwo(test): DEFAULT_SIGINT_HANDLER = None def __init__(self, *nodes: phase_descriptor.PhaseCallableOrNodeT, + plug_manager: Optional[PlugManager] = None, **metadata: Any): # Some sanity checks on special metadata keys we automatically fill in. if 'config' in metadata: raise KeyError( - 'Invalid metadata key "config", it will be automatically populated.') + 'Invalid metadata key "config", it will be automatically populated.') self.created_time_millis = util.time_millis() self.last_run_time_millis = None self._test_options = TestOptions() self._lock = threading.Lock() self._executor = None + self._plug_manager = plug_manager # TODO(arsharma): Drop _flatten at some point. sequence = phase_collections.PhaseSequence(nodes) self._test_desc = TestDescriptor(sequence, @@ -161,9 +164,9 @@ def __init__(self, *nodes: phase_descriptor.PhaseCallableOrNodeT, if CONF.capture_source: # Copy the phases with the real CodeInfo for them. self._test_desc.phase_sequence = ( - self._test_desc.phase_sequence.load_code_info()) + self._test_desc.phase_sequence.load_code_info()) self._test_desc.code_info = ( - htf_test_record.CodeInfo.for_module_from_stack(levels_up=2)) + htf_test_record.CodeInfo.for_module_from_stack(levels_up=2)) # Make sure configure() gets called at least once before Execute(). The # user might call configure() again to override options, but we don't want @@ -253,7 +256,8 @@ def configure(self, **kwargs: Any) -> None: def handle_sig_int(cls, signalnum: Optional[int], handler: Any) -> None: """Handle the SIGINT callback.""" if not cls.TEST_INSTANCES: - cls.DEFAULT_SIGINT_HANDLER(signalnum, handler) # pylint: disable=not-callable # pytype: disable=not-callable + cls.DEFAULT_SIGINT_HANDLER(signalnum, + handler) # pylint: disable=not-callable # pytype: disable=not-callable return _LOG.error('Received SIGINT, stopping all tests.') @@ -297,10 +301,10 @@ def execute(self, InvalidTestStateError: if this test is already being executed. """ phase_descriptor.check_for_duplicate_results( - self._test_desc.phase_sequence.all_phases(), - self._test_options.diagnosers) + self._test_desc.phase_sequence.all_phases(), + self._test_options.diagnosers) phase_collections.check_for_duplicate_subtest_names( - self._test_desc.phase_sequence) + self._test_desc.phase_sequence) # Lock this section so we don't .stop() the executor between instantiating # it and .Start()'ing it, doing so does weird things to the executor state. with (self._lock): @@ -336,11 +340,13 @@ def trigger_phase(test): trigger.code_info = htf_test_record.CodeInfo.for_function(trigger.func) self._executor = test_executor.TestExecutor( - self._test_desc, - self.make_uid(), - trigger, - self._test_options, - run_with_profiling=profile_filename is not None) + self._test_desc, + self.make_uid(), + trigger, + self._test_options, + run_with_profiling=profile_filename is not None, + plug_manager=self._plug_manager, + ) _LOG.info('Executing test: %s', self.descriptor.code_info.name) self.TEST_INSTANCES[self.uid] = self @@ -377,21 +383,21 @@ def trigger_phase(test): else: colors = collections.defaultdict(lambda: colorama.Style.BRIGHT) colors[htf_test_record.Outcome.PASS] = ''.join( - (colorama.Style.BRIGHT, colorama.Fore.GREEN)) # pytype: disable=wrong-arg-types + (colorama.Style.BRIGHT, colorama.Fore.GREEN)) # pytype: disable=wrong-arg-types colors[htf_test_record.Outcome.FAIL] = ''.join( - (colorama.Style.BRIGHT, colorama.Fore.RED)) # pytype: disable=wrong-arg-types + (colorama.Style.BRIGHT, colorama.Fore.RED)) # pytype: disable=wrong-arg-types msg_template = ( - 'test: {name} outcome: {color}{outcome}{marginal}{rst}') + 'test: {name} outcome: {color}{outcome}{marginal}{rst}') console_output.banner_print( - msg_template.format( - name=final_state.test_record.metadata['test_name'], - color=(colorama.Fore.YELLOW - if final_state.test_record.marginal else - colors[final_state.test_record.outcome]), - outcome=final_state.test_record.outcome.name, - marginal=(' (MARGINAL)' - if final_state.test_record.marginal else ''), - rst=colorama.Style.RESET_ALL)) + msg_template.format( + name=final_state.test_record.metadata['test_name'], + color=(colorama.Fore.YELLOW + if final_state.test_record.marginal else + colors[final_state.test_record.outcome]), + outcome=final_state.test_record.outcome.name, + marginal=(' (MARGINAL)' + if final_state.test_record.marginal else ''), + rst=colorama.Style.RESET_ALL)) finally: del self.TEST_INSTANCES[self.uid] self._executor.close() @@ -420,7 +426,7 @@ class TestOptions(object): name = attr.ib(type=Text, default='openhtf_test') output_callbacks = attr.ib( - type=List[Callable[[htf_test_record.TestRecord], None]], factory=list) + type=List[Callable[[htf_test_record.TestRecord], None]], factory=list) failure_exceptions = attr.ib(type=List[Type[Exception]], factory=list) default_dut_id = attr.ib(type=Text, default='UNKNOWN_DUT') stop_on_first_failure = attr.ib(type=bool, default=False) @@ -574,7 +580,7 @@ def attach_from_file( IOError: Raised if the given filename couldn't be opened. """ self._running_phase_state.attach_from_file( - filename, name=name, mimetype=mimetype) + filename, name=name, mimetype=mimetype) def get_measurement( self, @@ -611,7 +617,7 @@ def get_measurement_strict( measurement = self._running_test_state.get_measurement(measurement_name) if measurement is None: raise MeasurementNotFoundError( - f'Failed to find test measurement {measurement_name}') + f'Failed to find test measurement {measurement_name}') return measurement def get_attachment( diff --git a/openhtf/core/test_executor.py b/openhtf/core/test_executor.py index 6457b6f9e..3ee5390ac 100644 --- a/openhtf/core/test_executor.py +++ b/openhtf/core/test_executor.py @@ -34,6 +34,7 @@ from openhtf.core import phase_nodes from openhtf.core import test_record from openhtf.core import test_state +from openhtf.plugs import PlugManager from openhtf.util import configuration from openhtf.util import threads @@ -96,10 +97,11 @@ def __init__(self, test_descriptor: 'test_descriptor.TestDescriptor', execution_uid: Text, test_start: Optional[phase_descriptor.PhaseDescriptor], test_options: 'test_descriptor.TestOptions', - run_with_profiling: bool): + run_with_profiling: bool, plug_manager: Optional[PlugManager] = None): super(TestExecutor, self).__init__( name='TestExecutorThread', run_with_profiling=run_with_profiling) self.test_state = None # type: Optional[test_state.TestState] + self._plug_manager = plug_manager self._test_descriptor = test_descriptor self._test_start = test_start @@ -201,7 +203,9 @@ def _thread_proc(self) -> None: try: # Top level steps required to run a single iteration of the Test. self.test_state = test_state.TestState(self._test_descriptor, self.uid, - self._test_options) + self._test_options, + plug_manager=self._plug_manager + ) phase_exec = phase_executor.PhaseExecutor(self.test_state) # Any access to self._exit_stacks must be done while holding this lock. diff --git a/openhtf/core/test_state.py b/openhtf/core/test_state.py index 673b77eb9..80d74537c 100644 --- a/openhtf/core/test_state.py +++ b/openhtf/core/test_state.py @@ -160,7 +160,8 @@ class Status(enum.Enum): def __init__(self, test_desc: 'test_descriptor.TestDescriptor', execution_uid: Text, - test_options: 'test_descriptor.TestOptions'): + test_options: 'test_descriptor.TestOptions', + plug_manager: Optional[plugs.PlugManager] = None): """Initializer. Args: @@ -184,7 +185,11 @@ def __init__(self, test_desc: 'test_descriptor.TestDescriptor', logs.initialize_record_handler(execution_uid, self.test_record, self.notify_update) self.state_logger = logs.get_record_logger_for(execution_uid) - self.plug_manager = plugs.PlugManager(test_desc.plug_types, + if plug_manager is not None: + self.plug_manager = plug_manager + self.plug_manager.initialize_plugs(test_desc.plug_types) + else: + self.plug_manager = plugs.PlugManager(test_desc.plug_types, self.state_logger) self.diagnoses_manager = diagnoses_lib.DiagnosesManager( self.state_logger.getChild('diagnoses')) diff --git a/openhtf/plugs/__init__.py b/openhtf/plugs/__init__.py index f3dfbc0be..90dbf81e5 100644 --- a/openhtf/plugs/__init__.py +++ b/openhtf/plugs/__init__.py @@ -37,10 +37,10 @@ _BASE_PLUGS_LOG = base_plugs._LOG # pylint: disable=protected-access CONF.declare( - 'plug_teardown_timeout_s', - default_value=0, - description='Timeout (in seconds) for each plug tearDown function if > 0; ' - 'otherwise, will wait an unlimited time.') + 'plug_teardown_timeout_s', + default_value=0, + description='Timeout (in seconds) for each plug tearDown function if > 0; ' + 'otherwise, will wait an unlimited time.') # TODO(arsharma): Remove this aliases when users have moved to using the core # library. @@ -89,8 +89,8 @@ def plug( if not (isinstance(a_plug, base_plugs.PlugPlaceholder) or issubclass(a_plug, base_plugs.BasePlug)): raise base_plugs.InvalidPlugError( - 'Plug %s is not a subclass of base_plugs.BasePlug nor a placeholder ' - 'for one' % a_plug) + 'Plug %s is not a subclass of base_plugs.BasePlug nor a placeholder ' + 'for one' % a_plug) def result( func: 'phase_descriptor.PhaseT') -> 'phase_descriptor.PhaseDescriptor': @@ -114,8 +114,8 @@ def result( (duplicates, func)) phase.plugs.extend([ - base_plugs.PhasePlug(name, a_plug, update_kwargs=update_kwargs) - for name, a_plug in plugs_map.items() + base_plugs.PhasePlug(name, a_plug, update_kwargs=update_kwargs) + for name, a_plug in plugs_map.items() ]) return phase @@ -136,7 +136,7 @@ def _thread_proc(self) -> None: # Including the stack trace from ThreadTerminationErrors received when # killed. _LOG.warning( - 'Exception calling tearDown on %s:', self._plug, exc_info=True) + 'Exception calling tearDown on %s:', self._plug, exc_info=True) PlugT = TypeVar('PlugT', bound=base_plugs.BasePlug) @@ -164,30 +164,32 @@ class PlugManager(object): def __init__(self, plug_types: Optional[Set[Type[base_plugs.BasePlug]]] = None, - record_logger: Optional[logging.Logger] = None): + record_logger: Optional[logging.Logger] = None + ): self._plug_types = plug_types or set() for plug_type in self._plug_types: if isinstance(plug_type, base_plugs.PlugPlaceholder): raise base_plugs.InvalidPlugError( - 'Plug {} is a placeholder, replace it using with_plugs().'.format( - plug_type)) + 'Plug {} is a placeholder, replace it using with_plugs().'.format( + plug_type)) self._plugs_by_type = {} self._plugs_by_name = {} self._plug_descriptors = {} + self._unmanaged_plugs = {} if not record_logger: record_logger = _LOG self.logger = record_logger.getChild('plug') def as_base_types(self) -> Dict[Text, Any]: return { - 'plug_descriptors': { - name: attr.asdict(descriptor) - for name, descriptor in self._plug_descriptors.items() - }, - 'plug_states': { - name: data.convert_to_base_types(plug) - for name, plug in self._plugs_by_name.items() - }, + 'plug_descriptors': { + name: attr.asdict(descriptor) + for name, descriptor in self._plug_descriptors.items() + }, + 'plug_states': { + name: data.convert_to_base_types(plug) + for name, plug in self._plugs_by_name.items() + }, } def _make_plug_descriptor( @@ -209,12 +211,19 @@ def get_plug_mro(self, plug_type: Type[base_plugs.BasePlug]) -> List[Text]: """ ignored_classes = (base_plugs.BasePlug, base_plugs.FrontendAwareBasePlug) return [ - self.get_plug_name(base_class) # pylint: disable=g-complex-comprehension - for base_class in plug_type.mro() - if (issubclass(base_class, base_plugs.BasePlug) and - base_class not in ignored_classes) + self.get_plug_name(base_class) # pylint: disable=g-complex-comprehension + for base_class in plug_type.mro() + if (issubclass(base_class, base_plugs.BasePlug) and + base_class not in ignored_classes) ] + def add_non_managed_plug(self, plug): + t = type(plug) + if t in self._unmanaged_plugs: + raise Exception(f"Plug of type {t} already exists") + self._unmanaged_plugs[t] = plug + self.update_plug(t, plug) + def get_plug_name(self, plug_type: Type[base_plugs.BasePlug]) -> Text: """Returns the plug's name, which is the class name and module. @@ -240,6 +249,9 @@ def initialize_plugs( """ types = plug_types if plug_types is not None else self._plug_types for plug_type in types: + if plug_type in self._unmanaged_plugs: + continue + # Create a logger for this plug. All plug loggers go under the 'plug' # sub-logger in the logger hierarchy. plug_logger = self.logger.getChild(plug_type.__name__) @@ -248,12 +260,12 @@ def initialize_plugs( try: if not issubclass(plug_type, base_plugs.BasePlug): raise base_plugs.InvalidPlugError( - 'Plug type "{}" is not an instance of base_plugs.BasePlug'.format( - plug_type)) + 'Plug type "{}" is not an instance of base_plugs.BasePlug'.format( + plug_type)) if plug_type.logger != _BASE_PLUGS_LOG: # They put a logger attribute on the class itself, overriding ours. raise base_plugs.InvalidPlugError( - 'Do not override "logger" in your plugs.', plug_type) + 'Do not override "logger" in your plugs.', plug_type) # Override the logger so that __init__'s logging goes into the record. plug_type.logger = plug_logger @@ -267,7 +279,7 @@ def initialize_plugs( # it set. if plug_instance.logger != _BASE_PLUGS_LOG: raise base_plugs.InvalidPlugError( - 'Do not set "self.logger" in __init__ in your plugs', plug_type) + 'Do not set "self.logger" in __init__ in your plugs', plug_type) else: # Now the instance has its own copy of the test logger. plug_instance.logger = plug_logger @@ -319,7 +331,10 @@ def provide_plugs( self, plug_name_map: Iterable[Tuple[Text, Type[base_plugs.BasePlug]]] ) -> Dict[Text, base_plugs.BasePlug]: """Provide the requested plugs [(name, type),] as {name: plug instance}.""" - return {name: self._plugs_by_type[cls] for name, cls in plug_name_map} + try: + return {name: self._plugs_by_type[cls] for name, cls in plug_name_map} + except Exception as e: + raise e def tear_down_plugs(self) -> None: """Call tearDown() on all instantiated plugs. @@ -333,6 +348,8 @@ def tear_down_plugs(self) -> None: """ _LOG.debug('Tearing down all plugs.') for plug_type, plug_instance in self._plugs_by_type.items(): + if plug_type in self._unmanaged_plugs: + continue if plug_instance.uses_base_tear_down(): name = '' % plug_type else: @@ -340,8 +357,8 @@ def tear_down_plugs(self) -> None: thread = _PlugTearDownThread(plug_instance, name=name) thread.start() timeout_s = ( - CONF.plug_teardown_timeout_s - if CONF.plug_teardown_timeout_s else None) + CONF.plug_teardown_timeout_s + if CONF.plug_teardown_timeout_s else None) thread.join(timeout_s) if thread.is_alive(): thread.kill() @@ -349,6 +366,9 @@ def tear_down_plugs(self) -> None: plug_instance) self._plugs_by_type.clear() self._plugs_by_name.clear() + # Re-add the un-managed plugs for the next test + for plug_type, plug_instance in self._unmanaged_plugs.items(): + self.update_plug(plug_type, plug_instance) def wait_for_plug_update( self, plug_name: Text, remote_state: Dict[Text, Any], @@ -371,12 +391,12 @@ def wait_for_plug_update( if plug_instance is None: raise base_plugs.InvalidPlugError( - 'Cannot wait on unknown plug "{}".'.format(plug_name)) + 'Cannot wait on unknown plug "{}".'.format(plug_name)) if not isinstance(plug_instance, base_plugs.FrontendAwareBasePlug): raise base_plugs.InvalidPlugError( - 'Cannot wait on a plug {} that is not an subclass ' - 'of FrontendAwareBasePlug.'.format(plug_name)) + 'Cannot wait on a plug {} that is not an subclass ' + 'of FrontendAwareBasePlug.'.format(plug_name)) state, update_event = plug_instance.asdict_with_event() if state != remote_state: @@ -388,6 +408,6 @@ def wait_for_plug_update( def get_frontend_aware_plug_names(self) -> List[Text]: """Returns the names of frontend-aware plugs.""" return [ - name for name, plug in self._plugs_by_name.items() - if isinstance(plug, base_plugs.FrontendAwareBasePlug) + name for name, plug in self._plugs_by_name.items() + if isinstance(plug, base_plugs.FrontendAwareBasePlug) ]