Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add rollback functionality to setup wizard #75

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
171 changes: 94 additions & 77 deletions agentstack/cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,86 +20,100 @@
from .. import generation
from ..utils import open_json_file, term_color, is_snake_case

created_files = []
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These variables should be created inside the function scope, not at the module level.

created_dirs = []

def init_project_builder(slug_name: Optional[str] = None, template: Optional[str] = None, use_wizard: bool = False):
if slug_name and not is_snake_case(slug_name):
print(term_color("Project name must be snake case", 'red'))
return
def rollback_actions():
for file in created_files:
if os.path.exists(file):
os.remove(file)
for dir in created_dirs:
if os.path.exists(dir):
shutil.rmtree(dir)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should be very careful about deleting files when any exception occurs. Possible that running this command on an already initialized project would permanently erase user data.


if template is not None and use_wizard:
print(term_color("Template and wizard flags cannot be used together", 'red'))
return

template_data = None
if template is not None:
url_start = "https://"
if template[:len(url_start)] == url_start:
# template is a url
response = requests.get(template)
if response.status_code == 200:
template_data = response.json()
else:
print(term_color(f"Failed to fetch template data from {template}. Status code: {response.status_code}", 'red'))
sys.exit(1)
else:
with importlib.resources.path('agentstack.templates.proj_templates', f'{template}.json') as template_path:
if template_path is None:
print(term_color(f"No such template {template} found", 'red'))
def init_project_builder(slug_name: Optional[str] = None, template: Optional[str] = None, use_wizard: bool = False):
try:
if slug_name and not is_snake_case(slug_name):
print(term_color("Project name must be snake case", 'red'))
return

if template is not None and use_wizard:
print(term_color("Template and wizard flags cannot be used together", 'red'))
return

template_data = None
if template is not None:
url_start = "https://"
if template[:len(url_start)] == url_start:
# template is a url
response = requests.get(template)
if response.status_code == 200:
template_data = response.json()
else:
print(term_color(f"Failed to fetch template data from {template}. Status code: {response.status_code}", 'red'))
sys.exit(1)
template_data = open_json_file(template_path)

if template_data:
project_details = {
"name": slug_name or template_data['name'],
"version": "0.0.1",
"description": template_data['description'],
"author": "Name <Email>",
"license": "MIT"
}
framework = template_data['framework']
design = {
'agents': template_data['agents'],
'tasks': template_data['tasks']
}

tools = template_data['tools']

elif use_wizard:
welcome_message()
project_details = ask_project_details(slug_name)
welcome_message()
framework = ask_framework()
design = ask_design()
tools = ask_tools()

else:
welcome_message()
project_details = {
"name": slug_name or "agentstack_project",
"version": "0.0.1",
"description": "New agentstack project",
"author": "Name <Email>",
"license": "MIT"
}

framework = "CrewAI" # TODO: if --no-wizard, require a framework flag

design = {
'agents': [],
'tasks': []
}

tools = []

log.debug(
f"project_details: {project_details}"
f"framework: {framework}"
f"design: {design}"
)
insert_template(project_details, framework, design, template_data)
for tool_data in tools:
generation.add_tool(tool_data['name'], agents=tool_data['agents'], path=project_details['name'])
else:
with importlib.resources.path('agentstack.templates.proj_templates', f'{template}.json') as template_path:
if template_path is None:
print(term_color(f"No such template {template} found", 'red'))
sys.exit(1)
template_data = open_json_file(template_path)

if template_data:
project_details = {
"name": slug_name or template_data['name'],
"version": "0.0.1",
"description": template_data['description'],
"author": "Name <Email>",
"license": "MIT"
}
framework = template_data['framework']
design = {
'agents': template_data['agents'],
'tasks': template_data['tasks']
}

tools = template_data['tools']

elif use_wizard:
welcome_message()
project_details = ask_project_details(slug_name)
welcome_message()
framework = ask_framework()
design = ask_design()
tools = ask_tools()

else:
welcome_message()
project_details = {
"name": slug_name or "agentstack_project",
"version": "0.0.1",
"description": "New agentstack project",
"author": "Name <Email>",
"license": "MIT"
}

framework = "CrewAI" # TODO: if --no-wizard, require a framework flag

design = {
'agents': [],
'tasks': []
}

tools = []

log.debug(
f"project_details: {project_details}"
f"framework: {framework}"
f"design: {design}"
)
insert_template(project_details, framework, design, template_data)
for tool_data in tools:
generation.add_tool(tool_data['name'], agents=tool_data['agents'], path=project_details['name'])
except Exception as e:
print(term_color(f"An error occurred: {e}", 'red'))
rollback_actions()
sys.exit(1)

