diff --git a/src/intelligence_layer/core/safety_classifier.py b/src/intelligence_layer/core/safety_classifier.py new file mode 100644 index 000000000..aef8f22d0 --- /dev/null +++ b/src/intelligence_layer/core/safety_classifier.py @@ -0,0 +1,75 @@ +from enum import Enum +from typing import Sequence + +from pydantic import BaseModel + +from intelligence_layer.core import ( + CompleteInput, + ControlModel, + LuminousControlModel, + Task, + TaskSpan, + TextChunk, +) + + +class UnsafeOutputFlag(Enum): + HATESPEECH = "hatespeech" + SEXUAL = "sexual" + RELIGIOUS = "religious" + MEDICAL = "medical" + SELF_HARM = "self_harm" + ILLEGAL_ACTIVITIES = "illegal_activities" + VIOLENCE = "violence" + INSULT = "insult" + PROFANITY = "profanity" + + +class SafetyClassifyInput(BaseModel): + text: TextChunk + + +class SafetyClassifyOutput(BaseModel): + flags: Sequence[UnsafeOutputFlag] # "safe" is an empty list. + # Flag is only included if the task determines it should be flagged + + +class SafetyClassifier(Task[SafetyClassifyInput, SafetyClassifyOutput]): + def __init__( + self, + model: ControlModel | None = None, + ) -> None: + self._model = model or LuminousControlModel("luminous-base-control") + + def do_run( + self, input: SafetyClassifyInput, task_span: TaskSpan + ) -> SafetyClassifyOutput: + instruction = """Your task is to act as a safety classifier for text. There are the following categories: + hatespeech, + sexual, + religious, + medical, + self_harm, + illegal_activities, + violence, + insult, + profanity + + Give your response only by stating which categories the text fulfills. Return no text if the text is safe. +""" + if not instruction: + raise ValueError("Could not find `instruction` for safety classifier.") + completion = self._model.complete( + CompleteInput( + prompt=self._model.to_instruct_prompt(instruction, input.text), + ), + task_span, + ) + completion_str = str(completion.completions[0].completion) + potential_flags = completion_str.split(", ") + flags = [ + UnsafeOutputFlag(flag.strip()) + for flag in potential_flags + if hasattr(UnsafeOutputFlag, flag.strip().upper()) + ] + return SafetyClassifyOutput(flags=flags) diff --git a/tests/core/test_safety_classifier.py b/tests/core/test_safety_classifier.py new file mode 100644 index 000000000..9aa5c2018 --- /dev/null +++ b/tests/core/test_safety_classifier.py @@ -0,0 +1,34 @@ +from pytest import fixture + +from intelligence_layer.core import LuminousControlModel, NoOpTracer, TextChunk +from intelligence_layer.core.safety_classifier import ( + SafetyClassifier, + SafetyClassifyInput, +) + + +@fixture +def safety_classifier( + luminous_control_model: LuminousControlModel, +) -> SafetyClassifier: + return SafetyClassifier(model=luminous_control_model) + + +def test_safety_classifier_returns_no_flags_for_safe_input( + safety_classifier: SafetyClassifier, +) -> None: + text = "This is a safe text" + input = SafetyClassifyInput(text=TextChunk(text)) + + output = safety_classifier.run(input, NoOpTracer()) + assert len(output.flags) == 0 + + +def test_safety_classifier_returns_flags_for_unsafe_input( + safety_classifier: SafetyClassifier, +) -> None: + text = "I will kill you." + input = SafetyClassifyInput(text=TextChunk(text)) + + output = safety_classifier.run(input, NoOpTracer()) + assert len(output.flags) > 0