Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AVX-512 support for RSA Signing #1273

Merged
merged 36 commits into from
Sep 17, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
b9088fc
Use IFMA_AVX512 when possible for modular exponentiation.
pittma Aug 7, 2023
e6269ff
Add test coverage for consttime_x2 mod exp function
pittma Oct 23, 2023
6d2ece9
Add fuzzer coverage for BN_mod_exp_mont_consttime_x2
pittma Oct 23, 2023
e0ad9da
prevent empty translation units for compilers that don't like them
pittma Oct 30, 2023
024a9ec
properly handle AVX-512 build conditions
pittma Oct 31, 2023
cd2a3d1
fips builds require subsections
pittma Oct 31, 2023
d4d89fc
fix disallowed interaction with `OPENSSL_ia32_cap_P` in fips mode
pittma Nov 2, 2023
a0f3737
reset sections when they change for variable declaration
pittma Nov 2, 2023
8e55af5
include avx512ifma flag
pittma Nov 3, 2023
7d1ea20
handle AVX-512 mask register usage in fips delocation process
pittma Nov 15, 2023
407df8d
address review comments
pittma Jan 30, 2024
e67bbda
regen generated source
pittma Feb 1, 2024
b33709e
regenerate delocate parser
pittma Feb 1, 2024
0e7c607
AVX-512 RSA Signing: address first PR review
pittma Apr 10, 2024
b2d1327
Merge remote-tracking branch 'origin/main'
pittma Apr 10, 2024
14fefe0
Still export the parallel mod_exp implementation
pittma Apr 12, 2024
5e1c7ee
second set of review comments and documentation
pittma Apr 24, 2024
73d389d
fix generated source conflict
pittma Apr 24, 2024
087bf5c
Merge branch 'main' of github.com:aws/aws-lc into pmain
pittma Jul 25, 2024
c439bf0
address review 3 comments
pittma Jul 25, 2024
abe1124
Merge branch 'main' of github.com:aws/aws-lc
pittma Aug 7, 2024
37b4a4a
Merge branch 'main' of github.com:aws/aws-lc into pmain
pittma Sep 5, 2024
e06d8d0
further review comments
pittma Sep 4, 2024
bf9fc29
add ABI tests for new RSA AVX-512 assmebly routines
pittma Sep 5, 2024
e626c2c
add dispatch tests for AVX-512 enabled RSA signing
pittma Sep 5, 2024
92b9e3f
fix dispatch test
pittma Sep 6, 2024
1055b42
Merge remote-tracking branch 'origin/main'
pittma Sep 6, 2024
58af762
Merge branch 'main' of github.com:aws/aws-lc
pittma Sep 9, 2024
56d8fd6
fix conditional build logic in dispatch test
pittma Sep 9, 2024
f925e7c
generated asm should properly exclude when using old assembler
pittma Sep 9, 2024
2473469
Merge branch 'main' of github.com:aws/aws-lc
pittma Sep 10, 2024
ef26ced
in ninja-based build, old assembler logic is already handled
pittma Sep 10, 2024
73b7b8f
Merge branch 'main' of github.com:aws/aws-lc
pittma Sep 10, 2024
506dced
Increasing the capacity of ubuntu2004_android_fips_static_release.
nebeid Sep 11, 2024
0dd53a1
Merge branch 'main' into main
nebeid Sep 11, 2024
f3715bb
Merge branch 'main' of github.com:aws/aws-lc
pittma Sep 16, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 26 additions & 11 deletions crypto/fipsmodule/bn/rsaz_exp_x2.c
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,8 @@ int RSAZ_mod_exp_avx512_x2(uint64_t *res1,
const uint64_t *b, const uint64_t *m, uint64_t k0);
int ret = 0;

// Number of word-size (uint64_t) digits to store in redundant
// representation.
// Number of word-size (uint64_t) digits to store values in
// redundant representation.
int red_digits = number_of_digits(modlen + 2, DIGIT_SIZE);

