Skip to content

Commit

Permalink
Try declaring a regular recursive type for exp as opposed to the recu…
Browse files Browse the repository at this point in the history
…rsive modules (#1391)
  • Loading branch information
cyrus- authored Nov 5, 2024
2 parents 1b6b475 + 12b8ddc commit 297b04b
Show file tree
Hide file tree
Showing 23 changed files with 452 additions and 560 deletions.
21 changes: 10 additions & 11 deletions src/haz3lcore/dynamics/Builtins.re
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,13 @@ let fn =
module Pervasives = {
module Impls = {
/* constants */
let infinity = DHExp.Float(Float.infinity) |> fresh;
let neg_infinity = DHExp.Float(Float.neg_infinity) |> fresh;
let nan = DHExp.Float(Float.nan) |> fresh;
let epsilon_float = DHExp.Float(epsilon_float) |> fresh;
let pi = DHExp.Float(Float.pi) |> fresh;
let max_int = DHExp.Int(Int.max_int) |> fresh;
let min_int = DHExp.Int(Int.min_int) |> fresh;
let infinity = Float(Float.infinity) |> fresh;
let neg_infinity = Float(Float.neg_infinity) |> fresh;
let nan = Float(Float.nan) |> fresh;
let epsilon_float = Float(epsilon_float) |> fresh;
let pi = Float(Float.pi) |> fresh;
let max_int = Int(Int.max_int) |> fresh;
let min_int = Int(Int.min_int) |> fresh;

let unary = (f: DHExp.t => result, d: DHExp.t) => {
switch (f(d)) {
Expand Down Expand Up @@ -180,8 +180,8 @@ module Pervasives = {
switch (convert(s)) {
| Some(n) => Ok(wrap(n))
| None =>
let d' = DHExp.BuiltinFun(name) |> DHExp.fresh;
let d' = DHExp.Ap(Forward, d', d) |> DHExp.fresh;
let d' = BuiltinFun(name) |> DHExp.fresh;
let d' = Ap(Forward, d', d) |> DHExp.fresh;
let d' = DynamicErrorHole(d', InvalidOfString) |> DHExp.fresh;
Ok(d');
}
Expand All @@ -204,8 +204,7 @@ module Pervasives = {
Ok(
fresh(
DynamicErrorHole(
DHExp.Ap(Forward, DHExp.BuiltinFun(name) |> fresh, d1)
|> fresh,
Ap(Forward, BuiltinFun(name) |> fresh, d1) |> fresh,
DivideByZero,
),
),
Expand Down
18 changes: 9 additions & 9 deletions src/haz3lcore/dynamics/Casts.re
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ let grounded_Forall =
);
let grounded_Prod = length =>
NotGroundOrHole(
Prod(ListUtil.replicate(length, Typ.Unknown(Internal) |> Typ.temp))
Prod(ListUtil.replicate(length, Unknown(Internal) |> Typ.temp))
|> Typ.temp,
);
let grounded_Sum: unit => Typ.sum_map =
Expand All @@ -50,7 +50,7 @@ let grounded_List =
let rec ground_cases_of = (ty: Typ.t): ground_cases => {
let is_hole: Typ.t => bool =
fun
| {term: Typ.Unknown(_), _} => true
| {term: Unknown(_), _} => true
| _ => false;
switch (Typ.term_of(ty)) {
| Unknown(_) => Hole
Expand All @@ -67,7 +67,7 @@ let rec ground_cases_of = (ty: Typ.t): ground_cases => {
| Prod(tys) =>
if (List.for_all(
fun
| ({term: Typ.Unknown(_), _}: Typ.t) => true
| ({term: Unknown(_), _}: Typ.t) => true
| _ => false,
tys,
)) {
Expand Down Expand Up @@ -132,12 +132,12 @@ let rec transition = (~recursive=false, d: DHExp.t): option(DHExp.t) => {
| Some(d1) => d1
| None => inner_cast
};
Some(DHExp.Cast(inner_cast, t2_grounded, t2) |> DHExp.fresh);
Some(Cast(inner_cast, t2_grounded, t2) |> DHExp.fresh);

| (NotGroundOrHole(t1_grounded), Hole) =>
/* ITGround rule */
Some(
DHExp.Cast(Cast(d1, t1, t1_grounded) |> DHExp.fresh, t1_grounded, t2)
Cast(Cast(d1, t1, t1_grounded) |> DHExp.fresh, t1_grounded, t2)
|> DHExp.fresh,
)

Expand Down Expand Up @@ -187,7 +187,7 @@ let pattern_fixup = (p: DHPat.t): DHPat.t => {
let (p1, d1) = unwrap_casts(p1);
(
p1,
{term: DHExp.Cast(d1, t1, t2), copied: p.copied, ids: p.ids}
{term: Cast(d1, t1, t2), copied: p.copied, ids: p.ids}
|> transition_multiple,
);
| _ => (p, hole)
Expand All @@ -198,13 +198,13 @@ let pattern_fixup = (p: DHPat.t): DHPat.t => {
| EmptyHole => p
| Cast(d1, t1, t2) =>
let p1 = rewrap_casts((p, d1));
{term: DHPat.Cast(p1, t1, t2), copied: d.copied, ids: d.ids};
{term: Cast(p1, t1, t2), copied: d.copied, ids: d.ids};
| FailedCast(d1, t1, t2) =>
let p1 = rewrap_casts((p, d1));
{
term:
DHPat.Cast(
DHPat.Cast(p1, t1, Typ.fresh(Unknown(Internal))) |> DHPat.fresh,
Cast(
Cast(p1, t1, Typ.fresh(Unknown(Internal))) |> DHPat.fresh,
Typ.fresh(Unknown(Internal)),
t2,
),
Expand Down
77 changes: 33 additions & 44 deletions src/haz3lcore/dynamics/Elaborator.re
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,10 @@ let fresh_cast = (d: DHExp.t, t1: Typ.t, t2: Typ.t): DHExp.t => {
? d
: {
let d' =
DHExp.Cast(d, t1, Typ.temp(Unknown(Internal)))
Cast(d, t1, Typ.temp(Unknown(Internal)))
|> DHExp.fresh
|> Casts.transition_multiple;
DHExp.Cast(d', Typ.temp(Unknown(Internal)), t2)
Cast(d', Typ.temp(Unknown(Internal)), t2)
|> DHExp.fresh
|> Casts.transition_multiple;
};
Expand Down Expand Up @@ -63,11 +63,11 @@ let elaborated_type = (m: Statics.Map.t, uexp: UExp.t): (Typ.t, Ctx.t, 'a) => {
| Syn => self_ty
| SynFun =>
let (ty1, ty2) = Typ.matched_arrow(ctx, self_ty);
Typ.Arrow(ty1, ty2) |> Typ.temp;
Arrow(ty1, ty2) |> Typ.temp;
| SynTypFun =>
let (tpat, ty) = Typ.matched_forall(ctx, self_ty);
let tpat = Option.value(tpat, ~default=TPat.fresh(EmptyHole));
Typ.Forall(tpat, ty) |> Typ.temp;
Forall(tpat, ty) |> Typ.temp;
// We need to remove the synswitches from this type.
| Ana(ana_ty) => Typ.match_synswitch(ana_ty, self_ty)
};
Expand All @@ -90,11 +90,11 @@ let elaborated_pat_type = (m: Statics.Map.t, upat: UPat.t): (Typ.t, Ctx.t) => {
| Syn => self_ty
| SynFun =>
let (ty1, ty2) = Typ.matched_arrow(ctx, self_ty);
Typ.Arrow(ty1, ty2) |> Typ.temp;
Arrow(ty1, ty2) |> Typ.temp;
| SynTypFun =>
let (tpat, ty) = Typ.matched_forall(ctx, self_ty);
let tpat = Option.value(tpat, ~default=TPat.fresh(EmptyHole));
Typ.Forall(tpat, ty) |> Typ.temp;
Forall(tpat, ty) |> Typ.temp;
| Ana(ana_ty) =>
switch (prev_synswitch) {
| None => ana_ty
Expand Down Expand Up @@ -125,9 +125,7 @@ let rec elaborate_pattern =
|> List.map2((p, t) => fresh_pat_cast(p, t, inner_type), _, tys)
|> (
ps' =>
DHPat.ListLit(ps')
|> rewrap
|> cast_from(List(inner_type) |> Typ.temp)
ListLit(ps') |> rewrap |> cast_from(List(inner_type) |> Typ.temp)
);
| Cons(p1, p2) =>
let (p1', ty1) = elaborate_pattern(m, p1);
Expand All @@ -138,19 +136,17 @@ let rec elaborate_pattern =
|> Option.value(~default=Typ.temp(Unknown(Internal)));
let p1'' = fresh_pat_cast(p1', ty1, ty_inner);
let p2'' = fresh_pat_cast(p2', ty2, List(ty_inner) |> Typ.temp);
DHPat.Cons(p1'', p2'')
|> rewrap
|> cast_from(List(ty_inner) |> Typ.temp);
Cons(p1'', p2'') |> rewrap |> cast_from(List(ty_inner) |> Typ.temp);
| Tuple(ps) =>
let (ps', tys) = List.map(elaborate_pattern(m), ps) |> ListUtil.unzip;
DHPat.Tuple(ps') |> rewrap |> cast_from(Typ.Prod(tys) |> Typ.temp);
Tuple(ps') |> rewrap |> cast_from(Prod(tys) |> Typ.temp);
| Ap(p1, p2) =>
let (p1', ty1) = elaborate_pattern(m, p1);
let (p2', ty2) = elaborate_pattern(m, p2);
let (ty1l, ty1r) = Typ.matched_arrow(ctx, ty1);
let p1'' = fresh_pat_cast(p1', ty1, Arrow(ty1l, ty1r) |> Typ.temp);
let p2'' = fresh_pat_cast(p2', ty2, ty1l);
DHPat.Ap(p1'', p2'') |> rewrap |> cast_from(ty1r);
Ap(p1'', p2'') |> rewrap |> cast_from(ty1r);
| Invalid(_)
| EmptyHole
| MultiHole(_)
Expand Down Expand Up @@ -213,7 +209,7 @@ let rec elaborate = (m: Statics.Map.t, uexp: UExp.t): (DHExp.t, Typ.t) => {
switch (term) {
| Invalid(_)
| Undefined
| EmptyHole => uexp |> cast_from(Typ.temp(Typ.Unknown(Internal)))
| EmptyHole => uexp |> cast_from(Typ.temp(Unknown(Internal)))
| MultiHole(stuff) =>
Any.map_term(
~f_exp=(_, exp) => {elaborate(m, exp) |> fst},
Expand All @@ -223,9 +219,9 @@ let rec elaborate = (m: Statics.Map.t, uexp: UExp.t): (DHExp.t, Typ.t) => {
|> List.map(_, stuff)
|> (
stuff =>
DHExp.MultiHole(stuff)
MultiHole(stuff)
|> rewrap
|> cast_from(Typ.temp(Typ.Unknown(Internal)))
|> cast_from(Typ.temp(Unknown(Internal)))
)
| DynamicErrorHole(e, err) =>
let (e', _) = elaborate(m, e);
Expand All @@ -245,10 +241,10 @@ let rec elaborate = (m: Statics.Map.t, uexp: UExp.t): (DHExp.t, Typ.t) => {
| ListLit(es) =>
let (ds, tys) = List.map(elaborate(m), es) |> ListUtil.unzip;
let inner_type =
Typ.join_all(~empty=Typ.Unknown(Internal) |> Typ.temp, ctx, tys)
|> Option.value(~default=Typ.temp(Typ.Unknown(Internal)));
Typ.join_all(~empty=Unknown(Internal) |> Typ.temp, ctx, tys)
|> Option.value(~default=Typ.temp(Unknown(Internal)));
let ds' = List.map2((d, t) => fresh_cast(d, t, inner_type), ds, tys);
Exp.ListLit(ds') |> rewrap |> cast_from(List(inner_type) |> Typ.temp);
ListLit(ds') |> rewrap |> cast_from(List(inner_type) |> Typ.temp);
| Constructor(c, _) =>
let mode =
switch (Id.Map.find_opt(Exp.rep_id(uexp), m)) {
Expand All @@ -266,23 +262,23 @@ let rec elaborate = (m: Statics.Map.t, uexp: UExp.t): (DHExp.t, Typ.t) => {
| Fun(p, e, env, n) =>
let (p', typ) = elaborate_pattern(m, p);
let (e', tye) = elaborate(m, e);
Exp.Fun(p', e', env, n)
Fun(p', e', env, n)
|> rewrap
|> cast_from(Arrow(typ, tye) |> Typ.temp);
| TypFun(tpat, e, name) =>
let (e', tye) = elaborate(m, e);
Exp.TypFun(tpat, e', name)
TypFun(tpat, e', name)
|> rewrap
|> cast_from(Typ.Forall(tpat, tye) |> Typ.temp);
|> cast_from(Forall(tpat, tye) |> Typ.temp);
| Tuple(es) =>
let (ds, tys) = List.map(elaborate(m), es) |> ListUtil.unzip;
Exp.Tuple(ds) |> rewrap |> cast_from(Prod(tys) |> Typ.temp);
Tuple(ds) |> rewrap |> cast_from(Prod(tys) |> Typ.temp);
| Var(v) =>
uexp
|> cast_from(
Ctx.lookup_var(ctx, v)
|> Option.map((x: Ctx.var_entry) => x.typ |> Typ.normalize(ctx))
|> Option.value(~default=Typ.temp(Typ.Unknown(Internal))),
|> Option.value(~default=Typ.temp(Unknown(Internal))),
)
| Let(p, def, body) =>
let add_name: (option(string), DHExp.t) => DHExp.t = (
Expand All @@ -305,24 +301,20 @@ let rec elaborate = (m: Statics.Map.t, uexp: UExp.t): (DHExp.t, Typ.t) => {
let def = add_name(Pat.get_var(p), def);
let (def, ty2) = elaborate(m, def);
let (body, ty) = elaborate(m, body);
Exp.Let(p, fresh_cast(def, ty2, ty1), body)
|> rewrap
|> cast_from(ty);
Let(p, fresh_cast(def, ty2, ty1), body) |> rewrap |> cast_from(ty);
} else {
// TODO: Add names to mutually recursive functions
// TODO: Don't add fixpoint if there already is one
let def = add_name(Option.map(s => s ++ "+", Pat.get_var(p)), def);
let (def, ty2) = elaborate(m, def);
let (body, ty) = elaborate(m, body);
let fixf = FixF(p, fresh_cast(def, ty2, ty1), None) |> DHExp.fresh;
Exp.Let(p, fixf, body) |> rewrap |> cast_from(ty);
Let(p, fixf, body) |> rewrap |> cast_from(ty);
};
| FixF(p, e, env) =>
let (p', typ) = elaborate_pattern(m, p);
let (e', tye) = elaborate(m, e);
Exp.FixF(p', fresh_cast(e', tye, typ), env)
|> rewrap
|> cast_from(typ);
FixF(p', fresh_cast(e', tye, typ), env) |> rewrap |> cast_from(typ);
| TyAlias(_, _, e) =>
let (e', tye) = elaborate(m, e);
e' |> cast_from(tye);
Expand All @@ -332,7 +324,7 @@ let rec elaborate = (m: Statics.Map.t, uexp: UExp.t): (DHExp.t, Typ.t) => {
let (tyf1, tyf2) = Typ.matched_arrow(ctx, tyf);
let f'' = fresh_cast(f', tyf, Arrow(tyf1, tyf2) |> Typ.temp);
let a'' = fresh_cast(a', tya, tyf1);
Exp.Ap(dir, f'', a'') |> rewrap |> cast_from(tyf2);
Ap(dir, f'', a'') |> rewrap |> cast_from(tyf2);
| DeferredAp(f, args) =>
let (f', tyf) = elaborate(m, f);
let (args', tys) = List.map(elaborate(m), args) |> ListUtil.unzip;
Expand Down Expand Up @@ -374,11 +366,11 @@ let rec elaborate = (m: Statics.Map.t, uexp: UExp.t): (DHExp.t, Typ.t) => {
let (f', tyf) = elaborate(m, f);
let ty =
Typ.join(~fix=false, ctx, tyt, tyf)
|> Option.value(~default=Typ.temp(Typ.Unknown(Internal)));
|> Option.value(~default=Typ.temp(Unknown(Internal)));
let c'' = fresh_cast(c', tyc, Bool |> Typ.temp);
let t'' = fresh_cast(t', tyt, ty);
let f'' = fresh_cast(f', tyf, ty);
Exp.If(c'', t'', f'') |> rewrap |> cast_from(ty);
If(c'', t'', f'') |> rewrap |> cast_from(ty);
| Seq(e1, e2) =>
let (e1', _) = elaborate(m, e1);
let (e2', ty2) = elaborate(m, e2);
Expand Down Expand Up @@ -430,10 +422,7 @@ let rec elaborate = (m: Statics.Map.t, uexp: UExp.t): (DHExp.t, Typ.t) => {
Constructor("$e", Unknown(Internal) |> Typ.temp) |> rewrap
| Var("v") =>
Constructor("$v", Unknown(Internal) |> Typ.temp) |> rewrap
| _ =>
DHExp.EmptyHole
|> rewrap
|> cast_from(Typ.temp(Typ.Unknown(Internal)))
| _ => EmptyHole |> rewrap |> cast_from(Typ.temp(Unknown(Internal)))
}
| UnOp(Int(Minus), e) =>
let (e', t) = elaborate(m, e);
Expand Down Expand Up @@ -536,23 +525,23 @@ let rec elaborate = (m: Statics.Map.t, uexp: UExp.t): (DHExp.t, Typ.t) => {
|> cast_from(
Ctx.lookup_var(Builtins.ctx_init, fn)
|> Option.map((x: Ctx.var_entry) => x.typ)
|> Option.value(~default=Typ.temp(Typ.Unknown(Internal))),
|> Option.value(~default=Typ.temp(Unknown(Internal))),
)
| Match(e, cases) =>
let (e', t) = elaborate(m, e);
let (ps, es) = ListUtil.unzip(cases);
let (ps', ptys) =
List.map(elaborate_pattern(m), ps) |> ListUtil.unzip;
let joined_pty =
Typ.join_all(~empty=Typ.Unknown(Internal) |> Typ.temp, ctx, ptys)
|> Option.value(~default=Typ.temp(Typ.Unknown(Internal)));
Typ.join_all(~empty=Unknown(Internal) |> Typ.temp, ctx, ptys)
|> Option.value(~default=Typ.temp(Unknown(Internal)));
let ps'' =
List.map2((p, t) => fresh_pat_cast(p, t, joined_pty), ps', ptys);
let e'' = fresh_cast(e', t, joined_pty);
let (es', etys) = List.map(elaborate(m), es) |> ListUtil.unzip;
let joined_ety =
Typ.join_all(~empty=Typ.Unknown(Internal) |> Typ.temp, ctx, etys)
|> Option.value(~default=Typ.temp(Typ.Unknown(Internal)));
Typ.join_all(~empty=Unknown(Internal) |> Typ.temp, ctx, etys)
|> Option.value(~default=Typ.temp(Unknown(Internal)));
let es'' =
List.map2((e, t) => fresh_cast(e, t, joined_ety), es', etys);
Match(e'', List.combine(ps'', es''))
Expand Down
Loading

0 comments on commit 297b04b

Please sign in to comment.