Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add extension ordering #242

Merged
merged 14 commits into from
Oct 3, 2023
8 changes: 6 additions & 2 deletions src/rocker/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from .core import get_rocker_version
from .core import RockerExtensionManager
from .core import DependencyMissing
from .core import RequiredExtensionMissingError

from .os_detector import detect_os

Expand Down Expand Up @@ -55,8 +56,11 @@ def main():
print('DEPRECATION Warning: --noexecute is deprecated for --mode dry-run please switch your usage by December 2020')

active_extensions = extension_manager.get_active_extensions(args_dict)
# Force user to end if present otherwise it will break other extensions
active_extensions.sort(key=lambda e:e.get_name().startswith('user'))
try:
active_extensions = extension_manager.get_active_extensions(args_dict)
except RequiredExtensionMissingError as e:
print(f"ERROR! Aborting, {str(e)}")
return 1
print("Active extensions %s" % [e.get_name() for e in active_extensions])

base_image = args.image
Expand Down
89 changes: 85 additions & 4 deletions src/rocker/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import signal
import struct
import termios
import typing

SYS_STDOUT = sys.stdout

Expand All @@ -45,6 +46,10 @@ class DependencyMissing(RuntimeError):
pass


class RequiredExtensionMissingError(RuntimeError):
pass


class RockerExtension(object):
"""The base class for Rocker extension points"""

