Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

F13-114 simple safety classification task #758

Closed
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ repos:
args:
[
"-L",
"newyorker,te,responde,ist,als,oder,technik,sie,rouge,unter,juli,fiel,couldn,mke, vor,fille,ans",
"newyorker,te,responde,ist,als,oder,technik,sie,rouge,unter,juli,fiel,couldn,mke,vor,fille,ans,deine,Deine",
]
exclude: '^(poetry\.lock|trace-viewer/.*|tests/connectors/retrievers/test_document_index_retriever\.py|src/intelligence_layer/use_cases/qa/multiple_chunk_qa.py|src/intelligence_layer/use_cases/summarize/.*|tests/connectors/retrievers/test_document_index_retriever\.py|src/intelligence_layer/use_cases/classify/keyword_extract.py|tests/use_cases/summarize/test_single_chunk_few_shot_summarize.py|tests/use_cases/summarize/very_long_text.txt)$'
- repo: https://github.com/akaihola/darglint2
Expand Down
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@


### New Features
- feature: New `SafetyClassifier` allows to flag safe/unsafe text
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe we should mark this as beta or something? have you evaluated the performance yet?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am happy to mark it as beta here, or maybe not even 'document' it for now.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The simplicity of this task is due to time pressure to deliver something in this direction for the upcoming release, therefore we have not evaluated it yet.

There also exists a prompt based classifier which maybe could be adapted to also return MultiLabelClassifyOutput.

These are all topics up to future improvement.


### Fixes
- fix: `ChunkWithIndices` now additionally returns end_index
Expand Down
1 change: 1 addition & 0 deletions src/intelligence_layer/core/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,7 @@ def __init__(
"luminous-base-control-20240215",
"luminous-extended-control-20240215",
"luminous-supreme-control-20240215",
"luminous-nextgen-7b-control-384k",
] = "luminous-base-control",
client: Optional[AlephAlphaClientProtocol] = None,
) -> None:
Expand Down
92 changes: 92 additions & 0 deletions src/intelligence_layer/core/safety_classifier.py

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is 'core' the correct module for this task?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would probably put it into use-case, but this is not 100% sure and up to discussion, since just having this safety classification is not a use case itself. However, since this is very close to the prompt based classify, it might be better there.

Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
from enum import Enum
from typing import Mapping, Sequence

from pydantic import BaseModel

from intelligence_layer.core import (
CompleteInput,
ControlModel,
Language,
LuminousControlModel,
Task,
TaskSpan,
TextChunk,
)


class UnsafeOutputFlag(Enum):
HATESPEECH = "hatespeech"
SEXUAL = "sexual"
RELIGIOUS = "religious"
MEDICAL = "medical"
SELF_HARM = "self_harm"
ILLEGAL_ACTIVITIES = "illegal_activities"
VIOLENCE = "violence"
INSULT = "insult"
PROFANITY = "profanity"
Comment on lines +17 to +26
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

where did you get these labels from, are they from some sort of requirement or did you just think of them?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.



ALL_FLAGS = "\n-".join([e.value for e in UnsafeOutputFlag])

INSTRUCTION_CONFIGS = {

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The prompts are WIP and subject to findings from https://aleph-alpha.atlassian.net/wiki/spaces/EN/pages/605716507/Eliminate+all+of+our+key+customers+safety-related+blockers+IL+Safety+Module

There exists a PromptBasedClassify task, however it comes with some caveats and the tests looked less promissing than with this custom prompt logic

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suspect this "safety-classification" task will be quite large / complex in the end, because solving this task is very hard in itself. Have you looked at alternatives to implement this? Afaik there are whole python libraries who try to achieve this, like guardrails etc

Language(
"en"
): f"""Your task is to act as a safety classifier for text. There are the following categories:
-{ALL_FLAGS}

Your response should be a single line containing the labels of which categories the text fulfills.
Return the label 'safe' if the text is safe.
Please make sure your response is a single line.
""",
Language(
"de"
): f"""Deine Aufgabe besteht darin, als Sicherheitsklassifizierer für Text zu fungieren. Es gibt folgende Kategorien:
-{ALL_FLAGS}

Deine Antwort solle eine Zeile mit Labels der jeweiligen Kategorien umfassen, die der Text erfüllt.
Gib das Label 'safe' zurück, wenn der Text sicher ist.
Bitte stelle sicher, dass deine Antwort eine Zeile lang ist.
""",
}


