Skip to content

Commit

Permalink
feat: Add Optional Support for Rocket-Okapi in IntrospectedUser (#559)
Browse files Browse the repository at this point in the history
Allow Rocket-Okapi to be used with a specialized feature.
  • Loading branch information
NewtTheWolf authored Sep 26, 2024
1 parent cff3c30 commit 2fb5786
Show file tree
Hide file tree
Showing 3 changed files with 165 additions and 0 deletions.
5 changes: 5 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,9 @@ oidc = ["credentials", "dep:base64-compat"]
## Refer to the rocket module for more information.
rocket = ["credentials", "oidc", "dep:rocket"]

## Feature that enables support for the [rocket okapi](https://github.com/GREsau/okapi).
rocket_okapi = ["rocket", "dep:rocket_okapi", "dep:schemars"]

# @@protoc_deletion_point(features)
# This section is automatically generated by protoc-gen-prost-crate.
# Changes in this area may be lost on regeneration.
Expand Down Expand Up @@ -167,6 +170,8 @@ tokio = { version = "1.37.0", optional = true, features = [
tonic = { version = "0.12.1", features = [
"tls",
], optional = true }
rocket_okapi = { version = "0.8.0", optional = true, default-features = false }
schemars = {version = "0.8.21", optional = true}
tonic-types = { version = "0.12.1", optional = true }

[dev-dependencies]
Expand Down
21 changes: 21 additions & 0 deletions src/rocket/introspection/config.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use openidconnect::IntrospectionUrl;
use serde::Deserialize;

#[cfg(feature = "introspection_cache")]
use crate::oidc::introspection::cache::IntrospectionCache;
Expand All @@ -18,3 +19,23 @@ pub struct IntrospectionConfig {
#[cfg(feature = "introspection_cache")]
pub(crate) cache: Option<Box<dyn IntrospectionCache>>,
}

#[cfg(feature = "rocket_okapi")]
/// Configuration for OAuth token introspection read from a Rocket.toml file.
///
/// # Fields
/// - `authority`: A string representing the authority URL used for introspection. This is typically
/// the base URL of the OAuth provider that will validate the tokens.
///
/// # Example
/// ```toml
/// [default]
/// authority = "https://auth.example.com/"
/// ```
///
/// # Features
/// This struct is only available when the `rocket_okapi` feature is enabled.
#[derive(Debug, Deserialize)]
pub struct IntrospectionRocketConfig {
pub(crate) authority: String,
}
139 changes: 139 additions & 0 deletions src/rocket/introspection/guard.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,31 @@
use custom_error::custom_error;
use openidconnect::TokenIntrospectionResponse;
use rocket::figment::Figment;
use rocket::http::Status;
use rocket::request::{FromRequest, Outcome};
use rocket::{async_trait, Request};
use std::collections::BTreeSet;
use std::collections::HashMap;

#[cfg(feature = "rocket_okapi")]
use rocket_okapi::{
gen::OpenApiGenerator,
okapi::openapi3::{
Object, Responses, SecurityRequirement, SecurityScheme, SecuritySchemeData, MediaType,
RefOr, Response,
},
okapi::Map,
request::{OpenApiFromRequest, RequestHeaderInput},
};
#[cfg(feature = "rocket_okapi")]
use schemars::schema::{InstanceType, ObjectValidation, Schema, SchemaObject};
#[cfg(feature = "rocket_okapi")]

use crate::oidc::introspection::{introspect, IntrospectionError, ZitadelIntrospectionResponse};
use crate::rocket::introspection::IntrospectionConfig;

use super::config::IntrospectionRocketConfig;

custom_error! {
/// Error type for guard related errors.
pub IntrospectionGuardError
Expand Down Expand Up @@ -147,6 +165,127 @@ impl<'request> FromRequest<'request> for &'request IntrospectedUser {
}
}

#[cfg(feature = "rocket_okapi")]
impl<'a> OpenApiFromRequest<'a> for &'a IntrospectedUser {
fn from_request_input(
_gen: &mut OpenApiGenerator,
_name: String,
_required: bool,
) -> rocket_okapi::Result<RequestHeaderInput> {
let figment: Figment = rocket::Config::figment();
let config: IntrospectionRocketConfig = figment
.extract()
.expect("authority must be set in Rocket.toml");

// Setup global requirement for Security scheme
let security_scheme = SecurityScheme {
description: Some(
"Use OpenID Connect to authenticate. (does not work in RapiDoc at all)".to_owned(),
),
data: SecuritySchemeData::OpenIdConnect {
open_id_connect_url: format!(
"{}/.well-known/openid-configuration",
config.authority
),
},
extensions: Object::default(),
};
// Add the requirement for this route/endpoint
// This can change between routes.
let mut security_req = SecurityRequirement::new();
// Each security requirement needs to be met before access is allowed.
security_req.insert("OpenID".to_owned(), Vec::new());
// These vvvv-------^^^^^^^ values need to match exactly!
Ok(RequestHeaderInput::Security(
"OpenID".to_owned(),
security_scheme,
security_req,
))
}

fn get_responses(_gen: &mut OpenApiGenerator) -> rocket_okapi::Result<Responses> {
let mut res = Responses::default();

// Manually defining the error response schema
let error_detail_schema = SchemaObject {
instance_type: Some(InstanceType::Object.into()),
object: Some(Box::new(ObjectValidation {
properties: {
let mut properties = Map::new();
properties.insert(
"code".to_owned(),
Schema::Object(SchemaObject {
instance_type: Some(InstanceType::Integer.into()),
..Default::default()
}),
);
properties.insert(
"reason".to_owned(),
Schema::Object(SchemaObject {
instance_type: Some(InstanceType::String.into()),
..Default::default()
}),
);
properties.insert(
"description".to_owned(),
Schema::Object(SchemaObject {
instance_type: Some(InstanceType::String.into()),
..Default::default()
}),
);
properties
},
required: vec!["code".to_owned(), "reason".to_owned(), "description".to_owned()]
.into_iter()
.collect::<BTreeSet<_>>(), // Convert Vec to BTreeSet
..Default::default()
})),
..Default::default()
};

let error_response_schema = SchemaObject {
instance_type: Some(InstanceType::Object.into()),
object: Some(Box::new(ObjectValidation {
properties: {
let mut properties = Map::new();
properties.insert("error".to_owned(), Schema::Object(error_detail_schema));
properties
},
required: vec!["error".to_owned()].into_iter().collect::<BTreeSet<_>>(), // Convert Vec to BTreeSet
..Default::default()
})),
..Default::default()
};

// Create the content for the error response
let mut content = Map::new();
content.insert(
"application/json".to_owned(),
MediaType {
schema: Some(Schema::Object(error_response_schema).into()),
..Default::default()
},
);

// Adding 400 BadRequest response
let bad_request_response = Response {
description: "Bad Request - Multiple authorization headers found.".to_owned(),
content: content.clone(),
..Default::default()
};
res.responses.insert("400".to_owned(), RefOr::Object(bad_request_response));

// Adding 401 Unauthorized response
let unauthorized_response = Response {
description: "Unauthorized - The request requires user authentication.".to_owned(),
content: content.clone(),
..Default::default()
};
res.responses.insert("401".to_owned(), RefOr::Object(unauthorized_response));

Ok(res)
}}

#[cfg(test)]
mod tests {
#![allow(clippy::all)]
Expand Down

0 comments on commit 2fb5786

Please sign in to comment.