forked from 360trev/ME7Sum
-
Notifications
You must be signed in to change notification settings - Fork 19
/
rsa.c
316 lines (257 loc) · 8.94 KB
/
rsa.c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
/**********************************************************************
* *
* Created by Adam Brockett *
* *
* Copyright (c) 2010 *
* *
* Redistribution and use in source and binary forms, with or without *
* modification is allowed. *
* *
* But if you let me know you're using my code, that would be freaking*
* sweet. *
* *
**********************************************************************/
#include <stdlib.h>
#include <stdio.h>
#include <string.h>
#include <time.h>
#include "rsa.h"
#define MODULUS_SIZE 1024 /* This is the number of bits we want in the modulus */
#define BLOCK_SIZE (MODULUS_SIZE/8) /* This is the size of a block that gets en/decrypted at once */
#define BUFFER_SIZE ((MODULUS_SIZE/8) / 2) /* This is the number of bytes in n and p */
static void generate_random_p(mpz_t p)
{
char buf[BUFFER_SIZE];
int i;
// Set the bits of tmp randomly
for(i = 0; i < BUFFER_SIZE; i++)
buf[i] = rand() % 0xFF;
// Set the top two bits to 1 to ensure int(tmp) is relatively large
buf[0] |= 0xC0;
// Set the bottom bit to 1 to ensure int(tmp) is odd (better for finding primes)
buf[BUFFER_SIZE - 1] |= 0x01;
// Interpret this char buffer as an int
mpz_import(p, BUFFER_SIZE, 1, sizeof(buf[0]), 0, 0, buf);
// Pick the next prime starting from that random number
mpz_nextprime(p, p);
}
static void generate_primes(private_key *ku)
{
mpz_t tmp;
mpz_init(tmp);
srand(time(NULL));
/* Select p and q */
/* Start with p */
generate_random_p(ku->p);
/* Make sure this is a good choice*/
/* If p mod e == 1, gcd(phi, e) != 1 */
mpz_mod(tmp, ku->p, ku->e);
while(!mpz_cmp_ui(tmp, 1))
{
/* Nope. Choose the next prime */
mpz_nextprime(ku->p, ku->p);
mpz_mod(tmp, ku->p, ku->e);
}
/* Now select q, where q!=p */
do {
generate_random_p(ku->q);
/* Make sure this is a good choice*/
/* If p mod e == 1, gcd(phi, e) != 1 */
mpz_mod(tmp, ku->q, ku->e);
while(!mpz_cmp_ui(tmp, 1))
{
/* Nope. Choose the next prime */
mpz_nextprime(ku->q, ku->q);
mpz_mod(tmp, ku->q, ku->e);
}
} while(mpz_cmp(ku->p, ku->q) == 0); /* If we have identical primes (unlikely), try again */
mpz_clear(tmp);
}
/* NOTE: Assumes mpz_t's are initted in ku and kp */
int generate_keys(private_key* ku, public_key* kp)
{
mpz_t phi;
mpz_t tmp1;
mpz_t tmp2;
mpz_init(phi);
mpz_init(tmp1);
mpz_init(tmp2);
/* Insetead of selecting e st. gcd(phi, e) = 1; 1 < e < phi, lets choose e
* first then pick p,q st. gcd(e, p-1) = gcd(e, q-1) = 1 */
// We'll set e globally. I've seen suggestions to use primes like 3, 17 or
// 65537, as they make coming calculations faster. Lets use 3.
mpz_set_ui(ku->e, 3);
generate_primes(ku);
/* Calculate n = p x q */
mpz_mul(ku->n, ku->p, ku->q);
/* Compute phi(n) = (p-1)(q-1) */
mpz_sub_ui(tmp1, ku->p, 1);
mpz_sub_ui(tmp2, ku->q, 1);
mpz_mul(phi, tmp1, tmp2);
mpz_clear(tmp2);
/* Calculate d (multiplicative inverse of e mod phi) */
if(mpz_invert(ku->d, ku->e, phi) == 0)
{
mpz_gcd(tmp1, ku->e, phi);
printf("gcd(e, phi) = [%s]\n", mpz_get_str(NULL, 16, tmp1));
printf("Invert failed\n");
mpz_clear(tmp1);
mpz_clear(phi);
return -1;
}
mpz_clear(tmp1);
mpz_clear(phi);
/* Set public key */
mpz_set(kp->e, ku->e);
mpz_set(kp->n, ku->n);
return 0;
}
void block_encrypt(mpz_t C, mpz_t M, public_key kp)
{
/* C = M^e mod n */
mpz_powm(C, M, kp.e, kp.n);
return;
}
int rsa_encrypt(char *cipher, const char *message, int length, public_key kp)
{
/* Its probably overkill, but I implemented PKCS#1v1.5 paging
* Encoded message block is of the form:
* EMB = 00 || 02 || PS || 00 || D
* Where || is concatenation, D is the message, and PS is a string of
* (block_size-|D|-3) non-zero, randomly generated bytes
* |D| must be less than block_size - 11, which means we have at least 8
* bytes of PS
*/
int block_count = 0;
int prog = length;
char mess_block[BLOCK_SIZE];
mpz_t m;
mpz_t c;
mpz_init(m);
mpz_init(c);
while(prog > 0)
{
int i = 0;
int d_len = (prog >= (BLOCK_SIZE - 11)) ? BLOCK_SIZE - 11 : prog;
int off;
/* Construct the header */
mess_block[i++] = 0x00;
mess_block[i++] = 0x02;
while(i < (BLOCK_SIZE - d_len - 1))
mess_block[i++] = (rand() % (0xFF - 1)) + 1;
mess_block[i++] = 0x00;
/* Copy in the message */
memcpy(mess_block + i, message + (length - prog), d_len);
// Convert bytestream to integer
mpz_import(m, BLOCK_SIZE, 1, sizeof(mess_block[0]), 0, 0, mess_block);
// Perform encryption on that block
block_encrypt(c, m, kp);
// Calculate cipher write offset to take into account that we want to
// pad with zeros in the front if the number we get back has fewer bits
// than BLOCK_SIZE
off = block_count * BLOCK_SIZE; // Base offset to start of this block
off += (BLOCK_SIZE - (mpz_sizeinbase(c, 2) + 8 - 1)/8); // See manual for mpz_export
// Pull out bytestream of ciphertext
mpz_export(cipher + off, NULL, 1, sizeof(char), 0, 0, c);
block_count++;
prog -= d_len;
}
mpz_clear(m);
mpz_clear(c);
return block_count * BLOCK_SIZE;
}
void block_decrypt(mpz_t M, mpz_t C, private_key ku)
{
mpz_powm(M, C, ku.d, ku.n);
return;
}
int rsa_decrypt(char* message, const char* cipher, int length, private_key ku)
{
int msg_idx = 0;
char buf[BLOCK_SIZE];
int i;
mpz_t c;
mpz_t m;
mpz_init(c);
mpz_init(m);
memset(buf,0,BLOCK_SIZE);
for(i = 0; i < (length / BLOCK_SIZE); i++)
{
int off;
int j;
// Pull block into mpz_t
mpz_import(c, BLOCK_SIZE, 1, sizeof(char), 0, 0, cipher + i * BLOCK_SIZE);
// Decrypt block
block_decrypt(m, c, ku);
// Calculate message write offset to take into account that we want to
// pad with zeros in the front if the number we get back has fewer bits
// than BLOCK_SIZE
off = (BLOCK_SIZE - (mpz_sizeinbase(m, 2) + 8 - 1)/8); // See manual for mpz_export
// Convert back to bitstream
mpz_export(buf + off, NULL, 1, sizeof(char), 0, 0, m);
// Now we just need to lop off top padding before memcpy-ing to message
// We know the first 2 bytes are 0x00 and 0x02, so manually skip those
// After that, increment forward till we see a zero byte
for(j = 2; ((buf[j] != 0) && (j < BLOCK_SIZE)); j++);
j++; // Skip the 00 byte
/* Copy over the message part of the plaintext to the message return var */
memcpy(message + msg_idx, buf + j, BLOCK_SIZE - j);
msg_idx += BLOCK_SIZE - j;
}
mpz_clear(c);
mpz_clear(m);
return msg_idx;
}
#ifdef TEST
int main()
{
int i;
mpz_t M;
mpz_t C;
mpz_t DC;
private_key ku;
public_key kp;
char buf[BLOCK_SIZE];
mpz_init(M);
mpz_init(C);
mpz_init(DC);
// Initialize public key
mpz_init(kp.n);
mpz_init(kp.e);
// Initialize private key
mpz_init(ku.n);
mpz_init(ku.e);
mpz_init(ku.d);
mpz_init(ku.p);
mpz_init(ku.q);
generate_keys(&ku, &kp);
printf("---------------Private Key-----------------");
printf("kp.n is [%s]\n", mpz_get_str(NULL, 16, kp.n));
printf("kp.e is [%s]\n", mpz_get_str(NULL, 16, kp.e));
printf("---------------Public Key------------------");
printf("ku.n is [%s]\n", mpz_get_str(NULL, 16, ku.n));
printf("ku.e is [%s]\n", mpz_get_str(NULL, 16, ku.e));
printf("ku.d is [%s]\n", mpz_get_str(NULL, 16, ku.d));
printf("ku.p is [%s]\n", mpz_get_str(NULL, 16, ku.p));
printf("ku.q is [%s]\n", mpz_get_str(NULL, 16, ku.q));
for(i = 0; i < BLOCK_SIZE; i++)
buf[i] = rand() % 0xFF;
mpz_import(M, (BLOCK_SIZE), 1, sizeof(buf[0]), 0, 0, buf);
printf("original is [%s]\n", mpz_get_str(NULL, 16, M));
block_encrypt(C, M, kp);
printf("encrypted is [%s]\n", mpz_get_str(NULL, 16, C));
block_decrypt(DC, C, ku);
printf("decrypted is [%s]\n", mpz_get_str(NULL, 16, DC));
mpz_clear(M);
mpz_clear(C);
mpz_clear(DC);
mpz_clear(kp.n);
mpz_clear(kp.e);
mpz_clear(ku.n);
mpz_clear(ku.e);
mpz_clear(ku.d);
mpz_clear(ku.p);
mpz_clear(ku.q);
return 0;
}
#endif