diff --git a/lib/Dialect/LWE/IR/NewLWEAttributes.td b/lib/Dialect/LWE/IR/NewLWEAttributes.td index 6870f1f7d..e8766d4aa 100644 --- a/lib/Dialect/LWE/IR/NewLWEAttributes.td +++ b/lib/Dialect/LWE/IR/NewLWEAttributes.td @@ -281,20 +281,25 @@ def LWE_KeyAttr : AttrDef { $1$ for LWE instances. A ciphertext encrypted with a `key_size` of $k$ will have size $k+1$. - The key basis describes the inner product used in the phase calculation in - decryption. This attribute is only supported for RLWE ciphertexts whose + The key basis/power describes the inner product used in the phase calculation + in decryption. This attribute is only supported for RLWE ciphertexts whose `key_size` is $1$. An RLWE ciphertext is canonically encrypted against key basis `(1, s)`. After a multiplication, its size will increase and the basis will be `(1, s, s^2)`. The array that represents the key basis is constructed by listing the powers of `s` at each position of the array. For example, `(1, s, s^2)` corresponds to `[0, 1, 2]`, while `(1, s^2)` - corresponds to `[0, 2]`. + corresponds to `[0, 2]`. The array that represents the key rotate is constructed + by listing the powers of `X` at each position of the array. For example, + `(1, s, s(X^2))` corresponds to `[0, 1, 2]`. Combining the basis/rotate array + together, we can express `(1, s, s^2, s(X^2), s^2(X^2))` as `[0, 1, 2, 1, 2]` + and `[0, 1, 1, 2, 2]`. }]; let parameters = (ins "::mlir::StringAttr":$id, DefaultValuedParameter<"unsigned", "1">:$size, - OptionalArrayRefParameter<"unsigned int">:$basis + OptionalArrayRefParameter<"unsigned int">:$basis, + OptionalArrayRefParameter<"unsigned int">:$rotate ); let assemblyFormat = "`<` struct(params) `>`"; @@ -348,4 +353,60 @@ def LWE_ModulusChainAttr : AttrDef { // let genVerifyDecl = 1; // Verify index into list } +def LWE_BVKeySwitchAttr : AttrDef { + let mnemonic = "bv_keyswitch_technique"; + let description = [{ + An attribute describing the BV technique for keyswitch. + + `base` is the radix base used in decomposition of the coefficient modulus + `Q` (non-RNS case) / `qi` (RNS case). + + `dnum` is the number of large digits for the RNS case. It takes effect + only when `base` equals 0. + + Check Appendix A of https://eprint.iacr.org/2021/204.pdf for more detail. + }]; + + let parameters = (ins + "IntegerAttr":$base, + "IntegerAttr":$dnum + ); + + let assemblyFormat = "`<` struct(params) `>`"; + + // let genVerifyDecl = 1; +} + + +def LWE_GHSKeySwitchAttr : AttrDef { + let mnemonic = "ghs_keyswitch_technique"; + let description = [{ + An attribute describing the GHS technique for keyswitch. + + `extra_modulus` is the extra modulus `P` needed by the technique. + In RNS case, it is a chain of modulus. + + Check Appendix A of https://eprint.iacr.org/2021/204.pdf for more detail. + }]; + + let parameters = (ins + "ModulusChainAttr":$extra_modulus + ); + + let assemblyFormat = "`<` struct(params) `>`"; + + // let genVerifyDecl = 1; +} + +def LWE_AnyKeySwitchAttr : AttrDef { + let mnemonic = "keyswitch_technique"; + let returnType = "Attribute"; + let convertFromStorage = "$_self"; + string cppType = "Attribute"; + let predicate = Or<[ + LWE_BVKeySwitchAttr.predicate, + LWE_GHSKeySwitchAttr.predicate + ]>; +} + #endif // LIB_DIALECT_LWE_IR_NEWLWEATTRIBUTES_TD_ diff --git a/lib/Dialect/LWE/IR/NewLWETypes.td b/lib/Dialect/LWE/IR/NewLWETypes.td index ce576d9b6..a5a5cca3f 100644 --- a/lib/Dialect/LWE/IR/NewLWETypes.td +++ b/lib/Dialect/LWE/IR/NewLWETypes.td @@ -34,6 +34,17 @@ def NewLWEPublicKey : LWE_Type<"NewLWEPublicKey", "new_lwe_public_key"> { ); } +def NewLWEEvaluationKey : LWE_Type<"NewLWEEvaluationKey", "new_lwe_evaluation_key"> { + let summary = "A evaluation key for LWE"; + let parameters = (ins + "KeyAttr":$from_key, + "KeyAttr":$to_key, + "::mlir::polynomial::RingAttr":$ring, + // can not be ArrayRefParameter<"KeySwitchAttr"> + ArrayRefParameter<"Attribute">:$keyswitch_techniques + ); +} + def NewLWESecretOrPublicKey : AnyTypeOf<[NewLWESecretKey, NewLWEPublicKey]>; def NewLWEPlaintext : LWE_Type<"NewLWEPlaintext", "new_lwe_plaintext"> { diff --git a/tests/Dialect/LWE/IR/attributes.mlir b/tests/Dialect/LWE/IR/attributes.mlir index 73afee962..2e92a53a3 100644 --- a/tests/Dialect/LWE/IR/attributes.mlir +++ b/tests/Dialect/LWE/IR/attributes.mlir @@ -147,8 +147,11 @@ func.func @test_fn() { // ----- #key = #lwe.key -#key_rlwe_rotate = #lwe.key +#key_rlwe_mult = #lwe.key #key_rlwe_2 = #lwe.key +#key_rlwe_rotate = #lwe.key +// not allowed! mlir cannot parse this +// #key_rlwe_mult_rotate = #lwe.key // CHECK-LABEL: test_fn func.func @test_fn() { @@ -174,3 +177,14 @@ func.func @test_fn() { func.func @test_fn() { return } + +// ----- + +#keyswitch_bv_base = #lwe.bv_keyswitch_technique +#keyswitch_bv = #lwe.bv_keyswitch_technique +#keyswitch_ghs = #lwe.ghs_keyswitch_technique, current = 0>> + +// CHECK-LABEL: test_fn +func.func @test_fn() { + return +} diff --git a/tests/Dialect/LWE/IR/types.mlir b/tests/Dialect/LWE/IR/types.mlir index 73e1ed6ad..789ea2f75 100644 --- a/tests/Dialect/LWE/IR/types.mlir +++ b/tests/Dialect/LWE/IR/types.mlir @@ -48,6 +48,27 @@ func.func @test_new_lwe_public_key(%arg0 : !public_key) -> !public_key { return %arg0 : !public_key } +#key_mult = #lwe.key +#key_rotate = #lwe.key + +#keyswitch_bv_base = #lwe.bv_keyswitch_technique +#keyswitch_bv = #lwe.bv_keyswitch_technique +#keyswitch_ghs = #lwe.ghs_keyswitch_technique, current = 0>> + +!evaluation_key_mult = !lwe.new_lwe_evaluation_key +!evaluation_key_rotate = !lwe.new_lwe_evaluation_key + +// CHECK-LABEL test_new_lwe_evaluation_key_mult + +func.func @test_new_lwe_evaluation_key_mult(%arg0 : !evaluation_key_mult) -> !evaluation_key_mult { + return %arg0 : !evaluation_key_mult +} + +// CHECK-LABEL test_new_lwe_evaluation_key_rotate + +func.func @test_new_lwe_evaluation_key_rotate(%arg0 : !evaluation_key_rotate) -> !evaluation_key_rotate { + return %arg0 : !evaluation_key_rotate +} #preserve_overflow = #lwe.preserve_overflow<> #application_data = #lwe.application_data