Skip to content

Commit

Permalink
Update evaluation script with newest version of fabricator
Browse files Browse the repository at this point in the history
  • Loading branch information
HallerPatrick committed Aug 6, 2023
1 parent ce5fa46 commit 4b2e09c
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 14 deletions.
4 changes: 3 additions & 1 deletion paper_experiments/fine_tune_ner/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ python -m pip install -e .
Install relevant requirements for experiment

```bash
python -m pip instal -r requirements.txt
python -m pip install -r requirements.txt
```

## Fine-tune Model
Expand Down Expand Up @@ -41,6 +41,8 @@ torchrun --nproc_per_node=2 train.py \

## Evaluate LLM with library

This will generate NER tokens based on the CONLL03 evaluation split and evaluate at it against the gold labels.

```bash
python evaluate.py --model_name_or_path "<HF_MODEL>"
```
19 changes: 6 additions & 13 deletions paper_experiments/fine_tune_ner/evaluate.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,15 @@
import argparse
import os
import random

from datasets import load_dataset, load_from_disk
from datasets import load_dataset, load_from_disk, Dataset
from haystack.nodes import PromptNode

from transformers import AutoModelForCausalLM, AutoTokenizer
from sklearn.metrics import classification_report, accuracy_score
from sklearn.preprocessing import MultiLabelBinarizer

from ai_dataset_generator import DatasetGenerator
from ai_dataset_generator.prompts import BasePrompt
from ai_dataset_generator.samplers import random_sampler
from ai_dataset_generator.dataset_transformations.text_classification import convert_label_ids_to_texts
from fabricator import DatasetGenerator
from fabricator.prompts import BasePrompt
from fabricator.samplers import random_sampler


ner_prompt = (
Expand All @@ -21,16 +18,14 @@

def main(args):

# model_name_or_path = "EleutherAI/pythia-70M-deduped"
dataset = load_dataset(args.dataset_name, split=args.split)
fewshot_examples = dataset.select(random.sample(range(len(dataset)), 3))


prompt = BasePrompt(
task_description=ner_prompt,
generate_data_for_column="ner_tags",
fewshot_example_columns="tokens",
label_options=None #{"O", 1: "B-PER", 2: "I-PER", 3: "B-ORG", 4: "I-ORG", 5: "B-LOC", 6: "I-LOC"},
label_options={0: "O", 1: "B-PER", 2: "I-PER", 3: "B-ORG", 4: "I-ORG", 5: "B-LOC", 6: "I-LOC"},
)

unlabeled_data = random_sampler(dataset, 30)
Expand All @@ -46,12 +41,10 @@ def main(args):


generator = DatasetGenerator(prompt_node)
generated_dataset = generator.generate(
generated_dataset: Dataset = generator.generate(
prompt_template=prompt,
fewshot_dataset=None,
unlabeled_dataset=unlabeled_data,
max_prompt_calls=30,
fewshot_examples_per_class=0,
timeout_per_prompt=2,
)

Expand Down

0 comments on commit 4b2e09c

Please sign in to comment.