diff --git a/CHANGELOG.md b/CHANGELOG.md new file mode 100644 index 00000000..94e9d3a0 --- /dev/null +++ b/CHANGELOG.md @@ -0,0 +1,17 @@ +# Changelog + +All notable changes to this project will be documented in this file. + +The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), +and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). + +# 0.2.0 + +> WIP + +# 0.1.0 + +> Release date: `2022-09-01` + +Reserve the name `rama` on crates.io and +start the R&D and design work in Rust of this project. diff --git a/Cargo.lock b/Cargo.lock index 8393d288..b724b73c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -60,53 +60,59 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5c6cb57a04249c6480766f7f7cef5467412af1490f8d1e243141daddada3264f" [[package]] -name = "anstyle" -version = "1.0.7" +name = "anstream" +version = "0.6.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "038dfcf04a5feb68e9c60b21c9625a54c2c0616e79b72b0fd87075a056ae1d1b" +checksum = "418c75fa768af9c03be99d17643f93f79bbba589895012a80e3452a19ddda15b" +dependencies = [ + "anstyle", + "anstyle-parse", + "anstyle-query", + "anstyle-wincon", + "colorchoice", + "is_terminal_polyfill", + "utf8parse", +] [[package]] -name = "anyhow" -version = "1.0.86" +name = "anstyle" +version = "1.0.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b3d1d046238990b9cf5bcde22a3fb3584ee5cf65fb2765f454ed428c7a0063da" +checksum = "038dfcf04a5feb68e9c60b21c9625a54c2c0616e79b72b0fd87075a056ae1d1b" [[package]] -name = "arbitrary" -version = "1.3.2" +name = "anstyle-parse" +version = "0.2.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7d5a26814d8dcb93b0e5a0ff3c6d80a8843bafb21b39e8e18a6f05471870e110" +checksum = "c03a11a9034d92058ceb6ee011ce58af4a9bf61491aa7e1e59ecd24bd40d22d4" +dependencies = [ + "utf8parse", +] [[package]] -name = "argh" -version = "0.1.12" +name = "anstyle-query" +version = "1.0.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7af5ba06967ff7214ce4c7419c7d185be7ecd6cc4965a8f6e1d8ce0398aad219" +checksum = "a64c907d4e79225ac72e2a354c9ce84d50ebb4586dee56c82b3ee73004f537f5" dependencies = [ - "argh_derive", - "argh_shared", + "windows-sys 0.52.0", ] [[package]] -name = "argh_derive" -version = "0.1.12" +name = "anstyle-wincon" +version = "3.0.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "56df0aeedf6b7a2fc67d06db35b09684c3e8da0c95f8f27685cb17e08413d87a" +checksum = "61a38449feb7068f52bb06c12759005cf459ee52bb4adc1d5a7c4322d716fb19" dependencies = [ - "argh_shared", - "proc-macro2", - "quote", - "syn", + "anstyle", + "windows-sys 0.52.0", ] [[package]] -name = "argh_shared" -version = "0.1.12" +name = "arbitrary" +version = "1.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5693f39141bda5760ecc4111ab08da40565d1771038c4a0250f03457ec707531" -dependencies = [ - "serde", -] +checksum = "7d5a26814d8dcb93b0e5a0ff3c6d80a8843bafb21b39e8e18a6f05471870e110" [[package]] name = "async-compression" @@ -280,6 +286,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "90bc066a67923782aa8515dbaea16946c5bcc5addbd668bb80af688e53e548a0" dependencies = [ "clap_builder", + "clap_derive", ] [[package]] @@ -288,23 +295,63 @@ version = "4.5.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ae129e2e766ae0ec03484e609954119f123cc1fe650337e155d03b022f24f7b4" dependencies = [ + "anstream", "anstyle", "clap_lex", + "strsim", "terminal_size", ] +[[package]] +name = "clap_derive" +version = "4.5.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "528131438037fd55894f62d6e9f068b8f45ac57ffa77517819645d10aed04f64" +dependencies = [ + "heck", + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "clap_lex" version = "0.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "98cc8fbded0c607b7ba9dd60cd98df59af97e84d24e49c8557331cfc26d301ce" +[[package]] +name = "colorchoice" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0b6a852b24ab71dffc585bcb46eaf7959d175cb865a7152e35b348d1b2960422" + [[package]] name = "condtype" version = "1.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "baf0a07a401f374238ab8e2f11a104d2851bf9ce711ec69804834de8af45c7af" +[[package]] +name = "const_format" +version = "0.2.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e3a214c7af3d04997541b18d432afaff4c455e79e2029079647e72fc2bd27673" +dependencies = [ + "const_format_proc_macros", +] + +[[package]] +name = "const_format_proc_macros" +version = "0.2.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c7f6ff08fd20f4f299298a28e2dfa8a8ba1036e6cd2460ac1de7b425d76f2500" +dependencies = [ + "proc-macro2", + "quote", + "unicode-xid", +] + [[package]] name = "core-foundation" version = "0.9.4" @@ -657,12 +704,24 @@ dependencies = [ "http", ] +[[package]] +name = "heck" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" + [[package]] name = "hermit-abi" version = "0.3.9" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d231dfb89cfffdbc30e7fc41579ed6066ad03abda9e567ccafae602b97ec5024" +[[package]] +name = "hex" +version = "0.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7f24254aa9a54b5c858eaee2f5bccdb46aaf0e486a595ed5fd8f86ba55232a70" + [[package]] name = "http" version = "1.1.0" @@ -777,6 +836,12 @@ dependencies = [ "serde", ] +[[package]] +name = "is_terminal_polyfill" +version = "1.70.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f8478577c03552c21db0e2724ffb8986a5ce7af88107e6be5d2ee6e158c12800" + [[package]] name = "itertools" version = "0.13.0" @@ -1185,13 +1250,14 @@ checksum = "dc33ff2d4973d518d823d61aa239014831e521c75da58e3df4840d3f47749d09" [[package]] name = "rama" -version = "0.2.0" +version = "0.2.0-alpha.0" dependencies = [ "async-compression", "base64 0.22.1", "bitflags", "brotli", "bytes", + "const_format", "divan", "escargot", "flate2", @@ -1203,7 +1269,6 @@ dependencies = [ "http-body", "http-body-util", "http-range-header", - "httparse", "httpdate", "hyper", "hyper-util", @@ -1251,24 +1316,27 @@ dependencies = [ [[package]] name = "rama-cli" -version = "0.2.0" +version = "0.2.0-alpha.0" dependencies = [ - "anyhow", - "argh", + "bytes", + "clap", + "hex", "rama", + "serde_json", + "terminal-prompt", "tokio", + "tracing", + "tracing-subscriber", ] [[package]] name = "rama-fp" -version = "0.2.0" +version = "0.2.0-alpha.0" dependencies = [ - "anyhow", - "argh", "base64 0.22.1", + "clap", "rama", "serde", - "serde_html_form", "serde_json", "tokio", "tracing", @@ -1285,7 +1353,7 @@ dependencies = [ [[package]] name = "rama-macros" -version = "0.2.0" +version = "0.2.0-alpha.0" dependencies = [ "proc-macro2", "quote", @@ -1654,6 +1722,12 @@ version = "0.9.8" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6980e8d7511241f8acf4aebddbb1ff938df5eebe98691418c4468d0b72a96a67" +[[package]] +name = "strsim" +version = "0.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7da8b5736845d9f2fcb837ea5d9e2628564b3b043a70948a3f0b778838c5fb4f" + [[package]] name = "subtle" version = "2.5.0" @@ -1704,6 +1778,16 @@ dependencies = [ "winapi-util", ] +[[package]] +name = "terminal-prompt" +version = "0.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "572818b3472910acbd5dff46a3413715c18e934b071ab2ba464a7b2c2af16376" +dependencies = [ + "libc", + "winapi", +] + [[package]] name = "terminal_size" version = "0.3.0" @@ -1780,9 +1864,9 @@ checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" [[package]] name = "tokio" -version = "1.37.0" +version = "1.38.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1adbebffeca75fcfd058afa480fb6c0b81e165a0323f9c9d39c9697e37c46787" +checksum = "ba4f4a02a7a80d6f274636f0aa95c7e383b912d41fe721a31f29e29698585a4a" dependencies = [ "backtrace", "bytes", @@ -1812,9 +1896,9 @@ dependencies = [ [[package]] name = "tokio-macros" -version = "2.2.0" +version = "2.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5b8a1e28f2deaa14e508979454cb3a223b10b938b45af148bc0986de36f1923b" +checksum = "5f5ae998a069d4b5aba8ee9dad856af7d520c3699e6159b185c2acd48155d39a" dependencies = [ "proc-macro2", "quote", @@ -2014,12 +2098,24 @@ dependencies = [ "tinyvec", ] +[[package]] +name = "unicode-xid" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f962df74c8c05a667b5ee8bcf162993134c104e96440b663c8daa176dc772d8c" + [[package]] name = "untrusted" version = "0.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8ecb6da28b8a351d773b68d5825ac39017e680750f980f3a1a85cd8dd28a47c1" +[[package]] +name = "utf8parse" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "711b9620af191e0cdc7468a8d14e709c3dcdb115b36f838e601583af800a370a" + [[package]] name = "uuid" version = "1.8.0" diff --git a/Cargo.toml b/Cargo.toml index a914033a..f4a56d76 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -2,7 +2,7 @@ members = [".", "fuzz", "rama-cli", "rama-fp", "rama-macros"] [workspace.package] -version = "0.2.0" +version = "0.2.0-alpha.0" license = "MIT OR Apache-2.0" edition = "2021" repository = "https://github.com/plabayo/rama" @@ -13,24 +13,23 @@ authors = ["Glen De Cauwsemaecker "] rust-version = "1.75.0" [workspace.dependencies] -anyhow = "1.0" async-compression = "0.4" base64 = "0.22" bitflags = "2.4" brotli = "6" bytes = "1" -argh = "0.1" +clap = { version = "4.5.4", features = ["derive"] } crossterm = "0.27" flate2 = "1.0" futures-lite = "2.3.0" futures-core = "0.3" h2 = "0.4" headers = "0.4" +hex = "0.4" http = "1" http-body = "1" http-body-util = "0.1" http-range-header = "0.4.0" -httparse = "1.8" httpdate = "1.0" hyper = "1.2" hyper-util = "0.1.4" @@ -79,7 +78,9 @@ iri-string = "0.7.0" escargot = "0.5.10" divan = "0.1.14" webpki-roots = "0.26.1" +terminal-prompt = "0.2.3" parking_lot = "0.12.3" +const_format = "0.2.32" [package] name = "rama" @@ -110,6 +111,7 @@ async-compression = { workspace = true, features = ["tokio", "brotli", "zlib", " base64 = { workspace = true } bitflags = { workspace = true } bytes = { workspace = true } +const_format = { workspace = true } futures-core = { workspace = true } futures-lite = { workspace = true } h2 = { workspace = true } @@ -118,7 +120,6 @@ http = { workspace = true } http-body = { workspace = true } http-body-util = { workspace = true } http-range-header = { workspace = true } -httparse = { workspace = true } httpdate = { workspace = true } hyper = { workspace = true, features = ["http1", "http2", "server", "client"] } hyper-util = { workspace = true, features = ["tokio", "server-auto"] } @@ -147,7 +148,7 @@ serde = { workspace = true, features = ["derive"] } serde_html_form = { workspace = true } serde_json = { workspace = true } sync_wrapper = { workspace = true } -tokio = { workspace = true, features = ["macros", "fs"] } +tokio = { workspace = true, features = ["macros", "fs", "io-std"] } tokio-graceful = { workspace = true } tokio-rustls = { workspace = true } tokio-util = { workspace = true } diff --git a/README.md b/README.md index 4b691145..cdb34228 100644 --- a/README.md +++ b/README.md @@ -62,7 +62,7 @@ This framework comes with πŸ”‹ batteries included, giving you the full freedome | πŸ—οΈ [User Agent (UA)](https://ramaproxy.org/book/intro/user_agent) | πŸ—οΈ Http Emulation (1) βΈ± πŸ—οΈ Tls Emulation (1) βΈ± βœ… [UA Parsing](https://ramaproxy.org/docs/rama/ua/struct.UserAgent.html) | | πŸ—οΈ utilities | βœ… [error handling](https://ramaproxy.org/docs/rama/error/index.html) βΈ± βœ… [graceful shutdown](https://ramaproxy.org/docs/rama/utils/graceful/index.html) βΈ± πŸ—οΈ Connection Pool (1) βΈ± πŸ—οΈ IP2Loc (2) | | πŸ—οΈ [TUI](https://ratatui.rs/) | πŸ—οΈ traffic logger (2) βΈ± πŸ—οΈ curl export (2) βΈ± ❌ traffic intercept (3) βΈ± ❌ traffic replay (3) | -| πŸ—οΈ binary | πŸ—οΈ prebuilt binaries (1) βΈ± πŸ—οΈ proxy config (2) βΈ± πŸ—οΈ http client (1) βΈ± ❌ WASM Plugins (3) | +| βœ… binary | βœ… [prebuilt binaries](https://ramaproxy.org/book/binary/rama) βΈ± πŸ—οΈ proxy config (2) βΈ± βœ… http client βΈ± ❌ WASM Plugins (3) | | πŸ—οΈ data scraping | πŸ—οΈ Html Processor (2) βΈ± ❌ Json Processor (3) | | ❌ browser | ❌ JS Engine (3) βΈ± ❌ [Web API](https://developer.mozilla.org/en-US/docs/Web/API) Emulation (3) | @@ -125,6 +125,16 @@ rama = { git = "https://github.com/plabayo/rama" } πŸ’¬ Come join us at [Discord][discord-url] on the `#rama` public channel. To ask questions, discuss ideas and ask how rama may be useful for you. +## ⌨️ | `rama` binary + +The `rama` binary allows you to use a lot of what `rama` has to offer without +having to code yourself. It comes with a working http client for CLI, which emulates +User-Agents and has other utilities. And it also comes with IP/Echo services. + +It also allows you to run a `rama` proxy, configured to your needs. + +Learn more about the `rama` binary and how to install it at . + ## πŸ§ͺ | Experimental πŸ¦™ Rama (γƒ©γƒž) is to be considered experimental software for the foreseeable future. In the meanwhile it is already used diff --git a/docs/book/src/SUMMARY.md b/docs/book/src/SUMMARY.md index 85cf6cf4..23a2d5ba 100644 --- a/docs/book/src/SUMMARY.md +++ b/docs/book/src/SUMMARY.md @@ -31,5 +31,11 @@ - [πŸ”Ž MITM proxies](./proxies/mitm.md) - [πŸ•΅οΈβ€β™€οΈ Distortion proxies](./proxies/distort.md) -[❓ FAQ](./faq.md) -[πŸ’– Sponsor](./sponsor.md) +# Binary + +- [⌨️ `rama` binary](./binary/rama.md) + +# Appendices + +- [❓ FAQ](./faq.md) +- [πŸ’– Sponsor](./sponsor.md) diff --git a/docs/book/src/binary/rama.md b/docs/book/src/binary/rama.md new file mode 100644 index 00000000..c4a57f68 --- /dev/null +++ b/docs/book/src/binary/rama.md @@ -0,0 +1,53 @@ +# ⌨️ `rama` binary + +The `rama` binary allows you to use a lot of what `rama` has to offer without +having to code yourself. It comes with a working http client for CLI, which emulates +User-Agents and has other utilities. And it also comes with IP/Echo services. + +It also allows you to run a `rama` proxy, configured to your needs. + +## Usage + +```text +rama cli to move and transform network packets + +Usage: rama + +Commands: + echo rama echo service (echos the http request and tls client config) + http rama http client + proxy rama proxy runner + ip rama ip service (returns the ip address of the client) + help Print this message or the help of the given subcommand(s) + +Options: + -h, --help Print help + -V, --version Print version +``` + +## Install + +> ❗ None of these install instructions work at the moment, +> as we still need to release a first alpha version of `rama` to make this work. +> These instructions are for now just preparation towards that. + +The easiest way to install `rama` is by using `cargo`: + +```sh +cargo install rama-cli +``` + +This will install `rama-cli` from source and make it available +under your cargo _bin_ folder as `rama`. In case you want to install +a pre-built binary when available for your platform you can do so +using [`cargo binstall`](https://github.com/cargo-bins/cargo-binstall): + +```sh +cargo binstall rama-cli +``` + +On 🍎 MacOS you can also install the `rama` binary using [HomeBrew](https://brew.sh/): + +``` +brew install rama +``` diff --git a/docs/book/src/faq.md b/docs/book/src/faq.md index 3f3f26dd..3e452ed4 100644 --- a/docs/book/src/faq.md +++ b/docs/book/src/faq.md @@ -106,3 +106,33 @@ Most commonly you might get this error, especially the difficult ones, for high - return a Result as the output of an `Endpoint` service/fn (when using the `WebService` router), instead of only returning the happy path value; There are other possibilities to get long wielded compiler errors as well. It is not feasible to list all possible reasons here, but know most likely it is among the lines of the examples above. If not, and you continue to be stuck, to feel free to join our discord at and reach out for help. We're here for you. + +## my cargo check/build/... commands take forever + +[Service stacks](./intro/service_stack.md) can become quiet complex in Rama. In case you notice that your current change +makes the `cargo check` command (or something similar) becomes very slow, it should hopefully be clear +why by checking `git diff` or a similar VCS action. + +The most common reasons for this is if: + +1. you have a very large function which also contains deeply nested generic types; +2. you have a lot of [`Either`] service/layer stuff within your [Service stacks](./intro/service_stack.md). + +It's especially (2) that can slow you down if you overuse it. This usually comes op in case you use +plenty of `Option>` code to optionally create a layer based on a certain input/config variable. +While this might seem like a good idea, and it can be if used sparsly, it can really slow you down once you +use a couple of these. This is because under the hood this results in `Either`, meaning your +`S` service (stack) will be twice in that signature. Do that a couple of times and you very quickly have a very long long type. + +Therefore it is recommended for optional layers/services to instead provide an option to create the same kind of layer/service +type, but in a "nop" mode. Meaning the (middleware) service would essentially do nothing more then passing the request and response. + +Middleware provided by `rama` should provide this for all types that are commonly used in a setting where they might be opt-in. +Please do [open an issue](https://github.com/plabayo/rama/issues) if you notice a case for which this is not yet possible. + +Another option is to use [`Either`] on the internal policy/config items used by your layer. +[`follow_redirect::policy::Unlimited`](https://ramaproxy.org/docs/rama/http/layer/follow_redirect/policy/struct.Unlimited.html) is an example +of this, to allow you to have a `redirect` layer which is either limited or not. This is fine, +because your `Either` has only a depth of one, in contrast to having it contain the entire inner "service stack". + +[`Either`]: https://ramaproxy.org/docs/rama/service/util/combinators/enum.Either.html diff --git a/docs/book/src/preface.md b/docs/book/src/preface.md index 0e513cbc..79bb367b 100644 --- a/docs/book/src/preface.md +++ b/docs/book/src/preface.md @@ -53,7 +53,7 @@ This framework comes with πŸ”‹ batteries included, giving you the full freedome | πŸ—οΈ [User Agent (UA)](https://ramaproxy.org/book/intro/user_agent) | πŸ—οΈ Http Emulation (1) βΈ± πŸ—οΈ Tls Emulation (1) βΈ± βœ… [UA Parsing](https://ramaproxy.org/docs/rama/ua/struct.UserAgent.html) | | πŸ—οΈ utilities | βœ… [error handling](https://ramaproxy.org/docs/rama/error/index.html) βΈ± βœ… [graceful shutdown](https://ramaproxy.org/docs/rama/utils/graceful/index.html) βΈ± πŸ—οΈ Connection Pool (1) βΈ± πŸ—οΈ IP2Loc (2) | | πŸ—οΈ [TUI](https://ratatui.rs/) | πŸ—οΈ traffic logger (2) βΈ± πŸ—οΈ curl export (2) βΈ± ❌ traffic intercept (3) βΈ± ❌ traffic replay (3) | -| πŸ—οΈ binary | πŸ—οΈ prebuilt binaries (1) βΈ± πŸ—οΈ proxy config (2) βΈ± πŸ—οΈ http client (1) βΈ± ❌ WASM Plugins (3) | +| βœ… binary | βœ… [prebuilt binaries](https://ramaproxy.org/book/binary/rama) βΈ± πŸ—οΈ proxy config (2) βΈ± βœ… http client βΈ± ❌ WASM Plugins (3) | | πŸ—οΈ data scraping | πŸ—οΈ Html Processor (2) βΈ± ❌ Json Processor (3) | | ❌ browser | ❌ JS Engine (3) βΈ± ❌ [Web API](https://developer.mozilla.org/en-US/docs/Web/API) Emulation (3) | @@ -115,6 +115,16 @@ to know how to use rama for your purposes. πŸ’– Please consider becoming [a sponsor][ghs-url] if you critically depend upon Rama (γƒ©γƒž) or if you are a fan of the project. +## ⌨️ | `rama` binary + +The `rama` binary allows you to use a lot of what `rama` has to offer without +having to code yourself. It comes with a working http client for CLI, which emulates +User-Agents and has other utilities. And it also comes with IP/Echo services. + +It also allows you to run a `rama` proxy, configured to your needs. + +Learn more about the `rama` binary and how to install it at [/binary/rama.md](./binary/rama.md). + ## πŸ§ͺ | Experimental πŸ¦™ Rama (γƒ©γƒž) is to be considered experimental software for the foreseeable future. In the meanwhile it is already used diff --git a/justfile b/justfile index 07f407ca..6e647458 100644 --- a/justfile +++ b/justfile @@ -91,3 +91,15 @@ vet: miri: cargo +nightly miri test + +detect-unused-deps: + cargo machete --skip-target-dir + +detect-biggest-fn: + cargo bloat --package rama-cli --release -n 10 + +detect-biggest-crates: + cargo bloat --package rama-cli --release --crates + +mdbook-serve: + cd docs/book && mdbook serve diff --git a/rama-cli/Cargo.toml b/rama-cli/Cargo.toml index 5e77d56f..8f6ff88d 100644 --- a/rama-cli/Cargo.toml +++ b/rama-cli/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "rama-cli" -description = "binary version of and cli utility for rama, a modular service framework" +description = "rama cli to move and transform network packets" version = { workspace = true } license = { workspace = true } edition = { workspace = true } @@ -12,10 +12,15 @@ rust-version = { workspace = true } default-run = "rama" [dependencies] -anyhow = { workspace = true } -argh = { workspace = true } -rama = { version = "0.2", path = ".." } +bytes = { workspace = true } +clap = { workspace = true } +hex = { workspace = true } +rama = { version = "0.2.0-alpha.0", path = ".." } +serde_json = { workspace = true } +terminal-prompt = { workspace = true } tokio = { workspace = true, features = ["rt-multi-thread", "macros"] } +tracing = { workspace = true } +tracing-subscriber = { workspace = true, features = ["env-filter"] } [[bin]] name = "rama" diff --git a/rama-cli/src/echo/mod.rs b/rama-cli/src/echo/mod.rs new file mode 100644 index 00000000..5af03f92 --- /dev/null +++ b/rama-cli/src/echo/mod.rs @@ -0,0 +1,187 @@ +use clap::Args; +use rama::{ + error::BoxError, + http::{ + dep::http_body_util::BodyExt, + layer::{required_header::AddRequiredResponseHeadersLayer, trace::TraceLayer}, + response::Json, + server::HttpServer, + IntoResponse, Request, RequestContext, Response, + }, + proxy::pp::server::HaProxyLayer, + rt::Executor, + service::{ + layer::{limit::policy::ConcurrentPolicy, LimitLayer, TimeoutLayer}, + Context, ServiceBuilder, + }, + stream::{layer::http::BodyLimitLayer, SocketInfo}, + tcp::server::TcpListener, + tls::rustls::server::IncomingClientHello, + ua::{UserAgent, UserAgentClassifierLayer}, +}; +use serde_json::json; +use std::{convert::Infallible, time::Duration}; +use tracing::level_filters::LevelFilter; +use tracing_subscriber::{fmt, layer::SubscriberExt, util::SubscriberInitExt, EnvFilter}; + +#[derive(Debug, Args)] +/// rama echo service (echos the http request and tls client config) +pub struct CliCommandEcho { + #[arg(short = 'p', long, default_value_t = 8080)] + /// the port to listen on + port: u16, + + #[arg(short = 'i', long, default_value = "127.0.0.1")] + /// the interface to listen on + interface: String, + + #[arg(short = 'c', long, default_value_t = 0)] + /// the number of concurrent connections to allow (0 = no limit) + concurrent: usize, + + #[arg(short = 't', long, default_value_t = 8)] + /// the timeout in seconds for each connection (0 = no timeout) + timeout: u64, + + #[arg(short = 'a', long)] + /// enable HaProxy PROXY Protocol + ha_proxy: bool, +} + +pub async fn run(cfg: CliCommandEcho) -> Result<(), BoxError> { + tracing_subscriber::registry() + .with(fmt::layer()) + .with( + EnvFilter::builder() + .with_default_directive(LevelFilter::INFO.into()) + .from_env_lossy(), + ) + .init(); + + let graceful = rama::utils::graceful::Shutdown::default(); + + let address = format!("{}:{}", cfg.interface, cfg.port); + tracing::info!("starting echo service on: {}", address); + + graceful.spawn_task_fn(move |guard| async move { + let tcp_listener = TcpListener::build() + .bind(address) + .await + .expect("bind echo service to 127.0.0.1:62001"); + + let tcp_service_builder = ServiceBuilder::new() + .layer( + (cfg.concurrent > 0) + .then(|| LimitLayer::new(ConcurrentPolicy::max(cfg.concurrent))), + ) + .layer((cfg.timeout > 0).then(|| TimeoutLayer::new(Duration::from_secs(cfg.timeout)))) + .layer((cfg.ha_proxy).then(HaProxyLayer::default)) + // Limit the body size to 1MB for requests + .layer(BodyLimitLayer::request_only(1024 * 1024)); + + // TODO: support opt-in TLS + + // TODO document how one would force IPv4 or IPv6 + + let http_service = ServiceBuilder::new() + .layer(TraceLayer::new_for_http()) + .layer(AddRequiredResponseHeadersLayer::default()) + .layer(UserAgentClassifierLayer::new()) + .service_fn(echo); + + let tcp_service = tcp_service_builder + .service(HttpServer::auto(Executor::graceful(guard.clone())).service(http_service)); + + tracing::info!("echo service ready"); + + tcp_listener.serve_graceful(guard, tcp_service).await; + }); + + graceful + .shutdown_with_limit(Duration::from_secs(30)) + .await?; + + Ok(()) +} + +pub async fn echo(ctx: Context, req: Request) -> Result { + let user_agent_info = ctx + .get() + .map(|ua: &UserAgent| { + json!({ + "user_agent": ua.header_str().to_owned(), + "kind": ua.info().map(|info| info.kind.to_string()), + "version": ua.info().and_then(|info| info.version), + "platform": ua.platform().map(|v| v.to_string()), + }) + }) + .unwrap_or_default(); + + let authority = ctx + .get::() + .and_then(RequestContext::authority); + + // TODO: get in correct order + // TODO: get in correct case + // TODO: get also pseudo headers (or separate?!) + + let headers: Vec<_> = req + .headers() + .iter() + .map(|(name, value)| { + ( + name.as_str().to_owned(), + value.to_str().map(|v| v.to_owned()).unwrap_or_default(), + ) + }) + .collect(); + + let (parts, body) = req.into_parts(); + + let body = body.collect().await.unwrap().to_bytes(); + let body = hex::encode(body.as_ref()); + + let tls_client_hello = ctx.get::().map(|hello| { + json!({ + "server_name": hello.server_name.clone(), + "signature_schemes": hello + .signature_schemes + .iter() + .map(|v| format!("{:?}", v)) + .collect::>(), + "alpn": hello.alpn.clone(), + "cipher_suites": hello + .cipher_suites + .iter() + .map(|v| format!("{:?}", v)) + .collect::>(), + }) + }); + + Ok(Json(json!({ + "ua": user_agent_info, + "http": { + "version": format!("{:?}", parts.version), + "scheme": parts.uri + .scheme_str() + .map(|v| v.to_owned()) + .unwrap_or_else(|| { + if ctx.get::().is_some() { + "https" + } else { + "http" + } + .to_owned() + }), + "method": format!("{:?}", parts.method), + "authority": authority, + "path": parts.uri.path().to_string(), + "query": parts.uri.query().map(str::to_owned), + "headers": headers, + "payload": body, + }, + "tls": tls_client_hello, + "ip": ctx.get::().map(|v| v.peer_addr().to_string()), + })) + .into_response()) +} diff --git a/rama-cli/src/http/mod.rs b/rama-cli/src/http/mod.rs new file mode 100644 index 00000000..cdbf4d97 --- /dev/null +++ b/rama-cli/src/http/mod.rs @@ -0,0 +1,434 @@ +use clap::Args; +use rama::{ + cli::args::RequestArgsBuilder, + error::{error, BoxError, ErrorContext, OpaqueError}, + http::{ + client::HttpClient, + layer::{ + auth::AddAuthorizationLayer, + decompression::DecompressionLayer, + follow_redirect::{policy::Limited, FollowRedirectLayer}, + required_header::AddRequiredRequestHeadersLayer, + timeout::TimeoutLayer, + traffic_writer::WriterMode, + }, + IntoResponse, Request, Response, StatusCode, + }, + proxy::http::client::HttpProxyConnectorLayer, + rt::Executor, + service::{layer::HijackLayer, service_fn, Context, Service, ServiceBuilder}, + tcp::service::HttpConnector, + tls::rustls::client::HttpsConnectorLayer, + utils::graceful::{self, Shutdown, ShutdownGuard}, +}; +use std::{io::IsTerminal, time::Duration}; +use terminal_prompt::Terminal; +use tokio::sync::oneshot; +use tracing::level_filters::LevelFilter; +use tracing_subscriber::{fmt, layer::SubscriberExt, util::SubscriberInitExt, EnvFilter}; + +mod tls; +mod writer; + +#[derive(Args, Debug, Clone)] +/// rama http client +pub struct CliCommandHttp { + #[arg(short = 'j', long)] + /// data items from the command line are serialized as a JSON object. + /// The `Content-Type` and `Accept headers` are set to `application/json` + /// (if not specified) + /// + /// (default) + json: bool, + + #[arg(short = 'f', long)] + /// data items from the command line are serialized as form fields. + /// + /// The `Content-Type` is set to `application/x-www-form-urlencoded` (if not specified). + form: bool, + + #[arg(short = 'F', long)] + /// follow 30 Location redirects + follow: bool, + + #[arg(long, default_value_t = 30)] + /// the maximum number of redirects to follow + max_redirects: usize, + + #[arg(long, short = 'a')] + /// client authentication: `USER[:PASS]` | TOKEN, + /// if basic and no password is given it will be promped + auth: Option, + + #[arg(long, short = 'A', default_value = "basic")] + /// the type of authentication to use (basic, bearer) + auth_type: String, + + #[arg(short = 'k', long)] + /// skip Tls certificate verification + insecure: bool, + + #[arg(long)] + /// the desired tls version to use (automatically defined by default, choices are: 1.2, 1.3) + tls: Option, + + #[arg(long)] + /// the client tls certificate file path to use + cert: Option, + + #[arg(long)] + /// the client tls key file path to use + cert_key: Option, + + #[arg(long, short = 't', default_value = "0")] + /// the timeout in seconds for each connection (0 = default timeout of 180s) + timeout: u64, + + #[arg(long)] + /// fail if status code is not 2xx (4 if 4xx and 5 if 5xx) + check_status: bool, + + #[arg(long, short = 'p')] + /// define what the output should contain ('h'/'H' for headers, 'b'/'B' for body (response/request) + print: Option, + + #[arg(short = 'b', long)] + /// print the response body (short for --print b) + body: bool, + + #[arg(short = 'H', long)] + /// print the response headers (short for --print h) + headers: bool, + + #[arg(short = 'v', long)] + /// print verbose output, alias for --all --print hHbB (not used in offline mode) + verbose: bool, + + #[arg(long)] + /// show output for all requests/responses (including redirects) + all: bool, + + #[arg(long)] + /// print the request instead of executing it + offline: bool, + + #[arg(long, short = 'o')] + /// write output to file instead of stdout + output: Option, + + #[arg(long)] + /// print debug info + debug: bool, + + #[arg(trailing_var_arg = true, allow_hyphen_values = true)] + /// positional arguments to populate request headers and body + /// + /// These arguments come after any flags and in the order they are listed here. + /// Only the URL is required. + /// + /// # METHOD + /// + /// The HTTP method to be used for the request (GET, POST, PUT, DELETE, ...). + /// + /// This argument can be omitted in which case HTTPie will use POST if there + /// is some data to be sent, otherwise GET: + /// + /// $ rama http example.org # => GET + /// + /// $ rama http example.org hello=world # => POST + /// + /// # URL + /// + /// The request URL. Scheme defaults to 'http://' if the URL + /// does not include one. + /// + /// You can also use a shorthand for localhost + /// + /// $ rama http :3000 # => http://localhost:3000 + /// + /// $ rama http :/foo # => http://localhost/foo + /// + /// # REQUEST_ITEM + /// + /// Optional key-value pairs to be included in the request. The separator used + /// determines the type: + /// + /// ':' HTTP headers: + /// + /// Referer:https://ramaproxy.org Cookie:foo=bar User-Agent:rama/0.2.0 + /// + /// '==' URL parameters to be appended to the request URI: + /// + /// search==rama + /// + /// '=' Data fields to be serialized into a JSON object or form data: + /// + /// name=rama language=Rust description='CLI HTTP client' + /// + /// ':=' Non-string data fields: + /// + /// awesome:=true amount:=42 colors:='["red", "green", "blue"]' + /// + /// You can use a backslash to escape a colliding separator in the field name: + /// + /// field-name-with\:colon=value + args: Vec, +} + +// TODO in future: +// - http sessions (e.g. cookies) +// - fix bug in body print (we seem to print garbage) +// - this might to do with fact that decompressor comes later + +pub async fn run(cfg: CliCommandHttp) -> Result<(), BoxError> { + tracing_subscriber::registry() + .with(fmt::layer()) + .with( + EnvFilter::builder() + .with_default_directive( + if cfg.debug { + if cfg.verbose { + LevelFilter::TRACE + } else { + LevelFilter::DEBUG + } + } else { + LevelFilter::ERROR + } + .into(), + ) + .from_env_lossy(), + ) + .init(); + + let (tx, rx) = oneshot::channel(); + let (tx_final, rx_final) = oneshot::channel(); + + let shutdown = Shutdown::new(async move { + tokio::select! { + _ = graceful::default_signal() => { + let _ = tx_final.send(Ok(())); + } + result = rx => { + match result { + Ok(result) => { + let _ = tx_final.send(result); + } + Err(_) => { + let _ = tx_final.send(Ok(())); + } + } + } + } + }); + + shutdown.spawn_task_fn(move |guard| async move { + let result = run_inner(guard, cfg).await; + let _ = tx.send(result); + }); + + let _ = shutdown.shutdown_with_limit(Duration::from_secs(1)).await; + + rx_final.await? +} + +async fn run_inner(guard: ShutdownGuard, cfg: CliCommandHttp) -> Result<(), BoxError> { + let mut request_args_builder = if cfg.json { + RequestArgsBuilder::new_json() + } else if cfg.form { + RequestArgsBuilder::new_form() + } else { + RequestArgsBuilder::new() + }; + + for arg in cfg.args.clone() { + request_args_builder.parse_arg(arg); + } + + let request = request_args_builder.build()?; + + let client = create_client(guard, cfg.clone()).await?; + + let response = client.serve(Context::default(), request).await?; + + if cfg.check_status { + let status = response.status(); + if status.is_client_error() { + eprintln!("client error: {}", status); + std::process::exit(4); + } else if status.is_server_error() { + eprintln!("server error: {}", status); + std::process::exit(5); + } + } + + Ok(()) +} + +async fn create_client( + guard: ShutdownGuard, + mut cfg: CliCommandHttp, +) -> Result, BoxError> +where + S: Send + Sync + 'static, +{ + let (request_writer_mode, response_writer_mode) = if cfg.offline { + (Some(WriterMode::All), None) + } else if cfg.verbose { + cfg.all = true; + (Some(WriterMode::All), Some(WriterMode::All)) + } else if cfg.body { + if cfg.headers { + (None, Some(WriterMode::All)) + } else { + (None, Some(WriterMode::Body)) + } + } else if cfg.headers { + (None, Some(WriterMode::Headers)) + } else { + match &cfg.print { + Some(mode) => parse_print_mode(mode) + .map_err(OpaqueError::from_boxed) + .context("parse CLI print option")?, + None => { + if std::io::stdout().is_terminal() { + (None, Some(WriterMode::All)) + } else { + (None, Some(WriterMode::Body)) + } + } + } + }; + + let writer_kind = match cfg.output.take() { + Some(path) => writer::WriterKind::File(path.into()), + None => writer::WriterKind::Stdout, + }; + + let executor = Executor::graceful(guard); + let (request_writer, response_writer) = writer::create_traffic_writers( + &executor, + writer_kind, + cfg.all, + request_writer_mode, + response_writer_mode, + ) + .await?; + + let client_builder = ServiceBuilder::new() + .map_result(map_internal_client_error) + .layer(TimeoutLayer::new(if cfg.timeout > 0 { + Duration::from_secs(cfg.timeout) + } else { + Duration::from_secs(180) + })) + .layer(FollowRedirectLayer::with_policy(Limited::new( + if cfg.follow { cfg.max_redirects } else { 0 }, + ))) + .layer(response_writer) + .layer(DecompressionLayer::new()) + .layer( + cfg.auth + .as_deref() + .map(|auth| { + let auth = auth.trim().trim_end_matches(':'); + match cfg.auth_type.trim().to_lowercase().as_str() { + "basic" => match auth.split_once(':') { + Some((user, pass)) => AddAuthorizationLayer::basic(user, pass), + None => { + let mut terminal = + Terminal::open().expect("open terminal for password prompting"); + let password = terminal + .prompt_sensitive("password: ") + .expect("prompt password from terminal"); + AddAuthorizationLayer::basic(auth, password.as_str()) + } + }, + "bearer" => AddAuthorizationLayer::bearer(auth), + unknown => panic!("unknown auth type: {} (known: basic, bearer)", unknown), + } + }) + .unwrap_or_else(AddAuthorizationLayer::none), + ) + .layer(AddRequiredRequestHeadersLayer::default()) + .layer(request_writer) + .layer(HijackLayer::new(cfg.offline, service_fn(dummy_response))); + + let tls_client_config = + tls::create_tls_client_config(cfg.insecure, cfg.tls, cfg.cert, cfg.cert_key).await?; + + Ok(client_builder.service(HttpClient::new( + ServiceBuilder::new() + .layer(HttpsConnectorLayer::auto().with_config(tls_client_config)) + .layer(HttpProxyConnectorLayer::proxy_from_context()) + .layer(HttpsConnectorLayer::tunnel()) + .service(HttpConnector::default()), + ))) +} + +fn parse_print_mode(mode: &str) -> Result<(Option, Option), BoxError> { + let mut request_mode = None; + let mut response_mode = None; + + for c in mode.chars() { + match c { + 'h' => { + response_mode = Some(match response_mode { + Some(mode) => match mode { + WriterMode::All | WriterMode::Body => WriterMode::All, + WriterMode::Headers => WriterMode::Headers, + }, + None => WriterMode::Headers, + }); + } + 'H' => { + request_mode = Some(match request_mode { + Some(mode) => match mode { + WriterMode::All | WriterMode::Body => WriterMode::All, + WriterMode::Headers => WriterMode::Headers, + }, + None => WriterMode::Headers, + }); + } + 'b' => { + response_mode = Some(match response_mode { + Some(mode) => match mode { + WriterMode::All | WriterMode::Headers => WriterMode::All, + WriterMode::Body => WriterMode::Body, + }, + None => WriterMode::Body, + }); + } + 'B' => { + request_mode = Some(match request_mode { + Some(mode) => match mode { + WriterMode::All | WriterMode::Headers => WriterMode::All, + WriterMode::Body => WriterMode::Body, + }, + None => WriterMode::Body, + }); + } + c => return Err(error!("unknown print mode character: {}", c).into()), + } + } + + Ok((request_mode, response_mode)) +} + +async fn dummy_response(_ctx: Context, _req: Request) -> Result { + Ok(StatusCode::OK.into_response()) +} + +fn map_internal_client_error( + result: Result, E>, +) -> Result +where + E: Into, + Body: rama::http::dep::http_body::Body + Send + Sync + 'static, + Body::Error: Into, +{ + match result { + Ok(response) => Ok(response.map(rama::http::Body::new)), + Err(err) => Err(err.into()), + } +} diff --git a/rama-cli/src/http/tls.rs b/rama-cli/src/http/tls.rs new file mode 100644 index 00000000..340c6bc0 --- /dev/null +++ b/rama-cli/src/http/tls.rs @@ -0,0 +1,69 @@ +use rama::{ + error::BoxError, + tls::rustls::{ + dep::{ + pki_types::{CertificateDer, PrivateKeyDer}, + rustls::{ + version::{TLS12, TLS13}, + ClientConfig, KeyLogFile, RootCertStore, + }, + webpki_roots, + }, + verify::NoServerCertVerifier, + }, +}; +use std::sync::Arc; + +/// Create a new [`ClientConfig`] for a TLS cli client. +pub async fn create_tls_client_config( + insecure: bool, + tls_version: Option, + client_cert_path: Option, + client_key_path: Option, +) -> Result, BoxError> { + let config = if let Some(version) = tls_version { + match version.as_str() { + "1.2" => ClientConfig::builder_with_protocol_versions(&[&TLS12]), + "1.3" => ClientConfig::builder_with_protocol_versions(&[&TLS13]), + _ => return Err(format!("Unsupported TLS version: {}", version).into()), + } + } else { + ClientConfig::builder() + }; + + // TODO: allow root certs to be passed in / customised (e.g. use system roots perhaps by default?!) + let mut root_storage = RootCertStore::empty(); + root_storage.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned()); + let config = config.with_root_certificates(root_storage); + + let mut config = if let Some(client_cert_path) = client_cert_path { + let client_key_path = match client_key_path { + Some(path) => path, + None => { + return Err( + "client_key_path must be provided if client_cert_path is provided".into(), + ) + } + }; + let client_cert = tokio::fs::read(client_cert_path).await?; + let cert = CertificateDer::from(client_cert); + + let client_key = tokio::fs::read(client_key_path).await?; + let key = PrivateKeyDer::try_from(client_key)?; + config.with_client_auth_cert(vec![cert], key)? + } else { + config.with_no_client_auth() + }; + + if insecure { + config + .dangerous() + .set_certificate_verifier(Arc::new(NoServerCertVerifier::new())); + } + + if std::env::var("SSLKEYLOGFILE").is_ok() { + config.key_log = Arc::new(KeyLogFile::new()); + } + + Ok(Arc::new(config)) +} diff --git a/rama-cli/src/http/writer.rs b/rama-cli/src/http/writer.rs new file mode 100644 index 00000000..8f6fb87e --- /dev/null +++ b/rama-cli/src/http/writer.rs @@ -0,0 +1,53 @@ +use rama::{ + error::BoxError, + http::layer::traffic_writer::{ + BidirectionalMessage, BidirectionalWriter, RequestWriterLayer, ResponseWriterLayer, + WriterMode, + }, + rt::Executor, + service::util::combinators::Either, +}; +use std::path::PathBuf; +use tokio::{fs::OpenOptions, io::stdout, sync::mpsc::Sender}; + +#[derive(Debug, Clone)] +pub enum WriterKind { + Stdout, + File(PathBuf), +} + +pub async fn create_traffic_writers( + executor: &Executor, + kind: WriterKind, + all: bool, + request_mode: Option, + response_mode: Option, +) -> Result< + ( + RequestWriterLayer>>, + ResponseWriterLayer>>, + ), + BoxError, +> { + let writer = match kind { + WriterKind::Stdout => Either::A(stdout()), + WriterKind::File(path) => Either::B( + OpenOptions::new() + .create(true) + .append(true) + .open(path) + .await?, + ), + }; + + let bidirectional_writer = if all { + BidirectionalWriter::new(executor, writer, 32, request_mode, response_mode) + } else { + BidirectionalWriter::last(executor, writer, request_mode, response_mode) + }; + + Ok(( + RequestWriterLayer::new(bidirectional_writer.clone()), + ResponseWriterLayer::new(bidirectional_writer), + )) +} diff --git a/rama-cli/src/ip/mod.rs b/rama-cli/src/ip/mod.rs new file mode 100644 index 00000000..e8820abd --- /dev/null +++ b/rama-cli/src/ip/mod.rs @@ -0,0 +1,176 @@ +use clap::Args; +use rama::{ + error::BoxError, + http::{ + layer::{required_header::AddRequiredResponseHeadersLayer, trace::TraceLayer}, + server::HttpServer, + IntoResponse, Request, Response, StatusCode, + }, + proxy::pp::server::HaProxyLayer, + rt::Executor, + service::{ + layer::{ + limit::policy::{ConcurrentPolicy, UnlimitedPolicy}, + LimitLayer, TimeoutLayer, + }, + util::combinators::Either, + Context, ServiceBuilder, + }, + stream::{layer::http::BodyLimitLayer, SocketInfo, Stream}, + tcp::server::TcpListener, +}; +use std::{convert::Infallible, time::Duration}; +use tokio::io::AsyncWriteExt; +use tracing::level_filters::LevelFilter; +use tracing_subscriber::{fmt, layer::SubscriberExt, util::SubscriberInitExt, EnvFilter}; + +#[derive(Debug, Args)] +/// rama ip service (returns the ip address of the client) +pub struct CliCommandIp { + #[arg(long, short = 'p', default_value_t = 8080)] + /// the port to listen on + port: u16, + + #[arg(long, short = 'i', default_value = "127.0.0.1")] + /// the interface to listen on + interface: String, + + #[arg(long, short = 'c', default_value_t = 0)] + /// the number of concurrent connections to allow (0 = no limit) + concurrent: usize, + + #[arg(long, short = 't', default_value = "8")] + /// the timeout in seconds for each connection (0 = default timeout of 30s) + timeout: u64, + + #[arg(long, short = 'a')] + /// enable HaProxy PROXY Protocol + ha_proxy: bool, + + #[arg(long, short = 'T')] + /// operate the IP service on transport layer (tcp) + transport: bool, +} + +pub async fn run(cfg: CliCommandIp) -> Result<(), BoxError> { + tracing_subscriber::registry() + .with(fmt::layer()) + .with( + EnvFilter::builder() + .with_default_directive(LevelFilter::INFO.into()) + .from_env_lossy(), + ) + .init(); + + let graceful = rama::utils::graceful::Shutdown::default(); + + let address = format!("{}:{}", cfg.interface, cfg.port); + tracing::info!("starting ip service on: {}", address); + + graceful.spawn_task_fn(move |guard| async move { + let tcp_listener = TcpListener::build() + .bind(address) + .await + .expect("bind ip service to 127.0.0.1:62001"); + + let tcp_service_builder = ServiceBuilder::new() + .layer(LimitLayer::new(if cfg.concurrent > 0 { + Either::A(ConcurrentPolicy::max(cfg.concurrent)) + } else { + Either::B(UnlimitedPolicy::default()) + })) + .layer(TimeoutLayer::new(if cfg.timeout > 0 { + Duration::from_secs(cfg.timeout) + } else { + Duration::from_secs(30) + })) + .layer((cfg.ha_proxy).then(HaProxyLayer::default)); + + // TODO document how one would force IPv4 or IPv6 + + // TODO: support opt-in TLS + + if cfg.transport { + let tcp_service = tcp_service_builder.service(IpTransportEchoService); + + tracing::info!("ip service ready"); + + tcp_listener.serve_graceful(guard, tcp_service).await; + } else { + let http_service = ServiceBuilder::new() + .layer(TraceLayer::new_for_http()) + .layer(AddRequiredResponseHeadersLayer::default()) + .service_fn(ip); + + let tcp_service = tcp_service_builder + // Limit the body size to 1MB for requests + .layer(BodyLimitLayer::request_only(1024 * 1024)) + .service(HttpServer::auto(Executor::graceful(guard.clone())).service(http_service)); + + tracing::info!("ip service ready"); + + tcp_listener.serve_graceful(guard, tcp_service).await; + } + }); + + graceful + .shutdown_with_limit(Duration::from_secs(30)) + .await?; + + Ok(()) +} + +pub async fn ip(ctx: Context, _: Request) -> Result +where + State: Send + Sync + 'static, +{ + Ok( + match ctx.get::().map(|v| v.peer_addr().to_string()) { + Some(ip) => ip.into_response(), + None => StatusCode::INTERNAL_SERVER_ERROR.into_response(), + }, + ) +} + +#[derive(Debug, Clone)] +struct IpTransportEchoService; + +impl rama::service::Service for IpTransportEchoService +where + State: Send + Sync + 'static, + Input: Stream, +{ + type Response = (); + type Error = BoxError; + + async fn serve( + &self, + ctx: rama::service::Context, + stream: Input, + ) -> Result { + let socket_info = match ctx.get::() { + Some(socket_info) => socket_info, + None => { + tracing::error!("missing socket info"); + return Ok(()); + } + }; + + let mut stream = std::pin::pin!(stream); + + match socket_info.peer_addr().ip() { + std::net::IpAddr::V4(ip) => { + if let Err(err) = stream.write_all(&ip.octets()).await { + tracing::error!("error writing IPv4 of peer to peer: {}", err); + } + } + std::net::IpAddr::V6(ip) => { + if let Err(err) = stream.write_all(&ip.octets()).await { + tracing::error!("error writing IPv6 of peer to peer: {}", err); + } + } + }; + + Ok(()) + } +} diff --git a/rama-cli/src/main.rs b/rama-cli/src/main.rs index 6774208d..34073ae9 100644 --- a/rama-cli/src/main.rs +++ b/rama-cli/src/main.rs @@ -1,11 +1,43 @@ -use argh::FromArgs; +use clap::{Parser, Subcommand}; +use rama::error::BoxError; -#[derive(Debug, FromArgs)] -/// a distortion proxy cli -struct Cli {} +mod echo; +use echo::CliCommandEcho; + +mod http; +use http::CliCommandHttp; + +mod proxy; +use proxy::CliCommandProxy; + +mod ip; +use ip::CliCommandIp; + +#[derive(Debug, Parser)] +#[command(name = "rama")] +#[command(bin_name = "rama")] +#[command(version, about, long_about = None)] +struct Cli { + #[command(subcommand)] + cmds: CliCommands, +} + +#[derive(Debug, Subcommand)] +enum CliCommands { + Echo(CliCommandEcho), + Http(CliCommandHttp), + Proxy(CliCommandProxy), + Ip(CliCommandIp), +} #[tokio::main] -async fn main() -> anyhow::Result<()> { - let _: Cli = argh::from_env(); - Ok(()) +async fn main() -> Result<(), BoxError> { + let cli = Cli::parse(); + + match cli.cmds { + CliCommands::Echo(cfg) => echo::run(cfg).await, + CliCommands::Http(cfg) => http::run(cfg).await, + CliCommands::Proxy(cfg) => proxy::run(cfg).await, + CliCommands::Ip(cfg) => ip::run(cfg).await, + } } diff --git a/rama-cli/src/proxy/mod.rs b/rama-cli/src/proxy/mod.rs new file mode 100644 index 00000000..0200bd7e --- /dev/null +++ b/rama-cli/src/proxy/mod.rs @@ -0,0 +1,170 @@ +use clap::Args; +use rama::{ + error::BoxError, + http::{ + client::HttpClient, + layer::{ + remove_header::{RemoveRequestHeaderLayer, RemoveResponseHeaderLayer}, + trace::TraceLayer, + upgrade::{UpgradeLayer, Upgraded}, + }, + matcher::MethodMatcher, + server::HttpServer, + Body, IntoResponse, Request, RequestContext, Response, StatusCode, + }, + rt::Executor, + service::{ + layer::{limit::policy::ConcurrentPolicy, LimitLayer, TimeoutLayer}, + service_fn, Context, Service, ServiceBuilder, + }, + stream::layer::http::BodyLimitLayer, + tcp::{server::TcpListener, utils::is_connection_error}, +}; +use std::{convert::Infallible, time::Duration}; +use tracing::level_filters::LevelFilter; +use tracing_subscriber::{fmt, layer::SubscriberExt, util::SubscriberInitExt, EnvFilter}; + +#[derive(Debug, Args)] +/// rama proxy runner +pub struct CliCommandProxy { + #[arg(long, short = 'p', default_value_t = 8080)] + /// the port to listen on + port: u16, + + #[arg(long, short = 'i', default_value = "127.0.0.1")] + /// the interface to listen on + interface: String, + + #[arg(long, short = 'c', default_value_t = 0)] + /// the number of concurrent connections to allow (0 = no limit) + concurrent: usize, + + #[arg(long, short = 't', default_value_t = 8)] + /// the timeout in seconds for each connection (0 = no timeout) + timeout: u64, +} + +pub async fn run(cfg: CliCommandProxy) -> Result<(), BoxError> { + tracing_subscriber::registry() + .with(fmt::layer()) + .with( + EnvFilter::builder() + .with_default_directive(LevelFilter::INFO.into()) + .from_env_lossy(), + ) + .init(); + + let graceful = rama::utils::graceful::Shutdown::default(); + + let address = format!("{}:{}", cfg.interface, cfg.port); + tracing::info!("starting proxy on: {}", address); + + graceful.spawn_task_fn(move |guard| async move { + let tcp_service = TcpListener::build() + .bind(address) + .await + .expect("bind proxy to 127.0.0.1:62001"); + + let exec = Executor::graceful(guard.clone()); + let http_service = HttpServer::auto(exec).service( + ServiceBuilder::new() + .layer(TraceLayer::new_for_http()) + .layer(UpgradeLayer::new( + MethodMatcher::CONNECT, + service_fn(http_connect_accept), + service_fn(http_connect_proxy), + )) + .service( + ServiceBuilder::new() + .layer(RemoveResponseHeaderLayer::hop_by_hop()) + .layer(RemoveRequestHeaderLayer::hop_by_hop()) + .service_fn(http_plain_proxy), + ), + ); + + let tcp_service_builder = ServiceBuilder::new() + // protect the http proxy from too large bodies, both from request and response end + .layer(BodyLimitLayer::symmetric(2 * 1024 * 1024)) + .layer( + (cfg.concurrent > 0) + .then(|| LimitLayer::new(ConcurrentPolicy::max(cfg.concurrent))), + ) + .layer((cfg.timeout > 0).then(|| TimeoutLayer::new(Duration::from_secs(cfg.timeout)))); + + tcp_service + .serve_graceful(guard, tcp_service_builder.service(http_service)) + .await; + }); + + graceful + .shutdown_with_limit(Duration::from_secs(30)) + .await?; + + Ok(()) +} + +async fn http_connect_accept( + mut ctx: Context, + req: Request, +) -> Result<(Response, Context, Request), Response> +where + S: Send + Sync + 'static, +{ + match ctx + .get_or_insert_with::(|| RequestContext::from(&req)) + .host + .as_ref() + { + Some(host) => tracing::info!("accept CONNECT to {host}"), + None => { + tracing::error!("error extracting host"); + return Err(StatusCode::BAD_REQUEST.into_response()); + } + } + + Ok((StatusCode::OK.into_response(), ctx, req)) +} + +async fn http_connect_proxy(ctx: Context, mut upgraded: Upgraded) -> Result<(), Infallible> +where + S: Send + Sync + 'static, +{ + let host = ctx + .get::() + .unwrap() + .host + .as_ref() + .unwrap() + .clone(); + tracing::info!("CONNECT to {}", host); + let mut stream = match tokio::net::TcpStream::connect(&host).await { + Ok(stream) => stream, + Err(err) => { + tracing::error!(error = %err, "error connecting to host"); + return Ok(()); + } + }; + if let Err(err) = tokio::io::copy_bidirectional(&mut upgraded, &mut stream).await { + if !is_connection_error(&err) { + tracing::error!(error = %err, "error copying data"); + } + } + Ok(()) +} + +async fn http_plain_proxy(ctx: Context, req: Request) -> Result +where + S: Send + Sync + 'static, +{ + let client = HttpClient::default(); + match client.serve(ctx, req).await { + Ok(resp) => Ok(resp), + Err(err) => { + tracing::error!(error = %err, "error in client request"); + Ok(Response::builder() + .status(StatusCode::INTERNAL_SERVER_ERROR) + .body(Body::empty()) + .unwrap()) + } + } +} diff --git a/rama-fp/Cargo.toml b/rama-fp/Cargo.toml index 87836c7f..9acdbee7 100644 --- a/rama-fp/Cargo.toml +++ b/rama-fp/Cargo.toml @@ -12,12 +12,10 @@ rust-version = { workspace = true } default-run = "rama-fp" [dependencies] -anyhow = { workspace = true } -argh = { workspace = true } base64 = { workspace = true } -rama = { version = "0.2", path = "..", features = ["full"] } +clap = { workspace = true } +rama = { version = "0.2.0-alpha.0", path = "..", features = ["full"] } serde = { workspace = true } -serde_html_form = { workspace = true } serde_json = { workspace = true } tokio = { workspace = true, features = ["rt-multi-thread", "macros"] } tracing = { workspace = true } diff --git a/rama-fp/src/main.rs b/rama-fp/src/main.rs index bbfdfbd9..4092e099 100644 --- a/rama-fp/src/main.rs +++ b/rama-fp/src/main.rs @@ -1,40 +1,42 @@ -use argh::FromArgs; +use clap::{Args, Parser, Subcommand}; +use rama::error::BoxError; pub mod service; -#[derive(Debug, FromArgs)] -/// a fingerprinting service for rama +#[derive(Debug, Parser)] +#[command(name = "rama-fp")] +#[command(bin_name = "rama-fp")] +#[command(version, about, long_about = None)] struct Cli { /// the interface to listen on - #[argh(option, short = 'i', default = "String::from(\"127.0.0.1\")")] + #[arg(long, short = 'i', default_value = "127.0.0.1")] interface: String, /// the port to listen on - #[argh(option, short = 'p', default = "8080")] + #[arg(long, short = 'p', default_value_t = 8080)] port: u16, /// the port to listen on for the TLS service - #[argh(option, short = 's', default = "8443")] + #[arg(long, short = 's', default_value_t = 8443)] secure_port: u16, /// the port to listen on for the TLS service - #[argh(option, short = 't', default = "9091")] + #[arg(long, short = 't', default_value_t = 9091)] prometheus_port: u16, /// http version to serve FP Service from - #[argh(option, default = "String::from(\"auto\")")] + #[arg(long, default_value = "auto")] http_version: String, /// serve as an HaProxy - #[argh(switch, short = 'f')] + #[arg(long, short = 'f')] ha_proxy: bool, - #[argh(subcommand)] + #[command(subcommand)] command: Option, } -#[derive(Debug, FromArgs)] -#[argh(subcommand)] +#[derive(Debug, Subcommand)] enum Commands { Run(RunSubCommand), Echo(EchoSubCommand), @@ -46,19 +48,17 @@ impl Default for Commands { } } -#[derive(FromArgs, Debug)] +#[derive(Debug, Args)] /// Run the regular FP Server -#[argh(subcommand, name = "run")] -struct RunSubCommand {} +struct RunSubCommand; -#[derive(FromArgs, Debug)] +#[derive(Debug, Args)] /// Run an echo server -#[argh(subcommand, name = "echo")] -struct EchoSubCommand {} +struct EchoSubCommand; #[tokio::main] -async fn main() -> anyhow::Result<()> { - let args: Cli = argh::from_env(); +async fn main() -> Result<(), BoxError> { + let args = Cli::parse(); match args.command.unwrap_or_default() { Commands::Run(_) => { diff --git a/rama-fp/src/service/mod.rs b/rama-fp/src/service/mod.rs index 380d6254..35cbb434 100644 --- a/rama-fp/src/service/mod.rs +++ b/rama-fp/src/service/mod.rs @@ -1,11 +1,11 @@ use base64::Engine as _; use rama::{ + error::BoxError, http::{ - headers::Server, layer::{ catch_panic::CatchPanicLayer, compression::CompressionLayer, - opentelemetry::RequestMetricsLayer, set_header::SetResponseHeaderLayer, - trace::TraceLayer, + opentelemetry::RequestMetricsLayer, required_header::AddRequiredResponseHeadersLayer, + set_header::SetResponseHeaderLayer, trace::TraceLayer, }, matcher::HttpMatcher, response::Redirect, @@ -17,10 +17,10 @@ use rama::{ rt::Executor, service::{ layer::{ - limit::policy::ConcurrentPolicy, HijackLayer, LimitLayer, MapErrLayer, TimeoutLayer, + limit::policy::ConcurrentPolicy, ConsumeErrLayer, HijackLayer, LimitLayer, TimeoutLayer, }, service_fn, - util::{backoff::ExponentialBackoff, combinators::Either}, + util::backoff::ExponentialBackoff, ServiceBuilder, }, stream::layer::{http::BodyLimitLayer, opentelemetry::NetworkMetricsLayer}, @@ -61,7 +61,7 @@ pub struct Config { pub ha_proxy: bool, } -pub async fn run(cfg: Config) -> anyhow::Result<()> { +pub async fn run(cfg: Config) -> Result<(), BoxError> { tracing_subscriber::registry() .with(fmt::layer()) .with( @@ -159,7 +159,7 @@ pub async fn run(cfg: Config) -> anyhow::Result<()> { .layer(RequestMetricsLayer::default()) .layer(CompressionLayer::new()) .layer(CatchPanicLayer::new()) - .layer(SetResponseHeaderLayer::overriding_typed(format!("{}/{}", rama::utils::info::NAME, rama::utils::info::VERSION).parse::().unwrap())) + .layer(AddRequiredResponseHeadersLayer::default()) .layer(SetResponseHeaderLayer::overriding( HeaderName::from_static("x-sponsored-by"), HeaderValue::from_static("fly.io"), @@ -193,12 +193,7 @@ pub async fn run(cfg: Config) -> anyhow::Result<()> { ); let tcp_service_builder = ServiceBuilder::new() - .map_result(|result| { - if let Err(err) = result { - tracing::warn!(error = %err, "rama service failed"); - } - Ok::<_, Infallible>(()) - }) + .layer(ConsumeErrLayer::trace(tracing::Level::WARN)) .layer(NetworkMetricsLayer::default()) .layer(TimeoutLayer::new(Duration::from_secs(16))) .layer(LimitLayer::new(ConcurrentPolicy::max_with_backoff( @@ -219,11 +214,8 @@ pub async fn run(cfg: Config) -> anyhow::Result<()> { let http_service = http_service.clone(); - let tcp_service_builder = if ha_proxy { - tcp_service_builder.clone().layer(Either::A(HaProxyLayer::default())) - } else { - tcp_service_builder.clone().layer(Either::B(MapErrLayer::new(Into::into))) - }; + let tcp_service_builder = tcp_service_builder.clone() + .layer(ha_proxy.then(HaProxyLayer::default)); // create tls service builder let server_config = @@ -282,11 +274,8 @@ pub async fn run(cfg: Config) -> anyhow::Result<()> { }); } - let tcp_service_builder = if ha_proxy { - tcp_service_builder.layer(Either::A(HaProxyLayer::default())) - } else { - tcp_service_builder.layer(Either::B(MapErrLayer::new(Into::into))) - }; + let tcp_service_builder = tcp_service_builder + .layer(ha_proxy.then(HaProxyLayer::default)); let tcp_listener = TcpListener::build_with_state(State::new(acme_data)) .bind(&http_address) @@ -353,7 +342,7 @@ pub async fn run(cfg: Config) -> anyhow::Result<()> { Ok(()) } -pub async fn echo(cfg: Config) -> anyhow::Result<()> { +pub async fn echo(cfg: Config) -> Result<(), BoxError> { tracing_subscriber::registry() .with(fmt::layer()) .with( @@ -408,7 +397,7 @@ pub async fn echo(cfg: Config) -> anyhow::Result<()> { .layer(RequestMetricsLayer::default()) .layer(CompressionLayer::new()) .layer(CatchPanicLayer::new()) - .layer(SetResponseHeaderLayer::overriding_typed(format!("{}/{}", rama::utils::info::NAME, rama::utils::info::VERSION).parse::().unwrap())) + .layer(AddRequiredResponseHeadersLayer::default()) .layer(SetResponseHeaderLayer::overriding( HeaderName::from_static("x-sponsored-by"), HeaderValue::from_static("fly.io"), @@ -422,12 +411,7 @@ pub async fn echo(cfg: Config) -> anyhow::Result<()> { ); let tcp_service_builder = ServiceBuilder::new() - .map_result(|result| { - if let Err(err) = result { - tracing::warn!(error = %err, "rama service failed"); - } - Ok::<_, Infallible>(()) - }) + .layer(ConsumeErrLayer::trace(tracing::Level::WARN)) .layer(NetworkMetricsLayer::default()) .layer(TimeoutLayer::new(Duration::from_secs(16))) // Why the below layer makes it no longer cloneable?!?! @@ -449,11 +433,8 @@ pub async fn echo(cfg: Config) -> anyhow::Result<()> { let http_service = http_service.clone(); - let tcp_service_builder = if ha_proxy { - tcp_service_builder.clone().layer(Either::A(HaProxyLayer::default())) - } else { - tcp_service_builder.clone().layer(Either::B(MapErrLayer::new(Into::into))) - }; + let tcp_service_builder = tcp_service_builder.clone() + .layer(ha_proxy.then(HaProxyLayer::default)); // create tls service builder let server_config = @@ -517,11 +498,8 @@ pub async fn echo(cfg: Config) -> anyhow::Result<()> { .await .expect("bind TCP Listener"); - let tcp_service_builder = if ha_proxy { - tcp_service_builder.layer(Either::A(HaProxyLayer::default())) - } else { - tcp_service_builder.layer(Either::B(MapErrLayer::new(Into::into))) - }; + let tcp_service_builder = tcp_service_builder + .layer(ha_proxy.then(HaProxyLayer::default)); match cfg.http_version.as_str() { "" | "auto" => { @@ -587,7 +565,7 @@ async fn get_server_config( tls_cert_pem_raw: String, tls_key_pem_raw: String, http_version: &str, -) -> anyhow::Result { +) -> Result { // server TLS Certs let tls_cert_pem_raw = BASE64.decode(tls_cert_pem_raw.as_bytes())?; let mut pem = BufReader::new(&tls_cert_pem_raw[..]); diff --git a/src/cli/args.rs b/src/cli/args.rs new file mode 100644 index 00000000..3dd50442 --- /dev/null +++ b/src/cli/args.rs @@ -0,0 +1,599 @@ +//! build requests from command line arguments + +use crate::{ + error::{ErrorContext, OpaqueError}, + http::{ + header::{Entry, HeaderValue, ACCEPT, CONTENT_LENGTH, CONTENT_TYPE}, + Body, Method, Request, Uri, + }, +}; +use serde_json::Value; +use std::collections::HashMap; + +#[derive(Debug, Clone)] +/// A builder to create a request from command line arguments. +pub struct RequestArgsBuilder { + state: BuilderState, +} + +impl Default for RequestArgsBuilder { + fn default() -> Self { + Self::new() + } +} + +impl RequestArgsBuilder { + /// Create a new [`RequestArgsBuilder`], which auto-detects the content type. + pub fn new() -> Self { + Self { + state: BuilderState::MethodOrUrl { content_type: None }, + } + } + + /// Create a new [`RequestArgsBuilder`], which expects JSON data. + pub fn new_json() -> RequestArgsBuilder { + RequestArgsBuilder { + state: BuilderState::MethodOrUrl { + content_type: Some(ContentType::Json), + }, + } + } + + /// Create a new [`RequestArgsBuilder`], which expects Form data. + pub fn new_form() -> RequestArgsBuilder { + RequestArgsBuilder { + state: BuilderState::MethodOrUrl { + content_type: Some(ContentType::Form), + }, + } + } + + /// parse a command line argument, the possible meaning + /// depend on the current state of the builder, driven by the position of the argument. + pub fn parse_arg(&mut self, arg: String) { + let new_state = match &mut self.state { + BuilderState::MethodOrUrl { content_type } => { + if let Some(method) = parse_arg_as_method(&arg) { + Some(BuilderState::Url { + content_type: *content_type, + method: Some(method), + }) + } else { + Some(BuilderState::Data { + content_type: *content_type, + method: None, + url: arg, + query: HashMap::new(), + headers: HashMap::new(), + body: HashMap::new(), + }) + } + } + BuilderState::Url { + content_type, + method, + } => Some(BuilderState::Data { + content_type: *content_type, + method: method.clone(), + url: arg, + query: HashMap::new(), + headers: HashMap::new(), + body: HashMap::new(), + }), + BuilderState::Data { + ref mut query, + ref mut headers, + ref mut body, + .. + } => match parse_arg_as_data(arg, query, headers, body) { + Ok(_) => None, + Err(msg) => Some(BuilderState::Error { + message: msg, + ignored: vec![], + }), + }, + BuilderState::Error { + ref mut ignored, .. + } => { + ignored.push(arg); + None + } + }; + if let Some(new_state) = new_state { + self.state = new_state; + } + } + + /// Build the request from the parsed arguments. + pub fn build(self) -> Result { + match self.state { + BuilderState::MethodOrUrl { .. } | BuilderState::Url { .. } => { + Err(OpaqueError::from_display("no url defined")) + } + BuilderState::Error { message, ignored } => { + Err(OpaqueError::from_display(if ignored.is_empty() { + format!("request arg parser failed: {}", message) + } else { + format!( + "request arg parser failed: {} (ignored: {:?})", + message, ignored + ) + })) + } + BuilderState::Data { + content_type, + method, + url, + query, + headers, + body, + } => { + let mut req = Request::builder(); + + let url = expand_url(url); + + let uri: Uri = url + .parse() + .map_err(OpaqueError::from_std) + .context("parse base uri")?; + + if query.is_empty() { + req = req.uri(url); + } else { + let mut uri_parts = uri.into_parts(); + uri_parts.path_and_query = Some(match uri_parts.path_and_query { + Some(pq) => match pq.query() { + Some(q) => { + let mut existing_query: HashMap> = + serde_html_form::from_str(q) + .map_err(OpaqueError::from_std) + .context("parse existing query")?; + for (k, v) in query { + existing_query.entry(k).or_default().extend(v); + } + let query = serde_html_form::to_string(&existing_query) + .map_err(OpaqueError::from_std) + .context("serialize extended query")?; + format!("{}?{}", pq.path(), query) + .parse() + .map_err(OpaqueError::from_std) + .context("create new path+query from extended query")? + } + None => { + let query = serde_html_form::to_string(&query) + .map_err(OpaqueError::from_std) + .context("serialize new and only query params")?; + format!("{}?{}", pq.path(), query) + .parse() + .map_err(OpaqueError::from_std) + .context("create path+query from given query params")? + } + }, + None => { + let query = serde_html_form::to_string(&query) + .map_err(OpaqueError::from_std)?; + format!("/?{}", query) + .parse() + .map_err(OpaqueError::from_std)? + } + }); + req = req.uri(Uri::from_parts(uri_parts).map_err(OpaqueError::from_std)?); + } + + match method { + Some(method) => req = req.method(method), + None => { + if body.is_empty() { + req = req.method(Method::GET); + } else { + req = req.method(Method::POST); + } + } + } + for (name, value) in headers { + req = req.header(name, value); + } + + if body.is_empty() { + return req + .body(Body::empty()) + .map_err(OpaqueError::from_std) + .context("create request without body"); + } + + let ct = content_type.unwrap_or_else(|| { + match req + .headers_ref() + .and_then(|h| h.get(CONTENT_TYPE)) + .and_then(|h| h.to_str().ok()) + { + Some(cv) if cv.contains("application/x-www-form-urlencoded") => { + ContentType::Form + } + _ => ContentType::Json, + } + }); + + let req = if req.headers_ref().is_none() { + let req = req.header(CONTENT_TYPE, ct.header_value()); + if ct == ContentType::Json { + req.header(ACCEPT, ct.header_value()) + } else { + req + } + } else { + let headers = req.headers_mut().unwrap(); + + if let Entry::Vacant(entry) = headers.entry(CONTENT_TYPE) { + entry.insert(ct.header_value()); + } + + if ct == ContentType::Json { + if let Entry::Vacant(entry) = headers.entry(ACCEPT) { + entry.insert(ct.header_value()); + } + } + + req + }; + + match ct { + ContentType::Json => { + let body = serde_json::to_string(&body) + .map_err(OpaqueError::from_std) + .context("serialize form body")?; + req.header(CONTENT_LENGTH, body.len().to_string()) + .body(Body::from(body)) + } + ContentType::Form => { + let body = serde_html_form::to_string(&body) + .map_err(OpaqueError::from_std) + .context("serialize json body")?; + req.header(CONTENT_LENGTH, body.len().to_string()) + .body(Body::from(body)) + } + } + .map_err(OpaqueError::from_std) + .context("create request with body") + } + } + } +} + +fn parse_arg_as_data( + arg: String, + query: &mut HashMap>, + headers: &mut HashMap, + body: &mut HashMap, +) -> Result<(), String> { + let mut state = DataParseArgState::None; + for (i, c) in arg.chars().enumerate() { + match state { + DataParseArgState::None => match c { + '\\' => state = DataParseArgState::Escaped, + '=' => state = DataParseArgState::Equal, + ':' => state = DataParseArgState::Colon, + _ => (), + }, + DataParseArgState::Escaped => { + // \* + state = DataParseArgState::None; + } + DataParseArgState::Equal => { + let (name, value) = arg.split_at(i - 1); + if c == '=' { + // == + let value = &value[2..]; + query + .entry(name.to_owned()) + .or_default() + .push(value.to_owned()); + } else { + // = + let value = &value[1..]; + body.insert(name.to_owned(), Value::String(value.to_owned())); + } + break; + } + DataParseArgState::Colon => { + let (name, value) = arg.split_at(i - 1); + if c == '=' { + // := + let value = &value[2..]; + let value: Value = + serde_json::from_str(value).map_err(|err| err.to_string())?; + body.insert(name.to_owned(), value); + } else { + // : + let value = &value[1..]; + headers.insert(name.to_owned(), value.to_owned()); + } + break; + } + } + } + Ok(()) +} + +fn parse_arg_as_method(arg: impl AsRef) -> Option { + match_ignore_ascii_case_str! { + match (arg.as_ref()) { + "GET" => Some(Method::GET), + "POST" => Some(Method::POST), + "PUT" => Some(Method::PUT), + "DELETE" => Some(Method::DELETE), + "PATCH" => Some(Method::PATCH), + "HEAD" => Some(Method::HEAD), + "OPTIONS" => Some(Method::OPTIONS), + "CONNECT" => Some(Method::CONNECT), + "TRACE" => Some(Method::TRACE), + _ => None, + + } + } +} + +/// Expand a URL string to a full URL, +/// e.g. `example.com` -> `http://example.com` +fn expand_url(url: String) -> String { + if url.is_empty() { + "http://localhost".to_owned() + } else if let Some(stripped_url) = url.strip_prefix(':') { + if stripped_url.is_empty() { + "http://localhost".to_owned() + } else if stripped_url + .chars() + .next() + .map(|c| c.is_ascii_digit()) + .unwrap_or_default() + { + format!("http://localhost{}", url) + } else { + format!("http://localhost{}", stripped_url) + } + } else if !url.contains("://") { + format!("http://{}", url) + } else { + url.to_string() + } +} + +enum DataParseArgState { + None, + Escaped, + Equal, + Colon, +} + +#[derive(Debug, Clone, Copy, PartialEq)] +enum ContentType { + Json, + Form, +} + +impl ContentType { + fn header_value(&self) -> HeaderValue { + HeaderValue::from_static(match self { + ContentType::Json => "application/json", + ContentType::Form => "application/x-www-form-urlencoded", + }) + } +} + +#[derive(Debug, Clone)] +enum BuilderState { + MethodOrUrl { + content_type: Option, + }, + Url { + content_type: Option, + method: Option, + }, + Data { + content_type: Option, + method: Option, + url: String, + query: HashMap>, + headers: HashMap, + body: HashMap, + }, + Error { + message: String, + ignored: Vec, + }, +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::http::io::write_http_request; + + #[test] + fn test_parse_arg_as_method() { + for (arg, expected) in [ + ("GET", Some(Method::GET)), + ("POST", Some(Method::POST)), + ("PUT", Some(Method::PUT)), + ("DELETE", Some(Method::DELETE)), + ("PATCH", Some(Method::PATCH)), + ("HEAD", Some(Method::HEAD)), + ("OPTIONS", Some(Method::OPTIONS)), + ("CONNECT", Some(Method::CONNECT)), + ("TRACE", Some(Method::TRACE)), + ("get", Some(Method::GET)), + ("post", Some(Method::POST)), + ("put", Some(Method::PUT)), + ("delete", Some(Method::DELETE)), + ("patch", Some(Method::PATCH)), + ("head", Some(Method::HEAD)), + ("options", Some(Method::OPTIONS)), + ("connect", Some(Method::CONNECT)), + ("trace", Some(Method::TRACE)), + ("invalid", None), + ("", None), + ] { + assert_eq!(parse_arg_as_method(arg), expected); + } + } + + #[test] + fn test_expand_url() { + for (url, expected) in [ + ("example.com", "http://example.com"), + ("http://example.com", "http://example.com"), + ("https://example.com", "https://example.com"), + ("example.com:8080", "http://example.com:8080"), + (":8080/foo", "http://localhost:8080/foo"), + (":8080", "http://localhost:8080"), + ("", "http://localhost"), + ] { + assert_eq!(expand_url(url.to_owned()), expected); + } + } + + #[tokio::test] + async fn test_request_args_builder_happy() { + for (args, expected_request_str) in [ + (vec![":8080"], "GET / HTTP/1.1\r\n\r\n"), + (vec!["HeAD", ":8000/foo"], "HEAD /foo HTTP/1.1\r\n\r\n"), + ( + vec![ + "example.com/foo", + "c=d", + "Content-Type:application/x-www-form-urlencoded", + ], + "POST /foo HTTP/1.1\r\ncontent-type: application/x-www-form-urlencoded\r\ncontent-length: 3\r\n\r\nc=d", + ), + ( + vec![ + "example.com/foo", + "a=b", + "Content-Type:application/json", + ], + "POST /foo HTTP/1.1\r\ncontent-type: application/json\r\naccept: application/json\r\ncontent-length: 9\r\n\r\n{\"a\":\"b\"}", + ), + ( + vec![ + "example.com/foo", + "a=b", + ], + "POST /foo HTTP/1.1\r\ncontent-type: application/json\r\naccept: application/json\r\ncontent-length: 9\r\n\r\n{\"a\":\"b\"}", + ), + ( + vec![ + "example.com/foo", + "x-a:1", + "a=b", + ], + "POST /foo HTTP/1.1\r\nx-a: 1\r\ncontent-type: application/json\r\naccept: application/json\r\ncontent-length: 9\r\n\r\n{\"a\":\"b\"}", + ), + ( + vec![ + "put", + "example.com/foo?a=2", + "x-a:1", + "a:=42", + "a==3" + ], + "PUT /foo?a=2&a=3 HTTP/1.1\r\nx-a: 1\r\ncontent-type: application/json\r\naccept: application/json\r\ncontent-length: 8\r\n\r\n{\"a\":42}", + ), + ( + vec![ + ":3000", + "Cookie:foo=bar", + ], + "GET / HTTP/1.1\r\ncookie: foo=bar\r\n\r\n", + ), + ( + vec![ + ":/foo", + "search==rama", + ], + "GET /foo?search=rama HTTP/1.1\r\n\r\n", + ), + ( + vec![ + "example.com", + "description='CLI HTTP client'", + ], + "POST / HTTP/1.1\r\ncontent-type: application/json\r\naccept: application/json\r\ncontent-length: 35\r\n\r\n{\"description\":\"'CLI HTTP client'\"}", + ) + ] { + let mut builder = RequestArgsBuilder::new(); + for arg in args { + builder.parse_arg(arg.to_owned()); + } + let request = builder.build().unwrap(); + let mut w = Vec::new(); + let _ = write_http_request(&mut w, request, true, true) + .await + .unwrap(); + assert_eq!(String::from_utf8(w).unwrap(), expected_request_str); + } + } + + #[tokio::test] + async fn test_request_args_builder_form_happy() { + for (args, expected_request_str) in [ + ( + vec![ + "example.com/foo", + "c=d", + ], + "POST /foo HTTP/1.1\r\ncontent-type: application/x-www-form-urlencoded\r\ncontent-length: 3\r\n\r\nc=d", + ), + ] { + let mut builder = RequestArgsBuilder::new_form(); + for arg in args { + builder.parse_arg(arg.to_owned()); + } + let request = builder.build().unwrap(); + let mut w = Vec::new(); + let _ = write_http_request(&mut w, request, true, true) + .await + .unwrap(); + assert_eq!(String::from_utf8(w).unwrap(), expected_request_str); + } + } + + #[tokio::test] + async fn test_request_args_builder_json_happy() { + for (args, expected_request_str) in [ + ( + vec![ + "example.com/foo", + "a=b", + ], + "POST /foo HTTP/1.1\r\ncontent-type: application/json\r\naccept: application/json\r\ncontent-length: 9\r\n\r\n{\"a\":\"b\"}", + ), + ] { + let mut builder = RequestArgsBuilder::new(); + for arg in args { + builder.parse_arg(arg.to_owned()); + } + let request = builder.build().unwrap(); + let mut w = Vec::new(); + let _ = write_http_request(&mut w, request, true, true) + .await + .unwrap(); + assert_eq!(String::from_utf8(w).unwrap(), expected_request_str); + } + } + + #[tokio::test] + async fn test_request_args_builder_error() { + for test in [ + vec![], + vec!["invalid url"], + vec!["get"], + vec!["get", "invalid url"], + ] { + let mut builder = RequestArgsBuilder::new(); + for arg in test { + builder.parse_arg(arg.to_owned()); + } + let request = builder.build(); + assert!(request.is_err()); + } + } +} diff --git a/src/cli/mod.rs b/src/cli/mod.rs new file mode 100644 index 00000000..4964965d --- /dev/null +++ b/src/cli/mod.rs @@ -0,0 +1,3 @@ +//! rama cli utilities + +pub mod args; diff --git a/src/error/ext/wrapper.rs b/src/error/ext/wrapper.rs index 48887d8a..9f5af1c6 100644 --- a/src/error/ext/wrapper.rs +++ b/src/error/ext/wrapper.rs @@ -87,6 +87,12 @@ impl std::error::Error for OpaqueError { } } +impl From for OpaqueError { + fn from(error: BoxError) -> Self { + Self(error) + } +} + #[repr(transparent)] /// An error type that wraps a message. pub(crate) struct MessageError(pub(crate) M); diff --git a/src/http/client/error.rs b/src/http/client/error.rs index cfe78c06..1d23917e 100644 --- a/src/http/client/error.rs +++ b/src/http/client/error.rs @@ -68,3 +68,12 @@ impl std::error::Error for HttpClientError { self.inner.source() } } + +impl From for HttpClientError { + fn from(err: BoxError) -> Self { + Self { + inner: OpaqueError::from_boxed(err), + uri: None, + } + } +} diff --git a/src/http/client/ext.rs b/src/http/client/ext.rs index 23ca9a8a..f7660603 100644 --- a/src/http/client/ext.rs +++ b/src/http/client/ext.rs @@ -647,7 +647,7 @@ where /// /// This method fails if there was an error while sending [`Request`]. pub async fn send(self, ctx: Context) -> Result, HttpClientError> { - let mut request = match self.state { + let request = match self.state { RequestBuilderState::PreBody(builder) => builder .body(crate::http::Body::empty()) .map_err(HttpClientError::from_std)?, @@ -655,23 +655,6 @@ where RequestBuilderState::Error(err) => return Err(err), }; - // add user-agent header if not already set - if !request - .headers() - .contains_key(crate::http::header::USER_AGENT) - { - request.headers_mut().insert( - crate::http::header::USER_AGENT, - format!( - "{}/{}", - crate::utils::info::NAME, - crate::utils::info::VERSION - ) - .parse() - .unwrap(), - ); - } - let uri = request.uri().clone(); match self.http_client_service.serve(ctx, request).await { Ok(response) => Ok(response), @@ -688,6 +671,7 @@ mod test { use crate::{ http::{ layer::{ + required_header::AddRequiredRequestHeadersLayer, retry::{ManagedPolicy, RetryLayer}, trace::TraceLayer, }, @@ -752,6 +736,7 @@ mod test { .layer(RetryLayer::new( ManagedPolicy::default().with_backoff(ExponentialBackoff::default()), )) + .layer(AddRequiredRequestHeadersLayer::default()) .service_fn(fake_client_fn) .boxed() } diff --git a/src/http/io/mod.rs b/src/http/io/mod.rs new file mode 100644 index 00000000..164477c6 --- /dev/null +++ b/src/http/io/mod.rs @@ -0,0 +1,9 @@ +//! http I/O utilities, e.g. writing http requests/responses in std http format. + +mod request; +#[doc(inline)] +pub use request::write_http_request; + +mod response; +#[doc(inline)] +pub use response::write_http_response; diff --git a/src/http/io/request.rs b/src/http/io/request.rs new file mode 100644 index 00000000..c3f0a7e2 --- /dev/null +++ b/src/http/io/request.rs @@ -0,0 +1,141 @@ +use crate::{ + error::BoxError, + http::{ + dep::{http_body, http_body_util::BodyExt}, + Body, Request, + }, +}; +use bytes::Bytes; +use tokio::io::{AsyncWrite, AsyncWriteExt}; + +/// Write an HTTP request to a writer in std http format. +pub async fn write_http_request( + w: &mut W, + req: Request, + write_headers: bool, + write_body: bool, +) -> Result +where + W: AsyncWrite + Unpin + Send + Sync + 'static, + B: http_body::Body + Send + Sync + 'static, + B::Error: Into, +{ + let (parts, body) = req.into_parts(); + + if write_headers { + w.write_all( + format!( + "{} {}{} {:?}\r\n", + parts.method, + parts.uri.path(), + parts + .uri + .query() + .map(|q| format!("?{}", q)) + .unwrap_or_default(), + parts.version + ) + .as_bytes(), + ) + .await?; + + for (key, value) in parts.headers.iter() { + w.write_all(format!("{}: {}\r\n", key, value.to_str()?).as_bytes()) + .await?; + } + } + + let body = if write_body { + let body = body.collect().await.map_err(Into::into)?.to_bytes(); + w.write_all(b"\r\n").await?; + if !body.is_empty() { + w.write_all(body.as_ref()).await?; + } + Body::from(body) + } else { + Body::new(body) + }; + + let req = Request::from_parts(parts, body); + Ok(req) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[tokio::test] + async fn test_write_http_request_get() { + let mut buf = Vec::new(); + let req = Request::builder() + .method("GET") + .uri("http://example.com") + .body(Body::empty()) + .unwrap(); + + write_http_request(&mut buf, req, true, true).await.unwrap(); + + let req = String::from_utf8(buf).unwrap(); + assert_eq!(req, "GET / HTTP/1.1\r\n\r\n"); + } + + #[tokio::test] + async fn test_write_http_request_get_with_headers() { + let mut buf = Vec::new(); + let req = Request::builder() + .method("GET") + .uri("http://example.com") + .header("content-type", "text/plain") + .header("user-agent", "test/0") + .body(Body::empty()) + .unwrap(); + + write_http_request(&mut buf, req, true, true).await.unwrap(); + + let req = String::from_utf8(buf).unwrap(); + assert_eq!( + req, + "GET / HTTP/1.1\r\ncontent-type: text/plain\r\nuser-agent: test/0\r\n\r\n" + ); + } + + #[tokio::test] + async fn test_write_http_request_get_with_headers_and_query() { + let mut buf = Vec::new(); + let req = Request::builder() + .method("GET") + .uri("http://example.com?foo=bar") + .header("content-type", "text/plain") + .header("user-agent", "test/0") + .body(Body::empty()) + .unwrap(); + + write_http_request(&mut buf, req, true, true).await.unwrap(); + + let req = String::from_utf8(buf).unwrap(); + assert_eq!( + req, + "GET /?foo=bar HTTP/1.1\r\ncontent-type: text/plain\r\nuser-agent: test/0\r\n\r\n" + ); + } + + #[tokio::test] + async fn test_write_http_request_post_with_headers_and_body() { + let mut buf = Vec::new(); + let req = Request::builder() + .method("POST") + .uri("http://example.com") + .header("content-type", "text/plain") + .header("user-agent", "test/0") + .body(Body::from("hello")) + .unwrap(); + + write_http_request(&mut buf, req, true, true).await.unwrap(); + + let req = String::from_utf8(buf).unwrap(); + assert_eq!( + req, + "POST / HTTP/1.1\r\ncontent-type: text/plain\r\nuser-agent: test/0\r\n\r\nhello" + ); + } +} diff --git a/src/http/io/response.rs b/src/http/io/response.rs new file mode 100644 index 00000000..8f3b513c --- /dev/null +++ b/src/http/io/response.rs @@ -0,0 +1,120 @@ +use crate::{ + error::BoxError, + http::{ + dep::{http_body, http_body_util::BodyExt}, + Body, Response, + }, +}; +use bytes::Bytes; +use tokio::io::{AsyncWrite, AsyncWriteExt}; + +/// Write an HTTP response to a writer in std http format. +pub async fn write_http_response( + w: &mut W, + res: Response, + write_headers: bool, + write_body: bool, +) -> Result +where + W: AsyncWrite + Unpin + Send + Sync + 'static, + B: http_body::Body + Send + Sync + 'static, + B::Error: Into, +{ + let (parts, body) = res.into_parts(); + + if write_headers { + w.write_all( + format!( + "{:?} {}{}\r\n", + parts.version, + parts.status.as_u16(), + parts + .status + .canonical_reason() + .map(|r| format!(" {}", r)) + .unwrap_or_default(), + ) + .as_bytes(), + ) + .await?; + + for (key, value) in parts.headers.iter() { + w.write_all(format!("{}: {}\r\n", key, value.to_str()?).as_bytes()) + .await?; + } + } + + let body = if write_body { + let body = body.collect().await.map_err(Into::into)?.to_bytes(); + w.write_all(b"\r\n").await?; + if !body.is_empty() { + w.write_all(body.as_ref()).await?; + } + Body::from(body) + } else { + Body::new(body) + }; + + let req = Response::from_parts(parts, body); + Ok(req) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[tokio::test] + async fn test_write_response_ok() { + let mut buf = Vec::new(); + let res = Response::builder().status(200).body(Body::empty()).unwrap(); + + write_http_response(&mut buf, res, true, true) + .await + .unwrap(); + + let res = String::from_utf8(buf).unwrap(); + assert_eq!(res, "HTTP/1.1 200 OK\r\n\r\n"); + } + + #[tokio::test] + async fn test_write_response_redirect() { + let mut buf = Vec::new(); + let res = Response::builder() + .status(301) + .header("location", "http://example.com") + .header("server", "test/0") + .body(Body::empty()) + .unwrap(); + + write_http_response(&mut buf, res, true, true) + .await + .unwrap(); + + let res = String::from_utf8(buf).unwrap(); + assert_eq!( + res, + "HTTP/1.1 301 Moved Permanently\r\nlocation: http://example.com\r\nserver: test/0\r\n\r\n" + ); + } + + #[tokio::test] + async fn test_write_response_with_headers_and_body() { + let mut buf = Vec::new(); + let res = Response::builder() + .status(200) + .header("content-type", "text/plain") + .header("server", "test/0") + .body(Body::from("hello")) + .unwrap(); + + write_http_response(&mut buf, res, true, true) + .await + .unwrap(); + + let res = String::from_utf8(buf).unwrap(); + assert_eq!( + res, + "HTTP/1.1 200 OK\r\ncontent-type: text/plain\r\nserver: test/0\r\n\r\nhello" + ); + } +} diff --git a/src/http/layer/auth/add_authorization.rs b/src/http/layer/auth/add_authorization.rs index a5911b69..504bdb8f 100644 --- a/src/http/layer/auth/add_authorization.rs +++ b/src/http/layer/auth/add_authorization.rs @@ -59,11 +59,22 @@ const BASE64: base64::engine::GeneralPurpose = base64::engine::general_purpose:: /// [`SetRequestHeader`]: crate::http::layer::set_header::SetRequestHeader #[derive(Debug, Clone)] pub struct AddAuthorizationLayer { - value: HeaderValue, + value: Option, if_not_present: bool, } impl AddAuthorizationLayer { + /// Create a new [`AddAuthorizationLayer`] that does not add any authorization. + /// + /// Can be useful if you only want to add authorization for some branches + /// of your service. + pub fn none() -> Self { + Self { + value: None, + if_not_present: false, + } + } + /// Authorize requests using a username and password pair. /// /// The `Authorization` header will be set to `Basic {credentials}` where `credentials` is @@ -75,7 +86,7 @@ impl AddAuthorizationLayer { let encoded = BASE64.encode(format!("{}:{}", username, password)); let value = HeaderValue::try_from(format!("Basic {}", encoded)).unwrap(); Self { - value, + value: Some(value), if_not_present: false, } } @@ -91,7 +102,7 @@ impl AddAuthorizationLayer { let value = HeaderValue::try_from(format!("Bearer {}", token)).expect("token is not valid header"); Self { - value, + value: Some(value), if_not_present: false, } } @@ -103,7 +114,9 @@ impl AddAuthorizationLayer { /// [sensitive]: https://docs.rs/http/latest/http/header/struct.HeaderValue.html#method.set_sensitive #[allow(clippy::wrong_self_convention)] pub fn as_sensitive(mut self, sensitive: bool) -> Self { - self.value.set_sensitive(sensitive); + if let Some(value) = &mut self.value { + value.set_sensitive(sensitive); + } self } @@ -140,11 +153,19 @@ impl Layer for AddAuthorizationLayer { #[derive(Debug, Clone)] pub struct AddAuthorization { inner: S, - value: HeaderValue, + value: Option, if_not_present: bool, } impl AddAuthorization { + /// Create a new [`AddAuthorization`] that does not add any authorization. + /// + /// Can be useful if you only want to add authorization for some branches + /// of your service. + pub fn none(inner: S) -> Self { + AddAuthorizationLayer::none().layer(inner) + } + /// Authorize requests using a username and password pair. /// /// The `Authorization` header will be set to `Basic {credentials}` where `credentials` is @@ -176,7 +197,9 @@ impl AddAuthorization { /// [sensitive]: https://docs.rs/http/latest/http/header/struct.HeaderValue.html#method.set_sensitive #[allow(clippy::wrong_self_convention)] pub fn as_sensitive(mut self, sensitive: bool) -> Self { - self.value.set_sensitive(sensitive); + if let Some(value) = &mut self.value { + value.set_sensitive(sensitive); + } self } @@ -204,9 +227,11 @@ where ctx: Context, mut req: Request, ) -> Result { - if !self.if_not_present || !req.headers().contains_key(http::header::AUTHORIZATION) { - req.headers_mut() - .insert(http::header::AUTHORIZATION, self.value.clone()); + if let Some(value) = &self.value { + if !self.if_not_present || !req.headers().contains_key(http::header::AUTHORIZATION) { + req.headers_mut() + .insert(http::header::AUTHORIZATION, value.clone()); + } } self.inner.serve(ctx, req).await } diff --git a/src/http/layer/mod.rs b/src/http/layer/mod.rs index c203d7f9..b6a47190 100644 --- a/src/http/layer/mod.rs +++ b/src/http/layer/mod.rs @@ -31,12 +31,14 @@ pub mod propagate_headers; pub mod proxy_auth; pub mod remove_header; pub mod request_id; +pub mod required_header; pub mod retry; pub mod sensitive_headers; pub mod set_header; pub mod set_status; pub mod timeout; pub mod trace; +pub mod traffic_writer; pub mod upgrade; pub mod validate_request; diff --git a/src/http/layer/required_header/mod.rs b/src/http/layer/required_header/mod.rs new file mode 100644 index 00000000..91988ff7 --- /dev/null +++ b/src/http/layer/required_header/mod.rs @@ -0,0 +1,12 @@ +//! Middleware for setting required headers on requests and responses, if they are missing. +//! +//! See [request] and [response] for more details. + +pub mod request; +pub mod response; + +#[doc(inline)] +pub use self::{ + request::{AddRequiredRequestHeaders, AddRequiredRequestHeadersLayer}, + response::{AddRequiredResponseHeaders, AddRequiredResponseHeadersLayer}, +}; diff --git a/src/http/layer/required_header/request.rs b/src/http/layer/required_header/request.rs new file mode 100644 index 00000000..67bb17b8 --- /dev/null +++ b/src/http/layer/required_header/request.rs @@ -0,0 +1,195 @@ +//! Set required headers on the request, if they are missing. +//! +//! For now this only sets `Host` header on http/1.1, +//! as well as always a User-Agent for all versions. + +use http::header::{HOST, USER_AGENT}; + +use crate::http::{ + header::{self, RAMA_ID_HEADER_VALUE}, + Request, RequestContext, Response, +}; +use crate::service::{Context, Layer, Service}; +use std::fmt; + +/// Layer that applies [`AddRequiredRequestHeaders`] which adds a request header. +/// +/// See [`AddRequiredRequestHeaders`] for more details. +#[derive(Debug, Clone, Default)] +pub struct AddRequiredRequestHeadersLayer { + overwrite: bool, +} + +impl AddRequiredRequestHeadersLayer { + /// Create a new [`AddRequiredRequestHeadersLayer`]. + pub fn new() -> Self { + Self { overwrite: false } + } + + /// Set whether to overwrite the existing headers. + /// If set to `true`, the headers will be overwritten. + /// + /// Default is `false`. + pub fn overwrite(mut self, overwrite: bool) -> Self { + self.overwrite = overwrite; + self + } +} + +impl Layer for AddRequiredRequestHeadersLayer { + type Service = AddRequiredRequestHeaders; + + fn layer(&self, inner: S) -> Self::Service { + AddRequiredRequestHeaders { + inner, + overwrite: self.overwrite, + } + } +} + +/// Middleware that sets a header on the request. +#[derive(Clone)] +pub struct AddRequiredRequestHeaders { + inner: S, + overwrite: bool, +} + +impl AddRequiredRequestHeaders { + /// Create a new [`AddRequiredRequestHeaders`]. + pub fn new(inner: S) -> Self { + Self { + inner, + overwrite: false, + } + } + + /// Set whether to overwrite the existing headers. + /// If set to `true`, the headers will be overwritten. + /// + /// Default is `false`. + pub fn overwrite(mut self, overwrite: bool) -> Self { + self.overwrite = overwrite; + self + } + + define_inner_service_accessors!(); +} + +impl fmt::Debug for AddRequiredRequestHeaders +where + S: fmt::Debug, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("AddRequiredRequestHeaders") + .field("inner", &self.inner) + .finish() + } +} + +impl Service> for AddRequiredRequestHeaders +where + ReqBody: Send + 'static, + ResBody: Send + 'static, + State: Send + Sync + 'static, + S: Service, Response = Response>, +{ + type Response = S::Response; + type Error = S::Error; + + async fn serve( + &self, + mut ctx: Context, + mut req: Request, + ) -> Result { + if self.overwrite || !req.headers().contains_key(HOST) { + if let Some(host) = ctx + .get_or_insert_with(|| RequestContext::from(&req)) + .host + .as_deref() + .and_then(|host| host.parse().ok()) + { + req.headers_mut().insert(HOST, host); + }; + } + + if self.overwrite { + req.headers_mut() + .insert(USER_AGENT, RAMA_ID_HEADER_VALUE.clone()); + } else if let header::Entry::Vacant(header) = req.headers_mut().entry(USER_AGENT) { + header.insert(RAMA_ID_HEADER_VALUE.clone()); + } + + self.inner.serve(ctx, req).await + } +} + +#[cfg(test)] +mod test { + use super::*; + use crate::http::{Body, Request}; + use crate::service::{Context, Service, ServiceBuilder}; + use std::convert::Infallible; + + #[tokio::test] + async fn add_required_request_headers() { + let svc = ServiceBuilder::new() + .layer(AddRequiredRequestHeadersLayer::default()) + .service_fn(|_ctx: Context<()>, req: Request| async move { + assert!(req.headers().contains_key(HOST)); + assert!(req.headers().contains_key(USER_AGENT)); + Ok::<_, Infallible>(http::Response::new(Body::empty())) + }); + + let req = Request::builder() + .uri("http://www.example.com/") + .body(Body::empty()) + .unwrap(); + let resp = svc.serve(Context::default(), req).await.unwrap(); + + assert!(!resp.headers().contains_key(HOST)); + assert!(!resp.headers().contains_key(USER_AGENT)); + } + + #[tokio::test] + async fn add_required_request_headers_overwrite() { + let svc = ServiceBuilder::new() + .layer(AddRequiredRequestHeadersLayer::new().overwrite(true)) + .service_fn(|_ctx: Context<()>, req: Request| async move { + assert_eq!(req.headers().get(HOST).unwrap(), "example.com"); + assert_eq!( + req.headers().get(USER_AGENT).unwrap(), + RAMA_ID_HEADER_VALUE.to_str().unwrap() + ); + Ok::<_, Infallible>(http::Response::new(Body::empty())) + }); + + let req = Request::builder() + .uri("http://127.0.0.1/") + .header(HOST, "example.com") + .header(USER_AGENT, "test") + .body(Body::empty()) + .unwrap(); + + let resp = svc.serve(Context::default(), req).await.unwrap(); + + assert!(!resp.headers().contains_key(HOST)); + assert!(!resp.headers().contains_key(USER_AGENT)); + } + + #[tokio::test] + async fn add_required_request_headers_no_host() { + let svc = ServiceBuilder::new() + .layer(AddRequiredRequestHeadersLayer::default()) + .service_fn(|_ctx: Context<()>, req: Request| async move { + assert!(!req.headers().contains_key(HOST)); + assert!(req.headers().contains_key(USER_AGENT)); + Ok::<_, Infallible>(http::Response::new(Body::empty())) + }); + + let req = Request::builder().body(Body::empty()).unwrap(); + let resp = svc.serve(Context::default(), req).await.unwrap(); + + assert!(!resp.headers().contains_key(HOST)); + assert!(!resp.headers().contains_key(USER_AGENT)); + } +} diff --git a/src/http/layer/required_header/response.rs b/src/http/layer/required_header/response.rs new file mode 100644 index 00000000..2fcb9f29 --- /dev/null +++ b/src/http/layer/required_header/response.rs @@ -0,0 +1,174 @@ +//! Set required headers on the response, if they are missing. +//! +//! For now this only sets `Server` and `Date` heades. + +use crate::http::{ + header::{self, RAMA_ID_HEADER_VALUE}, + Request, Response, +}; +use crate::http::{ + header::{DATE, SERVER}, + headers::{Date, HeaderMapExt}, +}; +use crate::service::{Context, Layer, Service}; +use std::{fmt, time::SystemTime}; + +/// Layer that applies [`AddRequiredResponseHeaders`] which adds a request header. +/// +/// See [`AddRequiredResponseHeaders`] for more details. +#[derive(Debug, Clone, Default)] +pub struct AddRequiredResponseHeadersLayer { + overwrite: bool, +} + +impl AddRequiredResponseHeadersLayer { + /// Create a new [`AddRequiredResponseHeadersLayer`]. + pub fn new() -> Self { + Self { overwrite: false } + } + + /// Set whether to overwrite the existing headers. + /// If set to `true`, the headers will be overwritten. + /// + /// Default is `false`. + pub fn overwrite(mut self, overwrite: bool) -> Self { + self.overwrite = overwrite; + self + } +} + +impl Layer for AddRequiredResponseHeadersLayer { + type Service = AddRequiredResponseHeaders; + + fn layer(&self, inner: S) -> Self::Service { + AddRequiredResponseHeaders { + inner, + overwrite: self.overwrite, + } + } +} + +/// Middleware that sets a header on the request. +#[derive(Clone)] +pub struct AddRequiredResponseHeaders { + inner: S, + overwrite: bool, +} + +impl AddRequiredResponseHeaders { + /// Create a new [`AddRequiredResponseHeaders`]. + pub fn new(inner: S) -> Self { + Self { + inner, + overwrite: false, + } + } + + /// Set whether to overwrite the existing headers. + /// If set to `true`, the headers will be overwritten. + /// + /// Default is `false`. + pub fn overwrite(mut self, overwrite: bool) -> Self { + self.overwrite = overwrite; + self + } + + define_inner_service_accessors!(); +} + +impl fmt::Debug for AddRequiredResponseHeaders +where + S: fmt::Debug, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("AddRequiredResponseHeaders") + .field("inner", &self.inner) + .finish() + } +} + +impl Service> for AddRequiredResponseHeaders +where + ReqBody: Send + 'static, + ResBody: Send + 'static, + State: Send + Sync + 'static, + S: Service, Response = Response>, +{ + type Response = S::Response; + type Error = S::Error; + + async fn serve( + &self, + ctx: Context, + req: Request, + ) -> Result { + let mut resp = self.inner.serve(ctx, req).await?; + + if self.overwrite { + resp.headers_mut() + .insert(SERVER, RAMA_ID_HEADER_VALUE.clone()); + } else if let header::Entry::Vacant(header) = resp.headers_mut().entry(SERVER) { + header.insert(RAMA_ID_HEADER_VALUE.clone()); + } + + if self.overwrite || !resp.headers().contains_key(DATE) { + resp.headers_mut() + .typed_insert(Date::from(SystemTime::now())); + } + + Ok(resp) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{http::Body, service::ServiceBuilder}; + use std::convert::Infallible; + + #[tokio::test] + async fn add_required_response_headers() { + let svc = ServiceBuilder::new() + .layer(AddRequiredResponseHeadersLayer::default()) + .service_fn(|_ctx: Context<()>, req: Request| async move { + assert!(!req.headers().contains_key(SERVER)); + assert!(!req.headers().contains_key(DATE)); + Ok::<_, Infallible>(Response::new(Body::empty())) + }); + + let req = Request::new(Body::empty()); + let resp = svc.serve(Context::default(), req).await.unwrap(); + + assert_eq!( + resp.headers().get(SERVER).unwrap(), + RAMA_ID_HEADER_VALUE.as_ref() + ); + assert!(resp.headers().contains_key(DATE)); + } + + #[tokio::test] + async fn add_required_response_headers_overwrite() { + let svc = ServiceBuilder::new() + .layer(AddRequiredResponseHeadersLayer::new().overwrite(true)) + .service_fn(|_ctx: Context<()>, req: Request| async move { + assert!(!req.headers().contains_key(SERVER)); + assert!(!req.headers().contains_key(DATE)); + Ok::<_, Infallible>( + Response::builder() + .header(SERVER, "foo") + .header(DATE, "bar") + .body(Body::empty()) + .unwrap(), + ) + }); + + let req = Request::new(Body::empty()); + let resp = svc.serve(Context::default(), req).await.unwrap(); + + assert_eq!( + resp.headers().get(SERVER).unwrap(), + RAMA_ID_HEADER_VALUE.to_str().unwrap() + ); + assert_ne!(resp.headers().get(DATE).unwrap(), "bar"); + } +} diff --git a/src/http/layer/traffic_writer/mod.rs b/src/http/layer/traffic_writer/mod.rs new file mode 100644 index 00000000..1299b470 --- /dev/null +++ b/src/http/layer/traffic_writer/mod.rs @@ -0,0 +1,367 @@ +//! Middleware to write Http traffic in std format. +//! +//! Can be useful for cli / debug purposes. + +use crate::{ + http::{ + io::{write_http_request, write_http_response}, + Request, Response, + }, + rt::Executor, +}; +use tokio::{ + io::{AsyncWrite, AsyncWriteExt}, + sync::mpsc::{channel, unbounded_channel, Sender, UnboundedSender}, +}; + +mod request; +#[doc(inline)] +pub use request::{DoNotWriteRequest, RequestWriter, RequestWriterLayer, RequestWriterService}; + +mod response; +#[doc(inline)] +pub use response::{ + DoNotWriteResponse, ResponseWriter, ResponseWriterLayer, ResponseWriterService, +}; + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +/// Http writer mode. +pub enum WriterMode { + /// Print the entire request / response. + All, + /// Print only the headers of the request / response. + Headers, + /// Print only the body of the request / response. + Body, +} + +/// A writer that can write both requests and responses. +pub struct BidirectionalWriter { + sender: S, +} + +impl std::fmt::Debug for BidirectionalWriter { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("BidirectionalWriter") + .field("sender", &format_args!("{}", std::any::type_name::())) + .finish() + } +} + +impl Clone for BidirectionalWriter { + fn clone(&self) -> Self { + Self { + sender: self.sender.clone(), + } + } +} + +impl BidirectionalWriter> { + /// Create a new [`BidirectionalWriter`] with a custom writer gated behind an unbounded sender. + pub fn unbounded( + executor: &Executor, + mut writer: W, + request_mode: Option, + response_mode: Option, + ) -> Self + where + W: AsyncWrite + Unpin + Send + Sync + 'static, + { + let (tx, mut rx) = unbounded_channel(); + let (write_request_headers, write_request_body) = match request_mode { + Some(WriterMode::All) => (true, true), + Some(WriterMode::Headers) => (true, false), + Some(WriterMode::Body) => (false, true), + None => (false, false), + }; + + let (write_response_headers, write_response_body) = match response_mode { + Some(WriterMode::All) => (true, true), + Some(WriterMode::Headers) => (true, false), + Some(WriterMode::Body) => (false, true), + None => (false, false), + }; + + executor.spawn_task(async move { + while let Some(msg) = rx.recv().await { + match msg { + BidirectionalMessage::Request(req) => { + if let Err(err) = write_http_request( + &mut writer, + req, + write_request_headers, + write_request_body, + ) + .await + { + tracing::error!(err = %err, "failed to write http request to writer") + } + } + BidirectionalMessage::Response(res) => { + if let Err(err) = write_http_response( + &mut writer, + res, + write_response_headers, + write_response_body, + ) + .await + { + tracing::error!(err = %err, "failed to write http response to writer") + } + } + } + if let Err(err) = writer.write_all(b"\r\n").await { + tracing::error!(err = %err, "failed to write separator to writer") + } + } + }); + + Self { sender: tx } + } + + /// Create a new [`BidirectionalWriter`] that prints requests and responses to stdout + /// over an unbounded channel. + pub fn stdout_unbounded( + executor: &Executor, + request_mode: Option, + response_mode: Option, + ) -> Self { + Self::unbounded(executor, tokio::io::stdout(), request_mode, response_mode) + } + + /// Create a new [`BidirectionalWriter`] that prints requests and responses to stderr + /// over an unbounded channel. + pub fn stderr_unbounded( + executor: &Executor, + request_mode: Option, + response_mode: Option, + ) -> Self { + Self::unbounded(executor, tokio::io::stderr(), request_mode, response_mode) + } +} + +impl BidirectionalWriter> { + /// Create a new [`BidirectionalWriter`] with a custom writer gated behind a custom bounded channel. + pub fn new( + executor: &Executor, + mut writer: W, + buffer: usize, + request_mode: Option, + response_mode: Option, + ) -> Self + where + W: AsyncWrite + Unpin + Send + Sync + 'static, + { + let (tx, mut rx) = channel(buffer); + let (write_request_headers, write_request_body) = match request_mode { + Some(WriterMode::All) => (true, true), + Some(WriterMode::Headers) => (true, false), + Some(WriterMode::Body) => (false, true), + None => (false, false), + }; + + let (write_response_headers, write_response_body) = match response_mode { + Some(WriterMode::All) => (true, true), + Some(WriterMode::Headers) => (true, false), + Some(WriterMode::Body) => (false, true), + None => (false, false), + }; + + executor.spawn_task(async move { + while let Some(msg) = rx.recv().await { + match msg { + BidirectionalMessage::Request(req) => { + if let Err(err) = write_http_request( + &mut writer, + req, + write_request_headers, + write_request_body, + ) + .await + { + tracing::error!(err = %err, "failed to write http request to writer") + } + } + BidirectionalMessage::Response(res) => { + if let Err(err) = write_http_response( + &mut writer, + res, + write_response_headers, + write_response_body, + ) + .await + { + tracing::error!(err = %err, "failed to write http response to writer") + } + } + } + if let Err(err) = writer.write_all(b"\r\n").await { + tracing::error!(err = %err, "failed to write separator to writer") + } + } + }); + + Self { sender: tx } + } + + /// Create a new [`BidirectionalWriter`] with a custom writer that only writes the last request and response received. + pub fn last( + executor: &Executor, + mut writer: W, + request_mode: Option, + response_mode: Option, + ) -> Self + where + W: AsyncWrite + Unpin + Send + Sync + 'static, + { + let (tx, mut rx) = channel(2); + let (write_request_headers, write_request_body) = match request_mode { + Some(WriterMode::All) => (true, true), + Some(WriterMode::Headers) => (true, false), + Some(WriterMode::Body) => (false, true), + None => (false, false), + }; + + let (write_response_headers, write_response_body) = match response_mode { + Some(WriterMode::All) => (true, true), + Some(WriterMode::Headers) => (true, false), + Some(WriterMode::Body) => (false, true), + None => (false, false), + }; + + executor.spawn_task(async move { + let mut last_request = None; + let mut last_response = None; + + while let Some(msg) = rx.recv().await { + match msg { + BidirectionalMessage::Request(req) => last_request = Some(req), + BidirectionalMessage::Response(res) => last_response = Some(res), + } + } + + if let Some(req) = last_request { + if let Err(err) = + write_http_request(&mut writer, req, write_request_headers, write_request_body) + .await + { + tracing::error!(err = %err, "failed to write last http request to writer") + } + if let Err(err) = writer.write_all(b"\r\n").await { + tracing::error!(err = %err, "failed to write separator to writer") + } + } + + if let Some(res) = last_response { + if let Err(err) = write_http_response( + &mut writer, + res, + write_response_headers, + write_response_body, + ) + .await + { + tracing::error!(err = %err, "failed to write last http response to writer") + } + if let Err(err) = writer.write_all(b"\r\n").await { + tracing::error!(err = %err, "failed to write separator to writer") + } + } + }); + + Self { sender: tx } + } + + /// Create a new [`BidirectionalWriter`] that prints requests and responses to stdout + /// over a bounded channel. + pub fn stdout( + executor: &Executor, + buffer: usize, + request_mode: Option, + response_mode: Option, + ) -> Self { + Self::new( + executor, + tokio::io::stdout(), + buffer, + request_mode, + response_mode, + ) + } + + /// Create a new [`BidirectionalWriter`] that prints the last request and response to stdout. + pub fn stdout_last( + executor: &Executor, + request_mode: Option, + response_mode: Option, + ) -> Self { + Self::last(executor, tokio::io::stdout(), request_mode, response_mode) + } + + /// Create a new [`BidirectionalWriter`] that prints requests and responses to stderr + /// over a bounded channel. + pub fn stderr( + executor: &Executor, + buffer: usize, + request_mode: Option, + response_mode: Option, + ) -> Self { + Self::new( + executor, + tokio::io::stderr(), + buffer, + request_mode, + response_mode, + ) + } + + /// Create a new [`BidirectionalWriter`] that prints the last request and responses to stderr. + pub fn stderr_last( + executor: &Executor, + request_mode: Option, + response_mode: Option, + ) -> Self { + Self::last(executor, tokio::io::stderr(), request_mode, response_mode) + } +} + +impl RequestWriter for BidirectionalWriter> { + async fn write_request(&self, req: Request) { + if let Err(err) = self.sender.send(BidirectionalMessage::Request(req)) { + tracing::error!(err = %err, "failed to send request to writer over unbounded channel") + } + } +} + +impl ResponseWriter for BidirectionalWriter> { + async fn write_response(&self, res: Response) { + if let Err(err) = self.sender.send(BidirectionalMessage::Response(res)) { + tracing::error!(err = %err, "failed to send response to writer over unbounded channel") + } + } +} + +impl RequestWriter for BidirectionalWriter> { + async fn write_request(&self, req: Request) { + if let Err(err) = self.sender.send(BidirectionalMessage::Request(req)).await { + tracing::error!(err = %err, "failed to send request to writer over bounded channel") + } + } +} + +impl ResponseWriter for BidirectionalWriter> { + async fn write_response(&self, res: Response) { + if let Err(err) = self.sender.send(BidirectionalMessage::Response(res)).await { + tracing::error!(err = %err, "failed to send response to writer over bounded channel") + } + } +} + +/// The internal message type for the [`BidirectionalWriter`]. +#[derive(Debug)] +pub enum BidirectionalMessage { + /// A request to be written. + Request(Request), + /// A response to be written. + Response(Response), +} diff --git a/src/http/layer/traffic_writer/request.rs b/src/http/layer/traffic_writer/request.rs new file mode 100644 index 00000000..1550e946 --- /dev/null +++ b/src/http/layer/traffic_writer/request.rs @@ -0,0 +1,328 @@ +use super::WriterMode; +use crate::error::{BoxError, ErrorExt, OpaqueError}; +use crate::http::dep::http_body; +use crate::http::dep::http_body_util::BodyExt; +use crate::http::io::write_http_request; +use crate::http::{Body, Request, Response}; +use crate::rt::Executor; +use crate::service::{Context, Layer, Service}; +use bytes::Bytes; +use std::fmt::Debug; +use std::future::Future; +use tokio::io::{stderr, stdout, AsyncWrite, AsyncWriteExt}; +use tokio::sync::mpsc::{channel, unbounded_channel, Sender, UnboundedSender}; + +/// Layer that applies [`RequestWriterService`] which prints the http request in std format. +pub struct RequestWriterLayer { + writer: W, +} + +impl Debug for RequestWriterLayer { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("RequestWriterLayer") + .field("writer", &format_args!("{}", std::any::type_name::())) + .finish() + } +} + +impl Clone for RequestWriterLayer { + fn clone(&self) -> Self { + Self { + writer: self.writer.clone(), + } + } +} + +impl RequestWriterLayer { + /// Create a new [`RequestWriterLayer`] with a custom [`RequestWriter`]. + pub fn new(writer: W) -> Self { + Self { writer } + } +} + +/// A trait for writing http requests. +pub trait RequestWriter: Send + Sync + 'static { + /// Write the http request. + fn write_request(&self, req: Request) -> impl Future + Send + '_; +} + +/// Marker struct to indicate that the request should not be printed. +#[derive(Debug, Clone, Default)] +#[non_exhaustive] +pub struct DoNotWriteRequest; + +impl DoNotWriteRequest { + /// Create a new [`DoNotWriteRequest`] marker. + pub fn new() -> Self { + Self + } +} + +impl RequestWriterLayer> { + /// Create a new [`RequestWriterLayer`] that prints requests to an [`AsyncWrite`]r + /// over an unbounded channel + pub fn writer_unbounded(executor: &Executor, mut writer: W, mode: Option) -> Self + where + W: AsyncWrite + Unpin + Send + Sync + 'static, + { + let (tx, mut rx) = unbounded_channel(); + let (write_headers, write_body) = match mode { + Some(WriterMode::All) => (true, true), + Some(WriterMode::Headers) => (true, false), + Some(WriterMode::Body) => (false, true), + None => (false, false), + }; + executor.spawn_task(async move { + while let Some(req) = rx.recv().await { + if let Err(err) = + write_http_request(&mut writer, req, write_headers, write_body).await + { + tracing::error!(err = %err, "failed to write http request to writer") + } + if let Err(err) = writer.write_all(b"\r\n").await { + tracing::error!(err = %err, "failed to write separator to writer") + } + } + }); + Self { writer: tx } + } + + /// Create a new [`RequestWriterLayer`] that prints requests to stdout + /// over an unbounded channel. + pub fn stdout_unbounded(executor: &Executor, mode: Option) -> Self { + Self::writer_unbounded(executor, stdout(), mode) + } + + /// Create a new [`RequestWriterLayer`] that prints requests to stderr + /// over an unbounded channel. + pub fn stderr_unbounded(executor: &Executor, mode: Option) -> Self { + Self::writer_unbounded(executor, stderr(), mode) + } +} + +impl RequestWriterLayer> { + /// Create a new [`RequestWriterLayer`] that prints requests to an [`AsyncWrite`]r + /// over a bounded channel with a fixed buffer size. + pub fn writer( + executor: &Executor, + mut writer: W, + buffer_size: usize, + mode: Option, + ) -> Self + where + W: AsyncWrite + Unpin + Send + Sync + 'static, + { + let (tx, mut rx) = channel(buffer_size); + let (write_headers, write_body) = match mode { + Some(WriterMode::All) => (true, true), + Some(WriterMode::Headers) => (true, false), + Some(WriterMode::Body) => (false, true), + None => (false, false), + }; + executor.spawn_task(async move { + while let Some(req) = rx.recv().await { + if let Err(err) = + write_http_request(&mut writer, req, write_headers, write_body).await + { + tracing::error!(err = %err, "failed to write http request to writer") + } + if let Err(err) = writer.write_all(b"\r\n").await { + tracing::error!(err = %err, "failed to write separator to writer") + } + } + }); + Self { writer: tx } + } + + /// Create a new [`RequestWriterLayer`] that prints requests to stdout + /// over a bounded channel with a fixed buffer size. + pub fn stdout(executor: &Executor, buffer_size: usize, mode: Option) -> Self { + Self::writer(executor, stdout(), buffer_size, mode) + } + + /// Create a new [`RequestWriterLayer`] that prints requests to stderr + /// over a bounded channel with a fixed buffer size. + pub fn stderr(executor: &Executor, buffer_size: usize, mode: Option) -> Self { + Self::writer(executor, stderr(), buffer_size, mode) + } +} + +impl Layer for RequestWriterLayer { + type Service = RequestWriterService; + + fn layer(&self, inner: S) -> Self::Service { + RequestWriterService { + inner, + writer: self.writer.clone(), + } + } +} + +/// Middleware to print Http request in std format. +/// +/// See the [module docs](super) for more details. +pub struct RequestWriterService { + inner: S, + writer: W, +} + +impl RequestWriterService { + /// Create a new [`RequestWriterService`] with a custom [`RequestWriter`]. + pub fn new(writer: W, inner: S) -> Self { + Self { inner, writer } + } +} + +impl Debug for RequestWriterService { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("RequestWriterService") + .field("inner", &self.inner) + .field("writer", &format_args!("{}", std::any::type_name::())) + .finish() + } +} + +impl Clone for RequestWriterService { + fn clone(&self) -> Self { + Self { + inner: self.inner.clone(), + writer: self.writer.clone(), + } + } +} + +impl RequestWriterService> { + /// Create a new [`RequestWriterService`] that prints requests to an [`AsyncWrite`]r + /// over an unbounded channel + pub fn writer_unbounded( + executor: &Executor, + writer: W, + mode: Option, + inner: S, + ) -> Self + where + W: AsyncWrite + Unpin + Send + Sync + 'static, + { + let layer = RequestWriterLayer::writer_unbounded(executor, writer, mode); + layer.layer(inner) + } + + /// Create a new [`RequestWriterService`] that prints requests to stdout + /// over an unbounded channel. + pub fn stdout_unbounded(executor: &Executor, mode: Option, inner: S) -> Self { + Self::writer_unbounded(executor, stdout(), mode, inner) + } + + /// Create a new [`RequestWriterService`] that prints requests to stderr + /// over an unbounded channel. + pub fn stderr_unbounded(executor: &Executor, mode: Option, inner: S) -> Self { + Self::writer_unbounded(executor, stderr(), mode, inner) + } +} + +impl RequestWriterService> { + /// Create a new [`RequestWriterService`] that prints requests to an [`AsyncWrite`]r + /// over a bounded channel with a fixed buffer size. + pub fn writer( + executor: &Executor, + writer: W, + buffer_size: usize, + mode: Option, + inner: S, + ) -> Self + where + W: AsyncWrite + Unpin + Send + Sync + 'static, + { + let layer = RequestWriterLayer::writer(executor, writer, buffer_size, mode); + layer.layer(inner) + } + + /// Create a new [`RequestWriterService`] that prints requests to stdout + /// over a bounded channel with a fixed buffer size. + pub fn stdout( + executor: &Executor, + buffer_size: usize, + mode: Option, + inner: S, + ) -> Self { + Self::writer(executor, stdout(), buffer_size, mode, inner) + } + + /// Create a new [`RequestWriterService`] that prints requests to stderr + /// over a bounded channel with a fixed buffer size. + pub fn stderr( + executor: &Executor, + buffer_size: usize, + mode: Option, + inner: S, + ) -> Self { + Self::writer(executor, stderr(), buffer_size, mode, inner) + } +} + +impl RequestWriterService {} + +impl Service> for RequestWriterService +where + State: Send + Sync + 'static, + S: Service>, + S::Error: Into, + W: RequestWriter, + ReqBody: http_body::Body + Send + Sync + 'static, + ReqBody::Error: Into, + ResBody: Send + 'static, +{ + type Response = Response; + type Error = BoxError; + + async fn serve( + &self, + ctx: Context, + req: Request, + ) -> Result { + let req = match ctx.get::() { + Some(_) => req.map(Body::new), + None => { + let (parts, body) = req.into_parts(); + let body_bytes = body + .collect() + .await + .map_err(|err| { + OpaqueError::from_boxed(err.into()) + .context("printer prepare: collect request body") + })? + .to_bytes(); + let req = Request::from_parts(parts.clone(), Body::from(body_bytes.clone())); + self.writer.write_request(req).await; + Request::from_parts(parts, Body::from(body_bytes)) + } + }; + self.inner.serve(ctx, req).await.map_err(Into::into) + } +} + +impl RequestWriter for Sender { + async fn write_request(&self, req: Request) { + if let Err(err) = self.send(req).await { + tracing::error!(err = %err, "failed to send request to channel") + } + } +} + +impl RequestWriter for UnboundedSender { + async fn write_request(&self, req: Request) { + if let Err(err) = self.send(req) { + tracing::error!(err = %err, "failed to send request to unbounded channel") + } + } +} + +impl RequestWriter for F +where + F: Fn(Request) -> Fut + Send + Sync + 'static, + Fut: Future + Send + 'static, +{ + async fn write_request(&self, req: Request) { + self(req).await + } +} diff --git a/src/http/layer/traffic_writer/response.rs b/src/http/layer/traffic_writer/response.rs new file mode 100644 index 00000000..ed534cfd --- /dev/null +++ b/src/http/layer/traffic_writer/response.rs @@ -0,0 +1,323 @@ +use super::WriterMode; +use crate::error::{BoxError, ErrorContext, OpaqueError}; +use crate::http::dep::http_body; +use crate::http::dep::http_body_util::BodyExt; +use crate::http::io::write_http_response; +use crate::http::{Body, Request, Response}; +use crate::rt::Executor; +use crate::service::{Context, Layer, Service}; +use bytes::Bytes; +use std::fmt::Debug; +use std::future::Future; +use tokio::io::{stderr, stdout, AsyncWrite}; +use tokio::sync::mpsc::{channel, unbounded_channel, Sender, UnboundedSender}; + +/// Layer that applies [`ResponseWriterService`] which prints the http response in std format. +pub struct ResponseWriterLayer { + writer: W, +} + +impl Debug for ResponseWriterLayer { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("ResponseWriterLayer") + .field("writer", &format_args!("{}", std::any::type_name::())) + .finish() + } +} + +impl Clone for ResponseWriterLayer { + fn clone(&self) -> Self { + Self { + writer: self.writer.clone(), + } + } +} + +impl ResponseWriterLayer { + /// Create a new [`ResponseWriterLayer`] with a custom [`ResponseWriter`]. + pub fn new(writer: W) -> Self { + Self { writer } + } +} + +/// A trait for writing http responses. +pub trait ResponseWriter: Send + Sync + 'static { + /// Write the http response. + fn write_response(&self, res: Response) -> impl Future + Send + '_; +} + +/// Marker struct to indicate that the response should not be printed. +#[derive(Debug, Clone, Default)] +#[non_exhaustive] +pub struct DoNotWriteResponse; + +impl DoNotWriteResponse { + /// Create a new [`DoNotWriteResponse`] marker. + pub fn new() -> Self { + Self + } +} + +impl ResponseWriterLayer> { + /// Create a new [`ResponseWriterLayer`] that prints responses to an [`AsyncWrite`]r + /// over an unbounded channel + pub fn writer_unbounded(executor: &Executor, mut writer: W, mode: Option) -> Self + where + W: AsyncWrite + Unpin + Send + Sync + 'static, + { + let (tx, mut rx) = unbounded_channel(); + let (write_headers, write_body) = match mode { + Some(WriterMode::All) => (true, true), + Some(WriterMode::Headers) => (true, false), + Some(WriterMode::Body) => (false, true), + None => (false, false), + }; + executor.spawn_task(async move { + while let Some(res) = rx.recv().await { + if let Err(err) = + write_http_response(&mut writer, res, write_headers, write_body).await + { + tracing::error!(err = %err, "failed to write http response to writer") + } + } + }); + Self { writer: tx } + } + + /// Create a new [`ResponseWriterLayer`] that prints responses to stdout + /// over an unbounded channel. + pub fn stdout_unbounded(executor: &Executor, mode: Option) -> Self { + Self::writer_unbounded(executor, stdout(), mode) + } + + /// Create a new [`ResponseWriterLayer`] that prints responses to stderr + /// over an unbounded channel. + pub fn stderr_unbounded(executor: &Executor, mode: Option) -> Self { + Self::writer_unbounded(executor, stderr(), mode) + } +} + +impl ResponseWriterLayer> { + /// Create a new [`ResponseWriterLayer`] that prints responses to an [`AsyncWrite`]r + /// over a bounded channel with a fixed buffer size. + pub fn writer( + executor: &Executor, + mut writer: W, + buffer_size: usize, + mode: Option, + ) -> Self + where + W: AsyncWrite + Unpin + Send + Sync + 'static, + { + let (tx, mut rx) = channel(buffer_size); + let (write_headers, write_body) = match mode { + Some(WriterMode::All) => (true, true), + Some(WriterMode::Headers) => (true, false), + Some(WriterMode::Body) => (false, true), + None => (false, false), + }; + executor.spawn_task(async move { + while let Some(res) = rx.recv().await { + if let Err(err) = + write_http_response(&mut writer, res, write_headers, write_body).await + { + tracing::error!(err = %err, "failed to write http response to writer") + } + } + }); + Self { writer: tx } + } + + /// Create a new [`ResponseWriterLayer`] that prints responses to stdout + /// over a bounded channel with a fixed buffer size. + pub fn stdout(executor: &Executor, buffer_size: usize, mode: Option) -> Self { + Self::writer(executor, stdout(), buffer_size, mode) + } + + /// Create a new [`ResponseWriterLayer`] that prints responses to stderr + /// over a bounded channel with a fixed buffer size. + pub fn stderr(executor: &Executor, buffer_size: usize, mode: Option) -> Self { + Self::writer(executor, stderr(), buffer_size, mode) + } +} + +impl Layer for ResponseWriterLayer { + type Service = ResponseWriterService; + + fn layer(&self, inner: S) -> Self::Service { + ResponseWriterService { + inner, + writer: self.writer.clone(), + } + } +} + +/// Middleware to print Http request in std format. +/// +/// See the [module docs](super) for more details. +pub struct ResponseWriterService { + inner: S, + writer: W, +} + +impl ResponseWriterService { + /// Create a new [`ResponseWriterService`] with a custom [`ResponseWriter`]. + pub fn new(writer: W, inner: S) -> Self { + Self { inner, writer } + } +} + +impl Debug for ResponseWriterService { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("ResponseWriterService") + .field("inner", &self.inner) + .field("writer", &format_args!("{}", std::any::type_name::())) + .finish() + } +} + +impl Clone for ResponseWriterService { + fn clone(&self) -> Self { + Self { + inner: self.inner.clone(), + writer: self.writer.clone(), + } + } +} + +impl ResponseWriterService> { + /// Create a new [`ResponseWriterService`] that prints responses to an [`AsyncWrite`]r + /// over an unbounded channel + pub fn writer_unbounded( + executor: &Executor, + writer: W, + mode: Option, + inner: S, + ) -> Self + where + W: AsyncWrite + Unpin + Send + Sync + 'static, + { + let layer = ResponseWriterLayer::writer_unbounded(executor, writer, mode); + layer.layer(inner) + } + + /// Create a new [`ResponseWriterService`] that prints responses to stdout + /// over an unbounded channel. + pub fn stdout_unbounded(executor: &Executor, mode: Option, inner: S) -> Self { + Self::writer_unbounded(executor, stdout(), mode, inner) + } + + /// Create a new [`ResponseWriterService`] that prints responses to stderr + /// over an unbounded channel. + pub fn stderr_unbounded(executor: &Executor, mode: Option, inner: S) -> Self { + Self::writer_unbounded(executor, stderr(), mode, inner) + } +} + +impl ResponseWriterService> { + /// Create a new [`ResponseWriterService`] that prints responses to an [`AsyncWrite`]r + /// over a bounded channel with a fixed buffer size. + pub fn writer( + executor: &Executor, + writer: W, + buffer_size: usize, + mode: Option, + inner: S, + ) -> Self + where + W: AsyncWrite + Unpin + Send + Sync + 'static, + { + let layer = ResponseWriterLayer::writer(executor, writer, buffer_size, mode); + layer.layer(inner) + } + + /// Create a new [`ResponseWriterService`] that prints responses to stdout + /// over a bounded channel with a fixed buffer size. + pub fn stdout( + executor: &Executor, + buffer_size: usize, + mode: Option, + inner: S, + ) -> Self { + Self::writer(executor, stdout(), buffer_size, mode, inner) + } + + /// Create a new [`ResponseWriterService`] that prints responses to stderr + /// over a bounded channel with a fixed buffer size. + pub fn stderr( + executor: &Executor, + buffer_size: usize, + mode: Option, + inner: S, + ) -> Self { + Self::writer(executor, stderr(), buffer_size, mode, inner) + } +} + +impl ResponseWriterService {} + +impl Service> for ResponseWriterService +where + State: Send + Sync + 'static, + S: Service, Response = Response>, + S::Error: Into, + W: ResponseWriter, + ReqBody: Send + 'static, + ResBody: http_body::Body + Send + Sync + 'static, + ResBody::Error: Into, +{ + type Response = Response; + type Error = BoxError; + + async fn serve( + &self, + ctx: Context, + req: Request, + ) -> Result { + let do_not_print_response: Option = ctx.get().cloned(); + let resp = self.inner.serve(ctx, req).await.map_err(Into::into)?; + let resp = match do_not_print_response { + Some(_) => resp.map(Body::new), + None => { + let (parts, body) = resp.into_parts(); + let body_bytes = body + .collect() + .await + .map_err(|err| OpaqueError::from_boxed(err.into())) + .context("printer prepare: collect response body")? + .to_bytes(); + let resp: http::Response = + Response::from_parts(parts.clone(), Body::from(body_bytes.clone())); + self.writer.write_response(resp).await; + Response::from_parts(parts, Body::from(body_bytes)) + } + }; + Ok(resp) + } +} + +impl ResponseWriter for Sender { + async fn write_response(&self, res: Response) { + if let Err(err) = self.send(res).await { + tracing::error!(err = %err, "failed to send response to channel") + } + } +} + +impl ResponseWriter for UnboundedSender { + async fn write_response(&self, res: Response) { + if let Err(err) = self.send(res) { + tracing::error!(err = %err, "failed to send response to unbounded channel") + } + } +} + +impl ResponseWriter for F +where + F: Fn(Response) -> Fut + Send + Sync + 'static, + Fut: Future + Send + 'static, +{ + async fn write_response(&self, res: Response) { + self(res).await + } +} diff --git a/src/http/mod.rs b/src/http/mod.rs index 1329c5d6..80b915e2 100644 --- a/src/http/mod.rs +++ b/src/http/mod.rs @@ -36,6 +36,8 @@ pub mod server; pub mod client; +pub mod io; + pub mod dep { //! Dependencies for rama http modules. //! @@ -106,6 +108,14 @@ pub mod header { /// Key str constant for the `Proxy-Connection` header. pub const PROXY_CONNECTION_HEADER_KEY: &str = "proxy-connection"; + + /// Static Header Value that is can be used as `User-Agent` or `Server` header. + pub static RAMA_ID_HEADER_VALUE: HeaderValue = + HeaderValue::from_static(const_format::formatcp!( + "{}/{}", + crate::utils::info::NAME, + crate::utils::info::VERSION, + )); } pub use self::dep::http::header::HeaderMap; diff --git a/src/lib.rs b/src/lib.rs index 847558aa..e1f1b352 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -24,7 +24,7 @@ //! | πŸ—οΈ [User Agent (UA)](https://ramaproxy.org/book/intro/user_agent) | πŸ—οΈ Http Emulation (1) βΈ± πŸ—οΈ Tls Emulation (1) βΈ± βœ… [UA Parsing](crate::ua::UserAgent) | //! | πŸ—οΈ utilities | βœ… [error handling](crate::error) βΈ± βœ… [graceful shutdown](crate::utils::graceful) βΈ± πŸ—οΈ Connection Pool (1) βΈ± πŸ—οΈ IP2Loc (2) | //! | πŸ—οΈ [TUI](https://ratatui.rs/) | πŸ—οΈ traffic logger (2) βΈ± πŸ—οΈ curl export (2) βΈ± ❌ traffic intercept (3) βΈ± ❌ traffic replay (3) | -//! | πŸ—οΈ binary | πŸ—οΈ prebuilt binaries (1) βΈ± πŸ—οΈ proxy config (2) βΈ± πŸ—οΈ http client (1) βΈ± ❌ WASM Plugins (3) | +//! | βœ… binary | βœ… [prebuilt binaries](https://ramaproxy.org/book/binary/rama) βΈ± πŸ—οΈ proxy config (2) βΈ± βœ… http client (1) βΈ± ❌ WASM Plugins (3) | //! | πŸ—οΈ data scraping | πŸ—οΈ Html Processor (2) βΈ± ❌ Json Processor (3) | //! | ❌ browser | ❌ JS Engine (3) βΈ± ❌ [Web API](https://developer.mozilla.org/en-US/docs/Web/API) Emulation (3) | //! @@ -299,3 +299,5 @@ pub mod http; pub mod proxy; pub mod ua; + +pub mod cli; diff --git a/src/proxy/http/client/layer.rs b/src/proxy/http/client/layer.rs index f8050731..b5272c4e 100644 --- a/src/proxy/http/client/layer.rs +++ b/src/proxy/http/client/layer.rs @@ -234,7 +234,7 @@ where Body: Send + 'static, { type Response = EstablishedClientConnection; - type Error = OpaqueError; + type Error = BoxError; async fn serve( &self, @@ -313,7 +313,7 @@ where let authority = match request_context.authority() { Some(authority) => authority, None => { - return Err(OpaqueError::from_display("missing http authority")); + return Err("missing http authority".into()); } }; diff --git a/src/service/builder.rs b/src/service/builder.rs index 4192d22d..41fdbe40 100644 --- a/src/service/builder.rs +++ b/src/service/builder.rs @@ -6,9 +6,7 @@ use super::{ layer_fn, AndThenLayer, Identity, LayerFn, MapErrLayer, MapRequestLayer, MapResponseLayer, MapResultLayer, MapStateLayer, Stack, ThenLayer, TraceErrLayer, }, - service_fn, - util::combinators::Either, - BoxService, Layer, Service, + service_fn, BoxService, Layer, Service, }; use std::fmt; use std::future::Future; @@ -63,19 +61,6 @@ impl ServiceBuilder { } } - /// Optionally add a new layer `T` into the [`ServiceBuilder`]. - pub fn option_layer( - self, - layer: Option, - ) -> ServiceBuilder, L>> { - let layer = if let Some(layer) = layer { - Either::A(layer) - } else { - Either::B(Identity::new()) - }; - self.layer(layer) - } - /// Add a [`Layer`] built from a function that accepts a service and returns another service. /// /// See the documentation for [`layer_fn`] for more details. diff --git a/src/service/layer/consume_err.rs b/src/service/layer/consume_err.rs new file mode 100644 index 00000000..a8592ce5 --- /dev/null +++ b/src/service/layer/consume_err.rs @@ -0,0 +1,168 @@ +use crate::{ + error::BoxError, + service::{Context, Layer, Service}, +}; +use std::{convert::Infallible, fmt}; + +use sealed::Trace; + +/// Consumes this service's error value and returns [`Infallible`]. +#[derive(Clone)] +pub struct ConsumeErr { + inner: S, + f: F, +} + +impl fmt::Debug for ConsumeErr +where + S: fmt::Debug, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("ConsumeErr") + .field("inner", &self.inner) + .field("f", &format_args!("{}", std::any::type_name::())) + .finish() + } +} + +/// A [`Layer`] that produces [`ConsumeErr`] services. +/// +/// [`Layer`]: crate::service::Layer +#[derive(Clone)] +pub struct ConsumeErrLayer { + f: F, +} + +impl fmt::Debug for ConsumeErrLayer { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("ConsumeErrLayer") + .field("f", &format_args!("{}", std::any::type_name::())) + .finish() + } +} + +impl Default for ConsumeErrLayer { + fn default() -> Self { + Self::trace(tracing::Level::ERROR) + } +} + +impl ConsumeErr { + /// Creates a new [`ConsumeErr`] service. + pub fn new(inner: S, f: F) -> Self { + ConsumeErr { f, inner } + } +} + +impl ConsumeErr { + /// Trace the error passed to this [`ConsumeErr`] service for the provided trace level. + pub fn trace(inner: S, level: tracing::Level) -> Self { + Self::new(inner, Trace(level)) + } +} + +impl Service for ConsumeErr +where + S: Service, + S::Response: Default, + F: FnOnce(S::Error) + Clone + Send + Sync + 'static, + State: Send + Sync + 'static, + Request: Send + 'static, +{ + type Response = S::Response; + type Error = Infallible; + + async fn serve( + &self, + ctx: Context, + req: Request, + ) -> Result { + match self.inner.serve(ctx, req).await { + Ok(resp) => Ok(resp), + Err(err) => { + (self.f.clone())(err); + Ok(S::Response::default()) + } + } + } +} + +impl Service for ConsumeErr +where + S: Service, + S::Response: Default, + S::Error: Into, + State: Send + Sync + 'static, + Request: Send + 'static, +{ + type Response = S::Response; + type Error = Infallible; + + async fn serve( + &self, + ctx: Context, + req: Request, + ) -> Result { + match self.inner.serve(ctx, req).await { + Ok(resp) => Ok(resp), + Err(err) => { + const MESSAGE: &str = "unhandled service error consumed"; + match self.f.0 { + tracing::Level::TRACE => { + tracing::trace!(error = err.into(), MESSAGE); + } + tracing::Level::DEBUG => { + tracing::debug!(error = err.into(), MESSAGE); + } + tracing::Level::INFO => { + tracing::info!(error = err.into(), MESSAGE); + } + tracing::Level::WARN => { + tracing::warn!(error = err.into(), MESSAGE); + } + tracing::Level::ERROR => { + tracing::error!(error = err.into(), MESSAGE); + } + } + Ok(S::Response::default()) + } + } + } +} + +impl ConsumeErrLayer { + /// Creates a new [`ConsumeErrLayer`]. + pub fn new(f: F) -> Self { + ConsumeErrLayer { f } + } +} + +impl ConsumeErrLayer { + /// Creates a new [`ConsumeErrLayer`] to trace the consumed error. + pub fn trace(level: tracing::Level) -> Self { + Self::new(Trace(level)) + } +} + +impl Layer for ConsumeErrLayer +where + F: Clone, +{ + type Service = ConsumeErr; + + fn layer(&self, inner: S) -> Self::Service { + ConsumeErr { + f: self.f.clone(), + inner, + } + } +} + +mod sealed { + #[derive(Debug, Clone)] + /// A sealed new type to prevent downstream users from + /// passing the trace level directly to the [`ConsumeErr::new`] method. + /// + /// [`ConsumeErr::new`]: crate::service::layer::ConsumeErr::new + pub struct Trace(pub tracing::Level); +} diff --git a/src/service/layer/http/body_limit.rs b/src/service/layer/http/body_limit.rs index cb7ef28e..222c3a98 100644 --- a/src/service/layer/http/body_limit.rs +++ b/src/service/layer/http/body_limit.rs @@ -1,6 +1,7 @@ use bytes::Bytes; use crate::{ + error::BoxError, http::{dep::http_body::Body as HttpBody, Body, BodyLimit, IntoResponse, Request, Response}, service::{Context, Layer, Service}, }; @@ -88,7 +89,7 @@ where S::Response: IntoResponse, State: Send + Sync + 'static, ReqBody: HttpBody + Send + Sync + 'static, - ReqBody::Error: std::error::Error + Send + Sync + 'static, + ReqBody::Error: Into, { type Response = Response; type Error = S::Error; diff --git a/src/service/layer/limit/layer.rs b/src/service/layer/limit/layer.rs index e1df9f85..abf6cec0 100644 --- a/src/service/layer/limit/layer.rs +++ b/src/service/layer/limit/layer.rs @@ -1,4 +1,4 @@ -use super::Limit; +use super::{policy::UnlimitedPolicy, Limit}; use crate::service::Layer; /// Limit requests based on a [`Policy`]. @@ -16,6 +16,15 @@ impl

