Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix handling of enum fields in quantified records #1106

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 78 additions & 4 deletions src/lustre/lustreAstNormalizer.ml
Original file line number Diff line number Diff line change
Expand Up @@ -398,7 +398,81 @@ let mk_fresh_node_arg_local info pos is_const expr_type expr =
NodeArgCache.add node_arg_cache expr nexpr;
nexpr, gids

let mk_range_expr ctx node_id expr_type expr =
let mk_range_expr ctx node_id expr_type expr =
let rec mk ctx n expr_type expr =
let expr_type = Chk.expand_type_syn_reftype_history ctx expr_type |> unwrap in
match expr_type with
| A.IntRange (_, l, u) ->
let original_ty, _ = Chk.infer_type_expr ctx node_id expr |> unwrap in
let original_ty = Chk.expand_type_syn_reftype_history ctx original_ty |> unwrap in
let user_prop, is_original = match original_ty with
| A.IntRange (_, l', u') ->
let eval_int_expr_opt expr = match expr with
| Some expr -> Some (AIC.eval_int_expr ctx expr)
| None -> None
in
let is_original =
let (l, u) = eval_int_expr_opt l, eval_int_expr_opt u in
let (l', u') = eval_int_expr_opt l', eval_int_expr_opt u' in
(match (l, u, l', u') with
| Some (Ok l), Some (Ok u), Some (Ok l'), Some (Ok u') -> l = l' && u = u'
| Some (Ok l), None, Some (Ok l'), None -> l = l'
| None, Some (Ok u), None, Some (Ok u') -> u = u'
| None, None, None, None -> true
| _ -> false)
in
let user_prop = if is_original then []
else
match l', u' with
| Some l', Some u' ->
let l' = A.CompOp (dpos, A.Lte, l', expr) in
let u' = A.CompOp (dpos, A.Lte, expr, u') in
[A.BinaryOp (dpos, A.And, l', u'), true]
| Some l', None -> [A.CompOp (dpos, A.Lte, l', expr), true]
| None, Some u' -> [A.CompOp (dpos, A.Lte, expr, u'), true]
| None, None -> [(A.Const (dpos, A.True)), true]
in
user_prop, is_original
| A.Int _ -> [], false
| _ -> assert false
in (match l, u with
| Some l, Some u ->
let l = A.CompOp (dpos, A.Lte, l, expr) in
let u = A.CompOp (dpos, A.Lte, expr, u) in
[A.BinaryOp (dpos, A.And, l, u), is_original] @ user_prop
| Some l, None ->
[A.CompOp (dpos, A.Lte, l, expr), is_original] @ user_prop
| None, Some u ->
[A.CompOp (dpos, A.Lte, expr, u), is_original] @ user_prop
| None, None -> [(A.Const (dpos, A.True)), is_original] @ user_prop
)
| A.ArrayType (_, (ty, upper_bound)) ->
let id_str = HString.concat2 (HString.mk_hstring "x") (HString.mk_hstring (string_of_int n)) in
let id = A.Ident (dpos, id_str) in
let ctx = Ctx.add_ty ctx id_str (A.Int dpos) in
let expr = A.ArrayIndex (dpos, expr, id) in
let rexpr = mk ctx (succ n) ty expr in
let l = A.CompOp (dpos, A.Lte, A.Const (dpos, A.Num (HString.mk_hstring "0")), id) in
let u = A.CompOp (dpos, A.Lt, id, upper_bound) in
let assumption = A.BinaryOp (dpos, A.And, l, u) in
let var = dpos, id_str, (A.Int dpos) in
let body = fun e -> A.BinaryOp (dpos, A.Impl, assumption, e) in
List.map (fun (e, is_original) -> A.Quantifier (dpos, A.Forall, [var], body e), is_original) rexpr
| TupleType (_, tys) ->
let mk_proj i = A.TupleProject (dpos, expr, i) in
let tys = List.filter (fun ty -> Ctx.type_contains_subrange ctx ty) tys in
let tys = List.mapi (fun i ty -> mk ctx n ty (mk_proj i)) tys in
List.fold_left (@) [] tys
| RecordType (_, _, tys) ->
let mk_proj i = A.RecordProject (dpos, expr, i) in
let tys = List.filter (fun (_, _, ty) -> Ctx.type_contains_subrange ctx ty) tys in
let tys = List.map (fun (_, i, ty) -> mk ctx n ty (mk_proj i)) tys in
List.fold_left (@) [] tys
| _ -> []
in
mk ctx 0 expr_type expr