class SafetyClassifyInput(BaseModel):
text: TextChunk
language: Language = Language("en")


class SafetyClassifyOutput(BaseModel):
flags: Sequence[UnsafeOutputFlag] # "safe" is an empty list.
# Flag is only included if the task determines it should be flagged


class SafetyClassifier(Task[SafetyClassifyInput, SafetyClassifyOutput]):
def __init__(
self,
model: ControlModel | None = None,
instruction_configs: Mapping[Language, str] = INSTRUCTION_CONFIGS,
) -> None:
self._model = model or LuminousControlModel("luminous-nextgen-7b-control-384k")
self._instruction_configs = instruction_configs

def do_run(
self, input: SafetyClassifyInput, task_span: TaskSpan
) -> SafetyClassifyOutput:
instruction = self._instruction_configs.get(input.language)

if not instruction:
raise ValueError(f"Could not find `prompt_config` for {input.language}.")
completion = self._model.complete(
CompleteInput(
prompt=self._model.to_instruct_prompt(instruction, input.text),
),
task_span,
)
completion_str = str(completion.completions[0].completion)
potential_flags = completion_str.split(", ")
flags = [
UnsafeOutputFlag(flag.strip())
for flag in potential_flags
if hasattr(UnsafeOutputFlag, flag.strip().upper())
]
return SafetyClassifyOutput(flags=flags)
108 changes: 108 additions & 0 deletions tests/core/test_safety_classifier.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
from typing import List

import pytest
from pytest import fixture

from intelligence_layer.core import (
Language,
LuminousControlModel,
NoOpTracer,
TextChunk,
)
from intelligence_layer.core.safety_classifier import (
SafetyClassifier,
SafetyClassifyInput,
UnsafeOutputFlag,
)


@fixture
def safety_classifier(
luminous_control_model: LuminousControlModel,
) -> SafetyClassifier:
return SafetyClassifier(model=None)


@fixture
def long_text() -> str:
return """Green Day is an American rock band formed in the East Bay of California in 1987 by lead vocalist and guitarist Billie Joe Armstrong, together with bassist and backing vocalist Mike Dirnt. For most of the band's career they have been a power trio[4] with drummer Tré Cool, who replaced John Kiffmeyer in 1990 before the recording of the band's second studio album, Kerplunk (1991). Before taking its current name in 1989, Green Day was called Blood Rage, then Sweet Children and they were part of the late 1980s/early 1990s Bay Area punk scene that emerged from the 924 Gilman Street club in Berkeley, California. The band's early releases were with the independent record label Lookout! Records. In 1994, their major-label debut Dookie, released through Reprise Records, became a breakout success and eventually shipped over 10 million copies in the U.S. Alongside fellow California punk bands Bad Religion, the Offspring, Rancid, NOFX, Pennywise and Social Distortion, Green Day is credited with popularizing mainstream interest in punk rock in the U.S.
Though the albums Insomniac (1995), Nimrod (1997) and Warning (2000) did not match the success of Dookie, they were still successful, with the first two reaching double platinum status, while the last achieved gold. Green Day's seventh album, a rock opera called American Idiot (2004), found popularity with a younger generation, selling six million copies in the U.S. Their next album, 21st Century Breakdown, was released in 2009 and achieved the band's best chart performance. It was followed by a trilogy of albums, ¡Uno!, ¡Dos!, and ¡Tré!, released in September, November, and December 2012, respectively. The trilogy did not perform as well as expected commercially, in comparison to their previous albums, largely due to a lack of promotion and Armstrong entering rehab. Their twelfth studio album, Revolution Radio, was released in October 2016 and became their third to debut at No. 1 on the Billboard 200. The band's thirteenth studio album, Father of All Motherfuckers, was released on February 7, 2020. The band's fourteenth studio album, titled Saviors, is planned to come out on January 19, 2024.[5][6]
In 2010, a stage adaptation of American Idiot debuted on Broadway. The musical was nominated for three Tony Awards: Best Musical, Best Scenic Design and Best Lighting Design, winning the latter two. The band was inducted into the Rock and Roll Hall of Fame in 2015, their first year of eligibility. Members of the band have collaborated on the side projects Pinhead Gunpowder, The Network, Foxboro Hot Tubs, The Longshot and The Coverups. They have also worked on solo careers.
Green Day has sold roughly 75 million records worldwide,[7][unreliable source?] making them one of the world's best-selling artists. The group has been nominated for 20 Grammy awards and has won five of them with Best Alternative Album for Dookie, Best Rock Album for American Idiot and 21st Century Breakdown, Record of the Year for "Boulevard of Broken Dreams" and Best Musical Show Album for American Idiot: The Original Broadway Cast Recording."""


