Skip to content

Commit

Permalink
Integrated new synthesis with polymorphism
Browse files Browse the repository at this point in the history
  • Loading branch information
alcides committed Nov 27, 2024
1 parent e72c42f commit a071516
Show file tree
Hide file tree
Showing 5 changed files with 225 additions and 188 deletions.
85 changes: 42 additions & 43 deletions aeon/core/substitutions.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def substitution_in_liquid(
"""Substitutes name in the term t with the new replacement term rep."""
assert isinstance(rep, LiquidTerm)
if isinstance(
t,
t,
(
LiquidLiteralInt,
LiquidLiteralBool,
Expand All @@ -131,10 +131,7 @@ def substitution_in_liquid(
else:
return LiquidHornApplication(
t.name,
[
(substitution_in_liquid(a, rep, name), t)
for (a, t) in t.argtypes
],
[(substitution_in_liquid(a, rep, name), t) for (a, t) in t.argtypes],
)
else:
assert False
Expand Down Expand Up @@ -198,44 +195,45 @@ def substitution(t: Term, rep: Term, name: str) -> Term:
def rec(x: Term):
return substitution(x, rep, name)

if isinstance(t, Literal):
return t
elif isinstance(t, Var):
if t.name == name:
return rep
return t
elif isinstance(t, Hole):
if t.name == name:
return rep
return t
elif isinstance(t, Application):
return Application(fun=rec(t.fun), arg=rec(t.arg))
elif isinstance(t, Abstraction):
if t.var_name == name:
match t:
case Literal(_):
return t
else:
return Abstraction(t.var_name, rec(t.body))
elif isinstance(t, Let):
if t.var_name == name:
n_value = t.var_value
n_body = t.body
else:
n_value = rec(t.var_value)
n_body = rec(t.body)
return Let(t.var_name, n_value, n_body)
elif isinstance(t, Rec):
if t.var_name == name:
n_value = t.var_value
n_body = t.body
else:
n_value = rec(t.var_value)
n_body = rec(t.body)
return Rec(t.var_name, t.var_type, n_value, n_body)
elif isinstance(t, Annotation):
return Annotation(rec(t.expr), t.type)
elif isinstance(t, If):
return If(rec(t.cond), rec(t.then), rec(t.otherwise))
assert False
case Var(tname) | Hole(tname):
if tname == name:
return rep
else:
return t
case Application(fun, arg):
return Application(fun=rec(fun), arg=rec(arg))
case Abstraction(vname, body):
if vname == name:
return t
else:
return Abstraction(vname, rec(t.body))
case Let(tname, val, body):
if tname == name:
n_value = val
n_body = body
else:
n_value = rec(val)
n_body = rec(body)
return Let(tname, n_value, n_body)
case Rec(tname, ty, val, body):
if tname == name:
n_value = val
n_body = body
else:
n_value = rec(val)
n_body = rec(body)
return Rec(tname, ty, n_value, n_body)
case Annotation(body, ty):
return Annotation(rec(body), ty)
case If(cond, then, otherwise):
return If(rec(cond), rec(then), rec(otherwise))
case TypeApplication(expr, ty):
return TypeApplication(rec(expr), ty)
case _:
assert False


def liquefy_app(app: Application) -> LiquidApp | None:
Expand Down Expand Up @@ -264,7 +262,8 @@ def liquefy_app(app: Application) -> LiquidApp | None:
fun.var_name,
),
app.arg,
), )
),
)
else:
raise Exception(f"{app} is not a valid predicate.")

Expand Down
136 changes: 82 additions & 54 deletions aeon/synthesis_grammar/grammar.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from aeon.core.terms import If
from aeon.core.terms import Literal
from aeon.core.terms import Var
from aeon.core.types import AbstractionType, Type
from aeon.core.types import AbstractionType, Type, TypePolymorphism, TypeVar
from aeon.core.types import BaseType
from aeon.core.types import Bottom
from aeon.core.types import RefinedType
Expand Down Expand Up @@ -58,7 +58,7 @@ def extract_class_name(class_name: str) -> str:
]
for prefix in prefixes:
if class_name.startswith(prefix):
return class_name[len(prefix):]
return class_name[len(prefix) :]
return class_name


Expand All @@ -69,8 +69,7 @@ class GrammarError(Exception):
# Protocol for classes that can have a get_core method
class HasGetCore(Protocol):

