diff --git a/automata/cli/scripts/run_config_validation.py b/automata/cli/scripts/run_config_validation.py index 4da44be8..ad4facb1 100644 --- a/automata/cli/scripts/run_config_validation.py +++ b/automata/cli/scripts/run_config_validation.py @@ -35,7 +35,7 @@ # Validation test function def test_yaml_validation(file_path): - with open(file_path, "r") as file: + with open(file_path, 'r', encoding='utf-8') as file: yaml_data = yaml.safe_load(file) try: @@ -47,7 +47,7 @@ def test_yaml_validation(file_path): # Compatibility test function def test_yaml_compatibility(file_path): - with open(file_path, "r") as file: + with open(file_path, 'r', encoding='utf-8') as file: yaml_data = yaml.safe_load(file) # Add compatibility test cases based on your specific requirements @@ -71,12 +71,11 @@ def test_yaml_compatibility(file_path): raise ValidationError( f"Compatibility test '{test['test_name']}' for {file_path} failed." ) - else: - logger.debug(f"Compatibility test '{test['test_name']}' for {file_path} passed.") + logger.debug(f"Compatibility test '{test['test_name']}' for {file_path} passed.") def test_action_extraction(file_path): - with open(file_path, "r") as file: + with open(file_path, 'r', encoding='utf-8') as file: yaml_data = yaml.safe_load(file) actions = AutomataActionExtractor.extract_actions(yaml_data["system_instruction_template"]) number_of_expected_actions = yaml_data["number_of_expected_actions"] @@ -96,4 +95,4 @@ def test_action_extraction(file_path): logger.debug(f"yaml_file={yaml_file}") test_yaml_validation(yaml_file) test_yaml_compatibility(yaml_file) - test_action_extraction(yaml_file) + test_action_extraction(yaml_file) \ No newline at end of file diff --git a/automata/cli/scripts/run_coordinator.py b/automata/cli/scripts/run_coordinator.py index 49bac154..7c8f4506 100644 --- a/automata/cli/scripts/run_coordinator.py +++ b/automata/cli/scripts/run_coordinator.py @@ -41,8 +41,7 @@ def run(kwargs): if not kwargs.get("session_id"): return agent_manager.run() - else: - agent_manager.replay_messages() + agent_manager.replay_messages() def main(kwargs): diff --git a/automata/cli/scripts/run_embedding.py b/automata/cli/scripts/run_embedding.py index 64d0f37b..5f0a133b 100644 --- a/automata/cli/scripts/run_embedding.py +++ b/automata/cli/scripts/run_embedding.py @@ -51,7 +51,7 @@ def main(*args, **kwargs): symbol_embedding.save(embedding_path, overwrite=True) return "Success" - elif kwargs.get("query_embedding"): + if kwargs.get("query_embedding"): symbol_graph = SymbolGraph(scip_path) symbol_embedding = SymbolEmbeddingMap( load_embedding_map=True, diff --git a/automata/cli/scripts/run_evaluator.py b/automata/cli/scripts/run_evaluator.py index 3a6a30ce..3d289af6 100644 --- a/automata/cli/scripts/run_evaluator.py +++ b/automata/cli/scripts/run_evaluator.py @@ -20,7 +20,7 @@ def evaluator_decoder( return EvalAction( ToolAction(dct["tool_name"], dct["tool_query"], dct["tool_args"]), dct["check_tokens"] # type: ignore ) - elif "result_name" in dct: + if "result_name" in dct: return EvalAction( ResultAction(dct["result_name"], dct["result_outputs"]), dct["check_tokens"] # type: ignore ) diff --git a/automata/configs/automata_agent_configs.py b/automata/configs/automata_agent_configs.py index 6a4f05e8..c4f04ddc 100644 --- a/automata/configs/automata_agent_configs.py +++ b/automata/configs/automata_agent_configs.py @@ -95,7 +95,7 @@ def load_automata_yaml_config(cls, config_name: AgentConfigName) -> Dict: file_dir_path, ConfigCategory.AGENT.value, f"{config_name.value}.yaml" ) - with open(config_abs_path, "r") as file: + with open(config_abs_path, 'r', encoding='utf-8') as file: loaded_yaml = yaml.safe_load(file) if "tools" in loaded_yaml: @@ -128,9 +128,9 @@ def _add_overview_to_instruction_payload(cls, config: "AutomataAgentConfig") -> @staticmethod def _format_prompt(format_variables: AutomataInstructionPayload, input_text: str) -> str: """Format expected strings into the config.""" - for arg in format_variables.__dict__.keys(): - if format_variables.__dict__[arg]: - input_text = input_text.replace(f"{{{arg}}}", format_variables.__dict__[arg]) + for (arg, format_variables___dict___arg) in format_variables.__dict__.items(): + if format_variables___dict___arg: + input_text = input_text.replace(f"{{{arg}}}", format_variables___dict___arg) return input_text def _build_tool_message(self): @@ -146,4 +146,4 @@ def _build_tool_message(self): for toolkit in self.llm_toolkits.values() for tool in toolkit.tools ] - ) + ) \ No newline at end of file diff --git a/automata/core/agent/tests/test_automata_agent_builder.py b/automata/core/agent/tests/test_automata_agent_builder.py index f986fa90..34dc417d 100644 --- a/automata/core/agent/tests/test_automata_agent_builder.py +++ b/automata/core/agent/tests/test_automata_agent_builder.py @@ -105,7 +105,7 @@ def test_config_loading_different_versions(): for config_name in AgentConfigName: if config_name == AgentConfigName.DEFAULT: continue - elif config_name == AgentConfigName.AUTOMATA_INITIALIZER: + if config_name == AgentConfigName.AUTOMATA_INITIALIZER: continue main_config = AutomataAgentConfig.load(config_name) assert isinstance(main_config, AutomataAgentConfig) diff --git a/automata/core/base/tool.py b/automata/core/base/tool.py index 9127c83e..8ce65cdf 100644 --- a/automata/core/base/tool.py +++ b/automata/core/base/tool.py @@ -93,19 +93,18 @@ def _make_tool(func: Callable[[str], str]) -> Tool: # if the argument is a string, then we use the string as the tool name # Example usage: @tool("search", return_direct=True) return _make_with_name(args[0]) - elif len(args) == 1 and callable(args[0]): + if len(args) == 1 and callable(args[0]): # if the argument is a function, then we use the function name as the tool name # Example usage: @tool return _make_with_name(args[0].__name__)(args[0]) - elif len(args) == 0: + if len(args) == 0: # if there are no arguments, then we use the function name as the tool name # Example usage: @tool(return_direct=True) def _partial(func: Callable[[str], str]) -> BaseTool: return _make_with_name(func.__name__)(func) return _partial - else: - raise ValueError("Too many arguments for tool decorator") + raise ValueError("Too many arguments for tool decorator") class Toolkit: diff --git a/automata/core/code_indexing/python_code_retriever.py b/automata/core/code_indexing/python_code_retriever.py index 00453c74..76944271 100644 --- a/automata/core/code_indexing/python_code_retriever.py +++ b/automata/core/code_indexing/python_code_retriever.py @@ -91,7 +91,7 @@ def _remove_docstrings(node: FSTNode) -> None: if filtered_node and isinstance(filtered_node[0], StringNode): index = filtered_node[0].index_on_parent node.pop(index) - child_nodes = node.find_all(lambda identifier: identifier in ("def", "class")) + child_nodes = node.find_all(lambda identifier: identifier in {"def", "class"}) for child_node in child_nodes: if child_node is not node: _remove_docstrings(child_node) @@ -128,8 +128,7 @@ def get_parent_function_name_by_line(self, module_dotpath: str, line_number: int if node: if node.parent[0].type == "class": return f"{node.parent.name}.{node.name}" - else: - return node.name + return node.name return NO_RESULT_FOUND_STR def get_parent_function_num_code_lines( @@ -181,9 +180,9 @@ def get_parent_code_by_line( # retarget def or class node if node.type not in ("def", "class") and node.parent_find( - lambda identifier: identifier in ("def", "class") + lambda identifier: identifier in {"def", "class"} ): - node = node.parent_find(lambda identifier: identifier in ("def", "class")) + node = node.parent_find(lambda identifier: identifier in {"def", "class"}) path = node.path().to_baron_path() pointer = module @@ -202,7 +201,7 @@ def get_parent_code_by_line( result += self._create_line_number_tuples( pointer[x], start_line, start_col ) - if pointer[x].type in ("def", "class"): + if pointer[x].type in {"def", "class"}: docstring = PythonCodeRetriever._get_docstring(pointer[x]) node_copy = pointer[x].copy() node_copy.value = '"""' + docstring + '"""' @@ -274,7 +273,7 @@ def get_expression_context( node = module.at(lineno) if node.type not in ("def", "class"): - node = node.parent_find(lambda identifier: identifier in ("def", "class")) + node = node.parent_find(lambda identifier: identifier in {"def", "class"}) if node: result += f".{node.name}" diff --git a/automata/core/code_indexing/syntax_tree_navigation.py b/automata/core/code_indexing/syntax_tree_navigation.py index d0c7fcbb..cd9e8057 100644 --- a/automata/core/code_indexing/syntax_tree_navigation.py +++ b/automata/core/code_indexing/syntax_tree_navigation.py @@ -51,7 +51,7 @@ def _find_subnode(code_obj: RedBaron, obj_name: str) -> Optional[Union[DefNode, Returns: Optional[Union[DefNode, ClassNode]]: The found node, or None. """ - return code_obj.find(lambda identifier: identifier in ("def", "class"), name=obj_name) + return code_obj.find(lambda identifier: identifier in {"def", "class"}, name=obj_name) def find_import_syntax_tree_nodes(module: RedBaron) -> Optional[NodeList]: @@ -64,7 +64,7 @@ def find_import_syntax_tree_nodes(module: RedBaron) -> Optional[NodeList]: Returns: Optional[NodeList]: A list of ImportNode and FromImportNode objects. """ - return module.find_all(lambda identifier: identifier in ("import", "from_import")) + return module.find_all(lambda identifier: identifier in {"import", "from_import"}) def find_import_syntax_tree_node_by_name( @@ -81,7 +81,7 @@ def find_import_syntax_tree_node_by_name( Optional[Union[ImportNode, FromImportNode]]: The found import, or None if not found. """ return module.find( - lambda identifier: identifier in ("import", "from_import"), name=import_name + lambda identifier: identifier in {"import", "from_import"}, name=import_name ) @@ -95,4 +95,4 @@ def find_all_function_and_class_syntax_tree_nodes(module: RedBaron) -> NodeList: Returns: NodeList: A list of ClassNode and DefNode objects. """ - return module.find_all(lambda identifier: identifier in ("class", "def")) + return module.find_all(lambda identifier: identifier in {"class", "def"}) diff --git a/automata/core/code_indexing/utils.py b/automata/core/code_indexing/utils.py index 6cc3fe9e..b20405cb 100644 --- a/automata/core/code_indexing/utils.py +++ b/automata/core/code_indexing/utils.py @@ -36,8 +36,8 @@ def build_repository_overview(path: str, skip_test: bool = True) -> str: def _overview_traverse_helper(node, line_items, num_spaces=1): if isinstance(node, ClassDef): line_items.append(" " * num_spaces + " - cls " + node.name) - elif isinstance(node, FunctionDef) or isinstance(node, AsyncFunctionDef): + elif isinstance(node, (AsyncFunctionDef, FunctionDef)): line_items.append(" " * num_spaces + " - func " + node.name) for child in ast.iter_child_nodes(node): - _overview_traverse_helper(child, line_items, num_spaces + 1) + _overview_traverse_helper(child, line_items, num_spaces + 1) \ No newline at end of file diff --git a/automata/core/coordinator/automata_coordinator.py b/automata/core/coordinator/automata_coordinator.py index 287da910..0d35a168 100644 --- a/automata/core/coordinator/automata_coordinator.py +++ b/automata/core/coordinator/automata_coordinator.py @@ -27,7 +27,7 @@ def add_agent_instance(self, agent_instance: AutomataInstance) -> None: ValueError: If an agent with the same config_name already exists in the list. """ # Check agent has not already been added via name field - if agent_instance.config_name in [ele.config_name for ele in self.agent_instances]: + if agent_instance.config_name in (ele.config_name for ele in self.agent_instances): raise ValueError("Agent already exists.") self.agent_instances.append(agent_instance) diff --git a/automata/core/search/scip_pb2.py b/automata/core/search/scip_pb2.py index 387a2e3c..6e98c6df 100644 --- a/automata/core/search/scip_pb2.py +++ b/automata/core/search/scip_pb2.py @@ -2,11 +2,11 @@ # Generated by the protocol buffer compiler. DO NOT EDIT! # source: scip.proto """Generated protocol buffer code.""" -from google.protobuf import descriptor as _descriptor -from google.protobuf import descriptor_pool as _descriptor_pool -from google.protobuf import message as _message -from google.protobuf import reflection as _reflection -from google.protobuf import symbol_database as _symbol_database +from google.protobuf import descriptor as _descriptor, descriptor_pool as _descriptor_pool, message as _message, reflection as _reflection, symbol_database as _symbol_database + + + + from google.protobuf.internal import enum_type_wrapper # @@protoc_insertion_point(imports) @@ -319,7 +319,7 @@ ) _sym_db.RegisterMessage(Diagnostic) -if _descriptor._USE_C_DESCRIPTORS == False: +if _descriptor._USE_C_DESCRIPTORS is False: DESCRIPTOR._options = None DESCRIPTOR._serialized_options = b"Z-github.com/sourcegraph/scip/bindings/go/scip/" _SYNTAXKIND._options = None diff --git a/automata/core/search/symbol_graph.py b/automata/core/search/symbol_graph.py index 37187e5e..22a5a0bd 100644 --- a/automata/core/search/symbol_graph.py +++ b/automata/core/search/symbol_graph.py @@ -118,10 +118,7 @@ def get_references_to_module(self, module_name: str) -> List[SymbolReference]: List[SymbolReference]: List of symbol references """ reference_edges_in_module = self._graph.in_edges(module_name, data=True) - result = [] - for _, __, data in reference_edges_in_module: - if data["label"] == "reference": - result.append(data.get("symbol_reference")) + result = [data.get("symbol_reference") for (_, __, data) in reference_edges_in_module if data["label"] == "reference"] return result @@ -302,7 +299,7 @@ def _get_symbol_dependencies(self, symbol: Symbol) -> Set[Symbol]: # TODO: Consider implications of using list instead of set """ references_in_range = self._get_symbol_references_in_scope(symbol) - symbols_in_range = set([ref.symbol for ref in references_in_range]) + symbols_in_range = {ref.symbol for ref in references_in_range} return symbols_in_range def _get_symbol_references_in_scope(self, symbol: Symbol) -> List[SymbolReference]: @@ -345,13 +342,7 @@ def _get_symbol_relationships(self, symbol: Symbol) -> Set[Symbol]: # TODO: Consider implications of using list instead of set """ - related_symbol_nodes = set( - [ - target - for _, target, data in self._graph.out_edges(symbol, data=True) - if data.get("label") == "relationship" - ] - ) + related_symbol_nodes = {target for (_, target, data) in self._graph.out_edges(symbol, data=True) if data.get("label") == "relationship"} return related_symbol_nodes @staticmethod @@ -364,10 +355,7 @@ def _process_symbol_roles(role: int) -> Dict[str, bool]: Returns: Dict[str, bool]: A dictionary of symbol roles """ - result = {} - for role_name, role_value in SymbolRole.items(): - if (role & role_value) > 0: - result[role_name] = (role & role_value) > 0 + result = {role_name: role & role_value > 0 for (role_name, role_value) in SymbolRole.items() if role & role_value > 0} return result @staticmethod diff --git a/automata/core/search/symbol_parser.py b/automata/core/search/symbol_parser.py index e08bf1dd..aa3602ad 100644 --- a/automata/core/search/symbol_parser.py +++ b/automata/core/search/symbol_parser.py @@ -49,36 +49,35 @@ def parse_descriptor(self) -> Descriptor: descriptor = Descriptor(name, Descriptor.ScipSuffix.Parameter) self.accept_character(")", "closing parameter name") return descriptor - elif next_char == "[": + if next_char == "[": self.index += 1 name = self.accept_identifier("type parameter name") descriptor = Descriptor(name, Descriptor.ScipSuffix.TypeParameter) self.accept_character("]", "closing type parameter name") return descriptor - else: - name = self.accept_identifier("descriptor name") - suffix = self.current() - self.index += 1 - if suffix == "(": - disambiguator = "" - if self.current() != ")": - disambiguator = self.accept_identifier("method disambiguator") - descriptor = Descriptor(name, Descriptor.ScipSuffix.Method, disambiguator) - self.accept_character(")", "closing method") - self.accept_character(".", "closing method") - return descriptor - elif suffix == "/": - return Descriptor(name, Descriptor.ScipSuffix.Namespace) - elif suffix == ".": - return Descriptor(name, Descriptor.ScipSuffix.Term) - elif suffix == "#": - return Descriptor(name, Descriptor.ScipSuffix.Type) - elif suffix == ":": - return Descriptor(name, Descriptor.ScipSuffix.Meta) - elif suffix == "!": - return Descriptor(name, Descriptor.ScipSuffix.Macro) - else: - raise self.error("Expected a descriptor suffix") + + name = self.accept_identifier("descriptor name") + suffix = self.current() + self.index += 1 + if suffix == "(": + disambiguator = "" + if self.current() != ")": + disambiguator = self.accept_identifier("method disambiguator") + descriptor = Descriptor(name, Descriptor.ScipSuffix.Method, disambiguator) + self.accept_character(")", "closing method") + self.accept_character(".", "closing method") + return descriptor + if suffix == "/": + return Descriptor(name, Descriptor.ScipSuffix.Namespace) + if suffix == ".": + return Descriptor(name, Descriptor.ScipSuffix.Term) + if suffix == "#": + return Descriptor(name, Descriptor.ScipSuffix.Type) + if suffix == ":": + return Descriptor(name, Descriptor.ScipSuffix.Meta) + if suffix == "!": + return Descriptor(name, Descriptor.ScipSuffix.Macro) + raise self.error("Expected a descriptor suffix") def accept_identifier(self, what: str) -> str: if self.current() == "`": @@ -124,7 +123,7 @@ def accept_character(self, r: str, what: str) -> None: @staticmethod def is_identifier_character(c: str) -> bool: - return c.isalpha() or c.isdigit() or c in ["-", "+", "$", "_"] + return c.isalpha() or c.isdigit() or c in {"-", "+", "$", '_'} def parse_symbol(symbol_uri: str, include_descriptors: bool = True) -> Symbol: @@ -156,7 +155,7 @@ def new_local_symbol(symbol: str, id: str) -> Symbol: symbol, "local", Package("", "", ""), - tuple([Descriptor(id, Descriptor.ScipSuffix.Local)]), + (Descriptor(id, Descriptor.ScipSuffix.Local),), ) diff --git a/automata/core/search/symbol_rank/symbol_embedding_map.py b/automata/core/search/symbol_rank/symbol_embedding_map.py index 8078c69b..6eb5582e 100644 --- a/automata/core/search/symbol_rank/symbol_embedding_map.py +++ b/automata/core/search/symbol_rank/symbol_embedding_map.py @@ -165,7 +165,7 @@ def save(self, output_embedding_path: StrPath, overwrite: bool = False) -> None: # Raise error if the file already exists if os.path.exists(output_embedding_path) and not overwrite: raise ValueError("output_embedding_path must be a path to a non-existing file.") - with open(output_embedding_path, "w") as f: + with open(output_embedding_path, 'w', encoding='utf-8') as f: encoded_embedding = jsonpickle.encode(self.embedding_dict) f.write(encoded_embedding) @@ -181,7 +181,7 @@ def load(cls, input_embedding_path: StrPath) -> Dict[Symbol, SymbolEmbedding]: raise ValueError("input_embedding_path must be a path to an existing file.") embedding_dict = {} - with open(input_embedding_path, "r") as f: + with open(input_embedding_path, 'r', encoding='utf-8') as f: embedding_map_str_keys = jsonpickle.decode(f.read()) embedding_dict = { Symbol.from_string(key): value for key, value in embedding_map_str_keys.items() @@ -213,4 +213,4 @@ def _build_embedding_map(self, defined_symbols: List[Symbol]) -> Dict[Symbol, Sy except Exception as e: logger.error("Building embedding for symbol: %s failed with %s" % (symbol, e)) - return embedding_dict + return embedding_dict \ No newline at end of file diff --git a/automata/core/search/symbol_rank/symbol_rank.py b/automata/core/search/symbol_rank/symbol_rank.py index 709b4a62..50f86a45 100644 --- a/automata/core/search/symbol_rank/symbol_rank.py +++ b/automata/core/search/symbol_rank/symbol_rank.py @@ -125,9 +125,8 @@ def _prepare_initial_ranks( node_count = stochastic_graph.number_of_nodes() if initial_weights is None: return {k: 1.0 / node_count for k in stochastic_graph} - else: - s = sum(initial_weights.values()) - return {k: v / s for k, v in initial_weights.items()} + s = sum(initial_weights.values()) + return {k: v / s for k, v in initial_weights.items()} def _prepare_symbol_similarity( self, @@ -148,15 +147,14 @@ def _prepare_symbol_similarity( """ if symbol_similarity is None: return {k: 1.0 / node_count for k in stochastic_graph} - else: - missing = set(self.graph) - set(symbol_similarity) - if missing: - raise NetworkXError( - "symbol_similarity dictionary must have a value for every node. Missing nodes %s" - % missing - ) - s = sum(symbol_similarity.values()) - return {k: v / s for k, v in symbol_similarity.items()} + missing = set(self.graph) - set(symbol_similarity) + if missing: + raise NetworkXError( + "symbol_similarity dictionary must have a value for every node. Missing nodes %s" + % missing + ) + s = sum(symbol_similarity.values()) + return {k: v / s for k, v in symbol_similarity.items()} def _prepare_dangling_weights( self, dangling: Optional[Dict[str, float]], symbol_similarity: Dict[str, float] @@ -173,15 +171,14 @@ def _prepare_dangling_weights( """ if dangling is None: return symbol_similarity - else: - missing = set(self.graph) - set(dangling) - if missing: - raise NetworkXError( - "Dangling node dictionary must have a value for every node. Missing nodes %s" - % missing - ) - s = sum(dangling.values()) - return {k: v / s for k, v in dangling.items()} + missing = set(self.graph) - set(dangling) + if missing: + raise NetworkXError( + "Dangling node dictionary must have a value for every node. Missing nodes %s" + % missing + ) + s = sum(dangling.values()) + return {k: v / s for k, v in dangling.items()} def _get_dangling_nodes(self, stochastic_graph: nx.DiGraph) -> List[Hashable]: """ diff --git a/automata/core/search/symbol_rank/symbol_similarity.py b/automata/core/search/symbol_rank/symbol_similarity.py index 96d5db99..7f9d883e 100644 --- a/automata/core/search/symbol_rank/symbol_similarity.py +++ b/automata/core/search/symbol_rank/symbol_similarity.py @@ -38,8 +38,8 @@ def __init__( ) self.embedding_provider: EmbeddingsProvider = symbol_embedding_map.embedding_provider self.default_norm_type = norm_type - symbols = sorted(list(self.embedding_dict.keys()), key=lambda x: x.uri) - self.index_to_symbol = {i: symbol for i, symbol in enumerate(symbols)} + symbols = sorted(self.embedding_dict.keys()) + self.index_to_symbol = dict(enumerate(symbols)) self.symbol_to_index = {symbol: i for i, symbol in enumerate(symbols)} def transform_similarity_matrix( @@ -229,13 +229,12 @@ def _normalize_embeddings(embeddings: np.ndarray, norm_type: NormType) -> np.nda if norm_type == NormType.L1: norm = np.sum(np.abs(embeddings), axis=1, keepdims=True) return embeddings / norm - elif norm_type == NormType.L2: + if norm_type == NormType.L2: return embeddings / np.linalg.norm(embeddings, axis=1, keepdims=True) - elif norm_type == NormType.SOFTMAX: + if norm_type == NormType.SOFTMAX: e_x = np.exp(embeddings - np.max(embeddings, axis=1, keepdims=True)) return e_x / np.sum(e_x, axis=1, keepdims=True) - else: - raise ValueError(f"Invalid normalization type {norm_type}") + raise ValueError(f"Invalid normalization type {norm_type}") @staticmethod def _normalize_matrix(M: np.ndarray) -> np.ndarray: diff --git a/automata/core/search/symbol_rank/tests/test_symbol_embedding_map.py b/automata/core/search/symbol_rank/tests/test_symbol_embedding_map.py index aaf37777..f0a99535 100644 --- a/automata/core/search/symbol_rank/tests/test_symbol_embedding_map.py +++ b/automata/core/search/symbol_rank/tests/test_symbol_embedding_map.py @@ -22,7 +22,7 @@ def test_build_embedding_map( # Verify the results assert len(embedding_dict) == 200 - for _, symbol_embedding in embedding_dict.items(): + for symbol_embedding in embedding_dict.values(): assert symbol_embedding.vector.all() == mock_embedding.all() @@ -40,7 +40,7 @@ def test_save_load_embedding_map( sem.save(temp_output_filename) sem_load = SymbolEmbeddingMap.load(temp_output_filename) for key, val in sem_load.items(): - assert key.uri in [symbol.uri for symbol in sem.embedding_dict.keys()] + assert key.uri in (symbol.uri for symbol in sem.embedding_dict.keys()) def test_get_embedding_sets_correct_result( diff --git a/automata/core/search/symbol_rank/tests/test_symbol_rank.py b/automata/core/search/symbol_rank/tests/test_symbol_rank.py index 7a66f9a0..13744118 100644 --- a/automata/core/search/symbol_rank/tests/test_symbol_rank.py +++ b/automata/core/search/symbol_rank/tests/test_symbol_rank.py @@ -58,7 +58,7 @@ def test_get_ranks(): ranks = pagerank.get_ranks() assert len(ranks) == nodes - assert sum([ele[1] for ele in ranks]) == pytest.approx(1.0) + assert sum(ele[1] for ele in ranks) == pytest.approx(1.0) def test_get_ranks_small_graph(): @@ -71,4 +71,4 @@ def test_get_ranks_small_graph(): ranks = pagerank.get_ranks() assert len(ranks) == 3 - assert sum([ele[1] for ele in ranks]) == pytest.approx(1.0) + assert sum(ele[1] for ele in ranks) == pytest.approx(1.0) \ No newline at end of file diff --git a/automata/core/search/symbol_rank/tests/test_symbol_similarity.py b/automata/core/search/symbol_rank/tests/test_symbol_similarity.py index e963f8c9..9193f3cf 100644 --- a/automata/core/search/symbol_rank/tests/test_symbol_similarity.py +++ b/automata/core/search/symbol_rank/tests/test_symbol_similarity.py @@ -61,10 +61,9 @@ def test_get_nearest_symbols_for_query(monkeypatch, mock_simple_method_symbols): def mock_get_embedding(_, symbol_source): if symbol_source == "symbol1": return embedding1.vector - elif symbol_source == "symbol2": + if symbol_source == "symbol2": return embedding2.vector - else: - return embedding3.vector + return embedding3.vector monkeypatch.setattr(EmbeddingsProvider, "get_embedding", mock_get_embedding) @@ -108,10 +107,9 @@ def test_transform_similarity_matrix(monkeypatch, mock_simple_method_symbols): def mock_get_embedding(_, symbol_source): if symbol_source == "symbol1": return embedding1.vector - elif symbol_source == "symbol2": + if symbol_source == "symbol2": return embedding2.vector - else: - return embedding3.vector + return embedding3.vector monkeypatch.setattr(EmbeddingsProvider, "get_embedding", mock_get_embedding) @@ -154,10 +152,9 @@ def test_generate_unit_normed_query_vector(monkeypatch, mock_simple_method_symbo def mock_get_embedding(_, symbol_source): if symbol_source == "symbol1": return embedding1.vector - elif symbol_source == "symbol2": + if symbol_source == "symbol2": return embedding2.vector - else: - return embedding3.vector + return embedding3.vector monkeypatch.setattr(EmbeddingsProvider, "get_embedding", mock_get_embedding) diff --git a/automata/core/search/symbol_types.py b/automata/core/search/symbol_types.py index ff9c2314..e0847c0f 100644 --- a/automata/core/search/symbol_types.py +++ b/automata/core/search/symbol_types.py @@ -49,20 +49,19 @@ def unparse(self): escaped_name = Descriptor.get_escaped_name(self.name) if self.suffix == Descriptor.ScipSuffix.Namespace: return f"{escaped_name}/" - elif self.suffix == Descriptor.ScipSuffix.Type: + if self.suffix == Descriptor.ScipSuffix.Type: return f"{escaped_name}#" - elif self.suffix == Descriptor.ScipSuffix.Term: + if self.suffix == Descriptor.ScipSuffix.Term: return f"{escaped_name}." - elif self.suffix == Descriptor.ScipSuffix.Meta: + if self.suffix == Descriptor.ScipSuffix.Meta: return f"{escaped_name}:" - elif self.suffix == Descriptor.ScipSuffix.Method: + if self.suffix == Descriptor.ScipSuffix.Method: return f"{escaped_name}({self.disambiguator})." - elif self.suffix == Descriptor.ScipSuffix.Parameter: + if self.suffix == Descriptor.ScipSuffix.Parameter: return f"({escaped_name})" - elif self.suffix == Descriptor.ScipSuffix.TypeParameter: + if self.suffix == Descriptor.ScipSuffix.TypeParameter: return f"[{escaped_name}]" - else: - raise ValueError(f"Invalid descriptor suffix: {self.suffix}") + raise ValueError(f"Invalid descriptor suffix: {self.suffix}") @staticmethod def get_escaped_name(name): @@ -80,29 +79,28 @@ def convert_scip_to_python_suffix(descriptor_suffix: DescriptorProto) -> PythonK if descriptor_suffix == Descriptor.ScipSuffix.Local: return Descriptor.PythonKinds.Local - elif descriptor_suffix == Descriptor.ScipSuffix.Namespace: + if descriptor_suffix == Descriptor.ScipSuffix.Namespace: return Descriptor.PythonKinds.Module - elif descriptor_suffix == Descriptor.ScipSuffix.Type: + if descriptor_suffix == Descriptor.ScipSuffix.Type: return Descriptor.PythonKinds.Class - elif descriptor_suffix == Descriptor.ScipSuffix.Method: + if descriptor_suffix == Descriptor.ScipSuffix.Method: return Descriptor.PythonKinds.Method - elif descriptor_suffix == Descriptor.ScipSuffix.Term: + if descriptor_suffix == Descriptor.ScipSuffix.Term: return Descriptor.PythonKinds.Value - elif descriptor_suffix == Descriptor.ScipSuffix.Macro: + if descriptor_suffix == Descriptor.ScipSuffix.Macro: return Descriptor.PythonKinds.Macro - elif descriptor_suffix == Descriptor.ScipSuffix.Parameter: + if descriptor_suffix == Descriptor.ScipSuffix.Parameter: return Descriptor.PythonKinds.Parameter - elif descriptor_suffix == Descriptor.ScipSuffix.TypeParameter: + if descriptor_suffix == Descriptor.ScipSuffix.TypeParameter: return Descriptor.PythonKinds.TypeParameter - else: - return Descriptor.PythonKinds.Meta + return Descriptor.PythonKinds.Meta @dataclass @@ -135,7 +133,7 @@ def __hash__(self) -> int: def __eq__(self, other): if isinstance(other, Symbol): return self.uri == other.uri - elif isinstance(other, str): + if isinstance(other, str): return self.uri == other return False @@ -147,20 +145,19 @@ def symbol_raw_kind_by_suffix(self) -> DescriptorProto: return Descriptor.ScipSuffix.Local if self.uri.endswith("/"): return Descriptor.ScipSuffix.Namespace - elif self.uri.endswith("#"): + if self.uri.endswith("#"): return Descriptor.ScipSuffix.Type - elif self.uri.endswith(")."): + if self.uri.endswith(")."): return Descriptor.ScipSuffix.Method - elif self.uri.endswith("."): + if self.uri.endswith("."): return Descriptor.ScipSuffix.Term - elif self.uri.endswith(":"): + if self.uri.endswith(":"): return Descriptor.ScipSuffix.Meta - elif self.uri.endswith(")"): + if self.uri.endswith(")"): return Descriptor.ScipSuffix.Parameter - elif self.uri.endswith("]"): + if self.uri.endswith("]"): return Descriptor.ScipSuffix.TypeParameter - else: - raise ValueError(f"Invalid descriptor suffix: {self.uri}") + raise ValueError(f"Invalid descriptor suffix: {self.uri}") def parent(self) -> "Symbol": parent_descriptors = list(self.descriptors)[:-1] @@ -232,6 +229,6 @@ def __hash__(self) -> int: def __eq__(self, other): if isinstance(other, File): return self.path == other.path - elif isinstance(other, str): + if isinstance(other, str): return self.path == other return False diff --git a/automata/core/search/symbol_utils.py b/automata/core/search/symbol_utils.py index f013048d..77fb8bb1 100644 --- a/automata/core/search/symbol_utils.py +++ b/automata/core/search/symbol_utils.py @@ -133,14 +133,14 @@ def sync_graph_and_dict( """ # Use list() to create a copy of the node list, as you can't modify a list while iterating over it - for node in list(graph.nodes()): + for node in graph.nodes(): if node not in dictionary: graph.remove_node(node) # Again, use list() to create a copy of the key list - for key in list(dictionary.keys()): + for (key, dictionary_key) in dictionary.items(): if key not in graph: - del dictionary[key] + del dictionary_key return graph, dictionary @@ -175,7 +175,5 @@ def transform_dict_values(dictionary: Dict[Any, float], func: Callable[[List[flo transformed_values = func([dictionary[key] for key in dictionary]) # Re-distribute the transformed values back into the dictionary - transformed_dict = {} - for i, key in enumerate(dictionary): - transformed_dict[key] = transformed_values[i] + transformed_dict = {key: transformed_values[i] for (i, key) in enumerate(dictionary)} return transformed_dict diff --git a/automata/core/tasks/automata_task_registry.py b/automata/core/tasks/automata_task_registry.py index eff54d98..567c9777 100644 --- a/automata/core/tasks/automata_task_registry.py +++ b/automata/core/tasks/automata_task_registry.py @@ -165,12 +165,11 @@ def get_task_by_id(self, task_id: str) -> Optional[AutomataTask]: ) if not results: return None - else: - if len(results) != 1: - raise Exception(f"Found multiple tasks with id {task_id}") - task = results[0] - task.observer = self.update_task - return task + if len(results) != 1: + raise Exception(f"Found multiple tasks with id {task_id}") + task = results[0] + task.observer = self.update_task + return task def get_all_tasks(self) -> list[AutomataTask]: results = self.db.get_tasks_by(query="SELECT json FROM tasks") diff --git a/automata/core/tasks/task.py b/automata/core/tasks/task.py index 1b9738cd..88550fb5 100644 --- a/automata/core/tasks/task.py +++ b/automata/core/tasks/task.py @@ -141,11 +141,10 @@ def build_agent_manager(self): ) return agent_manager - else: - return AutomataManagerFactory.create_manager( - AutomataAgentFactory.create_agent(instructions=instructions, config=main_config), - {}, - ) + return AutomataManagerFactory.create_manager( + AutomataAgentFactory.create_agent(instructions=instructions, config=main_config), + {}, + ) def validate_initialization(self): """ @@ -167,20 +166,7 @@ def to_partial_json(self) -> Dict[str, str]: Returns a JSON representation of key attributes of the task. """ - result = { - "task_id": str(self.task_id), - "status": self.status.value, - "priority": self.priority, - "max_retries": self.max_retries, - "retry_count": self.retry_count, - "path_to_root_py": self.path_to_root_py, - "result": self.result, - "error": self.error, - } - result["model"] = self.kwargs.get("model", "gpt-4") - result["llm_toolkits"] = self.kwargs.get("llm_toolkits", "").split(",") - result["instructions"] = self.kwargs.get("instructions", None) - result["instruction_payload"] = self.kwargs.get("instruction_payload", None) + result = {"task_id": str(self.task_id), "status": self.status.value, "priority": self.priority, "max_retries": self.max_retries, "retry_count": self.retry_count, "path_to_root_py": self.path_to_root_py, "result": self.result, "error": self.error, "model": self.kwargs.get("model", "gpt-4"), "llm_toolkits": self.kwargs.get("llm_toolkits", "").split(","), "instructions": self.kwargs.get("instructions", None), "instruction_payload": self.kwargs.get("instruction_payload", None)} main_config = self.kwargs.get("main_config", None) if main_config: result["main_config"] = main_config.config_name.value @@ -219,8 +205,7 @@ def get_logs(self) -> str: log_file = os.path.join(log_dir, Task.TASK_LOG_NAME.replace("TASK_ID", str(self.task_id))) if os.path.exists(log_file): - with open(log_file, "r") as f: + with open(log_file, 'r', encoding='utf-8') as f: log_content = f.read() return log_content - else: - raise FileNotFoundError(f"Log file {log_file} not found.") + raise FileNotFoundError(f"Log file {log_file} not found.") \ No newline at end of file diff --git a/automata/core/tasks/test/test_task.py b/automata/core/tasks/test/test_task.py index e62800ff..6ac99601 100644 --- a/automata/core/tasks/test/test_task.py +++ b/automata/core/tasks/test/test_task.py @@ -56,8 +56,7 @@ def registry(task): def mock_get_tasks_by(query, params): if params[0] == task.task_id: return [task] - else: - return [] + return [] db = MagicMock() repo_manager = MockRepositoryManager() diff --git a/automata/core/tests/conftest.py b/automata/core/tests/conftest.py index bb71d53a..c85f6764 100644 --- a/automata/core/tests/conftest.py +++ b/automata/core/tests/conftest.py @@ -65,7 +65,7 @@ def cleanup_and_check(expected_content: str, file_name: str) -> None: assert os.path.isfile(file_path), "File does not exist" # Check if the content of the file is as expected - with open(file_path, "r") as file: + with open(file_path, 'r', encoding='utf-8') as file: content = file.read() # Delete the whole "sample_code" directory after the test @@ -93,4 +93,4 @@ def wrapper(*args, **kwargs): return wrapper - return decorator + return decorator \ No newline at end of file diff --git a/automata/core/utils.py b/automata/core/utils.py index 20d90d22..6b60b627 100644 --- a/automata/core/utils.py +++ b/automata/core/utils.py @@ -54,13 +54,10 @@ def load_config( Returns: Any: The content of the YAML file as a Python object. """ - with open( - os.path.join(config_path(), config_name, f"{file_name}.{config_type}"), - "r", - ) as file: + with open(os.path.join(config_path(), config_name, f'{file_name}.{config_type}'), 'r', encoding='utf-8') as file: if config_type == "yaml": return yaml.safe_load(file) - elif config_type == "json": + if config_type == "json": samples_json_string = file.read() return json.loads(samples_json_string, object_hook=custom_decoder) @@ -154,4 +151,4 @@ def calculate_similarity( dot_product = np.dot(embedding_a, embedding_b) magnitude_a = np.sqrt(np.dot(embedding_a, embedding_a)) magnitude_b = np.sqrt(np.dot(embedding_b, embedding_b)) - return dot_product / (magnitude_a * magnitude_b) + return dot_product / (magnitude_a * magnitude_b) \ No newline at end of file diff --git a/automata/evals/eval_helpers.py b/automata/evals/eval_helpers.py index e4667762..0da910a6 100644 --- a/automata/evals/eval_helpers.py +++ b/automata/evals/eval_helpers.py @@ -30,9 +30,7 @@ def token_match(self, action_str: str) -> bool: """ Performs a relative comparison between two actions """ - for token in self.tokens: - if token not in action_str: - return False + return not any(token not in action_str for token in self.tokens) return True def full_match(self, extracted_action: Action, expected_action: Action) -> bool: @@ -110,4 +108,4 @@ def calc_eval_result( if not has_full_match: all_full_matches = False - return EvalResult(token_match=all_token_matches, full_match=all_full_matches) + return EvalResult(token_match=all_token_matches, full_match=all_full_matches) \ No newline at end of file diff --git a/automata/tool_management/coverage_tool_manager.py b/automata/tool_management/coverage_tool_manager.py index 31f0f927..e57395ab 100644 --- a/automata/tool_management/coverage_tool_manager.py +++ b/automata/tool_management/coverage_tool_manager.py @@ -12,7 +12,7 @@ def __init__(self, **kwargs): self.coverage_processor = CoverageProcessor(coverage_analyzer, do_create_issue=True) self.model = kwargs.get("model") or "gpt-4" self.temperature = kwargs.get("temperature") or 0.7 - self.verbose = kwargs.get("verbose") or False + self.verbose = kwargs.get("verbose") self.stream = kwargs.get("stream") or True def _run_list_coverage_gaps(self, input_module): diff --git a/automata/tool_management/python_code_retriever_tool_manager.py b/automata/tool_management/python_code_retriever_tool_manager.py index b4e12f77..d263d2e2 100644 --- a/automata/tool_management/python_code_retriever_tool_manager.py +++ b/automata/tool_management/python_code_retriever_tool_manager.py @@ -31,7 +31,7 @@ def __init__(self, **kwargs): ) self.model = kwargs.get("model") or "gpt-4" self.temperature = kwargs.get("temperature") or 0.7 - self.verbose = kwargs.get("verbose") or False + self.verbose = kwargs.get("verbose") self.stream = kwargs.get("stream") or True def build_tools(self) -> List[Tool]: diff --git a/automata/tool_management/tool_management_utils.py b/automata/tool_management/tool_management_utils.py index 7129d1bf..89ab7d47 100644 --- a/automata/tool_management/tool_management_utils.py +++ b/automata/tool_management/tool_management_utils.py @@ -30,7 +30,7 @@ def create_tool_manager(toolkit_type: ToolkitType) -> BaseToolManager: return PythonCodeRetrieverToolManager( python_retriever=ToolManagerFactory._retriever_instance ) - elif toolkit_type == ToolkitType.PYTHON_WRITER: + if toolkit_type == ToolkitType.PYTHON_WRITER: if ToolManagerFactory._retriever_instance is None: ToolManagerFactory._retriever_instance = PythonCodeRetriever() @@ -40,12 +40,12 @@ def create_tool_manager(toolkit_type: ToolkitType) -> BaseToolManager: return PythonWriterToolManager( python_writer=PythonWriter(ToolManagerFactory._retriever_instance) ) - elif toolkit_type == ToolkitType.COVERAGE_PROCESSOR: + if toolkit_type == ToolkitType.COVERAGE_PROCESSOR: CoverageToolManager = importlib.import_module( "automata.tool_management.coverage_tool_manager" ).CoverageToolManager return CoverageToolManager() - elif toolkit_type == ToolkitType.SYMBOL_SEARCHER: + if toolkit_type == ToolkitType.SYMBOL_SEARCHER: SymbolSearcherToolManager = importlib.import_module( "automata.tool_management.symbol_searcher_tool_manager" ).SymbolSearcherToolManager @@ -54,8 +54,7 @@ def create_tool_manager(toolkit_type: ToolkitType) -> BaseToolManager: index_name="index.scip", symbol_embedding_name="symbol_embedding.json" ) ) - else: - raise ValueError("Unknown toolkit type: %s" % toolkit_type) + raise ValueError("Unknown toolkit type: %s" % toolkit_type) class ToolkitBuilder: diff --git a/automata/tools/coverage_tools/coverage_processor.py b/automata/tools/coverage_tools/coverage_processor.py index 2e2f5c5d..8a33b5f0 100644 --- a/automata/tools/coverage_tools/coverage_processor.py +++ b/automata/tools/coverage_tools/coverage_processor.py @@ -38,9 +38,7 @@ def process_coverage_gap(self, module: str, object: str): coverage_df = self.get_coverage_df(module_path) # get lines from df by module and object uncovered_line_numbers = sorted( - coverage_df[ - (coverage_df["module"] == module_path) & (coverage_df["object"] == function_name) - ]["line_number"].iloc[0] + coverage_df[(coverage_df["module"] == module_path) & (coverage_df["object"] == function_name)]["line_number"].iat[0] ) uncovered_line_numbers_queue = uncovered_line_numbers[:] diff --git a/automata/tools/python_tools/python_writer.py b/automata/tools/python_tools/python_writer.py index 98b58e41..b747d92b 100644 --- a/automata/tools/python_tools/python_writer.py +++ b/automata/tools/python_tools/python_writer.py @@ -182,7 +182,7 @@ def _write_module_to_disk(self, module_dotpath: str) -> None: f"Module fpath found in module map for dotpath: {module_dotpath}" ) module_fpath = cast(str, module_fpath) - with open(module_fpath, "w") as output_file: + with open(module_fpath, 'w', encoding='utf-8') as output_file: output_file.write(source_code) subprocess.run(["black", module_fpath]) subprocess.run(["isort", module_fpath]) @@ -302,7 +302,7 @@ def replace(match): @staticmethod def _update_imports(module_obj: RedBaron, new_import_statements: NodeList) -> None: """Manage the imports in the module.""" - first_import = module_obj.find(lambda identifier: identifier in ("import", "from_import")) + first_import = module_obj.find(lambda identifier: identifier in {"import", "from_import"}) for new_import_statement in new_import_statements: existing_import_statement = find_import_syntax_tree_node_by_name( @@ -312,4 +312,4 @@ def _update_imports(module_obj: RedBaron, new_import_statements: NodeList) -> No if first_import: first_import.insert_before(new_import_statement) # we will run isort later else: - module_obj.insert(0, new_import_statement) + module_obj.insert(0, new_import_statement) \ No newline at end of file diff --git a/automata/tools/python_tools/tests/test_python_writer.py b/automata/tools/python_tools/tests/test_python_writer.py index 225d5547..68d21f90 100644 --- a/automata/tools/python_tools/tests/test_python_writer.py +++ b/automata/tools/python_tools/tests/test_python_writer.py @@ -232,7 +232,7 @@ def test_create_update_write_module(python_writer): root_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "sample_modules") fpath = os.path.join(root_dir, "sample_module_write.py") assert os.path.exists(fpath) - with open(fpath, "r") as f: + with open(fpath, 'r', encoding='utf-8') as f: contents = f.read() assert_code_lines_equal(source_code, contents) @@ -246,7 +246,7 @@ def test_create_update_write_module(python_writer): source_code=source_code_2, module_dotpath="sample_module_write", do_write=True ) - with open(fpath, "r") as f: + with open(fpath, 'r', encoding='utf-8') as f: contents = f.read() assert_code_lines_equal("\n".join([source_code, source_code_2]), contents) @@ -374,4 +374,4 @@ def test_write_and_retrieve_mock_code(python_writer): module_map = LazyModuleTreeMap(sample_dir) retriever = PythonCodeRetriever(module_map) module_docstring = retriever.get_docstring("sample_module_2", None) - assert module_docstring == mock_generator.module_docstring + assert module_docstring == mock_generator.module_docstring \ No newline at end of file diff --git a/automata/tools/search/symbol_searcher.py b/automata/tools/search/symbol_searcher.py index 63b95a96..a2760edf 100644 --- a/automata/tools/search/symbol_searcher.py +++ b/automata/tools/search/symbol_searcher.py @@ -127,11 +127,10 @@ def process_query( if search_type == "symbol_references": return self.symbol_references(query_remainder) - elif search_type == "symbol_rank": + if search_type == "symbol_rank": return self.symbol_rank_search(query_remainder) - elif search_type == "exact": + if search_type == "exact": return self.exact_search(query_remainder) - elif search_type == "source": + if search_type == "source": return self.retrieve_source_code_by_symbol(query_remainder) - else: - raise ValueError(f"Unknown search type: {search_type}") + raise ValueError(f"Unknown search type: {search_type}") diff --git a/setup.py b/setup.py index 29e4d58b..5ff2ddcd 100644 --- a/setup.py +++ b/setup.py @@ -2,7 +2,7 @@ def read_requirements(): - with open("requirements.txt", "r") as req_file: + with open('requirements.txt', 'r', encoding='utf-8') as req_file: return req_file.readlines() @@ -19,4 +19,4 @@ def read_requirements(): ], }, python_requires=">=3.9", # Adjust this to your desired minimum Python version -) +) \ No newline at end of file