Skip to content

Commit

Permalink
changed files
Browse files Browse the repository at this point in the history
  • Loading branch information
Chhinna committed Jan 16, 2024
1 parent 672d7ae commit 50d8d16
Show file tree
Hide file tree
Showing 36 changed files with 169 additions and 224 deletions.
11 changes: 5 additions & 6 deletions automata/cli/scripts/run_config_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@

# Validation test function
def test_yaml_validation(file_path):
with open(file_path, "r") as file:
with open(file_path, 'r', encoding='utf-8') as file:
yaml_data = yaml.safe_load(file)

try:
Expand All @@ -47,7 +47,7 @@ def test_yaml_validation(file_path):

# Compatibility test function
def test_yaml_compatibility(file_path):
with open(file_path, "r") as file:
with open(file_path, 'r', encoding='utf-8') as file:
yaml_data = yaml.safe_load(file)

# Add compatibility test cases based on your specific requirements
Expand All @@ -71,12 +71,11 @@ def test_yaml_compatibility(file_path):
raise ValidationError(
f"Compatibility test '{test['test_name']}' for {file_path} failed."
)
else:
logger.debug(f"Compatibility test '{test['test_name']}' for {file_path} passed.")
logger.debug(f"Compatibility test '{test['test_name']}' for {file_path} passed.")


def test_action_extraction(file_path):
with open(file_path, "r") as file:
with open(file_path, 'r', encoding='utf-8') as file:
yaml_data = yaml.safe_load(file)
actions = AutomataActionExtractor.extract_actions(yaml_data["system_instruction_template"])
number_of_expected_actions = yaml_data["number_of_expected_actions"]
Expand All @@ -96,4 +95,4 @@ def test_action_extraction(file_path):
logger.debug(f"yaml_file={yaml_file}")
test_yaml_validation(yaml_file)
test_yaml_compatibility(yaml_file)
test_action_extraction(yaml_file)
test_action_extraction(yaml_file)
3 changes: 1 addition & 2 deletions automata/cli/scripts/run_coordinator.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,7 @@ def run(kwargs):

if not kwargs.get("session_id"):
return agent_manager.run()
else:
agent_manager.replay_messages()
agent_manager.replay_messages()


def main(kwargs):
Expand Down
2 changes: 1 addition & 1 deletion automata/cli/scripts/run_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def main(*args, **kwargs):
symbol_embedding.save(embedding_path, overwrite=True)
return "Success"

elif kwargs.get("query_embedding"):
if kwargs.get("query_embedding"):
symbol_graph = SymbolGraph(scip_path)
symbol_embedding = SymbolEmbeddingMap(
load_embedding_map=True,
Expand Down
2 changes: 1 addition & 1 deletion automata/cli/scripts/run_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def evaluator_decoder(
return EvalAction(
ToolAction(dct["tool_name"], dct["tool_query"], dct["tool_args"]), dct["check_tokens"] # type: ignore
)
elif "result_name" in dct:
if "result_name" in dct:
return EvalAction(
ResultAction(dct["result_name"], dct["result_outputs"]), dct["check_tokens"] # type: ignore
)
Expand Down
10 changes: 5 additions & 5 deletions automata/configs/automata_agent_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def load_automata_yaml_config(cls, config_name: AgentConfigName) -> Dict:
file_dir_path, ConfigCategory.AGENT.value, f"{config_name.value}.yaml"
)

with open(config_abs_path, "r") as file:
with open(config_abs_path, 'r', encoding='utf-8') as file:
loaded_yaml = yaml.safe_load(file)

if "tools" in loaded_yaml:
Expand Down Expand Up @@ -128,9 +128,9 @@ def _add_overview_to_instruction_payload(cls, config: "AutomataAgentConfig") ->
@staticmethod
def _format_prompt(format_variables: AutomataInstructionPayload, input_text: str) -> str:
"""Format expected strings into the config."""
for arg in format_variables.__dict__.keys():
if format_variables.__dict__[arg]:
input_text = input_text.replace(f"{{{arg}}}", format_variables.__dict__[arg])
for (arg, format_variables___dict___arg) in format_variables.__dict__.items():
if format_variables___dict___arg:
input_text = input_text.replace(f"{{{arg}}}", format_variables___dict___arg)
return input_text

def _build_tool_message(self):
Expand All @@ -146,4 +146,4 @@ def _build_tool_message(self):
for toolkit in self.llm_toolkits.values()
for tool in toolkit.tools
]
)
)
2 changes: 1 addition & 1 deletion automata/core/agent/tests/test_automata_agent_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def test_config_loading_different_versions():
for config_name in AgentConfigName:
if config_name == AgentConfigName.DEFAULT:
continue
elif config_name == AgentConfigName.AUTOMATA_INITIALIZER:
if config_name == AgentConfigName.AUTOMATA_INITIALIZER:
continue
main_config = AutomataAgentConfig.load(config_name)
assert isinstance(main_config, AutomataAgentConfig)
7 changes: 3 additions & 4 deletions automata/core/base/tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,19 +93,18 @@ def _make_tool(func: Callable[[str], str]) -> Tool:
# if the argument is a string, then we use the string as the tool name
# Example usage: @tool("search", return_direct=True)
return _make_with_name(args[0])
elif len(args) == 1 and callable(args[0]):
if len(args) == 1 and callable(args[0]):
# if the argument is a function, then we use the function name as the tool name
# Example usage: @tool
return _make_with_name(args[0].__name__)(args[0])
elif len(args) == 0:
if len(args) == 0:
# if there are no arguments, then we use the function name as the tool name
# Example usage: @tool(return_direct=True)
def _partial(func: Callable[[str], str]) -> BaseTool:
return _make_with_name(func.__name__)(func)

