Skip to content

Commit

Permalink
add macros for generating payload structs, use them in generated task…
Browse files Browse the repository at this point in the history
…s controller
  • Loading branch information
hdoordt committed Oct 29, 2024
1 parent c1c89a8 commit bddbd86
Show file tree
Hide file tree
Showing 4 changed files with 173 additions and 22 deletions.
2 changes: 1 addition & 1 deletion blueprint/db/src/entities/tasks.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ pub struct Task {
/// ```
/// let task_changeset: TaskChangeset = Faker.fake();
/// ```
#[derive(Deserialize, Validate, Clone)]
#[derive(Debug, Deserialize, Validate, Clone)]
#[cfg_attr(feature = "test-helpers", derive(Serialize, Dummy))]
pub struct TaskChangeset {
/// The description must be at least 1 character long.
Expand Down
129 changes: 126 additions & 3 deletions blueprint/macros/src/lib.rs.liquid
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
//! The {{crate_name}}-macros crate contains the `test`{%- unless template_type == "minimal" %} and `db_test`{%- endunless %} macro{%- unless template_type == "minimal" -%} s{% endunless -%}.

use proc_macro::TokenStream;
use quote::quote;
use syn::{parse_macro_input, ItemFn};
use quote::{quote, ToTokens};
use syn::{parse_macro_input, Fields, Ident, ItemFn, ItemStruct, Type};

#[allow(clippy::test_attr_in_doctest)]
/// Used to mark an application test.
Expand Down Expand Up @@ -110,4 +110,127 @@ pub fn db_test(_: TokenStream, item: TokenStream) -> TokenStream {

TokenStream::from(output)
}
{%- endunless %}
{%- endunless %}

#[proc_macro_attribute]
pub fn request_payload(_: TokenStream, item: TokenStream) -> TokenStream {
let input = parse_macro_input!(item as ItemStruct);
let PayloadStructInfo {
outer_ty,
inner_ty,
inner_ty_lit_str,
} = PayloadStructInfo::from_input(&input);

TokenStream::from(quote! {
#[derive(::serde::Deserialize)]
#[serde(try_from = #inner_ty_lit_str)]
#input

impl TryFrom<#inner_ty> for #outer_ty {
type Error = ::validator::ValidationErrors;

fn try_from(inner: #inner_ty) -> Result<Self, Self::Error> {
::validator::Validate::validate(&inner)?;
Ok(Self(inner))
}
}

impl From<#outer_ty> for #inner_ty {
fn from(#outer_ty(inner): #outer_ty) -> Self {
inner
}
}
})
}

#[proc_macro_attribute]
pub fn batch_request_payload(_: TokenStream, item: TokenStream) -> TokenStream {
let input = parse_macro_input!(item as ItemStruct);
let PayloadStructInfo {
outer_ty,
inner_ty,
inner_ty_lit_str,
} = PayloadStructInfo::from_input(&input);

TokenStream::from(quote! {
#[derive(::serde::Deserialize)]
#[serde(try_from = #inner_ty_lit_str)]
#input

impl TryFrom<#inner_ty> for #outer_ty {
type Error = ::validator::ValidationErrors;

fn try_from(inner: #inner_ty) -> Result<Self, Self::Error> {
let cap = inner.len();

inner
.into_iter()
.try_fold(Vec::with_capacity(cap), |mut v, item| {
::validator::Validate::validate(&item)?;
v.push(item);
Ok(v)
})
.map(Self)
}
}

impl From<#outer_ty> for #inner_ty {
fn from(#outer_ty(inner): #outer_ty) -> Self {
inner
}
}
})
}

#[proc_macro_attribute]
pub fn response_payload(_: TokenStream, item: TokenStream) -> TokenStream {
let input = parse_macro_input!(item as ItemStruct);
let PayloadStructInfo {
outer_ty,
inner_ty,
inner_ty_lit_str,
} = PayloadStructInfo::from_input(&input);

TokenStream::from(quote! {
#[derive(::serde::Serialize)]
#[serde(try_from = #inner_ty_lit_str)]
#input

impl From<#inner_ty> for #outer_ty {
fn from(inner: #inner_ty) -> Self {
Self(inner)
}
}
})
}

struct PayloadStructInfo<'input> {
outer_ty: &'input Ident,
inner_ty: &'input Type,
inner_ty_lit_str: String,
}

impl<'input> PayloadStructInfo<'input> {
fn from_input(input: &'input ItemStruct) -> Self {
fn error() -> ! {
panic!("Macro can only be applied to tuple structs with a single field")
}

let outer_ty = &input.ident;

let Fields::Unnamed(fields) = &input.fields else {
error()
};
let mut fields = fields.unnamed.iter();
let Some(field) = fields.next() else { error() };
let None = fields.next() else { error() };

let inner_ty = &field.ty;
let inner_ty_lit_str = inner_ty.clone().to_token_stream().to_string();
Self {
outer_ty,
inner_ty,
inner_ty_lit_str,
}
}
}
8 changes: 3 additions & 5 deletions blueprint/web/Cargo.toml.liquid
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ publish = false
doctest = false

[features]
test-helpers = ["dep:serde_json", "dep:tower", "dep:hyper", "dep:{{project-name}}-macros"]
test-helpers = ["dep:serde_json", "dep:tower", "dep:hyper"]

