From bb24637c9349ae4065b6583cb7ac5cf6983a882a Mon Sep 17 00:00:00 2001 From: Timothy Hoffman <4001421+tim-hoffman@users.noreply.github.com> Date: Wed, 10 Jul 2024 08:10:07 -0500 Subject: [PATCH] Cache extracted loop body functions for reuse (#131) Also: * Refactor storing of global data for extracted loop functions * Refactor context kind into the EnvRecorder * Remove unnecessary derives * Rename some things for clarity --- ...om => multiuse_func_with_loop_diff.circom} | 0 .../multiuse_func_with_loop_same.circom | 247 +++++++++++++ .../env/unrolled_block_env.rs | 6 +- circuit_passes/src/passes/checks.rs | 1 + .../src/passes/loop_unroll/body_extractor.rs | 342 +++++++++++------- .../passes/loop_unroll/loop_env_recorder.rs | 32 +- circuit_passes/src/passes/loop_unroll/mod.rs | 15 +- circuit_passes/src/passes/mod.rs | 13 +- 8 files changed, 491 insertions(+), 165 deletions(-) rename circom/tests/controlflow/{multiuse_func_with_loop.circom => multiuse_func_with_loop_diff.circom} (100%) create mode 100644 circom/tests/controlflow/multiuse_func_with_loop_same.circom diff --git a/circom/tests/controlflow/multiuse_func_with_loop.circom b/circom/tests/controlflow/multiuse_func_with_loop_diff.circom similarity index 100% rename from circom/tests/controlflow/multiuse_func_with_loop.circom rename to circom/tests/controlflow/multiuse_func_with_loop_diff.circom diff --git a/circom/tests/controlflow/multiuse_func_with_loop_same.circom b/circom/tests/controlflow/multiuse_func_with_loop_same.circom new file mode 100644 index 000000000..c08a40479 --- /dev/null +++ b/circom/tests/controlflow/multiuse_func_with_loop_same.circom @@ -0,0 +1,247 @@ +pragma circom 2.0.0; +// REQUIRES: circom +// RUN: rm -rf %t && mkdir %t && %circom --llvm -o %t %s | sed -n 's/.*Written successfully:.* \(.*\)/\1/p' | xargs cat | FileCheck %s --enable-var-scope + +// %0 = [ s[0], s[1], s[2], s[3], s[4], s[5], s[6], s[7], s[8], s[9], n, sum, i ] +function f(s, n) { + var sum = 0; + for (var i = 0; i < n; i++) { + sum += s[i]; + } + return sum; +} + +template MultiUse() { + signal input inp[10]; + signal output outp[3]; + + outp[0] <-- f(inp, 2); + outp[1] <-- f(inp, 2); + outp[2] <-- f(inp, 2); +} + +component main = MultiUse(); + +//CHECK-LABEL: define{{.*}} void @..generated..loop.body. +//CHECK-SAME: [[$F_ID_1:[0-9]+]]([0 x i256]* %lvars, [0 x i256]* %signals, i256* %var_0){{.*}} { +//CHECK-NEXT: ..generated..loop.body.[[$F_ID_1]]: +//CHECK-NEXT: br label %store1 +//CHECK-EMPTY: +//CHECK-NEXT: store1: +//CHECK-NEXT: %[[T00:[0-9a-zA-Z_\.]+]] = getelementptr [0 x i256], [0 x i256]* %lvars, i32 0, i32 11 +//CHECK-NEXT: %[[T01:[0-9a-zA-Z_\.]+]] = getelementptr [0 x i256], [0 x i256]* %lvars, i32 0, i32 11 +//CHECK-NEXT: %[[T02:[0-9a-zA-Z_\.]+]] = load i256, i256* %[[T01]], align 4 +//CHECK-NEXT: %[[T03:[0-9a-zA-Z_\.]+]] = getelementptr i256, i256* %var_0, i32 0 +//CHECK-NEXT: %[[T04:[0-9a-zA-Z_\.]+]] = load i256, i256* %[[T03]], align 4 +//CHECK-NEXT: %[[T08:[0-9a-zA-Z_\.]+]] = call i256 @fr_add(i256 %[[T02]], i256 %[[T04]]) +//CHECK-NEXT: store i256 %[[T08]], i256* %[[T00]], align 4 +//CHECK-NEXT: br label %store2 +//CHECK-EMPTY: +//CHECK-NEXT: store2: +//CHECK-NEXT: %[[T05:[0-9a-zA-Z_\.]+]] = getelementptr [0 x i256], [0 x i256]* %lvars, i32 0, i32 12 +//CHECK-NEXT: %[[T06:[0-9a-zA-Z_\.]+]] = getelementptr [0 x i256], [0 x i256]* %lvars, i32 0, i32 12 +//CHECK-NEXT: %[[T07:[0-9a-zA-Z_\.]+]] = load i256, i256* %[[T06]], align 4 +//CHECK-NEXT: %[[T09:[0-9a-zA-Z_\.]+]] = call i256 @fr_add(i256 %[[T07]], i256 1) +//CHECK-NEXT: store i256 %[[T09]], i256* %[[T05]], align 4 +//CHECK-NEXT: br label %return3 +//CHECK-EMPTY: +//CHECK-NEXT: return3: +//CHECK-NEXT: ret void +//CHECK-NEXT: } +// +//CHECK-LABEL: define{{.*}} i256 @f_0.2(i256* %0){{.*}} { +//CHECK-NEXT: f_0.2: +//CHECK-NEXT: br label %store1 +//CHECK-EMPTY: +//CHECK-NEXT: store1: +//CHECK-NEXT: %[[T01:[0-9a-zA-Z_\.]+]] = getelementptr i256, i256* %0, i32 11 +//CHECK-NEXT: store i256 0, i256* %[[T01]], align 4 +//CHECK-NEXT: br label %store2 +//CHECK-EMPTY: +//CHECK-NEXT: store2: +//CHECK-NEXT: %[[T02:[0-9a-zA-Z_\.]+]] = getelementptr i256, i256* %0, i32 12 +//CHECK-NEXT: store i256 0, i256* %[[T02]], align 4 +//CHECK-NEXT: br label %unrolled_loop3 +//CHECK-EMPTY: +//CHECK-NEXT: unrolled_loop3: +//CHECK-NEXT: %[[T03:[0-9a-zA-Z_\.]+]] = bitcast i256* %0 to [0 x i256]* +//CHECK-NEXT: %[[T04:[0-9a-zA-Z_\.]+]] = bitcast i256* %0 to [0 x i256]* +//CHECK-NEXT: %[[T05:[0-9a-zA-Z_\.]+]] = getelementptr [0 x i256], [0 x i256]* %[[T04]], i32 0, i256 0 +//CHECK-NEXT: call void @..generated..loop.body.[[$F_ID_1]]([0 x i256]* %[[T03]], [0 x i256]* null, i256* %[[T05]]) +//CHECK-NEXT: %[[T06:[0-9a-zA-Z_\.]+]] = bitcast i256* %0 to [0 x i256]* +//CHECK-NEXT: %[[T07:[0-9a-zA-Z_\.]+]] = bitcast i256* %0 to [0 x i256]* +//CHECK-NEXT: %[[T08:[0-9a-zA-Z_\.]+]] = getelementptr [0 x i256], [0 x i256]* %[[T07]], i32 0, i256 1 +//CHECK-NEXT: call void @..generated..loop.body.[[$F_ID_1]]([0 x i256]* %[[T06]], [0 x i256]* null, i256* %[[T08]]) +//CHECK-NEXT: br label %return4 +//CHECK-EMPTY: +//CHECK-NEXT: return4: +//CHECK-NEXT: %[[T09:[0-9a-zA-Z_\.]+]] = getelementptr i256, i256* %0, i32 11 +//CHECK-NEXT: %[[T10:[0-9a-zA-Z_\.]+]] = load i256, i256* %[[T09]], align 4 +//CHECK-NEXT: ret i256 %[[T10]] +//CHECK-NEXT: } +// +//CHECK-LABEL: define{{.*}} void @MultiUse_0_run([0 x i256]* %0){{.*}} { +//CHECK-NEXT: prelude: +//CHECK-NEXT: %lvars = alloca [0 x i256], align 8 +//CHECK-NEXT: %subcmps = alloca [0 x { [0 x i256]*, i32 }], align 8 +//CHECK-NEXT: br label %call1 +//CHECK-EMPTY: +//CHECK-NEXT: call1: +//CHECK-NEXT: %[[A01:[0-9a-zA-Z_\.]+]] = alloca [13 x i256], align 8 +//CHECK-NEXT: %[[T01:[0-9a-zA-Z_\.]+]] = getelementptr [13 x i256], [13 x i256]* %[[A01]], i32 0, i32 0 +//CHECK-NEXT: %[[T02:[0-9a-zA-Z_\.]+]] = getelementptr [0 x i256], [0 x i256]* %0, i32 0, i32 3 +//CHECK-NEXT: %[[CPY_SRC_10:[0-9a-zA-Z_\.]+]] = getelementptr i256, i256* %[[T02]], i32 0 +//CHECK-NEXT: %[[CPY_DST_10:[0-9a-zA-Z_\.]+]] = getelementptr i256, i256* %[[T01]], i32 0 +//CHECK-NEXT: %[[CPY_VAL_10:[0-9a-zA-Z_\.]+]] = load i256, i256* %[[CPY_SRC_10]], align 4 +//CHECK-NEXT: store i256 %[[CPY_VAL_10]], i256* %[[CPY_DST_10]], align 4 +//CHECK-NEXT: %[[CPY_SRC_11:[0-9a-zA-Z_\.]+]] = getelementptr i256, i256* %[[T02]], i32 1 +//CHECK-NEXT: %[[CPY_DST_11:[0-9a-zA-Z_\.]+]] = getelementptr i256, i256* %[[T01]], i32 1 +//CHECK-NEXT: %[[CPY_VAL_11:[0-9a-zA-Z_\.]+]] = load i256, i256* %[[CPY_SRC_11]], align 4 +//CHECK-NEXT: store i256 %[[CPY_VAL_11]], i256* %[[CPY_DST_11]], align 4 +//CHECK-NEXT: %[[CPY_SRC_12:[0-9a-zA-Z_\.]+]] = getelementptr i256, i256* %[[T02]], i32 2 +//CHECK-NEXT: %[[CPY_DST_12:[0-9a-zA-Z_\.]+]] = getelementptr i256, i256* %[[T01]], i32 2 +//CHECK-NEXT: %[[CPY_VAL_12:[0-9a-zA-Z_\.]+]] = load i256, i256* %[[CPY_SRC_12]], align 4 +//CHECK-NEXT: store i256 %[[CPY_VAL_12]], i256* %[[CPY_DST_12]], align 4 +//CHECK-NEXT: %[[CPY_SRC_13:[0-9a-zA-Z_\.]+]] = getelementptr i256, i256* %[[T02]], i32 3 +//CHECK-NEXT: %[[CPY_DST_13:[0-9a-zA-Z_\.]+]] = getelementptr i256, i256* %[[T01]], i32 3 +//CHECK-NEXT: %[[CPY_VAL_13:[0-9a-zA-Z_\.]+]] = load i256, i256* %[[CPY_SRC_13]], align 4 +//CHECK-NEXT: store i256 %[[CPY_VAL_13]], i256* %[[CPY_DST_13]], align 4 +//CHECK-NEXT: %[[CPY_SRC_14:[0-9a-zA-Z_\.]+]] = getelementptr i256, i256* %[[T02]], i32 4 +//CHECK-NEXT: %[[CPY_DST_14:[0-9a-zA-Z_\.]+]] = getelementptr i256, i256* %[[T01]], i32 4 +//CHECK-NEXT: %[[CPY_VAL_14:[0-9a-zA-Z_\.]+]] = load i256, i256* %[[CPY_SRC_14]], align 4 +//CHECK-NEXT: store i256 %[[CPY_VAL_14]], i256* %[[CPY_DST_14]], align 4 +//CHECK-NEXT: %[[CPY_SRC_15:[0-9a-zA-Z_\.]+]] = getelementptr i256, i256* %[[T02]], i32 5 +//CHECK-NEXT: %[[CPY_DST_15:[0-9a-zA-Z_\.]+]] = getelementptr i256, i256* %[[T01]], i32 5 +//CHECK-NEXT: %[[CPY_VAL_15:[0-9a-zA-Z_\.]+]] = load i256, i256* %[[CPY_SRC_15]], align 4 +//CHECK-NEXT: store i256 %[[CPY_VAL_15]], i256* %[[CPY_DST_15]], align 4 +//CHECK-NEXT: %[[CPY_SRC_16:[0-9a-zA-Z_\.]+]] = getelementptr i256, i256* %[[T02]], i32 6 +//CHECK-NEXT: %[[CPY_DST_16:[0-9a-zA-Z_\.]+]] = getelementptr i256, i256* %[[T01]], i32 6 +//CHECK-NEXT: %[[CPY_VAL_16:[0-9a-zA-Z_\.]+]] = load i256, i256* %[[CPY_SRC_16]], align 4 +//CHECK-NEXT: store i256 %[[CPY_VAL_16]], i256* %[[CPY_DST_16]], align 4 +//CHECK-NEXT: %[[CPY_SRC_17:[0-9a-zA-Z_\.]+]] = getelementptr i256, i256* %[[T02]], i32 7 +//CHECK-NEXT: %[[CPY_DST_17:[0-9a-zA-Z_\.]+]] = getelementptr i256, i256* %[[T01]], i32 7 +//CHECK-NEXT: %[[CPY_VAL_17:[0-9a-zA-Z_\.]+]] = load i256, i256* %[[CPY_SRC_17]], align 4 +//CHECK-NEXT: store i256 %[[CPY_VAL_17]], i256* %[[CPY_DST_17]], align 4 +//CHECK-NEXT: %[[CPY_SRC_18:[0-9a-zA-Z_\.]+]] = getelementptr i256, i256* %[[T02]], i32 8 +//CHECK-NEXT: %[[CPY_DST_18:[0-9a-zA-Z_\.]+]] = getelementptr i256, i256* %[[T01]], i32 8 +//CHECK-NEXT: %[[CPY_VAL_18:[0-9a-zA-Z_\.]+]] = load i256, i256* %[[CPY_SRC_18]], align 4 +//CHECK-NEXT: store i256 %[[CPY_VAL_18]], i256* %[[CPY_DST_18]], align 4 +//CHECK-NEXT: %[[CPY_SRC_19:[0-9a-zA-Z_\.]+]] = getelementptr i256, i256* %[[T02]], i32 9 +//CHECK-NEXT: %[[CPY_DST_19:[0-9a-zA-Z_\.]+]] = getelementptr i256, i256* %[[T01]], i32 9 +//CHECK-NEXT: %[[CPY_VAL_19:[0-9a-zA-Z_\.]+]] = load i256, i256* %[[CPY_SRC_19]], align 4 +//CHECK-NEXT: store i256 %[[CPY_VAL_19]], i256* %[[CPY_DST_19]], align 4 +//CHECK-NEXT: %[[T03:[0-9a-zA-Z_\.]+]] = getelementptr [13 x i256], [13 x i256]* %[[A01]], i32 0, i32 10 +//CHECK-NEXT: store i256 2, i256* %[[T03]], align 4 +//CHECK-NEXT: %[[T04:[0-9a-zA-Z_\.]+]] = bitcast [13 x i256]* %[[A01]] to i256* +//CHECK-NEXT: %[[T16:[0-9a-zA-Z_\.]+]] = call i256 @f_0.2(i256* %[[T04]]) +//CHECK-NEXT: %[[T05:[0-9a-zA-Z_\.]+]] = getelementptr [0 x i256], [0 x i256]* %0, i32 0, i32 0 +//CHECK-NEXT: store i256 %[[T16]], i256* %[[T05]], align 4 +//CHECK-NEXT: br label %call2 +//CHECK-EMPTY: +//CHECK-NEXT: call2: +//CHECK-NEXT: %[[A02:[0-9a-zA-Z_\.]+]] = alloca [13 x i256], align 8 +//CHECK-NEXT: %[[T06:[0-9a-zA-Z_\.]+]] = getelementptr [13 x i256], [13 x i256]* %[[A02]], i32 0, i32 0 +//CHECK-NEXT: %[[T07:[0-9a-zA-Z_\.]+]] = getelementptr [0 x i256], [0 x i256]* %0, i32 0, i32 3 +//CHECK-NEXT: %[[CPY_SRC_20:[0-9a-zA-Z_\.]+]] = getelementptr i256, i256* %[[T07]], i32 0 +//CHECK-NEXT: %[[CPY_DST_20:[0-9a-zA-Z_\.]+]] = getelementptr i256, i256* %[[T06]], i32 0 +//CHECK-NEXT: %[[CPY_VAL_20:[0-9a-zA-Z_\.]+]] = load i256, i256* %[[CPY_SRC_20]], align 4 +//CHECK-NEXT: store i256 %[[CPY_VAL_20]], i256* %[[CPY_DST_20]], align 4 +//CHECK-NEXT: %[[CPY_SRC_21:[0-9a-zA-Z_\.]+]] = getelementptr i256, i256* %[[T07]], i32 1 +//CHECK-NEXT: %[[CPY_DST_21:[0-9a-zA-Z_\.]+]] = getelementptr i256, i256* %[[T06]], i32 1 +//CHECK-NEXT: %[[CPY_VAL_21:[0-9a-zA-Z_\.]+]] = load i256, i256* %[[CPY_SRC_21]], align 4 +//CHECK-NEXT: store i256 %[[CPY_VAL_21]], i256* %[[CPY_DST_21]], align 4 +//CHECK-NEXT: %[[CPY_SRC_22:[0-9a-zA-Z_\.]+]] = getelementptr i256, i256* %[[T07]], i32 2 +//CHECK-NEXT: %[[CPY_DST_22:[0-9a-zA-Z_\.]+]] = getelementptr i256, i256* %[[T06]], i32 2 +//CHECK-NEXT: %[[CPY_VAL_22:[0-9a-zA-Z_\.]+]] = load i256, i256* %[[CPY_SRC_22]], align 4 +//CHECK-NEXT: store i256 %[[CPY_VAL_22]], i256* %[[CPY_DST_22]], align 4 +//CHECK-NEXT: %[[CPY_SRC_23:[0-9a-zA-Z_\.]+]] = getelementptr i256, i256* %[[T07]], i32 3 +//CHECK-NEXT: %[[CPY_DST_23:[0-9a-zA-Z_\.]+]] = getelementptr i256, i256* %[[T06]], i32 3 +//CHECK-NEXT: %[[CPY_VAL_23:[0-9a-zA-Z_\.]+]] = load i256, i256* %[[CPY_SRC_23]], align 4 +//CHECK-NEXT: store i256 %[[CPY_VAL_23]], i256* %[[CPY_DST_23]], align 4 +//CHECK-NEXT: %[[CPY_SRC_24:[0-9a-zA-Z_\.]+]] = getelementptr i256, i256* %[[T07]], i32 4 +//CHECK-NEXT: %[[CPY_DST_24:[0-9a-zA-Z_\.]+]] = getelementptr i256, i256* %[[T06]], i32 4 +//CHECK-NEXT: %[[CPY_VAL_24:[0-9a-zA-Z_\.]+]] = load i256, i256* %[[CPY_SRC_24]], align 4 +//CHECK-NEXT: store i256 %[[CPY_VAL_24]], i256* %[[CPY_DST_24]], align 4 +//CHECK-NEXT: %[[CPY_SRC_25:[0-9a-zA-Z_\.]+]] = getelementptr i256, i256* %[[T07]], i32 5 +//CHECK-NEXT: %[[CPY_DST_25:[0-9a-zA-Z_\.]+]] = getelementptr i256, i256* %[[T06]], i32 5 +//CHECK-NEXT: %[[CPY_VAL_25:[0-9a-zA-Z_\.]+]] = load i256, i256* %[[CPY_SRC_25]], align 4 +//CHECK-NEXT: store i256 %[[CPY_VAL_25]], i256* %[[CPY_DST_25]], align 4 +//CHECK-NEXT: %[[CPY_SRC_26:[0-9a-zA-Z_\.]+]] = getelementptr i256, i256* %[[T07]], i32 6 +//CHECK-NEXT: %[[CPY_DST_26:[0-9a-zA-Z_\.]+]] = getelementptr i256, i256* %[[T06]], i32 6 +//CHECK-NEXT: %[[CPY_VAL_26:[0-9a-zA-Z_\.]+]] = load i256, i256* %[[CPY_SRC_26]], align 4 +//CHECK-NEXT: store i256 %[[CPY_VAL_26]], i256* %[[CPY_DST_26]], align 4 +//CHECK-NEXT: %[[CPY_SRC_27:[0-9a-zA-Z_\.]+]] = getelementptr i256, i256* %[[T07]], i32 7 +//CHECK-NEXT: %[[CPY_DST_27:[0-9a-zA-Z_\.]+]] = getelementptr i256, i256* %[[T06]], i32 7 +//CHECK-NEXT: %[[CPY_VAL_27:[0-9a-zA-Z_\.]+]] = load i256, i256* %[[CPY_SRC_27]], align 4 +//CHECK-NEXT: store i256 %[[CPY_VAL_27]], i256* %[[CPY_DST_27]], align 4 +//CHECK-NEXT: %[[CPY_SRC_28:[0-9a-zA-Z_\.]+]] = getelementptr i256, i256* %[[T07]], i32 8 +//CHECK-NEXT: %[[CPY_DST_28:[0-9a-zA-Z_\.]+]] = getelementptr i256, i256* %[[T06]], i32 8 +//CHECK-NEXT: %[[CPY_VAL_28:[0-9a-zA-Z_\.]+]] = load i256, i256* %[[CPY_SRC_28]], align 4 +//CHECK-NEXT: store i256 %[[CPY_VAL_28]], i256* %[[CPY_DST_28]], align 4 +//CHECK-NEXT: %[[CPY_SRC_29:[0-9a-zA-Z_\.]+]] = getelementptr i256, i256* %[[T07]], i32 9 +//CHECK-NEXT: %[[CPY_DST_29:[0-9a-zA-Z_\.]+]] = getelementptr i256, i256* %[[T06]], i32 9 +//CHECK-NEXT: %[[CPY_VAL_29:[0-9a-zA-Z_\.]+]] = load i256, i256* %[[CPY_SRC_29]], align 4 +//CHECK-NEXT: store i256 %[[CPY_VAL_29]], i256* %[[CPY_DST_29]], align 4 +//CHECK-NEXT: %[[T08:[0-9a-zA-Z_\.]+]] = getelementptr [13 x i256], [13 x i256]* %[[A02]], i32 0, i32 10 +//CHECK-NEXT: store i256 2, i256* %[[T08]], align 4 +//CHECK-NEXT: %[[T09:[0-9a-zA-Z_\.]+]] = bitcast [13 x i256]* %[[A02]] to i256* +//CHECK-NEXT: %[[T17:[0-9a-zA-Z_\.]+]] = call i256 @f_0.2(i256* %[[T09]]) +//CHECK-NEXT: %[[T10:[0-9a-zA-Z_\.]+]] = getelementptr [0 x i256], [0 x i256]* %0, i32 0, i32 1 +//CHECK-NEXT: store i256 %[[T17]], i256* %[[T10]], align 4 +//CHECK-NEXT: br label %call3 +//CHECK-EMPTY: +//CHECK-NEXT: call3: +//CHECK-NEXT: %[[A03:[0-9a-zA-Z_\.]+]] = alloca [13 x i256], align 8 +//CHECK-NEXT: %[[T11:[0-9a-zA-Z_\.]+]] = getelementptr [13 x i256], [13 x i256]* %[[A03]], i32 0, i32 0 +//CHECK-NEXT: %[[T12:[0-9a-zA-Z_\.]+]] = getelementptr [0 x i256], [0 x i256]* %0, i32 0, i32 3 +//CHECK-NEXT: %[[CPY_SRC_30:[0-9a-zA-Z_\.]+]] = getelementptr i256, i256* %[[T12]], i32 0 +//CHECK-NEXT: %[[CPY_DST_30:[0-9a-zA-Z_\.]+]] = getelementptr i256, i256* %[[T11]], i32 0 +//CHECK-NEXT: %[[CPY_VAL_30:[0-9a-zA-Z_\.]+]] = load i256, i256* %[[CPY_SRC_30]], align 4 +//CHECK-NEXT: store i256 %[[CPY_VAL_30]], i256* %[[CPY_DST_30]], align 4 +//CHECK-NEXT: %[[CPY_SRC_31:[0-9a-zA-Z_\.]+]] = getelementptr i256, i256* %[[T12]], i32 1 +//CHECK-NEXT: %[[CPY_DST_31:[0-9a-zA-Z_\.]+]] = getelementptr i256, i256* %[[T11]], i32 1 +//CHECK-NEXT: %[[CPY_VAL_31:[0-9a-zA-Z_\.]+]] = load i256, i256* %[[CPY_SRC_31]], align 4 +//CHECK-NEXT: store i256 %[[CPY_VAL_31]], i256* %[[CPY_DST_31]], align 4 +//CHECK-NEXT: %[[CPY_SRC_32:[0-9a-zA-Z_\.]+]] = getelementptr i256, i256* %[[T12]], i32 2 +//CHECK-NEXT: %[[CPY_DST_32:[0-9a-zA-Z_\.]+]] = getelementptr i256, i256* %[[T11]], i32 2 +//CHECK-NEXT: %[[CPY_VAL_32:[0-9a-zA-Z_\.]+]] = load i256, i256* %[[CPY_SRC_32]], align 4 +//CHECK-NEXT: store i256 %[[CPY_VAL_32]], i256* %[[CPY_DST_32]], align 4 +//CHECK-NEXT: %[[CPY_SRC_33:[0-9a-zA-Z_\.]+]] = getelementptr i256, i256* %[[T12]], i32 3 +//CHECK-NEXT: %[[CPY_DST_33:[0-9a-zA-Z_\.]+]] = getelementptr i256, i256* %[[T11]], i32 3 +//CHECK-NEXT: %[[CPY_VAL_33:[0-9a-zA-Z_\.]+]] = load i256, i256* %[[CPY_SRC_33]], align 4 +//CHECK-NEXT: store i256 %[[CPY_VAL_33]], i256* %[[CPY_DST_33]], align 4 +//CHECK-NEXT: %[[CPY_SRC_34:[0-9a-zA-Z_\.]+]] = getelementptr i256, i256* %[[T12]], i32 4 +//CHECK-NEXT: %[[CPY_DST_34:[0-9a-zA-Z_\.]+]] = getelementptr i256, i256* %[[T11]], i32 4 +//CHECK-NEXT: %[[CPY_VAL_34:[0-9a-zA-Z_\.]+]] = load i256, i256* %[[CPY_SRC_34]], align 4 +//CHECK-NEXT: store i256 %[[CPY_VAL_34]], i256* %[[CPY_DST_34]], align 4 +//CHECK-NEXT: %[[CPY_SRC_35:[0-9a-zA-Z_\.]+]] = getelementptr i256, i256* %[[T12]], i32 5 +//CHECK-NEXT: %[[CPY_DST_35:[0-9a-zA-Z_\.]+]] = getelementptr i256, i256* %[[T11]], i32 5 +//CHECK-NEXT: %[[CPY_VAL_35:[0-9a-zA-Z_\.]+]] = load i256, i256* %[[CPY_SRC_35]], align 4 +//CHECK-NEXT: store i256 %[[CPY_VAL_35]], i256* %[[CPY_DST_35]], align 4 +//CHECK-NEXT: %[[CPY_SRC_36:[0-9a-zA-Z_\.]+]] = getelementptr i256, i256* %[[T12]], i32 6 +//CHECK-NEXT: %[[CPY_DST_36:[0-9a-zA-Z_\.]+]] = getelementptr i256, i256* %[[T11]], i32 6 +//CHECK-NEXT: %[[CPY_VAL_36:[0-9a-zA-Z_\.]+]] = load i256, i256* %[[CPY_SRC_36]], align 4 +//CHECK-NEXT: store i256 %[[CPY_VAL_36]], i256* %[[CPY_DST_36]], align 4 +//CHECK-NEXT: %[[CPY_SRC_37:[0-9a-zA-Z_\.]+]] = getelementptr i256, i256* %[[T12]], i32 7 +//CHECK-NEXT: %[[CPY_DST_37:[0-9a-zA-Z_\.]+]] = getelementptr i256, i256* %[[T11]], i32 7 +//CHECK-NEXT: %[[CPY_VAL_37:[0-9a-zA-Z_\.]+]] = load i256, i256* %[[CPY_SRC_37]], align 4 +//CHECK-NEXT: store i256 %[[CPY_VAL_37]], i256* %[[CPY_DST_37]], align 4 +//CHECK-NEXT: %[[CPY_SRC_38:[0-9a-zA-Z_\.]+]] = getelementptr i256, i256* %[[T12]], i32 8 +//CHECK-NEXT: %[[CPY_DST_38:[0-9a-zA-Z_\.]+]] = getelementptr i256, i256* %[[T11]], i32 8 +//CHECK-NEXT: %[[CPY_VAL_38:[0-9a-zA-Z_\.]+]] = load i256, i256* %[[CPY_SRC_38]], align 4 +//CHECK-NEXT: store i256 %[[CPY_VAL_38]], i256* %[[CPY_DST_38]], align 4 +//CHECK-NEXT: %[[CPY_SRC_39:[0-9a-zA-Z_\.]+]] = getelementptr i256, i256* %[[T12]], i32 9 +//CHECK-NEXT: %[[CPY_DST_39:[0-9a-zA-Z_\.]+]] = getelementptr i256, i256* %[[T11]], i32 9 +//CHECK-NEXT: %[[CPY_VAL_39:[0-9a-zA-Z_\.]+]] = load i256, i256* %[[CPY_SRC_39]], align 4 +//CHECK-NEXT: store i256 %[[CPY_VAL_39]], i256* %[[CPY_DST_39]], align 4 +//CHECK-NEXT: %[[T13:[0-9a-zA-Z_\.]+]] = getelementptr [13 x i256], [13 x i256]* %[[A03]], i32 0, i32 10 +//CHECK-NEXT: store i256 2, i256* %[[T13]], align 4 +//CHECK-NEXT: %[[T14:[0-9a-zA-Z_\.]+]] = bitcast [13 x i256]* %[[A03]] to i256* +//CHECK-NEXT: %[[T18:[0-9a-zA-Z_\.]+]] = call i256 @f_0.2(i256* %[[T14]]) +//CHECK-NEXT: %[[T15:[0-9a-zA-Z_\.]+]] = getelementptr [0 x i256], [0 x i256]* %0, i32 0, i32 2 +//CHECK-NEXT: store i256 %[[T18]], i256* %[[T15]], align 4 +//CHECK-NEXT: br label %prologue +//CHECK-EMPTY: +//CHECK-NEXT: prologue: +//CHECK-NEXT: ret void +//CHECK-NEXT: } diff --git a/circuit_passes/src/bucket_interpreter/env/unrolled_block_env.rs b/circuit_passes/src/bucket_interpreter/env/unrolled_block_env.rs index 7a61a32da..75d7d6ff2 100644 --- a/circuit_passes/src/bucket_interpreter/env/unrolled_block_env.rs +++ b/circuit_passes/src/bucket_interpreter/env/unrolled_block_env.rs @@ -32,11 +32,7 @@ impl Display for UnrolledBlockEnvData<'_> { impl LibraryAccess for UnrolledBlockEnvData<'_> { fn get_function(&self, name: &String) -> Ref { if name.starts_with(LOOP_BODY_FN_PREFIX) { - Ref::map(self.extractor.get_new_functions(), |f| { - f.iter() - .find(|f| f.header.eq(name)) - .expect("Cannot find extracted function definition!") - }) + self.extractor.search_new_functions(name) } else { self.base.get_function(name) } diff --git a/circuit_passes/src/passes/checks.rs b/circuit_passes/src/passes/checks.rs index e4906872d..69de8c718 100644 --- a/circuit_passes/src/passes/checks.rs +++ b/circuit_passes/src/passes/checks.rs @@ -153,6 +153,7 @@ macro_rules! checked_insert { ($map: expr, $key: expr, $val: expr) => {{ let key = $key; let val = $val; + #[allow(unused_mut)] // some callers may already pass a &mut and that causes warning here let mut map = $map; assert!( !map.contains_key(&key) || map[&key] == val, diff --git a/circuit_passes/src/passes/loop_unroll/body_extractor.rs b/circuit_passes/src/passes/loop_unroll/body_extractor.rs index 308cd9e74..67cfc626e 100644 --- a/circuit_passes/src/passes/loop_unroll/body_extractor.rs +++ b/circuit_passes/src/passes/loop_unroll/body_extractor.rs @@ -1,5 +1,7 @@ -use std::cell::{RefCell, Ref}; +use std::cell::{Ref, RefCell}; +use std::cmp::Ordering; use std::collections::{BTreeMap, HashMap, HashSet, BTreeSet}; +use std::rc::Rc; use std::vec; use indexmap::{IndexMap, IndexSet}; use code_producers::llvm_elements::fr::*; @@ -10,16 +12,24 @@ use compiler::intermediate_representation::ir_interface::*; use crate::bucket_interpreter::env::EnvContextKind; use crate::bucket_interpreter::error::BadInterp; use crate::bucket_interpreter::value::Value; -use crate::passes::loop_unroll::{DEBUG_LOOP_UNROLL, LOOP_BODY_FN_PREFIX}; -use crate::passes::loop_unroll::extracted_location_updater::ExtractedFunctionLocationUpdater; -use crate::passes::loop_unroll::loop_env_recorder::EnvRecorder; +use crate::checked_insert; use crate::passes::{builders, checks}; +use super::{DEBUG_LOOP_UNROLL, LOOP_BODY_FN_PREFIX}; +use super::extracted_location_updater::ExtractedFunctionLocationUpdater; +use super::loop_env_recorder::EnvRecorder; pub type FuncArgIdx = usize; pub type AddressOffset = usize; pub type UnrolledIterLvars = BTreeMap; pub type ToOriginalLocation = HashMap; +/// Table structure indexed first by load/store/call BucketId, then by iteration number +/// (i.e. the Vec index), containing the original memory references to use as arguments +/// when calling the extracted body function. +// NOTE: This collection and several intermediate collections that are used to build it +// must use IndexMap/IndexSet to preserve insertion order to stabilize lit test output. +type MemRefsPerIter = IndexMap>>; + #[derive(Clone, Copy, Debug, Eq, PartialEq, Ord, PartialOrd)] pub enum ArgIndex { Signal(FuncArgIdx, &'static str), @@ -39,7 +49,7 @@ impl ArgIndex { /// Also, the input/output stuff doesn't matter since the extra arguments that are added /// based on this are only used to trigger generation of the run function after all of /// the inputs have been assigned. -#[derive(Clone, Debug, Eq, PartialEq, Ord, PartialOrd)] +#[derive(Debug, Eq, PartialEq)] struct SubcmpSignalCompare { cmp_address_parse_as: ValueType, cmp_address_op_aux_no: usize, @@ -74,95 +84,164 @@ impl SubcmpSignalCompare { } } -struct ExtraArgsResult { - // NOTE: This collection and several intermediate collections that are used to build this - // one must use IndexMap/IndexSet to preserve insertion order to stabilize lit test output. - bucket_to_itr_to_ref: HashMap>>, - bucket_to_args: IndexMap, +#[derive(Debug, Eq, PartialEq)] +struct ArgInfo { + // NOTE: This collection and several intermediate collections that are used to build it + // must use IndexMap/IndexSet to preserve insertion order to stabilize lit test output. + loc_to_args: IndexMap, num_args: usize, } -impl ExtraArgsResult { - fn get_passing_refs_for_itr( +impl PartialOrd for ArgInfo { + #[inline] + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +impl Ord for ArgInfo { + fn cmp(&self, other: &Self) -> Ordering { + self.num_args + .cmp(&other.num_args) + .then_with(|| self.loc_to_args.iter().cmp(other.loc_to_args.iter())) + } +} + +impl ArgInfo { + #[inline] + fn get_passing_refs_for_itr<'a>( &self, + mem_refs: &'a MemRefsPerIter, iter_num: usize, - ) -> Vec<(&Option<(AddressType, AddressOffset)>, ArgIndex)> { - self.bucket_to_itr_to_ref - .iter() - .map(|(k, v)| (&v[iter_num], self.bucket_to_args[k])) - .collect() + ) -> Vec<(&'a Option<(AddressType, AddressOffset)>, ArgIndex)> { + mem_refs.iter().map(|(k, v)| (&v[iter_num], self.loc_to_args[k])).collect() } fn get_reverse_passing_refs_for_itr( &self, + mem_refs: &MemRefsPerIter, iter_num: usize, ) -> (ToOriginalLocation, HashSet) { - self.bucket_to_itr_to_ref.iter().fold( - (ToOriginalLocation::new(), HashSet::new()), - |mut acc, (k, v)| { - if let Some((addr_ty, addr_offset)) = v[iter_num].as_ref() { - let ai = self.bucket_to_args[k]; - acc.0.insert(ai.get_signal_idx(), (addr_ty.clone(), *addr_offset)); - // If applicable, insert the subcmp counter reference as well - if let ArgIndex::SubCmp { counter, arena, .. } = ai { - match addr_ty { - AddressType::SubcmpSignal { counter_override, cmp_address, .. } => { - assert_eq!(*counter_override, false); //there's no counter for a counter - let counter_addr_ty = AddressType::SubcmpSignal { - cmp_address: cmp_address.clone(), - uniform_parallel_value: None, - is_output: false, - input_information: InputInformation::NoInput, - counter_override: true, - }; - // NOTE: when there's a true subcomponent (indicated by the ArgIndex::SubCmp check above), - // the 'addr_offset' indicates which signal inside the subcomponent is accessed. That - // value is not relevant here because subcomponents have a single counter variable. - acc.0.insert(counter, (counter_addr_ty, 0)); - // - acc.1.insert(arena); - } - _ => unreachable!(), // SubcmpSignal was created for all of these refs + mem_refs.iter().fold((ToOriginalLocation::new(), HashSet::new()), |mut acc, (k, v)| { + if let Some((addr_ty, addr_offset)) = v[iter_num].as_ref() { + let ai = self.loc_to_args[k]; + acc.0.insert(ai.get_signal_idx(), (addr_ty.clone(), *addr_offset)); + // If applicable, insert the index for the subcmp counter and arena parameters + if let ArgIndex::SubCmp { counter, arena, .. } = ai { + match addr_ty { + AddressType::SubcmpSignal { counter_override, cmp_address, .. } => { + assert_eq!(*counter_override, false); //there's no counter for a counter + let counter_addr_ty = AddressType::SubcmpSignal { + cmp_address: cmp_address.clone(), + uniform_parallel_value: None, + is_output: false, + input_information: InputInformation::NoInput, + counter_override: true, + }; + // NOTE: when there's a true subcomponent (indicated by the ArgIndex::SubCmp check above), + // the 'addr_offset' indicates which signal inside the subcomponent is accessed. That + // value is not relevant here because subcomponents have a single counter variable. + acc.0.insert(counter, (counter_addr_ty, 0)); + // + acc.1.insert(arena); } + _ => unreachable!(), // SubcmpSignal was created for all of these refs } } - acc - }, - ) + } + acc + }) } } -#[derive(Clone, Debug, Eq, PartialEq, Default)] +/// Collection of data that determines if a new unique extracted +/// body function should be created for a specific LoopBucket. +#[derive(Debug, Eq, PartialEq, Ord, PartialOrd)] +struct UniqueFuncKey { + loop_bucket_id: BucketId, + extra_arg_info: Rc, +} + +#[derive(Debug, Default)] pub struct LoopBodyExtractor { - new_body_functions: RefCell>, + new_body_functions: RefCell>, + /// Exists only to stabilize function order for lit test output. + func_creation_order: RefCell>, } impl LoopBodyExtractor { - fn new_filled_vec(new_len: usize, value: T) -> Vec { - let mut result = Vec::with_capacity(new_len); - result.resize(new_len, value); - result + #[inline] + pub fn search_new_functions(&self, name: &String) -> Ref { + Ref::map(self.new_body_functions.borrow(), |m| { + m.values() + .find(|f| f.header.eq(name)) + .expect("Cannot find extracted function definition!") + }) } - pub fn get_new_functions(&self) -> Ref> { - self.new_body_functions.borrow() + pub fn take_new_functions(&self) -> impl ExactSizeIterator { + // NOTE: Ordering is only to stabilize lit test output. Otherwise, this could + // be implemented as `self.new_body_functions.take().into_values()`. + let mut ret: Vec = self.new_body_functions.take().into_values().collect(); + let index_map: HashMap = + self.func_creation_order.take().into_iter().enumerate().map(|(i, t)| (t, i)).collect(); + ret.sort_by_key(|f| index_map[&f.header]); + ret.into_iter() } pub fn extract<'a>( &self, bucket: &LoopBucket, recorder: EnvRecorder<'a, '_>, - context_kind: EnvContextKind, unrolled: &mut InstructionList, ) -> Result<(), BadInterp> { assert!(bucket.body.len() > 1); - let extra_arg_info = Self::compute_extra_args(&recorder, context_kind)?; - let name = self.build_new_body( - &recorder.get_current_scope_name(), - bucket, - extra_arg_info.bucket_to_args.clone(), - extra_arg_info.num_args, - ); + + // Check if an applicable function already exists, otherwise create a new extracted loop body function. + let (arg_info, mem_refs, extracted_name) = { + let (arg_info, mem_refs) = Self::compute_extra_args(&recorder)?; + let arg_info = Rc::new(arg_info); + let existing_fun = self.get_extracted_function_name(bucket.id, Rc::clone(&arg_info)); + match existing_fun { + None => { + // Create and store the new function + let new_func = self.build_new_extracted_function( + &recorder.get_current_scope_name(), + bucket, + &arg_info.loc_to_args, + arg_info.num_args, + ); + let new_name = new_func.header.clone(); + self.store_new_extracted_function(bucket.id, Rc::clone(&arg_info), new_func); + + (arg_info, mem_refs, new_name) + } + Some(name) => (arg_info, mem_refs, name.clone()), + } + }; + + // Store the parameter information for the function to GlobalPassData based on current Env + { + let mut gd = recorder.global_data.borrow_mut(); + let extraction_data = + gd.extract_func_orig_loc.entry(extracted_name.clone()).or_default(); + for iter_num in 0..recorder.get_iter() { + let iter_env = recorder.take_header_vars_for_iter(&iter_num); + let mapping = arg_info.get_reverse_passing_refs_for_itr(&mem_refs, iter_num); + if DEBUG_LOOP_UNROLL { + println!( + "[extract] storing orig loc data for: {}+{:?} -> {:?}", + extracted_name, iter_env, mapping + ); + } + // NOTE: Encountering different iteration counts for the same loop will produce + // different Env at loop header (i.e. `iter_env`) and the same Env at the loop + // header will produce the same reverse mapping. Thus, no key conflicts. + checked_insert!(&mut *extraction_data, iter_env, mapping); + } + } + + // Generate the list of unrolled calls to the extracted loop body function. for iter_num in 0..recorder.get_iter() { // NOTE: CallBucket arguments must use a LoadBucket to reference the necessary pointers // within the current body. However, it doesn't actually need to generate a load @@ -170,19 +249,19 @@ impl LoopBodyExtractor { // `bounded_fn` field of the LoadBucket to specify the identity function to perform // the "loading" (but really it just returns the pointer that was passed in). let mut args = Self::new_filled_vec( - extra_arg_info.num_args, + arg_info.num_args, NopBucket { id: 0 }.allocate(), // garbage fill ); // Parameter for local vars args[0] = builders::build_storage_ptr_ref(bucket, AddressType::Variable); // Parameter for signals/arena, not needed when unrolling w/in a circom source function - args[1] = if context_kind == EnvContextKind::SourceFunction { + args[1] = if recorder.ctx_kind == EnvContextKind::SourceFunction { builders::build_null_ptr(bucket, FR_NULL_I256_ARR_PTR) } else { builders::build_storage_ptr_ref(bucket, AddressType::Signal) }; // Additional parameters for subcmps and variant array indexing within the loop - let mut passing_refs = extra_arg_info.get_passing_refs_for_itr(iter_num); + let mut passing_refs = arg_info.get_passing_refs_for_itr(&mem_refs, iter_num); // Sort by the Option to ensure None comes first so that the value for a Some entry that uses the same // 'arena' and 'counter' as a None entry will be preserved, replacing the 'null' for the None entry. passing_refs.sort_by(|(a, _), (b, _)| a.cmp(b)); @@ -223,25 +302,22 @@ impl LoopBodyExtractor { }, } } - unrolled.push(builders::build_call(bucket, &name, args)); - - recorder.record_reverse_arg_mapping( - name.clone(), - recorder.take_header_vars_for_iter(&iter_num), - extra_arg_info.get_reverse_passing_refs_for_itr(iter_num), - ); + unrolled.push(builders::build_call(bucket, &extracted_name, args)); } + Ok(()) } - fn build_new_body( + /// The new function's name is in `FunctionCode.header` + fn build_new_extracted_function( &self, - source_body_name: &String, - bucket: &LoopBucket, - mut bucket_to_args: IndexMap, + source_fun_name: &String, + source_bucket: &LoopBucket, + loc_to_args: &IndexMap, num_args: usize, - ) -> String { - // NOTE: must create parameter list before 'bucket_to_args' is modified + ) -> FunctionCode { + let mut loc_to_args = loc_to_args.clone(); + // NOTE: must create parameter list before 'loc_to_args' is modified // Since the ArgIndex instances could have indices in any random order, // create the vector of required size and then set elements by index. let mut params = Self::new_filled_vec( @@ -250,7 +326,7 @@ impl LoopBodyExtractor { ); params[0] = Param { name: String::from("lvars"), length: vec![0] }; params[1] = Param { name: String::from("signals"), length: vec![0] }; - for (i, arg_index) in bucket_to_args.values().enumerate() { + for (i, arg_index) in loc_to_args.values().enumerate() { match arg_index { ArgIndex::Signal(signal, prefix) => { //Single signal uses scalar pointer @@ -267,16 +343,16 @@ impl LoopBodyExtractor { // Copy loop body and add a "return void" at the end let mut new_body = vec![]; - for s in &bucket.body { + for s in &source_bucket.body { let mut copy = vec![s.clone()]; //Traverse each cloned statement before calling `update_id()` and replace the // old location reference with reference to the proper argument. Mappings are // removed as they are processed so no change is needed once the map is empty. // Also retrieve the list of statements that were generated to be inserted // after the current statement and insert them after the updated statement. - //NOTE: nothing will be updated or added if 'bucket_to_args' is empty so skip. - if !bucket_to_args.is_empty() { - let mut upd = ExtractedFunctionLocationUpdater::new(&mut bucket_to_args); + //NOTE: nothing will be updated or added if 'loc_to_args' is empty so skip. + if !loc_to_args.is_empty() { + let mut upd = ExtractedFunctionLocationUpdater::new(&mut loc_to_args); upd.check_instructions(&mut copy); } for mut s in copy.drain(..) { @@ -284,32 +360,31 @@ impl LoopBodyExtractor { new_body.push(s); } } - assert!(bucket_to_args.is_empty()); - new_body.push(builders::build_void_return(bucket)); + assert!(loc_to_args.is_empty()); + new_body.push(builders::build_void_return(source_bucket)); // Create new function to hold the copied body // NOTE: This name must start with `GENERATED_FN_PREFIX` (which is the prefix // of `LOOP_BODY_FN_PREFIX`) so that `ExtractedFunctionCtx` will be used. let func_name = format!("{}{}", LOOP_BODY_FN_PREFIX, new_id()); + if DEBUG_LOOP_UNROLL { + println!("[BodyExtractor] created function {}", func_name); + } let new_func = Box::new(FunctionCodeInfo { - source_file_id: bucket.source_file_id, - line: bucket.line, - name: source_body_name.clone(), - header: func_name.clone(), + source_file_id: source_bucket.source_file_id, + line: source_bucket.line, + name: source_fun_name.clone(), + header: func_name, body: new_body, params, returns: vec![], // void return type on the function ..FunctionCodeInfo::default() }); - // Store the function to be transformed and added to circuit later - self.new_body_functions.borrow_mut().push(new_func); - if DEBUG_LOOP_UNROLL { - println!("[BodyExtractor] created function {}", func_name); - } - func_name + new_func } /// Create an Iterator containing the results of applying the given /// function to only the `Some` entries in the given vector. + #[inline] fn filter_map<'a, A, B, C>( column: &'a Vec>, f: impl FnMut(&(A, B)) -> C + 'a, @@ -319,6 +394,7 @@ impl LoopBodyExtractor { /// Create an Iterator containing the results of applying the given /// function to only the `Some` entries in the given vector. + #[inline] fn filter_map_any(column: &Vec>, f: impl FnMut(&(A, B)) -> bool) -> bool { column.iter().filter_map(|x| x.as_ref()).any(f) } @@ -347,6 +423,7 @@ impl LoopBodyExtractor { /// Test for true equality when both parameters are Some, otherwise return true when either is None. /// The first value in the tuple is the result and the second is `true` for a "fuzzy" None equality. + #[inline] fn fuzzy_equals(a: &Option, b: &Option) -> (bool, bool) { match (a, b) { (Some(x), Some(y)) => (x == y, false), @@ -476,8 +553,7 @@ impl LoopBodyExtractor { /// extra arguments that will be needed. fn compute_extra_args<'a>( recorder: &EnvRecorder<'a, '_>, - context_kind: EnvContextKind, - ) -> Result { + ) -> Result<(ArgInfo, MemRefsPerIter), BadInterp> { // Table structure indexed first by load/store/call BucketId, then by iteration number. // View the first (BucketId) as columns and the second (iteration number) as rows. // The data reference is wrapped in Option to allow for some iterations that don't @@ -488,20 +564,17 @@ impl LoopBodyExtractor { // that do not execute that specific bucket. This is the reason it was important to // store Unknown values in the `loadstore_to_index` index as well, so they are not // confused with values that simply don't exist. - let mut bucket_to_itr_to_ref: IndexMap< - BucketId, - Vec>, - > = Default::default(); + let mut loc_to_itr_to_ref: MemRefsPerIter = Default::default(); // - let mut bucket_to_args: IndexMap = Default::default(); + let mut loc_to_args: IndexMap = Default::default(); let mut vpi = recorder.take_loadstore_to_index_map(); // NOTE: starts at 2 because the current component's signal arena and lvars are first. let mut next_idx: FuncArgIdx = 2; - // First step is to collect all location references into the 'bucket_to_itr_to_ref' table. + // First step is to collect all location references into the 'loc_to_itr_to_ref' table. let all_loadstore_buckets: IndexSet = vpi.values().flat_map(|x| x.keys().cloned()).collect(); for id in all_loadstore_buckets.iter() { - let column = bucket_to_itr_to_ref.entry(*id).or_default(); + let column = loc_to_itr_to_ref.entry(*id).or_default(); for iter_num in 0..recorder.get_iter() { column.push(match vpi.get_mut(&iter_num).unwrap().shift_remove(id) { None => None, @@ -536,22 +609,22 @@ impl LoopBodyExtractor { AddressType::Signal => "sig", AddressType::SubcmpSignal { .. } => "subsig", }; - bucket_to_args.insert(*id, ArgIndex::Signal(next_idx, prefix)); + loc_to_args.insert(*id, ArgIndex::Signal(next_idx, prefix)); next_idx += 1; } } } if DEBUG_LOOP_UNROLL { - println!("bucket_to_args = {:?}", bucket_to_args); - println!("bucket_to_itr_to_ref = {:?}", bucket_to_itr_to_ref); + println!("loc_to_args = {:?}", loc_to_args); + println!("loc_to_itr_to_ref = {:?}", loc_to_itr_to_ref); println!("all_loadstore_bucket_ids = {:?}", all_loadstore_buckets); } //ASSERT: All columns have the same length (i.e. the number of iterations) - assert!(bucket_to_itr_to_ref.values().all(|c| c.len() == recorder.get_iter())); - //ASSERT: 'bucket_to_itr_to_ref.keys' is equal to 'all_loadstore_bucket_ids' - assert!(checks::contains_same(&all_loadstore_buckets, bucket_to_itr_to_ref.keys())); - //ASSERT: 'bucket_to_args.keys' is a subset of 'all_loadstore_bucket_ids' - assert!(checks::contains_all(&all_loadstore_buckets, bucket_to_args.keys())); + assert!(loc_to_itr_to_ref.values().all(|c| c.len() == recorder.get_iter())); + //ASSERT: 'loc_to_itr_to_ref.keys' is equal to 'all_loadstore_bucket_ids' + assert!(checks::contains_same(&all_loadstore_buckets, loc_to_itr_to_ref.keys())); + //ASSERT: 'loc_to_args.keys' is a subset of 'all_loadstore_bucket_ids' + assert!(checks::contains_all(&all_loadstore_buckets, loc_to_args.keys())); // Also, if it's a subcomponent reference, then extra arguments are needed for it's // signal arena and counter (because the entire subcomponent storage pointer is not @@ -563,8 +636,8 @@ impl LoopBodyExtractor { // are unused in some iteration(s) (indicated by None), they can be included in a // group if they are part of that same group in all iterations where present. // However, this should not be done if already within an extracted function body. - if context_kind != EnvContextKind::ExtractedFunction { - let x: Vec<(BucketId, Vec>)> = bucket_to_itr_to_ref + if recorder.ctx_kind != EnvContextKind::ExtractedFunction { + let x: Vec<(BucketId, Vec>)> = loc_to_itr_to_ref .iter() .filter_map(|(b, col)| { //if the iteration does not contain any Some(SubCmp), then we return None @@ -596,7 +669,7 @@ impl LoopBodyExtractor { println!("subcmp_arg_groups = {:?}", subcmp_arg_groups); } //ASSERT: Every bucket mapped to a Some(SubcmpSignal) in any iteration is present in exactly one group. - assert!(bucket_to_itr_to_ref + debug_assert!(loc_to_itr_to_ref .iter() .filter_map(|(k, v)| { if v.iter().any(|e| matches!(e, Some((AddressType::SubcmpSignal { .. }, _)))) { @@ -613,7 +686,7 @@ impl LoopBodyExtractor { let counter_idx: FuncArgIdx = next_idx + 1; next_idx += 2; for b in buckets { - bucket_to_args.entry(*b).and_modify(|e| { + loc_to_args.entry(*b).and_modify(|e| { if let ArgIndex::Signal(sig, _) = e { *e = ArgIndex::SubCmp { signal: *sig, @@ -631,12 +704,37 @@ impl LoopBodyExtractor { } //Keep only the table columns where extra parameters are necessary - bucket_to_itr_to_ref.retain(|k, _| bucket_to_args.contains_key(k)); - Ok(ExtraArgsResult { - bucket_to_itr_to_ref: bucket_to_itr_to_ref.into_iter().collect(), - bucket_to_args, - num_args: next_idx, - }) + loc_to_itr_to_ref.retain(|k, _| loc_to_args.contains_key(k)); + Ok((ArgInfo { loc_to_args, num_args: next_idx }, loc_to_itr_to_ref.into_iter().collect())) + } + + #[inline] + fn new_filled_vec(new_len: usize, value: T) -> Vec { + let mut result = Vec::with_capacity(new_len); + result.resize(new_len, value); + result + } + + /// Store the function to be transformed and added to the circuit later. + fn store_new_extracted_function( + &self, + loop_bucket_id: BucketId, + extra_arg_info: Rc, + new_func: FunctionCode, + ) { + self.func_creation_order.borrow_mut().push(new_func.header.clone()); + let key = UniqueFuncKey { loop_bucket_id, extra_arg_info }; + checked_insert!(self.new_body_functions.borrow_mut(), key, new_func); + } + + /// Check if an extracted function exists for the given key information and return its header name. + fn get_extracted_function_name( + &self, + loop_bucket_id: BucketId, + extra_arg_info: Rc, + ) -> Option> { + let key = UniqueFuncKey { loop_bucket_id, extra_arg_info }; + Ref::filter_map(self.new_body_functions.borrow(), |m| m.get(&key).map(|f| &f.header)).ok() } } diff --git a/circuit_passes/src/passes/loop_unroll/loop_env_recorder.rs b/circuit_passes/src/passes/loop_unroll/loop_env_recorder.rs index 2c0cf8658..29ebcbcc5 100644 --- a/circuit_passes/src/passes/loop_unroll/loop_env_recorder.rs +++ b/circuit_passes/src/passes/loop_unroll/loop_env_recorder.rs @@ -1,21 +1,21 @@ use std::cell::{RefCell, Ref}; -use std::collections::{BTreeMap, HashSet, HashMap}; +use std::collections::{BTreeMap, HashMap}; use std::fmt::{Debug, Formatter}; use indexmap::IndexMap; use compiler::intermediate_representation::BucketId; use compiler::intermediate_representation::ir_interface::*; -use crate::bucket_interpreter::env::Env; +use crate::bucket_interpreter::env::{Env, EnvContextKind}; use crate::bucket_interpreter::error::BadInterp; use crate::bucket_interpreter::memory::PassMemory; use crate::bucket_interpreter::observer::Observer; use crate::bucket_interpreter::value::Value; use crate::passes::GlobalPassData; use super::DEBUG_LOOP_UNROLL; -use super::body_extractor::{UnrolledIterLvars, ToOriginalLocation, FuncArgIdx}; pub struct EnvRecorder<'a, 'd> { - global_data: &'d RefCell, + pub global_data: &'d RefCell, mem: &'a PassMemory, + pub(crate) ctx_kind: EnvContextKind, // NOTE: RefCell is needed here because the instance of this struct is borrowed by // the main interpreter while we also need to mutate these internal structures. current_iter_num: RefCell, @@ -55,10 +55,15 @@ impl Debug for EnvRecorder<'_, '_> { } impl<'a, 'd> EnvRecorder<'a, 'd> { - pub fn new(global_data: &'d RefCell, mem: &'a PassMemory) -> Self { + pub fn new( + global_data: &'d RefCell, + mem: &'a PassMemory, + ctx_kind: EnvContextKind, + ) -> Self { EnvRecorder { global_data, mem, + ctx_kind, current_iter_num: RefCell::new(0), safe_to_move: RefCell::new(true), loadstore_to_index_per_iter: Default::default(), @@ -110,23 +115,6 @@ impl<'a, 'd> EnvRecorder<'a, 'd> { self.env_at_header.replace(None); } - pub fn record_reverse_arg_mapping( - &self, - extract_func: String, - iter_env: UnrolledIterLvars, - value: (ToOriginalLocation, HashSet), - ) { - if DEBUG_LOOP_UNROLL { - println!("[EnvRecorder] stored data {:?} -> {:?}", iter_env, value); - } - self.global_data - .borrow_mut() - .extract_func_orig_loc - .entry(extract_func) - .or_default() - .insert(iter_env, value); - } - #[inline] fn default_return(&self) -> Result { Ok(self.is_safe_to_move()) //continue observing unless something unsafe has been found diff --git a/circuit_passes/src/passes/loop_unroll/mod.rs b/circuit_passes/src/passes/loop_unroll/mod.rs index 20f33201a..e2acf8f08 100644 --- a/circuit_passes/src/passes/loop_unroll/mod.rs +++ b/circuit_passes/src/passes/loop_unroll/mod.rs @@ -91,7 +91,7 @@ impl<'d> LoopUnrollPass<'d> { println!("[UNROLL] LOOP ENTRY env {}", env); } // Compute loop iteration count. If unknown, return immediately. - let recorder = EnvRecorder::new(self.global_data, &self.memory); + let recorder = EnvRecorder::new(self.global_data, &self.memory, env.get_context_kind()); { let interpreter = self.memory.build_interpreter(self.global_data, &recorder); let mut inner_env = env.clone(); @@ -151,12 +151,7 @@ impl<'d> LoopUnrollPass<'d> { if DEBUG_LOOP_UNROLL { println!("[UNROLL][try_unroll_loop] OUTCOME: safe to move, extracting"); } - self.extractor.extract( - bucket, - recorder, - env.get_context_kind(), - &mut block_body, - )?; + self.extractor.extract(bucket, recorder, &mut block_body)?; } } } else { @@ -193,7 +188,7 @@ impl Observer> for LoopUnrollPass<'_> { if DEBUG_LOOP_UNROLL { println!("[UNROLL][try_unroll_loop] result = {:?}", result); } - // Add the loop bucket to the ordering for the before visiting within via continue_inside() + // Add the loop bucket to the ordering before visiting within via continue_inside() // so that outer loop iteration counts appear first in the new function name self.loop_bucket_order.borrow_mut().insert(bucket.id); // @@ -251,9 +246,9 @@ impl CircuitTransformationPass for LoopUnrollPass<'_> { fn post_hook_circuit(&self, cir: &mut Circuit) -> Result<(), BadInterp> { // Transform and add the new body functions from the extractor - let new_funcs = self.extractor.get_new_functions(); + let new_funcs = self.extractor.take_new_functions(); cir.functions.reserve_exact(new_funcs.len()); - for f in new_funcs.iter() { + for f in new_funcs { cir.functions.insert(0, self.transform_function(&f)?); } // Add the duplicated versions of functions created by transform_call_bucket() diff --git a/circuit_passes/src/passes/mod.rs b/circuit_passes/src/passes/mod.rs index ec5dbdb1e..b70b985e2 100644 --- a/circuit_passes/src/passes/mod.rs +++ b/circuit_passes/src/passes/mod.rs @@ -766,15 +766,16 @@ pub enum PassKind { MappedToIndexed, UnknownIndexSanitization, } +/// Maps UnrolledIterLvars (from Env::get_vars_sort) to a pair containing: +/// (1) location references from the original function, used by ExtractedFuncEnvData to +/// access the original function's Env via the extracted function's parameter references +/// (2) the set of parameters that contain subcomponent arenas +pub type ExtractedFuncData = BTreeMap)>; #[derive(Debug)] pub struct GlobalPassData { - /// Created during loop unrolling, maps generated function name + UnrolledIterLvars - /// (from Env::get_vars_sort) to location reference in the original function. Used - /// by ExtractedFuncEnvData to access the original function's Env via the extracted - /// function's parameter references. - extract_func_orig_loc: - HashMap)>>, + /// Created during loop unrolling, maps generated function name to ExtractedFuncData for it. + extract_func_orig_loc: HashMap, } impl GlobalPassData {