diff --git a/src/rocker/cli.py b/src/rocker/cli.py index 7d78641..2840c8b 100644 --- a/src/rocker/cli.py +++ b/src/rocker/cli.py @@ -20,6 +20,7 @@ from .core import get_rocker_version from .core import RockerExtensionManager from .core import DependencyMissing +from .core import ExtensionError from .os_detector import detect_os @@ -54,9 +55,11 @@ def main(): args_dict['mode'] = OPERATIONS_DRY_RUN print('DEPRECATION Warning: --noexecute is deprecated for --mode dry-run please switch your usage by December 2020') - active_extensions = extension_manager.get_active_extensions(args_dict) - # Force user to end if present otherwise it will break other extensions - active_extensions.sort(key=lambda e:e.get_name().startswith('user')) + try: + active_extensions = extension_manager.get_active_extensions(args_dict) + except ExtensionError as e: + print(f"ERROR! {str(e)}") + return 1 print("Active extensions %s" % [e.get_name() for e in active_extensions]) base_image = args.image diff --git a/src/rocker/core.py b/src/rocker/core.py index 2f1a535..7c23d9f 100644 --- a/src/rocker/core.py +++ b/src/rocker/core.py @@ -32,6 +32,7 @@ import signal import struct import termios +import typing SYS_STDOUT = sys.stdout @@ -45,6 +46,10 @@ class DependencyMissing(RuntimeError): pass +class ExtensionError(RuntimeError): + pass + + class RockerExtension(object): """The base class for Rocker extension points""" @@ -58,6 +63,22 @@ def validate_environment(self, cliargs): necessary resources are available, like hardware.""" pass + def invoke_after(self, cliargs) -> typing.Set[str]: + """ + This extension should be loaded after the extensions in the returned + set. These extensions are not required to be present, but if they are, + they will be loaded before this extension. + """ + return set() + + def required(self, cliargs) -> typing.Set[str]: + """ + Ensures the specified extensions are present and combined with + this extension. If the required extension should be loaded before + this extension, it should also be added to the `invoke_after` set. + """ + return set() + def get_preamble(self, cliargs): return '' @@ -106,13 +127,70 @@ def extend_cli_parser(self, parser, default_args={}): parser.add_argument('--extension-blacklist', nargs='*', default=[], help='Prevent any of these extensions from being loaded.') + parser.add_argument('--strict-extension-selection', action='store_true', + help='When enabled, causes an error if required extensions are not explicitly ' + 'called out on the command line. Otherwise, the required extensions will ' + 'automatically be loaded if available.') def get_active_extensions(self, cli_args): - active_extensions = [e() for e in self.available_plugins.values() if e.check_args_for_activation(cli_args) and e.get_name() not in cli_args['extension_blacklist']] - active_extensions.sort(key=lambda e:e.get_name().startswith('user')) - return active_extensions + """ + Checks for missing dependencies (specified by each extension's + required() method) and additionally sorts them. + """ + def sort_extensions(extensions, cli_args): + + def topological_sort(source: typing.Dict[str, typing.Set[str]]) -> typing.List[str]: + """Perform a topological sort on names and dependencies and returns the sorted list of names.""" + names = set(source.keys()) + # prune optional dependencies if they are not present (at this point the required check has already occurred) + pending = [(name, dependencies.intersection(names)) for name, dependencies in source.items()] + emitted = [] + while pending: + next_pending = [] + next_emitted = [] + for entry in pending: + name, deps = entry + deps.difference_update(emitted) # remove dependencies already emitted + if deps: # still has dependencies? recheck during next pass + next_pending.append(entry) + else: # no more dependencies? time to emit + yield name + next_emitted.append(name) # remember what was emitted for difference_update() + if not next_emitted: + raise ExtensionError("Cyclic dependancy detected: %r" % (next_pending,)) + pending = next_pending + emitted = next_emitted + + extension_graph = {name: cls.invoke_after(cli_args) for name, cls in sorted(extensions.items())} + active_extension_list = [extensions[name] for name in topological_sort(extension_graph)] + return active_extension_list + + active_extensions = {} + find_reqs = set([name for name, cls in self.available_plugins.items() if cls.check_args_for_activation(cli_args)]) + while find_reqs: + name = find_reqs.pop() + + if name in self.available_plugins.keys(): + if name not in cli_args['extension_blacklist']: + ext = self.available_plugins[name]() + active_extensions[name] = ext + else: + raise ExtensionError(f"Extension '{name}' is blacklisted.") + else: + raise ExtensionError(f"Extension '{name}' not found. Is it installed?") + + # add additional reqs for processing not already known about + known_reqs = set(active_extensions.keys()).union(find_reqs) + missing_reqs = ext.required(cli_args).difference(known_reqs) + if missing_reqs: + if cli_args['strict_extension_selection']: + raise ExtensionError(f"Extension '{name}' is missing required extension(s) {list(missing_reqs)}") + else: + print(f"Adding implicilty required extension(s) {list(missing_reqs)} required by extension '{name}'") + find_reqs = find_reqs.union(missing_reqs) + return sort_extensions(active_extensions, cli_args) def get_docker_client(): """Simple helper function for pre 2.0 imports""" @@ -254,7 +332,6 @@ def get_operating_mode(self, args): print("No tty detected for stdin forcing non-interactive") return operating_mode - def generate_docker_cmd(self, command='', **kwargs): docker_args = '' diff --git a/test/test_core.py b/test/test_core.py index f57e659..f7b12dc 100644 --- a/test/test_core.py +++ b/test/test_core.py @@ -26,7 +26,9 @@ from rocker.core import list_plugins from rocker.core import get_docker_client from rocker.core import get_rocker_version +from rocker.core import RockerExtension from rocker.core import RockerExtensionManager +from rocker.core import ExtensionError class RockerCoreTest(unittest.TestCase): @@ -128,9 +130,82 @@ def test_extension_manager(self): self.assertIn('non-interactive', help_str) self.assertIn('--extension-blacklist', help_str) - active_extensions = active_extensions = extension_manager.get_active_extensions({'user': True, 'ssh': True, 'extension_blacklist': ['ssh']}) - self.assertEqual(len(active_extensions), 1) - self.assertEqual(active_extensions[0].get_name(), 'user') + self.assertRaises(ExtensionError, + extension_manager.get_active_extensions, + {'user': True, 'ssh': True, 'extension_blacklist': ['ssh']}) + + def test_strict_required_extensions(self): + class Foo(RockerExtension): + @classmethod + def get_name(cls): + return 'foo' + + class Bar(RockerExtension): + @classmethod + def get_name(cls): + return 'bar' + + def required(self, cli_args): + return {'foo'} + + extension_manager = RockerExtensionManager() + extension_manager.available_plugins = {'foo': Foo, 'bar': Bar} + + correct_extensions_args = {'strict_extension_selection': True, 'bar': True, 'foo': True, 'extension_blacklist': []} + extension_manager.get_active_extensions(correct_extensions_args) + + incorrect_extensions_args = {'strict_extension_selection': True, 'bar': True, 'extension_blacklist': []} + self.assertRaises(ExtensionError, + extension_manager.get_active_extensions, incorrect_extensions_args) + + def test_implicit_required_extensions(self): + class Foo(RockerExtension): + @classmethod + def get_name(cls): + return 'foo' + + class Bar(RockerExtension): + @classmethod + def get_name(cls): + return 'bar' + + def required(self, cli_args): + return {'foo'} + + extension_manager = RockerExtensionManager() + extension_manager.available_plugins = {'foo': Foo, 'bar': Bar} + + implicit_extensions_args = {'strict_extension_selection': False, 'bar': True, 'extension_blacklist': []} + active_extensions = extension_manager.get_active_extensions(implicit_extensions_args) + self.assertEqual(len(active_extensions), 2) + # required extensions are not ordered, just check to make sure they are both present + if active_extensions[0].get_name() == 'foo': + self.assertEqual(active_extensions[1].get_name(), 'bar') + else: + self.assertEqual(active_extensions[0].get_name(), 'bar') + self.assertEqual(active_extensions[1].get_name(), 'foo') + + def test_extension_sorting(self): + class Foo(RockerExtension): + @classmethod + def get_name(cls): + return 'foo' + + class Bar(RockerExtension): + @classmethod + def get_name(cls): + return 'bar' + + def invoke_after(self, cli_args): + return {'foo', 'absent_extension'} + + extension_manager = RockerExtensionManager() + extension_manager.available_plugins = {'foo': Foo, 'bar': Bar} + + args = {'bar': True, 'foo': True, 'extension_blacklist': []} + active_extensions = extension_manager.get_active_extensions(args) + self.assertEqual(active_extensions[0].get_name(), 'foo') + self.assertEqual(active_extensions[1].get_name(), 'bar') def test_docker_cmd_interactive(self): dig = DockerImageGenerator([], {}, 'ubuntu:bionic') @@ -148,7 +223,6 @@ def test_docker_cmd_interactive(self): self.assertNotIn('-it', dig.generate_docker_cmd(mode='non-interactive')) - def test_docker_cmd_nocleanup(self): dig = DockerImageGenerator([], {}, 'ubuntu:bionic')