def test_safety_classifier_returns_no_flags_for_safe_input(
safety_classifier: SafetyClassifier,
) -> None:
text = "This is a safe text"
input = SafetyClassifyInput(text=TextChunk(text))

output = safety_classifier.run(input, NoOpTracer())
assert len(output.flags) == 0


@pytest.mark.parametrize(
"text, expected_labels",
[
("I will kill you.", [UnsafeOutputFlag.VIOLENCE]),
("I want to kill myself.", [UnsafeOutputFlag.SELF_HARM]),
("You fucking asshole.", [UnsafeOutputFlag.INSULT]),
(
"The holy lord is going to deliver us from evil.",
[UnsafeOutputFlag.RELIGIOUS],
),
],
)
def test_safety_classifier_returns_flags_for_unsafe_english_input(
safety_classifier: SafetyClassifier,
text: str,
expected_labels: List[UnsafeOutputFlag],
) -> None:
input = SafetyClassifyInput(text=TextChunk(text))

output = safety_classifier.run(input, NoOpTracer())
assert output.flags == expected_labels


@pytest.mark.parametrize(
"text, expected_labels",
[
("Ich töte dich.", [UnsafeOutputFlag.VIOLENCE]),
("Ich möchte mich umbringen.", [UnsafeOutputFlag.SELF_HARM]),
("Du verdammtes Arschloch.", [UnsafeOutputFlag.INSULT]),
(
"Gott ist gütig und wird uns von allem Bösen erlösen.",
[UnsafeOutputFlag.RELIGIOUS],
),
],
)
def test_safety_classifier_returns_flags_for_unsafe_german_input(
safety_classifier: SafetyClassifier,
text: str,
expected_labels: List[UnsafeOutputFlag],
) -> None:
input = SafetyClassifyInput(text=TextChunk(text), language=Language("de"))

output = safety_classifier.run(input, NoOpTracer())
assert output.flags == expected_labels


def test_safety_classifier_returns_flags_for_safe_long_input(
safety_classifier: SafetyClassifier, long_text: str
) -> None:
input = SafetyClassifyInput(text=TextChunk(long_text), language=Language("en"))

output = safety_classifier.run(input, NoOpTracer())
assert len(output.flags) == 0


def test_safety_classifier_returns_flags_for_unsafe_long_input(
safety_classifier: SafetyClassifier, long_text: str
) -> None:
unsafe_text = "\nThe band is know for their history of open drug abuse and has made some headline about violence towards minors."
input = SafetyClassifyInput(
text=TextChunk(long_text + unsafe_text), language=Language("en")
)

output = safety_classifier.run(input, NoOpTracer())
assert len(output.flags) == 1