Skip to content

Commit

Permalink
added support for directly adding modules
Browse files Browse the repository at this point in the history
  • Loading branch information
kessler-frost committed Dec 6, 2023
1 parent 8219f3b commit dc2bba2
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 11 deletions.
1 change: 1 addition & 0 deletions covalent/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from ._workflow import ( # nopycln: import
DepsBash,
DepsCall,
DepsModule,
DepsPip,
Lepton,
TransportableObject,
Expand Down
1 change: 1 addition & 0 deletions covalent/_workflow/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

from .depsbash import DepsBash
from .depscall import DepsCall
from .depsmodule import DepsModule
from .depspip import DepsPip
from .electron import electron
from .lattice import lattice
Expand Down
20 changes: 14 additions & 6 deletions covalent/_workflow/depsmodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,34 +15,40 @@
# limitations under the License.

import importlib
from types import ModuleType
from typing import Union

import cloudpickle as pickle

from .depscall import DepsCall


def _client_side_pickle_module(module_name: str):
def _client_side_pickle_module(module: Union[str, ModuleType]):
"""
Pickle a module by value on the client side
and return the pickled bytes.
Args:
module_name: The name of the module to pickle.
module: The name of the module to pickle, can also be a module.
This module must be importable on the client side.
Returns:
The pickled bytes of the module.
"""

# Import the module on the client side
module = importlib.import_module(module_name)
if isinstance(module, str):
# Import the module on the client side
module = importlib.import_module(module)

# Register the module with cloudpickle by value
pickle.register_pickle_by_value(module)

# Pickle the module
pickled_module = pickle.dumps(module)

# Unregister the module with cloudpickle
# pickle.unregister_pickle_by_value(module)

return pickled_module


Expand Down Expand Up @@ -79,9 +85,11 @@ class DepsModule(DepsCall):
module_name: A string containing the name of the module to be imported.
"""

def __init__(self, module_name: str):
def __init__(self, module: Union[str, ModuleType]):
module_name = module if isinstance(module, str) else module.__name__

# Pickle the module by value on the client side
module_pickle = _client_side_pickle_module(module_name)
module_pickle = _client_side_pickle_module(module)

# Pass the pickled module to the server side
func = _server_side_import_module
Expand Down
22 changes: 17 additions & 5 deletions covalent/_workflow/electron.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from builtins import list
from dataclasses import asdict
from functools import wraps
from types import ModuleType
from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, List, Optional, Union

from covalent._dispatcher_plugins.local import LocalDispatcher
Expand Down Expand Up @@ -753,12 +754,23 @@ def electron(

if deps_module:
if isinstance(deps_module, list):
deps_module = [
DepsModule(module_name=module) if isinstance(module, str) else module
for module in deps_module
]
# Convert to DepsModule objects
converted_deps = []
for dep in deps_module:
if isinstance(dep, str):
converted_deps.append(DepsModule(dep))
elif isinstance(dep, ModuleType):
converted_deps.append(DepsModule(dep))
else:
converted_deps.append(dep)
deps_module = converted_deps

elif isinstance(deps_module, str):
deps_module = [DepsModule(module_name=deps_module)]
deps_module = [DepsModule(deps_module)]

elif isinstance(deps_module, ModuleType):
deps_module = [DepsModule(deps_module)]

elif isinstance(deps_module, DepsModule):
deps_module = [deps_module]

Expand Down

0 comments on commit dc2bba2

Please sign in to comment.