From c2b17065a8044b916b908d00f93cf5a3f91be949 Mon Sep 17 00:00:00 2001 From: Devin Robison Date: Fri, 12 Jan 2024 16:30:01 -0700 Subject: [PATCH] Start adding unit tests --- examples/llm/common/web_scraper_module.py | 125 +++++++++--------- examples/llm/common/web_scraper_stage.py | 115 +++------------- .../llm/vdb_upload/module/file_source_pipe.py | 3 +- .../stages/general/linear_modules_stage.py | 7 +- tests/examples/llm/common/conftest.py | 24 +++- tests/examples/llm/common/test_utils.py | 61 +++++++++ .../llm/common/test_web_scraper_module.py | 47 +++++++ .../mocks/RSS/single_entry/GET.mock | 20 +++ 8 files changed, 242 insertions(+), 160 deletions(-) create mode 100644 tests/examples/llm/common/test_utils.py create mode 100644 tests/examples/llm/common/test_web_scraper_module.py create mode 100644 tests/mock_rest_server/mocks/RSS/single_entry/GET.mock diff --git a/examples/llm/common/web_scraper_module.py b/examples/llm/common/web_scraper_module.py index eada36932e..ad4d7509ce 100644 --- a/examples/llm/common/web_scraper_module.py +++ b/examples/llm/common/web_scraper_module.py @@ -22,6 +22,7 @@ logging import cudf +from functools import partial import logging @@ -46,6 +47,69 @@ class WebScraperParamContract(BaseModel): cache_dir: str = "./.cache/llm/rss" +def download_and_split(msg: MessageMeta, text_splitter, link_column, session) -> MessageMeta: + """ +Uses the HTTP GET method to download/scrape the links found in the message, splits the scraped data, and stores +it in the output, excludes output for any links which produce an error. +""" + if (link_column not in msg.get_column_names()): + return None + + df = msg.df + + if isinstance(df, cudf.DataFrame): + df: pd.DataFrame = df.to_pandas() + + # Convert the dataframe into a list of dictionaries + df_dicts = df.to_dict(orient="records") + + final_rows: list[dict] = [] + + for row in df_dicts: + url = row[link_column] + + try: + # Try to get the page content + response = session.get(url) + logger.info(f"RESPONSE TEXT: {response.text}") + + if (not response.ok): + logger.warning( + "Error downloading document from URL '%s'. " + + "Returned code: %s. With reason: '%s'", + url, + response.status_code, + response.reason) + continue + + raw_html = response.text + soup = BeautifulSoup(raw_html, "html.parser") + + text = soup.get_text(strip=True, separator=' ') + split_text = text_splitter.split_text(text) + + for text in split_text: + row_cp = row.copy() + row_cp.update({"page_content": text}) + final_rows.append(row_cp) + + logger.info(final_rows) + + if isinstance(response, requests_cache.models.response.CachedResponse): + logger.debug("Processed cached page: '%s'", url) + else: + logger.debug("Processed page: '%s'", url) + + except ValueError as exc: + logger.error("Error parsing document: %s", exc) + continue + except Exception as exc: + logger.error("Error downloading document from URL '%s'. Error: %s", url, exc) + continue + + return MessageMeta(df=pd.DataFrame(final_rows)) + + @register_module("web_scraper", "morpheus_examples_llm") def web_scraper(builder: mrc.Builder): module_config = builder.get_current_module_config() @@ -82,66 +146,9 @@ def web_scraper(builder: mrc.Builder): "Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/116.0.0.0 Safari/537.36" }) - def download_and_split(msg: MessageMeta) -> MessageMeta: - """ - Uses the HTTP GET method to download/scrape the links found in the message, splits the scraped data, and stores - it in the output, excludes output for any links which produce an error. - """ - if link_column not in msg.get_column_names(): - return None - - df = msg.df - - if isinstance(df, cudf.DataFrame): - df: pd.DataFrame = df.to_pandas() - - # Convert the dataframe into a list of dictionaries - df_dicts = df.to_dict(orient="records") - - final_rows: list[dict] = [] - - for row in df_dicts: - url = row[link_column] - - try: - # Try to get the page content - response = session.get(url) - - if (not response.ok): - logger.warning( - "Error downloading document from URL '%s'. " + - "Returned code: %s. With reason: '%s'", - url, - response.status_code, - response.reason) - continue - - raw_html = response.text - soup = BeautifulSoup(raw_html, "html.parser") - - text = soup.get_text(strip=True, separator=' ') - split_text = text_splitter.split_text(text) - - for text in split_text: - row_cp = row.copy() - row_cp.update({"page_content": text}) - final_rows.append(row_cp) - - if isinstance(response, requests_cache.models.response.CachedResponse): - logger.debug("Processed cached page: '%s'", url) - else: - logger.debug("Processed page: '%s'", url) - - except ValueError as exc: - logger.error("Error parsing document: %s", exc) - continue - except Exception as exc: - logger.error("Error downloading document from URL '%s'. Error: %s", url, exc) - continue - - return MessageMeta(df=pd.DataFrame(final_rows)) + op_func = partial(download_and_split, text_splitter=text_splitter, link_column=link_column, session=session) - node = builder.make_node("web_scraper", ops.map(download_and_split), ops.filter(lambda x: x is not None)) + node = builder.make_node("web_scraper", ops.map(op_func), ops.filter(lambda x: x is not None)) builder.register_module_input("input", node) builder.register_module_output("output", node) diff --git a/examples/llm/common/web_scraper_stage.py b/examples/llm/common/web_scraper_stage.py index 8925c77fec..cec06544b8 100644 --- a/examples/llm/common/web_scraper_stage.py +++ b/examples/llm/common/web_scraper_stage.py @@ -13,27 +13,19 @@ # limitations under the License. import logging -import os import typing -import cudf import mrc -import mrc.core.operators as ops -import pandas as pd -import requests -import requests_cache -from bs4 import BeautifulSoup -from langchain.text_splitter import RecursiveCharacterTextSplitter from morpheus.config import Config from morpheus.messages import MessageMeta from morpheus.pipeline.single_port_stage import SinglePortStage from morpheus.pipeline.stage_schema import StageSchema +from web_scraper_module import WebScraperInterface logger = logging.getLogger(f"morpheus.{__name__}") -# TODO(Devin) Convert to use module class WebScraperStage(SinglePortStage): """ Stage for scraping web based content using the HTTP GET protocol. @@ -61,27 +53,20 @@ def __init__(self, cache_path: str = "./.cache/http/RSSDownloadStage.sqlite"): super().__init__(c) - self._link_column = link_column - self._chunk_size = chunk_size - self._cache_dir = "./.cache/llm/rss/" + self._module_config = { + "web_scraper_config": { + "link_column": link_column, + "chunk_size": chunk_size, + "enable_cache": enable_cache, + "cache_path": cache_path, + "cache_dir": "./.cache/llm/rss", + } + } - # Ensure the directory exists - if (enable_cache): - os.makedirs(self._cache_dir, exist_ok=True) + self._input_port_name = "input" + self._output_port_name = "output" - self._text_splitter = RecursiveCharacterTextSplitter(chunk_size=self._chunk_size, - chunk_overlap=self._chunk_size // 10, - length_function=len) - - if enable_cache: - self._session = requests_cache.CachedSession(cache_path, backend="sqlite") - else: - self._session = requests.Session() - - self._session.headers.update({ - "User-Agent": - "Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/116.0.0.0 Safari/537.36" - }) + self._module_definition = WebScraperInterface.get_definition("web_scraper", self._module_config) @property def name(self) -> str: @@ -108,75 +93,11 @@ def compute_schema(self, schema: StageSchema): schema.output_schema.set_type(MessageMeta) def _build_single(self, builder: mrc.Builder, input_node: mrc.SegmentObject) -> mrc.SegmentObject: + module = self._module_definition.load(builder=builder) - node = builder.make_node(self.unique_name, - ops.map(self._download_and_split), - ops.filter(lambda x: x is not None)) - - node.launch_options.pe_count = self._config.num_threads - - builder.make_edge(input_node, node) - - return node - - def _download_and_split(self, msg: MessageMeta) -> MessageMeta: - """ - Uses the HTTP GET method to download/scrape the links found in the message, splits the scraped data, and stores - it in the output, excludes output for any links which produce an error. - """ - if self._link_column not in msg.get_column_names(): - return None - - df = msg.df - - if isinstance(df, cudf.DataFrame): - df: pd.DataFrame = df.to_pandas() - - # Convert the dataframe into a list of dictionaries - df_dicts = df.to_dict(orient="records") - - final_rows: list[dict] = [] - - for row in df_dicts: - - url = row[self._link_column] - - try: - # Try to get the page content - response = self._session.get(url) - - if (not response.ok): - logger.warning( - "Error downloading document from URL '%s'. " + "Returned code: %s. With reason: '%s'", - url, - response.status_code, - response.reason) - continue - - raw_html = response.text - - soup = BeautifulSoup(raw_html, "html.parser") - - text = soup.get_text(strip=True, separator=' ') - - split_text = self._text_splitter.split_text(text) - - for text in split_text: - row_cp = row.copy() - row_cp.update({"page_content": text}) - final_rows.append(row_cp) - - if isinstance(response, requests_cache.models.response.CachedResponse): - logger.debug("Processed page: '%s'. Cache hit: %s", url, response.from_cache) - else: - logger.debug("Processed page: '%s'", url) + mod_in_node = module.input_port(self._input_port_name) + mod_out_node = module.output_port(self._output_port_name) - except ValueError as exc: - logger.error("Error parsing document: %s", exc) - continue - except Exception as exc: - logger.error("Error downloading document from URL '%s'. Error: %s", url, exc) - continue + builder.make_edge(input_node, mod_in_node) - # Not using cudf to avoid error: pyarrow.lib.ArrowInvalid: cannot mix list and non-list, non-null values - return MessageMeta(pd.DataFrame(final_rows)) + return mod_out_node diff --git a/examples/llm/vdb_upload/module/file_source_pipe.py b/examples/llm/vdb_upload/module/file_source_pipe.py index e9b59b1b0c..67f15d456f 100644 --- a/examples/llm/vdb_upload/module/file_source_pipe.py +++ b/examples/llm/vdb_upload/module/file_source_pipe.py @@ -158,5 +158,4 @@ def _file_source_pipe(builder: mrc.Builder): FileSourcePipe = ModuleInterface("file_source_pipe", "morpheus_examples_llm", - FileSourceParamContract) -FileSourcePipe.print_schema() + FileSourceParamContract) \ No newline at end of file diff --git a/morpheus/stages/general/linear_modules_stage.py b/morpheus/stages/general/linear_modules_stage.py index 79ecff9d7f..a0a67c081a 100644 --- a/morpheus/stages/general/linear_modules_stage.py +++ b/morpheus/stages/general/linear_modules_stage.py @@ -63,9 +63,14 @@ def __init__(self, self._input_port_name = input_port_name self._output_port_name = output_port_name + if (isinstance(self._module_config, dict)): + self._unique_name = self._module_config.get("unique_name", "linear_module") + else: + self._unique_name = self._module_config.name + @property def name(self) -> str: - return self._module_config.get("module_name", "linear_module") + return self._unique_name def supports_cpp_node(self): return False diff --git a/tests/examples/llm/common/conftest.py b/tests/examples/llm/common/conftest.py index 11ef4bad0c..2b1384bf07 100644 --- a/tests/examples/llm/common/conftest.py +++ b/tests/examples/llm/common/conftest.py @@ -13,11 +13,33 @@ # See the License for the specific language governing permissions and # limitations under the License. -import pytest +import os +import sys +import pytest +from _utils import TEST_DIRS from _utils import import_or_skip +@pytest.fixture(scope="module") +def import_utils(): + utils_path = os.path.join(TEST_DIRS.examples_dir, 'llm/common/') + sys.path.insert(0, utils_path) + + import utils + + return utils + +@pytest.fixture(scope="module") +def import_web_scraper_module(): + web_scaper_path = os.path.join(TEST_DIRS.examples_dir, 'llm/common/') + sys.path.insert(0, web_scraper_path) + + import web_scraper_module + + return web_scraper_module + + @pytest.fixture(name="nemollm", autouse=True, scope='session') def nemollm_fixture(fail_missing: bool): """ diff --git a/tests/examples/llm/common/test_utils.py b/tests/examples/llm/common/test_utils.py new file mode 100644 index 0000000000..73603ad41e --- /dev/null +++ b/tests/examples/llm/common/test_utils.py @@ -0,0 +1,61 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import types + +import pymilvus + + +# TODO(Devin) +# build_huggingface_embeddings, build_milvus_service, and build_llm_service + +def test_build_milvus_config_with_valid_embedding_size(import_utils: types.ModuleType): + embedding_size = 128 + config = import_utils.build_milvus_config(embedding_size) + + assert 'index_conf' in config + assert 'schema_conf' in config + + embedding_field_schema = next( + (field for field in config['schema_conf']['schema_fields'] if field["name"] == 'embedding'), + None + ) + assert embedding_field_schema is not None + assert embedding_field_schema['params']['dim'] == embedding_size + +def test_build_milvus_config_uses_correct_field_types(import_utils: types.ModuleType): + embedding_size = 128 + config = import_utils.build_milvus_config(embedding_size) + + for field in config['schema_conf']['schema_fields']: + assert 'name' in field + assert 'type' in field + assert 'description' in field + + if field['name'] == 'embedding': + assert field['type'] == pymilvus.DataType.FLOAT_VECTOR + else: + assert field['type'] in [pymilvus.DataType.INT64, pymilvus.DataType.VARCHAR] + +def test_build_milvus_config_index_configuration(import_utils: types.ModuleType): + embedding_size = 128 + config = import_utils.build_milvus_config(embedding_size) + + index_conf = config['index_conf'] + assert index_conf['field_name'] == 'embedding' + assert index_conf['metric_type'] == 'L2' + assert index_conf['index_type'] == 'HNSW' + assert 'params' in index_conf + assert index_conf['params']['M'] == 8 + assert index_conf['params']['efConstruction'] == 64 diff --git a/tests/examples/llm/common/test_web_scraper_module.py b/tests/examples/llm/common/test_web_scraper_module.py new file mode 100644 index 0000000000..868afa59db --- /dev/null +++ b/tests/examples/llm/common/test_web_scraper_module.py @@ -0,0 +1,47 @@ +import os +import types + +import cudf +import pytest +from _utils import TEST_DIRS +from _utils import assert_results + +from morpheus.config import Config +from morpheus.messages import MessageMeta +from morpheus.pipeline import LinearPipeline +from morpheus.stages.general.linear_modules_stage import LinearModulesStage +from morpheus.stages.input.in_memory_source_stage import InMemorySourceStage +from morpheus.stages.output.compare_dataframe_stage import CompareDataFrameStage + + +@pytest.mark.slow +@pytest.mark.use_python +@pytest.mark.use_cudf +@pytest.mark.import_mod(os.path.join(TEST_DIRS.examples_dir, 'llm/common/web_scraper_module.py')) +def test_http_client_source_stage_pipe(config: Config, mock_rest_server: str, import_mod: types.ModuleType): + url = f"{mock_rest_server}/www/index" + + df = cudf.DataFrame({"link": [url]}) + df_expected = cudf.DataFrame({"link": [url], "page_content": "website title some paragraph"}) + + web_scraper_definition = import_mod.WebScraperInterface.get_definition("web_scraper", + module_config={"web_scraper_config": { + "link_column": "link", "chunk_size": 100, + "enable_cache": False, + "cache_path": "./.cache/http/RSSDownloadStage.sqlite", + "cache_dir": "./.cache/llm/rss"}}) + + pipe = LinearPipeline(config) + pipe.set_source(InMemorySourceStage(config, [df])) + pipe.add_stage(LinearModulesStage(config, + web_scraper_definition, + input_type=MessageMeta, + output_type=MessageMeta, + input_port_name="input", + output_port_name="output")) + comp_stage = pipe.add_stage(CompareDataFrameStage(config, compare_df=df_expected)) + pipe.run() + + print(comp_stage.get_messages()) + + assert_results(comp_stage.get_results()) diff --git a/tests/mock_rest_server/mocks/RSS/single_entry/GET.mock b/tests/mock_rest_server/mocks/RSS/single_entry/GET.mock new file mode 100644 index 0000000000..39dbe4bb24 --- /dev/null +++ b/tests/mock_rest_server/mocks/RSS/single_entry/GET.mock @@ -0,0 +1,20 @@ +HTTP/1.1 200 OK +Content-Type: application/json + + + + + Cyber Security News + http://localhost:8080/RSS/feed_link + Latest updates and articles in cybersecurity + en-us + Mon, 10 Jan 2024 10:00:00 GMT + + New Vulnerability Discovered in Popular Web Framework + https://www.cybersecuritynews.com/new-vulnerability-web-framework + A new security vulnerability has been identified in the popular XYZ Web Framework, which could allow attackers to execute arbitrary code on affected systems. Users are advised to apply the latest patches immediately. + Mon, 10 Jan 2024 09:00:00 GMT + https://www.cybersecuritynews.com/new-vulnerability-web-framework + + +