Skip to content

Commit

Permalink
refactoring: Move common functions from SDJWTCommon to utils.
Browse files Browse the repository at this point in the history
Signed-off-by: Abdulbois <[email protected]>
Signed-off-by: Abdulbois <[email protected]>
  • Loading branch information
Abdulbois committed Dec 15, 2023
1 parent ea656af commit ae77493
Show file tree
Hide file tree
Showing 7 changed files with 52 additions and 57 deletions.
Binary file added .DS_Store
Binary file not shown.
12 changes: 6 additions & 6 deletions src/disclosure.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::SDJWTCommon;
use crate::utils::{base64_hash, base64url_encode, generate_salt};

#[derive(Debug)]
pub(crate) struct SDJWTDisclosure {
Expand All @@ -7,21 +7,21 @@ pub(crate) struct SDJWTDisclosure {
}

impl SDJWTDisclosure {
pub(crate) fn new<V>(key: Option<String>, value: V, inner: &SDJWTCommon) -> Self where V: ToString {
let salt = SDJWTCommon::generate_salt(key.clone());
pub(crate) fn new<V>(key: Option<String>, value: V) -> Self where V: ToString {
let salt = generate_salt(key.clone());
let mut value_str = value.to_string();
value_str = value_str.replace(":[", ": [").replace(',', ", ");
let (_data, raw_b64) = if let Some(key) = &key { //TODO remove data?
let data = format!(r#"["{}", "{}", {}]"#, salt, key, value_str);
let raw_b64 = SDJWTCommon::base64url_encode(data.as_bytes());
let raw_b64 = base64url_encode(data.as_bytes());
(data, raw_b64)
} else {
let data = format!(r#"["{}", {}]"#, salt, value_str);
let raw_b64 = SDJWTCommon::base64url_encode(data.as_bytes());
let raw_b64 = base64url_encode(data.as_bytes());
(data, raw_b64)
};

let hash = inner.b64hash(raw_b64.as_bytes());
let hash = base64_hash(raw_b64.as_bytes());

Self {
raw_b64,
Expand Down
4 changes: 2 additions & 2 deletions src/holder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use serde_json::{Map, Value};

use crate::{COMBINED_SERIALIZATION_FORMAT_SEPARATOR, DEFAULT_SIGNING_ALG, KB_DIGEST_KEY, SD_DIGESTS_KEY, SD_LIST_PREFIX};
use crate::SDJWTCommon;
use crate::utils::create_base64_encoded_hash;
use crate::utils::base64_hash;

pub struct SDJWTHolder {
sd_jwt_engine: SDJWTCommon,
Expand Down Expand Up @@ -205,7 +205,7 @@ impl SDJWTHolder {
combined.extend(self.hs_disclosures.iter().map(|s| s.as_str()));
let combined = combined.join(COMBINED_SERIALIZATION_FORMAT_SEPARATOR);

let _sd_hash = create_base64_encoded_hash(combined);
let _sd_hash = base64_hash(combined.as_bytes());
let _sd_hash = serde_json::to_value(&_sd_hash).unwrap();
self.key_binding_jwt_payload.insert(KB_DIGEST_KEY.to_owned(), _sd_hash);
}
Expand Down
7 changes: 4 additions & 3 deletions src/issuer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ use serde_json::Value;

use crate::{COMBINED_SERIALIZATION_FORMAT_SEPARATOR, DEFAULT_DIGEST_ALG, DIGEST_ALG_KEY, SD_DIGESTS_KEY, SD_LIST_PREFIX, DEFAULT_SIGNING_ALG, SDJWTCommon, SDJWTHasSDClaimException, CNF_KEY, JWK_KEY};
use crate::disclosure::SDJWTDisclosure;
use crate::utils::{base64_hash, generate_salt};

pub struct SDJWTIssuer {
// parameters
Expand Down Expand Up @@ -179,7 +180,7 @@ impl SDJWTIssuer {
let subtree = self.create_sd_claims(object, strategy_for_child);

if sd_strategy.sd_for_key(&key) {
let disclosure = SDJWTDisclosure::new(None, subtree, &self.inner);
let disclosure = SDJWTDisclosure::new(None, subtree);
claims.push(json!({ SD_LIST_PREFIX: disclosure.hash}));
self.all_disclosures.push(disclosure);
} else {
Expand All @@ -198,7 +199,7 @@ impl SDJWTIssuer {
let subtree_from_here = self.create_sd_claims(value, strategy_for_child);

if sd_strategy.sd_for_key(key) {
let disclosure = SDJWTDisclosure::new(Some(key.to_owned()), subtree_from_here, &self.inner);
let disclosure = SDJWTDisclosure::new(Some(key.to_owned()), subtree_from_here);
sd_claims.push(disclosure.hash.clone());
self.all_disclosures.push(disclosure);
} else {
Expand Down Expand Up @@ -262,7 +263,7 @@ impl SDJWTIssuer {
}

fn create_decoy_claim_entry(&mut self) -> String {
let digest = self.inner.b64hash(SDJWTCommon::generate_salt(None).as_bytes()).to_string();
let digest = base64_hash(generate_salt(None).as_bytes()).to_string();
digest
}
}
Expand Down
50 changes: 8 additions & 42 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,17 +1,15 @@
use std::collections::HashMap;
use base64::{engine::general_purpose, Engine};
use crate::utils::{base64_hash, base64url_decode, jwt_payload_decode};
use lazy_static::lazy_static;
use rand::prelude::ThreadRng;
use rand::RngCore;
use serde_json::{Map, Value};
use sha2::Digest;
use std::collections::HashMap;
use std::sync::Mutex;
pub use {holder::SDJWTHolder, issuer::SDJWTIssuer, verifier::SDJWTVerifier};

mod disclosure;
pub mod holder;
pub mod issuer;
pub mod verifier;
pub mod utils;
pub mod verifier;

pub const DEFAULT_SIGNING_ALG: &str = "ES256";
const SD_DIGESTS_KEY: &str = "_sd";
Expand Down Expand Up @@ -47,26 +45,19 @@ pub(crate) struct SDJWTCommon {

// Define the SDJWTCommon struct to hold common properties.
impl SDJWTCommon {
fn b64hash(&self, data: &[u8]) -> String {
let mut hasher = sha2::Sha256::new(); // TODO dynamic type
hasher.update(data);
let hash = hasher.finalize();
SDJWTCommon::base64url_encode(&hash)
}

fn create_hash_mappings(&mut self) -> Result<(), String> {
self.hash_to_decoded_disclosure = HashMap::new();
self.hash_to_disclosure = HashMap::new();

for disclosure in &self.input_disclosures {
let decoded_disclosure = SDJWTCommon::base64url_decode(disclosure).map_err(
let decoded_disclosure = base64url_decode(disclosure).map_err(
|err| format!("Error decoding disclosure {}: {}", disclosure, err)
)?;
let decoded_disclosure: Value = serde_json::from_slice(&decoded_disclosure).map_err(
|err| format!("Error parsing disclosure {}: {}", disclosure, err)
)?;

let hash = self.b64hash(disclosure.as_bytes());
let hash = base64_hash(disclosure.as_bytes());
if self.hash_to_decoded_disclosure.contains_key(&hash) {
return Err(format!("Duplicate disclosure hash {} for disclosure {:?}", hash, decoded_disclosure));
}
Expand Down Expand Up @@ -115,47 +106,22 @@ impl SDJWTCommon {
let mut sd_jwt = sd_jwt.split(JWT_SEPARATOR);
sd_jwt.next();
let jwt_body = sd_jwt.next().unwrap();
self.unverified_input_sd_jwt_payload = Some(SDJWTCommon::jwt_payload_decode(jwt_body).unwrap());
self.unverified_input_sd_jwt_payload = Some(jwt_payload_decode(jwt_body).unwrap());
Ok(())
} else {
// If the SD-JWT is in JSON format, parse the JSON and extract the disclosures.
let unverified_input_sd_jwt_parsed: Value = serde_json::from_str(&sd_jwt_with_disclosures).unwrap();
self.unverified_input_key_binding_jwt = unverified_input_sd_jwt_parsed.get(JWS_KEY_KB_JWT).map(Value::to_string);
self.input_disclosures = unverified_input_sd_jwt_parsed[JWS_KEY_DISCLOSURES].as_array().unwrap().iter().map(Value::to_string).collect();
let payload = unverified_input_sd_jwt_parsed["payload"].as_str().unwrap();
self.unverified_input_sd_jwt_payload = Some(SDJWTCommon::jwt_payload_decode(payload).unwrap());
self.unverified_input_sd_jwt_payload = Some(jwt_payload_decode(payload).unwrap());
Ok(())
}
}

fn get_serialization_format(&self) -> &str {
&self.serialization_format
}

fn base64url_encode(data: &[u8]) -> String {
general_purpose::URL_SAFE_NO_PAD.encode(data)
}

fn base64url_decode(b64data: &str) -> Result<Vec<u8>, base64::DecodeError> {
general_purpose::URL_SAFE_NO_PAD.decode(b64data)
}

fn jwt_payload_decode(b64data: &str) -> Result<serde_json::Map<String, Value>, SDJWTHasSDClaimException> {
Ok(serde_json::from_str(&String::from_utf8(Self::base64url_decode(b64data).unwrap()).unwrap()).unwrap())
}


fn generate_salt(key_for_predefined_salt: Option<String>) -> String {
let map = SALTS.lock().unwrap();

if let Some(salt) = key_for_predefined_salt.and_then(|key|map.get(&key)) { //FIXME better mock approach
salt.clone()
} else {
let mut buf = [0u8; 16];
ThreadRng::default().fill_bytes(&mut buf);
Self::base64url_encode(&buf)
}
}
}

lazy_static! {
Expand Down
32 changes: 30 additions & 2 deletions src/utils.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,39 @@
use base64::Engine;
use base64::engine::general_purpose;
use rand::prelude::ThreadRng;
use rand::RngCore;
use serde_json::Value;
use sha2::Digest;
use crate::{SALTS, SDJWTHasSDClaimException};

pub fn create_base64_encoded_hash(data: String) -> String {
pub(crate) fn base64_hash(data: &[u8]) -> String {
let mut hasher = sha2::Sha256::new();
hasher.update(data);
let hash = hasher.finalize();

general_purpose::URL_SAFE_NO_PAD.encode(&hash)
}
}

pub(crate) fn base64url_encode(data: &[u8]) -> String {
general_purpose::URL_SAFE_NO_PAD.encode(data)
}

pub(crate) fn base64url_decode(b64data: &str) -> Result<Vec<u8>, base64::DecodeError> {
general_purpose::URL_SAFE_NO_PAD.decode(b64data)
}

pub(crate) fn generate_salt(key_for_predefined_salt: Option<String>) -> String {
let map = SALTS.lock().unwrap();

if let Some(salt) = key_for_predefined_salt.and_then(|key| map.get(&key)) { //FIXME better mock approach
salt.clone()
} else {
let mut buf = [0u8; 16];
ThreadRng::default().fill_bytes(&mut buf);
base64url_encode(&buf)
}
}

pub(crate) fn jwt_payload_decode(b64data: &str) -> Result<serde_json::Map<String, Value>, SDJWTHasSDClaimException> {
Ok(serde_json::from_str(&String::from_utf8(base64url_decode(b64data).unwrap()).unwrap()).unwrap())
}
4 changes: 2 additions & 2 deletions src/verifier.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use log::debug;
use serde_json::{Map, Value};

use crate::{CNF_KEY, COMBINED_SERIALIZATION_FORMAT_SEPARATOR, DEFAULT_DIGEST_ALG, DEFAULT_SIGNING_ALG, DIGEST_ALG_KEY, JWK_KEY, KB_DIGEST_KEY, KB_JWT_TYP_HEADER, SD_DIGESTS_KEY, SDJWTCommon};
use crate::utils::create_base64_encoded_hash;
use crate::utils::base64_hash;

type KeyResolver = dyn Fn(&str, &Header) -> DecodingKey;

Expand Down Expand Up @@ -148,7 +148,7 @@ impl SDJWTVerifier {
combined.extend(self.sd_jwt_engine.input_disclosures.iter().map(|s| s.as_str()));
let combined = combined.join(COMBINED_SERIALIZATION_FORMAT_SEPARATOR);

create_base64_encoded_hash(combined)
base64_hash(combined.as_bytes())
}

fn extract_sd_claims(&mut self) -> Result<Value, String> {
Expand Down

0 comments on commit ae77493

Please sign in to comment.