diff --git a/src/lustre/lustreNodeGen.ml b/src/lustre/lustreNodeGen.ml index 7d77fd1a9..e8f15331c 100644 --- a/src/lustre/lustreNodeGen.ml +++ b/src/lustre/lustreNodeGen.ml @@ -665,7 +665,7 @@ and compile_ast_type (* Old code does flattening here, but that flattening is only ever used once! And it is for a check, in lustreDeclarations line 423 *) if expand then - let upper = E.numeral_of_expr array_size in + let upper = Numeral.(max zero (E.numeral_of_expr array_size)) in let result = ref X.empty in for ix = 0 to (Numeral.to_int upper - 1) do result := X.fold @@ -681,8 +681,11 @@ and compile_ast_type else let over_element_type j t a = X.add (j @ [X.ArrayVarIndex array_size]) - (Type.mk_array t (if E.is_numeral array_size - then Type.mk_int_range (Some Numeral.zero) (Some (E.numeral_of_expr array_size)) + (Type.mk_array t ( + if E.is_numeral array_size + then + let array_size = Numeral.(max zero (E.numeral_of_expr array_size)) in + Type.mk_int_range (Some Numeral.zero) (Some array_size) else Type.t_int)) a in diff --git a/src/lustre/lustreSyntaxChecks.ml b/src/lustre/lustreSyntaxChecks.ml index e765299ed..0f33fb360 100644 --- a/src/lustre/lustreSyntaxChecks.ml +++ b/src/lustre/lustreSyntaxChecks.ml @@ -101,7 +101,7 @@ let error_message kind = match kind with | NodeCallInRefinableContract (kind, node) -> "Illegal call to " ^ kind ^ " '" ^ HString.string_of_hstring node ^ "' in the cone of influence of this contract: " ^ kind ^ " " ^ HString.string_of_hstring node ^ " has a refinable contract" - | NodeCallInConstant id -> "Illegal node call or choose operator in definition of constant " ^ HString.string_of_hstring id + | NodeCallInConstant id -> "Illegal node call or choose operator in definition of constant '" ^ HString.string_of_hstring id ^ "'" | IllegalTemporalOperator (kind, variant) -> "Illegal " ^ kind ^ " in " ^ variant ^ " definition, " ^ variant ^ "s cannot have state" | IllegalImportOfStatefulContract contract -> "Illegal import of stateful contract '" @@ -524,8 +524,10 @@ let no_dangling_identifiers ctx = function no_a_dangling_identifier ctx pos i | _ -> Ok () -let no_node_calls_in_constant pos i e = - if LAH.expr_contains_call e then syntax_error pos (NodeCallInConstant i) else Ok () +let no_node_calls_in_constant i e = + if LAH.expr_contains_call e + then syntax_error (LAH.pos_of_expr e) (NodeCallInConstant i) + else Ok () let no_quant_var_or_symbolic_index_in_node_call ctx = function | LA.Call (pos, i, args) -> @@ -686,8 +688,8 @@ and check_declaration ctx = function | ConstDecl (span, decl) -> let check = match decl with | LA.FreeConst _ -> Ok () - | UntypedConst (pos, i, e) - | TypedConst (pos, i, e, _) -> check_const_expr_decl pos i ctx e + | UntypedConst (_, i, e) + | TypedConst (_, i, e, _) -> check_const_expr_decl i ctx e in check >> Ok (LA.ConstDecl (span, decl)) | NodeDecl (span, decl) -> check_node_decl ctx span decl @@ -695,13 +697,12 @@ and check_declaration ctx = function | ContractNodeDecl (span, decl) -> check_contract_node_decl ctx span decl | NodeParamInst (span, _) -> syntax_error span.start_pos UnsupportedParametricDeclaration -and check_const_expr_decl pos i ctx expr = - let composed_checks pos i ctx e = - (no_temporal_operator "constant" e) - >> (no_dangling_identifiers ctx e) - >> (no_node_calls_in_constant pos i e) +and check_const_expr_decl i ctx expr = + let composed_checks i ctx e = + (no_dangling_identifiers ctx e) + >> (no_node_calls_in_constant i e) in - check_expr ctx (composed_checks pos i) expr + check_expr ctx (composed_checks i) expr and common_node_equations_checks ctx e = (no_dangling_calls ctx e) @@ -728,8 +729,8 @@ and check_output_items (pos, _id, _ty, clock) = and check_local_items ctx local = match local with | LA.NodeConstDecl (_, FreeConst _) -> Ok () - | LA.NodeConstDecl (pos, UntypedConst (_, i, e)) -> check_const_expr_decl pos i ctx e - | LA.NodeConstDecl (pos, TypedConst (_, i, e, _)) -> check_const_expr_decl pos i ctx e + | LA.NodeConstDecl (_, UntypedConst (_, i, e)) -> check_const_expr_decl i ctx e + | LA.NodeConstDecl (_, TypedConst (_, i, e, _)) -> check_const_expr_decl i ctx e | NodeVarDecl (_, (_, _, _, LA.ClockTrue)) -> Ok () | NodeVarDecl (_, (pos, i, _, _)) -> syntax_error pos (UnsupportedClockedLocal i) @@ -884,8 +885,8 @@ and check_contract is_contract_node ctx f contract = | GhostConst decl -> ( let check = match decl with | LA.FreeConst _ -> Ok () - | UntypedConst (pos, i, e) - | TypedConst (pos, i, e, _) -> check_const_expr_decl pos i ctx e + | UntypedConst (_, i, e) + | TypedConst (_, i, e, _) -> check_const_expr_decl i ctx e in check >> Ok () ) diff --git a/src/lustre/lustreTypeChecker.ml b/src/lustre/lustreTypeChecker.ml index eeea5bf7e..ee2ad141c 100644 --- a/src/lustre/lustreTypeChecker.ml +++ b/src/lustre/lustreTypeChecker.ml @@ -44,6 +44,7 @@ type error_kind = Unknown of string | MergeCaseNotUnique of HString.t | UnboundIdentifier of HString.t | UnboundModeReference of HString.t + | UnboundNodeName of HString.t | NotAFieldOfRecord of HString.t | NoValueForRecordField of HString.t | IlltypedRecordProjection of tc_type @@ -51,9 +52,9 @@ type error_kind = Unknown of string | IlltypedTupleProjection of tc_type | UnequalIteBranchTypes of tc_type * tc_type | ExpectedBooleanExpression of tc_type + | ExpectedIntegerExpression of tc_type | Unsupported of string | UnequalArrayExpressionType - | ExpectedNumeralArrayBound | TypeMismatchOfRecordLabel of HString.t * tc_type * tc_type | IlltypedRecordUpdate of tc_type | ExpectedLabel of LA.expr @@ -89,8 +90,7 @@ type error_kind = Unknown of string | DisallowedSubrangeInContractReturn of bool * HString.t * tc_type | AssumptionMustBeInputOrOutput of HString.t | Redeclaration of HString.t - | ExpectedConstant of LA.expr - | ArrayBoundsInvalidExpression + | ExpectedConstant of string * string | UndeclaredType of HString.t | EmptySubrange of int * int | SubrangeArgumentMustBeConstantInteger of LA.expr @@ -111,6 +111,7 @@ let error_message kind = match kind with | MergeCaseNotUnique case -> "Merge case " ^ HString.string_of_hstring case ^ " must be unique" | UnboundIdentifier id -> "Unbound identifier '" ^ HString.string_of_hstring id ^ "'" | UnboundModeReference id -> "Unbound mode reference '" ^ HString.string_of_hstring id ^ "'" + | UnboundNodeName id -> "Unbound node identifier '" ^ HString.string_of_hstring id ^ "'" | NotAFieldOfRecord id -> "No field name '" ^ HString.string_of_hstring id ^ "' in record type" | NoValueForRecordField id -> "No value given for field '" ^ HString.string_of_hstring id ^ "'" | IlltypedRecordProjection ty -> "Cannot project field out of non record expression type " ^ string_of_tc_type ty @@ -118,10 +119,10 @@ let error_message kind = match kind with | IlltypedTupleProjection ty -> "Cannot project field out of non tuple type " ^ string_of_tc_type ty | UnequalIteBranchTypes (ty1, ty2) -> "Expected equal types of each if-then-else branch but found: " ^ string_of_tc_type ty1 ^ " on the then-branch and " ^ string_of_tc_type ty2 ^ " on the the else-branch" - | ExpectedBooleanExpression ty -> "Expected a boolean expression but bound " ^ string_of_tc_type ty + | ExpectedBooleanExpression ty -> "Expected a boolean expression but found expression of type " ^ string_of_tc_type ty + | ExpectedIntegerExpression ty -> "Expected an integer expression but found expression of type " ^ string_of_tc_type ty | Unsupported s -> "Unsupported: " ^ s | UnequalArrayExpressionType -> "All expressions must be of the same type in an Array" - | ExpectedNumeralArrayBound -> "Array cannot have non numeral type as its bounds" | TypeMismatchOfRecordLabel (label, ty1, ty2) -> "Type mismatch. Type of record label '" ^ (HString.string_of_hstring label) ^ "' is of type " ^ string_of_tc_type ty1 ^ " but the type of the expression is " ^ string_of_tc_type ty2 | IlltypedRecordUpdate ty -> "Cannot do an update on non-record type " ^ string_of_tc_type ty @@ -183,8 +184,7 @@ let error_message kind = match kind with | AssumptionMustBeInputOrOutput id -> "Assumption variable must be either an input or an output variable, " ^ "but found '" ^ HString.string_of_hstring id ^ "'" | Redeclaration id -> HString.string_of_hstring id ^ " is already declared" - | ExpectedConstant e -> "Expression " ^ LA.string_of_expr e ^ " is not a constant expression" - | ArrayBoundsInvalidExpression -> "Invalid expression in array bounds" + | ExpectedConstant (where, what) -> "Illegal " ^ what ^ " in " ^ where | UndeclaredType id -> "Type '" ^ HString.string_of_hstring id ^ "' is undeclared" | EmptySubrange (v1, v2) -> "Range can not be empty, but found range: [" ^ string_of_int v1 ^ ", " ^ string_of_int v2 ^ "]" @@ -250,6 +250,126 @@ let check_merge_exhaustive: tc_context -> Lib.position -> LA.lustre_type -> HStr | _ -> type_error pos (Impossible ("Type " ^ string_of_tc_type ty ^ " must be an abstract type")) +let rec infer_const_attr ctx exp = + let r = infer_const_attr ctx in + let combine l1 l2 = List.map2 (fun r1 r2 -> r1 >> r2) l1 l2 in + let error exp what = + let pos = LH.pos_of_expr exp in + Error (pos, fun w -> ExpectedConstant (w, what)) + in + match exp with + | LA.Ident (_, i) -> + let res = + if member_val ctx i then R.ok () + else error exp ("variable '" ^ HString.string_of_hstring i ^ "'") + in + [res] + | ModeRef _ -> [error exp "mode reference"] + | RecordProject (_, e, _) -> r e + | TupleProject (_, e, _) -> r e + (* Values *) + | Const _ -> [R.ok ()] + (* Operators *) + | UnaryOp (_,_,e) -> r e + | BinaryOp (_,_, e1, e2) -> combine (r e1) (r e2) + | TernaryOp (_, Ite, e1, e2, e3) -> ( + let r_e2 = r e2 in + match r e1 with + | [Ok _] -> combine r_e2 (r e3) + | [err] -> List.map (fun _ -> err) r_e2 + | _ -> assert false + ) + | ConvOp (_,_,e) -> r e + | CompOp (_,_, e1, e2) -> combine (r e1) (r e2) + (* Structured expressions *) + | RecordExpr (_, _, flds) -> + List.fold_left + (fun l1 l2 -> combine l1 l2) + [R.ok ()] + (List.map r (snd (List.split flds))) + | GroupExpr (_, ArrayExpr, es) | GroupExpr (_, TupleExpr, es) -> + List.fold_left + (fun l1 l2 -> combine l1 l2) + [R.ok ()] + (List.map r es) + | GroupExpr (_, ExprList, es) -> List.flatten (List.map r es) + (* Update of structured expressions *) + | StructUpdate (_, e1, _, e2) -> combine (r e1) (r e2) + | ArrayConstr (_, e1, e2) -> combine (r e1) (r e2) + | ArrayIndex (_, e1, e2) -> combine (r e1) (r e2) + (* Quantified expressions *) + | Quantifier (_, _, _, _) -> + [error exp "quantified expression"] + (* Clock operators *) + | When (_, e, _) -> + List.map (fun _ -> error exp "when operator") (r e) + | Merge (_, _, es) -> + List.map (fun _ -> error exp "merge operator") + (r (List.hd (snd (List.split es)))) + (* Temporal operators *) + | Pre (_, e) -> + List.map (fun _ -> error exp "pre operator") (r e) + | Arrow (_, e1, _) -> + List.map (fun _ -> error exp "arrow operator") (r e1) + (* Node calls *) + | ChooseOp _ -> assert false + | Condact (_, _, _, i, _, _) + | Activate (_, i, _, _, _) + | RestartEvery (_, i, _, _) + | Call (_, i, _) -> ( + let err = error exp "node call or choose operator" in + match lookup_node_ty ctx i with + | Some (TArr (_, _, exp_ret_tys)) -> ( + match exp_ret_tys with + | GroupType (_, tys) -> List.map (fun _ -> err) tys + | _ -> [err] + ) + | _ -> [err] + ) + +let check_expr_is_constant ctx kind e = + match R.seq_ (infer_const_attr ctx e) with + | Ok _ -> R.ok () + | Error (pos, exn_fn) -> type_error pos (exn_fn kind) + +let check_and_add_constant_definition ctx i e ty = + match R.seq_ (infer_const_attr ctx e) with + | Ok _ -> R.ok (add_ty (add_const ctx i e ty) i ty) + | Error (pos, exn_fn) -> + let where = + "definition of constant '" ^ HString.string_of_hstring i ^ "'" + in + type_error pos (exn_fn where) + +let check_constant_args ctx i arg_exprs = + let check param_attr = + let arg_attr = + List.map (infer_const_attr ctx) arg_exprs + |> List.flatten + in + R.seq_chain + (fun _ ((id, is_const_param), res) -> + match is_const_param, res with + | true, Error (pos, exn_fn) -> ( + let where = + "argument for constant parameter '" ^ HString.string_of_hstring id ^ "'" + in + type_error pos (exn_fn where) + ) + | _ -> R.ok () + ) + () + (List.combine param_attr arg_attr) + in + match lookup_node_param_attr ctx i with + | None -> assert false + | Some param_attr -> ( + if List.exists (fun (_, is_const) -> is_const) param_attr then ( + check param_attr + ) + else R.ok () + ) + let rec infer_type_expr: tc_context -> LA.expr -> (tc_type, [> error]) result = fun ctx -> function (* Identifiers *) @@ -375,12 +495,11 @@ let rec infer_type_expr: tc_context -> LA.expr -> (tc_type, [> error]) result (type_error pos UnequalArrayExpressionType))) (* Update structured expressions *) - | LA.ArrayConstr (pos, b_expr, sup_expr) -> - infer_type_expr ctx b_expr - >>= (fun b_ty -> - if is_expr_int_type ctx sup_expr - then R.ok (LA.ArrayType (pos, (b_ty, sup_expr))) - else type_error pos ExpectedNumeralArrayBound) + | LA.ArrayConstr (pos, b_expr, sup_expr) -> ( + let* b_ty = infer_type_expr ctx b_expr in + check_array_size_expr ctx sup_expr + >> R.ok (LA.ArrayType (pos, (b_ty, sup_expr))) + ) | LA.StructUpdate (pos, r, i_or_ls, e) -> if List.length i_or_ls != 1 then type_error pos (Unsupported ("List of labels or indices for structure update is not supported")) @@ -461,26 +580,26 @@ let rec infer_type_expr: tc_context -> LA.expr -> (tc_type, [> error]) result (type_error pos (IlltypedArrow (ty1, ty2))) (* Node calls *) - | LA.Call (pos, i, arg_exprs) -> - Debug.parse "Inferring type for node call %a" LA.pp_print_ident i - ; let infer_type_node_args: tc_context -> LA.expr list -> (tc_type, [> error]) result - = fun ctx args -> - R.seq (List.map (infer_type_expr ctx) args) - >>= (fun arg_tys -> - if List.length arg_tys = 1 then R.ok (List.hd arg_tys) - else R.ok (LA.GroupType (pos, arg_tys))) in - (match (lookup_node_ty ctx i) with - | Some (TArr (_, exp_arg_tys, exp_ret_tys)) -> - infer_type_node_args ctx arg_exprs - >>= (fun given_arg_tys -> - R.ifM (eq_lustre_type ctx exp_arg_tys given_arg_tys) - (R.ok exp_ret_tys) - (type_error pos (IlltypedCall (exp_arg_tys, given_arg_tys)))) - | Some ty -> type_error pos (ExpectedFunctionType ty) - | None -> assert false) -(* | None -> type_error pos ("No node with name: " - ^ (HString.string_of_hstring i) - ^ " found")) *) + | LA.Call (pos, i, arg_exprs) -> ( + Debug.parse "Inferring type for node call %a" LA.pp_print_ident i ; + let infer_type_node_args: tc_context -> LA.expr list -> (tc_type, [> error]) result = + fun ctx args -> + let* arg_tys = R.seq (List.map (infer_type_expr ctx) args) in + if List.length arg_tys = 1 then R.ok (List.hd arg_tys) + else R.ok (LA.GroupType (pos, arg_tys)) + in + match (lookup_node_ty ctx i) with + | Some (TArr (_, exp_arg_tys, exp_ret_tys)) -> ( + let* given_arg_tys = infer_type_node_args ctx arg_exprs in + let* are_equal = eq_lustre_type ctx exp_arg_tys given_arg_tys in + if are_equal then + (check_constant_args ctx i arg_exprs >> (R.ok exp_ret_tys)) + else + (type_error pos (IlltypedCall (exp_arg_tys, given_arg_tys))) + ) + | Some ty -> type_error pos (ExpectedFunctionType ty) + | None -> type_error pos (UnboundNodeName i) + ) (** Infer the type of a [LA.expr] with the types of free variables given in [tc_context] *) and check_type_expr: tc_context -> LA.expr -> tc_type -> (unit, [> error]) result @@ -652,9 +771,7 @@ and check_type_expr: tc_context -> LA.expr -> tc_type -> (unit, [> error]) resul let arg_ty = if List.length arg_tys = 1 then List.hd arg_tys else GroupType (pos, arg_tys) in (match (lookup_node_ty ctx i) with - | None -> type_error pos (Impossible ("No node/function with name " - ^ (HString.string_of_hstring i) - ^ " found.")) + | None -> type_error pos (UnboundNodeName i) | Some ty -> R.guard_with (eq_lustre_type ctx ty (LA.TArr (pos, arg_ty, exp_ty))) (type_error pos (MismatchedNodeType (i, (TArr (pos, arg_ty, exp_ty)), ty)))) @@ -1193,21 +1310,17 @@ and tc_ctx_const_decl: tc_context -> LA.const_decl -> (tc_context, [> error]) re | LA.UntypedConst (pos, i, e) -> if member_ty ctx i then type_error pos (Redeclaration i) - else + else ( let* ty = infer_type_expr ctx e in - if (is_expr_of_consts ctx e) then - R.ok (add_ty (add_const ctx i e ty) i ty) - else - type_error pos (ExpectedConstant e) - + check_and_add_constant_definition ctx i e ty + ) | LA.TypedConst (pos, i, e, exp_ty) -> + check_type_well_formed ctx exp_ty >> if member_ty ctx i then type_error pos (Redeclaration i) else check_type_expr (add_ty ctx i exp_ty) e exp_ty - >> if (is_expr_of_consts ctx e) - then R.ok (add_ty (add_const ctx i e exp_ty) i exp_ty) - else type_error pos (ExpectedConstant e) + >> check_and_add_constant_definition ctx i e exp_ty (** Fail if a duplicate constant is detected *) and tc_ctx_contract_vars: tc_context -> LA.contract_ghost_vars -> (tc_context, [> error]) result @@ -1265,10 +1378,11 @@ and tc_ctx_of_node_decl: Lib.position -> tc_context -> LA.node_decl -> (tc_conte Debug.parse "Extracting type of node declaration: %a" LA.pp_print_ident nname - ; if (member_node ctx nname) - then type_error pos (Redeclaration nname) - else build_node_fun_ty pos ctx ip op >>= fun fun_ty -> - R.ok (add_ty_node ctx nname fun_ty) + ; + if (member_node ctx nname) + then type_error pos (Redeclaration nname) + else build_node_fun_ty pos ctx ip op >>= fun fun_ty -> + R.ok (let ctx = add_ty_node ctx nname fun_ty in add_node_param_attr ctx nname ip) (** computes the type signature of node or a function and its node summary*) and tc_ctx_contract_node_eqn ?(ignore_modes = false) ctx = @@ -1368,7 +1482,26 @@ and build_type_and_const_context (* : tc_context -> LA.t -> (tc_context, [> erro | _ :: rest -> build_type_and_const_context ctx rest (** Process top level type declarations and make a type context with * user types, enums populated *) - + +and check_const_integer_expr ctx kind e = + match infer_type_expr ctx e with + | Error (`LustreTypeCheckerError (pos, UnboundNodeName _)) -> + type_error pos + (ExpectedConstant (kind, "node call or choose operator")) + | Ok ty -> + let* eq = eq_lustre_type ctx ty (LA.Int (LH.pos_of_expr e)) in + if eq then + check_expr_is_constant ctx kind e + else + type_error (LH.pos_of_expr e) (ExpectedIntegerExpression ty) + | Error err -> Error err + +and check_array_size_expr ctx e = + check_const_integer_expr ctx "array size expression" e + +and check_range_bound ctx e = + check_const_integer_expr ctx "subrange bound" e + and check_type_well_formed: tc_context -> tc_type -> (unit, [> error]) result = fun ctx -> function @@ -1378,10 +1511,10 @@ and check_type_well_formed: tc_context -> tc_type -> (unit, [> error]) result | LA.RecordType (_, _, idTys) -> (R.seq_ (List.map (fun (_, _, ty) -> check_type_well_formed ctx ty) idTys)) - | LA.ArrayType (pos, (b_ty, s)) -> - if is_expr_int_type ctx s && is_expr_of_consts ctx s - then check_type_well_formed ctx b_ty - else type_error pos ArrayBoundsInvalidExpression + | LA.ArrayType (_, (b_ty, s)) -> ( + check_array_size_expr ctx s + >> check_type_well_formed ctx b_ty + ) | LA.TupleType (_, tys) -> R.seq_ (List.map (check_type_well_formed ctx) tys) | LA.GroupType (_, tys) -> @@ -1391,20 +1524,14 @@ and check_type_well_formed: tc_context -> tc_type -> (unit, [> error]) result then R.ok () else type_error pos (UndeclaredType i) | LA.IntRange (pos, e1, e2) -> ( match e1, e2 with - | Some e1, Some e2 -> - if is_expr_int_type ctx e1 && is_expr_of_consts ctx e1 then - if is_expr_int_type ctx e2 && is_expr_of_consts ctx e2 then - let v1 = IC.eval_int_expr ctx e1 in - let v2 = IC.eval_int_expr ctx e2 in - v1 >>= fun v1 -> v2 >>= fun v2 -> - if v1 > v2 then - type_error pos (EmptySubrange (v1, v2)) - else Ok () - else type_error pos (SubrangeArgumentMustBeConstantInteger e2) - else type_error pos (SubrangeArgumentMustBeConstantInteger e1) - | Some e1, None -> if is_expr_int_type ctx e1 && is_expr_of_consts ctx e1 then Ok () else type_error pos (SubrangeArgumentMustBeConstantInteger e1) - | None, Some e2 -> if is_expr_int_type ctx e2 && is_expr_of_consts ctx e2 then Ok () else type_error pos (SubrangeArgumentMustBeConstantInteger e2) | None, None -> type_error pos IntervalMustHaveBound + | Some e, None | None, Some e -> + check_range_bound ctx e >> IC.eval_int_expr ctx e >> Ok () + | Some e1, Some e2 -> + check_range_bound ctx e1 >> check_range_bound ctx e2 >> + let* v1 = IC.eval_int_expr ctx e1 in + let* v2 = IC.eval_int_expr ctx e2 in + if v1 > v2 then type_error pos (EmptySubrange (v1, v2)) else Ok () ) | _ -> R.ok () (** Does it make sense to have this type i.e. is it inhabited? @@ -1514,12 +1641,6 @@ and is_expr_int_type: tc_context -> LA.expr -> bool = fun ctx e -> (** Checks if the expr is of type Int. This will be useful * in evaluating array sizes that we need to have as constant integers * while declaring the array type *) - -and is_expr_of_consts: tc_context -> LA.expr -> bool = fun ctx e -> - not (LH.expr_contains_call e) && - List.map (member_val ctx) (LA.SI.elements (LH.vars_without_node_call_ids e)) - |> List.fold_left (&&) true -(** checks if the expression only contains constant variables *) and eq_typed_ident: tc_context -> LA.typed_ident -> LA.typed_ident -> (bool, [> error]) result = fun ctx (_, _, ty1) (_, _, ty2) -> eq_lustre_type ctx ty1 ty2 diff --git a/src/lustre/lustreTypeChecker.mli b/src/lustre/lustreTypeChecker.mli index caca9adb8..1149be7ea 100644 --- a/src/lustre/lustreTypeChecker.mli +++ b/src/lustre/lustreTypeChecker.mli @@ -29,6 +29,7 @@ type error_kind = Unknown of string | MergeCaseNotUnique of HString.t | UnboundIdentifier of HString.t | UnboundModeReference of HString.t + | UnboundNodeName of HString.t | NotAFieldOfRecord of HString.t | NoValueForRecordField of HString.t | IlltypedRecordProjection of tc_type @@ -36,9 +37,9 @@ type error_kind = Unknown of string | IlltypedTupleProjection of tc_type | UnequalIteBranchTypes of tc_type * tc_type | ExpectedBooleanExpression of tc_type + | ExpectedIntegerExpression of tc_type | Unsupported of string | UnequalArrayExpressionType - | ExpectedNumeralArrayBound | TypeMismatchOfRecordLabel of HString.t * tc_type * tc_type | IlltypedRecordUpdate of tc_type | ExpectedLabel of LA.expr @@ -74,8 +75,7 @@ type error_kind = Unknown of string | DisallowedSubrangeInContractReturn of bool * HString.t * tc_type | AssumptionMustBeInputOrOutput of HString.t | Redeclaration of HString.t - | ExpectedConstant of LA.expr - | ArrayBoundsInvalidExpression + | ExpectedConstant of string * string | UndeclaredType of HString.t | EmptySubrange of int * int | SubrangeArgumentMustBeConstantInteger of LA.expr diff --git a/src/lustre/typeCheckerContext.ml b/src/lustre/typeCheckerContext.ml index 68abeb0b8..b6123361f 100644 --- a/src/lustre/typeCheckerContext.ml +++ b/src/lustre/typeCheckerContext.ml @@ -58,14 +58,18 @@ type ty_set = SI.t type contract_exports = (ty_store) IMap.t (** Mapping for all the exports of the contract, modes and contract ghost const and vars *) -type tc_context = { ty_syns: ty_alias_store (* store of the type alias mappings *) - ; ty_ctx: ty_store (* store of the types of identifiers and nodes *) - ; contract_ctx: ty_store (* store of the types of contracts *) - ; node_ctx: ty_store (* store of the types of nodes *) - ; vl_ctx: const_store (* store of typed constants to its value *) - ; u_types: ty_set (* store of all declared user types, - this is poor mans kind (type of type) context *) - ; contract_export_ctx: (* stores all the export variables of the contract *) +type param_store = (HString.t * bool) list IMap.t +(** A store of parameter names and flags indicating if the argument is constant *) + +type tc_context = { ty_syns: ty_alias_store (* store of the type alias mappings *) + ; ty_ctx: ty_store (* store of the types of identifiers and nodes *) + ; contract_ctx: ty_store (* store of the types of contracts *) + ; node_ctx: ty_store (* store of the types of nodes *) + ; node_param_attr: param_store (* store of the parameter attributes of nodes *) + ; vl_ctx: const_store (* store of typed constants to its value *) + ; u_types: ty_set (* store of all declared user types, + this is poor mans kind (type of type) context *) + ; contract_export_ctx: (* stores all the export variables of the contract *) contract_exports ; enum_vars:enum_variants } @@ -76,6 +80,7 @@ let empty_tc_context: tc_context = ; ty_ctx = IMap.empty ; contract_ctx = IMap.empty ; node_ctx = IMap.empty + ; node_param_attr = IMap.empty ; vl_ctx = IMap.empty ; u_types = SI.empty ; contract_export_ctx = IMap.empty @@ -169,7 +174,10 @@ let lookup_contract_ty: tc_context -> LA.ident -> tc_type option let lookup_node_ty: tc_context -> LA.ident -> tc_type option = fun ctx i -> IMap.find_opt i (ctx.node_ctx) (** Lookup a node type *) - + +let lookup_node_param_attr: tc_context -> LA.ident -> (HString.t * bool) list option + = fun ctx i -> IMap.find_opt i (ctx.node_param_attr) + let lookup_const: tc_context -> LA.ident -> (LA.expr * tc_type option) option = fun ctx i -> IMap.find_opt i (ctx.vl_ctx) (** Lookup a constant identifier *) @@ -192,8 +200,15 @@ let add_ty_contract: tc_context -> LA.ident -> tc_type -> tc_context let add_ty_node: tc_context -> LA.ident -> tc_type -> tc_context = fun ctx i ty -> {ctx with node_ctx = IMap.add i ty (ctx.node_ctx)} -(** Add the type of the contract *) - +(** Add the type of the node *) + +let add_node_param_attr : tc_context -> LA.ident -> LA.const_clocked_typed_decl list -> tc_context + = fun ctx i args -> + let v = + List.map (function (_, id, _, _, is_const) -> (id, is_const)) args + in + {ctx with node_param_attr = IMap.add i v (ctx.node_param_attr)} + let add_ty_decl: tc_context -> LA.ident -> tc_context = fun ctx i -> {ctx with u_types = SI.add i (ctx.u_types)} (** Add a user declared type in the typing context *) @@ -226,6 +241,9 @@ let union: tc_context -> tc_context -> tc_context ; node_ctx = (IMap.union (fun _ _ v2 -> Some v2) (ctx1.node_ctx) (ctx2.node_ctx)) + ; node_param_attr = (IMap.union (fun _ _ v2 -> Some v2) + (ctx1.node_param_attr) + (ctx2.node_param_attr)) ; vl_ctx = (IMap.union (fun _ _ v2 -> Some v2) (ctx1.vl_ctx) (ctx2.vl_ctx)) @@ -377,4 +395,4 @@ let rec traverse_group_expr_list f ctx proj es = if a<=i then traverse_group_expr_list f ctx (i-a) es else f proj e ) - | _ -> assert false \ No newline at end of file + | _ -> assert false diff --git a/src/lustre/typeCheckerContext.mli b/src/lustre/typeCheckerContext.mli index 5b424b344..05fd96590 100644 --- a/src/lustre/typeCheckerContext.mli +++ b/src/lustre/typeCheckerContext.mli @@ -51,7 +51,10 @@ type ty_set = SI.t type contract_exports = (ty_store) IMap.t (** Mapping for all the exports of the contract, modes and contract ghost const and vars *) - + +type param_store = (HString.t * bool) list IMap.t +(** A store of parameter names and flags indicating if the argument is constant *) + type tc_context val empty_tc_context: tc_context @@ -100,6 +103,8 @@ val lookup_contract_ty: tc_context -> LA.ident -> tc_type option val lookup_node_ty: tc_context -> LA.ident -> tc_type option (** Lookup a node type *) +val lookup_node_param_attr: tc_context -> LA.ident -> (HString.t * bool) list option + val lookup_const: tc_context -> LA.ident -> (LA.expr * tc_type option) option (** Lookup a constant identifier *) @@ -114,7 +119,10 @@ val add_ty: tc_context -> LA.ident -> tc_type -> tc_context val add_ty_node: tc_context -> LA.ident -> tc_type -> tc_context (** Add node/function type binding into the typing context *) - + +val add_node_param_attr: tc_context -> LA.ident -> LA.const_clocked_typed_decl list -> tc_context +(** Track whether node parameters are constant or not *) + val add_ty_contract: tc_context -> LA.ident -> tc_type -> tc_context (** Add the type of the contract *) diff --git a/tests/ounit/lustre/lustreTypeChecker/bad_array_size_1.lus b/tests/ounit/lustre/lustreTypeChecker/bad_array_size_1.lus new file mode 100644 index 000000000..229ed90bc --- /dev/null +++ b/tests/ounit/lustre/lustreTypeChecker/bad_array_size_1.lus @@ -0,0 +1,7 @@ + +node N() returns (y:int); +let + y = 4; +tel + +const A: int^(N()); diff --git a/tests/ounit/lustre/lustreTypeChecker/bad_array_size_2.lus b/tests/ounit/lustre/lustreTypeChecker/bad_array_size_2.lus new file mode 100644 index 000000000..593b792ec --- /dev/null +++ b/tests/ounit/lustre/lustreTypeChecker/bad_array_size_2.lus @@ -0,0 +1,4 @@ + +node M() returns (A: int^(0 -> 1)); +let +tel \ No newline at end of file diff --git a/tests/ounit/lustre/lustreTypeChecker/bad_array_size_3.lus b/tests/ounit/lustre/lustreTypeChecker/bad_array_size_3.lus new file mode 100644 index 000000000..664bbfee1 --- /dev/null +++ b/tests/ounit/lustre/lustreTypeChecker/bad_array_size_3.lus @@ -0,0 +1,5 @@ + + +node M() returns (A: int^(pre 0)); +let +tel \ No newline at end of file diff --git a/tests/ounit/lustre/lustreTypeChecker/bad_subrange_bound_1.lus b/tests/ounit/lustre/lustreTypeChecker/bad_subrange_bound_1.lus new file mode 100644 index 000000000..2b9998994 --- /dev/null +++ b/tests/ounit/lustre/lustreTypeChecker/bad_subrange_bound_1.lus @@ -0,0 +1,7 @@ + + +node N() returns (ok:bool); +var s: subrange [0,X()] of int; +let + s = 0; +tel \ No newline at end of file diff --git a/tests/ounit/lustre/lustreTypeChecker/const_param_1.lus b/tests/ounit/lustre/lustreTypeChecker/const_param_1.lus new file mode 100644 index 000000000..fc3dacec0 --- /dev/null +++ b/tests/ounit/lustre/lustreTypeChecker/const_param_1.lus @@ -0,0 +1,9 @@ + +node imported F() returns (y:int); + +node imported N(const n: int) returns (y:int); + +node M() returns (z:int); +let + z=N(F()); +tel \ No newline at end of file diff --git a/tests/ounit/lustre/lustreTypeChecker/const_param_2.lus b/tests/ounit/lustre/lustreTypeChecker/const_param_2.lus new file mode 100644 index 000000000..0d970ff7b --- /dev/null +++ b/tests/ounit/lustre/lustreTypeChecker/const_param_2.lus @@ -0,0 +1,9 @@ + +node imported F() returns (y:int); + +node imported N(const n: int) returns (y:int); + +node M() returns (z:int); +let + z=N(0 -> 1); +tel \ No newline at end of file diff --git a/tests/ounit/lustre/lustreTypeChecker/const_param_3.lus b/tests/ounit/lustre/lustreTypeChecker/const_param_3.lus new file mode 100644 index 000000000..f08405bd6 --- /dev/null +++ b/tests/ounit/lustre/lustreTypeChecker/const_param_3.lus @@ -0,0 +1,8 @@ + + +node imported N(const n: int) returns (y:int); + +node M(x: int) returns (y:int); +let + y = N(x); +tel diff --git a/tests/ounit/lustre/lustreTypeChecker/node_call_in_array_size_expr.lus b/tests/ounit/lustre/lustreTypeChecker/node_call_in_array_size_expr.lus new file mode 100644 index 000000000..b6a469148 --- /dev/null +++ b/tests/ounit/lustre/lustreTypeChecker/node_call_in_array_size_expr.lus @@ -0,0 +1,8 @@ + + +node N() returns (y:int); +let + y = 4; +tel + +const A: int^(N()); \ No newline at end of file diff --git a/tests/ounit/lustre/lustreTypeChecker/symbolic_subrange_bound.lus b/tests/ounit/lustre/lustreTypeChecker/symbolic_subrange_bound.lus new file mode 100644 index 000000000..0742a14aa --- /dev/null +++ b/tests/ounit/lustre/lustreTypeChecker/symbolic_subrange_bound.lus @@ -0,0 +1,5 @@ + +const X: int; +const M: subrange[0, X] of int = 1; + +node imported N() returns (ok:bool); \ No newline at end of file diff --git a/tests/ounit/lustre/lustreTypeChecker/symbolic_subrange_bound_2.lus b/tests/ounit/lustre/lustreTypeChecker/symbolic_subrange_bound_2.lus new file mode 100644 index 000000000..c0690b648 --- /dev/null +++ b/tests/ounit/lustre/lustreTypeChecker/symbolic_subrange_bound_2.lus @@ -0,0 +1,8 @@ + +const X: int; + +node N() returns (ok:bool); +var s: subrange [X,*] of int; +let + s = 0; +tel \ No newline at end of file diff --git a/tests/ounit/lustre/testLustreFrontend.ml b/tests/ounit/lustre/testLustreFrontend.ml index a70ec8002..4b7b18a5d 100644 --- a/tests/ounit/lustre/testLustreFrontend.ml +++ b/tests/ounit/lustre/testLustreFrontend.ml @@ -52,6 +52,14 @@ let _ = run_test_tt_main ("frontend LustreAstInlineConstants error tests" >::: [ match load_file "./lustreAstInlineConstants/test_access_out_of_bounds.lus" with | Error (`LustreAstInlineConstantsError (_, OutOfBounds _)) -> true | _ -> false); + mk_test "test symbolic subrange bound 1" (fun () -> + match load_file "./lustreTypeChecker/symbolic_subrange_bound.lus" with + | Error (`LustreAstInlineConstantsError (_, FreeIntIdentifier _)) -> true + | _ -> false); + mk_test "test symbolic subrange bound 2" (fun () -> + match load_file "./lustreTypeChecker/symbolic_subrange_bound_2.lus" with + | Error (`LustreAstInlineConstantsError (_, FreeIntIdentifier _)) -> true + | _ -> false); ]) (* *************************************************************************** *) @@ -66,14 +74,6 @@ let _ = run_test_tt_main ("frontend LustreSyntaxChecks error tests" >::: [ match load_file "./lustreSyntaxChecks/unsupported_when.lus" with | Error (`LustreSyntaxChecksError (_, UnsupportedWhen _)) -> true | _ -> false); - mk_test "test temporal op in const" (fun () -> - match load_file "./lustreSyntaxChecks/const_not_const.lus" with - | Error (`LustreSyntaxChecksError (_, IllegalTemporalOperator _)) -> true - | _ -> false); - mk_test "test temporal op in ghost const" (fun () -> - match load_file "./lustreSyntaxChecks/ghost_const_not_const.lus" with - | Error (`LustreSyntaxChecksError (_, IllegalTemporalOperator _)) -> true - | _ -> false); mk_test "test undefined node" (fun () -> match load_file "./lustreSyntaxChecks/dangling_call_in_ghost_var.lus" with | Error (`LustreSyntaxChecksError (_, UndefinedNode _)) -> true @@ -487,13 +487,13 @@ let _ = run_test_tt_main ("frontend LustreTypeChecker error tests" >::: [ match load_file "./lustreTypeChecker/test_array_sizes.lus" with | Error (`LustreTypeCheckerError (_, ExpectedType _)) -> true | _ -> false); - mk_test "test invalid array bounds" (fun () -> + mk_test "test invalid type for array size 1" (fun () -> match load_file "./lustreTypeChecker/test_const_bool_in_array_type.lus" with - | Error (`LustreTypeCheckerError (_, ArrayBoundsInvalidExpression)) -> true + | Error (`LustreTypeCheckerError (_, ExpectedIntegerExpression _)) -> true | _ -> false); mk_test "test range type integer arguments" (fun () -> match load_file "./lustreTypeChecker/test_const_decls_tyalias.lus" with - | Error (`LustreTypeCheckerError (_, SubrangeArgumentMustBeConstantInteger _)) -> true + | Error (`LustreTypeCheckerError (_, UnboundIdentifier _)) -> true | _ -> false); mk_test "test unification failure 1" (fun () -> match load_file "./lustreTypeChecker/test_homeomorphic_exn_array.lus" with @@ -531,9 +531,13 @@ let _ = run_test_tt_main ("frontend LustreTypeChecker error tests" >::: [ match load_file "./lustreTypeChecker/test-type.lus" with | Error (`LustreTypeCheckerError (_, UnequalArrayExpressionType)) -> true | _ -> false); - mk_test "test invalid array bounds 2" (fun () -> + mk_test "test invalid type for array size 2" (fun () -> match load_file "./lustreTypeChecker/type-grammer.lus" with - | Error (`LustreTypeCheckerError (_, ArrayBoundsInvalidExpression)) -> true + | Error (`LustreTypeCheckerError (_, ExpectedIntegerExpression _)) -> true + | _ -> false); + mk_test "test invalid expression for array size 1" (fun () -> + match load_file "./lustreTypeChecker/node_call_in_array_size_expr.lus" with + | Error (`LustreTypeCheckerError (_, ExpectedConstant _)) -> true | _ -> false); mk_test "test undeclared 1" (fun () -> match load_file "./lustreTypeChecker/undeclared_type_01.lus" with @@ -599,6 +603,42 @@ let _ = run_test_tt_main ("frontend LustreTypeChecker error tests" >::: [ match load_file "./lustreTypeChecker/nondeterministic_choice_2.lus" with | Error (`LustreTypeCheckerError (_, UnificationFailed _)) -> true | _ -> false); + mk_test "test temporal op in const" (fun () -> + match load_file "./lustreSyntaxChecks/const_not_const.lus" with + | Error (`LustreTypeCheckerError (_, ExpectedConstant _)) -> true + | _ -> false); + mk_test "test temporal op in ghost const" (fun () -> + match load_file "./lustreSyntaxChecks/ghost_const_not_const.lus" with + | Error (`LustreTypeCheckerError (_, ExpectedConstant _)) -> true + | _ -> false); + mk_test "test illegal node call in array size expression" (fun () -> + match load_file "./lustreTypeChecker/bad_array_size_1.lus" with + | Error (`LustreTypeCheckerError (_, ExpectedConstant _)) -> true + | _ -> false); + mk_test "test illegal arrow operator in array size expression" (fun () -> + match load_file "./lustreTypeChecker/bad_array_size_2.lus" with + | Error (`LustreTypeCheckerError (_, ExpectedConstant _)) -> true + | _ -> false); + mk_test "test illegal pre operator in array size expression" (fun () -> + match load_file "./lustreTypeChecker/bad_array_size_3.lus" with + | Error (`LustreTypeCheckerError (_, ExpectedConstant _)) -> true + | _ -> false); + mk_test "test illegal node call in argument for constant parameter" (fun () -> + match load_file "./lustreTypeChecker/const_param_1.lus" with + | Error (`LustreTypeCheckerError (_, ExpectedConstant _)) -> true + | _ -> false); + mk_test "test illegal arrow operator in argument for constant parameter" (fun () -> + match load_file "./lustreTypeChecker/const_param_2.lus" with + | Error (`LustreTypeCheckerError (_, ExpectedConstant _)) -> true + | _ -> false); + mk_test "test illegal variable in argument for constant parameter" (fun () -> + match load_file "./lustreTypeChecker/const_param_3.lus" with + | Error (`LustreTypeCheckerError (_, ExpectedConstant _)) -> true + | _ -> false); + mk_test "test illegal node call in subrange bound" (fun () -> + match load_file "./lustreTypeChecker/bad_subrange_bound_1.lus" with + | Error (`LustreTypeCheckerError (_, ExpectedConstant _)) -> true + | _ -> false); ]) (* *************************************************************************** *) diff --git a/tests/regression/success/const_parameters.lus b/tests/regression/success/const_parameters.lus new file mode 100644 index 000000000..a69df038b --- /dev/null +++ b/tests/regression/success/const_parameters.lus @@ -0,0 +1,18 @@ + +node F() returns (x,y:int) +let + x = 1; y = 2; +tel + +node N(x,y:int;const n: int;m:int) returns (z:int); +let + z = if C>0 then y+m else x+n; +tel + +const C, W:int; + +node M(t:int) returns (z:int); +let + z=N(F(), if C>0 then (W+3, t+1) else (2*C, C)); + check z=t+3 or z=2*C+1; +tel diff --git a/tests/regression/success/test_activate_every.lus b/tests/regression/success/test_activate_every.lus index 60b15a13d..0c62a1d15 100644 --- a/tests/regression/success/test_activate_every.lus +++ b/tests/regression/success/test_activate_every.lus @@ -1,4 +1,4 @@ -node sum_ge_10 (const in: int) returns (out: bool) ; +node sum_ge_10 (in: int) returns (out: bool) ; var sum: int ; let sum = in + (0 -> pre sum) ;