Skip to content

Commit

Permalink
uadk/cipher: isa_ce - support SM4 cbc_cts mode
Browse files Browse the repository at this point in the history
This patch implements the CE instruction using SM4 CBC_CTS modes.

Signed-off-by: Yang Shen <[email protected]>
Signed-off-by: Qi Tao <[email protected]>
  • Loading branch information
Yang Shen authored and Liulongfang committed Apr 3, 2024
1 parent 6e66b44 commit 8c23969
Show file tree
Hide file tree
Showing 3 changed files with 238 additions and 10 deletions.
91 changes: 89 additions & 2 deletions drv/isa_ce_sm4.c
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,82 @@ static void sm4_cbc_decrypt(struct wd_cipher_msg *msg, const struct SM4_KEY *rke
sm4_v8_cbc_encrypt(msg->in, msg->out, msg->in_bytes, rkey_dec, msg->iv, SM4_DECRYPT);
}

/*
* In some situations, the cts mode can use cbc mode instead to imporve performance.
*/
static int sm4_cts_cbc_instead(struct wd_cipher_msg *msg)
{
if (msg->in_bytes == SM4_BLOCK_SIZE)
return true;

if (!(msg->in_bytes % SM4_BLOCK_SIZE) && msg->mode != WD_CIPHER_CBC_CS3)
return true;

return false;
}

static void sm4_cts_cs1_mode_adapt(__u8 *cts_in, __u8 *cts_out,
const __u32 cts_bytes, const int enc)
{
__u32 rsv_bytes = cts_bytes % SM4_BLOCK_SIZE;
__u8 blocks[SM4_BLOCK_SIZE] = {0};

if (enc == SM4_ENCRYPT) {
memcpy(blocks, cts_out, SM4_BLOCK_SIZE);
memcpy(cts_out, cts_out + SM4_BLOCK_SIZE, rsv_bytes);
memcpy(cts_out + rsv_bytes, blocks, SM4_BLOCK_SIZE);
} else {
memcpy(blocks, cts_in + rsv_bytes, SM4_BLOCK_SIZE);
memcpy(cts_in + SM4_BLOCK_SIZE, cts_in, rsv_bytes);
memcpy(cts_in, blocks, SM4_BLOCK_SIZE);
}
}

static void sm4_cts_cbc_crypt(struct wd_cipher_msg *msg,
const struct SM4_KEY *rkey_enc, const int enc)
{
enum wd_cipher_mode mode = msg->mode;
__u32 in_bytes = msg->in_bytes;
__u8 *cts_in, *cts_out;
__u32 cts_bytes;

if (sm4_cts_cbc_instead(msg))
return sm4_v8_cbc_encrypt(msg->in, msg->out, in_bytes, rkey_enc, msg->iv, enc);

cts_bytes = in_bytes % SM4_BLOCK_SIZE + SM4_BLOCK_SIZE;
if (cts_bytes == SM4_BLOCK_SIZE)
cts_bytes += SM4_BLOCK_SIZE;

in_bytes -= cts_bytes;
if (in_bytes)
sm4_v8_cbc_encrypt(msg->in, msg->out, in_bytes, rkey_enc, msg->iv, enc);

cts_in = msg->in + in_bytes;
cts_out = msg->out + in_bytes;

if (enc == SM4_ENCRYPT) {
sm4_v8_cbc_cts_encrypt(cts_in, cts_out, cts_bytes, rkey_enc, msg->iv);

if (mode == WD_CIPHER_CBC_CS1)
sm4_cts_cs1_mode_adapt(cts_in, cts_out, cts_bytes, enc);
} else {
if (mode == WD_CIPHER_CBC_CS1)
sm4_cts_cs1_mode_adapt(cts_in, cts_out, cts_bytes, enc);

sm4_v8_cbc_cts_decrypt(cts_in, cts_out, cts_bytes, rkey_enc, msg->iv);
}
}

static void sm4_cbc_cts_encrypt(struct wd_cipher_msg *msg, const struct SM4_KEY *rkey_enc)
{
sm4_cts_cbc_crypt(msg, rkey_enc, SM4_ENCRYPT);
}

static void sm4_cbc_cts_decrypt(struct wd_cipher_msg *msg, const struct SM4_KEY *rkey_enc)
{
sm4_cts_cbc_crypt(msg, rkey_enc, SM4_DECRYPT);
}

