Skip to content

Commit

Permalink
Merge pull request #114 from tcdent/templates-config
Browse files Browse the repository at this point in the history
Project template validation.
  • Loading branch information
tcdent authored Dec 10, 2024
2 parents 3c17e79 + 9e781bc commit 981cded
Show file tree
Hide file tree
Showing 7 changed files with 177 additions and 51 deletions.
2 changes: 1 addition & 1 deletion agentstack/cli/agentstack_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def __init__(
license: str = "",
year: int = datetime.now().year,
template: str = "none",
template_version: str = "0",
template_version: int = 0,
):
self.project_name = clean_input(project_name) if project_name else "myagent"
self.project_slug = clean_input(project_slug) if project_slug else self.project_name
Expand Down
52 changes: 21 additions & 31 deletions agentstack/cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from agentstack import generation
from agentstack.utils import open_json_file, term_color, is_snake_case
from agentstack.update import AGENTSTACK_PACKAGE
from agentstack.proj_templates import TemplateConfig


PREFERRED_MODELS = [
Expand All @@ -57,45 +58,34 @@ def init_project_builder(

template_data = None
if template is not None:
url_start = "https://"
if template[: len(url_start)] == url_start:
# template is a url
response = requests.get(template)
if response.status_code == 200:
template_data = response.json()
else:
print(
term_color(
f"Failed to fetch template data from {template}. Status code: {response.status_code}",
'red',
)
)
if template.startswith("https://"):
try:
template_data = TemplateConfig.from_url(template)
except Exception as e:
print(term_color(f"Failed to fetch template data from {template}", 'red'))
sys.exit(1)
else:
with importlib.resources.path(
'agentstack.templates.proj_templates', f'{template}.json'
) as template_path:
if template_path is None:
print(term_color(f"No such template {template} found", 'red'))
sys.exit(1)
template_data = open_json_file(template_path)
try:
template_data = TemplateConfig.from_template_name(template)
except Exception as e:
print(term_color(f"Failed to load template {template}", 'red'))
sys.exit(1)

if template_data:
project_details = {
"name": slug_name or template_data['name'],
"name": slug_name or template_data.name,
"version": "0.0.1",
"description": template_data['description'],
"description": template_data.description,
"author": "Name <Email>",
"license": "MIT",
}
framework = template_data['framework']
framework = template_data.framework
design = {
'agents': template_data['agents'],
'tasks': template_data['tasks'],
'inputs': template_data['inputs'],
'agents': template_data.agents,
'tasks': template_data.tasks,
'inputs': template_data.inputs,
}

tools = template_data['tools']
tools = template_data.tools

elif use_wizard:
welcome_message()
Expand Down Expand Up @@ -381,7 +371,7 @@ def insert_template(
project_details: dict,
framework_name: str,
design: dict,
template_data: Optional[dict] = None,
template_data: Optional[TemplateConfig] = None,
):
framework = FrameworkData(
name=framework_name.lower(),
Expand All @@ -393,8 +383,8 @@ def insert_template(
version="0.0.1",
license="MIT",
year=datetime.now().year,
template=template_data['name'] if template_data else 'none',
template_version=template_data['template_version'] if template_data else '0',
template=template_data.name if template_data else 'none',
template_version=template_data.template_version if template_data else 0,
)

project_structure = ProjectStructure()
Expand Down
92 changes: 92 additions & 0 deletions agentstack/proj_templates.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
from typing import Optional
import os, sys
from pathlib import Path
import pydantic
import requests
from agentstack import ValidationError
from agentstack.utils import get_package_path, open_json_file, term_color


class TemplateConfig(pydantic.BaseModel):
"""
Interface for interacting with template configuration files.
Templates are read-only.
Template Schema
-------------
name: str
The name of the project.
description: str
A description of the template.
template_version: str
The version of the template.
framework: str
The framework the template is for.
method: str
The method used by the project. ie. "sequential"
agents: list[dict]
A list of agents used by the project. TODO vaidate this against an agent schema
tasks: list[dict]
A list of tasks used by the project. TODO validate this against a task schema
tools: list[dict]
A list of tools used by the project. TODO validate this against a tool schema
inputs: list[str]
A list of inputs used by the project.
"""

name: str
description: str
template_version: int
framework: str
method: str
agents: list[dict]
tasks: list[dict]
tools: list[dict]
inputs: list[str]

@classmethod
def from_template_name(cls, name: str) -> 'TemplateConfig':
path = get_package_path() / f'templates/proj_templates/{name}.json'
if not os.path.exists(path): # TODO raise exceptions and handle message/exit in cli
print(term_color(f'No known agentstack tool: {name}', 'red'))
sys.exit(1)
return cls.from_json(path)

@classmethod
def from_json(cls, path: Path) -> 'TemplateConfig':
data = open_json_file(path)
try:
return cls(**data)
except pydantic.ValidationError as e:
# TODO raise exceptions and handle message/exit in cli
print(term_color(f"Error validating template config JSON: \n{path}", 'red'))
for error in e.errors():
print(f"{' '.join([str(loc) for loc in error['loc']])}: {error['msg']}")
sys.exit(1)

@classmethod
def from_url(cls, url: str) -> 'TemplateConfig':
if not url.startswith("https://"):
raise ValidationError(f"Invalid URL: {url}")
response = requests.get(url)
if response.status_code != 200:
raise ValidationError(f"Failed to fetch template from {url}")
return cls(**response.json())


def get_all_template_paths() -> list[Path]:
paths = []
templates_dir = get_package_path() / 'templates/proj_templates'
for file in templates_dir.iterdir():
if file.suffix == '.json':
paths.append(file)
return paths


def get_all_template_names() -> list[str]:
return [path.stem for path in get_all_template_paths()]


def get_all_templates() -> list[TemplateConfig]:
return [TemplateConfig.from_json(path) for path in get_all_template_paths()]
2 changes: 1 addition & 1 deletion agentstack/templates/proj_templates/content_creator.json
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
{
"name": "content_creation",
"name": "content_creator",
"description": "Multi-agent system for creating high-quality content",
"template_version": 1,
"framework": "crewai",
Expand Down
42 changes: 42 additions & 0 deletions tests/test_cli_init.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import subprocess
import os, sys
import unittest
from parameterized import parameterized
from pathlib import Path
import shutil
from agentstack.proj_templates import get_all_template_names

BASE_PATH = Path(__file__).parent
CLI_ENTRY = [
sys.executable,
"-m",
"agentstack.main",
]


class CLIInitTest(unittest.TestCase):
def setUp(self):
self.project_dir = Path(BASE_PATH / 'tmp/cli_init')
os.makedirs(self.project_dir)

def tearDown(self):
shutil.rmtree(self.project_dir)

def _run_cli(self, *args):
"""Helper method to run the CLI with arguments."""
return subprocess.run([*CLI_ENTRY, *args], capture_output=True, text=True)

def test_init_command(self):
"""Test the 'init' command to create a project directory."""
os.chdir(self.project_dir)
result = self._run_cli('init', str(self.project_dir))
self.assertEqual(result.returncode, 0)
self.assertTrue(self.project_dir.exists())

@parameterized.expand([(x,) for x in get_all_template_names()])
def test_init_command_for_template(self, template_name):
"""Test the 'init' command to create a project directory with a template."""
os.chdir(self.project_dir)
result = self._run_cli('init', str(self.project_dir), '--template', template_name)
self.assertEqual(result.returncode, 0)
self.assertTrue(self.project_dir.exists())
18 changes: 0 additions & 18 deletions tests/test_cli_loads.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@


class TestAgentStackCLI(unittest.TestCase):
# Replace with your actual CLI entry point if different
CLI_ENTRY = [
sys.executable,
"-m",
Expand All @@ -32,23 +31,6 @@ def test_invalid_command(self):
self.assertNotEqual(result.returncode, 0)
self.assertIn("usage:", result.stderr)

def test_init_command(self):
"""Test the 'init' command to create a project directory."""
test_dir = Path(BASE_PATH / 'tmp/test_project')

# Ensure the directory doesn't exist from previous runs
if test_dir.exists():
shutil.rmtree(test_dir)
os.makedirs(test_dir)

os.chdir(test_dir)
result = self.run_cli("init", str(test_dir))
self.assertEqual(result.returncode, 0)
self.assertTrue(test_dir.exists())

# Clean up
shutil.rmtree(test_dir)

def test_run_command_invalid_project(self):
"""Test the 'run' command on an invalid project."""
test_dir = Path(BASE_PATH / 'tmp/test_project')
Expand Down
20 changes: 20 additions & 0 deletions tests/test_templates_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import json
import unittest
from pathlib import Path
from agentstack.proj_templates import TemplateConfig, get_all_template_names, get_all_template_paths

BASE_PATH = Path(__file__).parent


class TemplateConfigTest(unittest.TestCase):
def test_all_configs_from_template_name(self):
for template_name in get_all_template_names():
config = TemplateConfig.from_template_name(template_name)
assert config.name == template_name
# We can assume that pydantic validation caught any other issues

def test_all_configs_from_template_path(self):
for path in get_all_template_paths():
config = TemplateConfig.from_json(path)
assert config.name == path.stem
# We can assume that pydantic validation caught any other issues

0 comments on commit 981cded

Please sign in to comment.