From 0a36f6ccd4bf17d119a0b884c812fdec454b63db Mon Sep 17 00:00:00 2001 From: stolzenp Date: Thu, 8 Feb 2024 22:15:59 +0100 Subject: [PATCH] use HfArgumentParser() --- src/small_model_training/config.json | 14 +++++ .../text_classification.py | 54 +++++++------------ .../training_parameters.json | 4 -- 3 files changed, 33 insertions(+), 39 deletions(-) create mode 100644 src/small_model_training/config.json delete mode 100644 src/small_model_training/training_parameters.json diff --git a/src/small_model_training/config.json b/src/small_model_training/config.json new file mode 100644 index 0000000..a807b77 --- /dev/null +++ b/src/small_model_training/config.json @@ -0,0 +1,14 @@ +{ + "model_name_or_path": "bert-base-uncased", + "tokenizer_name": "distilbert-base-uncased", + "output_dir":"my_awesome_model", + "learning_rate":2e-5, + "per_device_train_batch_size":16, + "per_device_eval_batch_size":16, + "num_train_epochs":2, + "weight_decay":0.01, + "evaluation_strategy":"epoch", + "save_strategy":"epoch", + "load_best_model_at_end":true, + "push_to_hub":false +} \ No newline at end of file diff --git a/src/small_model_training/text_classification.py b/src/small_model_training/text_classification.py index 1bda908..9c17e3c 100644 --- a/src/small_model_training/text_classification.py +++ b/src/small_model_training/text_classification.py @@ -1,16 +1,28 @@ import numpy as np +from dataclasses import dataclass, field from transformers import AutoTokenizer from transformers import DataCollatorWithPadding -from transformers import AutoModelForSequenceClassification, TrainingArguments, Trainer +from transformers import AutoModelForSequenceClassification, TrainingArguments, Trainer, HfArgumentParser from datasets import load_dataset -import json import evaluate +@dataclass +class ModelArguments: + model_name_or_path: str = field( + metadata={"help": "Path to pretrained model or model from huggingface.co/models"} + ) + tokenizer_name: str = field( + metadata={"help": "Path to pretrained tokenizer or model from huggingface.co/models"} + ) def get_influential_subset(dataset): - # get parameters from dict - data = get_training_parameters() - small_model = data['small_model'] - batch_size = data['batch_size'] + # get parameters from config + parser = HfArgumentParser((ModelArguments, TrainingArguments)) + model_args, training_args = parser.parse_json_file('config.json') + + tokenizer = AutoTokenizer.from_pretrained(model_args.tokenizer_name) + + def preprocess_function(examples): + return tokenizer(examples["text"], truncation=True) tokenized_imdb = dataset.map(preprocess_function, batched=True) @@ -20,20 +32,7 @@ def get_influential_subset(dataset): label2id = {"NEGATIVE": 0, "POSITIVE": 1} model = AutoModelForSequenceClassification.from_pretrained( - "distilbert-base-uncased", num_labels=2, id2label=id2label, label2id=label2id - ) - - training_args = TrainingArguments( - output_dir="my_awesome_model", - learning_rate=2e-5, - per_device_train_batch_size=16, - per_device_eval_batch_size=16, - num_train_epochs=2, - weight_decay=0.01, - evaluation_strategy="epoch", - save_strategy="epoch", - load_best_model_at_end=True, - push_to_hub=False, + model_args.model_name_or_path, num_labels=2, id2label=id2label, label2id=label2id ) trainer = Trainer( @@ -55,20 +54,6 @@ def get_influential_subset(dataset): # TO-DO: check for pre-processing return inf_subset -def get_training_parameters(): - - # open config file - f = open('training_parameters.json') - - # return json object as dict - data = json.load(f) - - # close file - f.close() - return data - -def preprocess_function(examples): - return tokenizer(examples["text"], truncation=True) def compute_metrics(eval_pred): predictions, labels = eval_pred @@ -76,7 +61,6 @@ def compute_metrics(eval_pred): return accuracy.compute(predictions=predictions, references=labels) accuracy = evaluate.load("accuracy") -tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased") # example dataset for debugging imdb = load_dataset("imdb") diff --git a/src/small_model_training/training_parameters.json b/src/small_model_training/training_parameters.json deleted file mode 100644 index b715c8e..0000000 --- a/src/small_model_training/training_parameters.json +++ /dev/null @@ -1,4 +0,0 @@ -{ - "small_model": "bert-base-uncased", - "batch_size": 128 -} \ No newline at end of file