def get_core(self):
...
def get_core(self): ...


classType = TypingType[HasGetCore]
Expand Down Expand Up @@ -125,17 +124,21 @@ def get_core(self):
try:
if value is not None:
if class_name_without_prefix == "Int" or class_name_without_prefix.startswith(
"Int", ):
"Int",
):
base = Literal(int(value), type=t_int)
elif class_name_without_prefix == "Float" or class_name_without_prefix.startswith(
"Float", ):
"Float",
):
base = Literal(float(value), type=t_float)
elif class_name_without_prefix == "Bool" or class_name_without_prefix.startswith(
"Bool", ):
"Bool",
):
value = str(value) == "true"
base = Literal(value, type=t_bool)
elif class_name_without_prefix == "String" or class_name_without_prefix.startswith(
"String", ):
"String",
):
v = str(value)[1:-1]
base = Literal(str(v), type=t_string)
else:
Expand All @@ -154,13 +157,19 @@ def liquid_term_to_str(ty: RefinedType) -> str:
base_type_str: str = ty.type.name
refinement: LiquidTerm = ty.refinement
if isinstance(refinement, LiquidApp):
refined_type_str = (str(ty.refinement).replace(
var,
base_type_str,
).replace("(", "").replace(
")",
"",
).replace(" ", "_"))
refined_type_str = (
str(ty.refinement)
.replace(
var,
base_type_str,
)
.replace("(", "")
.replace(
")",
"",
)
.replace(" ", "_")
)
for op, op_str in aeon_prelude_ops_to_text.items():
refined_type_str = refined_type_str.replace(op, op_str)
else:
Expand All @@ -169,17 +178,27 @@ def liquid_term_to_str(ty: RefinedType) -> str:


def process_type_name(ty: Type) -> str:
if isinstance(ty, RefinedType):
refinement_str = liquid_term_to_str(ty)
refined_type_name = f"Refined_{refinement_str}"
return refined_type_name

elif isinstance(ty, Top) or isinstance(ty, Bottom):
return str(ty)
elif isinstance(ty, BaseType):
return str(ty.name)
else:
assert False
match ty:
case Top() | Bottom():
return str(ty)
case BaseType(name):
return name
case RefinedType(_, _, _):
refinement_str = liquid_term_to_str(ty)
refined_type_name = f"Refined_{refinement_str}"
return refined_type_name
case AbstractionType(vname, vtype, rtype):
r1 = process_type_name(vtype)
r2 = process_type_name(rtype)
return f"{vname}_{r1}_to_{r2}"
case TypeVar(name):
return name
case TypePolymorphism(name, kind, body):
r = process_type_name(body)
return f"{name}_{kind}_{r}"
case _:
print(ty, type(ty))
assert False