LimitLayer

{ } } +impl LimitLayer { + /// Creates a new [`LimitLayer`] with an unlimited policy. + /// + /// Meaning that all requests are allowed to proceed. + pub fn unlimited() -> Self { + Self::new(UnlimitedPolicy::default()) + } +} + impl

Clone for LimitLayer

where P: Clone, diff --git a/src/service/layer/limit/mod.rs b/src/service/layer/limit/mod.rs index 37581931..abba5693 100644 --- a/src/service/layer/limit/mod.rs +++ b/src/service/layer/limit/mod.rs @@ -6,6 +6,7 @@ use crate::error::BoxError; use crate::service::{Context, Service}; pub mod policy; +use policy::UnlimitedPolicy; pub use policy::{Policy, PolicyOutput}; mod layer; @@ -29,6 +30,18 @@ impl Limit { } } +impl Limit { + /// Creates a new [`Limit`] with an unlimited policy. + /// + /// Meaning that all requests are allowed to proceed. + pub fn unlimited(inner: T) -> Self { + Limit { + inner, + policy: UnlimitedPolicy, + } + } +} + impl Clone for Limit where T: Clone, diff --git a/src/service/layer/limit/policy/matcher.rs b/src/service/layer/limit/policy/matcher.rs index 7c11821e..3fccbb85 100644 --- a/src/service/layer/limit/policy/matcher.rs +++ b/src/service/layer/limit/policy/matcher.rs @@ -88,7 +88,6 @@ mod tests { use crate::service::{ context::Extensions, layer::limit::policy::{ConcurrentCounter, ConcurrentPolicy}, - matcher::Always, }; use super::*; @@ -109,7 +108,7 @@ mod tests { #[tokio::test] async fn matcher_policy_empty() { - let policy = Vec::<(Always, ConcurrentPolicy<(), ConcurrentCounter>)>::new(); + let policy = Vec::<(bool, ConcurrentPolicy<(), ConcurrentCounter>)>::new(); for i in 0..10 { assert_ready(policy.check(Context::default(), i).await); @@ -120,7 +119,7 @@ mod tests { async fn matcher_policy_always() { let concurrency_policy = ConcurrentPolicy::max(2); - let policy = Arc::new(vec![(Always, concurrency_policy)]); + let policy = Arc::new(vec![(true, concurrency_policy)]); let guard_1 = assert_ready(policy.check(Context::default(), ()).await); let guard_2 = assert_ready(policy.check(Context::default(), ()).await); diff --git a/src/service/layer/limit/policy/mod.rs b/src/service/layer/limit/policy/mod.rs index 6a16680b..5d372a83 100644 --- a/src/service/layer/limit/policy/mod.rs +++ b/src/service/layer/limit/policy/mod.rs @@ -14,7 +14,7 @@ //! The first matching policy is used. //! If no policy matches, the request is allowed to proceed as well. //! If you want to enforce a default policy, you can add a policy with a [`Matcher`] that always matches, -//! such as [`matcher::Always`]. +//! such as the bool `true`. //! //! Note that the [`Matcher`]s will not receive the mutable [`Extensions`], //! as polices are not intended to keep track of what is matched on. @@ -24,13 +24,11 @@ //! See the [`http_rate_limit.rs`] example for a use case. //! //! [`Matcher`]: crate::service::Matcher -//! [`matcher::Always`]: crate::service::matcher::Always //! [`Extensions`]: crate::service::context::Extensions //! [`http_listener_hello.rs`]: https://github.com/plabayo/rama/blob/main/examples/http_rate_limit.rs -use std::sync::Arc; - use crate::service::Context; +use std::{convert::Infallible, sync::Arc}; mod concurrent; #[doc(inline)] @@ -168,3 +166,36 @@ where self.as_ref().check(ctx, request).await } } + +#[derive(Debug, Clone, Default)] +#[non_exhaustive] +/// An unlimited policy that allows all requests to proceed. +pub struct UnlimitedPolicy; + +impl UnlimitedPolicy { + /// Create a new [`UnlimitedPolicy`]. + pub fn new() -> Self { + UnlimitedPolicy + } +} + +impl Policy for UnlimitedPolicy +where + State: Send + Sync + 'static, + Request: Send + 'static, +{ + type Guard = (); + type Error = Infallible; + + async fn check( + &self, + ctx: Context, + request: Request, + ) -> PolicyResult { + PolicyResult { + ctx, + request, + output: PolicyOutput::Ready(()), + } + } +} diff --git a/src/service/layer/map_err.rs b/src/service/layer/map_err.rs index 8f69414b..ae6c6028 100644 --- a/src/service/layer/map_err.rs +++ b/src/service/layer/map_err.rs @@ -28,11 +28,19 @@ where /// A [`Layer`] that produces [`MapErr`] services. /// /// [`Layer`]: crate::service::Layer -#[derive(Clone, Debug)] +#[derive(Clone)] pub struct MapErrLayer { f: F, } +impl std::fmt::Debug for MapErrLayer { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("MapErrLayer") + .field("f", &format_args!("{}", std::any::type_name::())) + .finish() + } +} + impl MapErr { /// Creates a new [`MapErr`] service. pub fn new(inner: S, f: F) -> Self { diff --git a/src/service/layer/map_request.rs b/src/service/layer/map_request.rs index 62e5c120..729502fd 100644 --- a/src/service/layer/map_request.rs +++ b/src/service/layer/map_request.rs @@ -52,11 +52,19 @@ where /// A [`Layer`] that produces [`MapRequest`] services. /// /// [`Layer`]: crate::service::Layer -#[derive(Clone, Debug)] +#[derive(Clone)] pub struct MapRequestLayer { f: F, } +impl fmt::Debug for MapRequestLayer { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("MapRequestLayer") + .field("f", &format_args!("{}", std::any::type_name::())) + .finish() + } +} + impl MapRequestLayer { /// Creates a new [`MapRequestLayer`]. pub fn new(f: F) -> Self { diff --git a/src/service/layer/map_result.rs b/src/service/layer/map_result.rs index 923bf359..5681de08 100644 --- a/src/service/layer/map_result.rs +++ b/src/service/layer/map_result.rs @@ -59,11 +59,18 @@ where /// A [`Layer`] that produces a [`MapResult`] service. /// /// [`Layer`]: crate::service::Layer -#[derive(Debug)] pub struct MapResultLayer { f: F, } +impl fmt::Debug for MapResultLayer { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("MapResultLayer") + .field("f", &format_args!("{}", std::any::type_name::())) + .finish() + } +} + impl Clone for MapResultLayer where F: Clone, diff --git a/src/service/layer/map_state.rs b/src/service/layer/map_state.rs index 18e45b62..a8badefa 100644 --- a/src/service/layer/map_state.rs +++ b/src/service/layer/map_state.rs @@ -8,9 +8,12 @@ pub struct MapState { f: F, } -impl std::fmt::Debug for MapState { +impl std::fmt::Debug for MapState { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("MapState").finish() + f.debug_struct("MapState") + .field("inner", &self.inner) + .field("f", &format_args!("{}", std::any::type_name::())) + .finish() } } @@ -64,7 +67,9 @@ pub struct MapStateLayer { impl std::fmt::Debug for MapStateLayer { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("MapStateLayer").finish() + f.debug_struct("MapStateLayer") + .field("f", &format_args!("{}", std::any::type_name::())) + .finish() } } diff --git a/src/service/layer/mod.rs b/src/service/layer/mod.rs index 2d780c9b..0db23b91 100644 --- a/src/service/layer/mod.rs +++ b/src/service/layer/mod.rs @@ -13,6 +13,20 @@ pub trait Layer { fn layer(&self, inner: S) -> Self::Service; } +impl Layer for Option +where + L: Layer, +{ + type Service = Either; + + fn layer(&self, inner: S) -> Self::Service { + match self { + Some(layer) => Either::A(layer.layer(inner)), + None => Either::B(inner), + } + } +} + mod into_error; #[doc(inline)] pub use into_error::{LayerErrorFn, LayerErrorStatic, MakeLayerError}; @@ -57,6 +71,10 @@ mod map_err; #[doc(inline)] pub use map_err::{MapErr, MapErrLayer}; +mod consume_err; +#[doc(inline)] +pub use consume_err::{ConsumeErr, ConsumeErrLayer}; + mod trace_err; #[doc(inline)] pub use trace_err::{TraceErr, TraceErrLayer}; @@ -74,4 +92,6 @@ pub use limit::{Limit, LimitLayer}; pub mod add_extension; pub use add_extension::{AddExtension, AddExtensionLayer}; +use super::util::combinators::Either; + pub mod http; diff --git a/src/service/matcher/always.rs b/src/service/matcher/always.rs deleted file mode 100644 index 66fefe9f..00000000 --- a/src/service/matcher/always.rs +++ /dev/null @@ -1,21 +0,0 @@ -use crate::service::{context::Extensions, Context}; - -use super::Matcher; - -#[derive(Debug, Default)] -#[non_exhaustive] -/// Matches any request. -pub struct Always; - -impl Always { - /// Create a new instance of `Always`. - pub fn new() -> Self { - Self - } -} - -impl Matcher for Always { - fn matches(&self, _: Option<&mut Extensions>, _: &Context, _: &Request) -> bool { - true - } -} diff --git a/src/service/matcher/mod.rs b/src/service/matcher/mod.rs index 19f0000b..f3edbf19 100644 --- a/src/service/matcher/mod.rs +++ b/src/service/matcher/mod.rs @@ -5,7 +5,7 @@ //! //! - Examples of this are iterator "reducers" as made available via [`IteratorMatcherExt`], //! as well as optional [`Matcher::or`] and [`Matcher::and`] trait methods. -//! - These all serve as building blocks together with [`And`], [`Or`], [`Not`] and [`Always`] +//! - These all serve as building blocks together with [`And`], [`Or`], [`Not`] and a bool //! to combine and transform any kind of [`Matcher`]. //! - And finally there is [`MatchFn`], easily created using [`match_fn`] to create a [`Matcher`] //! from any compatible [`Fn`]. @@ -24,10 +24,6 @@ use super::{context::Extensions, Context}; -mod always; -#[doc(inline)] -pub use always::Always; - mod op_or; #[doc(inline)] pub use op_or::{or, Or}; @@ -116,5 +112,11 @@ where } } +impl Matcher for bool { + fn matches(&self, _: Option<&mut Extensions>, _: &Context, _: &Request) -> bool { + *self + } +} + #[cfg(test)] mod test; diff --git a/src/service/matcher/test.rs b/src/service/matcher/test.rs index 62370c97..32c209f3 100644 --- a/src/service/matcher/test.rs +++ b/src/service/matcher/test.rs @@ -1,28 +1,16 @@ use super::*; -#[test] -fn test_always() { - assert!(Always.matches(None, &Context::default(), &())); - assert!(Always.matches(None, &Context::default(), &0)); - assert!(Always.matches(None, &Context::default(), &false)); - assert!(Always.matches(None, &Context::default(), &"foo")); -} - #[test] fn test_not() { - assert!(!Not::new(Always).matches(None, &Context::default(), &())); + assert!(!Not::new(true).matches(None, &Context::default(), &())); } #[test] fn test_not_builder() { - assert!(!Always::new().not().matches(None, &Context::default(), &())); - assert!(!Always::new().not().matches(None, &Context::default(), &0)); - assert!(!Always::new() - .not() - .matches(None, &Context::default(), &false)); - assert!(!Always::new() - .not() - .matches(None, &Context::default(), &"foo")); + assert!(!true.not().matches(None, &Context::default(), &())); + assert!(!true.not().matches(None, &Context::default(), &0)); + assert!(!true.not().matches(None, &Context::default(), &false)); + assert!(!true.not().matches(None, &Context::default(), &"foo")); } mod marker { diff --git a/src/service/util/combinators/either.rs b/src/service/util/combinators/either.rs index 7cf01854..17550b05 100644 --- a/src/service/util/combinators/either.rs +++ b/src/service/util/combinators/either.rs @@ -1,7 +1,12 @@ +use crate::error::BoxError; use crate::http::{self, layer::retry}; use crate::service::{ context::Extensions, layer::limit, matcher::Matcher, Context, Layer, Service, }; +use std::io::IoSlice; +use std::pin::Pin; +use std::task::{Context as TaskContext, Poll}; +use tokio::io::{AsyncRead, AsyncWrite, Error as IoError, ReadBuf, Result as IoResult}; macro_rules! create_either { ($id:ident, $($param:ident),+ $(,)?) => { @@ -44,21 +49,23 @@ macro_rules! create_either { } } - impl<$($param),+, State, Request, Response, Error> Service for $id<$($param),+> + impl<$($param),+, State, Request, Response> Service for $id<$($param),+> where - $($param: Service),+, + $( + $param: Service, + $param::Error: Into, + )+ Request: Send + 'static, State: Send + Sync + 'static, Response: Send + 'static, - Error: Send + Sync + 'static, { type Response = Response; - type Error = Error; + type Error = BoxError; async fn serve(&self, ctx: Context, req: Request) -> Result { match self { $( - $id::$param(s) => s.serve(ctx, req).await, + $id::$param(s) => s.serve(ctx, req).await.map_err(Into::into), )+ } } @@ -99,15 +106,17 @@ macro_rules! create_either { } } - impl<$($param),+, State, Request, Error> limit::Policy for $id<$($param),+> + impl<$($param),+, State, Request> limit::Policy for $id<$($param),+> where - $($param: limit::Policy),+, + $( + $param: limit::Policy, + $param::Error: Into, + )+ Request: Send + 'static, State: Send + Sync + 'static, - Error: Send + Sync + 'static, { type Guard = $id<$($param::Guard),+>; - type Error = Error; + type Error = BoxError; async fn check( &self, @@ -127,7 +136,7 @@ macro_rules! create_either { limit::policy::PolicyOutput::Abort(err) => limit::policy::PolicyResult { ctx: result.ctx, request: result.request, - output: limit::policy::PolicyOutput::Abort(err), + output: limit::policy::PolicyOutput::Abort(err.into()), }, limit::policy::PolicyOutput::Retry => limit::policy::PolicyResult { ctx: result.ctx, @@ -173,6 +182,76 @@ macro_rules! create_either { } } } + + impl<$($param),+> AsyncRead for $id<$($param),+> + where + $($param: AsyncRead + Unpin),+, + { + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut TaskContext<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + match &mut *self { + $( + $id::$param(reader) => Pin::new(reader).poll_read(cx, buf), + )+ + } + } + } + + impl<$($param),+> AsyncWrite for $id<$($param),+> + where + $($param: AsyncWrite + Unpin),+, + { + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut TaskContext<'_>, + buf: &[u8], + ) -> Poll> { + match &mut *self { + $( + $id::$param(writer) => Pin::new(writer).poll_write(cx, buf), + )+ + } + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut TaskContext<'_>) -> Poll> { + match &mut *self { + $( + $id::$param(writer) => Pin::new(writer).poll_flush(cx), + )+ + } + } + + fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut TaskContext<'_>) -> Poll> { + match &mut *self { + $( + $id::$param(writer) => Pin::new(writer).poll_shutdown(cx), + )+ + } + } + + fn poll_write_vectored( + mut self: Pin<&mut Self>, + cx: &mut TaskContext<'_>, + bufs: &[IoSlice<'_>], + ) -> Poll> { + match &mut *self { + $( + $id::$param(writer) => Pin::new(writer).poll_write_vectored(cx, bufs), + )+ + } + } + + fn is_write_vectored(&self) -> bool { + match self { + $( + $id::$param(reader) => reader.is_write_vectored(), + )+ + } + } + } }; } diff --git a/src/tls/rustls/client/http.rs b/src/tls/rustls/client/http.rs index cd31eb26..288e3e4b 100644 --- a/src/tls/rustls/client/http.rs +++ b/src/tls/rustls/client/http.rs @@ -1,4 +1,4 @@ -use crate::error::{BoxError, ErrorExt, OpaqueError}; +use crate::error::{BoxError, ErrorExt}; use crate::http::client::{ClientConnection, EstablishedClientConnection}; use crate::http::{Request, RequestContext}; use crate::service::{Context, Service}; @@ -177,7 +177,7 @@ where Body: Send + 'static, { type Response = EstablishedClientConnection, Body, State>; - type Error = OpaqueError; + type Error = BoxError; async fn serve( &self, @@ -185,10 +185,7 @@ where req: Request, ) -> Result { let EstablishedClientConnection { mut ctx, req, conn } = - self.inner - .serve(ctx, req) - .await - .map_err(|err| OpaqueError::from_boxed(err.into()))?; + self.inner.serve(ctx, req).await.map_err(Into::into)?; let (addr, stream) = conn.into_parts(); let request_ctx = ctx.get_or_insert_with(|| RequestContext::new(&req)); @@ -209,11 +206,11 @@ where let host = match request_ctx.host.as_deref() { Some(host) => host, None => { - return Err(OpaqueError::from_display("missing http host")); + return Err("missing http host".into()); } }; let domain = pki_types::ServerName::try_from(host) - .map_err(|err| OpaqueError::from_std(err).context("invalid DNS Hostname (tls)"))? + .map_err(|err| err.context("invalid DNS Hostname (tls)"))? .to_owned(); let stream = self.handshake(domain, stream).await?; @@ -240,7 +237,7 @@ where Body: Send + 'static, { type Response = EstablishedClientConnection, Body, State>; - type Error = OpaqueError; + type Error = BoxError; async fn serve( &self, @@ -251,11 +248,7 @@ where mut ctx, mut req, conn, - } = self - .inner - .serve(ctx, req) - .await - .map_err(|err| OpaqueError::from_boxed(err.into()))?; + } = self.inner.serve(ctx, req).await.map_err(Into::into)?; let (addr, stream) = conn.into_parts(); @@ -276,11 +269,11 @@ where let host = match request_ctx.host.as_deref() { Some(host) => host, None => { - return Err(OpaqueError::from_display("missing http host")); + return Err("missing http host".into()); } }; let domain = pki_types::ServerName::try_from(host) - .map_err(|err| OpaqueError::from_std(err).context("invalid DNS Hostname (tls)"))? + .map_err(|err| err.context("invalid DNS Hostname (tls)"))? .to_owned(); let stream = self.handshake(domain, stream).await?; @@ -302,27 +295,21 @@ where Body: Send + 'static, { type Response = EstablishedClientConnection, Body, State>; - type Error = OpaqueError; + type Error = BoxError; async fn serve( &self, ctx: Context, req: Request, ) -> Result { - let EstablishedClientConnection { ctx, req, conn } = self - .inner - .serve(ctx, req) - .await - .map_err(|err| OpaqueError::from_boxed(err.into()))?; + let EstablishedClientConnection { ctx, req, conn } = + self.inner.serve(ctx, req).await.map_err(Into::into)?; let (addr, stream) = conn.into_parts(); let domain = match ctx.get::() { Some(tunnel) => pki_types::ServerName::try_from(tunnel.server_name.as_str()) - .map_err(|err| { - OpaqueError::from_std(err) - .context("invalid DNS Hostname (tls) for https tunnel") - })? + .map_err(|err| err.context("invalid DNS Hostname (tls) for https tunnel"))? .to_owned(), None => { return Ok(EstablishedClientConnection { @@ -358,7 +345,7 @@ impl HttpsConnector { &self, server_name: ServerName<'static>, stream: T, - ) -> Result, OpaqueError> + ) -> Result, BoxError> where T: Stream + Unpin, { @@ -371,7 +358,7 @@ impl HttpsConnector { connector .connect(server_name, stream) .await - .map_err(OpaqueError::from_std) + .map_err(Into::into) } } diff --git a/src/ua/layer.rs b/src/ua/layer.rs index 873c62e0..6d59b71f 100644 --- a/src/ua/layer.rs +++ b/src/ua/layer.rs @@ -177,6 +177,7 @@ mod tests { use super::*; use crate::http::client::HttpClientExt; use crate::http::headers; + use crate::http::layer::required_header::AddRequiredRequestHeadersLayer; use crate::http::{IntoResponse, StatusCode}; use crate::ua::{PlatformKind, UserAgentKind}; use crate::{ @@ -206,6 +207,7 @@ mod tests { } let service = ServiceBuilder::new() + .layer(AddRequiredRequestHeadersLayer::default()) .layer(UserAgentClassifierLayer::new()) .service_fn(handle); diff --git a/src/utils/graceful.rs b/src/utils/graceful.rs index 0517c6ad..c3f1571e 100644 --- a/src/utils/graceful.rs +++ b/src/utils/graceful.rs @@ -1,3 +1,3 @@ //! Shutdown management for graceful shutdown of async-first applications. -pub use tokio_graceful::{Shutdown, ShutdownGuard, WeakShutdownGuard}; +pub use tokio_graceful::{default_signal, Shutdown, ShutdownGuard, WeakShutdownGuard}; diff --git a/src/utils/username.rs b/src/utils/username.rs index 913ddd6f..8cf2c99c 100644 --- a/src/utils/username.rs +++ b/src/utils/username.rs @@ -51,7 +51,7 @@ //! assert!(filter.mobile.is_none()); //! ``` -use crate::error::OpaqueError; +use crate::error::{BoxError, OpaqueError}; use crate::service::context::Extensions; use std::{convert::Infallible, fmt}; @@ -68,7 +68,7 @@ pub fn parse_username

( ) -> Result where P: UsernameLabelParser, - P::Error: std::error::Error + Send + Sync + 'static, + P::Error: Into, { let username_ref = username_ref.as_ref(); let mut label_it = username_ref.split(separator); @@ -93,7 +93,9 @@ where } } - parser.build(ext).map_err(OpaqueError::from_std)?; + parser + .build(ext) + .map_err(|err| OpaqueError::from_boxed(err.into()))?; Ok(username.to_owned()) } @@ -124,7 +126,7 @@ pub enum UsernameLabelState { /// as it is what is used to create the parser instances for one-time usage. pub trait UsernameLabelParser: Default + Send + Sync + 'static { /// Error which can occur during the building phase. - type Error: std::error::Error + Send + Sync + 'static; + type Error: Into; /// Interpret the label and return whether or not the label was recognised and valid. /// @@ -166,7 +168,7 @@ macro_rules! username_label_parser_tuple_impl { where $( $T: UsernameLabelParser, - $T::Error: std::error::Error + Send + Sync + 'static, + $T::Error: Into, )+ { type Error = OpaqueError; @@ -185,7 +187,7 @@ macro_rules! username_label_parser_tuple_impl { fn build(self, ext: &mut Extensions) -> Result<(), Self::Error> { let ($($T,)+) = self; $( - $T.build(ext).map_err(OpaqueError::from_std)?; + $T.build(ext).map_err(|err| OpaqueError::from_boxed(err.into()))?; )+ Ok(()) } @@ -202,7 +204,7 @@ macro_rules! username_label_parser_tuple_exclusive_labels_impl { where $( $T: UsernameLabelParser, - $T::Error: std::error::Error + Send + Sync + 'static, + $T::Error: Into, )+ { type Error = OpaqueError; @@ -220,7 +222,7 @@ macro_rules! username_label_parser_tuple_exclusive_labels_impl { fn build(self, ext: &mut Extensions) -> Result<(), Self::Error> { let ($($T,)+) = self.0; $( - $T.build(ext).map_err(OpaqueError::from_std)?; + $T.build(ext).map_err(|err| OpaqueError::from_boxed(err.into()))?; )+ Ok(()) } diff --git a/tests/cli.rs b/tests/cli.rs new file mode 100644 index 00000000..66accc3a --- /dev/null +++ b/tests/cli.rs @@ -0,0 +1 @@ +mod cli_tests; diff --git a/tests/cli_tests/help.rs b/tests/cli_tests/help.rs new file mode 100644 index 00000000..f9caa004 --- /dev/null +++ b/tests/cli_tests/help.rs @@ -0,0 +1,40 @@ +use super::utils; + +#[tokio::test] +#[ignore] +async fn test_help() { + let lines = utils::RamaService::run(vec!["help"]).unwrap(); + assert!(lines.contains("rama cli to move and transform network packets")); + assert!(lines.contains("Usage:")); + assert!(lines.contains("Commands:")); + assert!(lines.contains("Options:")); +} + +#[tokio::test] +#[ignore] +async fn test_help_ip() { + let lines = utils::RamaService::run(vec!["help", "ip"]).unwrap(); + assert!(lines.contains("rama ip service")); + assert!(lines.contains("Usage:")); + assert!(lines.contains("Options:")); +} + +#[tokio::test] +#[ignore] +async fn test_help_echo() { + let lines = utils::RamaService::run(vec!["help", "echo"]).unwrap(); + assert!(lines.contains("rama echo service")); + assert!(lines.contains("Usage:")); + assert!(lines.contains("Options:")); +} + +#[tokio::test] +#[ignore] +async fn test_help_http() { + let lines = utils::RamaService::run(vec!["help", "http"]).unwrap(); + assert!(lines.contains("rama http client")); + assert!(lines.contains("Usage:")); + assert!(lines.contains("Arguments:")); + assert!(lines.contains("rama http :3000")); + assert!(lines.contains("Options:")); +} diff --git a/tests/cli_tests/http_echo.rs b/tests/cli_tests/http_echo.rs new file mode 100644 index 00000000..61c76e52 --- /dev/null +++ b/tests/cli_tests/http_echo.rs @@ -0,0 +1,24 @@ +use super::utils; + +#[tokio::test] +#[ignore] +async fn test_http_echo() { + let _guard = utils::RamaService::echo(63101); + + let lines = utils::RamaService::http(vec!["http://127.0.0.1:63101"]).unwrap(); + assert!(lines.contains("HTTP/1.1 200 OK"), "lines: {:?}", lines); + + let lines = + utils::RamaService::http(vec!["http://127.0.0.1:63101", "foo:bar", "a=4", "q==1"]).unwrap(); + assert!(lines.contains("HTTP/1.1 200 OK"), "lines: {:?}", lines); + assert!(lines.contains(r##""method":"POST""##), "lines: {:?}", lines); + assert!(lines.contains(r##""foo","bar""##), "lines: {:?}", lines); + assert!( + lines.contains(r##""content-type","application/json""##), + "lines: {:?}", + lines + ); + assert!(lines.contains(r##""a":"4""##), "lines: {:?}", lines); + assert!(lines.contains(r##""path":"/""##), "lines: {:?}", lines); + assert!(lines.contains(r##""query":"q=1""##), "lines: {:?}", lines); +} diff --git a/tests/cli_tests/http_ip.rs b/tests/cli_tests/http_ip.rs new file mode 100644 index 00000000..be521753 --- /dev/null +++ b/tests/cli_tests/http_ip.rs @@ -0,0 +1,11 @@ +use super::utils; + +#[tokio::test] +#[ignore] +async fn test_http_ip() { + let _guard = utils::RamaService::ip(63100); + + let lines = utils::RamaService::http(vec!["http://127.0.0.1:63100"]).unwrap(); + assert!(lines.contains("HTTP/1.1 200 OK")); + assert!(lines.contains("127.0.0.1:")); +} diff --git a/tests/cli_tests/mod.rs b/tests/cli_tests/mod.rs new file mode 100644 index 00000000..3097d1f9 --- /dev/null +++ b/tests/cli_tests/mod.rs @@ -0,0 +1,5 @@ +mod utils; + +mod help; +mod http_echo; +mod http_ip; diff --git a/tests/cli_tests/utils/mod.rs b/tests/cli_tests/utils/mod.rs new file mode 100644 index 00000000..fbd660b2 --- /dev/null +++ b/tests/cli_tests/utils/mod.rs @@ -0,0 +1,120 @@ +#![allow(dead_code)] + +use std::{ + io::{BufRead, BufReader}, + process::Child, + thread, +}; + +#[derive(Debug)] +/// A wrapper around a rama service process. +pub struct RamaService { + process: Child, +} + +impl RamaService { + /// Start the rama Ip service with the given port. + pub fn ip(port: u16) -> Self { + let mut process = escargot::CargoBuild::new() + .package("rama-cli") + .bin("rama") + .target_dir("./target/") + .run() + .unwrap() + .command() + .stdout(std::process::Stdio::piped()) + .arg("ip") + .arg("-p") + .arg(port.to_string()) + .spawn() + .unwrap(); + + let stdout = process.stdout.take().unwrap(); + let mut stdout = BufReader::new(stdout).lines(); + + for line in &mut stdout { + let line = line.unwrap(); + if line.contains("ip service ready") { + break; + } + } + + thread::spawn(move || { + for line in stdout { + let line = line.unwrap(); + eprintln!("rama ip >> {}", line); + } + }); + + Self { process } + } + + /// Start the rama echo service with the given port. + pub fn echo(port: u16) -> Self { + let mut process = escargot::CargoBuild::new() + .package("rama-cli") + .bin("rama") + .target_dir("./target/") + .run() + .unwrap() + .command() + .stdout(std::process::Stdio::piped()) + .arg("echo") + .arg("-p") + .arg(port.to_string()) + .spawn() + .unwrap(); + + let stdout = process.stdout.take().unwrap(); + let mut stdout = BufReader::new(stdout).lines(); + + for line in &mut stdout { + let line = line.unwrap(); + if line.contains("echo service ready") { + break; + } + } + + thread::spawn(move || { + for line in stdout { + let line = line.unwrap(); + println!("rama echo >> {}", line); + } + }); + + Self { process } + } + + /// Run any rama cmd + pub fn run(args: Vec<&'static str>) -> Result> { + let child = escargot::CargoBuild::new() + .package("rama-cli") + .bin("rama") + .target_dir("./target/") + .run() + .unwrap() + .command() + .stdout(std::process::Stdio::piped()) + .args(args) + .spawn() + .unwrap(); + + let output = child.wait_with_output()?; + assert!(output.status.success()); + let output = String::from_utf8(output.stdout)?; + Ok(output) + } + + /// Run the http command + pub fn http(input_args: Vec<&'static str>) -> Result> { + let mut args = vec!["http", "--debug", "-v", "--all", "-F"]; + args.extend(input_args); + Self::run(args) + } +} + +impl Drop for RamaService { + fn drop(&mut self) { + self.process.kill().expect("kill server process"); + } +} diff --git a/tests/example_tests/utils/mod.rs b/tests/example_tests/utils/mod.rs index e5b4ba61..1df7c38f 100644 --- a/tests/example_tests/utils/mod.rs +++ b/tests/example_tests/utils/mod.rs @@ -7,6 +7,7 @@ use rama::{ layer::{ decompression::DecompressionLayer, follow_redirect::FollowRedirectLayer, + required_header::AddRequiredRequestHeadersLayer, retry::{ManagedPolicy, RetryLayer}, trace::TraceLayer, }, @@ -95,6 +96,7 @@ where .unwrap(), ), )) + .layer(AddRequiredRequestHeadersLayer::default()) .service(HttpClient::new( ServiceBuilder::new() .layer(HttpsConnectorLayer::auto())