Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LWE: add evaluation key type and technique #1067

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 65 additions & 4 deletions lib/Dialect/LWE/IR/NewLWEAttributes.td
Original file line number Diff line number Diff line change
Expand Up @@ -281,20 +281,25 @@ def LWE_KeyAttr : AttrDef<LWE_Dialect, "Key"> {
$1$ for LWE instances. A ciphertext encrypted with a `key_size` of $k$ will
have size $k+1$.

The key basis describes the inner product used in the phase calculation in
decryption. This attribute is only supported for RLWE ciphertexts whose
The key basis/power describes the inner product used in the phase calculation
in decryption. This attribute is only supported for RLWE ciphertexts whose
`key_size` is $1$. An RLWE ciphertext is canonically encrypted against key
basis `(1, s)`. After a multiplication, its size will increase and the basis
will be `(1, s, s^2)`. The array that represents the key basis is
constructed by listing the powers of `s` at each position of the array. For
example, `(1, s, s^2)` corresponds to `[0, 1, 2]`, while `(1, s^2)`
corresponds to `[0, 2]`.
corresponds to `[0, 2]`. The array that represents the key rotate is constructed
by listing the powers of `X` at each position of the array. For example,
`(1, s, s(X^2))` corresponds to `[0, 1, 2]`. Combining the basis/rotate array
together, we can express `(1, s, s^2, s(X^2), s^2(X^2))` as `[0, 1, 2, 1, 2]`
and `[0, 1, 1, 2, 2]`.
}];

let parameters = (ins
"::mlir::StringAttr":$id,
DefaultValuedParameter<"unsigned", "1">:$size,
OptionalArrayRefParameter<"unsigned int">:$basis
OptionalArrayRefParameter<"unsigned int">:$basis,
OptionalArrayRefParameter<"unsigned int">:$rotate
);

let assemblyFormat = "`<` struct(params) `>`";
Expand Down Expand Up @@ -348,4 +353,60 @@ def LWE_ModulusChainAttr : AttrDef<LWE_Dialect, "ModulusChain"> {
// let genVerifyDecl = 1; // Verify index into list
}

def LWE_BVKeySwitchAttr : AttrDef<LWE_Dialect, "BVKeySwitch"> {
let mnemonic = "bv_keyswitch_technique";
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: just remove the _technique here and below

let description = [{
An attribute describing the BV technique for keyswitch.

`base` is the radix base used in decomposition of the coefficient modulus
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
`base` is the radix base used in decomposition of the coefficient modulus
`base` is the radix base used in the digit decomposition of the coefficient modulus

`Q` (non-RNS case) / `qi` (RNS case).

`dnum` is the number of large digits for the RNS case. It takes effect
only when `base` equals 0.

Check Appendix A of https://eprint.iacr.org/2021/204.pdf for more detail.
}];

let parameters = (ins
"IntegerAttr":$base,
"IntegerAttr":$dnum
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thank you for considering the RNS variants! I understand it's important to know the dnums so that the type converter knows the type / shape of the key switching key. For the lowering, we'll need the actual RNS moduli, which we'll pull from the RNS modulus of the ring in the keyswitching key, right?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OptionalParameter? (can also update the assembly format for an optional print)

);

let assemblyFormat = "`<` struct(params) `>`";

// let genVerifyDecl = 1;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we could implement the type constraint here.

}


def LWE_GHSKeySwitchAttr : AttrDef<LWE_Dialect, "GHSKeySwitch"> {
let mnemonic = "ghs_keyswitch_technique";
let description = [{
An attribute describing the GHS technique for keyswitch.

`extra_modulus` is the extra modulus `P` needed by the technique.
In RNS case, it is a chain of modulus.

Check Appendix A of https://eprint.iacr.org/2021/204.pdf for more detail.
}];

let parameters = (ins
"ModulusChainAttr":$extra_modulus
);

let assemblyFormat = "`<` struct(params) `>`";

// let genVerifyDecl = 1;
}

def LWE_AnyKeySwitchAttr : AttrDef<LWE_Dialect, "KeySwitch"> {
let mnemonic = "keyswitch_technique";
let returnType = "Attribute";
let convertFromStorage = "$_self";
string cppType = "Attribute";
let predicate = Or<[
LWE_BVKeySwitchAttr.predicate,
LWE_GHSKeySwitchAttr.predicate
]>;
}

