diff --git a/Cargo.toml b/Cargo.toml index aea4bfe..75539f2 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -21,7 +21,7 @@ lazy_static = { version = "1.4", optional = true } log = "0.4" rand = "0.8" serde = { version = "1.0.193", features = ["derive"] } -serde_json = "1.0" +serde_json = { version = "1.0.113", features = ["preserve_order"] } sha2 = "0.10" thiserror = "1.0.51" strum = { version = "0.25", default-features = false, features = ["std", "derive"] } diff --git a/README.md b/README.md index 067c7d0..1ab78c1 100644 --- a/README.md +++ b/README.md @@ -40,7 +40,7 @@ cargo test ``` ### Interoperability testing tool -Coming soon (planned for v0.0.7) +See [Generate tool README](./generate/README.md) document. ## External Dependencies diff --git a/generate/Cargo.toml b/generate/Cargo.toml index f76864e..76bdffd 100644 --- a/generate/Cargo.toml +++ b/generate/Cargo.toml @@ -2,6 +2,7 @@ name = "sd-jwt-generate" version = "0.1.0" edition = "2021" +authors = ["Abdulbois Tursunov ", "Alexander Sukhachev "] # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html @@ -9,6 +10,6 @@ edition = "2021" clap = { version = "4.4.10", features = ["derive"] } serde = { version = "1.0.193", features = ["derive"] } serde_yaml = "0.9.27" -serde_json = "1.0.108" +serde_json = { version = "1.0.113", features = ["preserve_order"] } jsonwebtoken = "9.1" -sd-jwt-rs = {path = "./.."} \ No newline at end of file +sd-jwt-rs = {path = "./..", features = ["mock_salts"]} \ No newline at end of file diff --git a/generate/README.md b/generate/README.md new file mode 100644 index 0000000..3dd3bfa --- /dev/null +++ b/generate/README.md @@ -0,0 +1,117 @@ +# SD-JWT Interop tool + +This tool is used to verify interoperability between the `sd-jwt-rust` and `sd-jwt-python` implementations of the [IETF SD-JWT specification](https://datatracker.ietf.org/doc/draft-ietf-oauth-selective-disclosure-jwt/). + +## How does the Interop tool work? + +The main idea is to generate data structures (SDJWT/presentation/verified claims) using both implementations and compare them. + +The `sd-jwt-python` is used to generate artifacts based on input data (`specification.yml`) and store them as files. +The interop tool (based on `sd-jwt-rust`) is used to generate artifacts using the same specification file, load artifacts stored in files by `sd-jwt-python` and compare them. The interop tool doesn't store any files on filesystem. + +There are some factors that make impossible to compare data due to non-equivalence data generated by different implementations: + +- Using random 'salt' in each run that make results different even though they are generated by the same implementation. +- Not equivalent json-serialized strings (different number of spaces) generated under the hood of the different implementations. +- Using 'decoy' digests in the SD-JWT payload. + +In order to reach reproducibility and equivalence of the values generated by both implementations it is required to use the same input data (issuer private key, user claims, etc.) and to get rid of some non-deterministic values during data generating (values of 'salt', for example). + +### Deterministic 'salt' + +In order to make it possible to get reproducible result each run it's required to use deterministic values of 'salt' used in internal algorithms. The `sd-jwt-python` project implements such behavior for test purposes. + +In order to use the same set of 'salt' values by the `sd-jwt-rust` project Python-implementation stores values in the `claims_vs_salts.json` file as artifact. The Interop tool loads values from the file and use it instead of random generated values (see the `mock_salts` feature). + + +### Similar json serialization + +In order to have the same json-strings used under the hood of the both implementations there is some code that gets rid of different number of spaces: + +```rust + value_str = value_str + .replace(":[", ": [") + .replace(',', ", ") + .replace("\":", "\": ") + .replace("\": ", "\": "); +``` + +### 'Decoy' SD items + +In order to make it possible to compare `SD-JWT` payloads that contains [decoy](https://www.ietf.org/archive/id/draft-ietf-oauth-selective-disclosure-jwt-07.html#name-decoy-digests) it was decided to detect and remove all `decoy` items from payloads and then compare them. + + +## How to use the interop tool? + +1. Install the prerequisites +2. Clone and build the `sd-jwt-rust` project +3. Clone and build the `sd-jwt-python` project +4. Generate artifacts using the `sd-jwt-python` project +5. Run the interop tool + +### Install the prerequisites + +In order to be able to build both implementations it is required to setup following tools: + +- `Rust`/`cargo` +- `poetry` + + +### Clone and build the `sd-jwt-rust` project + +```shell +git clone git@github.com:openwallet-foundation-labs/sd-jwt-rust.git +cd sd-jwt-rust/generate +cargo build +``` + + +### Clone and build the `sd-jwt-python` project + +Once the project repo is cloned to local directory it is necessary to apply special patch. +This patch is required to have some additional files as artifacts generated by the `sd-jwt-python` project. + +Files: + +- `claims_vs_salts.json` file contains values of so called 'salt' that have been used during `SDJWT` issuance. +- `issuer_key.pem` file contains the issuer's private key. +- `issuer_public_key.pem` file contains the issuer's public key. +- `holder_key.pem` file contains the holder's private key. + +The files are used to make it possible for this tool to generate the same values of artifacts (SDJWT payload/SDJWT claims/presentation/verified claims) that are generated by `sd-jwt-python`. + + +```shell +git clone git@github.com:openwallet-foundation-labs/sd-jwt-python.git +cd sd-jwt-python + +# apply the patch +git apply ../sd-jwt-rust/generate/sd_jwt_python.patch + +# build +poetry install && poetry build +``` + + + +### Generate artifacts using the `sd-jwt-python` project + +```shell +pushd sd-jwt-python/tests/testcases && poetry run ../../src/sd_jwt/bin/generate.py -- example && popd +pushd sd-jwt-python/examples && poetry run ../src/sd_jwt/bin/generate.py -- example && popd +``` + + +### Run the interop tool + +```shell +cd sd-jwt-rust/generate +sd_jwt_py="../../sd-jwt-python" +for cases_dir in $sd_jwt_py/examples $sd_jwt_py/tests/testcases; do + for test_case_dir in $(ls $cases_dir); do + if [[ -d $cases_dir/$test_case_dir ]]; then + ./target/debug/sd-jwt-generate -p $cases_dir/$test_case_dir + fi + done +done +``` diff --git a/generate/sd_jwt_python.patch b/generate/sd_jwt_python.patch new file mode 100644 index 0000000..5692bb0 --- /dev/null +++ b/generate/sd_jwt_python.patch @@ -0,0 +1,97 @@ +diff --git a/.gitignore b/.gitignore +index 1874e26..72ff453 100644 +--- a/.gitignore ++++ b/.gitignore +@@ -157,7 +157,7 @@ cython_debug/ + # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore + # and can be added to the global gitignore or merged into this file. For a more nuclear + # option (not recommended) you can uncomment the following to ignore the entire idea folder. +-#.idea/ ++.idea/ + + + # Ignore output of test cases except for specification.yml +diff --git a/pyproject.toml b/pyproject.toml +index 4294e64..47c9281 100644 +--- a/pyproject.toml ++++ b/pyproject.toml +@@ -12,7 +12,7 @@ jwcrypto = ">=1.3.1" + pyyaml = ">=5.4" + + [tool.poetry.group.dev.dependencies] +-flake8 = "^6.0.0" ++# flake8 = "^6.0.0" + black = "^23.3.0" + + [build-system] +diff --git a/src/sd_jwt/bin/generate.py b/src/sd_jwt/bin/generate.py +index ad00641..d0299ea 100755 +--- a/src/sd_jwt/bin/generate.py ++++ b/src/sd_jwt/bin/generate.py +@@ -105,12 +105,36 @@ def generate_test_case_data(settings: Dict, testcase_path: Path, type: str): + + # Write the test case data to the directory of the test case + ++ claims_vs_salts = [] ++ for disclosure in sdjwt_at_issuer.ii_disclosures: ++ claims_vs_salts.append(disclosure.salt) ++ + _artifacts = { + "user_claims": ( + remove_sdobj_wrappers(testcase["user_claims"]), + "User Claims", + "json", + ), ++ "issuer_key": ( ++ demo_keys["issuer_key"].export_to_pem(True, None).decode("utf-8"), ++ "Issuer private key", ++ "pem", ++ ), ++ "issuer_public_key": ( ++ demo_keys["issuer_public_key"].export_to_pem(False, None).decode("utf-8"), ++ "Issuer public key", ++ "pem", ++ ), ++ "holder_key": ( ++ demo_keys["holder_key"].export_to_pem(True, None).decode("utf-8"), ++ "Issuer private key", ++ "pem", ++ ), ++ "claims_vs_salts": ( ++ claims_vs_salts, ++ "Claims with Salts", ++ "json", ++ ), + "sd_jwt_payload": ( + sdjwt_at_issuer.sd_jwt_payload, + "Payload of the SD-JWT", +diff --git a/src/sd_jwt/disclosure.py b/src/sd_jwt/disclosure.py +index a9727c4..d1f983a 100644 +--- a/src/sd_jwt/disclosure.py ++++ b/src/sd_jwt/disclosure.py +@@ -15,11 +15,11 @@ class SDJWTDisclosure: + self._hash() + + def _hash(self): +- salt = self.issuer._generate_salt() ++ self._salt = self.issuer._generate_salt() + if self.key is None: +- data = [salt, self.value] ++ data = [self._salt, self.value] + else: +- data = [salt, self.key, self.value] ++ data = [self._salt, self.key, self.value] + + self._json = dumps(data).encode("utf-8") + +@@ -30,6 +30,10 @@ class SDJWTDisclosure: + def hash(self): + return self._hash + ++ @property ++ def salt(self): ++ return self._salt ++ + @property + def b64(self): + return self._raw_b64 diff --git a/generate/src/error.rs b/generate/src/error.rs index 9905d78..615f23c 100644 --- a/generate/src/error.rs +++ b/generate/src/error.rs @@ -3,8 +3,6 @@ use std::error::Error as StdError; use std::fmt::{self, Display, Formatter}; use std::result::Result as StdResult; -use serde_json; -use serde_yaml; pub type Result = std::result::Result; @@ -12,6 +10,7 @@ pub type Result = std::result::Result; pub enum ErrorKind { Input, IOError, + DataNotEqual, } impl ErrorKind { @@ -19,7 +18,8 @@ impl ErrorKind { pub const fn as_str(&self) -> &'static str { match self { Self::Input => "Input error", - Self::IOError => "IO error" + Self::IOError => "IO error", + Self::DataNotEqual => "Data not equal error", } } } diff --git a/generate/src/main.rs b/generate/src/main.rs index 5e0674f..80f617c 100644 --- a/generate/src/main.rs +++ b/generate/src/main.rs @@ -2,25 +2,30 @@ mod error; mod types; mod utils; +use jsonwebtoken::jwk::Jwk; + use crate::error::{Error, ErrorKind, Result}; +use crate::utils::funcs::{parse_sdjwt_paylod, load_salts}; use clap::Parser; -use jsonwebtoken::EncodingKey; -use sd_jwt_rs::issuer::{SDJWTClaimsStrategy, SDJWTIssuer}; -use sd_jwt_rs::SALTS; -use serde_json::Value; -use std::collections::HashMap; +use jsonwebtoken::{EncodingKey, DecodingKey}; +use sd_jwt_rs::issuer::{ClaimsForSelectiveDisclosureStrategy, SDJWTIssuer}; +use sd_jwt_rs::holder::SDJWTHolder; +use sd_jwt_rs::verifier::SDJWTVerifier; +use sd_jwt_rs::SDJWTSerializationFormat; +use serde_json::{Number, Value}; use std::path::PathBuf; use types::cli::{Cli, GenerateType}; use types::settings::Settings; use types::specification::Specification; const ISSUER_KEY_PEM_FILE_NAME: &str = "issuer_key.pem"; +const ISSUER_PUBLIC_KEY_PEM_FILE_NAME: &str = "issuer_public_key.pem"; // const HOLDER_KEY_PEM_FILE_NAME: &str = "holder_key.pem"; -const SERIALIZATION_FORMAT: &str = "compact"; const SETTINGS_FILE_NAME: &str = "settings.yml"; const SPECIFICATION_FILE_NAME: &str = "specification.yml"; const SALTS_FILE_NAME: &str = "claims_vs_salts.json"; -const SD_JWT_PAYLOAD_FILE_NAME: &str = "sd_jwt_payload.json"; +const SD_JWT_FILE_NAME_TEMPLATE: &str = "sd_jwt_issuance"; +const VERIFIED_CLAIMS_FILE_NAME: &str = "verified_contents.json"; fn main() { let args = Cli::parse(); @@ -28,14 +33,13 @@ fn main() { println!("type_: {:?}, paths: {:?}", args.type_.clone(), args.paths); let basedir = std::env::current_dir().expect("Unable to get current directory"); - - let settings = get_settings(&basedir.join(SETTINGS_FILE_NAME)); - let spec_directories = get_specification_paths(&args, basedir).unwrap(); for mut directory in spec_directories { println!("Generating data for '{:?}'", directory); + let settings = get_settings(&directory.parent().unwrap().join("..").join(SETTINGS_FILE_NAME)); let specs = Specification::from(&directory); + // Remove specification.yaml from path directory.pop(); @@ -45,63 +49,181 @@ fn main() { fn generate_and_check( directory: &PathBuf, - _: &Settings, + settings: &Settings, specs: Specification, _: GenerateType, ) -> Result<()> { - // let seed = settings.random_seed.unwrap_or(0); + let decoy = specs.add_decoy_claims.unwrap_or(false); + let serialization_format; + let stored_sd_jwt_file_path; + + match &specs.serialization_format { + Some(format) if format == "json" => { + serialization_format = SDJWTSerializationFormat::JSON; + stored_sd_jwt_file_path = directory.join(format!("{SD_JWT_FILE_NAME_TEMPLATE}.json")); + }, + Some(format) if format == "compact" => { + serialization_format = SDJWTSerializationFormat::Compact; + stored_sd_jwt_file_path = directory.join(format!("{SD_JWT_FILE_NAME_TEMPLATE}.txt")); + }, + None => { + println!("using default serialization format: Compact"); + serialization_format = SDJWTSerializationFormat::Compact; + stored_sd_jwt_file_path = directory.join(format!("{SD_JWT_FILE_NAME_TEMPLATE}.txt")); + }, + Some(format) => { + panic!("unsupported format: {format}"); + }, + }; + + let sd_jwt = issue_sd_jwt(directory, &specs, settings, serialization_format.clone(), decoy)?; + let presentation = create_presentation(&sd_jwt, serialization_format.clone(), &specs.holder_disclosed_claims)?; + + // Verify presentation + let verified_claims = verify_presentation(directory, &presentation, serialization_format.clone())?; + + let loaded_sd_jwt = load_sd_jwt(&stored_sd_jwt_file_path)?; + + let loaded_sdjwt_paylod = parse_sdjwt_paylod(&loaded_sd_jwt.replace('\n', ""), &serialization_format, decoy)?; + let issued_sdjwt_paylod = parse_sdjwt_paylod(&sd_jwt, &serialization_format, decoy)?; + + compare_jwt_payloads(&loaded_sdjwt_paylod, &issued_sdjwt_paylod)?; + + let loaded_verified_claims_content = load_sd_jwt(&directory.join(VERIFIED_CLAIMS_FILE_NAME))?; + let loaded_verified_claims = parse_verified_claims(&loaded_verified_claims_content)?; + + compare_verified_claims(&loaded_verified_claims, &verified_claims)?; - // Get keys from .pem files + Ok(()) +} + +fn issue_sd_jwt( + directory: &PathBuf, + specs: &Specification, + settings: &Settings, + serialization_format: SDJWTSerializationFormat, + decoy: bool +) -> Result { let issuer_key = get_key(&directory.join(ISSUER_KEY_PEM_FILE_NAME)); - // let holder_key = get_key(key_path.join(HOLDER_KEY_PEM_FILE_NAME)); - let user_claims = specs.user_claims.claims_to_json_value()?; - let decoy = specs.add_decoy_claims.unwrap_or(false); + let mut user_claims = specs.user_claims.claims_to_json_value()?; + let claims_obj = user_claims.as_object_mut().expect("must be an object"); + + if !claims_obj.contains_key("iss") { + claims_obj.insert(String::from("iss"), Value::String(settings.identifiers.issuer.clone())); + } + + if !claims_obj.contains_key("iat") { + let iat = settings.iat.expect("'iat' value must be provided by settings.yml"); + claims_obj.insert(String::from("iat"), Value::Number(Number::from(iat))); + } + + if !claims_obj.contains_key("exp") { + let exp = settings.exp.expect("'expt' value must be provided by settings.yml"); + claims_obj.insert(String::from("exp"), Value::Number(Number::from(exp))); + } + let sd_claims_jsonpaths = specs.user_claims.sd_claims_to_jsonpath()?; let strategy = - SDJWTClaimsStrategy::Partial(sd_claims_jsonpaths.iter().map(String::as_str).collect()); + ClaimsForSelectiveDisclosureStrategy::Custom(sd_claims_jsonpaths.iter().map(String::as_str).collect()); + + let jwk: Option = if specs.key_binding.unwrap_or(false) { + let jwk: Jwk = serde_yaml::from_value(settings.key_settings.holder_key.clone()).unwrap(); + Some(jwk) + } else { + None + }; + + let mut issuer = SDJWTIssuer::new(issuer_key, Some(String::from("ES256"))); + let sd_jwt = issuer.issue_sd_jwt( + user_claims, + strategy, + jwk, + decoy, + serialization_format) + .unwrap(); + + Ok(sd_jwt) +} + +fn create_presentation( + sd_jwt: &str, + serialization_format: SDJWTSerializationFormat, + disclosed_claims: &serde_json::Map +) -> Result { + let mut holder = SDJWTHolder::new(sd_jwt.to_string(), serialization_format).unwrap(); + + let presentation = holder + .create_presentation( + disclosed_claims.clone(), + None, + None, + None, + None + ).unwrap(); + + Ok(presentation) +} - let issuer = SDJWTIssuer::issue_sd_jwt( - user_claims, - strategy, - issuer_key, +fn verify_presentation( + directory: &PathBuf, + presentation: &str, + serialization_format: SDJWTSerializationFormat +) -> Result { + let pub_key_path = directory.clone().join(ISSUER_PUBLIC_KEY_PEM_FILE_NAME); + + let _verified = SDJWTVerifier::new( + presentation.to_string(), + Box::new(move |_, _| { + let key = std::fs::read(&pub_key_path).expect("Failed to read file"); + DecodingKey::from_ec_pem(&key).expect("Unable to create EncodingKey") + }), None, None, - decoy, - SERIALIZATION_FORMAT.to_string(), - ); - println!("Issued SD-JWT \n {:#?}", issuer.sd_jwt_payload); - - compare_jwt_payloads( - &directory.join(SD_JWT_PAYLOAD_FILE_NAME), - &issuer.sd_jwt_payload, - ) - - // let mut holder = SDJWTHolder::new( - // issuer.serialized_sd_jwt.clone(), - // SERIALIZATION_FORMAT.to_string(), - // ); - // holder.create_presentation(Some(vec!["address".to_string()]), None, None, None, None); - // println!("Created presentation \n {:?}", holder.sd_jwt_presentation) + serialization_format, + ).unwrap(); + + Ok(_verified.verified_claims) } -fn compare_jwt_payloads(path: &PathBuf, compare: &serde_json::Map) -> Result<()> { - let contents = std::fs::read_to_string(path)?; +fn parse_verified_claims(content: &str) -> Result { + let json_value: Value = serde_json::from_str(content)?; + + // TODO: check if the json_value is json object + Ok(json_value) +} - let json_value: serde_json::Map = serde_json::from_str(&contents) - .expect(&format!("Failed to parse to serde_json::Value {:?}", path)); +fn load_sd_jwt(path: &PathBuf) -> Result { + let content = std::fs::read_to_string(path)?; + Ok(content) +} - if json_value.eq(compare) { - println!("Issued JWT payload is the same as payload of {:?}", path); +fn compare_jwt_payloads(loaded_payload: &Value, issued_payload: &Value) -> Result<()> { + if issued_payload.eq(loaded_payload) { + println!("\nJWT payloads are equal"); } else { - eprintln!( - "Issued JWT payload is NOT the same as payload of {:?}", - path - ); + eprintln!("\nJWT payloads are NOT equal"); + + println!("Issued SD-JWT \n {:#?}", issued_payload); + println!("Loaded SD-JWT \n {:#?}", loaded_payload); - println!("Issued SD-JWT \n {:#?}", compare); - println!("Loaded SD-JWT \n {:#?}", json_value); + return Err(Error::from_msg(ErrorKind::DataNotEqual, "JWT payloads are different")); + } + + Ok(()) +} + +fn compare_verified_claims(loaded_claims: &Value, verified_claims: &Value) -> Result<()> { + if loaded_claims.eq(verified_claims) { + println!("Verified claims are equal",); + } else { + eprintln!("Verified claims are NOT equal"); + + println!("Issued verified claims \n {:#?}", verified_claims); + println!("Loaded verified claims \n {:#?}", loaded_claims); + + return Err(Error::from_msg(ErrorKind::DataNotEqual, "verified claims are different")); } Ok(()) @@ -116,10 +238,7 @@ fn get_key(path: &PathBuf) -> EncodingKey { fn get_settings(path: &PathBuf) -> Settings { println!("settings.yaml - {:?}", path); - let settings = Settings::from(path); - println!("{:#?}", settings); - - settings + Settings::from(path) } fn get_specification_paths(args: &Cli, basedir: PathBuf) -> Result> { @@ -132,7 +251,7 @@ fn get_specification_paths(args: &Cli, basedir: PathBuf) -> Result> let path = entry.path(); if path.is_dir() && path.join(SPECIFICATION_FILE_NAME).exists() { // load_salts(&path).map_err(|err| Error::from_msg(ErrorKind::IOError, err.to_string()))?; - load_salts(&path).unwrap(); + load_salts(&path.join(SALTS_FILE_NAME)).unwrap(); return Some(path.join(SPECIFICATION_FILE_NAME)); } } @@ -145,7 +264,7 @@ fn get_specification_paths(args: &Cli, basedir: PathBuf) -> Result> .iter() .map(|d| { // load_salts(&path).map_err(|err| Error::from_msg(ErrorKind::IOError, err.to_string()))?; - load_salts(&d).unwrap(); + load_salts(&d.join(SALTS_FILE_NAME)).unwrap(); basedir.join(d).join(SPECIFICATION_FILE_NAME) }) .collect(); @@ -155,17 +274,3 @@ fn get_specification_paths(args: &Cli, basedir: PathBuf) -> Result> Ok(glob) } - -fn load_salts(path: &PathBuf) -> Result<()> { - let salts_path = path.join(SALTS_FILE_NAME); - let json_data = std::fs::read_to_string(salts_path) - .map_err(|e| Error::from_msg(ErrorKind::IOError, e.to_string()))?; - let salts: HashMap = serde_json::from_str(&json_data)?; - - { - let mut map = SALTS.lock().unwrap(); - map.extend(salts.into_iter()); - } - - Ok(()) -} diff --git a/generate/src/types/settings.rs b/generate/src/types/settings.rs index 08ce014..dfd9f1e 100644 --- a/generate/src/types/settings.rs +++ b/generate/src/types/settings.rs @@ -1,12 +1,13 @@ use std::path::PathBuf; use serde::{Deserialize, Serialize}; +use serde_yaml::Value; #[derive(Serialize, Deserialize, PartialEq, Debug)] pub struct KeySettings { pub key_size: i32, pub kty: String, pub issuer_key: Key, - pub holder_key: Key, + pub holder_key: Value, } #[derive(Serialize, Deserialize, PartialEq, Debug)] diff --git a/generate/src/types/specification.rs b/generate/src/types/specification.rs index a210a4c..88d36fd 100644 --- a/generate/src/types/specification.rs +++ b/generate/src/types/specification.rs @@ -1,7 +1,6 @@ use crate::utils::generate::generate_jsonpath_from_tagged_values; use serde::{Deserialize, Serialize}; use serde_yaml::Value; -use std::collections::HashMap; use std::path::PathBuf; use crate::error::Result; @@ -10,14 +9,61 @@ const SD_TAG: &str = "!sd"; #[derive(Serialize, Deserialize, PartialEq, Debug, Clone, Default)] pub struct Specification { pub user_claims: UserClaims, - pub holder_disclosed_claims: HashMap, + pub holder_disclosed_claims: serde_json::Map, pub add_decoy_claims: Option, pub key_binding: Option, + pub serialization_format: Option, +} + +impl Specification { + fn update_disclosed_claims(&mut self) { + // not to transform top-level empty object + if self.holder_disclosed_claims.is_empty() { + return; + } + + let res = replace_empty_items(&serde_json::Value::Object(self.holder_disclosed_claims.clone())); + self.holder_disclosed_claims = res.as_object().unwrap().clone(); + } +} + +fn replace_empty_items(m: &serde_json::Value) -> serde_json::Value { + match m { + serde_json::Value::Array(arr) if (arr.is_empty()) => { + serde_json::Value::Bool(false) + } + serde_json::Value::Object(obj) if (obj.is_empty()) => { + serde_json::Value::Bool(false) + } + serde_json::Value::Array(arr) => { + let mut result = Vec::new(); + + for value in arr { + result.push(replace_empty_items(value)); + } + + serde_json::Value::Array(result) + } + serde_json::Value::Object(obj) => { + let mut result = serde_json::Map::new(); + + for (key, value) in obj { + result.insert(key.clone(), replace_empty_items(value)); + } + + serde_json::Value::Object(result) + } + _ => { + m.clone() + } + } } impl From<&str> for Specification { fn from(value: &str) -> Self { - serde_yaml::from_str(value).unwrap_or(Specification::default()) + let mut result = serde_yaml::from_str(value).unwrap_or(Specification::default()); + result.update_disclosed_claims(); + result } } @@ -25,20 +71,20 @@ impl From<&PathBuf> for Specification { fn from(path: &PathBuf) -> Self { let contents = std::fs::read_to_string(path).expect("Failed to read specification file"); - let spec: Specification = serde_yaml::from_str(&contents).expect("Failed to parse YAML"); + let mut spec: Specification = serde_yaml::from_str(&contents).expect("Failed to parse YAML"); + + spec.update_disclosed_claims(); spec } } #[derive(Serialize, Deserialize, PartialEq, Debug, Clone, Default)] -pub struct UserClaims(HashMap); +pub struct UserClaims(Value); impl UserClaims { pub fn claims_to_json_value(&self) -> Result { - let value = serde_yaml::to_value(&self.0) - .expect("Failed to convert user-claims into serde_yaml::Value"); - let filtered_value = _remove_tags(&value); + let filtered_value = _remove_tags(&self.0); let json_value: serde_json::Value = serde_yaml::from_value(filtered_value).expect("Failed to convert serde_json::Value"); @@ -46,11 +92,11 @@ impl UserClaims { } pub fn sd_claims_to_jsonpath(&self) -> Result> { - let mut path = "".to_string(); + let path = "".to_string(); let mut paths = Vec::new(); - let mut claims = serde_yaml::to_value(&self.0)?; + let mut claims = self.0.clone(); - let _ = generate_jsonpath_from_tagged_values(&mut claims, &mut path, &mut paths); + let _ = generate_jsonpath_from_tagged_values(&mut claims, path, &mut paths); Ok(paths) } @@ -60,7 +106,7 @@ fn _validate(value: &Value) -> Result<()> { match value { Value::String(_) | Value::Bool(_) | Value::Number(_) => Ok(()), Value::Tagged(tag) => { - if tag.tag.to_string() == SD_TAG { + if tag.tag == SD_TAG { _validate(&tag.value) } else { panic!( @@ -115,7 +161,7 @@ fn _remove_tags(original: &Value) -> Value { Value::Mapping(filtered_map) } Value::Sequence(seq) => { - let filtered_seq: Vec = seq.iter().map(|v| _remove_tags(v)).collect(); + let filtered_seq: Vec = seq.iter().map(_remove_tags).collect(); Value::Sequence(filtered_seq) } diff --git a/generate/src/utils/funcs.rs b/generate/src/utils/funcs.rs new file mode 100644 index 0000000..3dff777 --- /dev/null +++ b/generate/src/utils/funcs.rs @@ -0,0 +1,117 @@ +use std::collections::HashSet; +use std::path::PathBuf; + +use serde_json::Value; +use sd_jwt_rs::SDJWTSerializationFormat; +use sd_jwt_rs::utils::{base64_hash, base64url_decode}; +use sd_jwt_rs::utils::SALTS; +use crate::error::{Error, ErrorKind, Result}; + + +pub fn parse_sdjwt_paylod( + sd_jwt: &str, + serialization_format: &SDJWTSerializationFormat, + remove_decoy: bool +) -> Result { + + match serialization_format { + SDJWTSerializationFormat::JSON => { + parse_payload_json(sd_jwt, remove_decoy) + }, + SDJWTSerializationFormat::Compact => { + parse_payload_compact(sd_jwt, remove_decoy) + } + } +} + +fn parse_payload_json(sd_jwt: &str, remove_decoy: bool) -> Result { + let v: serde_json::Value = serde_json::from_str(sd_jwt).unwrap(); + + let disclosures = v.as_object().unwrap().get("disclosures").unwrap(); + + let mut hashes: HashSet = HashSet::new(); + + for disclosure in disclosures.as_array().unwrap() { + let hash = base64_hash(disclosure.as_str().unwrap().replace(' ', "").as_bytes()); + hashes.insert(hash.clone()); + } + + let ddd = v.as_object().unwrap().get("payload").unwrap().as_str().unwrap().replace(' ', ""); + let payload = base64url_decode(&ddd).unwrap(); + + let payload: serde_json::Value = serde_json::from_slice(&payload).unwrap(); + + if remove_decoy { + return Ok(remove_decoy_items(&payload, &hashes)); + } + + Ok(payload) +} + +fn parse_payload_compact(sd_jwt: &str, remove_decoy: bool) -> Result { + let mut disclosures: Vec = sd_jwt + .split('~') + .filter(|s| !s.is_empty()) + .map(String::from) + .collect(); + + let payload = disclosures.remove(0); + + let payload: Vec<_> = payload.split('.').collect(); + let payload = String::from(payload[1]); + + let mut hashes: HashSet = HashSet::new(); + + for disclosure in disclosures { + let hash = base64_hash(disclosure.as_bytes()); + hashes.insert(hash.clone()); + } + + let payload = base64url_decode(&payload).unwrap(); + + let payload: serde_json::Value = serde_json::from_slice(&payload).unwrap(); + + if remove_decoy { + return Ok(remove_decoy_items(&payload, &hashes)); + } + + Ok(payload) +} + +fn remove_decoy_items(payload: &Value, hashes: &HashSet) -> Value { + let mut map: serde_json::Map = serde_json::Map::new(); + + for (key, val) in payload.as_object().unwrap() { + if key == "_sd" { + let v1: Vec<_> = val.as_array().unwrap().iter() + .filter(|item| hashes.contains(item.as_str().unwrap())).cloned() + .collect(); + + let filtered_array = serde_json::Value::Array(v1); + map.insert(key.clone(), filtered_array); + } else if val.is_object() { + let filtered_object = remove_decoy_items(val, hashes); + map.insert(key.clone(), filtered_object); + } else { + map.insert(key.clone(), val.clone()); + } + } + + Value::Object(map) +} + +pub fn load_salts(path: &PathBuf) -> Result<()> { + let json_data = std::fs::read_to_string(path) + .map_err(|e| Error::from_msg(ErrorKind::IOError, e.to_string()))?; + let salts: Vec = serde_json::from_str(&json_data)?; + + { + let mut s = SALTS.lock().unwrap(); + + for salt in salts.iter() { + s.push_back(salt.clone()); + } + } + + Ok(()) +} diff --git a/generate/src/utils/generate.rs b/generate/src/utils/generate.rs index 831f318..a98cecc 100644 --- a/generate/src/utils/generate.rs +++ b/generate/src/utils/generate.rs @@ -4,60 +4,59 @@ use crate::error::Result; #[allow(unused)] pub fn generate_jsonpath_from_tagged_values( yaml: &Value, - path: &mut String, + mut path: String, paths: &mut Vec, ) -> Result<()> { + + if path.is_empty() { + path.push('$'); + } + match yaml { Value::Mapping(map) => { for (key, value) in map { - let len = path.len(); - - if path.is_empty() { - path.push_str("$."); - } // Handle nested - match key { - Value::Tagged(tagged) => { - path.push_str(tagged.value.as_str().unwrap()); - match value { - Value::Mapping(_) => { - path.push('.'); - generate_jsonpath_from_tagged_values(value, path, paths); - } - Value::Sequence(_) => { - generate_jsonpath_from_tagged_values(value, path, paths); - } - _ => {}, - } + let mut subpath: String; - if path.ends_with('.') { - path.pop().unwrap(); - } - - paths.push(path.clone()); + match key { + Value::Tagged(tagged) => { + subpath = format!("{}.{}", &path, tagged.value.as_str().unwrap()); + paths.push(subpath.clone()); + generate_jsonpath_from_tagged_values(value, subpath, paths); } Value::String(s) => { - path.push_str(s); - path.push('.'); - - generate_jsonpath_from_tagged_values(value, path, paths); + subpath = format!("{}.{}", &path, &s); + generate_jsonpath_from_tagged_values(value, subpath, paths); } _ => {} } - - path.truncate(len); } } Value::Sequence(seq) => { for (idx, value) in seq.iter().enumerate() { - let len = path.len(); - path.push_str(&format!("[{}].", idx)); - generate_jsonpath_from_tagged_values(value, path, paths); + let mut subpath = format!("{}.[{}]", &path, idx); + generate_jsonpath_from_tagged_values(value, subpath, paths); + } + } + Value::Tagged(tagged) => { + // TODO: handle other value types (int/bool/etc) - path.truncate(len); + match &tagged.value { + Value::Mapping(m) => { + paths.push(path.clone()); + generate_jsonpath_from_tagged_values(&tagged.value, path.clone(), paths); + } + Value::Sequence(s) => { + paths.push(path.clone()); + generate_jsonpath_from_tagged_values(&tagged.value, path.clone(), paths); + } + _ => { + paths.push(path.clone()); + } } + } _ => {} } diff --git a/generate/src/utils/mod.rs b/generate/src/utils/mod.rs index 118c66d..6b1d26e 100644 --- a/generate/src/utils/mod.rs +++ b/generate/src/utils/mod.rs @@ -1 +1,2 @@ -pub mod generate; \ No newline at end of file +pub mod generate; +pub mod funcs; \ No newline at end of file diff --git a/src/disclosure.rs b/src/disclosure.rs index 079fbfa..5c0ea65 100644 --- a/src/disclosure.rs +++ b/src/disclosure.rs @@ -1,4 +1,10 @@ -use crate::utils::{base64_hash, base64url_encode, generate_salt}; +use crate::utils::{base64_hash, base64url_encode}; +#[cfg(not(feature = "mock_salts"))] +use crate::utils::generate_salt; +#[cfg(feature = "mock_salts")] +use crate::utils::generate_salt_mock; +use serde_json::Value; + #[derive(Debug)] pub(crate) struct SDJWTDisclosure { @@ -8,19 +14,31 @@ pub(crate) struct SDJWTDisclosure { impl SDJWTDisclosure { pub(crate) fn new(key: Option, value: V) -> Self where V: ToString { - let salt = generate_salt(key.clone()); + #[cfg(not(feature = "mock_salts"))] + let salt = generate_salt(); let mut value_str = value.to_string(); - value_str = value_str.replace(":[", ": [").replace(',', ", "); - let (_data, raw_b64) = if let Some(key) = &key { //TODO remove data? - let data = format!(r#"["{}", "{}", {}]"#, salt, key, value_str); - let raw_b64 = base64url_encode(data.as_bytes()); - (data, raw_b64) + + #[cfg(feature = "mock_salts")] + let salt = { + value_str = value_str + .replace(":[", ": [") + .replace(',', ", ") + .replace("\":", "\": ") + .replace("\": ", "\": "); + generate_salt_mock() + }; + + if !value_str.is_ascii() { + value_str = escape_unicode_chars(&value_str); + } + + let data = if let Some(key) = &key { + format!(r#"["{}", {}, {}]"#, salt, escape_json(key), value_str) } else { - let data = format!(r#"["{}", {}]"#, salt, value_str); - let raw_b64 = base64url_encode(data.as_bytes()); - (data, raw_b64) + format!(r#"["{}", {}]"#, salt, value_str) }; + let raw_b64 = base64url_encode(data.as_bytes()); let hash = base64_hash(raw_b64.as_bytes()); Self { @@ -30,6 +48,33 @@ impl SDJWTDisclosure { } } +fn escape_unicode_chars(s: &str) -> String { + let mut result = String::new(); + + for c in s.chars() { + if c.is_ascii() { + result.push(c); + } else { + let esc_c = c.escape_unicode().to_string(); + + let esc_c_new = match esc_c.chars().count() { + 6 => esc_c.replace("\\u{", "\\u00").replace('}', ""), // example: \u{de} + 7 => esc_c.replace("\\u{", "\\u0").replace('}', ""), // example: \u{980} + 8 => esc_c.replace("\\u{", "\\u").replace('}', ""), // example: \u{23f0} + _ => {panic!("unexpected value")} + }; + + result.push_str(&esc_c_new); + } + } + + result +} + +fn escape_json(s: &str) -> String { + Value::String(String::from(s)).to_string() +} + #[cfg(test)] mod tests { use super::*; diff --git a/src/holder.rs b/src/holder.rs index 0f93ebb..3ddb2c6 100644 --- a/src/holder.rs +++ b/src/holder.rs @@ -433,10 +433,10 @@ mod tests { let mut parts: Vec<&str> = issued .split(COMBINED_SERIALIZATION_FORMAT_SEPARATOR) .collect(); + parts.remove(5); parts.remove(4); parts.remove(3); parts.remove(2); - parts.remove(1); let expected = parts.join(COMBINED_SERIALIZATION_FORMAT_SEPARATOR); assert_eq!(expected, presentation); } diff --git a/src/issuer.rs b/src/issuer.rs index 3ac501a..c156259 100644 --- a/src/issuer.rs +++ b/src/issuer.rs @@ -32,9 +32,9 @@ pub struct SDJWTIssuer { // internal fields inner: SDJWTCommon, all_disclosures: Vec, - pub sd_jwt_payload: SJMap, - pub signed_sd_jwt: String, - pub serialized_sd_jwt: String, + sd_jwt_payload: SJMap, + signed_sd_jwt: String, + serialized_sd_jwt: String, } /// ClaimsForSelectiveDisclosureStrategy is used to determine which claims can be selectively disclosed later by the holder. @@ -196,7 +196,7 @@ impl SDJWTIssuer { let always_revealed_root_keys = vec!["iss", "iat", "exp"]; let mut always_revealed_claims: Map = always_revealed_root_keys .into_iter() - .filter_map(|key| claims_obj_ref.remove_entry(key)) + .filter_map(|key| claims_obj_ref.shift_remove_entry(key)) .collect(); self.sd_jwt_payload = self @@ -252,6 +252,10 @@ impl SDJWTIssuer { sd_strategy: ClaimsForSelectiveDisclosureStrategy, ) -> Value { let mut claims = SJMap::new(); + + // to have the first key "_sd" in the ordered map + claims.insert(SD_DIGESTS_KEY.to_owned(), Value::Null); + let mut sd_claims = Vec::new(); for (key, value) in user_claims.iter() { @@ -281,6 +285,8 @@ impl SDJWTIssuer { SD_DIGESTS_KEY.to_owned(), Value::Array(sd_claims.into_iter().map(Value::String).collect()), ); + } else { + claims.shift_remove(SD_DIGESTS_KEY); } Value::Object(claims) @@ -353,7 +359,7 @@ impl SDJWTIssuer { } fn create_decoy_claim_entry(&mut self) -> String { - let digest = base64_hash(generate_salt(None).as_bytes()).to_string(); + let digest = base64_hash(generate_salt().as_bytes()).to_string(); digest } } diff --git a/src/utils.rs b/src/utils.rs index 4cd14fc..76880a1 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -12,14 +12,15 @@ use rand::RngCore; use serde_json::Value; use sha2::Digest; #[cfg(feature = "mock_salts")] -use std::{collections::HashMap, sync::Mutex}; +use std::{collections::VecDeque, sync::Mutex}; #[cfg(feature = "mock_salts")] lazy_static! { - pub static ref SALTS: Mutex> = Mutex::new(HashMap::new()); + pub static ref SALTS: Mutex> = Mutex::new(VecDeque::new()); } -pub(crate) fn base64_hash(data: &[u8]) -> String { +#[doc(hidden)] +pub fn base64_hash(data: &[u8]) -> String { let mut hasher = sha2::Sha256::new(); hasher.update(data); let hash = hasher.finalize(); @@ -31,28 +32,25 @@ pub(crate) fn base64url_encode(data: &[u8]) -> String { general_purpose::URL_SAFE_NO_PAD.encode(data) } -pub(crate) fn base64url_decode(b64data: &str) -> Result> { +#[doc(hidden)] +pub fn base64url_decode(b64data: &str) -> Result> { general_purpose::URL_SAFE_NO_PAD .decode(b64data) .map_err(|e| Error::DeserializationError(e.to_string())) } -pub(crate) fn generate_salt(_key_for_predefined_salt: Option) -> String { - - #[cfg(feature = "mock_salts")] - { - let map = SALTS.lock().unwrap(); - if let Some(salt) = _key_for_predefined_salt.and_then(|key| map.get(&key)) { - //FIXME better mock approach - return salt.clone() - } - } - +pub(crate) fn generate_salt() -> String { let mut buf = [0u8; 16]; ThreadRng::default().fill_bytes(&mut buf); base64url_encode(&buf) } +#[cfg(feature = "mock_salts")] +pub(crate) fn generate_salt_mock() -> String { + let mut salts = SALTS.lock().unwrap(); + return salts.pop_front().expect("SALTS is empty"); +} + pub(crate) fn jwt_payload_decode(b64data: &str) -> Result> { serde_json::from_str( &String::from_utf8( diff --git a/src/verifier.rs b/src/verifier.rs index 748ae53..7b27752 100644 --- a/src/verifier.rs +++ b/src/verifier.rs @@ -231,15 +231,15 @@ impl SDJWTVerifier { fn unpack_disclosed_claims(&mut self, sd_jwt_claims: &Value) -> Result { match sd_jwt_claims { Value::Null | Value::Bool(_) | Value::Number(_) | Value::String(_) => { - return Ok(sd_jwt_claims.to_owned()); + Ok(sd_jwt_claims.to_owned()) } Value::Array(arr) => { - return self.unpack_disclosed_claims_in_array(arr); + self.unpack_disclosed_claims_in_array(arr) } Value::Object(obj) => { - return self.unpack_disclosed_claims_in_object(obj); + self.unpack_disclosed_claims_in_object(obj) } - }; + } } fn unpack_disclosed_claims_in_array(&mut self, arr: &Vec) -> Result { @@ -263,8 +263,8 @@ impl SDJWTVerifier { let digest = obj.get(SD_LIST_PREFIX).unwrap(); let disclosed_claim = self.unpack_from_digest(digest)?; - if disclosed_claim.is_some() { - claims.push(disclosed_claim.unwrap()); + if let Some(disclosed_claim) = disclosed_claim { + claims.push(disclosed_claim); } }, _ => { @@ -273,7 +273,7 @@ impl SDJWTVerifier { }, } } - return Ok(Value::Array(claims)); + Ok(Value::Array(claims)) } fn unpack_disclosed_claims_in_object(&mut self, nested_sd_jwt_claims: &Map) -> Result {