Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Use prompt based classifier for safety classifier #764

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 28 additions & 49 deletions src/intelligence_layer/core/safety_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,54 +4,29 @@
from pydantic import BaseModel

from intelligence_layer.core import (
CompleteInput,
ControlModel,
Language,
LuminousControlModel,
Task,
TaskSpan,
TextChunk,
)
from intelligence_layer.use_cases import ClassifyInput, PromptBasedClassify

INSTRUCTION_CONFIGS = {
Language(
"en"
): """Your task is to act as a safety classifier for text. There are the following categories:
-hatespeech,
-sexual,
-religious,
-medical,
-self_harm,
-illegal_activities,
-violence,
-insult,
-profanity

Your response should be a single line containing the labels of which categories the text fulfills.
Return the label 'safe' if the text is safe.
Please make sure your response is a single line.
""",
): """Your task is to act as a safety classifier for text. Identify a class that describes the text adequately.
Reply with only the class label.""",
Language(
"de"
): """Deine Aufgabe besteht darin, als Sicherheitsklassifizierer für Text zu fungieren. Es gibt folgende Kategorien:
-hatespeech,
-sexual,
-religious,
-medical,
-self_harm,
-illegal_activities,
-violence,
-insult,
-profanity

Deine Antwort solle eine Zeile mit Labels der jeweiligen Kategorien umfassen, die der Text erfüllt.
Gib das Label 'safe' zurück, wenn der Text sicher ist.
Bitte stelle sicher, dass deine Antwort eine Zeile lang ist.
""",
): """Deine Aufgabe besteht darin, als Sicherheitsklassifizierer für Text zu fungieren. Identifiziere eine Klasse,
die den Text adäquat beschreibt. Antworte nur mit dem Label der Klasse.
""",
}


class UnsafeOutputFlag(Enum):
class SafetyOutputFlag(Enum):
SAFE = "safe"
HATESPEECH = "hatespeech"
SEXUAL = "sexual"
RELIGIOUS = "religious"
Expand All @@ -60,7 +35,6 @@ class UnsafeOutputFlag(Enum):
ILLEGAL_ACTIVITIES = "illegal_activities"
VIOLENCE = "violence"
INSULT = "insult"
PROFANITY = "profanity"


class SafetyClassifyInput(BaseModel):
Expand All @@ -69,14 +43,14 @@ class SafetyClassifyInput(BaseModel):


class SafetyClassifyOutput(BaseModel):
flags: Sequence[UnsafeOutputFlag] # "safe" is an empty list.
flags: Sequence[SafetyOutputFlag] # "safe" is an empty list.
# Flag is only included if the task determines it should be flagged


class SafetyClassifier(Task[SafetyClassifyInput, SafetyClassifyOutput]):
def __init__(
self,
model: ControlModel | None = None,
model: LuminousControlModel | None = None,
instruction_configs: Mapping[Language, str] = INSTRUCTION_CONFIGS,
) -> None:
self._model = model or LuminousControlModel("luminous-nextgen-7b-control-384k")
Expand All @@ -89,17 +63,22 @@ def do_run(

if not instruction:
raise ValueError(f"Could not find `prompt_config` for {input.language}.")
completion = self._model.complete(
CompleteInput(
prompt=self._model.to_instruct_prompt(instruction, input.text),
),
task_span,
classify_inputs = ClassifyInput(
chunk=input.text,
labels=frozenset({flag.value for flag in SafetyOutputFlag}),
)
completion_str = str(completion.completions[0].completion)
potential_flags = completion_str.split(", ")
flags = [
UnsafeOutputFlag(flag.strip())
for flag in potential_flags
if hasattr(UnsafeOutputFlag, flag.strip().upper())
]
return SafetyClassifyOutput(flags=flags)
prompt_based_classify = PromptBasedClassify(
model=self._model, instruction=instruction
)
output_probabilities_per_flag = prompt_based_classify.run(
classify_inputs, task_span
)

most_probable_flag = SafetyOutputFlag(
output_probabilities_per_flag.sorted_scores[0][0]
)

if most_probable_flag == SafetyOutputFlag.SAFE:
return SafetyClassifyOutput(flags=[])

return SafetyClassifyOutput(flags=[most_probable_flag])
22 changes: 11 additions & 11 deletions tests/core/test_safety_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from intelligence_layer.core.safety_classifier import (
SafetyClassifier,
SafetyClassifyInput,
UnsafeOutputFlag,
SafetyOutputFlag,
)


Expand All @@ -36,19 +36,19 @@ def test_safety_classifier_returns_no_flags_for_safe_input(
@pytest.mark.parametrize(
"text, expected_labels",
[
("I will kill you.", [UnsafeOutputFlag.VIOLENCE]),
("I want to kill myself.", [UnsafeOutputFlag.SELF_HARM]),
("You fucking asshole.", [UnsafeOutputFlag.INSULT]),
("I will kill you.", [SafetyOutputFlag.VIOLENCE]),
("I want to kill myself.", [SafetyOutputFlag.SELF_HARM]),
("You fucking asshole.", [SafetyOutputFlag.INSULT]),
(
"The holy lord is going to deliver us from evil.",
[UnsafeOutputFlag.RELIGIOUS],
[SafetyOutputFlag.RELIGIOUS],
),
],
)
def test_safety_classifier_returns_flags_for_unsafe_english_input(
safety_classifier: SafetyClassifier,
text: str,
expected_labels: List[UnsafeOutputFlag],
expected_labels: List[SafetyOutputFlag],
) -> None:
input = SafetyClassifyInput(text=TextChunk(text))

Expand All @@ -59,19 +59,19 @@ def test_safety_classifier_returns_flags_for_unsafe_english_input(
@pytest.mark.parametrize(
"text, expected_labels",
[
("Ich töte dich.", [UnsafeOutputFlag.VIOLENCE]),
("Ich möchte mich umbringen.", [UnsafeOutputFlag.SELF_HARM]),
("Du verdammtes Arschloch.", [UnsafeOutputFlag.INSULT]),
("Ich töte dich.", [SafetyOutputFlag.VIOLENCE]),
("Ich möchte mich umbringen.", [SafetyOutputFlag.SELF_HARM]),
("Du verdammtes Arschloch.", [SafetyOutputFlag.INSULT]),
(
"Gott ist gütig und wird uns von allem Bösen erlösen.",
[UnsafeOutputFlag.RELIGIOUS],
[SafetyOutputFlag.RELIGIOUS],
),
],
)
def test_safety_classifier_returns_flags_for_unsafe_german_input(
safety_classifier: SafetyClassifier,
text: str,
expected_labels: List[UnsafeOutputFlag],
expected_labels: List[SafetyOutputFlag],
) -> None:
input = SafetyClassifyInput(text=TextChunk(text), language=Language("de"))

Expand Down