Skip to content

Commit

Permalink
Start adding unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
drobison00 committed Jan 12, 2024
1 parent 8808b4c commit c2b1706
Show file tree
Hide file tree
Showing 8 changed files with 242 additions and 160 deletions.
125 changes: 66 additions & 59 deletions examples/llm/common/web_scraper_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
logging

import cudf
from functools import partial

import logging

Expand All @@ -46,6 +47,69 @@ class WebScraperParamContract(BaseModel):
cache_dir: str = "./.cache/llm/rss"


def download_and_split(msg: MessageMeta, text_splitter, link_column, session) -> MessageMeta:
"""
Uses the HTTP GET method to download/scrape the links found in the message, splits the scraped data, and stores
it in the output, excludes output for any links which produce an error.
"""
if (link_column not in msg.get_column_names()):
return None

df = msg.df

if isinstance(df, cudf.DataFrame):
df: pd.DataFrame = df.to_pandas()

# Convert the dataframe into a list of dictionaries
df_dicts = df.to_dict(orient="records")

final_rows: list[dict] = []

for row in df_dicts:
url = row[link_column]

try:
# Try to get the page content
response = session.get(url)
logger.info(f"RESPONSE TEXT: {response.text}")

if (not response.ok):
logger.warning(
"Error downloading document from URL '%s'. " +
"Returned code: %s. With reason: '%s'",
url,
response.status_code,
response.reason)
continue

raw_html = response.text
soup = BeautifulSoup(raw_html, "html.parser")

text = soup.get_text(strip=True, separator=' ')
split_text = text_splitter.split_text(text)

for text in split_text:
row_cp = row.copy()
row_cp.update({"page_content": text})
final_rows.append(row_cp)

logger.info(final_rows)

if isinstance(response, requests_cache.models.response.CachedResponse):
logger.debug("Processed cached page: '%s'", url)
else:
logger.debug("Processed page: '%s'", url)

except ValueError as exc:
logger.error("Error parsing document: %s", exc)
continue
except Exception as exc:
logger.error("Error downloading document from URL '%s'. Error: %s", url, exc)
continue

return MessageMeta(df=pd.DataFrame(final_rows))


@register_module("web_scraper", "morpheus_examples_llm")
def web_scraper(builder: mrc.Builder):
module_config = builder.get_current_module_config()
Expand Down Expand Up @@ -82,66 +146,9 @@ def web_scraper(builder: mrc.Builder):
"Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/116.0.0.0 Safari/537.36"
})

def download_and_split(msg: MessageMeta) -> MessageMeta:
"""
Uses the HTTP GET method to download/scrape the links found in the message, splits the scraped data, and stores
it in the output, excludes output for any links which produce an error.
"""
if link_column not in msg.get_column_names():
return None

df = msg.df

if isinstance(df, cudf.DataFrame):
df: pd.DataFrame = df.to_pandas()

# Convert the dataframe into a list of dictionaries
df_dicts = df.to_dict(orient="records")

final_rows: list[dict] = []

for row in df_dicts:
url = row[link_column]

try:
# Try to get the page content
response = session.get(url)

if (not response.ok):
logger.warning(
"Error downloading document from URL '%s'. " +
"Returned code: %s. With reason: '%s'",
url,
response.status_code,
response.reason)
continue

raw_html = response.text
soup = BeautifulSoup(raw_html, "html.parser")

text = soup.get_text(strip=True, separator=' ')
split_text = text_splitter.split_text(text)

for text in split_text:
row_cp = row.copy()
row_cp.update({"page_content": text})
final_rows.append(row_cp)

if isinstance(response, requests_cache.models.response.CachedResponse):
logger.debug("Processed cached page: '%s'", url)
else:
logger.debug("Processed page: '%s'", url)

except ValueError as exc:
logger.error("Error parsing document: %s", exc)
continue
except Exception as exc:
logger.error("Error downloading document from URL '%s'. Error: %s", url, exc)
continue

return MessageMeta(df=pd.DataFrame(final_rows))
op_func = partial(download_and_split, text_splitter=text_splitter, link_column=link_column, session=session)

node = builder.make_node("web_scraper", ops.map(download_and_split), ops.filter(lambda x: x is not None))
node = builder.make_node("web_scraper", ops.map(op_func), ops.filter(lambda x: x is not None))

builder.register_module_input("input", node)
builder.register_module_output("output", node)
Expand Down
115 changes: 18 additions & 97 deletions examples/llm/common/web_scraper_stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,27 +13,19 @@
# limitations under the License.

import logging
import os
import typing

import cudf
import mrc
import mrc.core.operators as ops
import pandas as pd
import requests
import requests_cache
from bs4 import BeautifulSoup
from langchain.text_splitter import RecursiveCharacterTextSplitter

from morpheus.config import Config
from morpheus.messages import MessageMeta
from morpheus.pipeline.single_port_stage import SinglePortStage
from morpheus.pipeline.stage_schema import StageSchema
from web_scraper_module import WebScraperInterface

logger = logging.getLogger(f"morpheus.{__name__}")


