diff --git a/iroh-base/src/base32.rs b/iroh-base/src/base32.rs index 082e539056..a87153a166 100644 --- a/iroh-base/src/base32.rs +++ b/iroh-base/src/base32.rs @@ -1,4 +1,5 @@ pub use data_encoding::{DecodeError, DecodeKind}; +use hex::FromHexError; /// Convert to a base32 string pub fn fmt(bytes: impl AsRef<[u8]>) -> String { @@ -38,3 +39,29 @@ pub fn parse_array(input: &str) -> Result<[u8; N], DecodeError> pub fn parse_vec(input: &str) -> Result, DecodeError> { data_encoding::BASE32_NOPAD.decode(input.to_ascii_uppercase().as_bytes()) } + +/// Error when parsing a hex or base32 string. +#[derive(thiserror::Error, Debug)] +pub enum HexOrBase32ParseError { + /// Error when decoding the base32. + #[error("base32: {0}")] + Base32(#[from] data_encoding::DecodeError), + /// Error when decoding the public key. + #[error("hex: {0}")] + Hex(#[from] FromHexError), +} + +/// Parse a fixed length hex or base32 string into a byte array +/// +/// For fixed length we can know the encoding by the length of the string. +pub fn parse_array_hex_or_base32( + input: &str, +) -> std::result::Result<[u8; LEN], HexOrBase32ParseError> { + let mut bytes = [0u8; LEN]; + if input.len() == LEN * 2 { + hex::decode_to_slice(input, &mut bytes)?; + Ok(bytes) + } else { + Ok(parse_array(input)?) + } +} diff --git a/iroh-base/src/hash.rs b/iroh-base/src/hash.rs index f4bf231e52..423be42722 100644 --- a/iroh-base/src/hash.rs +++ b/iroh-base/src/hash.rs @@ -7,6 +7,8 @@ use bao_tree::blake3; use postcard::experimental::max_size::MaxSize; use serde::{de, Deserialize, Deserializer, Serialize, Serializer}; +use crate::base32::{parse_array_hex_or_base32, HexOrBase32ParseError}; + /// Hash type used throughout. #[derive(PartialEq, Eq, Copy, Clone, Hash)] pub struct Hash(blake3::Hash); @@ -118,32 +120,10 @@ impl fmt::Display for Hash { } impl FromStr for Hash { - type Err = anyhow::Error; + type Err = HexOrBase32ParseError; fn from_str(s: &str) -> Result { - let sb = s.as_bytes(); - if sb.len() == 64 { - // this is most likely a hex encoded hash - // try to decode it as hex - let mut bytes = [0u8; 32]; - if hex::decode_to_slice(sb, &mut bytes).is_ok() { - return Ok(Self::from(bytes)); - } - } - anyhow::ensure!(sb.len() == 52, "invalid base32 length"); - // this is a base32 encoded hash, we can decode it directly - let mut t = [0u8; 52]; - t.copy_from_slice(sb); - // hack since data_encoding doesn't have BASE32LOWER_NOPAD as a const - std::str::from_utf8_mut(t.as_mut()) - .unwrap() - .make_ascii_uppercase(); - // decode the bytes - let mut res = [0u8; 32]; - data_encoding::BASE32_NOPAD - .decode_mut(&t, &mut res) - .map_err(|_e| anyhow::anyhow!("invalid base32"))?; - Ok(Self::from(res)) + parse_array_hex_or_base32(s).map(Hash::from) } } diff --git a/iroh-net/src/key.rs b/iroh-net/src/key.rs index 854b0d3be1..d7744e3a29 100644 --- a/iroh-net/src/key.rs +++ b/iroh-net/src/key.rs @@ -12,7 +12,7 @@ use std::{ pub use ed25519_dalek::{Signature, PUBLIC_KEY_LENGTH}; use ed25519_dalek::{SignatureError, SigningKey, VerifyingKey}; -use iroh_base::base32; +use iroh_base::base32::{self, HexOrBase32ParseError}; use once_cell::sync::OnceCell; use rand_core::CryptoRngCore; use serde::{Deserialize, Serialize}; @@ -236,7 +236,7 @@ impl Display for PublicKey { pub enum KeyParsingError { /// Error when decoding the base32. #[error("decoding: {0}")] - Base32(#[from] data_encoding::DecodeError), + Base32(#[from] HexOrBase32ParseError), /// Error when decoding the public key. #[error("key: {0}")] Key(#[from] ed25519_dalek::SignatureError), @@ -249,9 +249,8 @@ impl FromStr for PublicKey { type Err = KeyParsingError; fn from_str(s: &str) -> Result { - let bytes = data_encoding::BASE32_NOPAD.decode(s.to_ascii_uppercase().as_bytes())?; - let key = PublicKey::try_from(&bytes[..])?; - Ok(key) + let bytes = base32::parse_array_hex_or_base32::<32>(s)?; + Ok(Self::try_from(bytes.as_ref())?) } } @@ -278,7 +277,7 @@ impl FromStr for SecretKey { type Err = KeyParsingError; fn from_str(s: &str) -> Result { - Ok(SecretKey::from(base32::parse_array::<32>(s)?)) + Ok(SecretKey::from(base32::parse_array_hex_or_base32::<32>(s)?)) } } @@ -402,12 +401,9 @@ mod tests { #[test] fn test_public_key_postcard() { - let public_key = PublicKey::try_from( - hex::decode("ae58ff8833241ac82d6ff7611046ed67b5072d142c588d0063e942d9a75502b6") - .unwrap() - .as_slice(), - ) - .unwrap(); + let public_key = + PublicKey::from_str("ae58ff8833241ac82d6ff7611046ed67b5072d142c588d0063e942d9a75502b6") + .unwrap(); let bytes = postcard::to_stdvec(&public_key).unwrap(); let expected = parse_hexdump("ae58ff8833241ac82d6ff7611046ed67b5072d142c588d0063e942d9a75502b6")