-
Notifications
You must be signed in to change notification settings - Fork 71
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #114 from tcdent/templates-config
Project template validation.
- Loading branch information
Showing
7 changed files
with
177 additions
and
51 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
from typing import Optional | ||
import os, sys | ||
from pathlib import Path | ||
import pydantic | ||
import requests | ||
from agentstack import ValidationError | ||
from agentstack.utils import get_package_path, open_json_file, term_color | ||
|
||
|
||
class TemplateConfig(pydantic.BaseModel): | ||
""" | ||
Interface for interacting with template configuration files. | ||
Templates are read-only. | ||
Template Schema | ||
------------- | ||
name: str | ||
The name of the project. | ||
description: str | ||
A description of the template. | ||
template_version: str | ||
The version of the template. | ||
framework: str | ||
The framework the template is for. | ||
method: str | ||
The method used by the project. ie. "sequential" | ||
agents: list[dict] | ||
A list of agents used by the project. TODO vaidate this against an agent schema | ||
tasks: list[dict] | ||
A list of tasks used by the project. TODO validate this against a task schema | ||
tools: list[dict] | ||
A list of tools used by the project. TODO validate this against a tool schema | ||
inputs: list[str] | ||
A list of inputs used by the project. | ||
""" | ||
|
||
name: str | ||
description: str | ||
template_version: int | ||
framework: str | ||
method: str | ||
agents: list[dict] | ||
tasks: list[dict] | ||
tools: list[dict] | ||
inputs: list[str] | ||
|
||
@classmethod | ||
def from_template_name(cls, name: str) -> 'TemplateConfig': | ||
path = get_package_path() / f'templates/proj_templates/{name}.json' | ||
if not os.path.exists(path): # TODO raise exceptions and handle message/exit in cli | ||
print(term_color(f'No known agentstack tool: {name}', 'red')) | ||
sys.exit(1) | ||
return cls.from_json(path) | ||
|
||
@classmethod | ||
def from_json(cls, path: Path) -> 'TemplateConfig': | ||
data = open_json_file(path) | ||
try: | ||
return cls(**data) | ||
except pydantic.ValidationError as e: | ||
# TODO raise exceptions and handle message/exit in cli | ||
print(term_color(f"Error validating template config JSON: \n{path}", 'red')) | ||
for error in e.errors(): | ||
print(f"{' '.join([str(loc) for loc in error['loc']])}: {error['msg']}") | ||
sys.exit(1) | ||
|
||
@classmethod | ||
def from_url(cls, url: str) -> 'TemplateConfig': | ||
if not url.startswith("https://"): | ||
raise ValidationError(f"Invalid URL: {url}") | ||
response = requests.get(url) | ||
if response.status_code != 200: | ||
raise ValidationError(f"Failed to fetch template from {url}") | ||
return cls(**response.json()) | ||
|
||
|
||
def get_all_template_paths() -> list[Path]: | ||
paths = [] | ||
templates_dir = get_package_path() / 'templates/proj_templates' | ||
for file in templates_dir.iterdir(): | ||
if file.suffix == '.json': | ||
paths.append(file) | ||
return paths | ||
|
||
|
||
def get_all_template_names() -> list[str]: | ||
return [path.stem for path in get_all_template_paths()] | ||
|
||
|
||
def get_all_templates() -> list[TemplateConfig]: | ||
return [TemplateConfig.from_json(path) for path in get_all_template_paths()] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
import subprocess | ||
import os, sys | ||
import unittest | ||
from parameterized import parameterized | ||
from pathlib import Path | ||
import shutil | ||
from agentstack.proj_templates import get_all_template_names | ||
|
||
BASE_PATH = Path(__file__).parent | ||
CLI_ENTRY = [ | ||
sys.executable, | ||
"-m", | ||
"agentstack.main", | ||
] | ||
|
||
|
||
class CLIInitTest(unittest.TestCase): | ||
def setUp(self): | ||
self.project_dir = Path(BASE_PATH / 'tmp/cli_init') | ||
os.makedirs(self.project_dir) | ||
|
||
def tearDown(self): | ||
shutil.rmtree(self.project_dir) | ||
|
||
def _run_cli(self, *args): | ||
"""Helper method to run the CLI with arguments.""" | ||
return subprocess.run([*CLI_ENTRY, *args], capture_output=True, text=True) | ||
|
||
def test_init_command(self): | ||
"""Test the 'init' command to create a project directory.""" | ||
os.chdir(self.project_dir) | ||
result = self._run_cli('init', str(self.project_dir)) | ||
self.assertEqual(result.returncode, 0) | ||
self.assertTrue(self.project_dir.exists()) | ||
|
||
@parameterized.expand([(x,) for x in get_all_template_names()]) | ||
def test_init_command_for_template(self, template_name): | ||
"""Test the 'init' command to create a project directory with a template.""" | ||
os.chdir(self.project_dir) | ||
result = self._run_cli('init', str(self.project_dir), '--template', template_name) | ||
self.assertEqual(result.returncode, 0) | ||
self.assertTrue(self.project_dir.exists()) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
import json | ||
import unittest | ||
from pathlib import Path | ||
from agentstack.proj_templates import TemplateConfig, get_all_template_names, get_all_template_paths | ||
|
||
BASE_PATH = Path(__file__).parent | ||
|
||
|
||
class TemplateConfigTest(unittest.TestCase): | ||
def test_all_configs_from_template_name(self): | ||
for template_name in get_all_template_names(): | ||
config = TemplateConfig.from_template_name(template_name) | ||
assert config.name == template_name | ||
# We can assume that pydantic validation caught any other issues | ||
|
||
def test_all_configs_from_template_path(self): | ||
for path in get_all_template_paths(): | ||
config = TemplateConfig.from_json(path) | ||
assert config.name == path.stem | ||
# We can assume that pydantic validation caught any other issues |