Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add recursive cache for private modules #2928

Draft
wants to merge 4 commits into
base: danielsola/se-251-make-auto-cache-plugin
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
:toctree: generated/

CacheFunctionBody
CachePrivateModules
"""

from .cache_function_body import CacheFunctionBody
from .cache_private_modules import CachePrivateModules
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import ast
import hashlib
import inspect
import textwrap
from typing import Any, Callable


Expand All @@ -27,6 +28,9 @@ def __init__(self, salt: str = "salt") -> None:
self.salt = salt

def get_version(self, func: Callable[..., Any]) -> str:
return self._get_version(func=func)

def _get_version(self, func: Callable[..., Any]) -> str:
"""
Generate a version hash for the provided function by parsing its source code
and adding a salt before applying the SHA-256 hash function.
Expand All @@ -37,11 +41,12 @@ def get_version(self, func: Callable[..., Any]) -> str:
Returns:
str: The SHA-256 hash of the function's source code combined with the salt.
"""
# Get the source code of the function
# Get the source code of the function and dedent
source = inspect.getsource(func)
dedented_source = textwrap.dedent(source)

# Parse the source code into an Abstract Syntax Tree (AST)
parsed_ast = ast.parse(source)
parsed_ast = ast.parse(dedented_source)

# Convert the AST into a string representation (dump it)
ast_bytes = ast.dump(parsed_ast).encode("utf-8")
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
import ast
import hashlib
import importlib.util
import inspect
import sys
import textwrap
from contextlib import contextmanager
from pathlib import Path
from typing import Any, Callable, Set, Union


@contextmanager
def temporarily_add_to_syspath(path):
"""Temporarily add the given path to sys.path."""
sys.path.insert(0, str(path))
try:
yield
finally:
sys.path.pop(0)


class CachePrivateModules:
def __init__(self, salt: str, root_dir: str):
self.salt = salt
self.root_dir = Path(root_dir).resolve()

def get_version(self, func: Callable[..., Any]) -> str:
hash_components = [self._get_version(func)]
dependencies = self._get_function_dependencies(func, set())
for dep in dependencies:
hash_components.append(self._get_version(dep))
# Combine all component hashes into a single version hash
combined_hash = hashlib.sha256("".join(hash_components).encode("utf-8")).hexdigest()
return combined_hash

def _get_version(self, func: Callable[..., Any]) -> str:
source = inspect.getsource(func)
dedented_source = textwrap.dedent(source)
parsed_ast = ast.parse(dedented_source)
ast_bytes = ast.dump(parsed_ast).encode("utf-8")
combined_data = ast_bytes + self.salt.encode("utf-8")
return hashlib.sha256(combined_data).hexdigest()

def _get_function_dependencies(self, func: Callable[..., Any], visited: Set[str]) -> Set[Callable[..., Any]]:
"""Recursively gather all functions, methods, and classes used within `func` and defined in the user’s package."""
dependencies = set()
# Dedent the source code to handle class method indentation
source = textwrap.dedent(inspect.getsource(func))
parsed_ast = ast.parse(source)

# Build a locals dictionary for function-level imports
locals_dict = {}
for node in ast.walk(parsed_ast):
if isinstance(node, ast.Import):
for alias in node.names:
module = importlib.import_module(alias.name)
locals_dict[alias.asname or alias.name] = module
elif isinstance(node, ast.ImportFrom):
module = importlib.import_module(node.module)
for alias in node.names:
imported_obj = getattr(module, alias.name, None)
if imported_obj:
locals_dict[alias.asname or alias.name] = imported_obj

# Check each function call in the AST
for node in ast.walk(parsed_ast):
if isinstance(node, ast.Call):
func_name = self._get_callable_name(node.func)
if func_name and func_name not in visited:
visited.add(func_name)
try:
# Attempt to resolve using locals first, then globals
func_obj = locals_dict.get(func_name) or self._resolve_callable(func_name, func.__globals__)
if inspect.isclass(func_obj) and self._is_user_defined(func_obj):
# Add class methods as dependencies
for name, method in inspect.getmembers(func_obj, predicate=inspect.isfunction):
if method not in visited:
visited.add(method.__qualname__)
dependencies.add(method)
dependencies.update(self._get_function_dependencies(method, visited))
elif (inspect.isfunction(func_obj) or inspect.ismethod(func_obj)) and self._is_user_defined(
func_obj
):
dependencies.add(func_obj)
dependencies.update(self._get_function_dependencies(func_obj, visited))
except (NameError, AttributeError):
pass
return dependencies

def _get_callable_name(self, node: ast.AST) -> Union[str, None]:
"""Retrieve the name of the callable from an AST node."""
if isinstance(node, ast.Name):
return node.id
elif isinstance(node, ast.Attribute):
return f"{node.value.id}.{node.attr}" if isinstance(node.value, ast.Name) else node.attr
return None

def _resolve_callable(self, func_name: str, globals_dict: dict) -> Callable[..., Any]:
"""Resolve a callable from its name within the given globals dictionary, handling modules as entry points."""
parts = func_name.split(".")

# First, try resolving directly from globals_dict for a straightforward reference
obj = globals_dict.get(parts[0], None)
for part in parts[1:]:
if obj is None:
break
obj = getattr(obj, part, None)

# If not found, iterate through modules in globals_dict and attempt resolution from them
if not callable(obj):
for module in globals_dict.values():
if isinstance(module, type(sys)): # Check if the global value is a module
obj = module
for part in parts:
obj = getattr(obj, part, None)
if obj is None:
break
if callable(obj): # Exit if we successfully resolve the callable
break
obj = None # Reset if we didn't find the callable in this module

# Return the callable if successfully resolved; otherwise, None
return obj if callable(obj) else None

def _is_user_defined(self, obj: Any) -> bool:
"""Check if a callable or class is user-defined within the package."""
module_name = getattr(obj, "__module__", None)
if not module_name:
return False

# Retrieve the module specification to get its path
with temporarily_add_to_syspath(self.root_dir):
spec = importlib.util.find_spec(module_name)
if not spec or not spec.origin:
return False

module_path = Path(spec.origin).resolve()

# Check if the module is within the root directory but not in site-packages
if self.root_dir in module_path.parents:
# Exclude standard library or site-packages by checking common paths
site_packages_paths = {Path(p).resolve() for p in sys.path if "site-packages" in p}
is_in_site_packages = any(sp in module_path.parents for sp in site_packages_paths)

# Return True if within root_dir but not in site-packages
return not is_in_site_packages

return False
Empty file.
20 changes: 20 additions & 0 deletions plugins/flytekit-auto-cache/tests/my_package/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import sys
from pathlib import Path

# Add the parent directory of `my_package` to sys.path
sys.path.append(str(Path(__file__).resolve().parent))

from module_a import helper_function
from my_dir.module_in_dir import helper_in_directory
from module_c import DummyClass
import pandas as pd # External library

def my_main_function():
print("Main function")
helper_in_directory()
helper_function()
df = pd.DataFrame({"a": [1, 2, 3]})
print(df)
dc = DummyClass()
print(dc)
dc.dummy_method()
8 changes: 8 additions & 0 deletions plugins/flytekit-auto-cache/tests/my_package/module_a.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
import module_b

def helper_function():
print("Helper function")
module_b.another_helper()

def unused_function():
print("Unused function")
5 changes: 5 additions & 0 deletions plugins/flytekit-auto-cache/tests/my_package/module_b.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from module_c import third_helper

def another_helper():
print("Another helper")
third_helper()
14 changes: 14 additions & 0 deletions plugins/flytekit-auto-cache/tests/my_package/module_c.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import my_dir

def third_helper():
print("Third helper")

class DummyClass:
def dummy_method(self) -> str:
my_dir.module_in_dir.other_helper_in_directory()
return "Hello from dummy method!"

def other_dummy_method(self):
from module_d import fourth_helper
print("Other dummy method")
fourth_helper()
2 changes: 2 additions & 0 deletions plugins/flytekit-auto-cache/tests/my_package/module_d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
def fourth_helper():
print("Fourth helper")
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .module_in_dir import other_helper_in_directory
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
def helper_in_directory():
print("Helper in directory")

def other_helper_in_directory():
print("Other helper in directory")
30 changes: 30 additions & 0 deletions plugins/flytekit-auto-cache/tests/test_recursive.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from flytekitplugins.auto_cache import CachePrivateModules
from my_package.main import my_main_function as func



def test_dependencies():
expected_dependencies = {
"module_a.helper_function",
"module_b.another_helper",
"module_c.DummyClass.dummy_method",
"module_c.DummyClass.other_dummy_method",
"module_c.third_helper",
"module_d.fourth_helper",
"my_dir.module_in_dir.helper_in_directory",
"my_dir.module_in_dir.other_helper_in_directory",
}

cache = CachePrivateModules(salt="salt", root_dir="./my_package")
actual_dependencies = cache._get_function_dependencies(func, set())

actual_dependencies_str = {
f"{dep.__module__}.{dep.__qualname__}".replace("my_package.", "")
for dep in actual_dependencies
}

assert actual_dependencies_str == expected_dependencies, (
f"Dependencies do not match:\n"
f"Expected: {expected_dependencies}\n"
f"Actual: {actual_dependencies_str}"
)
Loading