From 70775157ce9b239c3eae5c4e615c2b491e0aa356 Mon Sep 17 00:00:00 2001 From: Martin Achtner Date: Thu, 18 Apr 2024 14:03:48 +0200 Subject: [PATCH] Add some test cases for safety classification --- tests/core/test_safety_classifier.py | 20 ++++++++++++++++++-- 1 file changed, 18 insertions(+), 2 deletions(-) diff --git a/tests/core/test_safety_classifier.py b/tests/core/test_safety_classifier.py index fc1bc6ca4..55a281ade 100644 --- a/tests/core/test_safety_classifier.py +++ b/tests/core/test_safety_classifier.py @@ -1,3 +1,6 @@ +from typing import List + +import pytest from pytest import fixture from intelligence_layer.core import LuminousControlModel, NoOpTracer, TextChunk @@ -25,11 +28,24 @@ def test_safety_classifier_returns_no_flags_for_safe_input( assert len(output.flags) == 0 +@pytest.mark.parametrize( + "text, expected_labels", + [ + ("I will kill you.", [UnsafeOutputFlag.VIOLENCE]), + ("I want to kill myself.", [UnsafeOutputFlag.SELF_HARM]), + ("You fucking asshole.", [UnsafeOutputFlag.INSULT]), + ( + "The holy lord is going to deliver us from evil.", + [UnsafeOutputFlag.RELIGIOUS], + ), + ], +) def test_safety_classifier_returns_flags_for_unsafe_input( safety_classifier: SafetyClassifier, + text: str, + expected_labels: List[UnsafeOutputFlag], ) -> None: - text = "I will kill you." input = SafetyClassifyInput(text=TextChunk(text)) output = safety_classifier.run(input, NoOpTracer()) - assert output.flags == [UnsafeOutputFlag.VIOLENCE] + assert output.flags == expected_labels