From 8da96885d1f58641dd50861e3a06b8b2fbbb2f97 Mon Sep 17 00:00:00 2001 From: Adam Souzis Date: Tue, 12 Nov 2024 20:10:40 -0800 Subject: [PATCH] loader: clear modules when root changes, dont fallback to safe mode when creating a dsl configurator. --- tosca-package/tosca/loader.py | 33 ++++++++++++++++++--------------- unfurl/planrequests.py | 7 +++++-- 2 files changed, 23 insertions(+), 17 deletions(-) diff --git a/tosca-package/tosca/loader.py b/tosca-package/tosca/loader.py index e766b7a3..794e009f 100644 --- a/tosca-package/tosca/loader.py +++ b/tosca-package/tosca/loader.py @@ -162,7 +162,7 @@ def exec_module(self, module): python_filepath = str(path) with open(path) as f: src = f.read() - safe_mode = import_resolver.get_safe_mode() if import_resolver else True + safe_mode = import_resolver.get_safe_mode() if import_resolver else global_state.safe_mode module.__dict__["__file__"] = python_filepath for i in range(self.full_name.count(".")): path = path.parent @@ -333,20 +333,20 @@ def __safe_import__( level=0, ): parts = name.split(".") + assert modules is not sys.modules if level == 0: is_allowed_package = name in ALLOWED_PRIVATE_PACKAGES if name in modules and not (is_allowed_package and fromlist): # we need to skip the second check to allow the fromlist to includes sub-packages module = modules[name] - if parts[0] != "tosca_repositories": - _check_fromlist(module, fromlist) - elif fromlist: - for from_name in fromlist: - if not hasattr(module, from_name): - # e.g. from tosca_repositories.repo import module - load_private_module(base_dir, modules, name+"."+from_name) + if parts[0] == "tosca_repositories": + if fromlist: + for from_name in fromlist: + if not hasattr(module, from_name): + # e.g. from tosca_repositories.repo import module + load_private_module(base_dir, modules, name+"."+from_name) return module if fromlist else modules[parts[0]] - if name in ALLOWED_MODULES: + if name in ALLOWED_MODULES: # allowed but need to be made ImmutableModule if len(parts) > 1: first = importlib.import_module(parts[0]) first = ImmutableModule(parts[0], **vars(first)) @@ -587,6 +587,11 @@ def check_import_names(self, node): import_resolver: Optional[ImportResolver] = None service_template_basedir = "" +def _clear_special_modules(): + # these are relative to the manifest + for name in list(sys.modules): + if name.startswith("service_template") or name.startswith("tosca_repositories"): + del sys.modules[name] def install(import_resolver_: Optional[ImportResolver], base_dir=None) -> str: # insert the path hook ahead of other path hooks @@ -596,13 +601,16 @@ def install(import_resolver_: Optional[ImportResolver], base_dir=None) -> str: old_basedir = service_template_basedir if base_dir: service_template_basedir = base_dir + if base_dir != old_basedir: + _clear_special_modules() # these are bad else: service_template_basedir = os.getcwd() global installed if installed: return old_basedir - sys.meta_path.insert(0, RepositoryFinder()) + if not isinstance(sys.meta_path[0], RepositoryFinder): + sys.meta_path.insert(0, RepositoryFinder()) # XXX this breaks imports in local scope somehow: # sys.path_hooks.insert(0, FileFinder.path_hook(loader_details)) # this break some imports: @@ -658,7 +666,6 @@ def restricted_exec( ) package, sep, module_name = full_name.rpartition(".") if modules is None: - logger.warning(f"!!!! {full_name} {safe_mode}, {global_state.modules is sys.modules}") modules = global_state.modules if safe_mode else sys.modules if namespace is None: @@ -749,7 +756,6 @@ def restricted_exec( sys.modules[full_name] = temp_module previous_safe_mode = global_state.safe_mode previous_mode = global_state.mode - logger.warning("00000000000 %s %s", full_name, temp_module) try: global_state.safe_mode = safe_mode global_state.mode = "spec" @@ -759,9 +765,6 @@ def restricted_exec( namespace.update(temp_module.__dict__) else: exec(result.code, namespace) - except: - logger.warning("ASDFASDFASDFASDF %s %s", full_name, temp_module) - raise finally: global_state.safe_mode = previous_safe_mode global_state.mode = previous_mode diff --git a/unfurl/planrequests.py b/unfurl/planrequests.py index 24e256cb..b7dc9310 100644 --- a/unfurl/planrequests.py +++ b/unfurl/planrequests.py @@ -26,7 +26,7 @@ from toscaparser.elements.interfaces import OperationDef from toscaparser.nodetemplate import NodeTemplate from .spec import ArtifactSpec, EntitySpec -import tosca +import tosca.loader if TYPE_CHECKING: from .job import Job, ConfigTask, JobOptions @@ -132,7 +132,10 @@ def create(self) -> "Configurator": module_name, qualname, action = className.split(":") logger.warning("loading dsl configurator %s %s", module_name, tosca.safe_mode()) - assert not tosca.safe_mode(), module_name + if tosca.loader.import_resolver: + assert not tosca.loader.import_resolver.get_safe_mode(), module_name + else: + assert not tosca.safe_mode(), module_name module = importlib.import_module(module_name) cls_name, sep, func_name = qualname.rpartition(".") cls = getattr(module, cls_name)