diff --git a/drv/isa_ce_sm4.c b/drv/isa_ce_sm4.c index ccab8fb6..69614712 100644 --- a/drv/isa_ce_sm4.c +++ b/drv/isa_ce_sm4.c @@ -128,6 +128,82 @@ static void sm4_cbc_decrypt(struct wd_cipher_msg *msg, const struct SM4_KEY *rke sm4_v8_cbc_encrypt(msg->in, msg->out, msg->in_bytes, rkey_dec, msg->iv, SM4_DECRYPT); } +/* + * In some situations, the cts mode can use cbc mode instead to imporve performance. + */ +static int sm4_cts_cbc_instead(struct wd_cipher_msg *msg) +{ + if (msg->in_bytes == SM4_BLOCK_SIZE) + return true; + + if (!(msg->in_bytes % SM4_BLOCK_SIZE) && msg->mode != WD_CIPHER_CBC_CS3) + return true; + + return false; +} + +static void sm4_cts_cs1_mode_adapt(__u8 *cts_in, __u8 *cts_out, + const __u32 cts_bytes, const int enc) +{ + __u32 rsv_bytes = cts_bytes % SM4_BLOCK_SIZE; + __u8 blocks[SM4_BLOCK_SIZE] = {0}; + + if (enc == SM4_ENCRYPT) { + memcpy(blocks, cts_out, SM4_BLOCK_SIZE); + memcpy(cts_out, cts_out + SM4_BLOCK_SIZE, rsv_bytes); + memcpy(cts_out + rsv_bytes, blocks, SM4_BLOCK_SIZE); + } else { + memcpy(blocks, cts_in + rsv_bytes, SM4_BLOCK_SIZE); + memcpy(cts_in + SM4_BLOCK_SIZE, cts_in, rsv_bytes); + memcpy(cts_in, blocks, SM4_BLOCK_SIZE); + } +} + +static void sm4_cts_cbc_crypt(struct wd_cipher_msg *msg, + const struct SM4_KEY *rkey_enc, const int enc) +{ + enum wd_cipher_mode mode = msg->mode; + __u32 in_bytes = msg->in_bytes; + __u8 *cts_in, *cts_out; + __u32 cts_bytes; + + if (sm4_cts_cbc_instead(msg)) + return sm4_v8_cbc_encrypt(msg->in, msg->out, in_bytes, rkey_enc, msg->iv, enc); + + cts_bytes = in_bytes % SM4_BLOCK_SIZE + SM4_BLOCK_SIZE; + if (cts_bytes == SM4_BLOCK_SIZE) + cts_bytes += SM4_BLOCK_SIZE; + + in_bytes -= cts_bytes; + if (in_bytes) + sm4_v8_cbc_encrypt(msg->in, msg->out, in_bytes, rkey_enc, msg->iv, enc); + + cts_in = msg->in + in_bytes; + cts_out = msg->out + in_bytes; + + if (enc == SM4_ENCRYPT) { + sm4_v8_cbc_cts_encrypt(cts_in, cts_out, cts_bytes, rkey_enc, msg->iv); + + if (mode == WD_CIPHER_CBC_CS1) + sm4_cts_cs1_mode_adapt(cts_in, cts_out, cts_bytes, enc); + } else { + if (mode == WD_CIPHER_CBC_CS1) + sm4_cts_cs1_mode_adapt(cts_in, cts_out, cts_bytes, enc); + + sm4_v8_cbc_cts_decrypt(cts_in, cts_out, cts_bytes, rkey_enc, msg->iv); + } +} + +static void sm4_cbc_cts_encrypt(struct wd_cipher_msg *msg, const struct SM4_KEY *rkey_enc) +{ + sm4_cts_cbc_crypt(msg, rkey_enc, SM4_ENCRYPT); +} + +static void sm4_cbc_cts_decrypt(struct wd_cipher_msg *msg, const struct SM4_KEY *rkey_enc) +{ + sm4_cts_cbc_crypt(msg, rkey_enc, SM4_DECRYPT); +} + static void sm4_ecb_encrypt(struct wd_cipher_msg *msg, const struct SM4_KEY *rkey_enc) { sm4_v8_ecb_encrypt(msg->in, msg->out, msg->in_bytes, rkey_enc, SM4_ENCRYPT); @@ -138,12 +214,12 @@ static void sm4_ecb_decrypt(struct wd_cipher_msg *msg, const struct SM4_KEY *rke sm4_v8_ecb_encrypt(msg->in, msg->out, msg->in_bytes, rkey_dec, SM4_DECRYPT); } -void sm4_set_encrypt_key(const __u8 *userKey, struct SM4_KEY *key) +static void sm4_set_encrypt_key(const __u8 *userKey, struct SM4_KEY *key) { sm4_v8_set_encrypt_key(userKey, key); } -void sm4_set_decrypt_key(const __u8 *userKey, struct SM4_KEY *key) +static void sm4_set_decrypt_key(const __u8 *userKey, struct SM4_KEY *key) { sm4_v8_set_decrypt_key(userKey, key); } @@ -276,6 +352,14 @@ static int isa_ce_cipher_send(struct wd_alg_driver *drv, handle_t ctx, void *wd_ else sm4_cbc_decrypt(msg, &rkey); break; + case WD_CIPHER_CBC_CS1: + case WD_CIPHER_CBC_CS2: + case WD_CIPHER_CBC_CS3: + if (msg->op_type == WD_CIPHER_ENCRYPTION) + sm4_cbc_cts_encrypt(msg, &rkey); + else + sm4_cbc_cts_decrypt(msg, &rkey); + break; case WD_CIPHER_CTR: sm4_ctr_encrypt(msg, &rkey); break; @@ -330,6 +414,9 @@ static int cipher_recv(struct wd_alg_driver *drv, handle_t ctx, void *msg) static struct wd_alg_driver cipher_alg_driver[] = { GEN_CE_ALG_DRIVER("cbc(sm4)", cipher), + GEN_CE_ALG_DRIVER("cbc-cs1(sm4)", cipher), + GEN_CE_ALG_DRIVER("cbc-cs2(sm4)", cipher), + GEN_CE_ALG_DRIVER("cbc-cs3(sm4)", cipher), GEN_CE_ALG_DRIVER("ctr(sm4)", cipher), GEN_CE_ALG_DRIVER("cfb(sm4)", cipher), GEN_CE_ALG_DRIVER("xts(sm4)", cipher), diff --git a/drv/isa_ce_sm4.h b/drv/isa_ce_sm4.h index d10b0af7..308619e7 100644 --- a/drv/isa_ce_sm4.h +++ b/drv/isa_ce_sm4.h @@ -25,27 +25,35 @@ struct sm4_ce_drv_ctx { void sm4_v8_set_encrypt_key(const unsigned char *userKey, struct SM4_KEY *key); void sm4_v8_set_decrypt_key(const unsigned char *userKey, struct SM4_KEY *key); + void sm4_v8_cbc_encrypt(const unsigned char *in, unsigned char *out, size_t length, const struct SM4_KEY *key, unsigned char *ivec, const int enc); +void sm4_v8_cbc_cts_encrypt(const unsigned char *in, unsigned char *out, + size_t len, const void *key, const unsigned char ivec[16]); +void sm4_v8_cbc_cts_decrypt(const unsigned char *in, unsigned char *out, + size_t len, const void *key, const unsigned char ivec[16]); + void sm4_v8_ecb_encrypt(const unsigned char *in, unsigned char *out, size_t length, const struct SM4_KEY *key, const int enc); + void sm4_v8_ctr32_encrypt_blocks(const unsigned char *in, unsigned char *out, - size_t len, const void *key, const unsigned char ivec[16]); + size_t len, const void *key, const unsigned char ivec[16]); void sm4_v8_cfb_encrypt_blocks(const unsigned char *in, unsigned char *out, - size_t length, const struct SM4_KEY *key, unsigned char *ivec); + size_t length, const struct SM4_KEY *key, unsigned char *ivec); void sm4_v8_cfb_decrypt_blocks(const unsigned char *in, unsigned char *out, - size_t length, const struct SM4_KEY *key, unsigned char *ivec); + size_t length, const struct SM4_KEY *key, unsigned char *ivec); + void sm4_v8_crypt_block(const unsigned char *in, unsigned char *out, - const struct SM4_KEY *key); + const struct SM4_KEY *key); int sm4_v8_xts_encrypt(const unsigned char *in, unsigned char *out, size_t length, - const struct SM4_KEY *key, unsigned char *ivec, - const struct SM4_KEY *key2); + const struct SM4_KEY *key, unsigned char *ivec, + const struct SM4_KEY *key2); int sm4_v8_xts_decrypt(const unsigned char *in, unsigned char *out, size_t length, - const struct SM4_KEY *key, unsigned char *ivec, - const struct SM4_KEY *key2); + const struct SM4_KEY *key, unsigned char *ivec, + const struct SM4_KEY *key2); #ifdef __cplusplus } diff --git a/drv/isa_ce_sm4_armv8.S b/drv/isa_ce_sm4_armv8.S index 7d844969..6ebf39b3 100644 --- a/drv/isa_ce_sm4_armv8.S +++ b/drv/isa_ce_sm4_armv8.S @@ -506,6 +506,139 @@ sm4_v8_cbc_encrypt: ldp d8,d9,[sp],#16 ret .size sm4_v8_cbc_encrypt,.-sm4_v8_cbc_encrypt + +.globl sm4_v8_cbc_cts_encrypt +.type sm4_v8_cbc_cts_encrypt,%function +.align 5 +sm4_v8_cbc_cts_encrypt: + AARCH64_VALID_CALL_TARGET + ld1 {v0.4s,v1.4s,v2.4s,v3.4s}, [x3], #64 + ld1 {v4.4s,v5.4s,v6.4s,v7.4s}, [x3] + sub x5, x2, #16 + + ld1 {v8.4s}, [x4] + + ld1 {v10.4s}, [x0] + eor v8.16b, v8.16b, v10.16b + rev32 v8.16b, v8.16b; + sm4e v8.4s, v0.4s; + sm4e v8.4s, v1.4s; + sm4e v8.4s, v2.4s; + sm4e v8.4s, v3.4s; + sm4e v8.4s, v4.4s; + sm4e v8.4s, v5.4s; + sm4e v8.4s, v6.4s; + sm4e v8.4s, v7.4s; + rev64 v8.4s, v8.4s; + ext v8.16b, v8.16b, v8.16b, #8; + rev32 v8.16b, v8.16b; + + /* load permute table */ + adr x6, .cts_permute_table + add x7, x6, #32 + add x6, x6, x5 + sub x7, x7, x5 + ld1 {v13.4s}, [x6] + ld1 {v14.4s}, [x7] + + /* overlapping loads */ + add x0, x0, x5 + ld1 {v11.4s}, [x0] + + /* create Cn from En-1 */ + tbl v10.16b, {v8.16b}, v13.16b + /* padding Pn with zeros */ + tbl v11.16b, {v11.16b}, v14.16b + + eor v11.16b, v11.16b, v8.16b + rev32 v11.16b, v11.16b; + sm4e v11.4s, v0.4s; + sm4e v11.4s, v1.4s; + sm4e v11.4s, v2.4s; + sm4e v11.4s, v3.4s; + sm4e v11.4s, v4.4s; + sm4e v11.4s, v5.4s; + sm4e v11.4s, v6.4s; + sm4e v11.4s, v7.4s; + rev64 v11.4s, v11.4s; + ext v11.16b, v11.16b, v11.16b, #8; + rev32 v11.16b, v11.16b; + + /* overlapping stores */ + add x5, x1, x5 + st1 {v10.16b}, [x5] + st1 {v11.16b}, [x1] + + ret +.size sm4_v8_cbc_cts_encrypt,.-sm4_v8_cbc_cts_encrypt + +.globl sm4_v8_cbc_cts_decrypt +.type sm4_v8_cbc_cts_decrypt,%function +.align 5 +sm4_v8_cbc_cts_decrypt: + AARCH64_VALID_CALL_TARGET + ld1 {v0.4s,v1.4s,v2.4s,v3.4s}, [x3], #64 + ld1 {v4.4s,v5.4s,v6.4s,v7.4s}, [x3] + + sub x5, x2, #16 + + ld1 {v8.4s}, [x4] + + /* load permute table */ + adr x6, .cts_permute_table + add x7, x6, #32 + add x6, x6, x5 + sub x7, x7, x5 + ld1 {v13.4s}, [x6] + ld1 {v14.4s}, [x7] + + /* overlapping loads */ + ld1 {v10.16b}, [x0], x5 + ld1 {v11.16b}, [x0] + + rev32 v10.16b, v10.16b; + sm4e v10.4s, v0.4s; + sm4e v10.4s, v1.4s; + sm4e v10.4s, v2.4s; + sm4e v10.4s, v3.4s; + sm4e v10.4s, v4.4s; + sm4e v10.4s, v5.4s; + sm4e v10.4s, v6.4s; + sm4e v10.4s, v7.4s; + rev64 v10.4s, v10.4s; + ext v10.16b, v10.16b, v10.16b, #8; + rev32 v10.16b, v10.16b; + + /* select the first Ln bytes of Xn to create Pn */ + tbl v12.16b, {v10.16b}, v13.16b + eor v12.16b, v12.16b, v11.16b + + /* overwrite the first Ln bytes with Cn to create En-1 */ + tbx v10.16b, {v11.16b}, v14.16b + + rev32 v10.16b, v10.16b; + sm4e v10.4s, v0.4s; + sm4e v10.4s, v1.4s; + sm4e v10.4s, v2.4s; + sm4e v10.4s, v3.4s; + sm4e v10.4s, v4.4s; + sm4e v10.4s, v5.4s; + sm4e v10.4s, v6.4s; + sm4e v10.4s, v7.4s; + rev64 v10.4s, v10.4s; + ext v10.16b, v10.16b, v10.16b, #8; + rev32 v10.16b, v10.16b; + + eor v10.16b, v10.16b, v8.16b + + /* overlapping stores */ + add x5, x1, x5 + st1 {v12.16b}, [x5] + st1 {v10.16b}, [x1] + + ret +.size sm4_v8_cbc_cts_decrypt,.-sm4_v8_cbc_cts_decrypt + .globl sm4_v8_ecb_encrypt .type sm4_v8_ecb_encrypt,%function .align 5