Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

heir-simd-vectorizer: dot product example for ckks incorrectly transformed #1115

Open
ZenithalHourlyRate opened this issue Nov 23, 2024 · 6 comments

Comments

@ZenithalHourlyRate
Copy link
Contributor

ZenithalHourlyRate commented Nov 23, 2024

I tried to migrate the bgv dot product example to CKKS, and the result is incorrect. After inspection, it seems that the transformation that heir-simd-vectorizer has done is incorrect.

The input for the pipeline is the belowing; just the bgv example with all i16 substituted with f16.

func.func @dot_product(%arg0: tensor<8xf16>, %arg1: tensor<8xf16>) -> f16 {
  %c0 = arith.constant 0 : index
  %c0_sf16 = arith.constant 0.0 : f16
  %0 = affine.for %arg2 = 0 to 8 iter_args(%iter = %c0_sf16) -> (f16) {
    %1 = tensor.extract %arg0[%arg2] : tensor<8xf16>
    %2 = tensor.extract %arg1[%arg2] : tensor<8xf16>
    %3 = arith.mulf %1, %2 : f16
    %4 = arith.addf %iter, %3 : f16
    affine.yield %4 : f16
  }
  return %0 : f16
}

After running --mlir-to-secret-arithmetic="entry-function=dot_product", we get

module {
  func.func @dot_product(%arg0: !secret.secret<tensor<8xf16>>, %arg1: !secret.secret<tensor<8xf16>>) -> !secret.secret<f16> {
    %c6 = arith.constant 6 : index
    %cst = arith.constant dense<0.000000e+00> : tensor<8xf16>
    %c7 = arith.constant 7 : index
    %0 = secret.generic ins(%arg0, %arg1 : !secret.secret<tensor<8xf16>>, !secret.secret<tensor<8xf16>>) {
    ^bb0(%arg2: tensor<8xf16>, %arg3: tensor<8xf16>):
      %1 = arith.mulf %arg2, %arg3 : tensor<8xf16>
      %2 = arith.addf %1, %cst : tensor<8xf16>
      %3 = tensor_ext.rotate %2, %c6 : tensor<8xf16>, index
      %4 = tensor_ext.rotate %1, %c7 : tensor<8xf16>, index
      %5 = arith.addf %3, %4 : tensor<8xf16>
      %6 = arith.addf %5, %1 : tensor<8xf16>
      %7 = tensor_ext.rotate %6, %c6 : tensor<8xf16>, index
      %8 = arith.addf %7, %4 : tensor<8xf16>
      %9 = arith.addf %8, %1 : tensor<8xf16>
      %10 = tensor_ext.rotate %9, %c6 : tensor<8xf16>, index
      %11 = arith.addf %10, %4 : tensor<8xf16>
      %12 = arith.addf %11, %1 : tensor<8xf16>
      %13 = tensor_ext.rotate %12, %c7 : tensor<8xf16>, index
      %14 = arith.addf %13, %1 : tensor<8xf16>
      %extracted = tensor.extract %14[%c7] : tensor<8xf16>
      secret.yield %extracted : f16
    } -> !secret.secret<f16>
    return %0 : !secret.secret<f16>
  }
}

It is apparently different from the result of dot product for bgv, where a rotate-and-reduce pattern is working.

module {
  func.func @dot_product(%arg0: !secret.secret<tensor<8xi16>>, %arg1: !secret.secret<tensor<8xi16>>) -> !secret.secret<i16> {
    %c1 = arith.constant 1 : index
    %c2 = arith.constant 2 : index
    %c4 = arith.constant 4 : index
    %c7 = arith.constant 7 : index
    %0 = secret.generic ins(%arg0, %arg1 : !secret.secret<tensor<8xi16>>, !secret.secret<tensor<8xi16>>) {
    ^bb0(%arg2: tensor<8xi16>, %arg3: tensor<8xi16>):
      %1 = arith.muli %arg2, %arg3 : tensor<8xi16>
      %2 = tensor_ext.rotate %1, %c4 : tensor<8xi16>, index
      %3 = arith.addi %1, %2 : tensor<8xi16>
      %4 = tensor_ext.rotate %3, %c2 : tensor<8xi16>, index
      %5 = arith.addi %3, %4 : tensor<8xi16>
      %6 = tensor_ext.rotate %5, %c1 : tensor<8xi16>, index
      %7 = arith.addi %5, %6 : tensor<8xi16>
      %extracted = tensor.extract %7[%c7] : tensor<8xi16>
      secret.yield %extracted : i16
    } -> !secret.secret<i16>
    return %0 : !secret.secret<i16>
  }
}

