From a071516a9964e59cfaf022614c77dce460d43c6c Mon Sep 17 00:00:00 2001 From: Alcides Fonseca Date: Wed, 27 Nov 2024 00:41:49 +0000 Subject: [PATCH] Integrated new synthesis with polymorphism --- aeon/core/substitutions.py | 85 ++++++----- aeon/synthesis_grammar/grammar.py | 136 ++++++++++------- aeon/synthesis_grammar/identification.py | 5 +- aeon/synthesis_grammar/synthesizer.py | 178 ++++++++++++----------- tests/synth_fitness_test.py | 9 +- 5 files changed, 225 insertions(+), 188 deletions(-) diff --git a/aeon/core/substitutions.py b/aeon/core/substitutions.py index 9f8e0409..9f053de7 100644 --- a/aeon/core/substitutions.py +++ b/aeon/core/substitutions.py @@ -106,7 +106,7 @@ def substitution_in_liquid( """Substitutes name in the term t with the new replacement term rep.""" assert isinstance(rep, LiquidTerm) if isinstance( - t, + t, ( LiquidLiteralInt, LiquidLiteralBool, @@ -131,10 +131,7 @@ def substitution_in_liquid( else: return LiquidHornApplication( t.name, - [ - (substitution_in_liquid(a, rep, name), t) - for (a, t) in t.argtypes - ], + [(substitution_in_liquid(a, rep, name), t) for (a, t) in t.argtypes], ) else: assert False @@ -198,44 +195,45 @@ def substitution(t: Term, rep: Term, name: str) -> Term: def rec(x: Term): return substitution(x, rep, name) - if isinstance(t, Literal): - return t - elif isinstance(t, Var): - if t.name == name: - return rep - return t - elif isinstance(t, Hole): - if t.name == name: - return rep - return t - elif isinstance(t, Application): - return Application(fun=rec(t.fun), arg=rec(t.arg)) - elif isinstance(t, Abstraction): - if t.var_name == name: + match t: + case Literal(_): return t - else: - return Abstraction(t.var_name, rec(t.body)) - elif isinstance(t, Let): - if t.var_name == name: - n_value = t.var_value - n_body = t.body - else: - n_value = rec(t.var_value) - n_body = rec(t.body) - return Let(t.var_name, n_value, n_body) - elif isinstance(t, Rec): - if t.var_name == name: - n_value = t.var_value - n_body = t.body - else: - n_value = rec(t.var_value) - n_body = rec(t.body) - return Rec(t.var_name, t.var_type, n_value, n_body) - elif isinstance(t, Annotation): - return Annotation(rec(t.expr), t.type) - elif isinstance(t, If): - return If(rec(t.cond), rec(t.then), rec(t.otherwise)) - assert False + case Var(tname) | Hole(tname): + if tname == name: + return rep + else: + return t + case Application(fun, arg): + return Application(fun=rec(fun), arg=rec(arg)) + case Abstraction(vname, body): + if vname == name: + return t + else: + return Abstraction(vname, rec(t.body)) + case Let(tname, val, body): + if tname == name: + n_value = val + n_body = body + else: + n_value = rec(val) + n_body = rec(body) + return Let(tname, n_value, n_body) + case Rec(tname, ty, val, body): + if tname == name: + n_value = val + n_body = body + else: + n_value = rec(val) + n_body = rec(body) + return Rec(tname, ty, n_value, n_body) + case Annotation(body, ty): + return Annotation(rec(body), ty) + case If(cond, then, otherwise): + return If(rec(cond), rec(then), rec(otherwise)) + case TypeApplication(expr, ty): + return TypeApplication(rec(expr), ty) + case _: + assert False def liquefy_app(app: Application) -> LiquidApp | None: @@ -264,7 +262,8 @@ def liquefy_app(app: Application) -> LiquidApp | None: fun.var_name, ), app.arg, - ), ) + ), + ) else: raise Exception(f"{app} is not a valid predicate.") diff --git a/aeon/synthesis_grammar/grammar.py b/aeon/synthesis_grammar/grammar.py index edcb9991..49d4f34e 100644 --- a/aeon/synthesis_grammar/grammar.py +++ b/aeon/synthesis_grammar/grammar.py @@ -18,7 +18,7 @@ from aeon.core.terms import If from aeon.core.terms import Literal from aeon.core.terms import Var -from aeon.core.types import AbstractionType, Type +from aeon.core.types import AbstractionType, Type, TypePolymorphism, TypeVar from aeon.core.types import BaseType from aeon.core.types import Bottom from aeon.core.types import RefinedType @@ -58,7 +58,7 @@ def extract_class_name(class_name: str) -> str: ] for prefix in prefixes: if class_name.startswith(prefix): - return class_name[len(prefix):] + return class_name[len(prefix) :] return class_name @@ -69,8 +69,7 @@ class GrammarError(Exception): # Protocol for classes that can have a get_core method class HasGetCore(Protocol): - def get_core(self): - ... + def get_core(self): ... classType = TypingType[HasGetCore] @@ -125,17 +124,21 @@ def get_core(self): try: if value is not None: if class_name_without_prefix == "Int" or class_name_without_prefix.startswith( - "Int", ): + "Int", + ): base = Literal(int(value), type=t_int) elif class_name_without_prefix == "Float" or class_name_without_prefix.startswith( - "Float", ): + "Float", + ): base = Literal(float(value), type=t_float) elif class_name_without_prefix == "Bool" or class_name_without_prefix.startswith( - "Bool", ): + "Bool", + ): value = str(value) == "true" base = Literal(value, type=t_bool) elif class_name_without_prefix == "String" or class_name_without_prefix.startswith( - "String", ): + "String", + ): v = str(value)[1:-1] base = Literal(str(v), type=t_string) else: @@ -154,13 +157,19 @@ def liquid_term_to_str(ty: RefinedType) -> str: base_type_str: str = ty.type.name refinement: LiquidTerm = ty.refinement if isinstance(refinement, LiquidApp): - refined_type_str = (str(ty.refinement).replace( - var, - base_type_str, - ).replace("(", "").replace( - ")", - "", - ).replace(" ", "_")) + refined_type_str = ( + str(ty.refinement) + .replace( + var, + base_type_str, + ) + .replace("(", "") + .replace( + ")", + "", + ) + .replace(" ", "_") + ) for op, op_str in aeon_prelude_ops_to_text.items(): refined_type_str = refined_type_str.replace(op, op_str) else: @@ -169,17 +178,27 @@ def liquid_term_to_str(ty: RefinedType) -> str: def process_type_name(ty: Type) -> str: - if isinstance(ty, RefinedType): - refinement_str = liquid_term_to_str(ty) - refined_type_name = f"Refined_{refinement_str}" - return refined_type_name - - elif isinstance(ty, Top) or isinstance(ty, Bottom): - return str(ty) - elif isinstance(ty, BaseType): - return str(ty.name) - else: - assert False + match ty: + case Top() | Bottom(): + return str(ty) + case BaseType(name): + return name + case RefinedType(_, _, _): + refinement_str = liquid_term_to_str(ty) + refined_type_name = f"Refined_{refinement_str}" + return refined_type_name + case AbstractionType(vname, vtype, rtype): + r1 = process_type_name(vtype) + r2 = process_type_name(rtype) + return f"{vname}_{r1}_to_{r2}" + case TypeVar(name): + return name + case TypePolymorphism(name, kind, body): + r = process_type_name(body) + return f"{name}_{kind}_{r}" + case _: + print(ty, type(ty)) + assert False def replace_tuples_with_lists(structure): @@ -231,21 +250,28 @@ def intervals_to_metahandlers( for interval in intervals_list: if isinstance(interval, Interval): if isinstance(ref, LiquidApp): - max_range = (max_number if isinstance( - interval.sup, - Infinity, - ) else interval.sup) # or 2 ** 31 - 1 + max_range = ( + max_number + if isinstance( + interval.sup, + Infinity, + ) + else interval.sup + ) # or 2 ** 31 - 1 max_range = max_range - 1 if interval.right_open else max_range - min_range = (min_number if isinstance( - interval.inf, - NegativeInfinity, - ) else interval.inf) # or -2 ** 31 + min_range = ( + min_number + if isinstance( + interval.inf, + NegativeInfinity, + ) + else interval.inf + ) # or -2 ** 31 min_range = min_range + 1 if interval.left_open else min_range metahandler_instance = gengy_metahandler(min_range, max_range) - metahandler_type = Annotated[ - python_type, metahandler_instance] # type: ignore + metahandler_type = Annotated[python_type, metahandler_instance] # type: ignore metahandler_list.append(metahandler_type) else: assert False @@ -266,7 +292,8 @@ def get_metahandler_union( def refined_type_to_metahandler( - ty: RefinedType, ) -> MetaHandlerGenerator | Union[MetaHandlerGenerator]: + ty: RefinedType, +) -> MetaHandlerGenerator | Union[MetaHandlerGenerator]: base_type_str = str(ty.type.name) gengy_metahandler = aeon_to_gengy_metahandlers[base_type_str] name, ref = ty.name, ty.refinement @@ -288,9 +315,8 @@ def refined_type_to_metahandler( def create_abstract_class(class_name: str) -> type: """Create and return a new abstract class with the given name.""" - class_name = "t_" + class_name if not class_name.startswith( - "t_") else class_name - return make_dataclass(class_name, [], bases=(ABC, )) + class_name = "t_" + class_name if not class_name.startswith("t_") else class_name + return make_dataclass(class_name, [], bases=(ABC,)) def create_literal_class( @@ -302,7 +328,7 @@ def create_literal_class( new_class = make_dataclass( "literal_" + class_name, [("value", value_type)], - bases=(base_class, ), + bases=(base_class,), ) return mk_method_core_literal(new_class) @@ -313,8 +339,7 @@ def handle_refined_type( grammar_nodes: list[type], ) -> tuple[list[type], type]: """Handle the creation of classes for refined types and update grammar nodes accordingly.""" - class_name = "t_" + class_name if not class_name.startswith( - "t_") else class_name + class_name = "t_" + class_name if not class_name.startswith("t_") else class_name new_abs_class = create_abstract_class(class_name) grammar_nodes.append(new_abs_class) @@ -368,8 +393,7 @@ def find_class_by_name( return grammar_nodes, new_abs_class - if ty is not None and isinstance(ty, RefinedType) and str( - ty.type.name) in aeon_to_gengy_metahandlers: + if ty is not None and isinstance(ty, RefinedType) and str(ty.type.name) in aeon_to_gengy_metahandlers: return handle_refined_type(class_name, ty, grammar_nodes) new_abs_class = create_abstract_class(class_name) @@ -379,7 +403,8 @@ def find_class_by_name( def is_valid_class_name(class_name: str) -> bool: return class_name not in prelude_ops and not class_name.startswith( - ("_anf_", "target"), ) + ("_anf_", "target"), + ) def get_attribute_type_name( @@ -412,10 +437,14 @@ def generate_class_components( fields = [] parent_name = "" while isinstance(class_type, AbstractionType): - attribute_name = (class_type.var_name.value if isinstance( - class_type.var_name, - Token, - ) else class_type.var_name) + attribute_name = ( + class_type.var_name.value + if isinstance( + class_type.var_name, + Token, + ) + else class_type.var_name + ) attribute_type = class_type.var_type attribute_type_name = get_attribute_type_name(attribute_type) @@ -446,7 +475,7 @@ def create_new_class(class_name: str, parent_class: type, fields=None) -> type: """Creates a new class with the given name, parent class, and fields.""" if fields is None: fields = [] - new_class = make_dataclass(class_name, fields, bases=(parent_class, )) + new_class = make_dataclass(class_name, fields, bases=(parent_class,)) new_class = mk_method_core(new_class) return new_class @@ -578,9 +607,8 @@ def create_if_class( def build_control_flow_grammar_nodes(grammar_nodes: list[type]) -> list[type]: types_names_set = { cls.__name__ - for cls in grammar_nodes if cls.__base__ is ABC and not any( - issubclass(cls, other) and cls is not other - for other in grammar_nodes) + for cls in grammar_nodes + if cls.__base__ is ABC and not any(issubclass(cls, other) and cls is not other for other in grammar_nodes) } for ty_name in types_names_set: grammar_nodes = create_if_class( diff --git a/aeon/synthesis_grammar/identification.py b/aeon/synthesis_grammar/identification.py index e961d282..6f6ea1ab 100644 --- a/aeon/synthesis_grammar/identification.py +++ b/aeon/synthesis_grammar/identification.py @@ -88,9 +88,10 @@ def get_holes_info( hs2 = get_holes_info(ctx, body, ty, targets, refined_types) return hs1 | hs2 case TypeApplication(body=body, type=argty): + _, bty = synth(ctx, body) argty = argty if refined_types else refined_to_unrefined_type(argty) - if isinstance(ty, TypePolymorphism): - ntype = substitute_vartype(ty.body, argty, ty.name) + if isinstance(bty, TypePolymorphism): + ntype = substitute_vartype(bty.body, argty, bty.name) ntype = ntype if refined_types else refined_to_unrefined_type(ntype) return get_holes_info(ctx, body, ntype, targets, refined_types) else: diff --git a/aeon/synthesis_grammar/synthesizer.py b/aeon/synthesis_grammar/synthesizer.py index 27633399..434efdd6 100644 --- a/aeon/synthesis_grammar/synthesizer.py +++ b/aeon/synthesis_grammar/synthesizer.py @@ -10,12 +10,14 @@ from typing import Callable import configparser +from geneticengine.representations.tree.initializations import MaxDepthDecider import multiprocess as mp from geneticengine.algorithms.gp.operators.combinators import ParallelStep, SequenceStep from geneticengine.algorithms.gp.operators.crossover import GenericCrossoverStep from geneticengine.algorithms.gp.operators.elitism import ElitismStep from geneticengine.algorithms.gp.operators.initializers import ( - StandardInitializer, ) + StandardInitializer, +) from geneticengine.algorithms.gp.operators.mutation import GenericMutationStep from geneticengine.algorithms.gp.operators.novelty import NoveltyStep from geneticengine.algorithms.gp.operators.selection import LexicaseSelection @@ -33,10 +35,12 @@ from geneticengine.problems import MultiObjectiveProblem, Problem, SingleObjectiveProblem from geneticengine.random.sources import RandomSource from geneticengine.representations.grammatical_evolution.dynamic_structured_ge import ( - DynamicStructuredGrammaticalEvolutionRepresentation, ) + DynamicStructuredGrammaticalEvolutionRepresentation, +) from geneticengine.representations.grammatical_evolution.ge import GrammaticalEvolutionRepresentation from geneticengine.representations.grammatical_evolution.structured_ge import ( - StructuredGrammaticalEvolutionRepresentation, ) + StructuredGrammaticalEvolutionRepresentation, +) from geneticengine.representations.tree.treebased import TreeBasedRepresentation from geneticengine.solutions import Individual from loguru import logger @@ -111,8 +115,13 @@ def __init__(self, target_fitness: float): def is_done(self, tracker: ProgressTracker): assert isinstance(tracker, MultiObjectiveProgressTracker) - comps = (tracker.get_best_individuals()[0].get_fitness( - tracker.get_problem(), ).fitness_components) + comps = ( + tracker.get_best_individuals()[0] + .get_fitness( + tracker.get_problem(), + ) + .fitness_components + ) return all(abs(c - self.target_fitness) < 0.001 for c in comps) @@ -150,28 +159,23 @@ def register( self.csv_writer = csv.writer(self.csv_file) if self.fields is None: self.fields = { - "Execution Time": - lambda t, i, _: - (time.monotonic_ns() - t.start_time) * 0.000000001, - "Fitness Aggregated": - lambda t, i, p: i.get_fitness(p).maximizing_aggregate, - "Phenotype": - lambda t, i, _: i.get_phenotype(), + "Execution Time": lambda t, i, _: (time.monotonic_ns() - t.start_time) * 0.000000001, + "Fitness Aggregated": lambda t, i, p: i.get_fitness(p).maximizing_aggregate, + "Phenotype": lambda t, i, _: i.get_phenotype(), } for comp in range(problem.number_of_objectives()): - self.fields[ - f"Fitness{comp}"] = lambda t, i, p: i.get_fitness( - p, ).fitness_components[comp] + self.fields[f"Fitness{comp}"] = lambda t, i, p: i.get_fitness( + p, + ).fitness_components[comp] if self.extra_fields is not None: for name in self.extra_fields: self.fields[name] = self.extra_fields[name] self.csv_writer.writerow([name for name in self.fields]) self.csv_file.flush() if not self.only_record_best_individuals or is_best: - self.csv_writer.writerow([ - self.fields[name](tracker, individual, problem) - for name in self.fields - ], ) + self.csv_writer.writerow( + [self.fields[name](tracker, individual, problem) for name in self.fields], + ) self.csv_file.flush() @@ -191,9 +195,12 @@ def parse_config(config_file: str, section: str) -> dict[str, Any]: def is_valid_term_literal(term_literal: Term) -> bool: - return (isinstance(term_literal, Literal) - and term_literal.type == BaseType("Int") - and isinstance(term_literal.value, int) and term_literal.value > 0) + return ( + isinstance(term_literal, Literal) + and term_literal.type == BaseType("Int") + and isinstance(term_literal.value, int) + and term_literal.value > 0 + ) def get_csv_file_path( @@ -258,19 +265,14 @@ def create_evaluator( "minimize_float", "multi_minimize_float", ] - used_decorators = [ - decorator for decorator in fitness_decorators - if decorator in metadata[fun_name] - ] + used_decorators = [decorator for decorator in fitness_decorators if decorator in metadata[fun_name]] assert used_decorators, "No fitness decorators used in metadata for function." objectives_list: list[Definition] = [ - objective for decorator in used_decorators - for objective in metadata[fun_name][decorator] + objective for decorator in used_decorators for objective in metadata[fun_name][decorator] ] programs_to_evaluate: list[Term] = [ - substitution(program, Var(objective.name), "main") - for objective in objectives_list + substitution(program, Var(objective.name), "main") for objective in objectives_list ] def evaluate_individual( @@ -289,10 +291,7 @@ def evaluate_individual( first_hole_name, individual_term, ) - results = [ - eval(substitution(p, individual_term, first_hole_name), ectx) - for p in programs_to_evaluate - ] + results = [eval(substitution(p, individual_term, first_hole_name), ectx) for p in programs_to_evaluate] result = results if len(results) > 1 else results[0] result = filter_nan_values(result) result_queue.put(result) @@ -350,10 +349,7 @@ def problem_for_fitness_function( ] if fun_name in metadata: - used_decorators = [ - decorator for decorator in fitness_decorators - if decorator in metadata[fun_name].keys() - ] + used_decorators = [decorator for decorator in fitness_decorators if decorator in metadata[fun_name].keys()] assert used_decorators, "No valid fitness decorators found." set_error_fitness(used_decorators) @@ -366,8 +362,13 @@ def problem_for_fitness_function( metadata, hole_names, ) - problem_type = (MultiObjectiveProblem if is_multiobjective( - used_decorators, ) else SingleObjectiveProblem) + problem_type = ( + MultiObjectiveProblem + if is_multiobjective( + used_decorators, + ) + else SingleObjectiveProblem + ) target_fitness: float | list[float] = ( 0 if isinstance(problem_type, SingleObjectiveProblem) else 0 ) # TODO: add support to maximize decorators @@ -418,9 +419,12 @@ def create_grammar( fun_name: str, metadata: dict[str, Any], ): - assert (len( - holes, - ) == 1), "More than one hole per function is not supported at the moment." + assert ( + len( + holes, + ) + == 1 + ), "More than one hole per function is not supported at the moment." hole_name = list(holes.keys())[0] ty, ctx = holes[hole_name] @@ -444,10 +448,13 @@ def random_search_synthesis( r = RandomSource(42) population = [rep.create_individual(r, max_depth) for _ in range(budget)] - population_with_score = [( - problem.evaluate(phenotype), - phenotype.get_core(), - ) for phenotype in population] + population_with_score = [ + ( + problem.evaluate(phenotype), + phenotype.get_core(), + ) + for phenotype in population + ] return min(population_with_score, key=lambda x: x[0])[1] @@ -472,20 +479,19 @@ def create_gp_step(problem: Problem, gp_params: dict[str, Any]): weights=[ gp_params["n_elites"], gp_params["novelty"], - gp_params["population_size"] - gp_params["n_elites"] - - gp_params["novelty"], + gp_params["population_size"] - gp_params["n_elites"] - gp_params["novelty"], ], ) def geneticengine_synthesis( - grammar: Grammar, - problem: Problem, - filename: str | None, - hole_name: str, - target_fitness: float | list[float], - gp_params: dict[str, Any] | None = None, - ui: SynthesisUI = SilentSynthesisUI(), + grammar: Grammar, + problem: Problem, + filename: str | None, + hole_name: str, + target_fitness: float | list[float], + gp_params: dict[str, Any] | None = None, + ui: SynthesisUI = SilentSynthesisUI(), ) -> Term: """Performs a synthesis procedure with GeneticEngine.""" # gp_params = gp_params or parse_config("aeon/synthesis_grammar/gpconfig.gengy", "DEFAULT") # TODO @@ -494,12 +500,12 @@ def geneticengine_synthesis( representation_name = gp_params.pop("representation") config_name = gp_params.pop("config_name") seed = gp_params["seed"] + r = NativeRandomSource(seed) assert isinstance(representation_name, str) assert isinstance(config_name, str) assert isinstance(seed, int) representation: type = representations[representation_name]( - grammar, - max_depth=gp_params["max_depth"], + grammar, decider=MaxDepthDecider(r, grammar, gp_params["max_depth"]) ) tracker: ProgressTracker @@ -517,9 +523,9 @@ def geneticengine_synthesis( LazyCSVRecorder( csv_file_path, problem, - only_record_best_individuals=gp_params[ - "only_record_best_inds"], - ), ) + only_record_best_individuals=gp_params["only_record_best_inds"], + ), + ) if isinstance(problem, SingleObjectiveProblem): tracker = SingleObjectiveProgressTracker( problem, @@ -556,12 +562,12 @@ def register( if isinstance(tracker, SingleObjectiveProgressTracker): search_budget = TargetFitness(target_fitness) elif isinstance(tracker, MultiObjectiveProgressTracker) and isinstance( - target_fitness, - list, + target_fitness, + list, ): search_budget = TargetMultiFitness(target_fitness) elif isinstance(tracker, MultiObjectiveProgressTracker) and isinstance( - target_fitness, + target_fitness, (float, int), ): search_budget = TargetMultiSameFitness(target_fitness) @@ -572,7 +578,7 @@ def register( problem=problem, budget=budget, representation=representation, - random=NativeRandomSource(seed), + random=r, tracker=tracker, population_size=gp_params["population_size"], population_initializer=StandardInitializer(), @@ -603,15 +609,15 @@ def set_error_fitness(decorators): def synthesize_single_function( - ctx: TypingContext, - ectx: EvaluationContext, - term: Term, - fun_name: str, - holes: dict[str, tuple[Type, TypingContext]], - metadata: Metadata, - filename: str | None, - synth_config: dict[str, Any] | None = None, - ui: SynthesisUI = SynthesisUI(), + ctx: TypingContext, + ectx: EvaluationContext, + term: Term, + fun_name: str, + holes: dict[str, tuple[Type, TypingContext]], + metadata: Metadata, + filename: str | None, + synth_config: dict[str, Any] | None = None, + ui: SynthesisUI = SynthesisUI(), ) -> Tuple[Term, Term]: # Step 1 Create a Single or Multi-Objective Problem instance. @@ -651,21 +657,22 @@ def synthesize_single_function( def synthesize( - ctx: TypingContext, - ectx: EvaluationContext, - term: Term, - targets: list[tuple[str, list[str]]], - metadata: Metadata, - filename: str | None = None, - synth_config: dict[str, Any] | None = None, - refined_grammar: bool = False, - ui: SynthesisUI = SynthesisUI(), + ctx: TypingContext, + ectx: EvaluationContext, + term: Term, + targets: list[tuple[str, list[str]]], + metadata: Metadata, + filename: str | None = None, + synth_config: dict[str, Any] | None = None, + refined_grammar: bool = False, + ui: SynthesisUI = SynthesisUI(), ) -> Tuple[Term, dict[str, Term]]: """Synthesizes code for multiple functions, each with multiple holes.""" program_holes = get_holes_info(ctx, term, top, targets, refined_grammar) assert len(program_holes) == len( - targets, ), "No support for function with more than one hole" + targets, + ), "No support for function with more than one hole" results = {} @@ -675,8 +682,7 @@ def synthesize( ectx, term, name, - {h: v - for h, v in program_holes.items() if h in holes_names}, + {h: v for h, v in program_holes.items() if h in holes_names}, metadata, filename, synth_config, diff --git a/tests/synth_fitness_test.py b/tests/synth_fitness_test.py index 70ae8320..4f0d45b2 100644 --- a/tests/synth_fitness_test.py +++ b/tests/synth_fitness_test.py @@ -11,7 +11,8 @@ from aeon.sugar.program import Definition from aeon.synthesis_grammar.grammar import mk_method_core_literal from aeon.synthesis_grammar.synthesizer import synthesize -from aeon.typechecking import elaborate_and_check_type_errors +from aeon.typechecking.elaboration import elaborate +from aeon.typechecking.typeinfer import check_type setup_logger() @@ -39,7 +40,8 @@ def synth (i: Int): Int { (?hole: Int) * i} prog = parse_program(code) p, ctx, ectx, _ = desugar(prog) p = ensure_anf(p) - elaborate_and_check_type_errors(ctx, p, top) + p = elaborate(ctx, p, top) + assert check_type(ctx, p, top) internal_minimize = Definition( name="__internal__minimize_int_synth_0", args=[], @@ -72,7 +74,8 @@ def synth (i:Int) : Int {(?hole: Int) * i} prog = parse_program(code) p, ctx, ectx, metadata = desugar(prog) p = ensure_anf(p) - elaborate_and_check_type_errors(ctx, p, top) + p = elaborate(ctx, p, top) + assert check_type(ctx, p, top) term, _ = synthesize(ctx, ectx, p, [("synth", ["hole"])], metadata) assert isinstance(term, Term)