return _partial
else:
raise ValueError("Too many arguments for tool decorator")
raise ValueError("Too many arguments for tool decorator")


class Toolkit:
Expand Down
13 changes: 6 additions & 7 deletions automata/core/code_indexing/python_code_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def _remove_docstrings(node: FSTNode) -> None:
if filtered_node and isinstance(filtered_node[0], StringNode):
index = filtered_node[0].index_on_parent
node.pop(index)
child_nodes = node.find_all(lambda identifier: identifier in ("def", "class"))
child_nodes = node.find_all(lambda identifier: identifier in {"def", "class"})
for child_node in child_nodes:
if child_node is not node:
_remove_docstrings(child_node)
Expand Down Expand Up @@ -128,8 +128,7 @@ def get_parent_function_name_by_line(self, module_dotpath: str, line_number: int
if node:
if node.parent[0].type == "class":
return f"{node.parent.name}.{node.name}"
else:
return node.name
return node.name
return NO_RESULT_FOUND_STR

def get_parent_function_num_code_lines(
Expand Down Expand Up @@ -181,9 +180,9 @@ def get_parent_code_by_line(

# retarget def or class node
if node.type not in ("def", "class") and node.parent_find(
lambda identifier: identifier in ("def", "class")
lambda identifier: identifier in {"def", "class"}
):
node = node.parent_find(lambda identifier: identifier in ("def", "class"))
node = node.parent_find(lambda identifier: identifier in {"def", "class"})

path = node.path().to_baron_path()
pointer = module
Expand All @@ -202,7 +201,7 @@ def get_parent_code_by_line(
result += self._create_line_number_tuples(
pointer[x], start_line, start_col
)
if pointer[x].type in ("def", "class"):
if pointer[x].type in {"def", "class"}:
docstring = PythonCodeRetriever._get_docstring(pointer[x])
node_copy = pointer[x].copy()
node_copy.value = '"""' + docstring + '"""'
Expand Down Expand Up @@ -274,7 +273,7 @@ def get_expression_context(

node = module.at(lineno)
if node.type not in ("def", "class"):
node = node.parent_find(lambda identifier: identifier in ("def", "class"))
node = node.parent_find(lambda identifier: identifier in {"def", "class"})

if node:
result += f".{node.name}"
Expand Down
8 changes: 4 additions & 4 deletions automata/core/code_indexing/syntax_tree_navigation.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def _find_subnode(code_obj: RedBaron, obj_name: str) -> Optional[Union[DefNode,
Returns:
Optional[Union[DefNode, ClassNode]]: The found node, or None.
"""
return code_obj.find(lambda identifier: identifier in ("def", "class"), name=obj_name)
return code_obj.find(lambda identifier: identifier in {"def", "class"}, name=obj_name)


def find_import_syntax_tree_nodes(module: RedBaron) -> Optional[NodeList]:
Expand All @@ -64,7 +64,7 @@ def find_import_syntax_tree_nodes(module: RedBaron) -> Optional[NodeList]:
Returns:
Optional[NodeList]: A list of ImportNode and FromImportNode objects.
"""
return module.find_all(lambda identifier: identifier in ("import", "from_import"))
return module.find_all(lambda identifier: identifier in {"import", "from_import"})


def find_import_syntax_tree_node_by_name(
Expand All @@ -81,7 +81,7 @@ def find_import_syntax_tree_node_by_name(
Optional[Union[ImportNode, FromImportNode]]: The found import, or None if not found.
"""
return module.find(
lambda identifier: identifier in ("import", "from_import"), name=import_name
lambda identifier: identifier in {"import", "from_import"}, name=import_name
)


Expand All @@ -95,4 +95,4 @@ def find_all_function_and_class_syntax_tree_nodes(module: RedBaron) -> NodeList:
Returns:
NodeList: A list of ClassNode and DefNode objects.
"""
return module.find_all(lambda identifier: identifier in ("class", "def"))
return module.find_all(lambda identifier: identifier in {"class", "def"})
4 changes: 2 additions & 2 deletions automata/core/code_indexing/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ def build_repository_overview(path: str, skip_test: bool = True) -> str:
def _overview_traverse_helper(node, line_items, num_spaces=1):
if isinstance(node, ClassDef):
line_items.append(" " * num_spaces + " - cls " + node.name)
elif isinstance(node, FunctionDef) or isinstance(node, AsyncFunctionDef):
elif isinstance(node, (AsyncFunctionDef, FunctionDef)):
line_items.append(" " * num_spaces + " - func " + node.name)

for child in ast.iter_child_nodes(node):
_overview_traverse_helper(child, line_items, num_spaces + 1)
_overview_traverse_helper(child, line_items, num_spaces + 1)
2 changes: 1 addition & 1 deletion automata/core/coordinator/automata_coordinator.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def add_agent_instance(self, agent_instance: AutomataInstance) -> None:
ValueError: If an agent with the same config_name already exists in the list.
"""
# Check agent has not already been added via name field
if agent_instance.config_name in [ele.config_name for ele in self.agent_instances]:
if agent_instance.config_name in (ele.config_name for ele in self.agent_instances):
raise ValueError("Agent already exists.")
self.agent_instances.append(agent_instance)

Expand Down
12 changes: 6 additions & 6 deletions automata/core/search/scip_pb2.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

20 changes: 4 additions & 16 deletions automata/core/search/symbol_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,10 +118,7 @@ def get_references_to_module(self, module_name: str) -> List[SymbolReference]:
List[SymbolReference]: List of symbol references
"""
reference_edges_in_module = self._graph.in_edges(module_name, data=True)
result = []
for _, __, data in reference_edges_in_module:
if data["label"] == "reference":
result.append(data.get("symbol_reference"))
result = [data.get("symbol_reference") for (_, __, data) in reference_edges_in_module if data["label"] == "reference"]

return result

Expand Down Expand Up @@ -302,7 +299,7 @@ def _get_symbol_dependencies(self, symbol: Symbol) -> Set[Symbol]:
# TODO: Consider implications of using list instead of set
"""
references_in_range = self._get_symbol_references_in_scope(symbol)
symbols_in_range = set([ref.symbol for ref in references_in_range])
symbols_in_range = {ref.symbol for ref in references_in_range}
return symbols_in_range

def _get_symbol_references_in_scope(self, symbol: Symbol) -> List[SymbolReference]:
Expand Down Expand Up @@ -345,13 +342,7 @@ def _get_symbol_relationships(self, symbol: Symbol) -> Set[Symbol]:
# TODO: Consider implications of using list instead of set
"""
related_symbol_nodes = set(
[
target
for _, target, data in self._graph.out_edges(symbol, data=True)
if data.get("label") == "relationship"
]
)
related_symbol_nodes = {target for (_, target, data) in self._graph.out_edges(symbol, data=True) if data.get("label") == "relationship"}
return related_symbol_nodes

@staticmethod
Expand All @@ -364,10 +355,7 @@ def _process_symbol_roles(role: int) -> Dict[str, bool]:
Returns:
Dict[str, bool]: A dictionary of symbol roles
"""
result = {}
for role_name, role_value in SymbolRole.items():
if (role & role_value) > 0:
result[role_name] = (role & role_value) > 0
result = {role_name: role & role_value > 0 for (role_name, role_value) in SymbolRole.items() if role & role_value > 0}
return result

@staticmethod
Expand Down
53 changes: 26 additions & 27 deletions automata/core/search/symbol_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,36 +49,35 @@ def parse_descriptor(self) -> Descriptor:
descriptor = Descriptor(name, Descriptor.ScipSuffix.Parameter)
self.accept_character(")", "closing parameter name")
return descriptor
elif next_char == "[":
if next_char == "[":
self.index += 1
name = self.accept_identifier("type parameter name")
descriptor = Descriptor(name, Descriptor.ScipSuffix.TypeParameter)
self.accept_character("]", "closing type parameter name")
return descriptor
else:
name = self.accept_identifier("descriptor name")
suffix = self.current()
self.index += 1
if suffix == "(":
disambiguator = ""
if self.current() != ")":
disambiguator = self.accept_identifier("method disambiguator")
descriptor = Descriptor(name, Descriptor.ScipSuffix.Method, disambiguator)
self.accept_character(")", "closing method")
self.accept_character(".", "closing method")
return descriptor
elif suffix == "/":
return Descriptor(name, Descriptor.ScipSuffix.Namespace)
elif suffix == ".":
return Descriptor(name, Descriptor.ScipSuffix.Term)
elif suffix == "#":
return Descriptor(name, Descriptor.ScipSuffix.Type)
elif suffix == ":":
return Descriptor(name, Descriptor.ScipSuffix.Meta)
elif suffix == "!":
return Descriptor(name, Descriptor.ScipSuffix.Macro)
else:
raise self.error("Expected a descriptor suffix")

name = self.accept_identifier("descriptor name")
suffix = self.current()
self.index += 1
if suffix == "(":
disambiguator = ""
if self.current() != ")":
disambiguator = self.accept_identifier("method disambiguator")
descriptor = Descriptor(name, Descriptor.ScipSuffix.Method, disambiguator)
self.accept_character(")", "closing method")
self.accept_character(".", "closing method")
return descriptor
if suffix == "/":
return Descriptor(name, Descriptor.ScipSuffix.Namespace)
if suffix == ".":
return Descriptor(name, Descriptor.ScipSuffix.Term)
if suffix == "#":
return Descriptor(name, Descriptor.ScipSuffix.Type)
if suffix == ":":
return Descriptor(name, Descriptor.ScipSuffix.Meta)
if suffix == "!":
return Descriptor(name, Descriptor.ScipSuffix.Macro)
raise self.error("Expected a descriptor suffix")

def accept_identifier(self, what: str) -> str:
if self.current() == "`":
Expand Down Expand Up @@ -124,7 +123,7 @@ def accept_character(self, r: str, what: str) -> None:

@staticmethod
def is_identifier_character(c: str) -> bool:
return c.isalpha() or c.isdigit() or c in ["-", "+", "$", "_"]
return c.isalpha() or c.isdigit() or c in {"-", "+", "$", '_'}


def parse_symbol(symbol_uri: str, include_descriptors: bool = True) -> Symbol:
Expand Down Expand Up @@ -156,7 +155,7 @@ def new_local_symbol(symbol: str, id: str) -> Symbol:
symbol,
"local",
Package("", "", ""),
tuple([Descriptor(id, Descriptor.ScipSuffix.Local)]),
(Descriptor(id, Descriptor.ScipSuffix.Local),),
)


Expand Down
Loading

0 comments on commit 50d8d16

Please sign in to comment.