From 37fe93edad71b1d12252d41c836e2acbc3b1e63f Mon Sep 17 00:00:00 2001 From: Evan Chang Date: Sun, 28 Apr 2024 17:53:22 -0400 Subject: [PATCH] Whitelist reddit media links --- src/link_embed/reddit.rs | 86 ++++++++++++++++++---------------------- 1 file changed, 39 insertions(+), 47 deletions(-) diff --git a/src/link_embed/reddit.rs b/src/link_embed/reddit.rs index 689230b..72491e3 100644 --- a/src/link_embed/reddit.rs +++ b/src/link_embed/reddit.rs @@ -1,5 +1,6 @@ use std::sync::Arc; +use futures::{stream, StreamExt}; use once_cell::sync::Lazy; use regex::{Captures, Regex}; use reqwest::{header, Client}; @@ -36,7 +37,9 @@ enum RedditLink { impl RedditLink { fn url(&self) -> Box { match self { - RedditLink::Submission { id, .. } => format!("https://rxddit.com/{id}").into_boxed_str(), + RedditLink::Submission { id, .. } => { + format!("https://rxddit.com/{id}").into_boxed_str() + } RedditLink::Comment { subreddit, submission_id, @@ -65,25 +68,17 @@ pub async fn reddit_links(text: &str) -> Vec { /// Regular reddit urls async fn reddit_normal_links(text: &str) -> Vec { - let captures = REDDIT_RE.captures_iter(text); - let mut new_links = Vec::new(); - for capture in captures { - let start = capture.get(0).unwrap().start(); - let links = match_reddit_link(capture) - .await - .into_iter() - .map(|link| (start, link)); - new_links.extend(links); - } - let content = new_links - .iter() - .map(|(start, link)| ReplacedLink { - start: *start, - link: link.url(), - media: link.media(), + stream::iter(REDDIT_RE.captures_iter(text)) + .filter_map(|c| async { + let start = c.get(0).unwrap().start(); + match_reddit_link(c).await.map(|link| ReplacedLink { + start, + link: link.url(), + media: link.media(), + }) }) - .collect(); - content + .collect() + .await } /// Reddit app share urls @@ -114,26 +109,19 @@ async fn reddit_share_links(text: &str) -> Vec { links }; - let capture_links = links + let iter = links .iter() .filter_map(|(start, link)| REDDIT_RE.captures(link).map(|c| (start, c))); - let mut new_links = Vec::new(); - for (start, link) in capture_links { - let links = match_reddit_link(link) - .await - .into_iter() - .map(|link| (start, link)); - new_links.extend(links); - } - let content = new_links - .iter() - .map(|(start, link)| ReplacedLink { - start: **start, - link: link.url(), - media: link.media(), + stream::iter(iter) + .filter_map(|(start, c)| async { + match_reddit_link(c).await.map(|link| ReplacedLink { + start: *start, + link: link.url(), + media: link.media(), + }) }) - .collect(); - content + .collect() + .await } async fn match_reddit_link(m: Captures<'_>) -> Option { @@ -143,12 +131,10 @@ async fn match_reddit_link(m: Captures<'_>) -> Option { submission_id: submission_id.as_str().into(), comment_id: comment_id.as_str().into(), }), - (_, Some(submission_id), _) => { - Some(RedditLink::Submission { - id: submission_id.as_str().into(), - media: reddit_post_media(submission_id.as_str()).await, - }) - } + (_, Some(submission_id), _) => Some(RedditLink::Submission { + id: submission_id.as_str().into(), + media: reddit_post_media(submission_id.as_str()).await, + }), _ => None, } } @@ -186,10 +172,7 @@ async fn reddit_post_media(submission_id: &str) -> Option> { url: Option>, } - let url = format!( - "https://oauth.reddit.com/comments/{}/", - submission_id - ); + let url = format!("https://oauth.reddit.com/comments/{}/", submission_id); let response = CLIENT .get(url) .headers(headers) @@ -206,11 +189,20 @@ async fn reddit_post_media(submission_id: &str) -> Option> { let url = response .first() .and_then(|r| r.data.children.first()) - .and_then(|c| c.data.url.clone()); + .and_then(|c| c.data.url.clone()) + .and_then(|url| reqwest::Url::parse(&url).ok()) + .and_then(|url| media_url(&url)); url } +fn media_url(url: &reqwest::Url) -> Option> { + match url.domain() { + Some("streamable.com") => Some(url.as_str().into()), + _ => None, + } +} + // Use an access token to bypass rate limits #[derive(Default, Clone)] struct AccessToken {