Skip to content

Commit

Permalink
Add extension ordering (#242)
Browse files Browse the repository at this point in the history
* Added extension ordering functionality based on changes made in https://github.com/stonier/groot_rocker.git

Co-authored-by: Alex Youngs <[email protected]>
  • Loading branch information
agyoungs and agyoungs authored Oct 3, 2023
1 parent 2b8d5ab commit 3afa6ac
Show file tree
Hide file tree
Showing 3 changed files with 165 additions and 11 deletions.
9 changes: 6 additions & 3 deletions src/rocker/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from .core import get_rocker_version
from .core import RockerExtensionManager
from .core import DependencyMissing
from .core import ExtensionError

from .os_detector import detect_os

Expand Down Expand Up @@ -54,9 +55,11 @@ def main():
args_dict['mode'] = OPERATIONS_DRY_RUN
print('DEPRECATION Warning: --noexecute is deprecated for --mode dry-run please switch your usage by December 2020')

active_extensions = extension_manager.get_active_extensions(args_dict)
# Force user to end if present otherwise it will break other extensions
active_extensions.sort(key=lambda e:e.get_name().startswith('user'))
try:
active_extensions = extension_manager.get_active_extensions(args_dict)
except ExtensionError as e:
print(f"ERROR! {str(e)}")
return 1
print("Active extensions %s" % [e.get_name() for e in active_extensions])

base_image = args.image
Expand Down
85 changes: 81 additions & 4 deletions src/rocker/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import signal
import struct
import termios
import typing

SYS_STDOUT = sys.stdout

Expand All @@ -45,6 +46,10 @@ class DependencyMissing(RuntimeError):
pass


class ExtensionError(RuntimeError):
pass


class RockerExtension(object):
"""The base class for Rocker extension points"""

Expand All @@ -58,6 +63,22 @@ def validate_environment(self, cliargs):
necessary resources are available, like hardware."""
pass

def invoke_after(self, cliargs) -> typing.Set[str]:
"""
This extension should be loaded after the extensions in the returned
set. These extensions are not required to be present, but if they are,
they will be loaded before this extension.
"""
return set()

def required(self, cliargs) -> typing.Set[str]:
"""
Ensures the specified extensions are present and combined with
this extension. If the required extension should be loaded before
this extension, it should also be added to the `invoke_after` set.
"""
return set()

def get_preamble(self, cliargs):
return ''

Expand Down Expand Up @@ -106,13 +127,70 @@ def extend_cli_parser(self, parser, default_args={}):
parser.add_argument('--extension-blacklist', nargs='*',
default=[],
help='Prevent any of these extensions from being loaded.')
parser.add_argument('--strict-extension-selection', action='store_true',
help='When enabled, causes an error if required extensions are not explicitly '
'called out on the command line. Otherwise, the required extensions will '
'automatically be loaded if available.')


def get_active_extensions(self, cli_args):
active_extensions = [e() for e in self.available_plugins.values() if e.check_args_for_activation(cli_args) and e.get_name() not in cli_args['extension_blacklist']]
active_extensions.sort(key=lambda e:e.get_name().startswith('user'))
return active_extensions
"""
Checks for missing dependencies (specified by each extension's
required() method) and additionally sorts them.
"""
def sort_extensions(extensions, cli_args):

def topological_sort(source: typing.Dict[str, typing.Set[str]]) -> typing.List[str]:
"""Perform a topological sort on names and dependencies and returns the sorted list of names."""
names = set(source.keys())
# prune optional dependencies if they are not present (at this point the required check has already occurred)
pending = [(name, dependencies.intersection(names)) for name, dependencies in source.items()]
emitted = []
while pending:
next_pending = []
next_emitted = []
for entry in pending:
name, deps = entry
deps.difference_update(emitted) # remove dependencies already emitted
if deps: # still has dependencies? recheck during next pass
next_pending.append(entry)
else: # no more dependencies? time to emit
yield name
next_emitted.append(name) # remember what was emitted for difference_update()
if not next_emitted:
raise ExtensionError("Cyclic dependancy detected: %r" % (next_pending,))
pending = next_pending
emitted = next_emitted

extension_graph = {name: cls.invoke_after(cli_args) for name, cls in sorted(extensions.items())}
active_extension_list = [extensions[name] for name in topological_sort(extension_graph)]
return active_extension_list

active_extensions = {}
find_reqs = set([name for name, cls in self.available_plugins.items() if cls.check_args_for_activation(cli_args)])
while find_reqs:
name = find_reqs.pop()

if name in self.available_plugins.keys():
if name not in cli_args['extension_blacklist']:
ext = self.available_plugins[name]()
active_extensions[name] = ext
else:
raise ExtensionError(f"Extension '{name}' is blacklisted.")
else:
raise ExtensionError(f"Extension '{name}' not found. Is it installed?")

# add additional reqs for processing not already known about
known_reqs = set(active_extensions.keys()).union(find_reqs)
missing_reqs = ext.required(cli_args).difference(known_reqs)
if missing_reqs:
if cli_args['strict_extension_selection']:
raise ExtensionError(f"Extension '{name}' is missing required extension(s) {list(missing_reqs)}")
else:
print(f"Adding implicilty required extension(s) {list(missing_reqs)} required by extension '{name}'")
find_reqs = find_reqs.union(missing_reqs)

return sort_extensions(active_extensions, cli_args)

def get_docker_client():
"""Simple helper function for pre 2.0 imports"""
Expand Down Expand Up @@ -254,7 +332,6 @@ def get_operating_mode(self, args):
print("No tty detected for stdin forcing non-interactive")
return operating_mode


def generate_docker_cmd(self, command='', **kwargs):
docker_args = ''

Expand Down
82 changes: 78 additions & 4 deletions test/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,9 @@
from rocker.core import list_plugins
from rocker.core import get_docker_client
from rocker.core import get_rocker_version
from rocker.core import RockerExtension
from rocker.core import RockerExtensionManager
from rocker.core import ExtensionError

class RockerCoreTest(unittest.TestCase):

Expand Down Expand Up @@ -128,9 +130,82 @@ def test_extension_manager(self):
self.assertIn('non-interactive', help_str)
self.assertIn('--extension-blacklist', help_str)

active_extensions = active_extensions = extension_manager.get_active_extensions({'user': True, 'ssh': True, 'extension_blacklist': ['ssh']})
self.assertEqual(len(active_extensions), 1)
self.assertEqual(active_extensions[0].get_name(), 'user')
self.assertRaises(ExtensionError,
extension_manager.get_active_extensions,
{'user': True, 'ssh': True, 'extension_blacklist': ['ssh']})

def test_strict_required_extensions(self):
class Foo(RockerExtension):
@classmethod
def get_name(cls):
return 'foo'

class Bar(RockerExtension):
@classmethod
def get_name(cls):
return 'bar'

def required(self, cli_args):
return {'foo'}

extension_manager = RockerExtensionManager()
extension_manager.available_plugins = {'foo': Foo, 'bar': Bar}

correct_extensions_args = {'strict_extension_selection': True, 'bar': True, 'foo': True, 'extension_blacklist': []}
extension_manager.get_active_extensions(correct_extensions_args)

incorrect_extensions_args = {'strict_extension_selection': True, 'bar': True, 'extension_blacklist': []}
self.assertRaises(ExtensionError,
extension_manager.get_active_extensions, incorrect_extensions_args)

def test_implicit_required_extensions(self):
class Foo(RockerExtension):
@classmethod
def get_name(cls):
return 'foo'

class Bar(RockerExtension):
@classmethod
def get_name(cls):
return 'bar'

def required(self, cli_args):
return {'foo'}

extension_manager = RockerExtensionManager()
extension_manager.available_plugins = {'foo': Foo, 'bar': Bar}

implicit_extensions_args = {'strict_extension_selection': False, 'bar': True, 'extension_blacklist': []}
active_extensions = extension_manager.get_active_extensions(implicit_extensions_args)
self.assertEqual(len(active_extensions), 2)
# required extensions are not ordered, just check to make sure they are both present
if active_extensions[0].get_name() == 'foo':
self.assertEqual(active_extensions[1].get_name(), 'bar')
else:
self.assertEqual(active_extensions[0].get_name(), 'bar')
self.assertEqual(active_extensions[1].get_name(), 'foo')

def test_extension_sorting(self):
class Foo(RockerExtension):
@classmethod
def get_name(cls):
return 'foo'

class Bar(RockerExtension):
@classmethod
def get_name(cls):
return 'bar'

def invoke_after(self, cli_args):
return {'foo', 'absent_extension'}

extension_manager = RockerExtensionManager()
extension_manager.available_plugins = {'foo': Foo, 'bar': Bar}

args = {'bar': True, 'foo': True, 'extension_blacklist': []}
active_extensions = extension_manager.get_active_extensions(args)
self.assertEqual(active_extensions[0].get_name(), 'foo')
self.assertEqual(active_extensions[1].get_name(), 'bar')

def test_docker_cmd_interactive(self):
dig = DockerImageGenerator([], {}, 'ubuntu:bionic')
Expand All @@ -148,7 +223,6 @@ def test_docker_cmd_interactive(self):

self.assertNotIn('-it', dig.generate_docker_cmd(mode='non-interactive'))


def test_docker_cmd_nocleanup(self):
dig = DockerImageGenerator([], {}, 'ubuntu:bionic')

Expand Down

0 comments on commit 3afa6ac

Please sign in to comment.