diff --git a/doc/changelog/06-Ltac2-language/18311-ltac2-ikfprintf.rst b/doc/changelog/06-Ltac2-language/18311-ltac2-ikfprintf.rst new file mode 100644 index 000000000000..16956b2cd045 --- /dev/null +++ b/doc/changelog/06-Ltac2-language/18311-ltac2-ikfprintf.rst @@ -0,0 +1,6 @@ +- **Added:** + `Ltac2.Message.Format.ikfprintf` useful to implement conditional printing + efficiently (i.e. without building an unused message when not printing) + (`#18311 `_, + fixes `#18292 `_, + by Gaƫtan Gilbert). diff --git a/plugins/ltac2/tac2core.ml b/plugins/ltac2/tac2core.ml index ba23f7847590..7e65637f52f6 100644 --- a/plugins/ltac2/tac2core.ml +++ b/plugins/ltac2/tac2core.ml @@ -299,17 +299,21 @@ let () = define "format_alpha" (format @-> ret format) @@ fun s -> Tac2print.FmtAlpha :: s -let () = - define "format_kfprintf" (closure @-> format @-> tac valexpr) @@ fun k fmt -> +let arity_of_format fmt = let open Tac2print in let fold accu = function - | FmtLiteral _ -> accu - | FmtString | FmtInt | FmtConstr | FmtIdent -> 1 + accu - | FmtAlpha -> 2 + accu + | FmtLiteral _ -> accu + | FmtString | FmtInt | FmtConstr | FmtIdent -> 1 + accu + | FmtAlpha -> 2 + accu in + List.fold_left fold 0 fmt + +let () = + define "format_kfprintf" (closure @-> format @-> tac valexpr) @@ fun k fmt -> + let open Tac2print in let pop1 l = match l with [] -> assert false | x :: l -> (x, l) in let pop2 l = match l with [] | [_] -> assert false | x :: y :: l -> (x, y, l) in - let arity = List.fold_left fold 0 fmt in + let arity = arity_of_format fmt in let rec eval accu args fmt = match fmt with | [] -> apply k [of_pp accu] | tag :: fmt -> @@ -344,6 +348,13 @@ let () = if Int.equal arity 0 then eval [] else return (Tac2ffi.of_closure (Tac2ffi.abstract arity eval)) +let () = + define "format_ikfprintf" (closure @-> valexpr @-> format @-> tac valexpr) @@ fun k v fmt -> + let arity = arity_of_format fmt in + let eval _args = apply k [v] in + if Int.equal arity 0 then eval [] + else return (Tac2ffi.of_closure (Tac2ffi.abstract arity eval)) + (** Array *) let () = define "array_empty" (unit @-> ret valexpr) (fun _ -> v_blk 0 [||]) diff --git a/test-suite/ltac2/printf.v b/test-suite/ltac2/printf.v index f96a01a9c987..a81f222ad5fc 100644 --- a/test-suite/ltac2/printf.v +++ b/test-suite/ltac2/printf.v @@ -29,3 +29,20 @@ Fail Ltac2 Eval printf "%I" "foo". Fail Ltac2 Eval printf "%t" "foo". Fail Ltac2 Eval printf "%a" (fun _ _ => ()). Fail Ltac2 Eval printf "%a" (fun _ i => Message.of_int i) "foo". + +Import Message. + +Ltac2 print_if b fmt := + if b then Format.kfprintf Message.to_string fmt + else Format.ikfprintf Message.to_string (Message.of_string "") fmt. + +Ltac2 Notation "print_if" b(tactic(0)) fmt(format) := print_if b fmt. + +Ltac2 Eval Control.assert_true (String.equal "hello friend" (print_if true "hello %s" "friend")). + +Ltac2 Eval Control.assert_true (String.equal "" (print_if false "hello %s" "friend")). + +Fail Ltac2 Eval print_if true "%a" (fun _ => Control.throw Assertion_failure) (). + +(* ikfprintf doesn't run the closure *) +Ltac2 Eval print_if false "%a" (fun _ => Control.throw Assertion_failure) (). diff --git a/user-contrib/Ltac2/Message.v b/user-contrib/Ltac2/Message.v index ffefc482509e..09188a41c9f7 100644 --- a/user-contrib/Ltac2/Message.v +++ b/user-contrib/Ltac2/Message.v @@ -55,4 +55,7 @@ Ltac2 @ external alpha : ('a, 'b, 'c, 'd) format -> Ltac2 @ external kfprintf : (message -> 'r) -> ('a, unit, message, 'r) format -> 'a := "coq-core.plugins.ltac2" "format_kfprintf". +Ltac2 @ external ikfprintf : ('v -> 'r) -> 'v -> ('a, unit, 'v, 'r) format -> 'a := + "coq-core.plugins.ltac2" "format_ikfprintf". + End Format.