static void sm4_ecb_encrypt(struct wd_cipher_msg *msg, const struct SM4_KEY *rkey_enc)
{
sm4_v8_ecb_encrypt(msg->in, msg->out, msg->in_bytes, rkey_enc, SM4_ENCRYPT);
Expand All @@ -138,12 +214,12 @@ static void sm4_ecb_decrypt(struct wd_cipher_msg *msg, const struct SM4_KEY *rke
sm4_v8_ecb_encrypt(msg->in, msg->out, msg->in_bytes, rkey_dec, SM4_DECRYPT);
}

void sm4_set_encrypt_key(const __u8 *userKey, struct SM4_KEY *key)
static void sm4_set_encrypt_key(const __u8 *userKey, struct SM4_KEY *key)
{
sm4_v8_set_encrypt_key(userKey, key);
}

void sm4_set_decrypt_key(const __u8 *userKey, struct SM4_KEY *key)
static void sm4_set_decrypt_key(const __u8 *userKey, struct SM4_KEY *key)
{
sm4_v8_set_decrypt_key(userKey, key);
}
Expand Down Expand Up @@ -276,6 +352,14 @@ static int isa_ce_cipher_send(struct wd_alg_driver *drv, handle_t ctx, void *wd_
else
sm4_cbc_decrypt(msg, &rkey);
break;
case WD_CIPHER_CBC_CS1:
case WD_CIPHER_CBC_CS2:
case WD_CIPHER_CBC_CS3:
if (msg->op_type == WD_CIPHER_ENCRYPTION)
sm4_cbc_cts_encrypt(msg, &rkey);
else
sm4_cbc_cts_decrypt(msg, &rkey);
break;
case WD_CIPHER_CTR:
sm4_ctr_encrypt(msg, &rkey);
break;
Expand Down Expand Up @@ -330,6 +414,9 @@ static int cipher_recv(struct wd_alg_driver *drv, handle_t ctx, void *msg)

static struct wd_alg_driver cipher_alg_driver[] = {
GEN_CE_ALG_DRIVER("cbc(sm4)", cipher),
GEN_CE_ALG_DRIVER("cbc-cs1(sm4)", cipher),
GEN_CE_ALG_DRIVER("cbc-cs2(sm4)", cipher),
GEN_CE_ALG_DRIVER("cbc-cs3(sm4)", cipher),
GEN_CE_ALG_DRIVER("ctr(sm4)", cipher),
GEN_CE_ALG_DRIVER("cfb(sm4)", cipher),
GEN_CE_ALG_DRIVER("xts(sm4)", cipher),
Expand Down
24 changes: 16 additions & 8 deletions drv/isa_ce_sm4.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,27 +25,35 @@ struct sm4_ce_drv_ctx {

void sm4_v8_set_encrypt_key(const unsigned char *userKey, struct SM4_KEY *key);
void sm4_v8_set_decrypt_key(const unsigned char *userKey, struct SM4_KEY *key);

void sm4_v8_cbc_encrypt(const unsigned char *in, unsigned char *out,
size_t length, const struct SM4_KEY *key,
unsigned char *ivec, const int enc);
void sm4_v8_cbc_cts_encrypt(const unsigned char *in, unsigned char *out,
size_t len, const void *key, const unsigned char ivec[16]);
void sm4_v8_cbc_cts_decrypt(const unsigned char *in, unsigned char *out,
size_t len, const void *key, const unsigned char ivec[16]);

void sm4_v8_ecb_encrypt(const unsigned char *in, unsigned char *out,
size_t length, const struct SM4_KEY *key, const int enc);

void sm4_v8_ctr32_encrypt_blocks(const unsigned char *in, unsigned char *out,
size_t len, const void *key, const unsigned char ivec[16]);
size_t len, const void *key, const unsigned char ivec[16]);

void sm4_v8_cfb_encrypt_blocks(const unsigned char *in, unsigned char *out,
size_t length, const struct SM4_KEY *key, unsigned char *ivec);
size_t length, const struct SM4_KEY *key, unsigned char *ivec);
void sm4_v8_cfb_decrypt_blocks(const unsigned char *in, unsigned char *out,
size_t length, const struct SM4_KEY *key, unsigned char *ivec);
size_t length, const struct SM4_KEY *key, unsigned char *ivec);

void sm4_v8_crypt_block(const unsigned char *in, unsigned char *out,
const struct SM4_KEY *key);
const struct SM4_KEY *key);

int sm4_v8_xts_encrypt(const unsigned char *in, unsigned char *out, size_t length,
const struct SM4_KEY *key, unsigned char *ivec,
const struct SM4_KEY *key2);
const struct SM4_KEY *key, unsigned char *ivec,
const struct SM4_KEY *key2);
int sm4_v8_xts_decrypt(const unsigned char *in, unsigned char *out, size_t length,
const struct SM4_KEY *key, unsigned char *ivec,
const struct SM4_KEY *key2);
const struct SM4_KEY *key, unsigned char *ivec,
const struct SM4_KEY *key2);

#ifdef __cplusplus
}
Expand Down
133 changes: 133 additions & 0 deletions drv/isa_ce_sm4_armv8.S
Original file line number Diff line number Diff line change
Expand Up @@ -506,6 +506,139 @@ sm4_v8_cbc_encrypt:
ldp d8,d9,[sp],#16
ret
.size sm4_v8_cbc_encrypt,.-sm4_v8_cbc_encrypt

