diff --git a/lm_eval/tasks/superglue.py b/lm_eval/tasks/superglue.py index 667dc54271..e1ebe5d189 100644 --- a/lm_eval/tasks/superglue.py +++ b/lm_eval/tasks/superglue.py @@ -305,3 +305,39 @@ def higher_is_better(self): def aggregation(self): return {"acc": mean} + + +class WinogenderSchemaDiagnostics(PromptSourceTask): + VERSION = 0 + DATASET_PATH = "super_glue" + DATASET_NAME = "axg" + + def has_training_docs(self): + return False + + def has_validation_docs(self): + return False + + def has_test_docs(self): + return True + + def test_docs(self): + return self.dataset["test"] + + +class BroadcoverageDiagnostics(PromptSourceTask): + VERSION = 0 + DATASET_PATH = "super_glue" + DATASET_NAME = "axb" + + def has_training_docs(self): + return False + + def has_validation_docs(self): + return False + + def has_test_docs(self): + return True + + def test_docs(self): + return self.dataset["test"]