diff --git a/src/lib.rs b/src/lib.rs index cf47f081..01a08c53 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -205,7 +205,7 @@ impl Protocol { self.process(&l_tag, Operation::AuthCrypt); // Check the tag against the counterfactual short tag. - if ct_eq(s_tag, &s_tag_p) == 1 { + if ct_eq(s_tag, s_tag_p) == 1 { // If the tag is verified, then the ciphertext is authentic. Return the slice of the // input which contains the plaintext. Some(in_out) @@ -337,13 +337,18 @@ enum Operation { Chain = 0x07, } -/// A constant-time comparison using CMOV/CSEL instructions. Returns `1` if the two slices are -/// equal, `0` otherwise. +/// A constant-time comparison of two `u8` slices using conditional move instructions. Returns `1` +/// iff the two slices are equal, `0` otherwise. #[inline(never)] // don't inline to avoid getting optimized into vartime -fn ct_eq(a: &[u8], b: &[u8]) -> u8 { - debug_assert_eq!(a.len(), b.len(), "both slices should be the same length"); +pub fn ct_eq(a: impl AsRef<[u8]>, b: impl AsRef<[u8]>) -> u8 { + // Compare slice lengths in variable time, since there's no other way to do that. + if a.as_ref().len() != b.as_ref().len() { + return 0; + } + + // Iterate through the value pairs, checking for inequality. let mut res = 1; - for (x, y) in a.iter().zip(b.iter()) { + for (x, y) in a.as_ref().iter().zip(b.as_ref().iter()) { x.cmovne(y, 0, &mut res); } res