.globl sm4_v8_cbc_cts_encrypt
.type sm4_v8_cbc_cts_encrypt,%function
.align 5
sm4_v8_cbc_cts_encrypt:
AARCH64_VALID_CALL_TARGET
ld1 {v0.4s,v1.4s,v2.4s,v3.4s}, [x3], #64
ld1 {v4.4s,v5.4s,v6.4s,v7.4s}, [x3]
sub x5, x2, #16

ld1 {v8.4s}, [x4]

ld1 {v10.4s}, [x0]
eor v8.16b, v8.16b, v10.16b
rev32 v8.16b, v8.16b;
sm4e v8.4s, v0.4s;
sm4e v8.4s, v1.4s;
sm4e v8.4s, v2.4s;
sm4e v8.4s, v3.4s;
sm4e v8.4s, v4.4s;
sm4e v8.4s, v5.4s;
sm4e v8.4s, v6.4s;
sm4e v8.4s, v7.4s;
rev64 v8.4s, v8.4s;
ext v8.16b, v8.16b, v8.16b, #8;
rev32 v8.16b, v8.16b;

/* load permute table */
adr x6, .cts_permute_table
add x7, x6, #32
add x6, x6, x5
sub x7, x7, x5
ld1 {v13.4s}, [x6]
ld1 {v14.4s}, [x7]

/* overlapping loads */
add x0, x0, x5
ld1 {v11.4s}, [x0]

/* create Cn from En-1 */
tbl v10.16b, {v8.16b}, v13.16b
/* padding Pn with zeros */
tbl v11.16b, {v11.16b}, v14.16b

eor v11.16b, v11.16b, v8.16b
rev32 v11.16b, v11.16b;
sm4e v11.4s, v0.4s;
sm4e v11.4s, v1.4s;
sm4e v11.4s, v2.4s;
sm4e v11.4s, v3.4s;
sm4e v11.4s, v4.4s;
sm4e v11.4s, v5.4s;
sm4e v11.4s, v6.4s;
sm4e v11.4s, v7.4s;
rev64 v11.4s, v11.4s;
ext v11.16b, v11.16b, v11.16b, #8;
rev32 v11.16b, v11.16b;

/* overlapping stores */
add x5, x1, x5
st1 {v10.16b}, [x5]
st1 {v11.16b}, [x1]

ret
.size sm4_v8_cbc_cts_encrypt,.-sm4_v8_cbc_cts_encrypt

.globl sm4_v8_cbc_cts_decrypt
.type sm4_v8_cbc_cts_decrypt,%function
.align 5
sm4_v8_cbc_cts_decrypt:
AARCH64_VALID_CALL_TARGET
ld1 {v0.4s,v1.4s,v2.4s,v3.4s}, [x3], #64
ld1 {v4.4s,v5.4s,v6.4s,v7.4s}, [x3]

sub x5, x2, #16

ld1 {v8.4s}, [x4]

/* load permute table */
adr x6, .cts_permute_table
add x7, x6, #32
add x6, x6, x5
sub x7, x7, x5
ld1 {v13.4s}, [x6]
ld1 {v14.4s}, [x7]

/* overlapping loads */
ld1 {v10.16b}, [x0], x5
ld1 {v11.16b}, [x0]

rev32 v10.16b, v10.16b;
sm4e v10.4s, v0.4s;
sm4e v10.4s, v1.4s;
sm4e v10.4s, v2.4s;
sm4e v10.4s, v3.4s;
sm4e v10.4s, v4.4s;
sm4e v10.4s, v5.4s;
sm4e v10.4s, v6.4s;
sm4e v10.4s, v7.4s;
rev64 v10.4s, v10.4s;
ext v10.16b, v10.16b, v10.16b, #8;
rev32 v10.16b, v10.16b;

/* select the first Ln bytes of Xn to create Pn */
tbl v12.16b, {v10.16b}, v13.16b
eor v12.16b, v12.16b, v11.16b

/* overwrite the first Ln bytes with Cn to create En-1 */
tbx v10.16b, {v11.16b}, v14.16b

rev32 v10.16b, v10.16b;
sm4e v10.4s, v0.4s;
sm4e v10.4s, v1.4s;
sm4e v10.4s, v2.4s;
sm4e v10.4s, v3.4s;
sm4e v10.4s, v4.4s;
sm4e v10.4s, v5.4s;
sm4e v10.4s, v6.4s;
sm4e v10.4s, v7.4s;
rev64 v10.4s, v10.4s;
ext v10.16b, v10.16b, v10.16b, #8;
rev32 v10.16b, v10.16b;

eor v10.16b, v10.16b, v8.16b

/* overlapping stores */
add x5, x1, x5
st1 {v12.16b}, [x5]
st1 {v10.16b}, [x1]

ret
.size sm4_v8_cbc_cts_decrypt,.-sm4_v8_cbc_cts_decrypt

.globl sm4_v8_ecb_encrypt
.type sm4_v8_ecb_encrypt,%function
.align 5
Expand Down

0 comments on commit 8c23969

Please sign in to comment.