From 09784ab77b3aff319e97176c99f707aff584e0eb Mon Sep 17 00:00:00 2001 From: OmegaJak Date: Fri, 13 Oct 2023 18:36:51 -0500 Subject: [PATCH 01/72] WIP end-of-day push --- src/auth/flows.rs | 136 +++++++++++++++++- src/auth/mod.rs | 134 ++++++++++++++++- src/database/models/flow_item.rs | 15 +- src/database/models/ids.rs | 13 ++ src/database/models/mod.rs | 2 + .../models/oauth_client_authorization_item.rs | 36 +++++ src/database/models/oauth_client_item.rs | 47 ++++++ src/models/pats.rs | 66 +++++++++ 8 files changed, 444 insertions(+), 5 deletions(-) create mode 100644 src/database/models/oauth_client_authorization_item.rs create mode 100644 src/database/models/oauth_client_item.rs diff --git a/src/auth/flows.rs b/src/auth/flows.rs index 7b54f5b1..daae0d8a 100644 --- a/src/auth/flows.rs +++ b/src/auth/flows.rs @@ -1,8 +1,11 @@ use crate::auth::email::send_email; use crate::auth::session::issue_session; use crate::auth::validate::get_user_record_from_bearer_token; -use crate::auth::{get_user_from_headers, AuthenticationError}; +use crate::auth::{get_user_from_headers, AuthenticationError, OAuthError, OAuthErrorType}; use crate::database::models::flow_item::Flow; +use crate::database::models::oauth_client_authorization_item::OAuthClientAuthorization; +use crate::database::models::oauth_client_item::OAuthClient as DBOAuthClient; +use crate::database::models::OAuthClientId; use crate::database::redis::RedisPool; use crate::file_hosting::FileHost; use crate::models::ids::base62_impl::{parse_base62, to_base62}; @@ -53,7 +56,8 @@ pub fn config(cfg: &mut ServiceConfig) { .service(set_email) .service(verify_email) .service(subscribe_newsletter) - .service(link_trolley), + .service(link_trolley) + .service(init_oauth), ); } @@ -2247,7 +2251,11 @@ pub async fn link_trolley( } if let Some(email) = user.email { - let id = payouts_queue.lock().await.register_recipient(&email, body.0).await?; + let id = payouts_queue + .lock() + .await + .register_recipient(&email, body.0) + .await?; let mut transaction = pool.begin().await?; @@ -2274,3 +2282,125 @@ pub async fn link_trolley( )) } } + +#[derive(Serialize, Deserialize)] +pub struct OAuthInit { + pub client_id: OAuthClientId, + pub redirect_uri: Option, + pub scope: Option, + pub state: Option, +} + +#[get("oauth/authorize")] +pub async fn init_oauth( + req: HttpRequest, + Query(oauth_info): Query, // callback url + pool: Data, + redis: Data, + session_queue: Data, +) -> Result { + let user = get_user_from_headers(&req, &**pool, &redis, &session_queue, None) + .await + .map_err(|e| OAuthError::error(e))? + .1; + + let client = DBOAuthClient::get(oauth_info.client_id.0, &**pool, &redis) + .await + .map_err(|e| OAuthError::error(e))?; + + if let Some(client) = client { + let redirect_uri = ValidatedRedirectUri::validate(&oauth_info.redirect_uri, &client)?; + + let requested_scopes = oauth_info + .scope + .as_ref() + .map_or(Ok(client.max_scopes), |s| { + Scopes::parse_from_oauth_scopes(s).map_err(|e| { + OAuthError::redirect( + OAuthErrorType::FailedScopeParse(e), + &oauth_info, + &redirect_uri, + ) + }) + })?; + + if !client.max_scopes.contains(requested_scopes) { + return Err(OAuthError::redirect( + OAuthErrorType::ScopesTooBroad, + &oauth_info, + &redirect_uri, + )); + } + + let existing_authorization = + OAuthClientAuthorization::get(client.id, user.id.into(), &**pool, &redis) + .await + .map_err(|e| OAuthError::redirect(e, &oauth_info, &redirect_uri))?; + let has_already_accepted = + existing_authorization.is_some_and(|auth| auth.scopes.contains(requested_scopes)); + if has_already_accepted { + Flow::OAuthAuthorizationCodeSupplied { + user_id: user.id.into(), + client_id: client.id, + scopes: requested_scopes, + original_redirect_uri: oauth_info.redirect_uri.clone(), + } + .insert(Duration::minutes(10), &redis) + .await + .map_err(|e| OAuthError::redirect(e, &oauth_info, &redirect_uri))?; + } else { + let flow_id = Flow::InitOAuthAppApproval { + user_id: user.id.into(), + client_id: client.id, + scopes: requested_scopes, + redirect_uri: todo!(), + state: oauth_info.state, + } + .insert(Duration::minutes(30), &redis) + .await + .map_err(|e| OAuthError::redirect(e, &oauth_info, &redirect_uri))?; + } + } else { + return Err(OAuthError::error(OAuthErrorType::UnrecognizedClient { + client_id: oauth_info.client_id, + })); + } + + todo!() +} + +#[derive(Clone, Debug)] +pub struct ValidatedRedirectUri(pub String); + +impl ValidatedRedirectUri { + pub fn validate( + to_validate: &Option, + client: &DBOAuthClient, + ) -> Result { + if let Some(first_client_redirect_uri) = client.redirect_uris.first() { + if let Some(to_validate) = to_validate { + if client.redirect_uris.iter().any(|uri| { + same_uri_except_query_components(&first_client_redirect_uri.uri, to_validate) + }) { + return Ok(ValidatedRedirectUri(to_validate.clone())); + } else { + return Err(OAuthError::error(OAuthErrorType::InvalidRedirectUri( + to_validate.clone(), + ))); + } + } else { + return Ok(ValidatedRedirectUri(first_client_redirect_uri.uri.clone())); + } + } else { + return Err(OAuthError::error( + OAuthErrorType::ClientMissingRedirectURI { + client_id: client.id, + }, + )); + } + } +} + +fn same_uri_except_query_components(a: &str, b: &str) -> bool { + true +} diff --git a/src/auth/mod.rs b/src/auth/mod.rs index eec82ec5..2dea6c9d 100644 --- a/src/auth/mod.rs +++ b/src/auth/mod.rs @@ -5,7 +5,6 @@ pub mod pats; pub mod session; mod templates; pub mod validate; - pub use checks::{ filter_authorized_projects, filter_authorized_versions, is_authorized, is_authorized_version, }; @@ -18,6 +17,8 @@ use actix_web::http::StatusCode; use actix_web::HttpResponse; use thiserror::Error; +use self::flows::{OAuthInit, ValidatedRedirectUri}; + #[derive(Error, Debug)] pub enum AuthenticationError { #[error("Environment Error")] @@ -98,3 +99,134 @@ impl AuthenticationError { } } } + +#[derive(thiserror::Error, Debug)] +#[error("{}", .error_type)] +pub struct OAuthError { + #[source] + pub error_type: OAuthErrorType, + + pub state: Option, + pub valid_redirect_uri: Option, +} + +impl OAuthError { + /// The OAuth request failed either because of an invalid redirection URI + /// or before we could validate the one we were given, so return an error + /// directly to the caller + /// + /// See: IETF RFC 6749 4.1.2.1 (https://datatracker.ietf.org/doc/html/rfc6749#section-4.1.2.1) + pub fn error(error_type: impl Into) -> Self { + Self { + error_type: error_type.into(), + valid_redirect_uri: None, + state: None, + } + } + + /// The OAuth request failed for a reason other than an invalid redirection URI + /// So send the error in url-encoded form to the redirect URI + /// + /// See: IETF RFC 6749 4.1.2.1 (https://datatracker.ietf.org/doc/html/rfc6749#section-4.1.2.1) + pub fn redirect( + err: impl Into, + oauth_info: &OAuthInit, + valid_redirect_uri: &ValidatedRedirectUri, + ) -> Self { + Self { + error_type: err.into(), + state: oauth_info.state.clone(), + valid_redirect_uri: Some(valid_redirect_uri.clone()), + } + } +} + +impl actix_web::ResponseError for OAuthError { + fn status_code(&self) -> StatusCode { + match self.error_type { + OAuthErrorType::AuthenticationError(_) + | OAuthErrorType::UnrecognizedClient { client_id: _ } + | OAuthErrorType::FailedScopeParse(_) + | OAuthErrorType::ScopesTooBroad => { + if self.valid_redirect_uri.is_some() { + StatusCode::FOUND + } else { + StatusCode::INTERNAL_SERVER_ERROR + } + } + OAuthErrorType::InvalidUri + | OAuthErrorType::InvalidRedirectUri(_) + | OAuthErrorType::ClientMissingRedirectURI { client_id: _ } => StatusCode::BAD_REQUEST, + } + } + + fn error_response(&self) -> HttpResponse { + if let Some(ValidatedRedirectUri(mut redirect_uri)) = self.valid_redirect_uri.clone() { + redirect_uri = format!( + "{}?error={}&error_description={}", + redirect_uri.to_string(), + self.error_type.error_name(), + self.error_type.to_string(), + ); + + if let Some(state) = self.state.as_ref() { + redirect_uri = format!("{}&state={}", redirect_uri, state); + } + + redirect_uri = urlencoding::encode(&redirect_uri).to_string(); + HttpResponse::Found() + .append_header(("Location".to_string(), redirect_uri)) + .finish() + } else { + HttpResponse::build(self.status_code()).json(ApiError { + error: &self.error_type.error_name(), + description: &self.error_type.to_string(), + }) + } + } +} + +#[derive(thiserror::Error, Debug)] +pub enum OAuthErrorType { + #[error("The provided URI was not valid")] + InvalidUri, + #[error(transparent)] + AuthenticationError(#[from] AuthenticationError), + #[error("Client {} not recognized", .client_id.0)] + UnrecognizedClient { + client_id: crate::database::models::OAuthClientId, + }, + #[error("Client {} has no redirect URIs specified", .client_id.0)] + ClientMissingRedirectURI { + client_id: crate::database::models::OAuthClientId, + }, + #[error("The provided redirect URI did not match any configured in the client")] + InvalidRedirectUri(String), + #[error("The provided scope was malformed or did not correspond to known scopes ({0})")] + FailedScopeParse(bitflags::parser::ParseError), + #[error( + "The provided scope requested scopes broader than the developer app is configured with" + )] + ScopesTooBroad, +} + +impl From for OAuthErrorType { + fn from(value: crate::database::models::DatabaseError) -> Self { + OAuthErrorType::AuthenticationError(value.into()) + } +} + +impl OAuthErrorType { + pub fn error_name(&self) -> String { + // IETF RFC 6749 4.1.2.1 (https://datatracker.ietf.org/doc/html/rfc6749#autoid-38) + match self { + OAuthErrorType::InvalidUri + | OAuthErrorType::InvalidRedirectUri(_) + | OAuthErrorType::ClientMissingRedirectURI { client_id: _ } => "invalid_uri", + OAuthErrorType::AuthenticationError(_) => "server_error", + OAuthErrorType::UnrecognizedClient { client_id: _ } => "invalid_request", + OAuthErrorType::FailedScopeParse(_) | OAuthErrorType::ScopesTooBroad => "invalid_scope", + } + .to_string() + } +} diff --git a/src/database/models/flow_item.rs b/src/database/models/flow_item.rs index 5bcd2671..20c6b8ea 100644 --- a/src/database/models/flow_item.rs +++ b/src/database/models/flow_item.rs @@ -1,7 +1,7 @@ use super::ids::*; -use crate::auth::flows::AuthProvider; use crate::database::models::DatabaseError; use crate::database::redis::RedisPool; +use crate::{auth::flows::AuthProvider, models::pats::Scopes}; use chrono::Duration; use rand::distributions::Alphanumeric; use rand::Rng; @@ -34,6 +34,19 @@ pub enum Flow { confirm_email: String, }, MinecraftAuth, + InitOAuthAppApproval { + user_id: UserId, + client_id: OAuthClientId, + scopes: Scopes, + redirect_uri: String, + state: Option, + }, + OAuthAuthorizationCodeSupplied { + user_id: UserId, + client_id: OAuthClientId, + scopes: Scopes, + original_redirect_uri: Option, + }, } impl Flow { diff --git a/src/database/models/ids.rs b/src/database/models/ids.rs index 1d1af665..31e246c9 100644 --- a/src/database/models/ids.rs +++ b/src/database/models/ids.rs @@ -3,6 +3,7 @@ use crate::models::ids::base62_impl::to_base62; use crate::models::ids::random_base62_rng; use censor::Censor; use serde::{Deserialize, Serialize}; +use serde_with::SerializeDisplay; use sqlx::sqlx_macros::Type; const ID_RETRY_COUNT: usize = 20; @@ -238,6 +239,18 @@ pub struct SessionId(pub i64); #[sqlx(transparent)] pub struct ImageId(pub i64); +#[derive(Copy, Clone, Debug, Type, Serialize, Deserialize, Eq, PartialEq, Hash)] +#[sqlx(transparent)] +pub struct OAuthClientId(pub i64); + +#[derive(Copy, Clone, Debug, Type, Serialize, Deserialize, Eq, PartialEq, Hash)] +#[sqlx(transparent)] +pub struct OAuthClientAuthorizationId(pub i64); + +#[derive(Copy, Clone, Debug, Type, Serialize, Deserialize, Eq, PartialEq, Hash)] +#[sqlx(transparent)] +pub struct OAuthRedirectUriId(pub i64); + use crate::models::ids; impl From for ProjectId { diff --git a/src/database/models/mod.rs b/src/database/models/mod.rs index bfd6e781..a1e6c75f 100644 --- a/src/database/models/mod.rs +++ b/src/database/models/mod.rs @@ -6,6 +6,8 @@ pub mod flow_item; pub mod ids; pub mod image_item; pub mod notification_item; +pub mod oauth_client_authorization_item; +pub mod oauth_client_item; pub mod organization_item; pub mod pat_item; pub mod project_item; diff --git a/src/database/models/oauth_client_authorization_item.rs b/src/database/models/oauth_client_authorization_item.rs new file mode 100644 index 00000000..496793cc --- /dev/null +++ b/src/database/models/oauth_client_authorization_item.rs @@ -0,0 +1,36 @@ +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; + +use crate::{database::redis::RedisPool, models::pats::Scopes}; + +use super::{DatabaseError, OAuthClientAuthorizationId, OAuthClientId, UserId}; + +#[derive(Deserialize, Serialize, Clone, Debug)] +pub struct OAuthClientAuthorization { + pub id: OAuthClientAuthorizationId, + pub client_id: OAuthClientId, + pub user_id: UserId, + pub scopes: Scopes, + pub created: DateTime, + // last_used? +} + +impl OAuthClientAuthorization { + pub async fn get<'a, E>( + client_id: OAuthClientId, + user_id: UserId, + exec: E, + redis: &RedisPool, + ) -> Result, DatabaseError> + where + E: sqlx::Executor<'a, Database = sqlx::Postgres>, + { + return Ok(Some(OAuthClientAuthorization { + id: OAuthClientAuthorizationId(1), + client_id: OAuthClientId(1), + user_id: UserId(1), + scopes: Scopes::all(), + created: Utc::now(), + })); + } +} diff --git a/src/database/models/oauth_client_item.rs b/src/database/models/oauth_client_item.rs new file mode 100644 index 00000000..4e109ee3 --- /dev/null +++ b/src/database/models/oauth_client_item.rs @@ -0,0 +1,47 @@ +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; + +use crate::{database::redis::RedisPool, models::pats::Scopes}; + +use super::{DatabaseError, OAuthClientId, OAuthRedirectUriId, UserId}; + +#[derive(Deserialize, Serialize, Clone, Debug)] +pub struct OAuthRedirectUri { + pub id: OAuthRedirectUriId, + pub client_id: OAuthClientId, + pub uri: String, +} + +#[derive(Deserialize, Serialize, Clone, Debug)] +pub struct OAuthClient { + pub id: OAuthClientId, + pub name: String, + pub icon: Option, + pub max_scopes: Scopes, + pub secret: String, + pub redirect_uris: Vec, + pub created: DateTime, + pub created_by: UserId, +} + +impl OAuthClient { + pub async fn get<'a, E, T: ToString>( + id: T, + exec: E, + redis: &RedisPool, + ) -> Result, DatabaseError> + where + E: sqlx::Executor<'a, Database = sqlx::Postgres>, + { + return Ok(Some(OAuthClient { + id: OAuthClientId(1), + name: "Test Client".to_string(), + icon: None, + max_scopes: Scopes::all(), + redirect_uris: vec![], + secret: "hashed secret".to_string(), + created: Utc::now(), + created_by: UserId(1), + })); + } +} diff --git a/src/models/pats.rs b/src/models/pats.rs index 6ac2afc8..b66cd8f2 100644 --- a/src/models/pats.rs +++ b/src/models/pats.rs @@ -126,6 +126,11 @@ impl Scopes { pub fn is_restricted(&self) -> bool { self.intersects(Self::restricted()) } + + pub fn parse_from_oauth_scopes(scopes: &str) -> Result { + let scopes = scopes.replace(' ', "|").replace("%20", "|"); + bitflags::parser::from_str(&scopes) + } } #[derive(Serialize, Deserialize)] @@ -161,3 +166,64 @@ impl PersonalAccessToken { } } } + +#[cfg(test)] +mod test { + use super::*; + use itertools::Itertools; + + #[test] + fn test_parse_from_oauth_scopes_well_formed() { + let raw = "USER_READ_EMAIL SESSION_READ ORGANIZATION_CREATE"; + let expected = Scopes::USER_READ_EMAIL | Scopes::SESSION_READ | Scopes::ORGANIZATION_CREATE; + + let parsed = Scopes::parse_from_oauth_scopes(raw).unwrap(); + + assert_same_flags(expected, parsed); + } + + #[test] + fn test_parse_from_oauth_scopes_empty() { + let raw = ""; + let expected = Scopes::empty(); + + let parsed = Scopes::parse_from_oauth_scopes(raw).unwrap(); + + assert_same_flags(expected, parsed); + } + + #[test] + fn test_parse_from_oauth_scopes_invalid_scopes() { + let raw = "notascope"; + + let parsed = Scopes::parse_from_oauth_scopes(raw); + + assert!(parsed.is_err()); + } + + #[test] + fn test_parse_from_oauth_scopes_invalid_separator() { + let raw = "USER_READ_EMAIL & SESSION_READ"; + + let parsed = Scopes::parse_from_oauth_scopes(raw); + + assert!(parsed.is_err()); + } + + #[test] + fn test_parse_from_oauth_scopes_url_encoded() { + let raw = urlencoding::encode("PAT_WRITE COLLECTION_DELETE").to_string(); + let expected = Scopes::PAT_WRITE | Scopes::COLLECTION_DELETE; + + let parsed = Scopes::parse_from_oauth_scopes(&raw).unwrap(); + + assert_same_flags(expected, parsed); + } + + fn assert_same_flags(expected: Scopes, actual: Scopes) { + assert_eq!( + expected.iter_names().map(|(name, _)| name).collect_vec(), + actual.iter_names().map(|(name, _)| name).collect_vec() + ); + } +} From 053086489f2f03fef46d9d2ce6d3e55706546617 Mon Sep 17 00:00:00 2001 From: OmegaJak Date: Mon, 16 Oct 2023 16:32:18 -0500 Subject: [PATCH 02/72] Authorize endpoint, accept endpoints, DB stuff for oauth clients, their redirects, and client authorizations --- migrations/20231016190056_oauth_provider.sql | 21 ++ src/auth/flows.rs | 249 +++++++++++++++--- src/auth/mod.rs | 27 +- src/database/models/flow_item.rs | 7 +- src/database/models/ids.rs | 9 +- .../models/oauth_client_authorization_item.rs | 52 +++- src/database/models/oauth_client_item.rs | 123 +++++++-- 7 files changed, 415 insertions(+), 73 deletions(-) create mode 100644 migrations/20231016190056_oauth_provider.sql diff --git a/migrations/20231016190056_oauth_provider.sql b/migrations/20231016190056_oauth_provider.sql new file mode 100644 index 00000000..a47704e8 --- /dev/null +++ b/migrations/20231016190056_oauth_provider.sql @@ -0,0 +1,21 @@ +CREATE TABLE oauth_clients ( + id bigint PRIMARY KEY, + name varchar(255) NOT NULL, + icon_url varchar(255) NULL, + max_scopes bigint NOT NULL, + secret_hash char(512) NOT NULL, + created timestamptz NOT NULL DEFAULT CURRENT_TIMESTAMP, + created_by bigint NOT NULL REFERENCES users(id) +); +CREATE TABLE oauth_client_redirect_uris ( + id bigint PRIMARY KEY, + client_id bigint REFERENCES oauth_clients(id) NOT NULL, + uri varchar(255) +); +CREATE TABLE oauth_client_authorizations ( + id bigint PRIMARY KEY, + client_id bigint NOT NULL REFERENCES oauth_clients(id), + user_id bigint NOT NULL REFERENCES users(id), + scopes bigint NOT NULL, + created timestamptz NOT NULL DEFAULT CURRENT_TIMESTAMP +); \ No newline at end of file diff --git a/src/auth/flows.rs b/src/auth/flows.rs index daae0d8a..f5477071 100644 --- a/src/auth/flows.rs +++ b/src/auth/flows.rs @@ -5,7 +5,7 @@ use crate::auth::{get_user_from_headers, AuthenticationError, OAuthError, OAuthE use crate::database::models::flow_item::Flow; use crate::database::models::oauth_client_authorization_item::OAuthClientAuthorization; use crate::database::models::oauth_client_item::OAuthClient as DBOAuthClient; -use crate::database::models::OAuthClientId; +use crate::database::models::{generate_oauth_client_authorization_id, OAuthClientId}; use crate::database::redis::RedisPool; use crate::file_hosting::FileHost; use crate::models::ids::base62_impl::{parse_base62, to_base62}; @@ -57,7 +57,8 @@ pub fn config(cfg: &mut ServiceConfig) { .service(verify_email) .service(subscribe_newsletter) .service(link_trolley) - .service(init_oauth), + .service(init_oauth) + .service(accept_client_scopes), ); } @@ -2291,6 +2292,15 @@ pub struct OAuthInit { pub state: Option, } +#[derive(Serialize, Deserialize)] +pub struct OAuthClientAccessRequest { + pub flow_id: String, + pub client_id: OAuthClientId, + pub client_name: String, + pub client_icon: Option, + pub requested_scopes: Scopes, +} + #[get("oauth/authorize")] pub async fn init_oauth( req: HttpRequest, @@ -2301,15 +2311,19 @@ pub async fn init_oauth( ) -> Result { let user = get_user_from_headers(&req, &**pool, &redis, &session_queue, None) .await - .map_err(|e| OAuthError::error(e))? + .map_err(OAuthError::error)? .1; - let client = DBOAuthClient::get(oauth_info.client_id.0, &**pool, &redis) + let client = DBOAuthClient::get(oauth_info.client_id, &**pool) .await - .map_err(|e| OAuthError::error(e))?; + .map_err(OAuthError::error)?; if let Some(client) = client { - let redirect_uri = ValidatedRedirectUri::validate(&oauth_info.redirect_uri, &client)?; + let redirect_uri = ValidatedRedirectUri::validate( + &oauth_info.redirect_uri, + client.redirect_uris.iter().map(|r| r.uri.as_ref()), + client.id, + )?; let requested_scopes = oauth_info .scope @@ -2318,7 +2332,7 @@ pub async fn init_oauth( Scopes::parse_from_oauth_scopes(s).map_err(|e| { OAuthError::redirect( OAuthErrorType::FailedScopeParse(e), - &oauth_info, + &oauth_info.state, &redirect_uri, ) }) @@ -2327,61 +2341,166 @@ pub async fn init_oauth( if !client.max_scopes.contains(requested_scopes) { return Err(OAuthError::redirect( OAuthErrorType::ScopesTooBroad, - &oauth_info, + &oauth_info.state, &redirect_uri, )); } let existing_authorization = - OAuthClientAuthorization::get(client.id, user.id.into(), &**pool, &redis) + OAuthClientAuthorization::get(client.id, user.id.into(), &**pool) .await - .map_err(|e| OAuthError::redirect(e, &oauth_info, &redirect_uri))?; + .map_err(|e| OAuthError::redirect(e, &oauth_info.state, &redirect_uri))?; let has_already_accepted = existing_authorization.is_some_and(|auth| auth.scopes.contains(requested_scopes)); if has_already_accepted { - Flow::OAuthAuthorizationCodeSupplied { - user_id: user.id.into(), - client_id: client.id, - scopes: requested_scopes, - original_redirect_uri: oauth_info.redirect_uri.clone(), - } - .insert(Duration::minutes(10), &redis) + get_init_oauth_code_flow_response( + user.id.into(), + client.id, + requested_scopes, + redirect_uri, + oauth_info.redirect_uri, + oauth_info.state, + &redis, + ) .await - .map_err(|e| OAuthError::redirect(e, &oauth_info, &redirect_uri))?; } else { let flow_id = Flow::InitOAuthAppApproval { user_id: user.id.into(), client_id: client.id, scopes: requested_scopes, - redirect_uri: todo!(), - state: oauth_info.state, + validated_redirect_uri: redirect_uri.clone(), + original_redirect_uri: oauth_info.redirect_uri.clone(), + state: oauth_info.state.clone(), } .insert(Duration::minutes(30), &redis) .await - .map_err(|e| OAuthError::redirect(e, &oauth_info, &redirect_uri))?; + .map_err(|e| OAuthError::redirect(e, &oauth_info.state, &redirect_uri))?; + + let access_request = OAuthClientAccessRequest { + client_id: client.id, + client_name: client.name, + client_icon: client.icon_url, + flow_id, + requested_scopes, + }; + Ok(HttpResponse::Ok().json(access_request)) } } else { - return Err(OAuthError::error(OAuthErrorType::UnrecognizedClient { + Err(OAuthError::error(OAuthErrorType::UnrecognizedClient { client_id: oauth_info.client_id, - })); + })) + } +} + +#[derive(Deserialize)] +pub struct AcceptOAuthClientScopes { + pub flow: String, +} + +#[get("oauth/accept")] +pub async fn accept_client_scopes( + req: HttpRequest, + accept_body: web::Json, // callback url + pool: Data, + redis: Data, + session_queue: Data, +) -> Result { + let user = get_user_from_headers(&req, &**pool, &redis, &session_queue, None) + .await + .map_err(|e| OAuthError::error(e))? + .1; + + let flow = Flow::get(&accept_body.flow, &redis) + .await + .map_err(|e| OAuthError::error(e))?; + if let Some(Flow::InitOAuthAppApproval { + user_id, + client_id, + scopes, + validated_redirect_uri, + original_redirect_uri, + state, + }) = flow + { + let mut transaction = pool.begin().await.map_err(OAuthError::error)?; + + let id = generate_oauth_client_authorization_id(&mut transaction) + .await + .map_err(OAuthError::error)?; + OAuthClientAuthorization { + id, + client_id, + user_id, + scopes, + created: Utc::now(), + } + .insert(&mut transaction) + .await + .map_err(OAuthError::error)?; + + get_init_oauth_code_flow_response( + user_id, + client_id, + scopes, + validated_redirect_uri, + original_redirect_uri, + state, + &redis, + ) + .await + } else { + Err(OAuthError::error(OAuthErrorType::InvalidAcceptFlowId)) } +} + +async fn get_init_oauth_code_flow_response( + user_id: crate::database::models::UserId, + client_id: OAuthClientId, + scopes: Scopes, + validated_redirect_uri: ValidatedRedirectUri, + original_redirect_uri: Option, + state: Option, + redis: &RedisPool, +) -> Result { + let code = Flow::OAuthAuthorizationCodeSupplied { + user_id, + client_id, + scopes, + validated_redirect_uri: validated_redirect_uri.clone(), + original_redirect_uri, + } + .insert(Duration::minutes(10), redis) + .await + .map_err(|e| OAuthError::redirect(e, &state, &validated_redirect_uri.clone()))?; + + let mut redirect_params = vec![format!("code={code}")]; + if let Some(state) = state { + redirect_params.push(format!("state={state}")); + } + + let redirect_uri = append_params_to_uri(&validated_redirect_uri.0, &redirect_params); - todo!() + // IETF RFC 6749 Section 4.1.2 (https://datatracker.ietf.org/doc/html/rfc6749#section-4.1.2) + Ok(HttpResponse::Found() + .append_header(("Location", redirect_uri)) + .finish()) } -#[derive(Clone, Debug)] +#[derive(Clone, Debug, Serialize, Deserialize)] pub struct ValidatedRedirectUri(pub String); impl ValidatedRedirectUri { - pub fn validate( + pub fn validate<'a>( to_validate: &Option, - client: &DBOAuthClient, + validate_against: impl IntoIterator + Clone, + client_id: OAuthClientId, ) -> Result { - if let Some(first_client_redirect_uri) = client.redirect_uris.first() { + if let Some(first_client_redirect_uri) = validate_against.clone().into_iter().next() { if let Some(to_validate) = to_validate { - if client.redirect_uris.iter().any(|uri| { - same_uri_except_query_components(&first_client_redirect_uri.uri, to_validate) - }) { + if validate_against + .into_iter() + .any(|uri| same_uri_except_query_components(&uri, to_validate)) + { return Ok(ValidatedRedirectUri(to_validate.clone())); } else { return Err(OAuthError::error(OAuthErrorType::InvalidRedirectUri( @@ -2389,18 +2508,76 @@ impl ValidatedRedirectUri { ))); } } else { - return Ok(ValidatedRedirectUri(first_client_redirect_uri.uri.clone())); + return Ok(ValidatedRedirectUri(first_client_redirect_uri.to_string())); } } else { return Err(OAuthError::error( - OAuthErrorType::ClientMissingRedirectURI { - client_id: client.id, - }, + OAuthErrorType::ClientMissingRedirectURI { client_id }, )); } } } fn same_uri_except_query_components(a: &str, b: &str) -> bool { - true + let mut a_components = a.split('?'); + let mut b_components = b.split('?'); + a_components.next() == b_components.next() +} + +fn append_params_to_uri(uri: &str, params: &[impl AsRef]) -> String { + let mut uri = uri.to_string(); + let mut connector = if uri.contains('?') { "&" } else { "?" }; + for param in params { + uri.push_str(&format!("{}{}", connector, param.as_ref())); + connector = "&"; + } + + uri +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn validate_for_none_returns_first_valid_uri() { + let validate_against = vec!["https://modrinth.com/a"]; + + let validated = + ValidatedRedirectUri::validate(&None, validate_against.clone(), OAuthClientId(0)) + .unwrap(); + + assert_eq!(validate_against[0], validated.0); + } + + #[test] + fn validate_for_valid_uri_returns_first_matching_uri_ignoring_query_params() { + let validate_against = vec![ + "https://modrinth.com/a?q3=p3&q4=p4", + "https://modrinth.com/a/b/c?q1=p1&q2=p2", + ]; + let to_validate = "https://modrinth.com/a/b/c?query0=param0&query1=param1".to_string(); + + let validated = ValidatedRedirectUri::validate( + &Some(to_validate.clone()), + validate_against, + OAuthClientId(0), + ) + .unwrap(); + + assert_eq!(to_validate, validated.0); + } + + #[test] + fn validate_for_invalid_uri_returns_err() { + let validate_against = vec!["https://modrinth.com/a"]; + let to_validate = "https://modrinth.com/a/b".to_string(); + + let validated = + ValidatedRedirectUri::validate(&Some(to_validate), validate_against, OAuthClientId(0)); + + assert!( + validated.is_err_and(|e| matches!(e.error_type, OAuthErrorType::InvalidRedirectUri(_))) + ); + } } diff --git a/src/auth/mod.rs b/src/auth/mod.rs index 2dea6c9d..b633f308 100644 --- a/src/auth/mod.rs +++ b/src/auth/mod.rs @@ -130,12 +130,12 @@ impl OAuthError { /// See: IETF RFC 6749 4.1.2.1 (https://datatracker.ietf.org/doc/html/rfc6749#section-4.1.2.1) pub fn redirect( err: impl Into, - oauth_info: &OAuthInit, + state: &Option, valid_redirect_uri: &ValidatedRedirectUri, ) -> Self { Self { error_type: err.into(), - state: oauth_info.state.clone(), + state: state.clone(), valid_redirect_uri: Some(valid_redirect_uri.clone()), } } @@ -154,9 +154,9 @@ impl actix_web::ResponseError for OAuthError { StatusCode::INTERNAL_SERVER_ERROR } } - OAuthErrorType::InvalidUri - | OAuthErrorType::InvalidRedirectUri(_) - | OAuthErrorType::ClientMissingRedirectURI { client_id: _ } => StatusCode::BAD_REQUEST, + OAuthErrorType::InvalidRedirectUri(_) + | OAuthErrorType::ClientMissingRedirectURI { client_id: _ } + | OAuthErrorType::InvalidAcceptFlowId => StatusCode::BAD_REQUEST, } } @@ -188,8 +188,6 @@ impl actix_web::ResponseError for OAuthError { #[derive(thiserror::Error, Debug)] pub enum OAuthErrorType { - #[error("The provided URI was not valid")] - InvalidUri, #[error(transparent)] AuthenticationError(#[from] AuthenticationError), #[error("Client {} not recognized", .client_id.0)] @@ -208,6 +206,8 @@ pub enum OAuthErrorType { "The provided scope requested scopes broader than the developer app is configured with" )] ScopesTooBroad, + #[error("The provided flow id was invalid")] + InvalidAcceptFlowId, } impl From for OAuthErrorType { @@ -216,14 +216,21 @@ impl From for OAuthErrorType { } } +impl From for OAuthErrorType { + fn from(value: sqlx::Error) -> Self { + OAuthErrorType::AuthenticationError(value.into()) + } +} + impl OAuthErrorType { pub fn error_name(&self) -> String { // IETF RFC 6749 4.1.2.1 (https://datatracker.ietf.org/doc/html/rfc6749#autoid-38) match self { - OAuthErrorType::InvalidUri - | OAuthErrorType::InvalidRedirectUri(_) + OAuthErrorType::InvalidRedirectUri(_) | OAuthErrorType::ClientMissingRedirectURI { client_id: _ } => "invalid_uri", - OAuthErrorType::AuthenticationError(_) => "server_error", + OAuthErrorType::AuthenticationError(_) | OAuthErrorType::InvalidAcceptFlowId => { + "server_error" + } OAuthErrorType::UnrecognizedClient { client_id: _ } => "invalid_request", OAuthErrorType::FailedScopeParse(_) | OAuthErrorType::ScopesTooBroad => "invalid_scope", } diff --git a/src/database/models/flow_item.rs b/src/database/models/flow_item.rs index 20c6b8ea..efa52b82 100644 --- a/src/database/models/flow_item.rs +++ b/src/database/models/flow_item.rs @@ -1,4 +1,5 @@ use super::ids::*; +use crate::auth::flows::ValidatedRedirectUri; use crate::database::models::DatabaseError; use crate::database::redis::RedisPool; use crate::{auth::flows::AuthProvider, models::pats::Scopes}; @@ -38,14 +39,16 @@ pub enum Flow { user_id: UserId, client_id: OAuthClientId, scopes: Scopes, - redirect_uri: String, + validated_redirect_uri: ValidatedRedirectUri, + original_redirect_uri: Option, state: Option, }, OAuthAuthorizationCodeSupplied { user_id: UserId, client_id: OAuthClientId, scopes: Scopes, - original_redirect_uri: Option, + validated_redirect_uri: ValidatedRedirectUri, + original_redirect_uri: Option, // Needed for https://datatracker.ietf.org/doc/html/rfc6749#section-4.1.3 }, } diff --git a/src/database/models/ids.rs b/src/database/models/ids.rs index 31e246c9..ff12ae8d 100644 --- a/src/database/models/ids.rs +++ b/src/database/models/ids.rs @@ -3,7 +3,6 @@ use crate::models::ids::base62_impl::to_base62; use crate::models::ids::random_base62_rng; use censor::Censor; use serde::{Deserialize, Serialize}; -use serde_with::SerializeDisplay; use sqlx::sqlx_macros::Type; const ID_RETRY_COUNT: usize = 20; @@ -153,6 +152,14 @@ generate_ids!( ImageId ); +generate_ids!( + pub generate_oauth_client_authorization_id, + OAuthClientAuthorizationId, + 8, + "SELECT EXISTS(SELECT 1 FROM oauth_client_authorizations WHERE id=$1)", + OAuthClientAuthorizationId +); + #[derive(Copy, Clone, Debug, PartialEq, Eq, Type, Hash, Serialize, Deserialize)] #[sqlx(transparent)] pub struct UserId(pub i64); diff --git a/src/database/models/oauth_client_authorization_item.rs b/src/database/models/oauth_client_authorization_item.rs index 496793cc..cc8dbc18 100644 --- a/src/database/models/oauth_client_authorization_item.rs +++ b/src/database/models/oauth_client_authorization_item.rs @@ -1,7 +1,7 @@ use chrono::{DateTime, Utc}; use serde::{Deserialize, Serialize}; -use crate::{database::redis::RedisPool, models::pats::Scopes}; +use crate::models::pats::Scopes; use super::{DatabaseError, OAuthClientAuthorizationId, OAuthClientId, UserId}; @@ -12,7 +12,6 @@ pub struct OAuthClientAuthorization { pub user_id: UserId, pub scopes: Scopes, pub created: DateTime, - // last_used? } impl OAuthClientAuthorization { @@ -20,17 +19,52 @@ impl OAuthClientAuthorization { client_id: OAuthClientId, user_id: UserId, exec: E, - redis: &RedisPool, ) -> Result, DatabaseError> where E: sqlx::Executor<'a, Database = sqlx::Postgres>, { - return Ok(Some(OAuthClientAuthorization { - id: OAuthClientAuthorizationId(1), - client_id: OAuthClientId(1), - user_id: UserId(1), - scopes: Scopes::all(), - created: Utc::now(), + let value = sqlx::query!( + " + SELECT id, client_id, user_id, scopes, created + FROM oauth_client_authorizations + WHERE client_id=$1 AND user_id=$2 + ", + client_id.0, + user_id.0, + ) + .fetch_optional(exec) + .await?; + + return Ok(value.map(|r| OAuthClientAuthorization { + id: OAuthClientAuthorizationId(r.id), + client_id: OAuthClientId(r.client_id), + user_id: UserId(r.user_id), + scopes: Scopes::from_bits(r.scopes as u64).unwrap_or(Scopes::NONE), + created: r.created, })); } + + pub async fn insert( + &self, + transaction: &mut sqlx::Transaction<'_, sqlx::Postgres>, + ) -> Result<(), DatabaseError> { + sqlx::query!( + " + INSERT INTO oauth_client_authorizations ( + id, client_id, user_id, scopes + ) + VALUES ( + $1, $2, $3, $4 + ) + ", + self.id.0, + self.client_id.0, + self.user_id.0, + self.scopes.bits() as i64, + ) + .execute(&mut *transaction) + .await?; + + Ok(()) + } } diff --git a/src/database/models/oauth_client_item.rs b/src/database/models/oauth_client_item.rs index 4e109ee3..5c4d7038 100644 --- a/src/database/models/oauth_client_item.rs +++ b/src/database/models/oauth_client_item.rs @@ -1,7 +1,8 @@ use chrono::{DateTime, Utc}; +use itertools::Itertools; use serde::{Deserialize, Serialize}; -use crate::{database::redis::RedisPool, models::pats::Scopes}; +use crate::models::pats::Scopes; use super::{DatabaseError, OAuthClientId, OAuthRedirectUriId, UserId}; @@ -16,32 +17,124 @@ pub struct OAuthRedirectUri { pub struct OAuthClient { pub id: OAuthClientId, pub name: String, - pub icon: Option, + pub icon_url: Option, pub max_scopes: Scopes, - pub secret: String, + pub secret_hash: String, pub redirect_uris: Vec, pub created: DateTime, pub created_by: UserId, } impl OAuthClient { - pub async fn get<'a, E, T: ToString>( - id: T, + pub async fn get<'a, E>( + id: OAuthClientId, exec: E, - redis: &RedisPool, ) -> Result, DatabaseError> where E: sqlx::Executor<'a, Database = sqlx::Postgres>, { - return Ok(Some(OAuthClient { - id: OAuthClientId(1), - name: "Test Client".to_string(), - icon: None, - max_scopes: Scopes::all(), - redirect_uris: vec![], - secret: "hashed secret".to_string(), - created: Utc::now(), - created_by: UserId(1), + let value = sqlx::query!( + " + SELECT + clients.id, + clients.name, + clients.icon_url, + clients.max_scopes, + clients.secret_hash, + clients.created, + clients.created_by, + uris.uri_ids, + uris.uri_vals + FROM oauth_clients clients + LEFT JOIN ( + SELECT client_id, array_agg(id) as uri_ids, array_agg(uri) as uri_vals + FROM oauth_client_redirect_uris + GROUP BY client_id + ) uris ON clients.id = uris.client_id + WHERE clients.id = $1 + ", + id.0 + ) + .fetch_optional(exec) + .await?; + + return Ok(value.map(|r| { + let redirects = + if let (Some(ids), Some(uris)) = (r.uri_ids.as_ref(), r.uri_vals.as_ref()) { + ids.iter() + .zip(uris.iter()) + .map(|(id, uri)| OAuthRedirectUri { + id: OAuthRedirectUriId(*id), + client_id: OAuthClientId(r.id.clone()), + uri: uri.to_string(), + }) + .collect() + } else { + vec![] + }; + + OAuthClient { + id: OAuthClientId(r.id), + name: r.name, + icon_url: r.icon_url, + max_scopes: Scopes::from_bits(r.max_scopes as u64).unwrap_or(Scopes::NONE), + secret_hash: r.secret_hash, + redirect_uris: redirects, + created: r.created, + created_by: UserId(r.created_by), + } })); } + + pub async fn insert( + &self, + transaction: &mut sqlx::Transaction<'_, sqlx::Postgres>, + ) -> Result<(), DatabaseError> { + sqlx::query!( + " + INSERT INTO oauth_clients ( + id, name, icon_url, max_scopes, secret_hash, created_by + ) + VALUES ( + $1, $2, $3, $4, $5, $6 + ) + ", + self.id.0, + self.name, + self.icon_url, + self.max_scopes.bits() as i64, + self.secret_hash, + self.created_by.0 + ) + .execute(&mut *transaction) + .await?; + + self.insert_redirect_uris(transaction).await?; + + Ok(()) + } + + async fn insert_redirect_uris( + &self, + transaction: &mut sqlx::Transaction<'_, sqlx::Postgres>, + ) -> Result<(), DatabaseError> { + let (ids, client_ids, uris): (Vec<_>, Vec<_>, Vec<_>) = self + .redirect_uris + .iter() + .map(|r| (r.id.0, r.client_id.0, r.uri.clone())) + .multiunzip(); + sqlx::query!( + " + INSERT INTO oauth_client_redirect_uris (id, client_id, uri) + SELECT * FROM UNNEST($1::bigint[], $2::bigint[], $3::varchar[]) + ", + &ids[..], + &client_ids[..], + &uris[..], + ) + .execute(&mut *transaction) + .await?; + + Ok(()) + } } From 4163335826f2394bdc3b4e297527790480a29611 Mon Sep 17 00:00:00 2001 From: OmegaJak Date: Mon, 16 Oct 2023 18:36:22 -0500 Subject: [PATCH 03/72] OAuth Client create route --- src/auth/flows.rs | 2 + src/database/models/ids.rs | 26 +++++++ src/models/ids.rs | 3 + src/models/mod.rs | 1 + src/models/oauth_clients.rs | 68 ++++++++++++++++++ src/models/pats.rs | 7 ++ src/routes/v3/mod.rs | 2 + src/routes/v3/oauth_clients.rs | 121 +++++++++++++++++++++++++++++++++ 8 files changed, 230 insertions(+) create mode 100644 src/models/oauth_clients.rs create mode 100644 src/routes/v3/oauth_clients.rs diff --git a/src/auth/flows.rs b/src/auth/flows.rs index f5477071..4acc18dc 100644 --- a/src/auth/flows.rs +++ b/src/auth/flows.rs @@ -2438,6 +2438,8 @@ pub async fn accept_client_scopes( .await .map_err(OAuthError::error)?; + transaction.commit().await.map_err(OAuthError::error)?; + get_init_oauth_code_flow_response( user_id, client_id, diff --git a/src/database/models/ids.rs b/src/database/models/ids.rs index ff12ae8d..93ba7b01 100644 --- a/src/database/models/ids.rs +++ b/src/database/models/ids.rs @@ -160,6 +160,22 @@ generate_ids!( OAuthClientAuthorizationId ); +generate_ids!( + pub generate_oauth_client_id, + OAuthClientId, + 8, + "SELECT EXISTS(SELECT 1 FROM oauth_clients WHERE id=$1)", + OAuthClientId +); + +generate_ids!( + pub generate_oauth_redirect_id, + OAuthRedirectUriId, + 8, + "SELECT EXISTS(SELECT 1 FROM oauth_client_redirect_uris WHERE id=$1)", + OAuthRedirectUriId +); + #[derive(Copy, Clone, Debug, PartialEq, Eq, Type, Hash, Serialize, Deserialize)] #[sqlx(transparent)] pub struct UserId(pub i64); @@ -380,3 +396,13 @@ impl From for ids::PatId { ids::PatId(id.0 as u64) } } +impl From for ids::OAuthClientId { + fn from(id: OAuthClientId) -> Self { + ids::OAuthClientId(id.0 as u64) + } +} +impl From for ids::OAuthRedirectUriId { + fn from(id: OAuthRedirectUriId) -> Self { + ids::OAuthRedirectUriId(id.0 as u64) + } +} diff --git a/src/models/ids.rs b/src/models/ids.rs index 20166b79..63fcdeb1 100644 --- a/src/models/ids.rs +++ b/src/models/ids.rs @@ -3,6 +3,7 @@ use thiserror::Error; pub use super::collections::CollectionId; pub use super::images::ImageId; pub use super::notifications::NotificationId; +pub use super::oauth_clients::{OAuthClientId, OAuthRedirectUriId}; pub use super::organizations::OrganizationId; pub use super::pats::PatId; pub use super::projects::{ProjectId, VersionId}; @@ -122,6 +123,8 @@ base62_id_impl!(ThreadMessageId, ThreadMessageId); base62_id_impl!(SessionId, SessionId); base62_id_impl!(PatId, PatId); base62_id_impl!(ImageId, ImageId); +base62_id_impl!(OAuthClientId, OAuthClientId); +base62_id_impl!(OAuthRedirectUriId, OAuthRedirectUriId); pub mod base62_impl { use serde::de::{self, Deserializer, Visitor}; diff --git a/src/models/mod.rs b/src/models/mod.rs index e1d4ace9..7c97ad31 100644 --- a/src/models/mod.rs +++ b/src/models/mod.rs @@ -4,6 +4,7 @@ pub mod error; pub mod ids; pub mod images; pub mod notifications; +pub mod oauth_clients; pub mod organizations; pub mod pack; pub mod pats; diff --git a/src/models/oauth_clients.rs b/src/models/oauth_clients.rs new file mode 100644 index 00000000..758a5169 --- /dev/null +++ b/src/models/oauth_clients.rs @@ -0,0 +1,68 @@ +use super::{ + ids::{Base62Id, TeamId, UserId}, + pats::Scopes, + teams::TeamMember, +}; +use serde::{Deserialize, Serialize}; + +#[derive(Copy, Clone, PartialEq, Eq, Serialize, Deserialize)] +#[serde(from = "Base62Id")] +#[serde(into = "Base62Id")] +pub struct OAuthClientId(pub u64); + +#[derive(Copy, Clone, PartialEq, Eq, Serialize, Deserialize)] +#[serde(from = "Base62Id")] +#[serde(into = "Base62Id")] +pub struct OAuthRedirectUriId(pub u64); + +#[derive(Deserialize, Serialize)] +pub struct OAuthRedirectUri { + pub id: OAuthRedirectUriId, + pub client_id: OAuthClientId, + pub uri: String, +} + +#[derive(Serialize)] +pub struct OAuthClientCreationResult { + pub client_secret: String, + pub client: OAuthClient, +} + +#[derive(Deserialize, Serialize)] +pub struct OAuthClient { + pub id: OAuthClientId, + pub name: String, + pub icon_url: Option, + + // The maximum scopes the client can request for OAuth + pub max_scopes: Scopes, + + // The valid URIs that can be redirected to during an authorization request + pub redirect_uris: Vec, + + // The user that created (and thus controls) this client + pub created_by: UserId, +} + +impl From for OAuthClient { + fn from(value: crate::database::models::oauth_client_item::OAuthClient) -> Self { + Self { + id: value.id.into(), + name: value.name, + icon_url: value.icon_url, + max_scopes: value.max_scopes, + redirect_uris: todo!(), + created_by: value.created_by.into(), + } + } +} + +impl From for OAuthRedirectUri { + fn from(value: crate::database::models::oauth_client_item::OAuthRedirectUri) -> Self { + Self { + id: value.id.into(), + client_id: value.client_id.into(), + uri: value.uri, + } + } +} diff --git a/src/models/pats.rs b/src/models/pats.rs index b66cd8f2..d3e6801c 100644 --- a/src/models/pats.rs +++ b/src/models/pats.rs @@ -103,6 +103,13 @@ bitflags::bitflags! { // delete an organization const ORGANIZATION_DELETE = 1 << 38; + // create an OAuth client + const OAUTH_CLIENT_CREATE = 1 << 39; + // modify an OAuth client + const OAUTH_CLIENT_WRITE = 1 << 40; + // delete an OAuth client + const OAUTH_CLIENT_DELETE = 1 << 41; + const NONE = 0b0; } } diff --git a/src/routes/v3/mod.rs b/src/routes/v3/mod.rs index ddfb05e5..530a01a6 100644 --- a/src/routes/v3/mod.rs +++ b/src/routes/v3/mod.rs @@ -3,6 +3,8 @@ use crate::util::cors::default_cors; use actix_web::{web, HttpResponse}; use serde_json::json; +pub mod oauth_clients; + pub fn config(cfg: &mut web::ServiceConfig) { cfg.service( web::scope("v3") diff --git a/src/routes/v3/oauth_clients.rs b/src/routes/v3/oauth_clients.rs new file mode 100644 index 00000000..ccf039b3 --- /dev/null +++ b/src/routes/v3/oauth_clients.rs @@ -0,0 +1,121 @@ +use actix_web::{ + post, + web::{self}, + HttpRequest, HttpResponse, +}; +use chrono::Utc; +use rand::{distributions::Alphanumeric, Rng, SeedableRng}; +use rand_chacha::ChaCha20Rng; +use serde::Deserialize; +use sha2::Digest; +use sqlx::PgPool; +use validator::Validate; + +use crate::{ + auth::get_user_from_headers, + database::{ + models::{ + generate_oauth_client_id, generate_oauth_redirect_id, + oauth_client_item::{OAuthClient, OAuthRedirectUri}, + }, + redis::RedisPool, + }, + models::{self, oauth_clients::OAuthClientCreationResult, pats::Scopes}, + queue::session::AuthQueue, + routes::v2::project_creation::CreateError, + util::validate::validation_errors_to_string, +}; + +pub fn config(cfg: &mut web::ServiceConfig) { + cfg.service(oauth_client_create); +} + +#[derive(Deserialize, Validate)] +pub struct NewOAuthApp { + #[validate( + custom(function = "crate::util::validate::validate_name"), + length(min = 3, max = 255) + )] + pub name: String, + + #[validate( + custom(function = "crate::util::validate::validate_url"), + length(max = 255) + )] + pub icon_url: Option, + + pub max_scopes: Scopes, + + #[validate(length(min = 1))] + pub redirect_uris: Vec, +} + +#[post("oauth_app")] +pub async fn oauth_client_create<'a>( + req: HttpRequest, + new_oauth_app: web::Json, + pool: web::Data, + redis: web::Data, + session_queue: web::Data, +) -> Result { + //TODO: Figure out a better error type + let current_user = get_user_from_headers( + &req, + &**pool, + &redis, + &session_queue, + Some(&[Scopes::OAUTH_CLIENT_CREATE]), + ) + .await? + .1; + + new_oauth_app + .validate() + .map_err(|e| CreateError::ValidationError(validation_errors_to_string(e, None)))?; + + let mut transaction = pool.begin().await?; + + let client_id = generate_oauth_client_id(&mut transaction).await?; + + let client_secret = generate_oauth_client_secret(); + let client_secret_hash = format!("{:x}", sha2::Sha512::digest(client_secret.as_bytes())); + + let mut redirect_uris = vec![]; + for uri in new_oauth_app.redirect_uris.iter() { + let id = generate_oauth_redirect_id(&mut transaction).await?; + redirect_uris.push(OAuthRedirectUri { + id, + client_id, + uri: uri.to_string(), + }); + } + + let client = OAuthClient { + id: client_id, + icon_url: new_oauth_app.icon_url.clone(), + max_scopes: new_oauth_app.max_scopes, + name: new_oauth_app.name.clone(), + redirect_uris, + created: Utc::now(), + created_by: current_user.id.into(), + secret_hash: client_secret_hash, + }; + client.clone().insert(&mut transaction).await?; + + transaction.commit().await?; + + let client = models::oauth_clients::OAuthClient::from(client); + + Ok(HttpResponse::Ok().json(OAuthClientCreationResult { + client, + client_secret, + })) +} + +fn generate_oauth_client_secret() -> String { + ChaCha20Rng::from_entropy() + .sample_iter(&Alphanumeric) + .take(32) + .map(char::from) + .collect::() +} From 4d1b0bfae325f2de88a1ce21a9ca482dcb07c19a Mon Sep 17 00:00:00 2001 From: OmegaJak Date: Tue, 17 Oct 2023 11:35:38 -0500 Subject: [PATCH 04/72] Get user clients --- src/auth/flows.rs | 32 ++++-- src/database/models/oauth_client_item.rs | 118 +++++++++++++++-------- src/models/pats.rs | 10 +- src/routes/v3/oauth_clients.rs | 48 ++++++++- 4 files changed, 154 insertions(+), 54 deletions(-) diff --git a/src/auth/flows.rs b/src/auth/flows.rs index 4acc18dc..7e29330d 100644 --- a/src/auth/flows.rs +++ b/src/auth/flows.rs @@ -2304,15 +2304,21 @@ pub struct OAuthClientAccessRequest { #[get("oauth/authorize")] pub async fn init_oauth( req: HttpRequest, - Query(oauth_info): Query, // callback url + Query(oauth_info): Query, pool: Data, redis: Data, session_queue: Data, ) -> Result { - let user = get_user_from_headers(&req, &**pool, &redis, &session_queue, None) - .await - .map_err(OAuthError::error)? - .1; + let user = get_user_from_headers( + &req, + &**pool, + &redis, + &session_queue, + Some(&[Scopes::USER_AUTH_WRITE]), + ) + .await + .map_err(OAuthError::error)? + .1; let client = DBOAuthClient::get(oauth_info.client_id, &**pool) .await @@ -2400,15 +2406,21 @@ pub struct AcceptOAuthClientScopes { #[get("oauth/accept")] pub async fn accept_client_scopes( req: HttpRequest, - accept_body: web::Json, // callback url + accept_body: web::Json, pool: Data, redis: Data, session_queue: Data, ) -> Result { - let user = get_user_from_headers(&req, &**pool, &redis, &session_queue, None) - .await - .map_err(|e| OAuthError::error(e))? - .1; + let user = get_user_from_headers( + &req, + &**pool, + &redis, + &session_queue, + Some(&[Scopes::USER_AUTH_WRITE]), + ) + .await + .map_err(|e| OAuthError::error(e))? + .1; let flow = Flow::get(&accept_body.flow, &redis) .await diff --git a/src/database/models/oauth_client_item.rs b/src/database/models/oauth_client_item.rs index 5c4d7038..0458cd99 100644 --- a/src/database/models/oauth_client_item.rs +++ b/src/database/models/oauth_client_item.rs @@ -25,15 +25,22 @@ pub struct OAuthClient { pub created_by: UserId, } -impl OAuthClient { - pub async fn get<'a, E>( - id: OAuthClientId, - exec: E, - ) -> Result, DatabaseError> - where - E: sqlx::Executor<'a, Database = sqlx::Postgres>, - { - let value = sqlx::query!( +struct ClientQueryResult { + id: i64, + name: String, + icon_url: Option, + max_scopes: i64, + secret_hash: String, + created: DateTime, + created_by: i64, + uri_ids: Option>, + uri_vals: Option>, +} + +macro_rules! select_clients_with_predicate { + ($predicate:tt, $param:ident) => { + sqlx::query_as!( + ClientQueryResult, " SELECT clients.id, @@ -51,39 +58,42 @@ impl OAuthClient { FROM oauth_client_redirect_uris GROUP BY client_id ) uris ON clients.id = uris.client_id - WHERE clients.id = $1 - ", - id.0 + " + + $predicate, + $param ) - .fetch_optional(exec) - .await?; + }; +} + +impl OAuthClient { + pub async fn get<'a, E>( + id: OAuthClientId, + exec: E, + ) -> Result, DatabaseError> + where + E: sqlx::Executor<'a, Database = sqlx::Postgres>, + { + let client_id_param = id.0; + let value = select_clients_with_predicate!("WHERE clients.id = $1", client_id_param) + .fetch_optional(exec) + .await?; - return Ok(value.map(|r| { - let redirects = - if let (Some(ids), Some(uris)) = (r.uri_ids.as_ref(), r.uri_vals.as_ref()) { - ids.iter() - .zip(uris.iter()) - .map(|(id, uri)| OAuthRedirectUri { - id: OAuthRedirectUriId(*id), - client_id: OAuthClientId(r.id.clone()), - uri: uri.to_string(), - }) - .collect() - } else { - vec![] - }; - - OAuthClient { - id: OAuthClientId(r.id), - name: r.name, - icon_url: r.icon_url, - max_scopes: Scopes::from_bits(r.max_scopes as u64).unwrap_or(Scopes::NONE), - secret_hash: r.secret_hash, - redirect_uris: redirects, - created: r.created, - created_by: UserId(r.created_by), - } - })); + return Ok(value.map(|r| r.into())); + } + + pub async fn get_all_user_clients<'a, E>( + user_id: UserId, + exec: E, + ) -> Result, DatabaseError> + where + E: sqlx::Executor<'a, Database = sqlx::Postgres>, + { + let user_id_param = user_id.0; + let clients = select_clients_with_predicate!("WHERE created_by = $1", user_id_param) + .fetch_all(exec) + .await?; + + return Ok(clients.into_iter().map(|r| r.into()).collect()); } pub async fn insert( @@ -138,3 +148,31 @@ impl OAuthClient { Ok(()) } } + +impl From for OAuthClient { + fn from(r: ClientQueryResult) -> Self { + let redirects = if let (Some(ids), Some(uris)) = (r.uri_ids.as_ref(), r.uri_vals.as_ref()) { + ids.iter() + .zip(uris.iter()) + .map(|(id, uri)| OAuthRedirectUri { + id: OAuthRedirectUriId(*id), + client_id: OAuthClientId(r.id.clone()), + uri: uri.to_string(), + }) + .collect() + } else { + vec![] + }; + + OAuthClient { + id: OAuthClientId(r.id), + name: r.name, + icon_url: r.icon_url, + max_scopes: Scopes::from_bits(r.max_scopes as u64).unwrap_or(Scopes::NONE), + secret_hash: r.secret_hash, + redirect_uris: redirects, + created: r.created, + created_by: UserId(r.created_by), + } + } +} diff --git a/src/models/pats.rs b/src/models/pats.rs index d3e6801c..6035eeef 100644 --- a/src/models/pats.rs +++ b/src/models/pats.rs @@ -105,10 +105,12 @@ bitflags::bitflags! { // create an OAuth client const OAUTH_CLIENT_CREATE = 1 << 39; + // get oauth clients + const OAUTH_CLIENT_READ = 1 << 40; // modify an OAuth client - const OAUTH_CLIENT_WRITE = 1 << 40; + const OAUTH_CLIENT_WRITE = 1 << 41; // delete an OAuth client - const OAUTH_CLIENT_DELETE = 1 << 41; + const OAUTH_CLIENT_DELETE = 1 << 42; const NONE = 0b0; } @@ -128,6 +130,10 @@ impl Scopes { | Scopes::USER_AUTH_WRITE | Scopes::USER_DELETE | Scopes::PERFORM_ANALYTICS + | Scopes::OAUTH_CLIENT_CREATE + | Scopes::OAUTH_CLIENT_READ + | Scopes::OAUTH_CLIENT_WRITE + | Scopes::OAUTH_CLIENT_DELETE } pub fn is_restricted(&self) -> bool { diff --git a/src/routes/v3/oauth_clients.rs b/src/routes/v3/oauth_clients.rs index ccf039b3..7322421a 100644 --- a/src/routes/v3/oauth_clients.rs +++ b/src/routes/v3/oauth_clients.rs @@ -1,9 +1,10 @@ use actix_web::{ - post, + get, post, web::{self}, HttpRequest, HttpResponse, }; use chrono::Utc; +use itertools::Itertools; use rand::{distributions::Alphanumeric, Rng, SeedableRng}; use rand_chacha::ChaCha20Rng; use serde::Deserialize; @@ -11,12 +12,14 @@ use sha2::Digest; use sqlx::PgPool; use validator::Validate; +use super::ApiError; use crate::{ auth::get_user_from_headers, database::{ models::{ generate_oauth_client_id, generate_oauth_redirect_id, oauth_client_item::{OAuthClient, OAuthRedirectUri}, + User, UserId, }, redis::RedisPool, }, @@ -28,6 +31,7 @@ use crate::{ pub fn config(cfg: &mut web::ServiceConfig) { cfg.service(oauth_client_create); + cfg.service(get_user_clients); } #[derive(Deserialize, Validate)] @@ -50,6 +54,46 @@ pub struct NewOAuthApp { pub redirect_uris: Vec, } +#[get("user/{user_id}/oauth_apps")] +pub async fn get_user_clients( + req: HttpRequest, + info: web::Path<(String,)>, + pool: web::Data, + redis: web::Data, + session_queue: web::Data, +) -> Result { + let current_user = get_user_from_headers( + &req, + &**pool, + &redis, + &session_queue, + Some(&[Scopes::OAUTH_CLIENT_READ]), + ) + .await? + .1; + + let target_user = User::get(&info.into_inner().0, &**pool, &redis).await?; + + if let Some(target_user) = target_user { + let target_user_id: models::ids::UserId = target_user.id.into(); + + if current_user.role.is_mod() || current_user.id == target_user_id { + let clients = OAuthClient::get_all_user_clients(target_user.id, &**pool).await?; + + let response = clients + .into_iter() + .map(|c| models::oauth_clients::OAuthClient::from(c)) + .collect_vec(); + + Ok(HttpResponse::Ok().json(response)) + } else { + todo!() + } + } else { + Ok(HttpResponse::NotFound().body("")) + } +} + #[post("oauth_app")] pub async fn oauth_client_create<'a>( req: HttpRequest, @@ -64,7 +108,7 @@ pub async fn oauth_client_create<'a>( &**pool, &redis, &session_queue, - Some(&[Scopes::OAUTH_CLIENT_CREATE]), + Some(&[Scopes::OAUTH_CLIENT_WRITE]), ) .await? .1; From d4b37dee33b63062ca49fe3bb57140cb6ab64ee2 Mon Sep 17 00:00:00 2001 From: OmegaJak Date: Tue, 17 Oct 2023 12:57:23 -0500 Subject: [PATCH 05/72] Client delete --- migrations/20231016190056_oauth_provider.sql | 4 +- src/database/models/oauth_client_item.rs | 18 +++++++ src/models/users.rs | 6 +++ src/routes/v3/oauth_clients.rs | 55 +++++++++++++++++--- 4 files changed, 74 insertions(+), 9 deletions(-) diff --git a/migrations/20231016190056_oauth_provider.sql b/migrations/20231016190056_oauth_provider.sql index a47704e8..081b15a8 100644 --- a/migrations/20231016190056_oauth_provider.sql +++ b/migrations/20231016190056_oauth_provider.sql @@ -9,12 +9,12 @@ CREATE TABLE oauth_clients ( ); CREATE TABLE oauth_client_redirect_uris ( id bigint PRIMARY KEY, - client_id bigint REFERENCES oauth_clients(id) NOT NULL, + client_id bigint REFERENCES oauth_clients(id) NOT NULL ON DELETE CASCADE, uri varchar(255) ); CREATE TABLE oauth_client_authorizations ( id bigint PRIMARY KEY, - client_id bigint NOT NULL REFERENCES oauth_clients(id), + client_id bigint NOT NULL REFERENCES oauth_clients(id) ON DELETE CASCADE, user_id bigint NOT NULL REFERENCES users(id), scopes bigint NOT NULL, created timestamptz NOT NULL DEFAULT CURRENT_TIMESTAMP diff --git a/src/database/models/oauth_client_item.rs b/src/database/models/oauth_client_item.rs index 0458cd99..18686311 100644 --- a/src/database/models/oauth_client_item.rs +++ b/src/database/models/oauth_client_item.rs @@ -96,6 +96,24 @@ impl OAuthClient { return Ok(clients.into_iter().map(|r| r.into()).collect()); } + pub async fn remove<'a, E>(id: OAuthClientId, exec: E) -> Result<(), DatabaseError> + where + E: sqlx::Executor<'a, Database = sqlx::Postgres>, + { + // Cascades to oauth_client_redirect_uris, oauth_client_authorizations + sqlx::query!( + " + DELETE FROM oauth_clients + WHERE id = $1 + ", + id.0 + ) + .execute(exec) + .await?; + + Ok(()) + } + pub async fn insert( &self, transaction: &mut sqlx::Transaction<'_, sqlx::Postgres>, diff --git a/src/models/users.rs b/src/models/users.rs index 8a012bf1..cba77618 100644 --- a/src/models/users.rs +++ b/src/models/users.rs @@ -89,6 +89,12 @@ impl From for User { } } +impl User { + pub fn can_interact_with_oauth_client(&self, client_owner_id: UserId) -> bool { + self.role.is_mod() || self.id == client_owner_id + } +} + #[derive(Serialize, Deserialize, PartialEq, Eq, Clone, Debug)] #[serde(rename_all = "lowercase")] pub enum Role { diff --git a/src/routes/v3/oauth_clients.rs b/src/routes/v3/oauth_clients.rs index 7322421a..bcfd4c94 100644 --- a/src/routes/v3/oauth_clients.rs +++ b/src/routes/v3/oauth_clients.rs @@ -1,5 +1,5 @@ use actix_web::{ - get, post, + delete, get, post, web::{self}, HttpRequest, HttpResponse, }; @@ -19,11 +19,14 @@ use crate::{ models::{ generate_oauth_client_id, generate_oauth_redirect_id, oauth_client_item::{OAuthClient, OAuthRedirectUri}, - User, UserId, + OAuthClientId, User, UserId, }, redis::RedisPool, }, - models::{self, oauth_clients::OAuthClientCreationResult, pats::Scopes}, + models::{ + self, ids::base62_impl::parse_base62, oauth_clients::OAuthClientCreationResult, + pats::Scopes, + }, queue::session::AuthQueue, routes::v2::project_creation::CreateError, util::validate::validation_errors_to_string, @@ -32,6 +35,7 @@ use crate::{ pub fn config(cfg: &mut web::ServiceConfig) { cfg.service(oauth_client_create); cfg.service(get_user_clients); + cfg.service(oauth_client_delete); } #[derive(Deserialize, Validate)] @@ -57,7 +61,7 @@ pub struct NewOAuthApp { #[get("user/{user_id}/oauth_apps")] pub async fn get_user_clients( req: HttpRequest, - info: web::Path<(String,)>, + info: web::Path, pool: web::Data, redis: web::Data, session_queue: web::Data, @@ -72,12 +76,12 @@ pub async fn get_user_clients( .await? .1; - let target_user = User::get(&info.into_inner().0, &**pool, &redis).await?; + let target_user = User::get(&info.into_inner(), &**pool, &redis).await?; if let Some(target_user) = target_user { let target_user_id: models::ids::UserId = target_user.id.into(); - if current_user.role.is_mod() || current_user.id == target_user_id { + if current_user.can_interact_with_oauth_client(target_user_id) { let clients = OAuthClient::get_all_user_clients(target_user.id, &**pool).await?; let response = clients @@ -87,7 +91,9 @@ pub async fn get_user_clients( Ok(HttpResponse::Ok().json(response)) } else { - todo!() + Err(ApiError::CustomAuthentication( + "You don't have permission to view this user's OAuth Applications".to_string(), + )) } } else { Ok(HttpResponse::NotFound().body("")) @@ -156,6 +162,41 @@ pub async fn oauth_client_create<'a>( })) } +#[delete("oauth_app/{id}")] +pub async fn oauth_client_delete<'a>( + req: HttpRequest, + client_id: web::Path, + pool: web::Data, + redis: web::Data, + session_queue: web::Data, +) -> Result { + let current_user = get_user_from_headers( + &req, + &**pool, + &redis, + &session_queue, + Some(&[Scopes::OAUTH_CLIENT_DELETE]), + ) + .await? + .1; + + let client_id = OAuthClientId(parse_base62(&client_id.into_inner())? as i64); + let client = OAuthClient::get(client_id, &**pool).await?; + if let Some(client) = client { + if current_user.can_interact_with_oauth_client(client.created_by.into()) { + OAuthClient::remove(client_id, &**pool).await?; + + Ok(HttpResponse::NoContent().body("")) + } else { + Err(ApiError::CustomAuthentication( + "You don't have permission to delete this client".to_string(), + )) + } + } else { + Ok(HttpResponse::NotFound().body("")) + } +} + fn generate_oauth_client_secret() -> String { ChaCha20Rng::from_entropy() .sample_iter(&Alphanumeric) From 631e95331482b9e74353f4fae54a43a419def5a5 Mon Sep 17 00:00:00 2001 From: OmegaJak Date: Tue, 17 Oct 2023 19:59:52 -0500 Subject: [PATCH 06/72] Edit oauth client --- src/database/models/oauth_client_item.rs | 23 +++- src/models/pats.rs | 4 + src/models/users.rs | 14 +- src/routes/v3/oauth_clients.rs | 161 ++++++++++++++++------- 4 files changed, 155 insertions(+), 47 deletions(-) diff --git a/src/database/models/oauth_client_item.rs b/src/database/models/oauth_client_item.rs index 18686311..cfe01a39 100644 --- a/src/database/models/oauth_client_item.rs +++ b/src/database/models/oauth_client_item.rs @@ -130,7 +130,7 @@ impl OAuthClient { self.id.0, self.name, self.icon_url, - self.max_scopes.bits() as i64, + self.max_scopes.to_postgres(), self.secret_hash, self.created_by.0 ) @@ -142,6 +142,27 @@ impl OAuthClient { Ok(()) } + pub async fn update_editable_fields( + &self, + exec: impl sqlx::Executor<'_, Database = sqlx::Postgres>, + ) -> Result<(), DatabaseError> { + sqlx::query!( + " + UPDATE oauth_clients + SET name = $1, icon_url = $2, max_scopes = $3 + WHERE (id = $4) + ", + self.name, + self.icon_url, + self.max_scopes.to_postgres(), + self.id.0, + ) + .execute(exec) + .await?; + + Ok(()) + } + async fn insert_redirect_uris( &self, transaction: &mut sqlx::Transaction<'_, sqlx::Postgres>, diff --git a/src/models/pats.rs b/src/models/pats.rs index 6035eeef..755115bf 100644 --- a/src/models/pats.rs +++ b/src/models/pats.rs @@ -144,6 +144,10 @@ impl Scopes { let scopes = scopes.replace(' ', "|").replace("%20", "|"); bitflags::parser::from_str(&scopes) } + + pub fn to_postgres(&self) -> i64 { + self.bits() as i64 + } } #[derive(Serialize, Deserialize)] diff --git a/src/models/users.rs b/src/models/users.rs index cba77618..19aa16d5 100644 --- a/src/models/users.rs +++ b/src/models/users.rs @@ -90,8 +90,18 @@ impl From for User { } impl User { - pub fn can_interact_with_oauth_client(&self, client_owner_id: UserId) -> bool { - self.role.is_mod() || self.id == client_owner_id + pub fn validate_can_interact_with_oauth_client( + &self, + client_owner_id: UserId, + ) -> Result<(), crate::routes::ApiError> { + if self.role.is_mod() || self.id == client_owner_id { + Ok(()) + } else { + Err(crate::routes::ApiError::CustomAuthentication( + "You don't have sufficient permissions to interact with this OAuth application" + .to_string(), + )) + } } } diff --git a/src/routes/v3/oauth_clients.rs b/src/routes/v3/oauth_clients.rs index bcfd4c94..87aa02b2 100644 --- a/src/routes/v3/oauth_clients.rs +++ b/src/routes/v3/oauth_clients.rs @@ -1,5 +1,5 @@ use actix_web::{ - delete, get, post, + delete, get, patch, post, web::{self}, HttpRequest, HttpResponse, }; @@ -38,26 +38,6 @@ pub fn config(cfg: &mut web::ServiceConfig) { cfg.service(oauth_client_delete); } -#[derive(Deserialize, Validate)] -pub struct NewOAuthApp { - #[validate( - custom(function = "crate::util::validate::validate_name"), - length(min = 3, max = 255) - )] - pub name: String, - - #[validate( - custom(function = "crate::util::validate::validate_url"), - length(max = 255) - )] - pub icon_url: Option, - - pub max_scopes: Scopes, - - #[validate(length(min = 1))] - pub redirect_uris: Vec, -} - #[get("user/{user_id}/oauth_apps")] pub async fn get_user_clients( req: HttpRequest, @@ -80,26 +60,41 @@ pub async fn get_user_clients( if let Some(target_user) = target_user { let target_user_id: models::ids::UserId = target_user.id.into(); + current_user.validate_can_interact_with_oauth_client(target_user_id)?; - if current_user.can_interact_with_oauth_client(target_user_id) { - let clients = OAuthClient::get_all_user_clients(target_user.id, &**pool).await?; + let clients = OAuthClient::get_all_user_clients(target_user.id, &**pool).await?; - let response = clients - .into_iter() - .map(|c| models::oauth_clients::OAuthClient::from(c)) - .collect_vec(); + let response = clients + .into_iter() + .map(|c| models::oauth_clients::OAuthClient::from(c)) + .collect_vec(); - Ok(HttpResponse::Ok().json(response)) - } else { - Err(ApiError::CustomAuthentication( - "You don't have permission to view this user's OAuth Applications".to_string(), - )) - } + Ok(HttpResponse::Ok().json(response)) } else { Ok(HttpResponse::NotFound().body("")) } } +#[derive(Deserialize, Validate)] +pub struct NewOAuthApp { + #[validate( + custom(function = "crate::util::validate::validate_name"), + length(min = 3, max = 255) + )] + pub name: String, + + #[validate( + custom(function = "crate::util::validate::validate_url"), + length(max = 255) + )] + pub icon_url: Option, + + pub max_scopes: Scopes, + + #[validate(length(min = 1))] + pub redirect_uris: Vec, +} + #[post("oauth_app")] pub async fn oauth_client_create<'a>( req: HttpRequest, @@ -108,7 +103,6 @@ pub async fn oauth_client_create<'a>( redis: web::Data, session_queue: web::Data, ) -> Result { - //TODO: Figure out a better error type let current_user = get_user_from_headers( &req, &**pool, @@ -180,18 +174,89 @@ pub async fn oauth_client_delete<'a>( .await? .1; - let client_id = OAuthClientId(parse_base62(&client_id.into_inner())? as i64); - let client = OAuthClient::get(client_id, &**pool).await?; + let client = get_oauth_client_from_str_id(client_id.into_inner(), &**pool).await?; if let Some(client) = client { - if current_user.can_interact_with_oauth_client(client.created_by.into()) { - OAuthClient::remove(client_id, &**pool).await?; - - Ok(HttpResponse::NoContent().body("")) - } else { - Err(ApiError::CustomAuthentication( - "You don't have permission to delete this client".to_string(), - )) + current_user.validate_can_interact_with_oauth_client(client.created_by.into())?; + OAuthClient::remove(client.id, &**pool).await?; + + Ok(HttpResponse::NoContent().body("")) + } else { + Ok(HttpResponse::NotFound().body("")) + } +} + +#[derive(Deserialize, Validate)] +pub struct OAuthClientEdit { + #[validate( + custom(function = "crate::util::validate::validate_name"), + length(min = 3, max = 255) + )] + pub name: Option, + + #[validate( + custom(function = "crate::util::validate::validate_url"), + length(max = 255) + )] + pub icon_url: Option>, + + pub max_scopes: Option, +} + +#[patch("oauth_app/{id}")] +pub async fn oauth_client_edit( + req: HttpRequest, + client_id: web::Path, + client_updates: web::Json, + pool: web::Data, + redis: web::Data, + session_queue: web::Data, +) -> Result { + let current_user = get_user_from_headers( + &req, + &**pool, + &redis, + &session_queue, + Some(&[Scopes::OAUTH_CLIENT_DELETE]), + ) + .await? + .1; + + client_updates + .validate() + .map_err(|e| ApiError::Validation(validation_errors_to_string(e, None)))?; + + if client_updates.icon_url.is_none() + && client_updates.name.is_none() + && client_updates.max_scopes.is_none() + { + return Err(ApiError::InvalidInput("No changes provided".to_string())); + } + + if let Some(existing_client) = + get_oauth_client_from_str_id(client_id.into_inner(), &**pool).await? + { + current_user.validate_can_interact_with_oauth_client(existing_client.created_by.into())?; + + let mut updated_client = existing_client.clone(); + let OAuthClientEdit { + name, + icon_url, + max_scopes, + } = client_updates.into_inner(); + if let Some(name) = name { + updated_client.name = name; + } + + if let Some(icon_url) = icon_url { + updated_client.icon_url = icon_url; } + + if let Some(max_scopes) = max_scopes { + updated_client.max_scopes = max_scopes; + } + + updated_client.update_editable_fields(&**pool).await?; + Ok(HttpResponse::Ok().body("")) } else { Ok(HttpResponse::NotFound().body("")) } @@ -204,3 +269,11 @@ fn generate_oauth_client_secret() -> String { .map(char::from) .collect::() } + +async fn get_oauth_client_from_str_id( + client_id: String, + pool: impl sqlx::Executor<'_, Database = sqlx::Postgres>, +) -> Result, ApiError> { + let client_id = OAuthClientId(parse_base62(&client_id)? as i64); + Ok(OAuthClient::get(client_id, pool).await?) +} From 107a53bea021f16fb81b1d58210b1ef0acd3cddc Mon Sep 17 00:00:00 2001 From: OmegaJak Date: Tue, 17 Oct 2023 20:53:30 -0500 Subject: [PATCH 07/72] Include redirects in edit client route --- src/database/models/oauth_client_item.rs | 32 ++++++++-- src/routes/v3/oauth_clients.rs | 81 ++++++++++++++++++++---- 2 files changed, 93 insertions(+), 20 deletions(-) diff --git a/src/database/models/oauth_client_item.rs b/src/database/models/oauth_client_item.rs index cfe01a39..6512409b 100644 --- a/src/database/models/oauth_client_item.rs +++ b/src/database/models/oauth_client_item.rs @@ -137,7 +137,7 @@ impl OAuthClient { .execute(&mut *transaction) .await?; - self.insert_redirect_uris(transaction).await?; + Self::insert_redirect_uris(&self.redirect_uris, transaction).await?; Ok(()) } @@ -163,12 +163,30 @@ impl OAuthClient { Ok(()) } - async fn insert_redirect_uris( - &self, - transaction: &mut sqlx::Transaction<'_, sqlx::Postgres>, + pub async fn remove_redirect_uris( + ids: impl IntoIterator, + exec: impl sqlx::Executor<'_, Database = sqlx::Postgres>, + ) -> Result<(), DatabaseError> { + let ids = ids.into_iter().map(|id| id.0).collect_vec(); + sqlx::query!( + " + DELETE FROM oauth_clients + WHERE id IN + (SELECT * FROM UNNEST($1::bigint[])) + ", + &ids[..] + ) + .execute(exec) + .await?; + + Ok(()) + } + + pub async fn insert_redirect_uris( + uris: &[OAuthRedirectUri], + exec: impl sqlx::Executor<'_, Database = sqlx::Postgres>, ) -> Result<(), DatabaseError> { - let (ids, client_ids, uris): (Vec<_>, Vec<_>, Vec<_>) = self - .redirect_uris + let (ids, client_ids, uris): (Vec<_>, Vec<_>, Vec<_>) = uris .iter() .map(|r| (r.id.0, r.client_id.0, r.uri.clone())) .multiunzip(); @@ -181,7 +199,7 @@ impl OAuthClient { &client_ids[..], &uris[..], ) - .execute(&mut *transaction) + .execute(exec) .await?; Ok(()) diff --git a/src/routes/v3/oauth_clients.rs b/src/routes/v3/oauth_clients.rs index 87aa02b2..b4a1a47b 100644 --- a/src/routes/v3/oauth_clients.rs +++ b/src/routes/v3/oauth_clients.rs @@ -1,3 +1,5 @@ +use std::{collections::HashSet, fmt::Display, iter::FromIterator}; + use actix_web::{ delete, get, patch, post, web::{self}, @@ -19,7 +21,7 @@ use crate::{ models::{ generate_oauth_client_id, generate_oauth_redirect_id, oauth_client_item::{OAuthClient, OAuthRedirectUri}, - OAuthClientId, User, UserId, + DatabaseError, OAuthClientId, User, UserId, }, redis::RedisPool, }, @@ -124,15 +126,8 @@ pub async fn oauth_client_create<'a>( let client_secret = generate_oauth_client_secret(); let client_secret_hash = format!("{:x}", sha2::Sha512::digest(client_secret.as_bytes())); - let mut redirect_uris = vec![]; - for uri in new_oauth_app.redirect_uris.iter() { - let id = generate_oauth_redirect_id(&mut transaction).await?; - redirect_uris.push(OAuthRedirectUri { - id, - client_id, - uri: uri.to_string(), - }); - } + let redirect_uris = + create_redirect_uris(&new_oauth_app.redirect_uris, client_id, &mut transaction).await?; let client = OAuthClient { id: client_id, @@ -200,6 +195,9 @@ pub struct OAuthClientEdit { pub icon_url: Option>, pub max_scopes: Option, + + #[validate(length(min = 1))] + pub redirect_uris: Option>, } #[patch("oauth_app/{id}")] @@ -242,6 +240,7 @@ pub async fn oauth_client_edit( name, icon_url, max_scopes, + redirect_uris, } = client_updates.into_inner(); if let Some(name) = name { updated_client.name = name; @@ -255,7 +254,17 @@ pub async fn oauth_client_edit( updated_client.max_scopes = max_scopes; } - updated_client.update_editable_fields(&**pool).await?; + let mut transaction = pool.begin().await?; + updated_client + .update_editable_fields(&mut transaction) + .await?; + + if let Some(redirects) = redirect_uris { + edit_redirects(redirects, &existing_client, &mut transaction).await?; + } + + transaction.commit().await?; + Ok(HttpResponse::Ok().body("")) } else { Ok(HttpResponse::NotFound().body("")) @@ -272,8 +281,54 @@ fn generate_oauth_client_secret() -> String { async fn get_oauth_client_from_str_id( client_id: String, - pool: impl sqlx::Executor<'_, Database = sqlx::Postgres>, + exec: impl sqlx::Executor<'_, Database = sqlx::Postgres>, ) -> Result, ApiError> { let client_id = OAuthClientId(parse_base62(&client_id)? as i64); - Ok(OAuthClient::get(client_id, pool).await?) + Ok(OAuthClient::get(client_id, exec).await?) +} + +async fn create_redirect_uris( + uri_strings: impl IntoIterator, + client_id: OAuthClientId, + transaction: &mut sqlx::Transaction<'_, sqlx::Postgres>, +) -> Result, DatabaseError> { + let mut redirect_uris = vec![]; + for uri in uri_strings.into_iter() { + let id = generate_oauth_redirect_id(transaction).await?; + redirect_uris.push(OAuthRedirectUri { + id, + client_id, + uri: uri.to_string(), + }); + } + + Ok(redirect_uris) +} + +async fn edit_redirects( + redirects: Vec, + existing_client: &OAuthClient, + transaction: &mut sqlx::Transaction<'_, sqlx::Postgres>, +) -> Result<(), DatabaseError> { + let updated_redirects: HashSet = redirects.into_iter().collect(); + let original_redirects: HashSet = existing_client + .redirect_uris + .iter() + .map(|r| r.uri.to_string()) + .collect(); + + let redirects_to_add = create_redirect_uris( + updated_redirects.difference(&original_redirects), + existing_client.id, + &mut *transaction, + ) + .await?; + OAuthClient::insert_redirect_uris(&redirects_to_add, &mut *transaction).await?; + + let mut redirects_to_remove = existing_client.redirect_uris.clone(); + redirects_to_remove.retain(|r| !updated_redirects.contains(&r.uri)); + OAuthClient::remove_redirect_uris(redirects_to_remove.iter().map(|r| r.id), &mut *transaction) + .await?; + + Ok(()) } From 6c8bbf8610bc5711acd6e50b118c6451d23841ab Mon Sep 17 00:00:00 2001 From: OmegaJak Date: Wed, 18 Oct 2023 12:04:28 -0500 Subject: [PATCH 08/72] Database stuff for tokens --- migrations/20231016190056_oauth_provider.sql | 29 +++++-- src/auth/flows.rs | 85 ++++++++++--------- src/database/models/flow_item.rs | 1 + src/database/models/ids.rs | 12 +++ src/database/models/mod.rs | 1 + src/database/models/oauth_client_item.rs | 28 +++---- src/database/models/oauth_token_item.rs | 86 ++++++++++++++++++++ src/models/oauth_clients.rs | 5 +- src/models/pats.rs | 4 + src/routes/v3/oauth_clients.rs | 4 +- 10 files changed, 187 insertions(+), 68 deletions(-) create mode 100644 src/database/models/oauth_token_item.rs diff --git a/migrations/20231016190056_oauth_provider.sql b/migrations/20231016190056_oauth_provider.sql index 081b15a8..3e290d48 100644 --- a/migrations/20231016190056_oauth_provider.sql +++ b/migrations/20231016190056_oauth_provider.sql @@ -1,21 +1,34 @@ CREATE TABLE oauth_clients ( id bigint PRIMARY KEY, - name varchar(255) NOT NULL, - icon_url varchar(255) NULL, + name text NOT NULL, + icon_url text NULL, max_scopes bigint NOT NULL, - secret_hash char(512) NOT NULL, + secret_hash text NOT NULL, created timestamptz NOT NULL DEFAULT CURRENT_TIMESTAMP, created_by bigint NOT NULL REFERENCES users(id) ); CREATE TABLE oauth_client_redirect_uris ( id bigint PRIMARY KEY, - client_id bigint REFERENCES oauth_clients(id) NOT NULL ON DELETE CASCADE, - uri varchar(255) + client_id bigint NOT NULL REFERENCES oauth_clients (id) ON DELETE CASCADE, + uri text ); CREATE TABLE oauth_client_authorizations ( id bigint PRIMARY KEY, - client_id bigint NOT NULL REFERENCES oauth_clients(id) ON DELETE CASCADE, - user_id bigint NOT NULL REFERENCES users(id), + client_id bigint NOT NULL REFERENCES oauth_clients (id) ON DELETE CASCADE, + user_id bigint NOT NULL REFERENCES users (id) ON DELETE CASCADE, scopes bigint NOT NULL, created timestamptz NOT NULL DEFAULT CURRENT_TIMESTAMP -); \ No newline at end of file +); +CREATE TABLE oauth_access_tokens ( + id bigint PRIMARY KEY, + authorization_id bigint NOT NULL REFERENCES oauth_client_authorizations(id) ON DELETE CASCADE, + token_hash text NOT NULL UNIQUE, + scopes bigint NOT NULL, + created timestamptz NOT NULL DEFAULT CURRENT_TIMESTAMP, + expires timestamptz NOT NULL, + last_used timestamptz NULL +); +CREATE INDEX oauth_client_creator ON oauth_clients(created_by); +CREATE INDEX oauth_redirect_client ON oauth_client_redirect_uris(client_id); +CREATE INDEX oauth_authorization_client_user ON oauth_client_authorizations(client_id, user_id); +CREATE INDEX oauth_access_token_hash ON oauth_access_tokens(token_hash); \ No newline at end of file diff --git a/src/auth/flows.rs b/src/auth/flows.rs index 7e29330d..c1f5bcd6 100644 --- a/src/auth/flows.rs +++ b/src/auth/flows.rs @@ -5,7 +5,9 @@ use crate::auth::{get_user_from_headers, AuthenticationError, OAuthError, OAuthE use crate::database::models::flow_item::Flow; use crate::database::models::oauth_client_authorization_item::OAuthClientAuthorization; use crate::database::models::oauth_client_item::OAuthClient as DBOAuthClient; -use crate::database::models::{generate_oauth_client_authorization_id, OAuthClientId}; +use crate::database::models::{ + generate_oauth_client_authorization_id, OAuthClientAuthorizationId, OAuthClientId, +}; use crate::database::redis::RedisPool; use crate::file_hosting::FileHost; use crate::models::ids::base62_impl::{parse_base62, to_base62}; @@ -2356,40 +2358,44 @@ pub async fn init_oauth( OAuthClientAuthorization::get(client.id, user.id.into(), &**pool) .await .map_err(|e| OAuthError::redirect(e, &oauth_info.state, &redirect_uri))?; - let has_already_accepted = - existing_authorization.is_some_and(|auth| auth.scopes.contains(requested_scopes)); - if has_already_accepted { - get_init_oauth_code_flow_response( - user.id.into(), - client.id, - requested_scopes, - redirect_uri, - oauth_info.redirect_uri, - oauth_info.state, - &redis, - ) - .await - } else { - let flow_id = Flow::InitOAuthAppApproval { - user_id: user.id.into(), - client_id: client.id, - scopes: requested_scopes, - validated_redirect_uri: redirect_uri.clone(), - original_redirect_uri: oauth_info.redirect_uri.clone(), - state: oauth_info.state.clone(), + match existing_authorization { + Some(existing_authorization) + if existing_authorization.scopes.contains(requested_scopes) => + { + init_oauth_code_flow( + user.id.into(), + client.id, + existing_authorization.id, + requested_scopes, + redirect_uri, + oauth_info.redirect_uri, + oauth_info.state, + &redis, + ) + .await + } + _ => { + let flow_id = Flow::InitOAuthAppApproval { + user_id: user.id.into(), + client_id: client.id, + scopes: requested_scopes, + validated_redirect_uri: redirect_uri.clone(), + original_redirect_uri: oauth_info.redirect_uri.clone(), + state: oauth_info.state.clone(), + } + .insert(Duration::minutes(30), &redis) + .await + .map_err(|e| OAuthError::redirect(e, &oauth_info.state, &redirect_uri))?; + + let access_request = OAuthClientAccessRequest { + client_id: client.id, + client_name: client.name, + client_icon: client.icon_url, + flow_id, + requested_scopes, + }; + Ok(HttpResponse::Ok().json(access_request)) } - .insert(Duration::minutes(30), &redis) - .await - .map_err(|e| OAuthError::redirect(e, &oauth_info.state, &redirect_uri))?; - - let access_request = OAuthClientAccessRequest { - client_id: client.id, - client_name: client.name, - client_icon: client.icon_url, - flow_id, - requested_scopes, - }; - Ok(HttpResponse::Ok().json(access_request)) } } else { Err(OAuthError::error(OAuthErrorType::UnrecognizedClient { @@ -2436,11 +2442,11 @@ pub async fn accept_client_scopes( { let mut transaction = pool.begin().await.map_err(OAuthError::error)?; - let id = generate_oauth_client_authorization_id(&mut transaction) + let auth_id = generate_oauth_client_authorization_id(&mut transaction) .await .map_err(OAuthError::error)?; OAuthClientAuthorization { - id, + id: auth_id, client_id, user_id, scopes, @@ -2452,9 +2458,10 @@ pub async fn accept_client_scopes( transaction.commit().await.map_err(OAuthError::error)?; - get_init_oauth_code_flow_response( + init_oauth_code_flow( user_id, client_id, + auth_id, scopes, validated_redirect_uri, original_redirect_uri, @@ -2467,9 +2474,10 @@ pub async fn accept_client_scopes( } } -async fn get_init_oauth_code_flow_response( +async fn init_oauth_code_flow( user_id: crate::database::models::UserId, client_id: OAuthClientId, + authorization_id: OAuthClientAuthorizationId, scopes: Scopes, validated_redirect_uri: ValidatedRedirectUri, original_redirect_uri: Option, @@ -2479,6 +2487,7 @@ async fn get_init_oauth_code_flow_response( let code = Flow::OAuthAuthorizationCodeSupplied { user_id, client_id, + authorization_id, scopes, validated_redirect_uri: validated_redirect_uri.clone(), original_redirect_uri, diff --git a/src/database/models/flow_item.rs b/src/database/models/flow_item.rs index efa52b82..719ef6e3 100644 --- a/src/database/models/flow_item.rs +++ b/src/database/models/flow_item.rs @@ -46,6 +46,7 @@ pub enum Flow { OAuthAuthorizationCodeSupplied { user_id: UserId, client_id: OAuthClientId, + authorization_id: OAuthClientAuthorizationId, scopes: Scopes, validated_redirect_uri: ValidatedRedirectUri, original_redirect_uri: Option, // Needed for https://datatracker.ietf.org/doc/html/rfc6749#section-4.1.3 diff --git a/src/database/models/ids.rs b/src/database/models/ids.rs index 93ba7b01..1daef422 100644 --- a/src/database/models/ids.rs +++ b/src/database/models/ids.rs @@ -176,6 +176,14 @@ generate_ids!( OAuthRedirectUriId ); +generate_ids!( + pub generate_oauth_access_token_id, + OAuthAccessTokenId, + 8, + "SELECT EXISTS(SELECT 1 FROM oauth_access_tokens WHERE id=$1)", + OAuthAccessTokenId +); + #[derive(Copy, Clone, Debug, PartialEq, Eq, Type, Hash, Serialize, Deserialize)] #[sqlx(transparent)] pub struct UserId(pub i64); @@ -274,6 +282,10 @@ pub struct OAuthClientAuthorizationId(pub i64); #[sqlx(transparent)] pub struct OAuthRedirectUriId(pub i64); +#[derive(Copy, Clone, Debug, Type, Serialize, Deserialize, Eq, PartialEq, Hash)] +#[sqlx(transparent)] +pub struct OAuthAccessTokenId(pub i64); + use crate::models::ids; impl From for ProjectId { diff --git a/src/database/models/mod.rs b/src/database/models/mod.rs index a1e6c75f..ffb2179d 100644 --- a/src/database/models/mod.rs +++ b/src/database/models/mod.rs @@ -8,6 +8,7 @@ pub mod image_item; pub mod notification_item; pub mod oauth_client_authorization_item; pub mod oauth_client_item; +pub mod oauth_token_item; pub mod organization_item; pub mod pat_item; pub mod project_item; diff --git a/src/database/models/oauth_client_item.rs b/src/database/models/oauth_client_item.rs index 6512409b..14bd65c4 100644 --- a/src/database/models/oauth_client_item.rs +++ b/src/database/models/oauth_client_item.rs @@ -66,13 +66,10 @@ macro_rules! select_clients_with_predicate { } impl OAuthClient { - pub async fn get<'a, E>( + pub async fn get( id: OAuthClientId, - exec: E, - ) -> Result, DatabaseError> - where - E: sqlx::Executor<'a, Database = sqlx::Postgres>, - { + exec: impl sqlx::Executor<'_, Database = sqlx::Postgres>, + ) -> Result, DatabaseError> { let client_id_param = id.0; let value = select_clients_with_predicate!("WHERE clients.id = $1", client_id_param) .fetch_optional(exec) @@ -81,13 +78,10 @@ impl OAuthClient { return Ok(value.map(|r| r.into())); } - pub async fn get_all_user_clients<'a, E>( + pub async fn get_all_user_clients( user_id: UserId, - exec: E, - ) -> Result, DatabaseError> - where - E: sqlx::Executor<'a, Database = sqlx::Postgres>, - { + exec: impl sqlx::Executor<'_, Database = sqlx::Postgres>, + ) -> Result, DatabaseError> { let user_id_param = user_id.0; let clients = select_clients_with_predicate!("WHERE created_by = $1", user_id_param) .fetch_all(exec) @@ -96,10 +90,10 @@ impl OAuthClient { return Ok(clients.into_iter().map(|r| r.into()).collect()); } - pub async fn remove<'a, E>(id: OAuthClientId, exec: E) -> Result<(), DatabaseError> - where - E: sqlx::Executor<'a, Database = sqlx::Postgres>, - { + pub async fn remove( + id: OAuthClientId, + exec: impl sqlx::Executor<'_, Database = sqlx::Postgres>, + ) -> Result<(), DatabaseError> { // Cascades to oauth_client_redirect_uris, oauth_client_authorizations sqlx::query!( " @@ -225,7 +219,7 @@ impl From for OAuthClient { id: OAuthClientId(r.id), name: r.name, icon_url: r.icon_url, - max_scopes: Scopes::from_bits(r.max_scopes as u64).unwrap_or(Scopes::NONE), + max_scopes: Scopes::from_postgres(r.max_scopes), secret_hash: r.secret_hash, redirect_uris: redirects, created: r.created, diff --git a/src/database/models/oauth_token_item.rs b/src/database/models/oauth_token_item.rs new file mode 100644 index 00000000..dbafa786 --- /dev/null +++ b/src/database/models/oauth_token_item.rs @@ -0,0 +1,86 @@ +use super::{DatabaseError, OAuthAccessTokenId, OAuthClientAuthorizationId, OAuthClientId, UserId}; +use crate::models::pats::Scopes; +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; + +#[derive(Deserialize, Serialize, Clone, Debug)] +pub struct OAuthAccessToken { + pub id: OAuthAccessTokenId, + pub authorization_id: OAuthClientAuthorizationId, + pub token_hash: String, + pub scopes: Scopes, + pub created: DateTime, + pub expires: DateTime, + pub last_used: Option>, + + // Stored separately inside oauth_client_authorizations table + pub client_id: OAuthClientId, + pub user_id: UserId, +} + +impl OAuthAccessToken { + pub async fn get( + token_hash: String, + exec: impl sqlx::Executor<'_, Database = sqlx::Postgres>, + ) -> Result, DatabaseError> { + let value = sqlx::query!( + " + SELECT + tokens.id, + tokens.authorization_id, + tokens.token_hash, + tokens.scopes, + tokens.created, + tokens.expires, + tokens.last_used, + auths.client_id, + auths.user_id + FROM oauth_access_tokens tokens + JOIN oauth_client_authorizations auths + ON tokens.authorization_id = auths.id + WHERE tokens.token_hash = $1 + ", + token_hash + ) + .fetch_optional(exec) + .await?; + + return Ok(value.map(|r| OAuthAccessToken { + id: OAuthAccessTokenId(r.id), + authorization_id: OAuthClientAuthorizationId(r.authorization_id), + token_hash: r.token_hash, + scopes: Scopes::from_postgres(r.scopes), + created: r.created, + expires: r.expires, + last_used: r.last_used, + client_id: OAuthClientId(r.client_id), + user_id: UserId(r.user_id), + })); + } + + pub async fn insert( + &self, + exec: impl sqlx::Executor<'_, Database = sqlx::Postgres>, + ) -> Result<(), DatabaseError> { + sqlx::query!( + " + INSERT INTO oauth_access_tokens ( + id, authorization_id, token_hash, scopes, expires, last_used + ) + Values ( + $1, $2, $3, $4, $5, $6 + ) + ", + self.id.0, + self.authorization_id.0, + self.token_hash, + self.scopes.to_postgres(), + self.expires, + Option::>::None + ) + .execute(exec) + .await?; + + Ok(()) + } +} diff --git a/src/models/oauth_clients.rs b/src/models/oauth_clients.rs index 758a5169..7defa59e 100644 --- a/src/models/oauth_clients.rs +++ b/src/models/oauth_clients.rs @@ -1,7 +1,6 @@ use super::{ - ids::{Base62Id, TeamId, UserId}, + ids::{Base62Id, UserId}, pats::Scopes, - teams::TeamMember, }; use serde::{Deserialize, Serialize}; @@ -51,7 +50,7 @@ impl From for OAuthClie name: value.name, icon_url: value.icon_url, max_scopes: value.max_scopes, - redirect_uris: todo!(), + redirect_uris: value.redirect_uris.into_iter().map(|r| r.into()).collect(), created_by: value.created_by.into(), } } diff --git a/src/models/pats.rs b/src/models/pats.rs index 755115bf..038847f3 100644 --- a/src/models/pats.rs +++ b/src/models/pats.rs @@ -148,6 +148,10 @@ impl Scopes { pub fn to_postgres(&self) -> i64 { self.bits() as i64 } + + pub fn from_postgres(value: i64) -> Self { + Self::from_bits(value as u64).unwrap_or(Scopes::NONE) + } } #[derive(Serialize, Deserialize)] diff --git a/src/routes/v3/oauth_clients.rs b/src/routes/v3/oauth_clients.rs index b4a1a47b..636d5398 100644 --- a/src/routes/v3/oauth_clients.rs +++ b/src/routes/v3/oauth_clients.rs @@ -1,4 +1,4 @@ -use std::{collections::HashSet, fmt::Display, iter::FromIterator}; +use std::{collections::HashSet, fmt::Display}; use actix_web::{ delete, get, patch, post, @@ -21,7 +21,7 @@ use crate::{ models::{ generate_oauth_client_id, generate_oauth_redirect_id, oauth_client_item::{OAuthClient, OAuthRedirectUri}, - DatabaseError, OAuthClientId, User, UserId, + DatabaseError, OAuthClientId, User, }, redis::RedisPool, }, From 083f7540ad756e4d4cfd1d19f5da911f8781a26e Mon Sep 17 00:00:00 2001 From: OmegaJak Date: Wed, 18 Oct 2023 12:20:34 -0500 Subject: [PATCH 09/72] Reorg oauth stuff out of auth/flows and into its own module --- src/auth/flows.rs | 330 +---------------------------- src/auth/mod.rs | 141 +------------ src/auth/oauth/errors.rs | 143 +++++++++++++ src/auth/oauth/mod.rs | 347 +++++++++++++++++++++++++++++++ src/database/models/flow_item.rs | 2 +- src/routes/v2/mod.rs | 1 + 6 files changed, 495 insertions(+), 469 deletions(-) create mode 100644 src/auth/oauth/errors.rs create mode 100644 src/auth/oauth/mod.rs diff --git a/src/auth/flows.rs b/src/auth/flows.rs index c1f5bcd6..8572c29a 100644 --- a/src/auth/flows.rs +++ b/src/auth/flows.rs @@ -1,13 +1,8 @@ use crate::auth::email::send_email; use crate::auth::session::issue_session; use crate::auth::validate::get_user_record_from_bearer_token; -use crate::auth::{get_user_from_headers, AuthenticationError, OAuthError, OAuthErrorType}; +use crate::auth::{get_user_from_headers, AuthenticationError}; use crate::database::models::flow_item::Flow; -use crate::database::models::oauth_client_authorization_item::OAuthClientAuthorization; -use crate::database::models::oauth_client_item::OAuthClient as DBOAuthClient; -use crate::database::models::{ - generate_oauth_client_authorization_id, OAuthClientAuthorizationId, OAuthClientId, -}; use crate::database::redis::RedisPool; use crate::file_hosting::FileHost; use crate::models::ids::base62_impl::{parse_base62, to_base62}; @@ -58,9 +53,7 @@ pub fn config(cfg: &mut ServiceConfig) { .service(set_email) .service(verify_email) .service(subscribe_newsletter) - .service(link_trolley) - .service(init_oauth) - .service(accept_client_scopes), + .service(link_trolley), ); } @@ -2285,322 +2278,3 @@ pub async fn link_trolley( )) } } - -#[derive(Serialize, Deserialize)] -pub struct OAuthInit { - pub client_id: OAuthClientId, - pub redirect_uri: Option, - pub scope: Option, - pub state: Option, -} - -#[derive(Serialize, Deserialize)] -pub struct OAuthClientAccessRequest { - pub flow_id: String, - pub client_id: OAuthClientId, - pub client_name: String, - pub client_icon: Option, - pub requested_scopes: Scopes, -} - -#[get("oauth/authorize")] -pub async fn init_oauth( - req: HttpRequest, - Query(oauth_info): Query, - pool: Data, - redis: Data, - session_queue: Data, -) -> Result { - let user = get_user_from_headers( - &req, - &**pool, - &redis, - &session_queue, - Some(&[Scopes::USER_AUTH_WRITE]), - ) - .await - .map_err(OAuthError::error)? - .1; - - let client = DBOAuthClient::get(oauth_info.client_id, &**pool) - .await - .map_err(OAuthError::error)?; - - if let Some(client) = client { - let redirect_uri = ValidatedRedirectUri::validate( - &oauth_info.redirect_uri, - client.redirect_uris.iter().map(|r| r.uri.as_ref()), - client.id, - )?; - - let requested_scopes = oauth_info - .scope - .as_ref() - .map_or(Ok(client.max_scopes), |s| { - Scopes::parse_from_oauth_scopes(s).map_err(|e| { - OAuthError::redirect( - OAuthErrorType::FailedScopeParse(e), - &oauth_info.state, - &redirect_uri, - ) - }) - })?; - - if !client.max_scopes.contains(requested_scopes) { - return Err(OAuthError::redirect( - OAuthErrorType::ScopesTooBroad, - &oauth_info.state, - &redirect_uri, - )); - } - - let existing_authorization = - OAuthClientAuthorization::get(client.id, user.id.into(), &**pool) - .await - .map_err(|e| OAuthError::redirect(e, &oauth_info.state, &redirect_uri))?; - match existing_authorization { - Some(existing_authorization) - if existing_authorization.scopes.contains(requested_scopes) => - { - init_oauth_code_flow( - user.id.into(), - client.id, - existing_authorization.id, - requested_scopes, - redirect_uri, - oauth_info.redirect_uri, - oauth_info.state, - &redis, - ) - .await - } - _ => { - let flow_id = Flow::InitOAuthAppApproval { - user_id: user.id.into(), - client_id: client.id, - scopes: requested_scopes, - validated_redirect_uri: redirect_uri.clone(), - original_redirect_uri: oauth_info.redirect_uri.clone(), - state: oauth_info.state.clone(), - } - .insert(Duration::minutes(30), &redis) - .await - .map_err(|e| OAuthError::redirect(e, &oauth_info.state, &redirect_uri))?; - - let access_request = OAuthClientAccessRequest { - client_id: client.id, - client_name: client.name, - client_icon: client.icon_url, - flow_id, - requested_scopes, - }; - Ok(HttpResponse::Ok().json(access_request)) - } - } - } else { - Err(OAuthError::error(OAuthErrorType::UnrecognizedClient { - client_id: oauth_info.client_id, - })) - } -} - -#[derive(Deserialize)] -pub struct AcceptOAuthClientScopes { - pub flow: String, -} - -#[get("oauth/accept")] -pub async fn accept_client_scopes( - req: HttpRequest, - accept_body: web::Json, - pool: Data, - redis: Data, - session_queue: Data, -) -> Result { - let user = get_user_from_headers( - &req, - &**pool, - &redis, - &session_queue, - Some(&[Scopes::USER_AUTH_WRITE]), - ) - .await - .map_err(|e| OAuthError::error(e))? - .1; - - let flow = Flow::get(&accept_body.flow, &redis) - .await - .map_err(|e| OAuthError::error(e))?; - if let Some(Flow::InitOAuthAppApproval { - user_id, - client_id, - scopes, - validated_redirect_uri, - original_redirect_uri, - state, - }) = flow - { - let mut transaction = pool.begin().await.map_err(OAuthError::error)?; - - let auth_id = generate_oauth_client_authorization_id(&mut transaction) - .await - .map_err(OAuthError::error)?; - OAuthClientAuthorization { - id: auth_id, - client_id, - user_id, - scopes, - created: Utc::now(), - } - .insert(&mut transaction) - .await - .map_err(OAuthError::error)?; - - transaction.commit().await.map_err(OAuthError::error)?; - - init_oauth_code_flow( - user_id, - client_id, - auth_id, - scopes, - validated_redirect_uri, - original_redirect_uri, - state, - &redis, - ) - .await - } else { - Err(OAuthError::error(OAuthErrorType::InvalidAcceptFlowId)) - } -} - -async fn init_oauth_code_flow( - user_id: crate::database::models::UserId, - client_id: OAuthClientId, - authorization_id: OAuthClientAuthorizationId, - scopes: Scopes, - validated_redirect_uri: ValidatedRedirectUri, - original_redirect_uri: Option, - state: Option, - redis: &RedisPool, -) -> Result { - let code = Flow::OAuthAuthorizationCodeSupplied { - user_id, - client_id, - authorization_id, - scopes, - validated_redirect_uri: validated_redirect_uri.clone(), - original_redirect_uri, - } - .insert(Duration::minutes(10), redis) - .await - .map_err(|e| OAuthError::redirect(e, &state, &validated_redirect_uri.clone()))?; - - let mut redirect_params = vec![format!("code={code}")]; - if let Some(state) = state { - redirect_params.push(format!("state={state}")); - } - - let redirect_uri = append_params_to_uri(&validated_redirect_uri.0, &redirect_params); - - // IETF RFC 6749 Section 4.1.2 (https://datatracker.ietf.org/doc/html/rfc6749#section-4.1.2) - Ok(HttpResponse::Found() - .append_header(("Location", redirect_uri)) - .finish()) -} - -#[derive(Clone, Debug, Serialize, Deserialize)] -pub struct ValidatedRedirectUri(pub String); - -impl ValidatedRedirectUri { - pub fn validate<'a>( - to_validate: &Option, - validate_against: impl IntoIterator + Clone, - client_id: OAuthClientId, - ) -> Result { - if let Some(first_client_redirect_uri) = validate_against.clone().into_iter().next() { - if let Some(to_validate) = to_validate { - if validate_against - .into_iter() - .any(|uri| same_uri_except_query_components(&uri, to_validate)) - { - return Ok(ValidatedRedirectUri(to_validate.clone())); - } else { - return Err(OAuthError::error(OAuthErrorType::InvalidRedirectUri( - to_validate.clone(), - ))); - } - } else { - return Ok(ValidatedRedirectUri(first_client_redirect_uri.to_string())); - } - } else { - return Err(OAuthError::error( - OAuthErrorType::ClientMissingRedirectURI { client_id }, - )); - } - } -} - -fn same_uri_except_query_components(a: &str, b: &str) -> bool { - let mut a_components = a.split('?'); - let mut b_components = b.split('?'); - a_components.next() == b_components.next() -} - -fn append_params_to_uri(uri: &str, params: &[impl AsRef]) -> String { - let mut uri = uri.to_string(); - let mut connector = if uri.contains('?') { "&" } else { "?" }; - for param in params { - uri.push_str(&format!("{}{}", connector, param.as_ref())); - connector = "&"; - } - - uri -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn validate_for_none_returns_first_valid_uri() { - let validate_against = vec!["https://modrinth.com/a"]; - - let validated = - ValidatedRedirectUri::validate(&None, validate_against.clone(), OAuthClientId(0)) - .unwrap(); - - assert_eq!(validate_against[0], validated.0); - } - - #[test] - fn validate_for_valid_uri_returns_first_matching_uri_ignoring_query_params() { - let validate_against = vec![ - "https://modrinth.com/a?q3=p3&q4=p4", - "https://modrinth.com/a/b/c?q1=p1&q2=p2", - ]; - let to_validate = "https://modrinth.com/a/b/c?query0=param0&query1=param1".to_string(); - - let validated = ValidatedRedirectUri::validate( - &Some(to_validate.clone()), - validate_against, - OAuthClientId(0), - ) - .unwrap(); - - assert_eq!(to_validate, validated.0); - } - - #[test] - fn validate_for_invalid_uri_returns_err() { - let validate_against = vec!["https://modrinth.com/a"]; - let to_validate = "https://modrinth.com/a/b".to_string(); - - let validated = - ValidatedRedirectUri::validate(&Some(to_validate), validate_against, OAuthClientId(0)); - - assert!( - validated.is_err_and(|e| matches!(e.error_type, OAuthErrorType::InvalidRedirectUri(_))) - ); - } -} diff --git a/src/auth/mod.rs b/src/auth/mod.rs index b633f308..9c65c914 100644 --- a/src/auth/mod.rs +++ b/src/auth/mod.rs @@ -1,6 +1,7 @@ pub mod checks; pub mod email; pub mod flows; +pub mod oauth; pub mod pats; pub mod session; mod templates; @@ -17,8 +18,6 @@ use actix_web::http::StatusCode; use actix_web::HttpResponse; use thiserror::Error; -use self::flows::{OAuthInit, ValidatedRedirectUri}; - #[derive(Error, Debug)] pub enum AuthenticationError { #[error("Environment Error")] @@ -99,141 +98,3 @@ impl AuthenticationError { } } } - -#[derive(thiserror::Error, Debug)] -#[error("{}", .error_type)] -pub struct OAuthError { - #[source] - pub error_type: OAuthErrorType, - - pub state: Option, - pub valid_redirect_uri: Option, -} - -impl OAuthError { - /// The OAuth request failed either because of an invalid redirection URI - /// or before we could validate the one we were given, so return an error - /// directly to the caller - /// - /// See: IETF RFC 6749 4.1.2.1 (https://datatracker.ietf.org/doc/html/rfc6749#section-4.1.2.1) - pub fn error(error_type: impl Into) -> Self { - Self { - error_type: error_type.into(), - valid_redirect_uri: None, - state: None, - } - } - - /// The OAuth request failed for a reason other than an invalid redirection URI - /// So send the error in url-encoded form to the redirect URI - /// - /// See: IETF RFC 6749 4.1.2.1 (https://datatracker.ietf.org/doc/html/rfc6749#section-4.1.2.1) - pub fn redirect( - err: impl Into, - state: &Option, - valid_redirect_uri: &ValidatedRedirectUri, - ) -> Self { - Self { - error_type: err.into(), - state: state.clone(), - valid_redirect_uri: Some(valid_redirect_uri.clone()), - } - } -} - -impl actix_web::ResponseError for OAuthError { - fn status_code(&self) -> StatusCode { - match self.error_type { - OAuthErrorType::AuthenticationError(_) - | OAuthErrorType::UnrecognizedClient { client_id: _ } - | OAuthErrorType::FailedScopeParse(_) - | OAuthErrorType::ScopesTooBroad => { - if self.valid_redirect_uri.is_some() { - StatusCode::FOUND - } else { - StatusCode::INTERNAL_SERVER_ERROR - } - } - OAuthErrorType::InvalidRedirectUri(_) - | OAuthErrorType::ClientMissingRedirectURI { client_id: _ } - | OAuthErrorType::InvalidAcceptFlowId => StatusCode::BAD_REQUEST, - } - } - - fn error_response(&self) -> HttpResponse { - if let Some(ValidatedRedirectUri(mut redirect_uri)) = self.valid_redirect_uri.clone() { - redirect_uri = format!( - "{}?error={}&error_description={}", - redirect_uri.to_string(), - self.error_type.error_name(), - self.error_type.to_string(), - ); - - if let Some(state) = self.state.as_ref() { - redirect_uri = format!("{}&state={}", redirect_uri, state); - } - - redirect_uri = urlencoding::encode(&redirect_uri).to_string(); - HttpResponse::Found() - .append_header(("Location".to_string(), redirect_uri)) - .finish() - } else { - HttpResponse::build(self.status_code()).json(ApiError { - error: &self.error_type.error_name(), - description: &self.error_type.to_string(), - }) - } - } -} - -#[derive(thiserror::Error, Debug)] -pub enum OAuthErrorType { - #[error(transparent)] - AuthenticationError(#[from] AuthenticationError), - #[error("Client {} not recognized", .client_id.0)] - UnrecognizedClient { - client_id: crate::database::models::OAuthClientId, - }, - #[error("Client {} has no redirect URIs specified", .client_id.0)] - ClientMissingRedirectURI { - client_id: crate::database::models::OAuthClientId, - }, - #[error("The provided redirect URI did not match any configured in the client")] - InvalidRedirectUri(String), - #[error("The provided scope was malformed or did not correspond to known scopes ({0})")] - FailedScopeParse(bitflags::parser::ParseError), - #[error( - "The provided scope requested scopes broader than the developer app is configured with" - )] - ScopesTooBroad, - #[error("The provided flow id was invalid")] - InvalidAcceptFlowId, -} - -impl From for OAuthErrorType { - fn from(value: crate::database::models::DatabaseError) -> Self { - OAuthErrorType::AuthenticationError(value.into()) - } -} - -impl From for OAuthErrorType { - fn from(value: sqlx::Error) -> Self { - OAuthErrorType::AuthenticationError(value.into()) - } -} - -impl OAuthErrorType { - pub fn error_name(&self) -> String { - // IETF RFC 6749 4.1.2.1 (https://datatracker.ietf.org/doc/html/rfc6749#autoid-38) - match self { - OAuthErrorType::InvalidRedirectUri(_) - | OAuthErrorType::ClientMissingRedirectURI { client_id: _ } => "invalid_uri", - OAuthErrorType::AuthenticationError(_) | OAuthErrorType::InvalidAcceptFlowId => { - "server_error" - } - OAuthErrorType::UnrecognizedClient { client_id: _ } => "invalid_request", - OAuthErrorType::FailedScopeParse(_) | OAuthErrorType::ScopesTooBroad => "invalid_scope", - } - .to_string() - } -} diff --git a/src/auth/oauth/errors.rs b/src/auth/oauth/errors.rs new file mode 100644 index 00000000..1da533d3 --- /dev/null +++ b/src/auth/oauth/errors.rs @@ -0,0 +1,143 @@ +use super::ValidatedRedirectUri; +use crate::auth::AuthenticationError; +use crate::models::error::ApiError; +use actix_web::http::StatusCode; +use actix_web::HttpResponse; + +#[derive(thiserror::Error, Debug)] +#[error("{}", .error_type)] +pub struct OAuthError { + #[source] + pub error_type: OAuthErrorType, + + pub state: Option, + pub valid_redirect_uri: Option, +} + +impl OAuthError { + /// The OAuth request failed either because of an invalid redirection URI + /// or before we could validate the one we were given, so return an error + /// directly to the caller + /// + /// See: IETF RFC 6749 4.1.2.1 (https://datatracker.ietf.org/doc/html/rfc6749#section-4.1.2.1) + pub fn error(error_type: impl Into) -> Self { + Self { + error_type: error_type.into(), + valid_redirect_uri: None, + state: None, + } + } + + /// The OAuth request failed for a reason other than an invalid redirection URI + /// So send the error in url-encoded form to the redirect URI + /// + /// See: IETF RFC 6749 4.1.2.1 (https://datatracker.ietf.org/doc/html/rfc6749#section-4.1.2.1) + pub fn redirect( + err: impl Into, + state: &Option, + valid_redirect_uri: &ValidatedRedirectUri, + ) -> Self { + Self { + error_type: err.into(), + state: state.clone(), + valid_redirect_uri: Some(valid_redirect_uri.clone()), + } + } +} + +impl actix_web::ResponseError for OAuthError { + fn status_code(&self) -> StatusCode { + match self.error_type { + OAuthErrorType::AuthenticationError(_) + | OAuthErrorType::UnrecognizedClient { client_id: _ } + | OAuthErrorType::FailedScopeParse(_) + | OAuthErrorType::ScopesTooBroad => { + if self.valid_redirect_uri.is_some() { + StatusCode::FOUND + } else { + StatusCode::INTERNAL_SERVER_ERROR + } + } + OAuthErrorType::InvalidRedirectUri(_) + | OAuthErrorType::ClientMissingRedirectURI { client_id: _ } + | OAuthErrorType::InvalidAcceptFlowId => StatusCode::BAD_REQUEST, + } + } + + fn error_response(&self) -> HttpResponse { + if let Some(ValidatedRedirectUri(mut redirect_uri)) = self.valid_redirect_uri.clone() { + redirect_uri = format!( + "{}?error={}&error_description={}", + redirect_uri.to_string(), + self.error_type.error_name(), + self.error_type.to_string(), + ); + + if let Some(state) = self.state.as_ref() { + redirect_uri = format!("{}&state={}", redirect_uri, state); + } + + redirect_uri = urlencoding::encode(&redirect_uri).to_string(); + HttpResponse::Found() + .append_header(("Location".to_string(), redirect_uri)) + .finish() + } else { + HttpResponse::build(self.status_code()).json(ApiError { + error: &self.error_type.error_name(), + description: &self.error_type.to_string(), + }) + } + } +} + +#[derive(thiserror::Error, Debug)] +pub enum OAuthErrorType { + #[error(transparent)] + AuthenticationError(#[from] AuthenticationError), + #[error("Client {} not recognized", .client_id.0)] + UnrecognizedClient { + client_id: crate::database::models::OAuthClientId, + }, + #[error("Client {} has no redirect URIs specified", .client_id.0)] + ClientMissingRedirectURI { + client_id: crate::database::models::OAuthClientId, + }, + #[error("The provided redirect URI did not match any configured in the client")] + InvalidRedirectUri(String), + #[error("The provided scope was malformed or did not correspond to known scopes ({0})")] + FailedScopeParse(bitflags::parser::ParseError), + #[error( + "The provided scope requested scopes broader than the developer app is configured with" + )] + ScopesTooBroad, + #[error("The provided flow id was invalid")] + InvalidAcceptFlowId, +} + +impl From for OAuthErrorType { + fn from(value: crate::database::models::DatabaseError) -> Self { + OAuthErrorType::AuthenticationError(value.into()) + } +} + +impl From for OAuthErrorType { + fn from(value: sqlx::Error) -> Self { + OAuthErrorType::AuthenticationError(value.into()) + } +} + +impl OAuthErrorType { + pub fn error_name(&self) -> String { + // IETF RFC 6749 4.1.2.1 (https://datatracker.ietf.org/doc/html/rfc6749#autoid-38) + match self { + OAuthErrorType::InvalidRedirectUri(_) + | OAuthErrorType::ClientMissingRedirectURI { client_id: _ } => "invalid_uri", + OAuthErrorType::AuthenticationError(_) | OAuthErrorType::InvalidAcceptFlowId => { + "server_error" + } + OAuthErrorType::UnrecognizedClient { client_id: _ } => "invalid_request", + OAuthErrorType::FailedScopeParse(_) | OAuthErrorType::ScopesTooBroad => "invalid_scope", + } + .to_string() + } +} diff --git a/src/auth/oauth/mod.rs b/src/auth/oauth/mod.rs new file mode 100644 index 00000000..059ba731 --- /dev/null +++ b/src/auth/oauth/mod.rs @@ -0,0 +1,347 @@ +use crate::auth::get_user_from_headers; +use crate::database::models::flow_item::Flow; +use crate::database::models::oauth_client_authorization_item::OAuthClientAuthorization; +use crate::database::models::oauth_client_item::OAuthClient as DBOAuthClient; +use crate::database::models::{ + generate_oauth_client_authorization_id, OAuthClientAuthorizationId, OAuthClientId, +}; +use crate::database::redis::RedisPool; +use crate::models::pats::Scopes; +use crate::queue::session::AuthQueue; +use actix_web::web::{scope, Data, Query, ServiceConfig}; +use actix_web::{get, web, HttpRequest, HttpResponse}; +use chrono::{Duration, Utc}; +use serde::{Deserialize, Serialize}; +use sqlx::postgres::PgPool; + +use self::errors::{OAuthError, OAuthErrorType}; + +pub mod errors; + +pub fn config(cfg: &mut ServiceConfig) { + cfg.service( + scope("auth/oauth") + .service(init_oauth) + .service(accept_client_scopes), + ); +} + +#[derive(Serialize, Deserialize)] +pub struct OAuthInit { + pub client_id: OAuthClientId, + pub redirect_uri: Option, + pub scope: Option, + pub state: Option, +} + +#[derive(Serialize, Deserialize)] +pub struct OAuthClientAccessRequest { + pub flow_id: String, + pub client_id: OAuthClientId, + pub client_name: String, + pub client_icon: Option, + pub requested_scopes: Scopes, +} + +#[get("authorize")] +pub async fn init_oauth( + req: HttpRequest, + Query(oauth_info): Query, + pool: Data, + redis: Data, + session_queue: Data, +) -> Result { + let user = get_user_from_headers( + &req, + &**pool, + &redis, + &session_queue, + Some(&[Scopes::USER_AUTH_WRITE]), + ) + .await + .map_err(OAuthError::error)? + .1; + + let client = DBOAuthClient::get(oauth_info.client_id, &**pool) + .await + .map_err(OAuthError::error)?; + + if let Some(client) = client { + let redirect_uri = ValidatedRedirectUri::validate( + &oauth_info.redirect_uri, + client.redirect_uris.iter().map(|r| r.uri.as_ref()), + client.id, + )?; + + let requested_scopes = oauth_info + .scope + .as_ref() + .map_or(Ok(client.max_scopes), |s| { + Scopes::parse_from_oauth_scopes(s).map_err(|e| { + OAuthError::redirect( + OAuthErrorType::FailedScopeParse(e), + &oauth_info.state, + &redirect_uri, + ) + }) + })?; + + if !client.max_scopes.contains(requested_scopes) { + return Err(OAuthError::redirect( + OAuthErrorType::ScopesTooBroad, + &oauth_info.state, + &redirect_uri, + )); + } + + let existing_authorization = + OAuthClientAuthorization::get(client.id, user.id.into(), &**pool) + .await + .map_err(|e| OAuthError::redirect(e, &oauth_info.state, &redirect_uri))?; + match existing_authorization { + Some(existing_authorization) + if existing_authorization.scopes.contains(requested_scopes) => + { + init_oauth_code_flow( + user.id.into(), + client.id, + existing_authorization.id, + requested_scopes, + redirect_uri, + oauth_info.redirect_uri, + oauth_info.state, + &redis, + ) + .await + } + _ => { + let flow_id = Flow::InitOAuthAppApproval { + user_id: user.id.into(), + client_id: client.id, + scopes: requested_scopes, + validated_redirect_uri: redirect_uri.clone(), + original_redirect_uri: oauth_info.redirect_uri.clone(), + state: oauth_info.state.clone(), + } + .insert(Duration::minutes(30), &redis) + .await + .map_err(|e| OAuthError::redirect(e, &oauth_info.state, &redirect_uri))?; + + let access_request = OAuthClientAccessRequest { + client_id: client.id, + client_name: client.name, + client_icon: client.icon_url, + flow_id, + requested_scopes, + }; + Ok(HttpResponse::Ok().json(access_request)) + } + } + } else { + Err(OAuthError::error(OAuthErrorType::UnrecognizedClient { + client_id: oauth_info.client_id, + })) + } +} + +#[derive(Deserialize)] +pub struct AcceptOAuthClientScopes { + pub flow: String, +} + +#[get("accept")] +pub async fn accept_client_scopes( + req: HttpRequest, + accept_body: web::Json, + pool: Data, + redis: Data, + session_queue: Data, +) -> Result { + //TODO: Any way to do this without getting the user? + let user = get_user_from_headers( + &req, + &**pool, + &redis, + &session_queue, + Some(&[Scopes::USER_AUTH_WRITE]), + ) + .await + .map_err(|e| OAuthError::error(e))? + .1; + + let flow = Flow::get(&accept_body.flow, &redis) + .await + .map_err(|e| OAuthError::error(e))?; + if let Some(Flow::InitOAuthAppApproval { + user_id, + client_id, + scopes, + validated_redirect_uri, + original_redirect_uri, + state, + }) = flow + { + let mut transaction = pool.begin().await.map_err(OAuthError::error)?; + + let auth_id = generate_oauth_client_authorization_id(&mut transaction) + .await + .map_err(OAuthError::error)?; + OAuthClientAuthorization { + id: auth_id, + client_id, + user_id, + scopes, + created: Utc::now(), + } + .insert(&mut transaction) + .await + .map_err(OAuthError::error)?; + + transaction.commit().await.map_err(OAuthError::error)?; + + init_oauth_code_flow( + user_id, + client_id, + auth_id, + scopes, + validated_redirect_uri, + original_redirect_uri, + state, + &redis, + ) + .await + } else { + Err(OAuthError::error(OAuthErrorType::InvalidAcceptFlowId)) + } +} + +async fn init_oauth_code_flow( + user_id: crate::database::models::UserId, + client_id: OAuthClientId, + authorization_id: OAuthClientAuthorizationId, + scopes: Scopes, + validated_redirect_uri: ValidatedRedirectUri, + original_redirect_uri: Option, + state: Option, + redis: &RedisPool, +) -> Result { + let code = Flow::OAuthAuthorizationCodeSupplied { + user_id, + client_id, + authorization_id, + scopes, + validated_redirect_uri: validated_redirect_uri.clone(), + original_redirect_uri, + } + .insert(Duration::minutes(10), redis) + .await + .map_err(|e| OAuthError::redirect(e, &state, &validated_redirect_uri.clone()))?; + + let mut redirect_params = vec![format!("code={code}")]; + if let Some(state) = state { + redirect_params.push(format!("state={state}")); + } + + let redirect_uri = append_params_to_uri(&validated_redirect_uri.0, &redirect_params); + + // IETF RFC 6749 Section 4.1.2 (https://datatracker.ietf.org/doc/html/rfc6749#section-4.1.2) + Ok(HttpResponse::Found() + .append_header(("Location", redirect_uri)) + .finish()) +} + +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct ValidatedRedirectUri(pub String); + +impl ValidatedRedirectUri { + pub fn validate<'a>( + to_validate: &Option, + validate_against: impl IntoIterator + Clone, + client_id: OAuthClientId, + ) -> Result { + if let Some(first_client_redirect_uri) = validate_against.clone().into_iter().next() { + if let Some(to_validate) = to_validate { + if validate_against + .into_iter() + .any(|uri| same_uri_except_query_components(&uri, to_validate)) + { + return Ok(ValidatedRedirectUri(to_validate.clone())); + } else { + return Err(OAuthError::error(OAuthErrorType::InvalidRedirectUri( + to_validate.clone(), + ))); + } + } else { + return Ok(ValidatedRedirectUri(first_client_redirect_uri.to_string())); + } + } else { + return Err(OAuthError::error( + OAuthErrorType::ClientMissingRedirectURI { client_id }, + )); + } + } +} + +fn same_uri_except_query_components(a: &str, b: &str) -> bool { + let mut a_components = a.split('?'); + let mut b_components = b.split('?'); + a_components.next() == b_components.next() +} + +fn append_params_to_uri(uri: &str, params: &[impl AsRef]) -> String { + let mut uri = uri.to_string(); + let mut connector = if uri.contains('?') { "&" } else { "?" }; + for param in params { + uri.push_str(&format!("{}{}", connector, param.as_ref())); + connector = "&"; + } + + uri +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn validate_for_none_returns_first_valid_uri() { + let validate_against = vec!["https://modrinth.com/a"]; + + let validated = + ValidatedRedirectUri::validate(&None, validate_against.clone(), OAuthClientId(0)) + .unwrap(); + + assert_eq!(validate_against[0], validated.0); + } + + #[test] + fn validate_for_valid_uri_returns_first_matching_uri_ignoring_query_params() { + let validate_against = vec![ + "https://modrinth.com/a?q3=p3&q4=p4", + "https://modrinth.com/a/b/c?q1=p1&q2=p2", + ]; + let to_validate = "https://modrinth.com/a/b/c?query0=param0&query1=param1".to_string(); + + let validated = ValidatedRedirectUri::validate( + &Some(to_validate.clone()), + validate_against, + OAuthClientId(0), + ) + .unwrap(); + + assert_eq!(to_validate, validated.0); + } + + #[test] + fn validate_for_invalid_uri_returns_err() { + let validate_against = vec!["https://modrinth.com/a"]; + let to_validate = "https://modrinth.com/a/b".to_string(); + + let validated = + ValidatedRedirectUri::validate(&Some(to_validate), validate_against, OAuthClientId(0)); + + assert!( + validated.is_err_and(|e| matches!(e.error_type, OAuthErrorType::InvalidRedirectUri(_))) + ); + } +} diff --git a/src/database/models/flow_item.rs b/src/database/models/flow_item.rs index 719ef6e3..b6168abe 100644 --- a/src/database/models/flow_item.rs +++ b/src/database/models/flow_item.rs @@ -1,5 +1,5 @@ use super::ids::*; -use crate::auth::flows::ValidatedRedirectUri; +use crate::auth::oauth::ValidatedRedirectUri; use crate::database::models::DatabaseError; use crate::database::redis::RedisPool; use crate::{auth::flows::AuthProvider, models::pats::Scopes}; diff --git a/src/routes/v2/mod.rs b/src/routes/v2/mod.rs index 3f95ae6d..ddac66cf 100644 --- a/src/routes/v2/mod.rs +++ b/src/routes/v2/mod.rs @@ -29,6 +29,7 @@ pub fn config(cfg: &mut actix_web::web::ServiceConfig) { .configure(crate::auth::session::config) .configure(crate::auth::flows::config) .configure(crate::auth::pats::config) + .configure(crate::auth::oauth::config) .configure(moderation::config) .configure(notifications::config) .configure(organizations::config) From aeb45e08834365c46cb92629598fef27f87d1d9a Mon Sep 17 00:00:00 2001 From: OmegaJak Date: Wed, 18 Oct 2023 14:55:08 -0500 Subject: [PATCH 10/72] Impl OAuth get access token endpoint --- migrations/20231016190056_oauth_provider.sql | 2 +- src/auth/oauth/errors.rs | 48 ++++-- src/auth/oauth/mod.rs | 169 +++++++++++++++++-- src/auth/validate.rs | 16 +- src/database/models/flow_item.rs | 1 - src/database/models/ids.rs | 5 + src/database/models/oauth_client_item.rs | 5 + src/database/models/oauth_token_item.rs | 20 ++- src/models/ids.rs | 11 ++ src/routes/v3/oauth_clients.rs | 6 +- 10 files changed, 242 insertions(+), 41 deletions(-) diff --git a/migrations/20231016190056_oauth_provider.sql b/migrations/20231016190056_oauth_provider.sql index 3e290d48..2b82384c 100644 --- a/migrations/20231016190056_oauth_provider.sql +++ b/migrations/20231016190056_oauth_provider.sql @@ -25,7 +25,7 @@ CREATE TABLE oauth_access_tokens ( token_hash text NOT NULL UNIQUE, scopes bigint NOT NULL, created timestamptz NOT NULL DEFAULT CURRENT_TIMESTAMP, - expires timestamptz NOT NULL, + expires timestamptz NOT NULL DEFAULT CURRENT_TIMESTAMP + interval '14 days', last_used timestamptz NULL ); CREATE INDEX oauth_client_creator ON oauth_clients(created_by); diff --git a/src/auth/oauth/errors.rs b/src/auth/oauth/errors.rs index 1da533d3..7ba42bdd 100644 --- a/src/auth/oauth/errors.rs +++ b/src/auth/oauth/errors.rs @@ -1,6 +1,7 @@ use super::ValidatedRedirectUri; use crate::auth::AuthenticationError; use crate::models::error::ApiError; +use crate::models::ids::DecodingError; use actix_web::http::StatusCode; use actix_web::HttpResponse; @@ -49,7 +50,6 @@ impl actix_web::ResponseError for OAuthError { fn status_code(&self) -> StatusCode { match self.error_type { OAuthErrorType::AuthenticationError(_) - | OAuthErrorType::UnrecognizedClient { client_id: _ } | OAuthErrorType::FailedScopeParse(_) | OAuthErrorType::ScopesTooBroad => { if self.valid_redirect_uri.is_some() { @@ -58,9 +58,16 @@ impl actix_web::ResponseError for OAuthError { StatusCode::INTERNAL_SERVER_ERROR } } - OAuthErrorType::InvalidRedirectUri(_) + OAuthErrorType::RedirectUriNotConfigured(_) | OAuthErrorType::ClientMissingRedirectURI { client_id: _ } - | OAuthErrorType::InvalidAcceptFlowId => StatusCode::BAD_REQUEST, + | OAuthErrorType::InvalidAcceptFlowId + | OAuthErrorType::MalformedId(_) + | OAuthErrorType::InvalidClientId(_) + | OAuthErrorType::InvalidAuthCode + | OAuthErrorType::OnlySupportsAuthorizationCodeGrant(_) + | OAuthErrorType::RedirectUriChanged(_) + | OAuthErrorType::UnauthorizedClient => StatusCode::BAD_REQUEST, + OAuthErrorType::ClientAuthenticationFailed => StatusCode::UNAUTHORIZED, } } @@ -94,16 +101,12 @@ impl actix_web::ResponseError for OAuthError { pub enum OAuthErrorType { #[error(transparent)] AuthenticationError(#[from] AuthenticationError), - #[error("Client {} not recognized", .client_id.0)] - UnrecognizedClient { - client_id: crate::database::models::OAuthClientId, - }, #[error("Client {} has no redirect URIs specified", .client_id.0)] ClientMissingRedirectURI { client_id: crate::database::models::OAuthClientId, }, #[error("The provided redirect URI did not match any configured in the client")] - InvalidRedirectUri(String), + RedirectUriNotConfigured(String), #[error("The provided scope was malformed or did not correspond to known scopes ({0})")] FailedScopeParse(bitflags::parser::ParseError), #[error( @@ -112,6 +115,20 @@ pub enum OAuthErrorType { ScopesTooBroad, #[error("The provided flow id was invalid")] InvalidAcceptFlowId, + #[error("The provided client id was invalid")] + InvalidClientId(crate::database::models::OAuthClientId), + #[error("The provided ID could not be decoded: {0}")] + MalformedId(#[from] DecodingError), + #[error("Failed to authenticate client")] + ClientAuthenticationFailed, + #[error("The provided authorization grant code was invalid")] + InvalidAuthCode, + #[error("The provided client id did not match the id this authorization code was granted to")] + UnauthorizedClient, + #[error("The provided redirect URI did not exactly match the uri originally provided when this flow began")] + RedirectUriChanged(Option), + #[error("The provided grant type ({0}) must be \"authorization_code\"")] + OnlySupportsAuthorizationCodeGrant(String), } impl From for OAuthErrorType { @@ -129,14 +146,17 @@ impl From for OAuthErrorType { impl OAuthErrorType { pub fn error_name(&self) -> String { // IETF RFC 6749 4.1.2.1 (https://datatracker.ietf.org/doc/html/rfc6749#autoid-38) + // And 5.2 (https://datatracker.ietf.org/doc/html/rfc6749#section-5.2) match self { - OAuthErrorType::InvalidRedirectUri(_) - | OAuthErrorType::ClientMissingRedirectURI { client_id: _ } => "invalid_uri", - OAuthErrorType::AuthenticationError(_) | OAuthErrorType::InvalidAcceptFlowId => { - "server_error" + Self::RedirectUriNotConfigured(_) | Self::ClientMissingRedirectURI { client_id: _ } => { + "invalid_uri" } - OAuthErrorType::UnrecognizedClient { client_id: _ } => "invalid_request", - OAuthErrorType::FailedScopeParse(_) | OAuthErrorType::ScopesTooBroad => "invalid_scope", + Self::AuthenticationError(_) | Self::InvalidAcceptFlowId => "server_error", + Self::RedirectUriChanged(_) | Self::MalformedId(_) => "invalid_request", + Self::FailedScopeParse(_) | Self::ScopesTooBroad => "invalid_scope", + Self::InvalidClientId(_) | Self::ClientAuthenticationFailed => "invalid_client", + Self::InvalidAuthCode | Self::OnlySupportsAuthorizationCodeGrant(_) => "invalid_grant", + Self::UnauthorizedClient => "unauthorized_client", } .to_string() } diff --git a/src/auth/oauth/mod.rs b/src/auth/oauth/mod.rs index 059ba731..10397e45 100644 --- a/src/auth/oauth/mod.rs +++ b/src/auth/oauth/mod.rs @@ -1,16 +1,24 @@ use crate::auth::get_user_from_headers; +use crate::auth::validate::extract_authorization_header; use crate::database::models::flow_item::Flow; use crate::database::models::oauth_client_authorization_item::OAuthClientAuthorization; use crate::database::models::oauth_client_item::OAuthClient as DBOAuthClient; +use crate::database::models::oauth_token_item::OAuthAccessToken; use crate::database::models::{ - generate_oauth_client_authorization_id, OAuthClientAuthorizationId, OAuthClientId, + generate_oauth_access_token_id, generate_oauth_client_authorization_id, + OAuthClientAuthorizationId, OAuthClientId, }; use crate::database::redis::RedisPool; +use crate::models; use crate::models::pats::Scopes; use crate::queue::session::AuthQueue; use actix_web::web::{scope, Data, Query, ServiceConfig}; -use actix_web::{get, web, HttpRequest, HttpResponse}; +use actix_web::{get, post, web, HttpRequest, HttpResponse}; use chrono::{Duration, Utc}; +use rand::distributions::Alphanumeric; +use rand::{Rng, SeedableRng}; +use rand_chacha::ChaCha20Rng; +use reqwest::header::{CACHE_CONTROL, PRAGMA}; use serde::{Deserialize, Serialize}; use sqlx::postgres::PgPool; @@ -28,7 +36,7 @@ pub fn config(cfg: &mut ServiceConfig) { #[derive(Serialize, Deserialize)] pub struct OAuthInit { - pub client_id: OAuthClientId, + pub client_id: String, pub redirect_uri: Option, pub scope: Option, pub state: Option, @@ -62,7 +70,10 @@ pub async fn init_oauth( .map_err(OAuthError::error)? .1; - let client = DBOAuthClient::get(oauth_info.client_id, &**pool) + let client_id: OAuthClientId = models::ids::OAuthClientId::parse(&oauth_info.client_id) + .map_err(OAuthError::error)? + .into(); + let client = DBOAuthClient::get(client_id, &**pool) .await .map_err(OAuthError::error)?; @@ -138,9 +149,9 @@ pub async fn init_oauth( } } } else { - Err(OAuthError::error(OAuthErrorType::UnrecognizedClient { - client_id: oauth_info.client_id, - })) + Err(OAuthError::error(OAuthErrorType::InvalidClientId( + client_id, + ))) } } @@ -215,6 +226,142 @@ pub async fn accept_client_scopes( } } +#[derive(Deserialize)] +pub struct TokenRequest { + grant_type: String, + code: String, + redirect_uri: Option, + client_id: String, +} + +#[derive(Serialize)] +pub struct TokenResponse { + access_token: String, + token_type: String, + expires_in: i64, +} + +#[post("token")] +/// Params should be in the urlencoded request body +/// And client secret should be in the HTTP basic authorization header +/// Per IETF RFC6749 Section 4.1.3 (https://datatracker.ietf.org/doc/html/rfc6749#section-4.1.3) +pub async fn request_token( + req: HttpRequest, + req_params: web::Form, + pool: Data, + redis: Data, +) -> Result { + let client_id = models::ids::OAuthClientId::parse(&req_params.client_id) + .map_err(OAuthError::error)? + .into(); + let client = DBOAuthClient::get(client_id, &**pool) + .await + .map_err(OAuthError::error)?; + if let Some(client) = client { + authenticate_client_token_request(&req, &client)?; + + let flow = Flow::get(&req_params.code, &redis) + .await + .map_err(OAuthError::error)?; + if let Some(Flow::OAuthAuthorizationCodeSupplied { + user_id, + client_id, + authorization_id, + scopes, + original_redirect_uri, + }) = flow + { + // Ensure auth code is single use + // per IETF RFC6749 Section 10.5 (https://datatracker.ietf.org/doc/html/rfc6749#section-10.5) + Flow::remove(&req_params.code, &redis) + .await + .map_err(OAuthError::error)?; + + // https://datatracker.ietf.org/doc/html/rfc6749#section-4.1.3 + if client_id != client_id { + return Err(OAuthError::error(OAuthErrorType::UnauthorizedClient)); + } + + if original_redirect_uri != req_params.redirect_uri { + return Err(OAuthError::error(OAuthErrorType::RedirectUriChanged( + req_params.redirect_uri.clone(), + ))); + } + + if req_params.grant_type != "authorization_code" { + return Err(OAuthError::error( + OAuthErrorType::OnlySupportsAuthorizationCodeGrant( + req_params.grant_type.clone(), + ), + )); + } + + let mut transaction = pool.begin().await.map_err(OAuthError::error)?; + let token_id = generate_oauth_access_token_id(&mut transaction) + .await + .map_err(OAuthError::error)?; + let token = generate_access_token(); + let token_hash = OAuthAccessToken::hash_token(&token); + let time_until_expiration = OAuthAccessToken { + id: token_id, + authorization_id, + token_hash, + scopes, + created: Default::default(), + expires: Default::default(), + last_used: None, + client_id, + user_id, + } + .insert(&mut *transaction) + .await + .map_err(OAuthError::error)?; + + transaction.commit().await.map_err(OAuthError::error)?; + + // IETF RFC6749 Section 5.1 (https://datatracker.ietf.org/doc/html/rfc6749#section-5.1) + Ok(HttpResponse::Ok() + .append_header((CACHE_CONTROL, "no-store")) + .append_header((PRAGMA, "no-cache")) + .json(TokenResponse { + access_token: token, + token_type: "Bearer".to_string(), + expires_in: time_until_expiration.num_seconds(), + })) + } else { + Err(OAuthError::error(OAuthErrorType::InvalidAuthCode)) + } + } else { + Err(OAuthError::error(OAuthErrorType::InvalidClientId( + client_id, + ))) + } +} + +fn authenticate_client_token_request( + req: &HttpRequest, + client: &DBOAuthClient, +) -> Result<(), OAuthError> { + let client_secret = extract_authorization_header(&req).map_err(OAuthError::error)?; + let hashed_client_secret = DBOAuthClient::hash_secret(client_secret); + if client.secret_hash != hashed_client_secret { + Err(OAuthError::error( + OAuthErrorType::ClientAuthenticationFailed, + )) + } else { + Ok(()) + } +} + +fn generate_access_token() -> String { + let random = ChaCha20Rng::from_entropy() + .sample_iter(&Alphanumeric) + .take(60) + .map(char::from) + .collect::(); + format!("mro_{}", random) +} + async fn init_oauth_code_flow( user_id: crate::database::models::UserId, client_id: OAuthClientId, @@ -230,7 +377,6 @@ async fn init_oauth_code_flow( client_id, authorization_id, scopes, - validated_redirect_uri: validated_redirect_uri.clone(), original_redirect_uri, } .insert(Duration::minutes(10), redis) @@ -267,7 +413,7 @@ impl ValidatedRedirectUri { { return Ok(ValidatedRedirectUri(to_validate.clone())); } else { - return Err(OAuthError::error(OAuthErrorType::InvalidRedirectUri( + return Err(OAuthError::error(OAuthErrorType::RedirectUriNotConfigured( to_validate.clone(), ))); } @@ -340,8 +486,7 @@ mod tests { let validated = ValidatedRedirectUri::validate(&Some(to_validate), validate_against, OAuthClientId(0)); - assert!( - validated.is_err_and(|e| matches!(e.error_type, OAuthErrorType::InvalidRedirectUri(_))) - ); + assert!(validated + .is_err_and(|e| matches!(e.error_type, OAuthErrorType::RedirectUriNotConfigured(_)))); } } diff --git a/src/auth/validate.rs b/src/auth/validate.rs index cb979354..d5c254cb 100644 --- a/src/auth/validate.rs +++ b/src/auth/validate.rs @@ -91,12 +91,7 @@ where let token = if let Some(token) = token { token } else { - let headers = req.headers(); - let token_val: Option<&HeaderValue> = headers.get(AUTHORIZATION); - token_val - .ok_or_else(|| AuthenticationError::InvalidAuthMethod)? - .to_str() - .map_err(|_| AuthenticationError::InvalidCredentials)? + extract_authorization_header(req)? }; let possible_user = match token.split_once('_') { @@ -160,6 +155,15 @@ where Ok(possible_user) } +pub fn extract_authorization_header(req: &HttpRequest) -> Result<&str, AuthenticationError> { + let headers = req.headers(); + let token_val: Option<&HeaderValue> = headers.get(AUTHORIZATION); + Ok(token_val + .ok_or_else(|| AuthenticationError::InvalidAuthMethod)? + .to_str() + .map_err(|_| AuthenticationError::InvalidCredentials)?) +} + pub async fn check_is_moderator_from_headers<'a, 'b, E>( req: &HttpRequest, executor: E, diff --git a/src/database/models/flow_item.rs b/src/database/models/flow_item.rs index b6168abe..34ac12d9 100644 --- a/src/database/models/flow_item.rs +++ b/src/database/models/flow_item.rs @@ -48,7 +48,6 @@ pub enum Flow { client_id: OAuthClientId, authorization_id: OAuthClientAuthorizationId, scopes: Scopes, - validated_redirect_uri: ValidatedRedirectUri, original_redirect_uri: Option, // Needed for https://datatracker.ietf.org/doc/html/rfc6749#section-4.1.3 }, } diff --git a/src/database/models/ids.rs b/src/database/models/ids.rs index 1daef422..bcd399fc 100644 --- a/src/database/models/ids.rs +++ b/src/database/models/ids.rs @@ -413,6 +413,11 @@ impl From for ids::OAuthClientId { ids::OAuthClientId(id.0 as u64) } } +impl From for OAuthClientId { + fn from(id: ids::OAuthClientId) -> Self { + Self(id.0 as i64) + } +} impl From for ids::OAuthRedirectUriId { fn from(id: OAuthRedirectUriId) -> Self { ids::OAuthRedirectUriId(id.0 as u64) diff --git a/src/database/models/oauth_client_item.rs b/src/database/models/oauth_client_item.rs index 14bd65c4..499bac05 100644 --- a/src/database/models/oauth_client_item.rs +++ b/src/database/models/oauth_client_item.rs @@ -1,6 +1,7 @@ use chrono::{DateTime, Utc}; use itertools::Itertools; use serde::{Deserialize, Serialize}; +use sha2::Digest; use crate::models::pats::Scopes; @@ -198,6 +199,10 @@ impl OAuthClient { Ok(()) } + + pub fn hash_secret(secret: &str) -> String { + format!("{:x}", sha2::Sha512::digest(secret.as_bytes())) + } } impl From for OAuthClient { diff --git a/src/database/models/oauth_token_item.rs b/src/database/models/oauth_token_item.rs index dbafa786..099b511f 100644 --- a/src/database/models/oauth_token_item.rs +++ b/src/database/models/oauth_token_item.rs @@ -2,6 +2,7 @@ use super::{DatabaseError, OAuthAccessTokenId, OAuthClientAuthorizationId, OAuth use crate::models::pats::Scopes; use chrono::{DateTime, Utc}; use serde::{Deserialize, Serialize}; +use sha2::Digest; #[derive(Deserialize, Serialize, Clone, Debug)] pub struct OAuthAccessToken { @@ -58,18 +59,20 @@ impl OAuthAccessToken { })); } + /// Inserts and returns the time until the token expires pub async fn insert( &self, exec: impl sqlx::Executor<'_, Database = sqlx::Postgres>, - ) -> Result<(), DatabaseError> { - sqlx::query!( + ) -> Result { + let r = sqlx::query!( " INSERT INTO oauth_access_tokens ( id, authorization_id, token_hash, scopes, expires, last_used ) - Values ( + VALUES ( $1, $2, $3, $4, $5, $6 ) + RETURNING created, expires ", self.id.0, self.authorization_id.0, @@ -78,9 +81,16 @@ impl OAuthAccessToken { self.expires, Option::>::None ) - .execute(exec) + .fetch_one(exec) .await?; - Ok(()) + let (created, expires) = (r.created, r.expires); + let time_until_expiration = expires - created; + + Ok(time_until_expiration) + } + + pub fn hash_token(token: &str) -> String { + format!("{:x}", sha2::Sha512::digest(token.as_bytes())) } } diff --git a/src/models/ids.rs b/src/models/ids.rs index 63fcdeb1..f7fee884 100644 --- a/src/models/ids.rs +++ b/src/models/ids.rs @@ -104,6 +104,16 @@ macro_rules! impl_base62_display { } impl_base62_display!(Base62Id); +macro_rules! impl_base62_parse { + ($struct:ty) => { + impl $struct { + pub fn parse(value: &str) -> Result { + Ok(Self(base62_impl::parse_base62(value)?)) + } + } + }; +} + macro_rules! base62_id_impl { ($struct:ty, $cons:expr) => { from_base62id!($struct, $cons;); @@ -124,6 +134,7 @@ base62_id_impl!(SessionId, SessionId); base62_id_impl!(PatId, PatId); base62_id_impl!(ImageId, ImageId); base62_id_impl!(OAuthClientId, OAuthClientId); +impl_base62_parse!(OAuthClientId); base62_id_impl!(OAuthRedirectUriId, OAuthRedirectUriId); pub mod base62_impl { diff --git a/src/routes/v3/oauth_clients.rs b/src/routes/v3/oauth_clients.rs index 636d5398..c40aa7aa 100644 --- a/src/routes/v3/oauth_clients.rs +++ b/src/routes/v3/oauth_clients.rs @@ -34,6 +34,8 @@ use crate::{ util::validate::validation_errors_to_string, }; +use crate::database::models::oauth_client_item::OAuthClient as DBOAuthClient; + pub fn config(cfg: &mut web::ServiceConfig) { cfg.service(oauth_client_create); cfg.service(get_user_clients); @@ -124,7 +126,7 @@ pub async fn oauth_client_create<'a>( let client_id = generate_oauth_client_id(&mut transaction).await?; let client_secret = generate_oauth_client_secret(); - let client_secret_hash = format!("{:x}", sha2::Sha512::digest(client_secret.as_bytes())); + let client_secret_hash = DBOAuthClient::hash_secret(&client_secret); let redirect_uris = create_redirect_uris(&new_oauth_app.redirect_uris, client_id, &mut transaction).await?; @@ -283,7 +285,7 @@ async fn get_oauth_client_from_str_id( client_id: String, exec: impl sqlx::Executor<'_, Database = sqlx::Postgres>, ) -> Result, ApiError> { - let client_id = OAuthClientId(parse_base62(&client_id)? as i64); + let client_id = models::ids::OAuthClientId::parse(&client_id)?.into(); Ok(OAuthClient::get(client_id, exec).await?) } From 6450d044f2193f8722c33d6070fbf476859bbac3 Mon Sep 17 00:00:00 2001 From: OmegaJak Date: Wed, 18 Oct 2023 15:24:54 -0500 Subject: [PATCH 11/72] Accept oauth access tokens as auth and update through AuthQueue --- src/auth/validate.rs | 19 +++++++++++ src/queue/session.rs | 77 +++++++++++++++++++++++++++++++------------- 2 files changed, 74 insertions(+), 22 deletions(-) diff --git a/src/auth/validate.rs b/src/auth/validate.rs index d5c254cb..6f987c21 100644 --- a/src/auth/validate.rs +++ b/src/auth/validate.rs @@ -137,6 +137,25 @@ where user.map(|x| (Scopes::all(), x)) } + Some(("mro", _)) => { + use crate::database::models::oauth_token_item::OAuthAccessToken; + + let hash = OAuthAccessToken::hash_token(token); + let access_token = + crate::database::models::oauth_token_item::OAuthAccessToken::get(hash, executor) + .await? + .ok_or(AuthenticationError::InvalidCredentials)?; + + if access_token.expires < Utc::now() { + return Err(AuthenticationError::InvalidCredentials); + } + + let user = user_item::User::get_id(access_token.user_id, executor, redis).await?; + + session_queue.add_oauth_access_token(access_token.id).await; + + user.map(|u| (access_token.scopes, u)) + } Some(("github", _)) | Some(("gho", _)) | Some(("ghp", _)) => { let user = AuthProvider::GitHub.get_user(token).await?; let id = AuthProvider::GitHub.get_user_id(&user.id, executor).await?; diff --git a/src/queue/session.rs b/src/queue/session.rs index 8948810d..0b922b9c 100644 --- a/src/queue/session.rs +++ b/src/queue/session.rs @@ -1,9 +1,10 @@ use crate::auth::session::SessionMetadata; use crate::database::models::pat_item::PersonalAccessToken; use crate::database::models::session_item::Session; -use crate::database::models::{DatabaseError, PatId, SessionId, UserId}; +use crate::database::models::{DatabaseError, OAuthAccessTokenId, PatId, SessionId, UserId}; use crate::database::redis::RedisPool; use chrono::Utc; +use itertools::Itertools; use sqlx::PgPool; use std::collections::{HashMap, HashSet}; use tokio::sync::Mutex; @@ -11,6 +12,7 @@ use tokio::sync::Mutex; pub struct AuthQueue { session_queue: Mutex>, pat_queue: Mutex>, + oauth_access_token_queue: Mutex>, } impl Default for AuthQueue { @@ -25,6 +27,7 @@ impl AuthQueue { AuthQueue { session_queue: Mutex::new(HashMap::with_capacity(1000)), pat_queue: Mutex::new(HashSet::with_capacity(1000)), + oauth_access_token_queue: Mutex::new(HashSet::with_capacity(1000)), } } pub async fn add_session(&self, id: SessionId, metadata: SessionMetadata) { @@ -35,6 +38,10 @@ impl AuthQueue { self.pat_queue.lock().await.insert(id); } + pub async fn add_oauth_access_token(&self, id: crate::database::models::OAuthAccessTokenId) { + self.oauth_access_token_queue.lock().await.insert(id); + } + pub async fn take_sessions(&self) -> HashMap { let mut queue = self.session_queue.lock().await; let len = queue.len(); @@ -42,8 +49,8 @@ impl AuthQueue { std::mem::replace(&mut queue, HashMap::with_capacity(len)) } - pub async fn take_pats(&self) -> HashSet { - let mut queue = self.pat_queue.lock().await; + pub async fn take_set(queue: &Mutex>) -> HashSet { + let mut queue = queue.lock().await; let len = queue.len(); std::mem::replace(&mut queue, HashSet::with_capacity(len)) @@ -51,9 +58,13 @@ impl AuthQueue { pub async fn index(&self, pool: &PgPool, redis: &RedisPool) -> Result<(), DatabaseError> { let session_queue = self.take_sessions().await; - let pat_queue = self.take_pats().await; + let pat_queue = Self::take_set(&self.pat_queue).await; + let oauth_access_token_queue = Self::take_set(&self.oauth_access_token_queue).await; - if !session_queue.is_empty() || !pat_queue.is_empty() { + if !session_queue.is_empty() + || !pat_queue.is_empty() + || !oauth_access_token_queue.is_empty() + { let mut transaction = pool.begin().await?; let mut clear_cache_sessions = Vec::new(); @@ -102,29 +113,51 @@ impl AuthQueue { Session::clear_cache(clear_cache_sessions, redis).await?; - let mut clear_cache_pats = Vec::new(); - - for id in pat_queue { - clear_cache_pats.push((Some(id), None, None)); - - sqlx::query!( - " - UPDATE pats - SET last_used = $2 - WHERE (id = $1) - ", - id as PatId, - Utc::now(), - ) - .execute(&mut *transaction) - .await?; - } + let ids = pat_queue.iter().map(|id| id.0).collect_vec(); + let clear_cache_pats = pat_queue + .into_iter() + .map(|id| (Some(id), None, None)) + .collect_vec(); + sqlx::query!( + " + UPDATE pats + SET last_used = $2 + WHERE id IN + (SELECT * FROM UNNEST($1::bigint[])) + ", + &ids[..], + Utc::now(), + ) + .execute(&mut *transaction) + .await?; PersonalAccessToken::clear_cache(clear_cache_pats, redis).await?; + process_oauth_access_tokens(oauth_access_token_queue, &mut transaction).await?; + transaction.commit().await?; } Ok(()) } } + +async fn process_oauth_access_tokens( + oauth_access_token_queue: HashSet, + transaction: &mut sqlx::Transaction<'_, sqlx::Postgres>, +) -> Result<(), DatabaseError> { + let ids = oauth_access_token_queue.iter().map(|id| id.0).collect_vec(); + sqlx::query!( + " + UPDATE oauth_access_tokens + SET last_used = $2 + WHERE id IN + (SELECT * FROM UNNEST($1::bigint[])) + ", + &ids[..], + Utc::now() + ) + .execute(&mut *transaction) + .await?; + Ok(()) +} From 93cd620c3ee93ffe4b6a1e12d21716f7418d1c7f Mon Sep 17 00:00:00 2001 From: OmegaJak Date: Wed, 18 Oct 2023 17:38:53 -0500 Subject: [PATCH 12/72] User OAuth authorization management routes --- migrations/20231016190056_oauth_provider.sql | 4 +- src/auth/oauth/mod.rs | 22 ++--- src/database/models/flow_item.rs | 1 + src/database/models/ids.rs | 5 + .../models/oauth_client_authorization_item.rs | 98 ++++++++++++++++--- src/models/ids.rs | 2 + src/models/oauth_clients.rs | 33 +++++++ src/models/pats.rs | 9 +- src/routes/v3/oauth_clients.rs | 58 ++++++++++- 9 files changed, 199 insertions(+), 33 deletions(-) diff --git a/migrations/20231016190056_oauth_provider.sql b/migrations/20231016190056_oauth_provider.sql index 2b82384c..a6a9406e 100644 --- a/migrations/20231016190056_oauth_provider.sql +++ b/migrations/20231016190056_oauth_provider.sql @@ -17,7 +17,8 @@ CREATE TABLE oauth_client_authorizations ( client_id bigint NOT NULL REFERENCES oauth_clients (id) ON DELETE CASCADE, user_id bigint NOT NULL REFERENCES users (id) ON DELETE CASCADE, scopes bigint NOT NULL, - created timestamptz NOT NULL DEFAULT CURRENT_TIMESTAMP + created timestamptz NOT NULL DEFAULT CURRENT_TIMESTAMP, + UNIQUE (client_id, user_id) ); CREATE TABLE oauth_access_tokens ( id bigint PRIMARY KEY, @@ -30,5 +31,4 @@ CREATE TABLE oauth_access_tokens ( ); CREATE INDEX oauth_client_creator ON oauth_clients(created_by); CREATE INDEX oauth_redirect_client ON oauth_client_redirect_uris(client_id); -CREATE INDEX oauth_authorization_client_user ON oauth_client_authorizations(client_id, user_id); CREATE INDEX oauth_access_token_hash ON oauth_access_tokens(token_hash); \ No newline at end of file diff --git a/src/auth/oauth/mod.rs b/src/auth/oauth/mod.rs index 10397e45..c25f05e2 100644 --- a/src/auth/oauth/mod.rs +++ b/src/auth/oauth/mod.rs @@ -129,6 +129,7 @@ pub async fn init_oauth( let flow_id = Flow::InitOAuthAppApproval { user_id: user.id.into(), client_id: client.id, + existing_authorization_id: existing_authorization.map(|a| a.id), scopes: requested_scopes, validated_redirect_uri: redirect_uri.clone(), original_redirect_uri: oauth_info.redirect_uri.clone(), @@ -174,7 +175,7 @@ pub async fn accept_client_scopes( &**pool, &redis, &session_queue, - Some(&[Scopes::USER_AUTH_WRITE]), + Some(&[Scopes::USER_OAUTH_AUTHORIZATIONS_WRITE]), ) .await .map_err(|e| OAuthError::error(e))? @@ -186,6 +187,7 @@ pub async fn accept_client_scopes( if let Some(Flow::InitOAuthAppApproval { user_id, client_id, + existing_authorization_id, scopes, validated_redirect_uri, original_redirect_uri, @@ -194,19 +196,15 @@ pub async fn accept_client_scopes( { let mut transaction = pool.begin().await.map_err(OAuthError::error)?; - let auth_id = generate_oauth_client_authorization_id(&mut transaction) + let auth_id = match existing_authorization_id { + Some(id) => id, + None => generate_oauth_client_authorization_id(&mut transaction) + .await + .map_err(OAuthError::error)?, + }; + OAuthClientAuthorization::upsert(auth_id, client_id, user_id, scopes, &mut transaction) .await .map_err(OAuthError::error)?; - OAuthClientAuthorization { - id: auth_id, - client_id, - user_id, - scopes, - created: Utc::now(), - } - .insert(&mut transaction) - .await - .map_err(OAuthError::error)?; transaction.commit().await.map_err(OAuthError::error)?; diff --git a/src/database/models/flow_item.rs b/src/database/models/flow_item.rs index 34ac12d9..291c0e16 100644 --- a/src/database/models/flow_item.rs +++ b/src/database/models/flow_item.rs @@ -38,6 +38,7 @@ pub enum Flow { InitOAuthAppApproval { user_id: UserId, client_id: OAuthClientId, + existing_authorization_id: Option, scopes: Scopes, validated_redirect_uri: ValidatedRedirectUri, original_redirect_uri: Option, diff --git a/src/database/models/ids.rs b/src/database/models/ids.rs index bcd399fc..68d11a6c 100644 --- a/src/database/models/ids.rs +++ b/src/database/models/ids.rs @@ -423,3 +423,8 @@ impl From for ids::OAuthRedirectUriId { ids::OAuthRedirectUriId(id.0 as u64) } } +impl From for ids::OAuthClientAuthorizationId { + fn from(id: OAuthClientAuthorizationId) -> Self { + ids::OAuthClientAuthorizationId(id.0 as u64) + } +} diff --git a/src/database/models/oauth_client_authorization_item.rs b/src/database/models/oauth_client_authorization_item.rs index cc8dbc18..01eb48a3 100644 --- a/src/database/models/oauth_client_authorization_item.rs +++ b/src/database/models/oauth_client_authorization_item.rs @@ -1,4 +1,5 @@ use chrono::{DateTime, Utc}; +use itertools::Itertools; use serde::{Deserialize, Serialize}; use crate::models::pats::Scopes; @@ -14,15 +15,23 @@ pub struct OAuthClientAuthorization { pub created: DateTime, } +pub struct OAuthClientAuthorizationWithClientInfo { + pub id: OAuthClientAuthorizationId, + pub client_id: OAuthClientId, + pub user_id: UserId, + pub scopes: Scopes, + pub created: DateTime, + pub client_name: String, + pub client_icon_url: Option, + pub client_created_by: UserId, +} + impl OAuthClientAuthorization { - pub async fn get<'a, E>( + pub async fn get( client_id: OAuthClientId, user_id: UserId, - exec: E, - ) -> Result, DatabaseError> - where - E: sqlx::Executor<'a, Database = sqlx::Postgres>, - { + exec: impl sqlx::Executor<'_, Database = sqlx::Postgres>, + ) -> Result, DatabaseError> { let value = sqlx::query!( " SELECT id, client_id, user_id, scopes, created @@ -39,13 +48,55 @@ impl OAuthClientAuthorization { id: OAuthClientAuthorizationId(r.id), client_id: OAuthClientId(r.client_id), user_id: UserId(r.user_id), - scopes: Scopes::from_bits(r.scopes as u64).unwrap_or(Scopes::NONE), + scopes: Scopes::from_postgres(r.scopes), created: r.created, })); } - pub async fn insert( - &self, + pub async fn get_all_for_user( + user_id: UserId, + exec: impl sqlx::Executor<'_, Database = sqlx::Postgres>, + ) -> Result, DatabaseError> { + let results = sqlx::query!( + " + SELECT + auths.id, + auths.client_id, + auths.user_id, + auths.scopes, + auths.created, + clients.name as client_name, + clients.icon_url as client_icon_url, + clients.created_by as client_created_by + FROM oauth_client_authorizations auths + JOIN oauth_clients clients ON clients.id = auths.client_id + WHERE user_id=$1 + ", + user_id.0 + ) + .fetch_all(exec) + .await?; + + return Ok(results + .into_iter() + .map(|r| OAuthClientAuthorizationWithClientInfo { + id: OAuthClientAuthorizationId(r.id), + client_id: OAuthClientId(r.client_id), + user_id: UserId(r.user_id), + scopes: Scopes::from_postgres(r.scopes), + created: r.created, + client_name: r.client_name, + client_icon_url: r.client_icon_url, + client_created_by: UserId(r.client_created_by), + }) + .collect_vec()); + } + + pub async fn upsert( + id: OAuthClientAuthorizationId, + client_id: OAuthClientId, + user_id: UserId, + scopes: Scopes, transaction: &mut sqlx::Transaction<'_, sqlx::Postgres>, ) -> Result<(), DatabaseError> { sqlx::query!( @@ -56,15 +107,36 @@ impl OAuthClientAuthorization { VALUES ( $1, $2, $3, $4 ) + ON CONFLICT (id, client_id, user_id) + DO UPDATE SET scopes = EXCLUDED.scopes ", - self.id.0, - self.client_id.0, - self.user_id.0, - self.scopes.bits() as i64, + id.0, + client_id.0, + user_id.0, + scopes.bits() as i64, ) .execute(&mut *transaction) .await?; Ok(()) } + + pub async fn remove( + client_id: OAuthClientId, + user_id: UserId, + exec: impl sqlx::Executor<'_, Database = sqlx::Postgres>, + ) -> Result<(), DatabaseError> { + sqlx::query!( + " + DELETE FROM oauth_client_authorizations + WHERE client_id=$1 AND user_id=$2 + ", + client_id.0, + user_id.0 + ) + .execute(exec) + .await?; + + Ok(()) + } } diff --git a/src/models/ids.rs b/src/models/ids.rs index f7fee884..6c8bf79a 100644 --- a/src/models/ids.rs +++ b/src/models/ids.rs @@ -3,6 +3,7 @@ use thiserror::Error; pub use super::collections::CollectionId; pub use super::images::ImageId; pub use super::notifications::NotificationId; +pub use super::oauth_clients::OAuthClientAuthorizationId; pub use super::oauth_clients::{OAuthClientId, OAuthRedirectUriId}; pub use super::organizations::OrganizationId; pub use super::pats::PatId; @@ -136,6 +137,7 @@ base62_id_impl!(ImageId, ImageId); base62_id_impl!(OAuthClientId, OAuthClientId); impl_base62_parse!(OAuthClientId); base62_id_impl!(OAuthRedirectUriId, OAuthRedirectUriId); +base62_id_impl!(OAuthClientAuthorizationId, OAuthClientAuthorizationId); pub mod base62_impl { use serde::de::{self, Deserializer, Visitor}; diff --git a/src/models/oauth_clients.rs b/src/models/oauth_clients.rs index 7defa59e..bd02c217 100644 --- a/src/models/oauth_clients.rs +++ b/src/models/oauth_clients.rs @@ -2,6 +2,7 @@ use super::{ ids::{Base62Id, UserId}, pats::Scopes, }; +use chrono::{DateTime, Utc}; use serde::{Deserialize, Serialize}; #[derive(Copy, Clone, PartialEq, Eq, Serialize, Deserialize)] @@ -9,6 +10,11 @@ use serde::{Deserialize, Serialize}; #[serde(into = "Base62Id")] pub struct OAuthClientId(pub u64); +#[derive(Copy, Clone, PartialEq, Eq, Serialize, Deserialize)] +#[serde(from = "Base62Id")] +#[serde(into = "Base62Id")] +pub struct OAuthClientAuthorizationId(pub u64); + #[derive(Copy, Clone, PartialEq, Eq, Serialize, Deserialize)] #[serde(from = "Base62Id")] #[serde(into = "Base62Id")] @@ -43,6 +49,18 @@ pub struct OAuthClient { pub created_by: UserId, } +#[derive(Deserialize, Serialize)] +pub struct OAuthClientAuthorizationInfo { + pub id: OAuthClientAuthorizationId, + pub client_id: OAuthClientId, + pub user_id: UserId, + pub scopes: Scopes, + pub created: DateTime, + pub client_name: String, + pub client_icon_url: Option, + pub client_created_by: UserId, +} + impl From for OAuthClient { fn from(value: crate::database::models::oauth_client_item::OAuthClient) -> Self { Self { @@ -65,3 +83,18 @@ impl From for OAut } } } + +impl From for OAuthClientAuthorizationInfo { + fn from(value: crate::database::models::oauth_client_authorization_item::OAuthClientAuthorizationWithClientInfo) -> Self { + Self { + id: value.id.into(), + client_id: value.client_id.into(), + user_id: value.user_id.into(), + scopes: value.scopes, + created: value.created, + client_name: value.client_name, + client_icon_url: value.client_icon_url, + client_created_by: value.client_created_by.into(), + } + } +} diff --git a/src/models/pats.rs b/src/models/pats.rs index 038847f3..ff63e807 100644 --- a/src/models/pats.rs +++ b/src/models/pats.rs @@ -105,13 +105,18 @@ bitflags::bitflags! { // create an OAuth client const OAUTH_CLIENT_CREATE = 1 << 39; - // get oauth clients + // get OAuth clients const OAUTH_CLIENT_READ = 1 << 40; // modify an OAuth client const OAUTH_CLIENT_WRITE = 1 << 41; // delete an OAuth client const OAUTH_CLIENT_DELETE = 1 << 42; + // read a user's OAuth authorizations + const USER_OAUTH_AUTHORIZATIONS_READ = 1 << 43; + // write a user's OAuth authorizations + const USER_OAUTH_AUTHORIZATIONS_WRITE = 1 << 44; + const NONE = 0b0; } } @@ -134,6 +139,8 @@ impl Scopes { | Scopes::OAUTH_CLIENT_READ | Scopes::OAUTH_CLIENT_WRITE | Scopes::OAUTH_CLIENT_DELETE + | Scopes::USER_OAUTH_AUTHORIZATIONS_READ + | Scopes::USER_OAUTH_AUTHORIZATIONS_WRITE } pub fn is_restricted(&self) -> bool { diff --git a/src/routes/v3/oauth_clients.rs b/src/routes/v3/oauth_clients.rs index c40aa7aa..2d04a68a 100644 --- a/src/routes/v3/oauth_clients.rs +++ b/src/routes/v3/oauth_clients.rs @@ -10,7 +10,6 @@ use itertools::Itertools; use rand::{distributions::Alphanumeric, Rng, SeedableRng}; use rand_chacha::ChaCha20Rng; use serde::Deserialize; -use sha2::Digest; use sqlx::PgPool; use validator::Validate; @@ -20,15 +19,13 @@ use crate::{ database::{ models::{ generate_oauth_client_id, generate_oauth_redirect_id, + oauth_client_authorization_item::OAuthClientAuthorization, oauth_client_item::{OAuthClient, OAuthRedirectUri}, DatabaseError, OAuthClientId, User, }, redis::RedisPool, }, - models::{ - self, ids::base62_impl::parse_base62, oauth_clients::OAuthClientCreationResult, - pats::Scopes, - }, + models::{self, oauth_clients::OAuthClientCreationResult, pats::Scopes}, queue::session::AuthQueue, routes::v2::project_creation::CreateError, util::validate::validation_errors_to_string, @@ -273,6 +270,57 @@ pub async fn oauth_client_edit( } } +#[get("user/oauth_authorizations")] +pub async fn get_user_oauth_authorizations( + req: HttpRequest, + pool: web::Data, + redis: web::Data, + session_queue: web::Data, +) -> Result { + let current_user = get_user_from_headers( + &req, + &**pool, + &redis, + &session_queue, + Some(&[Scopes::USER_OAUTH_AUTHORIZATIONS_READ]), + ) + .await? + .1; + + let authorizations = + OAuthClientAuthorization::get_all_for_user(current_user.id.into(), &**pool).await?; + + let mapped: Vec = + authorizations.into_iter().map(|a| a.into()).collect_vec(); + + Ok(HttpResponse::Ok().json(mapped)) +} + +#[delete("user/oauth_authorizations/{client_id}")] +pub async fn revoke_oauth_authorization( + req: HttpRequest, + info: web::Path, + pool: web::Data, + redis: web::Data, + session_queue: web::Data, +) -> Result { + let current_user = get_user_from_headers( + &req, + &**pool, + &redis, + &session_queue, + Some(&[Scopes::USER_OAUTH_AUTHORIZATIONS_WRITE]), + ) + .await? + .1; + + let client_id = models::ids::OAuthClientId::parse(&info.into_inner())?; + + OAuthClientAuthorization::remove(client_id.into(), current_user.id.into(), &**pool).await?; + + Ok(HttpResponse::Ok().body("")) +} + fn generate_oauth_client_secret() -> String { ChaCha20Rng::from_entropy() .sample_iter(&Alphanumeric) From 2e4dbb1f0a80b4960dc1fdd30b3f3cea7574d174 Mon Sep 17 00:00:00 2001 From: OmegaJak Date: Wed, 18 Oct 2023 17:40:57 -0500 Subject: [PATCH 13/72] Forgot to actually add the routes lol --- src/auth/oauth/mod.rs | 3 ++- src/routes/v3/oauth_clients.rs | 2 ++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/src/auth/oauth/mod.rs b/src/auth/oauth/mod.rs index c25f05e2..11862731 100644 --- a/src/auth/oauth/mod.rs +++ b/src/auth/oauth/mod.rs @@ -30,7 +30,8 @@ pub fn config(cfg: &mut ServiceConfig) { cfg.service( scope("auth/oauth") .service(init_oauth) - .service(accept_client_scopes), + .service(accept_client_scopes) + .service(request_token), ); } diff --git a/src/routes/v3/oauth_clients.rs b/src/routes/v3/oauth_clients.rs index 2d04a68a..35475f21 100644 --- a/src/routes/v3/oauth_clients.rs +++ b/src/routes/v3/oauth_clients.rs @@ -37,6 +37,8 @@ pub fn config(cfg: &mut web::ServiceConfig) { cfg.service(oauth_client_create); cfg.service(get_user_clients); cfg.service(oauth_client_delete); + cfg.service(get_user_oauth_authorizations); + cfg.service(revoke_oauth_authorization); } #[get("user/{user_id}/oauth_apps")] From 8e6d1ff9eb0a1d2e4e9c4b8d0c63ec2645d6b313 Mon Sep 17 00:00:00 2001 From: OmegaJak Date: Wed, 18 Oct 2023 18:10:29 -0500 Subject: [PATCH 14/72] Bit o cleanup --- src/auth/oauth/mod.rs | 11 ++++++++--- src/routes/v3/oauth_clients.rs | 4 ++-- 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/src/auth/oauth/mod.rs b/src/auth/oauth/mod.rs index 11862731..c4d943fe 100644 --- a/src/auth/oauth/mod.rs +++ b/src/auth/oauth/mod.rs @@ -14,7 +14,7 @@ use crate::models::pats::Scopes; use crate::queue::session::AuthQueue; use actix_web::web::{scope, Data, Query, ServiceConfig}; use actix_web::{get, post, web, HttpRequest, HttpResponse}; -use chrono::{Duration, Utc}; +use chrono::Duration; use rand::distributions::Alphanumeric; use rand::{Rng, SeedableRng}; use rand_chacha::ChaCha20Rng; @@ -24,6 +24,8 @@ use sqlx::postgres::PgPool; use self::errors::{OAuthError, OAuthErrorType}; +use super::AuthenticationError; + pub mod errors; pub fn config(cfg: &mut ServiceConfig) { @@ -170,8 +172,7 @@ pub async fn accept_client_scopes( redis: Data, session_queue: Data, ) -> Result { - //TODO: Any way to do this without getting the user? - let user = get_user_from_headers( + let current_user = get_user_from_headers( &req, &**pool, &redis, @@ -195,6 +196,10 @@ pub async fn accept_client_scopes( state, }) = flow { + if current_user.id != user_id.into() { + return Err(OAuthError::error(AuthenticationError::InvalidCredentials)); + } + let mut transaction = pool.begin().await.map_err(OAuthError::error)?; let auth_id = match existing_authorization_id { diff --git a/src/routes/v3/oauth_clients.rs b/src/routes/v3/oauth_clients.rs index 35475f21..fde1fad2 100644 --- a/src/routes/v3/oauth_clients.rs +++ b/src/routes/v3/oauth_clients.rs @@ -111,7 +111,7 @@ pub async fn oauth_client_create<'a>( &**pool, &redis, &session_queue, - Some(&[Scopes::OAUTH_CLIENT_WRITE]), + Some(&[Scopes::OAUTH_CLIENT_CREATE]), ) .await? .1; @@ -215,7 +215,7 @@ pub async fn oauth_client_edit( &**pool, &redis, &session_queue, - Some(&[Scopes::OAUTH_CLIENT_DELETE]), + Some(&[Scopes::OAUTH_CLIENT_WRITE]), ) .await? .1; From 229f505a38f75c5a5016a6fcacae0e7995132a5d Mon Sep 17 00:00:00 2001 From: OmegaJak Date: Thu, 19 Oct 2023 14:15:08 -0500 Subject: [PATCH 15/72] Happy path test for OAuth and minor fixes for things it found --- src/auth/oauth/mod.rs | 26 ++--- .../models/oauth_client_authorization_item.rs | 2 +- src/database/models/oauth_token_item.rs | 5 +- src/models/oauth_clients.rs | 2 +- src/routes/v2/mod.rs | 1 - src/routes/v3/mod.rs | 6 +- tests/common/api_v2/project.rs | 4 +- tests/common/api_v2/team.rs | 4 + tests/common/api_v3/mod.rs | 18 +++ tests/common/api_v3/oauth_clients.rs | 35 ++++++ tests/common/asserts.rs | 2 +- tests/common/environment.rs | 7 +- tests/common/mod.rs | 1 + tests/oauth.rs | 110 ++++++++++++++++++ 14 files changed, 198 insertions(+), 25 deletions(-) create mode 100644 tests/common/api_v3/mod.rs create mode 100644 tests/common/api_v3/oauth_clients.rs create mode 100644 tests/oauth.rs diff --git a/src/auth/oauth/mod.rs b/src/auth/oauth/mod.rs index c4d943fe..bbffb1d9 100644 --- a/src/auth/oauth/mod.rs +++ b/src/auth/oauth/mod.rs @@ -18,7 +18,7 @@ use chrono::Duration; use rand::distributions::Alphanumeric; use rand::{Rng, SeedableRng}; use rand_chacha::ChaCha20Rng; -use reqwest::header::{CACHE_CONTROL, PRAGMA}; +use reqwest::header::{CACHE_CONTROL, LOCATION, PRAGMA}; use serde::{Deserialize, Serialize}; use sqlx::postgres::PgPool; @@ -159,12 +159,12 @@ pub async fn init_oauth( } } -#[derive(Deserialize)] +#[derive(Serialize, Deserialize)] pub struct AcceptOAuthClientScopes { pub flow: String, } -#[get("accept")] +#[post("accept")] pub async fn accept_client_scopes( req: HttpRequest, accept_body: web::Json, @@ -230,19 +230,19 @@ pub async fn accept_client_scopes( } } -#[derive(Deserialize)] +#[derive(Serialize, Deserialize)] pub struct TokenRequest { - grant_type: String, - code: String, - redirect_uri: Option, - client_id: String, + pub grant_type: String, + pub code: String, + pub redirect_uri: Option, + pub client_id: String, } -#[derive(Serialize)] +#[derive(Serialize, Deserialize)] pub struct TokenResponse { - access_token: String, - token_type: String, - expires_in: i64, + pub access_token: String, + pub token_type: String, + pub expires_in: i64, } #[post("token")] @@ -396,7 +396,7 @@ async fn init_oauth_code_flow( // IETF RFC 6749 Section 4.1.2 (https://datatracker.ietf.org/doc/html/rfc6749#section-4.1.2) Ok(HttpResponse::Found() - .append_header(("Location", redirect_uri)) + .append_header((LOCATION, redirect_uri)) .finish()) } diff --git a/src/database/models/oauth_client_authorization_item.rs b/src/database/models/oauth_client_authorization_item.rs index 01eb48a3..0972caf0 100644 --- a/src/database/models/oauth_client_authorization_item.rs +++ b/src/database/models/oauth_client_authorization_item.rs @@ -107,7 +107,7 @@ impl OAuthClientAuthorization { VALUES ( $1, $2, $3, $4 ) - ON CONFLICT (id, client_id, user_id) + ON CONFLICT (id) DO UPDATE SET scopes = EXCLUDED.scopes ", id.0, diff --git a/src/database/models/oauth_token_item.rs b/src/database/models/oauth_token_item.rs index 099b511f..c2e95501 100644 --- a/src/database/models/oauth_token_item.rs +++ b/src/database/models/oauth_token_item.rs @@ -67,10 +67,10 @@ impl OAuthAccessToken { let r = sqlx::query!( " INSERT INTO oauth_access_tokens ( - id, authorization_id, token_hash, scopes, expires, last_used + id, authorization_id, token_hash, scopes, last_used ) VALUES ( - $1, $2, $3, $4, $5, $6 + $1, $2, $3, $4, $5 ) RETURNING created, expires ", @@ -78,7 +78,6 @@ impl OAuthAccessToken { self.authorization_id.0, self.token_hash, self.scopes.to_postgres(), - self.expires, Option::>::None ) .fetch_one(exec) diff --git a/src/models/oauth_clients.rs b/src/models/oauth_clients.rs index bd02c217..a22eb16d 100644 --- a/src/models/oauth_clients.rs +++ b/src/models/oauth_clients.rs @@ -27,7 +27,7 @@ pub struct OAuthRedirectUri { pub uri: String, } -#[derive(Serialize)] +#[derive(Serialize, Deserialize)] pub struct OAuthClientCreationResult { pub client_secret: String, pub client: OAuthClient, diff --git a/src/routes/v2/mod.rs b/src/routes/v2/mod.rs index ddac66cf..3f95ae6d 100644 --- a/src/routes/v2/mod.rs +++ b/src/routes/v2/mod.rs @@ -29,7 +29,6 @@ pub fn config(cfg: &mut actix_web::web::ServiceConfig) { .configure(crate::auth::session::config) .configure(crate::auth::flows::config) .configure(crate::auth::pats::config) - .configure(crate::auth::oauth::config) .configure(moderation::config) .configure(notifications::config) .configure(organizations::config) diff --git a/src/routes/v3/mod.rs b/src/routes/v3/mod.rs index 530a01a6..d90429c2 100644 --- a/src/routes/v3/mod.rs +++ b/src/routes/v3/mod.rs @@ -1,5 +1,5 @@ pub use super::ApiError; -use crate::util::cors::default_cors; +use crate::{auth::oauth, util::cors::default_cors}; use actix_web::{web, HttpResponse}; use serde_json::json; @@ -9,7 +9,9 @@ pub fn config(cfg: &mut web::ServiceConfig) { cfg.service( web::scope("v3") .wrap(default_cors()) - .route("", web::get().to(hello_world)), + .route("", web::get().to(hello_world)) + .configure(oauth::config) + .configure(oauth_clients::config), ); } diff --git a/tests/common/api_v2/project.rs b/tests/common/api_v2/project.rs index d8f5f858..dd44692e 100644 --- a/tests/common/api_v2/project.rs +++ b/tests/common/api_v2/project.rs @@ -29,7 +29,7 @@ impl ApiV2 { .set_multipart(creation_data.segment_data) .to_request(); let resp = self.call(req).await; - assert_status(resp, StatusCode::OK); + assert_status(&resp, StatusCode::OK); // Approve as a moderator. let req = TestRequest::patch() @@ -42,7 +42,7 @@ impl ApiV2 { )) .to_request(); let resp = self.call(req).await; - assert_status(resp, StatusCode::NO_CONTENT); + assert_status(&resp, StatusCode::NO_CONTENT); let project = self .get_project_deserialized(&creation_data.slug, pat) diff --git a/tests/common/api_v2/team.rs b/tests/common/api_v2/team.rs index f1d6ef73..a4209ece 100644 --- a/tests/common/api_v2/team.rs +++ b/tests/common/api_v2/team.rs @@ -1,3 +1,4 @@ +use actix_http::StatusCode; use actix_web::{dev::ServiceResponse, test}; use labrinth::models::{ notifications::Notification, @@ -5,6 +6,8 @@ use labrinth::models::{ }; use serde_json::json; +use crate::common::asserts::assert_status; + use super::ApiV2; impl ApiV2 { @@ -124,6 +127,7 @@ impl ApiV2 { .append_header(("Authorization", pat)) .to_request(); let resp = self.call(req).await; + assert_status(&resp, StatusCode::OK); test::read_body_json(resp).await } diff --git a/tests/common/api_v3/mod.rs b/tests/common/api_v3/mod.rs new file mode 100644 index 00000000..68a23197 --- /dev/null +++ b/tests/common/api_v3/mod.rs @@ -0,0 +1,18 @@ +#![allow(dead_code)] + +use super::environment::LocalService; +use actix_web::dev::ServiceResponse; +use std::rc::Rc; + +pub mod oauth_clients; + +#[derive(Clone)] +pub struct ApiV3 { + pub test_app: Rc, +} + +impl ApiV3 { + pub async fn call(&self, req: actix_http::Request) -> ServiceResponse { + self.test_app.call(req).await.unwrap() + } +} diff --git a/tests/common/api_v3/oauth_clients.rs b/tests/common/api_v3/oauth_clients.rs new file mode 100644 index 00000000..aff9c1eb --- /dev/null +++ b/tests/common/api_v3/oauth_clients.rs @@ -0,0 +1,35 @@ +use actix_http::StatusCode; +use actix_web::test::{self, TestRequest}; +use labrinth::models::{oauth_clients::OAuthClientCreationResult, pats::Scopes}; +use reqwest::header::AUTHORIZATION; +use serde_json::json; + +use crate::common::asserts::assert_status; + +use super::ApiV3; + +impl ApiV3 { + pub async fn add_oauth_client( + &self, + name: String, + max_scopes: Scopes, + redirect_uris: Vec, + pat: &str, + ) -> OAuthClientCreationResult { + let max_scopes = max_scopes.bits(); + let req = TestRequest::post() + .uri("/v3/oauth_app") + .append_header((AUTHORIZATION, pat)) + .set_json(json!({ + "name": name, + "max_scopes": max_scopes, + "redirect_uris": redirect_uris + })) + .to_request(); + + let resp = self.call(req).await; + assert_status(&resp, StatusCode::OK); + + test::read_body_json(resp).await + } +} diff --git a/tests/common/asserts.rs b/tests/common/asserts.rs index c98dbd39..e15f177e 100644 --- a/tests/common/asserts.rs +++ b/tests/common/asserts.rs @@ -1,3 +1,3 @@ -pub fn assert_status(response: actix_web::dev::ServiceResponse, status: actix_http::StatusCode) { +pub fn assert_status(response: &actix_web::dev::ServiceResponse, status: actix_http::StatusCode) { assert_eq!(response.status(), status, "{:#?}", response.response()); } diff --git a/tests/common/environment.rs b/tests/common/environment.rs index abeaf730..08e72d8b 100644 --- a/tests/common/environment.rs +++ b/tests/common/environment.rs @@ -4,6 +4,7 @@ use std::{rc::Rc, sync::Arc}; use super::{ api_v2::ApiV2, + api_v3::ApiV3, asserts::assert_status, database::{TemporaryDatabase, FRIEND_USER_ID, USER_USER_PAT}, dummy_data, @@ -34,6 +35,7 @@ pub struct TestEnvironment { test_app: Rc, // Rc as it's not Send pub db: TemporaryDatabase, pub v2: ApiV2, + pub v3: ApiV3, pub dummy: Option>, } @@ -56,6 +58,9 @@ impl TestEnvironment { v2: ApiV2 { test_app: test_app.clone(), }, + v3: ApiV3 { + test_app: test_app.clone(), + }, test_app, db, dummy: None, @@ -81,7 +86,7 @@ impl TestEnvironment { USER_USER_PAT, ) .await; - assert_status(resp, StatusCode::NO_CONTENT); + assert_status(&resp, StatusCode::NO_CONTENT); } } diff --git a/tests/common/mod.rs b/tests/common/mod.rs index 39b3305a..5806c5db 100644 --- a/tests/common/mod.rs +++ b/tests/common/mod.rs @@ -6,6 +6,7 @@ use self::database::TemporaryDatabase; pub mod actix; pub mod api_v2; +pub mod api_v3; pub mod asserts; pub mod database; pub mod dummy_data; diff --git a/tests/oauth.rs b/tests/oauth.rs new file mode 100644 index 00000000..aea2616f --- /dev/null +++ b/tests/oauth.rs @@ -0,0 +1,110 @@ +use std::collections::HashMap; + +use actix_http::StatusCode; +use actix_web::test::{self, TestRequest}; +use common::{ + asserts::assert_status, + database::{FRIEND_USER_PAT, USER_USER_PAT}, + environment::with_test_environment, +}; +use labrinth::{ + auth::oauth::{AcceptOAuthClientScopes, OAuthClientAccessRequest, TokenRequest, TokenResponse}, + models::pats::Scopes, +}; +use reqwest::header::{AUTHORIZATION, CACHE_CONTROL, LOCATION, PRAGMA}; + +use crate::common::database::FRIEND_USER_ID; + +mod common; + +#[actix_rt::test] +async fn oauth_flow() { + with_test_environment(|env| async move { + let base_redirect_uri = "https://modrinth.com/test".to_string(); + let client_creation = env + .v3 + .add_oauth_client( + "test_client".to_string(), + Scopes::all(), + vec![base_redirect_uri.clone()], + USER_USER_PAT, + ) + .await; + let client_id_str = serde_json::to_value(client_creation.client.id) + .unwrap() + .as_str() + .unwrap() + .to_string(); + + // Initiate authorization + let redirect_uri = format!("{}?foo=bar", base_redirect_uri); + let original_state = "1234"; + let uri = &format!( + "/v3/auth/oauth/authorize?client_id={}&redirect_uri={}\ + &scope={}&state={original_state}", + urlencoding::encode(&client_id_str), + urlencoding::encode(&redirect_uri), + urlencoding::encode("USER_READ NOTIFICATION_READ") + ) + .to_string(); + let req = TestRequest::get() + .uri(&uri) + .append_header((AUTHORIZATION, FRIEND_USER_PAT)) + .to_request(); + let resp = env.call(req).await; + assert_status(&resp, StatusCode::OK); + let access_request: OAuthClientAccessRequest = test::read_body_json(resp).await; + + // Accept the authorization request + let resp = env + .call( + TestRequest::post() + .uri("/v3/auth/oauth/accept") + .append_header((AUTHORIZATION, FRIEND_USER_PAT)) + .set_json(AcceptOAuthClientScopes { + flow: access_request.flow_id, + }) + .to_request(), + ) + .await; + assert_status(&resp, StatusCode::FOUND); + let redirect_location = resp.headers().get(LOCATION).unwrap().to_str().unwrap(); + let query = actix_web::web::Query::>::from_query( + redirect_location.split_once('?').unwrap().1, + ) + .unwrap(); + + println!("redirect location: {}, {:#?}", redirect_location, query.0); + let auth_code = query.get("code").unwrap(); + let state = query.get("state").unwrap(); + let foo = query.get("foo").unwrap(); + assert_eq!(state, original_state); + assert_eq!(foo, "bar"); + + // Get the token + let resp = env + .call( + TestRequest::post() + .uri("/v3/auth/oauth/token") + .append_header((AUTHORIZATION, client_creation.client_secret)) + .set_form(TokenRequest { + grant_type: "authorization_code".to_string(), + code: auth_code.to_string(), + redirect_uri: Some(redirect_uri.clone()), + client_id: client_id_str.to_string(), + }) + .to_request(), + ) + .await; + assert_status(&resp, StatusCode::OK); + assert_eq!(resp.headers().get(CACHE_CONTROL).unwrap(), "no-store"); + assert_eq!(resp.headers().get(PRAGMA).unwrap(), "no-cache"); + let token_resp: TokenResponse = test::read_body_json(resp).await; + + // Validate that the token works + env.v2 + .get_user_notifications_deserialized(FRIEND_USER_ID, &token_resp.access_token) + .await; + }) + .await; +} From 16e7eb21be77909d8d88e2fdd4d84d5a5b449d58 Mon Sep 17 00:00:00 2001 From: OmegaJak Date: Thu, 19 Oct 2023 17:46:06 -0500 Subject: [PATCH 16/72] Add dummy data oauth client (and detect/handle dummy data version changes) --- tests/common/api_v3/mod.rs | 1 + tests/common/api_v3/oauth.rs | 75 ++++++++++++ tests/common/api_v3/oauth_clients.rs | 16 ++- tests/common/database.rs | 51 ++++++-- tests/common/dummy_data.rs | 172 +++++++++++++++++---------- tests/common/mod.rs | 8 ++ tests/files/dummy_data.sql | 19 +++ tests/oauth.rs | 112 +++++++---------- 8 files changed, 313 insertions(+), 141 deletions(-) create mode 100644 tests/common/api_v3/oauth.rs diff --git a/tests/common/api_v3/mod.rs b/tests/common/api_v3/mod.rs index 68a23197..2155aa3c 100644 --- a/tests/common/api_v3/mod.rs +++ b/tests/common/api_v3/mod.rs @@ -4,6 +4,7 @@ use super::environment::LocalService; use actix_web::dev::ServiceResponse; use std::rc::Rc; +pub mod oauth; pub mod oauth_clients; #[derive(Clone)] diff --git a/tests/common/api_v3/oauth.rs b/tests/common/api_v3/oauth.rs new file mode 100644 index 00000000..92b44892 --- /dev/null +++ b/tests/common/api_v3/oauth.rs @@ -0,0 +1,75 @@ +use actix_web::{dev::ServiceResponse, test::TestRequest}; +use labrinth::auth::oauth::{AcceptOAuthClientScopes, TokenRequest}; +use reqwest::header::AUTHORIZATION; + +use super::ApiV3; + +impl ApiV3 { + pub async fn oauth_authorize( + &self, + client_id: &str, + scope: &str, + redirect_uri: &str, + state: &str, + pat: &str, + ) -> ServiceResponse { + let uri = generate_authorize_uri(client_id, scope, redirect_uri, state); + let req = TestRequest::get() + .uri(&uri) + .append_header((AUTHORIZATION, pat)) + .to_request(); + self.call(req).await + } + + pub async fn oauth_accept(&self, flow: &str, pat: &str) -> ServiceResponse { + self.call( + TestRequest::post() + .uri("/v3/auth/oauth/accept") + .append_header((AUTHORIZATION, pat)) + .set_json(AcceptOAuthClientScopes { + flow: flow.to_string(), + }) + .to_request(), + ) + .await + } + + pub async fn oauth_token( + &self, + auth_code: String, + original_redirect_uri: Option, + client_id: String, + client_secret: &str, + ) -> ServiceResponse { + self.call( + TestRequest::post() + .uri("/v3/auth/oauth/token") + .append_header((AUTHORIZATION, client_secret)) + .set_form(TokenRequest { + grant_type: "authorization_code".to_string(), + code: auth_code, + redirect_uri: original_redirect_uri, + client_id, + }) + .to_request(), + ) + .await + } +} + +pub fn generate_authorize_uri( + client_id: &str, + scope: &str, + redirect_uri: &str, + state: &str, +) -> String { + format!( + "/v3/auth/oauth/authorize?client_id={}&redirect_uri={}\ + &scope={}&state={}", + urlencoding::encode(client_id), + urlencoding::encode(redirect_uri), + urlencoding::encode(scope), + urlencoding::encode(state), + ) + .to_string() +} diff --git a/tests/common/api_v3/oauth_clients.rs b/tests/common/api_v3/oauth_clients.rs index aff9c1eb..291bd160 100644 --- a/tests/common/api_v3/oauth_clients.rs +++ b/tests/common/api_v3/oauth_clients.rs @@ -1,6 +1,9 @@ use actix_http::StatusCode; use actix_web::test::{self, TestRequest}; -use labrinth::models::{oauth_clients::OAuthClientCreationResult, pats::Scopes}; +use labrinth::models::{ + oauth_clients::{OAuthClient, OAuthClientCreationResult}, + pats::Scopes, +}; use reqwest::header::AUTHORIZATION; use serde_json::json; @@ -32,4 +35,15 @@ impl ApiV3 { test::read_body_json(resp).await } + + pub async fn get_user_oauth_clients(&self, user_id: &str, pat: &str) -> Vec { + let req = TestRequest::get() + .uri(&format!("/v3/user/{}/oauth_apps", user_id)) + .append_header((AUTHORIZATION, pat)) + .to_request(); + let resp = self.call(req).await; + assert_status(&resp, StatusCode::OK); + + test::read_body_json(resp).await + } } diff --git a/tests/common/database.rs b/tests/common/database.rs index e30bd077..1071c88a 100644 --- a/tests/common/database.rs +++ b/tests/common/database.rs @@ -7,6 +7,8 @@ use url::Url; use crate::common::{dummy_data, environment::TestEnvironment}; +use super::dummy_data::DUMMY_DATA_UPDATE; + // The dummy test database adds a fair bit of 'dummy' data to test with. // Some constants are used to refer to that data, and are described here. // The rest can be accessed in the TestEnvironment 'dummy' field. @@ -119,11 +121,7 @@ impl TemporaryDatabase { .await .unwrap(); if db_exists.is_none() { - let create_db_query = format!("CREATE DATABASE {TEMPLATE_DATABASE_NAME}"); - sqlx::query(&create_db_query) - .execute(&main_pool) - .await - .expect("Database creation failed"); + create_template_database(&main_pool).await; } // Switch to template @@ -135,22 +133,45 @@ impl TemporaryDatabase { .await .expect("Connection to database failed"); - // Run migrations on the template - let migrations = sqlx::migrate!("./migrations"); - migrations.run(&pool).await.expect("Migrations failed"); - // Check if dummy data exists- a fake 'dummy_data' table is created if it does - let dummy_data_exists: bool = + let mut dummy_data_exists: bool = sqlx::query_scalar("SELECT to_regclass('dummy_data') IS NOT NULL") .fetch_one(&pool) .await .unwrap(); + if dummy_data_exists { + // Check if the dummy data needs to be updated + let dummy_data_update = + sqlx::query_scalar::<_, i64>("SELECT update_id FROM dummy_data") + .fetch_optional(&pool) + .await + .unwrap(); + let needs_update = !dummy_data_update.is_some_and(|d| d == DUMMY_DATA_UPDATE); + if needs_update { + println!("Dummy data updated, so template DB tables will be dropped and re-created"); + // Drop all tables in the database so they can be re-created and later filled with updated dummy data + sqlx::query!("DROP SCHEMA public CASCADE;") + .execute(&pool) + .await + .unwrap(); + sqlx::query!("CREATE SCHEMA public;") + .execute(&pool) + .await + .unwrap(); + dummy_data_exists = false; + } + } + + // Run migrations on the template + let migrations = sqlx::migrate!("./migrations"); + migrations.run(&pool).await.expect("Migrations failed"); + if !dummy_data_exists { // Add dummy data let temporary_test_env = TestEnvironment::build_with_db(TemporaryDatabase { pool: pool.clone(), database_name: TEMPLATE_DATABASE_NAME.to_string(), - redis_pool: RedisPool::new(None), + redis_pool: RedisPool::new(Some(generate_random_name("test_template_"))), }) .await; dummy_data::add_dummy_data(&temporary_test_env).await; @@ -216,6 +237,14 @@ impl TemporaryDatabase { } } +async fn create_template_database(pool: &sqlx::Pool) { + let create_db_query = format!("CREATE DATABASE {TEMPLATE_DATABASE_NAME}"); + sqlx::query(&create_db_query) + .execute(pool) + .await + .expect("Database creation failed"); +} + // Appends a random 8-digit number to the end of the str pub fn generate_random_name(str: &str) -> String { let mut str = String::from(str); diff --git a/tests/common/dummy_data.rs b/tests/common/dummy_data.rs index 7f7e10a4..2ed669df 100644 --- a/tests/common/dummy_data.rs +++ b/tests/common/dummy_data.rs @@ -2,7 +2,9 @@ use actix_web::test::{self, TestRequest}; use labrinth::{ models::projects::Project, - models::{organizations::Organization, pats::Scopes, projects::Version}, + models::{ + oauth_clients::OAuthClient, organizations::Organization, pats::Scopes, projects::Version, + }, }; use serde_json::json; use sqlx::Executor; @@ -11,11 +13,13 @@ use crate::common::{actix::AppendsMultipart, database::USER_USER_PAT}; use super::{ actix::{MultipartSegment, MultipartSegmentData}, + database::USER_USER_ID, environment::TestEnvironment, + get_json_val_str, request_data::get_public_project_creation_data, }; -pub const DUMMY_DATA_UPDATE: i64 = 1; +pub const DUMMY_DATA_UPDATE: i64 = 2; #[allow(dead_code)] pub const DUMMY_CATEGORIES: &[&str] = &[ @@ -28,6 +32,8 @@ pub const DUMMY_CATEGORIES: &[&str] = &[ "optimization", ]; +pub const DUMMY_OAUTH_CLIENT_ALPHA_SECRET: &str = "abcdefghijklmnopqrstuvwxyz"; + #[allow(dead_code)] pub enum DummyJarFile { DummyProjectAlpha, @@ -43,16 +49,80 @@ pub enum DummyImage { #[derive(Clone)] pub struct DummyData { + /// Alpha project: + /// This is a dummy project created by USER user. + /// It's approved, listed, and visible to the public. pub project_alpha: DummyProjectAlpha, + + /// Beta project: + /// This is a dummy project created by USER user. + /// It's not approved, unlisted, and not visible to the public. pub project_beta: DummyProjectBeta, + + /// Zeta organization: + /// This is a dummy organization created by USER user. + /// There are no projects in it. pub organization_zeta: DummyOrganizationZeta, + + /// Alpha OAuth Client: + /// This is a dummy OAuth client created by USER user. + /// + /// All scopes are included in its max scopes + /// + /// It has one valid redirect URI + pub oauth_client_alpha: DummyOAuthClientAlpha, +} + +impl DummyData { + pub fn new( + project_alpha: Project, + project_alpha_version: Version, + project_beta: Project, + project_beta_version: Version, + organization_zeta: Organization, + oauth_client_alpha: OAuthClient, + ) -> Self { + DummyData { + project_alpha: DummyProjectAlpha { + team_id: project_alpha.team.to_string(), + project_id: project_alpha.id.to_string(), + project_slug: project_alpha.slug.unwrap(), + version_id: project_alpha_version.id.to_string(), + thread_id: project_alpha.thread_id.to_string(), + file_hash: project_alpha_version.files[0].hashes["sha1"].clone(), + }, + + project_beta: DummyProjectBeta { + team_id: project_beta.team.to_string(), + project_id: project_beta.id.to_string(), + project_slug: project_beta.slug.unwrap(), + version_id: project_beta_version.id.to_string(), + thread_id: project_beta.thread_id.to_string(), + file_hash: project_beta_version.files[0].hashes["sha1"].clone(), + }, + + organization_zeta: DummyOrganizationZeta { + organization_id: organization_zeta.id.to_string(), + team_id: organization_zeta.team_id.to_string(), + organization_title: organization_zeta.title, + }, + + oauth_client_alpha: DummyOAuthClientAlpha { + client_id: get_json_val_str(oauth_client_alpha.id), + client_secret: DUMMY_OAUTH_CLIENT_ALPHA_SECRET.to_string(), + valid_redirect_uri: oauth_client_alpha + .redirect_uris + .first() + .unwrap() + .uri + .clone(), + }, + } + } } #[derive(Clone)] pub struct DummyProjectAlpha { - // Alpha project: - // This is a dummy project created by USER user. - // It's approved, listed, and visible to the public. pub project_id: String, pub project_slug: String, pub version_id: String, @@ -63,9 +133,6 @@ pub struct DummyProjectAlpha { #[derive(Clone)] pub struct DummyProjectBeta { - // Beta project: - // This is a dummy project created by USER user. - // It's not approved, unlisted, and not visible to the public. pub project_id: String, pub project_slug: String, pub version_id: String, @@ -76,14 +143,18 @@ pub struct DummyProjectBeta { #[derive(Clone)] pub struct DummyOrganizationZeta { - // Zeta organization: - // This is a dummy organization created by USER user. - // There are no projects in it. pub organization_id: String, pub organization_title: String, pub team_id: String, } +#[derive(Clone)] +pub struct DummyOAuthClientAlpha { + pub client_id: String, + pub client_secret: String, + pub valid_redirect_uri: String, +} + pub async fn add_dummy_data(test_env: &TestEnvironment) -> DummyData { // Adds basic dummy data to the database directly with sql (user, pats) let pool = &test_env.db.pool.clone(); @@ -101,37 +172,22 @@ pub async fn add_dummy_data(test_env: &TestEnvironment) -> DummyData { let zeta_organization = add_organization_zeta(test_env).await; + let oauth_client_alpha = get_oauth_client_alpha(test_env).await; + sqlx::query("INSERT INTO dummy_data (update_id) VALUES ($1)") .bind(DUMMY_DATA_UPDATE) .execute(pool) .await .unwrap(); - DummyData { - project_alpha: DummyProjectAlpha { - team_id: alpha_project.team.to_string(), - project_id: alpha_project.id.to_string(), - project_slug: alpha_project.slug.unwrap(), - version_id: alpha_version.id.to_string(), - thread_id: alpha_project.thread_id.to_string(), - file_hash: alpha_version.files[0].hashes["sha1"].clone(), - }, - - project_beta: DummyProjectBeta { - team_id: beta_project.team.to_string(), - project_id: beta_project.id.to_string(), - project_slug: beta_project.slug.unwrap(), - version_id: beta_version.id.to_string(), - thread_id: beta_project.thread_id.to_string(), - file_hash: beta_version.files[0].hashes["sha1"].clone(), - }, - - organization_zeta: DummyOrganizationZeta { - organization_id: zeta_organization.id.to_string(), - team_id: zeta_organization.team_id.to_string(), - organization_title: zeta_organization.title, - }, - } + DummyData::new( + alpha_project, + alpha_version, + beta_project, + beta_version, + zeta_organization, + oauth_client_alpha, + ) } pub async fn get_dummy_data(test_env: &TestEnvironment) -> DummyData { @@ -139,31 +195,17 @@ pub async fn get_dummy_data(test_env: &TestEnvironment) -> DummyData { let (beta_project, beta_version) = get_project_beta(test_env).await; let zeta_organization = get_organization_zeta(test_env).await; - DummyData { - project_alpha: DummyProjectAlpha { - team_id: alpha_project.team.to_string(), - project_id: alpha_project.id.to_string(), - project_slug: alpha_project.slug.unwrap(), - version_id: alpha_version.id.to_string(), - thread_id: alpha_project.thread_id.to_string(), - file_hash: alpha_version.files[0].hashes["sha1"].clone(), - }, - - project_beta: DummyProjectBeta { - team_id: beta_project.team.to_string(), - project_id: beta_project.id.to_string(), - project_slug: beta_project.slug.unwrap(), - version_id: beta_version.id.to_string(), - thread_id: beta_project.thread_id.to_string(), - file_hash: beta_version.files[0].hashes["sha1"].clone(), - }, - - organization_zeta: DummyOrganizationZeta { - organization_id: zeta_organization.id.to_string(), - team_id: zeta_organization.team_id.to_string(), - organization_title: zeta_organization.title, - }, - } + + let oauth_client_alpha = get_oauth_client_alpha(test_env).await; + + DummyData::new( + alpha_project, + alpha_version, + beta_project, + beta_version, + zeta_organization, + oauth_client_alpha, + ) } pub async fn add_project_alpha(test_env: &TestEnvironment) -> (Project, Version) { @@ -308,6 +350,14 @@ pub async fn get_organization_zeta(test_env: &TestEnvironment) -> Organization { organization } +pub async fn get_oauth_client_alpha(test_env: &TestEnvironment) -> OAuthClient { + let oauth_clients = test_env + .v3 + .get_user_oauth_clients(USER_USER_ID, USER_USER_PAT) + .await; + oauth_clients.into_iter().next().unwrap() +} + impl DummyJarFile { pub fn filename(&self) -> String { match self { diff --git a/tests/common/mod.rs b/tests/common/mod.rs index 5806c5db..b2a317bb 100644 --- a/tests/common/mod.rs +++ b/tests/common/mod.rs @@ -43,3 +43,11 @@ pub async fn setup(db: &TemporaryDatabase) -> LabrinthConfig { maxmind_reader.clone(), ) } + +pub fn get_json_val_str(val: impl serde::Serialize) -> String { + serde_json::to_value(val) + .unwrap() + .as_str() + .unwrap() + .to_string() +} diff --git a/tests/files/dummy_data.sql b/tests/files/dummy_data.sql index 487397a5..aaa8c1b7 100644 --- a/tests/files/dummy_data.sql +++ b/tests/files/dummy_data.sql @@ -45,6 +45,25 @@ INSERT INTO categories (id, category, project_type) VALUES (106, 'mobs', 2), (107, 'optimization', 2); +-- Create dummy oauth client, secret_hash is SHA512 hash of full lowercase alphabet +INSERT INTO oauth_clients ( + id, + name, + icon_url, + max_scopes, + secret_hash, + created_by + ) +VALUES ( + 1, + 'oauth_client_alpha', + NULL, + $1, + '4dbff86cc2ca1bae1e16468a05cb9881c97f1753bce3619034898faa1aabe429955a1bf8ec483d7421fe3c1646613a59ed5441fb0f321389f77f48a879c7b1f1', + 3 + ); +INSERT INTO oauth_client_redirect_uris (id, client_id, uri) VALUES (1, 1, 'https://modrinth.com/oauth_callback'); + -- Create dummy data table to mark that this file has been run CREATE TABLE dummy_data ( update_id bigint PRIMARY KEY diff --git a/tests/oauth.rs b/tests/oauth.rs index aea2616f..43081a2c 100644 --- a/tests/oauth.rs +++ b/tests/oauth.rs @@ -1,80 +1,51 @@ -use std::collections::HashMap; - +use crate::common::{database::FRIEND_USER_ID, dummy_data::DummyOAuthClientAlpha}; use actix_http::StatusCode; -use actix_web::test::{self, TestRequest}; -use common::{ - asserts::assert_status, - database::{FRIEND_USER_PAT, USER_USER_PAT}, - environment::with_test_environment, +use actix_web::{ + dev::ServiceResponse, + test::{self}, }; -use labrinth::{ - auth::oauth::{AcceptOAuthClientScopes, OAuthClientAccessRequest, TokenRequest, TokenResponse}, - models::pats::Scopes, +use common::{ + asserts::assert_status, database::FRIEND_USER_PAT, environment::with_test_environment, }; -use reqwest::header::{AUTHORIZATION, CACHE_CONTROL, LOCATION, PRAGMA}; - -use crate::common::database::FRIEND_USER_ID; +use labrinth::auth::oauth::{OAuthClientAccessRequest, TokenResponse}; +use reqwest::header::{CACHE_CONTROL, LOCATION, PRAGMA}; +use std::collections::HashMap; mod common; #[actix_rt::test] -async fn oauth_flow() { +async fn oauth_flow_happy_path() { with_test_environment(|env| async move { - let base_redirect_uri = "https://modrinth.com/test".to_string(); - let client_creation = env - .v3 - .add_oauth_client( - "test_client".to_string(), - Scopes::all(), - vec![base_redirect_uri.clone()], - USER_USER_PAT, - ) - .await; - let client_id_str = serde_json::to_value(client_creation.client.id) - .unwrap() - .as_str() - .unwrap() - .to_string(); + let DummyOAuthClientAlpha { + valid_redirect_uri: base_redirect_uri, + client_id, + client_secret, + } = env.dummy.unwrap().oauth_client_alpha.clone(); // Initiate authorization let redirect_uri = format!("{}?foo=bar", base_redirect_uri); let original_state = "1234"; - let uri = &format!( - "/v3/auth/oauth/authorize?client_id={}&redirect_uri={}\ - &scope={}&state={original_state}", - urlencoding::encode(&client_id_str), - urlencoding::encode(&redirect_uri), - urlencoding::encode("USER_READ NOTIFICATION_READ") - ) - .to_string(); - let req = TestRequest::get() - .uri(&uri) - .append_header((AUTHORIZATION, FRIEND_USER_PAT)) - .to_request(); - let resp = env.call(req).await; + let resp = env + .v3 + .oauth_authorize( + &client_id, + "USER_READ NOTIFICATION_READ", + &redirect_uri, + original_state, + FRIEND_USER_PAT, + ) + .await; assert_status(&resp, StatusCode::OK); let access_request: OAuthClientAccessRequest = test::read_body_json(resp).await; // Accept the authorization request let resp = env - .call( - TestRequest::post() - .uri("/v3/auth/oauth/accept") - .append_header((AUTHORIZATION, FRIEND_USER_PAT)) - .set_json(AcceptOAuthClientScopes { - flow: access_request.flow_id, - }) - .to_request(), - ) + .v3 + .oauth_accept(&access_request.flow_id, FRIEND_USER_PAT) .await; assert_status(&resp, StatusCode::FOUND); - let redirect_location = resp.headers().get(LOCATION).unwrap().to_str().unwrap(); - let query = actix_web::web::Query::>::from_query( - redirect_location.split_once('?').unwrap().1, - ) - .unwrap(); + let query = get_redirect_location_query_params(&resp); - println!("redirect location: {}, {:#?}", redirect_location, query.0); let auth_code = query.get("code").unwrap(); let state = query.get("state").unwrap(); let foo = query.get("foo").unwrap(); @@ -83,17 +54,12 @@ async fn oauth_flow() { // Get the token let resp = env - .call( - TestRequest::post() - .uri("/v3/auth/oauth/token") - .append_header((AUTHORIZATION, client_creation.client_secret)) - .set_form(TokenRequest { - grant_type: "authorization_code".to_string(), - code: auth_code.to_string(), - redirect_uri: Some(redirect_uri.clone()), - client_id: client_id_str.to_string(), - }) - .to_request(), + .v3 + .oauth_token( + auth_code.to_string(), + Some(redirect_uri.clone()), + client_id.to_string(), + &client_secret, ) .await; assert_status(&resp, StatusCode::OK); @@ -108,3 +74,13 @@ async fn oauth_flow() { }) .await; } + +fn get_redirect_location_query_params( + response: &ServiceResponse, +) -> actix_web::web::Query> { + let redirect_location = response.headers().get(LOCATION).unwrap().to_str().unwrap(); + actix_web::web::Query::>::from_query( + redirect_location.split_once('?').unwrap().1, + ) + .unwrap() +} From 1c1636deb52d0ae490cb48b58259464ddc4e8c3c Mon Sep 17 00:00:00 2001 From: OmegaJak Date: Thu, 19 Oct 2023 18:23:47 -0500 Subject: [PATCH 17/72] More tests --- tests/common/api_v3/oauth.rs | 21 ++++++- tests/oauth.rs | 108 +++++++++++++++++++++++++++++++++-- 2 files changed, 121 insertions(+), 8 deletions(-) diff --git a/tests/common/api_v3/oauth.rs b/tests/common/api_v3/oauth.rs index 92b44892..2dea1422 100644 --- a/tests/common/api_v3/oauth.rs +++ b/tests/common/api_v3/oauth.rs @@ -1,5 +1,8 @@ -use actix_web::{dev::ServiceResponse, test::TestRequest}; -use labrinth::auth::oauth::{AcceptOAuthClientScopes, TokenRequest}; +use actix_web::{ + dev::ServiceResponse, + test::{self, TestRequest}, +}; +use labrinth::auth::oauth::{AcceptOAuthClientScopes, TokenRequest, TokenResponse}; use reqwest::header::AUTHORIZATION; use super::ApiV3; @@ -55,6 +58,20 @@ impl ApiV3 { ) .await } + + pub async fn get_oauth_access_token( + &self, + auth_code: String, + original_redirect_uri: Option, + client_id: String, + client_secret: &str, + ) -> String { + let response = self + .oauth_token(auth_code, original_redirect_uri, client_id, client_secret) + .await; + let token_resp: TokenResponse = test::read_body_json(response).await; + token_resp.access_token + } } pub fn generate_authorize_uri( diff --git a/tests/oauth.rs b/tests/oauth.rs index 43081a2c..37128450 100644 --- a/tests/oauth.rs +++ b/tests/oauth.rs @@ -5,7 +5,9 @@ use actix_web::{ test::{self}, }; use common::{ - asserts::assert_status, database::FRIEND_USER_PAT, environment::with_test_environment, + asserts::assert_status, + database::{FRIEND_USER_PAT, USER_USER_PAT}, + environment::with_test_environment, }; use labrinth::auth::oauth::{OAuthClientAccessRequest, TokenResponse}; use reqwest::header::{CACHE_CONTROL, LOCATION, PRAGMA}; @@ -36,13 +38,10 @@ async fn oauth_flow_happy_path() { ) .await; assert_status(&resp, StatusCode::OK); - let access_request: OAuthClientAccessRequest = test::read_body_json(resp).await; + let flow_id = get_authorize_accept_flow_id(resp).await; // Accept the authorization request - let resp = env - .v3 - .oauth_accept(&access_request.flow_id, FRIEND_USER_PAT) - .await; + let resp = env.v3.oauth_accept(&flow_id, FRIEND_USER_PAT).await; assert_status(&resp, StatusCode::FOUND); let query = get_redirect_location_query_params(&resp); @@ -75,6 +74,103 @@ async fn oauth_flow_happy_path() { .await; } +#[actix_rt::test] +async fn oauth_authorize_for_already_authorized_scopes_returns_auth_code() { + with_test_environment(|env| async { + let DummyOAuthClientAlpha { + valid_redirect_uri, + client_id, + .. + } = env.dummy.unwrap().oauth_client_alpha.clone(); + + let resp = env + .v3 + .oauth_authorize( + &client_id, + "USER_READ NOTIFICATION_READ", + &valid_redirect_uri, + "1234", + USER_USER_PAT, + ) + .await; + let flow_id = get_authorize_accept_flow_id(resp).await; + env.v3.oauth_accept(&flow_id, USER_USER_PAT).await; + + let resp = env + .v3 + .oauth_authorize( + &client_id, + "USER_READ", + &valid_redirect_uri, + "5678", + USER_USER_PAT, + ) + .await; + assert_status(&resp, StatusCode::FOUND); + }) + .await; +} + +#[actix_rt::test] +async fn get_oauth_token_with_already_used_auth_code_fails() { + with_test_environment(|env| async { + let DummyOAuthClientAlpha { + valid_redirect_uri, + client_id, + client_secret, + } = env.dummy.unwrap().oauth_client_alpha.clone(); + + let resp = env + .v3 + .oauth_authorize( + &client_id, + "USER_READ", + &valid_redirect_uri, + "1234", + USER_USER_PAT, + ) + .await; + let flow_id = get_authorize_accept_flow_id(resp).await; + + let resp = env.v3.oauth_accept(&flow_id, USER_USER_PAT).await; + let auth_code = get_auth_code_from_redirect_params(&resp).await; + + let resp = env + .v3 + .oauth_token( + auth_code.clone(), + Some(valid_redirect_uri.clone()), + client_id.clone(), + &client_secret, + ) + .await; + assert_status(&resp, StatusCode::OK); + + let resp = env + .v3 + .oauth_token( + auth_code, + Some(valid_redirect_uri), + client_id, + &client_secret, + ) + .await; + assert_status(&resp, StatusCode::BAD_REQUEST); + }) + .await; +} + +async fn get_authorize_accept_flow_id(response: ServiceResponse) -> String { + test::read_body_json::(response) + .await + .flow_id +} + +async fn get_auth_code_from_redirect_params(response: &ServiceResponse) -> String { + let query_params = get_redirect_location_query_params(response); + query_params.get("code").unwrap().to_string() +} + fn get_redirect_location_query_params( response: &ServiceResponse, ) -> actix_web::web::Query> { From 32213813da359aede3a2bf33be51f25bd1b8a9aa Mon Sep 17 00:00:00 2001 From: OmegaJak Date: Thu, 19 Oct 2023 19:45:57 -0500 Subject: [PATCH 18/72] Another test --- tests/common/api_v3/oauth.rs | 12 +++++++---- tests/oauth.rs | 42 ++++++++++++++++++++++++++++++++---- 2 files changed, 46 insertions(+), 8 deletions(-) diff --git a/tests/common/api_v3/oauth.rs b/tests/common/api_v3/oauth.rs index 2dea1422..6d06a1d3 100644 --- a/tests/common/api_v3/oauth.rs +++ b/tests/common/api_v3/oauth.rs @@ -13,7 +13,7 @@ impl ApiV3 { client_id: &str, scope: &str, redirect_uri: &str, - state: &str, + state: Option<&str>, pat: &str, ) -> ServiceResponse { let uri = generate_authorize_uri(client_id, scope, redirect_uri, state); @@ -78,15 +78,19 @@ pub fn generate_authorize_uri( client_id: &str, scope: &str, redirect_uri: &str, - state: &str, + state: Option<&str>, ) -> String { format!( "/v3/auth/oauth/authorize?client_id={}&redirect_uri={}\ - &scope={}&state={}", + &scope={}{}", urlencoding::encode(client_id), urlencoding::encode(redirect_uri), urlencoding::encode(scope), - urlencoding::encode(state), + if let Some(state) = state { + urlencoding::encode(state).to_string() + } else { + "".to_string() + }, ) .to_string() } diff --git a/tests/oauth.rs b/tests/oauth.rs index 37128450..e4ede7fd 100644 --- a/tests/oauth.rs +++ b/tests/oauth.rs @@ -33,7 +33,7 @@ async fn oauth_flow_happy_path() { &client_id, "USER_READ NOTIFICATION_READ", &redirect_uri, - original_state, + Some(original_state), FRIEND_USER_PAT, ) .await; @@ -89,7 +89,7 @@ async fn oauth_authorize_for_already_authorized_scopes_returns_auth_code() { &client_id, "USER_READ NOTIFICATION_READ", &valid_redirect_uri, - "1234", + Some("1234"), USER_USER_PAT, ) .await; @@ -102,7 +102,7 @@ async fn oauth_authorize_for_already_authorized_scopes_returns_auth_code() { &client_id, "USER_READ", &valid_redirect_uri, - "5678", + Some("5678"), USER_USER_PAT, ) .await; @@ -126,7 +126,7 @@ async fn get_oauth_token_with_already_used_auth_code_fails() { &client_id, "USER_READ", &valid_redirect_uri, - "1234", + None, USER_USER_PAT, ) .await; @@ -160,6 +160,40 @@ async fn get_oauth_token_with_already_used_auth_code_fails() { .await; } +#[actix_rt::test] +async fn oauth_authorize_with_broader_scopes_requires_user_accept() { + with_test_environment(|env| async { + let DummyOAuthClientAlpha { + valid_redirect_uri: redirect_uri, + client_id, + .. + } = env.dummy.unwrap().oauth_client_alpha.clone(); + let resp = env + .v3 + .oauth_authorize(&client_id, "USER_READ", &redirect_uri, None, USER_USER_PAT) + .await; + let flow_id = get_authorize_accept_flow_id(resp).await; + env.v3.oauth_accept(&flow_id, USER_USER_PAT).await; + + let resp = env + .v3 + .oauth_authorize( + &client_id, + "USER_READ NOTIFICATION_READ", + &redirect_uri, + None, + USER_USER_PAT, + ) + .await; + + assert_status(&resp, StatusCode::OK); + get_authorize_accept_flow_id(resp).await; // ensure we can deser this without error to really confirm + }) + .await; +} + +//TODO: What about rejecting??? + async fn get_authorize_accept_flow_id(response: ServiceResponse) -> String { test::read_body_json::(response) .await From 43a64adea4392ef2366fd929fe65d1823d6ea779 Mon Sep 17 00:00:00 2001 From: OmegaJak Date: Fri, 20 Oct 2023 12:56:34 -0500 Subject: [PATCH 19/72] More tests and reject endpoint --- src/auth/oauth/errors.rs | 6 +- src/auth/oauth/mod.rs | 167 ++++++++++++--------- src/database/models/flow_item.rs | 16 ++ src/routes/v3/oauth_clients.rs | 4 +- tests/common/api_v2/project.rs | 14 +- tests/common/api_v2/team.rs | 14 +- tests/common/api_v3/oauth.rs | 100 +++++++++++-- tests/common/api_v3/oauth_clients.rs | 51 ++++++- tests/common/asserts.rs | 7 + tests/common/environment.rs | 20 +++ tests/oauth.rs | 214 ++++++++++++++++++--------- 11 files changed, 445 insertions(+), 168 deletions(-) diff --git a/src/auth/oauth/errors.rs b/src/auth/oauth/errors.rs index 7ba42bdd..c8cafd8c 100644 --- a/src/auth/oauth/errors.rs +++ b/src/auth/oauth/errors.rs @@ -51,7 +51,8 @@ impl actix_web::ResponseError for OAuthError { match self.error_type { OAuthErrorType::AuthenticationError(_) | OAuthErrorType::FailedScopeParse(_) - | OAuthErrorType::ScopesTooBroad => { + | OAuthErrorType::ScopesTooBroad + | OAuthErrorType::AccessDenied => { if self.valid_redirect_uri.is_some() { StatusCode::FOUND } else { @@ -129,6 +130,8 @@ pub enum OAuthErrorType { RedirectUriChanged(Option), #[error("The provided grant type ({0}) must be \"authorization_code\"")] OnlySupportsAuthorizationCodeGrant(String), + #[error("The resource owner denied the request")] + AccessDenied, } impl From for OAuthErrorType { @@ -157,6 +160,7 @@ impl OAuthErrorType { Self::InvalidClientId(_) | Self::ClientAuthenticationFailed => "invalid_client", Self::InvalidAuthCode | Self::OnlySupportsAuthorizationCodeGrant(_) => "invalid_grant", Self::UnauthorizedClient => "unauthorized_client", + Self::AccessDenied => "access_denied", } .to_string() } diff --git a/src/auth/oauth/mod.rs b/src/auth/oauth/mod.rs index bbffb1d9..4e651c28 100644 --- a/src/auth/oauth/mod.rs +++ b/src/auth/oauth/mod.rs @@ -33,6 +33,7 @@ pub fn config(cfg: &mut ServiceConfig) { scope("auth/oauth") .service(init_oauth) .service(accept_client_scopes) + .service(reject_client_scopes) .service(request_token), ); } @@ -160,74 +161,30 @@ pub async fn init_oauth( } #[derive(Serialize, Deserialize)] -pub struct AcceptOAuthClientScopes { +pub struct RespondToOAuthClientScopes { pub flow: String, } #[post("accept")] pub async fn accept_client_scopes( req: HttpRequest, - accept_body: web::Json, + accept_body: web::Json, pool: Data, redis: Data, session_queue: Data, ) -> Result { - let current_user = get_user_from_headers( - &req, - &**pool, - &redis, - &session_queue, - Some(&[Scopes::USER_OAUTH_AUTHORIZATIONS_WRITE]), - ) - .await - .map_err(|e| OAuthError::error(e))? - .1; - - let flow = Flow::get(&accept_body.flow, &redis) - .await - .map_err(|e| OAuthError::error(e))?; - if let Some(Flow::InitOAuthAppApproval { - user_id, - client_id, - existing_authorization_id, - scopes, - validated_redirect_uri, - original_redirect_uri, - state, - }) = flow - { - if current_user.id != user_id.into() { - return Err(OAuthError::error(AuthenticationError::InvalidCredentials)); - } - - let mut transaction = pool.begin().await.map_err(OAuthError::error)?; - - let auth_id = match existing_authorization_id { - Some(id) => id, - None => generate_oauth_client_authorization_id(&mut transaction) - .await - .map_err(OAuthError::error)?, - }; - OAuthClientAuthorization::upsert(auth_id, client_id, user_id, scopes, &mut transaction) - .await - .map_err(OAuthError::error)?; - - transaction.commit().await.map_err(OAuthError::error)?; + accept_or_reject_client_scopes(true, req, accept_body, pool, redis, session_queue).await +} - init_oauth_code_flow( - user_id, - client_id, - auth_id, - scopes, - validated_redirect_uri, - original_redirect_uri, - state, - &redis, - ) - .await - } else { - Err(OAuthError::error(OAuthErrorType::InvalidAcceptFlowId)) - } +#[post("reject")] +pub async fn reject_client_scopes( + req: HttpRequest, + body: web::Json, + pool: Data, + redis: Data, + session_queue: Data, +) -> Result { + accept_or_reject_client_scopes(false, req, body, pool, redis, session_queue).await } #[derive(Serialize, Deserialize)] @@ -264,9 +221,15 @@ pub async fn request_token( if let Some(client) = client { authenticate_client_token_request(&req, &client)?; - let flow = Flow::get(&req_params.code, &redis) - .await - .map_err(OAuthError::error)?; + // Ensure auth code is single use + // per IETF RFC6749 Section 10.5 (https://datatracker.ietf.org/doc/html/rfc6749#section-10.5) + let flow = Flow::take_if( + &req_params.code, + |f| matches!(f, Flow::OAuthAuthorizationCodeSupplied { .. }), + &redis, + ) + .await + .map_err(OAuthError::error)?; if let Some(Flow::OAuthAuthorizationCodeSupplied { user_id, client_id, @@ -275,12 +238,6 @@ pub async fn request_token( original_redirect_uri, }) = flow { - // Ensure auth code is single use - // per IETF RFC6749 Section 10.5 (https://datatracker.ietf.org/doc/html/rfc6749#section-10.5) - Flow::remove(&req_params.code, &redis) - .await - .map_err(OAuthError::error)?; - // https://datatracker.ietf.org/doc/html/rfc6749#section-4.1.3 if client_id != client_id { return Err(OAuthError::error(OAuthErrorType::UnauthorizedClient)); @@ -342,6 +299,84 @@ pub async fn request_token( } } +pub async fn accept_or_reject_client_scopes( + accept: bool, + req: HttpRequest, + body: web::Json, + pool: Data, + redis: Data, + session_queue: Data, +) -> Result { + let current_user = get_user_from_headers( + &req, + &**pool, + &redis, + &session_queue, + Some(&[Scopes::USER_OAUTH_AUTHORIZATIONS_WRITE]), + ) + .await + .map_err(|e| OAuthError::error(e))? + .1; + + let flow = Flow::take_if( + &body.flow, + |f| matches!(f, Flow::InitOAuthAppApproval { .. }), + &redis, + ) + .await + .map_err(OAuthError::error)?; + if let Some(Flow::InitOAuthAppApproval { + user_id, + client_id, + existing_authorization_id, + scopes, + validated_redirect_uri, + original_redirect_uri, + state, + }) = flow + { + if current_user.id != user_id.into() { + return Err(OAuthError::error(AuthenticationError::InvalidCredentials)); + } + + if accept { + let mut transaction = pool.begin().await.map_err(OAuthError::error)?; + + let auth_id = match existing_authorization_id { + Some(id) => id, + None => generate_oauth_client_authorization_id(&mut transaction) + .await + .map_err(OAuthError::error)?, + }; + OAuthClientAuthorization::upsert(auth_id, client_id, user_id, scopes, &mut transaction) + .await + .map_err(OAuthError::error)?; + + transaction.commit().await.map_err(OAuthError::error)?; + + init_oauth_code_flow( + user_id, + client_id, + auth_id, + scopes, + validated_redirect_uri, + original_redirect_uri, + state, + &redis, + ) + .await + } else { + Err(OAuthError::redirect( + OAuthErrorType::AccessDenied, + &state, + &validated_redirect_uri, + )) + } + } else { + Err(OAuthError::error(OAuthErrorType::InvalidAcceptFlowId)) + } +} + fn authenticate_client_token_request( req: &HttpRequest, client: &DBOAuthClient, diff --git a/src/database/models/flow_item.rs b/src/database/models/flow_item.rs index 291c0e16..1676f1fd 100644 --- a/src/database/models/flow_item.rs +++ b/src/database/models/flow_item.rs @@ -75,6 +75,22 @@ impl Flow { redis.get_deserialized_from_json(FLOWS_NAMESPACE, id).await } + /// Gets the flow and removes it from the cache, but only removes if the flow was present and the predicate returned true + /// The predicate should validate that the flow being removed is the correct one, as a security measure + pub async fn take_if( + id: &str, + predicate: impl FnOnce(&Flow) -> bool, + redis: &RedisPool, + ) -> Result, DatabaseError> { + let flow = Self::get(id, redis).await?; + if let Some(flow) = flow.as_ref() { + if predicate(flow) { + Self::remove(id, redis).await?; + } + } + Ok(flow) + } + pub async fn remove(id: &str, redis: &RedisPool) -> Result, DatabaseError> { redis.delete(FLOWS_NAMESPACE, id).await?; Ok(Some(())) diff --git a/src/routes/v3/oauth_clients.rs b/src/routes/v3/oauth_clients.rs index fde1fad2..2cbf12fd 100644 --- a/src/routes/v3/oauth_clients.rs +++ b/src/routes/v3/oauth_clients.rs @@ -9,7 +9,7 @@ use chrono::Utc; use itertools::Itertools; use rand::{distributions::Alphanumeric, Rng, SeedableRng}; use rand_chacha::ChaCha20Rng; -use serde::Deserialize; +use serde::{Deserialize, Serialize}; use sqlx::PgPool; use validator::Validate; @@ -181,7 +181,7 @@ pub async fn oauth_client_delete<'a>( } } -#[derive(Deserialize, Validate)] +#[derive(Serialize, Deserialize, Validate)] pub struct OAuthClientEdit { #[validate( custom(function = "crate::util::validate::validate_name"), diff --git a/tests/common/api_v2/project.rs b/tests/common/api_v2/project.rs index dd44692e..7b3af132 100644 --- a/tests/common/api_v2/project.rs +++ b/tests/common/api_v2/project.rs @@ -82,16 +82,20 @@ impl ApiV2 { test::read_body_json(resp).await } + pub async fn get_user_projects(&self, user_id_or_username: &str, pat: &str) -> ServiceResponse { + let req = test::TestRequest::get() + .uri(&format!("/v2/user/{}/projects", user_id_or_username)) + .append_header(("Authorization", pat)) + .to_request(); + self.call(req).await + } + pub async fn get_user_projects_deserialized( &self, user_id_or_username: &str, pat: &str, ) -> Vec { - let req = test::TestRequest::get() - .uri(&format!("/v2/user/{}/projects", user_id_or_username)) - .append_header(("Authorization", pat)) - .to_request(); - let resp = self.call(req).await; + let resp = self.get_user_projects(user_id_or_username, pat).await; assert_eq!(resp.status(), 200); test::read_body_json(resp).await } diff --git a/tests/common/api_v2/team.rs b/tests/common/api_v2/team.rs index a4209ece..1a772053 100644 --- a/tests/common/api_v2/team.rs +++ b/tests/common/api_v2/team.rs @@ -117,16 +117,20 @@ impl ApiV2 { self.call(req).await } + pub async fn get_user_notifications(&self, user_id: &str, pat: &str) -> ServiceResponse { + let req = test::TestRequest::get() + .uri(&format!("/v2/user/{user_id}/notifications")) + .append_header(("Authorization", pat)) + .to_request(); + self.call(req).await + } + pub async fn get_user_notifications_deserialized( &self, user_id: &str, pat: &str, ) -> Vec { - let req = test::TestRequest::get() - .uri(&format!("/v2/user/{user_id}/notifications")) - .append_header(("Authorization", pat)) - .to_request(); - let resp = self.call(req).await; + let resp = self.get_user_notifications(user_id, pat).await; assert_status(&resp, StatusCode::OK); test::read_body_json(resp).await } diff --git a/tests/common/api_v3/oauth.rs b/tests/common/api_v3/oauth.rs index 6d06a1d3..e7c6e7dc 100644 --- a/tests/common/api_v3/oauth.rs +++ b/tests/common/api_v3/oauth.rs @@ -1,18 +1,43 @@ +use std::collections::HashMap; + use actix_web::{ dev::ServiceResponse, test::{self, TestRequest}, }; -use labrinth::auth::oauth::{AcceptOAuthClientScopes, TokenRequest, TokenResponse}; -use reqwest::header::AUTHORIZATION; +use labrinth::auth::oauth::{ + OAuthClientAccessRequest, RespondToOAuthClientScopes, TokenRequest, TokenResponse, +}; +use reqwest::header::{AUTHORIZATION, LOCATION}; use super::ApiV3; impl ApiV3 { + pub async fn complete_full_authorize_flow( + &self, + client_id: &str, + client_secret: &str, + scope: Option<&str>, + redirect_uri: Option<&str>, + state: Option<&str>, + user_pat: &str, + ) -> String { + let auth_resp = self + .oauth_authorize(&client_id, scope, redirect_uri, state, user_pat) + .await; + let flow_id = get_authorize_accept_flow_id(auth_resp).await; + let redirect_resp = self.oauth_accept(&flow_id, user_pat).await; + let auth_code = get_auth_code_from_redirect_params(&redirect_resp).await; + let token_resp = self + .oauth_token(auth_code, None, client_id.to_string(), &client_secret) + .await; + get_access_token(token_resp).await + } + pub async fn oauth_authorize( &self, client_id: &str, - scope: &str, - redirect_uri: &str, + scope: Option<&str>, + redirect_uri: Option<&str>, state: Option<&str>, pat: &str, ) -> ServiceResponse { @@ -29,7 +54,20 @@ impl ApiV3 { TestRequest::post() .uri("/v3/auth/oauth/accept") .append_header((AUTHORIZATION, pat)) - .set_json(AcceptOAuthClientScopes { + .set_json(RespondToOAuthClientScopes { + flow: flow.to_string(), + }) + .to_request(), + ) + .await + } + + pub async fn oauth_reject(&self, flow: &str, pat: &str) -> ServiceResponse { + self.call( + TestRequest::post() + .uri("/v3/auth/oauth/reject") + .append_header((AUTHORIZATION, pat)) + .set_json(RespondToOAuthClientScopes { flow: flow.to_string(), }) .to_request(), @@ -76,21 +114,51 @@ impl ApiV3 { pub fn generate_authorize_uri( client_id: &str, - scope: &str, - redirect_uri: &str, + scope: Option<&str>, + redirect_uri: Option<&str>, state: Option<&str>, ) -> String { format!( - "/v3/auth/oauth/authorize?client_id={}&redirect_uri={}\ - &scope={}{}", + "/v3/auth/oauth/authorize?client_id={}{}{}{}", urlencoding::encode(client_id), - urlencoding::encode(redirect_uri), - urlencoding::encode(scope), - if let Some(state) = state { - urlencoding::encode(state).to_string() - } else { - "".to_string() - }, + optional_query_param("redirect_uri", redirect_uri), + optional_query_param("scope", scope), + optional_query_param("state", state), ) .to_string() } + +pub async fn get_authorize_accept_flow_id(response: ServiceResponse) -> String { + test::read_body_json::(response) + .await + .flow_id +} + +pub async fn get_auth_code_from_redirect_params(response: &ServiceResponse) -> String { + let query_params = get_redirect_location_query_params(response); + query_params.get("code").unwrap().to_string() +} + +pub async fn get_access_token(response: ServiceResponse) -> String { + test::read_body_json::(response) + .await + .access_token +} + +pub fn get_redirect_location_query_params( + response: &ServiceResponse, +) -> actix_web::web::Query> { + let redirect_location = response.headers().get(LOCATION).unwrap().to_str().unwrap(); + actix_web::web::Query::>::from_query( + redirect_location.split_once('?').unwrap().1, + ) + .unwrap() +} + +fn optional_query_param(key: &str, value: Option<&str>) -> String { + if let Some(val) = value { + format!("&{key}={}", urlencoding::encode(val)) + } else { + "".to_string() + } +} diff --git a/tests/common/api_v3/oauth_clients.rs b/tests/common/api_v3/oauth_clients.rs index 291bd160..3987087a 100644 --- a/tests/common/api_v3/oauth_clients.rs +++ b/tests/common/api_v3/oauth_clients.rs @@ -1,8 +1,14 @@ use actix_http::StatusCode; -use actix_web::test::{self, TestRequest}; -use labrinth::models::{ - oauth_clients::{OAuthClient, OAuthClientCreationResult}, - pats::Scopes, +use actix_web::{ + dev::ServiceResponse, + test::{self, TestRequest}, +}; +use labrinth::{ + models::{ + oauth_clients::{OAuthClient, OAuthClientAuthorizationInfo, OAuthClientCreationResult}, + pats::Scopes, + }, + routes::v3::oauth_clients::OAuthClientEdit, }; use reqwest::header::AUTHORIZATION; use serde_json::json; @@ -46,4 +52,41 @@ impl ApiV3 { test::read_body_json(resp).await } + + pub async fn edit_oauth_client( + &self, + client_id: &str, + edit: OAuthClientEdit, + pat: &str, + ) -> ServiceResponse { + let req = TestRequest::patch() + .uri(&format!("/v3/oauth_app/{}", client_id)) + .set_json(edit) + .append_header((AUTHORIZATION, pat)) + .to_request(); + + self.call(req).await + } + + pub async fn revoke_oauth_authorization(&self, client_id: &str, pat: &str) -> ServiceResponse { + let req = TestRequest::delete() + .uri(&format!("/v3/user/oauth_authorizations/{}", client_id)) + .append_header((AUTHORIZATION, pat)) + .to_request(); + self.call(req).await + } + + pub async fn get_user_oauth_authorizations( + &self, + pat: &str, + ) -> Vec { + let req = TestRequest::get() + .uri("/v3/user/oauth_authorizations") + .append_header((AUTHORIZATION, pat)) + .to_request(); + let resp = self.call(req).await; + assert_status(&resp, StatusCode::OK); + + test::read_body_json(resp).await + } } diff --git a/tests/common/asserts.rs b/tests/common/asserts.rs index e15f177e..a72e4ae2 100644 --- a/tests/common/asserts.rs +++ b/tests/common/asserts.rs @@ -1,3 +1,10 @@ pub fn assert_status(response: &actix_web::dev::ServiceResponse, status: actix_http::StatusCode) { assert_eq!(response.status(), status, "{:#?}", response.response()); } + +pub fn assert_any_status_except( + response: &actix_web::dev::ServiceResponse, + status: actix_http::StatusCode, +) { + assert_ne!(response.status(), status, "{:#?}", response.response()); +} diff --git a/tests/common/environment.rs b/tests/common/environment.rs index 08e72d8b..55fd82fa 100644 --- a/tests/common/environment.rs +++ b/tests/common/environment.rs @@ -88,6 +88,26 @@ impl TestEnvironment { .await; assert_status(&resp, StatusCode::NO_CONTENT); } + + pub async fn assert_read_notifications_status( + &self, + user_id: &str, + pat: &str, + status_code: StatusCode, + ) { + let resp = self.v2.get_user_notifications(user_id, pat).await; + assert_status(&resp, status_code); + } + + pub async fn assert_read_user_projects_status( + &self, + user_id: &str, + pat: &str, + status_code: StatusCode, + ) { + let resp = self.v2.get_user_projects(user_id, pat).await; + assert_status(&resp, status_code); + } } pub trait LocalService { diff --git a/tests/oauth.rs b/tests/oauth.rs index e4ede7fd..e4713a90 100644 --- a/tests/oauth.rs +++ b/tests/oauth.rs @@ -1,17 +1,17 @@ -use crate::common::{database::FRIEND_USER_ID, dummy_data::DummyOAuthClientAlpha}; -use actix_http::StatusCode; -use actix_web::{ - dev::ServiceResponse, - test::{self}, +use crate::common::{ + api_v3::oauth::get_redirect_location_query_params, database::FRIEND_USER_ID, + dummy_data::DummyOAuthClientAlpha, }; +use actix_http::StatusCode; +use actix_web::test::{self}; use common::{ - asserts::assert_status, - database::{FRIEND_USER_PAT, USER_USER_PAT}, + api_v3::oauth::{get_auth_code_from_redirect_params, get_authorize_accept_flow_id}, + asserts::{assert_any_status_except, assert_status}, + database::{FRIEND_USER_PAT, USER_USER_ID, USER_USER_PAT}, environment::with_test_environment, }; -use labrinth::auth::oauth::{OAuthClientAccessRequest, TokenResponse}; -use reqwest::header::{CACHE_CONTROL, LOCATION, PRAGMA}; -use std::collections::HashMap; +use labrinth::auth::oauth::TokenResponse; +use reqwest::header::{CACHE_CONTROL, PRAGMA}; mod common; @@ -22,7 +22,7 @@ async fn oauth_flow_happy_path() { valid_redirect_uri: base_redirect_uri, client_id, client_secret, - } = env.dummy.unwrap().oauth_client_alpha.clone(); + } = env.dummy.as_ref().unwrap().oauth_client_alpha.clone(); // Initiate authorization let redirect_uri = format!("{}?foo=bar", base_redirect_uri); @@ -31,8 +31,8 @@ async fn oauth_flow_happy_path() { .v3 .oauth_authorize( &client_id, - "USER_READ NOTIFICATION_READ", - &redirect_uri, + Some("USER_READ NOTIFICATION_READ"), + Some(&redirect_uri), Some(original_state), FRIEND_USER_PAT, ) @@ -66,10 +66,13 @@ async fn oauth_flow_happy_path() { assert_eq!(resp.headers().get(PRAGMA).unwrap(), "no-cache"); let token_resp: TokenResponse = test::read_body_json(resp).await; - // Validate that the token works - env.v2 - .get_user_notifications_deserialized(FRIEND_USER_ID, &token_resp.access_token) - .await; + // Validate the token works + env.assert_read_notifications_status( + FRIEND_USER_ID, + &token_resp.access_token, + StatusCode::OK, + ) + .await; }) .await; } @@ -77,18 +80,14 @@ async fn oauth_flow_happy_path() { #[actix_rt::test] async fn oauth_authorize_for_already_authorized_scopes_returns_auth_code() { with_test_environment(|env| async { - let DummyOAuthClientAlpha { - valid_redirect_uri, - client_id, - .. - } = env.dummy.unwrap().oauth_client_alpha.clone(); + let DummyOAuthClientAlpha { client_id, .. } = env.dummy.unwrap().oauth_client_alpha.clone(); let resp = env .v3 .oauth_authorize( &client_id, - "USER_READ NOTIFICATION_READ", - &valid_redirect_uri, + Some("USER_READ NOTIFICATION_READ"), + None, Some("1234"), USER_USER_PAT, ) @@ -100,8 +99,8 @@ async fn oauth_authorize_for_already_authorized_scopes_returns_auth_code() { .v3 .oauth_authorize( &client_id, - "USER_READ", - &valid_redirect_uri, + Some("USER_READ"), + None, Some("5678"), USER_USER_PAT, ) @@ -115,20 +114,14 @@ async fn oauth_authorize_for_already_authorized_scopes_returns_auth_code() { async fn get_oauth_token_with_already_used_auth_code_fails() { with_test_environment(|env| async { let DummyOAuthClientAlpha { - valid_redirect_uri, client_id, client_secret, + .. } = env.dummy.unwrap().oauth_client_alpha.clone(); let resp = env .v3 - .oauth_authorize( - &client_id, - "USER_READ", - &valid_redirect_uri, - None, - USER_USER_PAT, - ) + .oauth_authorize(&client_id, None, None, None, USER_USER_PAT) .await; let flow_id = get_authorize_accept_flow_id(resp).await; @@ -137,23 +130,13 @@ async fn get_oauth_token_with_already_used_auth_code_fails() { let resp = env .v3 - .oauth_token( - auth_code.clone(), - Some(valid_redirect_uri.clone()), - client_id.clone(), - &client_secret, - ) + .oauth_token(auth_code.clone(), None, client_id.clone(), &client_secret) .await; assert_status(&resp, StatusCode::OK); let resp = env .v3 - .oauth_token( - auth_code, - Some(valid_redirect_uri), - client_id, - &client_secret, - ) + .oauth_token(auth_code, None, client_id, &client_secret) .await; assert_status(&resp, StatusCode::BAD_REQUEST); }) @@ -161,16 +144,61 @@ async fn get_oauth_token_with_already_used_auth_code_fails() { } #[actix_rt::test] -async fn oauth_authorize_with_broader_scopes_requires_user_accept() { - with_test_environment(|env| async { +async fn authorize_with_broader_scopes_can_complete_flow() { + with_test_environment(|env| async move { let DummyOAuthClientAlpha { - valid_redirect_uri: redirect_uri, client_id, + client_secret, .. - } = env.dummy.unwrap().oauth_client_alpha.clone(); + } = env.dummy.as_ref().unwrap().oauth_client_alpha.clone(); + + let first_access_token = env + .v3 + .complete_full_authorize_flow( + &client_id, + &client_secret, + Some("PROJECT_READ"), + None, + None, + USER_USER_PAT, + ) + .await; + let second_access_token = env + .v3 + .complete_full_authorize_flow( + &client_id, + &client_secret, + Some("PROJECT_READ NOTIFICATION_READ"), + None, + None, + USER_USER_PAT, + ) + .await; + + env.assert_read_notifications_status( + USER_USER_ID, + &first_access_token, + StatusCode::UNAUTHORIZED, + ) + .await; + env.assert_read_user_projects_status(USER_USER_ID, &first_access_token, StatusCode::OK) + .await; + + env.assert_read_notifications_status(USER_USER_ID, &second_access_token, StatusCode::OK) + .await; + env.assert_read_user_projects_status(USER_USER_ID, &second_access_token, StatusCode::OK) + .await; + }) + .await; +} + +#[actix_rt::test] +async fn oauth_authorize_with_broader_scopes_requires_user_accept() { + with_test_environment(|env| async { + let client_id = env.dummy.unwrap().oauth_client_alpha.client_id.clone(); let resp = env .v3 - .oauth_authorize(&client_id, "USER_READ", &redirect_uri, None, USER_USER_PAT) + .oauth_authorize(&client_id, Some("USER_READ"), None, None, USER_USER_PAT) .await; let flow_id = get_authorize_accept_flow_id(resp).await; env.v3.oauth_accept(&flow_id, USER_USER_PAT).await; @@ -179,8 +207,8 @@ async fn oauth_authorize_with_broader_scopes_requires_user_accept() { .v3 .oauth_authorize( &client_id, - "USER_READ NOTIFICATION_READ", - &redirect_uri, + Some("USER_READ NOTIFICATION_READ"), + None, None, USER_USER_PAT, ) @@ -192,25 +220,73 @@ async fn oauth_authorize_with_broader_scopes_requires_user_accept() { .await; } -//TODO: What about rejecting??? +#[actix_rt::test] +async fn reject_authorize_ends_authorize_flow() { + with_test_environment(|env| async move { + let client_id = env.dummy.unwrap().oauth_client_alpha.client_id.clone(); + let resp = env + .v3 + .oauth_authorize(&client_id, None, None, None, USER_USER_PAT) + .await; + let flow_id = get_authorize_accept_flow_id(resp).await; -async fn get_authorize_accept_flow_id(response: ServiceResponse) -> String { - test::read_body_json::(response) - .await - .flow_id + let resp = env.v3.oauth_reject(&flow_id, USER_USER_PAT).await; + assert_status(&resp, StatusCode::FOUND); + + let resp = env.v3.oauth_accept(&flow_id, USER_USER_PAT).await; + assert_any_status_except(&resp, StatusCode::FOUND); + }) + .await; } -async fn get_auth_code_from_redirect_params(response: &ServiceResponse) -> String { - let query_params = get_redirect_location_query_params(response); - query_params.get("code").unwrap().to_string() +#[actix_rt::test] +async fn accept_authorize_after_already_accepting_fails() { + with_test_environment(|env| async move { + let client_id = env.dummy.unwrap().oauth_client_alpha.client_id.clone(); + let resp = env + .v3 + .oauth_authorize(&client_id, None, None, None, USER_USER_PAT) + .await; + let flow_id = get_authorize_accept_flow_id(resp).await; + let resp = env.v3.oauth_accept(&flow_id, USER_USER_PAT).await; + assert_status(&resp, StatusCode::FOUND); + + let resp = env.v3.oauth_accept(&flow_id, USER_USER_PAT).await; + assert_status(&resp, StatusCode::BAD_REQUEST); + }) + .await; } -fn get_redirect_location_query_params( - response: &ServiceResponse, -) -> actix_web::web::Query> { - let redirect_location = response.headers().get(LOCATION).unwrap().to_str().unwrap(); - actix_web::web::Query::>::from_query( - redirect_location.split_once('?').unwrap().1, - ) - .unwrap() +#[actix_rt::test] +async fn revoke_authorization_after_issuing_token_revokes_token() { + with_test_environment(|env| async move { + let DummyOAuthClientAlpha { + client_id, + client_secret, + .. + } = env.dummy.as_ref().unwrap().oauth_client_alpha.clone(); + let access_token = env + .v3 + .complete_full_authorize_flow( + &client_id, + &client_secret, + Some("NOTIFICATION_READ"), + None, + None, + USER_USER_PAT, + ) + .await; + env.assert_read_notifications_status(USER_USER_ID, &access_token, StatusCode::OK) + .await; + + let resp = env + .v3 + .revoke_oauth_authorization(&client_id, USER_USER_PAT) + .await; + assert_status(&resp, StatusCode::OK); + + env.assert_read_notifications_status(USER_USER_ID, &access_token, StatusCode::UNAUTHORIZED) + .await; + }) + .await; } From 6e783068e6e91b79dffdc949618a5c3be9ae406d Mon Sep 17 00:00:00 2001 From: OmegaJak Date: Fri, 20 Oct 2023 14:05:17 -0500 Subject: [PATCH 20/72] Test oauth client and authorization management routes --- src/database/models/oauth_client_item.rs | 2 +- src/routes/v3/oauth_clients.rs | 3 +- tests/common/api_v3/oauth.rs | 6 ++ tests/common/api_v3/oauth_clients.rs | 9 ++ tests/oauth_clients.rs | 132 +++++++++++++++++++++++ 5 files changed, 150 insertions(+), 2 deletions(-) create mode 100644 tests/oauth_clients.rs diff --git a/src/database/models/oauth_client_item.rs b/src/database/models/oauth_client_item.rs index 499bac05..e890b205 100644 --- a/src/database/models/oauth_client_item.rs +++ b/src/database/models/oauth_client_item.rs @@ -165,7 +165,7 @@ impl OAuthClient { let ids = ids.into_iter().map(|id| id.0).collect_vec(); sqlx::query!( " - DELETE FROM oauth_clients + DELETE FROM oauth_client_redirect_uris WHERE id IN (SELECT * FROM UNNEST($1::bigint[])) ", diff --git a/src/routes/v3/oauth_clients.rs b/src/routes/v3/oauth_clients.rs index 2cbf12fd..3a83baf4 100644 --- a/src/routes/v3/oauth_clients.rs +++ b/src/routes/v3/oauth_clients.rs @@ -35,8 +35,9 @@ use crate::database::models::oauth_client_item::OAuthClient as DBOAuthClient; pub fn config(cfg: &mut web::ServiceConfig) { cfg.service(oauth_client_create); - cfg.service(get_user_clients); + cfg.service(oauth_client_edit); cfg.service(oauth_client_delete); + cfg.service(get_user_clients); cfg.service(get_user_oauth_authorizations); cfg.service(revoke_oauth_authorization); } diff --git a/tests/common/api_v3/oauth.rs b/tests/common/api_v3/oauth.rs index e7c6e7dc..1c626987 100644 --- a/tests/common/api_v3/oauth.rs +++ b/tests/common/api_v3/oauth.rs @@ -1,5 +1,6 @@ use std::collections::HashMap; +use actix_http::StatusCode; use actix_web::{ dev::ServiceResponse, test::{self, TestRequest}, @@ -9,6 +10,8 @@ use labrinth::auth::oauth::{ }; use reqwest::header::{AUTHORIZATION, LOCATION}; +use crate::common::asserts::assert_status; + use super::ApiV3; impl ApiV3 { @@ -129,17 +132,20 @@ pub fn generate_authorize_uri( } pub async fn get_authorize_accept_flow_id(response: ServiceResponse) -> String { + assert_status(&response, StatusCode::OK); test::read_body_json::(response) .await .flow_id } pub async fn get_auth_code_from_redirect_params(response: &ServiceResponse) -> String { + assert_status(&response, StatusCode::FOUND); let query_params = get_redirect_location_query_params(response); query_params.get("code").unwrap().to_string() } pub async fn get_access_token(response: ServiceResponse) -> String { + assert_status(&response, StatusCode::OK); test::read_body_json::(response) .await .access_token diff --git a/tests/common/api_v3/oauth_clients.rs b/tests/common/api_v3/oauth_clients.rs index 3987087a..c014fd14 100644 --- a/tests/common/api_v3/oauth_clients.rs +++ b/tests/common/api_v3/oauth_clients.rs @@ -68,6 +68,15 @@ impl ApiV3 { self.call(req).await } + pub async fn delete_oauth_client(&self, client_id: &str, pat: &str) -> ServiceResponse { + let req = TestRequest::delete() + .uri(&format!("/v3/oauth_app/{}", client_id)) + .append_header((AUTHORIZATION, pat)) + .to_request(); + + self.call(req).await + } + pub async fn revoke_oauth_authorization(&self, client_id: &str, pat: &str) -> ServiceResponse { let req = TestRequest::delete() .uri(&format!("/v3/user/oauth_authorizations/{}", client_id)) diff --git a/tests/oauth_clients.rs b/tests/oauth_clients.rs new file mode 100644 index 00000000..e277b267 --- /dev/null +++ b/tests/oauth_clients.rs @@ -0,0 +1,132 @@ +use actix_http::StatusCode; +use common::{ + database::{FRIEND_USER_ID, FRIEND_USER_PAT, USER_USER_ID, USER_USER_PAT}, + dummy_data::DummyOAuthClientAlpha, + environment::with_test_environment, + get_json_val_str, +}; +use labrinth::{models::pats::Scopes, routes::v3::oauth_clients::OAuthClientEdit}; + +use crate::common::{asserts::assert_status, database::USER_USER_ID_PARSED}; + +mod common; + +#[actix_rt::test] +async fn can_create_edit_get_oauth_client() { + with_test_environment(|env| async move { + let client_name = "test_client".to_string(); + let redirect_uris = vec![ + "https://modrinth.com".to_string(), + "https://modrinth.com/a".to_string(), + ]; + let creation_result = env + .v3 + .add_oauth_client( + client_name.clone(), + Scopes::all(), + redirect_uris.clone(), + FRIEND_USER_PAT, + ) + .await; + let client_id = get_json_val_str(creation_result.client.id); + + let icon_url = Some("https://modrinth.com/icon".to_string()); + let edited_redirect_uris = vec![ + redirect_uris[0].clone(), + "https://modrinth.com/b".to_string(), + ]; + let edit = OAuthClientEdit { + name: None, + icon_url: Some(icon_url.clone()), + max_scopes: None, + redirect_uris: Some(edited_redirect_uris.clone()), + }; + let resp = env + .v3 + .edit_oauth_client(&client_id, edit, FRIEND_USER_PAT) + .await; + assert_status(&resp, StatusCode::OK); + + let clients = env + .v3 + .get_user_oauth_clients(FRIEND_USER_ID, FRIEND_USER_PAT) + .await; + assert_eq!(1, clients.len()); + assert_eq!(icon_url, clients[0].icon_url); + assert_eq!(client_name, clients[0].name); + assert_eq!(2, clients[0].redirect_uris.len()); + assert_eq!(edited_redirect_uris[0], clients[0].redirect_uris[0].uri); + assert_eq!(edited_redirect_uris[1], clients[0].redirect_uris[1].uri); + }) + .await; +} + +#[actix_rt::test] +async fn can_delete_oauth_client() { + with_test_environment(|env| async move { + let client_id = env.dummy.unwrap().oauth_client_alpha.client_id.clone(); + let resp = env.v3.delete_oauth_client(&client_id, USER_USER_PAT).await; + assert_status(&resp, StatusCode::NO_CONTENT); + + let clients = env + .v3 + .get_user_oauth_clients(USER_USER_ID, USER_USER_PAT) + .await; + assert_eq!(0, clients.len()); + }) + .await; +} + +#[actix_rt::test] +async fn delete_oauth_client_after_issuing_access_tokens_revokes_tokens() { + with_test_environment(|env| async move { + let DummyOAuthClientAlpha { + client_id, + client_secret, + .. + } = env.dummy.as_ref().unwrap().oauth_client_alpha.clone(); + let access_token = env + .v3 + .complete_full_authorize_flow( + &client_id, + &client_secret, + Some("NOTIFICATION_READ"), + None, + None, + USER_USER_PAT, + ) + .await; + + env.v3.delete_oauth_client(&client_id, USER_USER_PAT).await; + + env.assert_read_notifications_status(USER_USER_ID, &access_token, StatusCode::UNAUTHORIZED) + .await; + }) + .await; +} + +#[actix_rt::test] +async fn can_list_user_oauth_authorizations() { + with_test_environment(|env| async move { + let DummyOAuthClientAlpha { + client_id, + client_secret, + .. + } = env.dummy.as_ref().unwrap().oauth_client_alpha.clone(); + env.v3 + .complete_full_authorize_flow( + &client_id, + &client_secret, + None, + None, + None, + USER_USER_PAT, + ) + .await; + + let authorizations = env.v3.get_user_oauth_authorizations(USER_USER_PAT).await; + assert_eq!(1, authorizations.len()); + assert_eq!(USER_USER_ID_PARSED, authorizations[0].user_id.0 as i64); + }) + .await; +} From 0999d255519bafbb29b73f071c867238361428ab Mon Sep 17 00:00:00 2001 From: OmegaJak Date: Fri, 20 Oct 2023 14:06:54 -0500 Subject: [PATCH 21/72] cargo sqlx prepare --- sqlx-data.json | 558 +++++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 545 insertions(+), 13 deletions(-) diff --git a/sqlx-data.json b/sqlx-data.json index 8b02bea3..58c97aab 100644 --- a/sqlx-data.json +++ b/sqlx-data.json @@ -189,19 +189,6 @@ }, "query": "\n DELETE FROM reports\n WHERE user_id = $1 OR reporter = $1\n " }, - "0472045549758d8eef84592908c438d6222a26926f4b06865b84979fc92564ba": { - "describe": { - "columns": [], - "nullable": [], - "parameters": { - "Left": [ - "Int8", - "Timestamptz" - ] - } - }, - "query": "\n UPDATE pats\n SET last_used = $2\n WHERE (id = $1)\n " - }, "04e5ecb14c526000e9098efb65861f6125e6fcc88f39d6ad811ac8504d229de1": { "describe": { "columns": [], @@ -1077,6 +1064,19 @@ }, "query": "\n UPDATE threads_messages\n SET body = $2\n WHERE id = $1\n " }, + "2040e7f0a9b66bc12dc89007b07bab9da5fdd1b7ee72d411a9989deb4ee506bb": { + "describe": { + "columns": [], + "nullable": [], + "parameters": { + "Left": [ + "Int8Array", + "Timestamptz" + ] + } + }, + "query": "\n UPDATE pats\n SET last_used = $2\n WHERE id IN\n (SELECT * FROM UNNEST($1::bigint[]))\n " + }, "21d20e5f09cb0729dc16c8609c35cec5a913f3172b53b8ae05da0096a33b4b64": { "describe": { "columns": [ @@ -1805,6 +1805,26 @@ }, "query": "\n UPDATE users\n SET trolley_id = $1, trolley_account_status = $2\n WHERE id = $3\n " }, + "3f8bd0280a59ad4561ca652cebc7734a9af0e944f1671df71f9f4e25d835ffd9": { + "describe": { + "columns": [ + { + "name": "exists", + "ordinal": 0, + "type_info": "Bool" + } + ], + "nullable": [ + null + ], + "parameters": { + "Left": [ + "Int8" + ] + } + }, + "query": "SELECT EXISTS(SELECT 1 FROM oauth_client_authorizations WHERE id=$1)" + }, "40f7c5bec98fe3503d6bd6db2eae5a4edb8d5d6efda9b9dc124f344ae5c60e08": { "describe": { "columns": [], @@ -2108,6 +2128,26 @@ }, "query": "UPDATE mods\n SET downloads = downloads + 1\n WHERE (id = $1)" }, + "49d5360751072cc2cb5954cdecb31044f41d210dd64bbbb5e7c2347acc2304e9": { + "describe": { + "columns": [ + { + "name": "exists", + "ordinal": 0, + "type_info": "Bool" + } + ], + "nullable": [ + null + ], + "parameters": { + "Left": [ + "Int8" + ] + } + }, + "query": "SELECT EXISTS(SELECT 1 FROM oauth_client_redirect_uris WHERE id=$1)" + }, "4a54d350b4695c32a802675506e85b0506fc62a63ca0ee5f38890824301d6515": { "describe": { "columns": [], @@ -2198,6 +2238,68 @@ }, "query": "\n UPDATE mods\n SET description = $1\n WHERE (id = $2)\n " }, + "4eb4b52a089d2fa11b1036d7f9fca546db44014cb9381f4f3a6d0af5d9d69511": { + "describe": { + "columns": [ + { + "name": "id", + "ordinal": 0, + "type_info": "Int8" + }, + { + "name": "client_id", + "ordinal": 1, + "type_info": "Int8" + }, + { + "name": "user_id", + "ordinal": 2, + "type_info": "Int8" + }, + { + "name": "scopes", + "ordinal": 3, + "type_info": "Int8" + }, + { + "name": "created", + "ordinal": 4, + "type_info": "Timestamptz" + }, + { + "name": "client_name", + "ordinal": 5, + "type_info": "Text" + }, + { + "name": "client_icon_url", + "ordinal": 6, + "type_info": "Text" + }, + { + "name": "client_created_by", + "ordinal": 7, + "type_info": "Int8" + } + ], + "nullable": [ + false, + false, + false, + false, + false, + false, + true, + false + ], + "parameters": { + "Left": [ + "Int8" + ] + } + }, + "query": "\n SELECT \n auths.id,\n auths.client_id,\n auths.user_id,\n auths.scopes,\n auths.created,\n clients.name as client_name,\n clients.icon_url as client_icon_url,\n clients.created_by as client_created_by\n FROM oauth_client_authorizations auths\n JOIN oauth_clients clients ON clients.id = auths.client_id\n WHERE user_id=$1\n " + }, "4fb5bd341369b4beb6b4a88de296b608ea5441a96db9f7360fbdccceb4628202": { "describe": { "columns": [], @@ -2846,6 +2948,64 @@ }, "query": "\n SELECT tm.id, tm.team_id, tm.user_id, tm.role, tm.permissions, tm.organization_permissions, tm.accepted, tm.payouts_split, tm.ordering\n FROM organizations o\n INNER JOIN team_members tm ON tm.team_id = o.team_id AND user_id = $2 AND accepted = TRUE\n WHERE o.id = $1\n " }, + "65c9f9cd010c14100839cd0b044103cac7e4b850d446b29d2efd9757b642fc1c": { + "describe": { + "columns": [], + "nullable": [], + "parameters": { + "Left": [ + "Int8Array", + "Timestamptz" + ] + } + }, + "query": "\n UPDATE oauth_access_tokens\n SET last_used = $2\n WHERE id IN\n (SELECT * FROM UNNEST($1::bigint[]))\n " + }, + "65ddadc9d103ccb9d81e1f52565cff1889e5490f0d0d62170ed2b9515ffc5104": { + "describe": { + "columns": [ + { + "name": "id", + "ordinal": 0, + "type_info": "Int8" + }, + { + "name": "client_id", + "ordinal": 1, + "type_info": "Int8" + }, + { + "name": "user_id", + "ordinal": 2, + "type_info": "Int8" + }, + { + "name": "scopes", + "ordinal": 3, + "type_info": "Int8" + }, + { + "name": "created", + "ordinal": 4, + "type_info": "Timestamptz" + } + ], + "nullable": [ + false, + false, + false, + false, + false + ], + "parameters": { + "Left": [ + "Int8", + "Int8" + ] + } + }, + "query": "\n SELECT id, client_id, user_id, scopes, created\n FROM oauth_client_authorizations\n WHERE client_id=$1 AND user_id=$2\n " + }, "665e294e9737fd0299fc4639127d56811485dc8a5a4e08a4e7292044d8a2fb7a": { "describe": { "columns": [], @@ -2924,6 +3084,21 @@ }, "query": "\n DELETE FROM notifications\n WHERE user_id = $1\n " }, + "68ef15f50a067503dce124b50fb3c2efd07808c4a859ab1b1e9e65e16439a8f3": { + "describe": { + "columns": [], + "nullable": [], + "parameters": { + "Left": [ + "Int8", + "Int8", + "Int8", + "Int8" + ] + } + }, + "query": "\n INSERT INTO oauth_client_authorizations (\n id, client_id, user_id, scopes\n )\n VALUES (\n $1, $2, $3, $4\n )\n ON CONFLICT (id)\n DO UPDATE SET scopes = EXCLUDED.scopes\n " + }, "69b093cad9109ccf4779bfd969897f6b9ebc9d0d4230c958de4fa07435776349": { "describe": { "columns": [], @@ -3031,6 +3206,36 @@ }, "query": "\n SELECT r.id, rt.name, r.mod_id, r.version_id, r.user_id, r.body, r.reporter, r.created, t.id thread_id, r.closed\n FROM reports r\n INNER JOIN report_types rt ON rt.id = r.report_type_id\n INNER JOIN threads t ON t.report_id = r.id\n WHERE r.id = ANY($1)\n ORDER BY r.created DESC\n " }, + "6b881555e610ddc6796cdcbfd2de26e68b10522d0f1df3f006d58f6b72be9911": { + "describe": { + "columns": [ + { + "name": "created", + "ordinal": 0, + "type_info": "Timestamptz" + }, + { + "name": "expires", + "ordinal": 1, + "type_info": "Timestamptz" + } + ], + "nullable": [ + false, + false + ], + "parameters": { + "Left": [ + "Int8", + "Int8", + "Text", + "Int8", + "Timestamptz" + ] + } + }, + "query": "\n INSERT INTO oauth_access_tokens (\n id, authorization_id, token_hash, scopes, last_used\n )\n VALUES (\n $1, $2, $3, $4, $5\n )\n RETURNING created, expires\n " + }, "6b89c2b2557e304c2a3a02d7824327685f9be696254bf2370d0c995aafc6a2d8": { "describe": { "columns": [], @@ -3235,6 +3440,74 @@ }, "query": "\n DELETE FROM mods_categories\n WHERE joining_mod_id = $1 AND is_additional = $2\n " }, + "7174cd941ff95260ad9c564daf92876c5ae253df538f4cd4c3701e63137fb01b": { + "describe": { + "columns": [ + { + "name": "id", + "ordinal": 0, + "type_info": "Int8" + }, + { + "name": "authorization_id", + "ordinal": 1, + "type_info": "Int8" + }, + { + "name": "token_hash", + "ordinal": 2, + "type_info": "Text" + }, + { + "name": "scopes", + "ordinal": 3, + "type_info": "Int8" + }, + { + "name": "created", + "ordinal": 4, + "type_info": "Timestamptz" + }, + { + "name": "expires", + "ordinal": 5, + "type_info": "Timestamptz" + }, + { + "name": "last_used", + "ordinal": 6, + "type_info": "Timestamptz" + }, + { + "name": "client_id", + "ordinal": 7, + "type_info": "Int8" + }, + { + "name": "user_id", + "ordinal": 8, + "type_info": "Int8" + } + ], + "nullable": [ + false, + false, + false, + false, + false, + false, + true, + false, + false + ], + "parameters": { + "Left": [ + "Text" + ] + } + }, + "query": "\n SELECT\n tokens.id,\n tokens.authorization_id,\n tokens.token_hash,\n tokens.scopes,\n tokens.created,\n tokens.expires,\n tokens.last_used,\n auths.client_id,\n auths.user_id\n FROM oauth_access_tokens tokens\n JOIN oauth_client_authorizations auths\n ON tokens.authorization_id = auths.id\n WHERE tokens.token_hash = $1\n " + }, "71abd207410d123f9a50345ddcddee335fea0d0cc6f28762713ee01a36aee8a0": { "describe": { "columns": [ @@ -3998,6 +4271,23 @@ }, "query": "\n SELECT t.id, t.thread_type, t.mod_id, t.report_id, t.show_in_mod_inbox,\n ARRAY_AGG(DISTINCT tm.user_id) filter (where tm.user_id is not null) members,\n JSONB_AGG(DISTINCT jsonb_build_object('id', tmsg.id, 'author_id', tmsg.author_id, 'thread_id', tmsg.thread_id, 'body', tmsg.body, 'created', tmsg.created)) filter (where tmsg.id is not null) messages\n FROM threads t\n LEFT OUTER JOIN threads_messages tmsg ON tmsg.thread_id = t.id\n LEFT OUTER JOIN threads_members tm ON tm.thread_id = t.id\n WHERE t.id = ANY($1)\n GROUP BY t.id\n " }, + "93f6a94a9b288916dbf9999338d2278605311a311def3cbe38846b8ca465737f": { + "describe": { + "columns": [], + "nullable": [], + "parameters": { + "Left": [ + "Int8", + "Text", + "Text", + "Int8", + "Text", + "Int8" + ] + } + }, + "query": "\n INSERT INTO oauth_clients (\n id, name, icon_url, max_scopes, secret_hash, created_by\n )\n VALUES (\n $1, $2, $3, $4, $5, $6\n )\n " + }, "9544cea57095a94109be5fef9a4737626a9003d58680943cdbffc7c9ada7877b": { "describe": { "columns": [], @@ -4127,6 +4417,74 @@ }, "query": "\n UPDATE historical_payouts\n SET status = $1\n WHERE payment_id = $2\n " }, + "97f5b001699f8de4e80bdb5680f23d53fcf26489a387d6e170a58a55a814201d": { + "describe": { + "columns": [ + { + "name": "id", + "ordinal": 0, + "type_info": "Int8" + }, + { + "name": "name", + "ordinal": 1, + "type_info": "Text" + }, + { + "name": "icon_url", + "ordinal": 2, + "type_info": "Text" + }, + { + "name": "max_scopes", + "ordinal": 3, + "type_info": "Int8" + }, + { + "name": "secret_hash", + "ordinal": 4, + "type_info": "Text" + }, + { + "name": "created", + "ordinal": 5, + "type_info": "Timestamptz" + }, + { + "name": "created_by", + "ordinal": 6, + "type_info": "Int8" + }, + { + "name": "uri_ids", + "ordinal": 7, + "type_info": "Int8Array" + }, + { + "name": "uri_vals", + "ordinal": 8, + "type_info": "TextArray" + } + ], + "nullable": [ + false, + false, + true, + false, + false, + false, + false, + null, + null + ], + "parameters": { + "Left": [ + "Int8" + ] + } + }, + "query": "\n SELECT\n clients.id,\n clients.name,\n clients.icon_url,\n clients.max_scopes,\n clients.secret_hash,\n clients.created,\n clients.created_by,\n uris.uri_ids,\n uris.uri_vals\n FROM oauth_clients clients\n LEFT JOIN (\n SELECT client_id, array_agg(id) as uri_ids, array_agg(uri) as uri_vals\n FROM oauth_client_redirect_uris\n GROUP BY client_id\n ) uris ON clients.id = uris.client_id\n WHERE created_by = $1" + }, "99a1eac69d7f5a5139703df431e6a5c3012a90143a8c635f93632f04d0bc41d4": { "describe": { "columns": [], @@ -4203,6 +4561,20 @@ }, "query": "\n DELETE FROM collections_mods\n WHERE collection_id = $1\n " }, + "9dadd6926a8429e60cb5fd53285b81f2f47ccdded1e764c04d8b7651d9796ce0": { + "describe": { + "columns": [], + "nullable": [], + "parameters": { + "Left": [ + "Int8Array", + "Int8Array", + "VarcharArray" + ] + } + }, + "query": "\n INSERT INTO oauth_client_redirect_uris (id, client_id, uri)\n SELECT * FROM UNNEST($1::bigint[], $2::bigint[], $3::varchar[])\n " + }, "a0148ff25855202e7bb220b6a2bc9220a95e309fb0dae41d9a05afa86e6b33af": { "describe": { "columns": [], @@ -4532,6 +4904,74 @@ }, "query": "\n DELETE FROM mod_follows\n WHERE mod_id = $1\n " }, + "b0699eef7dbf20f3ea457d1fe16338a9e6a9d7763d5614066adb0255139d7956": { + "describe": { + "columns": [ + { + "name": "id", + "ordinal": 0, + "type_info": "Int8" + }, + { + "name": "name", + "ordinal": 1, + "type_info": "Text" + }, + { + "name": "icon_url", + "ordinal": 2, + "type_info": "Text" + }, + { + "name": "max_scopes", + "ordinal": 3, + "type_info": "Int8" + }, + { + "name": "secret_hash", + "ordinal": 4, + "type_info": "Text" + }, + { + "name": "created", + "ordinal": 5, + "type_info": "Timestamptz" + }, + { + "name": "created_by", + "ordinal": 6, + "type_info": "Int8" + }, + { + "name": "uri_ids", + "ordinal": 7, + "type_info": "Int8Array" + }, + { + "name": "uri_vals", + "ordinal": 8, + "type_info": "TextArray" + } + ], + "nullable": [ + false, + false, + true, + false, + false, + false, + false, + null, + null + ], + "parameters": { + "Left": [ + "Int8" + ] + } + }, + "query": "\n SELECT\n clients.id,\n clients.name,\n clients.icon_url,\n clients.max_scopes,\n clients.secret_hash,\n clients.created,\n clients.created_by,\n uris.uri_ids,\n uris.uri_vals\n FROM oauth_clients clients\n LEFT JOIN (\n SELECT client_id, array_agg(id) as uri_ids, array_agg(uri) as uri_vals\n FROM oauth_client_redirect_uris\n GROUP BY client_id\n ) uris ON clients.id = uris.client_id\n WHERE clients.id = $1" + }, "b0c29c51bd3ae5b93d487471a98ee9bbb43a4df468ba781852b137dd315b9608": { "describe": { "columns": [], @@ -5393,6 +5833,18 @@ }, "query": "\n UPDATE team_members\n SET role = $1\n WHERE (team_id = $2 AND user_id = $3)\n " }, + "cf84f5e2a594a90b2e7993758807aaaaf533a4409633cf00c071049bb6816c96": { + "describe": { + "columns": [], + "nullable": [], + "parameters": { + "Left": [ + "Int8Array" + ] + } + }, + "query": "\n DELETE FROM oauth_client_redirect_uris\n WHERE id IN\n (SELECT * FROM UNNEST($1::bigint[]))\n " + }, "cfcc6970c0b469c4afd37bedfd386def7980f6b7006030d4783723861d0e3a38": { "describe": { "columns": [ @@ -5438,6 +5890,26 @@ }, "query": "\n UPDATE users\n SET gitlab_id = $2\n WHERE (id = $1)\n " }, + "d0b2ddba90ce69a50d0260a191bf501784de06acdddeed1db8f570cb04755f1a": { + "describe": { + "columns": [ + { + "name": "exists", + "ordinal": 0, + "type_info": "Bool" + } + ], + "nullable": [ + null + ], + "parameters": { + "Left": [ + "Int8" + ] + } + }, + "query": "SELECT EXISTS(SELECT 1 FROM oauth_clients WHERE id=$1)" + }, "d1566672369ea22cb1f638f073f8e3fb467b354351ae71c67941323749ec9bcd": { "describe": { "columns": [ @@ -5757,6 +6229,19 @@ }, "query": "\n SELECT id, session, user_id\n FROM sessions\n WHERE refresh_expires <= NOW()\n " }, + "db1deb79fa509974f1cd68cacd541c55bf62928a96d9582d3e223d6473335428": { + "describe": { + "columns": [], + "nullable": [], + "parameters": { + "Left": [ + "Int8", + "Int8" + ] + } + }, + "query": "\n DELETE FROM oauth_client_authorizations\n WHERE client_id=$1 AND user_id=$2\n " + }, "dc6aa2e7bfd5d5004620ddd4cd6a47ecc56159e1489054e0652d56df802fb5e5": { "describe": { "columns": [], @@ -5875,6 +6360,26 @@ }, "query": "\n INSERT INTO versions (\n id, mod_id, author_id, name, version_number,\n changelog, date_published, downloads,\n version_type, featured, status\n )\n VALUES (\n $1, $2, $3, $4, $5,\n $6, $7, $8,\n $9, $10, $11\n )\n " }, + "e15ff50bd75a49d50975d337f61f3412349dd6bc5c836d2634bbcb376a6f7c12": { + "describe": { + "columns": [ + { + "name": "exists", + "ordinal": 0, + "type_info": "Bool" + } + ], + "nullable": [ + null + ], + "parameters": { + "Left": [ + "Int8" + ] + } + }, + "query": "SELECT EXISTS(SELECT 1 FROM oauth_access_tokens WHERE id=$1)" + }, "e3235e872f98eb85d3eb4a2518fb9dc88049ce62362bfd02623e9b49ac2e9fed": { "describe": { "columns": [ @@ -6022,6 +6527,18 @@ }, "query": "\n SELECT c.id FROM collections c\n WHERE c.user_id = $1\n " }, + "e60f725571e7b7b716d19735ab3b8f3133bea215a89964d78cb652f930465faf": { + "describe": { + "columns": [], + "nullable": [], + "parameters": { + "Left": [ + "Int8" + ] + } + }, + "query": "\n DELETE FROM oauth_clients\n WHERE id = $1\n " + }, "e6db02891be261e61a25716b83c1298482eb9a04f0c026532030aeb374405f13": { "describe": { "columns": [ @@ -6186,6 +6703,21 @@ }, "query": "SELECT EXISTS(SELECT 1 FROM team_members WHERE id=$1)" }, + "e8cc8895ebc8b1904a43e00f1e123f75ffdaebc76d67a5d35218fa9273d46d53": { + "describe": { + "columns": [], + "nullable": [], + "parameters": { + "Left": [ + "Text", + "Text", + "Int8", + "Int8" + ] + } + }, + "query": "\n UPDATE oauth_clients\n SET name = $1, icon_url = $2, max_scopes = $3\n WHERE (id = $4)\n " + }, "e8d4589132b094df1e7a3ca0440344fc8013c0d20b3c71a1142ccbee91fb3c70": { "describe": { "columns": [ From 68481a16073a5ab5fb2bb4bf2a34ab7da22d8f93 Mon Sep 17 00:00:00 2001 From: OmegaJak Date: Fri, 20 Oct 2023 14:10:48 -0500 Subject: [PATCH 22/72] dead code warning --- tests/common/asserts.rs | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/common/asserts.rs b/tests/common/asserts.rs index a72e4ae2..3c7f585a 100644 --- a/tests/common/asserts.rs +++ b/tests/common/asserts.rs @@ -1,3 +1,5 @@ +#![allow(dead_code)] + pub fn assert_status(response: &actix_web::dev::ServiceResponse, status: actix_http::StatusCode) { assert_eq!(response.status(), status, "{:#?}", response.response()); } From 075b1303935790059967d50acee05c7eb9c651d8 Mon Sep 17 00:00:00 2001 From: OmegaJak Date: Fri, 20 Oct 2023 14:12:52 -0500 Subject: [PATCH 23/72] Auto clippy fixes --- src/auth/oauth/errors.rs | 4 ++-- src/auth/oauth/mod.rs | 18 +++++++++--------- src/auth/validate.rs | 4 ++-- .../models/oauth_client_authorization_item.rs | 8 ++++---- src/database/models/oauth_client_item.rs | 6 +++--- src/database/models/oauth_token_item.rs | 4 ++-- src/routes/v3/oauth_clients.rs | 2 +- tests/common/api_v3/oauth.rs | 6 +++--- 8 files changed, 26 insertions(+), 26 deletions(-) diff --git a/src/auth/oauth/errors.rs b/src/auth/oauth/errors.rs index c8cafd8c..486f6180 100644 --- a/src/auth/oauth/errors.rs +++ b/src/auth/oauth/errors.rs @@ -76,9 +76,9 @@ impl actix_web::ResponseError for OAuthError { if let Some(ValidatedRedirectUri(mut redirect_uri)) = self.valid_redirect_uri.clone() { redirect_uri = format!( "{}?error={}&error_description={}", - redirect_uri.to_string(), + redirect_uri, self.error_type.error_name(), - self.error_type.to_string(), + self.error_type, ); if let Some(state) = self.state.as_ref() { diff --git a/src/auth/oauth/mod.rs b/src/auth/oauth/mod.rs index 4e651c28..3801ec87 100644 --- a/src/auth/oauth/mod.rs +++ b/src/auth/oauth/mod.rs @@ -315,7 +315,7 @@ pub async fn accept_or_reject_client_scopes( Some(&[Scopes::USER_OAUTH_AUTHORIZATIONS_WRITE]), ) .await - .map_err(|e| OAuthError::error(e))? + .map_err(OAuthError::error)? .1; let flow = Flow::take_if( @@ -381,7 +381,7 @@ fn authenticate_client_token_request( req: &HttpRequest, client: &DBOAuthClient, ) -> Result<(), OAuthError> { - let client_secret = extract_authorization_header(&req).map_err(OAuthError::error)?; + let client_secret = extract_authorization_header(req).map_err(OAuthError::error)?; let hashed_client_secret = DBOAuthClient::hash_secret(client_secret); if client.secret_hash != hashed_client_secret { Err(OAuthError::error( @@ -448,21 +448,21 @@ impl ValidatedRedirectUri { if let Some(to_validate) = to_validate { if validate_against .into_iter() - .any(|uri| same_uri_except_query_components(&uri, to_validate)) + .any(|uri| same_uri_except_query_components(uri, to_validate)) { - return Ok(ValidatedRedirectUri(to_validate.clone())); + Ok(ValidatedRedirectUri(to_validate.clone())) } else { - return Err(OAuthError::error(OAuthErrorType::RedirectUriNotConfigured( + Err(OAuthError::error(OAuthErrorType::RedirectUriNotConfigured( to_validate.clone(), - ))); + ))) } } else { - return Ok(ValidatedRedirectUri(first_client_redirect_uri.to_string())); + Ok(ValidatedRedirectUri(first_client_redirect_uri.to_string())) } } else { - return Err(OAuthError::error( + Err(OAuthError::error( OAuthErrorType::ClientMissingRedirectURI { client_id }, - )); + )) } } } diff --git a/src/auth/validate.rs b/src/auth/validate.rs index 6f987c21..8c2a92e1 100644 --- a/src/auth/validate.rs +++ b/src/auth/validate.rs @@ -177,10 +177,10 @@ where pub fn extract_authorization_header(req: &HttpRequest) -> Result<&str, AuthenticationError> { let headers = req.headers(); let token_val: Option<&HeaderValue> = headers.get(AUTHORIZATION); - Ok(token_val + token_val .ok_or_else(|| AuthenticationError::InvalidAuthMethod)? .to_str() - .map_err(|_| AuthenticationError::InvalidCredentials)?) + .map_err(|_| AuthenticationError::InvalidCredentials) } pub async fn check_is_moderator_from_headers<'a, 'b, E>( diff --git a/src/database/models/oauth_client_authorization_item.rs b/src/database/models/oauth_client_authorization_item.rs index 0972caf0..87603452 100644 --- a/src/database/models/oauth_client_authorization_item.rs +++ b/src/database/models/oauth_client_authorization_item.rs @@ -44,13 +44,13 @@ impl OAuthClientAuthorization { .fetch_optional(exec) .await?; - return Ok(value.map(|r| OAuthClientAuthorization { + Ok(value.map(|r| OAuthClientAuthorization { id: OAuthClientAuthorizationId(r.id), client_id: OAuthClientId(r.client_id), user_id: UserId(r.user_id), scopes: Scopes::from_postgres(r.scopes), created: r.created, - })); + })) } pub async fn get_all_for_user( @@ -77,7 +77,7 @@ impl OAuthClientAuthorization { .fetch_all(exec) .await?; - return Ok(results + Ok(results .into_iter() .map(|r| OAuthClientAuthorizationWithClientInfo { id: OAuthClientAuthorizationId(r.id), @@ -89,7 +89,7 @@ impl OAuthClientAuthorization { client_icon_url: r.client_icon_url, client_created_by: UserId(r.client_created_by), }) - .collect_vec()); + .collect_vec()) } pub async fn upsert( diff --git a/src/database/models/oauth_client_item.rs b/src/database/models/oauth_client_item.rs index e890b205..ae2f7d70 100644 --- a/src/database/models/oauth_client_item.rs +++ b/src/database/models/oauth_client_item.rs @@ -76,7 +76,7 @@ impl OAuthClient { .fetch_optional(exec) .await?; - return Ok(value.map(|r| r.into())); + Ok(value.map(|r| r.into())) } pub async fn get_all_user_clients( @@ -88,7 +88,7 @@ impl OAuthClient { .fetch_all(exec) .await?; - return Ok(clients.into_iter().map(|r| r.into()).collect()); + Ok(clients.into_iter().map(|r| r.into()).collect()) } pub async fn remove( @@ -212,7 +212,7 @@ impl From for OAuthClient { .zip(uris.iter()) .map(|(id, uri)| OAuthRedirectUri { id: OAuthRedirectUriId(*id), - client_id: OAuthClientId(r.id.clone()), + client_id: OAuthClientId(r.id), uri: uri.to_string(), }) .collect() diff --git a/src/database/models/oauth_token_item.rs b/src/database/models/oauth_token_item.rs index c2e95501..9c12f383 100644 --- a/src/database/models/oauth_token_item.rs +++ b/src/database/models/oauth_token_item.rs @@ -46,7 +46,7 @@ impl OAuthAccessToken { .fetch_optional(exec) .await?; - return Ok(value.map(|r| OAuthAccessToken { + Ok(value.map(|r| OAuthAccessToken { id: OAuthAccessTokenId(r.id), authorization_id: OAuthClientAuthorizationId(r.authorization_id), token_hash: r.token_hash, @@ -56,7 +56,7 @@ impl OAuthAccessToken { last_used: r.last_used, client_id: OAuthClientId(r.client_id), user_id: UserId(r.user_id), - })); + })) } /// Inserts and returns the time until the token expires diff --git a/src/routes/v3/oauth_clients.rs b/src/routes/v3/oauth_clients.rs index 3a83baf4..edccf5d7 100644 --- a/src/routes/v3/oauth_clients.rs +++ b/src/routes/v3/oauth_clients.rs @@ -70,7 +70,7 @@ pub async fn get_user_clients( let response = clients .into_iter() - .map(|c| models::oauth_clients::OAuthClient::from(c)) + .map(models::oauth_clients::OAuthClient::from) .collect_vec(); Ok(HttpResponse::Ok().json(response)) diff --git a/tests/common/api_v3/oauth.rs b/tests/common/api_v3/oauth.rs index 1c626987..7778695b 100644 --- a/tests/common/api_v3/oauth.rs +++ b/tests/common/api_v3/oauth.rs @@ -25,13 +25,13 @@ impl ApiV3 { user_pat: &str, ) -> String { let auth_resp = self - .oauth_authorize(&client_id, scope, redirect_uri, state, user_pat) + .oauth_authorize(client_id, scope, redirect_uri, state, user_pat) .await; let flow_id = get_authorize_accept_flow_id(auth_resp).await; let redirect_resp = self.oauth_accept(&flow_id, user_pat).await; let auth_code = get_auth_code_from_redirect_params(&redirect_resp).await; let token_resp = self - .oauth_token(auth_code, None, client_id.to_string(), &client_secret) + .oauth_token(auth_code, None, client_id.to_string(), client_secret) .await; get_access_token(token_resp).await } @@ -139,7 +139,7 @@ pub async fn get_authorize_accept_flow_id(response: ServiceResponse) -> String { } pub async fn get_auth_code_from_redirect_params(response: &ServiceResponse) -> String { - assert_status(&response, StatusCode::FOUND); + assert_status(response, StatusCode::FOUND); let query_params = get_redirect_location_query_params(response); query_params.get("code").unwrap().to_string() } From 4e2bcabcc27f847df6701b7ca08914c4e8d81bc3 Mon Sep 17 00:00:00 2001 From: OmegaJak Date: Fri, 20 Oct 2023 14:29:41 -0500 Subject: [PATCH 24/72] Uri refactoring --- src/auth/oauth/mod.rs | 119 +++++-------------------------- src/auth/oauth/uris.rs | 94 ++++++++++++++++++++++++ src/database/models/flow_item.rs | 5 +- tests/oauth.rs | 4 +- 4 files changed, 115 insertions(+), 107 deletions(-) create mode 100644 src/auth/oauth/uris.rs diff --git a/src/auth/oauth/mod.rs b/src/auth/oauth/mod.rs index 3801ec87..a68f934d 100644 --- a/src/auth/oauth/mod.rs +++ b/src/auth/oauth/mod.rs @@ -1,4 +1,5 @@ use crate::auth::get_user_from_headers; +use crate::auth::oauth::uris::{OAuthRedirectUris, ValidatedRedirectUri}; use crate::auth::validate::extract_authorization_header; use crate::database::models::flow_item::Flow; use crate::database::models::oauth_client_authorization_item::OAuthClientAuthorization; @@ -27,6 +28,7 @@ use self::errors::{OAuthError, OAuthErrorType}; use super::AuthenticationError; pub mod errors; +pub mod uris; pub fn config(cfg: &mut ServiceConfig) { cfg.service( @@ -113,6 +115,8 @@ pub async fn init_oauth( OAuthClientAuthorization::get(client.id, user.id.into(), &**pool) .await .map_err(|e| OAuthError::redirect(e, &oauth_info.state, &redirect_uri))?; + let redirect_uris = + OAuthRedirectUris::new(oauth_info.redirect_uri.clone(), redirect_uri.clone()); match existing_authorization { Some(existing_authorization) if existing_authorization.scopes.contains(requested_scopes) => @@ -122,8 +126,7 @@ pub async fn init_oauth( client.id, existing_authorization.id, requested_scopes, - redirect_uri, - oauth_info.redirect_uri, + redirect_uris, oauth_info.state, &redis, ) @@ -135,8 +138,7 @@ pub async fn init_oauth( client_id: client.id, existing_authorization_id: existing_authorization.map(|a| a.id), scopes: requested_scopes, - validated_redirect_uri: redirect_uri.clone(), - original_redirect_uri: oauth_info.redirect_uri.clone(), + redirect_uris, state: oauth_info.state.clone(), } .insert(Duration::minutes(30), &redis) @@ -212,10 +214,10 @@ pub async fn request_token( pool: Data, redis: Data, ) -> Result { - let client_id = models::ids::OAuthClientId::parse(&req_params.client_id) + let req_client_id = models::ids::OAuthClientId::parse(&req_params.client_id) .map_err(OAuthError::error)? .into(); - let client = DBOAuthClient::get(client_id, &**pool) + let client = DBOAuthClient::get(req_client_id, &**pool) .await .map_err(OAuthError::error)?; if let Some(client) = client { @@ -239,7 +241,7 @@ pub async fn request_token( }) = flow { // https://datatracker.ietf.org/doc/html/rfc6749#section-4.1.3 - if client_id != client_id { + if req_client_id != client_id { return Err(OAuthError::error(OAuthErrorType::UnauthorizedClient)); } @@ -294,7 +296,7 @@ pub async fn request_token( } } else { Err(OAuthError::error(OAuthErrorType::InvalidClientId( - client_id, + req_client_id, ))) } } @@ -330,8 +332,7 @@ pub async fn accept_or_reject_client_scopes( client_id, existing_authorization_id, scopes, - validated_redirect_uri, - original_redirect_uri, + redirect_uris, state, }) = flow { @@ -359,8 +360,7 @@ pub async fn accept_or_reject_client_scopes( client_id, auth_id, scopes, - validated_redirect_uri, - original_redirect_uri, + redirect_uris, state, &redis, ) @@ -369,7 +369,7 @@ pub async fn accept_or_reject_client_scopes( Err(OAuthError::redirect( OAuthErrorType::AccessDenied, &state, - &validated_redirect_uri, + &redirect_uris.validated, )) } } else { @@ -406,8 +406,7 @@ async fn init_oauth_code_flow( client_id: OAuthClientId, authorization_id: OAuthClientAuthorizationId, scopes: Scopes, - validated_redirect_uri: ValidatedRedirectUri, - original_redirect_uri: Option, + redirect_uris: OAuthRedirectUris, state: Option, redis: &RedisPool, ) -> Result { @@ -416,18 +415,18 @@ async fn init_oauth_code_flow( client_id, authorization_id, scopes, - original_redirect_uri, + original_redirect_uri: redirect_uris.original.clone(), } .insert(Duration::minutes(10), redis) .await - .map_err(|e| OAuthError::redirect(e, &state, &validated_redirect_uri.clone()))?; + .map_err(|e| OAuthError::redirect(e, &state, &redirect_uris.validated.clone()))?; let mut redirect_params = vec![format!("code={code}")]; if let Some(state) = state { redirect_params.push(format!("state={state}")); } - let redirect_uri = append_params_to_uri(&validated_redirect_uri.0, &redirect_params); + let redirect_uri = append_params_to_uri(&redirect_uris.validated.0, &redirect_params); // IETF RFC 6749 Section 4.1.2 (https://datatracker.ietf.org/doc/html/rfc6749#section-4.1.2) Ok(HttpResponse::Found() @@ -435,44 +434,6 @@ async fn init_oauth_code_flow( .finish()) } -#[derive(Clone, Debug, Serialize, Deserialize)] -pub struct ValidatedRedirectUri(pub String); - -impl ValidatedRedirectUri { - pub fn validate<'a>( - to_validate: &Option, - validate_against: impl IntoIterator + Clone, - client_id: OAuthClientId, - ) -> Result { - if let Some(first_client_redirect_uri) = validate_against.clone().into_iter().next() { - if let Some(to_validate) = to_validate { - if validate_against - .into_iter() - .any(|uri| same_uri_except_query_components(uri, to_validate)) - { - Ok(ValidatedRedirectUri(to_validate.clone())) - } else { - Err(OAuthError::error(OAuthErrorType::RedirectUriNotConfigured( - to_validate.clone(), - ))) - } - } else { - Ok(ValidatedRedirectUri(first_client_redirect_uri.to_string())) - } - } else { - Err(OAuthError::error( - OAuthErrorType::ClientMissingRedirectURI { client_id }, - )) - } - } -} - -fn same_uri_except_query_components(a: &str, b: &str) -> bool { - let mut a_components = a.split('?'); - let mut b_components = b.split('?'); - a_components.next() == b_components.next() -} - fn append_params_to_uri(uri: &str, params: &[impl AsRef]) -> String { let mut uri = uri.to_string(); let mut connector = if uri.contains('?') { "&" } else { "?" }; @@ -483,49 +444,3 @@ fn append_params_to_uri(uri: &str, params: &[impl AsRef]) -> String { uri } - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn validate_for_none_returns_first_valid_uri() { - let validate_against = vec!["https://modrinth.com/a"]; - - let validated = - ValidatedRedirectUri::validate(&None, validate_against.clone(), OAuthClientId(0)) - .unwrap(); - - assert_eq!(validate_against[0], validated.0); - } - - #[test] - fn validate_for_valid_uri_returns_first_matching_uri_ignoring_query_params() { - let validate_against = vec![ - "https://modrinth.com/a?q3=p3&q4=p4", - "https://modrinth.com/a/b/c?q1=p1&q2=p2", - ]; - let to_validate = "https://modrinth.com/a/b/c?query0=param0&query1=param1".to_string(); - - let validated = ValidatedRedirectUri::validate( - &Some(to_validate.clone()), - validate_against, - OAuthClientId(0), - ) - .unwrap(); - - assert_eq!(to_validate, validated.0); - } - - #[test] - fn validate_for_invalid_uri_returns_err() { - let validate_against = vec!["https://modrinth.com/a"]; - let to_validate = "https://modrinth.com/a/b".to_string(); - - let validated = - ValidatedRedirectUri::validate(&Some(to_validate), validate_against, OAuthClientId(0)); - - assert!(validated - .is_err_and(|e| matches!(e.error_type, OAuthErrorType::RedirectUriNotConfigured(_)))); - } -} diff --git a/src/auth/oauth/uris.rs b/src/auth/oauth/uris.rs new file mode 100644 index 00000000..708aa8a0 --- /dev/null +++ b/src/auth/oauth/uris.rs @@ -0,0 +1,94 @@ +use super::errors::OAuthError; +use crate::auth::oauth::OAuthErrorType; +use crate::database::models::OAuthClientId; +use serde::{Deserialize, Serialize}; + +#[derive(derive_new::new, Serialize, Deserialize)] +pub struct OAuthRedirectUris { + pub original: Option, + pub validated: ValidatedRedirectUri, +} + +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct ValidatedRedirectUri(pub String); + +impl ValidatedRedirectUri { + pub fn validate<'a>( + to_validate: &Option, + validate_against: impl IntoIterator + Clone, + client_id: OAuthClientId, + ) -> Result { + if let Some(first_client_redirect_uri) = validate_against.clone().into_iter().next() { + if let Some(to_validate) = to_validate { + if validate_against + .into_iter() + .any(|uri| same_uri_except_query_components(uri, to_validate)) + { + Ok(ValidatedRedirectUri(to_validate.clone())) + } else { + Err(OAuthError::error(OAuthErrorType::RedirectUriNotConfigured( + to_validate.clone(), + ))) + } + } else { + Ok(ValidatedRedirectUri(first_client_redirect_uri.to_string())) + } + } else { + Err(OAuthError::error( + OAuthErrorType::ClientMissingRedirectURI { client_id }, + )) + } + } +} + +fn same_uri_except_query_components(a: &str, b: &str) -> bool { + let mut a_components = a.split('?'); + let mut b_components = b.split('?'); + a_components.next() == b_components.next() +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn validate_for_none_returns_first_valid_uri() { + let validate_against = vec!["https://modrinth.com/a"]; + + let validated = + ValidatedRedirectUri::validate(&None, validate_against.clone(), OAuthClientId(0)) + .unwrap(); + + assert_eq!(validate_against[0], validated.0); + } + + #[test] + fn validate_for_valid_uri_returns_first_matching_uri_ignoring_query_params() { + let validate_against = vec![ + "https://modrinth.com/a?q3=p3&q4=p4", + "https://modrinth.com/a/b/c?q1=p1&q2=p2", + ]; + let to_validate = "https://modrinth.com/a/b/c?query0=param0&query1=param1".to_string(); + + let validated = ValidatedRedirectUri::validate( + &Some(to_validate.clone()), + validate_against, + OAuthClientId(0), + ) + .unwrap(); + + assert_eq!(to_validate, validated.0); + } + + #[test] + fn validate_for_invalid_uri_returns_err() { + let validate_against = vec!["https://modrinth.com/a"]; + let to_validate = "https://modrinth.com/a/b".to_string(); + + let validated = + ValidatedRedirectUri::validate(&Some(to_validate), validate_against, OAuthClientId(0)); + + assert!(validated + .is_err_and(|e| matches!(e.error_type, OAuthErrorType::RedirectUriNotConfigured(_)))); + } +} diff --git a/src/database/models/flow_item.rs b/src/database/models/flow_item.rs index 1676f1fd..fe81e4a8 100644 --- a/src/database/models/flow_item.rs +++ b/src/database/models/flow_item.rs @@ -1,5 +1,5 @@ use super::ids::*; -use crate::auth::oauth::ValidatedRedirectUri; +use crate::auth::oauth::uris::OAuthRedirectUris; use crate::database::models::DatabaseError; use crate::database::redis::RedisPool; use crate::{auth::flows::AuthProvider, models::pats::Scopes}; @@ -40,8 +40,7 @@ pub enum Flow { client_id: OAuthClientId, existing_authorization_id: Option, scopes: Scopes, - validated_redirect_uri: ValidatedRedirectUri, - original_redirect_uri: Option, + redirect_uris: OAuthRedirectUris, state: Option, }, OAuthAuthorizationCodeSupplied { diff --git a/tests/oauth.rs b/tests/oauth.rs index e4713a90..1ae59d32 100644 --- a/tests/oauth.rs +++ b/tests/oauth.rs @@ -47,9 +47,9 @@ async fn oauth_flow_happy_path() { let auth_code = query.get("code").unwrap(); let state = query.get("state").unwrap(); - let foo = query.get("foo").unwrap(); + let foo_val = query.get("foo").unwrap(); assert_eq!(state, original_state); - assert_eq!(foo, "bar"); + assert_eq!(foo_val, "bar"); // Get the token let resp = env From f03f6f5e09840434a5e7159e6f967d2f763becb6 Mon Sep 17 00:00:00 2001 From: OmegaJak Date: Fri, 20 Oct 2023 14:39:01 -0500 Subject: [PATCH 25/72] minor name improvement --- src/queue/session.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/queue/session.rs b/src/queue/session.rs index 0b922b9c..2046d9bd 100644 --- a/src/queue/session.rs +++ b/src/queue/session.rs @@ -49,7 +49,7 @@ impl AuthQueue { std::mem::replace(&mut queue, HashMap::with_capacity(len)) } - pub async fn take_set(queue: &Mutex>) -> HashSet { + pub async fn take_hashset(queue: &Mutex>) -> HashSet { let mut queue = queue.lock().await; let len = queue.len(); @@ -58,8 +58,8 @@ impl AuthQueue { pub async fn index(&self, pool: &PgPool, redis: &RedisPool) -> Result<(), DatabaseError> { let session_queue = self.take_sessions().await; - let pat_queue = Self::take_set(&self.pat_queue).await; - let oauth_access_token_queue = Self::take_set(&self.oauth_access_token_queue).await; + let pat_queue = Self::take_hashset(&self.pat_queue).await; + let oauth_access_token_queue = Self::take_hashset(&self.oauth_access_token_queue).await; if !session_queue.is_empty() || !pat_queue.is_empty() From 022ea62930d7dd190c36b3d28c76a00bcde52090 Mon Sep 17 00:00:00 2001 From: OmegaJak Date: Fri, 20 Oct 2023 14:50:18 -0500 Subject: [PATCH 26/72] Don't compile-time check the test sqlx queries --- tests/common/database.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/common/database.rs b/tests/common/database.rs index 1071c88a..70ddc766 100644 --- a/tests/common/database.rs +++ b/tests/common/database.rs @@ -150,11 +150,11 @@ impl TemporaryDatabase { if needs_update { println!("Dummy data updated, so template DB tables will be dropped and re-created"); // Drop all tables in the database so they can be re-created and later filled with updated dummy data - sqlx::query!("DROP SCHEMA public CASCADE;") + sqlx::query("DROP SCHEMA public CASCADE;") .execute(&pool) .await .unwrap(); - sqlx::query!("CREATE SCHEMA public;") + sqlx::query("CREATE SCHEMA public;") .execute(&pool) .await .unwrap(); From 6eaa35d0f5f57855a8277206417990c03aa7e4d0 Mon Sep 17 00:00:00 2001 From: OmegaJak Date: Fri, 20 Oct 2023 15:59:55 -0500 Subject: [PATCH 27/72] Trying to fix db concurrency problem to get tests to pass --- tests/common/database.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/common/database.rs b/tests/common/database.rs index 70ddc766..1386a496 100644 --- a/tests/common/database.rs +++ b/tests/common/database.rs @@ -175,6 +175,7 @@ impl TemporaryDatabase { }) .await; dummy_data::add_dummy_data(&temporary_test_env).await; + temporary_test_env.db.pool.close().await; } pool.close().await; From 6bdb9c72f83d0dd665e01399d03bc28536da4627 Mon Sep 17 00:00:00 2001 From: OmegaJak Date: Fri, 20 Oct 2023 17:56:32 -0500 Subject: [PATCH 28/72] Try fix from test PR --- tests/common/database.rs | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/tests/common/database.rs b/tests/common/database.rs index 1386a496..988cc028 100644 --- a/tests/common/database.rs +++ b/tests/common/database.rs @@ -178,9 +178,7 @@ impl TemporaryDatabase { temporary_test_env.db.pool.close().await; } pool.close().await; - - // Switch back to main database (as we cant create from template while connected to it) - let pool = PgPool::connect(url.as_str()).await.unwrap(); + drop(pool); // Create the temporary database from the template let create_db_query = format!( @@ -189,7 +187,7 @@ impl TemporaryDatabase { ); sqlx::query(&create_db_query) - .execute(&pool) + .execute(&main_pool) .await .expect("Database creation failed"); From f9246a1211c77979890d1d0a2933a3c92c46206b Mon Sep 17 00:00:00 2001 From: OmegaJak Date: Wed, 25 Oct 2023 12:32:34 -0500 Subject: [PATCH 29/72] Fixes for updated sqlx --- ...438d6222a26926f4b06865b84979fc92564ba.json | 15 ---- ...bab9da5fdd1b7ee72d411a9989deb4ee506bb.json | 15 ++++ ...7734a9af0e944f1671df71f9f4e25d835ffd9.json | 22 ++++++ ...31044f41d210dd64bbbb5e7c2347acc2304e9.json | 22 ++++++ ...ca546db44014cb9381f4f3a6d0af5d9d69511.json | 64 +++++++++++++++++ ...103cac7e4b850d446b29d2efd9757b642fc1c.json | 15 ++++ ...cff1889e5490f0d0d62170ed2b9515ffc5104.json | 47 +++++++++++++ ...3c2efd07808c4a859ab1b1e9e65e16439a8f3.json | 17 +++++ ...e26e68b10522d0f1df3f006d58f6b72be9911.json | 32 +++++++++ ...2876c5ae253df538f4cd4c3701e63137fb01b.json | 70 +++++++++++++++++++ ...2278605311a311def3cbe38846b8ca465737f.json | 19 +++++ ...23d53fcf26489a387d6e170a58a55a814201d.json | 70 +++++++++++++++++++ ...b81f2f47ccdded1e764c04d8b7651d9796ce0.json | 16 +++++ ...338a9e6a9d7763d5614066adb0255139d7956.json | 70 +++++++++++++++++++ ...7aaaaf533a4409633cf00c071049bb6816c96.json | 14 ++++ ...f501784de06acdddeed1db8f570cb04755f1a.json | 22 ++++++ ...41c55bf62928a96d9582d3e223d6473335428.json | 15 ++++ ...f3412349dd6bc5c836d2634bbcb376a6f7c12.json | 22 ++++++ ...b8f3133bea215a89964d78cb652f930465faf.json | 14 ++++ ...23f75ffdaebc76d67a5d35218fa9273d46d53.json | 17 +++++ .../models/oauth_client_authorization_item.rs | 2 +- src/database/models/oauth_client_item.rs | 4 +- src/queue/session.rs | 2 +- src/routes/v3/oauth_clients.rs | 6 +- 24 files changed, 590 insertions(+), 22 deletions(-) delete mode 100644 .sqlx/query-0472045549758d8eef84592908c438d6222a26926f4b06865b84979fc92564ba.json create mode 100644 .sqlx/query-2040e7f0a9b66bc12dc89007b07bab9da5fdd1b7ee72d411a9989deb4ee506bb.json create mode 100644 .sqlx/query-3f8bd0280a59ad4561ca652cebc7734a9af0e944f1671df71f9f4e25d835ffd9.json create mode 100644 .sqlx/query-49d5360751072cc2cb5954cdecb31044f41d210dd64bbbb5e7c2347acc2304e9.json create mode 100644 .sqlx/query-4eb4b52a089d2fa11b1036d7f9fca546db44014cb9381f4f3a6d0af5d9d69511.json create mode 100644 .sqlx/query-65c9f9cd010c14100839cd0b044103cac7e4b850d446b29d2efd9757b642fc1c.json create mode 100644 .sqlx/query-65ddadc9d103ccb9d81e1f52565cff1889e5490f0d0d62170ed2b9515ffc5104.json create mode 100644 .sqlx/query-68ef15f50a067503dce124b50fb3c2efd07808c4a859ab1b1e9e65e16439a8f3.json create mode 100644 .sqlx/query-6b881555e610ddc6796cdcbfd2de26e68b10522d0f1df3f006d58f6b72be9911.json create mode 100644 .sqlx/query-7174cd941ff95260ad9c564daf92876c5ae253df538f4cd4c3701e63137fb01b.json create mode 100644 .sqlx/query-93f6a94a9b288916dbf9999338d2278605311a311def3cbe38846b8ca465737f.json create mode 100644 .sqlx/query-97f5b001699f8de4e80bdb5680f23d53fcf26489a387d6e170a58a55a814201d.json create mode 100644 .sqlx/query-9dadd6926a8429e60cb5fd53285b81f2f47ccdded1e764c04d8b7651d9796ce0.json create mode 100644 .sqlx/query-b0699eef7dbf20f3ea457d1fe16338a9e6a9d7763d5614066adb0255139d7956.json create mode 100644 .sqlx/query-cf84f5e2a594a90b2e7993758807aaaaf533a4409633cf00c071049bb6816c96.json create mode 100644 .sqlx/query-d0b2ddba90ce69a50d0260a191bf501784de06acdddeed1db8f570cb04755f1a.json create mode 100644 .sqlx/query-db1deb79fa509974f1cd68cacd541c55bf62928a96d9582d3e223d6473335428.json create mode 100644 .sqlx/query-e15ff50bd75a49d50975d337f61f3412349dd6bc5c836d2634bbcb376a6f7c12.json create mode 100644 .sqlx/query-e60f725571e7b7b716d19735ab3b8f3133bea215a89964d78cb652f930465faf.json create mode 100644 .sqlx/query-e8cc8895ebc8b1904a43e00f1e123f75ffdaebc76d67a5d35218fa9273d46d53.json diff --git a/.sqlx/query-0472045549758d8eef84592908c438d6222a26926f4b06865b84979fc92564ba.json b/.sqlx/query-0472045549758d8eef84592908c438d6222a26926f4b06865b84979fc92564ba.json deleted file mode 100644 index f4c25dc7..00000000 --- a/.sqlx/query-0472045549758d8eef84592908c438d6222a26926f4b06865b84979fc92564ba.json +++ /dev/null @@ -1,15 +0,0 @@ -{ - "db_name": "PostgreSQL", - "query": "\n UPDATE pats\n SET last_used = $2\n WHERE (id = $1)\n ", - "describe": { - "columns": [], - "parameters": { - "Left": [ - "Int8", - "Timestamptz" - ] - }, - "nullable": [] - }, - "hash": "0472045549758d8eef84592908c438d6222a26926f4b06865b84979fc92564ba" -} diff --git a/.sqlx/query-2040e7f0a9b66bc12dc89007b07bab9da5fdd1b7ee72d411a9989deb4ee506bb.json b/.sqlx/query-2040e7f0a9b66bc12dc89007b07bab9da5fdd1b7ee72d411a9989deb4ee506bb.json new file mode 100644 index 00000000..42c7c646 --- /dev/null +++ b/.sqlx/query-2040e7f0a9b66bc12dc89007b07bab9da5fdd1b7ee72d411a9989deb4ee506bb.json @@ -0,0 +1,15 @@ +{ + "db_name": "PostgreSQL", + "query": "\n UPDATE pats\n SET last_used = $2\n WHERE id IN\n (SELECT * FROM UNNEST($1::bigint[]))\n ", + "describe": { + "columns": [], + "parameters": { + "Left": [ + "Int8Array", + "Timestamptz" + ] + }, + "nullable": [] + }, + "hash": "2040e7f0a9b66bc12dc89007b07bab9da5fdd1b7ee72d411a9989deb4ee506bb" +} diff --git a/.sqlx/query-3f8bd0280a59ad4561ca652cebc7734a9af0e944f1671df71f9f4e25d835ffd9.json b/.sqlx/query-3f8bd0280a59ad4561ca652cebc7734a9af0e944f1671df71f9f4e25d835ffd9.json new file mode 100644 index 00000000..bd806515 --- /dev/null +++ b/.sqlx/query-3f8bd0280a59ad4561ca652cebc7734a9af0e944f1671df71f9f4e25d835ffd9.json @@ -0,0 +1,22 @@ +{ + "db_name": "PostgreSQL", + "query": "SELECT EXISTS(SELECT 1 FROM oauth_client_authorizations WHERE id=$1)", + "describe": { + "columns": [ + { + "ordinal": 0, + "name": "exists", + "type_info": "Bool" + } + ], + "parameters": { + "Left": [ + "Int8" + ] + }, + "nullable": [ + null + ] + }, + "hash": "3f8bd0280a59ad4561ca652cebc7734a9af0e944f1671df71f9f4e25d835ffd9" +} diff --git a/.sqlx/query-49d5360751072cc2cb5954cdecb31044f41d210dd64bbbb5e7c2347acc2304e9.json b/.sqlx/query-49d5360751072cc2cb5954cdecb31044f41d210dd64bbbb5e7c2347acc2304e9.json new file mode 100644 index 00000000..7348f759 --- /dev/null +++ b/.sqlx/query-49d5360751072cc2cb5954cdecb31044f41d210dd64bbbb5e7c2347acc2304e9.json @@ -0,0 +1,22 @@ +{ + "db_name": "PostgreSQL", + "query": "SELECT EXISTS(SELECT 1 FROM oauth_client_redirect_uris WHERE id=$1)", + "describe": { + "columns": [ + { + "ordinal": 0, + "name": "exists", + "type_info": "Bool" + } + ], + "parameters": { + "Left": [ + "Int8" + ] + }, + "nullable": [ + null + ] + }, + "hash": "49d5360751072cc2cb5954cdecb31044f41d210dd64bbbb5e7c2347acc2304e9" +} diff --git a/.sqlx/query-4eb4b52a089d2fa11b1036d7f9fca546db44014cb9381f4f3a6d0af5d9d69511.json b/.sqlx/query-4eb4b52a089d2fa11b1036d7f9fca546db44014cb9381f4f3a6d0af5d9d69511.json new file mode 100644 index 00000000..cbc186c5 --- /dev/null +++ b/.sqlx/query-4eb4b52a089d2fa11b1036d7f9fca546db44014cb9381f4f3a6d0af5d9d69511.json @@ -0,0 +1,64 @@ +{ + "db_name": "PostgreSQL", + "query": "\n SELECT \n auths.id,\n auths.client_id,\n auths.user_id,\n auths.scopes,\n auths.created,\n clients.name as client_name,\n clients.icon_url as client_icon_url,\n clients.created_by as client_created_by\n FROM oauth_client_authorizations auths\n JOIN oauth_clients clients ON clients.id = auths.client_id\n WHERE user_id=$1\n ", + "describe": { + "columns": [ + { + "ordinal": 0, + "name": "id", + "type_info": "Int8" + }, + { + "ordinal": 1, + "name": "client_id", + "type_info": "Int8" + }, + { + "ordinal": 2, + "name": "user_id", + "type_info": "Int8" + }, + { + "ordinal": 3, + "name": "scopes", + "type_info": "Int8" + }, + { + "ordinal": 4, + "name": "created", + "type_info": "Timestamptz" + }, + { + "ordinal": 5, + "name": "client_name", + "type_info": "Text" + }, + { + "ordinal": 6, + "name": "client_icon_url", + "type_info": "Text" + }, + { + "ordinal": 7, + "name": "client_created_by", + "type_info": "Int8" + } + ], + "parameters": { + "Left": [ + "Int8" + ] + }, + "nullable": [ + false, + false, + false, + false, + false, + false, + true, + false + ] + }, + "hash": "4eb4b52a089d2fa11b1036d7f9fca546db44014cb9381f4f3a6d0af5d9d69511" +} diff --git a/.sqlx/query-65c9f9cd010c14100839cd0b044103cac7e4b850d446b29d2efd9757b642fc1c.json b/.sqlx/query-65c9f9cd010c14100839cd0b044103cac7e4b850d446b29d2efd9757b642fc1c.json new file mode 100644 index 00000000..f94f7a46 --- /dev/null +++ b/.sqlx/query-65c9f9cd010c14100839cd0b044103cac7e4b850d446b29d2efd9757b642fc1c.json @@ -0,0 +1,15 @@ +{ + "db_name": "PostgreSQL", + "query": "\n UPDATE oauth_access_tokens\n SET last_used = $2\n WHERE id IN\n (SELECT * FROM UNNEST($1::bigint[]))\n ", + "describe": { + "columns": [], + "parameters": { + "Left": [ + "Int8Array", + "Timestamptz" + ] + }, + "nullable": [] + }, + "hash": "65c9f9cd010c14100839cd0b044103cac7e4b850d446b29d2efd9757b642fc1c" +} diff --git a/.sqlx/query-65ddadc9d103ccb9d81e1f52565cff1889e5490f0d0d62170ed2b9515ffc5104.json b/.sqlx/query-65ddadc9d103ccb9d81e1f52565cff1889e5490f0d0d62170ed2b9515ffc5104.json new file mode 100644 index 00000000..de61d14e --- /dev/null +++ b/.sqlx/query-65ddadc9d103ccb9d81e1f52565cff1889e5490f0d0d62170ed2b9515ffc5104.json @@ -0,0 +1,47 @@ +{ + "db_name": "PostgreSQL", + "query": "\n SELECT id, client_id, user_id, scopes, created\n FROM oauth_client_authorizations\n WHERE client_id=$1 AND user_id=$2\n ", + "describe": { + "columns": [ + { + "ordinal": 0, + "name": "id", + "type_info": "Int8" + }, + { + "ordinal": 1, + "name": "client_id", + "type_info": "Int8" + }, + { + "ordinal": 2, + "name": "user_id", + "type_info": "Int8" + }, + { + "ordinal": 3, + "name": "scopes", + "type_info": "Int8" + }, + { + "ordinal": 4, + "name": "created", + "type_info": "Timestamptz" + } + ], + "parameters": { + "Left": [ + "Int8", + "Int8" + ] + }, + "nullable": [ + false, + false, + false, + false, + false + ] + }, + "hash": "65ddadc9d103ccb9d81e1f52565cff1889e5490f0d0d62170ed2b9515ffc5104" +} diff --git a/.sqlx/query-68ef15f50a067503dce124b50fb3c2efd07808c4a859ab1b1e9e65e16439a8f3.json b/.sqlx/query-68ef15f50a067503dce124b50fb3c2efd07808c4a859ab1b1e9e65e16439a8f3.json new file mode 100644 index 00000000..bbf10def --- /dev/null +++ b/.sqlx/query-68ef15f50a067503dce124b50fb3c2efd07808c4a859ab1b1e9e65e16439a8f3.json @@ -0,0 +1,17 @@ +{ + "db_name": "PostgreSQL", + "query": "\n INSERT INTO oauth_client_authorizations (\n id, client_id, user_id, scopes\n )\n VALUES (\n $1, $2, $3, $4\n )\n ON CONFLICT (id)\n DO UPDATE SET scopes = EXCLUDED.scopes\n ", + "describe": { + "columns": [], + "parameters": { + "Left": [ + "Int8", + "Int8", + "Int8", + "Int8" + ] + }, + "nullable": [] + }, + "hash": "68ef15f50a067503dce124b50fb3c2efd07808c4a859ab1b1e9e65e16439a8f3" +} diff --git a/.sqlx/query-6b881555e610ddc6796cdcbfd2de26e68b10522d0f1df3f006d58f6b72be9911.json b/.sqlx/query-6b881555e610ddc6796cdcbfd2de26e68b10522d0f1df3f006d58f6b72be9911.json new file mode 100644 index 00000000..797214d3 --- /dev/null +++ b/.sqlx/query-6b881555e610ddc6796cdcbfd2de26e68b10522d0f1df3f006d58f6b72be9911.json @@ -0,0 +1,32 @@ +{ + "db_name": "PostgreSQL", + "query": "\n INSERT INTO oauth_access_tokens (\n id, authorization_id, token_hash, scopes, last_used\n )\n VALUES (\n $1, $2, $3, $4, $5\n )\n RETURNING created, expires\n ", + "describe": { + "columns": [ + { + "ordinal": 0, + "name": "created", + "type_info": "Timestamptz" + }, + { + "ordinal": 1, + "name": "expires", + "type_info": "Timestamptz" + } + ], + "parameters": { + "Left": [ + "Int8", + "Int8", + "Text", + "Int8", + "Timestamptz" + ] + }, + "nullable": [ + false, + false + ] + }, + "hash": "6b881555e610ddc6796cdcbfd2de26e68b10522d0f1df3f006d58f6b72be9911" +} diff --git a/.sqlx/query-7174cd941ff95260ad9c564daf92876c5ae253df538f4cd4c3701e63137fb01b.json b/.sqlx/query-7174cd941ff95260ad9c564daf92876c5ae253df538f4cd4c3701e63137fb01b.json new file mode 100644 index 00000000..454d523c --- /dev/null +++ b/.sqlx/query-7174cd941ff95260ad9c564daf92876c5ae253df538f4cd4c3701e63137fb01b.json @@ -0,0 +1,70 @@ +{ + "db_name": "PostgreSQL", + "query": "\n SELECT\n tokens.id,\n tokens.authorization_id,\n tokens.token_hash,\n tokens.scopes,\n tokens.created,\n tokens.expires,\n tokens.last_used,\n auths.client_id,\n auths.user_id\n FROM oauth_access_tokens tokens\n JOIN oauth_client_authorizations auths\n ON tokens.authorization_id = auths.id\n WHERE tokens.token_hash = $1\n ", + "describe": { + "columns": [ + { + "ordinal": 0, + "name": "id", + "type_info": "Int8" + }, + { + "ordinal": 1, + "name": "authorization_id", + "type_info": "Int8" + }, + { + "ordinal": 2, + "name": "token_hash", + "type_info": "Text" + }, + { + "ordinal": 3, + "name": "scopes", + "type_info": "Int8" + }, + { + "ordinal": 4, + "name": "created", + "type_info": "Timestamptz" + }, + { + "ordinal": 5, + "name": "expires", + "type_info": "Timestamptz" + }, + { + "ordinal": 6, + "name": "last_used", + "type_info": "Timestamptz" + }, + { + "ordinal": 7, + "name": "client_id", + "type_info": "Int8" + }, + { + "ordinal": 8, + "name": "user_id", + "type_info": "Int8" + } + ], + "parameters": { + "Left": [ + "Text" + ] + }, + "nullable": [ + false, + false, + false, + false, + false, + false, + true, + false, + false + ] + }, + "hash": "7174cd941ff95260ad9c564daf92876c5ae253df538f4cd4c3701e63137fb01b" +} diff --git a/.sqlx/query-93f6a94a9b288916dbf9999338d2278605311a311def3cbe38846b8ca465737f.json b/.sqlx/query-93f6a94a9b288916dbf9999338d2278605311a311def3cbe38846b8ca465737f.json new file mode 100644 index 00000000..598dfeff --- /dev/null +++ b/.sqlx/query-93f6a94a9b288916dbf9999338d2278605311a311def3cbe38846b8ca465737f.json @@ -0,0 +1,19 @@ +{ + "db_name": "PostgreSQL", + "query": "\n INSERT INTO oauth_clients (\n id, name, icon_url, max_scopes, secret_hash, created_by\n )\n VALUES (\n $1, $2, $3, $4, $5, $6\n )\n ", + "describe": { + "columns": [], + "parameters": { + "Left": [ + "Int8", + "Text", + "Text", + "Int8", + "Text", + "Int8" + ] + }, + "nullable": [] + }, + "hash": "93f6a94a9b288916dbf9999338d2278605311a311def3cbe38846b8ca465737f" +} diff --git a/.sqlx/query-97f5b001699f8de4e80bdb5680f23d53fcf26489a387d6e170a58a55a814201d.json b/.sqlx/query-97f5b001699f8de4e80bdb5680f23d53fcf26489a387d6e170a58a55a814201d.json new file mode 100644 index 00000000..aa87c762 --- /dev/null +++ b/.sqlx/query-97f5b001699f8de4e80bdb5680f23d53fcf26489a387d6e170a58a55a814201d.json @@ -0,0 +1,70 @@ +{ + "db_name": "PostgreSQL", + "query": "\n SELECT\n clients.id,\n clients.name,\n clients.icon_url,\n clients.max_scopes,\n clients.secret_hash,\n clients.created,\n clients.created_by,\n uris.uri_ids,\n uris.uri_vals\n FROM oauth_clients clients\n LEFT JOIN (\n SELECT client_id, array_agg(id) as uri_ids, array_agg(uri) as uri_vals\n FROM oauth_client_redirect_uris\n GROUP BY client_id\n ) uris ON clients.id = uris.client_id\n WHERE created_by = $1", + "describe": { + "columns": [ + { + "ordinal": 0, + "name": "id", + "type_info": "Int8" + }, + { + "ordinal": 1, + "name": "name", + "type_info": "Text" + }, + { + "ordinal": 2, + "name": "icon_url", + "type_info": "Text" + }, + { + "ordinal": 3, + "name": "max_scopes", + "type_info": "Int8" + }, + { + "ordinal": 4, + "name": "secret_hash", + "type_info": "Text" + }, + { + "ordinal": 5, + "name": "created", + "type_info": "Timestamptz" + }, + { + "ordinal": 6, + "name": "created_by", + "type_info": "Int8" + }, + { + "ordinal": 7, + "name": "uri_ids", + "type_info": "Int8Array" + }, + { + "ordinal": 8, + "name": "uri_vals", + "type_info": "TextArray" + } + ], + "parameters": { + "Left": [ + "Int8" + ] + }, + "nullable": [ + false, + false, + true, + false, + false, + false, + false, + null, + null + ] + }, + "hash": "97f5b001699f8de4e80bdb5680f23d53fcf26489a387d6e170a58a55a814201d" +} diff --git a/.sqlx/query-9dadd6926a8429e60cb5fd53285b81f2f47ccdded1e764c04d8b7651d9796ce0.json b/.sqlx/query-9dadd6926a8429e60cb5fd53285b81f2f47ccdded1e764c04d8b7651d9796ce0.json new file mode 100644 index 00000000..4c3c291b --- /dev/null +++ b/.sqlx/query-9dadd6926a8429e60cb5fd53285b81f2f47ccdded1e764c04d8b7651d9796ce0.json @@ -0,0 +1,16 @@ +{ + "db_name": "PostgreSQL", + "query": "\n INSERT INTO oauth_client_redirect_uris (id, client_id, uri)\n SELECT * FROM UNNEST($1::bigint[], $2::bigint[], $3::varchar[])\n ", + "describe": { + "columns": [], + "parameters": { + "Left": [ + "Int8Array", + "Int8Array", + "VarcharArray" + ] + }, + "nullable": [] + }, + "hash": "9dadd6926a8429e60cb5fd53285b81f2f47ccdded1e764c04d8b7651d9796ce0" +} diff --git a/.sqlx/query-b0699eef7dbf20f3ea457d1fe16338a9e6a9d7763d5614066adb0255139d7956.json b/.sqlx/query-b0699eef7dbf20f3ea457d1fe16338a9e6a9d7763d5614066adb0255139d7956.json new file mode 100644 index 00000000..14be2209 --- /dev/null +++ b/.sqlx/query-b0699eef7dbf20f3ea457d1fe16338a9e6a9d7763d5614066adb0255139d7956.json @@ -0,0 +1,70 @@ +{ + "db_name": "PostgreSQL", + "query": "\n SELECT\n clients.id,\n clients.name,\n clients.icon_url,\n clients.max_scopes,\n clients.secret_hash,\n clients.created,\n clients.created_by,\n uris.uri_ids,\n uris.uri_vals\n FROM oauth_clients clients\n LEFT JOIN (\n SELECT client_id, array_agg(id) as uri_ids, array_agg(uri) as uri_vals\n FROM oauth_client_redirect_uris\n GROUP BY client_id\n ) uris ON clients.id = uris.client_id\n WHERE clients.id = $1", + "describe": { + "columns": [ + { + "ordinal": 0, + "name": "id", + "type_info": "Int8" + }, + { + "ordinal": 1, + "name": "name", + "type_info": "Text" + }, + { + "ordinal": 2, + "name": "icon_url", + "type_info": "Text" + }, + { + "ordinal": 3, + "name": "max_scopes", + "type_info": "Int8" + }, + { + "ordinal": 4, + "name": "secret_hash", + "type_info": "Text" + }, + { + "ordinal": 5, + "name": "created", + "type_info": "Timestamptz" + }, + { + "ordinal": 6, + "name": "created_by", + "type_info": "Int8" + }, + { + "ordinal": 7, + "name": "uri_ids", + "type_info": "Int8Array" + }, + { + "ordinal": 8, + "name": "uri_vals", + "type_info": "TextArray" + } + ], + "parameters": { + "Left": [ + "Int8" + ] + }, + "nullable": [ + false, + false, + true, + false, + false, + false, + false, + null, + null + ] + }, + "hash": "b0699eef7dbf20f3ea457d1fe16338a9e6a9d7763d5614066adb0255139d7956" +} diff --git a/.sqlx/query-cf84f5e2a594a90b2e7993758807aaaaf533a4409633cf00c071049bb6816c96.json b/.sqlx/query-cf84f5e2a594a90b2e7993758807aaaaf533a4409633cf00c071049bb6816c96.json new file mode 100644 index 00000000..0833383c --- /dev/null +++ b/.sqlx/query-cf84f5e2a594a90b2e7993758807aaaaf533a4409633cf00c071049bb6816c96.json @@ -0,0 +1,14 @@ +{ + "db_name": "PostgreSQL", + "query": "\n DELETE FROM oauth_client_redirect_uris\n WHERE id IN\n (SELECT * FROM UNNEST($1::bigint[]))\n ", + "describe": { + "columns": [], + "parameters": { + "Left": [ + "Int8Array" + ] + }, + "nullable": [] + }, + "hash": "cf84f5e2a594a90b2e7993758807aaaaf533a4409633cf00c071049bb6816c96" +} diff --git a/.sqlx/query-d0b2ddba90ce69a50d0260a191bf501784de06acdddeed1db8f570cb04755f1a.json b/.sqlx/query-d0b2ddba90ce69a50d0260a191bf501784de06acdddeed1db8f570cb04755f1a.json new file mode 100644 index 00000000..9a0a5492 --- /dev/null +++ b/.sqlx/query-d0b2ddba90ce69a50d0260a191bf501784de06acdddeed1db8f570cb04755f1a.json @@ -0,0 +1,22 @@ +{ + "db_name": "PostgreSQL", + "query": "SELECT EXISTS(SELECT 1 FROM oauth_clients WHERE id=$1)", + "describe": { + "columns": [ + { + "ordinal": 0, + "name": "exists", + "type_info": "Bool" + } + ], + "parameters": { + "Left": [ + "Int8" + ] + }, + "nullable": [ + null + ] + }, + "hash": "d0b2ddba90ce69a50d0260a191bf501784de06acdddeed1db8f570cb04755f1a" +} diff --git a/.sqlx/query-db1deb79fa509974f1cd68cacd541c55bf62928a96d9582d3e223d6473335428.json b/.sqlx/query-db1deb79fa509974f1cd68cacd541c55bf62928a96d9582d3e223d6473335428.json new file mode 100644 index 00000000..d80d7c90 --- /dev/null +++ b/.sqlx/query-db1deb79fa509974f1cd68cacd541c55bf62928a96d9582d3e223d6473335428.json @@ -0,0 +1,15 @@ +{ + "db_name": "PostgreSQL", + "query": "\n DELETE FROM oauth_client_authorizations\n WHERE client_id=$1 AND user_id=$2\n ", + "describe": { + "columns": [], + "parameters": { + "Left": [ + "Int8", + "Int8" + ] + }, + "nullable": [] + }, + "hash": "db1deb79fa509974f1cd68cacd541c55bf62928a96d9582d3e223d6473335428" +} diff --git a/.sqlx/query-e15ff50bd75a49d50975d337f61f3412349dd6bc5c836d2634bbcb376a6f7c12.json b/.sqlx/query-e15ff50bd75a49d50975d337f61f3412349dd6bc5c836d2634bbcb376a6f7c12.json new file mode 100644 index 00000000..5ed8687e --- /dev/null +++ b/.sqlx/query-e15ff50bd75a49d50975d337f61f3412349dd6bc5c836d2634bbcb376a6f7c12.json @@ -0,0 +1,22 @@ +{ + "db_name": "PostgreSQL", + "query": "SELECT EXISTS(SELECT 1 FROM oauth_access_tokens WHERE id=$1)", + "describe": { + "columns": [ + { + "ordinal": 0, + "name": "exists", + "type_info": "Bool" + } + ], + "parameters": { + "Left": [ + "Int8" + ] + }, + "nullable": [ + null + ] + }, + "hash": "e15ff50bd75a49d50975d337f61f3412349dd6bc5c836d2634bbcb376a6f7c12" +} diff --git a/.sqlx/query-e60f725571e7b7b716d19735ab3b8f3133bea215a89964d78cb652f930465faf.json b/.sqlx/query-e60f725571e7b7b716d19735ab3b8f3133bea215a89964d78cb652f930465faf.json new file mode 100644 index 00000000..29200a65 --- /dev/null +++ b/.sqlx/query-e60f725571e7b7b716d19735ab3b8f3133bea215a89964d78cb652f930465faf.json @@ -0,0 +1,14 @@ +{ + "db_name": "PostgreSQL", + "query": "\n DELETE FROM oauth_clients\n WHERE id = $1\n ", + "describe": { + "columns": [], + "parameters": { + "Left": [ + "Int8" + ] + }, + "nullable": [] + }, + "hash": "e60f725571e7b7b716d19735ab3b8f3133bea215a89964d78cb652f930465faf" +} diff --git a/.sqlx/query-e8cc8895ebc8b1904a43e00f1e123f75ffdaebc76d67a5d35218fa9273d46d53.json b/.sqlx/query-e8cc8895ebc8b1904a43e00f1e123f75ffdaebc76d67a5d35218fa9273d46d53.json new file mode 100644 index 00000000..d01d5876 --- /dev/null +++ b/.sqlx/query-e8cc8895ebc8b1904a43e00f1e123f75ffdaebc76d67a5d35218fa9273d46d53.json @@ -0,0 +1,17 @@ +{ + "db_name": "PostgreSQL", + "query": "\n UPDATE oauth_clients\n SET name = $1, icon_url = $2, max_scopes = $3\n WHERE (id = $4)\n ", + "describe": { + "columns": [], + "parameters": { + "Left": [ + "Text", + "Text", + "Int8", + "Int8" + ] + }, + "nullable": [] + }, + "hash": "e8cc8895ebc8b1904a43e00f1e123f75ffdaebc76d67a5d35218fa9273d46d53" +} diff --git a/src/database/models/oauth_client_authorization_item.rs b/src/database/models/oauth_client_authorization_item.rs index 87603452..befa8679 100644 --- a/src/database/models/oauth_client_authorization_item.rs +++ b/src/database/models/oauth_client_authorization_item.rs @@ -115,7 +115,7 @@ impl OAuthClientAuthorization { user_id.0, scopes.bits() as i64, ) - .execute(&mut *transaction) + .execute(&mut **transaction) .await?; Ok(()) diff --git a/src/database/models/oauth_client_item.rs b/src/database/models/oauth_client_item.rs index ae2f7d70..7088804e 100644 --- a/src/database/models/oauth_client_item.rs +++ b/src/database/models/oauth_client_item.rs @@ -129,10 +129,10 @@ impl OAuthClient { self.secret_hash, self.created_by.0 ) - .execute(&mut *transaction) + .execute(&mut **transaction) .await?; - Self::insert_redirect_uris(&self.redirect_uris, transaction).await?; + Self::insert_redirect_uris(&self.redirect_uris, &mut **transaction).await?; Ok(()) } diff --git a/src/queue/session.rs b/src/queue/session.rs index 2046d9bd..15e7731a 100644 --- a/src/queue/session.rs +++ b/src/queue/session.rs @@ -157,7 +157,7 @@ async fn process_oauth_access_tokens( &ids[..], Utc::now() ) - .execute(&mut *transaction) + .execute(&mut **transaction) .await?; Ok(()) } diff --git a/src/routes/v3/oauth_clients.rs b/src/routes/v3/oauth_clients.rs index edccf5d7..e2efd3c9 100644 --- a/src/routes/v3/oauth_clients.rs +++ b/src/routes/v3/oauth_clients.rs @@ -258,7 +258,7 @@ pub async fn oauth_client_edit( let mut transaction = pool.begin().await?; updated_client - .update_editable_fields(&mut transaction) + .update_editable_fields(&mut *transaction) .await?; if let Some(redirects) = redirect_uris { @@ -376,11 +376,11 @@ async fn edit_redirects( &mut *transaction, ) .await?; - OAuthClient::insert_redirect_uris(&redirects_to_add, &mut *transaction).await?; + OAuthClient::insert_redirect_uris(&redirects_to_add, &mut **transaction).await?; let mut redirects_to_remove = existing_client.redirect_uris.clone(); redirects_to_remove.retain(|r| !updated_redirects.contains(&r.uri)); - OAuthClient::remove_redirect_uris(redirects_to_remove.iter().map(|r| r.id), &mut *transaction) + OAuthClient::remove_redirect_uris(redirects_to_remove.iter().map(|r| r.id), &mut **transaction) .await?; Ok(()) From f486d99922c0e8efa4144da7180e44f2730ad247 Mon Sep 17 00:00:00 2001 From: Jackson Kruger Date: Thu, 26 Oct 2023 11:03:36 -0500 Subject: [PATCH 30/72] Prevent restricted scopes from being requested or issued --- src/auth/oauth/mod.rs | 4 +++- src/queue/session.rs | 4 ++-- src/routes/v3/oauth_clients.rs | 1 + src/util/validate.rs | 12 +++++++++++ tests/common/api_v3/oauth_clients.rs | 7 ++----- tests/oauth_clients.rs | 30 +++++++++++++++++++++++++--- 6 files changed, 47 insertions(+), 11 deletions(-) diff --git a/src/auth/oauth/mod.rs b/src/auth/oauth/mod.rs index a68f934d..a455b061 100644 --- a/src/auth/oauth/mod.rs +++ b/src/auth/oauth/mod.rs @@ -103,7 +103,7 @@ pub async fn init_oauth( }) })?; - if !client.max_scopes.contains(requested_scopes) { + if !client.max_scopes.contains(requested_scopes) || requested_scopes.is_restricted() { return Err(OAuthError::redirect( OAuthErrorType::ScopesTooBroad, &oauth_info.state, @@ -259,6 +259,8 @@ pub async fn request_token( )); } + let scopes = scopes - Scopes::restricted(); + let mut transaction = pool.begin().await.map_err(OAuthError::error)?; let token_id = generate_oauth_access_token_id(&mut transaction) .await diff --git a/src/queue/session.rs b/src/queue/session.rs index 15e7731a..ee4568a4 100644 --- a/src/queue/session.rs +++ b/src/queue/session.rs @@ -133,7 +133,7 @@ impl AuthQueue { PersonalAccessToken::clear_cache(clear_cache_pats, redis).await?; - process_oauth_access_tokens(oauth_access_token_queue, &mut transaction).await?; + update_oauth_access_token_last_used(oauth_access_token_queue, &mut transaction).await?; transaction.commit().await?; } @@ -142,7 +142,7 @@ impl AuthQueue { } } -async fn process_oauth_access_tokens( +async fn update_oauth_access_token_last_used( oauth_access_token_queue: HashSet, transaction: &mut sqlx::Transaction<'_, sqlx::Postgres>, ) -> Result<(), DatabaseError> { diff --git a/src/routes/v3/oauth_clients.rs b/src/routes/v3/oauth_clients.rs index e2efd3c9..900d5128 100644 --- a/src/routes/v3/oauth_clients.rs +++ b/src/routes/v3/oauth_clients.rs @@ -93,6 +93,7 @@ pub struct NewOAuthApp { )] pub icon_url: Option, + #[validate(custom(function = "crate::util::validate::validate_no_restricted_scopes"))] pub max_scopes: Scopes, #[validate(length(min = 1))] diff --git a/src/util/validate.rs b/src/util/validate.rs index 474bdf7a..7ebac27c 100644 --- a/src/util/validate.rs +++ b/src/util/validate.rs @@ -3,6 +3,8 @@ use lazy_static::lazy_static; use regex::Regex; use validator::{ValidationErrors, ValidationErrorsKind}; +use crate::models::pats::Scopes; + lazy_static! { pub static ref RE_URL_SAFE: Regex = Regex::new(r#"^[a-zA-Z0-9!@$()`.+,_"-]*$"#).unwrap(); } @@ -91,6 +93,16 @@ pub fn validate_url(value: &str) -> Result<(), validator::ValidationError> { Ok(()) } +pub fn validate_no_restricted_scopes(value: &Scopes) -> Result<(), validator::ValidationError> { + if value.is_restricted() { + return Err(validator::ValidationError::new( + "Restricted scopes not allowed", + )); + } + + Ok(()) +} + pub fn validate_name(value: &str) -> Result<(), validator::ValidationError> { if value.trim().is_empty() { return Err(validator::ValidationError::new( diff --git a/tests/common/api_v3/oauth_clients.rs b/tests/common/api_v3/oauth_clients.rs index c014fd14..b2251575 100644 --- a/tests/common/api_v3/oauth_clients.rs +++ b/tests/common/api_v3/oauth_clients.rs @@ -24,7 +24,7 @@ impl ApiV3 { max_scopes: Scopes, redirect_uris: Vec, pat: &str, - ) -> OAuthClientCreationResult { + ) -> ServiceResponse { let max_scopes = max_scopes.bits(); let req = TestRequest::post() .uri("/v3/oauth_app") @@ -36,10 +36,7 @@ impl ApiV3 { })) .to_request(); - let resp = self.call(req).await; - assert_status(&resp, StatusCode::OK); - - test::read_body_json(resp).await + self.call(req).await } pub async fn get_user_oauth_clients(&self, user_id: &str, pat: &str) -> Vec { diff --git a/tests/oauth_clients.rs b/tests/oauth_clients.rs index e277b267..3ffb1471 100644 --- a/tests/oauth_clients.rs +++ b/tests/oauth_clients.rs @@ -1,11 +1,15 @@ use actix_http::StatusCode; +use actix_web::test; use common::{ database::{FRIEND_USER_ID, FRIEND_USER_PAT, USER_USER_ID, USER_USER_PAT}, dummy_data::DummyOAuthClientAlpha, environment::with_test_environment, get_json_val_str, }; -use labrinth::{models::pats::Scopes, routes::v3::oauth_clients::OAuthClientEdit}; +use labrinth::{ + models::{oauth_clients::OAuthClientCreationResult, pats::Scopes}, + routes::v3::oauth_clients::OAuthClientEdit, +}; use crate::common::{asserts::assert_status, database::USER_USER_ID_PARSED}; @@ -19,15 +23,17 @@ async fn can_create_edit_get_oauth_client() { "https://modrinth.com".to_string(), "https://modrinth.com/a".to_string(), ]; - let creation_result = env + let resp = env .v3 .add_oauth_client( client_name.clone(), - Scopes::all(), + Scopes::all() - Scopes::restricted(), redirect_uris.clone(), FRIEND_USER_PAT, ) .await; + assert_status(&resp, StatusCode::OK); + let creation_result: OAuthClientCreationResult = test::read_body_json(resp).await; let client_id = get_json_val_str(creation_result.client.id); let icon_url = Some("https://modrinth.com/icon".to_string()); @@ -61,6 +67,24 @@ async fn can_create_edit_get_oauth_client() { .await; } +#[actix_rt::test] +async fn create_oauth_client_with_restricted_scopes_fails() { + with_test_environment(|env| async move { + let resp = env + .v3 + .add_oauth_client( + "test_client".to_string(), + Scopes::restricted(), + vec!["https://modrinth.com".to_string()], + FRIEND_USER_PAT, + ) + .await; + + assert_status(&resp, StatusCode::BAD_REQUEST); + }) + .await; +} + #[actix_rt::test] async fn can_delete_oauth_client() { with_test_environment(|env| async move { From 723a377cf3df298506fdd926a7cb836f07c58b0e Mon Sep 17 00:00:00 2001 From: Jackson Kruger Date: Thu, 26 Oct 2023 14:22:50 -0500 Subject: [PATCH 31/72] Get OAuth client(s) --- src/auth/checks.rs | 37 ++++++++++++ src/auth/oauth/mod.rs | 2 +- src/database/models/mod.rs | 1 + src/database/models/oauth_client_item.rs | 58 ++++++++++++------ src/models/oauth_clients.rs | 5 ++ src/models/users.rs | 16 ----- src/routes/v3/oauth_clients.rs | 75 +++++++++++++++++++++--- tests/common/api_v3/oauth_clients.rs | 21 ++++++- tests/common/dummy_data.rs | 2 +- tests/oauth_clients.rs | 39 +++++++++++- 10 files changed, 212 insertions(+), 44 deletions(-) diff --git a/src/auth/checks.rs b/src/auth/checks.rs index b358494c..8d8fab68 100644 --- a/src/auth/checks.rs +++ b/src/auth/checks.rs @@ -8,6 +8,26 @@ use crate::routes::ApiError; use actix_web::web; use sqlx::PgPool; +pub trait ValidateAuthorized { + fn validate_authorized(&self, user_option: Option<&User>) -> Result<(), ApiError>; +} + +pub trait ValidateAllAuthorized { + fn validate_all_authorized(self, user_option: Option<&User>) -> Result<(), ApiError>; +} + +impl<'a, T, A> ValidateAllAuthorized for T +where + T: IntoIterator, + A: ValidateAuthorized + 'a, +{ + fn validate_all_authorized(self, user_option: Option<&User>) -> Result<(), ApiError> { + self.into_iter() + .map(|c| c.validate_authorized(user_option.clone())) + .collect() + } +} + pub async fn is_authorized( project_data: &Project, user_option: &Option, @@ -156,6 +176,23 @@ pub async fn is_authorized_version( Ok(authorized) } +impl ValidateAuthorized for crate::database::models::OAuthClient { + fn validate_authorized(&self, user_option: Option<&User>) -> Result<(), ApiError> { + if let Some(user) = user_option { + if user.role.is_mod() || user.id == self.created_by.into() { + return Ok(()); + } else { + return Err(crate::routes::ApiError::CustomAuthentication( + "You don't have sufficient permissions to interact with this OAuth application" + .to_string(), + )); + } + } + + Ok(()) + } +} + pub async fn filter_authorized_versions( versions: Vec, user_option: &Option, diff --git a/src/auth/oauth/mod.rs b/src/auth/oauth/mod.rs index a455b061..e232e331 100644 --- a/src/auth/oauth/mod.rs +++ b/src/auth/oauth/mod.rs @@ -103,7 +103,7 @@ pub async fn init_oauth( }) })?; - if !client.max_scopes.contains(requested_scopes) || requested_scopes.is_restricted() { + if !client.max_scopes.contains(requested_scopes) { return Err(OAuthError::redirect( OAuthErrorType::ScopesTooBroad, &oauth_info.state, diff --git a/src/database/models/mod.rs b/src/database/models/mod.rs index ffb2179d..5d5bc34f 100644 --- a/src/database/models/mod.rs +++ b/src/database/models/mod.rs @@ -22,6 +22,7 @@ pub mod version_item; pub use collection_item::Collection; pub use ids::*; pub use image_item::Image; +pub use oauth_client_item::OAuthClient; pub use organization_item::Organization; pub use project_item::Project; pub use team_item::Team; diff --git a/src/database/models/oauth_client_item.rs b/src/database/models/oauth_client_item.rs index 7088804e..23be598b 100644 --- a/src/database/models/oauth_client_item.rs +++ b/src/database/models/oauth_client_item.rs @@ -1,9 +1,10 @@ use chrono::{DateTime, Utc}; +use futures::TryStreamExt; use itertools::Itertools; use serde::{Deserialize, Serialize}; use sha2::Digest; -use crate::models::pats::Scopes; +use crate::models::{ids::base62_impl::parse_base62, pats::Scopes}; use super::{DatabaseError, OAuthClientId, OAuthRedirectUriId, UserId}; @@ -40,26 +41,29 @@ struct ClientQueryResult { macro_rules! select_clients_with_predicate { ($predicate:tt, $param:ident) => { + // The columns in this query have nullability type hints, because for some reason + // the combination of the JOIN and filter using ANY makes sqlx think all columns are nullable + // https://docs.rs/sqlx/latest/sqlx/macro.query.html#force-nullable sqlx::query_as!( ClientQueryResult, - " + r#" SELECT - clients.id, - clients.name, - clients.icon_url, - clients.max_scopes, - clients.secret_hash, - clients.created, - clients.created_by, - uris.uri_ids, - uris.uri_vals + clients.id as "id!", + clients.name as "name!", + clients.icon_url as "icon_url?", + clients.max_scopes as "max_scopes!", + clients.secret_hash as "secret_hash!", + clients.created as "created!", + clients.created_by as "created_by!", + uris.uri_ids as "uri_ids?", + uris.uri_vals as "uri_vals?" FROM oauth_clients clients LEFT JOIN ( SELECT client_id, array_agg(id) as uri_ids, array_agg(uri) as uri_vals FROM oauth_client_redirect_uris GROUP BY client_id ) uris ON clients.id = uris.client_id - " + "# + $predicate, $param ) @@ -71,12 +75,32 @@ impl OAuthClient { id: OAuthClientId, exec: impl sqlx::Executor<'_, Database = sqlx::Postgres>, ) -> Result, DatabaseError> { - let client_id_param = id.0; - let value = select_clients_with_predicate!("WHERE clients.id = $1", client_id_param) - .fetch_optional(exec) - .await?; + Ok(Self::get_many(&[id], exec).await?.into_iter().next()) + } + + pub async fn get_many_str( + ids: &[impl ToString], + exec: impl sqlx::Executor<'_, Database = sqlx::Postgres>, + ) -> Result, DatabaseError> { + let ids = ids + .iter() + .flat_map(|x| parse_base62(&x.to_string()).map(|x| OAuthClientId(x as i64))) + .collect::>(); + Self::get_many(&ids, exec).await + } + + pub async fn get_many( + ids: &[OAuthClientId], + exec: impl sqlx::Executor<'_, Database = sqlx::Postgres>, + ) -> Result, DatabaseError> { + let ids = ids.into_iter().map(|id| id.0).collect_vec(); + let ids_ref: &[i64] = &ids; + let results = + select_clients_with_predicate!("WHERE clients.id = ANY($1::bigint[])", ids_ref) + .fetch_all(exec) + .await?; - Ok(value.map(|r| r.into())) + Ok(results.into_iter().map(|r| r.into()).collect_vec()) } pub async fn get_all_user_clients( diff --git a/src/models/oauth_clients.rs b/src/models/oauth_clients.rs index a22eb16d..ce747b87 100644 --- a/src/models/oauth_clients.rs +++ b/src/models/oauth_clients.rs @@ -61,6 +61,11 @@ pub struct OAuthClientAuthorizationInfo { pub client_created_by: UserId, } +#[derive(Deserialize, Serialize)] +pub struct GetOAuthClientsRequest { + pub ids: Vec, +} + impl From for OAuthClient { fn from(value: crate::database::models::oauth_client_item::OAuthClient) -> Self { Self { diff --git a/src/models/users.rs b/src/models/users.rs index 19aa16d5..8a012bf1 100644 --- a/src/models/users.rs +++ b/src/models/users.rs @@ -89,22 +89,6 @@ impl From for User { } } -impl User { - pub fn validate_can_interact_with_oauth_client( - &self, - client_owner_id: UserId, - ) -> Result<(), crate::routes::ApiError> { - if self.role.is_mod() || self.id == client_owner_id { - Ok(()) - } else { - Err(crate::routes::ApiError::CustomAuthentication( - "You don't have sufficient permissions to interact with this OAuth application" - .to_string(), - )) - } - } -} - #[derive(Serialize, Deserialize, PartialEq, Eq, Clone, Debug)] #[serde(rename_all = "lowercase")] pub enum Role { diff --git a/src/routes/v3/oauth_clients.rs b/src/routes/v3/oauth_clients.rs index 900d5128..ae9a68cd 100644 --- a/src/routes/v3/oauth_clients.rs +++ b/src/routes/v3/oauth_clients.rs @@ -14,8 +14,9 @@ use sqlx::PgPool; use validator::Validate; use super::ApiError; +use crate::auth::checks::ValidateAllAuthorized; use crate::{ - auth::get_user_from_headers, + auth::{checks::ValidateAuthorized, get_user_from_headers}, database::{ models::{ generate_oauth_client_id, generate_oauth_redirect_id, @@ -25,7 +26,11 @@ use crate::{ }, redis::RedisPool, }, - models::{self, oauth_clients::OAuthClientCreationResult, pats::Scopes}, + models::{ + self, + oauth_clients::{GetOAuthClientsRequest, OAuthClientCreationResult}, + pats::Scopes, + }, queue::session::AuthQueue, routes::v2::project_creation::CreateError, util::validate::validation_errors_to_string, @@ -37,6 +42,8 @@ pub fn config(cfg: &mut web::ServiceConfig) { cfg.service(oauth_client_create); cfg.service(oauth_client_edit); cfg.service(oauth_client_delete); + cfg.service(get_client); + cfg.service(get_clients); cfg.service(get_user_clients); cfg.service(get_user_oauth_authorizations); cfg.service(revoke_oauth_authorization); @@ -63,10 +70,10 @@ pub async fn get_user_clients( let target_user = User::get(&info.into_inner(), &**pool, &redis).await?; if let Some(target_user) = target_user { - let target_user_id: models::ids::UserId = target_user.id.into(); - current_user.validate_can_interact_with_oauth_client(target_user_id)?; - let clients = OAuthClient::get_all_user_clients(target_user.id, &**pool).await?; + clients + .iter() + .validate_all_authorized(Some(¤t_user))?; let response = clients .into_iter() @@ -79,6 +86,35 @@ pub async fn get_user_clients( } } +#[get("/oauth/app/{id}")] +pub async fn get_client( + req: HttpRequest, + id: web::Path, + pool: web::Data, + redis: web::Data, + session_queue: web::Data, +) -> Result { + let clients = get_clients_inner(&[id.into_inner()], req, pool, redis, session_queue).await?; + if let Some(client) = clients.into_iter().next() { + Ok(HttpResponse::Ok().json(client)) + } else { + Ok(HttpResponse::NotFound().body("")) + } +} + +#[get("/oauth/apps")] +pub async fn get_clients( + req: HttpRequest, + info: web::Json, + pool: web::Data, + redis: web::Data, + session_queue: web::Data, +) -> Result { + let clients = + get_clients_inner(&info.into_inner().ids, req, pool, redis, session_queue).await?; + Ok(HttpResponse::Ok().json(clients)) +} + #[derive(Deserialize, Validate)] pub struct NewOAuthApp { #[validate( @@ -174,7 +210,7 @@ pub async fn oauth_client_delete<'a>( let client = get_oauth_client_from_str_id(client_id.into_inner(), &**pool).await?; if let Some(client) = client { - current_user.validate_can_interact_with_oauth_client(client.created_by.into())?; + client.validate_authorized(Some(¤t_user))?; OAuthClient::remove(client.id, &**pool).await?; Ok(HttpResponse::NoContent().body("")) @@ -236,7 +272,7 @@ pub async fn oauth_client_edit( if let Some(existing_client) = get_oauth_client_from_str_id(client_id.into_inner(), &**pool).await? { - current_user.validate_can_interact_with_oauth_client(existing_client.created_by.into())?; + existing_client.validate_authorized(Some(¤t_user))?; let mut updated_client = existing_client.clone(); let OAuthClientEdit { @@ -386,3 +422,28 @@ async fn edit_redirects( Ok(()) } + +pub async fn get_clients_inner( + ids: &[String], + req: HttpRequest, + pool: web::Data, + redis: web::Data, + session_queue: web::Data, +) -> Result, ApiError> { + let current_user = get_user_from_headers( + &req, + &**pool, + &redis, + &session_queue, + Some(&[Scopes::OAUTH_CLIENT_READ]), + ) + .await? + .1; + + let clients = OAuthClient::get_many_str(&ids, &**pool).await?; + clients + .iter() + .validate_all_authorized(Some(¤t_user))?; + + Ok(clients.into_iter().map(|c| c.into()).collect_vec()) +} diff --git a/tests/common/api_v3/oauth_clients.rs b/tests/common/api_v3/oauth_clients.rs index b2251575..5edec98a 100644 --- a/tests/common/api_v3/oauth_clients.rs +++ b/tests/common/api_v3/oauth_clients.rs @@ -5,7 +5,7 @@ use actix_web::{ }; use labrinth::{ models::{ - oauth_clients::{OAuthClient, OAuthClientAuthorizationInfo, OAuthClientCreationResult}, + oauth_clients::{GetOAuthClientsRequest, OAuthClient, OAuthClientAuthorizationInfo}, pats::Scopes, }, routes::v3::oauth_clients::OAuthClientEdit, @@ -50,6 +50,25 @@ impl ApiV3 { test::read_body_json(resp).await } + pub async fn get_oauth_client(&self, client_id: String, pat: &str) -> ServiceResponse { + let req = TestRequest::get() + .uri(&format!("/v3/oauth/app/{}", client_id)) + .append_header((AUTHORIZATION, pat)) + .to_request(); + + self.call(req).await + } + + pub async fn get_oauth_clients(&self, clients: Vec, pat: &str) -> ServiceResponse { + let req = TestRequest::get() + .uri("/v3/oauth/apps") + .set_json(GetOAuthClientsRequest { ids: clients }) + .append_header((AUTHORIZATION, pat)) + .to_request(); + + self.call(req).await + } + pub async fn edit_oauth_client( &self, client_id: &str, diff --git a/tests/common/dummy_data.rs b/tests/common/dummy_data.rs index 2ed669df..8e4b57fa 100644 --- a/tests/common/dummy_data.rs +++ b/tests/common/dummy_data.rs @@ -19,7 +19,7 @@ use super::{ request_data::get_public_project_creation_data, }; -pub const DUMMY_DATA_UPDATE: i64 = 2; +pub const DUMMY_DATA_UPDATE: i64 = 3; #[allow(dead_code)] pub const DUMMY_CATEGORIES: &[&str] = &[ diff --git a/tests/oauth_clients.rs b/tests/oauth_clients.rs index 3ffb1471..257bbc90 100644 --- a/tests/oauth_clients.rs +++ b/tests/oauth_clients.rs @@ -7,7 +7,10 @@ use common::{ get_json_val_str, }; use labrinth::{ - models::{oauth_clients::OAuthClientCreationResult, pats::Scopes}, + models::{ + oauth_clients::{OAuthClient, OAuthClientCreationResult}, + pats::Scopes, + }, routes::v3::oauth_clients::OAuthClientEdit, }; @@ -85,6 +88,40 @@ async fn create_oauth_client_with_restricted_scopes_fails() { .await; } +#[actix_rt::test] +async fn get_oauth_client_for_client_creator_succeeds() { + with_test_environment(|env| async move { + let DummyOAuthClientAlpha { client_id, .. } = + env.dummy.as_ref().unwrap().oauth_client_alpha.clone(); + + let resp = env + .v3 + .get_oauth_client(client_id.clone(), USER_USER_PAT) + .await; + + assert_status(&resp, StatusCode::OK); + let client: OAuthClient = test::read_body_json(resp).await; + assert_eq!(get_json_val_str(client.id), client_id); + }) + .await; +} + +#[actix_rt::test] +async fn get_oauth_client_for_unrelated_user_fails() { + with_test_environment(|env| async move { + let DummyOAuthClientAlpha { client_id, .. } = + env.dummy.as_ref().unwrap().oauth_client_alpha.clone(); + + let resp = env + .v3 + .get_oauth_client(client_id.clone(), FRIEND_USER_PAT) + .await; + + assert_status(&resp, StatusCode::UNAUTHORIZED); + }) + .await; +} + #[actix_rt::test] async fn can_delete_oauth_client() { with_test_environment(|env| async move { From bc625fe731e768bc77670d3becab79ab7bcc5a8e Mon Sep 17 00:00:00 2001 From: Jackson Kruger Date: Thu, 26 Oct 2023 14:39:57 -0500 Subject: [PATCH 32/72] Remove joined oauth client info from authorization returns --- .../models/oauth_client_authorization_item.rs | 70 +++++++------------ src/models/oauth_clients.rs | 28 ++++---- src/routes/v3/oauth_clients.rs | 2 +- tests/common/api_v3/oauth_clients.rs | 7 +- 4 files changed, 43 insertions(+), 64 deletions(-) diff --git a/src/database/models/oauth_client_authorization_item.rs b/src/database/models/oauth_client_authorization_item.rs index befa8679..617e6fcd 100644 --- a/src/database/models/oauth_client_authorization_item.rs +++ b/src/database/models/oauth_client_authorization_item.rs @@ -15,15 +15,24 @@ pub struct OAuthClientAuthorization { pub created: DateTime, } -pub struct OAuthClientAuthorizationWithClientInfo { - pub id: OAuthClientAuthorizationId, - pub client_id: OAuthClientId, - pub user_id: UserId, - pub scopes: Scopes, - pub created: DateTime, - pub client_name: String, - pub client_icon_url: Option, - pub client_created_by: UserId, +struct AuthorizationQueryResult { + id: i64, + client_id: i64, + user_id: i64, + scopes: i64, + created: DateTime, +} + +impl From for OAuthClientAuthorization { + fn from(value: AuthorizationQueryResult) -> Self { + OAuthClientAuthorization { + id: OAuthClientAuthorizationId(value.id), + client_id: OAuthClientId(value.client_id), + user_id: UserId(value.user_id), + scopes: Scopes::from_postgres(value.scopes), + created: value.created, + } + } } impl OAuthClientAuthorization { @@ -32,7 +41,8 @@ impl OAuthClientAuthorization { user_id: UserId, exec: impl sqlx::Executor<'_, Database = sqlx::Postgres>, ) -> Result, DatabaseError> { - let value = sqlx::query!( + let value = sqlx::query_as!( + AuthorizationQueryResult, " SELECT id, client_id, user_id, scopes, created FROM oauth_client_authorizations @@ -44,32 +54,18 @@ impl OAuthClientAuthorization { .fetch_optional(exec) .await?; - Ok(value.map(|r| OAuthClientAuthorization { - id: OAuthClientAuthorizationId(r.id), - client_id: OAuthClientId(r.client_id), - user_id: UserId(r.user_id), - scopes: Scopes::from_postgres(r.scopes), - created: r.created, - })) + Ok(value.map(|r| r.into())) } pub async fn get_all_for_user( user_id: UserId, exec: impl sqlx::Executor<'_, Database = sqlx::Postgres>, - ) -> Result, DatabaseError> { - let results = sqlx::query!( + ) -> Result, DatabaseError> { + let results = sqlx::query_as!( + AuthorizationQueryResult, " - SELECT - auths.id, - auths.client_id, - auths.user_id, - auths.scopes, - auths.created, - clients.name as client_name, - clients.icon_url as client_icon_url, - clients.created_by as client_created_by - FROM oauth_client_authorizations auths - JOIN oauth_clients clients ON clients.id = auths.client_id + SELECT id, client_id, user_id, scopes, created + FROM oauth_client_authorizations WHERE user_id=$1 ", user_id.0 @@ -77,19 +73,7 @@ impl OAuthClientAuthorization { .fetch_all(exec) .await?; - Ok(results - .into_iter() - .map(|r| OAuthClientAuthorizationWithClientInfo { - id: OAuthClientAuthorizationId(r.id), - client_id: OAuthClientId(r.client_id), - user_id: UserId(r.user_id), - scopes: Scopes::from_postgres(r.scopes), - created: r.created, - client_name: r.client_name, - client_icon_url: r.client_icon_url, - client_created_by: UserId(r.client_created_by), - }) - .collect_vec()) + Ok(results.into_iter().map(|r| r.into()).collect_vec()) } pub async fn upsert( diff --git a/src/models/oauth_clients.rs b/src/models/oauth_clients.rs index ce747b87..fe980064 100644 --- a/src/models/oauth_clients.rs +++ b/src/models/oauth_clients.rs @@ -5,6 +5,10 @@ use super::{ use chrono::{DateTime, Utc}; use serde::{Deserialize, Serialize}; +use crate::database::models::oauth_client_authorization_item::OAuthClientAuthorization as DBOAuthClientAuthorization; +use crate::database::models::oauth_client_item::OAuthClient as DBOAuthClient; +use crate::database::models::oauth_client_item::OAuthRedirectUri as DBOAuthRedirectUri; + #[derive(Copy, Clone, PartialEq, Eq, Serialize, Deserialize)] #[serde(from = "Base62Id")] #[serde(into = "Base62Id")] @@ -50,15 +54,12 @@ pub struct OAuthClient { } #[derive(Deserialize, Serialize)] -pub struct OAuthClientAuthorizationInfo { +pub struct OAuthClientAuthorization { pub id: OAuthClientAuthorizationId, - pub client_id: OAuthClientId, + pub app_id: OAuthClientId, pub user_id: UserId, pub scopes: Scopes, pub created: DateTime, - pub client_name: String, - pub client_icon_url: Option, - pub client_created_by: UserId, } #[derive(Deserialize, Serialize)] @@ -66,8 +67,8 @@ pub struct GetOAuthClientsRequest { pub ids: Vec, } -impl From for OAuthClient { - fn from(value: crate::database::models::oauth_client_item::OAuthClient) -> Self { +impl From for OAuthClient { + fn from(value: DBOAuthClient) -> Self { Self { id: value.id.into(), name: value.name, @@ -79,8 +80,8 @@ impl From for OAuthClie } } -impl From for OAuthRedirectUri { - fn from(value: crate::database::models::oauth_client_item::OAuthRedirectUri) -> Self { +impl From for OAuthRedirectUri { + fn from(value: DBOAuthRedirectUri) -> Self { Self { id: value.id.into(), client_id: value.client_id.into(), @@ -89,17 +90,14 @@ impl From for OAut } } -impl From for OAuthClientAuthorizationInfo { - fn from(value: crate::database::models::oauth_client_authorization_item::OAuthClientAuthorizationWithClientInfo) -> Self { +impl From for OAuthClientAuthorization { + fn from(value: DBOAuthClientAuthorization) -> Self { Self { id: value.id.into(), - client_id: value.client_id.into(), + app_id: value.client_id.into(), user_id: value.user_id.into(), scopes: value.scopes, created: value.created, - client_name: value.client_name, - client_icon_url: value.client_icon_url, - client_created_by: value.client_created_by.into(), } } } diff --git a/src/routes/v3/oauth_clients.rs b/src/routes/v3/oauth_clients.rs index ae9a68cd..58cf53ee 100644 --- a/src/routes/v3/oauth_clients.rs +++ b/src/routes/v3/oauth_clients.rs @@ -330,7 +330,7 @@ pub async fn get_user_oauth_authorizations( let authorizations = OAuthClientAuthorization::get_all_for_user(current_user.id.into(), &**pool).await?; - let mapped: Vec = + let mapped: Vec = authorizations.into_iter().map(|a| a.into()).collect_vec(); Ok(HttpResponse::Ok().json(mapped)) diff --git a/tests/common/api_v3/oauth_clients.rs b/tests/common/api_v3/oauth_clients.rs index 5edec98a..181efa3e 100644 --- a/tests/common/api_v3/oauth_clients.rs +++ b/tests/common/api_v3/oauth_clients.rs @@ -5,7 +5,7 @@ use actix_web::{ }; use labrinth::{ models::{ - oauth_clients::{GetOAuthClientsRequest, OAuthClient, OAuthClientAuthorizationInfo}, + oauth_clients::{GetOAuthClientsRequest, OAuthClient, OAuthClientAuthorization}, pats::Scopes, }, routes::v3::oauth_clients::OAuthClientEdit, @@ -101,10 +101,7 @@ impl ApiV3 { self.call(req).await } - pub async fn get_user_oauth_authorizations( - &self, - pat: &str, - ) -> Vec { + pub async fn get_user_oauth_authorizations(&self, pat: &str) -> Vec { let req = TestRequest::get() .uri("/v3/user/oauth_authorizations") .append_header((AUTHORIZATION, pat)) From 5468b0aab81281c93ad403db88a610a5eebe2751 Mon Sep 17 00:00:00 2001 From: Jackson Kruger Date: Thu, 26 Oct 2023 14:47:35 -0500 Subject: [PATCH 33/72] Add default conversion to OAuthError::error so we can use ? --- src/auth/oauth/errors.rs | 9 +++++++ src/auth/oauth/mod.rs | 52 +++++++++++++--------------------------- 2 files changed, 26 insertions(+), 35 deletions(-) diff --git a/src/auth/oauth/errors.rs b/src/auth/oauth/errors.rs index 486f6180..2b82da35 100644 --- a/src/auth/oauth/errors.rs +++ b/src/auth/oauth/errors.rs @@ -15,6 +15,15 @@ pub struct OAuthError { pub valid_redirect_uri: Option, } +impl From for OAuthError +where + T: Into, +{ + fn from(value: T) -> Self { + OAuthError::error(value.into()) + } +} + impl OAuthError { /// The OAuth request failed either because of an invalid redirection URI /// or before we could validate the one we were given, so return an error diff --git a/src/auth/oauth/mod.rs b/src/auth/oauth/mod.rs index e232e331..226b7d82 100644 --- a/src/auth/oauth/mod.rs +++ b/src/auth/oauth/mod.rs @@ -72,16 +72,11 @@ pub async fn init_oauth( &session_queue, Some(&[Scopes::USER_AUTH_WRITE]), ) - .await - .map_err(OAuthError::error)? + .await? .1; - let client_id: OAuthClientId = models::ids::OAuthClientId::parse(&oauth_info.client_id) - .map_err(OAuthError::error)? - .into(); - let client = DBOAuthClient::get(client_id, &**pool) - .await - .map_err(OAuthError::error)?; + let client_id: OAuthClientId = models::ids::OAuthClientId::parse(&oauth_info.client_id)?.into(); + let client = DBOAuthClient::get(client_id, &**pool).await?; if let Some(client) = client { let redirect_uri = ValidatedRedirectUri::validate( @@ -214,12 +209,8 @@ pub async fn request_token( pool: Data, redis: Data, ) -> Result { - let req_client_id = models::ids::OAuthClientId::parse(&req_params.client_id) - .map_err(OAuthError::error)? - .into(); - let client = DBOAuthClient::get(req_client_id, &**pool) - .await - .map_err(OAuthError::error)?; + let req_client_id = models::ids::OAuthClientId::parse(&req_params.client_id)?.into(); + let client = DBOAuthClient::get(req_client_id, &**pool).await?; if let Some(client) = client { authenticate_client_token_request(&req, &client)?; @@ -230,8 +221,7 @@ pub async fn request_token( |f| matches!(f, Flow::OAuthAuthorizationCodeSupplied { .. }), &redis, ) - .await - .map_err(OAuthError::error)?; + .await?; if let Some(Flow::OAuthAuthorizationCodeSupplied { user_id, client_id, @@ -261,10 +251,8 @@ pub async fn request_token( let scopes = scopes - Scopes::restricted(); - let mut transaction = pool.begin().await.map_err(OAuthError::error)?; - let token_id = generate_oauth_access_token_id(&mut transaction) - .await - .map_err(OAuthError::error)?; + let mut transaction = pool.begin().await?; + let token_id = generate_oauth_access_token_id(&mut transaction).await?; let token = generate_access_token(); let token_hash = OAuthAccessToken::hash_token(&token); let time_until_expiration = OAuthAccessToken { @@ -279,10 +267,9 @@ pub async fn request_token( user_id, } .insert(&mut *transaction) - .await - .map_err(OAuthError::error)?; + .await?; - transaction.commit().await.map_err(OAuthError::error)?; + transaction.commit().await?; // IETF RFC6749 Section 5.1 (https://datatracker.ietf.org/doc/html/rfc6749#section-5.1) Ok(HttpResponse::Ok() @@ -318,8 +305,7 @@ pub async fn accept_or_reject_client_scopes( &session_queue, Some(&[Scopes::USER_OAUTH_AUTHORIZATIONS_WRITE]), ) - .await - .map_err(OAuthError::error)? + .await? .1; let flow = Flow::take_if( @@ -327,8 +313,7 @@ pub async fn accept_or_reject_client_scopes( |f| matches!(f, Flow::InitOAuthAppApproval { .. }), &redis, ) - .await - .map_err(OAuthError::error)?; + .await?; if let Some(Flow::InitOAuthAppApproval { user_id, client_id, @@ -343,19 +328,16 @@ pub async fn accept_or_reject_client_scopes( } if accept { - let mut transaction = pool.begin().await.map_err(OAuthError::error)?; + let mut transaction = pool.begin().await?; let auth_id = match existing_authorization_id { Some(id) => id, - None => generate_oauth_client_authorization_id(&mut transaction) - .await - .map_err(OAuthError::error)?, + None => generate_oauth_client_authorization_id(&mut transaction).await?, }; OAuthClientAuthorization::upsert(auth_id, client_id, user_id, scopes, &mut transaction) - .await - .map_err(OAuthError::error)?; + .await?; - transaction.commit().await.map_err(OAuthError::error)?; + transaction.commit().await?; init_oauth_code_flow( user_id, @@ -383,7 +365,7 @@ fn authenticate_client_token_request( req: &HttpRequest, client: &DBOAuthClient, ) -> Result<(), OAuthError> { - let client_secret = extract_authorization_header(req).map_err(OAuthError::error)?; + let client_secret = extract_authorization_header(req)?; let hashed_client_secret = DBOAuthClient::hash_secret(client_secret); if client.secret_hash != hashed_client_secret { Err(OAuthError::error( From 31787552b71c9c73ffbf530b675ffdb7109c95b8 Mon Sep 17 00:00:00 2001 From: Jackson Kruger Date: Thu, 26 Oct 2023 15:10:05 -0500 Subject: [PATCH 34/72] Rework routes --- src/models/oauth_clients.rs | 5 ++++ src/routes/v3/oauth_clients.rs | 41 ++++++++++++++++------------ tests/common/api_v3/oauth_clients.rs | 13 +++++---- 3 files changed, 36 insertions(+), 23 deletions(-) diff --git a/src/models/oauth_clients.rs b/src/models/oauth_clients.rs index fe980064..58797433 100644 --- a/src/models/oauth_clients.rs +++ b/src/models/oauth_clients.rs @@ -67,6 +67,11 @@ pub struct GetOAuthClientsRequest { pub ids: Vec, } +#[derive(Deserialize, Serialize)] +pub struct DeleteOAuthClientQueryParam { + pub client_id: String, +} + impl From for OAuthClient { fn from(value: DBOAuthClient) -> Self { Self { diff --git a/src/routes/v3/oauth_clients.rs b/src/routes/v3/oauth_clients.rs index 58cf53ee..66251c5e 100644 --- a/src/routes/v3/oauth_clients.rs +++ b/src/routes/v3/oauth_clients.rs @@ -2,7 +2,7 @@ use std::{collections::HashSet, fmt::Display}; use actix_web::{ delete, get, patch, post, - web::{self}, + web::{self, scope}, HttpRequest, HttpResponse, }; use chrono::Utc; @@ -14,7 +14,9 @@ use sqlx::PgPool; use validator::Validate; use super::ApiError; -use crate::auth::checks::ValidateAllAuthorized; +use crate::{ + auth::checks::ValidateAllAuthorized, models::oauth_clients::DeleteOAuthClientQueryParam, +}; use crate::{ auth::{checks::ValidateAuthorized, get_user_from_headers}, database::{ @@ -39,14 +41,17 @@ use crate::{ use crate::database::models::oauth_client_item::OAuthClient as DBOAuthClient; pub fn config(cfg: &mut web::ServiceConfig) { - cfg.service(oauth_client_create); - cfg.service(oauth_client_edit); - cfg.service(oauth_client_delete); - cfg.service(get_client); - cfg.service(get_clients); cfg.service(get_user_clients); - cfg.service(get_user_oauth_authorizations); - cfg.service(revoke_oauth_authorization); + cfg.service( + scope("oauth") + .service(oauth_client_create) + .service(oauth_client_edit) + .service(oauth_client_delete) + .service(get_client) + .service(get_clients) + .service(get_user_oauth_authorizations) + .service(revoke_oauth_authorization), + ); } #[get("user/{user_id}/oauth_apps")] @@ -86,7 +91,7 @@ pub async fn get_user_clients( } } -#[get("/oauth/app/{id}")] +#[get("app/{id}")] pub async fn get_client( req: HttpRequest, id: web::Path, @@ -102,7 +107,7 @@ pub async fn get_client( } } -#[get("/oauth/apps")] +#[get("apps")] pub async fn get_clients( req: HttpRequest, info: web::Json, @@ -136,7 +141,7 @@ pub struct NewOAuthApp { pub redirect_uris: Vec, } -#[post("oauth_app")] +#[post("app")] pub async fn oauth_client_create<'a>( req: HttpRequest, new_oauth_app: web::Json, @@ -190,7 +195,7 @@ pub async fn oauth_client_create<'a>( })) } -#[delete("oauth_app/{id}")] +#[delete("app/{id}")] pub async fn oauth_client_delete<'a>( req: HttpRequest, client_id: web::Path, @@ -239,7 +244,7 @@ pub struct OAuthClientEdit { pub redirect_uris: Option>, } -#[patch("oauth_app/{id}")] +#[patch("app/{id}")] pub async fn oauth_client_edit( req: HttpRequest, client_id: web::Path, @@ -310,7 +315,7 @@ pub async fn oauth_client_edit( } } -#[get("user/oauth_authorizations")] +#[get("authorizations")] pub async fn get_user_oauth_authorizations( req: HttpRequest, pool: web::Data, @@ -336,10 +341,10 @@ pub async fn get_user_oauth_authorizations( Ok(HttpResponse::Ok().json(mapped)) } -#[delete("user/oauth_authorizations/{client_id}")] +#[delete("authorizations")] pub async fn revoke_oauth_authorization( req: HttpRequest, - info: web::Path, + info: web::Query, pool: web::Data, redis: web::Data, session_queue: web::Data, @@ -354,7 +359,7 @@ pub async fn revoke_oauth_authorization( .await? .1; - let client_id = models::ids::OAuthClientId::parse(&info.into_inner())?; + let client_id = models::ids::OAuthClientId::parse(&info.into_inner().client_id)?; OAuthClientAuthorization::remove(client_id.into(), current_user.id.into(), &**pool).await?; diff --git a/tests/common/api_v3/oauth_clients.rs b/tests/common/api_v3/oauth_clients.rs index 181efa3e..61c47321 100644 --- a/tests/common/api_v3/oauth_clients.rs +++ b/tests/common/api_v3/oauth_clients.rs @@ -27,7 +27,7 @@ impl ApiV3 { ) -> ServiceResponse { let max_scopes = max_scopes.bits(); let req = TestRequest::post() - .uri("/v3/oauth_app") + .uri("/v3/oauth/app") .append_header((AUTHORIZATION, pat)) .set_json(json!({ "name": name, @@ -76,7 +76,7 @@ impl ApiV3 { pat: &str, ) -> ServiceResponse { let req = TestRequest::patch() - .uri(&format!("/v3/oauth_app/{}", client_id)) + .uri(&format!("/v3/oauth/app/{}", client_id)) .set_json(edit) .append_header((AUTHORIZATION, pat)) .to_request(); @@ -86,7 +86,7 @@ impl ApiV3 { pub async fn delete_oauth_client(&self, client_id: &str, pat: &str) -> ServiceResponse { let req = TestRequest::delete() - .uri(&format!("/v3/oauth_app/{}", client_id)) + .uri(&format!("/v3/oauth/app/{}", client_id)) .append_header((AUTHORIZATION, pat)) .to_request(); @@ -95,7 +95,10 @@ impl ApiV3 { pub async fn revoke_oauth_authorization(&self, client_id: &str, pat: &str) -> ServiceResponse { let req = TestRequest::delete() - .uri(&format!("/v3/user/oauth_authorizations/{}", client_id)) + .uri(&format!( + "/v3/oauth/authorizations?client_id={}", + urlencoding::encode(client_id) + )) .append_header((AUTHORIZATION, pat)) .to_request(); self.call(req).await @@ -103,7 +106,7 @@ impl ApiV3 { pub async fn get_user_oauth_authorizations(&self, pat: &str) -> Vec { let req = TestRequest::get() - .uri("/v3/user/oauth_authorizations") + .uri("/v3/oauth/authorizations") .append_header((AUTHORIZATION, pat)) .to_request(); let resp = self.call(req).await; From 3be6305ae1634fe77e66879193aa538f8159196c Mon Sep 17 00:00:00 2001 From: Jackson Kruger Date: Thu, 26 Oct 2023 15:16:59 -0500 Subject: [PATCH 35/72] Consolidate scopes into SESSION_ACCESS --- src/auth/oauth/mod.rs | 2 +- src/models/pats.rs | 22 +++------------------- src/routes/v3/oauth_clients.rs | 14 +++++++------- 3 files changed, 11 insertions(+), 27 deletions(-) diff --git a/src/auth/oauth/mod.rs b/src/auth/oauth/mod.rs index 226b7d82..13163236 100644 --- a/src/auth/oauth/mod.rs +++ b/src/auth/oauth/mod.rs @@ -303,7 +303,7 @@ pub async fn accept_or_reject_client_scopes( &**pool, &redis, &session_queue, - Some(&[Scopes::USER_OAUTH_AUTHORIZATIONS_WRITE]), + Some(&[Scopes::SESSION_ACCESS]), ) .await? .1; diff --git a/src/models/pats.rs b/src/models/pats.rs index ff63e807..07a58692 100644 --- a/src/models/pats.rs +++ b/src/models/pats.rs @@ -103,19 +103,8 @@ bitflags::bitflags! { // delete an organization const ORGANIZATION_DELETE = 1 << 38; - // create an OAuth client - const OAUTH_CLIENT_CREATE = 1 << 39; - // get OAuth clients - const OAUTH_CLIENT_READ = 1 << 40; - // modify an OAuth client - const OAUTH_CLIENT_WRITE = 1 << 41; - // delete an OAuth client - const OAUTH_CLIENT_DELETE = 1 << 42; - - // read a user's OAuth authorizations - const USER_OAUTH_AUTHORIZATIONS_READ = 1 << 43; - // write a user's OAuth authorizations - const USER_OAUTH_AUTHORIZATIONS_WRITE = 1 << 44; + // only accessible by modrinth-issued sessions + const SESSION_ACCESS = 1 << 39; const NONE = 0b0; } @@ -132,15 +121,10 @@ impl Scopes { | Scopes::PAT_DELETE | Scopes::SESSION_READ | Scopes::SESSION_DELETE + | Scopes::SESSION_ACCESS | Scopes::USER_AUTH_WRITE | Scopes::USER_DELETE | Scopes::PERFORM_ANALYTICS - | Scopes::OAUTH_CLIENT_CREATE - | Scopes::OAUTH_CLIENT_READ - | Scopes::OAUTH_CLIENT_WRITE - | Scopes::OAUTH_CLIENT_DELETE - | Scopes::USER_OAUTH_AUTHORIZATIONS_READ - | Scopes::USER_OAUTH_AUTHORIZATIONS_WRITE } pub fn is_restricted(&self) -> bool { diff --git a/src/routes/v3/oauth_clients.rs b/src/routes/v3/oauth_clients.rs index 66251c5e..6ee3561b 100644 --- a/src/routes/v3/oauth_clients.rs +++ b/src/routes/v3/oauth_clients.rs @@ -67,7 +67,7 @@ pub async fn get_user_clients( &**pool, &redis, &session_queue, - Some(&[Scopes::OAUTH_CLIENT_READ]), + Some(&[Scopes::SESSION_ACCESS]), ) .await? .1; @@ -154,7 +154,7 @@ pub async fn oauth_client_create<'a>( &**pool, &redis, &session_queue, - Some(&[Scopes::OAUTH_CLIENT_CREATE]), + Some(&[Scopes::SESSION_ACCESS]), ) .await? .1; @@ -208,7 +208,7 @@ pub async fn oauth_client_delete<'a>( &**pool, &redis, &session_queue, - Some(&[Scopes::OAUTH_CLIENT_DELETE]), + Some(&[Scopes::SESSION_ACCESS]), ) .await? .1; @@ -258,7 +258,7 @@ pub async fn oauth_client_edit( &**pool, &redis, &session_queue, - Some(&[Scopes::OAUTH_CLIENT_WRITE]), + Some(&[Scopes::SESSION_ACCESS]), ) .await? .1; @@ -327,7 +327,7 @@ pub async fn get_user_oauth_authorizations( &**pool, &redis, &session_queue, - Some(&[Scopes::USER_OAUTH_AUTHORIZATIONS_READ]), + Some(&[Scopes::SESSION_ACCESS]), ) .await? .1; @@ -354,7 +354,7 @@ pub async fn revoke_oauth_authorization( &**pool, &redis, &session_queue, - Some(&[Scopes::USER_OAUTH_AUTHORIZATIONS_WRITE]), + Some(&[Scopes::SESSION_ACCESS]), ) .await? .1; @@ -440,7 +440,7 @@ pub async fn get_clients_inner( &**pool, &redis, &session_queue, - Some(&[Scopes::OAUTH_CLIENT_READ]), + Some(&[Scopes::SESSION_ACCESS]), ) .await? .1; From f26de300ad4ac0a4fec206123f5579de5ee6a403 Mon Sep 17 00:00:00 2001 From: Jackson Kruger Date: Thu, 26 Oct 2023 15:20:56 -0500 Subject: [PATCH 36/72] Cargo sqlx prepare --- ...ca546db44014cb9381f4f3a6d0af5d9d69511.json | 64 ----------------- ...a2593a1556ef214f3bed519de6b6e21c7d477.json | 70 +++++++++++++++++++ ...23d53fcf26489a387d6e170a58a55a814201d.json | 70 ------------------- ...338a9e6a9d7763d5614066adb0255139d7956.json | 70 ------------------- ...17ef99fae60e5fbff5f0396f70787156de322.json | 46 ++++++++++++ ...549ab5c23306758168af38f955c06d251b0b7.json | 70 +++++++++++++++++++ 6 files changed, 186 insertions(+), 204 deletions(-) delete mode 100644 .sqlx/query-4eb4b52a089d2fa11b1036d7f9fca546db44014cb9381f4f3a6d0af5d9d69511.json create mode 100644 .sqlx/query-8bfb350d4f539a110b05f42812ea2593a1556ef214f3bed519de6b6e21c7d477.json delete mode 100644 .sqlx/query-97f5b001699f8de4e80bdb5680f23d53fcf26489a387d6e170a58a55a814201d.json delete mode 100644 .sqlx/query-b0699eef7dbf20f3ea457d1fe16338a9e6a9d7763d5614066adb0255139d7956.json create mode 100644 .sqlx/query-c5319631c46ffa46e218fcf308f17ef99fae60e5fbff5f0396f70787156de322.json create mode 100644 .sqlx/query-fdfb2433a8e407d42cec1791d67549ab5c23306758168af38f955c06d251b0b7.json diff --git a/.sqlx/query-4eb4b52a089d2fa11b1036d7f9fca546db44014cb9381f4f3a6d0af5d9d69511.json b/.sqlx/query-4eb4b52a089d2fa11b1036d7f9fca546db44014cb9381f4f3a6d0af5d9d69511.json deleted file mode 100644 index cbc186c5..00000000 --- a/.sqlx/query-4eb4b52a089d2fa11b1036d7f9fca546db44014cb9381f4f3a6d0af5d9d69511.json +++ /dev/null @@ -1,64 +0,0 @@ -{ - "db_name": "PostgreSQL", - "query": "\n SELECT \n auths.id,\n auths.client_id,\n auths.user_id,\n auths.scopes,\n auths.created,\n clients.name as client_name,\n clients.icon_url as client_icon_url,\n clients.created_by as client_created_by\n FROM oauth_client_authorizations auths\n JOIN oauth_clients clients ON clients.id = auths.client_id\n WHERE user_id=$1\n ", - "describe": { - "columns": [ - { - "ordinal": 0, - "name": "id", - "type_info": "Int8" - }, - { - "ordinal": 1, - "name": "client_id", - "type_info": "Int8" - }, - { - "ordinal": 2, - "name": "user_id", - "type_info": "Int8" - }, - { - "ordinal": 3, - "name": "scopes", - "type_info": "Int8" - }, - { - "ordinal": 4, - "name": "created", - "type_info": "Timestamptz" - }, - { - "ordinal": 5, - "name": "client_name", - "type_info": "Text" - }, - { - "ordinal": 6, - "name": "client_icon_url", - "type_info": "Text" - }, - { - "ordinal": 7, - "name": "client_created_by", - "type_info": "Int8" - } - ], - "parameters": { - "Left": [ - "Int8" - ] - }, - "nullable": [ - false, - false, - false, - false, - false, - false, - true, - false - ] - }, - "hash": "4eb4b52a089d2fa11b1036d7f9fca546db44014cb9381f4f3a6d0af5d9d69511" -} diff --git a/.sqlx/query-8bfb350d4f539a110b05f42812ea2593a1556ef214f3bed519de6b6e21c7d477.json b/.sqlx/query-8bfb350d4f539a110b05f42812ea2593a1556ef214f3bed519de6b6e21c7d477.json new file mode 100644 index 00000000..9be628e8 --- /dev/null +++ b/.sqlx/query-8bfb350d4f539a110b05f42812ea2593a1556ef214f3bed519de6b6e21c7d477.json @@ -0,0 +1,70 @@ +{ + "db_name": "PostgreSQL", + "query": "\n SELECT\n clients.id as \"id!\",\n clients.name as \"name!\",\n clients.icon_url as \"icon_url?\",\n clients.max_scopes as \"max_scopes!\",\n clients.secret_hash as \"secret_hash!\",\n clients.created as \"created!\",\n clients.created_by as \"created_by!\",\n uris.uri_ids as \"uri_ids?\",\n uris.uri_vals as \"uri_vals?\"\n FROM oauth_clients clients\n LEFT JOIN (\n SELECT client_id, array_agg(id) as uri_ids, array_agg(uri) as uri_vals\n FROM oauth_client_redirect_uris\n GROUP BY client_id\n ) uris ON clients.id = uris.client_id\n WHERE created_by = $1", + "describe": { + "columns": [ + { + "ordinal": 0, + "name": "id!", + "type_info": "Int8" + }, + { + "ordinal": 1, + "name": "name!", + "type_info": "Text" + }, + { + "ordinal": 2, + "name": "icon_url?", + "type_info": "Text" + }, + { + "ordinal": 3, + "name": "max_scopes!", + "type_info": "Int8" + }, + { + "ordinal": 4, + "name": "secret_hash!", + "type_info": "Text" + }, + { + "ordinal": 5, + "name": "created!", + "type_info": "Timestamptz" + }, + { + "ordinal": 6, + "name": "created_by!", + "type_info": "Int8" + }, + { + "ordinal": 7, + "name": "uri_ids?", + "type_info": "Int8Array" + }, + { + "ordinal": 8, + "name": "uri_vals?", + "type_info": "TextArray" + } + ], + "parameters": { + "Left": [ + "Int8" + ] + }, + "nullable": [ + false, + false, + true, + false, + false, + false, + false, + null, + null + ] + }, + "hash": "8bfb350d4f539a110b05f42812ea2593a1556ef214f3bed519de6b6e21c7d477" +} diff --git a/.sqlx/query-97f5b001699f8de4e80bdb5680f23d53fcf26489a387d6e170a58a55a814201d.json b/.sqlx/query-97f5b001699f8de4e80bdb5680f23d53fcf26489a387d6e170a58a55a814201d.json deleted file mode 100644 index aa87c762..00000000 --- a/.sqlx/query-97f5b001699f8de4e80bdb5680f23d53fcf26489a387d6e170a58a55a814201d.json +++ /dev/null @@ -1,70 +0,0 @@ -{ - "db_name": "PostgreSQL", - "query": "\n SELECT\n clients.id,\n clients.name,\n clients.icon_url,\n clients.max_scopes,\n clients.secret_hash,\n clients.created,\n clients.created_by,\n uris.uri_ids,\n uris.uri_vals\n FROM oauth_clients clients\n LEFT JOIN (\n SELECT client_id, array_agg(id) as uri_ids, array_agg(uri) as uri_vals\n FROM oauth_client_redirect_uris\n GROUP BY client_id\n ) uris ON clients.id = uris.client_id\n WHERE created_by = $1", - "describe": { - "columns": [ - { - "ordinal": 0, - "name": "id", - "type_info": "Int8" - }, - { - "ordinal": 1, - "name": "name", - "type_info": "Text" - }, - { - "ordinal": 2, - "name": "icon_url", - "type_info": "Text" - }, - { - "ordinal": 3, - "name": "max_scopes", - "type_info": "Int8" - }, - { - "ordinal": 4, - "name": "secret_hash", - "type_info": "Text" - }, - { - "ordinal": 5, - "name": "created", - "type_info": "Timestamptz" - }, - { - "ordinal": 6, - "name": "created_by", - "type_info": "Int8" - }, - { - "ordinal": 7, - "name": "uri_ids", - "type_info": "Int8Array" - }, - { - "ordinal": 8, - "name": "uri_vals", - "type_info": "TextArray" - } - ], - "parameters": { - "Left": [ - "Int8" - ] - }, - "nullable": [ - false, - false, - true, - false, - false, - false, - false, - null, - null - ] - }, - "hash": "97f5b001699f8de4e80bdb5680f23d53fcf26489a387d6e170a58a55a814201d" -} diff --git a/.sqlx/query-b0699eef7dbf20f3ea457d1fe16338a9e6a9d7763d5614066adb0255139d7956.json b/.sqlx/query-b0699eef7dbf20f3ea457d1fe16338a9e6a9d7763d5614066adb0255139d7956.json deleted file mode 100644 index 14be2209..00000000 --- a/.sqlx/query-b0699eef7dbf20f3ea457d1fe16338a9e6a9d7763d5614066adb0255139d7956.json +++ /dev/null @@ -1,70 +0,0 @@ -{ - "db_name": "PostgreSQL", - "query": "\n SELECT\n clients.id,\n clients.name,\n clients.icon_url,\n clients.max_scopes,\n clients.secret_hash,\n clients.created,\n clients.created_by,\n uris.uri_ids,\n uris.uri_vals\n FROM oauth_clients clients\n LEFT JOIN (\n SELECT client_id, array_agg(id) as uri_ids, array_agg(uri) as uri_vals\n FROM oauth_client_redirect_uris\n GROUP BY client_id\n ) uris ON clients.id = uris.client_id\n WHERE clients.id = $1", - "describe": { - "columns": [ - { - "ordinal": 0, - "name": "id", - "type_info": "Int8" - }, - { - "ordinal": 1, - "name": "name", - "type_info": "Text" - }, - { - "ordinal": 2, - "name": "icon_url", - "type_info": "Text" - }, - { - "ordinal": 3, - "name": "max_scopes", - "type_info": "Int8" - }, - { - "ordinal": 4, - "name": "secret_hash", - "type_info": "Text" - }, - { - "ordinal": 5, - "name": "created", - "type_info": "Timestamptz" - }, - { - "ordinal": 6, - "name": "created_by", - "type_info": "Int8" - }, - { - "ordinal": 7, - "name": "uri_ids", - "type_info": "Int8Array" - }, - { - "ordinal": 8, - "name": "uri_vals", - "type_info": "TextArray" - } - ], - "parameters": { - "Left": [ - "Int8" - ] - }, - "nullable": [ - false, - false, - true, - false, - false, - false, - false, - null, - null - ] - }, - "hash": "b0699eef7dbf20f3ea457d1fe16338a9e6a9d7763d5614066adb0255139d7956" -} diff --git a/.sqlx/query-c5319631c46ffa46e218fcf308f17ef99fae60e5fbff5f0396f70787156de322.json b/.sqlx/query-c5319631c46ffa46e218fcf308f17ef99fae60e5fbff5f0396f70787156de322.json new file mode 100644 index 00000000..722f0589 --- /dev/null +++ b/.sqlx/query-c5319631c46ffa46e218fcf308f17ef99fae60e5fbff5f0396f70787156de322.json @@ -0,0 +1,46 @@ +{ + "db_name": "PostgreSQL", + "query": "\n SELECT id, client_id, user_id, scopes, created\n FROM oauth_client_authorizations\n WHERE user_id=$1\n ", + "describe": { + "columns": [ + { + "ordinal": 0, + "name": "id", + "type_info": "Int8" + }, + { + "ordinal": 1, + "name": "client_id", + "type_info": "Int8" + }, + { + "ordinal": 2, + "name": "user_id", + "type_info": "Int8" + }, + { + "ordinal": 3, + "name": "scopes", + "type_info": "Int8" + }, + { + "ordinal": 4, + "name": "created", + "type_info": "Timestamptz" + } + ], + "parameters": { + "Left": [ + "Int8" + ] + }, + "nullable": [ + false, + false, + false, + false, + false + ] + }, + "hash": "c5319631c46ffa46e218fcf308f17ef99fae60e5fbff5f0396f70787156de322" +} diff --git a/.sqlx/query-fdfb2433a8e407d42cec1791d67549ab5c23306758168af38f955c06d251b0b7.json b/.sqlx/query-fdfb2433a8e407d42cec1791d67549ab5c23306758168af38f955c06d251b0b7.json new file mode 100644 index 00000000..08fa78f0 --- /dev/null +++ b/.sqlx/query-fdfb2433a8e407d42cec1791d67549ab5c23306758168af38f955c06d251b0b7.json @@ -0,0 +1,70 @@ +{ + "db_name": "PostgreSQL", + "query": "\n SELECT\n clients.id as \"id!\",\n clients.name as \"name!\",\n clients.icon_url as \"icon_url?\",\n clients.max_scopes as \"max_scopes!\",\n clients.secret_hash as \"secret_hash!\",\n clients.created as \"created!\",\n clients.created_by as \"created_by!\",\n uris.uri_ids as \"uri_ids?\",\n uris.uri_vals as \"uri_vals?\"\n FROM oauth_clients clients\n LEFT JOIN (\n SELECT client_id, array_agg(id) as uri_ids, array_agg(uri) as uri_vals\n FROM oauth_client_redirect_uris\n GROUP BY client_id\n ) uris ON clients.id = uris.client_id\n WHERE clients.id = ANY($1::bigint[])", + "describe": { + "columns": [ + { + "ordinal": 0, + "name": "id!", + "type_info": "Int8" + }, + { + "ordinal": 1, + "name": "name!", + "type_info": "Text" + }, + { + "ordinal": 2, + "name": "icon_url?", + "type_info": "Text" + }, + { + "ordinal": 3, + "name": "max_scopes!", + "type_info": "Int8" + }, + { + "ordinal": 4, + "name": "secret_hash!", + "type_info": "Text" + }, + { + "ordinal": 5, + "name": "created!", + "type_info": "Timestamptz" + }, + { + "ordinal": 6, + "name": "created_by!", + "type_info": "Int8" + }, + { + "ordinal": 7, + "name": "uri_ids?", + "type_info": "Int8Array" + }, + { + "ordinal": 8, + "name": "uri_vals?", + "type_info": "TextArray" + } + ], + "parameters": { + "Left": [ + "Int8Array" + ] + }, + "nullable": [ + true, + true, + true, + true, + true, + true, + true, + null, + null + ] + }, + "hash": "fdfb2433a8e407d42cec1791d67549ab5c23306758168af38f955c06d251b0b7" +} From 18fd30eb43517f5d88d6fefe9fd943e902ceb294 Mon Sep 17 00:00:00 2001 From: Jackson Kruger Date: Thu, 26 Oct 2023 18:46:15 -0500 Subject: [PATCH 37/72] Parse to OAuthClientId automatically through serde and actix --- src/auth/oauth/mod.rs | 14 +++++------ src/database/models/oauth_client_item.rs | 15 +----------- src/models/ids.rs | 11 --------- src/models/oauth_clients.rs | 4 +-- src/routes/v3/oauth_clients.rs | 31 +++++++++--------------- tests/common/api_v3/oauth.rs | 16 +----------- tests/common/api_v3/oauth_clients.rs | 14 ++--------- tests/common/dummy_data.rs | 4 +++ 8 files changed, 28 insertions(+), 81 deletions(-) diff --git a/src/auth/oauth/mod.rs b/src/auth/oauth/mod.rs index 13163236..0d64b53f 100644 --- a/src/auth/oauth/mod.rs +++ b/src/auth/oauth/mod.rs @@ -42,7 +42,7 @@ pub fn config(cfg: &mut ServiceConfig) { #[derive(Serialize, Deserialize)] pub struct OAuthInit { - pub client_id: String, + pub client_id: OAuthClientId, pub redirect_uri: Option, pub scope: Option, pub state: Option, @@ -75,7 +75,7 @@ pub async fn init_oauth( .await? .1; - let client_id: OAuthClientId = models::ids::OAuthClientId::parse(&oauth_info.client_id)?.into(); + let client_id = oauth_info.client_id; let client = DBOAuthClient::get(client_id, &**pool).await?; if let Some(client) = client { @@ -189,7 +189,7 @@ pub struct TokenRequest { pub grant_type: String, pub code: String, pub redirect_uri: Option, - pub client_id: String, + pub client_id: models::ids::OAuthClientId, } #[derive(Serialize, Deserialize)] @@ -209,8 +209,8 @@ pub async fn request_token( pool: Data, redis: Data, ) -> Result { - let req_client_id = models::ids::OAuthClientId::parse(&req_params.client_id)?.into(); - let client = DBOAuthClient::get(req_client_id, &**pool).await?; + let req_client_id = req_params.client_id; + let client = DBOAuthClient::get(req_client_id.into(), &**pool).await?; if let Some(client) = client { authenticate_client_token_request(&req, &client)?; @@ -231,7 +231,7 @@ pub async fn request_token( }) = flow { // https://datatracker.ietf.org/doc/html/rfc6749#section-4.1.3 - if req_client_id != client_id { + if req_client_id != client_id.into() { return Err(OAuthError::error(OAuthErrorType::UnauthorizedClient)); } @@ -285,7 +285,7 @@ pub async fn request_token( } } else { Err(OAuthError::error(OAuthErrorType::InvalidClientId( - req_client_id, + req_client_id.into(), ))) } } diff --git a/src/database/models/oauth_client_item.rs b/src/database/models/oauth_client_item.rs index 23be598b..48907bb8 100644 --- a/src/database/models/oauth_client_item.rs +++ b/src/database/models/oauth_client_item.rs @@ -1,12 +1,10 @@ use chrono::{DateTime, Utc}; -use futures::TryStreamExt; use itertools::Itertools; use serde::{Deserialize, Serialize}; use sha2::Digest; -use crate::models::{ids::base62_impl::parse_base62, pats::Scopes}; - use super::{DatabaseError, OAuthClientId, OAuthRedirectUriId, UserId}; +use crate::models::pats::Scopes; #[derive(Deserialize, Serialize, Clone, Debug)] pub struct OAuthRedirectUri { @@ -78,17 +76,6 @@ impl OAuthClient { Ok(Self::get_many(&[id], exec).await?.into_iter().next()) } - pub async fn get_many_str( - ids: &[impl ToString], - exec: impl sqlx::Executor<'_, Database = sqlx::Postgres>, - ) -> Result, DatabaseError> { - let ids = ids - .iter() - .flat_map(|x| parse_base62(&x.to_string()).map(|x| OAuthClientId(x as i64))) - .collect::>(); - Self::get_many(&ids, exec).await - } - pub async fn get_many( ids: &[OAuthClientId], exec: impl sqlx::Executor<'_, Database = sqlx::Postgres>, diff --git a/src/models/ids.rs b/src/models/ids.rs index 6c8bf79a..8cea089f 100644 --- a/src/models/ids.rs +++ b/src/models/ids.rs @@ -105,16 +105,6 @@ macro_rules! impl_base62_display { } impl_base62_display!(Base62Id); -macro_rules! impl_base62_parse { - ($struct:ty) => { - impl $struct { - pub fn parse(value: &str) -> Result { - Ok(Self(base62_impl::parse_base62(value)?)) - } - } - }; -} - macro_rules! base62_id_impl { ($struct:ty, $cons:expr) => { from_base62id!($struct, $cons;); @@ -135,7 +125,6 @@ base62_id_impl!(SessionId, SessionId); base62_id_impl!(PatId, PatId); base62_id_impl!(ImageId, ImageId); base62_id_impl!(OAuthClientId, OAuthClientId); -impl_base62_parse!(OAuthClientId); base62_id_impl!(OAuthRedirectUriId, OAuthRedirectUriId); base62_id_impl!(OAuthClientAuthorizationId, OAuthClientAuthorizationId); diff --git a/src/models/oauth_clients.rs b/src/models/oauth_clients.rs index 58797433..fd8a8c6d 100644 --- a/src/models/oauth_clients.rs +++ b/src/models/oauth_clients.rs @@ -64,12 +64,12 @@ pub struct OAuthClientAuthorization { #[derive(Deserialize, Serialize)] pub struct GetOAuthClientsRequest { - pub ids: Vec, + pub ids: Vec, } #[derive(Deserialize, Serialize)] pub struct DeleteOAuthClientQueryParam { - pub client_id: String, + pub client_id: OAuthClientId, } impl From for OAuthClient { diff --git a/src/routes/v3/oauth_clients.rs b/src/routes/v3/oauth_clients.rs index 6ee3561b..59d3f68c 100644 --- a/src/routes/v3/oauth_clients.rs +++ b/src/routes/v3/oauth_clients.rs @@ -39,6 +39,7 @@ use crate::{ }; use crate::database::models::oauth_client_item::OAuthClient as DBOAuthClient; +use crate::models::ids::OAuthClientId as ApiOAuthClientId; pub fn config(cfg: &mut web::ServiceConfig) { cfg.service(get_user_clients); @@ -94,7 +95,7 @@ pub async fn get_user_clients( #[get("app/{id}")] pub async fn get_client( req: HttpRequest, - id: web::Path, + id: web::Path, pool: web::Data, redis: web::Data, session_queue: web::Data, @@ -198,7 +199,7 @@ pub async fn oauth_client_create<'a>( #[delete("app/{id}")] pub async fn oauth_client_delete<'a>( req: HttpRequest, - client_id: web::Path, + client_id: web::Path, pool: web::Data, redis: web::Data, session_queue: web::Data, @@ -213,7 +214,7 @@ pub async fn oauth_client_delete<'a>( .await? .1; - let client = get_oauth_client_from_str_id(client_id.into_inner(), &**pool).await?; + let client = OAuthClient::get(client_id.into_inner().into(), &**pool).await?; if let Some(client) = client { client.validate_authorized(Some(¤t_user))?; OAuthClient::remove(client.id, &**pool).await?; @@ -247,7 +248,7 @@ pub struct OAuthClientEdit { #[patch("app/{id}")] pub async fn oauth_client_edit( req: HttpRequest, - client_id: web::Path, + client_id: web::Path, client_updates: web::Json, pool: web::Data, redis: web::Data, @@ -274,9 +275,7 @@ pub async fn oauth_client_edit( return Err(ApiError::InvalidInput("No changes provided".to_string())); } - if let Some(existing_client) = - get_oauth_client_from_str_id(client_id.into_inner(), &**pool).await? - { + if let Some(existing_client) = OAuthClient::get(client_id.into_inner().into(), &**pool).await? { existing_client.validate_authorized(Some(¤t_user))?; let mut updated_client = existing_client.clone(); @@ -359,9 +358,8 @@ pub async fn revoke_oauth_authorization( .await? .1; - let client_id = models::ids::OAuthClientId::parse(&info.into_inner().client_id)?; - - OAuthClientAuthorization::remove(client_id.into(), current_user.id.into(), &**pool).await?; + OAuthClientAuthorization::remove(info.client_id.into(), current_user.id.into(), &**pool) + .await?; Ok(HttpResponse::Ok().body("")) } @@ -374,14 +372,6 @@ fn generate_oauth_client_secret() -> String { .collect::() } -async fn get_oauth_client_from_str_id( - client_id: String, - exec: impl sqlx::Executor<'_, Database = sqlx::Postgres>, -) -> Result, ApiError> { - let client_id = models::ids::OAuthClientId::parse(&client_id)?.into(); - Ok(OAuthClient::get(client_id, exec).await?) -} - async fn create_redirect_uris( uri_strings: impl IntoIterator, client_id: OAuthClientId, @@ -429,7 +419,7 @@ async fn edit_redirects( } pub async fn get_clients_inner( - ids: &[String], + ids: &[ApiOAuthClientId], req: HttpRequest, pool: web::Data, redis: web::Data, @@ -445,7 +435,8 @@ pub async fn get_clients_inner( .await? .1; - let clients = OAuthClient::get_many_str(&ids, &**pool).await?; + let ids: Vec = ids.into_iter().map(|i| (*i).into()).collect(); + let clients = OAuthClient::get_many(&ids, &**pool).await?; clients .iter() .validate_all_authorized(Some(¤t_user))?; diff --git a/tests/common/api_v3/oauth.rs b/tests/common/api_v3/oauth.rs index 7778695b..6212dffa 100644 --- a/tests/common/api_v3/oauth.rs +++ b/tests/common/api_v3/oauth.rs @@ -93,26 +93,12 @@ impl ApiV3 { grant_type: "authorization_code".to_string(), code: auth_code, redirect_uri: original_redirect_uri, - client_id, + client_id: serde_json::from_str(&format!("\"{}\"", client_id)).unwrap(), }) .to_request(), ) .await } - - pub async fn get_oauth_access_token( - &self, - auth_code: String, - original_redirect_uri: Option, - client_id: String, - client_secret: &str, - ) -> String { - let response = self - .oauth_token(auth_code, original_redirect_uri, client_id, client_secret) - .await; - let token_resp: TokenResponse = test::read_body_json(response).await; - token_resp.access_token - } } pub fn generate_authorize_uri( diff --git a/tests/common/api_v3/oauth_clients.rs b/tests/common/api_v3/oauth_clients.rs index 61c47321..a08ad1d6 100644 --- a/tests/common/api_v3/oauth_clients.rs +++ b/tests/common/api_v3/oauth_clients.rs @@ -5,7 +5,7 @@ use actix_web::{ }; use labrinth::{ models::{ - oauth_clients::{GetOAuthClientsRequest, OAuthClient, OAuthClientAuthorization}, + oauth_clients::{OAuthClient, OAuthClientAuthorization}, pats::Scopes, }, routes::v3::oauth_clients::OAuthClientEdit, @@ -59,16 +59,6 @@ impl ApiV3 { self.call(req).await } - pub async fn get_oauth_clients(&self, clients: Vec, pat: &str) -> ServiceResponse { - let req = TestRequest::get() - .uri("/v3/oauth/apps") - .set_json(GetOAuthClientsRequest { ids: clients }) - .append_header((AUTHORIZATION, pat)) - .to_request(); - - self.call(req).await - } - pub async fn edit_oauth_client( &self, client_id: &str, @@ -76,7 +66,7 @@ impl ApiV3 { pat: &str, ) -> ServiceResponse { let req = TestRequest::patch() - .uri(&format!("/v3/oauth/app/{}", client_id)) + .uri(&format!("/v3/oauth/app/{}", urlencoding::encode(client_id))) .set_json(edit) .append_header((AUTHORIZATION, pat)) .to_request(); diff --git a/tests/common/dummy_data.rs b/tests/common/dummy_data.rs index 8e4b57fa..ed3b7f08 100644 --- a/tests/common/dummy_data.rs +++ b/tests/common/dummy_data.rs @@ -1,4 +1,5 @@ #![allow(dead_code)] +use actix_http::StatusCode; use actix_web::test::{self, TestRequest}; use labrinth::{ models::projects::Project, @@ -13,6 +14,7 @@ use crate::common::{actix::AppendsMultipart, database::USER_USER_PAT}; use super::{ actix::{MultipartSegment, MultipartSegmentData}, + asserts::assert_status, database::USER_USER_ID, environment::TestEnvironment, get_json_val_str, @@ -324,6 +326,7 @@ pub async fn get_project_beta(test_env: &TestEnvironment) -> (Project, Version) .append_header(("Authorization", USER_USER_PAT)) .to_request(); let resp = test_env.call(req).await; + assert_status(&resp, StatusCode::OK); let project: Project = test::read_body_json(resp).await; // Get project's versions @@ -332,6 +335,7 @@ pub async fn get_project_beta(test_env: &TestEnvironment) -> (Project, Version) .append_header(("Authorization", USER_USER_PAT)) .to_request(); let resp = test_env.call(req).await; + assert_status(&resp, StatusCode::OK); let versions: Vec = test::read_body_json(resp).await; let version = versions.into_iter().next().unwrap(); From 54a1c7c41cd5fc66df2637e2d3bc4dc67bebb446 Mon Sep 17 00:00:00 2001 From: OmegaJak Date: Fri, 20 Oct 2023 18:14:32 -0500 Subject: [PATCH 38/72] First steps --- ...20223602_user_and_organization_follows.sql | 12 +++ src/database/models/creator_follows.rs | 79 +++++++++++++++++++ src/database/models/mod.rs | 1 + src/database/models/project_item.rs | 10 --- 4 files changed, 92 insertions(+), 10 deletions(-) create mode 100644 migrations/20231020223602_user_and_organization_follows.sql create mode 100644 src/database/models/creator_follows.rs diff --git a/migrations/20231020223602_user_and_organization_follows.sql b/migrations/20231020223602_user_and_organization_follows.sql new file mode 100644 index 00000000..742adaab --- /dev/null +++ b/migrations/20231020223602_user_and_organization_follows.sql @@ -0,0 +1,12 @@ +CREATE TABLE user_follows( + follower_id bigint NOT NULL REFERENCES users ON DELETE CASCADE, + target_id bigint NOT NULL REFERENCES users (id) ON DELETE CASCADE, + created timestamptz DEFAULT CURRENT_TIMESTAMP NOT NULL, + PRIMARY KEY (follower_id, target_id) +); +CREATE TABLE organization_follows( + follower_id bigint NOT NULL REFERENCES users ON DELETE CASCADE, + target_id bigint NOT NULL REFERENCES organizations (id) ON DELETE CASCADE, + created timestamptz DEFAULT CURRENT_TIMESTAMP NOT NULL, + PRIMARY KEY (follower_id, target_id) +); \ No newline at end of file diff --git a/src/database/models/creator_follows.rs b/src/database/models/creator_follows.rs new file mode 100644 index 00000000..f53d6c9e --- /dev/null +++ b/src/database/models/creator_follows.rs @@ -0,0 +1,79 @@ +use itertools::Itertools; + +use super::{OrganizationId, UserId}; +use crate::database::models::DatabaseError; + +pub struct UserFollow { + follower_id: UserId, + target_id: UserId, +} + +pub struct OrganizationFollow { + follower_id: UserId, + target_id: OrganizationId, +} + +impl UserFollow { + pub async fn insert( + &self, + exec: impl sqlx::Executor<'_, Database = sqlx::Postgres>, + ) -> Result<(), DatabaseError> { + sqlx::query!( + " + INSERT INTO user_follows (follower_id, target_id) VALUES ($1, $2) + ", + self.follower_id.0, + self.target_id.0 + ) + .execute(exec) + .await?; + + Ok(()) + } + + pub async fn get_follows_targeting( + user_id: UserId, + exec: impl sqlx::Executor<'_, Database = sqlx::Postgres>, + ) -> Result, DatabaseError> { + let res = sqlx::query!( + " + SELECT follower_id, target_id FROM user_follows + WHERE target_id=$1 + ", + user_id.0 + ) + .fetch_all(exec) + .await?; + + Ok(res + .into_iter() + .map(|r| UserFollow { + follower_id: UserId(r.follower_id), + target_id: UserId(r.target_id), + }) + .collect_vec()) + } + + pub async fn get_follows_from( + user_id: UserId, + exec: impl sqlx::Executor<'_, Database = sqlx::Postgres>, + ) -> Result, DatabaseError> { + let res = sqlx::query!( + " + SELECT follower_id, target_id FROM user_follows + WHERE follower_id=$1 + ", + user_id.0 + ) + .fetch_all(exec) + .await?; + + Ok(res + .into_iter() + .map(|r| UserFollow { + follower_id: UserId(r.follower_id), + target_id: UserId(r.target_id), + }) + .collect_vec()) + } +} diff --git a/src/database/models/mod.rs b/src/database/models/mod.rs index 5d5bc34f..217c436a 100644 --- a/src/database/models/mod.rs +++ b/src/database/models/mod.rs @@ -2,6 +2,7 @@ use thiserror::Error; pub mod categories; pub mod collection_item; +pub mod creator_follows; pub mod flow_item; pub mod ids; pub mod image_item; diff --git a/src/database/models/project_item.rs b/src/database/models/project_item.rs index 365dd473..a3fddec2 100644 --- a/src/database/models/project_item.rs +++ b/src/database/models/project_item.rs @@ -348,16 +348,6 @@ impl Project { if let Some(project) = project { Project::clear_cache(id, project.inner.slug, Some(true), redis).await?; - sqlx::query!( - " - DELETE FROM mod_follows - WHERE mod_id = $1 - ", - id as ProjectId - ) - .execute(&mut **transaction) - .await?; - sqlx::query!( " DELETE FROM mods_gallery From 20d08f5761497ea1e06cba7e35a01b08ef545c66 Mon Sep 17 00:00:00 2001 From: OmegaJak Date: Mon, 23 Oct 2023 13:05:35 -0500 Subject: [PATCH 39/72] Some progress toward user following before needing to upgrade sqlx --- src/database/models/creator_follows.rs | 27 ++++++-- src/routes/v3/mod.rs | 3 + src/routes/v3/users.rs | 86 ++++++++++++++++++++++++++ 3 files changed, 112 insertions(+), 4 deletions(-) create mode 100644 src/routes/v3/users.rs diff --git a/src/database/models/creator_follows.rs b/src/database/models/creator_follows.rs index f53d6c9e..32bf1c94 100644 --- a/src/database/models/creator_follows.rs +++ b/src/database/models/creator_follows.rs @@ -4,13 +4,13 @@ use super::{OrganizationId, UserId}; use crate::database::models::DatabaseError; pub struct UserFollow { - follower_id: UserId, - target_id: UserId, + pub follower_id: UserId, + pub target_id: UserId, } pub struct OrganizationFollow { - follower_id: UserId, - target_id: OrganizationId, + pub follower_id: UserId, + pub target_id: OrganizationId, } impl UserFollow { @@ -76,4 +76,23 @@ impl UserFollow { }) .collect_vec()) } + + pub async fn unfollow( + follower_id: UserId, + target_id: UserId, + exec: impl sqlx::Executor<'_, Database = sqlx::Postgres>, + ) -> Result<(), DatabaseError> { + sqlx::query!( + " + DELETE FROM user_follows + WHERE follower_id=$1 AND target_id=$2 + ", + follower_id.0, + target_id.0, + ) + .execute(exec) + .await?; + + Ok(()) + } } diff --git a/src/routes/v3/mod.rs b/src/routes/v3/mod.rs index d90429c2..59793b8a 100644 --- a/src/routes/v3/mod.rs +++ b/src/routes/v3/mod.rs @@ -1,3 +1,5 @@ +pub mod users; + pub use super::ApiError; use crate::{auth::oauth, util::cors::default_cors}; use actix_web::{web, HttpResponse}; @@ -10,6 +12,7 @@ pub fn config(cfg: &mut web::ServiceConfig) { web::scope("v3") .wrap(default_cors()) .route("", web::get().to(hello_world)) + .configure(users::config) .configure(oauth::config) .configure(oauth_clients::config), ); diff --git a/src/routes/v3/users.rs b/src/routes/v3/users.rs new file mode 100644 index 00000000..02d4dcae --- /dev/null +++ b/src/routes/v3/users.rs @@ -0,0 +1,86 @@ +use crate::{ + auth::get_user_from_headers, + database::{self, models::DatabaseError, redis::RedisPool}, + models::pats::Scopes, + queue::session::AuthQueue, + routes::ApiError, +}; +use actix_web::{ + delete, post, + web::{self}, + HttpRequest, HttpResponse, +}; +use sqlx::PgPool; + +use database::models as db_models; +use database::models::creator_follows::UserFollow as DBUserFollow; + +pub fn config(cfg: &mut web::ServiceConfig) { + cfg.service(web::scope("user").service(user_follow)); +} + +#[post("{id}/follow")] +pub async fn user_follow( + req: HttpRequest, + target_id: web::Path, + pool: web::Data, + redis: web::Data, + session_queue: web::Data, +) -> Result { + let (_, current_user) = get_user_from_headers( + &req, + &**pool, + &redis, + &session_queue, + Some(&[Scopes::USER_WRITE]), + ) + .await?; + + let target_user = db_models::User::get(&target_id.into_inner(), &**pool, &redis) + .await? + .ok_or_else(|| ApiError::InvalidInput("The specified user does not exist!".to_string()))?; + + DBUserFollow { + follower_id: current_user.id.into(), + target_id: target_user.id, + } + .insert(&**pool) + .await + .map_err(|e| match e { + DatabaseError::Database(e) + if e.as_database_error() + .is_some_and(|e| /*e.is_unique_violation()*/todo!()) => + { + ApiError::InvalidInput("You are already following this user!".to_string()) + } + e => e.into(), + })?; + + Ok(HttpResponse::NoContent().body("")) +} + +#[delete("{id}/follow")] +pub async fn user_unfollow( + req: HttpRequest, + target_id: web::Path, + pool: web::Data, + redis: web::Data, + session_queue: web::Data, +) -> Result { + let (_, current_user) = get_user_from_headers( + &req, + &**pool, + &redis, + &session_queue, + Some(&[Scopes::USER_WRITE]), + ) + .await?; + + let target_user = db_models::User::get(&target_id.into_inner(), &**pool, &redis) + .await? + .ok_or_else(|| ApiError::InvalidInput("The specified user does not exist!".to_string()))?; + + DBUserFollow::unfollow(current_user.id.into(), target_user.id, &**pool).await?; + + Ok(HttpResponse::NoContent().body("")) +} From e3ce8d2cd8251625b6a06708b3f479339996edd3 Mon Sep 17 00:00:00 2001 From: OmegaJak Date: Mon, 23 Oct 2023 16:27:18 -0500 Subject: [PATCH 40/72] little things before starting on feeds --- src/database/models/creator_follows.rs | 6 +++--- src/routes/v3/users.rs | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/database/models/creator_follows.rs b/src/database/models/creator_follows.rs index 32bf1c94..2139289f 100644 --- a/src/database/models/creator_follows.rs +++ b/src/database/models/creator_follows.rs @@ -31,8 +31,8 @@ impl UserFollow { Ok(()) } - pub async fn get_follows_targeting( - user_id: UserId, + pub async fn get_followers( + target_id: UserId, exec: impl sqlx::Executor<'_, Database = sqlx::Postgres>, ) -> Result, DatabaseError> { let res = sqlx::query!( @@ -40,7 +40,7 @@ impl UserFollow { SELECT follower_id, target_id FROM user_follows WHERE target_id=$1 ", - user_id.0 + target_id.0 ) .fetch_all(exec) .await?; diff --git a/src/routes/v3/users.rs b/src/routes/v3/users.rs index 02d4dcae..0cee2be5 100644 --- a/src/routes/v3/users.rs +++ b/src/routes/v3/users.rs @@ -49,7 +49,7 @@ pub async fn user_follow( .map_err(|e| match e { DatabaseError::Database(e) if e.as_database_error() - .is_some_and(|e| /*e.is_unique_violation()*/todo!()) => + .is_some_and(|e| e.is_unique_violation()) => { ApiError::InvalidInput("You are already following this user!".to_string()) } From 20cef003ec9f8b19694c5ebd593cd336017c9f30 Mon Sep 17 00:00:00 2001 From: OmegaJak Date: Mon, 23 Oct 2023 18:46:45 -0500 Subject: [PATCH 41/72] Starting work towards feeds (mostly dealing with dynamic DB stuff) --- migrations/20231023212957_feeds.sql | 11 ++ src/database/models/event_item.rs | 211 ++++++++++++++++++++++++++++ src/database/models/ids.rs | 69 +++++++++ src/database/models/mod.rs | 9 ++ 4 files changed, 300 insertions(+) create mode 100644 migrations/20231023212957_feeds.sql create mode 100644 src/database/models/event_item.rs diff --git a/migrations/20231023212957_feeds.sql b/migrations/20231023212957_feeds.sql new file mode 100644 index 00000000..7738659d --- /dev/null +++ b/migrations/20231023212957_feeds.sql @@ -0,0 +1,11 @@ +CREATE TYPE event_type AS ENUM ('project_created'); +CREATE TYPE id_type AS ENUM ('project_id', 'user_id', 'organization_id'); +CREATE TYPE dynamic_id AS (id bigint, type id_type); +CREATE TABLE events( + id bigint NOT NULL PRIMARY KEY, + target_id dynamic_id NOT NULL, + triggerer_id dynamic_id NULL, + type event_type NOT NULL, + metadata jsonb NULL, + created timestamptz NOT NULL DEFAULT CURRENT_TIMESTAMP +); \ No newline at end of file diff --git a/src/database/models/event_item.rs b/src/database/models/event_item.rs new file mode 100644 index 00000000..10dadfd0 --- /dev/null +++ b/src/database/models/event_item.rs @@ -0,0 +1,211 @@ +use std::convert::{TryFrom, TryInto}; + +use super::{ + dynamic::{DynamicId, IdType}, + DatabaseError, EventId, OrganizationId, ProjectId, UserId, +}; +use chrono::{DateTime, Utc}; +use itertools::Itertools; +use sqlx::postgres::{PgHasArrayType, PgTypeInfo}; + +#[derive(sqlx::Type)] +#[sqlx(type_name = "event_type", rename_all = "snake_case")] +pub enum EventType { + ProjectCreated, +} + +pub enum CreatorId { + User(UserId), + Organization(OrganizationId), +} + +pub enum EventData { + ProjectCreated { + project_id: ProjectId, + creator_id: CreatorId, + }, +} + +pub struct Event { + pub id: EventId, + pub event_data: EventData, +} + +impl From for DynamicId { + fn from(value: CreatorId) -> Self { + match value { + CreatorId::User(user_id) => user_id.into(), + CreatorId::Organization(organization_id) => organization_id.into(), + } + } +} + +impl TryFrom for CreatorId { + type Error = DatabaseError; + + fn try_from(value: DynamicId) -> Result { + match value.id_type { + IdType::UserId => Ok(CreatorId::User(value.try_into()?)), + _ => Ok(CreatorId::Organization(value.try_into()?)), + } + } +} + +impl From for RawEvent { + fn from(value: Event) -> Self { + match value.event_data { + EventData::ProjectCreated { + project_id, + creator_id, + } => RawEvent { + id: value.id, + target_id: project_id.into(), + triggerer_id: Some(creator_id.into()), + event_type: EventType::ProjectCreated, + metadata: None, + created: None, + }, + } + } +} + +impl TryFrom for Event { + type Error = DatabaseError; + + fn try_from(value: RawEvent) -> Result { + Ok(match value.event_type { + EventType::ProjectCreated => Event { + id: value.id, + event_data: EventData::ProjectCreated { + project_id: value.target_id.try_into()?, + creator_id: value.triggerer_id.map_or_else( + || { + Err(DatabaseError::UnexpectedNull( + "triggerer_id should not be null for project creation".to_string(), + )) + }, + |v| v.try_into(), + )?, + }, + }, + }) + } +} + +struct RawEvent { + pub id: EventId, + pub target_id: DynamicId, + pub triggerer_id: Option, + pub event_type: EventType, + // #[serde::serde(flatten)] //TODO: is this necessary? + pub metadata: Option, + pub created: Option>, +} + +impl PgHasArrayType for EventType { + fn array_type_info() -> sqlx::postgres::PgTypeInfo { + PgTypeInfo::with_name("event_type") + } +} + +impl PgHasArrayType for DynamicId { + fn array_type_info() -> PgTypeInfo { + PgTypeInfo::with_name("dynamic_id") + } +} + +impl Event { + pub async fn insert_many( + events: Vec, + transaction: &mut sqlx::Transaction<'_, sqlx::Postgres>, + ) -> Result<(), DatabaseError> { + RawEvent::insert_many( + events.into_iter().map(|e| e.into()).collect_vec(), + transaction, + ) + .await + } + + pub async fn get_triggerer_feed( + triggerer_id: DynamicId, + exec: impl sqlx::Executor<'_, Database = sqlx::Postgres>, + ) -> Result, DatabaseError> { + Ok(sqlx::query_as!( + RawEvent, + r#" + SELECT + id, + target_id as "target_id: _", + triggerer_id as "triggerer_id: _", + type as "event_type: _", + metadata, + created + FROM events + WHERE triggerer_id=$1 + "#, + triggerer_id as DynamicId + ) + .fetch_all(exec) + .await? + .into_iter() + .map(|r| r.try_into()) + .collect::, _>>()?) + } +} + +impl RawEvent { + pub async fn insert_many( + events: Vec, + transaction: &mut sqlx::Transaction<'_, sqlx::Postgres>, + ) -> Result<(), DatabaseError> { + let (ids, target_ids, triggerer_ids, event_types, metadata): ( + Vec<_>, + Vec<_>, + Vec<_>, + Vec<_>, + Vec<_>, + ) = events + .into_iter() + .map(|e| { + ( + e.id.0, + e.target_id, + e.triggerer_id, + e.event_type, + e.metadata, + ) + }) + .multiunzip(); + sqlx::query!( + " + INSERT INTO events ( + id, + target_id.id, + target_id.type, + triggerer_id.id, + triggerer_id.type, + type, + metadata + ) + SELECT * FROM UNNEST ( + $1::bigint[], + $2::dynamic_id[], + $3::dynamic_id[], + $4::event_type[], + $5::jsonb[] + ) + ", + &ids[..], + &target_ids[..] as &[DynamicId], + &triggerer_ids[..] as &[Option], + &event_types[..] as &[EventType], + &metadata[..] as &[Option] + ) + .execute(&mut **transaction) + .await?; + + Ok(()) + } + + // pub async fn get_triggerer_feed(triggerer_id: DynamicId, exec: impl sqlx::Executor<'_, Database = sqlx::Postgres>,) -> Result< +} diff --git a/src/database/models/ids.rs b/src/database/models/ids.rs index b8953462..d864b224 100644 --- a/src/database/models/ids.rs +++ b/src/database/models/ids.rs @@ -183,6 +183,13 @@ generate_ids!( "SELECT EXISTS(SELECT 1 FROM oauth_access_tokens WHERE id=$1)", OAuthAccessTokenId ); +generate_ids!( + pub generate_event_id, + EventId, + 8, + "SELECT EXISTS(SELECT 1 FROM events WHERE id=$1)", + EventId +); #[derive(Copy, Clone, Debug, PartialEq, Eq, Type, Hash, Serialize, Deserialize)] #[sqlx(transparent)] @@ -286,6 +293,16 @@ pub struct OAuthRedirectUriId(pub i64); #[sqlx(transparent)] pub struct OAuthAccessTokenId(pub i64); +#[derive(Copy, Clone, Debug, Type, Serialize, Deserialize, Eq, PartialEq, Hash)] +#[sqlx(transparent)] +pub struct EventId(pub i64); + +impl From for EventId { + fn from(value: i64) -> Self { + EventId(value) + } +} + use crate::models::ids; impl From for ProjectId { @@ -408,6 +425,7 @@ impl From for ids::PatId { ids::PatId(id.0 as u64) } } + impl From for ids::OAuthClientId { fn from(id: OAuthClientId) -> Self { ids::OAuthClientId(id.0 as u64) @@ -428,3 +446,54 @@ impl From for ids::OAuthClientAuthorizationId { ids::OAuthClientAuthorizationId(id.0 as u64) } } + +pub mod dynamic { + use super::*; + + #[derive(sqlx::Type, PartialEq, Debug)] + #[sqlx(type_name = "id_type", rename_all = "snake_case")] + pub enum IdType { + ProjectId, + UserId, + OrganizationId, + } + + #[derive(sqlx::Type)] + #[sqlx(type_name = "dynamic_id")] + pub struct DynamicId { + pub id: i64, + pub id_type: IdType, + } + + macro_rules! from_static_impl { + ($struct:ident, $variant:expr) => { + impl From<$struct> for DynamicId { + fn from(value: $struct) -> Self { + DynamicId { + id: value.0, + id_type: $variant, + } + } + } + + impl std::convert::TryFrom for $struct { + type Error = DatabaseError; + + fn try_from(value: DynamicId) -> Result { + if value.id_type == $variant { + return Ok($struct(value.id)); + } + + Err(DatabaseError::DynamicIdConversionError { + expected: $variant, + actual: value.id_type, + }) + } + } + }; + } + + from_static_impl!(ProjectId, IdType::ProjectId); + from_static_impl!(UserId, IdType::UserId); + from_static_impl!(OrganizationId, IdType::OrganizationId); +} diff --git a/src/database/models/mod.rs b/src/database/models/mod.rs index 217c436a..41ff5fa8 100644 --- a/src/database/models/mod.rs +++ b/src/database/models/mod.rs @@ -3,6 +3,7 @@ use thiserror::Error; pub mod categories; pub mod collection_item; pub mod creator_follows; +pub mod event_item; pub mod flow_item; pub mod ids; pub mod image_item; @@ -32,10 +33,18 @@ pub use thread_item::{Thread, ThreadMessage}; pub use user_item::User; pub use version_item::Version; +use self::dynamic::IdType; + #[derive(Error, Debug)] pub enum DatabaseError { #[error("Error while interacting with the database: {0}")] Database(#[from] sqlx::Error), + #[error( + "Error converting from a dynamic id in the database (expected {expected:#?}, was {actual:#?})" + )] + DynamicIdConversionError { expected: IdType, actual: IdType }, + #[error("Didn't expect value to be null: {0}")] + UnexpectedNull(String), #[error("Error while trying to generate random ID")] RandomId, #[error("Error while interacting with the cache: {0}")] From cae428877930e34b311169ff301d9233a9b965d5 Mon Sep 17 00:00:00 2001 From: OmegaJak Date: Tue, 24 Oct 2023 10:14:56 -0500 Subject: [PATCH 42/72] insert project create event --- src/database/models/event_item.rs | 35 +++++++++++++++++++++++-------- src/routes/v2/project_creation.rs | 28 ++++++++++++++++++++++++- 2 files changed, 53 insertions(+), 10 deletions(-) diff --git a/src/database/models/event_item.rs b/src/database/models/event_item.rs index 10dadfd0..75080fa8 100644 --- a/src/database/models/event_item.rs +++ b/src/database/models/event_item.rs @@ -2,7 +2,7 @@ use std::convert::{TryFrom, TryInto}; use super::{ dynamic::{DynamicId, IdType}, - DatabaseError, EventId, OrganizationId, ProjectId, UserId, + generate_event_id, DatabaseError, EventId, OrganizationId, ProjectId, UserId, }; use chrono::{DateTime, Utc}; use itertools::Itertools; @@ -27,8 +27,18 @@ pub enum EventData { } pub struct Event { - pub id: EventId, - pub event_data: EventData, + id: EventId, + event_data: EventData, +} + +impl Event { + pub async fn new( + event_data: EventData, + transaction: &mut sqlx::Transaction<'_, sqlx::Postgres>, + ) -> Result { + let id = generate_event_id(transaction).await?; + Ok(Self { id, event_data }) + } } impl From for DynamicId { @@ -115,6 +125,13 @@ impl PgHasArrayType for DynamicId { } impl Event { + pub async fn insert( + self, + transaction: &mut sqlx::Transaction<'_, sqlx::Postgres>, + ) -> Result<(), DatabaseError> { + Self::insert_many(vec![self], transaction).await + } + pub async fn insert_many( events: Vec, transaction: &mut sqlx::Transaction<'_, sqlx::Postgres>, @@ -127,7 +144,8 @@ impl Event { } pub async fn get_triggerer_feed( - triggerer_id: DynamicId, + triggerer_ids: &[DynamicId], + event_types: &[EventType], exec: impl sqlx::Executor<'_, Database = sqlx::Postgres>, ) -> Result, DatabaseError> { Ok(sqlx::query_as!( @@ -140,10 +158,11 @@ impl Event { type as "event_type: _", metadata, created - FROM events - WHERE triggerer_id=$1 + FROM events e + WHERE triggerer_id=ANY($1) AND type=ANY($2) "#, - triggerer_id as DynamicId + &triggerer_ids[..] as &[DynamicId], + &event_types[..] as &[EventType] ) .fetch_all(exec) .await? @@ -206,6 +225,4 @@ impl RawEvent { Ok(()) } - - // pub async fn get_triggerer_feed(triggerer_id: DynamicId, exec: impl sqlx::Executor<'_, Database = sqlx::Postgres>,) -> Result< } diff --git a/src/routes/v2/project_creation.rs b/src/routes/v2/project_creation.rs index 91db4d0c..f7652fb2 100644 --- a/src/routes/v2/project_creation.rs +++ b/src/routes/v2/project_creation.rs @@ -1,5 +1,6 @@ use super::version_creation::InitialVersionData; use crate::auth::{get_user_from_headers, AuthenticationError}; +use crate::database::models::event_item::{CreatorId, Event, EventData}; use crate::database::models::thread_item::ThreadBuilder; use crate::database::models::{self, image_item, User}; use crate::database::redis::RedisPool; @@ -748,11 +749,15 @@ async fn project_create_inner( } } + let organization_id = project_create_data.organization_id; + insert_project_create_event(project_id, organization_id, ¤t_user, transaction) + .await?; + let project_builder_actual = models::project_item::ProjectBuilder { project_id: project_id.into(), project_type_id, team_id, - organization_id: project_create_data.organization_id, + organization_id, title: project_create_data.title, description: project_create_data.description, body: project_create_data.body, @@ -889,6 +894,27 @@ async fn project_create_inner( } } +async fn insert_project_create_event( + project_id: ProjectId, + organization_id: Option, + current_user: &crate::models::users::User, + transaction: &mut sqlx::Transaction<'_, sqlx::Postgres>, +) -> Result<(), CreateError> { + let event = Event::new( + EventData::ProjectCreated { + project_id: project_id.into(), + creator_id: organization_id.map_or_else( + || CreatorId::User(current_user.id.into()), + |org| CreatorId::Organization(org), + ), + }, + transaction, + ) + .await?; + event.insert(transaction).await?; + Ok(()) +} + async fn create_initial_version( version_data: &InitialVersionData, project_id: ProjectId, From d82c936f6a1e9605901bf34f84e0de11f2220022 Mon Sep 17 00:00:00 2001 From: OmegaJak Date: Tue, 24 Oct 2023 11:50:30 -0500 Subject: [PATCH 43/72] More generalized get from events --- migrations/20231023212957_feeds.sql | 14 ++++++++-- src/database/models/event_item.rs | 40 +++++++++++++++++++++-------- src/database/models/ids.rs | 4 +-- 3 files changed, 44 insertions(+), 14 deletions(-) diff --git a/migrations/20231023212957_feeds.sql b/migrations/20231023212957_feeds.sql index 7738659d..a8301aca 100644 --- a/migrations/20231023212957_feeds.sql +++ b/migrations/20231023212957_feeds.sql @@ -1,11 +1,21 @@ CREATE TYPE event_type AS ENUM ('project_created'); CREATE TYPE id_type AS ENUM ('project_id', 'user_id', 'organization_id'); -CREATE TYPE dynamic_id AS (id bigint, type id_type); +CREATE TYPE dynamic_id AS (id bigint, id_type id_type); CREATE TABLE events( id bigint NOT NULL PRIMARY KEY, target_id dynamic_id NOT NULL, triggerer_id dynamic_id NULL, - type event_type NOT NULL, + event_type event_type NOT NULL, metadata jsonb NULL, created timestamptz NOT NULL DEFAULT CURRENT_TIMESTAMP +); +CREATE INDEX events_targets ON events ( + ((target_id).id), + ((target_id).id_type), + event_type +); +CREATE INDEX events_triggerers ON events ( + ((triggerer_id).id), + ((triggerer_id).id_type), + event_type ); \ No newline at end of file diff --git a/src/database/models/event_item.rs b/src/database/models/event_item.rs index 75080fa8..30c039fa 100644 --- a/src/database/models/event_item.rs +++ b/src/database/models/event_item.rs @@ -8,7 +8,7 @@ use chrono::{DateTime, Utc}; use itertools::Itertools; use sqlx::postgres::{PgHasArrayType, PgTypeInfo}; -#[derive(sqlx::Type)] +#[derive(sqlx::Type, Clone, Copy)] #[sqlx(type_name = "event_type", rename_all = "snake_case")] pub enum EventType { ProjectCreated, @@ -112,6 +112,11 @@ struct RawEvent { pub created: Option>, } +pub struct EventSelector { + pub id: DynamicId, + pub event_type: EventType, +} + impl PgHasArrayType for EventType { fn array_type_info() -> sqlx::postgres::PgTypeInfo { PgTypeInfo::with_name("event_type") @@ -143,11 +148,19 @@ impl Event { .await } - pub async fn get_triggerer_feed( - triggerer_ids: &[DynamicId], - event_types: &[EventType], + pub async fn get_events( + target_selectors: &[EventSelector], + triggerer_selectors: &[EventSelector], exec: impl sqlx::Executor<'_, Database = sqlx::Postgres>, ) -> Result, DatabaseError> { + let (target_ids, target_event_types): (Vec<_>, Vec<_>) = target_selectors + .into_iter() + .map(|t| (t.id.clone(), t.event_type)) + .unzip(); + let (triggerer_ids, triggerer_event_types): (Vec<_>, Vec<_>) = triggerer_selectors + .into_iter() + .map(|t| (t.id.clone(), t.event_type)) + .unzip(); Ok(sqlx::query_as!( RawEvent, r#" @@ -155,14 +168,21 @@ impl Event { id, target_id as "target_id: _", triggerer_id as "triggerer_id: _", - type as "event_type: _", + event_type as "event_type: _", metadata, created FROM events e - WHERE triggerer_id=ANY($1) AND type=ANY($2) + WHERE + ((target_id).id, (target_id).id_type, event_type) + = ANY(SELECT * FROM UNNEST ($1::dynamic_id[], $2::event_type[])) + OR + ((triggerer_id).id, (triggerer_id).id_type, event_type) + = ANY(SELECT * FROM UNNEST ($3::dynamic_id[], $4::event_type[])) "#, + &target_ids[..] as &[DynamicId], + &target_event_types[..] as &[EventType], &triggerer_ids[..] as &[DynamicId], - &event_types[..] as &[EventType] + &triggerer_event_types[..] as &[EventType] ) .fetch_all(exec) .await? @@ -200,10 +220,10 @@ impl RawEvent { INSERT INTO events ( id, target_id.id, - target_id.type, + target_id.id_type, triggerer_id.id, - triggerer_id.type, - type, + triggerer_id.id_type, + event_type, metadata ) SELECT * FROM UNNEST ( diff --git a/src/database/models/ids.rs b/src/database/models/ids.rs index d864b224..f46a1d18 100644 --- a/src/database/models/ids.rs +++ b/src/database/models/ids.rs @@ -450,7 +450,7 @@ impl From for ids::OAuthClientAuthorizationId { pub mod dynamic { use super::*; - #[derive(sqlx::Type, PartialEq, Debug)] + #[derive(sqlx::Type, PartialEq, Debug, Clone)] #[sqlx(type_name = "id_type", rename_all = "snake_case")] pub enum IdType { ProjectId, @@ -458,7 +458,7 @@ pub mod dynamic { OrganizationId, } - #[derive(sqlx::Type)] + #[derive(sqlx::Type, Clone)] #[sqlx(type_name = "dynamic_id")] pub struct DynamicId { pub id: i64, From c8e88bec16cdabba35dcdcd86b1d35494958fc26 Mon Sep 17 00:00:00 2001 From: OmegaJak Date: Tue, 24 Oct 2023 13:09:48 -0500 Subject: [PATCH 44/72] Starting basics of getting a user's feed --- Local Postgres.session.sql | 8 +++ src/database/models/creator_follows.rs | 6 +- src/database/models/event_item.rs | 49 +++++++++------ src/database/models/ids.rs | 5 ++ src/models/feed_item.rs | 51 ++++++++++++++++ src/models/ids.rs | 2 + src/models/mod.rs | 1 + src/routes/v3/users.rs | 82 ++++++++++++++++++++++++-- 8 files changed, 179 insertions(+), 25 deletions(-) create mode 100644 Local Postgres.session.sql create mode 100644 src/models/feed_item.rs diff --git a/Local Postgres.session.sql b/Local Postgres.session.sql new file mode 100644 index 00000000..2b70cadf --- /dev/null +++ b/Local Postgres.session.sql @@ -0,0 +1,8 @@ +DROP INDEX events_targets; +DROP INDEX events_triggerers; +DROP TABLE events; +DROP TYPE dynamic_id; +DROP TYPE event_type; +DROP TYPE id_type; +DELETE FROM _sqlx_migrations +WHERE version = 20231023212957; \ No newline at end of file diff --git a/src/database/models/creator_follows.rs b/src/database/models/creator_follows.rs index 2139289f..d5b1c0a7 100644 --- a/src/database/models/creator_follows.rs +++ b/src/database/models/creator_follows.rs @@ -54,8 +54,8 @@ impl UserFollow { .collect_vec()) } - pub async fn get_follows_from( - user_id: UserId, + pub async fn get_follows_from_follower( + follower_user_id: UserId, exec: impl sqlx::Executor<'_, Database = sqlx::Postgres>, ) -> Result, DatabaseError> { let res = sqlx::query!( @@ -63,7 +63,7 @@ impl UserFollow { SELECT follower_id, target_id FROM user_follows WHERE follower_id=$1 ", - user_id.0 + follower_user_id.0 ) .fetch_all(exec) .await?; diff --git a/src/database/models/event_item.rs b/src/database/models/event_item.rs index 30c039fa..344f9a86 100644 --- a/src/database/models/event_item.rs +++ b/src/database/models/event_item.rs @@ -27,8 +27,24 @@ pub enum EventData { } pub struct Event { - id: EventId, - event_data: EventData, + pub id: EventId, + pub event_data: EventData, + pub time: DateTime, +} + +struct RawEvent { + pub id: EventId, + pub target_id: DynamicId, + pub triggerer_id: Option, + pub event_type: EventType, + // #[serde::serde(flatten)] //TODO: is this necessary? + pub metadata: Option, + pub created: Option>, +} + +pub struct EventSelector { + pub id: DynamicId, + pub event_type: EventType, } impl Event { @@ -37,7 +53,11 @@ impl Event { transaction: &mut sqlx::Transaction<'_, sqlx::Postgres>, ) -> Result { let id = generate_event_id(transaction).await?; - Ok(Self { id, event_data }) + Ok(Self { + id, + event_data, + time: Default::default(), + }) } } @@ -97,26 +117,19 @@ impl TryFrom for Event { |v| v.try_into(), )?, }, + time: value.created.map_or_else( + || { + Err(DatabaseError::UnexpectedNull( + "the value of created should not be null".to_string(), + )) + }, + |c| Ok(c), + )?, }, }) } } -struct RawEvent { - pub id: EventId, - pub target_id: DynamicId, - pub triggerer_id: Option, - pub event_type: EventType, - // #[serde::serde(flatten)] //TODO: is this necessary? - pub metadata: Option, - pub created: Option>, -} - -pub struct EventSelector { - pub id: DynamicId, - pub event_type: EventType, -} - impl PgHasArrayType for EventType { fn array_type_info() -> sqlx::postgres::PgTypeInfo { PgTypeInfo::with_name("event_type") diff --git a/src/database/models/ids.rs b/src/database/models/ids.rs index f46a1d18..84d25077 100644 --- a/src/database/models/ids.rs +++ b/src/database/models/ids.rs @@ -425,6 +425,11 @@ impl From for ids::PatId { ids::PatId(id.0 as u64) } } +impl From for ids::FeedItemId { + fn from(value: crate::database::models::ids::EventId) -> Self { + Self(value.0 as u64) + } +} impl From for ids::OAuthClientId { fn from(id: OAuthClientId) -> Self { diff --git a/src/models/feed_item.rs b/src/models/feed_item.rs new file mode 100644 index 00000000..5538a153 --- /dev/null +++ b/src/models/feed_item.rs @@ -0,0 +1,51 @@ +use super::ids::Base62Id; +use super::ids::OrganizationId; +use super::users::UserId; +use crate::database::models::notification_item::Notification as DBNotification; +use crate::database::models::notification_item::NotificationAction as DBNotificationAction; +use crate::models::ids::{ProjectId, ReportId, TeamId, ThreadId, ThreadMessageId, VersionId}; +use crate::models::projects::ProjectStatus; +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; + +use crate::database::models::event_item as DBEvent; + +#[derive(Copy, Clone, PartialEq, Eq, Serialize, Deserialize)] +#[serde(from = "Base62Id")] +#[serde(into = "Base62Id")] +pub struct FeedItemId(pub u64); + +#[derive(Serialize, Deserialize)] +#[serde(tag = "type", rename_all = "snake_case")] +pub enum CreatorId { + User(UserId), + Organization(OrganizationId), +} + +#[derive(Serialize, Deserialize)] +pub struct FeedItem { + pub id: FeedItemId, + pub body: FeedItemBody, + pub time: DateTime, +} + +#[derive(Serialize, Deserialize)] +#[serde(tag = "type", rename_all = "snake_case")] +pub enum FeedItemBody { + ProjectCreated { + project_id: ProjectId, + creator_id: CreatorId, + project_title: String, + }, +} + +impl From for CreatorId { + fn from(value: crate::database::models::event_item::CreatorId) -> Self { + match value { + DBEvent::CreatorId::User(user_id) => CreatorId::User(user_id.into()), + DBEvent::CreatorId::Organization(organization_id) => { + CreatorId::Organization(organization_id.into()) + } + } + } +} diff --git a/src/models/ids.rs b/src/models/ids.rs index 8cea089f..c5a68b4e 100644 --- a/src/models/ids.rs +++ b/src/models/ids.rs @@ -1,6 +1,7 @@ use thiserror::Error; pub use super::collections::CollectionId; +pub use super::feed_item::FeedItemId; pub use super::images::ImageId; pub use super::notifications::NotificationId; pub use super::oauth_clients::OAuthClientAuthorizationId; @@ -127,6 +128,7 @@ base62_id_impl!(ImageId, ImageId); base62_id_impl!(OAuthClientId, OAuthClientId); base62_id_impl!(OAuthRedirectUriId, OAuthRedirectUriId); base62_id_impl!(OAuthClientAuthorizationId, OAuthClientAuthorizationId); +base62_id_impl!(FeedItemId, FeedItemId); pub mod base62_impl { use serde::de::{self, Deserializer, Visitor}; diff --git a/src/models/mod.rs b/src/models/mod.rs index 7c97ad31..2c7c644b 100644 --- a/src/models/mod.rs +++ b/src/models/mod.rs @@ -1,6 +1,7 @@ pub mod analytics; pub mod collections; pub mod error; +pub mod feed_item; pub mod ids; pub mod images; pub mod notifications; diff --git a/src/routes/v3/users.rs b/src/routes/v3/users.rs index 0cee2be5..96535dd5 100644 --- a/src/routes/v3/users.rs +++ b/src/routes/v3/users.rs @@ -1,22 +1,38 @@ use crate::{ auth::get_user_from_headers, - database::{self, models::DatabaseError, redis::RedisPool}, - models::pats::Scopes, + database::{ + self, + models::{ + event_item::{EventData, EventSelector, EventType}, + DatabaseError, + }, + redis::RedisPool, + }, + models::{ + feed_item::{FeedItem, FeedItemBody}, + pats::Scopes, + }, queue::session::AuthQueue, routes::ApiError, }; use actix_web::{ - delete, post, + delete, get, post, web::{self}, HttpRequest, HttpResponse, }; +use itertools::Itertools; use sqlx::PgPool; use database::models as db_models; use database::models::creator_follows::UserFollow as DBUserFollow; +use database::models::event_item::Event as DBEvent; pub fn config(cfg: &mut web::ServiceConfig) { - cfg.service(web::scope("user").service(user_follow)); + cfg.service( + web::scope("user") + .service(user_follow) + .service(user_unfollow), + ); } #[post("{id}/follow")] @@ -84,3 +100,61 @@ pub async fn user_unfollow( Ok(HttpResponse::NoContent().body("")) } + +#[get("feed")] +pub async fn current_user_feed( + req: HttpRequest, + pool: web::Data, + redis: web::Data, + session_queue: web::Data, +) -> Result { + let (_, current_user) = get_user_from_headers( + &req, + &**pool, + &redis, + &session_queue, + Some(&[Scopes::NOTIFICATION_READ]), + ) + .await?; + + let followed_users = + DBUserFollow::get_follows_from_follower(current_user.id.into(), &**pool).await?; + + let selectors = followed_users + .into_iter() + .map(|follow| EventSelector { + id: follow.target_id.into(), + event_type: EventType::ProjectCreated, + }) + .collect_vec(); + let events = DBEvent::get_events(&[], &selectors, &**pool).await?; + + let mut feed_items: Vec = Vec::new(); + for event in events { + let body = match event.event_data { + EventData::ProjectCreated { + project_id, + creator_id, + } => { + let project = db_models::Project::get_id(project_id, &**pool, &redis).await?; + project.map(|p| FeedItemBody::ProjectCreated { + project_id: project_id.into(), + creator_id: creator_id.into(), + project_title: p.inner.title, + }) + } + }; + + if let Some(body) = body { + let feed_item = FeedItem { + id: event.id.into(), + body, + time: event.time, + }; + + feed_items.push(feed_item); + } + } + + todo!(); +} From 03ade2b03114acf2036fe13e5c0b091134d75c06 Mon Sep 17 00:00:00 2001 From: OmegaJak Date: Tue, 24 Oct 2023 20:49:28 -0500 Subject: [PATCH 45/72] As requested, remove use of custom postgres enums and types --- migrations/20231023212957_feeds.sql | 19 ++-- src/database/models/event_item.rs | 159 ++++++++++++++++------------ src/database/models/ids.rs | 15 ++- 3 files changed, 111 insertions(+), 82 deletions(-) diff --git a/migrations/20231023212957_feeds.sql b/migrations/20231023212957_feeds.sql index a8301aca..3ee02945 100644 --- a/migrations/20231023212957_feeds.sql +++ b/migrations/20231023212957_feeds.sql @@ -1,21 +1,20 @@ -CREATE TYPE event_type AS ENUM ('project_created'); -CREATE TYPE id_type AS ENUM ('project_id', 'user_id', 'organization_id'); -CREATE TYPE dynamic_id AS (id bigint, id_type id_type); CREATE TABLE events( id bigint NOT NULL PRIMARY KEY, - target_id dynamic_id NOT NULL, - triggerer_id dynamic_id NULL, - event_type event_type NOT NULL, + target_id bigint NOT NULL, + target_id_type text NOT NULL, + triggerer_id bigint NULL, + triggerer_id_type text NULL, + event_type text NOT NULL, metadata jsonb NULL, created timestamptz NOT NULL DEFAULT CURRENT_TIMESTAMP ); CREATE INDEX events_targets ON events ( - ((target_id).id), - ((target_id).id_type), + target_id, + target_id_type, event_type ); CREATE INDEX events_triggerers ON events ( - ((triggerer_id).id), - ((triggerer_id).id_type), + triggerer_id, + triggerer_id_type, event_type ); \ No newline at end of file diff --git a/src/database/models/event_item.rs b/src/database/models/event_item.rs index 344f9a86..887564c0 100644 --- a/src/database/models/event_item.rs +++ b/src/database/models/event_item.rs @@ -1,5 +1,3 @@ -use std::convert::{TryFrom, TryInto}; - use super::{ dynamic::{DynamicId, IdType}, generate_event_id, DatabaseError, EventId, OrganizationId, ProjectId, UserId, @@ -7,13 +5,21 @@ use super::{ use chrono::{DateTime, Utc}; use itertools::Itertools; use sqlx::postgres::{PgHasArrayType, PgTypeInfo}; +use std::convert::{TryFrom, TryInto}; #[derive(sqlx::Type, Clone, Copy)] -#[sqlx(type_name = "event_type", rename_all = "snake_case")] +#[sqlx(type_name = "text")] +#[sqlx(rename_all = "snake_case")] pub enum EventType { ProjectCreated, } +impl PgHasArrayType for EventType { + fn array_type_info() -> sqlx::postgres::PgTypeInfo { + PgTypeInfo::with_name("_text") + } +} + pub enum CreatorId { User(UserId), Organization(OrganizationId), @@ -34,10 +40,11 @@ pub struct Event { struct RawEvent { pub id: EventId, - pub target_id: DynamicId, - pub triggerer_id: Option, + pub target_id: i64, + pub target_id_type: IdType, + pub triggerer_id: Option, + pub triggerer_id_type: Option, pub event_type: EventType, - // #[serde::serde(flatten)] //TODO: is this necessary? pub metadata: Option, pub created: Option>, } @@ -87,14 +94,20 @@ impl From for RawEvent { EventData::ProjectCreated { project_id, creator_id, - } => RawEvent { - id: value.id, - target_id: project_id.into(), - triggerer_id: Some(creator_id.into()), - event_type: EventType::ProjectCreated, - metadata: None, - created: None, - }, + } => { + let target_id = DynamicId::from(project_id); + let triggerer_id = DynamicId::from(creator_id); + RawEvent { + id: value.id, + target_id: target_id.id, + target_id_type: target_id.id_type, + triggerer_id: Some(triggerer_id.id), + triggerer_id_type: Some(triggerer_id.id_type), + event_type: EventType::ProjectCreated, + metadata: None, + created: None, + } + } } } } @@ -103,19 +116,24 @@ impl TryFrom for Event { type Error = DatabaseError; fn try_from(value: RawEvent) -> Result { + let target_id = DynamicId { + id: value.target_id, + id_type: value.target_id_type, + }; + let triggerer_id = match (value.triggerer_id, value.triggerer_id_type) { + (Some(id), Some(id_type)) => Some(DynamicId { id, id_type }), + _ => None, + }; Ok(match value.event_type { EventType::ProjectCreated => Event { id: value.id, event_data: EventData::ProjectCreated { - project_id: value.target_id.try_into()?, - creator_id: value.triggerer_id.map_or_else( - || { - Err(DatabaseError::UnexpectedNull( - "triggerer_id should not be null for project creation".to_string(), - )) - }, - |v| v.try_into(), - )?, + project_id: target_id.try_into()?, + creator_id: triggerer_id.map_or_else(|| { + Err(DatabaseError::UnexpectedNull( + "Neither triggerer_id nor triggerer_id_type should be null for project creation".to_string(), + )) + }, |v| v.try_into())?, }, time: value.created.map_or_else( || { @@ -130,18 +148,6 @@ impl TryFrom for Event { } } -impl PgHasArrayType for EventType { - fn array_type_info() -> sqlx::postgres::PgTypeInfo { - PgTypeInfo::with_name("event_type") - } -} - -impl PgHasArrayType for DynamicId { - fn array_type_info() -> PgTypeInfo { - PgTypeInfo::with_name("dynamic_id") - } -} - impl Event { pub async fn insert( self, @@ -166,35 +172,35 @@ impl Event { triggerer_selectors: &[EventSelector], exec: impl sqlx::Executor<'_, Database = sqlx::Postgres>, ) -> Result, DatabaseError> { - let (target_ids, target_event_types): (Vec<_>, Vec<_>) = target_selectors - .into_iter() - .map(|t| (t.id.clone(), t.event_type)) - .unzip(); - let (triggerer_ids, triggerer_event_types): (Vec<_>, Vec<_>) = triggerer_selectors - .into_iter() - .map(|t| (t.id.clone(), t.event_type)) - .unzip(); + let (target_ids, target_id_types, target_event_types) = + unzip_event_selectors(target_selectors); + let (triggerer_ids, triggerer_id_types, triggerer_event_types) = + unzip_event_selectors(triggerer_selectors); Ok(sqlx::query_as!( RawEvent, r#" SELECT id, - target_id as "target_id: _", - triggerer_id as "triggerer_id: _", + target_id, + target_id_type as "target_id_type: _", + triggerer_id, + triggerer_id_type as "triggerer_id_type: _", event_type as "event_type: _", metadata, created FROM events e WHERE - ((target_id).id, (target_id).id_type, event_type) - = ANY(SELECT * FROM UNNEST ($1::dynamic_id[], $2::event_type[])) + (target_id, target_id_type, event_type) + = ANY(SELECT * FROM UNNEST ($1::bigint[], $2::text[], $3::text[])) OR - ((triggerer_id).id, (triggerer_id).id_type, event_type) - = ANY(SELECT * FROM UNNEST ($3::dynamic_id[], $4::event_type[])) + (triggerer_id, triggerer_id_type, event_type) + = ANY(SELECT * FROM UNNEST ($4::bigint[], $5::text[], $6::text[])) "#, - &target_ids[..] as &[DynamicId], + &target_ids[..], + &target_id_types[..] as &[IdType], &target_event_types[..] as &[EventType], - &triggerer_ids[..] as &[DynamicId], + &triggerer_ids[..], + &triggerer_id_types[..] as &[IdType], &triggerer_event_types[..] as &[EventType] ) .fetch_all(exec) @@ -210,19 +216,23 @@ impl RawEvent { events: Vec, transaction: &mut sqlx::Transaction<'_, sqlx::Postgres>, ) -> Result<(), DatabaseError> { - let (ids, target_ids, triggerer_ids, event_types, metadata): ( - Vec<_>, - Vec<_>, - Vec<_>, - Vec<_>, - Vec<_>, - ) = events + let ( + ids, + target_ids, + target_id_types, + triggerer_ids, + triggerer_id_types, + event_types, + metadata, + ): (Vec<_>, Vec<_>, Vec<_>, Vec<_>, Vec<_>, Vec<_>, Vec<_>) = events .into_iter() .map(|e| { ( e.id.0, e.target_id, + e.target_id_type, e.triggerer_id, + e.triggerer_id_type, e.event_type, e.metadata, ) @@ -232,24 +242,28 @@ impl RawEvent { " INSERT INTO events ( id, - target_id.id, - target_id.id_type, - triggerer_id.id, - triggerer_id.id_type, + target_id, + target_id_type, + triggerer_id, + triggerer_id_type, event_type, metadata ) SELECT * FROM UNNEST ( $1::bigint[], - $2::dynamic_id[], - $3::dynamic_id[], - $4::event_type[], - $5::jsonb[] + $2::bigint[], + $3::text[], + $4::bigint[], + $5::text[], + $6::text[], + $7::jsonb[] ) ", &ids[..], - &target_ids[..] as &[DynamicId], - &triggerer_ids[..] as &[Option], + &target_ids[..], + &target_id_types[..] as &[IdType], + &triggerer_ids[..] as &[Option], + &triggerer_id_types[..] as &[Option], &event_types[..] as &[EventType], &metadata[..] as &[Option] ) @@ -259,3 +273,12 @@ impl RawEvent { Ok(()) } } + +fn unzip_event_selectors( + target_selectors: &[EventSelector], +) -> (Vec, Vec, Vec) { + target_selectors + .into_iter() + .map(|t| (t.id.id, t.id.id_type, t.event_type)) + .multiunzip() +} diff --git a/src/database/models/ids.rs b/src/database/models/ids.rs index 84d25077..4f740ec8 100644 --- a/src/database/models/ids.rs +++ b/src/database/models/ids.rs @@ -454,17 +454,24 @@ impl From for ids::OAuthClientAuthorizationId { pub mod dynamic { use super::*; + use sqlx::postgres::{PgHasArrayType, PgTypeInfo}; - #[derive(sqlx::Type, PartialEq, Debug, Clone)] - #[sqlx(type_name = "id_type", rename_all = "snake_case")] + #[derive(sqlx::Type, PartialEq, Debug, Clone, Copy)] + #[sqlx(type_name = "text")] + #[sqlx(rename_all = "snake_case")] pub enum IdType { ProjectId, UserId, OrganizationId, } - #[derive(sqlx::Type, Clone)] - #[sqlx(type_name = "dynamic_id")] + impl PgHasArrayType for IdType { + fn array_type_info() -> sqlx::postgres::PgTypeInfo { + PgTypeInfo::with_name("_text") + } + } + + #[derive(Clone)] pub struct DynamicId { pub id: i64, pub id_type: IdType, From 196b3533198c3972f11b2b3ba5374193009037c7 Mon Sep 17 00:00:00 2001 From: OmegaJak Date: Tue, 24 Oct 2023 21:37:26 -0500 Subject: [PATCH 46/72] Remove file that shouldn't have been committed --- Local Postgres.session.sql | 8 -------- 1 file changed, 8 deletions(-) delete mode 100644 Local Postgres.session.sql diff --git a/Local Postgres.session.sql b/Local Postgres.session.sql deleted file mode 100644 index 2b70cadf..00000000 --- a/Local Postgres.session.sql +++ /dev/null @@ -1,8 +0,0 @@ -DROP INDEX events_targets; -DROP INDEX events_triggerers; -DROP TABLE events; -DROP TYPE dynamic_id; -DROP TYPE event_type; -DROP TYPE id_type; -DELETE FROM _sqlx_migrations -WHERE version = 20231023212957; \ No newline at end of file From 97585a9b9198e9ff77cac8f3103d3aab59dfba17 Mon Sep 17 00:00:00 2001 From: OmegaJak Date: Tue, 24 Oct 2023 23:30:47 -0500 Subject: [PATCH 47/72] Organization following --- src/database/models/creator_follows.rs | 162 ++++++++++++++----------- src/models/feed_item.rs | 5 +- src/routes/v3/follow.rs | 107 ++++++++++++++++ src/routes/v3/mod.rs | 5 +- src/routes/v3/organizations.rs | 73 +++++++++++ src/routes/v3/users.rs | 79 +++++------- 6 files changed, 304 insertions(+), 127 deletions(-) create mode 100644 src/routes/v3/follow.rs create mode 100644 src/routes/v3/organizations.rs diff --git a/src/database/models/creator_follows.rs b/src/database/models/creator_follows.rs index d5b1c0a7..2d16e4ae 100644 --- a/src/database/models/creator_follows.rs +++ b/src/database/models/creator_follows.rs @@ -13,86 +13,100 @@ pub struct OrganizationFollow { pub target_id: OrganizationId, } -impl UserFollow { - pub async fn insert( - &self, - exec: impl sqlx::Executor<'_, Database = sqlx::Postgres>, - ) -> Result<(), DatabaseError> { - sqlx::query!( - " - INSERT INTO user_follows (follower_id, target_id) VALUES ($1, $2) - ", - self.follower_id.0, - self.target_id.0 - ) - .execute(exec) - .await?; +struct FollowQuery { + follower_id: i64, + target_id: i64, +} - Ok(()) +impl From for UserFollow { + fn from(value: FollowQuery) -> Self { + UserFollow { + follower_id: UserId(value.follower_id), + target_id: UserId(value.target_id), + } } +} - pub async fn get_followers( - target_id: UserId, - exec: impl sqlx::Executor<'_, Database = sqlx::Postgres>, - ) -> Result, DatabaseError> { - let res = sqlx::query!( - " - SELECT follower_id, target_id FROM user_follows - WHERE target_id=$1 - ", - target_id.0 - ) - .fetch_all(exec) - .await?; - - Ok(res - .into_iter() - .map(|r| UserFollow { - follower_id: UserId(r.follower_id), - target_id: UserId(r.target_id), - }) - .collect_vec()) +impl From for OrganizationFollow { + fn from(value: FollowQuery) -> Self { + OrganizationFollow { + follower_id: UserId(value.follower_id), + target_id: OrganizationId(value.target_id), + } } +} - pub async fn get_follows_from_follower( - follower_user_id: UserId, - exec: impl sqlx::Executor<'_, Database = sqlx::Postgres>, - ) -> Result, DatabaseError> { - let res = sqlx::query!( - " - SELECT follower_id, target_id FROM user_follows - WHERE follower_id=$1 - ", - follower_user_id.0 - ) - .fetch_all(exec) - .await?; +macro_rules! impl_follow { + ($target_struct:ident, $table_name:tt, $target_id_type:ident, $target_id_ctor:expr) => { + impl $target_struct { + pub async fn insert( + &self, + exec: impl sqlx::Executor<'_, Database = sqlx::Postgres>, + ) -> Result<(), DatabaseError> { + sqlx::query!( + " INSERT INTO " + $table_name + " (follower_id, target_id) VALUES ($1, $2)", + self.follower_id.0, + self.target_id.0 + ) + .execute(exec) + .await?; - Ok(res - .into_iter() - .map(|r| UserFollow { - follower_id: UserId(r.follower_id), - target_id: UserId(r.target_id), - }) - .collect_vec()) - } + Ok(()) + } - pub async fn unfollow( - follower_id: UserId, - target_id: UserId, - exec: impl sqlx::Executor<'_, Database = sqlx::Postgres>, - ) -> Result<(), DatabaseError> { - sqlx::query!( - " - DELETE FROM user_follows - WHERE follower_id=$1 AND target_id=$2 - ", - follower_id.0, - target_id.0, - ) - .execute(exec) - .await?; + pub async fn get_followers( + target_id: $target_id_type, + exec: impl sqlx::Executor<'_, Database = sqlx::Postgres>, + ) -> Result, DatabaseError> { + let res = sqlx::query_as!( + FollowQuery, + "SELECT follower_id, target_id FROM " + $table_name + " WHERE target_id=$1", + target_id.0 + ) + .fetch_all(exec) + .await?; - Ok(()) - } + Ok(res.into_iter().map(|r| r.into()).collect_vec()) + } + + pub async fn get_follows_from_follower( + follower_user_id: UserId, + exec: impl sqlx::Executor<'_, Database = sqlx::Postgres>, + ) -> Result, DatabaseError> { + let res = sqlx::query_as!( + FollowQuery, + "SELECT follower_id, target_id FROM " + $table_name + " WHERE follower_id=$1", + follower_user_id.0 + ) + .fetch_all(exec) + .await?; + + Ok(res.into_iter().map(|r| r.into()).collect_vec()) + } + + pub async fn unfollow( + follower_id: UserId, + target_id: $target_id_type, + exec: impl sqlx::Executor<'_, Database = sqlx::Postgres>, + ) -> Result<(), DatabaseError> { + sqlx::query!( + "DELETE FROM " + $table_name + " WHERE follower_id=$1 AND target_id=$2", + follower_id.0, + target_id.0, + ) + .execute(exec) + .await?; + + Ok(()) + } + } + }; } + +impl_follow!(UserFollow, "user_follows", UserId, UserId); +impl_follow!( + OrganizationFollow, + "organization_follows", + OrganizationId, + OrganizationId +); diff --git a/src/models/feed_item.rs b/src/models/feed_item.rs index 5538a153..f1cee015 100644 --- a/src/models/feed_item.rs +++ b/src/models/feed_item.rs @@ -1,10 +1,7 @@ use super::ids::Base62Id; use super::ids::OrganizationId; use super::users::UserId; -use crate::database::models::notification_item::Notification as DBNotification; -use crate::database::models::notification_item::NotificationAction as DBNotificationAction; -use crate::models::ids::{ProjectId, ReportId, TeamId, ThreadId, ThreadMessageId, VersionId}; -use crate::models::projects::ProjectStatus; +use crate::models::ids::ProjectId; use chrono::{DateTime, Utc}; use serde::{Deserialize, Serialize}; diff --git a/src/routes/v3/follow.rs b/src/routes/v3/follow.rs new file mode 100644 index 00000000..a27bbe86 --- /dev/null +++ b/src/routes/v3/follow.rs @@ -0,0 +1,107 @@ +use super::ApiError; +use crate::{ + auth::get_user_from_headers, + database::{models::DatabaseError, redis::RedisPool}, + models::pats::Scopes, + queue::session::AuthQueue, +}; +use actix_web::{web, HttpRequest, HttpResponse}; +use sqlx::PgPool; + +use crate::database::models as db_models; + +pub trait HasId { + fn id(&self) -> T; +} + +impl HasId for db_models::User { + fn id(&self) -> db_models::UserId { + self.id + } +} + +impl HasId for db_models::Organization { + fn id(&self) -> db_models::OrganizationId { + self.id + } +} + +pub async fn follow( + req: HttpRequest, + target_id: web::Path, + pool: web::Data, + redis: web::Data, + session_queue: web::Data, + get_db_model: impl FnOnce(String, web::Data, web::Data) -> Fut1, + insert_follow: impl FnOnce(db_models::UserId, TId, web::Data) -> Fut2, + error_word: &str, +) -> Result +where + Fut1: futures::Future, DatabaseError>>, + Fut2: futures::Future>, + T: HasId, +{ + let (_, current_user) = get_user_from_headers( + &req, + &**pool, + &redis, + &session_queue, + Some(&[Scopes::USER_WRITE]), + ) + .await?; + + let target = get_db_model(target_id.into_inner(), pool.clone(), redis.clone()) + .await? + .ok_or_else(|| { + ApiError::InvalidInput(format!("The specified {} does not exist!", error_word)) + })?; + + insert_follow(current_user.id.into(), target.id(), pool.clone()) + .await + .map_err(|e| match e { + DatabaseError::Database(e) + if e.as_database_error() + .is_some_and(|e| e.is_unique_violation()) => + { + ApiError::InvalidInput(format!("You are already following this {}!", error_word)) + } + e => e.into(), + })?; + + Ok(HttpResponse::NoContent().body("")) +} + +pub async fn unfollow( + req: HttpRequest, + target_id: web::Path, + pool: web::Data, + redis: web::Data, + session_queue: web::Data, + get_db_model: impl FnOnce(String, web::Data, web::Data) -> Fut1, + unfollow: impl FnOnce(db_models::UserId, TId, web::Data) -> Fut2, + error_word: &str, +) -> Result +where + Fut1: futures::Future, DatabaseError>>, + Fut2: futures::Future>, + T: HasId, +{ + let (_, current_user) = get_user_from_headers( + &req, + &**pool, + &redis, + &session_queue, + Some(&[Scopes::USER_WRITE]), + ) + .await?; + + let target = get_db_model(target_id.into_inner(), pool.clone(), redis.clone()) + .await? + .ok_or_else(|| { + ApiError::InvalidInput(format!("The specified {} does not exist!", error_word)) + })?; + + unfollow(current_user.id.into(), target.id(), pool.clone()).await?; + + Ok(HttpResponse::NoContent().body("")) +} diff --git a/src/routes/v3/mod.rs b/src/routes/v3/mod.rs index 59793b8a..bf8aa97f 100644 --- a/src/routes/v3/mod.rs +++ b/src/routes/v3/mod.rs @@ -1,3 +1,5 @@ +pub mod follow; +pub mod organizations; pub mod users; pub use super::ApiError; @@ -14,7 +16,8 @@ pub fn config(cfg: &mut web::ServiceConfig) { .route("", web::get().to(hello_world)) .configure(users::config) .configure(oauth::config) - .configure(oauth_clients::config), + .configure(oauth_clients::config) + .configure(organizations::config), ); } diff --git a/src/routes/v3/organizations.rs b/src/routes/v3/organizations.rs new file mode 100644 index 00000000..1ae9b66b --- /dev/null +++ b/src/routes/v3/organizations.rs @@ -0,0 +1,73 @@ +use crate::{ + database::{self, redis::RedisPool}, + queue::session::AuthQueue, + routes::ApiError, +}; +use actix_web::{ + delete, post, + web::{self}, + HttpRequest, HttpResponse, +}; +use sqlx::PgPool; + +use database::models::creator_follows::OrganizationFollow as DBOrganizationFollow; +use database::models::organization_item::Organization as DBOrganization; + +pub fn config(cfg: &mut web::ServiceConfig) { + cfg.service( + web::scope("organization") + .service(organization_follow) + .service(organization_unfollow), + ); +} + +#[post("{id}/follow")] +pub async fn organization_follow( + req: HttpRequest, + target_id: web::Path, + pool: web::Data, + redis: web::Data, + session_queue: web::Data, +) -> Result { + crate::routes::v3::follow::follow( + req, + target_id, + pool, + redis, + session_queue, + |id, pool, redis| async move { DBOrganization::get(&id, &**pool, &redis).await }, + |follower_id, target_id, pool| async move { + DBOrganizationFollow { + follower_id, + target_id, + } + .insert(&**pool) + .await + }, + "organization", + ) + .await +} + +#[delete("{id}/follow")] +pub async fn organization_unfollow( + req: HttpRequest, + target_id: web::Path, + pool: web::Data, + redis: web::Data, + session_queue: web::Data, +) -> Result { + crate::routes::v3::follow::unfollow( + req, + target_id, + pool, + redis, + session_queue, + |id, pool, redis| async move { DBOrganization::get(&id, &**pool, &redis).await }, + |follower_id, target_id, pool| async move { + DBOrganizationFollow::unfollow(follower_id, target_id, &**pool).await + }, + "organization", + ) + .await +} diff --git a/src/routes/v3/users.rs b/src/routes/v3/users.rs index 96535dd5..dd99880e 100644 --- a/src/routes/v3/users.rs +++ b/src/routes/v3/users.rs @@ -2,10 +2,7 @@ use crate::{ auth::get_user_from_headers, database::{ self, - models::{ - event_item::{EventData, EventSelector, EventType}, - DatabaseError, - }, + models::event_item::{EventData, EventSelector, EventType}, redis::RedisPool, }, models::{ @@ -26,6 +23,7 @@ use sqlx::PgPool; use database::models as db_models; use database::models::creator_follows::UserFollow as DBUserFollow; use database::models::event_item::Event as DBEvent; +use database::models::user_item::User as DBUser; pub fn config(cfg: &mut web::ServiceConfig) { cfg.service( @@ -43,36 +41,24 @@ pub async fn user_follow( redis: web::Data, session_queue: web::Data, ) -> Result { - let (_, current_user) = get_user_from_headers( - &req, - &**pool, - &redis, - &session_queue, - Some(&[Scopes::USER_WRITE]), + crate::routes::v3::follow::follow( + req, + target_id, + pool, + redis, + session_queue, + |id, pool, redis| async move { DBUser::get(&id, &**pool, &redis).await }, + |follower_id, target_id, pool| async move { + DBUserFollow { + follower_id, + target_id, + } + .insert(&**pool) + .await + }, + "user", ) - .await?; - - let target_user = db_models::User::get(&target_id.into_inner(), &**pool, &redis) - .await? - .ok_or_else(|| ApiError::InvalidInput("The specified user does not exist!".to_string()))?; - - DBUserFollow { - follower_id: current_user.id.into(), - target_id: target_user.id, - } - .insert(&**pool) .await - .map_err(|e| match e { - DatabaseError::Database(e) - if e.as_database_error() - .is_some_and(|e| e.is_unique_violation()) => - { - ApiError::InvalidInput("You are already following this user!".to_string()) - } - e => e.into(), - })?; - - Ok(HttpResponse::NoContent().body("")) } #[delete("{id}/follow")] @@ -83,22 +69,19 @@ pub async fn user_unfollow( redis: web::Data, session_queue: web::Data, ) -> Result { - let (_, current_user) = get_user_from_headers( - &req, - &**pool, - &redis, - &session_queue, - Some(&[Scopes::USER_WRITE]), + crate::routes::v3::follow::unfollow( + req, + target_id, + pool, + redis, + session_queue, + |id, pool, redis| async move { DBUser::get(&id, &**pool, &redis).await }, + |follower_id, target_id, pool| async move { + DBUserFollow::unfollow(follower_id, target_id, &**pool).await + }, + "organization", ) - .await?; - - let target_user = db_models::User::get(&target_id.into_inner(), &**pool, &redis) - .await? - .ok_or_else(|| ApiError::InvalidInput("The specified user does not exist!".to_string()))?; - - DBUserFollow::unfollow(current_user.id.into(), target_user.id, &**pool).await?; - - Ok(HttpResponse::NoContent().body("")) + .await } #[get("feed")] @@ -156,5 +139,5 @@ pub async fn current_user_feed( } } - todo!(); + Ok(HttpResponse::Ok().json(feed_items)) } From a5d82d47d2d24b0cbb6fc4a6dce7b78b7691fcd4 Mon Sep 17 00:00:00 2001 From: OmegaJak Date: Wed, 25 Oct 2023 00:04:22 -0500 Subject: [PATCH 48/72] Only return authorized projects in the feed --- src/auth/checks.rs | 11 +++---- src/database/models/mod.rs | 1 + src/models/projects.rs | 2 +- src/routes/v2/analytics_get.rs | 2 +- src/routes/v2/organizations.rs | 2 +- src/routes/v2/projects.rs | 2 +- src/routes/v2/version_file.rs | 2 +- src/routes/v3/users.rs | 58 ++++++++++++++++++++++++++-------- 8 files changed, 54 insertions(+), 26 deletions(-) diff --git a/src/auth/checks.rs b/src/auth/checks.rs index 8d8fab68..1ff70e13 100644 --- a/src/auth/checks.rs +++ b/src/auth/checks.rs @@ -80,7 +80,7 @@ pub async fn is_authorized( pub async fn filter_authorized_projects( projects: Vec, - user_option: &Option, + current_user: Option<&User>, pool: &web::Data, ) -> Result, ApiError> { let mut return_projects = Vec::new(); @@ -88,19 +88,16 @@ pub async fn filter_authorized_projects( for project in projects { if !project.inner.status.is_hidden() - || user_option - .as_ref() - .map(|x| x.role.is_mod()) - .unwrap_or(false) + || current_user.map(|x| x.role.is_mod()).unwrap_or(false) { return_projects.push(project.into()); - } else if user_option.is_some() { + } else if current_user.is_some() { check_projects.push(project); } } if !check_projects.is_empty() { - if let Some(user) = user_option { + if let Some(user) = current_user { let user_id: models::ids::UserId = user.id.into(); use futures::TryStreamExt; diff --git a/src/database/models/mod.rs b/src/database/models/mod.rs index 41ff5fa8..54023cc4 100644 --- a/src/database/models/mod.rs +++ b/src/database/models/mod.rs @@ -22,6 +22,7 @@ pub mod user_item; pub mod version_item; pub use collection_item::Collection; +pub use event_item::Event; pub use ids::*; pub use image_item::Image; pub use oauth_client_item::OAuthClient; diff --git a/src/models/projects.rs b/src/models/projects.rs index 63f48f11..c4a323ad 100644 --- a/src/models/projects.rs +++ b/src/models/projects.rs @@ -9,7 +9,7 @@ use serde::{Deserialize, Serialize}; use validator::Validate; /// The ID of a specific project, encoded as base62 for usage in the API -#[derive(Copy, Clone, PartialEq, Eq, Serialize, Deserialize, Debug)] +#[derive(Copy, Clone, PartialEq, Eq, Serialize, Deserialize, Debug, Hash)] #[serde(from = "Base62Id")] #[serde(into = "Base62Id")] pub struct ProjectId(pub u64); diff --git a/src/routes/v2/analytics_get.rs b/src/routes/v2/analytics_get.rs index d6b61584..d6ab2c21 100644 --- a/src/routes/v2/analytics_get.rs +++ b/src/routes/v2/analytics_get.rs @@ -574,7 +574,7 @@ async fn filter_allowed_ids( .map(|id| Ok(ProjectId(parse_base62(id)?).into())) .collect::, ApiError>>()?; let projects = project_item::Project::get_many_ids(&ids, &***pool, redis).await?; - let ids: Vec = filter_authorized_projects(projects, &Some(user.clone()), pool) + let ids: Vec = filter_authorized_projects(projects, Some(&user), pool) .await? .into_iter() .map(|x| x.id) diff --git a/src/routes/v2/organizations.rs b/src/routes/v2/organizations.rs index d4c8a056..5503452b 100644 --- a/src/routes/v2/organizations.rs +++ b/src/routes/v2/organizations.rs @@ -531,7 +531,7 @@ pub async fn organization_projects_get( let projects_data = crate::database::models::Project::get_many_ids(&project_ids, &**pool, &redis).await?; - let projects = filter_authorized_projects(projects_data, ¤t_user, &pool).await?; + let projects = filter_authorized_projects(projects_data, current_user.as_ref(), &pool).await?; Ok(HttpResponse::Ok().json(projects)) } diff --git a/src/routes/v2/projects.rs b/src/routes/v2/projects.rs index 7da763d8..d1883798 100644 --- a/src/routes/v2/projects.rs +++ b/src/routes/v2/projects.rs @@ -141,7 +141,7 @@ pub async fn projects_get( .map(|x| x.1) .ok(); - let projects = filter_authorized_projects(projects_data, &user_option, &pool).await?; + let projects = filter_authorized_projects(projects_data, user_option.as_ref(), &pool).await?; Ok(HttpResponse::Ok().json(projects)) } diff --git a/src/routes/v2/version_file.rs b/src/routes/v2/version_file.rs index 171788b1..cf029f81 100644 --- a/src/routes/v2/version_file.rs +++ b/src/routes/v2/version_file.rs @@ -427,7 +427,7 @@ pub async fn get_projects_from_hashes( let projects_data = filter_authorized_projects( database::models::Project::get_many_ids(&project_ids, &**pool, &redis).await?, - &user_option, + user_option.as_ref(), &pool, ) .await?; diff --git a/src/routes/v3/users.rs b/src/routes/v3/users.rs index dd99880e..938536c2 100644 --- a/src/routes/v3/users.rs +++ b/src/routes/v3/users.rs @@ -1,5 +1,7 @@ +use std::{collections::HashMap, iter::FromIterator}; + use crate::{ - auth::get_user_from_headers, + auth::{filter_authorized_projects, get_user_from_headers}, database::{ self, models::event_item::{EventData, EventSelector, EventType}, @@ -7,7 +9,10 @@ use crate::{ }, models::{ feed_item::{FeedItem, FeedItemBody}, + ids::ProjectId, pats::Scopes, + projects::Project, + users::User, }, queue::session::AuthQueue, routes::ApiError, @@ -113,20 +118,22 @@ pub async fn current_user_feed( let events = DBEvent::get_events(&[], &selectors, &**pool).await?; let mut feed_items: Vec = Vec::new(); + let authorized_projects = + prefetch_authorized_event_projects(&events, &pool, &redis, ¤t_user).await?; for event in events { - let body = match event.event_data { - EventData::ProjectCreated { - project_id, - creator_id, - } => { - let project = db_models::Project::get_id(project_id, &**pool, &redis).await?; - project.map(|p| FeedItemBody::ProjectCreated { - project_id: project_id.into(), - creator_id: creator_id.into(), - project_title: p.inner.title, - }) - } - }; + let body = + match event.event_data { + EventData::ProjectCreated { + project_id, + creator_id, + } => authorized_projects.get(&project_id.into()).map(|p| { + FeedItemBody::ProjectCreated { + project_id: project_id.into(), + creator_id: creator_id.into(), + project_title: p.title.clone(), + } + }), + }; if let Some(body) = body { let feed_item = FeedItem { @@ -141,3 +148,26 @@ pub async fn current_user_feed( Ok(HttpResponse::Ok().json(feed_items)) } + +async fn prefetch_authorized_event_projects( + events: &[db_models::Event], + pool: &web::Data, + redis: &RedisPool, + current_user: &User, +) -> Result, ApiError> { + let project_ids = events + .iter() + .filter_map(|e| match &e.event_data { + EventData::ProjectCreated { + project_id, + creator_id: _, + } => Some(project_id.clone()), + }) + .collect_vec(); + let projects = db_models::Project::get_many_ids(&project_ids, &***pool, &redis).await?; + let authorized_projects = + filter_authorized_projects(projects, Some(¤t_user), &pool).await?; + Ok(HashMap::::from_iter( + authorized_projects.into_iter().map(|p| (p.id, p)), + )) +} From 1a81eea8814c6db0ae985df9acbbe39a90b5dd4a Mon Sep 17 00:00:00 2001 From: OmegaJak Date: Wed, 25 Oct 2023 00:07:58 -0500 Subject: [PATCH 49/72] cargo sqlx prepare --- ...6d9ea52ea89a38fd0420f351080e3e93eb4fb.json | 15 ++++ ...49782e217f550394641eec3ecaf06ac269276.json | 28 ++++++++ ...da3a3c9a8a3eee6985b548207eea7debcaafb.json | 69 +++++++++++++++++++ ...73208e7897cad64a4e371a6f3b76310d0479b.json | 20 ++++++ ...a0a1fae6852ad81db66e0333bb0f373d9b9a3.json | 28 ++++++++ ...f49e4f76c06ceb1ce1ab3fb9317b1035deb65.json | 28 ++++++++ ...1c612370373c6263566fc870da4a82b8ded7b.json | 15 ++++ ...e84cab2a5ab0953deb6b490d1e4df1c126036.json | 28 ++++++++ ...816ae3421cfa811be32ce478a405d5ef99c06.json | 15 ++++ ...e36c5ae2c2cb0067ff551a4ecea8a1685b87c.json | 22 ++++++ ...af09e5db89675320111e8006587b2666e72a3.json | 15 ++++ 11 files changed, 283 insertions(+) create mode 100644 .sqlx/query-35984b9f53e0c54a2b34d715c926d9ea52ea89a38fd0420f351080e3e93eb4fb.json create mode 100644 .sqlx/query-4bb2dddd46698d1aa3a5060f92549782e217f550394641eec3ecaf06ac269276.json create mode 100644 .sqlx/query-5352e8689a9d78c6976842afa12da3a3c9a8a3eee6985b548207eea7debcaafb.json create mode 100644 .sqlx/query-79889e768dd4b129ed10f46a6b973208e7897cad64a4e371a6f3b76310d0479b.json create mode 100644 .sqlx/query-7d4a842888bb3ed24090f9cb903a0a1fae6852ad81db66e0333bb0f373d9b9a3.json create mode 100644 .sqlx/query-9699feae1d917bc90738f3e6b46f49e4f76c06ceb1ce1ab3fb9317b1035deb65.json create mode 100644 .sqlx/query-9d73bfe2094b88fb7bcb0ef47741c612370373c6263566fc870da4a82b8ded7b.json create mode 100644 .sqlx/query-a260edc97e46f27b6d3bcc4ba3fe84cab2a5ab0953deb6b490d1e4df1c126036.json create mode 100644 .sqlx/query-c337079ccc53022395a4f3c35fa816ae3421cfa811be32ce478a405d5ef99c06.json create mode 100644 .sqlx/query-cbeb96926e6d958bcaf5ed28201e36c5ae2c2cb0067ff551a4ecea8a1685b87c.json create mode 100644 .sqlx/query-e50f350a4eb67ddbbfec7597185af09e5db89675320111e8006587b2666e72a3.json diff --git a/.sqlx/query-35984b9f53e0c54a2b34d715c926d9ea52ea89a38fd0420f351080e3e93eb4fb.json b/.sqlx/query-35984b9f53e0c54a2b34d715c926d9ea52ea89a38fd0420f351080e3e93eb4fb.json new file mode 100644 index 00000000..4ec3dcaa --- /dev/null +++ b/.sqlx/query-35984b9f53e0c54a2b34d715c926d9ea52ea89a38fd0420f351080e3e93eb4fb.json @@ -0,0 +1,15 @@ +{ + "db_name": "PostgreSQL", + "query": "DELETE FROM user_follows WHERE follower_id=$1 AND target_id=$2", + "describe": { + "columns": [], + "parameters": { + "Left": [ + "Int8", + "Int8" + ] + }, + "nullable": [] + }, + "hash": "35984b9f53e0c54a2b34d715c926d9ea52ea89a38fd0420f351080e3e93eb4fb" +} diff --git a/.sqlx/query-4bb2dddd46698d1aa3a5060f92549782e217f550394641eec3ecaf06ac269276.json b/.sqlx/query-4bb2dddd46698d1aa3a5060f92549782e217f550394641eec3ecaf06ac269276.json new file mode 100644 index 00000000..094458da --- /dev/null +++ b/.sqlx/query-4bb2dddd46698d1aa3a5060f92549782e217f550394641eec3ecaf06ac269276.json @@ -0,0 +1,28 @@ +{ + "db_name": "PostgreSQL", + "query": "SELECT follower_id, target_id FROM organization_follows WHERE target_id=$1", + "describe": { + "columns": [ + { + "ordinal": 0, + "name": "follower_id", + "type_info": "Int8" + }, + { + "ordinal": 1, + "name": "target_id", + "type_info": "Int8" + } + ], + "parameters": { + "Left": [ + "Int8" + ] + }, + "nullable": [ + false, + false + ] + }, + "hash": "4bb2dddd46698d1aa3a5060f92549782e217f550394641eec3ecaf06ac269276" +} diff --git a/.sqlx/query-5352e8689a9d78c6976842afa12da3a3c9a8a3eee6985b548207eea7debcaafb.json b/.sqlx/query-5352e8689a9d78c6976842afa12da3a3c9a8a3eee6985b548207eea7debcaafb.json new file mode 100644 index 00000000..12276d94 --- /dev/null +++ b/.sqlx/query-5352e8689a9d78c6976842afa12da3a3c9a8a3eee6985b548207eea7debcaafb.json @@ -0,0 +1,69 @@ +{ + "db_name": "PostgreSQL", + "query": "\n SELECT \n id,\n target_id,\n target_id_type as \"target_id_type: _\",\n triggerer_id,\n triggerer_id_type as \"triggerer_id_type: _\",\n event_type as \"event_type: _\",\n metadata,\n created\n FROM events e\n WHERE \n (target_id, target_id_type, event_type) \n = ANY(SELECT * FROM UNNEST ($1::bigint[], $2::text[], $3::text[]))\n OR\n (triggerer_id, triggerer_id_type, event_type) \n = ANY(SELECT * FROM UNNEST ($4::bigint[], $5::text[], $6::text[]))\n ", + "describe": { + "columns": [ + { + "ordinal": 0, + "name": "id", + "type_info": "Int8" + }, + { + "ordinal": 1, + "name": "target_id", + "type_info": "Int8" + }, + { + "ordinal": 2, + "name": "target_id_type: _", + "type_info": "Text" + }, + { + "ordinal": 3, + "name": "triggerer_id", + "type_info": "Int8" + }, + { + "ordinal": 4, + "name": "triggerer_id_type: _", + "type_info": "Text" + }, + { + "ordinal": 5, + "name": "event_type: _", + "type_info": "Text" + }, + { + "ordinal": 6, + "name": "metadata", + "type_info": "Jsonb" + }, + { + "ordinal": 7, + "name": "created", + "type_info": "Timestamptz" + } + ], + "parameters": { + "Left": [ + "Int8Array", + "TextArray", + "TextArray", + "Int8Array", + "TextArray", + "TextArray" + ] + }, + "nullable": [ + false, + false, + false, + true, + true, + false, + true, + false + ] + }, + "hash": "5352e8689a9d78c6976842afa12da3a3c9a8a3eee6985b548207eea7debcaafb" +} diff --git a/.sqlx/query-79889e768dd4b129ed10f46a6b973208e7897cad64a4e371a6f3b76310d0479b.json b/.sqlx/query-79889e768dd4b129ed10f46a6b973208e7897cad64a4e371a6f3b76310d0479b.json new file mode 100644 index 00000000..b1468732 --- /dev/null +++ b/.sqlx/query-79889e768dd4b129ed10f46a6b973208e7897cad64a4e371a6f3b76310d0479b.json @@ -0,0 +1,20 @@ +{ + "db_name": "PostgreSQL", + "query": "\n INSERT INTO events (\n id,\n target_id,\n target_id_type,\n triggerer_id,\n triggerer_id_type,\n event_type,\n metadata\n )\n SELECT * FROM UNNEST (\n $1::bigint[],\n $2::bigint[],\n $3::text[],\n $4::bigint[],\n $5::text[],\n $6::text[],\n $7::jsonb[]\n )\n ", + "describe": { + "columns": [], + "parameters": { + "Left": [ + "Int8Array", + "Int8Array", + "TextArray", + "Int8Array", + "TextArray", + "TextArray", + "JsonbArray" + ] + }, + "nullable": [] + }, + "hash": "79889e768dd4b129ed10f46a6b973208e7897cad64a4e371a6f3b76310d0479b" +} diff --git a/.sqlx/query-7d4a842888bb3ed24090f9cb903a0a1fae6852ad81db66e0333bb0f373d9b9a3.json b/.sqlx/query-7d4a842888bb3ed24090f9cb903a0a1fae6852ad81db66e0333bb0f373d9b9a3.json new file mode 100644 index 00000000..e4d13f68 --- /dev/null +++ b/.sqlx/query-7d4a842888bb3ed24090f9cb903a0a1fae6852ad81db66e0333bb0f373d9b9a3.json @@ -0,0 +1,28 @@ +{ + "db_name": "PostgreSQL", + "query": "SELECT follower_id, target_id FROM organization_follows WHERE follower_id=$1", + "describe": { + "columns": [ + { + "ordinal": 0, + "name": "follower_id", + "type_info": "Int8" + }, + { + "ordinal": 1, + "name": "target_id", + "type_info": "Int8" + } + ], + "parameters": { + "Left": [ + "Int8" + ] + }, + "nullable": [ + false, + false + ] + }, + "hash": "7d4a842888bb3ed24090f9cb903a0a1fae6852ad81db66e0333bb0f373d9b9a3" +} diff --git a/.sqlx/query-9699feae1d917bc90738f3e6b46f49e4f76c06ceb1ce1ab3fb9317b1035deb65.json b/.sqlx/query-9699feae1d917bc90738f3e6b46f49e4f76c06ceb1ce1ab3fb9317b1035deb65.json new file mode 100644 index 00000000..3ca3e3af --- /dev/null +++ b/.sqlx/query-9699feae1d917bc90738f3e6b46f49e4f76c06ceb1ce1ab3fb9317b1035deb65.json @@ -0,0 +1,28 @@ +{ + "db_name": "PostgreSQL", + "query": "SELECT follower_id, target_id FROM user_follows WHERE follower_id=$1", + "describe": { + "columns": [ + { + "ordinal": 0, + "name": "follower_id", + "type_info": "Int8" + }, + { + "ordinal": 1, + "name": "target_id", + "type_info": "Int8" + } + ], + "parameters": { + "Left": [ + "Int8" + ] + }, + "nullable": [ + false, + false + ] + }, + "hash": "9699feae1d917bc90738f3e6b46f49e4f76c06ceb1ce1ab3fb9317b1035deb65" +} diff --git a/.sqlx/query-9d73bfe2094b88fb7bcb0ef47741c612370373c6263566fc870da4a82b8ded7b.json b/.sqlx/query-9d73bfe2094b88fb7bcb0ef47741c612370373c6263566fc870da4a82b8ded7b.json new file mode 100644 index 00000000..af3befe5 --- /dev/null +++ b/.sqlx/query-9d73bfe2094b88fb7bcb0ef47741c612370373c6263566fc870da4a82b8ded7b.json @@ -0,0 +1,15 @@ +{ + "db_name": "PostgreSQL", + "query": "DELETE FROM organization_follows WHERE follower_id=$1 AND target_id=$2", + "describe": { + "columns": [], + "parameters": { + "Left": [ + "Int8", + "Int8" + ] + }, + "nullable": [] + }, + "hash": "9d73bfe2094b88fb7bcb0ef47741c612370373c6263566fc870da4a82b8ded7b" +} diff --git a/.sqlx/query-a260edc97e46f27b6d3bcc4ba3fe84cab2a5ab0953deb6b490d1e4df1c126036.json b/.sqlx/query-a260edc97e46f27b6d3bcc4ba3fe84cab2a5ab0953deb6b490d1e4df1c126036.json new file mode 100644 index 00000000..38ea7f22 --- /dev/null +++ b/.sqlx/query-a260edc97e46f27b6d3bcc4ba3fe84cab2a5ab0953deb6b490d1e4df1c126036.json @@ -0,0 +1,28 @@ +{ + "db_name": "PostgreSQL", + "query": "SELECT follower_id, target_id FROM user_follows WHERE target_id=$1", + "describe": { + "columns": [ + { + "ordinal": 0, + "name": "follower_id", + "type_info": "Int8" + }, + { + "ordinal": 1, + "name": "target_id", + "type_info": "Int8" + } + ], + "parameters": { + "Left": [ + "Int8" + ] + }, + "nullable": [ + false, + false + ] + }, + "hash": "a260edc97e46f27b6d3bcc4ba3fe84cab2a5ab0953deb6b490d1e4df1c126036" +} diff --git a/.sqlx/query-c337079ccc53022395a4f3c35fa816ae3421cfa811be32ce478a405d5ef99c06.json b/.sqlx/query-c337079ccc53022395a4f3c35fa816ae3421cfa811be32ce478a405d5ef99c06.json new file mode 100644 index 00000000..69bbe2ae --- /dev/null +++ b/.sqlx/query-c337079ccc53022395a4f3c35fa816ae3421cfa811be32ce478a405d5ef99c06.json @@ -0,0 +1,15 @@ +{ + "db_name": "PostgreSQL", + "query": " INSERT INTO user_follows (follower_id, target_id) VALUES ($1, $2)", + "describe": { + "columns": [], + "parameters": { + "Left": [ + "Int8", + "Int8" + ] + }, + "nullable": [] + }, + "hash": "c337079ccc53022395a4f3c35fa816ae3421cfa811be32ce478a405d5ef99c06" +} diff --git a/.sqlx/query-cbeb96926e6d958bcaf5ed28201e36c5ae2c2cb0067ff551a4ecea8a1685b87c.json b/.sqlx/query-cbeb96926e6d958bcaf5ed28201e36c5ae2c2cb0067ff551a4ecea8a1685b87c.json new file mode 100644 index 00000000..8179cbf5 --- /dev/null +++ b/.sqlx/query-cbeb96926e6d958bcaf5ed28201e36c5ae2c2cb0067ff551a4ecea8a1685b87c.json @@ -0,0 +1,22 @@ +{ + "db_name": "PostgreSQL", + "query": "SELECT EXISTS(SELECT 1 FROM events WHERE id=$1)", + "describe": { + "columns": [ + { + "ordinal": 0, + "name": "exists", + "type_info": "Bool" + } + ], + "parameters": { + "Left": [ + "Int8" + ] + }, + "nullable": [ + null + ] + }, + "hash": "cbeb96926e6d958bcaf5ed28201e36c5ae2c2cb0067ff551a4ecea8a1685b87c" +} diff --git a/.sqlx/query-e50f350a4eb67ddbbfec7597185af09e5db89675320111e8006587b2666e72a3.json b/.sqlx/query-e50f350a4eb67ddbbfec7597185af09e5db89675320111e8006587b2666e72a3.json new file mode 100644 index 00000000..c56772e1 --- /dev/null +++ b/.sqlx/query-e50f350a4eb67ddbbfec7597185af09e5db89675320111e8006587b2666e72a3.json @@ -0,0 +1,15 @@ +{ + "db_name": "PostgreSQL", + "query": " INSERT INTO organization_follows (follower_id, target_id) VALUES ($1, $2)", + "describe": { + "columns": [], + "parameters": { + "Left": [ + "Int8", + "Int8" + ] + }, + "nullable": [] + }, + "hash": "e50f350a4eb67ddbbfec7597185af09e5db89675320111e8006587b2666e72a3" +} From 8f9b6f8132915930bcc7623760c1ec31ea7ad2d6 Mon Sep 17 00:00:00 2001 From: OmegaJak Date: Wed, 25 Oct 2023 00:13:50 -0500 Subject: [PATCH 50/72] Actually use organization follows to get events too --- src/database/models/creator_follows.rs | 2 +- src/routes/v3/users.rs | 13 ++++++++++++- 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/src/database/models/creator_follows.rs b/src/database/models/creator_follows.rs index 2d16e4ae..6251189f 100644 --- a/src/database/models/creator_follows.rs +++ b/src/database/models/creator_follows.rs @@ -69,7 +69,7 @@ macro_rules! impl_follow { Ok(res.into_iter().map(|r| r.into()).collect_vec()) } - pub async fn get_follows_from_follower( + pub async fn get_follows_by_follower( follower_user_id: UserId, exec: impl sqlx::Executor<'_, Database = sqlx::Postgres>, ) -> Result, DatabaseError> { diff --git a/src/routes/v3/users.rs b/src/routes/v3/users.rs index 938536c2..1794ef7d 100644 --- a/src/routes/v3/users.rs +++ b/src/routes/v3/users.rs @@ -26,6 +26,7 @@ use itertools::Itertools; use sqlx::PgPool; use database::models as db_models; +use database::models::creator_follows::OrganizationFollow as DBOrganizationFollow; use database::models::creator_follows::UserFollow as DBUserFollow; use database::models::event_item::Event as DBEvent; use database::models::user_item::User as DBUser; @@ -106,7 +107,9 @@ pub async fn current_user_feed( .await?; let followed_users = - DBUserFollow::get_follows_from_follower(current_user.id.into(), &**pool).await?; + DBUserFollow::get_follows_by_follower(current_user.id.into(), &**pool).await?; + let followed_organizations = + DBOrganizationFollow::get_follows_by_follower(current_user.id.into(), &**pool).await?; let selectors = followed_users .into_iter() @@ -114,6 +117,14 @@ pub async fn current_user_feed( id: follow.target_id.into(), event_type: EventType::ProjectCreated, }) + .chain( + followed_organizations + .into_iter() + .map(|follow| EventSelector { + id: follow.target_id.into(), + event_type: EventType::ProjectCreated, + }), + ) .collect_vec(); let events = DBEvent::get_events(&[], &selectors, &**pool).await?; From 23120d8f395d5a7360f68c29564b1849b0c87063 Mon Sep 17 00:00:00 2001 From: OmegaJak Date: Wed, 25 Oct 2023 10:50:53 -0500 Subject: [PATCH 51/72] Add offset/limit parameters and explicitly sort by created descending --- src/database/models/event_item.rs | 1 + src/routes/v3/users.rs | 15 ++++++++++++++- 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/src/database/models/event_item.rs b/src/database/models/event_item.rs index 887564c0..64ca8617 100644 --- a/src/database/models/event_item.rs +++ b/src/database/models/event_item.rs @@ -195,6 +195,7 @@ impl Event { OR (triggerer_id, triggerer_id_type, event_type) = ANY(SELECT * FROM UNNEST ($4::bigint[], $5::text[], $6::text[])) + ORDER BY created DESC "#, &target_ids[..], &target_id_types[..] as &[IdType], diff --git a/src/routes/v3/users.rs b/src/routes/v3/users.rs index 1794ef7d..1dcd3f70 100644 --- a/src/routes/v3/users.rs +++ b/src/routes/v3/users.rs @@ -23,6 +23,7 @@ use actix_web::{ HttpRequest, HttpResponse, }; use itertools::Itertools; +use serde::{Deserialize, Serialize}; use sqlx::PgPool; use database::models as db_models; @@ -90,9 +91,16 @@ pub async fn user_unfollow( .await } +#[derive(Serialize, Deserialize)] +pub struct FeedParameters { + pub limit: Option, + pub offset: Option, +} + #[get("feed")] pub async fn current_user_feed( req: HttpRequest, + web::Query(params): web::Query, pool: web::Data, redis: web::Data, session_queue: web::Data, @@ -126,7 +134,12 @@ pub async fn current_user_feed( }), ) .collect_vec(); - let events = DBEvent::get_events(&[], &selectors, &**pool).await?; + let events = DBEvent::get_events(&[], &selectors, &**pool) + .await? + .into_iter() + .skip(params.offset.unwrap_or(0)) + .take(params.offset.unwrap_or(usize::MAX)) + .collect_vec(); let mut feed_items: Vec = Vec::new(); let authorized_projects = From f5f7b83d7aee968ce041550d55d32a24e7b5f0b7 Mon Sep 17 00:00:00 2001 From: OmegaJak Date: Wed, 25 Oct 2023 12:38:12 -0500 Subject: [PATCH 52/72] formatting --- src/routes/v3/mod.rs | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/routes/v3/mod.rs b/src/routes/v3/mod.rs index bf8aa97f..bd053a25 100644 --- a/src/routes/v3/mod.rs +++ b/src/routes/v3/mod.rs @@ -1,13 +1,12 @@ -pub mod follow; -pub mod organizations; -pub mod users; - pub use super::ApiError; use crate::{auth::oauth, util::cors::default_cors}; use actix_web::{web, HttpResponse}; use serde_json::json; +pub mod follow; pub mod oauth_clients; +pub mod organizations; +pub mod users; pub fn config(cfg: &mut web::ServiceConfig) { cfg.service( From 5694ca787aba1b51504206c4b538a4df615cb1f0 Mon Sep 17 00:00:00 2001 From: OmegaJak Date: Wed, 25 Oct 2023 15:56:58 -0500 Subject: [PATCH 53/72] Feed tests --- Cargo.lock | 7 +++ Cargo.toml | 1 + src/models/feed_item.rs | 16 +++---- src/models/organizations.rs | 2 +- src/routes/v3/mod.rs | 2 +- src/routes/v3/users.rs | 3 +- tests/common/actix.rs | 10 ++++ tests/common/api_v2/organization.rs | 11 +++-- tests/common/api_v3/mod.rs | 2 + tests/common/api_v3/organization.rs | 25 ++++++++++ tests/common/api_v3/user.rs | 41 +++++++++++++++++ tests/common/dummy_data.rs | 2 +- tests/common/permissions.rs | 2 +- tests/common/request_data.rs | 4 +- tests/feed.rs | 71 +++++++++++++++++++++++++++++ 15 files changed, 182 insertions(+), 17 deletions(-) create mode 100644 tests/common/api_v3/organization.rs create mode 100644 tests/common/api_v3/user.rs create mode 100644 tests/feed.rs diff --git a/Cargo.lock b/Cargo.lock index fc2b10ae..f87d9e0f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -444,6 +444,12 @@ version = "0.10.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "619743e34b5ba4e9703bba34deac3427c72507c7159f5fd030aea8cac0cfe341" +[[package]] +name = "assert_matches" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9b34d609dfbaf33d6889b2b7106d3ca345eacad44200913df5ba02bfd31d2ba9" + [[package]] name = "async-channel" version = "1.9.0" @@ -2244,6 +2250,7 @@ dependencies = [ "actix-web-prom", "actix-ws", "argon2", + "assert_matches", "async-trait", "base64 0.21.4", "bitflags 2.4.0", diff --git a/Cargo.toml b/Cargo.toml index 89ca5c31..471fc0e9 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -109,3 +109,4 @@ derive-new = "0.5.9" [dev-dependencies] actix-http = "3.4.0" +assert_matches = "1.5.0" diff --git a/src/models/feed_item.rs b/src/models/feed_item.rs index f1cee015..b1557e52 100644 --- a/src/models/feed_item.rs +++ b/src/models/feed_item.rs @@ -12,11 +12,11 @@ use crate::database::models::event_item as DBEvent; #[serde(into = "Base62Id")] pub struct FeedItemId(pub u64); -#[derive(Serialize, Deserialize)] +#[derive(Serialize, Deserialize, Debug)] #[serde(tag = "type", rename_all = "snake_case")] pub enum CreatorId { - User(UserId), - Organization(OrganizationId), + User { id: UserId }, + Organization { id: OrganizationId }, } #[derive(Serialize, Deserialize)] @@ -26,7 +26,7 @@ pub struct FeedItem { pub time: DateTime, } -#[derive(Serialize, Deserialize)] +#[derive(Serialize, Deserialize, Debug)] #[serde(tag = "type", rename_all = "snake_case")] pub enum FeedItemBody { ProjectCreated { @@ -39,10 +39,10 @@ pub enum FeedItemBody { impl From for CreatorId { fn from(value: crate::database::models::event_item::CreatorId) -> Self { match value { - DBEvent::CreatorId::User(user_id) => CreatorId::User(user_id.into()), - DBEvent::CreatorId::Organization(organization_id) => { - CreatorId::Organization(organization_id.into()) - } + DBEvent::CreatorId::User(user_id) => CreatorId::User { id: user_id.into() }, + DBEvent::CreatorId::Organization(organization_id) => CreatorId::Organization { + id: organization_id.into(), + }, } } } diff --git a/src/models/organizations.rs b/src/models/organizations.rs index 6163ddee..b583aced 100644 --- a/src/models/organizations.rs +++ b/src/models/organizations.rs @@ -5,7 +5,7 @@ use super::{ use serde::{Deserialize, Serialize}; /// The ID of a team -#[derive(Copy, Clone, PartialEq, Eq, Serialize, Deserialize)] +#[derive(Copy, Clone, PartialEq, Eq, Serialize, Deserialize, Debug)] #[serde(from = "Base62Id")] #[serde(into = "Base62Id")] pub struct OrganizationId(pub u64); diff --git a/src/routes/v3/mod.rs b/src/routes/v3/mod.rs index bd053a25..7709aff3 100644 --- a/src/routes/v3/mod.rs +++ b/src/routes/v3/mod.rs @@ -13,9 +13,9 @@ pub fn config(cfg: &mut web::ServiceConfig) { web::scope("v3") .wrap(default_cors()) .route("", web::get().to(hello_world)) - .configure(users::config) .configure(oauth::config) .configure(oauth_clients::config) + .configure(users::config) .configure(organizations::config), ); } diff --git a/src/routes/v3/users.rs b/src/routes/v3/users.rs index 1dcd3f70..951aceaa 100644 --- a/src/routes/v3/users.rs +++ b/src/routes/v3/users.rs @@ -36,7 +36,8 @@ pub fn config(cfg: &mut web::ServiceConfig) { cfg.service( web::scope("user") .service(user_follow) - .service(user_unfollow), + .service(user_unfollow) + .service(current_user_feed), ); } diff --git a/tests/common/actix.rs b/tests/common/actix.rs index 11759d7f..e8c11f90 100644 --- a/tests/common/actix.rs +++ b/tests/common/actix.rs @@ -80,3 +80,13 @@ fn generate_multipart(data: impl IntoIterator) -> (Stri (boundary, Bytes::from(payload)) } + +pub trait TestRequestExtensions { + fn append_auth(self, pat: &str) -> TestRequest; +} + +impl TestRequestExtensions for TestRequest { + fn append_auth(self, pat: &str) -> TestRequest { + self.append_header((reqwest::header::AUTHORIZATION, pat)) + } +} diff --git a/tests/common/api_v2/organization.rs b/tests/common/api_v2/organization.rs index 31f0ea4c..99394bf1 100644 --- a/tests/common/api_v2/organization.rs +++ b/tests/common/api_v2/organization.rs @@ -1,3 +1,4 @@ +use actix_http::StatusCode; use actix_web::{ dev::ServiceResponse, test::{self, TestRequest}, @@ -6,7 +7,7 @@ use bytes::Bytes; use labrinth::models::{organizations::Organization, projects::Project}; use serde_json::json; -use crate::common::request_data::ImageData; +use crate::common::{asserts::assert_status, request_data::ImageData}; use super::ApiV2; @@ -42,8 +43,7 @@ impl ApiV2 { pat: &str, ) -> Organization { let resp = self.get_organization(id_or_title, pat).await; - assert_eq!(resp.status(), 200); - test::read_body_json(resp).await + deser_organization(resp).await } pub async fn get_organization_projects(&self, id_or_title: &str, pat: &str) -> ServiceResponse { @@ -150,3 +150,8 @@ impl ApiV2 { self.call(req).await } } + +pub async fn deser_organization(response: ServiceResponse) -> Organization { + assert_status(&response, StatusCode::OK); + test::read_body_json(response).await +} diff --git a/tests/common/api_v3/mod.rs b/tests/common/api_v3/mod.rs index 2155aa3c..f4e08f34 100644 --- a/tests/common/api_v3/mod.rs +++ b/tests/common/api_v3/mod.rs @@ -6,6 +6,8 @@ use std::rc::Rc; pub mod oauth; pub mod oauth_clients; +pub mod organization; +pub mod user; #[derive(Clone)] pub struct ApiV3 { diff --git a/tests/common/api_v3/organization.rs b/tests/common/api_v3/organization.rs new file mode 100644 index 00000000..700144c7 --- /dev/null +++ b/tests/common/api_v3/organization.rs @@ -0,0 +1,25 @@ +use actix_web::{dev::ServiceResponse, test::TestRequest}; + +use crate::common::actix::TestRequestExtensions; + +use super::ApiV3; + +impl ApiV3 { + pub async fn follow_organization(&self, organization_id: &str, pat: &str) -> ServiceResponse { + let req = TestRequest::post() + .uri(&format!("/v3/organization/{}/follow", organization_id)) + .append_auth(pat) + .to_request(); + + self.call(req).await + } + + pub async fn unfollow_organization(&self, organization_id: &str, pat: &str) -> ServiceResponse { + let req = TestRequest::delete() + .uri(&format!("/v3/organization/{}/follow", organization_id)) + .append_auth(pat) + .to_request(); + + self.call(req).await + } +} diff --git a/tests/common/api_v3/user.rs b/tests/common/api_v3/user.rs new file mode 100644 index 00000000..2574c4f5 --- /dev/null +++ b/tests/common/api_v3/user.rs @@ -0,0 +1,41 @@ +use actix_http::StatusCode; +use actix_web::{ + dev::ServiceResponse, + test::{self, TestRequest}, +}; +use labrinth::models::feed_item::FeedItem; + +use crate::common::{actix::TestRequestExtensions, asserts::assert_status}; + +use super::ApiV3; + +impl ApiV3 { + pub async fn follow_user(&self, user_id: &str, pat: &str) -> ServiceResponse { + let req = TestRequest::post() + .uri(&format!("/v3/user/{}/follow", user_id)) + .append_auth(pat) + .to_request(); + + self.call(req).await + } + + pub async fn unfollow_user(&self, user_id: &str, pat: &str) -> ServiceResponse { + let req = TestRequest::delete() + .uri(&format!("/v3/user/{}/follow", user_id)) + .append_auth(pat) + .to_request(); + + self.call(req).await + } + + pub async fn get_feed(&self, pat: &str) -> Vec { + let req = TestRequest::get() + .uri("/v3/user/feed") + .append_auth(pat) + .to_request(); + let resp = self.call(req).await; + assert_status(&resp, StatusCode::OK); + + test::read_body_json(resp).await + } +} diff --git a/tests/common/dummy_data.rs b/tests/common/dummy_data.rs index ed3b7f08..0bd6981a 100644 --- a/tests/common/dummy_data.rs +++ b/tests/common/dummy_data.rs @@ -214,7 +214,7 @@ pub async fn add_project_alpha(test_env: &TestEnvironment) -> (Project, Version) let (project, versions) = test_env .v2 .add_public_project( - get_public_project_creation_data("alpha", Some(DummyJarFile::DummyProjectAlpha)), + get_public_project_creation_data("alpha", Some(DummyJarFile::DummyProjectAlpha), None), USER_USER_PAT, ) .await; diff --git a/tests/common/permissions.rs b/tests/common/permissions.rs index 4ab33900..58dda70b 100644 --- a/tests/common/permissions.rs +++ b/tests/common/permissions.rs @@ -849,7 +849,7 @@ async fn create_dummy_project(test_env: &TestEnvironment) -> (String, String) { // Create a very simple project let slug = generate_random_name("test_project"); - let creation_data = request_data::get_public_project_creation_data(&slug, None); + let creation_data = request_data::get_public_project_creation_data(&slug, None, None); let (project, _) = api.add_public_project(creation_data, ADMIN_USER_PAT).await; let project_id = project.id.to_string(); let team_id = project.team.to_string(); diff --git a/tests/common/request_data.rs b/tests/common/request_data.rs index bd5eb284..20167e44 100644 --- a/tests/common/request_data.rs +++ b/tests/common/request_data.rs @@ -22,6 +22,7 @@ pub struct ImageData { pub fn get_public_project_creation_data( slug: &str, version_jar: Option, + organization_id: Option<&str>, ) -> ProjectCreationRequestData { let initial_versions = if let Some(ref jar) = version_jar { json!([{ @@ -51,7 +52,8 @@ pub fn get_public_project_creation_data( "initial_versions": initial_versions, "is_draft": is_draft, "categories": [], - "license_id": "MIT" + "license_id": "MIT", + "organization_id": organization_id } ); diff --git a/tests/feed.rs b/tests/feed.rs new file mode 100644 index 00000000..d7addd3a --- /dev/null +++ b/tests/feed.rs @@ -0,0 +1,71 @@ +use assert_matches::assert_matches; +use common::{ + api_v2::organization::deser_organization, + database::{FRIEND_USER_PAT, USER_USER_ID, USER_USER_PAT}, + environment::with_test_environment, + request_data::get_public_project_creation_data, +}; +use labrinth::models::feed_item::FeedItemBody; + +use crate::common::dummy_data::DummyProjectAlpha; + +mod common; + +#[actix_rt::test] +async fn user_feed_before_following_user_shows_no_projects() { + with_test_environment(|env| async move { + let feed = env.v3.get_feed(FRIEND_USER_PAT).await; + + assert_eq!(feed.len(), 0); + }) + .await +} + +#[actix_rt::test] +async fn user_feed_after_following_user_shows_previously_created_public_projects() { + with_test_environment(|env| async move { + let DummyProjectAlpha { + project_id: alpha_project_id, + .. + } = env.dummy.as_ref().unwrap().project_alpha.clone(); + env.v3.follow_user(USER_USER_ID, FRIEND_USER_PAT).await; + + let feed = env.v3.get_feed(FRIEND_USER_PAT).await; + + assert_eq!(feed.len(), 1); + assert_matches!( + feed[0].body, + FeedItemBody::ProjectCreated { project_id, .. } if project_id.to_string() == alpha_project_id + ) + }) + .await +} + +#[actix_rt::test] +async fn user_feed_when_following_user_that_creates_project_as_org_only_shows_event_when_following_org( +) { + with_test_environment(|env| async move { + let resp = env + .v2 + .create_organization("test", "desc", USER_USER_ID) + .await; + let organization = deser_organization(resp).await; + let org_id = organization.id.to_string(); + let project_create_data = get_public_project_creation_data("a", None, Some(&org_id)); + let (project, _) = env + .v2 + .add_public_project(project_create_data, USER_USER_PAT) + .await; + + env.v3.follow_user(USER_USER_ID, FRIEND_USER_PAT).await; + let feed = env.v3.get_feed(FRIEND_USER_PAT).await; + assert_eq!(feed.len(), 1); + assert_matches!(feed[0].body, FeedItemBody::ProjectCreated { project_id, .. } if project_id != project.id); + + env.v3.follow_organization(&org_id, FRIEND_USER_PAT).await; + let feed = env.v3.get_feed(FRIEND_USER_PAT).await; + assert_eq!(feed.len(), 1); + assert_matches!(feed[0].body, FeedItemBody::ProjectCreated { project_id, .. } if project_id == project.id); + }) + .await; +} From 32bd67e34d401d0ca43b1ecb1a6e8b5767b44673 Mon Sep 17 00:00:00 2001 From: Jackson Kruger Date: Fri, 27 Oct 2023 11:01:22 -0500 Subject: [PATCH 54/72] Fix organization id bug and feed tests --- src/models/feed_item.rs | 4 ++-- src/routes/v2/project_creation.rs | 4 ++-- tests/common/asserts.rs | 9 +++++++++ tests/common/dummy_data.rs | 2 +- tests/feed.rs | 17 +++++++++-------- tests/user.rs | 4 ++-- 6 files changed, 25 insertions(+), 15 deletions(-) diff --git a/src/models/feed_item.rs b/src/models/feed_item.rs index b1557e52..0820b888 100644 --- a/src/models/feed_item.rs +++ b/src/models/feed_item.rs @@ -7,7 +7,7 @@ use serde::{Deserialize, Serialize}; use crate::database::models::event_item as DBEvent; -#[derive(Copy, Clone, PartialEq, Eq, Serialize, Deserialize)] +#[derive(Copy, Clone, PartialEq, Eq, Serialize, Deserialize, Debug)] #[serde(from = "Base62Id")] #[serde(into = "Base62Id")] pub struct FeedItemId(pub u64); @@ -19,7 +19,7 @@ pub enum CreatorId { Organization { id: OrganizationId }, } -#[derive(Serialize, Deserialize)] +#[derive(Serialize, Deserialize, Debug)] pub struct FeedItem { pub id: FeedItemId, pub body: FeedItemBody, diff --git a/src/routes/v2/project_creation.rs b/src/routes/v2/project_creation.rs index f7652fb2..f86d3504 100644 --- a/src/routes/v2/project_creation.rs +++ b/src/routes/v2/project_creation.rs @@ -245,7 +245,7 @@ struct ProjectCreateData { pub uploaded_images: Vec, /// The id of the organization to create the project in - pub organization_id: Option, + pub organization_id: Option, } #[derive(Serialize, Deserialize, Validate, Clone)] @@ -749,7 +749,7 @@ async fn project_create_inner( } } - let organization_id = project_create_data.organization_id; + let organization_id = project_create_data.organization_id.map(|id| id.into()); insert_project_create_event(project_id, organization_id, ¤t_user, transaction) .await?; diff --git a/tests/common/asserts.rs b/tests/common/asserts.rs index 3c7f585a..576b7548 100644 --- a/tests/common/asserts.rs +++ b/tests/common/asserts.rs @@ -1,5 +1,7 @@ #![allow(dead_code)] +use labrinth::models::feed_item::{FeedItem, FeedItemBody}; + pub fn assert_status(response: &actix_web::dev::ServiceResponse, status: actix_http::StatusCode) { assert_eq!(response.status(), status, "{:#?}", response.response()); } @@ -10,3 +12,10 @@ pub fn assert_any_status_except( ) { assert_ne!(response.status(), status, "{:#?}", response.response()); } + +pub fn assert_feed_contains_project_created( + feed: &[FeedItem], + expected_project_id: labrinth::models::projects::ProjectId, +) { + assert!(feed.iter().any(|fi| matches!(fi.body, FeedItemBody::ProjectCreated { project_id, .. } if project_id == expected_project_id)), "{:#?}", &feed); +} diff --git a/tests/common/dummy_data.rs b/tests/common/dummy_data.rs index 0bd6981a..c5c7bea0 100644 --- a/tests/common/dummy_data.rs +++ b/tests/common/dummy_data.rs @@ -21,7 +21,7 @@ use super::{ request_data::get_public_project_creation_data, }; -pub const DUMMY_DATA_UPDATE: i64 = 3; +pub const DUMMY_DATA_UPDATE: i64 = 4; #[allow(dead_code)] pub const DUMMY_CATEGORIES: &[&str] = &[ diff --git a/tests/feed.rs b/tests/feed.rs index d7addd3a..695a3783 100644 --- a/tests/feed.rs +++ b/tests/feed.rs @@ -7,7 +7,9 @@ use common::{ }; use labrinth::models::feed_item::FeedItemBody; -use crate::common::dummy_data::DummyProjectAlpha; +use crate::common::{ + asserts::assert_feed_contains_project_created, dummy_data::DummyProjectAlpha, get_json_val_str, +}; mod common; @@ -45,13 +47,12 @@ async fn user_feed_after_following_user_shows_previously_created_public_projects async fn user_feed_when_following_user_that_creates_project_as_org_only_shows_event_when_following_org( ) { with_test_environment(|env| async move { - let resp = env - .v2 - .create_organization("test", "desc", USER_USER_ID) + let resp = env.v2 + .create_organization("test", "desc", USER_USER_PAT) .await; let organization = deser_organization(resp).await; - let org_id = organization.id.to_string(); - let project_create_data = get_public_project_creation_data("a", None, Some(&org_id)); + let org_id = get_json_val_str(organization.id); + let project_create_data = get_public_project_creation_data("thisisaslug", None, Some(&org_id)); let (project, _) = env .v2 .add_public_project(project_create_data, USER_USER_PAT) @@ -64,8 +65,8 @@ async fn user_feed_when_following_user_that_creates_project_as_org_only_shows_ev env.v3.follow_organization(&org_id, FRIEND_USER_PAT).await; let feed = env.v3.get_feed(FRIEND_USER_PAT).await; - assert_eq!(feed.len(), 1); - assert_matches!(feed[0].body, FeedItemBody::ProjectCreated { project_id, .. } if project_id == project.id); + assert_eq!(feed.len(), 2); + assert_feed_contains_project_created(&feed, project.id); }) .await; } diff --git a/tests/user.rs b/tests/user.rs index 664bbdc1..f252a7b2 100644 --- a/tests/user.rs +++ b/tests/user.rs @@ -25,7 +25,7 @@ pub async fn get_user_projects_after_creating_project_returns_new_project() { let (project, _) = api .add_public_project( - get_public_project_creation_data("slug", Some(DummyJarFile::BasicMod)), + get_public_project_creation_data("slug", Some(DummyJarFile::BasicMod), None), USER_USER_PAT, ) .await; @@ -44,7 +44,7 @@ pub async fn get_user_projects_after_deleting_project_shows_removal() { let api = test_env.v2; let (project, _) = api .add_public_project( - get_public_project_creation_data("iota", Some(DummyJarFile::BasicMod)), + get_public_project_creation_data("iota", Some(DummyJarFile::BasicMod), None), USER_USER_PAT, ) .await; From f6e86b12574843725a65c4a5b47047db0431b60a Mon Sep 17 00:00:00 2001 From: Jackson Kruger Date: Fri, 27 Oct 2023 13:38:18 -0500 Subject: [PATCH 55/72] More feed + following tests --- tests/common/api_v2/organization.rs | 8 ++++- tests/common/api_v2/project.rs | 9 ++++- tests/feed.rs | 53 ++++++++++++++++++----------- 3 files changed, 49 insertions(+), 21 deletions(-) diff --git a/tests/common/api_v2/organization.rs b/tests/common/api_v2/organization.rs index 99394bf1..8db5b648 100644 --- a/tests/common/api_v2/organization.rs +++ b/tests/common/api_v2/organization.rs @@ -7,7 +7,7 @@ use bytes::Bytes; use labrinth::models::{organizations::Organization, projects::Project}; use serde_json::json; -use crate::common::{asserts::assert_status, request_data::ImageData}; +use crate::common::{asserts::assert_status, get_json_val_str, request_data::ImageData}; use super::ApiV2; @@ -29,6 +29,12 @@ impl ApiV2 { self.call(req).await } + pub async fn create_default_organization(&self, pat: &str) -> String { + let resp = self.create_organization("test", "desc", pat).await; + let organization = deser_organization(resp).await; + get_json_val_str(organization.id).to_string() + } + pub async fn get_organization(&self, id_or_title: &str, pat: &str) -> ServiceResponse { let req = TestRequest::get() .uri(&format!("/v2/organization/{id_or_title}")) diff --git a/tests/common/api_v2/project.rs b/tests/common/api_v2/project.rs index 7b3af132..4f8da97d 100644 --- a/tests/common/api_v2/project.rs +++ b/tests/common/api_v2/project.rs @@ -11,12 +11,19 @@ use crate::common::{ actix::AppendsMultipart, asserts::assert_status, database::MOD_USER_PAT, - request_data::{ImageData, ProjectCreationRequestData}, + request_data::{get_public_project_creation_data, ImageData, ProjectCreationRequestData}, }; use super::ApiV2; impl ApiV2 { + pub async fn add_default_org_project(&self, org_id: &str, pat: &str) -> Project { + let project_create_data = + get_public_project_creation_data("thisisaslug", None, Some(org_id)); + let (project, _) = self.add_public_project(project_create_data, pat).await; + project + } + pub async fn add_public_project( &self, creation_data: ProjectCreationRequestData, diff --git a/tests/feed.rs b/tests/feed.rs index 695a3783..5cf2412f 100644 --- a/tests/feed.rs +++ b/tests/feed.rs @@ -1,20 +1,15 @@ +use crate::common::{asserts::assert_feed_contains_project_created, dummy_data::DummyProjectAlpha}; use assert_matches::assert_matches; use common::{ - api_v2::organization::deser_organization, database::{FRIEND_USER_PAT, USER_USER_ID, USER_USER_PAT}, environment::with_test_environment, - request_data::get_public_project_creation_data, }; use labrinth::models::feed_item::FeedItemBody; -use crate::common::{ - asserts::assert_feed_contains_project_created, dummy_data::DummyProjectAlpha, get_json_val_str, -}; - mod common; #[actix_rt::test] -async fn user_feed_before_following_user_shows_no_projects() { +async fn get_feed_before_following_user_shows_no_projects() { with_test_environment(|env| async move { let feed = env.v3.get_feed(FRIEND_USER_PAT).await; @@ -24,7 +19,7 @@ async fn user_feed_before_following_user_shows_no_projects() { } #[actix_rt::test] -async fn user_feed_after_following_user_shows_previously_created_public_projects() { +async fn get_feed_after_following_user_shows_previously_created_public_projects() { with_test_environment(|env| async move { let DummyProjectAlpha { project_id: alpha_project_id, @@ -44,19 +39,11 @@ async fn user_feed_after_following_user_shows_previously_created_public_projects } #[actix_rt::test] -async fn user_feed_when_following_user_that_creates_project_as_org_only_shows_event_when_following_org( +async fn get_feed_when_following_user_that_creates_project_as_org_only_shows_event_when_following_org( ) { with_test_environment(|env| async move { - let resp = env.v2 - .create_organization("test", "desc", USER_USER_PAT) - .await; - let organization = deser_organization(resp).await; - let org_id = get_json_val_str(organization.id); - let project_create_data = get_public_project_creation_data("thisisaslug", None, Some(&org_id)); - let (project, _) = env - .v2 - .add_public_project(project_create_data, USER_USER_PAT) - .await; + let org_id = env.v2.create_default_organization(USER_USER_PAT).await; + let project = env.v2.add_default_org_project(&org_id, USER_USER_PAT).await; env.v3.follow_user(USER_USER_ID, FRIEND_USER_PAT).await; let feed = env.v3.get_feed(FRIEND_USER_PAT).await; @@ -70,3 +57,31 @@ async fn user_feed_when_following_user_that_creates_project_as_org_only_shows_ev }) .await; } + +#[actix_rt::test] +async fn get_feed_after_unfollowing_user_no_longer_shows_feed_items() { + with_test_environment(|env| async move { + env.v3.follow_user(USER_USER_ID, FRIEND_USER_PAT).await; + + env.v3.unfollow_user(USER_USER_ID, FRIEND_USER_PAT).await; + let feed = env.v3.get_feed(FRIEND_USER_PAT).await; + + assert_eq!(feed.len(), 0); + }) + .await; +} + +#[actix_rt::test] +async fn get_feed_after_unfollowing_organization_no_longer_shows_feed_items() { + with_test_environment(|env| async move { + let org_id = env.v2.create_default_organization(USER_USER_PAT).await; + env.v2.add_default_org_project(&org_id, USER_USER_PAT).await; + env.v3.follow_organization(&org_id, FRIEND_USER_PAT).await; + + env.v3.unfollow_organization(&org_id, FRIEND_USER_PAT).await; + let feed = env.v3.get_feed(FRIEND_USER_PAT).await; + + assert_eq!(feed.len(), 0); + }) + .await; +} From 5530a87049d941bb3b400fef8dd7f33925551162 Mon Sep 17 00:00:00 2001 From: Jackson Kruger Date: Fri, 27 Oct 2023 13:41:23 -0500 Subject: [PATCH 56/72] cargo sqlx prepare --- ...0f9a906164adc778e367cde9f0912ab6a040453f36d9b7ba49c6.json} | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) rename .sqlx/{query-5352e8689a9d78c6976842afa12da3a3c9a8a3eee6985b548207eea7debcaafb.json => query-502c57ae375e0f9a906164adc778e367cde9f0912ab6a040453f36d9b7ba49c6.json} (94%) diff --git a/.sqlx/query-5352e8689a9d78c6976842afa12da3a3c9a8a3eee6985b548207eea7debcaafb.json b/.sqlx/query-502c57ae375e0f9a906164adc778e367cde9f0912ab6a040453f36d9b7ba49c6.json similarity index 94% rename from .sqlx/query-5352e8689a9d78c6976842afa12da3a3c9a8a3eee6985b548207eea7debcaafb.json rename to .sqlx/query-502c57ae375e0f9a906164adc778e367cde9f0912ab6a040453f36d9b7ba49c6.json index 12276d94..b2394db1 100644 --- a/.sqlx/query-5352e8689a9d78c6976842afa12da3a3c9a8a3eee6985b548207eea7debcaafb.json +++ b/.sqlx/query-502c57ae375e0f9a906164adc778e367cde9f0912ab6a040453f36d9b7ba49c6.json @@ -1,6 +1,6 @@ { "db_name": "PostgreSQL", - "query": "\n SELECT \n id,\n target_id,\n target_id_type as \"target_id_type: _\",\n triggerer_id,\n triggerer_id_type as \"triggerer_id_type: _\",\n event_type as \"event_type: _\",\n metadata,\n created\n FROM events e\n WHERE \n (target_id, target_id_type, event_type) \n = ANY(SELECT * FROM UNNEST ($1::bigint[], $2::text[], $3::text[]))\n OR\n (triggerer_id, triggerer_id_type, event_type) \n = ANY(SELECT * FROM UNNEST ($4::bigint[], $5::text[], $6::text[]))\n ", + "query": "\n SELECT \n id,\n target_id,\n target_id_type as \"target_id_type: _\",\n triggerer_id,\n triggerer_id_type as \"triggerer_id_type: _\",\n event_type as \"event_type: _\",\n metadata,\n created\n FROM events e\n WHERE \n (target_id, target_id_type, event_type) \n = ANY(SELECT * FROM UNNEST ($1::bigint[], $2::text[], $3::text[]))\n OR\n (triggerer_id, triggerer_id_type, event_type) \n = ANY(SELECT * FROM UNNEST ($4::bigint[], $5::text[], $6::text[]))\n ORDER BY created DESC\n ", "describe": { "columns": [ { @@ -65,5 +65,5 @@ false ] }, - "hash": "5352e8689a9d78c6976842afa12da3a3c9a8a3eee6985b548207eea7debcaafb" + "hash": "502c57ae375e0f9a906164adc778e367cde9f0912ab6a040453f36d9b7ba49c6" } From e6df9be46e61bef97bab4da9530e9a6187eb84f0 Mon Sep 17 00:00:00 2001 From: Jackson Kruger Date: Fri, 27 Oct 2023 14:16:56 -0500 Subject: [PATCH 57/72] Undo rename --- src/auth/checks.rs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/auth/checks.rs b/src/auth/checks.rs index 1ff70e13..c8769bec 100644 --- a/src/auth/checks.rs +++ b/src/auth/checks.rs @@ -80,7 +80,7 @@ pub async fn is_authorized( pub async fn filter_authorized_projects( projects: Vec, - current_user: Option<&User>, + user_option: Option<&User>, pool: &web::Data, ) -> Result, ApiError> { let mut return_projects = Vec::new(); @@ -88,16 +88,16 @@ pub async fn filter_authorized_projects( for project in projects { if !project.inner.status.is_hidden() - || current_user.map(|x| x.role.is_mod()).unwrap_or(false) + || user_option.map(|x| x.role.is_mod()).unwrap_or(false) { return_projects.push(project.into()); - } else if current_user.is_some() { + } else if user_option.is_some() { check_projects.push(project); } } if !check_projects.is_empty() { - if let Some(user) = current_user { + if let Some(user) = user_option { let user_id: models::ids::UserId = user.id.into(); use futures::TryStreamExt; From 84287c6a1f8f7900387cbcfde512fb1ac92f722f Mon Sep 17 00:00:00 2001 From: Jackson Kruger Date: Fri, 27 Oct 2023 19:34:05 -0500 Subject: [PATCH 58/72] Cargo clippy --- src/auth/checks.rs | 3 +-- src/database/models/oauth_client_item.rs | 2 +- src/routes/v3/oauth_clients.rs | 2 +- 3 files changed, 3 insertions(+), 4 deletions(-) diff --git a/src/auth/checks.rs b/src/auth/checks.rs index 8d8fab68..4d47e72c 100644 --- a/src/auth/checks.rs +++ b/src/auth/checks.rs @@ -23,8 +23,7 @@ where { fn validate_all_authorized(self, user_option: Option<&User>) -> Result<(), ApiError> { self.into_iter() - .map(|c| c.validate_authorized(user_option.clone())) - .collect() + .try_for_each(|c| c.validate_authorized(user_option)) } } diff --git a/src/database/models/oauth_client_item.rs b/src/database/models/oauth_client_item.rs index 48907bb8..48870c67 100644 --- a/src/database/models/oauth_client_item.rs +++ b/src/database/models/oauth_client_item.rs @@ -80,7 +80,7 @@ impl OAuthClient { ids: &[OAuthClientId], exec: impl sqlx::Executor<'_, Database = sqlx::Postgres>, ) -> Result, DatabaseError> { - let ids = ids.into_iter().map(|id| id.0).collect_vec(); + let ids = ids.iter().map(|id| id.0).collect_vec(); let ids_ref: &[i64] = &ids; let results = select_clients_with_predicate!("WHERE clients.id = ANY($1::bigint[])", ids_ref) diff --git a/src/routes/v3/oauth_clients.rs b/src/routes/v3/oauth_clients.rs index 59d3f68c..19b12c1c 100644 --- a/src/routes/v3/oauth_clients.rs +++ b/src/routes/v3/oauth_clients.rs @@ -435,7 +435,7 @@ pub async fn get_clients_inner( .await? .1; - let ids: Vec = ids.into_iter().map(|i| (*i).into()).collect(); + let ids: Vec = ids.iter().map(|i| (*i).into()).collect(); let clients = OAuthClient::get_many(&ids, &**pool).await?; clients .iter() From fc3d31b35478e3b7155c51666167cf1724557c4b Mon Sep 17 00:00:00 2001 From: OmegaJak Date: Mon, 30 Oct 2023 10:21:43 -0500 Subject: [PATCH 59/72] Remove validation requiring 1 redirect URI on oauth client creation --- src/routes/v3/oauth_clients.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/src/routes/v3/oauth_clients.rs b/src/routes/v3/oauth_clients.rs index 19b12c1c..5f3839b6 100644 --- a/src/routes/v3/oauth_clients.rs +++ b/src/routes/v3/oauth_clients.rs @@ -138,7 +138,6 @@ pub struct NewOAuthApp { #[validate(custom(function = "crate::util::validate::validate_no_restricted_scopes"))] pub max_scopes: Scopes, - #[validate(length(min = 1))] pub redirect_uris: Vec, } From 19929259a882ca409e4acca34a0f924dafda66e0 Mon Sep 17 00:00:00 2001 From: OmegaJak Date: Mon, 30 Oct 2023 10:36:27 -0500 Subject: [PATCH 60/72] Use serde(flatten) on OAuthClientCreationResult --- src/models/oauth_clients.rs | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/models/oauth_clients.rs b/src/models/oauth_clients.rs index fd8a8c6d..16795aa3 100644 --- a/src/models/oauth_clients.rs +++ b/src/models/oauth_clients.rs @@ -33,8 +33,10 @@ pub struct OAuthRedirectUri { #[derive(Serialize, Deserialize)] pub struct OAuthClientCreationResult { - pub client_secret: String, + #[serde(flatten)] pub client: OAuthClient, + + pub client_secret: String, } #[derive(Deserialize, Serialize)] From 780aae65738cae309a99fa3fe52bceb63a32546c Mon Sep 17 00:00:00 2001 From: OmegaJak Date: Mon, 30 Oct 2023 12:10:27 -0500 Subject: [PATCH 61/72] Clippy --- src/database/models/event_item.rs | 8 ++++---- src/routes/v2/project_creation.rs | 4 ++-- src/routes/v3/users.rs | 8 ++++---- 3 files changed, 10 insertions(+), 10 deletions(-) diff --git a/src/database/models/event_item.rs b/src/database/models/event_item.rs index 64ca8617..e402249d 100644 --- a/src/database/models/event_item.rs +++ b/src/database/models/event_item.rs @@ -141,7 +141,7 @@ impl TryFrom for Event { "the value of created should not be null".to_string(), )) }, - |c| Ok(c), + Ok, )?, }, }) @@ -176,7 +176,7 @@ impl Event { unzip_event_selectors(target_selectors); let (triggerer_ids, triggerer_id_types, triggerer_event_types) = unzip_event_selectors(triggerer_selectors); - Ok(sqlx::query_as!( + sqlx::query_as!( RawEvent, r#" SELECT @@ -208,7 +208,7 @@ impl Event { .await? .into_iter() .map(|r| r.try_into()) - .collect::, _>>()?) + .collect::, _>>() } } @@ -279,7 +279,7 @@ fn unzip_event_selectors( target_selectors: &[EventSelector], ) -> (Vec, Vec, Vec) { target_selectors - .into_iter() + .iter() .map(|t| (t.id.id, t.id.id_type, t.event_type)) .multiunzip() } diff --git a/src/routes/v2/project_creation.rs b/src/routes/v2/project_creation.rs index f86d3504..0ab29d7e 100644 --- a/src/routes/v2/project_creation.rs +++ b/src/routes/v2/project_creation.rs @@ -848,7 +848,7 @@ async fn project_create_inner( slug: project_builder.slug.clone(), project_type: project_create_data.project_type.clone(), team: team_id.into(), - organization: project_create_data.organization_id.map(|x| x.into()), + organization: project_create_data.organization_id, title: project_builder.title.clone(), description: project_builder.description.clone(), body: project_builder.body.clone(), @@ -905,7 +905,7 @@ async fn insert_project_create_event( project_id: project_id.into(), creator_id: organization_id.map_or_else( || CreatorId::User(current_user.id.into()), - |org| CreatorId::Organization(org), + CreatorId::Organization, ), }, transaction, diff --git a/src/routes/v3/users.rs b/src/routes/v3/users.rs index 951aceaa..f1d3726e 100644 --- a/src/routes/v3/users.rs +++ b/src/routes/v3/users.rs @@ -182,16 +182,16 @@ async fn prefetch_authorized_event_projects( ) -> Result, ApiError> { let project_ids = events .iter() - .filter_map(|e| match &e.event_data { + .map(|e| match &e.event_data { EventData::ProjectCreated { project_id, creator_id: _, - } => Some(project_id.clone()), + } => *project_id, }) .collect_vec(); - let projects = db_models::Project::get_many_ids(&project_ids, &***pool, &redis).await?; + let projects = db_models::Project::get_many_ids(&project_ids, &***pool, redis).await?; let authorized_projects = - filter_authorized_projects(projects, Some(¤t_user), &pool).await?; + filter_authorized_projects(projects, Some(current_user), pool).await?; Ok(HashMap::::from_iter( authorized_projects.into_iter().map(|p| (p.id, p)), )) From 48037b13495abb2c58b8f17636246a8dd5196b79 Mon Sep 17 00:00:00 2001 From: OmegaJak Date: Mon, 30 Oct 2023 13:29:24 -0500 Subject: [PATCH 62/72] Code review improvements --- ...49782e217f550394641eec3ecaf06ac269276.json | 28 ------------------- ...bc1bcd97b78c724b6d08791d98811d966194.json} | 4 +-- ...e84cab2a5ab0953deb6b490d1e4df1c126036.json | 28 ------------------- src/database/models/creator_follows.rs | 15 ---------- src/database/models/event_item.rs | 4 +-- src/models/{feed_item.rs => feeds.rs} | 0 src/models/ids.rs | 2 +- src/models/mod.rs | 2 +- src/routes/v3/users.rs | 2 +- tests/common/api_v3/user.rs | 2 +- tests/common/asserts.rs | 2 +- tests/feed.rs | 2 +- 12 files changed, 10 insertions(+), 81 deletions(-) delete mode 100644 .sqlx/query-4bb2dddd46698d1aa3a5060f92549782e217f550394641eec3ecaf06ac269276.json rename .sqlx/{query-502c57ae375e0f9a906164adc778e367cde9f0912ab6a040453f36d9b7ba49c6.json => query-84fde217f2982fbb918157da1509bc1bcd97b78c724b6d08791d98811d966194.json} (78%) delete mode 100644 .sqlx/query-a260edc97e46f27b6d3bcc4ba3fe84cab2a5ab0953deb6b490d1e4df1c126036.json rename src/models/{feed_item.rs => feeds.rs} (100%) diff --git a/.sqlx/query-4bb2dddd46698d1aa3a5060f92549782e217f550394641eec3ecaf06ac269276.json b/.sqlx/query-4bb2dddd46698d1aa3a5060f92549782e217f550394641eec3ecaf06ac269276.json deleted file mode 100644 index 094458da..00000000 --- a/.sqlx/query-4bb2dddd46698d1aa3a5060f92549782e217f550394641eec3ecaf06ac269276.json +++ /dev/null @@ -1,28 +0,0 @@ -{ - "db_name": "PostgreSQL", - "query": "SELECT follower_id, target_id FROM organization_follows WHERE target_id=$1", - "describe": { - "columns": [ - { - "ordinal": 0, - "name": "follower_id", - "type_info": "Int8" - }, - { - "ordinal": 1, - "name": "target_id", - "type_info": "Int8" - } - ], - "parameters": { - "Left": [ - "Int8" - ] - }, - "nullable": [ - false, - false - ] - }, - "hash": "4bb2dddd46698d1aa3a5060f92549782e217f550394641eec3ecaf06ac269276" -} diff --git a/.sqlx/query-502c57ae375e0f9a906164adc778e367cde9f0912ab6a040453f36d9b7ba49c6.json b/.sqlx/query-84fde217f2982fbb918157da1509bc1bcd97b78c724b6d08791d98811d966194.json similarity index 78% rename from .sqlx/query-502c57ae375e0f9a906164adc778e367cde9f0912ab6a040453f36d9b7ba49c6.json rename to .sqlx/query-84fde217f2982fbb918157da1509bc1bcd97b78c724b6d08791d98811d966194.json index b2394db1..45dfe2e4 100644 --- a/.sqlx/query-502c57ae375e0f9a906164adc778e367cde9f0912ab6a040453f36d9b7ba49c6.json +++ b/.sqlx/query-84fde217f2982fbb918157da1509bc1bcd97b78c724b6d08791d98811d966194.json @@ -1,6 +1,6 @@ { "db_name": "PostgreSQL", - "query": "\n SELECT \n id,\n target_id,\n target_id_type as \"target_id_type: _\",\n triggerer_id,\n triggerer_id_type as \"triggerer_id_type: _\",\n event_type as \"event_type: _\",\n metadata,\n created\n FROM events e\n WHERE \n (target_id, target_id_type, event_type) \n = ANY(SELECT * FROM UNNEST ($1::bigint[], $2::text[], $3::text[]))\n OR\n (triggerer_id, triggerer_id_type, event_type) \n = ANY(SELECT * FROM UNNEST ($4::bigint[], $5::text[], $6::text[]))\n ORDER BY created DESC\n ", + "query": "\n SELECT \n id,\n target_id,\n target_id_type as \"target_id_type: _\",\n triggerer_id,\n triggerer_id_type as \"triggerer_id_type: _\",\n event_type as \"event_type: _\",\n metadata,\n created\n FROM events e\n WHERE \n (target_id, target_id_type, event_type)\n = ANY(SELECT * FROM UNNEST ($1::bigint[], $2::text[], $3::text[]))\n OR\n (triggerer_id, triggerer_id_type, event_type)\n = ANY(SELECT * FROM UNNEST ($4::bigint[], $5::text[], $6::text[]))\n ORDER BY created DESC\n ", "describe": { "columns": [ { @@ -65,5 +65,5 @@ false ] }, - "hash": "502c57ae375e0f9a906164adc778e367cde9f0912ab6a040453f36d9b7ba49c6" + "hash": "84fde217f2982fbb918157da1509bc1bcd97b78c724b6d08791d98811d966194" } diff --git a/.sqlx/query-a260edc97e46f27b6d3bcc4ba3fe84cab2a5ab0953deb6b490d1e4df1c126036.json b/.sqlx/query-a260edc97e46f27b6d3bcc4ba3fe84cab2a5ab0953deb6b490d1e4df1c126036.json deleted file mode 100644 index 38ea7f22..00000000 --- a/.sqlx/query-a260edc97e46f27b6d3bcc4ba3fe84cab2a5ab0953deb6b490d1e4df1c126036.json +++ /dev/null @@ -1,28 +0,0 @@ -{ - "db_name": "PostgreSQL", - "query": "SELECT follower_id, target_id FROM user_follows WHERE target_id=$1", - "describe": { - "columns": [ - { - "ordinal": 0, - "name": "follower_id", - "type_info": "Int8" - }, - { - "ordinal": 1, - "name": "target_id", - "type_info": "Int8" - } - ], - "parameters": { - "Left": [ - "Int8" - ] - }, - "nullable": [ - false, - false - ] - }, - "hash": "a260edc97e46f27b6d3bcc4ba3fe84cab2a5ab0953deb6b490d1e4df1c126036" -} diff --git a/src/database/models/creator_follows.rs b/src/database/models/creator_follows.rs index 6251189f..15fc9e2d 100644 --- a/src/database/models/creator_follows.rs +++ b/src/database/models/creator_follows.rs @@ -54,21 +54,6 @@ macro_rules! impl_follow { Ok(()) } - pub async fn get_followers( - target_id: $target_id_type, - exec: impl sqlx::Executor<'_, Database = sqlx::Postgres>, - ) -> Result, DatabaseError> { - let res = sqlx::query_as!( - FollowQuery, - "SELECT follower_id, target_id FROM " + $table_name + " WHERE target_id=$1", - target_id.0 - ) - .fetch_all(exec) - .await?; - - Ok(res.into_iter().map(|r| r.into()).collect_vec()) - } - pub async fn get_follows_by_follower( follower_user_id: UserId, exec: impl sqlx::Executor<'_, Database = sqlx::Postgres>, diff --git a/src/database/models/event_item.rs b/src/database/models/event_item.rs index e402249d..32ebc281 100644 --- a/src/database/models/event_item.rs +++ b/src/database/models/event_item.rs @@ -190,10 +190,10 @@ impl Event { created FROM events e WHERE - (target_id, target_id_type, event_type) + (target_id, target_id_type, event_type) = ANY(SELECT * FROM UNNEST ($1::bigint[], $2::text[], $3::text[])) OR - (triggerer_id, triggerer_id_type, event_type) + (triggerer_id, triggerer_id_type, event_type) = ANY(SELECT * FROM UNNEST ($4::bigint[], $5::text[], $6::text[])) ORDER BY created DESC "#, diff --git a/src/models/feed_item.rs b/src/models/feeds.rs similarity index 100% rename from src/models/feed_item.rs rename to src/models/feeds.rs diff --git a/src/models/ids.rs b/src/models/ids.rs index c5a68b4e..165a5d41 100644 --- a/src/models/ids.rs +++ b/src/models/ids.rs @@ -1,7 +1,7 @@ use thiserror::Error; pub use super::collections::CollectionId; -pub use super::feed_item::FeedItemId; +pub use super::feeds::FeedItemId; pub use super::images::ImageId; pub use super::notifications::NotificationId; pub use super::oauth_clients::OAuthClientAuthorizationId; diff --git a/src/models/mod.rs b/src/models/mod.rs index 2c7c644b..e0bc2c23 100644 --- a/src/models/mod.rs +++ b/src/models/mod.rs @@ -1,7 +1,7 @@ pub mod analytics; pub mod collections; pub mod error; -pub mod feed_item; +pub mod feeds; pub mod ids; pub mod images; pub mod notifications; diff --git a/src/routes/v3/users.rs b/src/routes/v3/users.rs index f1d3726e..3f2366d4 100644 --- a/src/routes/v3/users.rs +++ b/src/routes/v3/users.rs @@ -8,7 +8,7 @@ use crate::{ redis::RedisPool, }, models::{ - feed_item::{FeedItem, FeedItemBody}, + feeds::{FeedItem, FeedItemBody}, ids::ProjectId, pats::Scopes, projects::Project, diff --git a/tests/common/api_v3/user.rs b/tests/common/api_v3/user.rs index 2574c4f5..8529748f 100644 --- a/tests/common/api_v3/user.rs +++ b/tests/common/api_v3/user.rs @@ -3,7 +3,7 @@ use actix_web::{ dev::ServiceResponse, test::{self, TestRequest}, }; -use labrinth::models::feed_item::FeedItem; +use labrinth::models::feeds::FeedItem; use crate::common::{actix::TestRequestExtensions, asserts::assert_status}; diff --git a/tests/common/asserts.rs b/tests/common/asserts.rs index 576b7548..ca818fd6 100644 --- a/tests/common/asserts.rs +++ b/tests/common/asserts.rs @@ -1,6 +1,6 @@ #![allow(dead_code)] -use labrinth::models::feed_item::{FeedItem, FeedItemBody}; +use labrinth::models::feeds::{FeedItem, FeedItemBody}; pub fn assert_status(response: &actix_web::dev::ServiceResponse, status: actix_http::StatusCode) { assert_eq!(response.status(), status, "{:#?}", response.response()); diff --git a/tests/feed.rs b/tests/feed.rs index 5cf2412f..9c44cf07 100644 --- a/tests/feed.rs +++ b/tests/feed.rs @@ -4,7 +4,7 @@ use common::{ database::{FRIEND_USER_PAT, USER_USER_ID, USER_USER_PAT}, environment::with_test_environment, }; -use labrinth::models::feed_item::FeedItemBody; +use labrinth::models::feeds::FeedItemBody; mod common; From d6d2b0d9f1473926d3445a2dc49a35acf0c265f4 Mon Sep 17 00:00:00 2001 From: OmegaJak Date: Mon, 30 Oct 2023 13:47:51 -0500 Subject: [PATCH 63/72] Revert attempts at deduplication --- src/routes/v3/follow.rs | 107 --------------------------------- src/routes/v3/mod.rs | 1 - src/routes/v3/organizations.rs | 79 +++++++++++++++--------- src/routes/v3/users.rs | 76 ++++++++++++++--------- 4 files changed, 97 insertions(+), 166 deletions(-) delete mode 100644 src/routes/v3/follow.rs diff --git a/src/routes/v3/follow.rs b/src/routes/v3/follow.rs deleted file mode 100644 index a27bbe86..00000000 --- a/src/routes/v3/follow.rs +++ /dev/null @@ -1,107 +0,0 @@ -use super::ApiError; -use crate::{ - auth::get_user_from_headers, - database::{models::DatabaseError, redis::RedisPool}, - models::pats::Scopes, - queue::session::AuthQueue, -}; -use actix_web::{web, HttpRequest, HttpResponse}; -use sqlx::PgPool; - -use crate::database::models as db_models; - -pub trait HasId { - fn id(&self) -> T; -} - -impl HasId for db_models::User { - fn id(&self) -> db_models::UserId { - self.id - } -} - -impl HasId for db_models::Organization { - fn id(&self) -> db_models::OrganizationId { - self.id - } -} - -pub async fn follow( - req: HttpRequest, - target_id: web::Path, - pool: web::Data, - redis: web::Data, - session_queue: web::Data, - get_db_model: impl FnOnce(String, web::Data, web::Data) -> Fut1, - insert_follow: impl FnOnce(db_models::UserId, TId, web::Data) -> Fut2, - error_word: &str, -) -> Result -where - Fut1: futures::Future, DatabaseError>>, - Fut2: futures::Future>, - T: HasId, -{ - let (_, current_user) = get_user_from_headers( - &req, - &**pool, - &redis, - &session_queue, - Some(&[Scopes::USER_WRITE]), - ) - .await?; - - let target = get_db_model(target_id.into_inner(), pool.clone(), redis.clone()) - .await? - .ok_or_else(|| { - ApiError::InvalidInput(format!("The specified {} does not exist!", error_word)) - })?; - - insert_follow(current_user.id.into(), target.id(), pool.clone()) - .await - .map_err(|e| match e { - DatabaseError::Database(e) - if e.as_database_error() - .is_some_and(|e| e.is_unique_violation()) => - { - ApiError::InvalidInput(format!("You are already following this {}!", error_word)) - } - e => e.into(), - })?; - - Ok(HttpResponse::NoContent().body("")) -} - -pub async fn unfollow( - req: HttpRequest, - target_id: web::Path, - pool: web::Data, - redis: web::Data, - session_queue: web::Data, - get_db_model: impl FnOnce(String, web::Data, web::Data) -> Fut1, - unfollow: impl FnOnce(db_models::UserId, TId, web::Data) -> Fut2, - error_word: &str, -) -> Result -where - Fut1: futures::Future, DatabaseError>>, - Fut2: futures::Future>, - T: HasId, -{ - let (_, current_user) = get_user_from_headers( - &req, - &**pool, - &redis, - &session_queue, - Some(&[Scopes::USER_WRITE]), - ) - .await?; - - let target = get_db_model(target_id.into_inner(), pool.clone(), redis.clone()) - .await? - .ok_or_else(|| { - ApiError::InvalidInput(format!("The specified {} does not exist!", error_word)) - })?; - - unfollow(current_user.id.into(), target.id(), pool.clone()).await?; - - Ok(HttpResponse::NoContent().body("")) -} diff --git a/src/routes/v3/mod.rs b/src/routes/v3/mod.rs index 7709aff3..e90e0bc6 100644 --- a/src/routes/v3/mod.rs +++ b/src/routes/v3/mod.rs @@ -3,7 +3,6 @@ use crate::{auth::oauth, util::cors::default_cors}; use actix_web::{web, HttpResponse}; use serde_json::json; -pub mod follow; pub mod oauth_clients; pub mod organizations; pub mod users; diff --git a/src/routes/v3/organizations.rs b/src/routes/v3/organizations.rs index 1ae9b66b..a990f8d3 100644 --- a/src/routes/v3/organizations.rs +++ b/src/routes/v3/organizations.rs @@ -1,5 +1,7 @@ use crate::{ - database::{self, redis::RedisPool}, + auth::get_user_from_headers, + database::{self, models::DatabaseError, redis::RedisPool}, + models::pats::Scopes, queue::session::AuthQueue, routes::ApiError, }; @@ -29,24 +31,38 @@ pub async fn organization_follow( redis: web::Data, session_queue: web::Data, ) -> Result { - crate::routes::v3::follow::follow( - req, - target_id, - pool, - redis, - session_queue, - |id, pool, redis| async move { DBOrganization::get(&id, &**pool, &redis).await }, - |follower_id, target_id, pool| async move { - DBOrganizationFollow { - follower_id, - target_id, - } - .insert(&**pool) - .await - }, - "organization", + let (_, current_user) = get_user_from_headers( + &req, + &**pool, + &redis, + &session_queue, + Some(&[Scopes::USER_WRITE]), ) + .await?; + + let target = DBOrganization::get(&target_id, &**pool, &redis) + .await? + .ok_or_else(|| { + ApiError::InvalidInput("The specified organization does not exist!".to_string()) + })?; + + DBOrganizationFollow { + follower_id: current_user.id.into(), + target_id: target.id, + } + .insert(&**pool) .await + .map_err(|e| match e { + DatabaseError::Database(e) + if e.as_database_error() + .is_some_and(|e| e.is_unique_violation()) => + { + ApiError::InvalidInput("You are already following this organization!".to_string()) + } + e => e.into(), + })?; + + Ok(HttpResponse::NoContent().body("")) } #[delete("{id}/follow")] @@ -57,17 +73,22 @@ pub async fn organization_unfollow( redis: web::Data, session_queue: web::Data, ) -> Result { - crate::routes::v3::follow::unfollow( - req, - target_id, - pool, - redis, - session_queue, - |id, pool, redis| async move { DBOrganization::get(&id, &**pool, &redis).await }, - |follower_id, target_id, pool| async move { - DBOrganizationFollow::unfollow(follower_id, target_id, &**pool).await - }, - "organization", + let (_, current_user) = get_user_from_headers( + &req, + &**pool, + &redis, + &session_queue, + Some(&[Scopes::USER_WRITE]), ) - .await + .await?; + + let target = DBOrganization::get(&target_id, &**pool, &redis) + .await? + .ok_or_else(|| { + ApiError::InvalidInput("The specified organization does not exist!".to_string()) + })?; + + DBOrganizationFollow::unfollow(current_user.id.into(), target.id.into(), &**pool).await?; + + Ok(HttpResponse::NoContent().body("")) } diff --git a/src/routes/v3/users.rs b/src/routes/v3/users.rs index 3f2366d4..02e5263f 100644 --- a/src/routes/v3/users.rs +++ b/src/routes/v3/users.rs @@ -4,7 +4,10 @@ use crate::{ auth::{filter_authorized_projects, get_user_from_headers}, database::{ self, - models::event_item::{EventData, EventSelector, EventType}, + models::{ + event_item::{EventData, EventSelector, EventType}, + DatabaseError, + }, redis::RedisPool, }, models::{ @@ -49,24 +52,36 @@ pub async fn user_follow( redis: web::Data, session_queue: web::Data, ) -> Result { - crate::routes::v3::follow::follow( - req, - target_id, - pool, - redis, - session_queue, - |id, pool, redis| async move { DBUser::get(&id, &**pool, &redis).await }, - |follower_id, target_id, pool| async move { - DBUserFollow { - follower_id, - target_id, - } - .insert(&**pool) - .await - }, - "user", + let (_, current_user) = get_user_from_headers( + &req, + &**pool, + &redis, + &session_queue, + Some(&[Scopes::USER_WRITE]), ) + .await?; + + let target = DBUser::get(&target_id, &**pool, &redis) + .await? + .ok_or_else(|| ApiError::InvalidInput("The specified user does not exist!".to_string()))?; + + DBUserFollow { + follower_id: current_user.id.into(), + target_id: target.id, + } + .insert(&**pool) .await + .map_err(|e| match e { + DatabaseError::Database(e) + if e.as_database_error() + .is_some_and(|e| e.is_unique_violation()) => + { + ApiError::InvalidInput("You are already following this user!".to_string()) + } + e => e.into(), + })?; + + Ok(HttpResponse::NoContent().body("")) } #[delete("{id}/follow")] @@ -77,19 +92,22 @@ pub async fn user_unfollow( redis: web::Data, session_queue: web::Data, ) -> Result { - crate::routes::v3::follow::unfollow( - req, - target_id, - pool, - redis, - session_queue, - |id, pool, redis| async move { DBUser::get(&id, &**pool, &redis).await }, - |follower_id, target_id, pool| async move { - DBUserFollow::unfollow(follower_id, target_id, &**pool).await - }, - "organization", + let (_, current_user) = get_user_from_headers( + &req, + &**pool, + &redis, + &session_queue, + Some(&[Scopes::USER_WRITE]), ) - .await + .await?; + + let target = DBUser::get(&target_id, &**pool, &redis) + .await? + .ok_or_else(|| ApiError::InvalidInput("The specified user does not exist!".to_string()))?; + + DBUserFollow::unfollow(current_user.id.into(), target.id.into(), &**pool).await?; + + Ok(HttpResponse::NoContent().body("")) } #[derive(Serialize, Deserialize)] From 3edae49c9db52d3b955fe17b27cc7d04c4dfd6b5 Mon Sep 17 00:00:00 2001 From: Wyatt Verchere Date: Wed, 1 Nov 2023 19:17:51 -0700 Subject: [PATCH 64/72] adds version create and project update --- src/auth/checks.rs | 2 +- src/database/models/creator_follows.rs | 2 + src/database/models/event_item.rs | 90 +++++++++++++-- src/database/models/ids.rs | 2 + src/models/feeds.rs | 13 ++- src/routes/updates.rs | 2 +- src/routes/v2/analytics_get.rs | 2 +- src/routes/v2/projects.rs | 41 ++++++- src/routes/v2/version_creation.rs | 33 +++++- src/routes/v2/version_file.rs | 2 +- src/routes/v2/versions.rs | 4 +- src/routes/v3/organizations.rs | 2 +- src/routes/v3/users.rs | 153 +++++++++++++++++++++---- tests/feed.rs | 59 ++++++++++ tests/version.rs | 4 +- 15 files changed, 367 insertions(+), 44 deletions(-) diff --git a/src/auth/checks.rs b/src/auth/checks.rs index 769ffcee..9f60cf7d 100644 --- a/src/auth/checks.rs +++ b/src/auth/checks.rs @@ -191,7 +191,7 @@ impl ValidateAuthorized for crate::database::models::OAuthClient { pub async fn filter_authorized_versions( versions: Vec, - user_option: &Option, + user_option: Option<&User>, pool: &web::Data, ) -> Result, ApiError> { let mut return_versions = Vec::new(); diff --git a/src/database/models/creator_follows.rs b/src/database/models/creator_follows.rs index 15fc9e2d..28f9c360 100644 --- a/src/database/models/creator_follows.rs +++ b/src/database/models/creator_follows.rs @@ -3,11 +3,13 @@ use itertools::Itertools; use super::{OrganizationId, UserId}; use crate::database::models::DatabaseError; +#[derive(Copy, Clone, Debug)] pub struct UserFollow { pub follower_id: UserId, pub target_id: UserId, } +#[derive(Copy, Clone, Debug)] pub struct OrganizationFollow { pub follower_id: UserId, pub target_id: OrganizationId, diff --git a/src/database/models/event_item.rs b/src/database/models/event_item.rs index 32ebc281..caf0c51e 100644 --- a/src/database/models/event_item.rs +++ b/src/database/models/event_item.rs @@ -1,6 +1,6 @@ use super::{ dynamic::{DynamicId, IdType}, - generate_event_id, DatabaseError, EventId, OrganizationId, ProjectId, UserId, + generate_event_id, DatabaseError, EventId, OrganizationId, ProjectId, UserId, VersionId, }; use chrono::{DateTime, Utc}; use itertools::Itertools; @@ -12,6 +12,8 @@ use std::convert::{TryFrom, TryInto}; #[sqlx(rename_all = "snake_case")] pub enum EventType { ProjectCreated, + VersionCreated, + ProjectUpdated, } impl PgHasArrayType for EventType { @@ -20,18 +22,29 @@ impl PgHasArrayType for EventType { } } +#[derive(Debug)] pub enum CreatorId { User(UserId), Organization(OrganizationId), } +#[derive(Debug)] pub enum EventData { ProjectCreated { project_id: ProjectId, creator_id: CreatorId, }, + VersionCreated { + version_id: VersionId, + creator_id: CreatorId, + }, + ProjectUpdated { + project_id: ProjectId, + updater_id: CreatorId, + }, } +#[derive(Debug)] pub struct Event { pub id: EventId, pub event_data: EventData, @@ -108,6 +121,40 @@ impl From for RawEvent { created: None, } } + EventData::VersionCreated { + version_id, + creator_id, + } => { + let target_id = DynamicId::from(version_id); + let triggerer_id = DynamicId::from(creator_id); + RawEvent { + id: value.id, + target_id: target_id.id, + target_id_type: target_id.id_type, + triggerer_id: Some(triggerer_id.id), + triggerer_id_type: Some(triggerer_id.id_type), + event_type: EventType::VersionCreated, + metadata: None, + created: None, + } + } + EventData::ProjectUpdated { + project_id, + updater_id, + } => { + let target_id = DynamicId::from(project_id); + let triggerer_id = DynamicId::from(updater_id); + RawEvent { + id: value.id, + target_id: target_id.id, + target_id_type: target_id.id_type, + triggerer_id: Some(triggerer_id.id), + triggerer_id_type: Some(triggerer_id.id_type), + event_type: EventType::ProjectUpdated, + metadata: None, + created: None, + } + } } } } @@ -124,10 +171,11 @@ impl TryFrom for Event { (Some(id), Some(id_type)) => Some(DynamicId { id, id_type }), _ => None, }; - Ok(match value.event_type { - EventType::ProjectCreated => Event { - id: value.id, - event_data: EventData::ProjectCreated { + + let event = Event { + id : value.id, + event_data : match value.event_type { + EventType::ProjectCreated => EventData::ProjectCreated { project_id: target_id.try_into()?, creator_id: triggerer_id.map_or_else(|| { Err(DatabaseError::UnexpectedNull( @@ -135,16 +183,34 @@ impl TryFrom for Event { )) }, |v| v.try_into())?, }, - time: value.created.map_or_else( - || { + EventType::VersionCreated => EventData::VersionCreated { + version_id: target_id.try_into()?, + creator_id: triggerer_id.map_or_else(|| { Err(DatabaseError::UnexpectedNull( - "the value of created should not be null".to_string(), + "Neither triggerer_id nor triggerer_id_type should be null for version creation".to_string(), )) - }, - Ok, - )?, + }, |v| v.try_into())?, + }, + EventType::ProjectUpdated => EventData::ProjectUpdated { + project_id: target_id.try_into()?, + updater_id: triggerer_id.map_or_else(|| { + Err(DatabaseError::UnexpectedNull( + "Neither triggerer_id nor triggerer_id_type should be null for project update".to_string(), + )) + }, |v| v.try_into())?, + }, }, - }) + time : value.created.map_or_else( + || { + Err(DatabaseError::UnexpectedNull( + "the value of created should not be null".to_string(), + )) + }, + Ok, + )?, + }; + + Ok(event) } } diff --git a/src/database/models/ids.rs b/src/database/models/ids.rs index 4f740ec8..c07fc38d 100644 --- a/src/database/models/ids.rs +++ b/src/database/models/ids.rs @@ -463,6 +463,7 @@ pub mod dynamic { ProjectId, UserId, OrganizationId, + VersionId, } impl PgHasArrayType for IdType { @@ -506,6 +507,7 @@ pub mod dynamic { } from_static_impl!(ProjectId, IdType::ProjectId); + from_static_impl!(VersionId, IdType::VersionId); from_static_impl!(UserId, IdType::UserId); from_static_impl!(OrganizationId, IdType::OrganizationId); } diff --git a/src/models/feeds.rs b/src/models/feeds.rs index 0820b888..1d3fa387 100644 --- a/src/models/feeds.rs +++ b/src/models/feeds.rs @@ -1,7 +1,7 @@ use super::ids::Base62Id; use super::ids::OrganizationId; use super::users::UserId; -use crate::models::ids::ProjectId; +use crate::models::ids::{ProjectId, VersionId}; use chrono::{DateTime, Utc}; use serde::{Deserialize, Serialize}; @@ -34,6 +34,17 @@ pub enum FeedItemBody { creator_id: CreatorId, project_title: String, }, + ProjectUpdated { + project_id: ProjectId, + updater_id: CreatorId, + project_title: String, + }, + VersionCreated { + project_id: ProjectId, + version_id: VersionId, + creator_id: CreatorId, + project_title: String, + }, } impl From for CreatorId { diff --git a/src/routes/updates.rs b/src/routes/updates.rs index 004621a9..eaf8eced 100644 --- a/src/routes/updates.rs +++ b/src/routes/updates.rs @@ -72,7 +72,7 @@ pub async fn forge_updates( .into_iter() .filter(|x| x.loaders.iter().any(loaders)) .collect(), - &user_option, + user_option.as_ref(), &pool, ) .await?; diff --git a/src/routes/v2/analytics_get.rs b/src/routes/v2/analytics_get.rs index d6ab2c21..8e907ffb 100644 --- a/src/routes/v2/analytics_get.rs +++ b/src/routes/v2/analytics_get.rs @@ -590,7 +590,7 @@ async fn filter_allowed_ids( .map(|id| Ok(VersionId(parse_base62(id)?).into())) .collect::, ApiError>>()?; let versions = version_item::Version::get_many(&ids, &***pool, redis).await?; - let ids: Vec = filter_authorized_versions(versions, &Some(user), pool) + let ids: Vec = filter_authorized_versions(versions, Some(&user), pool) .await? .into_iter() .map(|x| x.id) diff --git a/src/routes/v2/projects.rs b/src/routes/v2/projects.rs index d1883798..876a7e1a 100644 --- a/src/routes/v2/projects.rs +++ b/src/routes/v2/projects.rs @@ -1,9 +1,10 @@ use crate::auth::{filter_authorized_projects, get_user_from_headers, is_authorized}; use crate::database; -use crate::database::models::image_item; +use crate::database::models::event_item::{CreatorId, EventData}; use crate::database::models::notification_item::NotificationBuilder; use crate::database::models::project_item::{GalleryItem, ModCategory}; use crate::database::models::thread_item::ThreadMessageBuilder; +use crate::database::models::{image_item, Event}; use crate::database::redis::RedisPool; use crate::file_hosting::FileHost; use crate::models; @@ -24,6 +25,7 @@ use crate::util::routes::read_from_payload; use crate::util::validate::validation_errors_to_string; use actix_web::{delete, get, patch, post, web, HttpRequest, HttpResponse}; use chrono::{DateTime, Utc}; +use db_ids::OrganizationId; use futures::TryStreamExt; use meilisearch_sdk::indexes::IndexesResults; use serde::{Deserialize, Serialize}; @@ -1118,6 +1120,14 @@ pub async fn project_edit( ) .await?; + insert_project_update_event( + id.into(), + project_item.inner.organization_id, + &user, + &mut transaction, + ) + .await?; + transaction.commit().await?; Ok(HttpResponse::NoContent().body("")) } else { @@ -1467,6 +1477,14 @@ pub async fn projects_edit( .await?; } + insert_project_update_event( + project.inner.id.into(), + project.inner.organization_id, + &user, + &mut transaction, + ) + .await?; + db_models::Project::clear_cache(project.inner.id, project.inner.slug, None, &redis).await?; } @@ -2558,3 +2576,24 @@ pub async fn delete_from_index( Ok(()) } + +async fn insert_project_update_event( + project_id: ProjectId, + organization_id: Option, + current_user: &crate::models::users::User, + transaction: &mut sqlx::Transaction<'_, sqlx::Postgres>, +) -> Result<(), ApiError> { + let event = Event::new( + EventData::ProjectUpdated { + project_id: project_id.into(), + updater_id: organization_id.map_or_else( + || CreatorId::User(current_user.id.into()), + CreatorId::Organization, + ), + }, + transaction, + ) + .await?; + event.insert(transaction).await?; + Ok(()) +} diff --git a/src/routes/v2/version_creation.rs b/src/routes/v2/version_creation.rs index e94284cf..83da5a0b 100644 --- a/src/routes/v2/version_creation.rs +++ b/src/routes/v2/version_creation.rs @@ -1,10 +1,11 @@ use super::project_creation::{CreateError, UploadedFile}; use crate::auth::get_user_from_headers; +use crate::database::models::event_item::{CreatorId, EventData}; use crate::database::models::notification_item::NotificationBuilder; use crate::database::models::version_item::{ DependencyBuilder, VersionBuilder, VersionFileBuilder, }; -use crate::database::models::{self, image_item, Organization}; +use crate::database::models::{self, image_item, Event, Organization}; use crate::database::redis::RedisPool; use crate::file_hosting::FileHost; use crate::models::images::{Image, ImageContext, ImageId}; @@ -304,6 +305,14 @@ async fn version_create_inner( }) .collect::>(); + insert_version_create_event( + version_id, + organization.map(|o| o.id), + &user, + transaction, + ) + .await?; + version_builder = Some(VersionBuilder { version_id: version_id.into(), project_id, @@ -961,3 +970,25 @@ pub fn get_name_ext( }; Ok((file_name, file_extension)) } + +async fn insert_version_create_event( + version_id: VersionId, + organization_id: Option, + current_user: &crate::models::users::User, + transaction: &mut sqlx::Transaction<'_, sqlx::Postgres>, +) -> Result<(), CreateError> { + println!("Adding version create event"); + let event = Event::new( + EventData::VersionCreated { + version_id: version_id.into(), + creator_id: organization_id.map_or_else( + || CreatorId::User(current_user.id.into()), + CreatorId::Organization, + ), + }, + transaction, + ) + .await?; + event.insert(transaction).await?; + Ok(()) +} diff --git a/src/routes/v2/version_file.rs b/src/routes/v2/version_file.rs index be74701e..3e948166 100644 --- a/src/routes/v2/version_file.rs +++ b/src/routes/v2/version_file.rs @@ -378,7 +378,7 @@ pub async fn get_versions_from_hashes( let version_ids = files.iter().map(|x| x.version_id).collect::>(); let versions_data = filter_authorized_versions( database::models::Version::get_many(&version_ids, &**pool, &redis).await?, - &user_option, + user_option.as_ref(), &pool, ) .await?; diff --git a/src/routes/v2/versions.rs b/src/routes/v2/versions.rs index 44517a84..8eca3ed9 100644 --- a/src/routes/v2/versions.rs +++ b/src/routes/v2/versions.rs @@ -158,7 +158,7 @@ pub async fn version_list( response.sort(); response.dedup_by(|a, b| a.inner.id == b.inner.id); - let response = filter_authorized_versions(response, &user_option, &pool).await?; + let response = filter_authorized_versions(response, user_option.as_ref(), &pool).await?; Ok(HttpResponse::Ok().json(response)) } else { @@ -243,7 +243,7 @@ pub async fn versions_get( .map(|x| x.1) .ok(); - let versions = filter_authorized_versions(versions_data, &user_option, &pool).await?; + let versions = filter_authorized_versions(versions_data, user_option.as_ref(), &pool).await?; Ok(HttpResponse::Ok().json(versions)) } diff --git a/src/routes/v3/organizations.rs b/src/routes/v3/organizations.rs index a990f8d3..e2093f57 100644 --- a/src/routes/v3/organizations.rs +++ b/src/routes/v3/organizations.rs @@ -88,7 +88,7 @@ pub async fn organization_unfollow( ApiError::InvalidInput("The specified organization does not exist!".to_string()) })?; - DBOrganizationFollow::unfollow(current_user.id.into(), target.id.into(), &**pool).await?; + DBOrganizationFollow::unfollow(current_user.id.into(), target.id, &**pool).await?; Ok(HttpResponse::NoContent().body("")) } diff --git a/src/routes/v3/users.rs b/src/routes/v3/users.rs index 02e5263f..1e459dea 100644 --- a/src/routes/v3/users.rs +++ b/src/routes/v3/users.rs @@ -1,7 +1,7 @@ use std::{collections::HashMap, iter::FromIterator}; use crate::{ - auth::{filter_authorized_projects, get_user_from_headers}, + auth::{filter_authorized_projects, filter_authorized_versions, get_user_from_headers}, database::{ self, models::{ @@ -12,9 +12,9 @@ use crate::{ }, models::{ feeds::{FeedItem, FeedItemBody}, - ids::ProjectId, + ids::{ProjectId, VersionId}, pats::Scopes, - projects::Project, + projects::{Project, Version}, users::User, }, queue::session::AuthQueue, @@ -105,7 +105,7 @@ pub async fn user_unfollow( .await? .ok_or_else(|| ApiError::InvalidInput("The specified user does not exist!".to_string()))?; - DBUserFollow::unfollow(current_user.id.into(), target.id.into(), &**pool).await?; + DBUserFollow::unfollow(current_user.id.into(), target.id, &**pool).await?; Ok(HttpResponse::NoContent().body("")) } @@ -138,20 +138,32 @@ pub async fn current_user_feed( let followed_organizations = DBOrganizationFollow::get_follows_by_follower(current_user.id.into(), &**pool).await?; + // Feed by default shows the following: + // - Projects created by users you follow + // - Projects created by organizations you follow + // - Versions created by users you follow + // - Versions created by organizations you follow + // - Projects updated by users you follow + // - Projects updated by organizations you follow + let event_types = [ + EventType::ProjectCreated, + EventType::VersionCreated, + EventType::ProjectUpdated, + ]; let selectors = followed_users .into_iter() - .map(|follow| EventSelector { - id: follow.target_id.into(), - event_type: EventType::ProjectCreated, + .flat_map(|follow| { + event_types.iter().map(move |event_type| EventSelector { + id: follow.target_id.into(), + event_type: *event_type, + }) }) - .chain( - followed_organizations - .into_iter() - .map(|follow| EventSelector { - id: follow.target_id.into(), - event_type: EventType::ProjectCreated, - }), - ) + .chain(followed_organizations.into_iter().flat_map(|follow| { + event_types.iter().map(move |event_type| EventSelector { + id: follow.target_id.into(), + event_type: *event_type, + }) + })) .collect_vec(); let events = DBEvent::get_events(&[], &selectors, &**pool) .await? @@ -160,9 +172,32 @@ pub async fn current_user_feed( .take(params.offset.unwrap_or(usize::MAX)) .collect_vec(); + println!("ALL EVENTS: {:#?}", events); + let mut feed_items: Vec = Vec::new(); - let authorized_projects = - prefetch_authorized_event_projects(&events, &pool, &redis, ¤t_user).await?; + let authorized_versions = + prefetch_authorized_event_versions(&events, &pool, &redis, ¤t_user).await?; + let authorized_version_project_ids = authorized_versions + .values() + .map(|versions| versions.project_id) + .collect_vec(); + let authorized_projects = prefetch_authorized_event_projects( + &events, + Some(&authorized_version_project_ids), + &pool, + &redis, + ¤t_user, + ) + .await?; + + println!( + "All authorized versoins: {:#?}", + serde_json::to_string(&authorized_versions).unwrap() + ); + println!( + "All authorized projects: {:#?}", + serde_json::to_string(&authorized_projects).unwrap() + ); for event in events { let body = match event.event_data { @@ -176,6 +211,45 @@ pub async fn current_user_feed( project_title: p.title.clone(), } }), + EventData::ProjectUpdated { + project_id, + updater_id, + } => authorized_projects.get(&project_id.into()).map(|p| { + FeedItemBody::ProjectUpdated { + project_id: project_id.into(), + updater_id: updater_id.into(), + project_title: p.title.clone(), + } + }), + EventData::VersionCreated { + version_id, + creator_id, + } => { + println!("Making version ev"); + let authorized_version = authorized_versions.get(&version_id.into()); + let authorized_project = + authorized_version.and_then(|v| authorized_projects.get(&v.project_id)); + println!( + "av: {:#?}", + serde_json::to_string(&authorized_version).unwrap() + ); + println!( + "ap: {:#?}", + serde_json::to_string(&authorized_project).unwrap() + ); + if let (Some(authorized_version), Some(authorized_project)) = + (authorized_version, authorized_project) + { + Some(FeedItemBody::VersionCreated { + project_id: authorized_project.id, + version_id: authorized_version.id, + creator_id: creator_id.into(), + project_title: authorized_project.title.clone(), + }) + } else { + None + } + } }; if let Some(body) = body { @@ -194,19 +268,33 @@ pub async fn current_user_feed( async fn prefetch_authorized_event_projects( events: &[db_models::Event], + additional_ids: Option<&[ProjectId]>, pool: &web::Data, redis: &RedisPool, current_user: &User, ) -> Result, ApiError> { - let project_ids = events + let mut project_ids = events .iter() - .map(|e| match &e.event_data { + .filter_map(|e| match &e.event_data { EventData::ProjectCreated { project_id, creator_id: _, - } => *project_id, + } => Some(*project_id), + EventData::ProjectUpdated { + project_id, + updater_id: _, + } => Some(*project_id), + EventData::VersionCreated { .. } => None, }) .collect_vec(); + if let Some(additional_ids) = additional_ids { + project_ids.extend( + additional_ids + .iter() + .copied() + .map(db_models::ProjectId::from), + ); + } let projects = db_models::Project::get_many_ids(&project_ids, &***pool, redis).await?; let authorized_projects = filter_authorized_projects(projects, Some(current_user), pool).await?; @@ -214,3 +302,28 @@ async fn prefetch_authorized_event_projects( authorized_projects.into_iter().map(|p| (p.id, p)), )) } + +async fn prefetch_authorized_event_versions( + events: &[db_models::Event], + pool: &web::Data, + redis: &RedisPool, + current_user: &User, +) -> Result, ApiError> { + let version_ids = events + .iter() + .filter_map(|e| match &e.event_data { + EventData::VersionCreated { + version_id, + creator_id: _, + } => Some(*version_id), + EventData::ProjectCreated { .. } => None, + EventData::ProjectUpdated { .. } => None, + }) + .collect_vec(); + let versions = db_models::Version::get_many(&version_ids, &***pool, redis).await?; + let authorized_versions = + filter_authorized_versions(versions, Some(current_user), pool).await?; + Ok(HashMap::::from_iter( + authorized_versions.into_iter().map(|v| (v.id, v)), + )) +} diff --git a/tests/feed.rs b/tests/feed.rs index 9c44cf07..134eed3a 100644 --- a/tests/feed.rs +++ b/tests/feed.rs @@ -38,6 +38,65 @@ async fn get_feed_after_following_user_shows_previously_created_public_projects( .await } +#[actix_rt::test] +async fn get_feed_after_following_user_shows_previously_created_public_versions() { + with_test_environment(|env| async move { + let DummyProjectAlpha { + project_id: alpha_project_id, + .. + } = env.dummy.as_ref().unwrap().project_alpha.clone(); + + // Add version + let v = env.v2.create_default_version(&alpha_project_id, None, USER_USER_PAT) + .await; + + env.v3.follow_user(USER_USER_ID, FRIEND_USER_PAT).await; + + let feed = env.v3.get_feed(FRIEND_USER_PAT).await; + + assert_eq!(feed.len(), 2); + assert_matches!( + feed[1].body, + FeedItemBody::ProjectCreated { project_id, .. } if project_id.to_string() == alpha_project_id + ); + assert_matches!( + feed[0].body, + FeedItemBody::VersionCreated { version_id, .. } if version_id == v.id + ); + }) + .await +} + +#[actix_rt::test] + +async fn get_feed_after_following_user_shows_previously_edited_public_versions() { + with_test_environment(|env| async move { + let DummyProjectAlpha { + project_id: alpha_project_id, + .. + } = env.dummy.as_ref().unwrap().project_alpha.clone(); + + // Empty patch + env.v2.edit_project(&alpha_project_id, serde_json::json!({}), USER_USER_PAT) + .await; + + env.v3.follow_user(USER_USER_ID, FRIEND_USER_PAT).await; + + let feed = env.v3.get_feed(FRIEND_USER_PAT).await; + + assert_eq!(feed.len(), 2); + assert_matches!( + feed[1].body, + FeedItemBody::ProjectCreated { project_id, .. } if project_id.to_string() == alpha_project_id + ); + assert_matches!( + feed[0].body, + FeedItemBody::ProjectUpdated { project_id, .. } if project_id.to_string() == alpha_project_id + ); + }) + .await +} + #[actix_rt::test] async fn get_feed_when_following_user_that_creates_project_as_org_only_shows_event_when_following_org( ) { diff --git a/tests/version.rs b/tests/version.rs index a1a6d84d..04198c5b 100644 --- a/tests/version.rs +++ b/tests/version.rs @@ -81,7 +81,7 @@ async fn version_ordering_when_unspecified_orders_oldest_first() { let alpha_version_id = env.dummy.as_ref().unwrap().project_alpha.version_id.clone(); let new_version_id = get_json_val_str( env.v2 - .create_default_version(&alpha_project_id, None, USER_USER_PAT) + .create_default_version(alpha_project_id, None, USER_USER_PAT) .await .id, ); @@ -105,7 +105,7 @@ async fn version_ordering_when_specified_orders_specified_before_unspecified() { let alpha_version_id = env.dummy.as_ref().unwrap().project_alpha.version_id.clone(); let new_version_id = get_json_val_str( env.v2 - .create_default_version(&alpha_project_id, Some(10000), USER_USER_PAT) + .create_default_version(alpha_project_id, Some(10000), USER_USER_PAT) .await .id, ); From 474ac889c3e9fcb2990f4bb7197fd27bab9c7dc4 Mon Sep 17 00:00:00 2001 From: Wyatt Verchere Date: Thu, 2 Nov 2023 02:27:36 -0700 Subject: [PATCH 65/72] fixed bugs; delayed notification to publishment --- ...0bafd0038147973cf7c9a2c4d0bd33f4776ab.json | 23 ++++++ ...d7e38197123e2af364dd7d3b01c322345d7b.json} | 7 +- src/database/models/event_item.rs | 21 +++--- src/database/models/project_item.rs | 7 +- src/database/models/team_item.rs | 26 +++++++ src/models/feeds.rs | 2 +- src/routes/v2/project_creation.rs | 50 +++++++------ src/routes/v2/projects.rs | 71 ++++++++++++++----- src/routes/v2/version_creation.rs | 1 - src/routes/v3/users.rs | 29 ++------ tests/common/asserts.rs | 14 +++- tests/feed.rs | 53 ++++++++------ tests/organizations.rs | 71 +++++++++++++++++++ 13 files changed, 269 insertions(+), 106 deletions(-) create mode 100644 .sqlx/query-0be48b74255947b8550eb98dfba0bafd0038147973cf7c9a2c4d0bd33f4776ab.json rename .sqlx/{query-b36877d60945eaae76680770a5d28d2cbb26cfbb0ec94ecc8f0741f48178ec1c.json => query-fbfa35f79abe9662cfbad62db8ecd7e38197123e2af364dd7d3b01c322345d7b.json} (63%) diff --git a/.sqlx/query-0be48b74255947b8550eb98dfba0bafd0038147973cf7c9a2c4d0bd33f4776ab.json b/.sqlx/query-0be48b74255947b8550eb98dfba0bafd0038147973cf7c9a2c4d0bd33f4776ab.json new file mode 100644 index 00000000..93e9d929 --- /dev/null +++ b/.sqlx/query-0be48b74255947b8550eb98dfba0bafd0038147973cf7c9a2c4d0bd33f4776ab.json @@ -0,0 +1,23 @@ +{ + "db_name": "PostgreSQL", + "query": "\n SELECT user_id\n FROM team_members\n WHERE (team_id = $1 AND role = $2)\n ", + "describe": { + "columns": [ + { + "ordinal": 0, + "name": "user_id", + "type_info": "Int8" + } + ], + "parameters": { + "Left": [ + "Int8", + "Text" + ] + }, + "nullable": [ + false + ] + }, + "hash": "0be48b74255947b8550eb98dfba0bafd0038147973cf7c9a2c4d0bd33f4776ab" +} diff --git a/.sqlx/query-b36877d60945eaae76680770a5d28d2cbb26cfbb0ec94ecc8f0741f48178ec1c.json b/.sqlx/query-fbfa35f79abe9662cfbad62db8ecd7e38197123e2af364dd7d3b01c322345d7b.json similarity index 63% rename from .sqlx/query-b36877d60945eaae76680770a5d28d2cbb26cfbb0ec94ecc8f0741f48178ec1c.json rename to .sqlx/query-fbfa35f79abe9662cfbad62db8ecd7e38197123e2af364dd7d3b01c322345d7b.json index 74091817..a598df51 100644 --- a/.sqlx/query-b36877d60945eaae76680770a5d28d2cbb26cfbb0ec94ecc8f0741f48178ec1c.json +++ b/.sqlx/query-fbfa35f79abe9662cfbad62db8ecd7e38197123e2af364dd7d3b01c322345d7b.json @@ -1,6 +1,6 @@ { "db_name": "PostgreSQL", - "query": "\n INSERT INTO mods (\n id, team_id, title, description, body,\n published, downloads, icon_url, issues_url,\n source_url, wiki_url, status, requested_status, discord_url,\n client_side, server_side, license_url, license,\n slug, project_type, color, monetization_status\n )\n VALUES (\n $1, $2, $3, $4, $5,\n $6, $7, $8, $9,\n $10, $11, $12, $13, $14,\n $15, $16, $17, $18,\n LOWER($19), $20, $21, $22\n )\n ", + "query": "\n INSERT INTO mods (\n id, team_id, title, description, body,\n published, downloads, icon_url, issues_url,\n source_url, wiki_url, status, requested_status, discord_url,\n client_side, server_side, license_url, license,\n slug, project_type, color, monetization_status,\n organization_id\n )\n VALUES (\n $1, $2, $3, $4, $5,\n $6, $7, $8, $9,\n $10, $11, $12, $13, $14,\n $15, $16, $17, $18,\n LOWER($19), $20, $21, $22,\n $23\n )\n ", "describe": { "columns": [], "parameters": { @@ -26,10 +26,11 @@ "Text", "Int4", "Int4", - "Varchar" + "Varchar", + "Int8" ] }, "nullable": [] }, - "hash": "b36877d60945eaae76680770a5d28d2cbb26cfbb0ec94ecc8f0741f48178ec1c" + "hash": "fbfa35f79abe9662cfbad62db8ecd7e38197123e2af364dd7d3b01c322345d7b" } diff --git a/src/database/models/event_item.rs b/src/database/models/event_item.rs index caf0c51e..567d3934 100644 --- a/src/database/models/event_item.rs +++ b/src/database/models/event_item.rs @@ -7,11 +7,11 @@ use itertools::Itertools; use sqlx::postgres::{PgHasArrayType, PgTypeInfo}; use std::convert::{TryFrom, TryInto}; -#[derive(sqlx::Type, Clone, Copy)] +#[derive(sqlx::Type, Clone, Copy, Debug)] #[sqlx(type_name = "text")] #[sqlx(rename_all = "snake_case")] pub enum EventType { - ProjectCreated, + ProjectPublished, VersionCreated, ProjectUpdated, } @@ -30,7 +30,7 @@ pub enum CreatorId { #[derive(Debug)] pub enum EventData { - ProjectCreated { + ProjectPublished { project_id: ProjectId, creator_id: CreatorId, }, @@ -104,7 +104,7 @@ impl TryFrom for CreatorId { impl From for RawEvent { fn from(value: Event) -> Self { match value.event_data { - EventData::ProjectCreated { + EventData::ProjectPublished { project_id, creator_id, } => { @@ -116,7 +116,7 @@ impl From for RawEvent { target_id_type: target_id.id_type, triggerer_id: Some(triggerer_id.id), triggerer_id_type: Some(triggerer_id.id_type), - event_type: EventType::ProjectCreated, + event_type: EventType::ProjectPublished, metadata: None, created: None, } @@ -175,11 +175,11 @@ impl TryFrom for Event { let event = Event { id : value.id, event_data : match value.event_type { - EventType::ProjectCreated => EventData::ProjectCreated { + EventType::ProjectPublished => EventData::ProjectPublished { project_id: target_id.try_into()?, creator_id: triggerer_id.map_or_else(|| { Err(DatabaseError::UnexpectedNull( - "Neither triggerer_id nor triggerer_id_type should be null for project creation".to_string(), + "Neither triggerer_id nor triggerer_id_type should be null for project publishing".to_string(), )) }, |v| v.try_into())?, }, @@ -242,7 +242,8 @@ impl Event { unzip_event_selectors(target_selectors); let (triggerer_ids, triggerer_id_types, triggerer_event_types) = unzip_event_selectors(triggerer_selectors); - sqlx::query_as!( + + let r = sqlx::query_as!( RawEvent, r#" SELECT @@ -274,7 +275,9 @@ impl Event { .await? .into_iter() .map(|r| r.try_into()) - .collect::, _>>() + .collect::, _>>()?; + + Ok(r) } } diff --git a/src/database/models/project_item.rs b/src/database/models/project_item.rs index a3fddec2..a8ee1c44 100644 --- a/src/database/models/project_item.rs +++ b/src/database/models/project_item.rs @@ -299,14 +299,16 @@ impl Project { published, downloads, icon_url, issues_url, source_url, wiki_url, status, requested_status, discord_url, client_side, server_side, license_url, license, - slug, project_type, color, monetization_status + slug, project_type, color, monetization_status, + organization_id ) VALUES ( $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, - LOWER($19), $20, $21, $22 + LOWER($19), $20, $21, $22, + $23 ) ", self.id as ProjectId, @@ -331,6 +333,7 @@ impl Project { self.project_type as ProjectTypeId, self.color.map(|x| x as i32), self.monetization_status.as_str(), + self.organization_id.map(|x| x.0), ) .execute(&mut **transaction) .await?; diff --git a/src/database/models/team_item.rs b/src/database/models/team_item.rs index a513aefe..df59c623 100644 --- a/src/database/models/team_item.rs +++ b/src/database/models/team_item.rs @@ -390,6 +390,32 @@ impl TeamMember { } } + pub async fn get_owner_id<'a, 'b, E>( + id: TeamId, + executor: E, + ) -> Result, super::DatabaseError> + where + E: sqlx::Executor<'a, Database = sqlx::Postgres>, + { + let result = sqlx::query!( + " + SELECT user_id + FROM team_members + WHERE (team_id = $1 AND role = $2) + ", + id as TeamId, + crate::models::teams::OWNER_ROLE, + ) + .fetch_optional(executor) + .await?; + + if let Some(m) = result { + Ok(Some(UserId(m.user_id))) + } else { + Ok(None) + } + } + pub async fn insert( &self, transaction: &mut sqlx::Transaction<'_, sqlx::Postgres>, diff --git a/src/models/feeds.rs b/src/models/feeds.rs index 1d3fa387..be3fdb5b 100644 --- a/src/models/feeds.rs +++ b/src/models/feeds.rs @@ -29,7 +29,7 @@ pub struct FeedItem { #[derive(Serialize, Deserialize, Debug)] #[serde(tag = "type", rename_all = "snake_case")] pub enum FeedItemBody { - ProjectCreated { + ProjectPublished { project_id: ProjectId, creator_id: CreatorId, project_title: String, diff --git a/src/routes/v2/project_creation.rs b/src/routes/v2/project_creation.rs index 91751c37..ea287745 100644 --- a/src/routes/v2/project_creation.rs +++ b/src/routes/v2/project_creation.rs @@ -1,6 +1,5 @@ use super::version_creation::InitialVersionData; use crate::auth::{get_user_from_headers, AuthenticationError}; -use crate::database::models::event_item::{CreatorId, Event, EventData}; use crate::database::models::thread_item::ThreadBuilder; use crate::database::models::{self, image_item, User}; use crate::database::redis::RedisPool; @@ -14,7 +13,7 @@ use crate::models::projects::{ DonationLink, License, MonetizationStatus, ProjectId, ProjectStatus, SideType, VersionId, VersionStatus, }; -use crate::models::teams::ProjectPermissions; +use crate::models::teams::{OrganizationPermissions, ProjectPermissions}; use crate::models::threads::ThreadType; use crate::models::users::UserId; use crate::queue::session::AuthQueue; @@ -453,6 +452,29 @@ async fn project_create_inner( } } + // If organization_id is set, make sure the user is a member of the organization + if let Some(organization_id) = create_data.organization_id { + let organization_team_member = + models::team_item::TeamMember::get_from_user_id_organization( + organization_id.into(), + current_user.id.into(), + &mut **transaction, + ) + .await?; + + let permissions = OrganizationPermissions::get_permissions_by_role( + ¤t_user.role, + &organization_team_member, + ) + .unwrap_or_default(); + + if !permissions.contains(OrganizationPermissions::ADD_PROJECT) { + return Err(CreateError::CustomAuthenticationError( + "You do not have permission to add projects to this organization!".to_string(), + )); + } + } + // Create VersionBuilders for the versions specified in `initial_versions` versions = Vec::with_capacity(create_data.initial_versions.len()); for (i, data) in create_data.initial_versions.iter().enumerate() { @@ -749,9 +771,6 @@ async fn project_create_inner( } let organization_id = project_create_data.organization_id.map(|id| id.into()); - insert_project_create_event(project_id, organization_id, ¤t_user, transaction) - .await?; - let project_builder_actual = models::project_item::ProjectBuilder { project_id: project_id.into(), project_type_id, @@ -893,27 +912,6 @@ async fn project_create_inner( } } -async fn insert_project_create_event( - project_id: ProjectId, - organization_id: Option, - current_user: &crate::models::users::User, - transaction: &mut sqlx::Transaction<'_, sqlx::Postgres>, -) -> Result<(), CreateError> { - let event = Event::new( - EventData::ProjectCreated { - project_id: project_id.into(), - creator_id: organization_id.map_or_else( - || CreatorId::User(current_user.id.into()), - CreatorId::Organization, - ), - }, - transaction, - ) - .await?; - event.insert(transaction).await?; - Ok(()) -} - async fn create_initial_version( version_data: &InitialVersionData, project_id: ProjectId, diff --git a/src/routes/v2/projects.rs b/src/routes/v2/projects.rs index 876a7e1a..29b25683 100644 --- a/src/routes/v2/projects.rs +++ b/src/routes/v2/projects.rs @@ -4,7 +4,7 @@ use crate::database::models::event_item::{CreatorId, EventData}; use crate::database::models::notification_item::NotificationBuilder; use crate::database::models::project_item::{GalleryItem, ModCategory}; use crate::database::models::thread_item::ThreadMessageBuilder; -use crate::database::models::{image_item, Event}; +use crate::database::models::{image_item, team_item, Event}; use crate::database::redis::RedisPool; use crate::file_hosting::FileHost; use crate::models; @@ -25,13 +25,14 @@ use crate::util::routes::read_from_payload; use crate::util::validate::validation_errors_to_string; use actix_web::{delete, get, patch, post, web, HttpRequest, HttpResponse}; use chrono::{DateTime, Utc}; -use db_ids::OrganizationId; +use db_ids::{OrganizationId, UserId}; use futures::TryStreamExt; use meilisearch_sdk::indexes::IndexesResults; use serde::{Deserialize, Serialize}; use serde_json::json; use sqlx::PgPool; use std::sync::Arc; +use team_item::TeamMember; use validator::Validate; use database::models as db_models; @@ -528,6 +529,21 @@ pub async fn project_edit( ) .execute(&mut *transaction) .await?; + + // On publish event, we send out notification using team *owner* as publishing user, at the time of mod approval + // (even though 'user' is the mod and doing the publishing) + let owner_id = + TeamMember::get_owner_id(project_item.inner.team_id, &mut *transaction) + .await?; + if let Some(owner_id) = owner_id { + insert_project_publish_event( + id.into(), + project_item.inner.organization_id, + owner_id, + &mut transaction, + ) + .await?; + } } if status.is_searchable() && !project_item.inner.webhook_sent { @@ -1120,13 +1136,15 @@ pub async fn project_edit( ) .await?; - insert_project_update_event( - id.into(), - project_item.inner.organization_id, - &user, - &mut transaction, - ) - .await?; + if project_item.inner.status.is_searchable() { + insert_project_update_event( + project_item.inner.id.into(), + project_item.inner.organization_id, + &user, + &mut transaction, + ) + .await?; + } transaction.commit().await?; Ok(HttpResponse::NoContent().body("")) @@ -1477,13 +1495,15 @@ pub async fn projects_edit( .await?; } - insert_project_update_event( - project.inner.id.into(), - project.inner.organization_id, - &user, - &mut transaction, - ) - .await?; + if project.inner.status.is_searchable() { + insert_project_update_event( + project.inner.id.into(), + project.inner.organization_id, + &user, + &mut transaction, + ) + .await?; + } db_models::Project::clear_cache(project.inner.id, project.inner.slug, None, &redis).await?; } @@ -2597,3 +2617,22 @@ async fn insert_project_update_event( event.insert(transaction).await?; Ok(()) } + +async fn insert_project_publish_event( + project_id: ProjectId, + organization_id: Option, + owner_id: UserId, + transaction: &mut sqlx::Transaction<'_, sqlx::Postgres>, +) -> Result<(), ApiError> { + let event = Event::new( + EventData::ProjectPublished { + project_id: project_id.into(), + creator_id: organization_id + .map_or_else(|| CreatorId::User(owner_id), CreatorId::Organization), + }, + transaction, + ) + .await?; + event.insert(transaction).await?; + Ok(()) +} diff --git a/src/routes/v2/version_creation.rs b/src/routes/v2/version_creation.rs index 83da5a0b..fef61899 100644 --- a/src/routes/v2/version_creation.rs +++ b/src/routes/v2/version_creation.rs @@ -977,7 +977,6 @@ async fn insert_version_create_event( current_user: &crate::models::users::User, transaction: &mut sqlx::Transaction<'_, sqlx::Postgres>, ) -> Result<(), CreateError> { - println!("Adding version create event"); let event = Event::new( EventData::VersionCreated { version_id: version_id.into(), diff --git a/src/routes/v3/users.rs b/src/routes/v3/users.rs index 1e459dea..8593540e 100644 --- a/src/routes/v3/users.rs +++ b/src/routes/v3/users.rs @@ -146,7 +146,7 @@ pub async fn current_user_feed( // - Projects updated by users you follow // - Projects updated by organizations you follow let event_types = [ - EventType::ProjectCreated, + EventType::ProjectPublished, EventType::VersionCreated, EventType::ProjectUpdated, ]; @@ -172,8 +172,6 @@ pub async fn current_user_feed( .take(params.offset.unwrap_or(usize::MAX)) .collect_vec(); - println!("ALL EVENTS: {:#?}", events); - let mut feed_items: Vec = Vec::new(); let authorized_versions = prefetch_authorized_event_versions(&events, &pool, &redis, ¤t_user).await?; @@ -190,22 +188,14 @@ pub async fn current_user_feed( ) .await?; - println!( - "All authorized versoins: {:#?}", - serde_json::to_string(&authorized_versions).unwrap() - ); - println!( - "All authorized projects: {:#?}", - serde_json::to_string(&authorized_projects).unwrap() - ); for event in events { let body = match event.event_data { - EventData::ProjectCreated { + EventData::ProjectPublished { project_id, creator_id, } => authorized_projects.get(&project_id.into()).map(|p| { - FeedItemBody::ProjectCreated { + FeedItemBody::ProjectPublished { project_id: project_id.into(), creator_id: creator_id.into(), project_title: p.title.clone(), @@ -225,18 +215,9 @@ pub async fn current_user_feed( version_id, creator_id, } => { - println!("Making version ev"); let authorized_version = authorized_versions.get(&version_id.into()); let authorized_project = authorized_version.and_then(|v| authorized_projects.get(&v.project_id)); - println!( - "av: {:#?}", - serde_json::to_string(&authorized_version).unwrap() - ); - println!( - "ap: {:#?}", - serde_json::to_string(&authorized_project).unwrap() - ); if let (Some(authorized_version), Some(authorized_project)) = (authorized_version, authorized_project) { @@ -276,7 +257,7 @@ async fn prefetch_authorized_event_projects( let mut project_ids = events .iter() .filter_map(|e| match &e.event_data { - EventData::ProjectCreated { + EventData::ProjectPublished { project_id, creator_id: _, } => Some(*project_id), @@ -316,7 +297,7 @@ async fn prefetch_authorized_event_versions( version_id, creator_id: _, } => Some(*version_id), - EventData::ProjectCreated { .. } => None, + EventData::ProjectPublished { .. } => None, EventData::ProjectUpdated { .. } => None, }) .collect_vec(); diff --git a/tests/common/asserts.rs b/tests/common/asserts.rs index 0137d7aa..131cadba 100644 --- a/tests/common/asserts.rs +++ b/tests/common/asserts.rs @@ -31,5 +31,17 @@ pub fn assert_feed_contains_project_created( feed: &[FeedItem], expected_project_id: labrinth::models::projects::ProjectId, ) { - assert!(feed.iter().any(|fi| matches!(fi.body, FeedItemBody::ProjectCreated { project_id, .. } if project_id == expected_project_id)), "{:#?}", &feed); + assert!(feed.iter().any(|fi| matches!(fi.body, FeedItemBody::ProjectPublished { project_id, .. } if project_id == expected_project_id)), "{:#?}", &feed); +} +pub fn assert_feed_contains_project_updated( + feed: &[FeedItem], + expected_project_id: labrinth::models::projects::ProjectId, +) { + assert!(feed.iter().any(|fi| matches!(fi.body, FeedItemBody::ProjectUpdated { project_id, .. } if project_id == expected_project_id)), "{:#?}", &feed); +} +pub fn assert_feed_contains_version_created( + feed: &[FeedItem], + expected_version_id: labrinth::models::projects::VersionId, +) { + assert!(feed.iter().any(|fi| matches!(fi.body, FeedItemBody::VersionCreated { version_id, .. } if version_id == expected_version_id)), "{:#?}", &feed); } diff --git a/tests/feed.rs b/tests/feed.rs index 134eed3a..a8f52a65 100644 --- a/tests/feed.rs +++ b/tests/feed.rs @@ -1,10 +1,16 @@ -use crate::common::{asserts::assert_feed_contains_project_created, dummy_data::DummyProjectAlpha}; +use crate::common::{ + asserts::{ + assert_feed_contains_project_created, assert_feed_contains_project_updated, + assert_feed_contains_version_created, + }, + dummy_data::DummyProjectAlpha, +}; use assert_matches::assert_matches; use common::{ database::{FRIEND_USER_PAT, USER_USER_ID, USER_USER_PAT}, environment::with_test_environment, }; -use labrinth::models::feeds::FeedItemBody; +use labrinth::models::{feeds::FeedItemBody, ids::base62_impl::parse_base62, projects::ProjectId}; mod common; @@ -30,10 +36,10 @@ async fn get_feed_after_following_user_shows_previously_created_public_projects( let feed = env.v3.get_feed(FRIEND_USER_PAT).await; assert_eq!(feed.len(), 1); - assert_matches!( - feed[0].body, - FeedItemBody::ProjectCreated { project_id, .. } if project_id.to_string() == alpha_project_id - ) + assert_feed_contains_project_created( + &feed, + ProjectId(parse_base62(&alpha_project_id).unwrap()), + ); }) .await } @@ -47,7 +53,9 @@ async fn get_feed_after_following_user_shows_previously_created_public_versions( } = env.dummy.as_ref().unwrap().project_alpha.clone(); // Add version - let v = env.v2.create_default_version(&alpha_project_id, None, USER_USER_PAT) + let v = env + .v2 + .create_default_version(&alpha_project_id, None, USER_USER_PAT) .await; env.v3.follow_user(USER_USER_ID, FRIEND_USER_PAT).await; @@ -55,20 +63,17 @@ async fn get_feed_after_following_user_shows_previously_created_public_versions( let feed = env.v3.get_feed(FRIEND_USER_PAT).await; assert_eq!(feed.len(), 2); - assert_matches!( - feed[1].body, - FeedItemBody::ProjectCreated { project_id, .. } if project_id.to_string() == alpha_project_id - ); - assert_matches!( - feed[0].body, - FeedItemBody::VersionCreated { version_id, .. } if version_id == v.id + assert_feed_contains_project_created( + &feed, + ProjectId(parse_base62(&alpha_project_id).unwrap()), ); + assert_feed_contains_version_created(&feed, v.id); + // Notably, this should *not* have a projectupdated from the publishing. }) .await } #[actix_rt::test] - async fn get_feed_after_following_user_shows_previously_edited_public_versions() { with_test_environment(|env| async move { let DummyProjectAlpha { @@ -77,7 +82,8 @@ async fn get_feed_after_following_user_shows_previously_edited_public_versions() } = env.dummy.as_ref().unwrap().project_alpha.clone(); // Empty patch - env.v2.edit_project(&alpha_project_id, serde_json::json!({}), USER_USER_PAT) + env.v2 + .edit_project(&alpha_project_id, serde_json::json!({}), USER_USER_PAT) .await; env.v3.follow_user(USER_USER_ID, FRIEND_USER_PAT).await; @@ -85,13 +91,13 @@ async fn get_feed_after_following_user_shows_previously_edited_public_versions() let feed = env.v3.get_feed(FRIEND_USER_PAT).await; assert_eq!(feed.len(), 2); - assert_matches!( - feed[1].body, - FeedItemBody::ProjectCreated { project_id, .. } if project_id.to_string() == alpha_project_id + assert_feed_contains_project_created( + &feed, + ProjectId(parse_base62(&alpha_project_id).unwrap()), ); - assert_matches!( - feed[0].body, - FeedItemBody::ProjectUpdated { project_id, .. } if project_id.to_string() == alpha_project_id + assert_feed_contains_project_updated( + &feed, + ProjectId(parse_base62(&alpha_project_id).unwrap()), ); }) .await @@ -107,7 +113,8 @@ async fn get_feed_when_following_user_that_creates_project_as_org_only_shows_eve env.v3.follow_user(USER_USER_ID, FRIEND_USER_PAT).await; let feed = env.v3.get_feed(FRIEND_USER_PAT).await; assert_eq!(feed.len(), 1); - assert_matches!(feed[0].body, FeedItemBody::ProjectCreated { project_id, .. } if project_id != project.id); + + assert_matches!(feed[0].body, FeedItemBody::ProjectPublished { project_id, .. } if project_id != project.id); env.v3.follow_organization(&org_id, FRIEND_USER_PAT).await; let feed = env.v3.get_feed(FRIEND_USER_PAT).await; diff --git a/tests/organizations.rs b/tests/organizations.rs index edbb84bd..f427c22f 100644 --- a/tests/organizations.rs +++ b/tests/organizations.rs @@ -7,8 +7,10 @@ use crate::common::{ use actix_web::test; use bytes::Bytes; use common::{ + actix::AppendsMultipart, database::{FRIEND_USER_ID, FRIEND_USER_PAT, USER_USER_PAT}, permissions::{PermissionsTest, PermissionsTestContext}, + request_data::get_public_project_creation_data, }; use labrinth::models::teams::{OrganizationPermissions, ProjectPermissions}; use serde_json::json; @@ -289,6 +291,75 @@ async fn add_remove_organization_projects() { test_env.cleanup().await; } +#[actix_rt::test] +async fn create_project_in_organization() { + let test_env = TestEnvironment::build(None).await; + let zeta_organization_id: &str = &test_env + .dummy + .as_ref() + .unwrap() + .organization_zeta + .organization_id; + + // Create project in organization + let resp = test_env + .v2 + .add_default_org_project(zeta_organization_id, USER_USER_PAT) + .await; + + // Get project + let project = test_env + .v2 + .get_project_deserialized(&resp.id.to_string(), USER_USER_PAT) + .await; + + // Ensure organization id is correectly set in both returned project and + // fetched project. + assert_eq!(resp.organization.unwrap().to_string(), zeta_organization_id); + assert_eq!( + project.organization.unwrap().to_string(), + zeta_organization_id + ); + + test_env.cleanup().await; +} + +#[actix_rt::test] +async fn permissions_create_project_in_organization() { + let test_env = TestEnvironment::build(None).await; + + let zeta_organization_id = &test_env + .dummy + .as_ref() + .unwrap() + .organization_zeta + .organization_id; + let zeta_team_id = &test_env.dummy.as_ref().unwrap().organization_zeta.team_id; + + // Requires ADD_PROJECT to create project in org + let add_project = OrganizationPermissions::ADD_PROJECT; + + let req_gen = |ctx: &PermissionsTestContext| { + let multipart = get_public_project_creation_data( + &generate_random_name("randomslug"), + None, + Some(ctx.organization_id.unwrap()), + ) + .segment_data; + test::TestRequest::post() + .uri("/v2/project") + .set_multipart(multipart) + }; + PermissionsTest::new(&test_env) + .with_existing_organization(zeta_organization_id, zeta_team_id) + .with_user(FRIEND_USER_ID, FRIEND_USER_PAT, true) + .simple_organization_permissions_test(add_project, req_gen) + .await + .unwrap(); + + test_env.cleanup().await; +} + #[actix_rt::test] async fn permissions_patch_organization() { let test_env = TestEnvironment::build(Some(8)).await; From 77617c21ea5c74c117c32c63e52c6e44c6a17ae7 Mon Sep 17 00:00:00 2001 From: Wyatt Verchere Date: Thu, 2 Nov 2023 15:03:22 -0700 Subject: [PATCH 66/72] removed project update --- src/database/models/event_item.rs | 30 ---------------------- src/models/feeds.rs | 5 ---- src/routes/v2/projects.rs | 41 ------------------------------- src/routes/v3/users.rs | 18 -------------- tests/common/asserts.rs | 6 ----- tests/feed.rs | 32 +----------------------- 6 files changed, 1 insertion(+), 131 deletions(-) diff --git a/src/database/models/event_item.rs b/src/database/models/event_item.rs index 567d3934..a93ffdcb 100644 --- a/src/database/models/event_item.rs +++ b/src/database/models/event_item.rs @@ -13,7 +13,6 @@ use std::convert::{TryFrom, TryInto}; pub enum EventType { ProjectPublished, VersionCreated, - ProjectUpdated, } impl PgHasArrayType for EventType { @@ -38,10 +37,6 @@ pub enum EventData { version_id: VersionId, creator_id: CreatorId, }, - ProjectUpdated { - project_id: ProjectId, - updater_id: CreatorId, - }, } #[derive(Debug)] @@ -138,23 +133,6 @@ impl From for RawEvent { created: None, } } - EventData::ProjectUpdated { - project_id, - updater_id, - } => { - let target_id = DynamicId::from(project_id); - let triggerer_id = DynamicId::from(updater_id); - RawEvent { - id: value.id, - target_id: target_id.id, - target_id_type: target_id.id_type, - triggerer_id: Some(triggerer_id.id), - triggerer_id_type: Some(triggerer_id.id_type), - event_type: EventType::ProjectUpdated, - metadata: None, - created: None, - } - } } } } @@ -191,14 +169,6 @@ impl TryFrom for Event { )) }, |v| v.try_into())?, }, - EventType::ProjectUpdated => EventData::ProjectUpdated { - project_id: target_id.try_into()?, - updater_id: triggerer_id.map_or_else(|| { - Err(DatabaseError::UnexpectedNull( - "Neither triggerer_id nor triggerer_id_type should be null for project update".to_string(), - )) - }, |v| v.try_into())?, - }, }, time : value.created.map_or_else( || { diff --git a/src/models/feeds.rs b/src/models/feeds.rs index be3fdb5b..956bb6c5 100644 --- a/src/models/feeds.rs +++ b/src/models/feeds.rs @@ -34,11 +34,6 @@ pub enum FeedItemBody { creator_id: CreatorId, project_title: String, }, - ProjectUpdated { - project_id: ProjectId, - updater_id: CreatorId, - project_title: String, - }, VersionCreated { project_id: ProjectId, version_id: VersionId, diff --git a/src/routes/v2/projects.rs b/src/routes/v2/projects.rs index 29b25683..9c55c0a9 100644 --- a/src/routes/v2/projects.rs +++ b/src/routes/v2/projects.rs @@ -1136,16 +1136,6 @@ pub async fn project_edit( ) .await?; - if project_item.inner.status.is_searchable() { - insert_project_update_event( - project_item.inner.id.into(), - project_item.inner.organization_id, - &user, - &mut transaction, - ) - .await?; - } - transaction.commit().await?; Ok(HttpResponse::NoContent().body("")) } else { @@ -1495,16 +1485,6 @@ pub async fn projects_edit( .await?; } - if project.inner.status.is_searchable() { - insert_project_update_event( - project.inner.id.into(), - project.inner.organization_id, - &user, - &mut transaction, - ) - .await?; - } - db_models::Project::clear_cache(project.inner.id, project.inner.slug, None, &redis).await?; } @@ -2597,27 +2577,6 @@ pub async fn delete_from_index( Ok(()) } -async fn insert_project_update_event( - project_id: ProjectId, - organization_id: Option, - current_user: &crate::models::users::User, - transaction: &mut sqlx::Transaction<'_, sqlx::Postgres>, -) -> Result<(), ApiError> { - let event = Event::new( - EventData::ProjectUpdated { - project_id: project_id.into(), - updater_id: organization_id.map_or_else( - || CreatorId::User(current_user.id.into()), - CreatorId::Organization, - ), - }, - transaction, - ) - .await?; - event.insert(transaction).await?; - Ok(()) -} - async fn insert_project_publish_event( project_id: ProjectId, organization_id: Option, diff --git a/src/routes/v3/users.rs b/src/routes/v3/users.rs index 8593540e..fa2155ea 100644 --- a/src/routes/v3/users.rs +++ b/src/routes/v3/users.rs @@ -143,12 +143,9 @@ pub async fn current_user_feed( // - Projects created by organizations you follow // - Versions created by users you follow // - Versions created by organizations you follow - // - Projects updated by users you follow - // - Projects updated by organizations you follow let event_types = [ EventType::ProjectPublished, EventType::VersionCreated, - EventType::ProjectUpdated, ]; let selectors = followed_users .into_iter() @@ -201,16 +198,6 @@ pub async fn current_user_feed( project_title: p.title.clone(), } }), - EventData::ProjectUpdated { - project_id, - updater_id, - } => authorized_projects.get(&project_id.into()).map(|p| { - FeedItemBody::ProjectUpdated { - project_id: project_id.into(), - updater_id: updater_id.into(), - project_title: p.title.clone(), - } - }), EventData::VersionCreated { version_id, creator_id, @@ -261,10 +248,6 @@ async fn prefetch_authorized_event_projects( project_id, creator_id: _, } => Some(*project_id), - EventData::ProjectUpdated { - project_id, - updater_id: _, - } => Some(*project_id), EventData::VersionCreated { .. } => None, }) .collect_vec(); @@ -298,7 +281,6 @@ async fn prefetch_authorized_event_versions( creator_id: _, } => Some(*version_id), EventData::ProjectPublished { .. } => None, - EventData::ProjectUpdated { .. } => None, }) .collect_vec(); let versions = db_models::Version::get_many(&version_ids, &***pool, redis).await?; diff --git a/tests/common/asserts.rs b/tests/common/asserts.rs index 131cadba..111d0b93 100644 --- a/tests/common/asserts.rs +++ b/tests/common/asserts.rs @@ -33,12 +33,6 @@ pub fn assert_feed_contains_project_created( ) { assert!(feed.iter().any(|fi| matches!(fi.body, FeedItemBody::ProjectPublished { project_id, .. } if project_id == expected_project_id)), "{:#?}", &feed); } -pub fn assert_feed_contains_project_updated( - feed: &[FeedItem], - expected_project_id: labrinth::models::projects::ProjectId, -) { - assert!(feed.iter().any(|fi| matches!(fi.body, FeedItemBody::ProjectUpdated { project_id, .. } if project_id == expected_project_id)), "{:#?}", &feed); -} pub fn assert_feed_contains_version_created( feed: &[FeedItem], expected_version_id: labrinth::models::projects::VersionId, diff --git a/tests/feed.rs b/tests/feed.rs index a8f52a65..09efe51c 100644 --- a/tests/feed.rs +++ b/tests/feed.rs @@ -1,6 +1,6 @@ use crate::common::{ asserts::{ - assert_feed_contains_project_created, assert_feed_contains_project_updated, + assert_feed_contains_project_created, assert_feed_contains_version_created, }, dummy_data::DummyProjectAlpha, @@ -73,36 +73,6 @@ async fn get_feed_after_following_user_shows_previously_created_public_versions( .await } -#[actix_rt::test] -async fn get_feed_after_following_user_shows_previously_edited_public_versions() { - with_test_environment(|env| async move { - let DummyProjectAlpha { - project_id: alpha_project_id, - .. - } = env.dummy.as_ref().unwrap().project_alpha.clone(); - - // Empty patch - env.v2 - .edit_project(&alpha_project_id, serde_json::json!({}), USER_USER_PAT) - .await; - - env.v3.follow_user(USER_USER_ID, FRIEND_USER_PAT).await; - - let feed = env.v3.get_feed(FRIEND_USER_PAT).await; - - assert_eq!(feed.len(), 2); - assert_feed_contains_project_created( - &feed, - ProjectId(parse_base62(&alpha_project_id).unwrap()), - ); - assert_feed_contains_project_updated( - &feed, - ProjectId(parse_base62(&alpha_project_id).unwrap()), - ); - }) - .await -} - #[actix_rt::test] async fn get_feed_when_following_user_that_creates_project_as_org_only_shows_event_when_following_org( ) { From f3b837514d4613be7d00e68800cdedf84e6309dc Mon Sep 17 00:00:00 2001 From: Wyatt Verchere Date: Fri, 3 Nov 2023 10:26:33 -0700 Subject: [PATCH 67/72] fixed nightly failure --- src/routes/v3/organizations.rs | 14 +++--- src/routes/v3/users.rs | 80 +++++++++++++++++----------------- tests/common/database.rs | 2 +- tests/feed.rs | 5 +-- 4 files changed, 51 insertions(+), 50 deletions(-) diff --git a/src/routes/v3/organizations.rs b/src/routes/v3/organizations.rs index e2093f57..5fe23809 100644 --- a/src/routes/v3/organizations.rs +++ b/src/routes/v3/organizations.rs @@ -53,11 +53,15 @@ pub async fn organization_follow( .insert(&**pool) .await .map_err(|e| match e { - DatabaseError::Database(e) - if e.as_database_error() - .is_some_and(|e| e.is_unique_violation()) => - { - ApiError::InvalidInput("You are already following this organization!".to_string()) + DatabaseError::Database(e) => { + if let Some(db_err) = e.as_database_error() { + if db_err.is_unique_violation() { + return ApiError::InvalidInput( + "You are already following this organization!".to_string(), + ); + } + } + e.into() } e => e.into(), })?; diff --git a/src/routes/v3/users.rs b/src/routes/v3/users.rs index fa2155ea..80f422ea 100644 --- a/src/routes/v3/users.rs +++ b/src/routes/v3/users.rs @@ -72,11 +72,15 @@ pub async fn user_follow( .insert(&**pool) .await .map_err(|e| match e { - DatabaseError::Database(e) - if e.as_database_error() - .is_some_and(|e| e.is_unique_violation()) => - { - ApiError::InvalidInput("You are already following this user!".to_string()) + DatabaseError::Database(e) => { + if let Some(db_err) = e.as_database_error() { + if db_err.is_unique_violation() { + return ApiError::InvalidInput( + "You are already following this user!".to_string(), + ); + } + } + e.into() } e => e.into(), })?; @@ -143,10 +147,7 @@ pub async fn current_user_feed( // - Projects created by organizations you follow // - Versions created by users you follow // - Versions created by organizations you follow - let event_types = [ - EventType::ProjectPublished, - EventType::VersionCreated, - ]; + let event_types = [EventType::ProjectPublished, EventType::VersionCreated]; let selectors = followed_users .into_iter() .flat_map(|follow| { @@ -186,39 +187,38 @@ pub async fn current_user_feed( .await?; for event in events { - let body = - match event.event_data { - EventData::ProjectPublished { - project_id, - creator_id, - } => authorized_projects.get(&project_id.into()).map(|p| { - FeedItemBody::ProjectPublished { - project_id: project_id.into(), + let body = match event.event_data { + EventData::ProjectPublished { + project_id, + creator_id, + } => authorized_projects.get(&project_id.into()).map(|p| { + FeedItemBody::ProjectPublished { + project_id: project_id.into(), + creator_id: creator_id.into(), + project_title: p.title.clone(), + } + }), + EventData::VersionCreated { + version_id, + creator_id, + } => { + let authorized_version = authorized_versions.get(&version_id.into()); + let authorized_project = + authorized_version.and_then(|v| authorized_projects.get(&v.project_id)); + if let (Some(authorized_version), Some(authorized_project)) = + (authorized_version, authorized_project) + { + Some(FeedItemBody::VersionCreated { + project_id: authorized_project.id, + version_id: authorized_version.id, creator_id: creator_id.into(), - project_title: p.title.clone(), - } - }), - EventData::VersionCreated { - version_id, - creator_id, - } => { - let authorized_version = authorized_versions.get(&version_id.into()); - let authorized_project = - authorized_version.and_then(|v| authorized_projects.get(&v.project_id)); - if let (Some(authorized_version), Some(authorized_project)) = - (authorized_version, authorized_project) - { - Some(FeedItemBody::VersionCreated { - project_id: authorized_project.id, - version_id: authorized_version.id, - creator_id: creator_id.into(), - project_title: authorized_project.title.clone(), - }) - } else { - None - } + project_title: authorized_project.title.clone(), + }) + } else { + None } - }; + } + }; if let Some(body) = body { let feed_item = FeedItem { diff --git a/tests/common/database.rs b/tests/common/database.rs index 988cc028..38bf6585 100644 --- a/tests/common/database.rs +++ b/tests/common/database.rs @@ -146,7 +146,7 @@ impl TemporaryDatabase { .fetch_optional(&pool) .await .unwrap(); - let needs_update = !dummy_data_update.is_some_and(|d| d == DUMMY_DATA_UPDATE); + let needs_update = !matches!(dummy_data_update, Some(DUMMY_DATA_UPDATE)); if needs_update { println!("Dummy data updated, so template DB tables will be dropped and re-created"); // Drop all tables in the database so they can be re-created and later filled with updated dummy data diff --git a/tests/feed.rs b/tests/feed.rs index 09efe51c..46324c43 100644 --- a/tests/feed.rs +++ b/tests/feed.rs @@ -1,8 +1,5 @@ use crate::common::{ - asserts::{ - assert_feed_contains_project_created, - assert_feed_contains_version_created, - }, + asserts::{assert_feed_contains_project_created, assert_feed_contains_version_created}, dummy_data::DummyProjectAlpha, }; use assert_matches::assert_matches; From 0943ae4d3882529f64f0b56e437103092ef74588 Mon Sep 17 00:00:00 2001 From: Wyatt Verchere Date: Fri, 3 Nov 2023 10:28:27 -0700 Subject: [PATCH 68/72] Revert "fixed nightly failure" This reverts commit f3b837514d4613be7d00e68800cdedf84e6309dc. --- src/routes/v3/organizations.rs | 14 +++--- src/routes/v3/users.rs | 80 +++++++++++++++++----------------- tests/common/database.rs | 2 +- tests/feed.rs | 5 ++- 4 files changed, 50 insertions(+), 51 deletions(-) diff --git a/src/routes/v3/organizations.rs b/src/routes/v3/organizations.rs index 5fe23809..e2093f57 100644 --- a/src/routes/v3/organizations.rs +++ b/src/routes/v3/organizations.rs @@ -53,15 +53,11 @@ pub async fn organization_follow( .insert(&**pool) .await .map_err(|e| match e { - DatabaseError::Database(e) => { - if let Some(db_err) = e.as_database_error() { - if db_err.is_unique_violation() { - return ApiError::InvalidInput( - "You are already following this organization!".to_string(), - ); - } - } - e.into() + DatabaseError::Database(e) + if e.as_database_error() + .is_some_and(|e| e.is_unique_violation()) => + { + ApiError::InvalidInput("You are already following this organization!".to_string()) } e => e.into(), })?; diff --git a/src/routes/v3/users.rs b/src/routes/v3/users.rs index 80f422ea..fa2155ea 100644 --- a/src/routes/v3/users.rs +++ b/src/routes/v3/users.rs @@ -72,15 +72,11 @@ pub async fn user_follow( .insert(&**pool) .await .map_err(|e| match e { - DatabaseError::Database(e) => { - if let Some(db_err) = e.as_database_error() { - if db_err.is_unique_violation() { - return ApiError::InvalidInput( - "You are already following this user!".to_string(), - ); - } - } - e.into() + DatabaseError::Database(e) + if e.as_database_error() + .is_some_and(|e| e.is_unique_violation()) => + { + ApiError::InvalidInput("You are already following this user!".to_string()) } e => e.into(), })?; @@ -147,7 +143,10 @@ pub async fn current_user_feed( // - Projects created by organizations you follow // - Versions created by users you follow // - Versions created by organizations you follow - let event_types = [EventType::ProjectPublished, EventType::VersionCreated]; + let event_types = [ + EventType::ProjectPublished, + EventType::VersionCreated, + ]; let selectors = followed_users .into_iter() .flat_map(|follow| { @@ -187,38 +186,39 @@ pub async fn current_user_feed( .await?; for event in events { - let body = match event.event_data { - EventData::ProjectPublished { - project_id, - creator_id, - } => authorized_projects.get(&project_id.into()).map(|p| { - FeedItemBody::ProjectPublished { - project_id: project_id.into(), - creator_id: creator_id.into(), - project_title: p.title.clone(), - } - }), - EventData::VersionCreated { - version_id, - creator_id, - } => { - let authorized_version = authorized_versions.get(&version_id.into()); - let authorized_project = - authorized_version.and_then(|v| authorized_projects.get(&v.project_id)); - if let (Some(authorized_version), Some(authorized_project)) = - (authorized_version, authorized_project) - { - Some(FeedItemBody::VersionCreated { - project_id: authorized_project.id, - version_id: authorized_version.id, + let body = + match event.event_data { + EventData::ProjectPublished { + project_id, + creator_id, + } => authorized_projects.get(&project_id.into()).map(|p| { + FeedItemBody::ProjectPublished { + project_id: project_id.into(), creator_id: creator_id.into(), - project_title: authorized_project.title.clone(), - }) - } else { - None + project_title: p.title.clone(), + } + }), + EventData::VersionCreated { + version_id, + creator_id, + } => { + let authorized_version = authorized_versions.get(&version_id.into()); + let authorized_project = + authorized_version.and_then(|v| authorized_projects.get(&v.project_id)); + if let (Some(authorized_version), Some(authorized_project)) = + (authorized_version, authorized_project) + { + Some(FeedItemBody::VersionCreated { + project_id: authorized_project.id, + version_id: authorized_version.id, + creator_id: creator_id.into(), + project_title: authorized_project.title.clone(), + }) + } else { + None + } } - } - }; + }; if let Some(body) = body { let feed_item = FeedItem { diff --git a/tests/common/database.rs b/tests/common/database.rs index 38bf6585..988cc028 100644 --- a/tests/common/database.rs +++ b/tests/common/database.rs @@ -146,7 +146,7 @@ impl TemporaryDatabase { .fetch_optional(&pool) .await .unwrap(); - let needs_update = !matches!(dummy_data_update, Some(DUMMY_DATA_UPDATE)); + let needs_update = !dummy_data_update.is_some_and(|d| d == DUMMY_DATA_UPDATE); if needs_update { println!("Dummy data updated, so template DB tables will be dropped and re-created"); // Drop all tables in the database so they can be re-created and later filled with updated dummy data diff --git a/tests/feed.rs b/tests/feed.rs index 46324c43..09efe51c 100644 --- a/tests/feed.rs +++ b/tests/feed.rs @@ -1,5 +1,8 @@ use crate::common::{ - asserts::{assert_feed_contains_project_created, assert_feed_contains_version_created}, + asserts::{ + assert_feed_contains_project_created, + assert_feed_contains_version_created, + }, dummy_data::DummyProjectAlpha, }; use assert_matches::assert_matches; From 930cd9afb646c5916f185dd7191edaa6cb3705c5 Mon Sep 17 00:00:00 2001 From: Wyatt Verchere Date: Fri, 3 Nov 2023 10:30:24 -0700 Subject: [PATCH 69/72] bumped rust version --- Dockerfile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Dockerfile b/Dockerfile index c0a65d8e..f57fa864 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,4 +1,4 @@ -FROM rust:1.68.0 as build +FROM rust:1.70.0 as build ENV PKG_CONFIG_ALLOW_CROSS=1 WORKDIR /usr/src/labrinth From 2b5b5be341c7613fe80657692010ef97cf3ca0d0 Mon Sep 17 00:00:00 2001 From: Wyatt Verchere Date: Wed, 15 Nov 2023 10:21:00 -0800 Subject: [PATCH 70/72] merge conflict fixes; passes tests --- src/models/mod.rs | 2 +- src/models/{ => v3}/feeds.rs | 0 src/models/v3/mod.rs | 1 + src/routes/v2/project_creation.rs | 1 - src/routes/v2/projects.rs | 1 - src/routes/v2/version_creation.rs | 23 +------- src/routes/v3/analytics_get.rs | 4 +- src/routes/v3/mod.rs | 2 - src/routes/v3/organizations.rs | 40 +++++-------- src/routes/v3/project_creation.rs | 25 +++++++- src/routes/v3/projects.rs | 40 ++++++++++++- src/routes/v3/users.rs | 88 ++++++++++++----------------- src/routes/v3/version_creation.rs | 32 ++++++++++- src/routes/v3/version_file.rs | 4 +- src/routes/v3/versions.rs | 4 +- tests/common/api_v2/project.rs | 2 +- tests/common/api_v3/organization.rs | 3 +- tests/common/api_v3/user.rs | 5 +- tests/common/request_data.rs | 3 +- tests/organizations.rs | 3 +- tests/project.rs | 2 + tests/search.rs | 2 +- 22 files changed, 160 insertions(+), 127 deletions(-) rename src/models/{ => v3}/feeds.rs (100%) diff --git a/src/models/mod.rs b/src/models/mod.rs index 76aa34df..e6f1653a 100644 --- a/src/models/mod.rs +++ b/src/models/mod.rs @@ -4,7 +4,7 @@ pub mod v3; pub use v3::analytics; pub use v3::collections; pub use v3::error; -pub mod feeds; +pub use v3::feeds; pub use v3::ids; pub use v3::images; pub use v3::notifications; diff --git a/src/models/feeds.rs b/src/models/v3/feeds.rs similarity index 100% rename from src/models/feeds.rs rename to src/models/v3/feeds.rs diff --git a/src/models/v3/mod.rs b/src/models/v3/mod.rs index 7c97ad31..e0bc2c23 100644 --- a/src/models/v3/mod.rs +++ b/src/models/v3/mod.rs @@ -1,6 +1,7 @@ pub mod analytics; pub mod collections; pub mod error; +pub mod feeds; pub mod ids; pub mod images; pub mod notifications; diff --git a/src/routes/v2/project_creation.rs b/src/routes/v2/project_creation.rs index 39611787..56f6f267 100644 --- a/src/routes/v2/project_creation.rs +++ b/src/routes/v2/project_creation.rs @@ -1,7 +1,6 @@ use crate::database::models::version_item; use crate::database::redis::RedisPool; use crate::file_hosting::FileHost; -use crate::models; use crate::models::ids::ImageId; use crate::models::projects::{DonationLink, Loader, Project, ProjectStatus, SideType}; use crate::models::v2::projects::LegacyProject; diff --git a/src/routes/v2/projects.rs b/src/routes/v2/projects.rs index e1d96121..f85f785c 100644 --- a/src/routes/v2/projects.rs +++ b/src/routes/v2/projects.rs @@ -17,7 +17,6 @@ use serde_json::json; use sqlx::PgPool; use std::collections::HashMap; use std::sync::Arc; -use team_item::TeamMember; use validator::Validate; pub fn config(cfg: &mut web::ServiceConfig) { diff --git a/src/routes/v2/version_creation.rs b/src/routes/v2/version_creation.rs index a9a8f46e..ebc1180a 100644 --- a/src/routes/v2/version_creation.rs +++ b/src/routes/v2/version_creation.rs @@ -168,25 +168,4 @@ pub async fn upload_file_to_version( ) .await?; Ok(response) -} - -async fn insert_version_create_event( - version_id: VersionId, - organization_id: Option, - current_user: &crate::models::users::User, - transaction: &mut sqlx::Transaction<'_, sqlx::Postgres>, -) -> Result<(), CreateError> { - let event = Event::new( - EventData::VersionCreated { - version_id: version_id.into(), - creator_id: organization_id.map_or_else( - || CreatorId::User(current_user.id.into()), - CreatorId::Organization, - ), - }, - transaction, - ) - .await?; - event.insert(transaction).await?; - Ok(()) -} +} \ No newline at end of file diff --git a/src/routes/v3/analytics_get.rs b/src/routes/v3/analytics_get.rs index dc31c69c..bbe7b131 100644 --- a/src/routes/v3/analytics_get.rs +++ b/src/routes/v3/analytics_get.rs @@ -580,7 +580,7 @@ async fn filter_allowed_ids( .map(|id| Ok(ProjectId(parse_base62(id)?).into())) .collect::, ApiError>>()?; let projects = project_item::Project::get_many_ids(&ids, &***pool, redis).await?; - let ids: Vec = filter_authorized_projects(projects, &Some(user.clone()), pool) + let ids: Vec = filter_authorized_projects(projects, Some(&user), pool) .await? .into_iter() .map(|x| x.id) @@ -596,7 +596,7 @@ async fn filter_allowed_ids( .map(|id| Ok(VersionId(parse_base62(id)?).into())) .collect::, ApiError>>()?; let versions = version_item::Version::get_many(&ids, &***pool, redis).await?; - let ids: Vec = filter_authorized_versions(versions, &Some(user), pool) + let ids: Vec = filter_authorized_versions(versions, Some(&user), pool) .await? .into_iter() .map(|x| x.id) diff --git a/src/routes/v3/mod.rs b/src/routes/v3/mod.rs index de343a11..b48421db 100644 --- a/src/routes/v3/mod.rs +++ b/src/routes/v3/mod.rs @@ -22,8 +22,6 @@ pub mod version_file; pub mod versions; pub mod oauth_clients; -pub mod organizations; -pub mod users; pub fn config(cfg: &mut web::ServiceConfig) { cfg.service( diff --git a/src/routes/v3/organizations.rs b/src/routes/v3/organizations.rs index 9a62289f..4ab72687 100644 --- a/src/routes/v3/organizations.rs +++ b/src/routes/v3/organizations.rs @@ -21,6 +21,10 @@ use rust_decimal::Decimal; use serde::{Deserialize, Serialize}; use sqlx::PgPool; use validator::Validate; +use crate::database::models::DatabaseError; +use database::models::creator_follows::OrganizationFollow as DBOrganizationFollow; +use database::models::organization_item::Organization as DBOrganization; + pub fn config(cfg: &mut web::ServiceConfig) { cfg.service( @@ -39,6 +43,14 @@ pub fn config(cfg: &mut web::ServiceConfig) { .route( "{id}/members", web::get().to(super::teams::team_members_get_organization), + ) + .route( + "{id}/follow", + web::post().to(organization_follow), + ) + .route( + "{id}/follow", + web::delete().to(organization_unfollow), ), ); } @@ -82,7 +94,7 @@ pub async fn organization_projects_get( let projects_data = crate::database::models::Project::get_many_ids(&project_ids, &**pool, &redis).await?; - let projects = filter_authorized_projects(projects_data, ¤t_user, &pool).await?; + let projects = filter_authorized_projects(projects_data, current_user.as_ref(), &pool).await?; Ok(HttpResponse::Ok().json(projects)) } @@ -916,32 +928,7 @@ pub async fn delete_organization_icon( Ok(HttpResponse::NoContent().body("")) } -use crate::{ - auth::get_user_from_headers, - database::{self, models::DatabaseError, redis::RedisPool}, - models::pats::Scopes, - queue::session::AuthQueue, - routes::ApiError, -}; -use actix_web::{ - delete, post, - web::{self}, - HttpRequest, HttpResponse, -}; -use sqlx::PgPool; - -use database::models::creator_follows::OrganizationFollow as DBOrganizationFollow; -use database::models::organization_item::Organization as DBOrganization; - -pub fn config(cfg: &mut web::ServiceConfig) { - cfg.service( - web::scope("organization") - .service(organization_follow) - .service(organization_unfollow), - ); -} -#[post("{id}/follow")] pub async fn organization_follow( req: HttpRequest, target_id: web::Path, @@ -983,7 +970,6 @@ pub async fn organization_follow( Ok(HttpResponse::NoContent().body("")) } -#[delete("{id}/follow")] pub async fn organization_unfollow( req: HttpRequest, target_id: web::Path, diff --git a/src/routes/v3/project_creation.rs b/src/routes/v3/project_creation.rs index a762b168..7f115383 100644 --- a/src/routes/v3/project_creation.rs +++ b/src/routes/v3/project_creation.rs @@ -14,7 +14,7 @@ use crate::models::pats::Scopes; use crate::models::projects::{ DonationLink, License, MonetizationStatus, ProjectId, ProjectStatus, VersionId, VersionStatus, }; -use crate::models::teams::ProjectPermissions; +use crate::models::teams::{ProjectPermissions, OrganizationPermissions}; use crate::models::threads::ThreadType; use crate::models::users::UserId; use crate::queue::session::AuthQueue; @@ -447,6 +447,29 @@ async fn project_create_inner( } } + // If organization_id is set, make sure the user is a member of the organization + if let Some(organization_id) = create_data.organization_id { + let organization_team_member = + models::team_item::TeamMember::get_from_user_id_organization( + organization_id.into(), + current_user.id.into(), + &mut **transaction, + ) + .await?; + + let permissions = OrganizationPermissions::get_permissions_by_role( + ¤t_user.role, + &organization_team_member, + ) + .unwrap_or_default(); + + if !permissions.contains(OrganizationPermissions::ADD_PROJECT) { + return Err(CreateError::CustomAuthenticationError( + "You do not have permission to add projects to this organization!".to_string(), + )); + } + } + // Create VersionBuilders for the versions specified in `initial_versions` versions = Vec::with_capacity(create_data.initial_versions.len()); for (i, data) in create_data.initial_versions.iter().enumerate() { diff --git a/src/routes/v3/projects.rs b/src/routes/v3/projects.rs index d35182aa..a4c5112c 100644 --- a/src/routes/v3/projects.rs +++ b/src/routes/v3/projects.rs @@ -1,10 +1,11 @@ use std::sync::Arc; use crate::auth::{filter_authorized_projects, get_user_from_headers, is_authorized}; +use crate::database::models::event_item::{EventData, CreatorId}; use crate::database::models::notification_item::NotificationBuilder; use crate::database::models::project_item::{GalleryItem, ModCategory}; use crate::database::models::thread_item::ThreadMessageBuilder; -use crate::database::models::{ids as db_ids, image_item}; +use crate::database::models::{ids as db_ids, image_item, Event, TeamMember}; use crate::database::redis::RedisPool; use crate::database::{self, models as db_models}; use crate::file_hosting::FileHost; @@ -26,6 +27,7 @@ use crate::util::routes::read_from_payload; use crate::util::validate::validation_errors_to_string; use actix_web::{web, HttpRequest, HttpResponse}; use chrono::{DateTime, Utc}; +use db_ids::{UserId, OrganizationId}; use futures::TryStreamExt; use meilisearch_sdk::indexes::IndexesResults; use serde::{Deserialize, Serialize}; @@ -134,7 +136,7 @@ pub async fn projects_get( .map(|x| x.1) .ok(); - let projects = filter_authorized_projects(projects_data, &user_option, &pool).await?; + let projects = filter_authorized_projects(projects_data, user_option.as_ref(), &pool).await?; Ok(HttpResponse::Ok().json(projects)) } @@ -411,6 +413,21 @@ pub async fn project_edit( ) .execute(&mut *transaction) .await?; + + // On publish event, we send out notification using team *owner* as publishing user, at the time of mod approval + // (even though 'user' is the mod and doing the publishing) + let owner_id = + TeamMember::get_owner_id(project_item.inner.team_id, &mut *transaction) + .await?; + if let Some(owner_id) = owner_id { + insert_project_publish_event( + id.into(), + project_item.inner.organization_id, + owner_id, + &mut transaction, + ) + .await?; + } } if status.is_searchable() && !project_item.inner.webhook_sent { if let Ok(webhook_url) = dotenvy::var("PUBLIC_DISCORD_WEBHOOK") { @@ -2493,3 +2510,22 @@ pub async fn project_unfollow( )) } } + +async fn insert_project_publish_event( + project_id: ProjectId, + organization_id: Option, + owner_id: UserId, + transaction: &mut sqlx::Transaction<'_, sqlx::Postgres>, +) -> Result<(), ApiError> { + let event = Event::new( + EventData::ProjectPublished { + project_id: project_id.into(), + creator_id: organization_id + .map_or_else(|| CreatorId::User(owner_id), CreatorId::Organization), + }, + transaction, + ) + .await?; + event.insert(transaction).await?; + Ok(()) +} diff --git a/src/routes/v3/users.rs b/src/routes/v3/users.rs index a6cea63c..b3c01f5b 100644 --- a/src/routes/v3/users.rs +++ b/src/routes/v3/users.rs @@ -12,9 +12,20 @@ use validator::Validate; use crate::{ auth::get_user_from_headers, + auth::{filter_authorized_projects, filter_authorized_versions}, + database::{ + self, + models::{ + event_item::{EventData, EventSelector, EventType}, + DatabaseError, + }, + }, database::{models::User, redis::RedisPool}, file_hosting::FileHost, models::{ + feeds::{FeedItem, FeedItemBody}, + ids::{ProjectId, VersionId}, + projects::Version, collections::{Collection, CollectionStatus}, ids::UserId, notifications::Notification, @@ -25,6 +36,14 @@ use crate::{ queue::{payouts::PayoutsQueue, session::AuthQueue}, util::{routes::read_from_payload, validate::validation_errors_to_string}, }; +use std::iter::FromIterator; +use itertools::Itertools; + +use database::models as db_models; +use database::models::creator_follows::OrganizationFollow as DBOrganizationFollow; +use database::models::creator_follows::UserFollow as DBUserFollow; +use database::models::event_item::Event as DBEvent; +use database::models::user_item::User as DBUser; use super::ApiError; @@ -34,6 +53,7 @@ pub fn config(cfg: &mut web::ServiceConfig) { cfg.service( web::scope("user") + .route("feed", web::get().to(current_user_feed)) .route("{user_id}/projects", web::get().to(projects_list)) .route("{id}", web::get().to(user_get)) .route("{user_id}/collections", web::get().to(collections_list)) @@ -45,7 +65,9 @@ pub fn config(cfg: &mut web::ServiceConfig) { .route("{id}/notifications", web::get().to(user_notifications)) .route("{id}/payouts", web::get().to(user_payouts)) .route("{id}/payouts_fees", web::get().to(user_payouts_fees)) - .route("{id}/payouts", web::post().to(user_payouts_request)), + .route("{id}/payouts", web::post().to(user_payouts_request)) + .route("{id}/follow", web::post().to(user_follow)) + .route("{id}/follow", web::delete().to(user_unfollow)) ); } @@ -911,53 +933,7 @@ pub async fn user_payouts_request( Ok(HttpResponse::NotFound().body("")) } } -use std::{collections::HashMap, iter::FromIterator}; -use crate::{ - auth::{filter_authorized_projects, filter_authorized_versions, get_user_from_headers}, - database::{ - self, - models::{ - event_item::{EventData, EventSelector, EventType}, - DatabaseError, - }, - redis::RedisPool, - }, - models::{ - feeds::{FeedItem, FeedItemBody}, - ids::{ProjectId, VersionId}, - pats::Scopes, - projects::{Project, Version}, - users::User, - }, - queue::session::AuthQueue, - routes::ApiError, -}; -use actix_web::{ - delete, get, post, - web::{self}, - HttpRequest, HttpResponse, -}; -use itertools::Itertools; -use serde::{Deserialize, Serialize}; -use sqlx::PgPool; - -use database::models as db_models; -use database::models::creator_follows::OrganizationFollow as DBOrganizationFollow; -use database::models::creator_follows::UserFollow as DBUserFollow; -use database::models::event_item::Event as DBEvent; -use database::models::user_item::User as DBUser; - -pub fn config(cfg: &mut web::ServiceConfig) { - cfg.service( - web::scope("user") - .service(user_follow) - .service(user_unfollow) - .service(current_user_feed), - ); -} - -#[post("{id}/follow")] pub async fn user_follow( req: HttpRequest, target_id: web::Path, @@ -965,6 +941,7 @@ pub async fn user_follow( redis: web::Data, session_queue: web::Data, ) -> Result { + println!("inside user_follow"); let (_, current_user) = get_user_from_headers( &req, &**pool, @@ -973,11 +950,13 @@ pub async fn user_follow( Some(&[Scopes::USER_WRITE]), ) .await?; + println!("current_user: {:?}", current_user); let target = DBUser::get(&target_id, &**pool, &redis) .await? .ok_or_else(|| ApiError::InvalidInput("The specified user does not exist!".to_string()))?; + println!("target: {:?}", target); DBUserFollow { follower_id: current_user.id.into(), target_id: target.id, @@ -993,11 +972,10 @@ pub async fn user_follow( } e => e.into(), })?; - + println!("inserted"); Ok(HttpResponse::NoContent().body("")) } -#[delete("{id}/follow")] pub async fn user_unfollow( req: HttpRequest, target_id: web::Path, @@ -1029,7 +1007,6 @@ pub struct FeedParameters { pub offset: Option, } -#[get("feed")] pub async fn current_user_feed( req: HttpRequest, web::Query(params): web::Query, @@ -1037,6 +1014,7 @@ pub async fn current_user_feed( redis: web::Data, session_queue: web::Data, ) -> Result { + println!("In current_user_feed"); let (_, current_user) = get_user_from_headers( &req, &**pool, @@ -1045,12 +1023,14 @@ pub async fn current_user_feed( Some(&[Scopes::NOTIFICATION_READ]), ) .await?; + println!("current_user: {:?}", current_user); let followed_users = DBUserFollow::get_follows_by_follower(current_user.id.into(), &**pool).await?; let followed_organizations = DBOrganizationFollow::get_follows_by_follower(current_user.id.into(), &**pool).await?; + println!("followed_users: {:?}", followed_users); // Feed by default shows the following: // - Projects created by users you follow // - Projects created by organizations you follow @@ -1075,13 +1055,14 @@ pub async fn current_user_feed( }) })) .collect_vec(); + println!("selectors:"); let events = DBEvent::get_events(&[], &selectors, &**pool) .await? .into_iter() .skip(params.offset.unwrap_or(0)) .take(params.offset.unwrap_or(usize::MAX)) .collect_vec(); - + println!("events: {:?}", events); let mut feed_items: Vec = Vec::new(); let authorized_versions = prefetch_authorized_event_versions(&events, &pool, &redis, ¤t_user).await?; @@ -1097,6 +1078,7 @@ pub async fn current_user_feed( ¤t_user, ) .await?; + println!("authorized projects"); for event in events { let body = @@ -1152,7 +1134,7 @@ async fn prefetch_authorized_event_projects( additional_ids: Option<&[ProjectId]>, pool: &web::Data, redis: &RedisPool, - current_user: &User, + current_user: &crate::models::v3::users::User, ) -> Result, ApiError> { let mut project_ids = events .iter() @@ -1184,7 +1166,7 @@ async fn prefetch_authorized_event_versions( events: &[db_models::Event], pool: &web::Data, redis: &RedisPool, - current_user: &User, + current_user: &crate::models::v3::users::User, ) -> Result, ApiError> { let version_ids = events .iter() diff --git a/src/routes/v3/version_creation.rs b/src/routes/v3/version_creation.rs index 106b956d..816d3304 100644 --- a/src/routes/v3/version_creation.rs +++ b/src/routes/v3/version_creation.rs @@ -1,11 +1,12 @@ use super::project_creation::{CreateError, UploadedFile}; use crate::auth::get_user_from_headers; +use crate::database::models::event_item::{EventData, CreatorId}; use crate::database::models::loader_fields::{LoaderField, LoaderFieldEnumValue, VersionField}; use crate::database::models::notification_item::NotificationBuilder; use crate::database::models::version_item::{ DependencyBuilder, VersionBuilder, VersionFileBuilder, }; -use crate::database::models::{self, image_item, Organization}; +use crate::database::models::{self, image_item, Organization, Event}; use crate::database::redis::RedisPool; use crate::file_hosting::FileHost; use crate::models::images::{Image, ImageContext, ImageId}; @@ -312,6 +313,14 @@ async fn version_create_inner( }) .collect::>(); + insert_version_create_event( + version_id, + organization.map(|o| o.id), + &user, + transaction, + ) + .await?; + version_builder = Some(VersionBuilder { version_id: version_id.into(), project_id, @@ -966,3 +975,24 @@ pub fn get_name_ext( }; Ok((file_name, file_extension)) } + +async fn insert_version_create_event( + version_id: VersionId, + organization_id: Option, + current_user: &crate::models::users::User, + transaction: &mut sqlx::Transaction<'_, sqlx::Postgres>, +) -> Result<(), CreateError> { + let event = Event::new( + EventData::VersionCreated { + version_id: version_id.into(), + creator_id: organization_id.map_or_else( + || CreatorId::User(current_user.id.into()), + CreatorId::Organization, + ), + }, + transaction, + ) + .await?; + event.insert(transaction).await?; + Ok(()) +} diff --git a/src/routes/v3/version_file.rs b/src/routes/v3/version_file.rs index 558cf5f9..f3b5cc4d 100644 --- a/src/routes/v3/version_file.rs +++ b/src/routes/v3/version_file.rs @@ -211,7 +211,7 @@ pub async fn get_versions_from_hashes( let version_ids = files.iter().map(|x| x.version_id).collect::>(); let versions_data = filter_authorized_versions( database::models::Version::get_many(&version_ids, &**pool, &redis).await?, - &user_option, + user_option.as_ref(), &pool, ) .await?; @@ -259,7 +259,7 @@ pub async fn get_projects_from_hashes( let projects_data = filter_authorized_projects( database::models::Project::get_many_ids(&project_ids, &**pool, &redis).await?, - &user_option, + user_option.as_ref(), &pool, ) .await?; diff --git a/src/routes/v3/versions.rs b/src/routes/v3/versions.rs index b3504ee9..423b1663 100644 --- a/src/routes/v3/versions.rs +++ b/src/routes/v3/versions.rs @@ -129,7 +129,7 @@ pub async fn versions_get( .map(|x| x.1) .ok(); - let versions = filter_authorized_versions(versions_data, &user_option, &pool).await?; + let versions = filter_authorized_versions(versions_data, user_option.as_ref(), &pool).await?; Ok(HttpResponse::Ok().json(versions)) } @@ -808,7 +808,7 @@ pub async fn version_list( response.sort(); response.dedup_by(|a, b| a.inner.id == b.inner.id); - let response = filter_authorized_versions(response, &user_option, &pool).await?; + let response = filter_authorized_versions(response, user_option.as_ref(), &pool).await?; Ok(HttpResponse::Ok().json(response)) } else { diff --git a/tests/common/api_v2/project.rs b/tests/common/api_v2/project.rs index 658f21a6..956e135b 100644 --- a/tests/common/api_v2/project.rs +++ b/tests/common/api_v2/project.rs @@ -24,7 +24,7 @@ use crate::common::{ use super::ApiV2; impl ApiV2 { - pub async fn add_default_org_project(&self, org_id: &str, pat: &str) -> Project { + pub async fn add_default_org_project(&self, org_id: &str, pat: &str) -> LegacyProject { let project_create_data = get_public_project_creation_data("thisisaslug", None, Some(org_id)); let (project, _) = self.add_public_project(project_create_data, pat).await; diff --git a/tests/common/api_v3/organization.rs b/tests/common/api_v3/organization.rs index 700144c7..10cadbee 100644 --- a/tests/common/api_v3/organization.rs +++ b/tests/common/api_v3/organization.rs @@ -1,6 +1,5 @@ use actix_web::{dev::ServiceResponse, test::TestRequest}; - -use crate::common::actix::TestRequestExtensions; +use labrinth::util::actix::TestRequestExtensions; // TODO: extend other tests to do this use super::ApiV3; diff --git a/tests/common/api_v3/user.rs b/tests/common/api_v3/user.rs index 8529748f..6ca446ce 100644 --- a/tests/common/api_v3/user.rs +++ b/tests/common/api_v3/user.rs @@ -3,9 +3,8 @@ use actix_web::{ dev::ServiceResponse, test::{self, TestRequest}, }; -use labrinth::models::feeds::FeedItem; - -use crate::common::{actix::TestRequestExtensions, asserts::assert_status}; +use labrinth::{models::feeds::FeedItem, util::actix::TestRequestExtensions}; +use crate::common::asserts::assert_status; use super::ApiV3; diff --git a/tests/common/request_data.rs b/tests/common/request_data.rs index af56fe43..d5ee6c2c 100644 --- a/tests/common/request_data.rs +++ b/tests/common/request_data.rs @@ -30,7 +30,7 @@ pub fn get_public_project_creation_data( version_jar: Option, organization_id: Option<&str>, ) -> ProjectCreationRequestData { - let json_data = get_public_project_creation_data_json(slug, version_jar.as_ref()); + let json_data = get_public_project_creation_data_json(slug, version_jar.as_ref(), organization_id); let multipart_data = get_public_creation_data_multipart(&json_data, version_jar.as_ref()); ProjectCreationRequestData { slug: slug.to_string(), @@ -73,6 +73,7 @@ pub fn get_public_version_creation_data_json( pub fn get_public_project_creation_data_json( slug: &str, version_jar: Option<&TestFile>, + organization_id: Option<&str>, ) -> serde_json::Value { let initial_versions = if let Some(jar) = version_jar { json!([get_public_version_creation_data_json("1.2.3", jar)]) diff --git a/tests/organizations.rs b/tests/organizations.rs index f427c22f..32bedd9d 100644 --- a/tests/organizations.rs +++ b/tests/organizations.rs @@ -7,12 +7,11 @@ use crate::common::{ use actix_web::test; use bytes::Bytes; use common::{ - actix::AppendsMultipart, database::{FRIEND_USER_ID, FRIEND_USER_PAT, USER_USER_PAT}, permissions::{PermissionsTest, PermissionsTestContext}, request_data::get_public_project_creation_data, }; -use labrinth::models::teams::{OrganizationPermissions, ProjectPermissions}; +use labrinth::{models::teams::{OrganizationPermissions, ProjectPermissions}, util::actix::AppendsMultipart}; use serde_json::json; mod common; diff --git a/tests/project.rs b/tests/project.rs index 5bdf5890..7a8c78a5 100644 --- a/tests/project.rs +++ b/tests/project.rs @@ -288,10 +288,12 @@ async fn test_project_type_sanity() { let test_creation_mod = request_data::get_public_project_creation_data( "test-mod", Some(TestFile::build_random_jar()), + None, ); let test_creation_modpack = request_data::get_public_project_creation_data( "test-modpack", Some(TestFile::build_random_mrpack()), + None, ); for (mod_or_modpack, test_creation_data) in [ ("mod", test_creation_mod), diff --git a/tests/search.rs b/tests/search.rs index 9e87ee18..d7cedb0e 100644 --- a/tests/search.rs +++ b/tests/search.rs @@ -36,7 +36,7 @@ async fn search_projects() { TestFile::build_random_jar() }; let mut basic_project_json = - request_data::get_public_project_creation_data_json(&slug, Some(&jar)); + request_data::get_public_project_creation_data_json(&slug, Some(&jar), None); modify_json(&mut basic_project_json); let basic_project_multipart = From 647786e3966c2c3e6f2a230a04f56d6bb158584f Mon Sep 17 00:00:00 2001 From: Wyatt Verchere Date: Wed, 15 Nov 2023 10:25:22 -0800 Subject: [PATCH 71/72] fmt; clippy; prepare --- ...d602f5ff06c9e1a9e40b2ef8016011943502.json} | 7 +- ...28d2cbb26cfbb0ec94ecc8f0741f48178ec1c.json | 1 - ...cd7e38197123e2af364dd7d3b01c322345d7b.json | 36 --------- src/routes/v2/version_creation.rs | 2 +- src/routes/v3/organizations.rs | 17 ++--- src/routes/v3/project_creation.rs | 2 +- src/routes/v3/projects.rs | 4 +- src/routes/v3/users.rs | 76 +++++++++---------- src/routes/v3/version_creation.rs | 6 +- tests/common/api_v3/mod.rs | 2 +- tests/common/api_v3/user.rs | 2 +- tests/common/request_data.rs | 3 +- tests/feed.rs | 5 +- tests/organizations.rs | 5 +- 14 files changed, 61 insertions(+), 107 deletions(-) rename .sqlx/{query-fefb4f07a0f0c0cf74e554d120f8707d698fc8b4dbb66d2830f4ec0229bc1019.json => query-26bf18543c97850a1387221a8d15d602f5ff06c9e1a9e40b2ef8016011943502.json} (65%) delete mode 100644 .sqlx/query-b36877d60945eaae76680770a5d28d2cbb26cfbb0ec94ecc8f0741f48178ec1c.json delete mode 100644 .sqlx/query-fbfa35f79abe9662cfbad62db8ecd7e38197123e2af364dd7d3b01c322345d7b.json diff --git a/.sqlx/query-fefb4f07a0f0c0cf74e554d120f8707d698fc8b4dbb66d2830f4ec0229bc1019.json b/.sqlx/query-26bf18543c97850a1387221a8d15d602f5ff06c9e1a9e40b2ef8016011943502.json similarity index 65% rename from .sqlx/query-fefb4f07a0f0c0cf74e554d120f8707d698fc8b4dbb66d2830f4ec0229bc1019.json rename to .sqlx/query-26bf18543c97850a1387221a8d15d602f5ff06c9e1a9e40b2ef8016011943502.json index e2b9c106..e00c8502 100644 --- a/.sqlx/query-fefb4f07a0f0c0cf74e554d120f8707d698fc8b4dbb66d2830f4ec0229bc1019.json +++ b/.sqlx/query-26bf18543c97850a1387221a8d15d602f5ff06c9e1a9e40b2ef8016011943502.json @@ -1,6 +1,6 @@ { "db_name": "PostgreSQL", - "query": "\n INSERT INTO mods (\n id, team_id, title, description, body,\n published, downloads, icon_url, issues_url,\n source_url, wiki_url, status, requested_status, discord_url,\n license_url, license,\n slug, color, monetization_status\n )\n VALUES (\n $1, $2, $3, $4, $5,\n $6, $7, $8, $9,\n $10, $11, $12, $13, $14,\n $15, $16, \n LOWER($17), $18, $19\n )\n ", + "query": "\n INSERT INTO mods (\n id, team_id, title, description, body,\n published, downloads, icon_url, issues_url,\n source_url, wiki_url, status, requested_status, discord_url,\n license_url, license,\n slug, color, monetization_status,\n organization_id\n )\n VALUES (\n $1, $2, $3, $4, $5,\n $6, $7, $8, $9,\n $10, $11, $12, $13, $14,\n $15, $16, \n LOWER($17), $18, $19,\n $20\n )\n ", "describe": { "columns": [], "parameters": { @@ -23,10 +23,11 @@ "Varchar", "Text", "Int4", - "Varchar" + "Varchar", + "Int8" ] }, "nullable": [] }, - "hash": "fefb4f07a0f0c0cf74e554d120f8707d698fc8b4dbb66d2830f4ec0229bc1019" + "hash": "26bf18543c97850a1387221a8d15d602f5ff06c9e1a9e40b2ef8016011943502" } diff --git a/.sqlx/query-b36877d60945eaae76680770a5d28d2cbb26cfbb0ec94ecc8f0741f48178ec1c.json b/.sqlx/query-b36877d60945eaae76680770a5d28d2cbb26cfbb0ec94ecc8f0741f48178ec1c.json deleted file mode 100644 index 8b137891..00000000 --- a/.sqlx/query-b36877d60945eaae76680770a5d28d2cbb26cfbb0ec94ecc8f0741f48178ec1c.json +++ /dev/null @@ -1 +0,0 @@ - diff --git a/.sqlx/query-fbfa35f79abe9662cfbad62db8ecd7e38197123e2af364dd7d3b01c322345d7b.json b/.sqlx/query-fbfa35f79abe9662cfbad62db8ecd7e38197123e2af364dd7d3b01c322345d7b.json deleted file mode 100644 index a598df51..00000000 --- a/.sqlx/query-fbfa35f79abe9662cfbad62db8ecd7e38197123e2af364dd7d3b01c322345d7b.json +++ /dev/null @@ -1,36 +0,0 @@ -{ - "db_name": "PostgreSQL", - "query": "\n INSERT INTO mods (\n id, team_id, title, description, body,\n published, downloads, icon_url, issues_url,\n source_url, wiki_url, status, requested_status, discord_url,\n client_side, server_side, license_url, license,\n slug, project_type, color, monetization_status,\n organization_id\n )\n VALUES (\n $1, $2, $3, $4, $5,\n $6, $7, $8, $9,\n $10, $11, $12, $13, $14,\n $15, $16, $17, $18,\n LOWER($19), $20, $21, $22,\n $23\n )\n ", - "describe": { - "columns": [], - "parameters": { - "Left": [ - "Int8", - "Int8", - "Varchar", - "Varchar", - "Varchar", - "Timestamptz", - "Int4", - "Varchar", - "Varchar", - "Varchar", - "Varchar", - "Varchar", - "Varchar", - "Varchar", - "Int4", - "Int4", - "Varchar", - "Varchar", - "Text", - "Int4", - "Int4", - "Varchar", - "Int8" - ] - }, - "nullable": [] - }, - "hash": "fbfa35f79abe9662cfbad62db8ecd7e38197123e2af364dd7d3b01c322345d7b" -} diff --git a/src/routes/v2/version_creation.rs b/src/routes/v2/version_creation.rs index ebc1180a..5652e829 100644 --- a/src/routes/v2/version_creation.rs +++ b/src/routes/v2/version_creation.rs @@ -168,4 +168,4 @@ pub async fn upload_file_to_version( ) .await?; Ok(response) -} \ No newline at end of file +} diff --git a/src/routes/v3/organizations.rs b/src/routes/v3/organizations.rs index 4ab72687..552fc118 100644 --- a/src/routes/v3/organizations.rs +++ b/src/routes/v3/organizations.rs @@ -4,6 +4,7 @@ use std::sync::Arc; use super::ApiError; use crate::auth::{filter_authorized_projects, get_user_from_headers}; use crate::database::models::team_item::TeamMember; +use crate::database::models::DatabaseError; use crate::database::models::{generate_organization_id, team_item, Organization}; use crate::database::redis::RedisPool; use crate::file_hosting::FileHost; @@ -17,14 +18,12 @@ use crate::util::routes::read_from_payload; use crate::util::validate::validation_errors_to_string; use crate::{database, models}; use actix_web::{web, HttpRequest, HttpResponse}; +use database::models::creator_follows::OrganizationFollow as DBOrganizationFollow; +use database::models::organization_item::Organization as DBOrganization; use rust_decimal::Decimal; use serde::{Deserialize, Serialize}; use sqlx::PgPool; use validator::Validate; -use crate::database::models::DatabaseError; -use database::models::creator_follows::OrganizationFollow as DBOrganizationFollow; -use database::models::organization_item::Organization as DBOrganization; - pub fn config(cfg: &mut web::ServiceConfig) { cfg.service( @@ -44,14 +43,8 @@ pub fn config(cfg: &mut web::ServiceConfig) { "{id}/members", web::get().to(super::teams::team_members_get_organization), ) - .route( - "{id}/follow", - web::post().to(organization_follow), - ) - .route( - "{id}/follow", - web::delete().to(organization_unfollow), - ), + .route("{id}/follow", web::post().to(organization_follow)) + .route("{id}/follow", web::delete().to(organization_unfollow)), ); } diff --git a/src/routes/v3/project_creation.rs b/src/routes/v3/project_creation.rs index 7f115383..f8b9ae14 100644 --- a/src/routes/v3/project_creation.rs +++ b/src/routes/v3/project_creation.rs @@ -14,7 +14,7 @@ use crate::models::pats::Scopes; use crate::models::projects::{ DonationLink, License, MonetizationStatus, ProjectId, ProjectStatus, VersionId, VersionStatus, }; -use crate::models::teams::{ProjectPermissions, OrganizationPermissions}; +use crate::models::teams::{OrganizationPermissions, ProjectPermissions}; use crate::models::threads::ThreadType; use crate::models::users::UserId; use crate::queue::session::AuthQueue; diff --git a/src/routes/v3/projects.rs b/src/routes/v3/projects.rs index a4c5112c..228dee8a 100644 --- a/src/routes/v3/projects.rs +++ b/src/routes/v3/projects.rs @@ -1,7 +1,7 @@ use std::sync::Arc; use crate::auth::{filter_authorized_projects, get_user_from_headers, is_authorized}; -use crate::database::models::event_item::{EventData, CreatorId}; +use crate::database::models::event_item::{CreatorId, EventData}; use crate::database::models::notification_item::NotificationBuilder; use crate::database::models::project_item::{GalleryItem, ModCategory}; use crate::database::models::thread_item::ThreadMessageBuilder; @@ -27,7 +27,7 @@ use crate::util::routes::read_from_payload; use crate::util::validate::validation_errors_to_string; use actix_web::{web, HttpRequest, HttpResponse}; use chrono::{DateTime, Utc}; -use db_ids::{UserId, OrganizationId}; +use db_ids::{OrganizationId, UserId}; use futures::TryStreamExt; use meilisearch_sdk::indexes::IndexesResults; use serde::{Deserialize, Serialize}; diff --git a/src/routes/v3/users.rs b/src/routes/v3/users.rs index b3c01f5b..79f8e567 100644 --- a/src/routes/v3/users.rs +++ b/src/routes/v3/users.rs @@ -23,21 +23,21 @@ use crate::{ database::{models::User, redis::RedisPool}, file_hosting::FileHost, models::{ - feeds::{FeedItem, FeedItemBody}, - ids::{ProjectId, VersionId}, - projects::Version, collections::{Collection, CollectionStatus}, + feeds::{FeedItem, FeedItemBody}, ids::UserId, + ids::{ProjectId, VersionId}, notifications::Notification, pats::Scopes, projects::Project, + projects::Version, users::{Badges, Payout, PayoutStatus, RecipientStatus, Role, UserPayoutData}, }, queue::{payouts::PayoutsQueue, session::AuthQueue}, util::{routes::read_from_payload, validate::validation_errors_to_string}, }; -use std::iter::FromIterator; use itertools::Itertools; +use std::iter::FromIterator; use database::models as db_models; use database::models::creator_follows::OrganizationFollow as DBOrganizationFollow; @@ -67,7 +67,7 @@ pub fn config(cfg: &mut web::ServiceConfig) { .route("{id}/payouts_fees", web::get().to(user_payouts_fees)) .route("{id}/payouts", web::post().to(user_payouts_request)) .route("{id}/follow", web::post().to(user_follow)) - .route("{id}/follow", web::delete().to(user_unfollow)) + .route("{id}/follow", web::delete().to(user_unfollow)), ); } @@ -1036,10 +1036,7 @@ pub async fn current_user_feed( // - Projects created by organizations you follow // - Versions created by users you follow // - Versions created by organizations you follow - let event_types = [ - EventType::ProjectPublished, - EventType::VersionCreated, - ]; + let event_types = [EventType::ProjectPublished, EventType::VersionCreated]; let selectors = followed_users .into_iter() .flat_map(|follow| { @@ -1081,39 +1078,38 @@ pub async fn current_user_feed( println!("authorized projects"); for event in events { - let body = - match event.event_data { - EventData::ProjectPublished { - project_id, - creator_id, - } => authorized_projects.get(&project_id.into()).map(|p| { - FeedItemBody::ProjectPublished { - project_id: project_id.into(), + let body = match event.event_data { + EventData::ProjectPublished { + project_id, + creator_id, + } => authorized_projects.get(&project_id.into()).map(|p| { + FeedItemBody::ProjectPublished { + project_id: project_id.into(), + creator_id: creator_id.into(), + project_title: p.title.clone(), + } + }), + EventData::VersionCreated { + version_id, + creator_id, + } => { + let authorized_version = authorized_versions.get(&version_id.into()); + let authorized_project = + authorized_version.and_then(|v| authorized_projects.get(&v.project_id)); + if let (Some(authorized_version), Some(authorized_project)) = + (authorized_version, authorized_project) + { + Some(FeedItemBody::VersionCreated { + project_id: authorized_project.id, + version_id: authorized_version.id, creator_id: creator_id.into(), - project_title: p.title.clone(), - } - }), - EventData::VersionCreated { - version_id, - creator_id, - } => { - let authorized_version = authorized_versions.get(&version_id.into()); - let authorized_project = - authorized_version.and_then(|v| authorized_projects.get(&v.project_id)); - if let (Some(authorized_version), Some(authorized_project)) = - (authorized_version, authorized_project) - { - Some(FeedItemBody::VersionCreated { - project_id: authorized_project.id, - version_id: authorized_version.id, - creator_id: creator_id.into(), - project_title: authorized_project.title.clone(), - }) - } else { - None - } + project_title: authorized_project.title.clone(), + }) + } else { + None } - }; + } + }; if let Some(body) = body { let feed_item = FeedItem { diff --git a/src/routes/v3/version_creation.rs b/src/routes/v3/version_creation.rs index 816d3304..94a650f6 100644 --- a/src/routes/v3/version_creation.rs +++ b/src/routes/v3/version_creation.rs @@ -1,12 +1,12 @@ use super::project_creation::{CreateError, UploadedFile}; use crate::auth::get_user_from_headers; -use crate::database::models::event_item::{EventData, CreatorId}; +use crate::database::models::event_item::{CreatorId, EventData}; use crate::database::models::loader_fields::{LoaderField, LoaderFieldEnumValue, VersionField}; use crate::database::models::notification_item::NotificationBuilder; use crate::database::models::version_item::{ DependencyBuilder, VersionBuilder, VersionFileBuilder, }; -use crate::database::models::{self, image_item, Organization, Event}; +use crate::database::models::{self, image_item, Event, Organization}; use crate::database::redis::RedisPool; use crate::file_hosting::FileHost; use crate::models::images::{Image, ImageContext, ImageId}; @@ -320,7 +320,7 @@ async fn version_create_inner( transaction, ) .await?; - + version_builder = Some(VersionBuilder { version_id: version_id.into(), project_id, diff --git a/tests/common/api_v3/mod.rs b/tests/common/api_v3/mod.rs index b5e88f9b..26ad8888 100644 --- a/tests/common/api_v3/mod.rs +++ b/tests/common/api_v3/mod.rs @@ -7,8 +7,8 @@ use std::rc::Rc; pub mod oauth; pub mod oauth_clients; pub mod organization; -pub mod user; pub mod tags; +pub mod user; #[derive(Clone)] pub struct ApiV3 { diff --git a/tests/common/api_v3/user.rs b/tests/common/api_v3/user.rs index 6ca446ce..e5a4f9ef 100644 --- a/tests/common/api_v3/user.rs +++ b/tests/common/api_v3/user.rs @@ -1,10 +1,10 @@ +use crate::common::asserts::assert_status; use actix_http::StatusCode; use actix_web::{ dev::ServiceResponse, test::{self, TestRequest}, }; use labrinth::{models::feeds::FeedItem, util::actix::TestRequestExtensions}; -use crate::common::asserts::assert_status; use super::ApiV3; diff --git a/tests/common/request_data.rs b/tests/common/request_data.rs index d5ee6c2c..bbdb1fdb 100644 --- a/tests/common/request_data.rs +++ b/tests/common/request_data.rs @@ -30,7 +30,8 @@ pub fn get_public_project_creation_data( version_jar: Option, organization_id: Option<&str>, ) -> ProjectCreationRequestData { - let json_data = get_public_project_creation_data_json(slug, version_jar.as_ref(), organization_id); + let json_data = + get_public_project_creation_data_json(slug, version_jar.as_ref(), organization_id); let multipart_data = get_public_creation_data_multipart(&json_data, version_jar.as_ref()); ProjectCreationRequestData { slug: slug.to_string(), diff --git a/tests/feed.rs b/tests/feed.rs index 09efe51c..46324c43 100644 --- a/tests/feed.rs +++ b/tests/feed.rs @@ -1,8 +1,5 @@ use crate::common::{ - asserts::{ - assert_feed_contains_project_created, - assert_feed_contains_version_created, - }, + asserts::{assert_feed_contains_project_created, assert_feed_contains_version_created}, dummy_data::DummyProjectAlpha, }; use assert_matches::assert_matches; diff --git a/tests/organizations.rs b/tests/organizations.rs index 32bedd9d..da29b65f 100644 --- a/tests/organizations.rs +++ b/tests/organizations.rs @@ -11,7 +11,10 @@ use common::{ permissions::{PermissionsTest, PermissionsTestContext}, request_data::get_public_project_creation_data, }; -use labrinth::{models::teams::{OrganizationPermissions, ProjectPermissions}, util::actix::AppendsMultipart}; +use labrinth::{ + models::teams::{OrganizationPermissions, ProjectPermissions}, + util::actix::AppendsMultipart, +}; use serde_json::json; mod common; From 5ffa116246bc3a44af48ff8037b4536b33f43847 Mon Sep 17 00:00:00 2001 From: Wyatt Verchere Date: Fri, 17 Nov 2023 18:29:40 -0800 Subject: [PATCH 72/72] merge conflicts, fixes deleted data --- tests/common/api_v2/project.rs | 5 ++++- tests/common/api_v3/mod.rs | 2 +- tests/common/api_v3/organization.rs | 22 +++++++++++++++------- tests/common/api_v3/project.rs | 9 ++++++++- tests/common/api_v3/request_data.rs | 13 ++++++++++--- tests/feed.rs | 15 ++++++++++----- tests/organizations.rs | 2 +- tests/v2/project.rs | 2 ++ tests/v2/search.rs | 2 +- 9 files changed, 52 insertions(+), 20 deletions(-) diff --git a/tests/common/api_v2/project.rs b/tests/common/api_v2/project.rs index 218f23dd..d57ffeee 100644 --- a/tests/common/api_v2/project.rs +++ b/tests/common/api_v2/project.rs @@ -17,7 +17,10 @@ use std::collections::HashMap; use crate::common::{asserts::assert_status, database::MOD_USER_PAT}; -use super::{request_data::get_public_project_creation_data, ImageData, ApiV2}; +use super::{ + request_data::{get_public_project_creation_data, ImageData}, + ApiV2, +}; impl ApiV2 { pub async fn add_default_org_project(&self, org_id: &str, pat: &str) -> LegacyProject { diff --git a/tests/common/api_v3/mod.rs b/tests/common/api_v3/mod.rs index 842ef746..bdc7e721 100644 --- a/tests/common/api_v3/mod.rs +++ b/tests/common/api_v3/mod.rs @@ -11,8 +11,8 @@ pub mod project; pub mod request_data; pub mod tags; pub mod team; -pub mod version; pub mod user; +pub mod version; #[derive(Clone)] pub struct ApiV3 { diff --git a/tests/common/api_v3/organization.rs b/tests/common/api_v3/organization.rs index 17b0d589..689b2648 100644 --- a/tests/common/api_v3/organization.rs +++ b/tests/common/api_v3/organization.rs @@ -4,7 +4,8 @@ use actix_web::{ }; use bytes::Bytes; use labrinth::models::{organizations::Organization, v3::projects::Project}; -use serde_json::json; +use labrinth::util::actix::TestRequestExtensions; +use serde_json::json; // TODO: extend other tests to do this use super::{request_data::ImageData, ApiV3}; @@ -26,6 +27,19 @@ impl ApiV3 { self.call(req).await } + pub async fn create_organization_deserialized( + &self, + organization_title: &str, + description: &str, + pat: &str, + ) -> Organization { + let resp = self + .create_organization(organization_title, description, pat) + .await; + assert_eq!(resp.status(), 200); + test::read_body_json(resp).await + } + pub async fn get_organization(&self, id_or_title: &str, pat: &str) -> ServiceResponse { let req = TestRequest::get() .uri(&format!("/v3/organization/{id_or_title}")) @@ -147,13 +161,7 @@ impl ApiV3 { self.call(req).await } -} -use actix_web::{dev::ServiceResponse, test::TestRequest}; -use labrinth::util::actix::TestRequestExtensions; // TODO: extend other tests to do this - -use super::ApiV3; -impl ApiV3 { pub async fn follow_organization(&self, organization_id: &str, pat: &str) -> ServiceResponse { let req = TestRequest::post() .uri(&format!("/v3/organization/{}/follow", organization_id)) diff --git a/tests/common/api_v3/project.rs b/tests/common/api_v3/project.rs index b4365d9c..8100c36f 100644 --- a/tests/common/api_v3/project.rs +++ b/tests/common/api_v3/project.rs @@ -18,11 +18,18 @@ use serde_json::json; use crate::common::{asserts::assert_status, database::MOD_USER_PAT}; use super::{ - request_data::{ImageData, ProjectCreationRequestData}, + request_data::{get_public_project_creation_data, ImageData, ProjectCreationRequestData}, ApiV3, }; impl ApiV3 { + pub async fn add_default_org_project(&self, org_id: &str, pat: &str) -> Project { + let project_create_data = + get_public_project_creation_data("thisisaslug", None, Some(org_id)); + let (project, _) = self.add_public_project(project_create_data, pat).await; + project + } + pub async fn add_public_project( &self, creation_data: ProjectCreationRequestData, diff --git a/tests/common/api_v3/request_data.rs b/tests/common/api_v3/request_data.rs index 6091992c..50ed8a59 100644 --- a/tests/common/api_v3/request_data.rs +++ b/tests/common/api_v3/request_data.rs @@ -28,8 +28,10 @@ pub struct ImageData { pub fn get_public_project_creation_data( slug: &str, version_jar: Option, + organization_id: Option<&str>, ) -> ProjectCreationRequestData { - let json_data = get_public_project_creation_data_json(slug, version_jar.as_ref()); + let json_data = + get_public_project_creation_data_json(slug, version_jar.as_ref(), organization_id); let multipart_data = get_public_creation_data_multipart(&json_data, version_jar.as_ref()); ProjectCreationRequestData { slug: slug.to_string(), @@ -88,6 +90,7 @@ pub fn get_public_version_creation_data_json( pub fn get_public_project_creation_data_json( slug: &str, version_jar: Option<&TestFile>, + organization_id: Option<&str>, ) -> serde_json::Value { let initial_versions = if let Some(jar) = version_jar { json!([get_public_version_creation_data_json("1.2.3", jar)]) @@ -96,7 +99,7 @@ pub fn get_public_project_creation_data_json( }; let is_draft = version_jar.is_none(); - json!( + let mut j = json!( { "title": format!("Test Project {slug}"), "slug": slug, @@ -107,7 +110,11 @@ pub fn get_public_project_creation_data_json( "categories": [], "license_id": "MIT", } - ) + ); + if let Some(organization_id) = organization_id { + j["organization_id"] = json!(organization_id); + } + j } pub fn get_public_creation_data_multipart( diff --git a/tests/feed.rs b/tests/feed.rs index 46324c43..b59d9415 100644 --- a/tests/feed.rs +++ b/tests/feed.rs @@ -51,7 +51,7 @@ async fn get_feed_after_following_user_shows_previously_created_public_versions( // Add version let v = env - .v2 + .v3 .create_default_version(&alpha_project_id, None, USER_USER_PAT) .await; @@ -74,8 +74,8 @@ async fn get_feed_after_following_user_shows_previously_created_public_versions( async fn get_feed_when_following_user_that_creates_project_as_org_only_shows_event_when_following_org( ) { with_test_environment(|env| async move { - let org_id = env.v2.create_default_organization(USER_USER_PAT).await; - let project = env.v2.add_default_org_project(&org_id, USER_USER_PAT).await; + let org_id = env.v3.create_organization_deserialized("test", "desc", USER_USER_PAT).await.id.to_string(); + let project = env.v3.add_default_org_project(&org_id, USER_USER_PAT).await; env.v3.follow_user(USER_USER_ID, FRIEND_USER_PAT).await; let feed = env.v3.get_feed(FRIEND_USER_PAT).await; @@ -107,8 +107,13 @@ async fn get_feed_after_unfollowing_user_no_longer_shows_feed_items() { #[actix_rt::test] async fn get_feed_after_unfollowing_organization_no_longer_shows_feed_items() { with_test_environment(|env| async move { - let org_id = env.v2.create_default_organization(USER_USER_PAT).await; - env.v2.add_default_org_project(&org_id, USER_USER_PAT).await; + let org_id = env + .v3 + .create_organization_deserialized("test", "desc", USER_USER_PAT) + .await + .id + .to_string(); + env.v3.add_default_org_project(&org_id, USER_USER_PAT).await; env.v3.follow_organization(&org_id, FRIEND_USER_PAT).await; env.v3.unfollow_organization(&org_id, FRIEND_USER_PAT).await; diff --git a/tests/organizations.rs b/tests/organizations.rs index 3e1f556e..b51bbebd 100644 --- a/tests/organizations.rs +++ b/tests/organizations.rs @@ -7,9 +7,9 @@ use crate::common::{ use actix_web::test; use bytes::Bytes; use common::{ + api_v3::request_data::get_public_project_creation_data, database::{FRIEND_USER_ID, FRIEND_USER_PAT, USER_USER_PAT}, permissions::{PermissionsTest, PermissionsTestContext}, - request_data::get_public_project_creation_data, }; use labrinth::{ models::teams::{OrganizationPermissions, ProjectPermissions}, diff --git a/tests/v2/project.rs b/tests/v2/project.rs index 609b8481..07b32181 100644 --- a/tests/v2/project.rs +++ b/tests/v2/project.rs @@ -23,10 +23,12 @@ async fn test_project_type_sanity() { let test_creation_mod = request_data::get_public_project_creation_data( "test-mod", Some(TestFile::build_random_jar()), + None, ); let test_creation_modpack = request_data::get_public_project_creation_data( "test-modpack", Some(TestFile::build_random_mrpack()), + None, ); for (mod_or_modpack, test_creation_data) in [ ("mod", test_creation_mod), diff --git a/tests/v2/search.rs b/tests/v2/search.rs index fbe39ca6..1e7beb74 100644 --- a/tests/v2/search.rs +++ b/tests/v2/search.rs @@ -37,7 +37,7 @@ async fn search_projects() { TestFile::build_random_jar() }; let mut basic_project_json = - request_data::get_public_project_creation_data_json(&slug, Some(&jar)); + request_data::get_public_project_creation_data_json(&slug, Some(&jar), None); modify_json(&mut basic_project_json); let basic_project_multipart =