let mk_enum_range_expr ctx node_id expr_type expr =
let rec mk ctx n expr_type expr =
let expr_type = Chk.expand_type_syn_reftype_history ctx expr_type |> unwrap in
match expr_type with
Expand Down Expand Up @@ -473,12 +547,12 @@ let mk_range_expr ctx node_id expr_type expr =
List.map (fun (e, is_original) -> A.Quantifier (dpos, A.Forall, [var], body e), is_original) rexpr
| TupleType (_, tys) ->
let mk_proj i = A.TupleProject (dpos, expr, i) in
let tys = List.filter (fun ty -> Ctx.type_contains_subrange ctx ty) tys in
let tys = List.filter (fun ty -> Ctx.type_contains_enum_or_subrange ctx ty) tys in
let tys = List.mapi (fun i ty -> mk ctx n ty (mk_proj i)) tys in
List.fold_left (@) [] tys
| RecordType (_, _, tys) ->
let mk_proj i = A.RecordProject (dpos, expr, i) in
let tys = List.filter (fun (_, _, ty) -> Ctx.type_contains_subrange ctx ty) tys in
let tys = List.filter (fun (_, _, ty) -> Ctx.type_contains_enum_or_subrange ctx ty) tys in
let tys = List.map (fun (_, i, ty) -> mk ctx n ty (mk_proj i)) tys in
List.fold_left (@) [] tys
| _ -> []
Expand Down Expand Up @@ -1683,7 +1757,7 @@ and normalize_expr ?guard info node_id map =
(fun acc (_, id, ty) ->
let expr = A.Ident(dpos, id) in
let range_exprs =
List.map fst (mk_range_expr info.context (Some node_id) ty expr) @
List.map fst (mk_enum_range_expr info.context (Some node_id) ty expr) @
List.map snd (mk_ref_type_expr info.context expr Local ty)
in
range_exprs :: acc
Expand Down
6 changes: 6 additions & 0 deletions src/lustre/lustreAstNormalizer.mli
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,12 @@ val mk_range_expr : TypeCheckerContext.tc_context ->
LustreAst.expr ->
(LustreAst.expr * bool) list

val mk_enum_range_expr : TypeCheckerContext.tc_context ->
HString.t option ->
LustreAst.lustre_type ->
LustreAst.expr ->
(LustreAst.expr * bool) list

val normalize : TypeCheckerContext.tc_context ->
LustreAbstractInterpretation.context ->
LustreAst.t ->
Expand Down
55 changes: 51 additions & 4 deletions src/lustre/typeCheckerContext.ml
Original file line number Diff line number Diff line change
Expand Up @@ -701,7 +701,35 @@ let rec type_contains_subrange ctx = function
| Some ty -> type_contains_subrange ctx ty
| None -> assert false
)
| _ -> false
| Bool _ | Int _ | Real _ | EnumType _
| UInt8 _| UInt16 _| UInt32 _| UInt64 _
| Int8 _ |Int16 _ |Int32 _ | Int64 _
| AbstractType _ -> false

