Skip to content

Commit

Permalink
Use reddit auth for share links
Browse files Browse the repository at this point in the history
  • Loading branch information
evanc577 committed Jan 19, 2024
1 parent d86bc5b commit 7cf4762
Show file tree
Hide file tree
Showing 3 changed files with 116 additions and 10 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 4 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -47,5 +47,9 @@ features = ["macros", "rt-multi-thread", "signal", "time"]
version = "0.11"
features = ["macros", "sqlx-postgres", "runtime-tokio-rustls"]

[dependencies.time]
version = "0.3"
features = ["macros", "parsing", "formatting"]

[dev-dependencies]
approx = "0.5"
121 changes: 111 additions & 10 deletions src/link_embed/reddit.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
use std::sync::Arc;

use once_cell::sync::Lazy;
use regex::{Captures, Regex};
use reqwest::header;
use reqwest::{header, Client};
use serde::{Deserialize, Deserializer, Serialize};
use time::{Duration, OffsetDateTime};
use tokio::sync::Mutex;

use super::ReplacedLink;
use crate::CLIENT;
Expand Down Expand Up @@ -64,22 +69,29 @@ pub fn reddit_normal_links(text: &str) -> Vec<ReplacedLink> {

/// Reddit app share urls
async fn reddit_share_links(text: &str) -> Vec<ReplacedLink> {
static REDDIT_ACCESS_TOKEN: Lazy<AccessToken> = Lazy::new(AccessToken::default);
static REDDIT_USER_AGENT: &str = "Reddit";

let share_links = REDDIT_SHARE_RE
.find_iter(text)
.map(|m| (m.start(), m.as_str()));

let links = {
let mut links: Vec<(usize, Box<str>)> = Vec::new();
for (start, share_link) in share_links {
if let Ok(response) = CLIENT
.head(share_link)
.header(
header::USER_AGENT,
header::HeaderValue::from_static("insomnia"),
)
.send()
.await
{
let mut headers = header::HeaderMap::new();
headers.insert(
header::USER_AGENT,
header::HeaderValue::from_static(REDDIT_USER_AGENT),
);
if let Some(auth) = REDDIT_ACCESS_TOKEN.authentication(&CLIENT).await {
headers.insert(
header::AUTHORIZATION,
header::HeaderValue::from_str(&auth).unwrap(),
);
}

if let Ok(response) = CLIENT.head(share_link).headers(headers).send().await {
links.push((start, response.url().as_str().into()));
}
}
Expand Down Expand Up @@ -114,3 +126,92 @@ fn match_reddit_link(m: Captures<'_>) -> Option<RedditLink> {
_ => None,
}
}

// Use an access token to bypass rate limits
#[derive(Default, Clone)]
struct AccessToken {
token: Arc<Mutex<Option<AccessTokenInternal>>>,
}

impl AccessToken {
/// Return stored authorization, refresh it if needed
async fn authentication(&self, client: &Client) -> Option<String> {
let mut access_token_guard = self.token.lock().await;
if access_token_guard.is_none() {
// Get a new token if none exists
let new_token = AccessTokenInternal::access_token(client).await?;
*access_token_guard = Some(new_token);
} else if let Some(token) = &*access_token_guard {
// Get a new token if current one is expiring soon
let expiry = token.expiry;
let buffer_time = Duration::new(4 * 60 * 60, 0); // 4 hours
if expiry - OffsetDateTime::now_utc() < buffer_time {
let new_token = AccessTokenInternal::access_token(client).await?;
*access_token_guard = Some(new_token);
}
}
Some(format!(
"Bearer {}",
access_token_guard.as_ref().unwrap().access_token.clone()
))
}
}

#[allow(dead_code)]
#[derive(Debug, Clone, Deserialize)]
struct AccessTokenInternal {
access_token: String,
#[serde(rename = "expiry_ts")]
#[serde(deserialize_with = "deserialize_timestamp")]
expiry: OffsetDateTime,
#[serde(deserialize_with = "deserialize_duration")]
expires_in: Duration,
scope: Vec<String>,
token_type: String,
}

fn deserialize_timestamp<'de, D>(deserializer: D) -> Result<OffsetDateTime, D::Error>
where
D: Deserializer<'de>,
{
let ts = i64::deserialize(deserializer)?;
let dt = OffsetDateTime::from_unix_timestamp(ts).map_err(serde::de::Error::custom)?;
Ok(dt)
}

fn deserialize_duration<'de, D>(deserializer: D) -> Result<Duration, D::Error>
where
D: Deserializer<'de>,
{
let s = i64::deserialize(deserializer)?;
let d = Duration::new(s, 0);
Ok(d)
}

#[derive(Serialize)]
struct Body {
scopes: Vec<String>,
}

impl AccessTokenInternal {
async fn access_token(client: &Client) -> Option<Self> {
static ENDPOINT: &str = "https://accounts.reddit.com/api/access_token";
static AUTHORIZATION: &str = "basic b2hYcG9xclpZdWIxa2c6";
let body = Body {
scopes: vec!["*".into(), "email".into(), "pii".into()],
};
let response = client
.post(ENDPOINT)
.header(header::AUTHORIZATION, AUTHORIZATION)
.json(&body)
.send()
.await
.ok()?;
let status = response.status();
if status.is_success() {
Some(response.json().await.ok()?)
} else {
None
}
}
}

0 comments on commit 7cf4762

Please sign in to comment.