Skip to content

Commit

Permalink
use PromptBasedClassify instead of custom text prompt
Browse files Browse the repository at this point in the history
  • Loading branch information
SebastianWolfschmidtAA committed Apr 18, 2024
1 parent fd578fb commit 2a6a6db
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 64 deletions.
79 changes: 29 additions & 50 deletions src/intelligence_layer/core/safety_classifier.py
Original file line number Diff line number Diff line change
@@ -1,57 +1,32 @@
from enum import Enum
from typing import Sequence, Mapping
from typing import Mapping, Sequence

from pydantic import BaseModel

from intelligence_layer.core import (
CompleteInput,
ControlModel,
Language,
LuminousControlModel,
Task,
TaskSpan,
TextChunk,
)
from intelligence_layer.use_cases import ClassifyInput, PromptBasedClassify

INSTRUCTION_CONFIGS = {
Language(
"en"
): """Your task is to act as a safety classifier for text. There are the following categories:
-hatespeech,
-sexual,
-religious,
-medical,
-self_harm,
-illegal_activities,
-violence,
-insult,
-profanity
Your response should be a single line containing the labels of which categories the text fulfills.
Return the label 'safe' if the text is safe.
Please make sure your response is a single line.
""",
): """Your task is to act as a safety classifier for text. Identify a class that describes the text adequately.
Reply with only the class label.""",
Language(
"de"
): """Deine Aufgabe besteht darin, als Sicherheitsklassifizierer für Text zu fungieren. Es gibt folgende Kategorien:
-hatespeech,
-sexual,
-religious,
-medical,
-self_harm,
-illegal_activities,
-violence,
-insult,
-profanity
Deine Antwort solle eine Zeile mit Labels der jeweiligen Kategorien umfassen, die der Text erfüllt.
Gib das Label 'safe' zurück, wenn der Text sicher ist.
Bitte stelle sicher, dass deine Antwort eine Zeile lang ist.
""",
): """Deine Aufgabe besteht darin, als Sicherheitsklassifizierer für Text zu fungieren. Identifiziere eine Klasse,
die den Text adäquat beschreibt. Antworte nur mit dem Label der Klasse.
""",
}


class UnsafeOutputFlag(Enum):
class SafetyOutputFlag(Enum):
SAFE = "safe"
HATESPEECH = "hatespeech"
SEXUAL = "sexual"
RELIGIOUS = "religious"
Expand All @@ -60,7 +35,6 @@ class UnsafeOutputFlag(Enum):
ILLEGAL_ACTIVITIES = "illegal_activities"
VIOLENCE = "violence"
INSULT = "insult"
PROFANITY = "profanity"


class SafetyClassifyInput(BaseModel):
Expand All @@ -69,14 +43,14 @@ class SafetyClassifyInput(BaseModel):


class SafetyClassifyOutput(BaseModel):
flags: Sequence[UnsafeOutputFlag] # "safe" is an empty list.
flags: Sequence[SafetyOutputFlag] # "safe" is an empty list.
# Flag is only included if the task determines it should be flagged


class SafetyClassifier(Task[SafetyClassifyInput, SafetyClassifyOutput]):
def __init__(
self,
model: ControlModel | None = None,
model: LuminousControlModel | None = None,
instruction_configs: Mapping[Language, str] = INSTRUCTION_CONFIGS,
) -> None:
self._model = model or LuminousControlModel("luminous-nextgen-7b-control-384k")
Expand All @@ -89,17 +63,22 @@ def do_run(

if not instruction:
raise ValueError(f"Could not find `prompt_config` for {input.language}.")
completion = self._model.complete(
CompleteInput(
prompt=self._model.to_instruct_prompt(instruction, input.text),
),
task_span,
classify_inputs = ClassifyInput(
chunk=input.text,
labels=frozenset({flag.value for flag in SafetyOutputFlag}),
)
completion_str = str(completion.completions[0].completion)
potential_flags = completion_str.split(", ")
flags = [
UnsafeOutputFlag(flag.strip())
for flag in potential_flags
if hasattr(UnsafeOutputFlag, flag.strip().upper())
]
return SafetyClassifyOutput(flags=flags)
prompt_based_classify = PromptBasedClassify(
model=self._model, instruction=instruction
)
output_probabilities_per_flag = prompt_based_classify.run(
classify_inputs, task_span
)

most_probable_flag = SafetyOutputFlag(
output_probabilities_per_flag.sorted_scores[0][0]
)

if most_probable_flag == SafetyOutputFlag.SAFE:
return SafetyClassifyOutput(flags=[])

return SafetyClassifyOutput(flags=[most_probable_flag])
33 changes: 19 additions & 14 deletions tests/core/test_safety_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,16 @@
import pytest
from pytest import fixture

from intelligence_layer.core import LuminousControlModel, NoOpTracer, TextChunk, Language
from intelligence_layer.core import (
Language,
LuminousControlModel,
NoOpTracer,
TextChunk,
)
from intelligence_layer.core.safety_classifier import (
SafetyClassifier,
SafetyClassifyInput,
UnsafeOutputFlag,
SafetyOutputFlag,
)


Expand All @@ -31,19 +36,19 @@ def test_safety_classifier_returns_no_flags_for_safe_input(
@pytest.mark.parametrize(
"text, expected_labels",
[
("I will kill you.", [UnsafeOutputFlag.VIOLENCE]),
("I want to kill myself.", [UnsafeOutputFlag.SELF_HARM]),
("You fucking asshole.", [UnsafeOutputFlag.INSULT]),
("I will kill you.", [SafetyOutputFlag.VIOLENCE]),
("I want to kill myself.", [SafetyOutputFlag.SELF_HARM]),
("You fucking asshole.", [SafetyOutputFlag.INSULT]),
(
"The holy lord is going to deliver us from evil.",
[UnsafeOutputFlag.RELIGIOUS],
[SafetyOutputFlag.RELIGIOUS],
),
],
)
def test_safety_classifier_returns_flags_for_unsafe_input(
def test_safety_classifier_returns_flags_for_unsafe_english_input(
safety_classifier: SafetyClassifier,
text: str,
expected_labels: List[UnsafeOutputFlag],
expected_labels: List[SafetyOutputFlag],
) -> None:
input = SafetyClassifyInput(text=TextChunk(text))

Expand All @@ -54,19 +59,19 @@ def test_safety_classifier_returns_flags_for_unsafe_input(
@pytest.mark.parametrize(
"text, expected_labels",
[
("Ich töte dich.", [UnsafeOutputFlag.VIOLENCE]),
("Ich möchte mich umbringen.", [UnsafeOutputFlag.SELF_HARM]),
("Du verdammtes Arschloch.", [UnsafeOutputFlag.INSULT]),
("Ich töte dich.", [SafetyOutputFlag.VIOLENCE]),
("Ich möchte mich umbringen.", [SafetyOutputFlag.SELF_HARM]),
("Du verdammtes Arschloch.", [SafetyOutputFlag.INSULT]),
(
"Gott ist gütig und wird uns von allem Bösen erlösen.",
[UnsafeOutputFlag.RELIGIOUS],
[SafetyOutputFlag.RELIGIOUS],
),
],
)
def test_safety_classifier_returns_flags_for_unsafe_input(
def test_safety_classifier_returns_flags_for_unsafe_german_input(
safety_classifier: SafetyClassifier,
text: str,
expected_labels: List[UnsafeOutputFlag],
expected_labels: List[SafetyOutputFlag],
) -> None:
input = SafetyClassifyInput(text=TextChunk(text), language=Language("de"))

Expand Down

0 comments on commit 2a6a6db

Please sign in to comment.