#endif // LIB_DIALECT_LWE_IR_NEWLWEATTRIBUTES_TD_
11 changes: 11 additions & 0 deletions lib/Dialect/LWE/IR/NewLWETypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,17 @@ def NewLWEPublicKey : LWE_Type<"NewLWEPublicKey", "new_lwe_public_key"> {
);
}

def NewLWEEvaluationKey : LWE_Type<"NewLWEEvaluationKey", "new_lwe_evaluation_key"> {
let summary = "A evaluation key for LWE";
let parameters = (ins
"KeyAttr":$from_key,
"KeyAttr":$to_key,
"::mlir::polynomial::RingAttr":$ring,
// can not be ArrayRefParameter<"KeySwitchAttr">
ArrayRefParameter<"Attribute">:$keyswitch_techniques
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this an array ref parameter? Wouldn't only a sinlgle technique apply to a given key?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I intentionally left the HYBRID keyswitching technique as a combination of BV attr and GHS attr, so that we do not have to define another LWE_HYBRIDKeySwitchAttr where all base/dnum/extra_modulus will be included. If we accept such elaboration, then we can define it and here can be a single technique.

);
}

def NewLWESecretOrPublicKey : AnyTypeOf<[NewLWESecretKey, NewLWEPublicKey]>;

def NewLWEPlaintext : LWE_Type<"NewLWEPlaintext", "new_lwe_plaintext"> {
Expand Down
16 changes: 15 additions & 1 deletion tests/Dialect/LWE/IR/attributes.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -147,8 +147,11 @@ func.func @test_fn() {
// -----

#key = #lwe.key<id = "1234">
#key_rlwe_rotate = #lwe.key<id = "1234", basis = 0, 2>
#key_rlwe_mult = #lwe.key<id = "1234", basis = 0, 2>
#key_rlwe_2 = #lwe.key<id = "1234", size = 2>
#key_rlwe_rotate = #lwe.key<id = "1234", rotate = 2>
// not allowed! mlir cannot parse this
// #key_rlwe_mult_rotate = #lwe.key<id = "1234", basis = 0, 2, rotate = 0, 2>

// CHECK-LABEL: test_fn
func.func @test_fn() {
Expand All @@ -174,3 +177,14 @@ func.func @test_fn() {
func.func @test_fn() {
return
}

// -----

#keyswitch_bv_base = #lwe.bv_keyswitch_technique<base = 65536, dnum = 0>
#keyswitch_bv = #lwe.bv_keyswitch_technique<base = 0, dnum = 3>
#keyswitch_ghs = #lwe.ghs_keyswitch_technique<extra_modulus=<elements = <65537 : i32>, current = 0>>

// CHECK-LABEL: test_fn
func.func @test_fn() {
return
}
21 changes: 21 additions & 0 deletions tests/Dialect/LWE/IR/types.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,27 @@ func.func @test_new_lwe_public_key(%arg0 : !public_key) -> !public_key {
return %arg0 : !public_key
}

#key_mult = #lwe.key<id = "1234", size = 1, basis = 2>
#key_rotate = #lwe.key<id = "1234", size = 1, rotate = 2>

#keyswitch_bv_base = #lwe.bv_keyswitch_technique<base = 65536, dnum = 0>
#keyswitch_bv = #lwe.bv_keyswitch_technique<base = 0, dnum = 3>
#keyswitch_ghs = #lwe.ghs_keyswitch_technique<extra_modulus=<elements = <65537 : i32>, current = 0>>

!evaluation_key_mult = !lwe.new_lwe_evaluation_key<from_key=#key_mult, to_key=#key, ring=#ring, keyswitch_techniques= #keyswitch_bv, #keyswitch_ghs>
!evaluation_key_rotate = !lwe.new_lwe_evaluation_key<from_key=#key_rotate, to_key=#key, ring=#ring, keyswitch_techniques= #keyswitch_bv_base>

// CHECK-LABEL test_new_lwe_evaluation_key_mult

func.func @test_new_lwe_evaluation_key_mult(%arg0 : !evaluation_key_mult) -> !evaluation_key_mult {
return %arg0 : !evaluation_key_mult
}

// CHECK-LABEL test_new_lwe_evaluation_key_rotate

func.func @test_new_lwe_evaluation_key_rotate(%arg0 : !evaluation_key_rotate) -> !evaluation_key_rotate {
return %arg0 : !evaluation_key_rotate
}

#preserve_overflow = #lwe.preserve_overflow<>
#application_data = #lwe.application_data<message_type = i1, overflow = #preserve_overflow>
Expand Down
Loading