diff --git a/tests/examples/log_parsing/test_postprocessing.py b/tests/examples/log_parsing/test_postprocessing.py index 68242a8860..c76c7dcab4 100644 --- a/tests/examples/log_parsing/test_postprocessing.py +++ b/tests/examples/log_parsing/test_postprocessing.py @@ -13,7 +13,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging import os +import re import types import typing @@ -81,7 +83,8 @@ def test_log_parsing_post_processing_stage(config: Config, @pytest.mark.import_mod(os.path.join(TEST_DIRS.examples_dir, 'log_parsing', 'postprocessing.py')) -def test_undefined_variable_error(config: Config, +def test_undefined_variable_error(caplog: pytest.LogCaptureFixture, + config: Config, dataset_cudf: DatasetManager, import_mod: typing.List[types.ModuleType], bert_cased_vocab: str, @@ -101,5 +104,16 @@ def test_undefined_variable_error(config: Config, post_proc_message = build_post_proc_message(dataset_cudf, log_test_data_dir) post_proc_message.get_tensor('input_ids')[0] = 27716.0 - with pytest.warns(UserWarning, match=r'Ignoring unexecpected subword token'): + expected_log_re = re.compile(r"^Ignoring unexecpected subword token:.*") + + caplog.clear() + with caplog.at_level(logging.WARNING): stage._postprocess(post_proc_message) + + logged_warning = False + for rec in caplog.records: + if rec.levelno == logging.WARNING and expected_log_re.match(rec.message) is not None: + logged_warning = True + break + + assert logged_warning, "Expected warning message not found in logs"