diff --git a/Cargo.toml b/Cargo.toml index 74f8f33..3a29ecd 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -15,6 +15,7 @@ jsonwebtoken = "9.2" lazy_static = "1.4" log = "0.4" rand = "0.8" +serde = { version = "1.0.193", features = ["derive"] } serde_json = { version = "1.0", features = ["preserve_order"] } sha2 = "0.10" thiserror = "1.0.50" diff --git a/README.md b/README.md index 2c2542c..bc421c4 100644 --- a/README.md +++ b/README.md @@ -51,7 +51,7 @@ TBD ## External Dependencies Dual license (MIT/Apache 2.0) -dependencies: [base64](https://crates.io/crates/base64), [lazy_static](https://crates.io/crates/lazy_static) [log](https://crates.io/crates/log), [serde_json](https://crates.io/crates/serde_json), [sha2](https://crates.io/crates/sha2), [rand](https://crates.io/crates/rand), [hmac](https://crates.io/crates/hmac), [thiserror](https://crates.io/crates/thiserror). +dependencies: [base64](https://crates.io/crates/base64), [lazy_static](https://crates.io/crates/lazy_static) [log](https://crates.io/crates/log), [serde](https://crates.io/crates/serde), [serde_json](https://crates.io/crates/serde_json), [sha2](https://crates.io/crates/sha2), [rand](https://crates.io/crates/rand), [hmac](https://crates.io/crates/hmac), [thiserror](https://crates.io/crates/thiserror). MIT license dependencies: [jsonwebtoken](https://crates.io/crates/jsonwebtoken), [strum](https://crates.io/crates/strum) Note: the list of dependencies may be changed in the future. diff --git a/src/holder.rs b/src/holder.rs index 9221798..b4c9f0f 100644 --- a/src/holder.rs +++ b/src/holder.rs @@ -1,4 +1,4 @@ -use crate::error; +use crate::{error, SDJWTJson}; use error::{Error, Result}; use jsonwebtoken::{Algorithm, EncodingKey, Header}; use serde_json::{Map, Value}; @@ -21,7 +21,7 @@ pub struct SDJWTHolder { serialized_key_binding_jwt: String, sd_jwt_payload: Map, serialized_sd_jwt: String, - sd_jwt: String, + sd_jwt_json: Option, } impl SDJWTHolder { @@ -36,7 +36,7 @@ impl SDJWTHolder { let mut holder = SDJWTHolder { sd_jwt_engine: SDJWTCommon { - serialization_format, + serialization_format: serialization_format.clone(), ..Default::default() }, hs_disclosures: Vec::new(), @@ -45,10 +45,12 @@ impl SDJWTHolder { serialized_key_binding_jwt: "".to_string(), sd_jwt_payload: Map::new(), serialized_sd_jwt: "".to_string(), - sd_jwt: "".to_string(), + sd_jwt_json: None, }; - holder.sd_jwt_engine.parse_sd_jwt(sd_jwt_with_disclosures)?; + holder + .sd_jwt_engine + .parse_sd_jwt(sd_jwt_with_disclosures.clone())?; //TODO Verify signature before accepting the JWT holder.sd_jwt_payload = holder @@ -61,6 +63,7 @@ impl SDJWTHolder { .unverified_sd_jwt .take() .ok_or(Error::InvalidState("Cannot take jwt".to_string()))?; + holder.sd_jwt_json = holder.sd_jwt_engine.unverified_sd_jwt_json.clone(); holder.sd_jwt_engine.create_hash_mappings()?; @@ -100,19 +103,15 @@ impl SDJWTHolder { let joined = combined.join(COMBINED_SERIALIZATION_FORMAT_SEPARATOR); joined.to_string() } else { - let mut sd_jwt_parsed: Map = serde_json::from_str(&self.sd_jwt) - .map_err(|e| Error::DeserializationError(e.to_string()))?; - sd_jwt_parsed.insert( - crate::JWS_KEY_DISCLOSURES.to_owned(), - self.hs_disclosures.clone().into(), - ); + let mut sd_jwt_json = self + .sd_jwt_json + .take() + .ok_or(Error::InvalidState("Cannot take SDJWTJson".to_string()))?; + sd_jwt_json.disclosures = self.hs_disclosures.clone(); if !self.serialized_key_binding_jwt.is_empty() { - sd_jwt_parsed.insert( - crate::JWS_KEY_KB_JWT.to_owned(), - self.serialized_key_binding_jwt.clone().into(), - ); + sd_jwt_json.kb_jwt = Some(self.serialized_key_binding_jwt.clone()); } - serde_json::to_string(&sd_jwt_parsed) + serde_json::to_string(&sd_jwt_json) .map_err(|e| Error::DeserializationError(e.to_string()))? }; diff --git a/src/issuer.rs b/src/issuer.rs index 6a96f39..4182da4 100644 --- a/src/issuer.rs +++ b/src/issuer.rs @@ -1,4 +1,4 @@ -use crate::error; +use crate::{error, SDJWTJson}; use error::Result; use std::collections::{HashMap, VecDeque}; use std::str::FromStr; @@ -279,13 +279,6 @@ impl SDJWTIssuer { self.signed_sd_jwt = jsonwebtoken::encode(&header, &self.sd_jwt_payload, &self.issuer_key) .map_err(|e| Error::DeserializationError(e.to_string()))?; - if self.inner.serialization_format == "json" { - unimplemented!("json serialization is not supported for issuance"); - // let jws_content = serde_json::from_str(&self.serialized_sd_jwt).unwrap(); - // jws_content.insert(JWS_KEY_DISCLOSURES.to_string(), self.ii_disclosures.iter().map(|d| d.b64.to_string()).collect()); - // self.serialized_sd_jwt = serde_json::to_string(&jws_content).unwrap(); - } - Ok(()) } @@ -306,7 +299,26 @@ impl SDJWTIssuer { COMBINED_SERIALIZATION_FORMAT_SEPARATOR, ); } else if self.inner.serialization_format == "json" { - self.serialized_sd_jwt = self.signed_sd_jwt.clone(); + let jwt: Vec<&str> = self.signed_sd_jwt.split('.').collect(); + if jwt.len() != 3 { + return Err(Error::InvalidInput(format!( + "Invalid JWT, JWT must contain three parts after splitting with \".\": jwt {}", + self.signed_sd_jwt + ))); + } + let sd_jwt_json = SDJWTJson { + protected: jwt[0].to_owned(), + payload: jwt[1].to_owned(), + signature: jwt[2].to_owned(), + kb_jwt: None, + disclosures: self + .all_disclosures + .iter() + .map(|d| d.raw_b64.to_string()) + .collect(), + }; + self.serialized_sd_jwt = serde_json::to_string(&sd_jwt_json) + .map_err(|e| Error::DeserializationError(e.to_string()))?; } else { return Err(Error::InvalidInput( format!("Unknown serialization format {}, only \"compact\" or \"json\" formats are supported", self.inner.serialization_format) diff --git a/src/lib.rs b/src/lib.rs index b8f704b..4e6644f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,6 +1,8 @@ use crate::error::Error; use crate::utils::{base64_hash, base64url_decode, jwt_payload_decode}; + use error::Result; +use serde::{Deserialize, Serialize}; use serde_json::{Map, Value}; use std::collections::HashMap; pub use {holder::SDJWTHolder, issuer::SDJWTIssuer, verifier::SDJWTVerifier}; @@ -20,8 +22,6 @@ const SD_LIST_PREFIX: &str = "..."; const _SD_JWT_TYP_HEADER: &str = "sd+jwt"; const KB_JWT_TYP_HEADER: &str = "kb+jwt"; const KB_DIGEST_KEY: &str = "_sd_hash"; -const JWS_KEY_DISCLOSURES: &str = "disclosures"; -const JWS_KEY_KB_JWT: &str = "kb_jwt"; pub const COMBINED_SERIALIZATION_FORMAT_SEPARATOR: &str = "~"; const JWT_SEPARATOR: &str = "."; const CNF_KEY: &str = "cnf"; @@ -38,12 +38,22 @@ pub(crate) struct SDJWTCommon { serialization_format: String, unverified_input_key_binding_jwt: Option, unverified_sd_jwt: Option, + unverified_sd_jwt_json: Option, unverified_input_sd_jwt_payload: Option>, hash_to_decoded_disclosure: HashMap, hash_to_disclosure: HashMap, input_disclosures: Vec, } +#[derive(Default, Serialize, Deserialize, Clone, Eq, PartialEq, Debug)] +pub struct SDJWTJson { + protected: String, + payload: String, + signature: String, + pub disclosures: Vec, + pub kb_jwt: Option, +} + // Define the SDJWTCommon struct to hold common properties. impl SDJWTCommon { fn create_hash_mappings(&mut self) -> Result<()> { @@ -150,25 +160,19 @@ impl SDJWTCommon { self.unverified_input_sd_jwt_payload = Some(jwt_payload_decode(jwt_body)?); Ok(()) } else { - // If the SD-JWT is in JSON format, parse the JSON and extract the disclosures. - let unverified_input_sd_jwt_parsed: Value = - serde_json::from_str(&sd_jwt_with_disclosures) - .map_err(|e| Error::DeserializationError(e.to_string()))?; - self.unverified_input_key_binding_jwt = unverified_input_sd_jwt_parsed - .get(JWS_KEY_KB_JWT) - .map(Value::to_string); - self.input_disclosures = unverified_input_sd_jwt_parsed[JWS_KEY_DISCLOSURES] - .as_array() - .ok_or(Error::ConversionError( - "Cannot convert `disclosures` to array".to_string(), - ))? - .iter() - .map(Value::to_string) - .collect(); - let payload = unverified_input_sd_jwt_parsed["payload"] - .as_str() - .ok_or(Error::KeyNotFound("payload".to_string()))?; - self.unverified_input_sd_jwt_payload = Some(jwt_payload_decode(payload)?); + let parsed_sd_jwt_json: SDJWTJson = serde_json::from_str(&sd_jwt_with_disclosures) + .map_err(|e| Error::DeserializationError(e.to_string()))?; + self.unverified_sd_jwt_json = Some(parsed_sd_jwt_json.clone()); + self.unverified_input_key_binding_jwt = parsed_sd_jwt_json.kb_jwt; + self.input_disclosures = parsed_sd_jwt_json.disclosures; + self.unverified_input_sd_jwt_payload = + Some(jwt_payload_decode(&parsed_sd_jwt_json.payload)?); + self.unverified_sd_jwt = Some(format!( + "{}.{}.{}", + parsed_sd_jwt_json.protected, + parsed_sd_jwt_json.payload, + parsed_sd_jwt_json.signature + )); Ok(()) } } diff --git a/tests/demos.rs b/tests/demos.rs index b19c53f..662563b 100644 --- a/tests/demos.rs +++ b/tests/demos.rs @@ -8,7 +8,7 @@ use jsonwebtoken::jwk::Jwk; use jsonwebtoken::{DecodingKey, EncodingKey}; use rstest::{fixture, rstest}; use sd_jwt_rs::issuer::SDJWTClaimsStrategy; -use sd_jwt_rs::{SDJWTHolder, SDJWTIssuer, SDJWTVerifier}; +use sd_jwt_rs::{SDJWTHolder, SDJWTIssuer, SDJWTJson, SDJWTVerifier}; use sd_jwt_rs::{COMBINED_SERIALIZATION_FORMAT_SEPARATOR, DEFAULT_SIGNING_ALG}; use serde_json::{json, Map, Value}; use std::collections::HashSet; @@ -287,7 +287,7 @@ fn demo_positive_cases( Option, Option, ), - #[values("compact".to_string())] format: String, + #[values("compact".to_string(), "json".to_string())] format: String, #[values(None, Some(DEFAULT_SIGNING_ALG.to_owned()))] sign_algo: Option, #[values(true, false)] add_decoy: bool, ) { @@ -300,38 +300,62 @@ fn demo_positive_cases( holder_jwk.clone(), add_decoy, format.clone(), - ).unwrap(); + ) + .unwrap(); let issued = sd_jwt.clone(); // Holder creates presentation let mut holder = SDJWTHolder::new(sd_jwt.clone(), format.clone()).unwrap(); - let presentation = holder.create_presentation( - holder_disclosed_claims, - nonce.clone(), - aud.clone(), - holder_key, - sign_algo, - ).unwrap(); - - let mut issued_parts: HashSet<&str> = issued - .split(COMBINED_SERIALIZATION_FORMAT_SEPARATOR) - .collect(); - issued_parts.remove(""); - - let mut revealed_parts: HashSet<&str> = presentation - .split(COMBINED_SERIALIZATION_FORMAT_SEPARATOR) - .collect(); - revealed_parts.remove(""); - - let intersected_parts: HashSet<_> = issued_parts.intersection(&revealed_parts).collect(); - // Compare that number of disclosed parts are equal - let mut revealed_parts_number = revealed_parts.len(); - if holder_jwk.is_some() { - // Remove KB - revealed_parts_number -= 1; + let presentation = holder + .create_presentation( + holder_disclosed_claims, + nonce.clone(), + aud.clone(), + holder_key, + sign_algo, + ) + .unwrap(); + + if format == "compact" { + let mut issued_parts: HashSet<&str> = issued + .split(COMBINED_SERIALIZATION_FORMAT_SEPARATOR) + .collect(); + issued_parts.remove(""); + + let mut revealed_parts: HashSet<&str> = presentation + .split(COMBINED_SERIALIZATION_FORMAT_SEPARATOR) + .collect(); + revealed_parts.remove(""); + + let intersected_parts: HashSet<_> = issued_parts.intersection(&revealed_parts).collect(); + // Compare that number of disclosed parts are equal + let mut revealed_parts_number = revealed_parts.len(); + if holder_jwk.is_some() { + // Remove KB + revealed_parts_number -= 1; + } + assert_eq!(intersected_parts.len(), revealed_parts_number); + // here `+1` means adding issued jwt part also + assert_eq!(number_of_revealed_sds + 1, revealed_parts_number); + } else { + let mut issued: SDJWTJson = serde_json::from_str(&issued).unwrap(); + let mut revealed: SDJWTJson = serde_json::from_str(&presentation).unwrap(); + let disclosures: Vec = revealed + .disclosures + .clone() + .into_iter() + .filter(|d| issued.disclosures.contains(d)) + .collect(); + assert_eq!(number_of_revealed_sds, disclosures.len()); + + if holder_jwk.is_some() { + assert!(revealed.kb_jwt.is_some()); + } + + issued.disclosures = disclosures; + revealed.kb_jwt = None; + assert_eq!(revealed, issued); } - assert_eq!(intersected_parts.len(), revealed_parts_number); - // here `+1` means adding issued jwt part also - assert_eq!(number_of_revealed_sds + 1, revealed_parts_number); + // Verify presentation let _verified = SDJWTVerifier::new( presentation.clone(), @@ -342,5 +366,6 @@ fn demo_positive_cases( aud, nonce, format, - ).unwrap(); + ) + .unwrap(); }