Skip to content

Commit

Permalink
Merge branch 'branch-23.11' into david-branch-23.11-docs-1356-1361
Browse files Browse the repository at this point in the history
  • Loading branch information
dagardner-nv authored Nov 22, 2023
2 parents 5d26d52 + c4f2b56 commit 8cbfd3f
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 11 deletions.
19 changes: 10 additions & 9 deletions tests/llm/test_vdb_upload_pipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,9 @@ def _run_pipeline(config: Config,

pipe = LinearPipeline(config)

pipe.set_source(RSSSourceStage(config, feed_input=rss_files, batch_size=128, run_indefinitely=False))
pipe.add_stage(web_scraper_stage_mod.WebScraperStage(config, chunk_size=MODEL_FEA_LENGTH))
pipe.set_source(
RSSSourceStage(config, feed_input=rss_files, batch_size=128, run_indefinitely=False, enable_cache=False))
pipe.add_stage(web_scraper_stage_mod.WebScraperStage(config, chunk_size=MODEL_FEA_LENGTH, enable_cache=False))
pipe.add_stage(DeserializeStage(config))

pipe.add_stage(
Expand All @@ -70,7 +71,7 @@ def _run_pipeline(config: Config,
column='page_content'))

pipe.add_stage(
TritonInferenceStage(config, model_name='all-MiniLM-L6-v2', server_url='test:0000', force_convert_inputs=True))
TritonInferenceStage(config, model_name='test-model', server_url='test:0000', force_convert_inputs=True))

pipe.add_stage(
WriteToVectorDBStage(config,
Expand All @@ -89,10 +90,10 @@ def _run_pipeline(config: Config,
os.path.join(TEST_DIRS.examples_dir, 'llm/common/utils.py'),
os.path.join(TEST_DIRS.examples_dir, 'llm/common/web_scraper_stage.py')
])
@mock.patch('requests_cache.CachedSession')
@mock.patch('requests.Session')
@mock.patch('tritonclient.grpc.InferenceServerClient')
def test_vdb_upload_pipe(mock_triton_client: mock.MagicMock,
mock_cache_session: mock.MagicMock,
mock_requests_session: mock.MagicMock,
config: Config,
dataset: DatasetManager,
milvus_server_uri: str,
Expand Down Expand Up @@ -134,17 +135,17 @@ def test_vdb_upload_pipe(mock_triton_client: mock.MagicMock,
async_infer = mk_async_infer(inf_results)
mock_triton_client.async_infer.side_effect = async_infer

# Mock requests_cache, since we are feeding the RSSSourceStage with a local file it won't be using the
# requests_cache lib, only web_scraper_stage.py will use it.
# Mock requests, since we are feeding the RSSSourceStage with a local file it won't be using the
# requests lib, only web_scraper_stage.py will use it.
def mock_get_fn(url: str):
mock_response = mock.MagicMock()
mock_response.ok = True
mock_response.status_code = 200
mock_response.text = web_responses[url]
return mock_response

mock_cache_session.return_value = mock_cache_session
mock_cache_session.get.side_effect = mock_get_fn
mock_requests_session.return_value = mock_requests_session
mock_requests_session.get.side_effect = mock_get_fn

(utils_mod, web_scraper_stage_mod) = import_mod
collection_name = "test_vdb_upload_pipe"
Expand Down
4 changes: 2 additions & 2 deletions tests/tests_data/service/milvus_rss_data.json
Git LFS file not shown

0 comments on commit 8cbfd3f

Please sign in to comment.