diff --git a/aeon/__main__.py b/aeon/__main__.py index a9168476..377beba0 100644 --- a/aeon/__main__.py +++ b/aeon/__main__.py @@ -6,7 +6,6 @@ from aeon.backend.evaluator import EvaluationContext from aeon.backend.evaluator import eval from aeon.core.types import top -from aeon.frontend.anf_converter import ensure_anf from aeon.frontend.parser import parse_term from aeon.logger.logger import export_log from aeon.logger.logger import setup_logger @@ -37,27 +36,16 @@ def parse_arguments(): "--log", nargs="+", default="", - help= - """set log level: \nTRACE \nDEBUG \nINFO \nWARNINGS \nTYPECHECKER \nSYNTH_TYPE \nCONSTRAINT \nSYNTHESIZER + help="""set log level: \nTRACE \nDEBUG \nINFO \nWARNINGS \nTYPECHECKER \nSYNTH_TYPE \nCONSTRAINT \nSYNTHESIZER \nERROR \nCRITICAL""", ) - parser.add_argument("-f", - "--logfile", - action="store_true", - help="export log file") + parser.add_argument("-f", "--logfile", action="store_true", help="export log file") - parser.add_argument("-csv", - "--csv-synth", - action="store_true", - help="export synthesis csv file") + parser.add_argument("-csv", "--csv-synth", action="store_true", help="export synthesis csv file") - parser.add_argument("-gp", - "--gp-config", - help="path to the GP configuration file") + parser.add_argument("-gp", "--gp-config", help="path to the GP configuration file") - parser.add_argument("-csec", - "--config-section", - help="section name in the GP configuration file") + parser.add_argument("-csec", "--config-section", help="section name in the GP configuration file") parser.add_argument( "-d", @@ -73,7 +61,7 @@ def read_file(filename: str) -> str: return file.read() -def log_type_errors(errors: list[Exception | str]): +def log_type_errors(errors: list[Exception]): logger.log("TYPECHECKER", "-------------------------------") logger.log("TYPECHECKER", "+ Type Checking Error +") for error in errors: @@ -104,32 +92,30 @@ def log_type_errors(errors: list[Exception | str]): ) = desugar(prog) logger.info(core_ast) - core_ast_anf = ensure_anf(core_ast) - type_errors = check_type_errors(typing_ctx, core_ast_anf, top) + type_errors = check_type_errors(typing_ctx, core_ast, top) if type_errors: log_type_errors(type_errors) sys.exit(1) - incomplete_functions: list[tuple[ - str, - list[str]]] = incomplete_functions_and_holes(typing_ctx, core_ast_anf) + incomplete_functions: list[tuple[str, list[str]]] = incomplete_functions_and_holes(typing_ctx, core_ast) if incomplete_functions: filename = args.filename if args.csv_synth else None - synth_config = (parse_config(args.gp_config, args.config_section) - if args.gp_config and args.config_section else None) + synth_config = ( + parse_config(args.gp_config, args.config_section) if args.gp_config and args.config_section else None + ) synthesis_result = synthesize( typing_ctx, evaluation_ctx, - core_ast_anf, + core_ast, incomplete_functions, filename, synth_config, ) print(f"Best solution:{synthesis_result}") # print() - # pretty_print_term(ensure_anf(synthesis_result, 200)) + # pretty_print_term(synthesis_result, 200) sys.exit(1) eval(core_ast, evaluation_ctx) diff --git a/aeon/core/instantiation.py b/aeon/core/instantiation.py index e32314ff..df279ec1 100644 --- a/aeon/core/instantiation.py +++ b/aeon/core/instantiation.py @@ -3,7 +3,7 @@ from aeon.core.liquid import LiquidVar from aeon.core.liquid_ops import mk_liquid_and from aeon.core.substitutions import substitution_in_liquid -from aeon.core.types import AbstractionType +from aeon.core.types import AbstractionType, ExistentialType from aeon.core.types import BaseType from aeon.core.types import RefinedType from aeon.core.types import Type @@ -42,6 +42,15 @@ def rec(x): target.kind, type_substitution(target.body, alpha, beta), ) + elif isinstance(t, ExistentialType): + new_name = t.var_name + new_type = t.type + while new_name == alpha: + old_name = new_name + new_name = new_name + "_fresh" + new_type = type_substitution(new_type, old_name, TypeVar(new_name)) + + return ExistentialType(new_name, t.var_type, new_type) else: assert False diff --git a/aeon/core/substitutions.py b/aeon/core/substitutions.py index e259dc36..0142c818 100644 --- a/aeon/core/substitutions.py +++ b/aeon/core/substitutions.py @@ -18,7 +18,7 @@ from aeon.core.terms import Rec from aeon.core.terms import Term from aeon.core.terms import Var -from aeon.core.types import AbstractionType +from aeon.core.types import AbstractionType, ExistentialType from aeon.core.types import BaseType from aeon.core.types import Bottom from aeon.core.types import RefinedType @@ -89,12 +89,10 @@ def rec(x: Term): assert False -def substitution_in_liquid(t: LiquidTerm, rep: LiquidTerm, - name: str) -> LiquidTerm: +def substitution_in_liquid(t: LiquidTerm, rep: LiquidTerm, name: str) -> LiquidTerm: """substitutes name in the term t with the new replacement term rep.""" assert isinstance(rep, LiquidTerm) - if isinstance(t, (LiquidLiteralInt, LiquidLiteralBool, LiquidLiteralString, - LiquidLiteralFloat)): + if isinstance(t, (LiquidLiteralInt, LiquidLiteralBool, LiquidLiteralString, LiquidLiteralFloat)): return t elif isinstance(t, LiquidVar): if t.name == name: @@ -102,16 +100,14 @@ def substitution_in_liquid(t: LiquidTerm, rep: LiquidTerm, else: return t elif isinstance(t, LiquidApp): - return LiquidApp( - t.fun, [substitution_in_liquid(a, rep, name) for a in t.args]) + return LiquidApp(t.fun, [substitution_in_liquid(a, rep, name) for a in t.args]) elif isinstance(t, LiquidHole): if t.name == name: return rep else: return LiquidHole( t.name, - [(substitution_in_liquid(a, rep, name), t) - for (a, t) in t.argtypes], + [(substitution_in_liquid(a, rep, name), t) for (a, t) in t.argtypes], ) else: print(t, type(t)) @@ -166,6 +162,14 @@ def rec(t: Type) -> Type: t.type, substitution_in_liquid(t.refinement, replacement, name), ) + elif isinstance(t, ExistentialType): + if name == t.var_name: + new_name = t.var_name + "_" # TODO: Fresh name + nt = ExistentialType(new_name, t.var_type, substitution_in_type(t.type, Var(new_name), t.var_name)) + else: + nt = t + return ExistentialType(nt.var_name, nt.var_type, substitution_in_type(nt.type, rep, name)) + assert False @@ -223,16 +227,15 @@ def liquefy_app(app: Application) -> LiquidApp | None: elif isinstance(app.fun, Application): liquid_pseudo_fun = liquefy_app(app.fun) if liquid_pseudo_fun: - return LiquidApp(liquid_pseudo_fun.fun, - liquid_pseudo_fun.args + [arg]) + return LiquidApp(liquid_pseudo_fun.fun, liquid_pseudo_fun.args + [arg]) return None elif isinstance(app.fun, Let): return liquefy_app( Application( - substitution(app.fun.body, app.fun.var_value, - app.fun.var_name), + substitution(app.fun.body, app.fun.var_value, app.fun.var_name), app.arg, - ), ) + ), + ) assert False diff --git a/aeon/core/types.py b/aeon/core/types.py index 19ff2a39..501f60f9 100644 --- a/aeon/core/types.py +++ b/aeon/core/types.py @@ -129,16 +129,12 @@ def __hash__(self) -> int: return hash(self.var_name) + hash(self.var_type) + hash(self.type) +@dataclass class RefinedType(Type): name: str - type: BaseType | TypeVar + type: BaseType | TypeVar | Bottom | Top refinement: LiquidTerm - def __init__(self, name: str, ty: BaseType | TypeVar, refinement: LiquidTerm): - self.name = name - self.type = ty - self.refinement = refinement - def __repr__(self): return f"{{ {self.name}:{self.type} | {self.refinement} }}" @@ -154,6 +150,16 @@ def __hash__(self) -> int: return hash(self.name) + hash(self.type) + hash(self.refinement) +@dataclass +class ExistentialType(Type): + var_name: str + var_type: Type + type: Type + + def __str__(self) -> str: + return f"∃{self.var_name}:{self.var_type}, {self.type}" + + @dataclass class TypePolymorphism(Type): name: str # alpha @@ -163,7 +169,7 @@ class TypePolymorphism(Type): def extract_parts( t: Type, -) -> tuple[str, BaseType | TypeVar, LiquidTerm]: +) -> tuple[str, BaseType | TypeVar | Top | Bottom, LiquidTerm]: assert isinstance(t, BaseType) or isinstance(t, RefinedType) or isinstance(t, TypeVar) if isinstance(t, RefinedType): return (t.name, t.type, t.refinement) diff --git a/aeon/frontend/anf_converter.py b/aeon/frontend/anf_converter.py deleted file mode 100644 index ecab854f..00000000 --- a/aeon/frontend/anf_converter.py +++ /dev/null @@ -1,109 +0,0 @@ -from aeon.core.terms import ( - Abstraction, - Annotation, - Application, - If, - Let, - Literal, - Rec, - Term, - TypeAbstraction, - TypeApplication, - Var, -) - - -class ANFConverter: - """Recursive visitor that applies ANF transformation.""" - - def __init__(self, starting_counter: int = 0): - self.counter = starting_counter - - def fresh(self): - self.counter += 1 - return f"_anf_{self.counter}" - - def convert(self, t: Term): - """Converts term to ANF form.""" - - match t: - case If(cond=cond, then=then, otherwise=otherwise): - cond = self.convert(cond) - then = self.convert(then) - otherwise = self.convert(otherwise) - if isinstance(cond, Var) or isinstance(cond, Literal): - return If(cond, then, otherwise) - else: - v = self.fresh() - return self.convert(Let(v, cond, If(Var(v), then, otherwise))) - case Application(fun=fun, arg=arg): - fun = self.convert(fun) - - if isinstance(fun, Var) or isinstance(fun, Literal): - pass - elif isinstance(fun, Let): - return Let( - fun.var_name, - fun.var_value, - self.convert(Application(fun.body, arg)), - ) - else: - v = self.fresh() - return self.convert(Let(v, fun, Application(Var(v), arg))) - - arg = self.convert(arg) - if isinstance(arg, Var) or isinstance(arg, Literal): - return Application(fun, arg) - else: - v = self.fresh() - return self.convert(Let(v, arg, Application(fun, Var(v)))) - - case Let(var_name=name, var_value=value, body=body): - value = self.convert(value) - body = self.convert(body) - match value: - case Let(var_name=vname, var_value=vvalue, body=vbody): - assert name != vname - vvalue = self.convert(vvalue) - vbody = self.convert(vbody) - return Let( - vname, - vvalue, - self.convert(Let(name, vbody, body)), - ) - case Rec(var_name=vname, var_type=vtype, var_value=vvalue, body=vbody): - assert name != vname - vvalue = self.convert(vvalue) - vbody = self.convert(vbody) - return Rec( - vname, - vtype, - vvalue, - self.convert(Let(name, vbody, body)), - ) - case _: - return Let(name, value, body) - case Rec(var_name=name, var_type=type, var_value=value, body=body): - value = self.convert(value) - body = self.convert(body) - return Rec(name, type, value, body) - case Abstraction(var_name=name, body=body): - body = self.convert(body) - return Abstraction(var_name=name, body=body) - case Annotation(expr=expr, type=ty): - expr = self.convert(expr) - return Annotation(expr=expr, type=ty) - case TypeAbstraction(name=name, kind=kind, body=body): - body = self.convert(body) - return TypeAbstraction(name, kind, body) - case TypeApplication(body=body, type=type): - body = self.convert(body) - return TypeApplication(body, type) - case _: - return t - - -def ensure_anf(t: Term, starting_counter: int = 0) -> Term: - """Converts a term to ANF form.""" - - return ANFConverter(starting_counter=starting_counter).convert(t) diff --git a/aeon/synthesis_grammar/synthesizer.py b/aeon/synthesis_grammar/synthesizer.py index 1f5315ca..abebe05f 100644 --- a/aeon/synthesis_grammar/synthesizer.py +++ b/aeon/synthesis_grammar/synthesizer.py @@ -15,7 +15,7 @@ from aeon.core.types import BaseType, Top from aeon.core.types import Type from aeon.core.types import top -from aeon.frontend.anf_converter import ensure_anf + from aeon.synthesis_grammar.grammar import ( gen_grammar_nodes, get_grammar_node, @@ -158,7 +158,6 @@ def evaluate_individual(individual: classType, result_queue: mp.Queue) -> Any: try: first_hole_name = holes[0] individual_term = individual.get_core() # type: ignore - individual_term = ensure_anf(individual_term, 10000000) new_program = substitution(program_template, individual_term, first_hole_name) check_type_errors(ctx, new_program, Top()) result = eval(new_program, ectx) diff --git a/aeon/typechecking/typeinfer.py b/aeon/typechecking/typeinfer.py index 91accf01..7775586e 100644 --- a/aeon/typechecking/typeinfer.py +++ b/aeon/typechecking/typeinfer.py @@ -1,9 +1,10 @@ from __future__ import annotations +from typing import Tuple from loguru import logger from aeon.core.instantiation import type_substitution -from aeon.core.liquid import LiquidApp, LiquidHole +from aeon.core.liquid import LiquidApp, LiquidHole, LiquidTerm from aeon.core.liquid import LiquidLiteralBool from aeon.core.liquid import LiquidLiteralFloat from aeon.core.liquid import LiquidLiteralInt @@ -24,7 +25,7 @@ from aeon.core.terms import TypeAbstraction from aeon.core.terms import TypeApplication from aeon.core.terms import Var -from aeon.core.types import AbstractionType +from aeon.core.types import AbstractionType, Bottom, ExistentialType, Top from aeon.core.types import BaseKind from aeon.core.types import BaseType from aeon.core.types import RefinedType @@ -50,7 +51,6 @@ from aeon.typechecking.entailment import entailment from aeon.verification.helpers import simplify_constraint from aeon.verification.horn import fresh -from aeon.verification.sub import ensure_refined from aeon.verification.sub import implication_constraint from aeon.verification.sub import sub from aeon.verification.vcs import Conjunction @@ -76,6 +76,38 @@ def __str__(self): return f"Constraint violated when checking if {self.t} : {self.ty}: \n {self.ks}" +def eq_ref(var_name: str, type_name: str) -> LiquidTerm: + return LiquidApp( + "==", + [ + LiquidVar(var_name), + LiquidVar(type_name), + ], + ) + + +def and_ref(cond1: LiquidTerm, cond2: LiquidTerm) -> LiquidTerm: + return LiquidApp("&&", [cond1, cond2]) + + +def refine_type(ctx: TypingContext, ty: Type, vname: str): + """The refine function is the selfication with support for existentials""" + match ty: + case BaseType(name=_) | Top() | Bottom(): + name = ctx.fresh_var() + return RefinedType(name, ty, eq_ref(name, vname)) + case RefinedType(name=name, type=ty, refinement=cond): + name = ctx.fresh_var() + assert name != vname + return RefinedType(name, ty, and_ref(cond, eq_ref(name, vname))) + case ExistentialType(var_name=var_name, var_type=var_type, type=ity): + return ExistentialType(var_name, var_type, refine_type(ctx, ity, vname)) + case AbstractionType(var_name=var_name, var_type=var_type, type=type): + return ty + case _: + assert False + + def argument_is_typevar(ty: Type): return ( isinstance(ty, TypeVar) @@ -87,6 +119,15 @@ def argument_is_typevar(ty: Type): ) +def extract_abstraction_type(ty: Type) -> Tuple[AbstractionType, list[Tuple[str, Type]]]: + binders = [] + while isinstance(ty, ExistentialType): + binders.append((ty.var_name, ty.var_type)) + ty = ty.type + assert isinstance(ty, AbstractionType) + return ty, binders + + def prim_litbool(t: bool) -> RefinedType: if t: return RefinedType("v", t_bool, LiquidVar("v")) @@ -220,52 +261,51 @@ def synth(ctx: TypingContext, t: Term) -> tuple[Constraint, Type]: if t.name in ops: return (ctrue, prim_op(t.name)) ty = ctx.type_of(t.name) - if isinstance(ty, BaseType) or isinstance(ty, RefinedType): - ty = ensure_refined(ty) - # assert ty.name != t.name - if ty.name == t.name: - ty = renamed_refined_type(ty) - # Self - ty = RefinedType( - ty.name, - ty.type, - LiquidApp( - "&&", - [ - ty.refinement, - LiquidApp( - "==", - [ - LiquidVar(ty.name), - LiquidVar(t.name), - ], - ), - ], - ), - ) - if not ty: - raise CouldNotGenerateConstraintException( - f"Variable {t.name} not in context", - ) - return (ctrue, ty) + if ty is not None: + return (ctrue, refine_type(ctx, ty, t.name)) + raise CouldNotGenerateConstraintException( + f"Variable {t.name} not in context", + ) elif isinstance(t, Application): - (c, ty) = synth(ctx, t.fun) - if isinstance(ty, AbstractionType): - # This is the solution to handle polymorphic "==" in refinements. - if argument_is_typevar(ty.var_type): - (_, b, _) = extract_parts(ty.var_type) + (c1, ty1) = synth(ctx, t.fun) + (c2, ty2) = synth(ctx, t.arg) + + abstraction_type, binders = extract_abstraction_type(ty1) + + if isinstance(abstraction_type, AbstractionType): + argument_type = abstraction_type.var_type + return_type = abstraction_type.type + + if argument_is_typevar(argument_type): + (_, b, _) = extract_parts(argument_type) assert isinstance(b, TypeVar) - (cp, at) = synth(ctx, t.arg) - if isinstance(at, RefinedType): - at = at.type # This is a hack before inference - return_type = substitute_vartype(ty.type, at, b.name) + argument_type = substitute_vartype(argument_type, ty2, b.name) + return_type = substitute_vartype(return_type, ty2, b.name) + c3 = ctrue else: - cp = check(ctx, t.arg, ty.var_type) - return_type = ty.type - t_subs = substitution_in_type(return_type, t.arg, ty.var_name) - c0 = Conjunction(c, cp) - # vs: list[str] = list(variables_free_in(c0)) - return (c0, t_subs) + c3 = sub(ty2, argument_type) + new_name = ctx.fresh_var() + return_type = type_substitution(return_type, abstraction_type.var_name, new_name) + nt = ExistentialType(var_name=new_name, var_type=argument_type, type=return_type) + for aname, aty in binders[::-1]: + nt = ExistentialType(aname, aty, nt) + return Conjunction(Conjunction(c1, c2), c3), nt + + # This is the solution to handle polymorphic "==" in refinements. + # if argument_is_typevar(ty.var_type): + # (_, b, _) = extract_parts(ty.var_type) + # assert isinstance(b, TypeVar) + # (cp, at) = synth(ctx, t.arg) + # if isinstance(at, RefinedType): + # at = at.type # This is a hack before inference + # return_type = substitute_vartype(ty.type, at, b.name) + # else: + # cp = check(ctx, t.arg, ty.var_type) + # return_type = ty.type + # t_subs = substitution_in_type(return_type, t.arg, ty.var_name) + # c0 = Conjunction(c, cp) + # # vs: list[str] = list(variables_free_in(c0)) + # return (c0, t_subs) else: raise CouldNotGenerateConstraintException( f"Application {t} is not a function.", @@ -342,7 +382,7 @@ def check_(ctx: TypingContext, t: Term, ty: Type) -> Constraint: # patterm matching term -@wrap_checks # DEMO1 +# @wrap_checks # DEMO1 def check(ctx: TypingContext, t: Term, ty: Type) -> Constraint: if isinstance(t, Abstraction) and isinstance( ty, @@ -398,35 +438,45 @@ def check(ctx: TypingContext, t: Term, ty: Type) -> Constraint: return Conjunction(c, cp) + + def check_type(ctx: TypingContext, t: Term, ty: Type) -> bool: """Returns whether expression t has type ty in context ctx.""" try: constraint = check(ctx, t, ty) return entailment(ctx, constraint) - except CouldNotGenerateConstraintException: + except CouldNotGenerateConstraintException as e: + logger.info(f"Could not generate constraint: f{e}") return False - except FailedConstraintException: + except FailedConstraintException as e: + logger.info(f"Could not prove constraint: f{e}") return False +class CouldNotProveTypingRelation(Exception): + def __init__(self, context: TypingContext, term: Term, type: Type): + self.context = context + self.term = term + self.type = type + + def __str__(self): + return f"Could not prove typing relation (Context: {self.context}) (Term: {self.term}) (Type: {self.type})." + + def check_type_errors( ctx: TypingContext, t: Term, ty: Type, -) -> list[Exception | str]: +) -> list[Exception]: """Checks whether t as type ty in ctx, but returns a list of errors.""" try: constraint = check(ctx, t, ty) + print(f"Constraint: {constraint}") r = entailment(ctx, constraint) if r: return [] else: - return [ - "Could not prove typing relation.", - f"Context: {ctx}", - f"Term: {t}", - f"Type: {ty}", - ] + return [CouldNotProveTypingRelation(ctx, t, ty)] except CouldNotGenerateConstraintException as e: return [e] except FailedConstraintException as e: @@ -434,6 +484,7 @@ def check_type_errors( def is_subtype(ctx: TypingContext, subt: Type, supt: Type): + assert not isinstance(supt, ExistentialType) if args_size_of_type(subt) != args_size_of_type(supt): return False if subt == supt: diff --git a/aeon/typechecking/well_formed.py b/aeon/typechecking/well_formed.py index da4f210a..ce3fd176 100644 --- a/aeon/typechecking/well_formed.py +++ b/aeon/typechecking/well_formed.py @@ -3,7 +3,7 @@ from aeon.core.liquid import LiquidLiteralBool from aeon.core.liquid import LiquidVar from aeon.core.substitutions import substitution_in_liquid -from aeon.core.types import AbstractionType +from aeon.core.types import AbstractionType, ExistentialType from aeon.core.types import BaseType from aeon.core.types import extract_parts from aeon.core.types import Kind @@ -45,7 +45,8 @@ def wellformed(ctx: TypingContext, t: Type, k: Kind = StarKind()) -> bool: wf_all = ( isinstance(t, TypePolymorphism) and k == StarKind() and wellformed(ctx.with_typevar(t.name, t.kind), t.body) ) - return wf_norefinement or wf_var or wf_base or wf_fun or wf_all + wf_existential = isinstance(t, ExistentialType) and wellformed(ctx.with_var(t.var_name, t.var_type), t.type, k) + return wf_norefinement or wf_var or wf_base or wf_fun or wf_all or wf_existential def inhabited(ctx: TypingContext, ty: Type) -> bool: diff --git a/aeon/verification/sub.py b/aeon/verification/sub.py index 796ab8b9..b4a3673b 100644 --- a/aeon/verification/sub.py +++ b/aeon/verification/sub.py @@ -7,7 +7,7 @@ from aeon.core.substitutions import substitution_in_liquid from aeon.core.substitutions import substitution_in_type from aeon.core.terms import Var -from aeon.core.types import AbstractionType, TypeVar +from aeon.core.types import AbstractionType, ExistentialType, TypeVar from aeon.core.types import BaseType from aeon.core.types import Bottom from aeon.core.types import RefinedType @@ -25,7 +25,7 @@ def ensure_refined(t: Type) -> RefinedType: if isinstance(t, RefinedType): return t - elif isinstance(t, BaseType): + elif isinstance(t, BaseType) or isinstance(t, Top) or isinstance(t, Bottom) or isinstance(t, TypeVar): return RefinedType(f"singleton_{t}", t, LiquidLiteralBool(True)) assert False @@ -51,10 +51,22 @@ def implication_constraint(name: str, t: Type, c: Constraint) -> Constraint: return c elif isinstance(t, Top): return c + elif isinstance(t, ExistentialType): + # TODO: Existential + pass logger.debug(f"{name} : {t} => {c} ({type(t)})") assert False +def ensure_safe_type(t: Type) -> BaseType: + if isinstance(t, Top) or isinstance(t, Bottom): + return BaseType("Bool") + elif isinstance(t, BaseType): + return t + print(f"Unsafe: {t}") + assert False + + def sub(t1: Type, t2: Type) -> Constraint: if isinstance(t2, Top) or isinstance(t1, Bottom): return ctrue @@ -62,7 +74,12 @@ def sub(t1: Type, t2: Type) -> Constraint: t1 = ensure_refined(t1) if isinstance(t2, BaseType): t2 = ensure_refined(t2) - if isinstance(t1, RefinedType) and isinstance(t2, RefinedType): + if isinstance(t2, ExistentialType): + assert False + if isinstance(t1, ExistentialType): + c = sub(t1.type, t2) + return implication_constraint(t1.var_name, t1.var_type, c) + elif isinstance(t1, RefinedType) and isinstance(t2, RefinedType): if isinstance(t1.type, Bottom) or isinstance(t2.type, Top): return ctrue elif t1.type == t2.type: diff --git a/aeon/verification/vcs.py b/aeon/verification/vcs.py index 7e4867e1..1da19dc2 100644 --- a/aeon/verification/vcs.py +++ b/aeon/verification/vcs.py @@ -9,7 +9,7 @@ from aeon.core.liquid import LiquidLiteralString from aeon.core.liquid import LiquidTerm from aeon.core.liquid import LiquidVar -from aeon.core.types import AbstractionType +from aeon.core.types import AbstractionType, ExistentialType from aeon.core.types import BaseType @@ -47,7 +47,7 @@ def __repr__(self): @dataclass class Implication(Constraint): name: str - base: BaseType + base: BaseType | ExistentialType pred: LiquidTerm seq: Constraint diff --git a/tests/end_to_end_test.py b/tests/end_to_end_test.py index b472f7f3..6bd0ba8c 100644 --- a/tests/end_to_end_test.py +++ b/tests/end_to_end_test.py @@ -3,7 +3,7 @@ from aeon.backend.evaluator import EvaluationContext from aeon.backend.evaluator import eval from aeon.core.types import top -from aeon.frontend.anf_converter import ensure_anf + from aeon.frontend.parser import parse_term from aeon.frontend.parser import parse_type from aeon.prelude.prelude import evaluation_vars @@ -17,12 +17,11 @@ def check_compile(source, ty, res): p = parse_term(source) - p = ensure_anf(p) assert check_type(ctx, p, ty) assert eval(p, ectx) == res -def test_anf(): +def test_multiple_applications(): source = r"""let f : (x:Int) -> (y:Int) -> Int = (\x -> (\y -> x)) in let r = f (f 1 2) (f 2 3) in r""" diff --git a/tests/frontend_test.py b/tests/frontend_test.py index 80e6d4ac..4819598d 100644 --- a/tests/frontend_test.py +++ b/tests/frontend_test.py @@ -19,14 +19,13 @@ from aeon.core.types import t_int from aeon.core.types import TypePolymorphism from aeon.core.types import TypeVar -from aeon.frontend.anf_converter import ensure_anf + from aeon.frontend.parser import parse_term from aeon.frontend.parser import parse_type from aeon.utils.ast_helpers import false from aeon.utils.ast_helpers import i0 from aeon.utils.ast_helpers import i1 from aeon.utils.ast_helpers import i2 -from aeon.utils.ast_helpers import is_anf from aeon.utils.ast_helpers import mk_binop from aeon.utils.ast_helpers import true @@ -126,12 +125,6 @@ def test_operators(): assert parse_term("1 % 1") == mk_binop(lambda: "t", "%", i1, i1) -def test_precedence(): - t1 = parse_term("1 + 2 * 0") - at1 = ensure_anf(t1) - assert is_anf(at1) - - def test_let(): assert parse_term("let x = 1 in x") == Let("x", i1, Var("x")) diff --git a/tests/hole_test.py b/tests/hole_test.py index e5179426..9c089503 100644 --- a/tests/hole_test.py +++ b/tests/hole_test.py @@ -1,5 +1,5 @@ from aeon.core.types import top -from aeon.frontend.anf_converter import ensure_anf + from aeon.sugar.desugar import desugar, apply_decorators_in_program from aeon.sugar.parser import parse_program from aeon.synthesis_grammar.identification import incomplete_functions_and_holes @@ -10,9 +10,8 @@ def extract_target_functions(source): prog = parse_program(source) prog = apply_decorators_in_program(prog) core, ctx, _ = desugar(prog) - core_anf = ensure_anf(core) - check_type_errors(ctx, core_anf, top) - return incomplete_functions_and_holes(ctx, core_anf) + check_type_errors(ctx, core, top) + return incomplete_functions_and_holes(ctx, core) def test_hole_identification(): diff --git a/tests/infer_test.py b/tests/infer_test.py index 115a3f05..43ef0f5f 100644 --- a/tests/infer_test.py +++ b/tests/infer_test.py @@ -1,7 +1,7 @@ from __future__ import annotations from aeon.core.types import t_int -from aeon.frontend.anf_converter import ensure_anf + from aeon.frontend.parser import parse_term from aeon.frontend.parser import parse_type from aeon.typechecking.context import EmptyContext @@ -14,7 +14,7 @@ def tt(e: str, t: str, vars: dict[str, str] = {}): ctx = build_context({k: parse_type(v) for (k, v) in vars.items()}) - term = ensure_anf(parse_term(e)) + term = parse_term(e) return check_type(ctx, term, parse_type(t)) diff --git a/tests/optimization_decorators_test.py b/tests/optimization_decorators_test.py index 3d0ef0f9..04036487 100644 --- a/tests/optimization_decorators_test.py +++ b/tests/optimization_decorators_test.py @@ -1,6 +1,6 @@ from aeon.core.terms import Term from aeon.core.types import top -from aeon.frontend.anf_converter import ensure_anf + from aeon.sugar.desugar import desugar from aeon.sugar.parser import parse_program from aeon.sugar.program import Program @@ -11,9 +11,8 @@ def extract_core(source: str) -> Term: prog = parse_program(source) core, ctx, _ = desugar(prog) - core_anf = ensure_anf(core) - check_type_errors(ctx, core_anf, top) - return core_anf + check_type_errors(ctx, core, top) + return core def test_hole_minimize_int(): @@ -43,6 +42,7 @@ def main(args:Int) : Unit { evaluation_ctx, ) = desugar(prog) - core_ast_anf = ensure_anf(core_ast) - type_errors = check_type_errors(typing_ctx, core_ast_anf, top) + type_errors = check_type_errors(typing_ctx, core_ast, top) + for te in type_errors: + print(te) assert len(type_errors) == 0 diff --git a/tests/smt_test.py b/tests/smt_test.py index 6fd671c4..03b3a714 100644 --- a/tests/smt_test.py +++ b/tests/smt_test.py @@ -8,7 +8,7 @@ from aeon.core.types import BaseType from aeon.core.types import t_int from aeon.core.types import top -from aeon.frontend.anf_converter import ensure_anf + from aeon.sugar.desugar import desugar from aeon.sugar.parser import parse_program from aeon.sugar.program import Program @@ -21,17 +21,15 @@ def extract_core(source: str) -> Term: prog = parse_program(source) core, ctx, _ = desugar(prog) - core_anf = ensure_anf(core) - check_type_errors(ctx, core_anf, top) - return core_anf + check_type_errors(ctx, core, top) + return core example = Implication( "x", t_int, LiquidApp("==", [LiquidVar("x"), LiquidLiteralInt(3)]), - LiquidConstraint(LiquidApp( - "==", [LiquidVar("x"), LiquidLiteralInt(3)])), + LiquidConstraint(LiquidApp("==", [LiquidVar("x"), LiquidLiteralInt(3)])), ) @@ -47,8 +45,7 @@ def test_smt_example3(): "y", BaseType("a"), LiquidApp("==", [LiquidVar("x"), LiquidVar("y")]), - LiquidConstraint(LiquidApp( - "==", [LiquidVar("x"), LiquidVar("y")])), + LiquidConstraint(LiquidApp("==", [LiquidVar("x"), LiquidVar("y")])), ), ) @@ -82,8 +79,7 @@ def main (x:Int) : Unit { evaluation_ctx, ) = desugar(prog) - core_ast_anf = ensure_anf(core_ast) - type_errors = check_type_errors(typing_ctx, core_ast_anf, top) + type_errors = check_type_errors(typing_ctx, core_ast, top) assert len(type_errors) == 0 @@ -106,6 +102,5 @@ def main (x:Int) : Unit { evaluation_ctx, ) = desugar(prog) - core_ast_anf = ensure_anf(core_ast) - type_errors = check_type_errors(typing_ctx, core_ast_anf, top) + type_errors = check_type_errors(typing_ctx, core_ast, top) assert len(type_errors) == 0 diff --git a/tests/substitutions_test.py b/tests/substitutions_test.py index 96ed4161..08ccbb72 100644 --- a/tests/substitutions_test.py +++ b/tests/substitutions_test.py @@ -2,6 +2,7 @@ from aeon.core.substitutions import substitution from aeon.core.substitutions import substitution_in_type +from aeon.core.types import ExistentialType from aeon.frontend.parser import parse_term from aeon.frontend.parser import parse_type @@ -66,3 +67,11 @@ def test_substitution_autorename_ref(): assert substitution_in_type(ty, parse_term("y"), "z") == parse_type( r"(y1:Int) -> {x : Int | y1 > y}", ) + + +def test_substitution_type_exist(): + ty = ExistentialType(var_name="z", var_type=parse_type("Int"), type=parse_type(r"(y:Int) -> {x : Int | x > z}")) + subs = substitution_in_type(ty, parse_term("3"), "z") + + assert subs.var_name != "z" # alpha renaming + assert "3" not in str(subs) diff --git a/tests/synth_fitness_test.py b/tests/synth_fitness_test.py index 6c136d50..136ed3ac 100644 --- a/tests/synth_fitness_test.py +++ b/tests/synth_fitness_test.py @@ -4,7 +4,7 @@ from aeon.core.terms import Term from aeon.core.types import top -from aeon.frontend.anf_converter import ensure_anf + from aeon.logger.logger import setup_logger from aeon.sugar.desugar import desugar from aeon.sugar.parser import parse_program @@ -37,7 +37,6 @@ def __internal__fitness_function_synth : Int = year - synth(7); """ prog = parse_program(code) p, ctx, ectx = desugar(prog) - p = ensure_anf(p) check_type_errors(ctx, p, top) term = synthesize(ctx, ectx, p, [("synth", ["hole"])]) @@ -51,7 +50,6 @@ def synth (i:Int) : Int {(?hole: Int) * i} """ prog = parse_program(code) p, ctx, ectx = desugar(prog) - p = ensure_anf(p) check_type_errors(ctx, p, top) term = synthesize(ctx, ectx, p, [("synth", ["hole"])]) diff --git a/tests/wellformed_test.py b/tests/wellformed_test.py index a3fafcb3..54842391 100644 --- a/tests/wellformed_test.py +++ b/tests/wellformed_test.py @@ -1,6 +1,6 @@ from __future__ import annotations -from aeon.core.types import BaseKind +from aeon.core.types import BaseKind, ExistentialType from aeon.core.types import StarKind from aeon.core.types import t_bool from aeon.core.types import t_int @@ -61,3 +61,10 @@ def test_poly(): TypePolymorphism("a", StarKind(), TypeVar("a")), BaseKind(), ) + + +def test_wf_existential(): + assert wellformed( + empty, + TypePolymorphism("a", BaseKind(), ExistentialType(var_name="x", var_type=parse_type("Int"), type=TypeVar("a"))), + )