Expand All @@ -58,6 +63,24 @@ def validate_environment(self, cliargs):
necessary resources are available, like hardware."""
pass

@staticmethod
def preceding_extensions() -> typing.Set[str]:
"""
Optional extensions. This merely ensures the preceding
extensions are applied before this extension when applying
snippets and arguments if they are present.
"""
return set()

@staticmethod
def required_extensions() -> typing.Set[str]:
"""
Ensures the specified extensions are present and combined with
this extension. In addition, it orders the application of
agyoungs marked this conversation as resolved.
Show resolved Hide resolved
the required extensions before this extension.
"""
return set()

def get_preamble(self, cliargs):
return ''

Expand Down Expand Up @@ -109,10 +132,69 @@ def extend_cli_parser(self, parser, default_args={}):


def get_active_extensions(self, cli_args):
active_extensions = [e() for e in self.available_plugins.values() if e.check_args_for_activation(cli_args) and e.get_name() not in cli_args['extension_blacklist']]
active_extensions.sort(key=lambda e:e.get_name().startswith('user'))
return active_extensions
"""
Checks for missing dependencies (specified by each extension's
required_extensions() method) and additionally sorts them.
"""
active_extensions = {
name: cls for name, cls in self.available_plugins.items()
if cls.check_args_for_activation(cli_args) and cls.get_name() not in cli_args['extension_blacklist']
}
names = set(active_extensions.keys())
for name, cls in active_extensions.items():
if not cls.required_extensions().issubset(names):
raise RequiredExtensionMissingError(f"extension '{name}' is missing required extensions {list(cls.required_extensions())}")
agyoungs marked this conversation as resolved.
Show resolved Hide resolved
return self.sort_extensions(active_extensions)

@staticmethod
def sort_extensions(extensions: typing.Dict[str, typing.Type[RockerExtension]]) -> typing.List[RockerExtension]:

def topological_sort(source: typing.Dict[str, typing.Set[str]]) -> typing.List[str]:
"""Perform a topological sort on names and dependencies and returns the sorted list of names."""
names = set(source.keys())
# dependencies are merely desired, not required, so prune them if they are not active
pending = [(name, dependencies.intersection(names)) for name, dependencies in source.items()]
emitted = []
while pending:
next_pending = []
next_emitted = []
for entry in pending:
name, deps = entry
deps.difference_update(emitted) # remove dependencies already emitted
if deps: # still has dependencies? recheck during next pass
next_pending.append(entry)
else: # no more dependencies? time to emit
yield name
next_emitted.append(name) # remember what was emitted for difference_update()
if not next_emitted:
raise ValueError("cyclic dependancy detected: %r" % (next_pending,))
agyoungs marked this conversation as resolved.
Show resolved Hide resolved
pending = next_pending
emitted = next_emitted

extension_graph = {}
# assume all extensions must precede user unless explicitly stated otherwise
agyoungs marked this conversation as resolved.
Show resolved Hide resolved
extensions_preceding_user = {k for k in extensions.keys() if k != 'user'}

for name, cls in sorted(extensions.items()):
if name == 'user':
# the 'user' extension is special and handled differently
continue

if 'user' in cls.preceding_extensions() or 'user' in cls.required_extensions():
# update the set so that the "user" extension can load before this extension
extensions_preceding_user.remove(name)

extension_graph[name] = cls.required_extensions().union(cls.preceding_extensions())

if 'user' in extensions.keys():
# update the "user" extension with the additional implied preceding extensions
extension_graph['user'] = extensions['user'].required_extensions().union(
extensions['user'].preceding_extensions()).union(extensions_preceding_user)

active_extension_list = []
for name in topological_sort(extension_graph):
active_extension_list.append(extensions[name]())
return active_extension_list

def get_docker_client():
"""Simple helper function for pre 2.0 imports"""
Expand Down Expand Up @@ -254,7 +336,6 @@ def get_operating_mode(self, args):
print("No tty detected for stdin forcing non-interactive")
return operating_mode


def generate_docker_cmd(self, command='', **kwargs):
docker_args = ''

Expand Down
67 changes: 66 additions & 1 deletion test/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,9 @@
from rocker.core import list_plugins
from rocker.core import get_docker_client
from rocker.core import get_rocker_version
from rocker.core import RockerExtension
from rocker.core import RockerExtensionManager
from rocker.core import RequiredExtensionMissingError

class RockerCoreTest(unittest.TestCase):

Expand Down Expand Up @@ -132,6 +134,70 @@ def test_extension_manager(self):
self.assertEqual(len(active_extensions), 1)
self.assertEqual(active_extensions[0].get_name(), 'user')

def test_required_extensions(self):
class Foo(RockerExtension):
@classmethod
def get_name(cls):
return 'foo'

class Bar(RockerExtension):
@classmethod
def get_name(cls):
return 'bar'

@staticmethod
def required_extensions():
return {'foo'}

extension_manager = RockerExtensionManager()
extension_manager.available_plugins = {'foo': Foo, 'bar': Bar}

correct_extensions = {'bar': True, 'foo': True, 'extension_blacklist': []}
extension_manager.get_active_extensions(correct_extensions)

incorrect_extensions = {'bar': True, 'extension_blacklist': []}
self.assertRaises(RequiredExtensionMissingError,
extension_manager.get_active_extensions, incorrect_extensions)

def test_extension_sorting(self):
class AUserExtension(RockerExtension):
@classmethod
def get_name(cls):
return 'a_user_extension'

@staticmethod
def preceding_extensions():
return {'user'}

class User(RockerExtension):
@classmethod
def get_name(cls):
return 'user'

class Foo(RockerExtension):
@classmethod
def get_name(cls):
return 'foo'

class Bar(RockerExtension):
@classmethod
def get_name(cls):
return 'bar'

@staticmethod
def required_extensions():
return {'foo'}

sorted_extensions = RockerExtensionManager.sort_extensions(
extensions={'a_user_extension': AUserExtension,
'user': User,
'bar': Bar,
'foo': Foo})
self.assertEqual(sorted_extensions[0].get_name(), 'foo')
self.assertEqual(sorted_extensions[1].get_name(), 'bar')
self.assertEqual(sorted_extensions[2].get_name(), 'user')
self.assertEqual(sorted_extensions[3].get_name(), 'a_user_extension')

def test_docker_cmd_interactive(self):
dig = DockerImageGenerator([], {}, 'ubuntu:bionic')

Expand All @@ -148,7 +214,6 @@ def test_docker_cmd_interactive(self):

self.assertNotIn('-it', dig.generate_docker_cmd(mode='non-interactive'))


def test_docker_cmd_nocleanup(self):
dig = DockerImageGenerator([], {}, 'ubuntu:bionic')

Expand Down