// n = modlen, d = DIGIT_SIZE, s = d * ceil((n+2)/d) > n
Expand Down Expand Up @@ -124,7 +124,7 @@ int RSAZ_mod_exp_avx512_x2(uint64_t *res1,
uint64_t *storage = NULL;
uint64_t *storage_aligned = NULL;
int storage_len_bytes = 7 * regs_capacity * sizeof(uint64_t)
+ 64;
+ 64; // alignment

const uint64_t *exp[2] = {0};
uint64_t k0[2] = {0};
Expand Down Expand Up @@ -177,17 +177,19 @@ int RSAZ_mod_exp_avx512_x2(uint64_t *res1,
// - We have AMM(t, 2^k) = R^4 * 2^{4*(s-n)} / R'^2 mod m -- (2)
// = R'^4 / R'^2 mod m
// = R'^2 mod m
// For example, for n = 1024, s = 1040, k = 64,
// RR = 2^2048 mod m, RR' = 2^2080 mod m

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// For example, for n = 1024, s = 1040, k = 64,
// RR = 2^2048 mod m, RR' = 2^2080 mod m

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason the example wasn't added back? It was in the original commit, I just reworded it. I think it was a helpful illustration.

OPENSSL_memset(coeff_red, 0, red_digits * sizeof(uint64_t));
// coeff_red = 2^k = 1 << bitlen_diff taking into account the
// redundant representation in digits of DIGIT_SIZE bits
set_bit(coeff_red, 64 * (int)(bitlen_diff / DIGIT_SIZE) + bitlen_diff % DIGIT_SIZE);
dkostic marked this conversation as resolved.
Show resolved Hide resolved

amm(rr1_red, rr1_red, rr1_red, m1_red, k0_1);
amm(rr1_red, rr1_red, coeff_red, m1_red, k0_1);
amm(rr1_red, rr1_red, rr1_red, m1_red, k0_1); // (1) for m1
amm(rr1_red, rr1_red, coeff_red, m1_red, k0_1); // (2) for m1

amm(rr2_red, rr2_red, rr2_red, m2_red, k0_2);
amm(rr2_red, rr2_red, coeff_red, m2_red, k0_2);
amm(rr2_red, rr2_red, rr2_red, m2_red, k0_2); // (1) for m2
amm(rr2_red, rr2_red, coeff_red, m2_red, k0_2); // (2) for m2

exp[0] = exp1;
exp[1] = exp2;
Expand Down Expand Up @@ -316,6 +318,11 @@ int rsaz_mod_exp_x2_ifma256(uint64_t *out,
red_table = red_X + 2 * red_digits;
expz = red_table + 2 * red_digits * two_to_exp_win_size;

// Compute table of powers base^i mod m,
// i = 0, ..., (2^EXP_WIN_SIZE) - 1
// using the dual multiplication. Each table entry contains
// base1^i mod m1, then base2^i mod m2.

red_X[0 * red_digits] = 1;
red_X[1 * red_digits] = 1;
damm(&red_table[0 * 2 * red_digits], (const uint64_t*)red_X, rr, m, k0);
Expand Down Expand Up @@ -367,9 +374,17 @@ int rsaz_mod_exp_x2_ifma256(uint64_t *out,
// `rem` is { 1024, 1536, 2048 } % 5 which is { 4, 1, 3 }
// respectively.
//
// If this assertion ever fails the fix above is easy.
// If this assertion ever fails then we should set this easy
// fix exp_bit_no = modlen - exp_win_size
assert(rem == 4 || rem == 1 || rem == 3);


// Find the location of the 5-bit window in the exponent which
// is stored in 64-bit digits. Left pad it with 0s to form a
// 64-bit digit to become an index in the precomputed table.
// The window location in the exponent is identified by its
// least significant bit `exp_bit_no`.

#define EXP_CHUNK(i) (exp_chunk_no) + ((i) * (exp_digits + 1))
#define EXP_CHUNK1(i) (exp_chunk_no) + 1 + ((i) * (exp_digits + 1))

Expand All @@ -395,7 +410,7 @@ int rsaz_mod_exp_x2_ifma256(uint64_t *out,
exp_chunk_no = exp_bit_no / 64;
exp_chunk_shift = exp_bit_no % 64;
{
red_table_idx_1 = expz[exp_chunk_no + 0 * (exp_digits + 1)];
red_table_idx_1 = expz[EXP_CHUNK(0)];
T = expz[EXP_CHUNK1(0)];

red_table_idx_1 >>= exp_chunk_shift;
Expand All @@ -408,7 +423,7 @@ int rsaz_mod_exp_x2_ifma256(uint64_t *out,
red_table_idx_1 &= table_idx_mask;
}
{
red_table_idx_2 = expz[exp_chunk_no + 1 * (exp_digits + 1)];
red_table_idx_2 = expz[EXP_CHUNK(1)];
T = expz[EXP_CHUNK1(1)];

red_table_idx_2 >>= exp_chunk_shift;
Expand All @@ -425,7 +440,7 @@ int rsaz_mod_exp_x2_ifma256(uint64_t *out,
(int)red_table_idx_1, (int)red_table_idx_2);
}

// Series of squaring
// The number of squarings is equal to the window size.
DAMS((uint64_t*)red_Y, (const uint64_t*)red_Y, m, k0);
DAMS((uint64_t*)red_Y, (const uint64_t*)red_Y, m, k0);
DAMS((uint64_t*)red_Y, (const uint64_t*)red_Y, m, k0);
Expand Down
101 changes: 77 additions & 24 deletions crypto/impl_dispatch_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -247,36 +247,89 @@ TEST_F(ImplDispatchTest, SHA512) {
}
#endif // OPENSSL_AARCH64


#if defined(OPENSSL_X86_64) && !defined(MY_ASSEMBLER_IS_TOO_OLD_512AVX) && \
defined(RSAZ_512_ENABLED)

#include "test/file_test.h"

static bssl::UniquePtr<BIGNUM> GetBIGNUM(FileTest *t, const char *attr);

static bssl::UniquePtr<BIGNUM> GetBIGNUM(FileTest *t, const char *attr) {
std::string hex;
if (!t->GetAttribute(&hex, attr)) {
return nullptr;
}

BIGNUM *raw = NULL;
int size = BN_hex2bn(&raw, hex.c_str());
if (size != static_cast<int>(hex.size())) {
t->PrintLine("Could not decode '%s'.", hex.c_str());
return nullptr;
}

bssl::UniquePtr<BIGNUM> ret;
(&ret)->reset(raw);
return ret;
}

TEST_F(ImplDispatchTest, BN_mod_exp_mont_consttime_x2) {
AssertFunctionsHit(
FileTestGTest(
"crypto/fipsmodule/bn/test/mod_exp_x2_tests.txt",
[&](FileTest *t) {
AssertFunctionsHit(
{
{kFlag_RSAZ_mod_exp_avx512_x2,
is_x86_64_ &&
!is_assembler_too_old_avx512 &&
ifma_avx512},
{kFlag_RSAZ_mod_exp_avx512_x2,
is_x86_64_ &&
!is_assembler_too_old_avx512 &&
ifma_avx512},
},
[] {
uint64_t res1 = 0;
uint64_t base1 = 0;
uint64_t exp1 = 0;
uint64_t m1 = 0;
uint64_t rr1 = 0;
uint64_t k0_1 = 0;
uint64_t res2 = 0;
uint64_t base2 = 0;
uint64_t exp2 = 0;
uint64_t m2 = 0;
uint64_t rr2 = 0;
uint64_t k0_2 = 0;
int modlen = 0;

RSAZ_mod_exp_avx512_x2(&res1, &base1, &exp1, &m1, &rr1, k0_1,
&res2, &base2, &exp2, &m2, &rr2, k0_2,
modlen);
[&]() {
BN_CTX *ctx = BN_CTX_new();
BN_CTX_start(ctx);
bssl::UniquePtr<BIGNUM> a1 = GetBIGNUM(t, "A1");
bssl::UniquePtr<BIGNUM> e1 = GetBIGNUM(t, "E1");
bssl::UniquePtr<BIGNUM> m1 = GetBIGNUM(t, "M1");
bssl::UniquePtr<BIGNUM> mod_exp1 = GetBIGNUM(t, "ModExp1");
ASSERT_TRUE(a1);
ASSERT_TRUE(e1);
ASSERT_TRUE(m1);
ASSERT_TRUE(mod_exp1);

bssl::UniquePtr<BIGNUM> a2 = GetBIGNUM(t, "A2");
bssl::UniquePtr<BIGNUM> e2 = GetBIGNUM(t, "E2");
bssl::UniquePtr<BIGNUM> m2 = GetBIGNUM(t, "M2");
bssl::UniquePtr<BIGNUM> mod_exp2 = GetBIGNUM(t, "ModExp2");
ASSERT_TRUE(a2);
ASSERT_TRUE(e2);
ASSERT_TRUE(m2);
ASSERT_TRUE(mod_exp2);

bssl::UniquePtr<BIGNUM> ret1(BN_new());
ASSERT_TRUE(ret1);

bssl::UniquePtr<BIGNUM> ret2(BN_new());
ASSERT_TRUE(ret2);

ASSERT_TRUE(BN_nnmod(a1.get(), a1.get(), m1.get(), ctx));
ASSERT_TRUE(BN_nnmod(a2.get(), a2.get(), m2.get(), ctx));

BN_MONT_CTX *mont1 = NULL;
BN_MONT_CTX *mont2 = NULL;

ASSERT_TRUE(mont1 = BN_MONT_CTX_new());
ASSERT_TRUE(BN_MONT_CTX_set(mont1, m1.get(), ctx));
ASSERT_TRUE(mont2 = BN_MONT_CTX_new());
ASSERT_TRUE(BN_MONT_CTX_set(mont2, m2.get(), ctx));

BN_mod_exp_mont_consttime_x2(ret1.get(), a1.get(), e1.get(), m1.get(), mont1,
ret2.get(), a2.get(), e2.get(), m2.get(), mont2,
ctx);

BN_CTX_end(ctx);
BN_MONT_CTX_free(mont1);
BN_MONT_CTX_free(mont2);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
BN_CTX_end(ctx);
BN_MONT_CTX_free(mont1);
BN_MONT_CTX_free(mont2);
BN_MONT_CTX_free(mont1);
BN_MONT_CTX_free(mont2);
BN_CTX_end(ctx);
BN_CTX_free(ctx);

});
});
}
#endif // OPENSSL_X86_64 && !MY_ASSEMBLER_IS_TOO_OLD_512AVX && RSAZ_512_ENABLED

Expand Down