Skip to content

Commit

Permalink
feat: introduce HeaderName instead of raw strings
Browse files Browse the repository at this point in the history
  • Loading branch information
augustoccesar committed Oct 17, 2023
1 parent b05cd77 commit 3b20311
Show file tree
Hide file tree
Showing 3 changed files with 84 additions and 46 deletions.
47 changes: 37 additions & 10 deletions linkup/src/headers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,28 @@ use unicase::UniCase;

pub struct HeaderMap(HashMap<UniCase<String>, String>);

pub enum HeaderName {
ForwardedHost,
TraceParent,
TraceState,
LinkupDestination,
Referer,
Origin,
}

impl From<HeaderName> for UniCase<String> {
fn from(value: HeaderName) -> Self {
match value {
HeaderName::ForwardedHost => "x-forwarded-host".into(),
HeaderName::TraceParent => "traceparent".into(),
HeaderName::TraceState => "tracestate".into(),
HeaderName::LinkupDestination => "linkup-destination".into(),
HeaderName::Referer => "referer".into(),
HeaderName::Origin => "origin".into(),
}
}
}

impl IntoIterator for &HeaderMap {
type Item = (UniCase<String>, String);
type IntoIter = std::collections::hash_map::IntoIter<UniCase<String>, String>;
Expand All @@ -30,26 +52,31 @@ impl HeaderMap {
Self(HashMap::new())
}

pub fn contains_key(&self, key: impl ToString) -> bool {
self.0.contains_key(&UniCase::new(key.to_string()))
pub fn contains_key(&self, key: impl Into<UniCase<String>>) -> bool {
self.0.contains_key(&key.into())
}

pub fn get(&self, key: impl ToString) -> Option<&str> {
self.0
.get(&UniCase::new(key.to_string()))
.map(String::as_ref)
pub fn get(&self, key: impl Into<UniCase<String>>) -> Option<&str> {
self.0.get(&key.into()).map(String::as_ref)
}

pub fn get_or_default<'a>(&'a self, key: impl ToString, default: &'a str) -> &'a str {
pub fn get_or_default<'a>(
&'a self,
key: impl Into<UniCase<String>>,
default: &'a str,
) -> &'a str {
match self.get(key) {
Some(value) => value,
None => default,
}
}

pub fn insert(&mut self, key: impl ToString, value: impl ToString) -> Option<String> {
self.0
.insert(UniCase::new(key.to_string()), value.to_string())
pub fn insert(
&mut self,
key: impl Into<UniCase<String>>,
value: impl ToString,
) -> Option<String> {
self.0.insert(key.into(), value.to_string())
}

pub fn extend(&mut self, iter: &HeaderMap) {
Expand Down
76 changes: 43 additions & 33 deletions linkup/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ mod session;
mod session_allocator;

use async_trait::async_trait;
use headers::HeaderName;
use rand::Rng;
use thiserror::Error;

Expand Down Expand Up @@ -49,7 +50,7 @@ pub fn get_additional_headers(
) -> HeaderMap {
let mut additional_headers = HeaderMap::new();

if !headers.contains_key("traceparent") {
if !headers.contains_key(HeaderName::ForwardedHost) {
let mut rng = rand::thread_rng();
let trace: [u8; 16] = rng.gen();
let parent: [u8; 8] = rng.gen();
Expand All @@ -62,29 +63,32 @@ pub fn get_additional_headers(
let flags_hex = hex::encode(flags);

let traceparent = format!("{}-{}-{}-{}", version_hex, trace_hex, parent_hex, flags_hex);
additional_headers.insert("traceparent", traceparent);
additional_headers.insert(HeaderName::TraceParent, traceparent);
}

let tracestate = headers.get("tracestate");
let tracestate = headers.get(HeaderName::TraceState);
let linkup_session = format!("linkup-session={}", session_name,);
match tracestate {
Some(ts) if !ts.contains(&linkup_session) => {
let new_tracestate = format!("{},{}", ts, linkup_session);
additional_headers.insert("tracestate", new_tracestate);
additional_headers.insert(HeaderName::TraceState, new_tracestate);
}
None => {
let new_tracestate = linkup_session;
additional_headers.insert("tracestate", new_tracestate);
additional_headers.insert(HeaderName::TraceState, new_tracestate);
}
_ => {}
}

if !headers.contains_key("linkup-destination") {
additional_headers.insert("linkup-destination", &target_service.name);
if !headers.contains_key(HeaderName::LinkupDestination) {
additional_headers.insert(HeaderName::LinkupDestination, &target_service.name);
}

if !headers.contains_key("x-forwarded-host") {
additional_headers.insert("x-forwarded-host", get_target_domain(url, session_name));
if !headers.contains_key(HeaderName::ForwardedHost) {
additional_headers.insert(
HeaderName::ForwardedHost,
get_target_domain(url, session_name),
);
}

additional_headers.insert("host", Url::parse(&target_service.url).unwrap());
Expand Down Expand Up @@ -128,7 +132,7 @@ pub fn get_target_service(

// If there was a destination created in a previous linkup, we don't want to
// re-do path rewrites, so we use the destination service.
if let Some(destination_service) = headers.get("linkup-destination") {
if let Some(destination_service) = headers.get(HeaderName::LinkupDestination) {
if let Some(service) = config.services.get(destination_service) {
let target = redirect(target.clone(), &service.origin, Some(path.to_string()));
return Some(TargetService {
Expand All @@ -142,19 +146,19 @@ pub fn get_target_service(

// Forwarded hosts persist over the tunnel
let forwarded_host_target = config.domains.get(&get_target_domain(
headers.get_or_default("X-Forwarded-Host", "does-not-exist"),
headers.get_or_default(HeaderName::ForwardedHost, "does-not-exist"),
session_name,
));

// This is more for e2e tests to work
let referer_target = config.domains.get(&get_target_domain(
headers.get_or_default("referer", "does-not-exist"),
headers.get_or_default(HeaderName::Referer, "does-not-exist"),
session_name,
));

// This one is for redirects, where the referer doesn't exist
let origin_target = config.domains.get(&get_target_domain(
headers.get_or_default("origin", "does-not-exist"),
headers.get_or_default(HeaderName::Origin, "does-not-exist"),
session_name,
));

Expand Down Expand Up @@ -360,7 +364,7 @@ mod tests {
// Trace state
let mut trace_headers = HeaderMap::new();
trace_headers.insert(
"tracestate",
HeaderName::TraceState,
format!("some-other=xyz,linkup-session={}", name),
);
sessions
Expand All @@ -369,7 +373,7 @@ mod tests {
.unwrap();

let mut trace_headers_two = HeaderMap::new();
trace_headers_two.insert("tracestate", format!("linkup-session={}", name));
trace_headers_two.insert(HeaderName::TraceState, format!("linkup-session={}", name));
sessions
.get_request_session("example.com", &trace_headers_two)
.await
Expand All @@ -391,46 +395,52 @@ mod tests {
&target_service,
);

assert_eq!(add_headers.get("traceparent").unwrap().len(), 55);
assert_eq!(add_headers.get(HeaderName::TraceParent).unwrap().len(), 55);
assert_eq!(
add_headers.get("tracestate").unwrap(),
add_headers.get(HeaderName::TraceState).unwrap(),
"linkup-session=tiny-cow"
);
assert_eq!(add_headers.get("x-forwarded-host").unwrap(), "example.com");
assert_eq!(add_headers.get("linkup-destination").unwrap(), "frontend");
assert_eq!(
add_headers.get(HeaderName::ForwardedHost).unwrap(),
"example.com"
);
assert_eq!(
add_headers.get(HeaderName::LinkupDestination).unwrap(),
"frontend"
);

let mut already_headers = HeaderMap::new();
already_headers.insert("traceparent", "anything");
already_headers.insert("tracestate", "linkup-session=tiny-cow");
already_headers.insert("X-Forwarded-Host", "example.com");
already_headers.insert("linkup-destination", "frontend");
already_headers.insert(HeaderName::TraceParent, "anything");
already_headers.insert(HeaderName::TraceState, "linkup-session=tiny-cow");
already_headers.insert(HeaderName::ForwardedHost, "example.com");
already_headers.insert(HeaderName::LinkupDestination, "frontend");
let add_headers = get_additional_headers(
"https://abc.some-tunnel.com/abc-xyz",
&already_headers,
&session_name,
&target_service,
);

assert!(add_headers.get("traceparent").is_none());
assert!(add_headers.get("tracestate").is_none());
assert!(add_headers.get("X-Forwarded-Host").is_none());
assert!(add_headers.get("linkup-destination").is_none());
assert!(add_headers.get(HeaderName::TraceParent).is_none());
assert!(add_headers.get(HeaderName::TraceState).is_none());
assert!(add_headers.get(HeaderName::ForwardedHost).is_none());
assert!(add_headers.get(HeaderName::LinkupDestination).is_none());

let mut already_headers_two = HeaderMap::new();
already_headers_two.insert("traceparent", "anything");
already_headers_two.insert("tracestate", "other-service=32");
already_headers_two.insert("X-Forwarded-Host", "example.com");
already_headers_two.insert(HeaderName::TraceParent, "anything");
already_headers_two.insert(HeaderName::TraceState, "other-service=32");
already_headers_two.insert(HeaderName::ForwardedHost, "example.com");
let add_headers = get_additional_headers(
"https://abc.some-tunnel.com/abc-xyz",
&already_headers_two,
&session_name,
&target_service,
);

assert!(add_headers.get("traceparent").is_none());
assert!(add_headers.get("X-Forwarded-Host").is_none());
assert!(add_headers.get(HeaderName::TraceParent).is_none());
assert!(add_headers.get(HeaderName::ForwardedHost).is_none());
assert_eq!(
add_headers.get("tracestate").unwrap(),
add_headers.get(HeaderName::TraceState).unwrap(),
"other-service=32,linkup-session=tiny-cow"
);
}
Expand Down
7 changes: 4 additions & 3 deletions linkup/src/session_allocator.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
use std::sync::Arc;

use crate::{
extract_tracestate_session, first_subdomain, random_animal, random_six_char, session_to_json,
ConfigError, HeaderMap, NameKind, Session, SessionError, StringStore,
extract_tracestate_session, first_subdomain, headers::HeaderName, random_animal,
random_six_char, session_to_json, ConfigError, HeaderMap, NameKind, Session, SessionError,
StringStore,
};

pub struct SessionAllocator {
Expand All @@ -24,7 +25,7 @@ impl SessionAllocator {
return Ok((url_name, config));
}

if let Some(forwarded_host) = headers.get("x-forwarded-host") {
if let Some(forwarded_host) = headers.get(HeaderName::ForwardedHost) {
let forwarded_host_name = first_subdomain(forwarded_host);
if let Some(config) = self
.get_session_config(forwarded_host_name.to_string())
Expand Down

0 comments on commit 3b20311

Please sign in to comment.