From 3289ab9f3b7d631421f6159d2dccf5b67d4de19e Mon Sep 17 00:00:00 2001 From: Simone Cottini Date: Wed, 6 Mar 2024 15:15:29 +0100 Subject: [PATCH] Replace actual jwks impl to use jwks_client --- Cargo.toml | 3 +- src/auth0/cache/inmemory.rs | 25 ------------- src/auth0/cache/mod.rs | 10 ----- src/auth0/cache/redis_impl.rs | 15 -------- src/auth0/errors.rs | 8 ++-- src/auth0/keyset.rs | 70 ----------------------------------- src/auth0/mod.rs | 56 +++++++++------------------- src/auth0/token.rs | 6 +++ 8 files changed, 29 insertions(+), 164 deletions(-) delete mode 100644 src/auth0/keyset.rs diff --git a/Cargo.toml b/Cargo.toml index 00a6b0d..8769b7a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,7 +11,7 @@ version = "0.15.1-rc.0" rust-version = "1.72" [features] -auth0 = ["rand", "redis", "jsonwebtoken", "chrono", "chrono-tz", "aes", "cbc", "dashmap", "tracing"] +auth0 = ["rand", "redis", "jsonwebtoken", "jwks_client_rs", "chrono", "chrono-tz", "aes", "cbc", "dashmap", "tracing"] default = ["tracing_opentelemetry"] gzip = ["reqwest/gzip"] redis-tls = ["redis/tls", "redis/tokio-native-tls-comp"] @@ -28,6 +28,7 @@ dashmap = {version = "5.1", optional = true} futures = "0.3" futures-util = "0.3" jsonwebtoken = {version = "9.0", optional = true} +jwks_client_rs = {version = "0.5", optional = true} opentelemetry = {version = ">=0.17, <=0.20", optional = true} rand = {version = "0.8", optional = true} redis = {version = "0.23", features = ["tokio-comp"], optional = true} diff --git a/src/auth0/cache/inmemory.rs b/src/auth0/cache/inmemory.rs index f4d296a..22250c5 100644 --- a/src/auth0/cache/inmemory.rs +++ b/src/auth0/cache/inmemory.rs @@ -2,7 +2,6 @@ use dashmap::DashMap; use crate::auth0::cache::{self, crypto}; use crate::auth0::errors::Auth0Error; -use crate::auth0::keyset::JsonWebKeySet; use crate::auth0::token::Token; use crate::auth0::{cache::Cache, Config}; @@ -40,20 +39,6 @@ impl Cache for InMemoryCache { let _ = self.key_value.insert(key, encrypted_value); Ok(()) } - - async fn get_jwks(&self) -> Result, Auth0Error> { - self.key_value - .get(&cache::jwks_key(&self.caller, &self.audience)) - .map(|value| crypto::decrypt(self.encryption_key.as_str(), value.as_slice())) - .transpose() - } - - async fn put_jwks(&self, value_ref: &JsonWebKeySet, _expiration: Option) -> Result<(), Auth0Error> { - let key: String = cache::jwks_key(&self.caller, &self.audience); - let encrypted_value: Vec = crypto::encrypt(value_ref, self.encryption_key.as_str())?; - let _ = self.key_value.insert(key, encrypted_value); - Ok(()) - } } #[cfg(test)] @@ -70,9 +55,6 @@ mod tests { let result: Option = cache.get_token().await.unwrap(); assert!(result.is_none()); - let result: Option = cache.get_jwks().await.unwrap(); - assert!(result.is_none()); - let token_str: &str = "token"; let token: Token = Token::new(token_str.to_string(), Utc::now(), Utc::now()); cache.put_token(&token).await.unwrap(); @@ -80,12 +62,5 @@ mod tests { let result: Option = cache.get_token().await.unwrap(); assert!(result.is_some()); assert_eq!(result.unwrap().as_str(), token_str); - - let string: &str = "{\"keys\": []}"; - let jwks: JsonWebKeySet = serde_json::from_str(string).unwrap(); - cache.put_jwks(&jwks, None).await.unwrap(); - - let result: Option = cache.get_jwks().await.unwrap(); - assert!(result.is_some()); } } diff --git a/src/auth0/cache/mod.rs b/src/auth0/cache/mod.rs index af7c8bf..01d671d 100644 --- a/src/auth0/cache/mod.rs +++ b/src/auth0/cache/mod.rs @@ -2,7 +2,6 @@ pub use inmemory::InMemoryCache; pub use redis_impl::RedisCache; use crate::auth0::errors::Auth0Error; -use crate::auth0::keyset::JsonWebKeySet; use crate::auth0::Token; mod crypto; @@ -10,23 +9,14 @@ mod inmemory; mod redis_impl; const TOKEN_PREFIX: &str = "auth0rs_tokens"; -const JWKS_PREFIX: &str = "auth0rs_jwks"; #[async_trait::async_trait] pub trait Cache: Send + Sync + std::fmt::Debug { async fn get_token(&self) -> Result, Auth0Error>; async fn put_token(&self, value_ref: &Token) -> Result<(), Auth0Error>; - - async fn get_jwks(&self) -> Result, Auth0Error>; - - async fn put_jwks(&self, value_ref: &JsonWebKeySet, expiration: Option) -> Result<(), Auth0Error>; } pub(in crate::auth0::cache) fn token_key(caller: &str, audience: &str) -> String { format!("{}:{}:{}", TOKEN_PREFIX, caller, audience) } - -pub(in crate::auth0::cache) fn jwks_key(caller: &str, audience: &str) -> String { - format!("{}:{}:{}", JWKS_PREFIX, caller, audience) -} diff --git a/src/auth0/cache/redis_impl.rs b/src/auth0/cache/redis_impl.rs index 629ca24..3b13001 100644 --- a/src/auth0/cache/redis_impl.rs +++ b/src/auth0/cache/redis_impl.rs @@ -2,7 +2,6 @@ use redis::AsyncCommands; use serde::Deserialize; use crate::auth0::cache::{self, crypto, Cache}; -use crate::auth0::keyset::JsonWebKeySet; use crate::auth0::token::Token; use crate::auth0::{Auth0Error, Config}; @@ -57,20 +56,6 @@ impl Cache for RedisCache { connection.set_ex(key, encrypted_value, expiration).await?; Ok(()) } - - async fn get_jwks(&self) -> Result, Auth0Error> { - let key: &str = &cache::jwks_key(&self.caller, &self.audience); - self.get(key).await - } - - async fn put_jwks(&self, value_ref: &JsonWebKeySet, expiration: Option) -> Result<(), Auth0Error> { - let key: &str = &cache::jwks_key(&self.caller, &self.audience); - let mut connection = self.client.get_async_connection().await?; - let encrypted_value: Vec = crypto::encrypt(value_ref, self.encryption_key.as_str())?; - let expiration: usize = expiration.unwrap_or(86400); - connection.set_ex(key, encrypted_value, expiration).await?; - Ok(()) - } } // To run this test (it works): diff --git a/src/auth0/errors.rs b/src/auth0/errors.rs index bb79cd1..246f39c 100644 --- a/src/auth0/errors.rs +++ b/src/auth0/errors.rs @@ -11,10 +11,10 @@ pub enum Auth0Error { JwtFetchError(u16, String, reqwest::Error), #[error("failed to deserialize jwt from {0}. {1}")] JwtFetchDeserializationError(String, reqwest::Error), - #[error("failed to fetch jwks from {0}. Status code: {0}; error: {1}")] - JwksFetchError(u16, String, reqwest::Error), - #[error("failed to deserialize jwks from {0}. {1}")] - JwksFetchDeserializationError(String, reqwest::Error), + #[error(transparent)] + JwksClientError(#[from] jwks_client_rs::JwksClientError), + #[error("failed to fetch jwt from {0}. Status code: {0}; error: {1}")] + JwksHttpError(String, reqwest::Error), #[error("redis error: {0}")] RedisError(#[from] redis::RedisError), #[error(transparent)] diff --git a/src/auth0/keyset.rs b/src/auth0/keyset.rs deleted file mode 100644 index 8cfc407..0000000 --- a/src/auth0/keyset.rs +++ /dev/null @@ -1,70 +0,0 @@ -// https://tools.ietf.org/id/draft-ietf-jose-json-web-key-00.html#rfc.section.3.1 - -use reqwest::Client; -use serde::Deserialize; -use serde::Serialize; - -use crate::auth0::errors::Auth0Error; -use crate::auth0::token::Token; -use crate::auth0::Config; - -#[derive(Serialize, Deserialize, Debug, Clone)] -pub struct JsonWebKeySet { - keys: Vec, -} - -impl JsonWebKeySet { - pub fn is_signed(&self, token_ref: &Token) -> bool { - let kid: String = jsonwebtoken::decode_header(token_ref.as_str()) - .map(|headers| headers.kid) - .unwrap_or(None) - .unwrap_or_else(|| "-".to_string()); - - self.keys.iter().any(|key| key.key_id() == kid) - } - - pub async fn fetch(client_ref: &Client, config_ref: &Config) -> Result { - client_ref - .get(config_ref.jwks_url().clone()) - .send() - .await - .map_err(|e| { - Auth0Error::JwksFetchError( - e.status().map(|v| v.as_u16()).unwrap_or_default(), - config_ref.jwks_url().as_str().to_string(), - e, - ) - })? - .json::() - .await - .map_err(|e| Auth0Error::JwksFetchDeserializationError(config_ref.jwks_url().as_str().to_string(), e)) - } -} - -// https://tools.ietf.org/id/draft-ietf-jose-json-web-key-00.html#rfc.section.3 -#[derive(Serialize, Deserialize, Debug, Clone)] -#[serde(tag = "kty")] -pub enum JsonWebKey { - #[serde(alias = "RSA")] - Rsa(RsaPublicJwk), -} - -impl JsonWebKey { - fn key_id(&self) -> &str { - match self { - JsonWebKey::Rsa(rsa_pk) => rsa_pk.key_id(), - } - } -} - -#[derive(Serialize, Deserialize, Debug, Clone)] -pub struct RsaPublicJwk { - #[serde(rename = "kid")] - key_id: String, -} - -impl RsaPublicJwk { - fn key_id(&self) -> &str { - &self.key_id - } -} diff --git a/src/auth0/mod.rs b/src/auth0/mod.rs index 793ae09..dae4b0a 100644 --- a/src/auth0/mod.rs +++ b/src/auth0/mod.rs @@ -1,7 +1,10 @@ //! Stuff used to provide JWT authentication via Auth0 use std::sync::{Arc, RwLock, RwLockReadGuard, RwLockWriteGuard}; +use std::time::Duration; +use jwks_client_rs::JwksClient; +use jwks_client_rs::source::WebSource; use reqwest::Client; use tokio::task::JoinHandle; use tokio::time::Interval; @@ -11,17 +14,16 @@ pub use errors::Auth0Error; use util::ResultExt; use crate::auth0::cache::Cache; -use crate::auth0::keyset::JsonWebKeySet; +use crate::auth0::token::Claims; pub use crate::auth0::token::Token; mod cache; mod config; mod errors; -mod keyset; mod token; mod util; -#[derive(Clone, Debug)] +#[derive(Debug)] pub struct Auth0 { token_lock: Arc>, } @@ -34,20 +36,25 @@ impl Auth0 { Arc::new(cache::RedisCache::new(&config).await?) }; - let jwks: JsonWebKeySet = get_jwks(client_ref, &cache, &config).await?; + let source: WebSource = WebSource::builder() + .with_timeout(Duration::from_secs(5)) + .with_connect_timeout(Duration::from_secs(55)) + .build(config.jwks_url().to_owned()) + .map_err(|err| Auth0Error::JwksHttpError(config.token_url().as_str().to_string(), err))?; + + let jwks_client = JwksClient::builder().build(source); let token: Token = get_token(client_ref, &cache, &config).await?; - let jwks_lock: Arc> = Arc::new(RwLock::new(jwks)); let token_lock: Arc> = Arc::new(RwLock::new(token)); start( - jwks_lock.clone(), token_lock.clone(), + jwks_client.clone(), client_ref.clone(), cache.clone(), config, ) - .await; + .await; Ok(Self { token_lock }) } @@ -58,8 +65,8 @@ impl Auth0 { } async fn start( - jwks_lock: Arc>, token_lock: Arc>, + jwks_client: JwksClient, client: Client, cache: Arc, config: Config, @@ -82,22 +89,10 @@ async fn start( if token.needs_refresh(&config) { tracing::info!("Refreshing JWT and JWKS"); - let jwks_opt = match JsonWebKeySet::fetch(&client, &config).await { - Ok(jwks) => { - let _ = cache.put_jwks(&jwks, None).await.log_err("Error caching JWKS"); - write(&jwks_lock, jwks.clone()); - Some(jwks) - } - Err(error) => { - tracing::error!("Failed to fetch JWKS. Reason: {:?}", error); - None - } - }; - match Token::fetch(&client, &config).await { Ok(token) => { - let is_signed: Option = jwks_opt.map(|j| j.is_signed(&token)); - tracing::info!("is signed: {}", is_signed.unwrap_or_default()); + let is_signed: bool = jwks_client.decode::(token.as_str(), &[config.audience()]).await.is_ok(); + tracing::info!("is signed: {}", is_signed); let _ = cache.put_token(&token).await.log_err("Error caching JWT"); write(&token_lock, token); @@ -111,23 +106,6 @@ async fn start( }) } -// Try to fetch the jwks from cache. If it's found return it; fetch from auth0 and put in cache otherwise -async fn get_jwks( - client_ref: &Client, - cache_ref: &Arc, - config_ref: &Config, -) -> Result { - match cache_ref.get_jwks().await? { - Some(jwks) => Ok(jwks), - None => { - let jwks: JsonWebKeySet = JsonWebKeySet::fetch(client_ref, config_ref).await?; - let _ = cache_ref.put_jwks(&jwks, None).await.log_err("JWKS cache set failed"); - - Ok(jwks) - } - } -} - // Try to fetch the token from cache. If it's found return it; fetch from auth0 and put in cache otherwise async fn get_token(client_ref: &Client, cache_ref: &Arc, config_ref: &Config) -> Result { match cache_ref.get_token().await? { diff --git a/src/auth0/token.rs b/src/auth0/token.rs index f2b663c..3b74197 100644 --- a/src/auth0/token.rs +++ b/src/auth0/token.rs @@ -131,3 +131,9 @@ impl From<&Config> for FetchTokenRequest { } } } + +#[derive(Deserialize, Debug)] +pub struct Claims { + #[serde(default)] + pub permissions: Vec, +} \ No newline at end of file