Skip to content

Commit

Permalink
experiment #3
Browse files Browse the repository at this point in the history
  • Loading branch information
aszs committed Nov 13, 2024
1 parent d0d061b commit 72a8c7b
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 34 deletions.
3 changes: 2 additions & 1 deletion tests/test_dsl.py
Original file line number Diff line number Diff line change
Expand Up @@ -1212,13 +1212,14 @@ def test_sandbox(capsys):
"""foo = dict(); foo[1] = 2; bar = list(); bar.append(1); baz = tuple()""",
"""import math; math.floor(1.0)""",
"""from unfurl.configurators.templates.dns import unfurl_relationships_DNSRecords""",
"""from unfurl.tosca_plugins import k8s; k8s.kube_artifacts""",
# """from unfurl.tosca_plugins import k8s; k8s.kube_artifacts""",
"""import tosca
node = tosca.nodes.Root()
node._name = "test"
""",
]
for src in allowed:
print("allowed", src)
assert _to_yaml(src, True)


Expand Down
76 changes: 43 additions & 33 deletions tosca-package/tosca/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,21 +181,21 @@ class ImmutableModule(ModuleType):
"__root__",
)

def __init__(self, name="__builtins__", **kw):
ModuleType.__init__(self, name)
super().__getattribute__("__dict__").update(kw)
def __init__(self, name, module):
super().__init__(name)
object.__getattribute__(self, "__dict__")["__protected_module__"] = module

def __getattribute__(self, __name: str) -> Any:
attrs = super().__getattribute__("__dict__")
module = super().__getattribute__("__protected_module__")
if (
__name not in ImmutableModule.__always_safe__
and __name not in attrs.get("__safe__", attrs.get("__all__", ()))
and attrs.get("__name__") != "math"
and __name not in getattr(module, "__safe__", getattr(module, "__all__", ()))
and module.__name__ != "math"
):
# special case "math", it doesn't have __all__
# only allow access to public attributes
raise AttributeError(__name)
return super().__getattribute__(__name)
return getattr(module, __name)

def __setattr__(self, name, v):
raise AttributeError(name)
Expand All @@ -204,14 +204,15 @@ def __delattr__(self, name):
raise AttributeError(name)


class DeniedModule(ImmutableModule):
class DeniedModule(ModuleType):
"""
A dummy module that defers raising ImportError until the module is accessed.
This allows unsafe import statements in the global scope as long as access is never attempted during sandbox execution.
"""

def __init__(self, name, fromlist, **kw):
super().__init__(name, **kw)
super().__init__(name)
object.__getattribute__(self, "__dict__").update(kw)
object.__getattribute__(self, "__dict__")["__fromlist__"] = fromlist

def __getattribute__(self, __name: str) -> Any:
Expand Down Expand Up @@ -296,6 +297,7 @@ def load_private_module(base_dir: str, modules: Dict[str, ModuleType], name: str


def _check_fromlist(module, fromlist):
# note: allowed modules aren't lazily checked like denied modules and so will raise ImportError in safe mode
if fromlist:
allowed = set(getattr(module, "__safe__", getattr(module, "__all__", ())))
for name in fromlist:
Expand All @@ -311,7 +313,7 @@ def _load_or_deny_module(name, ALLOWED_MODULES, modules):
return modules[name]
if name in ALLOWED_MODULES:
module = importlib.import_module(name)
module = ImmutableModule(name, **vars(module))
module = ImmutableModule(name, module)
modules[name] = module
return module
else:
Expand All @@ -332,41 +334,47 @@ def __safe_import__(
assert modules is not sys.modules
if level == 0:
is_allowed_package = name in ALLOWED_PRIVATE_PACKAGES
if name in modules and not (is_allowed_package and fromlist):
# we need to skip the second check to allow the fromlist to includes sub-packages
if name in modules:
if not fromlist:
return modules[parts[0]]
module = modules[name]
if parts[0] == "tosca_repositories":
if fromlist:
for from_name in fromlist:
if not hasattr(module, from_name):
# e.g. from tosca_repositories.repo import module
load_private_module(base_dir, modules, name+"."+from_name)
return module if fromlist else modules[parts[0]]
for from_name in fromlist:
if not hasattr(module, from_name):
# e.g. from tosca_repositories.repo import module
load_private_module(base_dir, modules, name+"."+from_name)
if name in ALLOWED_MODULES:
_check_fromlist(module, fromlist)
elif "*" in fromlist:
_validate_star(module)
# otherwise privately loaded or DeniedModule, don't need to validate fromlist besides *
# XXX if DeniedModule, update its __fromlist__ to defer ImportError
return module
if name in ALLOWED_MODULES: # allowed but need to be made ImmutableModule
if len(parts) > 1:
first = importlib.import_module(parts[0])
first = ImmutableModule(parts[0], **vars(first))
first = ImmutableModule(parts[0], first)
modules[parts[0]] = first
last = importlib.import_module(name)
_check_fromlist(last, fromlist)
last = ImmutableModule(name, **vars(last))
last = ImmutableModule(name, last)
modules[name] = last
# we don't need to worry about _handle_fromlist here because we don't allow importing submodules
return last if fromlist else first
else:
module = importlib.import_module(name)
_check_fromlist(module, fromlist)
module = ImmutableModule(name, **vars(module))
module = ImmutableModule(name, module)
modules[name] = module
return module
elif not is_allowed_package and parts[0] != "tosca_repositories":
# these modules fall through to load_private_module():
package_name, sep, module_name = name.rpartition(".")
if package_name not in ALLOWED_PRIVATE_PACKAGES:
if fromlist:
return DeniedModule(name, fromlist)
else:
return _load_or_deny_module(parts[0], ALLOWED_MODULES, modules)
# otherwise fall through to load_private_module():
else:
# relative import
package = globals["__package__"] if globals else None
Expand All @@ -377,22 +385,24 @@ def __safe_import__(
module = load_private_module(base_dir, modules, name)
if fromlist:
if "*" in fromlist:
# ok if there's no __all__ because default * will exclude names that start with "_"
safe = getattr(module, "__safe__", None)
if safe is not None:
all_name = getattr(module, "__all__", ())
if set(safe) != set(all_name):
raise ImportError(
f'Import of * from {module.__name__} is not permitted, its "__all__" does not match its "__safe__" attribute.',
name=module.__name__,
)

_validate_star(module)
# see https://github.com/python/cpython/blob/3.11/Lib/importlib/_bootstrap.py#L1207
importlib._bootstrap._handle_fromlist(
module, fromlist, lambda name: load_private_module(base_dir, modules, name)
)
return module

def _validate_star(module):
# ok if there's no __all__ because default * will exclude names that start with "_"
safe = getattr(module, "__safe__", None)
if safe is not None:
all_name = getattr(module, "__all__", ())
if set(safe) != set(all_name):
raise ImportError(
f'Import of * from {module.__name__} is not permitted, its "__all__" does not match its "__safe__" attribute.',
name=module.__name__,
)


def doc_str(node):
if isinstance(node, Expr):
Expand Down Expand Up @@ -751,7 +761,7 @@ def restricted_exec(
else:
temp_module = modules[full_name]
remove_temp_module = False
if full_name not in sys.modules:
if temp_module and full_name not in sys.modules:
# dataclass._process_class() might assume the current module is in sys.modules
# so to make it happy add a dummy one if its missing
sys.modules[full_name] = temp_module
Expand Down

0 comments on commit 72a8c7b

Please sign in to comment.