Skip to content

Commit

Permalink
decode: use exact decoded length rather than estimation
Browse files Browse the repository at this point in the history
Fixes: #210
Fixes: #212
  • Loading branch information
mina86 committed Feb 10, 2023
1 parent f766bc6 commit 9e1d415
Show file tree
Hide file tree
Showing 10 changed files with 186 additions and 319 deletions.
86 changes: 82 additions & 4 deletions src/decode.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::engine::{general_purpose::STANDARD, DecodeEstimate, Engine};
use crate::engine::{general_purpose::STANDARD, Engine};
#[cfg(any(feature = "alloc", feature = "std", test))]
use alloc::vec::Vec;
use core::fmt;
Expand Down Expand Up @@ -130,6 +130,73 @@ pub fn decode_engine_slice<E: Engine, T: AsRef<[u8]>>(
engine.decode_slice(input, output)
}

/// Returns the decoded size of the `encoded` input assuming the input is valid
/// base64 string.
///
/// Assumes input is a valid base64-encoded string. Result is unspecified if it
/// isn’t.
///
/// If you don’t need a precise length of the decoded string, you can use
/// [`decoded_len_estimate`] function instead. It’s faster and provides an
/// estimate which is only at most two bytes off from the real length.
///
/// # Examples
///
/// ```
/// use base64::decoded_len;
///
/// assert_eq!(0, decoded_len(b""));
/// assert_eq!(1, decoded_len(b"AA"));
/// assert_eq!(2, decoded_len(b"AAA"));
/// assert_eq!(3, decoded_len(b"AAAA"));
/// assert_eq!(1, decoded_len(b"AA=="));
/// assert_eq!(2, decoded_len(b"AAA="));
/// ```
pub fn decoded_len(encoded: impl AsRef<[u8]>) -> usize {
let encoded = encoded.as_ref();
if encoded.len() < 2 {
return 0;
}
let is_pad = |idx| (encoded[encoded.len() - idx] == b'=') as usize;
let len = encoded.len() - is_pad(1) - is_pad(2);
match len % 4 {
0 => len / 4 * 3,
remainder => len / 4 * 3 + remainder - 1,
}
}

#[test]
fn test_decoded_len() {
for chunks in 0..25 {
let mut input = vec![b'A'; chunks * 4 + 4];
assert_eq!(chunks * 3 + 0, decoded_len(&input[..chunks * 4]));
assert_eq!(chunks * 3 + 1, decoded_len(&input[..chunks * 4 + 2]));
assert_eq!(chunks * 3 + 2, decoded_len(&input[..chunks * 4 + 3]));
assert_eq!(chunks * 3 + 3, decoded_len(&input[..chunks * 4 + 4]));

input[chunks * 4 + 3] = b'=';
assert_eq!(chunks * 3 + 1, decoded_len(&input[..chunks * 4 + 2]));
assert_eq!(chunks * 3 + 2, decoded_len(&input[..chunks * 4 + 3]));
assert_eq!(chunks * 3 + 2, decoded_len(&input[..chunks * 4 + 4]));
input[chunks * 4 + 2] = b'=';
assert_eq!(chunks * 3 + 1, decoded_len(&input[..chunks * 4 + 2]));
assert_eq!(chunks * 3 + 1, decoded_len(&input[..chunks * 4 + 3]));
assert_eq!(chunks * 3 + 1, decoded_len(&input[..chunks * 4 + 4]));
}

// Mustn’t panic or overflow if given bogus input.
for len in 1..100 {
let mut input = vec![b'A'; len];
let got = decoded_len(&input);
debug_assert!(got <= len);
for padding in 1..=len.min(10) {
input[len - padding] = b'=';
let got = decoded_len(&input);
debug_assert!(got <= len);
}
}
}

