From 264d95071275b6aa039a4da62311787a14c5d90c Mon Sep 17 00:00:00 2001 From: Christopher Harris Date: Tue, 14 Nov 2023 12:17:02 -0600 Subject: [PATCH] Add `cache_path` arg to `WebScraperStage` (#1358) Allows injecting an optional cache path to the WebScraperStage. If `cache_path` is not None, it uses the path as the path for the response caching system's sqllite database. If `cache_path` is None, the response caching system is bypassed. Closes #1355 Authors: - Christopher Harris (https://github.com/cwharris) Approvers: - Devin Robison (https://github.com/drobison00) URL: https://github.com/nv-morpheus/Morpheus/pull/1358 --- examples/llm/common/web_scraper_stage.py | 16 ++++++++++++---- examples/llm/vdb_upload/pipeline.py | 3 ++- 2 files changed, 14 insertions(+), 5 deletions(-) diff --git a/examples/llm/common/web_scraper_stage.py b/examples/llm/common/web_scraper_stage.py index 95c199df92..fbf6d3e3a3 100644 --- a/examples/llm/common/web_scraper_stage.py +++ b/examples/llm/common/web_scraper_stage.py @@ -19,6 +19,7 @@ import mrc import mrc.core.operators as ops import pandas as pd +import requests import requests_cache from bs4 import BeautifulSoup from langchain.text_splitter import RecursiveCharacterTextSplitter @@ -45,9 +46,11 @@ class WebScraperStage(SinglePortStage): Size in which to split the scraped content link_column : str, default="link" Column which contains the links to scrape + cache_path : str, default=None + the path for the response caching system's sqllite database, if None, caching is bypassed """ - def __init__(self, c: Config, *, chunk_size: int, link_column: str = "link"): + def __init__(self, c: Config, *, chunk_size: int, link_column: str = "link", cache_path: str = None): super().__init__(c) self._link_column = link_column @@ -61,8 +64,10 @@ def __init__(self, c: Config, *, chunk_size: int, link_column: str = "link"): chunk_overlap=self._chunk_size // 10, length_function=len) - self._session = requests_cache.CachedSession(os.path.join("./.cache/http", "RSSDownloadStage.sqlite"), - backend="sqlite") + if cache_path is None: + self._session = requests.Session() + else: + self._session = requests_cache.CachedSession(cache_path, backend="sqlite") self._session.headers.update({ "User-Agent": @@ -152,7 +157,10 @@ def _download_and_split(self, msg: MessageMeta) -> MessageMeta: row_cp.update({"page_content": text}) final_rows.append(row_cp) - logger.debug("Processed page: '%s'. Cache hit: %s", url, response.from_cache) + if isinstance(response, requests_cache.models.response.CachedResponse): + logger.debug("Processed page: '%s'. Cache hit: %s", url, response.from_cache) + else: + logger.debug("Processed page: '%s'", url) except ValueError as exc: logger.error("Error parsing document: %s", exc) diff --git a/examples/llm/vdb_upload/pipeline.py b/examples/llm/vdb_upload/pipeline.py index fc888697d4..05b087e88b 100644 --- a/examples/llm/vdb_upload/pipeline.py +++ b/examples/llm/vdb_upload/pipeline.py @@ -75,7 +75,8 @@ def pipeline(num_threads: int, pipe.add_stage(MonitorStage(config, description="Source rate", unit='pages')) - pipe.add_stage(WebScraperStage(config, chunk_size=model_fea_length)) + pipe.add_stage( + WebScraperStage(config, chunk_size=model_fea_length, cache_path="./.cache/http/RSSDownloadStage.sqlite")) pipe.add_stage(MonitorStage(config, description="Download rate", unit='pages'))