diff --git a/Cargo.toml b/Cargo.toml index 0cc6254..d744983 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -32,4 +32,5 @@ regex = "1.10.0" ctor = "0.2.6" axum-client-ip = "0.4.0" jsonwebtoken = "9" -tower = "0.4.13" \ No newline at end of file +tower = "0.4.13" +mockall = "0.13.1" diff --git a/src/config.rs b/src/config.rs index b2cae9e..d661664 100644 --- a/src/config.rs +++ b/src/config.rs @@ -229,8 +229,10 @@ pub fn load() -> Config { let args: Vec = env::args().collect(); let config_path = if args.len() <= 1 { "config.toml" - } else { + } else if (args.len() > 1) && (args.get(1).unwrap().contains("toml")) { args.get(1).unwrap() + } else { + "config.toml" }; let file_contents = fs::read_to_string(config_path); if file_contents.is_err() { diff --git a/src/endpoints/get_quest_participants.rs b/src/endpoints/get_quest_participants.rs index 4ae8f29..3623292 100644 --- a/src/endpoints/get_quest_participants.rs +++ b/src/endpoints/get_quest_participants.rs @@ -52,44 +52,49 @@ pub async fn handler( } } }, + // First group by address to get the max timestamp for each participant doc! { "$group": { "_id": "$address", - "count" : { "$sum": 1 } + "count": { "$sum": 1 }, + "last_completion": { "$max": "$timestamp" } // Get the timestamp of their last task } }, + // Filter for participants who completed all tasks doc! { "$match": { "count": tasks_count as i64 } }, + // Sort by last completion time + doc! { + "$sort": { + "last_completion": 1 + } + }, doc! { "$facet": { - "count": [ - { - "$count": "count" - } + "total": [ + { "$count": "count" } ], - "firstParticipants": [ - { - "$limit": 3 - } + "participants": [ + { "$limit": 3 }, + { "$project": { + "address": "$_id", + "completion_time": "$last_completion", + "_id": 0 + }} ] } }, doc! { "$project": { - "count": { - "$arrayElemAt": [ - "$count.count", - 0 - ] - }, - "firstParticipants": "$firstParticipants._id" + "count": { "$ifNull": [{ "$arrayElemAt": ["$total.count", 0] }, 0] }, + "first_participants": "$participants" } - }, + } ]; - + let completed_tasks_collection = state.db.collection::("completed_tasks"); let mut cursor = completed_tasks_collection .aggregate(pipeline, None) @@ -106,5 +111,133 @@ pub async fn handler( } } - return (StatusCode::OK, Json(res)).into_response(); + (StatusCode::OK, Json(res)).into_response() } +#[cfg(test)] +mod tests { + use crate::{config::{self, Config}, logger}; + + use super::*; + use mongodb::{bson::doc, Client, Database}; + use reqwest::Url; + use starknet::providers::{jsonrpc::HttpTransport, JsonRpcClient}; + use tokio::sync::Mutex; + use std::sync::Arc; + use axum::{body::Bytes, http::StatusCode}; + use serde_json::Value; + use axum::body::HttpBody; + + async fn setup_test_db() -> Database { + let client = Client::with_uri_str("mongodb://localhost:27017") + .await + .expect("Failed to create MongoDB client"); + let db = client.database("test_db"); + + // Clear collections before each test + db.collection::("tasks").drop(None).await.ok(); + db.collection::("completed_tasks").drop(None).await.ok(); + + db + } + + async fn insert_test_data(db: Database, quest_id: i64, num_tasks: i64, num_participants: i64) { + let tasks_collection = db.collection::("tasks"); + let completed_tasks_collection = db.collection::("completed_tasks"); + + // Insert tasks + for task_id in 1..=num_tasks { + tasks_collection + .insert_one( + doc! { + "id": task_id, + "quest_id": quest_id, + }, + None, + ) + .await + .unwrap(); + } + + // Insert completed tasks for participants + for participant in 1..=num_participants { + let address = format!("participant_{}", participant); + let base_timestamp = 1000 - (participant * 10); // Spaces out timestamps more clearly + + for task_id in 1..=num_tasks { + completed_tasks_collection + .insert_one( + doc! { + "address": address.clone(), + "task_id": task_id, + // Last task for each participant will have the highest timestamp + "timestamp": base_timestamp + task_id + }, + None, + ) + .await + .unwrap(); + } + } + } + + #[tokio::test] + async fn test_get_quest_participants() { + // Setup + let db = setup_test_db().await; + let conf = config::load(); + let logger = logger::Logger::new(&conf.watchtower); + let provider= JsonRpcClient::new(HttpTransport::new( + Url::parse(&conf.variables.rpc_url).unwrap(), + )); + + let app_state = Arc::new(AppState { db: db.clone(), last_task_id: Mutex::new(0), last_question_id: Mutex::new(0), conf, logger, provider }); + + // Test data + let quest_id = 1; + let num_tasks = 3; + let num_participants = 5; + + insert_test_data(db.clone(), quest_id, num_tasks, num_participants).await; + + // Create request + let query = GetQuestParticipantsQuery { + quest_id: + quest_id as u32, + }; + + // Execute request + let response = handler( + State(app_state), + Query(query), + ) + .await + .into_response(); + + // Verify response + assert_eq!(response.status(), StatusCode::OK); + + + // Get the response body as bytes + let body_bytes = match response.into_body().data().await { + Some(Ok(bytes)) => bytes, + _ => panic!("Failed to get response body"), + }; + + // Parse the body + let body: Value = serde_json::from_slice(&body_bytes).unwrap(); + + assert_eq!(body["count"], num_participants); + assert_eq!(body["first_participants"].as_array().unwrap().len(), 3); + println!("{:?}", body); + + // Verify first participants format + let first_participants = body["first_participants"].as_array().unwrap(); + for participant in first_participants { + assert!(participant.as_str().unwrap().starts_with("participant_")); + } + + // Verify quest completion timestamp format + let quest_completion_timestamp = body["first_participants"][0]["completion_time"].as_i64().unwrap(); + assert_eq!(quest_completion_timestamp, 973); + } +} \ No newline at end of file diff --git a/src/tests/endpoints.rs b/src/tests/endpoints.rs index 1ed2a89..9ad02c7 100644 --- a/src/tests/endpoints.rs +++ b/src/tests/endpoints.rs @@ -26,4 +26,8 @@ pub mod tests { let response = client.get(endpoint).send().await.unwrap(); assert_eq!(response.status(), StatusCode::OK); } + + // #[tokio::test] + // pub async fn test_get_quest_participants() { + // let endpoint = format!("http:// }