/// Returns a conservative estimate of the decoded size of `encoded_len` base64 symbols (rounded up
/// to the next group of 3 decoded bytes).
///
Expand All @@ -141,6 +208,7 @@ pub fn decode_engine_slice<E: Engine, T: AsRef<[u8]>>(
/// ```
/// use base64::decoded_len_estimate;
///
/// assert_eq!(0, decoded_len_estimate(0));
/// assert_eq!(3, decoded_len_estimate(1));
/// assert_eq!(3, decoded_len_estimate(2));
/// assert_eq!(3, decoded_len_estimate(3));
Expand All @@ -149,9 +217,19 @@ pub fn decode_engine_slice<E: Engine, T: AsRef<[u8]>>(
/// assert_eq!(6, decoded_len_estimate(5));
/// ```
pub fn decoded_len_estimate(encoded_len: usize) -> usize {
STANDARD
.internal_decoded_len_estimate(encoded_len)
.decoded_len_estimate()
(encoded_len / 4 + (encoded_len % 4 > 0) as usize) * 3
}

#[test]
fn test_decode_len_estimate() {
for chunks in 0..250 {
assert_eq!(chunks * 3, decoded_len_estimate(chunks * 4));
assert_eq!(chunks * 3 + 3, decoded_len_estimate(chunks * 4 + 1));
assert_eq!(chunks * 3 + 3, decoded_len_estimate(chunks * 4 + 2));
assert_eq!(chunks * 3 + 3, decoded_len_estimate(chunks * 4 + 3));
}
// Mustn’t panic or overflow.
assert_eq!(usize::MAX / 4 * 3 + 3, decoded_len_estimate(usize::MAX));
}

#[cfg(test)]
Expand Down
25 changes: 8 additions & 17 deletions src/encode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -96,24 +96,15 @@ pub(crate) fn encode_with_padding<E: Engine + ?Sized>(
/// input lengths in approximately the top quarter of the range of `usize`.
pub fn encoded_len(bytes_len: usize, padding: bool) -> Option<usize> {
let rem = bytes_len % 3;

let complete_input_chunks = bytes_len / 3;
let complete_chunk_output = complete_input_chunks.checked_mul(4);

if rem > 0 {
if padding {
complete_chunk_output.and_then(|c| c.checked_add(4))
} else {
let encoded_rem = match rem {
1 => 2,
2 => 3,
_ => unreachable!("Impossible remainder"),
};
complete_chunk_output.and_then(|c| c.checked_add(encoded_rem))
}
let chunks = bytes_len / 3 + (rem > 0 && padding) as usize;
let encoded_len = chunks.checked_mul(4)?;
Some(if !padding && rem > 0 {
// This doesn’t overflow. encoded_len is divisible by four thus it’s at
// most usize::MAX - 3. rem ≤ 2 so we’re adding at most three.
encoded_len + rem + 1
} else {
complete_chunk_output
}
encoded_len
})
}

/// Write padding characters.
Expand Down
72 changes: 4 additions & 68 deletions src/engine/general_purpose/decode.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use crate::{
engine::{general_purpose::INVALID_VALUE, DecodeEstimate, DecodePaddingMode},
engine::{general_purpose::INVALID_VALUE, DecodePaddingMode},
DecodeError, PAD_BYTE,
};

Expand All @@ -21,30 +21,6 @@ const INPUT_BLOCK_LEN: usize = CHUNKS_PER_FAST_LOOP_BLOCK * INPUT_CHUNK_LEN;
const DECODED_BLOCK_LEN: usize =
CHUNKS_PER_FAST_LOOP_BLOCK * DECODED_CHUNK_LEN + DECODED_CHUNK_SUFFIX;

#[doc(hidden)]
pub struct GeneralPurposeEstimate {
/// Total number of decode chunks, including a possibly partial last chunk
num_chunks: usize,
decoded_len_estimate: usize,
}

impl GeneralPurposeEstimate {
pub(crate) fn new(encoded_len: usize) -> Self {
// Formulas that won't overflow
Self {
num_chunks: encoded_len / INPUT_CHUNK_LEN
+ (encoded_len % INPUT_CHUNK_LEN > 0) as usize,
decoded_len_estimate: (encoded_len / 4 + (encoded_len % 4 > 0) as usize) * 3,
}
}
}

impl DecodeEstimate for GeneralPurposeEstimate {
fn decoded_len_estimate(&self) -> usize {
self.decoded_len_estimate
}
}

/// Helper to avoid duplicating num_chunks calculation, which is costly on short inputs.
/// Returns the number of bytes written, or an error.
// We're on the fragile edge of compiler heuristics here. If this is not inlined, slow. If this is
Expand All @@ -53,12 +29,11 @@ impl DecodeEstimate for GeneralPurposeEstimate {
#[inline]
pub(crate) fn decode_helper(
input: &[u8],
estimate: GeneralPurposeEstimate,
output: &mut [u8],
decode_table: &[u8; 256],
decode_allow_trailing_bits: bool,
padding_mode: DecodePaddingMode,
) -> Result<usize, DecodeError> {
) -> Result<(), DecodeError> {
let remainder_len = input.len() % INPUT_CHUNK_LEN;

// Because the fast decode loop writes in groups of 8 bytes (unrolled to
Expand Down Expand Up @@ -99,7 +74,8 @@ pub(crate) fn decode_helper(
};

// rounded up to include partial chunks
let mut remaining_chunks = estimate.num_chunks;
let mut remaining_chunks =
input.len() / INPUT_CHUNK_LEN + (input.len() % INPUT_CHUNK_LEN > 0) as usize;

let mut input_index = 0;
let mut output_index = 0;
Expand Down Expand Up @@ -340,44 +316,4 @@ mod tests {
decode_chunk(&input[..], 0, &STANDARD.decode_table, &mut output).unwrap();
assert_eq!(&vec![b'f', b'o', b'o', b'b', b'a', b'r', 0, 0], &output);
}

#[test]
fn estimate_short_lengths() {
for (range, (num_chunks, decoded_len_estimate)) in [
(0..=0, (0, 0)),
(1..=4, (1, 3)),
(5..=8, (1, 6)),
(9..=12, (2, 9)),
(13..=16, (2, 12)),
(17..=20, (3, 15)),
] {
for encoded_len in range {
let estimate = GeneralPurposeEstimate::new(encoded_len);
assert_eq!(num_chunks, estimate.num_chunks);
assert_eq!(decoded_len_estimate, estimate.decoded_len_estimate);
}
}
}

#[test]
fn estimate_via_u128_inflation() {
// cover both ends of usize
(0..1000)
.chain(usize::MAX - 1000..=usize::MAX)
.for_each(|encoded_len| {
// inflate to 128 bit type to be able to safely use the easy formulas
let len_128 = encoded_len as u128;

let estimate = GeneralPurposeEstimate::new(encoded_len);
assert_eq!(
((len_128 + (INPUT_CHUNK_LEN - 1) as u128) / (INPUT_CHUNK_LEN as u128))
as usize,
estimate.num_chunks
);
assert_eq!(
((len_128 + 3) / 4 * 3) as usize,
estimate.decoded_len_estimate
);
})
}
}
10 changes: 6 additions & 4 deletions src/engine/general_purpose/decode_suffix.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@ use crate::{
/// Decode the last 1-8 bytes, checking for trailing set bits and padding per the provided
/// parameters.
///
/// Returns the total number of bytes decoded, including the ones indicated as already written by
/// `output_index`.
/// Expects output to be large enough to fit decoded data exactly without any
/// unused space. In debug builds panics if final output length (`output_index`
/// plus any bytes written by this function) doesn’t equal length of the output.
pub(crate) fn decode_suffix(
input: &[u8],
input_index: usize,
Expand All @@ -16,7 +17,7 @@ pub(crate) fn decode_suffix(
decode_table: &[u8; 256],
decode_allow_trailing_bits: bool,
padding_mode: DecodePaddingMode,
) -> Result<usize, DecodeError> {
) -> Result<(), DecodeError> {
// Decode any leftovers that aren't a complete input block of 8 bytes.
// Use a u64 as a stack-resident 8 byte buffer.
let mut leftover_bits: u64 = 0;
Expand Down Expand Up @@ -157,5 +158,6 @@ pub(crate) fn decode_suffix(
leftover_bits_appended_to_buf += 8;
}

Ok(output_index)
debug_assert_eq!(output.len(), output_index);
Ok(())
}
14 changes: 1 addition & 13 deletions src/engine/general_purpose/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ use core::convert::TryInto;

mod decode;
pub(crate) mod decode_suffix;
pub use decode::GeneralPurposeEstimate;

pub(crate) const INVALID_VALUE: u8 = 255;

Expand Down Expand Up @@ -40,7 +39,6 @@ impl GeneralPurpose {

impl super::Engine for GeneralPurpose {
type Config = GeneralPurposeConfig;
type DecodeEstimate = GeneralPurposeEstimate;

fn internal_encode(&self, input: &[u8], output: &mut [u8]) -> usize {
let mut input_index: usize = 0;
Expand Down Expand Up @@ -161,19 +159,9 @@ impl super::Engine for GeneralPurpose {
output_index
}

fn internal_decoded_len_estimate(&self, input_len: usize) -> Self::DecodeEstimate {
GeneralPurposeEstimate::new(input_len)
}

fn internal_decode(
&self,
input: &[u8],
output: &mut [u8],
estimate: Self::DecodeEstimate,
) -> Result<usize, DecodeError> {
fn internal_decode(&self, input: &[u8], output: &mut [u8]) -> Result<(), DecodeError> {
decode::decode_helper(
input,
estimate,
output,
&self.decode_table,
self.config.decode_allow_trailing_bits,
Expand Down
Loading

0 comments on commit 9e1d415

Please sign in to comment.