diff --git a/src/intelligence_layer/core/safety_classifier.py b/src/intelligence_layer/core/safety_classifier.py index 2a85f2f17..7dcc00eb3 100644 --- a/src/intelligence_layer/core/safety_classifier.py +++ b/src/intelligence_layer/core/safety_classifier.py @@ -4,54 +4,29 @@ from pydantic import BaseModel from intelligence_layer.core import ( - CompleteInput, - ControlModel, Language, LuminousControlModel, Task, TaskSpan, TextChunk, ) +from intelligence_layer.use_cases import ClassifyInput, PromptBasedClassify INSTRUCTION_CONFIGS = { Language( "en" - ): """Your task is to act as a safety classifier for text. There are the following categories: --hatespeech, --sexual, --religious, --medical, --self_harm, --illegal_activities, --violence, --insult, --profanity - -Your response should be a single line containing the labels of which categories the text fulfills. -Return the label 'safe' if the text is safe. -Please make sure your response is a single line. -""", + ): """Your task is to act as a safety classifier for text. Identify a class that describes the text adequately. + Reply with only the class label.""", Language( "de" - ): """Deine Aufgabe besteht darin, als Sicherheitsklassifizierer für Text zu fungieren. Es gibt folgende Kategorien: --hatespeech, --sexual, --religious, --medical, --self_harm, --illegal_activities, --violence, --insult, --profanity - -Deine Antwort solle eine Zeile mit Labels der jeweiligen Kategorien umfassen, die der Text erfüllt. -Gib das Label 'safe' zurück, wenn der Text sicher ist. -Bitte stelle sicher, dass deine Antwort eine Zeile lang ist. -""", + ): """Deine Aufgabe besteht darin, als Sicherheitsklassifizierer für Text zu fungieren. Identifiziere eine Klasse, + die den Text adäquat beschreibt. Antworte nur mit dem Label der Klasse. + """, } -class UnsafeOutputFlag(Enum): +class SafetyOutputFlag(Enum): + SAFE = "safe" HATESPEECH = "hatespeech" SEXUAL = "sexual" RELIGIOUS = "religious" @@ -60,7 +35,6 @@ class UnsafeOutputFlag(Enum): ILLEGAL_ACTIVITIES = "illegal_activities" VIOLENCE = "violence" INSULT = "insult" - PROFANITY = "profanity" class SafetyClassifyInput(BaseModel): @@ -69,14 +43,14 @@ class SafetyClassifyInput(BaseModel): class SafetyClassifyOutput(BaseModel): - flags: Sequence[UnsafeOutputFlag] # "safe" is an empty list. + flags: Sequence[SafetyOutputFlag] # "safe" is an empty list. # Flag is only included if the task determines it should be flagged class SafetyClassifier(Task[SafetyClassifyInput, SafetyClassifyOutput]): def __init__( self, - model: ControlModel | None = None, + model: LuminousControlModel | None = None, instruction_configs: Mapping[Language, str] = INSTRUCTION_CONFIGS, ) -> None: self._model = model or LuminousControlModel("luminous-nextgen-7b-control-384k") @@ -89,17 +63,22 @@ def do_run( if not instruction: raise ValueError(f"Could not find `prompt_config` for {input.language}.") - completion = self._model.complete( - CompleteInput( - prompt=self._model.to_instruct_prompt(instruction, input.text), - ), - task_span, + classify_inputs = ClassifyInput( + chunk=input.text, + labels=frozenset({flag.value for flag in SafetyOutputFlag}), ) - completion_str = str(completion.completions[0].completion) - potential_flags = completion_str.split(", ") - flags = [ - UnsafeOutputFlag(flag.strip()) - for flag in potential_flags - if hasattr(UnsafeOutputFlag, flag.strip().upper()) - ] - return SafetyClassifyOutput(flags=flags) + prompt_based_classify = PromptBasedClassify( + model=self._model, instruction=instruction + ) + output_probabilities_per_flag = prompt_based_classify.run( + classify_inputs, task_span + ) + + most_probable_flag = SafetyOutputFlag( + output_probabilities_per_flag.sorted_scores[0][0] + ) + + if most_probable_flag == SafetyOutputFlag.SAFE: + return SafetyClassifyOutput(flags=[]) + + return SafetyClassifyOutput(flags=[most_probable_flag]) diff --git a/tests/core/test_safety_classifier.py b/tests/core/test_safety_classifier.py index 14acd1eb1..40d7b28e7 100644 --- a/tests/core/test_safety_classifier.py +++ b/tests/core/test_safety_classifier.py @@ -12,7 +12,7 @@ from intelligence_layer.core.safety_classifier import ( SafetyClassifier, SafetyClassifyInput, - UnsafeOutputFlag, + SafetyOutputFlag, ) @@ -36,19 +36,19 @@ def test_safety_classifier_returns_no_flags_for_safe_input( @pytest.mark.parametrize( "text, expected_labels", [ - ("I will kill you.", [UnsafeOutputFlag.VIOLENCE]), - ("I want to kill myself.", [UnsafeOutputFlag.SELF_HARM]), - ("You fucking asshole.", [UnsafeOutputFlag.INSULT]), + ("I will kill you.", [SafetyOutputFlag.VIOLENCE]), + ("I want to kill myself.", [SafetyOutputFlag.SELF_HARM]), + ("You fucking asshole.", [SafetyOutputFlag.INSULT]), ( "The holy lord is going to deliver us from evil.", - [UnsafeOutputFlag.RELIGIOUS], + [SafetyOutputFlag.RELIGIOUS], ), ], ) def test_safety_classifier_returns_flags_for_unsafe_english_input( safety_classifier: SafetyClassifier, text: str, - expected_labels: List[UnsafeOutputFlag], + expected_labels: List[SafetyOutputFlag], ) -> None: input = SafetyClassifyInput(text=TextChunk(text)) @@ -59,19 +59,19 @@ def test_safety_classifier_returns_flags_for_unsafe_english_input( @pytest.mark.parametrize( "text, expected_labels", [ - ("Ich töte dich.", [UnsafeOutputFlag.VIOLENCE]), - ("Ich möchte mich umbringen.", [UnsafeOutputFlag.SELF_HARM]), - ("Du verdammtes Arschloch.", [UnsafeOutputFlag.INSULT]), + ("Ich töte dich.", [SafetyOutputFlag.VIOLENCE]), + ("Ich möchte mich umbringen.", [SafetyOutputFlag.SELF_HARM]), + ("Du verdammtes Arschloch.", [SafetyOutputFlag.INSULT]), ( "Gott ist gütig und wird uns von allem Bösen erlösen.", - [UnsafeOutputFlag.RELIGIOUS], + [SafetyOutputFlag.RELIGIOUS], ), ], ) def test_safety_classifier_returns_flags_for_unsafe_german_input( safety_classifier: SafetyClassifier, text: str, - expected_labels: List[UnsafeOutputFlag], + expected_labels: List[SafetyOutputFlag], ) -> None: input = SafetyClassifyInput(text=TextChunk(text), language=Language("de"))