From 8d8e8eeaaa3d7b4aa7fa554f956de7d27dafd99a Mon Sep 17 00:00:00 2001 From: Martin Achtner Date: Thu, 18 Apr 2024 14:25:05 +0200 Subject: [PATCH] feat: Add german prompt for safety classifier --- .pre-commit-config.yaml | 2 +- .../core/safety_classifier.py | 61 +++++++++++++------ tests/core/test_safety_classifier.py | 25 +++++++- 3 files changed, 69 insertions(+), 19 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index fe59b8084..a3e22d40b 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -51,7 +51,7 @@ repos: args: [ "-L", - "newyorker,te,responde,ist,als,oder,technik,sie,rouge,unter,juli,fiel,couldn,mke, vor,fille,ans", + "newyorker,te,responde,ist,als,oder,technik,sie,rouge,unter,juli,fiel,couldn,mke,vor,fille,ans,deine,Deine", ] exclude: '^(poetry\.lock|trace-viewer/.*|tests/connectors/retrievers/test_document_index_retriever\.py|src/intelligence_layer/use_cases/qa/multiple_chunk_qa.py|src/intelligence_layer/use_cases/summarize/.*|tests/connectors/retrievers/test_document_index_retriever\.py|src/intelligence_layer/use_cases/classify/keyword_extract.py|tests/use_cases/summarize/test_single_chunk_few_shot_summarize.py|tests/use_cases/summarize/very_long_text.txt)$' - repo: https://github.com/akaihola/darglint2 diff --git a/src/intelligence_layer/core/safety_classifier.py b/src/intelligence_layer/core/safety_classifier.py index 57736d426..804ef8627 100644 --- a/src/intelligence_layer/core/safety_classifier.py +++ b/src/intelligence_layer/core/safety_classifier.py @@ -1,17 +1,55 @@ from enum import Enum -from typing import Sequence +from typing import Sequence, Mapping from pydantic import BaseModel from intelligence_layer.core import ( CompleteInput, ControlModel, + Language, LuminousControlModel, Task, TaskSpan, TextChunk, ) +INSTRUCTION_CONFIGS = { + Language( + "en" + ): """Your task is to act as a safety classifier for text. There are the following categories: +-hatespeech, +-sexual, +-religious, +-medical, +-self_harm, +-illegal_activities, +-violence, +-insult, +-profanity + +Your response should be a single line containing the labels of which categories the text fulfills. +Return the label 'safe' if the text is safe. +Please make sure your response is a single line. +""", + Language( + "de" + ): """Deine Aufgabe besteht darin, als Sicherheitsklassifizierer für Text zu fungieren. Es gibt folgende Kategorien: +-hatespeech, +-sexual, +-religious, +-medical, +-self_harm, +-illegal_activities, +-violence, +-insult, +-profanity + +Deine Antwort solle eine Zeile mit Labels der jeweiligen Kategorien umfassen, die der Text erfüllt. +Gib das Label 'safe' zurück, wenn der Text sicher ist. +Bitte stelle sicher, dass deine Antwort eine Zeile lang ist. +""", +} + class UnsafeOutputFlag(Enum): HATESPEECH = "hatespeech" @@ -27,6 +65,7 @@ class UnsafeOutputFlag(Enum): class SafetyClassifyInput(BaseModel): text: TextChunk + language: Language = Language("en") class SafetyClassifyOutput(BaseModel): @@ -38,30 +77,18 @@ class SafetyClassifier(Task[SafetyClassifyInput, SafetyClassifyOutput]): def __init__( self, model: ControlModel | None = None, + instruction_configs: Mapping[Language, str] = INSTRUCTION_CONFIGS, ) -> None: self._model = model or LuminousControlModel("luminous-nextgen-7b-control-384k") + self._instruction_configs = instruction_configs def do_run( self, input: SafetyClassifyInput, task_span: TaskSpan ) -> SafetyClassifyOutput: - instruction = """Your task is to act as a safety classifier for text. There are the following categories: --hatespeech, --sexual, --religious, --medical, --self_harm, --illegal_activities, --violence, --insult, --profanity - -Your response should be a single line containing the labels of which categories the text fulfills. -Return the label 'safe' if the text is safe. -Please make sure your response is a single line. -""" + instruction = self._instruction_configs.get(input.language) if not instruction: - raise ValueError("Could not find `instruction` for safety classifier.") + raise ValueError(f"Could not find `prompt_config` for {input.language}.") completion = self._model.complete( CompleteInput( prompt=self._model.to_instruct_prompt(instruction, input.text), diff --git a/tests/core/test_safety_classifier.py b/tests/core/test_safety_classifier.py index 55a281ade..3aa1a06cb 100644 --- a/tests/core/test_safety_classifier.py +++ b/tests/core/test_safety_classifier.py @@ -3,7 +3,7 @@ import pytest from pytest import fixture -from intelligence_layer.core import LuminousControlModel, NoOpTracer, TextChunk +from intelligence_layer.core import LuminousControlModel, NoOpTracer, TextChunk, Language from intelligence_layer.core.safety_classifier import ( SafetyClassifier, SafetyClassifyInput, @@ -49,3 +49,26 @@ def test_safety_classifier_returns_flags_for_unsafe_input( output = safety_classifier.run(input, NoOpTracer()) assert output.flags == expected_labels + + +@pytest.mark.parametrize( + "text, expected_labels", + [ + ("Ich töte dich.", [UnsafeOutputFlag.VIOLENCE]), + ("Ich möchte mich umbringen.", [UnsafeOutputFlag.SELF_HARM]), + ("Du verdammtes Arschloch.", [UnsafeOutputFlag.INSULT]), + ( + "Gott ist gütig und wird uns von allem Bösen erlösen.", + [UnsafeOutputFlag.RELIGIOUS], + ), + ], +) +def test_safety_classifier_returns_flags_for_unsafe_input( + safety_classifier: SafetyClassifier, + text: str, + expected_labels: List[UnsafeOutputFlag], +) -> None: + input = SafetyClassifyInput(text=TextChunk(text), language=Language("de")) + + output = safety_classifier.run(input, NoOpTracer()) + assert output.flags == expected_labels