let rec type_contains_enum_or_subrange ctx = function
| LA.IntRange _
| EnumType _ -> true
| RefinementType (_, (_, _, ty), _) -> type_contains_enum_or_subrange ctx ty
| TupleType (_, tys) | GroupType (_, tys) ->
List.fold_left (fun acc ty -> acc || type_contains_enum_or_subrange ctx ty) false tys
| RecordType (_, _, tys) ->
List.fold_left (fun acc (_, _, ty) -> acc || type_contains_enum_or_subrange ctx ty)
false tys
| ArrayType (_, (ty, _)) -> type_contains_enum_or_subrange ctx ty
| TArr (_, ty1, ty2) -> type_contains_enum_or_subrange ctx ty1 || type_contains_enum_or_subrange ctx ty2
| History (_, id) ->
(match lookup_ty ctx id with
| Some ty -> type_contains_enum_or_subrange ctx ty
| _ -> assert false)
| UserType (_, ty_args, id) -> (
match lookup_ty_syn ctx id ty_args with
| Some ty -> type_contains_enum_or_subrange ctx ty
| None -> assert false
)
| Bool _ | Int _ | Real _
| UInt8 _| UInt16 _| UInt32 _| UInt64 _
| Int8 _ |Int16 _ |Int32 _ | Int64 _
| AbstractType _ -> false

let rec type_contains_ref ctx = function
| LA.RefinementType _ -> true
Expand All @@ -716,7 +744,15 @@ let rec type_contains_subrange ctx = function
(match lookup_ty ctx id with
| Some ty -> type_contains_ref ctx ty
| _ -> assert false)
| _ -> false
| UserType (_, ty_args, id) -> (
match lookup_ty_syn ctx id ty_args with
| Some ty -> type_contains_ref ctx ty
| None -> false
)
| Bool _ | Int _ | Real _ | EnumType _ | IntRange _
| UInt8 _| UInt16 _| UInt32 _| UInt64 _
| Int8 _ |Int16 _ |Int32 _ | Int64 _
| AbstractType _ -> false

let rec type_contains_enum_subrange_reftype ctx = function
| LA.IntRange _
Expand All @@ -733,7 +769,15 @@ let rec type_contains_enum_subrange_reftype ctx = function
(match lookup_ty ctx id with
| Some ty -> type_contains_enum_subrange_reftype ctx ty
| _ -> assert false)
| _ -> false
| UserType (_, ty_args, id) -> (
match lookup_ty_syn ctx id ty_args with
| Some ty -> type_contains_enum_subrange_reftype ctx ty
| None -> assert false
)
| Bool _ | Int _ | Real _
| UInt8 _| UInt16 _| UInt32 _| UInt64 _
| Int8 _ |Int16 _ |Int32 _ | Int64 _
| AbstractType _ -> false

let rec type_contains_abstract ctx = function
| LA.UserType (_, ty_args, id) ->
Expand All @@ -753,7 +797,10 @@ let rec type_contains_abstract ctx = function
(match lookup_ty ctx id with
| Some ty -> type_contains_abstract ctx ty
| _ -> assert false)
| _ -> false
| Bool _ | Int _ | Real _ | EnumType _ | IntRange _
| UInt8 _| UInt16 _| UInt32 _| UInt64 _
| Int8 _ |Int16 _ |Int32 _ | Int64 _
| AbstractType _ -> false

let rec ty_vars_of_expr ctx node_name expr =
let call = ty_vars_of_expr ctx node_name in match expr with
Expand Down
3 changes: 3 additions & 0 deletions src/lustre/typeCheckerContext.mli
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,9 @@ val is_machine_type_of_associated_width: tc_context -> (LA.lustre_type * LA.lust
val type_contains_subrange : tc_context -> LA.lustre_type -> bool
(** Returns true if the lustre type expression contains an IntRange or if it is an IntRange *)

val type_contains_enum_or_subrange : tc_context -> LA.lustre_type -> bool
(** Returns true if the lustre type expression contains an EnumType/IntRange or if it is an EnumType/IntRange *)

val type_contains_ref : tc_context -> LA.lustre_type -> bool
(** Returns true if the lustre type expression contains a RefinementType or if it is an RefinementType *)

Expand Down
10 changes: 10 additions & 0 deletions tests/regression/success/forall_enum.lus
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
type E = enum { E1, E2 };

type R = struct {
f: E;
};

node N() returns (y:int);
let
check forall (x: R) (x.f=E1 or x.f=E2);
tel