Skip to content

Commit

Permalink
Merge pull request #10 from wi0lono/question_patches
Browse files Browse the repository at this point in the history
Question patches
  • Loading branch information
sameersegal authored Jun 21, 2024
2 parents 1a2de51 + c150c9b commit afc3352
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 13 deletions.
2 changes: 1 addition & 1 deletion nl2dsl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from termcolor import cprint
from typing import List, Dict, Callable

from .utils.mini_llm import call_llm_for_json
from .utils.mini_llm import call_llm_for_json, chat_completion_request
from .utils.dsl_utils import (
update_flow,
update_global_variables,
Expand Down
31 changes: 19 additions & 12 deletions nl2dsl/utils/dsl_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,28 +51,36 @@ def update_flow(step, plugins={}, flow=[], debug=False):
dsl_list.insert(-1, llm_response)
elif step_type == "edit":
edited_task = llm_response
edited = False
for i, task in enumerate(dsl_list):
if task["name"] == edited_task["name"]:
dsl_list[i] = edited_task
edited = True
break
if not edited:
dsl_list.insert(-1, edited_task)
if edited_task["name"] == "end":
cprint(f"Cannot edit end task.", "red")

elif edited_task["name"] == "start":
for i, task in enumerate(dsl_list):
if task["name"] == "start":
if edited_task.get("goto"):
dsl_list[i]["goto"] = edited_task["goto"]
break
else:
edited = False
for i, task in enumerate(dsl_list):
if task["name"] == edited_task["name"]:
dsl_list[i] = edited_task
edited = True
break
if not edited:
dsl_list.insert(-1, edited_task)
except TypeError as e:
cprint(f"TypeError: {e}")
cprint(
f"It is likely that the assistant failed to return a valid json.", "red"
)
dsl_list = json.loads(flow)

if step_type == "delete":
dsl_list = json.loads(flow)
dsl_list: list = json.loads(flow)
delete_task_plan = step
deleted = False
if delete_task_plan["task_id"] == "start" or delete_task_plan["task_id"] == "end":
cprint(f"Cannot delete start or end task.", "red")
dsl_list = json.loads(flow)
else:
for i, task in enumerate(dsl_list):
if task["name"] == delete_task_plan["task_id"]:
Expand All @@ -84,7 +92,6 @@ def update_flow(step, plugins={}, flow=[], debug=False):
f"Task with id {delete_task_plan['task_id']} does not exist in the flow.",
"red",
)
dsl_list = json.loads(flow)

if debug:
cprint(f"Intermediate DSL: {json.dumps(dsl_list, indent=4)}", "light_red")
Expand Down

0 comments on commit afc3352

Please sign in to comment.