If we execute the emitted ckks code with input (1, 2, 3, 4, 1, 2, 3, 4), we get incorrect result, with traces like this:

v5=EvalMultNoRelin v1, v2
result decrypted: (1, 4, 9, 16, 1, 4, 9, 16,  ... ); Estimated precision: 47 bits

v6=Relinearize v5
result decrypted: (1, 4, 9, 16, 1, 4, 9, 16,  ... ); Estimated precision: 47 bits

v8=EvalAdd v6, v7
result decrypted: (1, 4, 9, 16, 1, 4, 9, 16,  ... ); Estimated precision: 47 bits

v9=EvalRotate v8, 6
result decrypted: (9, 16, -2.35721e-15, 1.85306e-15, 3.1336e-15, 5.86141e-16, -2.08615e-15, 2.21339e-15,  ... ); Estimated precision: 47 bits

v10=EvalRotate v6, 7
result decrypted: (16, -5.25744e-15, 6.75882e-15, 5.41846e-16, 7.02148e-15, -4.51887e-15, 4.41042e-15, -3.54915e-15,  ... ); Estimated precision: 47 bits

v11=EvalAdd v9, v10
result decrypted: (25, 16, 3.80997e-16, 3.49924e-15, -1.00012e-16, -6.83104e-15, -1.55568e-15, -3.20682e-15,  ... ); Estimated precision: 47 bits

v12=EvalAdd v11, v6
result decrypted: (26, 20, 9, 16, 1, 4, 9, 16,  ... ); Estimated precision: 47 bits

v13=EvalRotate v12, 6
result decrypted: (9, 16, -6.82527e-15, -7.59157e-16, 1.27657e-15, -1.48354e-15, 1.24327e-15, 9.49766e-16,  ... ); Estimated precision: 47 bits

v14=EvalAdd v13, v10
result decrypted: (25, 16, 3.2116e-15, 1.26548e-15, -4.74639e-15, -5.8531e-15, 1.28463e-15, 6.89365e-15,  ... ); Estimated precision: 47 bits

v15=EvalAdd v14, v6
result decrypted: (26, 20, 9, 16, 1, 4, 9, 16,  ... ); Estimated precision: 47 bits

v16=EvalRotate v15, 6
result decrypted: (9, 16, -3.58021e-15, 3.47007e-15, -2.34613e-15, -1.81569e-15, -5.64447e-16, -2.51926e-17,  ... ); Estimated precision: 47 bits

v17=EvalAdd v16, v10
result decrypted: (25, 16, 3.26906e-15, -8.31767e-16, 2.92677e-15, -2.5431e-15, -1.33646e-15, -2.1228e-15,  ... ); Estimated precision: 47 bits

v18=EvalAdd v17, v6
result decrypted: (26, 20, 9, 16, 1, 4, 9, 16,  ... ); Estimated precision: 47 bits

v19=EvalRotate v18, 7
result decrypted: (16, 3.52226e-16, 2.59192e-15, 1.11261e-15, -3.50525e-15, 2.17398e-15, -4.26154e-15, -8.52218e-16,  ... ); Estimated precision: 47 bits

v20=EvalAdd v19, v6
result decrypted: (17, 4, 9, 16, 1, 4, 9, 16,  ... ); Estimated precision: 47 bits

v22=EvalMult v20, v21
result decrypted: (-4.38183e-16, -2.34426e-15, 3.01672e-15, 2.99825e-15, -2.78687e-15, -1.29908e-15, -3.58669e-15, 16,  ... ); Estimated precision: 47 bits

v23=EvalRotate v22, 7
result decrypted: (16, -1.8405e-15, -3.56376e-15, -2.69961e-15, -4.74316e-15, -4.06186e-15, 5.39773e-15, -3.67044e-15,  ... ); Estimated precision: 47 bits

Expected: 60
Actual: (16,  ... ); Estimated precision: 47 bits
@j2kun
Copy link
Collaborator

j2kun commented Nov 23, 2024

How concerning! I don't immediately see why the rotate-and-reduce pass isn't properly handling the mulf/addf ops. Maybe you could run with --mlir-print-ir-after-all and --mlir-print-ir-tree-dir to print out the IR after each pass in the pipeline, and then you could compare the bgv/ckks versions to see the first pass at which they meaningfully differ.

I suspect the use of addf/mulf is triggering some incorrect match that causes some pattern or pass to not be applied. However, that shouldn't (absent bugs) cause the output to produce an incorrect result, it should just produce a less efficient program.

