From 5b4b94809bdd13e761d51984773840bc8a3b64b7 Mon Sep 17 00:00:00 2001 From: nate nowack Date: Tue, 11 Apr 2023 10:27:52 -0500 Subject: [PATCH] improve discourse loader --- src/marvin/loaders/discourse.py | 86 ++++++++++++++++++++++++--------- 1 file changed, 62 insertions(+), 24 deletions(-) diff --git a/src/marvin/loaders/discourse.py b/src/marvin/loaders/discourse.py index 4cda178a2..481dedb37 100644 --- a/src/marvin/loaders/discourse.py +++ b/src/marvin/loaders/discourse.py @@ -1,3 +1,4 @@ +import math from typing import Callable, Dict import httpx @@ -7,12 +8,6 @@ import marvin from marvin.loaders.base import Loader from marvin.models.documents import Document -from marvin.models.metadata import Metadata - - -def should_include_post(post: dict) -> bool: - """Return whether the post should be included in the results.""" - return post["accepted_answer"] class DiscoursePost(BaseModel): @@ -20,12 +15,10 @@ class DiscoursePost(BaseModel): base_url: str id: int - category_id: int + topic_id: int cooked: str created_at: pendulum.DateTime - topic_id: int topic_slug: str - topic_title: str @property def url(self) -> str: @@ -34,14 +27,16 @@ def url(self) -> str: class DiscourseLoader(Loader): - """Loader for Discourse posts.""" + """Loader for Discourse topics.""" source_type: str = Field(default="discourse") url: str = Field(default="https://discourse.prefect.io") - n_posts: int = Field(default=50) + n_topic: int = Field(default=30) + per_page: int = Field(default=30) request_headers: Dict[str, str] = Field(default_factory=dict) - include_post_filter: Callable[[dict], bool] = Field(default=should_include_post) + include_topic_filter: Callable[[dict], bool] = Field(default=lambda _: True) + include_post_filter: Callable[[dict], bool] = Field(default=lambda _: True) @validator("request_headers", always=True) def auth_headers(cls, v): @@ -63,29 +58,72 @@ def auth_headers(cls, v): async def load(self) -> list[Document]: """Load Discourse posts.""" documents = [] - for post in await self._get_posts(): + for post in await self._get_all_posts(): documents.extend( await Document( text=post.cooked, - metadata=Metadata( - source=self.source_type, - title=post.topic_title, - link=post.url, - created_at=post.created_at.timestamp(), - ), + metadata={ + "source": self.source_type, + "title": post.topic_slug.replace("-", " ").capitalize(), + "link": post.url, + "created_at": post.created_at.timestamp(), + }, ).to_excerpts() ) return documents - async def _get_posts(self) -> list[DiscoursePost]: - """Get posts from a Discourse forum.""" + async def _get_posts_for_topic(self, topic_id: int) -> list[dict]: + """Get posts for a specific topic.""" async with httpx.AsyncClient() as client: response = await client.get( - f"{self.url}/posts.json", headers=self.request_headers + f"{self.url}/t/{topic_id}.json", headers=self.request_headers ) response.raise_for_status() return [ - DiscoursePost(base_url=self.url, **post) - for post in response.json()["latest_posts"] + post + for post in response.json()["post_stream"]["posts"] if self.include_post_filter(post) ] + + async def _get_all_posts(self) -> list[DiscoursePost]: + """Get topics and posts from a Discourse forum filtered by a specific tag.""" + all_topics = [] + pages = math.ceil(self.n_topic / self.per_page) + + async with httpx.AsyncClient() as client: + for page in range(pages): + response = await client.get( + f"{self.url}/latest.json", + headers=self.request_headers, + params={"page": page, "per_page": self.per_page}, + ) + response.raise_for_status() + + topics = response.json()["topic_list"]["topics"] + all_topics.extend(topics) + + # Break the loop if we have fetched the desired number of topics + if len(all_topics) >= self.n_topic: + break + + filtered_topics = [ + topic for topic in all_topics if self.include_topic_filter(topic) + ] + + all_posts = [] + for topic in filtered_topics: + self.logger.info( + f"Fetching posts for retrieved topic {topic['title']!r}" + ) + posts = await self._get_posts_for_topic(topic["id"]) + all_posts.append( + DiscoursePost(base_url=self.url, **posts[0]) + ) # original post + all_posts.extend( + [ + DiscoursePost(base_url=self.url, **post) + for post in posts[1:] + if self.include_post_filter(post) + ] + ) + return all_posts