Skip to content

Commit

Permalink
Add docs and tests for WebScraperStage (#1349)
Browse files Browse the repository at this point in the history
Closes #1283

Authors:
  - Christopher Harris (https://github.com/cwharris)

Approvers:
  - Devin Robison (https://github.com/drobison00)

URL: #1349
  • Loading branch information
cwharris authored Nov 8, 2023
1 parent 895d915 commit 782c142
Show file tree
Hide file tree
Showing 4 changed files with 108 additions and 15 deletions.
42 changes: 27 additions & 15 deletions examples/llm/common/web_scraper_stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,19 +28,30 @@
from morpheus.config import Config
from morpheus.messages import MessageMeta
from morpheus.pipeline.single_port_stage import SinglePortStage
from morpheus.pipeline.stream_pair import StreamPair
from morpheus.pipeline.stage_schema import StageSchema

logger = logging.getLogger(f"morpheus.{__name__}")


class WebScraperStage(SinglePortStage):

def __init__(self, c: Config, *, chunk_size, link_column: str = "link"):
"""
Stage for scraping web based content using the HTTP GET protocol.
Parameters
----------
c : morpheus.config.Config
Pipeline configuration instance.
chunk_size : int
Size in which to split the scraped content
link_column : str, default="link"
Column which contains the links to scrape
"""

def __init__(self, c: Config, *, chunk_size: int, link_column: str = "link"):
super().__init__(c)

self._link_column = link_column
self._chunk_size = chunk_size

self._cache_dir = "./.cache/llm/rss/"

# Ensure the directory exists
Expand Down Expand Up @@ -79,19 +90,26 @@ def supports_cpp_node(self):
"""Indicates whether this stage supports a C++ node."""
return False

def _build_single(self, builder: mrc.Builder, input_stream: StreamPair) -> StreamPair:
def compute_schema(self, schema: StageSchema):
schema.output_schema.set_type(MessageMeta)

def _build_single(self, builder: mrc.Builder, input_node: mrc.SegmentObject) -> mrc.SegmentObject:

node = builder.make_node(self.unique_name,
ops.map(self._download_and_split),
ops.filter(lambda x: x is not None))

node.launch_options.pe_count = self._config.num_threads

builder.make_edge(input_stream[0], node)
builder.make_edge(input_node, node)

return node, input_stream[1]
return node

def _download_and_split(self, msg: MessageMeta) -> MessageMeta:

"""
Uses the HTTP GET method to download/scrape the links found in the message, splits the scraped data, and stores
it in the output, excludes output for any links which produce an error.
"""
if self._link_column not in msg.get_column_names():
return None

Expand Down Expand Up @@ -125,13 +143,7 @@ def _download_and_split(self, msg: MessageMeta) -> MessageMeta:

soup = BeautifulSoup(raw_html, "html.parser")

text = soup.get_text(strip=True)

# article = Article(url)
# article.download()
# article.parse()
# print(article.text)
# text = article.text
text = soup.get_text(strip=True, separator=' ')

split_text = self._text_splitter.split_text(text)

Expand Down
28 changes: 28 additions & 0 deletions tests/examples/llm/common/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# SPDX-FileCopyrightText: Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import pytest

from _utils import import_or_skip


@pytest.fixture(name="nemollm", autouse=True, scope='session')
def nemollm_fixture(fail_missing: bool):
"""
All of the tests in this subdir require nemollm
"""
skip_reason = ("Tests for the WebScraperStage require the langchain package to be installed, to install this run:\n"
"`mamba env update -n ${CONDA_DEFAULT_ENV} --file docker/conda/environments/cuda11.8_examples.yml`")
yield import_or_skip("langchain", reason=skip_reason, fail_missing=fail_missing)
49 changes: 49 additions & 0 deletions tests/examples/llm/common/test_web_scraper_stage.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# Copyright (c) 2023, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import types

import pytest

import cudf

from _utils import TEST_DIRS
from _utils import assert_results
from morpheus.config import Config
from morpheus.pipeline import LinearPipeline
from morpheus.stages.input.in_memory_source_stage import InMemorySourceStage
from morpheus.stages.output.compare_dataframe_stage import CompareDataFrameStage


@pytest.mark.slow
@pytest.mark.use_python
@pytest.mark.use_cudf
@pytest.mark.import_mod(os.path.join(TEST_DIRS.examples_dir, 'llm/common/web_scraper_stage.py'))
def test_http_client_source_stage_pipe(config: Config, mock_rest_server: str, import_mod: types.ModuleType):
url = f"{mock_rest_server}/www/index"

df = cudf.DataFrame({"link": [url]})

df_expected = cudf.DataFrame({"link": [url], "page_content": "website title some paragraph"})

pipe = LinearPipeline(config)
pipe.set_source(InMemorySourceStage(config, [df]))
pipe.add_stage(import_mod.WebScraperStage(config, chunk_size=config.feature_length))
comp_stage = pipe.add_stage(CompareDataFrameStage(config, compare_df=df_expected))
pipe.run()

print(comp_stage.get_messages())

assert_results(comp_stage.get_results())
4 changes: 4 additions & 0 deletions tests/mock_rest_server/mocks/www/index/GET.mock
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
HTTP/1.1 200 OK
Content-Type: application/json

<!DOCTYPE html><html><body><title>website title</title><p>some paragraph</p></body></html>

0 comments on commit 782c142

Please sign in to comment.