def welcome_message():
os.system("cls" if os.name == "nt" else "clear")
Expand Down Expand Up @@ -323,17 +337,20 @@ def insert_template(project_details: dict, framework_name: str, design: dict, te
template_path = get_package_path() / f'templates/{framework.name}'
with open(f"{template_path}/cookiecutter.json", "w") as json_file:
json.dump(cookiecutter_data.to_dict(), json_file)
created_files.append(f"{template_path}/cookiecutter.json")

# copy .env.example to .env
shutil.copy(
f'{template_path}/{"{{cookiecutter.project_metadata.project_slug}}"}/.env.example',
f'{template_path}/{"{{cookiecutter.project_metadata.project_slug}}"}/.env')
created_files.append(f'{template_path}/{"{{cookiecutter.project_metadata.project_slug}}"}/.env')

if os.path.isdir(project_details['name']):
print(term_color(f"Directory {template_path} already exists. Please check this and try again", "red"))
return

cookiecutter(str(template_path), no_input=True, extra_context=None)
created_dirs.append(project_details['name'])

# TODO: inits a git repo in the directory the command was run in
# TODO: not where the project is generated. Fix this
Expand Down Expand Up @@ -378,4 +395,4 @@ def list_tools():
print(f": {tool.url if tool.url else 'AgentStack default tool'}")

print("\n\n✨ Add a tool with: agentstack tools add <tool_name>")
print(" https://docs.agentstack.sh/tools/core")
print(" https://docs.agentstack.sh/tools/core")
36 changes: 24 additions & 12 deletions agentstack/generation/agent_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,16 @@
from ruamel.yaml import YAML
from ruamel.yaml.scalarstring import FoldedScalarString

created_files = []
created_dirs = []

def rollback_actions():
for file in created_files:
if os.path.exists(file):
os.remove(file)
for dir in created_dirs:
if os.path.exists(dir):
shutil.rmtree(dir)

def generate_agent(
name,
Expand All @@ -27,18 +37,19 @@ def generate_agent(

framework = get_framework()

if framework == 'crewai':
generate_crew_agent(name, role, goal, backstory, llm)
print(" > Added to src/config/agents.yaml")
else:
print(f"This function is not yet implemented for {framework}")
return

print(f"Added agent \"{name}\" to your AgentStack project successfully!")



try:
if framework == 'crewai':
generate_crew_agent(name, role, goal, backstory, llm)
print(" > Added to src/config/agents.yaml")
else:
print(f"This function is not yet implemented for {framework}")
return

print(f"Added agent \"{name}\" to your AgentStack project successfully!")
except Exception as e:
print(f"An error occurred: {e}")
rollback_actions()
sys.exit(1)

def generate_crew_agent(
name,
Expand Down Expand Up @@ -83,6 +94,7 @@ def generate_crew_agent(
# Write back to the file without altering existing content
with open(config_path, 'w') as file:
yaml.dump(data, file)
created_files.append(config_path)

# Now lets add the agent to crew.py
file_path = 'src/crew.py'
Expand All @@ -103,4 +115,4 @@ def generate_crew_agent(

def get_agent_names(framework: str = 'crewai', path: str = '') -> List[str]:
"""Get only agent names from the crew file"""
return get_crew_components(framework, CrewComponent.AGENT, path)['agents']
return get_crew_components(framework, CrewComponent.AGENT, path)['agents']
33 changes: 24 additions & 9 deletions agentstack/generation/task_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,16 @@
from ruamel.yaml import YAML
from ruamel.yaml.scalarstring import FoldedScalarString

created_files = []
created_dirs = []

def rollback_actions():
for file in created_files:
if os.path.exists(file):
os.remove(file)
for dir in created_dirs:
if os.path.exists(dir):
shutil.rmtree(dir)

def generate_task(
name,
Expand All @@ -24,15 +34,19 @@ def generate_task(

framework = get_framework()

if framework == 'crewai':
generate_crew_task(name, description, expected_output, agent)
print(" > Added to src/config/tasks.yaml")
else:
print(f"This function is not yet implemented for {framework}")
return

print(f"Added task \"{name}\" to your AgentStack project successfully!")
try:
if framework == 'crewai':
generate_crew_task(name, description, expected_output, agent)
print(" > Added to src/config/tasks.yaml")
else:
print(f"This function is not yet implemented for {framework}")
return

print(f"Added task \"{name}\" to your AgentStack project successfully!")
except Exception as e:
print(f"An error occurred: {e}")
rollback_actions()
sys.exit(1)

def generate_crew_task(
name,
Expand Down Expand Up @@ -74,6 +88,7 @@ def generate_crew_task(
# Write back to the file without altering existing content
with open(config_path, 'w') as file:
yaml.dump(data, file)
created_files.append(config_path)

# Add task to crew.py
file_path = 'src/crew.py'
Expand All @@ -91,4 +106,4 @@ def generate_crew_task(

def get_task_names(framework: str, path: str = '') -> List[str]:
"""Get only task names from the crew file"""
return get_crew_components(framework, CrewComponent.TASK, path)['tasks']
return get_crew_components(framework, CrewComponent.TASK, path)['tasks']
20 changes: 20 additions & 0 deletions tests/test_cli_loads.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import unittest
from pathlib import Path
import shutil
import os


class TestAgentStackCLI(unittest.TestCase):
Expand Down Expand Up @@ -44,6 +45,25 @@ def test_init_command(self):
# Clean up
shutil.rmtree(test_dir)

def test_rollback_on_error(self):
"""Test rollback functionality when an error occurs during project initialization."""
test_dir = Path("test_project_with_error")

# Ensure the directory doesn't exist from previous runs
if test_dir.exists():
shutil.rmtree(test_dir)

# Simulate an error by creating a directory that will cause a failure
os.makedirs(test_dir / "src")

result = self.run_cli("init", str(test_dir))
self.assertNotEqual(result.returncode, 0)
self.assertFalse(test_dir.exists()) # Directory should be removed on rollback

# Clean up
if test_dir.exists():
shutil.rmtree(test_dir)


if __name__ == "__main__":
unittest.main()