def replace_tuples_with_lists(structure):
Expand Down Expand Up @@ -231,21 +250,28 @@ def intervals_to_metahandlers(
for interval in intervals_list:
if isinstance(interval, Interval):
if isinstance(ref, LiquidApp):
max_range = (max_number if isinstance(
interval.sup,
Infinity,
) else interval.sup) # or 2 ** 31 - 1
max_range = (
max_number
if isinstance(
interval.sup,
Infinity,
)
else interval.sup
) # or 2 ** 31 - 1
max_range = max_range - 1 if interval.right_open else max_range

min_range = (min_number if isinstance(
interval.inf,
NegativeInfinity,
) else interval.inf) # or -2 ** 31
min_range = (
min_number
if isinstance(
interval.inf,
NegativeInfinity,
)
else interval.inf
) # or -2 ** 31
min_range = min_range + 1 if interval.left_open else min_range

metahandler_instance = gengy_metahandler(min_range, max_range)
metahandler_type = Annotated[
python_type, metahandler_instance] # type: ignore
metahandler_type = Annotated[python_type, metahandler_instance] # type: ignore
metahandler_list.append(metahandler_type)
else:
assert False
Expand All @@ -266,7 +292,8 @@ def get_metahandler_union(


def refined_type_to_metahandler(
ty: RefinedType, ) -> MetaHandlerGenerator | Union[MetaHandlerGenerator]:
ty: RefinedType,
) -> MetaHandlerGenerator | Union[MetaHandlerGenerator]:
base_type_str = str(ty.type.name)
gengy_metahandler = aeon_to_gengy_metahandlers[base_type_str]
name, ref = ty.name, ty.refinement
Expand All @@ -288,9 +315,8 @@ def refined_type_to_metahandler(

def create_abstract_class(class_name: str) -> type:
"""Create and return a new abstract class with the given name."""
class_name = "t_" + class_name if not class_name.startswith(
"t_") else class_name
return make_dataclass(class_name, [], bases=(ABC, ))
class_name = "t_" + class_name if not class_name.startswith("t_") else class_name
return make_dataclass(class_name, [], bases=(ABC,))


def create_literal_class(
Expand All @@ -302,7 +328,7 @@ def create_literal_class(
new_class = make_dataclass(
"literal_" + class_name,
[("value", value_type)],
bases=(base_class, ),
bases=(base_class,),
)
return mk_method_core_literal(new_class)

Expand All @@ -313,8 +339,7 @@ def handle_refined_type(
grammar_nodes: list[type],
) -> tuple[list[type], type]:
"""Handle the creation of classes for refined types and update grammar nodes accordingly."""
class_name = "t_" + class_name if not class_name.startswith(
"t_") else class_name
class_name = "t_" + class_name if not class_name.startswith("t_") else class_name
new_abs_class = create_abstract_class(class_name)
grammar_nodes.append(new_abs_class)

Expand Down Expand Up @@ -368,8 +393,7 @@ def find_class_by_name(

return grammar_nodes, new_abs_class

if ty is not None and isinstance(ty, RefinedType) and str(
ty.type.name) in aeon_to_gengy_metahandlers:
if ty is not None and isinstance(ty, RefinedType) and str(ty.type.name) in aeon_to_gengy_metahandlers:
return handle_refined_type(class_name, ty, grammar_nodes)

new_abs_class = create_abstract_class(class_name)
Expand All @@ -379,7 +403,8 @@ def find_class_by_name(

def is_valid_class_name(class_name: str) -> bool:
return class_name not in prelude_ops and not class_name.startswith(
("_anf_", "target"), )
("_anf_", "target"),
)


def get_attribute_type_name(
Expand Down Expand Up @@ -412,10 +437,14 @@ def generate_class_components(
fields = []
parent_name = ""
while isinstance(class_type, AbstractionType):
attribute_name = (class_type.var_name.value if isinstance(
class_type.var_name,
Token,
) else class_type.var_name)
attribute_name = (
class_type.var_name.value
if isinstance(
class_type.var_name,
Token,
)
else class_type.var_name
)
attribute_type = class_type.var_type

attribute_type_name = get_attribute_type_name(attribute_type)
Expand Down Expand Up @@ -446,7 +475,7 @@ def create_new_class(class_name: str, parent_class: type, fields=None) -> type:
"""Creates a new class with the given name, parent class, and fields."""
if fields is None:
fields = []
new_class = make_dataclass(class_name, fields, bases=(parent_class, ))
new_class = make_dataclass(class_name, fields, bases=(parent_class,))
new_class = mk_method_core(new_class)

return new_class
Expand Down Expand Up @@ -578,9 +607,8 @@ def create_if_class(
def build_control_flow_grammar_nodes(grammar_nodes: list[type]) -> list[type]:
types_names_set = {
cls.__name__
for cls in grammar_nodes if cls.__base__ is ABC and not any(
issubclass(cls, other) and cls is not other
for other in grammar_nodes)
for cls in grammar_nodes
if cls.__base__ is ABC and not any(issubclass(cls, other) and cls is not other for other in grammar_nodes)
}
for ty_name in types_names_set:
grammar_nodes = create_if_class(
Expand Down
5 changes: 3 additions & 2 deletions aeon/synthesis_grammar/identification.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,9 +88,10 @@ def get_holes_info(
hs2 = get_holes_info(ctx, body, ty, targets, refined_types)
return hs1 | hs2
case TypeApplication(body=body, type=argty):
_, bty = synth(ctx, body)
argty = argty if refined_types else refined_to_unrefined_type(argty)
if isinstance(ty, TypePolymorphism):
ntype = substitute_vartype(ty.body, argty, ty.name)
if isinstance(bty, TypePolymorphism):
ntype = substitute_vartype(bty.body, argty, bty.name)
ntype = ntype if refined_types else refined_to_unrefined_type(ntype)
return get_holes_info(ctx, body, ntype, targets, refined_types)
else:
Expand Down
Loading

0 comments on commit a071516

Please sign in to comment.