So looking for other reasons the output might be incorrect: I see the line above after rotating by 6

result decrypted: (9, 16, -2.35721e-15, 1.85306e-15, 3.1336e-15, 5.86141e-16, -2.08615e-15, 2.21339e-15,  ... )

I believe those zero values should be nonzero. In particular, I recall the encoding used by the simd-vectorizer passes expects 1D tensors to be repeated to fill up the available ciphertext space, since the rotations analyzed are cyclic mod 8 (in your example, because it's a tensor<8xf16>) but the openfhe backend uses larger ciphertext sizes. Cf.

// TODO(#645): support cyclic repetition in add-client-interface
// I want to do this, but MakePackedPlaintext does not repeat the values.
// It zero pads, and rotating the zero-padded values will not achieve the
// rotate-and-reduce trick required for simple_sum
//
// = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
// 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22,
// 23, 24, 25, 26, 27, 28, 29, 30, 31, 32};
input.reserve(n);
and #645

Could that be causing the incorrectness?

@ZenithalHourlyRate
Copy link
Contributor Author

ZenithalHourlyRate commented Nov 23, 2024

I recall the encoding used by the simd-vectorizer passes expects 1D tensors to be repeated to fill up the available ciphertext space

Oh it is exactly this reason, after copying the cyclic filling code from test/Examples, the result is correct.

  std::vector<double> x1 = {1, 2, 3, 4, 1, 2, 3, 4};

  int32_t n =
      cryptoContext->GetCryptoParameters()->GetElementParams()->GetRingDimension() / 2;
  std::vector<double> outputs;
  outputs.reserve(n);
  for (int i = 0; i < n; ++i) {
    outputs.push_back(x1[i % 8]);
  }
  const auto& ptxt1 = cryptoContext->MakeCKKSPackedPlaintext(outputs);
  const auto& c1 = cryptoContext->Encrypt(keyPair.publicKey, ptxt1);

@ZenithalHourlyRate
Copy link
Contributor Author

run with --mlir-print-ir-after-all and --mlir-print-ir-tree-dir

The difference is that the first apply-folder after full-loop-unroll does eliminate the constant 0 (i.e. %c0_si16 = arith.constant 0 : i16) for BGV as it does is a equivalent transform but does not eliminate the constant 0.0; then rotate-and-reduce can not recognize such pattern (I thought I have supported this pattern in past PRs?).

If I manually delete the constant after full-loop-unroll and run the passes again, a correct rotate-and-reduce version will come out.

@ZenithalHourlyRate
Copy link
Contributor Author

I thought I have supported this pattern in past PRs?

The current pattern is

%cst = arith.constant dense<0.000000e+00> : tensor<8xf16>
%1 = arith.mulf %arg2, %arg3 : tensor<8xf16>
%2 = arith.addf %1, %cst : tensor<8xf16>
%3 = tensor_ext.rotate %2, %c6 : tensor<8xf16>, index
%4 = tensor_ext.rotate %1, %c7 : tensor<8xf16>, index
// mixed using of %1 and %2 afterwards.

Which means %3 and %4 does not have the same root in rotation-analysis.

The constant tensor could be saved so that %3/%4 can have a same root, like mentioned in #522, but it is hard to handle, as the constant tensor gets rotated later. I avoided handling saving tensors in earlier PRs. And for this specific case, I think apply-folders should eliminate this constant tensor.

// Only support saving scalar value.
// If the saved rhs is a tensor, it might get rotated alongside
// the reduction tree later.
//
// TODO(#522): if no rotation later then a tensor can be saved.
// This can be implemented via checking in a canRotate method.
//

Full IR below

module {
  func.func @dot_product(%arg0: !secret.secret<tensor<8xf16>>, %arg1: !secret.secret<tensor<8xf16>>) -> !secret.secret<f16> {
    %c7 = arith.constant 7 : index
    %cst = arith.constant dense<0.000000e+00> : tensor<8xf16>
    %c6 = arith.constant 6 : index
    %0 = secret.generic ins(%arg0, %arg1 : !secret.secret<tensor<8xf16>>, !secret.secret<tensor<8xf16>>) {
    ^bb0(%arg2: tensor<8xf16>, %arg3: tensor<8xf16>):
      %1 = arith.mulf %arg2, %arg3 : tensor<8xf16>
      %2 = arith.addf %1, %cst : tensor<8xf16>
      %3 = tensor_ext.rotate %2, %c6 : tensor<8xf16>, index
      %4 = tensor_ext.rotate %1, %c7 : tensor<8xf16>, index
      %5 = arith.addf %3, %4 : tensor<8xf16>
      %6 = arith.addf %5, %1 : tensor<8xf16>
      %7 = tensor_ext.rotate %6, %c6 : tensor<8xf16>, index
      %8 = arith.addf %7, %4 : tensor<8xf16>
      %9 = arith.addf %8, %1 : tensor<8xf16>
      %10 = tensor_ext.rotate %9, %c6 : tensor<8xf16>, index
      %11 = arith.addf %10, %4 : tensor<8xf16>
      %12 = arith.addf %11, %1 : tensor<8xf16>
      %13 = tensor_ext.rotate %12, %c7 : tensor<8xf16>, index
      %14 = arith.addf %13, %1 : tensor<8xf16>
      %extracted = tensor.extract %14[%c7] : tensor<8xf16>
      secret.yield %extracted : f16
    } -> !secret.secret<f16>
    return %0 : !secret.secret<f16>
  }
}

@ZenithalHourlyRate
Copy link
Contributor Author

Or this case migrated to BGV, dot product having an initial non-zero sum, and rotate-and-reduce can not handle such thing correctly. Should we handle this in insert-rotate so that inserted-rotations can have a same root? Now the mixed using of %1 and %2 is not friendly

func.func @dot_product(%arg0: tensor<8xi16>, %arg1: tensor<8xi16>) -> i16 {
  %c0 = arith.constant 0 : index
  %c0_si16 = arith.constant 10 : i16
  %0 = affine.for %arg2 = 0 to 8 iter_args(%iter = %c0_si16) -> (i16) {
    %1 = tensor.extract %arg0[%arg2] : tensor<8xi16>
    %2 = tensor.extract %arg1[%arg2] : tensor<8xi16>
    %3 = arith.muli %1, %2 : i16
    %4 = arith.addi %iter, %3 : i16
    affine.yield %4 : i16
  }
  return %0 : i16
}

We get

module {
  func.func @dot_product(%arg0: !secret.secret<tensor<8xi16>>, %arg1: !secret.secret<tensor<8xi16>>) -> !secret.secret<i16> {
    %c6 = arith.constant 6 : index
    %cst = arith.constant dense<10> : tensor<8xi16>
    %c7 = arith.constant 7 : index
    %0 = secret.generic ins(%arg0, %arg1 : !secret.secret<tensor<8xi16>>, !secret.secret<tensor<8xi16>>) {
    ^bb0(%arg2: tensor<8xi16>, %arg3: tensor<8xi16>):
      %1 = arith.muli %arg2, %arg3 : tensor<8xi16>
      %2 = arith.addi %1, %cst : tensor<8xi16>
      %3 = tensor_ext.rotate %2, %c6 : tensor<8xi16>, index
      %4 = tensor_ext.rotate %1, %c7 : tensor<8xi16>, index
      %5 = arith.addi %3, %4 : tensor<8xi16>
      %6 = arith.addi %5, %1 : tensor<8xi16>
      %7 = tensor_ext.rotate %6, %c6 : tensor<8xi16>, index
      %8 = arith.addi %7, %4 : tensor<8xi16>
      %9 = arith.addi %8, %1 : tensor<8xi16>
      %10 = tensor_ext.rotate %9, %c6 : tensor<8xi16>, index
      %11 = arith.addi %10, %4 : tensor<8xi16>
      %12 = arith.addi %11, %1 : tensor<8xi16>
      %13 = tensor_ext.rotate %12, %c7 : tensor<8xi16>, index
      %14 = arith.addi %13, %1 : tensor<8xi16>
      %extracted = tensor.extract %14[%c7] : tensor<8xi16>
      secret.yield %extracted : i16
    } -> !secret.secret<i16>
    return %0 : !secret.secret<i16>
  }
}

@j2kun
Copy link
Collaborator

j2kun commented Nov 23, 2024

I'm not sure I fully understand what you're suggesting, but let me try to repeat:

does not eliminate the constant 0.0

If we could figure out why the floating point constant is not folded away, we could solve the immediate problem, but I think in your example with a non-zero initial value of the dot product, this problem would persist in another form. It could maybe be fixed by changing insert-rotate so that it aligns things properly, or maybe it could be changed in rotate-and-reduce to recognize a reduction in a smarter way than looking at a single linear chain (#522).

I support either of those improvements. I think also #521 might allow a workaround wherein the rotate-and-reduce works for the entire vector except that first element in the chain, which would be nearly optimal and give other side benefits to IRs that don't do complete reductions..

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants