Skip to content

Commit

Permalink
feat: Add german prompt for safety classifier
Browse files Browse the repository at this point in the history
  • Loading branch information
MartinAchtnerAA authored and SebastianWolfschmidtAA committed Apr 18, 2024
1 parent b0f88ae commit fd578fb
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 19 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ repos:
args:
[
"-L",
"newyorker,te,responde,ist,als,oder,technik,sie,rouge,unter,juli,fiel,couldn,mke, vor,fille,ans",
"newyorker,te,responde,ist,als,oder,technik,sie,rouge,unter,juli,fiel,couldn,mke,vor,fille,ans,deine,Deine",
]
exclude: '^(poetry\.lock|trace-viewer/.*|tests/connectors/retrievers/test_document_index_retriever\.py|src/intelligence_layer/use_cases/qa/multiple_chunk_qa.py|src/intelligence_layer/use_cases/summarize/.*|tests/connectors/retrievers/test_document_index_retriever\.py|src/intelligence_layer/use_cases/classify/keyword_extract.py|tests/use_cases/summarize/test_single_chunk_few_shot_summarize.py|tests/use_cases/summarize/very_long_text.txt)$'
- repo: https://github.com/akaihola/darglint2
Expand Down
61 changes: 44 additions & 17 deletions src/intelligence_layer/core/safety_classifier.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,55 @@
from enum import Enum
from typing import Sequence
from typing import Sequence, Mapping

from pydantic import BaseModel

from intelligence_layer.core import (
CompleteInput,
ControlModel,
Language,
LuminousControlModel,
Task,
TaskSpan,
TextChunk,
)

INSTRUCTION_CONFIGS = {
Language(
"en"
): """Your task is to act as a safety classifier for text. There are the following categories:
-hatespeech,
-sexual,
-religious,
-medical,
-self_harm,
-illegal_activities,
-violence,
-insult,
-profanity
Your response should be a single line containing the labels of which categories the text fulfills.
Return the label 'safe' if the text is safe.
Please make sure your response is a single line.
""",
Language(
"de"
): """Deine Aufgabe besteht darin, als Sicherheitsklassifizierer für Text zu fungieren. Es gibt folgende Kategorien:
-hatespeech,
-sexual,
-religious,
-medical,
-self_harm,
-illegal_activities,
-violence,
-insult,
-profanity
Deine Antwort solle eine Zeile mit Labels der jeweiligen Kategorien umfassen, die der Text erfüllt.
Gib das Label 'safe' zurück, wenn der Text sicher ist.
Bitte stelle sicher, dass deine Antwort eine Zeile lang ist.
""",
}


class UnsafeOutputFlag(Enum):
HATESPEECH = "hatespeech"
Expand All @@ -27,6 +65,7 @@ class UnsafeOutputFlag(Enum):

class SafetyClassifyInput(BaseModel):
text: TextChunk
language: Language = Language("en")


class SafetyClassifyOutput(BaseModel):
Expand All @@ -38,30 +77,18 @@ class SafetyClassifier(Task[SafetyClassifyInput, SafetyClassifyOutput]):
def __init__(
self,
model: ControlModel | None = None,
instruction_configs: Mapping[Language, str] = INSTRUCTION_CONFIGS,
) -> None:
self._model = model or LuminousControlModel("luminous-nextgen-7b-control-384k")
self._instruction_configs = instruction_configs

def do_run(
self, input: SafetyClassifyInput, task_span: TaskSpan
) -> SafetyClassifyOutput:
instruction = """Your task is to act as a safety classifier for text. There are the following categories:
-hatespeech,
-sexual,
-religious,
-medical,
-self_harm,
-illegal_activities,
-violence,
-insult,
-profanity
Your response should be a single line containing the labels of which categories the text fulfills.
Return the label 'safe' if the text is safe.
Please make sure your response is a single line.
"""
instruction = self._instruction_configs.get(input.language)

if not instruction:
raise ValueError("Could not find `instruction` for safety classifier.")
raise ValueError(f"Could not find `prompt_config` for {input.language}.")
completion = self._model.complete(
CompleteInput(
prompt=self._model.to_instruct_prompt(instruction, input.text),
Expand Down
25 changes: 24 additions & 1 deletion tests/core/test_safety_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import pytest
from pytest import fixture

from intelligence_layer.core import LuminousControlModel, NoOpTracer, TextChunk
from intelligence_layer.core import LuminousControlModel, NoOpTracer, TextChunk, Language
from intelligence_layer.core.safety_classifier import (
SafetyClassifier,
SafetyClassifyInput,
Expand Down Expand Up @@ -49,3 +49,26 @@ def test_safety_classifier_returns_flags_for_unsafe_input(

output = safety_classifier.run(input, NoOpTracer())
assert output.flags == expected_labels


@pytest.mark.parametrize(
"text, expected_labels",
[
("Ich töte dich.", [UnsafeOutputFlag.VIOLENCE]),
("Ich möchte mich umbringen.", [UnsafeOutputFlag.SELF_HARM]),
("Du verdammtes Arschloch.", [UnsafeOutputFlag.INSULT]),
(
"Gott ist gütig und wird uns von allem Bösen erlösen.",
[UnsafeOutputFlag.RELIGIOUS],
),
],
)
def test_safety_classifier_returns_flags_for_unsafe_input(
safety_classifier: SafetyClassifier,
text: str,
expected_labels: List[UnsafeOutputFlag],
) -> None:
input = SafetyClassifyInput(text=TextChunk(text), language=Language("de"))

output = safety_classifier.run(input, NoOpTracer())
assert output.flags == expected_labels

0 comments on commit fd578fb

Please sign in to comment.