# TODO(Devin) Convert to use module
class WebScraperStage(SinglePortStage):
"""
Stage for scraping web based content using the HTTP GET protocol.
Expand Down Expand Up @@ -61,27 +53,20 @@ def __init__(self,
cache_path: str = "./.cache/http/RSSDownloadStage.sqlite"):
super().__init__(c)

self._link_column = link_column
self._chunk_size = chunk_size
self._cache_dir = "./.cache/llm/rss/"
self._module_config = {
"web_scraper_config": {
"link_column": link_column,
"chunk_size": chunk_size,
"enable_cache": enable_cache,
"cache_path": cache_path,
"cache_dir": "./.cache/llm/rss",
}
}

# Ensure the directory exists
if (enable_cache):
os.makedirs(self._cache_dir, exist_ok=True)
self._input_port_name = "input"
self._output_port_name = "output"

self._text_splitter = RecursiveCharacterTextSplitter(chunk_size=self._chunk_size,
chunk_overlap=self._chunk_size // 10,
length_function=len)

if enable_cache:
self._session = requests_cache.CachedSession(cache_path, backend="sqlite")
else:
self._session = requests.Session()

self._session.headers.update({
"User-Agent":
"Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/116.0.0.0 Safari/537.36"
})
self._module_definition = WebScraperInterface.get_definition("web_scraper", self._module_config)

@property
def name(self) -> str:
Expand All @@ -108,75 +93,11 @@ def compute_schema(self, schema: StageSchema):
schema.output_schema.set_type(MessageMeta)

def _build_single(self, builder: mrc.Builder, input_node: mrc.SegmentObject) -> mrc.SegmentObject:
module = self._module_definition.load(builder=builder)

node = builder.make_node(self.unique_name,
ops.map(self._download_and_split),
ops.filter(lambda x: x is not None))

node.launch_options.pe_count = self._config.num_threads

builder.make_edge(input_node, node)

return node

def _download_and_split(self, msg: MessageMeta) -> MessageMeta:
"""
Uses the HTTP GET method to download/scrape the links found in the message, splits the scraped data, and stores
it in the output, excludes output for any links which produce an error.
"""
if self._link_column not in msg.get_column_names():
return None

df = msg.df

if isinstance(df, cudf.DataFrame):
df: pd.DataFrame = df.to_pandas()

# Convert the dataframe into a list of dictionaries
df_dicts = df.to_dict(orient="records")

final_rows: list[dict] = []

for row in df_dicts:

url = row[self._link_column]

try:
# Try to get the page content
response = self._session.get(url)

if (not response.ok):
logger.warning(
"Error downloading document from URL '%s'. " + "Returned code: %s. With reason: '%s'",
url,
response.status_code,
response.reason)
continue

raw_html = response.text

soup = BeautifulSoup(raw_html, "html.parser")

text = soup.get_text(strip=True, separator=' ')

split_text = self._text_splitter.split_text(text)

for text in split_text:
row_cp = row.copy()
row_cp.update({"page_content": text})
final_rows.append(row_cp)

if isinstance(response, requests_cache.models.response.CachedResponse):
logger.debug("Processed page: '%s'. Cache hit: %s", url, response.from_cache)
else:
logger.debug("Processed page: '%s'", url)
mod_in_node = module.input_port(self._input_port_name)
mod_out_node = module.output_port(self._output_port_name)

except ValueError as exc:
logger.error("Error parsing document: %s", exc)
continue
except Exception as exc:
logger.error("Error downloading document from URL '%s'. Error: %s", url, exc)
continue
builder.make_edge(input_node, mod_in_node)

# Not using cudf to avoid error: pyarrow.lib.ArrowInvalid: cannot mix list and non-list, non-null values
return MessageMeta(pd.DataFrame(final_rows))
return mod_out_node
3 changes: 1 addition & 2 deletions examples/llm/vdb_upload/module/file_source_pipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,5 +158,4 @@ def _file_source_pipe(builder: mrc.Builder):


FileSourcePipe = ModuleInterface("file_source_pipe", "morpheus_examples_llm",
FileSourceParamContract)
FileSourcePipe.print_schema()
FileSourceParamContract)
7 changes: 6 additions & 1 deletion morpheus/stages/general/linear_modules_stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,14 @@ def __init__(self,
self._input_port_name = input_port_name
self._output_port_name = output_port_name

if (isinstance(self._module_config, dict)):
self._unique_name = self._module_config.get("unique_name", "linear_module")
else:
self._unique_name = self._module_config.name

@property
def name(self) -> str:
return self._module_config.get("module_name", "linear_module")
return self._unique_name

def supports_cpp_node(self):
return False
Expand Down
24 changes: 23 additions & 1 deletion tests/examples/llm/common/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,33 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import pytest
import os
import sys

import pytest
from _utils import TEST_DIRS
from _utils import import_or_skip


@pytest.fixture(scope="module")
def import_utils():
utils_path = os.path.join(TEST_DIRS.examples_dir, 'llm/common/')
sys.path.insert(0, utils_path)

import utils

return utils

@pytest.fixture(scope="module")
def import_web_scraper_module():
web_scaper_path = os.path.join(TEST_DIRS.examples_dir, 'llm/common/')
sys.path.insert(0, web_scraper_path)

import web_scraper_module

return web_scraper_module


@pytest.fixture(name="nemollm", autouse=True, scope='session')
def nemollm_fixture(fail_missing: bool):
"""
Expand Down
Loading

0 comments on commit c2b1706

Please sign in to comment.