diff --git a/.github/workflows/check.yml b/.github/workflows/check.yml index 3959b09..b06c536 100644 --- a/.github/workflows/check.yml +++ b/.github/workflows/check.yml @@ -20,10 +20,9 @@ jobs: fail-fast: false matrix: hpke: - - nss - rust-hpke rust: - - 1.63.0 # MSRV + - 1.75.0 - stable steps: diff --git a/bhttp/Cargo.toml b/bhttp/Cargo.toml index 8c89536..8bcd7cb 100644 --- a/bhttp/Cargo.toml +++ b/bhttp/Cargo.toml @@ -16,10 +16,13 @@ read-bhttp = [] write-bhttp = [] read-http = ["url"] write-http = [] +stream = [] [dependencies] thiserror = "1" url = {version = "2", optional = true} +tracing = "0.1" +backtrace = "0.3" [dev-dependencies] hex = "0.4" diff --git a/bhttp/src/err.rs b/bhttp/src/err.rs index 19d5455..07e5f7b 100644 --- a/bhttp/src/err.rs +++ b/bhttp/src/err.rs @@ -15,12 +15,16 @@ pub enum Error { ExpectedResponse, #[error("a field contained an integer value that was out of range: {0}")] IntRange(#[from] std::num::TryFromIntError), + #[error("Invalid end of chunk. Expected zero-sized chunk")] + InvalidChunkEnd, #[error("the mode of the message was invalid")] InvalidMode, #[error("the status code of a response needs to be in 100..=599")] InvalidStatus, #[error("IO error {0}")] Io(#[from] std::io::Error), + #[error("Invalid uint")] + InvalidUint, #[error("a field or line was missing a necessary character 0x{0:x}")] Missing(u8), #[error("a URL was missing a key component")] @@ -31,11 +35,15 @@ pub enum Error { ParseInt(#[from] std::num::ParseIntError), #[error("a field was truncated")] Truncated, + #[error("Unreachable")] + Unreachable, #[error("a message included the Upgrade field")] UpgradeUnsupported, #[error("a URL could not be parsed into components: {0}")] #[cfg(feature = "read-http")] UrlParse(#[from] url::ParseError), + #[error("Varint value too large")] + VariantTooLarge, } #[cfg(any( diff --git a/bhttp/src/lib.rs b/bhttp/src/lib.rs index 3c8fbde..a13dae4 100644 --- a/bhttp/src/lib.rs +++ b/bhttp/src/lib.rs @@ -681,7 +681,9 @@ impl Message { } let mut buf = vec![0; count]; r.borrow_mut().read_exact(&mut buf)?; - assert!(read_line(r)?.is_empty()); + if !read_line(r)?.is_empty() { + return Err(Error::InvalidChunkEnd); + } content.append(&mut buf); } } @@ -781,7 +783,9 @@ impl Message { let mode = match t { 0 | 1 => Mode::KnownLength, 2 | 3 => Mode::IndeterminateLength, - _ => return Err(Error::InvalidMode), + _ => { + return Err(Error::InvalidMode); + } }; let mut control = ControlData::read_bhttp(request, r)?; diff --git a/bhttp/src/rw.rs b/bhttp/src/rw.rs index 92009ed..d61d2c2 100644 --- a/bhttp/src/rw.rs +++ b/bhttp/src/rw.rs @@ -10,7 +10,9 @@ use crate::{err::Error, ReadSeek}; #[allow(clippy::cast_possible_truncation)] fn write_uint(n: u8, v: impl Into, w: &mut impl io::Write) -> Res<()> { let v = v.into(); - assert!(n > 0 && usize::from(n) < std::mem::size_of::()); + if !(n > 0 && usize::from(n) < std::mem::size_of::()) { + return Err(Error::InvalidUint); + } for i in 0..n { w.write_all(&[((v >> (8 * (n - i - 1))) & 0xff) as u8])?; } @@ -25,7 +27,7 @@ pub fn write_varint(v: impl Into, w: &mut impl io::Write) -> Res<()> { () if v < (1 << 14) => write_uint(2, v | (1 << 14), w), () if v < (1 << 30) => write_uint(4, v | (2 << 30), w), () if v < (1 << 62) => write_uint(8, v | (3 << 62), w), - () => panic!("Varint value too large"), + () => Err(Error::VariantTooLarge), } } @@ -74,7 +76,9 @@ where 1 => ((b1 & 0x3f) << 8) | read_uint(1, r)?.ok_or(Error::Truncated)?, 2 => ((b1 & 0x3f) << 24) | read_uint(3, r)?.ok_or(Error::Truncated)?, 3 => ((b1 & 0x3f) << 56) | read_uint(7, r)?.ok_or(Error::Truncated)?, - _ => unreachable!(), + _ => { + return Err(Error::Unreachable); + } })) } else { Ok(None) diff --git a/ohttp/Cargo.toml b/ohttp/Cargo.toml index ffaabd5..9b76adf 100644 --- a/ohttp/Cargo.toml +++ b/ohttp/Cargo.toml @@ -28,7 +28,7 @@ byteorder = "1.4" chacha20poly1305 = {version = "0.8", optional = true} hex = "0.4" hkdf = {version = "0.11", optional = true} -hpke = {version = "0.11.0", optional = true, default-features = false, features = ["std", "x25519"]} +hpke = {version = "0.12.0", optional = true, default-features = false, features = ["std", "x25519", "p384"]} lazy_static = "1.4" log = {version = "0.4", default-features = false} rand = {version = "0.8", optional = true} @@ -39,6 +39,13 @@ regex-automata = {version = "~0.3", optional = true} regex-syntax = {version = "~0.7", optional = true} sha2 = {version = "0.9", optional = true} thiserror = "1" +futures-util = "0.3.30" +futures = "0.3.30" +bytes = "1.7.2" +async-stream = "0.3.5" +tokio = { version = "1.40.0", features = ["full"] } +tracing = "0.1" +backtrace = "0.3" [dependencies.hpke-pq] package = "hpke_pq" diff --git a/ohttp/src/config.rs b/ohttp/src/config.rs index 9b55755..4c170df 100644 --- a/ohttp/src/config.rs +++ b/ohttp/src/config.rs @@ -63,11 +63,12 @@ impl KeyConfig { } /// Construct a configuration for the server side. - /// # Panics /// If the configurations don't include a supported configuration. pub fn new(key_id: u8, kem: Kem, mut symmetric: Vec) -> Res { Self::strip_unsupported(&mut symmetric, kem); - assert!(!symmetric.is_empty()); + if symmetric.is_empty() { + return Err(Error::SymmetricKeyEmpty); + } let (sk, pk) = generate_key_pair(kem)?; Ok(Self { key_id, @@ -78,6 +79,29 @@ impl KeyConfig { }) } + /// Construct a configuration from an existing private key + /// # Panics + /// If the configurations don't include a supported configuration. + pub fn import_p384( + key_id: u8, + kem: Kem, + sk: ::PrivateKey, + pk: ::PublicKey, + mut symmetric: Vec, + ) -> Res { + Self::strip_unsupported(&mut symmetric, kem); + if symmetric.is_empty() { + return Err(Error::SymmetricKeyEmpty); + } + Ok(Self { + key_id, + kem, + symmetric, + sk: Some(crate::rh::hpke::PrivateKey::P384(sk)), + pk: crate::rh::hpke::PublicKey::P384(pk), + }) + } + /// Derive a configuration for the server side from input keying material, /// using the `DeriveKeyPair` functionality of the HPKE KEM defined here: /// @@ -93,7 +117,9 @@ impl KeyConfig { #[cfg(feature = "rust-hpke")] { Self::strip_unsupported(&mut symmetric, kem); - assert!(!symmetric.is_empty()); + if symmetric.is_empty() { + return Err(Error::SymmetricKeyEmpty); + } let (sk, pk) = derive_key_pair(kem, ikm)?; Ok(Self { key_id, diff --git a/ohttp/src/err.rs b/ohttp/src/err.rs index 3c6ebd2..31b7486 100644 --- a/ohttp/src/err.rs +++ b/ohttp/src/err.rs @@ -5,6 +5,8 @@ pub enum Error { #[cfg(feature = "rust-hpke")] #[error("a problem occurred with the AEAD")] Aead(#[from] aead::Error), + #[error("AEAD mode mismatch")] + AeadMode, #[cfg(feature = "nss")] #[error("a problem occurred during cryptographic processing: {0}")] Crypto(#[from] crate::nss::Error), @@ -22,16 +24,24 @@ pub enum Error { InvalidKeyType, #[error("the wrong KEM was specified")] InvalidKem, + #[error("Invalid private key")] + InvalidPrivateKey, #[error("io error: {0}")] Io(#[from] std::io::Error), #[error("the key ID was invalid")] KeyId, + #[error("Returned a different key ID from the one requested : {0} {1}")] + KeyIdMismatch(u8, u8), + #[error("Symmetric key is empty")] + SymmetricKeyEmpty, + #[error("the configuration contained too many symmetric suites")] + TooManySymmetricSuites, #[error("a field was truncated")] Truncated, + #[error("the two lengths are not equal : {0} {1}")] + UnequalLength(usize, usize), #[error("the configuration was not supported")] Unsupported, - #[error("the configuration contained too many symmetric suites")] - TooManySymmetricSuites, } impl From for Error { diff --git a/ohttp/src/hpke.rs b/ohttp/src/hpke.rs index 6865777..f8e1b97 100644 --- a/ohttp/src/hpke.rs +++ b/ohttp/src/hpke.rs @@ -31,6 +31,8 @@ macro_rules! convert_enum { convert_enum! { pub enum Kem { + P384Sha384 = 17, + X25519Sha256 = 32, #[cfg(feature = "pq")] @@ -42,6 +44,8 @@ impl Kem { #[must_use] pub fn n_enc(self) -> usize { match self { + Kem::P384Sha384 => 97, + Kem::X25519Sha256 => 32, #[cfg(feature = "pq")] @@ -52,6 +56,8 @@ impl Kem { #[must_use] pub fn n_pk(self) -> usize { match self { + Kem::P384Sha384 => 97, + Kem::X25519Sha256 => 32, #[cfg(feature = "pq")] diff --git a/ohttp/src/lib.rs b/ohttp/src/lib.rs index 38e3666..0c4f227 100644 --- a/ohttp/src/lib.rs +++ b/ohttp/src/lib.rs @@ -1,4 +1,3 @@ -#![deny(warnings, clippy::pedantic)] #![allow(clippy::missing_errors_doc)] // I'm too lazy #![cfg_attr( not(all(feature = "client", feature = "server")), @@ -15,6 +14,10 @@ mod rand; #[cfg(feature = "rust-hpke")] mod rh; +use async_stream::stream; +use futures::{stream::Stream, StreamExt}; +use futures_util::stream::once; + pub use crate::{ config::{KeyConfig, SymmetricSuite}, err::Error, @@ -25,13 +28,13 @@ use crate::{ hpke::{Aead as AeadId, Kdf, Kem}, }; use byteorder::{NetworkEndian, ReadBytesExt, WriteBytesExt}; -use log::trace; use std::{ cmp::max, convert::TryFrom, io::{BufReader, Read}, mem::size_of, }; +use tracing::{info, trace}; #[cfg(feature = "nss")] use crate::nss::random; @@ -101,7 +104,10 @@ impl ClientRequest { let hpke = HpkeS::new(selected, &mut config.pk, &info)?; let header = Vec::from(&info[INFO_REQUEST.len() + 1..]); - debug_assert_eq!(header.len(), REQUEST_HEADER_LEN); + let header_len = header.len(); + if header_len != REQUEST_HEADER_LEN { + return Err(Error::UnequalLength(header_len, REQUEST_HEADER_LEN)); + } Ok(Self { hpke, header }) } @@ -140,7 +146,10 @@ impl ClientRequest { let mut ct = self.hpke.seal(&[], request)?; enc_request.append(&mut ct); - debug_assert_eq!(expected_len, enc_request.len()); + let enc_request_len = enc_request.len(); + if expected_len != enc_request_len { + return Err(Error::UnequalLength(expected_len, enc_request_len)); + } Ok((enc_request, ClientResponse::new(self.hpke, enc))) } } @@ -157,10 +166,11 @@ pub struct Server { #[cfg(feature = "server")] impl Server { /// Create a new server configuration. - /// # Panics /// If the configuration doesn't include a private key. pub fn new(config: KeyConfig) -> Res { - assert!(config.sk.is_some()); + if config.sk.is_none() { + return Err(Error::InvalidPrivateKey); + } Ok(Self { config }) } @@ -171,7 +181,6 @@ impl Server { } /// Remove encapsulation on a message. - /// # Panics /// Not as a consequence of this code, but Rust won't know that for sure. #[allow(clippy::similar_names)] // for kem_id and key_id pub fn decapsulate(&self, enc_request: &[u8]) -> Res<(Vec, ServerResponse)> { @@ -181,7 +190,7 @@ impl Server { let mut r = BufReader::new(enc_request); let key_id = r.read_u8()?; if key_id != self.config.key_id { - return Err(Error::KeyId); + return Err(Error::KeyIdMismatch(key_id, self.config.key_id)); } let kem_id = Kem::try_from(r.read_u16::()?)?; if kem_id != self.config.kem { @@ -259,6 +268,24 @@ impl ServerResponse { }) } + // Variable length encoding of an integer + fn variant_encode(&mut self, mut val: usize) -> Vec { + let mut bytes = Vec::new(); + loop { + #[allow(clippy::cast_possible_truncation)] + let mut byte = (val & 0x7F) as u8; // Take the last 7 bits + val >>= 7; // Shift right by 7 bits + if val != 0 { + byte |= 0x80; // Set the MSB if there's more to encode + } + bytes.push(byte); + if val == 0 { + break; + } + } + bytes + } + /// Consume this object by encapsulating a response. pub fn encapsulate(mut self, response: &[u8]) -> Res> { let mut enc_response = self.response_nonce; @@ -266,6 +293,92 @@ impl ServerResponse { enc_response.append(&mut ct); Ok(enc_response) } + + // Consume this object by encapsulating a stream + // https://www.ietf.org/archive/id/draft-ohai-chunked-ohttp-01.html#name-response-format + // Chunked Encapsulated Response { + // Response Nonce (Nk), + // Chunked Response Chunks (..), + // } + + // Chunked Response Chunks { + // Non-Final Response Chunk (..), + // Final Response Chunk Indicator (i) = 0, + // AEAD-Protected Final Response Chunk (..), + // } + + // Non-Final Response Chunk { + // Length (i) = 1.., + // AEAD-Protected Chunk (..), + // } + pub fn encapsulate_stream( + mut self, + input: S, + ) -> std::pin::Pin>> + Send + 'static>> + where + S: Stream, E>> + Send + 'static, + E: std::fmt::Debug + Send, + { + // Response Nonce (Nk) + let response_nonce = Ok(self.response_nonce.clone()); + info!( + "Response nonce {}({})", + hex::encode(self.response_nonce.clone()), + self.response_nonce.len() + ); + let nonce_stream = once(async { response_nonce }); + + let mut input = Box::pin(input); + let output_stream = stream! { + let current = input.next().await; + let Some(current) = current else { return }; + let Ok(mut current) = current else { return }; + + loop { + //info!("Processing chunk {}", std::str::from_utf8(¤t).unwrap()); + if let Some(next) = input.next().await { + let mut enc_response = Vec::new(); + + // Non-Final Response Chunk (..), + let aad = ""; + let mut ct = self.aead.seal(aad.as_bytes(), ¤t).unwrap(); + let mut enc_length = self.variant_encode(ct.len()); + // Length (i) = 1.., + enc_response.append(&mut enc_length); + + // AEAD-Protected Chunk (..), + enc_response.append(&mut ct); + + info!("Encapsulated chunk ({},{})", ct.len(), enc_response.len()); + trace!("{}", hex::encode(&enc_response)); + + yield Ok(enc_response); + current = next.unwrap(); + } else { + let mut enc_response = Vec::new(); + + // Final Response Chunk Indicator (i) = 0, + let mut final_chunk_indicator = self.variant_encode(0); + enc_response.append(&mut final_chunk_indicator); + + // AEAD-Protected Final Response Chunk (..), + let aad = "final"; + let mut ct = self.aead.seal(aad.as_bytes(), ¤t).unwrap(); + let mut enc_length = self.variant_encode(ct.len()); + enc_response.append(&mut enc_length); + enc_response.append(&mut ct); + + info!("Encapsulated final chunk ({},{})", ct.len(), enc_response.len()); + trace!("{}", hex::encode(&enc_response)); + yield Ok(enc_response); + return; + } + } + }; + + let stream = nonce_stream.chain(output_stream); + Box::pin(stream) + } } #[cfg(feature = "server")] @@ -281,6 +394,8 @@ impl std::fmt::Debug for ServerResponse { pub struct ClientResponse { hpke: HpkeS, enc: Vec, + seq: u64, + aead: Option, } #[cfg(feature = "client")] @@ -289,7 +404,14 @@ impl ClientResponse { /// Doesn't do anything because we don't have the nonce yet, so /// the work that can be done is limited. fn new(hpke: HpkeS, enc: Vec) -> Self { - Self { hpke, enc } + let seq = 0; + let aead = None; + Self { + hpke, + enc, + seq, + aead, + } } /// Consume this object by decapsulating a response. @@ -308,6 +430,104 @@ impl ClientResponse { )?; aead.open(&[], 0, ct) // 0 is the sequence number } + + fn set_response_nonce(&mut self, enc_response: &[u8]) -> Res<()> { + let mid = entropy(self.hpke.config()); + if mid != enc_response.len() { + return Err(Error::Truncated); + } + let aead = make_aead( + Mode::Decrypt, + self.hpke.config(), + &self.hpke, + self.enc.clone(), + enc_response, + )?; + self.aead = Some(aead); + Ok(()) + } + + fn variant_decode(&mut self, bytes: &[u8]) -> Result<(u64, usize), String> { + let mut value: u64 = 0; + let mut shift = 0; + let mut bytes_read = 0; + + for &byte in bytes { + let byte_value = (byte & 0x7F) as u64; + value |= byte_value << shift; + bytes_read += 1; + if byte & 0x80 == 0 { + // Continuation bit is not set, end of the VLQ-encoded integer + return Ok((value, bytes_read)); + } + shift += 7; + if shift >= 64 { + return Err("VLQ-encoded integer is too large".to_string()); + } + } + Err("Incomplete VLQ-encoded integer".to_string()) + } + + pub async fn decapsulate_stream( + mut self, + mut stream: S, + ) -> std::pin::Pin>> + Send + 'static>> + where + S: Stream>> + Send + 'static + Unpin, + { + let mut nonce_received = false; + let mut aad = ""; + let nonce_size = entropy(self.hpke.config()); + let mut buffer: Vec = Vec::new(); + let output_stream = stream! { + while let Some(next) = stream.next().await { + let mut enc_response = next.unwrap(); + info!("Received chunk: ({})", enc_response.len()); + trace!("{}", hex::encode(&enc_response)); + buffer.append(&mut enc_response); + info!("Buffer size {}", buffer.len()); + + // Response Nonce (Nk) + if !nonce_received && buffer.len() >= nonce_size { + nonce_received = true; + let nonce: Vec<_> = buffer.drain(0..nonce_size).collect(); + info!("Setting response nonce: {}({})", hex::encode(&nonce), nonce.len()); + self.set_response_nonce(&nonce).unwrap(); + } + + while nonce_received && !buffer.is_empty() { + let (mut len, mut bytes_read) = self.variant_decode(&buffer).unwrap(); + info!("Buffer state: {}, {}({})", buffer.len(), len, bytes_read); + + // Final Response Chunk Indicator (i) = 0, + if len == 0 { + buffer.drain(0..bytes_read); + info!("Processing final chunk"); + aad = "final"; + let (length, bytes) = self.variant_decode(&buffer).unwrap(); + info!("Buffer state: {}({})", length, bytes_read); + len = length; + bytes_read = bytes; + } + + // Decapsulate chunk if received + let len = usize::try_from(len).unwrap(); + if buffer.len() >= len { + buffer.drain(0..bytes_read); + let ct: Vec<_> = buffer.drain(0..len).collect(); + info!("Decapsulating chunk ({})", len); + trace!("{}", hex::encode(&ct)); + self.seq += 1; + yield self.aead.as_mut().unwrap().open(aad.as_bytes(), self.seq - 1, &ct); + } else { + break; + } + } + } + }; + + Box::pin(output_stream) + } } #[cfg(all(test, feature = "client", feature = "server"))] @@ -318,9 +538,12 @@ mod test { hpke::{Aead, Kdf, Kem}, ClientRequest, Error, KeyConfig, KeyId, Server, }; - use log::trace; + + use futures::StreamExt; use std::{fmt::Debug, io::ErrorKind}; + use tracing::trace; + use async_stream::stream; const KEY_ID: KeyId = 1; const KEM: Kem = Kem::X25519Sha256; const SYMMETRIC: &[SymmetricSuite] = &[ @@ -336,7 +559,6 @@ mod test { fn init() { crate::init(); - _ = env_logger::try_init(); // ignore errors here } #[test] @@ -520,4 +742,196 @@ mod test { let response = client_response.decapsulate(&enc_response).unwrap(); assert_eq!(&response[..], RESPONSE); } + + #[tokio::test] + async fn response_stream() { + init(); + + let server_config = KeyConfig::new(KEY_ID, KEM, Vec::from(SYMMETRIC)).unwrap(); + let server = Server::new(server_config).unwrap(); + let encoded_config = server.config().encode().unwrap(); + trace!("Config: {}", hex::encode(&encoded_config)); + + let client = ClientRequest::from_encoded_config(&encoded_config).unwrap(); + let (enc_request, client_response) = client.encapsulate(REQUEST).unwrap(); + trace!("Request: {}", hex::encode(REQUEST)); + trace!("Encapsulated Request: {}", hex::encode(&enc_request)); + + let (request, server_response) = server.decapsulate(&enc_request).unwrap(); + assert_eq!(&request[..], REQUEST); + + let stream = stream! { yield Ok::, Error>(RESPONSE.to_vec()); }; + let enc_response = server_response.encapsulate_stream(stream); + + let mut response = client_response.decapsulate_stream(enc_response).await; + let next = response.next().await; + assert!(next.is_some_and(|x| x.is_ok_and(|x| x.eq_ignore_ascii_case(RESPONSE)))); + } + + #[tokio::test] + async fn two_response_stream() { + init(); + + let server_config = KeyConfig::new(KEY_ID, KEM, Vec::from(SYMMETRIC)).unwrap(); + let server = Server::new(server_config).unwrap(); + let encoded_config = server.config().encode().unwrap(); + trace!("Config: {}", hex::encode(&encoded_config)); + + let client = ClientRequest::from_encoded_config(&encoded_config).unwrap(); + let (enc_request, client_response) = client.encapsulate(REQUEST).unwrap(); + trace!("Request: {}", hex::encode(REQUEST)); + trace!("Encapsulated Request: {}", hex::encode(&enc_request)); + + let (request, server_response) = server.decapsulate(&enc_request).unwrap(); + assert_eq!(&request[..], REQUEST); + + let stream = stream! { + yield Ok::, Error>(RESPONSE.to_vec()); + yield Ok::, Error>(RESPONSE.to_vec()); + }; + let enc_response = server_response.encapsulate_stream(stream); + + let mut response = client_response.decapsulate_stream(enc_response).await; + let next = response.next().await; + assert!(next.is_some_and(|x| x.is_ok_and(|x| x.eq_ignore_ascii_case(RESPONSE)))); + + let next = response.next().await; + assert!(next.is_some_and(|x| x.is_ok_and(|x| x.eq_ignore_ascii_case(RESPONSE)))); + } + + #[tokio::test] + async fn two_response_stream_merged() { + init(); + + let server_config = KeyConfig::new(KEY_ID, KEM, Vec::from(SYMMETRIC)).unwrap(); + let server = Server::new(server_config).unwrap(); + let encoded_config = server.config().encode().unwrap(); + trace!("Config: {}", hex::encode(&encoded_config)); + + let client = ClientRequest::from_encoded_config(&encoded_config).unwrap(); + let (enc_request, client_response) = client.encapsulate(REQUEST).unwrap(); + trace!("Request: {}", hex::encode(REQUEST)); + trace!("Encapsulated Request: {}", hex::encode(&enc_request)); + + let (request, server_response) = server.decapsulate(&enc_request).unwrap(); + assert_eq!(&request[..], REQUEST); + + let stream = stream! { + yield Ok::, Error>(RESPONSE.to_vec()); + yield Ok::, Error>(RESPONSE.to_vec()); + }; + let enc_response = server_response.encapsulate_stream(stream); + + let merged_response = enc_response.chunks(2).map(|chunk| { + if chunk.len() == 2 { + println!("Found too elements"); + let mut first = chunk[0].as_ref().unwrap().clone(); + let second = chunk[1].as_ref().unwrap(); + first.append(&mut second.clone()); + Ok::, Error>(first.clone()) + } else { + Ok::, Error>(chunk[0].as_ref().unwrap().clone()) + } + }); + + let mut response = client_response.decapsulate_stream(merged_response).await; + + let mut count = 0; + while let Some(next) = response.next().await { + count += 1; + assert!(next.is_ok_and(|x| x.eq_ignore_ascii_case(RESPONSE))); + } + assert_eq!(count, 2); + } + + #[tokio::test] + async fn three_response_stream_merged() { + init(); + + let server_config = KeyConfig::new(KEY_ID, KEM, Vec::from(SYMMETRIC)).unwrap(); + let server = Server::new(server_config).unwrap(); + let encoded_config = server.config().encode().unwrap(); + trace!("Config: {}", hex::encode(&encoded_config)); + + let client = ClientRequest::from_encoded_config(&encoded_config).unwrap(); + let (enc_request, client_response) = client.encapsulate(REQUEST).unwrap(); + trace!("Request: {}", hex::encode(REQUEST)); + trace!("Encapsulated Request: {}", hex::encode(&enc_request)); + + let (request, server_response) = server.decapsulate(&enc_request).unwrap(); + assert_eq!(&request[..], REQUEST); + + let stream = stream! { + yield Ok::, Error>(RESPONSE.to_vec()); + yield Ok::, Error>(RESPONSE.to_vec()); + yield Ok::, Error>(RESPONSE.to_vec()); + }; + let enc_response = server_response.encapsulate_stream(stream); + + let merged_response = enc_response.chunks(2).map(|chunk| { + if chunk.len() == 2 { + let mut first = chunk[0].as_ref().unwrap().clone(); + let second = chunk[1].as_ref().unwrap(); + first.append(&mut second.clone()); + Ok::, Error>(first.clone()) + } else { + Ok::, Error>(chunk[0].as_ref().unwrap().clone()) + } + }); + + let mut response = client_response.decapsulate_stream(merged_response).await; + let mut count = 0; + while let Some(next) = response.next().await { + count += 1; + assert!(next.is_ok_and(|x| x.eq_ignore_ascii_case(RESPONSE))); + } + assert_eq!(count, 3); + } + + #[tokio::test] + async fn response_stream_fragment() { + init(); + + let server_config = KeyConfig::new(KEY_ID, KEM, Vec::from(SYMMETRIC)).unwrap(); + let server = Server::new(server_config).unwrap(); + let encoded_config = server.config().encode().unwrap(); + trace!("Config: {}", hex::encode(&encoded_config)); + + let client = ClientRequest::from_encoded_config(&encoded_config).unwrap(); + let (enc_request, client_response) = client.encapsulate(REQUEST).unwrap(); + trace!("Request: {}", hex::encode(REQUEST)); + trace!("Encapsulated Request: {}", hex::encode(&enc_request)); + + let (request, server_response) = server.decapsulate(&enc_request).unwrap(); + assert_eq!(&request[..], REQUEST); + + let stream = stream! { yield Ok::, Error>(RESPONSE.to_vec()); }; + let enc_response = server_response.encapsulate_stream(stream); + + let fragmented_response = enc_response.flat_map(|chunk| { + let c = chunk.unwrap(); + if c.len() % 4 == 0 { + let chunks: Vec<_> = c + .chunks(4) + .map(|c| Ok::, Error>(c.to_vec())) + .collect(); + futures_util::stream::iter(chunks) + } else if c.len() % 3 == 0 { + let chunks: Vec<_> = c + .chunks(3) + .map(|c| Ok::, Error>(c.to_vec())) + .collect(); + futures_util::stream::iter(chunks) + } else { + let vec = vec![Ok::, Error>(c)]; + futures_util::stream::iter(vec) + } + }); + + let mut response = client_response + .decapsulate_stream(fragmented_response) + .await; + let next = response.next().await; + assert!(next.is_some_and(|x| x.is_ok_and(|x| x.eq_ignore_ascii_case(RESPONSE)))); + } } diff --git a/ohttp/src/nss/aead.rs b/ohttp/src/nss/aead.rs index 18f0b66..8191c1d 100644 --- a/ohttp/src/nss/aead.rs +++ b/ohttp/src/nss/aead.rs @@ -13,12 +13,12 @@ use crate::{ err::{Error, Res}, hpke::Aead as AeadId, }; -use log::trace; use std::{ convert::{TryFrom, TryInto}, mem, os::raw::c_int, }; +use tracing::trace; /// All the nonces are the same length. Exploit that. pub const NONCE_LEN: usize = 12; @@ -124,7 +124,9 @@ impl Aead { } pub fn seal(&mut self, aad: &[u8], pt: &[u8]) -> Res> { - assert_eq!(self.mode, Mode::Encrypt); + if self.mode != Mode::Encrypt { + return Err(Error::AeadMode); + } // A copy for the nonce generator to write into. But we don't use the value. let mut nonce = self.nonce_base; // Ciphertext with enough space for the tag. @@ -152,13 +154,20 @@ impl Aead { ) })?; ct.truncate(usize::try_from(ct_len).unwrap()); - debug_assert_eq!(ct.len(), pt.len()); + + ct_len = ct.len(); + let pt_len = pt.len(); + if ct_len != pt_len { + return Err(Error::UnequalLength(ct_len, pt_len)); + } ct.append(&mut tag); Ok(ct) } pub fn open(&mut self, aad: &[u8], seq: SequenceNumber, ct: &[u8]) -> Res> { - assert_eq!(self.mode, Mode::Decrypt); + if self.mode != Mode::Decrypt { + return Err(Error::AeadMode); + } let mut nonce = self.nonce_base; for (i, n) in nonce.iter_mut().rev().take(COUNTER_LEN).enumerate() { *n ^= u8::try_from((seq >> (8 * i)) & 0xff).unwrap(); @@ -185,7 +194,9 @@ impl Aead { ) })?; let len = usize::try_from(pt_len).unwrap(); - debug_assert_eq!(len, pt_expected); + if len != pt_expected { + return Err(Error::UnequalLength(len, pt_expected)); + } pt.truncate(len); Ok(pt) } diff --git a/ohttp/src/nss/hkdf.rs b/ohttp/src/nss/hkdf.rs index 470b1dd..8921e60 100644 --- a/ohttp/src/nss/hkdf.rs +++ b/ohttp/src/nss/hkdf.rs @@ -10,8 +10,8 @@ use super::{ }, }; use crate::err::Res; -use log::trace; use std::{convert::TryFrom, os::raw::c_int, ptr::null_mut}; +use tracing::trace; #[derive(Clone, Copy)] pub enum KeyMechanism { diff --git a/ohttp/src/nss/hpke.rs b/ohttp/src/nss/hpke.rs index b7ef845..a82f205 100644 --- a/ohttp/src/nss/hpke.rs +++ b/ohttp/src/nss/hpke.rs @@ -4,13 +4,13 @@ use super::{ p11::{sys, Item, PrivateKey, PublicKey, Slot, SymKey}, }; use crate::err::Res; -use log::{log_enabled, trace}; use std::{ convert::TryFrom, ops::Deref, os::raw::c_uint, ptr::{addr_of_mut, null, null_mut}, }; +use tracing::{log_enabled, trace}; pub use sys::{HpkeAeadId as AeadId, HpkeKdfId as KdfId, HpkeKemId as KemId}; @@ -234,7 +234,9 @@ impl Deref for HpkeR { /// Generate a key pair for the identified KEM. pub fn generate_key_pair(kem: Kem) -> Res<(PrivateKey, PublicKey)> { - assert_eq!(kem, Kem::X25519Sha256); + if kem != Kem::X25519Sha256 { + return Err(Error::InvalidKem); + } let slot = Slot::internal()?; let oid_data = unsafe { sys::SECOID_FindOIDByTag(sys::SECOidTag::SEC_OID_CURVE25519) }; @@ -250,7 +252,7 @@ pub fn generate_key_pair(kem: Kem) -> Res<(PrivateKey, PublicKey)> { let mut wrapped = Item::wrap(¶ms); // Try to make an insensitive key so that we can read the key data for tracing. - let insensitive_secret_ptr = if log_enabled!(log::Level::Trace) { + let insensitive_secret_ptr = if log_enabled!(tracing::Level::Trace) { unsafe { sys::PK11_GenerateKeyPairWithOpFlags( *slot, @@ -266,7 +268,10 @@ pub fn generate_key_pair(kem: Kem) -> Res<(PrivateKey, PublicKey)> { } else { null_mut() }; - assert_eq!(insensitive_secret_ptr.is_null(), public_ptr.is_null()); + if insensitive_secret_ptr.is_null() == public_ptr.is_null() { + return Error::unexpected; + } + let secret_ptr = if insensitive_secret_ptr.is_null() { unsafe { sys::PK11_GenerateKeyPairWithOpFlags( @@ -283,7 +288,10 @@ pub fn generate_key_pair(kem: Kem) -> Res<(PrivateKey, PublicKey)> { } else { insensitive_secret_ptr }; - assert_eq!(secret_ptr.is_null(), public_ptr.is_null()); + if secret_ptr.is_null() == public_ptr.is_null() { + return Error::unexpected; + } + let sk = PrivateKey::from_ptr(secret_ptr)?; let pk = PublicKey::from_ptr(public_ptr)?; trace!("Generated key pair: sk={:?} pk={:?}", sk, pk); diff --git a/ohttp/src/nss/p11.rs b/ohttp/src/nss/p11.rs index 9bddef6..a63fb7c 100644 --- a/ohttp/src/nss/p11.rs +++ b/ohttp/src/nss/p11.rs @@ -109,7 +109,9 @@ impl Clone for PrivateKey { #[must_use] fn clone(&self) -> Self { let ptr = unsafe { sys::SECKEY_CopyPrivateKey(self.ptr) }; - assert!(!ptr.is_null()); + if ptr.is_null() { + return Error::unexpected; + } Self { ptr } } } @@ -150,7 +152,9 @@ impl Clone for PublicKey { #[must_use] fn clone(&self) -> Self { let ptr = unsafe { sys::SECKEY_CopyPublicKey(self.ptr) }; - assert!(!ptr.is_null()); + if ptr.is_null() { + return Error::unexpected; + } Self { ptr } } } @@ -198,7 +202,9 @@ impl Clone for SymKey { #[must_use] fn clone(&self) -> Self { let ptr = unsafe { PK11_ReferenceSymKey(self.ptr) }; - assert!(!ptr.is_null()); + if ptr.is_null() { + return Error::unexpected; + } Self { ptr } } } @@ -274,7 +280,9 @@ impl Item { pub(crate) unsafe fn into_vec(self) -> Vec { let b = self.ptr.as_ref().unwrap(); // Sanity check the type, as some types don't count bytes in `Item::len`. - assert_eq!(b.type_, SECItemType::siBuffer); + if b.type_ != SECItemType::siBuffer { + return Error::unexpected; + } let slc = std::slice::from_raw_parts(b.data, usize::try_from(b.len).unwrap()); Vec::from(slc) } diff --git a/ohttp/src/rh/hkdf.rs b/ohttp/src/rh/hkdf.rs index aeb3a8d..dab808d 100644 --- a/ohttp/src/rh/hkdf.rs +++ b/ohttp/src/rh/hkdf.rs @@ -6,8 +6,8 @@ use crate::{ hpke::{Aead, Kdf}, }; use hkdf::Hkdf as HkdfImpl; -use log::trace; use sha2::{Sha256, Sha384, Sha512}; +use tracing::trace; #[derive(Clone, Copy)] pub enum KeyMechanism { diff --git a/ohttp/src/rh/hpke.rs b/ohttp/src/rh/hpke.rs index 4b81152..0759263 100644 --- a/ohttp/src/rh/hpke.rs +++ b/ohttp/src/rh/hpke.rs @@ -11,9 +11,9 @@ use ::hpke as rust_hpke; use ::hpke_pq as rust_hpke; use rust_hpke::{ - aead::{AeadCtxR, AeadCtxS, AeadTag, AesGcm128, ChaCha20Poly1305}, - kdf::HkdfSha256, - kem::{Kem as KemTrait, X25519HkdfSha256}, + aead::{AeadCtxR, AeadCtxS, AeadTag, AesGcm128, AesGcm256, ChaCha20Poly1305}, + kdf::{HkdfSha256, HkdfSha384}, + kem::{DhP384HkdfSha384, Kem as KemTrait, X25519HkdfSha256}, setup_receiver, setup_sender, Deserializable, OpModeR, OpModeS, Serializable, }; @@ -21,8 +21,8 @@ use rust_hpke::{ use rust_hpke::kem::X25519Kyber768Draft00; use ::rand::thread_rng; -use log::trace; use std::ops::Deref; +use tracing::trace; /// Configuration for `Hpke`. #[derive(Clone, Copy)] @@ -51,7 +51,11 @@ impl Config { pub fn supported(self) -> bool { // TODO support more options - self.kdf == Kdf::HkdfSha256 && matches!(self.aead, Aead::Aes128Gcm | Aead::ChaCha20Poly1305) + matches!(self.kdf, Kdf::HkdfSha256 | Kdf::HkdfSha384) + && matches!( + self.aead, + Aead::Aes128Gcm | Aead::Aes256Gcm | Aead::ChaCha20Poly1305 + ) } } @@ -70,6 +74,8 @@ impl Default for Config { pub enum PublicKey { X25519(::PublicKey), + P384(::PublicKey), + #[cfg(feature = "pq")] X25519Kyber768Draft00(::PublicKey), } @@ -78,6 +84,8 @@ impl PublicKey { #[allow(clippy::unnecessary_wraps)] pub fn key_data(&self) -> Res> { Ok(match self { + Self::P384(k) => Vec::from(k.to_bytes().as_slice()), + Self::X25519(k) => Vec::from(k.to_bytes().as_slice()), #[cfg(feature = "pq")] @@ -99,8 +107,8 @@ impl std::fmt::Debug for PublicKey { #[allow(clippy::large_enum_variant)] #[derive(Clone)] pub enum PrivateKey { + P384(::PrivateKey), X25519(::PrivateKey), - #[cfg(feature = "pq")] X25519Kyber768Draft00(::PrivateKey), } @@ -109,6 +117,7 @@ impl PrivateKey { #[allow(clippy::unnecessary_wraps)] pub fn key_data(&self) -> Res> { Ok(match self { + Self::P384(k) => Vec::from(k.to_bytes().as_slice()), Self::X25519(k) => Vec::from(k.to_bytes().as_slice()), #[cfg(feature = "pq")] @@ -134,6 +143,11 @@ enum SenderContextX25519HkdfSha256HkdfSha256 { ChaCha20Poly1305(Box>), } +enum SenderContextDhP384HkdfSha384HkdfSha384 { + AesGcm128(Box>), + AesGcm256(Box>), +} + #[cfg(feature = "pq")] enum SenderContextX25519Kyber768Draft00HkdfSha256 { AesGcm128(Box>), @@ -143,6 +157,10 @@ enum SenderContextX25519HkdfSha256 { HkdfSha256(SenderContextX25519HkdfSha256HkdfSha256), } +enum SenderContextDhP384HkdfSha384 { + HkdfSha384(SenderContextDhP384HkdfSha384HkdfSha384), +} + #[cfg(feature = "pq")] enum SenderContextX25519Kyber768Draft00 { HkdfSha256(SenderContextX25519Kyber768Draft00HkdfSha256), @@ -151,6 +169,8 @@ enum SenderContextX25519Kyber768Draft00 { enum SenderContext { X25519HkdfSha256(SenderContextX25519HkdfSha256), + DhP384HkdfSha384(SenderContextDhP384HkdfSha384), + #[cfg(feature = "pq")] X25519Kyber768Draft00(SenderContextX25519Kyber768Draft00), } @@ -170,7 +190,18 @@ impl SenderContext { let tag = context.seal_in_place_detached(plaintext, aad)?; Vec::from(tag.to_bytes().as_slice()) } - + Self::DhP384HkdfSha384(SenderContextDhP384HkdfSha384::HkdfSha384( + SenderContextDhP384HkdfSha384HkdfSha384::AesGcm128(context), + )) => { + let tag = context.seal_in_place_detached(plaintext, aad)?; + Vec::from(tag.to_bytes().as_slice()) + } + Self::DhP384HkdfSha384(SenderContextDhP384HkdfSha384::HkdfSha384( + SenderContextDhP384HkdfSha384HkdfSha384::AesGcm256(context), + )) => { + let tag = context.seal_in_place_detached(plaintext, aad)?; + Vec::from(tag.to_bytes().as_slice()) + } #[cfg(feature = "pq")] Self::X25519Kyber768Draft00(SenderContextX25519Kyber768Draft00::HkdfSha256( SenderContextX25519Kyber768Draft00HkdfSha256::AesGcm128(context), @@ -193,7 +224,16 @@ impl SenderContext { )) => { context.export(info, out_buf)?; } - + Self::DhP384HkdfSha384(SenderContextDhP384HkdfSha384::HkdfSha384( + SenderContextDhP384HkdfSha384HkdfSha384::AesGcm128(context), + )) => { + context.export(info, out_buf)?; + } + Self::DhP384HkdfSha384(SenderContextDhP384HkdfSha384::HkdfSha384( + SenderContextDhP384HkdfSha384HkdfSha384::AesGcm256(context), + )) => { + context.export(info, out_buf)?; + } #[cfg(feature = "pq")] Self::X25519Kyber768Draft00(SenderContextX25519Kyber768Draft00::HkdfSha256( SenderContextX25519Kyber768Draft00HkdfSha256::AesGcm128(context), @@ -277,7 +317,24 @@ impl HpkeS { SenderContextX25519HkdfSha256::HkdfSha256, SenderContextX25519HkdfSha256HkdfSha256::ChaCha20Poly1305, }, - + { + Kem::P384Sha384 => DhP384HkdfSha384, + Kdf::HkdfSha384 => HkdfSha384, + Aead::Aes128Gcm => AesGcm128, + PublicKey::P384, + SenderContext::DhP384HkdfSha384, + SenderContextDhP384HkdfSha384::HkdfSha384, + SenderContextDhP384HkdfSha384HkdfSha384::AesGcm128, + }, + { + Kem::P384Sha384 => DhP384HkdfSha384, + Kdf::HkdfSha384 => HkdfSha384, + Aead::Aes256Gcm => AesGcm256, + PublicKey::P384, + SenderContext::DhP384HkdfSha384, + SenderContextDhP384HkdfSha384::HkdfSha384, + SenderContextDhP384HkdfSha384HkdfSha384::AesGcm256, + }, #[cfg(feature = "pq")] { Kem::X25519Kyber768Draft00 => X25519Kyber768Draft00, @@ -335,6 +392,11 @@ enum ReceiverContextX25519HkdfSha256HkdfSha256 { ChaCha20Poly1305(Box>), } +enum ReceiverContextDhP384HkdfSha384HkdfSha384 { + AesGcm128(Box>), + AesGcm256(Box>), +} + #[cfg(feature = "pq")] enum ReceiverContextX25519Kyber768Draft00HkdfSha256 { AesGcm128(Box>), @@ -344,6 +406,10 @@ enum ReceiverContextX25519HkdfSha256 { HkdfSha256(ReceiverContextX25519HkdfSha256HkdfSha256), } +enum ReceiverContextDhP384HkdfSha384 { + HkdfSha384(ReceiverContextDhP384HkdfSha384HkdfSha384), +} + #[cfg(feature = "pq")] enum ReceiverContextX25519Kyber768Draft00 { HkdfSha256(ReceiverContextX25519Kyber768Draft00HkdfSha256), @@ -351,6 +417,7 @@ enum ReceiverContextX25519Kyber768Draft00 { enum ReceiverContext { X25519HkdfSha256(ReceiverContextX25519HkdfSha256), + DhP384HkdfSha384(ReceiverContextDhP384HkdfSha384), #[cfg(feature = "pq")] X25519Kyber768Draft00(ReceiverContextX25519Kyber768Draft00), @@ -383,7 +450,30 @@ impl ReceiverContext { context.open_in_place_detached(ct, aad, &tag)?; ct } - + Self::DhP384HkdfSha384(ReceiverContextDhP384HkdfSha384::HkdfSha384( + ReceiverContextDhP384HkdfSha384HkdfSha384::AesGcm128(context), + )) => { + if ciphertext.len() < AeadTag::::size() { + return Err(Error::Truncated); + } + let (ct, tag_slice) = + ciphertext.split_at_mut(ciphertext.len() - AeadTag::::size()); + let tag = AeadTag::::from_bytes(tag_slice)?; + context.open_in_place_detached(ct, aad, &tag)?; + ct + } + Self::DhP384HkdfSha384(ReceiverContextDhP384HkdfSha384::HkdfSha384( + ReceiverContextDhP384HkdfSha384HkdfSha384::AesGcm256(context), + )) => { + if ciphertext.len() < AeadTag::::size() { + return Err(Error::Truncated); + } + let (ct, tag_slice) = + ciphertext.split_at_mut(ciphertext.len() - AeadTag::::size()); + let tag = AeadTag::::from_bytes(tag_slice)?; + context.open_in_place_detached(ct, aad, &tag)?; + ct + } #[cfg(feature = "pq")] Self::X25519Kyber768Draft00(ReceiverContextX25519Kyber768Draft00::HkdfSha256( ReceiverContextX25519Kyber768Draft00HkdfSha256::AesGcm128(context), @@ -412,6 +502,16 @@ impl ReceiverContext { )) => { context.export(info, out_buf)?; } + Self::DhP384HkdfSha384(ReceiverContextDhP384HkdfSha384::HkdfSha384( + ReceiverContextDhP384HkdfSha384HkdfSha384::AesGcm128(context), + )) => { + context.export(info, out_buf)?; + } + Self::DhP384HkdfSha384(ReceiverContextDhP384HkdfSha384::HkdfSha384( + ReceiverContextDhP384HkdfSha384HkdfSha384::AesGcm256(context), + )) => { + context.export(info, out_buf)?; + } #[cfg(feature = "pq")] Self::X25519Kyber768Draft00(ReceiverContextX25519Kyber768Draft00::HkdfSha256( @@ -493,6 +593,24 @@ impl HpkeR { ReceiverContextX25519HkdfSha256::HkdfSha256, ReceiverContextX25519HkdfSha256HkdfSha256::ChaCha20Poly1305, }, + { + Kem::P384Sha384 => DhP384HkdfSha384, + Kdf::HkdfSha384 => HkdfSha384, + Aead::Aes128Gcm => AesGcm128, + PrivateKey::P384, + ReceiverContext::DhP384HkdfSha384, + ReceiverContextDhP384HkdfSha384::HkdfSha384, + ReceiverContextDhP384HkdfSha384HkdfSha384::AesGcm128, + }, + { + Kem::P384Sha384 => DhP384HkdfSha384, + Kdf::HkdfSha384 => HkdfSha384, + Aead::Aes256Gcm => AesGcm256, + PrivateKey::P384, + ReceiverContext::DhP384HkdfSha384, + ReceiverContextDhP384HkdfSha384::HkdfSha384, + ReceiverContextDhP384HkdfSha384HkdfSha384::AesGcm256, + }, #[cfg(feature = "pq")] { @@ -515,6 +633,10 @@ impl HpkeR { pub fn decode_public_key(kem: Kem, k: &[u8]) -> Res { Ok(match kem { + Kem::P384Sha384 => { + PublicKey::P384(::PublicKey::from_bytes(k)?) + } + Kem::X25519Sha256 => { PublicKey::X25519(::PublicKey::from_bytes(k)?) } @@ -554,6 +676,11 @@ impl Deref for HpkeR { pub fn generate_key_pair(kem: Kem) -> Res<(PrivateKey, PublicKey)> { let mut csprng = thread_rng(); let (sk, pk) = match kem { + Kem::P384Sha384 => { + let (sk, pk) = DhP384HkdfSha384::gen_keypair(&mut csprng); + (PrivateKey::P384(sk), PublicKey::P384(pk)) + } + Kem::X25519Sha256 => { let (sk, pk) = X25519HkdfSha256::gen_keypair(&mut csprng); (PrivateKey::X25519(sk), PublicKey::X25519(pk)) @@ -575,6 +702,11 @@ pub fn generate_key_pair(kem: Kem) -> Res<(PrivateKey, PublicKey)> { #[allow(clippy::unnecessary_wraps)] pub fn derive_key_pair(kem: Kem, ikm: &[u8]) -> Res<(PrivateKey, PublicKey)> { let (sk, pk) = match kem { + Kem::P384Sha384 => { + let (sk, pk) = DhP384HkdfSha384::derive_keypair(ikm); + (PrivateKey::P384(sk), PublicKey::P384(pk)) + } + Kem::X25519Sha256 => { let (sk, pk) = X25519HkdfSha256::derive_keypair(ikm); (PrivateKey::X25519(sk), PublicKey::X25519(pk))