From 7cf47628c4e6d3ff24f94323ceab479c27d1aec1 Mon Sep 17 00:00:00 2001 From: Evan Chang Date: Fri, 19 Jan 2024 02:10:42 -0500 Subject: [PATCH] Use reddit auth for share links --- Cargo.lock | 1 + Cargo.toml | 4 ++ src/link_embed/reddit.rs | 121 +++++++++++++++++++++++++++++++++++---- 3 files changed, 116 insertions(+), 10 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index a7e3641..1ebdf58 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1113,6 +1113,7 @@ dependencies = [ "serde", "serde_json", "songbird", + "time", "tokio", "toml 0.8.6", "unicode-segmentation", diff --git a/Cargo.toml b/Cargo.toml index 2f5e154..76d8a08 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -47,5 +47,9 @@ features = ["macros", "rt-multi-thread", "signal", "time"] version = "0.11" features = ["macros", "sqlx-postgres", "runtime-tokio-rustls"] +[dependencies.time] +version = "0.3" +features = ["macros", "parsing", "formatting"] + [dev-dependencies] approx = "0.5" diff --git a/src/link_embed/reddit.rs b/src/link_embed/reddit.rs index 7e0620f..f918f79 100644 --- a/src/link_embed/reddit.rs +++ b/src/link_embed/reddit.rs @@ -1,6 +1,11 @@ +use std::sync::Arc; + use once_cell::sync::Lazy; use regex::{Captures, Regex}; -use reqwest::header; +use reqwest::{header, Client}; +use serde::{Deserialize, Deserializer, Serialize}; +use time::{Duration, OffsetDateTime}; +use tokio::sync::Mutex; use super::ReplacedLink; use crate::CLIENT; @@ -64,6 +69,9 @@ pub fn reddit_normal_links(text: &str) -> Vec { /// Reddit app share urls async fn reddit_share_links(text: &str) -> Vec { + static REDDIT_ACCESS_TOKEN: Lazy = Lazy::new(AccessToken::default); + static REDDIT_USER_AGENT: &str = "Reddit"; + let share_links = REDDIT_SHARE_RE .find_iter(text) .map(|m| (m.start(), m.as_str())); @@ -71,15 +79,19 @@ async fn reddit_share_links(text: &str) -> Vec { let links = { let mut links: Vec<(usize, Box)> = Vec::new(); for (start, share_link) in share_links { - if let Ok(response) = CLIENT - .head(share_link) - .header( - header::USER_AGENT, - header::HeaderValue::from_static("insomnia"), - ) - .send() - .await - { + let mut headers = header::HeaderMap::new(); + headers.insert( + header::USER_AGENT, + header::HeaderValue::from_static(REDDIT_USER_AGENT), + ); + if let Some(auth) = REDDIT_ACCESS_TOKEN.authentication(&CLIENT).await { + headers.insert( + header::AUTHORIZATION, + header::HeaderValue::from_str(&auth).unwrap(), + ); + } + + if let Ok(response) = CLIENT.head(share_link).headers(headers).send().await { links.push((start, response.url().as_str().into())); } } @@ -114,3 +126,92 @@ fn match_reddit_link(m: Captures<'_>) -> Option { _ => None, } } + +// Use an access token to bypass rate limits +#[derive(Default, Clone)] +struct AccessToken { + token: Arc>>, +} + +impl AccessToken { + /// Return stored authorization, refresh it if needed + async fn authentication(&self, client: &Client) -> Option { + let mut access_token_guard = self.token.lock().await; + if access_token_guard.is_none() { + // Get a new token if none exists + let new_token = AccessTokenInternal::access_token(client).await?; + *access_token_guard = Some(new_token); + } else if let Some(token) = &*access_token_guard { + // Get a new token if current one is expiring soon + let expiry = token.expiry; + let buffer_time = Duration::new(4 * 60 * 60, 0); // 4 hours + if expiry - OffsetDateTime::now_utc() < buffer_time { + let new_token = AccessTokenInternal::access_token(client).await?; + *access_token_guard = Some(new_token); + } + } + Some(format!( + "Bearer {}", + access_token_guard.as_ref().unwrap().access_token.clone() + )) + } +} + +#[allow(dead_code)] +#[derive(Debug, Clone, Deserialize)] +struct AccessTokenInternal { + access_token: String, + #[serde(rename = "expiry_ts")] + #[serde(deserialize_with = "deserialize_timestamp")] + expiry: OffsetDateTime, + #[serde(deserialize_with = "deserialize_duration")] + expires_in: Duration, + scope: Vec, + token_type: String, +} + +fn deserialize_timestamp<'de, D>(deserializer: D) -> Result +where + D: Deserializer<'de>, +{ + let ts = i64::deserialize(deserializer)?; + let dt = OffsetDateTime::from_unix_timestamp(ts).map_err(serde::de::Error::custom)?; + Ok(dt) +} + +fn deserialize_duration<'de, D>(deserializer: D) -> Result +where + D: Deserializer<'de>, +{ + let s = i64::deserialize(deserializer)?; + let d = Duration::new(s, 0); + Ok(d) +} + +#[derive(Serialize)] +struct Body { + scopes: Vec, +} + +impl AccessTokenInternal { + async fn access_token(client: &Client) -> Option { + static ENDPOINT: &str = "https://accounts.reddit.com/api/access_token"; + static AUTHORIZATION: &str = "basic b2hYcG9xclpZdWIxa2c6"; + let body = Body { + scopes: vec!["*".into(), "email".into(), "pii".into()], + }; + let response = client + .post(ENDPOINT) + .header(header::AUTHORIZATION, AUTHORIZATION) + .json(&body) + .send() + .await + .ok()?; + let status = response.status(); + if status.is_success() { + Some(response.json().await.ok()?) + } else { + None + } + } +}