diff --git a/pyproject.toml b/pyproject.toml index 97b7d74..efe735a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,7 +14,6 @@ pydantic-xml = "^2.9.0" numpy = "^1.26.3" pandas = "^2.1.4" jinja2 = "^3.1.3" -black = "^23.12.1" typer = "^0.9.0" pyyaml = "^6.0.1" toml = "^0.10.2" @@ -31,12 +30,13 @@ markdown-it-py = "^3.0.0" rich = "^13.7.0" lxml = "^5.1.0" python-frontmatter = "^1.1.0" +ruff = "^0.4.1" +black = "23.12.1" [tool.poetry.group.dev.dependencies] coverage = "^7.4.0" pytest-cov = "^4.1.0" - [tool.poetry.group.hdf5.dependencies] h5py = "^3.10.0" diff --git a/sdRDM/generator/codegen.py b/sdRDM/generator/codegen.py index 9ddfeac..f9d871d 100644 --- a/sdRDM/generator/codegen.py +++ b/sdRDM/generator/codegen.py @@ -163,9 +163,13 @@ def save_rendered_to_file(rendered: str, path: str, use_formatter: bool) -> None [ sys.executable, "-m", - "autoflake", - "--in-place", - "--remove-all-unused-imports", + "ruff", + "check", + "--fix", + "--select", + "I", + "--select", + "F", path, ] ) @@ -209,3 +213,14 @@ def _write_json_schemes(libpath: str, lib_name: str): with open(path, "w") as f: json.dump(cls.model_json_schema(), f, indent=2) + + +if __name__ == "__main__": + # Test the function + from sdRDM.generator import generate_python_api + + generate_python_api( + path="/Users/max/Documents/GitHub/pyeed/specifications/sequence_record.md", + dirpath="/Users/max/Documents/GitHub/pyeed/pyeed_ontology", + libname="pyeedd", + ) diff --git a/sdRDM/generator/updater.py b/sdRDM/generator/updater.py index 3da52c9..fb043ca 100644 --- a/sdRDM/generator/updater.py +++ b/sdRDM/generator/updater.py @@ -1,43 +1,19 @@ import ast import re - -from typing import List -from enum import Enum, auto +from typing import List, Tuple # Constants REFERENCE_PATTERN = r"get_[A-Za-z0-9\_]*_reference" ADDER_PATTERN = r"add_to_[A-Za-z0-9\_]*" INHERITANCE_PATTERN = r"class [A-Za-z0-9\_\.]*\(([A-Za-z0-9\_\.]*)\)\:" ATTRIBURE_PATTERN = r"description=(\"|\')[A-Za-z0-9\_\.]*" -FUNCTION_PATTERN = r"def ([a-zA-Z0-9_]+)\(" +FUNCTION_PATTERN = r"(async\s+)?def\s+([a-zA-Z0-9_]+)\(" FUNCTION_NAME_PATTERN = r"def ([a-zA-Z0-9_]+)\(" XML_PARSER_PATTERN = r"_parse_raw_xml_data" ANNOTATION_PATTERN = r"_validate_annotation" -class ModuleOrder(Enum): - IMPORT_SDRDM = auto() - IMPORT_MISC = auto() - - FROM_TYPING = auto() - FROM_PYDANTIC = auto() - FROM_SDRDM = auto() - FROM_MISC = auto() - FROM_LOCAL = auto() - - CLASSES = auto() - FUNCTIONS = auto() - MISC = auto() - - -class ClassOrder(Enum): - DOCSTRING = auto() - ATTRIBUTES = auto() - PRIV_ATTRIBUTES = auto() - METHODS = auto() - - -def preserve_custom_functions(rendered_class: str, path: str) -> str: +def preserve_custom_functions(rendered: str, existing_path: str) -> str: """When a class has already been written and modified, this function will read the original script and accordingly change only attributes and data model related methods, while preserving custom methods. @@ -47,96 +23,71 @@ def preserve_custom_functions(rendered_class: str, path: str) -> str: path (str): Path to the previous file """ - custom_methods = extract_custom_methods(rendered_class, path) + with open(existing_path, "r") as f: + existing = f.readlines() - # Turn the rendered class into an Abstract Syntax Tree and get the class - new_module = ast.parse(rendered_class) - previous_module = ast.parse(open(path).read()) + imports = concatinate_imports(rendered, existing) - # Format and merge imports - _format_imports(new_module, previous_module) + constants = extract_constants(existing) - # Get the class body - new_class = _stylize_class(ast.unparse(new_module)) - - # Merge the previous custom methods with the new class - return "\n".join([new_class, "\n", custom_methods]) + # Extract custom methods + custom_method_positions = get_custom_method_position_slices(existing) + custom_methods = truncate_custom_methods(existing, custom_method_positions) + # newly generated code + rendered_class = remove_imports(rendered) -def extract_custom_methods(rendered_class: str, path: str) -> List[str]: - with open(path, "r") as file: - previous_class = file.read().split("\n") - - # Identify lines where functions start and end - method_starts = [] - for line_count, line in enumerate(previous_class): - if not re.findall(FUNCTION_PATTERN, line): - continue - - # Ignore adder, annotation, and xml parser functions - if re.findall(ADDER_PATTERN, line): - continue - elif re.findall(ANNOTATION_PATTERN, line): - continue - elif re.findall(XML_PARSER_PATTERN, line): - continue - - # Account for decorators - if previous_class[line_count - 1].strip().startswith("@"): - method_starts.append(line_count - 1) - else: - method_starts.append(line_count) - - # Deduct the end of each function - method_ends = [fun_start - 1 for fun_start in method_starts[1:]] - method_ends.append(len(previous_class)) - - # Extract each custom method - methods = [] - for start, end in zip(method_starts, method_ends): - methods.append("\n".join(previous_class[start:end])) + # Merge the previous custom methods with the new class + combined = ( + imports + "\n\n" + constants + "\n\n" + rendered_class + "\n\n" + custom_methods + ) - return "\n".join(methods) + return combined -def _stylize_class(rendered: str): - """Inserts newlines to render the code more readable""" +def extract_constants(custom_class: List[str]) -> str: + """Extracts constants from a custom class. This includes all statements except + import statements before the class definition - if "Enum" in rendered: - return rendered + Args: + custom_class (List[str]): List of strings representing the generated code - nu_render = [] - for line in rendered.split("\n"): - if bool(re.findall(r"[Field|PrivateAttr]", line)) and ": " in line: - # restore formatting for attributes - if bool(re.findall(ATTRIBURE_PATTERN, line)): - line = line[:-1] + "," + line[-1] - nu_render.append("\n") + Returns: + str: Import statements as a string + """ - nu_render.append(line) + constants = [] + for line in custom_class: + line = line.strip() + if line.startswith("@") or line.startswith("class"): + constants.append("\n") + return "".join(constants) + elif line.startswith("from") or line.startswith("import"): + continue else: - nu_render.append(line) - - return _insert_new_lines("\n".join(nu_render)) + constants.append(line) + return "" -def _insert_new_lines(rendered: str): - """Inserts new lines for imports""" - rendered = rendered.replace("import sdRDM", "import sdRDM\n") - return rendered +def concatinate_imports(new_code: List[str], previous_code: List[str]) -> str: + """_summary_ + Args: + new_code (List[str]): Code for the new module + previous_code (List[str]): Code for the previous module -def _format_imports(new_module, previous_module): - """Formats given inputs and merges previous imports to the new ones""" + Returns: + str: Unique, concatenated import statements + """ # Get all imports imports = [] # import ... from_modules = [] # from ... import ... # Get all imports from the new module - for node in ast.walk(new_module): + for node in ast.walk(ast.parse(new_code)): if not isinstance(node, (ast.Import, ast.ImportFrom)): continue @@ -151,8 +102,7 @@ def _format_imports(new_module, previous_module): if node.module not in [imp.module for imp in from_modules]: from_modules.append(node) - # Get all imports from the previous module - for node in ast.walk(previous_module): + for node in ast.walk(ast.parse("\n".join(previous_code))): if not isinstance(node, (ast.Import, ast.ImportFrom)): continue @@ -161,55 +111,147 @@ def _format_imports(new_module, previous_module): ]: imports.append(node) - if isinstance(node, ast.Import): + if not isinstance(node, ast.ImportFrom): continue - if node.module not in [imp.module for imp in from_modules]: + if node.module not in [from_import.module for from_import in from_modules]: from_modules.append(node) + else: - # Add submodules to existing imports - for imp in from_modules: - if imp.module != node.module: + # Add classes to existing imports of modules + for from_import in from_modules: + if from_import.module != node.module: continue - for sub_module in node.names: - if sub_module.name not in [submod.name for submod in imp.names]: - imp.names.append(sub_module) - - # Add modified imports to the new module - nu_body = [] - for element in new_module.body: - if isinstance(element, (ast.Import, ast.ImportFrom)): - continue + for module_class in node.names: + if module_class.name not in [ + submod.name for submod in from_import.names + ]: + from_import.names.append(module_class) - nu_body.append(element) + all_imports = imports + from_modules - nu_body += imports + from_modules + imports = [ast.unparse(node) for node in all_imports] - new_module.body = sorted(nu_body, key=_sort_module) + return "\n".join(imports) -def _sort_module(element): - """Sorts module Imports > Classes > Functions""" +def get_custom_method_position_slices(custom: List[str]) -> List[Tuple[int, int]]: + """Extracts the start and end positions of custom methods in a class - if isinstance(element, ast.Import): - if "sdRDM" in ast.unparse(element): - return ModuleOrder.IMPORT_SDRDM.value - else: - return ModuleOrder.IMPORT_MISC.value - elif isinstance(element, ast.ImportFrom): - if "from ." in ast.unparse(element): - return ModuleOrder.FROM_LOCAL.value - elif element.module == "typing": - return ModuleOrder.FROM_TYPING.value - elif element.module == "pydantic": - return ModuleOrder.FROM_PYDANTIC.value - elif element.module == "sdRDM": - return ModuleOrder.FROM_SDRDM.value + Args: + custom (List[str]): List of strings representing the existing code + + Returns: + List[Tuple[int, int]]: List of tuples representing the start and + end positions of custom methods of the existing code. + """ + + # Identify lines where functions start + custom_method_starts = [] + for line_count, line in enumerate(custom): + if not re.findall(FUNCTION_PATTERN, line): + continue + + # Ignore sdrdm-generated methods + if re.findall(ADDER_PATTERN, line): + continue + elif re.findall(ANNOTATION_PATTERN, line): + continue + elif re.findall(XML_PARSER_PATTERN, line): + continue + + # Account for decorators + if custom[line_count - 1].strip().startswith("@"): + if custom[line_count - 2].strip().startswith("@"): + custom_method_starts.append(line_count - 2) + custom_method_starts.append(line_count - 1) else: - return ModuleOrder.FROM_MISC.value - elif isinstance(element, ast.ClassDef): - return ModuleOrder.CLASSES.value - elif isinstance(element, ast.FunctionDef): - return ModuleOrder.FUNCTIONS.value - else: - return ModuleOrder.MISC.value + custom_method_starts.append(line_count) + + if not custom_method_starts: + return [] + + custom_method_ends = [start - 1 for start in custom_method_starts[1:]] + custom_method_ends.append(len(custom)) + + assert len(custom_method_starts) == len( + custom_method_ends + ), "The number of method starts and ends do not match." + + return list(zip(custom_method_starts, custom_method_ends)) + + +def truncate_custom_methods( + custom: List[str], method_slices: List[Tuple[int, int]] +) -> List[str]: + + if not method_slices: + return "" + + return "\n".join( + ["".join(custom[slice(*start_end)]) for start_end in method_slices] + ) + + +def remove_imports(rendered: List[str]) -> List[str]: + """Removes all import statements from the rendered code + + Args: + rendered (List[str]): List of strings representing the rendered code + + Returns: + List[str]: List of strings representing the rendered code without import statements + """ + + rendered = rendered.split("\n") + + try: + start = [ + line_id + for line_id, line in enumerate(rendered) + if line.strip().startswith("class") + ][0] + except IndexError: + return "\n" + + res = "\n".join(rendered[start:]) + + return res + + +# def extract_custom_methods(existing_path: str) -> List[str]: +# with open(existing_path, "r") as file: +# previous_class = file.read().split("\n") + +# # Identify lines where functions start and end +# method_starts = [] +# for line_count, line in enumerate(previous_class): +# if not re.findall(FUNCTION_PATTERN, line): +# continue + +# # Ignore adder, annotation, and xml parser functions +# if re.findall(ADDER_PATTERN, line): +# continue +# elif re.findall(ANNOTATION_PATTERN, line): +# continue +# elif re.findall(XML_PARSER_PATTERN, line): +# continue + +# # Account for decorators +# if previous_class[line_count - 1].strip().startswith("@"): +# if previous_class[line_count - 2].strip().startswith("@"): +# method_starts.append(line_count - 2) +# method_starts.append(line_count - 1) +# else: +# method_starts.append(line_count) + +# # Deduct the end of each function +# method_ends = [fun_start - 1 for fun_start in method_starts[1:]] +# method_ends.append(len(previous_class)) + +# # Extract each custom method +# methods = [] +# for start, end in zip(method_starts, method_ends): +# methods.append("".join(previous_class[start:end])) + +# return "".join(methods)