[dependencies]
anyhow = "1.0"
Expand All @@ -31,10 +31,8 @@ serde_json = { version = "1.0", optional = true }
thiserror = "1.0"
tower = { version = "0.5", features = ["util"], optional = true }
hyper = { version = "1.0", features = ["full"], optional = true }
{% unless template_type == "minimal" -%}
validator = "0.18"
{%- endunless %}
{{project-name}}-macros = { path = "../macros", optional = true }
{{project-name}}-macros = { path = "../macros" }
validator = { version = "0.18.1", features = ["derive"] }

[dev-dependencies]
fake = "2.9"
Expand Down
56 changes: 43 additions & 13 deletions blueprint/web/src/controllers/tasks.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use crate::{error::Error, state::AppState};
use axum::{extract::Path, extract::State, http::StatusCode, Json};
use {{crate_name}}_db::{entities::tasks, transaction};
use payloads::*;
use tracing::info;
use uuid::Uuid;

Expand All @@ -10,34 +11,34 @@ use uuid::Uuid;
#[axum::debug_handler]
pub async fn create(
State(app_state): State<AppState>,
Json(task): Json<tasks::TaskChangeset>,
) -> Result<(StatusCode, Json<tasks::Task>), Error> {
Ok(tasks::create(task, &app_state.db_pool)
Json(payload): Json<CreateRequestPayload>,
) -> Result<(StatusCode, Json<CreateResponsePayload>), Error> {
Ok(tasks::create(payload.into(), &app_state.db_pool)
.await
.map(|task| (StatusCode::CREATED, Json(task)))?)
.map(|task| (StatusCode::CREATED, Json(task.into())))?)
}

/// Creates multiple tasks in the database.
///
/// This function creates multiple tasks in the database (see [`{{crate_name}}_db::entities::tasks::create`]) based on [`{{crate_name}}_db::entities::tasks::TaskChangeset`]s (sent as JSON). If all tasks are created successfully, a 201 response is returned with the created [`{{crate_name}}_db::entities::tasks::Task`]s' JSON representation in the response body. If any of the passed changesets is invalid, a 422 response is returned.
/// This function creates multiple tasks in the database (see [`getest_db::entities::tasks::create`]) based on [`getest_db::entities::tasks::TaskChangeset`]s (sent as JSON). If all tasks are created successfully, a 201 response is returned with the created [`getest_db::entities::tasks::Task`]s' JSON representation in the response body. If any of the passed changesets is invalid, a 422 response is returned.
///
/// This function creates all tasks in a transaction so that either all are created successfully or none is.
#[axum::debug_handler]
pub async fn create_batch(
State(app_state): State<AppState>,
Json(tasks): Json<Vec<tasks::TaskChangeset>>,
) -> Result<(StatusCode, Json<Vec<tasks::Task>>), Error> {
Json(payload): Json<CreateBatchRequestPayload>,
) -> Result<(StatusCode, Json<CreateBatchResponsePayload>), Error> {
let mut tx = transaction(&app_state.db_pool).await?;

let mut results: Vec<tasks::Task> = vec![];
for task in tasks {
for task in Vec::<_>::from(payload) {
let task = tasks::create(task, &mut *tx).await?;
results.push(task);
}

tx.commit().await.map_err(anyhow::Error::from)?;

Ok((StatusCode::CREATED, Json(results)))
Ok((StatusCode::CREATED, Json(results.into())))
}

/// Reads and responds with all the tasks currently present in the database.
Expand Down Expand Up @@ -71,10 +72,10 @@ pub async fn read_one(
pub async fn update(
State(app_state): State<AppState>,
Path(id): Path<Uuid>,
Json(task): Json<tasks::TaskChangeset>,
) -> Result<Json<tasks::Task>, Error> {
let task = tasks::update(id, task, &app_state.db_pool).await?;
Ok(Json(task))
Json(payload): Json<UpdateRequestPayload>,
) -> Result<Json<UpdateResponsePayload>, Error> {
let task = tasks::update(id, payload.into(), &app_state.db_pool).await?;
Ok(Json(task.into()))
}

/// Deletes a task identified by its ID from the database.
Expand All @@ -88,3 +89,32 @@ pub async fn delete(
tasks::delete(id, &app_state.db_pool).await?;
Ok(StatusCode::NO_CONTENT)
}

mod payloads {
use {{crate_name}}_db::entities::tasks::{Task, TaskChangeset};
use {{crate_name}}_macros::{batch_request_payload, request_payload, response_payload};

#[derive(Debug)]
#[request_payload]
pub struct CreateRequestPayload(TaskChangeset);

#[derive(Debug)]
#[response_payload]
pub struct CreateResponsePayload(Task);

#[derive(Debug)]
#[batch_request_payload]
pub struct CreateBatchRequestPayload(Vec<TaskChangeset>);

#[derive(Debug)]
#[response_payload]
pub struct CreateBatchResponsePayload(Vec<Task>);

#[derive(Debug)]
#[request_payload]
pub struct UpdateRequestPayload(TaskChangeset);

#[derive(Debug)]
#[response_payload]
pub struct UpdateResponsePayload(Task);
}

0 comments on commit bddbd86

Please sign in to comment.