Skip to content

Commit

Permalink
Add cache_path arg to WebScraperStage (#1358)
Browse files Browse the repository at this point in the history
Allows injecting an optional cache path to the WebScraperStage. If `cache_path` is not None, it uses the path as the path for the response caching system's sqllite database. If `cache_path` is None, the response caching system is bypassed.

Closes #1355

Authors:
  - Christopher Harris (https://github.com/cwharris)

Approvers:
  - Devin Robison (https://github.com/drobison00)

URL: #1358
  • Loading branch information
cwharris authored Nov 14, 2023
1 parent 5d1f815 commit 264d950
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 5 deletions.
16 changes: 12 additions & 4 deletions examples/llm/common/web_scraper_stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import mrc
import mrc.core.operators as ops
import pandas as pd
import requests
import requests_cache
from bs4 import BeautifulSoup
from langchain.text_splitter import RecursiveCharacterTextSplitter
Expand All @@ -45,9 +46,11 @@ class WebScraperStage(SinglePortStage):
Size in which to split the scraped content
link_column : str, default="link"
Column which contains the links to scrape
cache_path : str, default=None
the path for the response caching system's sqllite database, if None, caching is bypassed
"""

def __init__(self, c: Config, *, chunk_size: int, link_column: str = "link"):
def __init__(self, c: Config, *, chunk_size: int, link_column: str = "link", cache_path: str = None):
super().__init__(c)

self._link_column = link_column
Expand All @@ -61,8 +64,10 @@ def __init__(self, c: Config, *, chunk_size: int, link_column: str = "link"):
chunk_overlap=self._chunk_size // 10,
length_function=len)

self._session = requests_cache.CachedSession(os.path.join("./.cache/http", "RSSDownloadStage.sqlite"),
backend="sqlite")
if cache_path is None:
self._session = requests.Session()
else:
self._session = requests_cache.CachedSession(cache_path, backend="sqlite")

self._session.headers.update({
"User-Agent":
Expand Down Expand Up @@ -152,7 +157,10 @@ def _download_and_split(self, msg: MessageMeta) -> MessageMeta:
row_cp.update({"page_content": text})
final_rows.append(row_cp)

logger.debug("Processed page: '%s'. Cache hit: %s", url, response.from_cache)
if isinstance(response, requests_cache.models.response.CachedResponse):
logger.debug("Processed page: '%s'. Cache hit: %s", url, response.from_cache)
else:
logger.debug("Processed page: '%s'", url)

except ValueError as exc:
logger.error("Error parsing document: %s", exc)
Expand Down
3 changes: 2 additions & 1 deletion examples/llm/vdb_upload/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,8 @@ def pipeline(num_threads: int,

pipe.add_stage(MonitorStage(config, description="Source rate", unit='pages'))

pipe.add_stage(WebScraperStage(config, chunk_size=model_fea_length))
pipe.add_stage(
WebScraperStage(config, chunk_size=model_fea_length, cache_path="./.cache/http/RSSDownloadStage.sqlite"))

pipe.add_stage(MonitorStage(config, description="Download rate", unit='pages'))

Expand Down

0 comments on commit 264d950

Please sign in to comment.