diff --git a/Cargo.toml b/Cargo.toml index 5aa2bc6..f14f6d2 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -87,6 +87,9 @@ oidc = ["credentials", "dep:base64-compat"] ## Refer to the rocket module for more information. rocket = ["credentials", "oidc", "dep:rocket"] +## Feature that enables support for the [rocket okapi](https://github.com/GREsau/okapi). +rocket_okapi = ["rocket", "dep:rocket_okapi", "dep:schemars"] + # @@protoc_deletion_point(features) # This section is automatically generated by protoc-gen-prost-crate. # Changes in this area may be lost on regeneration. @@ -167,6 +170,8 @@ tokio = { version = "1.37.0", optional = true, features = [ tonic = { version = "0.12.1", features = [ "tls", ], optional = true } +rocket_okapi = { version = "0.8.0", optional = true, default-features = false } +schemars = {version = "0.8.21", optional = true} tonic-types = { version = "0.12.1", optional = true } [dev-dependencies] diff --git a/src/rocket/introspection/config.rs b/src/rocket/introspection/config.rs index 9dff39f..2ba5946 100644 --- a/src/rocket/introspection/config.rs +++ b/src/rocket/introspection/config.rs @@ -1,4 +1,5 @@ use openidconnect::IntrospectionUrl; +use serde::Deserialize; #[cfg(feature = "introspection_cache")] use crate::oidc::introspection::cache::IntrospectionCache; @@ -18,3 +19,23 @@ pub struct IntrospectionConfig { #[cfg(feature = "introspection_cache")] pub(crate) cache: Option>, } + +#[cfg(feature = "rocket_okapi")] +/// Configuration for OAuth token introspection read from a Rocket.toml file. +/// +/// # Fields +/// - `authority`: A string representing the authority URL used for introspection. This is typically +/// the base URL of the OAuth provider that will validate the tokens. +/// +/// # Example +/// ```toml +/// [default] +/// authority = "https://auth.example.com/" +/// ``` +/// +/// # Features +/// This struct is only available when the `rocket_okapi` feature is enabled. +#[derive(Debug, Deserialize)] +pub struct IntrospectionRocketConfig { + pub(crate) authority: String, +} diff --git a/src/rocket/introspection/guard.rs b/src/rocket/introspection/guard.rs index 11d9c38..42f23c5 100644 --- a/src/rocket/introspection/guard.rs +++ b/src/rocket/introspection/guard.rs @@ -1,13 +1,31 @@ use custom_error::custom_error; use openidconnect::TokenIntrospectionResponse; +use rocket::figment::Figment; use rocket::http::Status; use rocket::request::{FromRequest, Outcome}; use rocket::{async_trait, Request}; +use std::collections::BTreeSet; use std::collections::HashMap; +#[cfg(feature = "rocket_okapi")] +use rocket_okapi::{ + gen::OpenApiGenerator, + okapi::openapi3::{ + Object, Responses, SecurityRequirement, SecurityScheme, SecuritySchemeData, MediaType, + RefOr, Response, + }, + okapi::Map, + request::{OpenApiFromRequest, RequestHeaderInput}, +}; +#[cfg(feature = "rocket_okapi")] +use schemars::schema::{InstanceType, ObjectValidation, Schema, SchemaObject}; +#[cfg(feature = "rocket_okapi")] + use crate::oidc::introspection::{introspect, IntrospectionError, ZitadelIntrospectionResponse}; use crate::rocket::introspection::IntrospectionConfig; +use super::config::IntrospectionRocketConfig; + custom_error! { /// Error type for guard related errors. pub IntrospectionGuardError @@ -147,6 +165,127 @@ impl<'request> FromRequest<'request> for &'request IntrospectedUser { } } +#[cfg(feature = "rocket_okapi")] +impl<'a> OpenApiFromRequest<'a> for &'a IntrospectedUser { + fn from_request_input( + _gen: &mut OpenApiGenerator, + _name: String, + _required: bool, + ) -> rocket_okapi::Result { + let figment: Figment = rocket::Config::figment(); + let config: IntrospectionRocketConfig = figment + .extract() + .expect("authority must be set in Rocket.toml"); + + // Setup global requirement for Security scheme + let security_scheme = SecurityScheme { + description: Some( + "Use OpenID Connect to authenticate. (does not work in RapiDoc at all)".to_owned(), + ), + data: SecuritySchemeData::OpenIdConnect { + open_id_connect_url: format!( + "{}/.well-known/openid-configuration", + config.authority + ), + }, + extensions: Object::default(), + }; + // Add the requirement for this route/endpoint + // This can change between routes. + let mut security_req = SecurityRequirement::new(); + // Each security requirement needs to be met before access is allowed. + security_req.insert("OpenID".to_owned(), Vec::new()); + // These vvvv-------^^^^^^^ values need to match exactly! + Ok(RequestHeaderInput::Security( + "OpenID".to_owned(), + security_scheme, + security_req, + )) + } + + fn get_responses(_gen: &mut OpenApiGenerator) -> rocket_okapi::Result { + let mut res = Responses::default(); + + // Manually defining the error response schema + let error_detail_schema = SchemaObject { + instance_type: Some(InstanceType::Object.into()), + object: Some(Box::new(ObjectValidation { + properties: { + let mut properties = Map::new(); + properties.insert( + "code".to_owned(), + Schema::Object(SchemaObject { + instance_type: Some(InstanceType::Integer.into()), + ..Default::default() + }), + ); + properties.insert( + "reason".to_owned(), + Schema::Object(SchemaObject { + instance_type: Some(InstanceType::String.into()), + ..Default::default() + }), + ); + properties.insert( + "description".to_owned(), + Schema::Object(SchemaObject { + instance_type: Some(InstanceType::String.into()), + ..Default::default() + }), + ); + properties + }, + required: vec!["code".to_owned(), "reason".to_owned(), "description".to_owned()] + .into_iter() + .collect::>(), // Convert Vec to BTreeSet + ..Default::default() + })), + ..Default::default() + }; + + let error_response_schema = SchemaObject { + instance_type: Some(InstanceType::Object.into()), + object: Some(Box::new(ObjectValidation { + properties: { + let mut properties = Map::new(); + properties.insert("error".to_owned(), Schema::Object(error_detail_schema)); + properties + }, + required: vec!["error".to_owned()].into_iter().collect::>(), // Convert Vec to BTreeSet + ..Default::default() + })), + ..Default::default() + }; + + // Create the content for the error response + let mut content = Map::new(); + content.insert( + "application/json".to_owned(), + MediaType { + schema: Some(Schema::Object(error_response_schema).into()), + ..Default::default() + }, + ); + + // Adding 400 BadRequest response + let bad_request_response = Response { + description: "Bad Request - Multiple authorization headers found.".to_owned(), + content: content.clone(), + ..Default::default() + }; + res.responses.insert("400".to_owned(), RefOr::Object(bad_request_response)); + + // Adding 401 Unauthorized response + let unauthorized_response = Response { + description: "Unauthorized - The request requires user authentication.".to_owned(), + content: content.clone(), + ..Default::default() + }; + res.responses.insert("401".to_owned(), RefOr::Object(unauthorized_response)); + + Ok(res) + }} + #[cfg(test)] mod tests { #![allow(clippy::all)]