Skip to content

Commit

Permalink
Make entrypoint test programmatic and not pre-loaded
Browse files Browse the repository at this point in the history
  • Loading branch information
sebastian-quintero committed Nov 26, 2024
1 parent bbd7082 commit 0ec0377
Show file tree
Hide file tree
Showing 7 changed files with 37 additions and 111 deletions.
20 changes: 11 additions & 9 deletions nextmv/cloud/manifest.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,15 +206,17 @@ def from_model_configuration(cls, model_configuration: ModelConfiguration) -> "M
The Python manifest.
"""

manifest_python = ManifestPython.from_dict(
{
"pip-requirements": _REQUIREMENTS_FILE,
"model": {
"name": model_configuration.name,
"options": model_configuration.options.parameters_dict(),
},
}
)
manifest_python_dict = {
"pip-requirements": _REQUIREMENTS_FILE,
"model": {
"name": model_configuration.name,
},
}

if model_configuration.options is not None:
manifest_python_dict["model"]["options"] = model_configuration.options.parameters_dict()

manifest_python = ManifestPython.from_dict(manifest_python_dict)

return cls(
files=["main.py", f"{model_configuration.name}/**"],
Expand Down
59 changes: 0 additions & 59 deletions tests/test_entrypoint/nextroute_model/MLmodel

This file was deleted.

11 changes: 0 additions & 11 deletions tests/test_entrypoint/nextroute_model/conda.yaml

This file was deleted.

7 changes: 0 additions & 7 deletions tests/test_entrypoint/nextroute_model/python_env.yaml

This file was deleted.

Binary file not shown.
4 changes: 0 additions & 4 deletions tests/test_entrypoint/nextroute_model/requirements.txt

This file was deleted.

47 changes: 26 additions & 21 deletions tests/test_entrypoint/test_entrypoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,23 @@
import shutil
import subprocess
import sys
import time
import unittest

import nextmv
import nextmv.cloud


class SimpleDecisionModel(nextmv.Model):
def solve(self, input: nextmv.Input, options: nextmv.Options) -> nextmv.Output:
return nextmv.Output(
solution={"foo": "bar"},
statistics={"baz": "qux"},
)


class TestEntrypoint(unittest.TestCase):
TWO_DIRS_UP = os.path.join("..", "..")
MODEL_NAME = "simple_decision_model"

def setUp(self):
"""Copies the entrypoint script as the main script in the root of an
Expand All @@ -19,32 +30,18 @@ def setUp(self):
dst = self._file_name("main.py", self.TWO_DIRS_UP)
shutil.copy(src, dst)

# Copy app files.
for file in ["input.json", "app.yaml"]:
src = self._file_name(file, ".")
dst = self._file_name(file, self.TWO_DIRS_UP)
shutil.copy(src, dst)

# Copy mlflow dir.
src = self._file_name("nextroute_model", ".")
dst = self._file_name("nextroute_model", self.TWO_DIRS_UP)
shutil.copytree(src, dst, dirs_exist_ok=True)

time.sleep(10)

def tearDown(self):
"""Removes the newly created main script elements."""

filenames = [
self._file_name("main.py", self.TWO_DIRS_UP),
self._file_name("input.json", self.TWO_DIRS_UP),
self._file_name("app.yaml", self.TWO_DIRS_UP),
]

for filename in filenames:
os.remove(filename)

shutil.rmtree(self._file_name("nextroute_model", self.TWO_DIRS_UP))
shutil.rmtree(self._file_name(self.MODEL_NAME, self.TWO_DIRS_UP))
shutil.rmtree(self._file_name("mlruns", self.TWO_DIRS_UP))

def test_entrypoint(self):
Expand All @@ -55,11 +52,19 @@ def test_entrypoint(self):
"nextroute_model" directory.
"""

input_file = self._file_name("input.json", self.TWO_DIRS_UP)
with open(input_file) as f:
input_data = json.load(f)
model = SimpleDecisionModel()
options = nextmv.Options(nextmv.Parameter("param1", str, ""))

model_configuration = nextmv.ModelConfiguration(
name=self.MODEL_NAME,
options=options,
)
destination = os.path.join(os.path.dirname(__file__), self.TWO_DIRS_UP)
model.save(destination, model_configuration)

manifest = nextmv.cloud.Manifest.from_model_configuration(model_configuration)
manifest.to_yaml(dirpath=destination)

input_stream = json.dumps(input_data)
main_file = self._file_name("main.py", self.TWO_DIRS_UP)

args = [sys.executable, main_file]
Expand All @@ -70,7 +75,7 @@ def test_entrypoint(self):
check=True,
text=True,
capture_output=True,
input=input_stream,
input=json.dumps({}),
)
except subprocess.CalledProcessError as e:
print("stderr:\n", e.stderr)
Expand Down

0 comments on commit 0ec0377

Please sign in to comment.