-
Notifications
You must be signed in to change notification settings - Fork 6
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
F13-114 simple safety classification task #758
Changes from all commits
3706c49
1b1914d
facbeba
fba585c
79e6739
f160008
091e6e6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is 'core' the correct module for this task? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would probably put it into use-case, but this is not 100% sure and up to discussion, since just having this safety classification is not a use case itself. However, since this is very close to the prompt based classify, it might be better there. |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
from enum import Enum | ||
from typing import Mapping, Sequence | ||
|
||
from pydantic import BaseModel | ||
|
||
from intelligence_layer.core import ( | ||
CompleteInput, | ||
ControlModel, | ||
Language, | ||
LuminousControlModel, | ||
Task, | ||
TaskSpan, | ||
TextChunk, | ||
) | ||
|
||
|
||
class UnsafeOutputFlag(Enum): | ||
HATESPEECH = "hatespeech" | ||
SEXUAL = "sexual" | ||
RELIGIOUS = "religious" | ||
MEDICAL = "medical" | ||
SELF_HARM = "self_harm" | ||
ILLEGAL_ACTIVITIES = "illegal_activities" | ||
VIOLENCE = "violence" | ||
INSULT = "insult" | ||
PROFANITY = "profanity" | ||
Comment on lines
+17
to
+26
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. where did you get these labels from, are they from some sort of requirement or did you just think of them? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
|
||
|
||
ALL_FLAGS = "\n-".join([e.value for e in UnsafeOutputFlag]) | ||
|
||
INSTRUCTION_CONFIGS = { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The prompts are WIP and subject to findings from https://aleph-alpha.atlassian.net/wiki/spaces/EN/pages/605716507/Eliminate+all+of+our+key+customers+safety-related+blockers+IL+Safety+Module There exists a PromptBasedClassify task, however it comes with some caveats and the tests looked less promissing than with this custom prompt logic There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I suspect this "safety-classification" task will be quite large / complex in the end, because solving this task is very hard in itself. Have you looked at alternatives to implement this? Afaik there are whole python libraries who try to achieve this, like guardrails etc |
||
Language( | ||
"en" | ||
): f"""Your task is to act as a safety classifier for text. There are the following categories: | ||
-{ALL_FLAGS} | ||
|
||
Your response should be a single line containing the labels of which categories the text fulfills. | ||
Return the label 'safe' if the text is safe. | ||
Please make sure your response is a single line. | ||
""", | ||
Language( | ||
"de" | ||
): f"""Deine Aufgabe besteht darin, als Sicherheitsklassifizierer für Text zu fungieren. Es gibt folgende Kategorien: | ||
-{ALL_FLAGS} | ||
|
||
Deine Antwort solle eine Zeile mit Labels der jeweiligen Kategorien umfassen, die der Text erfüllt. | ||
Gib das Label 'safe' zurück, wenn der Text sicher ist. | ||
Bitte stelle sicher, dass deine Antwort eine Zeile lang ist. | ||
""", | ||
} | ||
|
||
|
||
class SafetyClassifyInput(BaseModel): | ||
text: TextChunk | ||
language: Language = Language("en") | ||
|
||
|
||
class SafetyClassifyOutput(BaseModel): | ||
flags: Sequence[UnsafeOutputFlag] # "safe" is an empty list. | ||
# Flag is only included if the task determines it should be flagged | ||
|
||
|
||
class SafetyClassifier(Task[SafetyClassifyInput, SafetyClassifyOutput]): | ||
def __init__( | ||
self, | ||
model: ControlModel | None = None, | ||
instruction_configs: Mapping[Language, str] = INSTRUCTION_CONFIGS, | ||
) -> None: | ||
self._model = model or LuminousControlModel("luminous-nextgen-7b-control-384k") | ||
self._instruction_configs = instruction_configs | ||
|
||
def do_run( | ||
self, input: SafetyClassifyInput, task_span: TaskSpan | ||
) -> SafetyClassifyOutput: | ||
instruction = self._instruction_configs.get(input.language) | ||
|
||
if not instruction: | ||
raise ValueError(f"Could not find `prompt_config` for {input.language}.") | ||
completion = self._model.complete( | ||
CompleteInput( | ||
prompt=self._model.to_instruct_prompt(instruction, input.text), | ||
), | ||
task_span, | ||
) | ||
completion_str = str(completion.completions[0].completion) | ||
potential_flags = completion_str.split(", ") | ||
flags = [ | ||
UnsafeOutputFlag(flag.strip()) | ||
for flag in potential_flags | ||
if hasattr(UnsafeOutputFlag, flag.strip().upper()) | ||
] | ||
return SafetyClassifyOutput(flags=flags) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,108 @@ | ||
from typing import List | ||
|
||
import pytest | ||
from pytest import fixture | ||
|
||
from intelligence_layer.core import ( | ||
Language, | ||
LuminousControlModel, | ||
NoOpTracer, | ||
TextChunk, | ||
) | ||
from intelligence_layer.core.safety_classifier import ( | ||
SafetyClassifier, | ||
SafetyClassifyInput, | ||
UnsafeOutputFlag, | ||
) | ||
|
||
|
||
@fixture | ||
def safety_classifier( | ||
luminous_control_model: LuminousControlModel, | ||
) -> SafetyClassifier: | ||
return SafetyClassifier(model=None) | ||
|
||
|
||
@fixture | ||
def long_text() -> str: | ||
return """Green Day is an American rock band formed in the East Bay of California in 1987 by lead vocalist and guitarist Billie Joe Armstrong, together with bassist and backing vocalist Mike Dirnt. For most of the band's career they have been a power trio[4] with drummer Tré Cool, who replaced John Kiffmeyer in 1990 before the recording of the band's second studio album, Kerplunk (1991). Before taking its current name in 1989, Green Day was called Blood Rage, then Sweet Children and they were part of the late 1980s/early 1990s Bay Area punk scene that emerged from the 924 Gilman Street club in Berkeley, California. The band's early releases were with the independent record label Lookout! Records. In 1994, their major-label debut Dookie, released through Reprise Records, became a breakout success and eventually shipped over 10 million copies in the U.S. Alongside fellow California punk bands Bad Religion, the Offspring, Rancid, NOFX, Pennywise and Social Distortion, Green Day is credited with popularizing mainstream interest in punk rock in the U.S. | ||
Though the albums Insomniac (1995), Nimrod (1997) and Warning (2000) did not match the success of Dookie, they were still successful, with the first two reaching double platinum status, while the last achieved gold. Green Day's seventh album, a rock opera called American Idiot (2004), found popularity with a younger generation, selling six million copies in the U.S. Their next album, 21st Century Breakdown, was released in 2009 and achieved the band's best chart performance. It was followed by a trilogy of albums, ¡Uno!, ¡Dos!, and ¡Tré!, released in September, November, and December 2012, respectively. The trilogy did not perform as well as expected commercially, in comparison to their previous albums, largely due to a lack of promotion and Armstrong entering rehab. Their twelfth studio album, Revolution Radio, was released in October 2016 and became their third to debut at No. 1 on the Billboard 200. The band's thirteenth studio album, Father of All Motherfuckers, was released on February 7, 2020. The band's fourteenth studio album, titled Saviors, is planned to come out on January 19, 2024.[5][6] | ||
In 2010, a stage adaptation of American Idiot debuted on Broadway. The musical was nominated for three Tony Awards: Best Musical, Best Scenic Design and Best Lighting Design, winning the latter two. The band was inducted into the Rock and Roll Hall of Fame in 2015, their first year of eligibility. Members of the band have collaborated on the side projects Pinhead Gunpowder, The Network, Foxboro Hot Tubs, The Longshot and The Coverups. They have also worked on solo careers. | ||
Green Day has sold roughly 75 million records worldwide,[7][unreliable source?] making them one of the world's best-selling artists. The group has been nominated for 20 Grammy awards and has won five of them with Best Alternative Album for Dookie, Best Rock Album for American Idiot and 21st Century Breakdown, Record of the Year for "Boulevard of Broken Dreams" and Best Musical Show Album for American Idiot: The Original Broadway Cast Recording.""" | ||
|
||
|
||
def test_safety_classifier_returns_no_flags_for_safe_input( | ||
safety_classifier: SafetyClassifier, | ||
) -> None: | ||
text = "This is a safe text" | ||
input = SafetyClassifyInput(text=TextChunk(text)) | ||
|
||
output = safety_classifier.run(input, NoOpTracer()) | ||
assert len(output.flags) == 0 | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"text, expected_labels", | ||
[ | ||
("I will kill you.", [UnsafeOutputFlag.VIOLENCE]), | ||
("I want to kill myself.", [UnsafeOutputFlag.SELF_HARM]), | ||
("You fucking asshole.", [UnsafeOutputFlag.INSULT]), | ||
( | ||
"The holy lord is going to deliver us from evil.", | ||
[UnsafeOutputFlag.RELIGIOUS], | ||
), | ||
], | ||
) | ||
def test_safety_classifier_returns_flags_for_unsafe_english_input( | ||
safety_classifier: SafetyClassifier, | ||
text: str, | ||
expected_labels: List[UnsafeOutputFlag], | ||
) -> None: | ||
input = SafetyClassifyInput(text=TextChunk(text)) | ||
|
||
output = safety_classifier.run(input, NoOpTracer()) | ||
assert output.flags == expected_labels | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"text, expected_labels", | ||
[ | ||
("Ich töte dich.", [UnsafeOutputFlag.VIOLENCE]), | ||
("Ich möchte mich umbringen.", [UnsafeOutputFlag.SELF_HARM]), | ||
("Du verdammtes Arschloch.", [UnsafeOutputFlag.INSULT]), | ||
( | ||
"Gott ist gütig und wird uns von allem Bösen erlösen.", | ||
[UnsafeOutputFlag.RELIGIOUS], | ||
), | ||
], | ||
) | ||
def test_safety_classifier_returns_flags_for_unsafe_german_input( | ||
safety_classifier: SafetyClassifier, | ||
text: str, | ||
expected_labels: List[UnsafeOutputFlag], | ||
) -> None: | ||
input = SafetyClassifyInput(text=TextChunk(text), language=Language("de")) | ||
|
||
output = safety_classifier.run(input, NoOpTracer()) | ||
assert output.flags == expected_labels | ||
|
||
|
||
def test_safety_classifier_returns_flags_for_safe_long_input( | ||
safety_classifier: SafetyClassifier, long_text: str | ||
) -> None: | ||
input = SafetyClassifyInput(text=TextChunk(long_text), language=Language("en")) | ||
|
||
output = safety_classifier.run(input, NoOpTracer()) | ||
assert len(output.flags) == 0 | ||
|
||
|
||
def test_safety_classifier_returns_flags_for_unsafe_long_input( | ||
safety_classifier: SafetyClassifier, long_text: str | ||
) -> None: | ||
unsafe_text = "\nThe band is know for their history of open drug abuse and has made some headline about violence towards minors." | ||
input = SafetyClassifyInput( | ||
text=TextChunk(long_text + unsafe_text), language=Language("en") | ||
) | ||
|
||
output = safety_classifier.run(input, NoOpTracer()) | ||
assert len(output.flags) == 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe we should mark this as beta or something? have you evaluated the performance yet?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am happy to mark it as beta here, or maybe not even 'document' it for now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The simplicity of this task is due to time pressure to deliver something in this direction for the upcoming release, therefore we have not evaluated it yet.
There also exists a prompt based classifier which maybe could be adapted to also return MultiLabelClassifyOutput.
These are all topics up to future improvement.