Skip to content

Commit

Permalink
loader: clear modules when root changes, dont fallback to safe mode w…
Browse files Browse the repository at this point in the history
…hen creating a dsl configurator.
  • Loading branch information
aszs committed Nov 13, 2024
1 parent 31ce807 commit 8da9688
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 17 deletions.
33 changes: 18 additions & 15 deletions tosca-package/tosca/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def exec_module(self, module):
python_filepath = str(path)
with open(path) as f:
src = f.read()
safe_mode = import_resolver.get_safe_mode() if import_resolver else True
safe_mode = import_resolver.get_safe_mode() if import_resolver else global_state.safe_mode
module.__dict__["__file__"] = python_filepath
for i in range(self.full_name.count(".")):
path = path.parent
Expand Down Expand Up @@ -333,20 +333,20 @@ def __safe_import__(
level=0,
):
parts = name.split(".")
assert modules is not sys.modules
if level == 0:
is_allowed_package = name in ALLOWED_PRIVATE_PACKAGES
if name in modules and not (is_allowed_package and fromlist):
# we need to skip the second check to allow the fromlist to includes sub-packages
module = modules[name]
if parts[0] != "tosca_repositories":
_check_fromlist(module, fromlist)
elif fromlist:
for from_name in fromlist:
if not hasattr(module, from_name):
# e.g. from tosca_repositories.repo import module
load_private_module(base_dir, modules, name+"."+from_name)
if parts[0] == "tosca_repositories":
if fromlist:
for from_name in fromlist:
if not hasattr(module, from_name):
# e.g. from tosca_repositories.repo import module
load_private_module(base_dir, modules, name+"."+from_name)
return module if fromlist else modules[parts[0]]
if name in ALLOWED_MODULES:
if name in ALLOWED_MODULES: # allowed but need to be made ImmutableModule
if len(parts) > 1:
first = importlib.import_module(parts[0])
first = ImmutableModule(parts[0], **vars(first))
Expand Down Expand Up @@ -587,6 +587,11 @@ def check_import_names(self, node):
import_resolver: Optional[ImportResolver] = None
service_template_basedir = ""

def _clear_special_modules():
# these are relative to the manifest
for name in list(sys.modules):
if name.startswith("service_template") or name.startswith("tosca_repositories"):
del sys.modules[name]

def install(import_resolver_: Optional[ImportResolver], base_dir=None) -> str:
# insert the path hook ahead of other path hooks
Expand All @@ -596,13 +601,16 @@ def install(import_resolver_: Optional[ImportResolver], base_dir=None) -> str:
old_basedir = service_template_basedir
if base_dir:
service_template_basedir = base_dir
if base_dir != old_basedir:
_clear_special_modules() # these are bad
else:
service_template_basedir = os.getcwd()
global installed
if installed:
return old_basedir

sys.meta_path.insert(0, RepositoryFinder())
if not isinstance(sys.meta_path[0], RepositoryFinder):
sys.meta_path.insert(0, RepositoryFinder())
# XXX this breaks imports in local scope somehow:
# sys.path_hooks.insert(0, FileFinder.path_hook(loader_details))
# this break some imports:
Expand Down Expand Up @@ -658,7 +666,6 @@ def restricted_exec(
)
package, sep, module_name = full_name.rpartition(".")
if modules is None:
logger.warning(f"!!!! {full_name} {safe_mode}, {global_state.modules is sys.modules}")
modules = global_state.modules if safe_mode else sys.modules

if namespace is None:
Expand Down Expand Up @@ -749,7 +756,6 @@ def restricted_exec(
sys.modules[full_name] = temp_module
previous_safe_mode = global_state.safe_mode
previous_mode = global_state.mode
logger.warning("00000000000 %s %s", full_name, temp_module)
try:
global_state.safe_mode = safe_mode
global_state.mode = "spec"
Expand All @@ -759,9 +765,6 @@ def restricted_exec(
namespace.update(temp_module.__dict__)
else:
exec(result.code, namespace)
except:
logger.warning("ASDFASDFASDFASDF %s %s", full_name, temp_module)
raise
finally:
global_state.safe_mode = previous_safe_mode
global_state.mode = previous_mode
Expand Down
7 changes: 5 additions & 2 deletions unfurl/planrequests.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from toscaparser.elements.interfaces import OperationDef
from toscaparser.nodetemplate import NodeTemplate
from .spec import ArtifactSpec, EntitySpec
import tosca
import tosca.loader

if TYPE_CHECKING:
from .job import Job, ConfigTask, JobOptions
Expand Down Expand Up @@ -132,7 +132,10 @@ def create(self) -> "Configurator":

module_name, qualname, action = className.split(":")
logger.warning("loading dsl configurator %s %s", module_name, tosca.safe_mode())
assert not tosca.safe_mode(), module_name
if tosca.loader.import_resolver:
assert not tosca.loader.import_resolver.get_safe_mode(), module_name
else:
assert not tosca.safe_mode(), module_name
module = importlib.import_module(module_name)
cls_name, sep, func_name = qualname.rpartition(".")
cls = getattr(module, cls_name)
Expand Down

0 comments on commit 8da9688

Please sign in to comment.