From 4d5457eebb9e6a7f2ee56a25fc2528d88f2bb4c3 Mon Sep 17 00:00:00 2001 From: Alexander 'z33ky' Hirsch <1zeeky@gmail.com> Date: Mon, 25 Dec 2023 18:00:52 +0100 Subject: [PATCH] Fix panic when trying to compare big numbers The previous version tried to convert numeric strings to u64 to compare the values numerically. This results in a panic if the number exceeds the bounds of u64. To fix this, numbers are now compared digit-by-digit instead. Fixes surrealdb/lexicmp#1. --- src/cmp.rs | 152 ++++++++++++++++++++++++++++++++++++++--------------- 1 file changed, 111 insertions(+), 41 deletions(-) diff --git a/src/cmp.rs b/src/cmp.rs index 49d37cb..a7c1dc7 100644 --- a/src/cmp.rs +++ b/src/cmp.rs @@ -1,38 +1,98 @@ use crate::iter::{iterate_lexical, iterate_lexical_only_alnum}; -use core::cmp::Ordering; - -macro_rules! cmp_ascii_digits { - (first_digits($lhs:ident, $rhs:ident), iterators($iter1:ident, $iter2:ident)) => { - let mut n1 = ascii_to_u64($lhs); - let mut n2 = ascii_to_u64($rhs); - loop { - match ( - $iter1.peek().copied().filter(|c| c.is_ascii_digit()), - $iter2.peek().copied().filter(|c| c.is_ascii_digit()), - ) { - (Some(lhs), Some(rhs)) => { - n1 = n1 * 10 + ascii_to_u64(lhs); - n2 = n2 * 10 + ascii_to_u64(rhs); - let _ = $iter1.next(); - let _ = $iter2.next(); - } - (Some(_), None) => return Ordering::Greater, - (None, Some(_)) => return Ordering::Less, - (None, None) => { - if n1 != n2 { - return n1.cmp(&n2); +use core::{ + cmp::Ordering, + iter::Peekable, +}; + +fn cmp_ascii_digits(lhs: &mut Peekable>, rhs: &mut Peekable>) -> Option { + #[derive(PartialEq)] + enum Origin { + Lhs, + Rhs, + } + + // The loop below iterates through both iterators at once and handles ascii digits for comparison. + // If one iterator runs out of ascii digits, it is stored in this struct together with the + // information where it originated from. + struct NonDigit { + c: char, + origin: Origin, + } + + impl core::ops::Deref for NonDigit { + type Target = char; + + fn deref(&self) -> &Self::Target { + &self.c + } + } + + impl NonDigit { + #[allow(dead_code)] + + fn is_lhs(&self) -> bool { + self.origin == Origin::Lhs + } + + fn is_rhs(&self) -> bool { + self.origin == Origin::Rhs + } + } + + fn ok_if_ascii_digit(c: char) -> Result { + Some(c).filter(char::is_ascii_digit).ok_or(c) + } + + let mut current_cmp = None; + loop { + match (lhs.peek(), rhs.peek()) { + (Some(&a), Some(&b)) => { + let non_digit = match (ok_if_ascii_digit(a), ok_if_ascii_digit(b)) { + (Ok(a), Ok(b)) => { + // Only update current_cmp if the current comparison is yet undecided. + // current_cmp is returned later when at least one iterator has hit a non-digit. + if current_cmp.is_none() || current_cmp == Some(Ordering::Equal) { + current_cmp = Some(a.cmp(&b)); + } + None + }, + (Err(c), Ok(_)) => Some(NonDigit{ c, origin: Origin::Lhs }), + (Ok(_), Err(c)) => Some(NonDigit{ c, origin: Origin::Rhs }), + (Err(_), Err(_)) => break current_cmp, + }; + + // Advance underlying iterators, since we only peek and break early if no iterator + // has any digits left, keeping these characters in the iterators for the caller to + // deal with in case current_cmp.is_none() or current_cmp == Some(Ordering::Equal). + let _ = lhs.next(); + let _ = rhs.next(); + + // Return the appropriate ordering of a number versus non-digit characters. + if let Some(c) = non_digit { + let mut ord = if current_cmp.is_none() && c.is_alphanumeric() { + Ordering::Greater } else { - break; + Ordering::Less + }; + if c.is_rhs() { + ord = ord.reverse(); } + break Some(ord); } } + (Some(_), None) => { + let _ = lhs.next(); + break Some(Ordering::Greater); + } + (None, Some(_)) => { + let _ = rhs.next(); + break Some(Ordering::Less); + } + (None, None) => { + break current_cmp; + } } - }; -} - -#[inline] -fn ascii_to_u64(c: char) -> u64 { - (c as u64) - (b'0' as u64) + } } #[inline] @@ -100,11 +160,13 @@ pub fn natural_lexical_cmp(s1: &str, s2: &str) -> Ordering { let mut iter2 = iterate_lexical(s2).peekable(); loop { + match cmp_ascii_digits(&mut iter1, &mut iter2) { + None | Some(Ordering::Equal) => (), + Some(result) => return result, + } match (iter1.next(), iter2.next()) { (Some(lhs), Some(rhs)) => { - if lhs.is_ascii_digit() && rhs.is_ascii_digit() { - cmp_ascii_digits!(first_digits(lhs, rhs), iterators(iter1, iter2)); - } else if lhs != rhs { + if lhs != rhs { return ret_ordering(lhs, rhs); } } @@ -123,11 +185,13 @@ pub fn natural_lexical_only_alnum_cmp(s1: &str, s2: &str) -> Ordering { let mut iter2 = iterate_lexical_only_alnum(s2).peekable(); loop { + match cmp_ascii_digits(&mut iter1, &mut iter2) { + None | Some(Ordering::Equal) => (), + Some(result) => return result, + } match (iter1.next(), iter2.next()) { (Some(lhs), Some(rhs)) => { - if lhs.is_ascii_digit() && rhs.is_ascii_digit() { - cmp_ascii_digits!(first_digits(lhs, rhs), iterators(iter1, iter2)); - } else if lhs != rhs { + if lhs != rhs { return lhs.cmp(&rhs); } } @@ -146,11 +210,13 @@ pub fn natural_cmp(s1: &str, s2: &str) -> Ordering { let mut iter2 = s2.chars().peekable(); loop { + match cmp_ascii_digits(&mut iter1, &mut iter2) { + None | Some(Ordering::Equal) => (), + Some(result) => return result, + } match (iter1.next(), iter2.next()) { (Some(lhs), Some(rhs)) => { - if lhs.is_ascii_digit() && rhs.is_ascii_digit() { - cmp_ascii_digits!(first_digits(lhs, rhs), iterators(iter1, iter2)); - } else if lhs != rhs { + if lhs != rhs { return lhs.cmp(&rhs); } } @@ -169,11 +235,13 @@ pub fn natural_only_alnum_cmp(s1: &str, s2: &str) -> Ordering { let mut iter2 = s2.chars().filter(|c| c.is_alphanumeric()).peekable(); loop { + match cmp_ascii_digits(&mut iter1, &mut iter2) { + None | Some(Ordering::Equal) => (), + Some(result) => return result, + } match (iter1.next(), iter2.next()) { (Some(lhs), Some(rhs)) => { - if lhs.is_ascii_digit() && rhs.is_ascii_digit() { - cmp_ascii_digits!(first_digits(lhs, rhs), iterators(iter1, iter2)); - } else if lhs != rhs { + if lhs != rhs { return lhs.cmp(&rhs); } } @@ -329,6 +397,8 @@ mod tests { ordered("T-27", "Ŧ-5"); ordered("T-5", "Ŧ-27"); ordered("T-5", "Ŧ-5"); + + ordered("00000000000000000000", "18446744073709551616"); } #[test]