Skip to content

Commit

Permalink
Add context_extractor_module tests
Browse files Browse the repository at this point in the history
  • Loading branch information
drobison00 committed Jan 17, 2024
1 parent 05c6fa9 commit 4baac10
Show file tree
Hide file tree
Showing 5 changed files with 219 additions and 82 deletions.
2 changes: 1 addition & 1 deletion examples/llm/common/content_extractor_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def convert(self,
if meta is None:
text_column_name = "content"
else:
text_column_name = meta["csv"]["text_column_name"]
text_column_name = meta.get("csv", {}).get("text_column_name", "content")

for path in file_path:
df = pd.read_csv(path, encoding=encoding)
Expand Down
26 changes: 26 additions & 0 deletions examples/llm/vdb_upload/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,32 @@ vdb_pipeline:
watch: false
```
*Example: Defining a custom source via a config file*
Note: See `vdb_config.yaml` for a full configuration example.
Note: This example uses the same module and config as the filesystem source example above, but explicitly specifies the
module to load

`vdb_config.yaml`

```yaml
vdb_pipeline:
sources:
- type: custom
name: "demo_custom_filesystem_source"
module_id: "file_source_pipe" # Required for custom source, defines the source module to load
module_output_id: "output" # Required for custom source, defines the output of the module to use
namespace: "morpheus_examples_llm" # Required for custom source, defines the namespace of the module to load
config:
batch_size: 1024
extractor_config:
chunk_size: 512
num_threads: 10 # Number of threads to use for file reads
config_name_mapping: "file_source_config"
filenames:
- "/path/to/data/*"
watch: false
```

```bash
python examples/llm/main.py vdb_upload pipeline \
--vdb_config_path "./vdb_config.yaml"
Expand Down
65 changes: 65 additions & 0 deletions morpheus/stages/input/in_memory_data_generation_stage.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
# Copyright (c) 2022-2023, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
from typing import Any
from typing import Callable
from typing import Generator
from typing import Type

import mrc

from morpheus.config import Config
from morpheus.pipeline.single_output_source import SingleOutputSource
from morpheus.pipeline.stage_schema import StageSchema

logger = logging.getLogger(f"morpheus.{__name__}")


class InMemoryDataGenStage(SingleOutputSource):
"""
Source stage that generates data in-memory using a provided generator function.
Parameters
----------
c : `morpheus.config.Config`
Pipeline configuration instance.
generator : Callable[[], Generator[Any, None, None]]
A generator function that yields data to be processed by the pipeline.
output_data_type : Type
The data type of the objects that the generator yields.
"""

def __init__(self, c: Config, generator: Callable[[], Generator[Any, None, None]],
output_data_type: Type = Any):
super().__init__(c)
self._generator = generator
self._output_data_type = output_data_type

@property
def name(self) -> str:
return "in-memory-data-gen"

def compute_schema(self, schema: StageSchema):
# Set the output schema based on the OutputDataType
schema.output_schema.set_type(self._output_data_type)

def supports_cpp_node(self):
return False

def _generate_data(self) -> Generator[Any, None, None]:
yield from self._generator()

def _build_source(self, builder: mrc.Builder) -> mrc.SegmentObject:
return builder.make_source(self.unique_name, self._generate_data())
206 changes: 126 additions & 80 deletions tests/examples/llm/common/test_content_extractor_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,111 +13,157 @@
# limitations under the License.


import logging
import os
import random
import shutil
import string
import sys
import tempfile
import types
from unittest.mock import patch, MagicMock

import cudf
import fsspec
import uuid
from functools import partial
from typing import Callable
from typing import Dict
from typing import Generator
from typing import List

import fsspec.core
import pandas as pd
import pytest
from _utils import TEST_DIRS
from _utils import assert_results

from morpheus.config import Config
from morpheus.messages import MessageMeta
from morpheus.pipeline import LinearPipeline
from morpheus.stages.general.linear_modules_stage import LinearModulesStage
from morpheus.stages.input.in_memory_source_stage import InMemorySourceStage
from morpheus.stages.output.compare_dataframe_stage import CompareDataFrameStage

# Mock dependencies
file_meta_mock = MagicMock()
text_converter_mock = MagicMock()

# TODO

@pytest.mark.use_python
@pytest.mark.use_cudf
@pytest.mark.import_mod(os.path.join(TEST_DIRS.examples_dir, 'llm/common/content_extractor_module.py'))
def test_http_client_source_stage_pipe(config: Config, mock_rest_server: str, import_mod: types.ModuleType):
url = f"{mock_rest_server}/www/index"
from morpheus.stages.input.in_memory_data_generation_stage import InMemoryDataGenStage
from morpheus.stages.output.in_memory_sink_stage import InMemorySinkStage

df = cudf.DataFrame({"link": [url]})
df_expected = cudf.DataFrame({"link": [url], "page_content": "website title some paragraph"})
logger = logging.getLogger(f"morpheus.{__name__}")

web_scraper_definition = import_mod.WebScraperInterface.get_definition("web_scraper",
module_config={"web_scraper_config": {
"link_column": "link", "chunk_size": 100,
"enable_cache": False,
"cache_path": "./.cache/http/RSSDownloadStage.sqlite",
"cache_dir": "./.cache/llm/rss"}})

pipe = LinearPipeline(config)
pipe.set_source(InMemorySourceStage(config, [df]))
pipe.add_stage(LinearModulesStage(config,
web_scraper_definition,
input_type=MessageMeta,
output_type=MessageMeta,
input_port_name="input",
output_port_name="output"))
comp_stage = pipe.add_stage(CompareDataFrameStage(config, compare_df=df_expected))
pipe.run()
class TempCSVFiles:
def __init__(self, num_files: int, columns: Dict[str, Callable[[], any]]):
self.temp_dir = None
self.temp_files = []
self.num_files = num_files
self.columns = columns
self._create_temp_dir_and_files()

print(comp_stage.get_messages())
def _create_temp_dir_and_files(self):
# Create a temporary directory
self.temp_dir = os.path.join(tempfile.gettempdir(), uuid.uuid4().hex)
os.makedirs(self.temp_dir, exist_ok=True)

assert_results(comp_stage.get_results())
for _ in range(self.num_files):
# Create a random filename within the temp directory
file_path = os.path.join(self.temp_dir, f"{uuid.uuid4().hex}.csv")

# Generate deterministic CSV data based on the specified columns
data = {col_name: col_func() for col_name, col_func in self.columns.items()}
df = pd.DataFrame(data)
df.to_csv(file_path, index=False)

# 1. Test with Mocked Files and Converters
def test_parse_files_with_mocked_files():
with patch('your_module.get_file_meta', return_value=file_meta_mock), \
patch('your_module.TextConverter', return_value=text_converter_mock):
open_files = [MagicMock(spec=fsspec.core.OpenFile) for _ in range(5)]
expected_data = [{'content': 'mock content'}] * len(open_files)
text_converter_mock.convert.return_value = expected_data
# Store the file path for later use
self.temp_files.append(file_path)

result = your_module.parse_files(open_files)
def __enter__(self):
return self.temp_files

assert isinstance(result, MessageMeta)
assert len(result.df) == len(open_files)
assert result.df.to_dict('records') == expected_data
def __exit__(self, exc_type, exc_value, traceback):
# Cleanup the temporary directory and its contents
if self.temp_dir and os.path.exists(self.temp_dir):
shutil.rmtree(self.temp_dir)


# 2. Test Handling of Exceptions During File Processing
def test_parse_files_with_exception():
with patch('your_module.get_file_meta', side_effect=Exception("Error")), \
patch('your_module.logger') as logger_mock:
open_files = [MagicMock(spec=fsspec.core.OpenFile) for _ in range(2)]
# Define a generator function that uses TempCSVFiles to generate CSV file paths
def csv_file_generator(csv_files: List[str], batch_size: int) -> Generator[
List[fsspec.core.OpenFile], None, None]:
# Create TempCSVFiles instance without using 'with' statement
open_files = fsspec.open_files(csv_files.temp_files)
for i in range(0, len(open_files), batch_size):
yield open_files[i:i + batch_size]

result = your_module.parse_files(open_files)

assert logger_mock.error.called
assert isinstance(result, MessageMeta)
assert result.df.empty
def generate_random_string(length: int) -> str:
return ''.join(random.choices(string.ascii_letters + string.digits, k=length))


# 3. Test Batch Processing
def test_parse_files_batch_processing():
batch_size = 2
open_files = [MagicMock(spec=fsspec.core.OpenFile) for _ in range(5)]
# Fixture for importing the module
@pytest.fixture(scope="module")
def import_content_extractor_module():
sys.path.insert(0, os.path.join(TEST_DIRS.examples_dir, 'llm/common'))
import content_extractor_module
sys.path.remove(os.path.join(TEST_DIRS.examples_dir, 'llm/common'))
return content_extractor_module

# Modify your_module.batch_size accordingly
your_module.batch_size = batch_size

result = your_module.parse_files(open_files)

assert len(result.df) == len(open_files) # Assuming each file results in one row


# 4. Test Processing Different File Types
@pytest.mark.parametrize("file_type, converter", [("pdf", pdf_converter_mock), ("txt", text_converter_mock)])
def test_parse_files_different_file_types(file_type, converter):
with patch('your_module.get_file_meta', return_value={"file_type": file_type}), \
patch(f'your_module.{converter.__class__.__name__}', return_value=converter):
open_files = [MagicMock(spec=fsspec.core.OpenFile) for _ in range(2)]
converter.convert.return_value = [{'content': 'mock content'}]
# Test function
@pytest.mark.use_python
@pytest.mark.use_cudf
@pytest.mark.parametrize("data_len, num_rows_per_file, batch_size", [
(40, 5, 2),
(51, 3, 1),
(150, 10, 5),
(500, 3, 2),
(1000, 5, 3),
(50, 10, 2),
(100, 20, 3),
(50, 5, 1),
(100, 10, 1),
(49, 5, 2),
(99, 5, 2),
(60, 7, 2),
(120, 6, 3),
(1000, 50, 10),
(2000, 100, 20)
])
def test_content_extractor_module(data_len, num_rows_per_file, batch_size, config: Config,
import_content_extractor_module: types.ModuleType):
chunk_size = 50
chunk_overlap = 10
# Text splitter handles things a bit differently on evenly divisible boundaries
chunk_boundary_size = (chunk_size - chunk_overlap) if (data_len > chunk_size) else chunk_size
module_config = {
"converters_meta": {
"csv": {
"chunk_size": chunk_size,
"chunk_overlap": chunk_overlap,
"text_column_name": "some_column",
}
},
"content_extractor_config": {
"batch_size": batch_size,
"num_threads": 1,
},
"enable_monitor": False,
}
content_extractor_def = import_content_extractor_module.FileContentExtractorInterface.get_definition(
"content_extractor",
module_config=module_config)

temp_csv_files = TempCSVFiles(num_files=5,
columns={'some_column': lambda: [generate_random_string(data_len) for _ in
range(num_rows_per_file)]})
file_generator = partial(csv_file_generator, temp_csv_files, batch_size=1)

result = your_module.parse_files(open_files)
pipe = LinearPipeline(config)
pipe.set_source(InMemoryDataGenStage(config, file_generator, output_data_type=List[fsspec.core.OpenFile]))
pipe.add_stage(LinearModulesStage(config,
content_extractor_def,
input_type=List[fsspec.core.OpenFile],
output_type=MessageMeta,
input_port_name="input",
output_port_name="output"))
sink_stage = pipe.add_stage(InMemorySinkStage(config))
pipe.run()

assert converter.convert.called
assert len(result.df) == len(open_files)
expected_columns = ["title", "source", "summary", "content"]
for message in sink_stage.get_messages():
output = message.df
assert set(expected_columns) == set(output.columns)
assert output.shape == (
num_rows_per_file * ((data_len // chunk_boundary_size) + (
1 if data_len % chunk_boundary_size else 0)),
4)
2 changes: 1 addition & 1 deletion tests/examples/llm/common/test_web_scraper_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
@pytest.mark.use_python
@pytest.mark.use_cudf
@pytest.mark.import_mod(os.path.join(TEST_DIRS.examples_dir, 'llm/common/web_scraper_module.py'))
def test_http_client_source_stage_pipe(config: Config, mock_rest_server: str, import_mod: types.ModuleType):
def test_web_scraper_module(config: Config, mock_rest_server: str, import_mod: types.ModuleType):
url = f"{mock_rest_server}/www/index"

df = cudf.DataFrame({"link": [url]})
Expand Down

0 comments on commit 4baac10

Please sign in to comment.