Skip to content

Commit

Permalink
Use axum instead of actix
Browse files Browse the repository at this point in the history
  • Loading branch information
vehlwn committed Dec 29, 2024
1 parent 464792b commit 498dac8
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 57 deletions.
20 changes: 10 additions & 10 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,16 @@ lto = true
strip = true

[dependencies]
actix-cors = "0.7.0"
actix-web = "4.7.0"
anyhow = "1.0.86"
clap = { version = "4.5.4", features = ["derive"] }
env_logger = "0.11.3"
log = "0.4.21"
mime = "0.3.17"
anyhow = "1.0.95"
axum = "0.7.9"
clap = { version = "4.5.23", features = ["derive"] }
env_logger = "0.11.6"
log = "0.4.22"
pleco = "0.5.0"
rand = "0.8.5"
rayon = "1.10.0"
regex = "1.10.4"
serde = { version = "1.0.203", features = ["derive"] }
systemd-journal-logger = "2.1.1"
regex = "1.11.1"
serde = { version = "1.0", features = ["derive"] }
systemd-journal-logger = "2.2.0"
tokio = { version = "1.42.0", features = ["macros", "rt-multi-thread", "net"] }
tower-http = { version = "0.6.2", features = ["cors"] }
91 changes: 44 additions & 47 deletions src/bin/chess-alpha-beta-server.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
use actix_cors::Cors;
use actix_web::{http::header, post, web, App, HttpResponse, HttpServer, Responder};
use anyhow::Context;

use axum::extract::Json;
use axum::http::header;
use axum::http::{Method, StatusCode};
use axum::response::{IntoResponse, Response};
use axum::routing::{get, post};
use tower_http::cors::{AllowOrigin, CorsLayer};

use chess_alpha_beta::alpha_beta::{get_best_move, ValueType};

#[derive(serde::Deserialize)]
Expand All @@ -16,30 +21,33 @@ struct GetBestMoveResponse {
value: ValueType,
}

#[post("/get_best_move")]
async fn api_get_best_move(json: web::Json<GetBestMoveRequest>) -> impl Responder {
async fn api_get_best_move(Json(json): Json<GetBestMoveRequest>) -> Response {
if json.search_depth.get() > 10 {
return HttpResponse::BadRequest()
.content_type(mime::TEXT_PLAIN_UTF_8)
.body("search_depth is too large!");
return (StatusCode::BAD_REQUEST, "search_depth is too large!")
.into_response();
}

let board = match pleco::Board::from_fen(&json.fen) {
Ok(x) => x,
Err(e) => {
return HttpResponse::BadRequest()
.content_type(mime::TEXT_PLAIN_UTF_8)
.body(format!("Failed to parse FEN: {e:?}"));
return (
StatusCode::BAD_REQUEST,
format!("Failed to parse FEN: {e:?}"),
)
.into_response();
}
};
return match get_best_move(&board, json.search_depth) {
Ok(ok) => HttpResponse::Ok().json(GetBestMoveResponse {
Ok(ok) => axum::response::Json(GetBestMoveResponse {
m: ok.m.to_string(),
value: ok.value,
}),
Err(e) => HttpResponse::BadRequest()
.content_type(mime::TEXT_PLAIN_UTF_8)
.body(format!("get_best_move failed: {e}")),
})
.into_response(),
Err(e) => (
StatusCode::BAD_REQUEST,
format!("get_best_move failed: {e}"),
)
.into_response(),
};
}

Expand Down Expand Up @@ -74,41 +82,30 @@ fn init_logging(args: &Args) -> anyhow::Result<()> {
return Ok(());
}

#[actix_web::main]
#[tokio::main]
async fn main() -> anyhow::Result<()> {
use clap::Parser;
let args = Args::parse();
init_logging(&args).context("Failed to init logging")?;

let server = HttpServer::new(|| {
App::new()
.wrap(actix_web::middleware::Logger::default())
.wrap(
// default settings are overly restrictive to reduce chance of
// misconfiguration leading to security concerns
Cors::default()
// allow any port on localhost
.allowed_origin_fn(|origin, _req_head| {
origin.as_bytes().starts_with(b"http://localhost")
})
// set allowed methods list
.allowed_methods(vec!["GET", "POST"])
// set allowed request header list
.allowed_headers(&[header::AUTHORIZATION, header::ACCEPT])
// add header to allowed list
.allowed_header(header::CONTENT_TYPE)
// set list of headers that are safe to expose
.expose_headers(&[header::CONTENT_DISPOSITION])
// allow cURL/HTTPie from working without providing Origin headers
.block_on_origin_mismatch(false)
// set preflight cache TTL
.max_age(3600),
)
.service(web::scope("/api").service(api_get_best_move))
.route("/healthy", web::get().to(HttpResponse::Ok))
})
.bind(args.bind_addr)
.with_context(|| format!("Failed to bind HTTP server to {}", args.bind_addr))?;
log::info!("Server listening {:?}", server.addrs());
return server.run().await.context("Failed to run server");
let app = axum::Router::new()
.route("/api/get_best_move", post(api_get_best_move))
.route("/healthy", get(|| std::future::ready("ok")))
.layer(
CorsLayer::new()
.allow_origin(AllowOrigin::predicate(|origin, _request_parts| {
origin.as_bytes().starts_with(b"http://localhost")
}))
.allow_methods([Method::GET, Method::POST])
.allow_headers([header::CONTENT_TYPE]),
);

let listener = tokio::net::TcpListener::bind(&args.bind_addr)
.await
.with_context(|| {
format!("Failed to bind HTTP server to '{}'", args.bind_addr)
})?;
log::info!("Server listening {:?}", listener.local_addr().unwrap());
axum::serve(listener, app).await?;
return Ok(());
}

0 comments on commit 498dac8

Please sign in to comment.