-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
0 parents
commit d3a44ca
Showing
7 changed files
with
416 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
# file: ~/.gitignore_global | ||
.idea | ||
/shelf/ | ||
.xml |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
from torch import nn | ||
|
||
|
||
class BertModel(nn.Module): | ||
def __init__(self, checkpoint, unique_labels): | ||
super(BertModel, self).__init__() | ||
self.bert = AutoModelForSequenceClassification.from_pretrained(checkpoint, | ||
num_labels=len(unique_labels), | ||
output_attentions=False, | ||
output_hidden_states=False) | ||
|
||
def forward(self, input_ids, attention_mask, labels=None): | ||
output = self.bert(input_ids=input_ids, attention_mask=attention_mask, labels=labels, return_dict=False) | ||
return output | ||
|
||
def save(self, path="model"): | ||
self.bert.save_pretrained(f"./{path}") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
![AviBert-cover](https://www.aeroturk.info/wp-content/uploads/2019/11/Aircraft-Systems.jpg) | ||
<h2 align="center">AviBert: on Classifying the news about Aircraft</h2> | ||
<p align="center"> | ||
Developed by <a href="https://github.com/ByUnal"> M.Cihat Unal </a> | ||
</p> | ||
|
||
## Overview | ||
|
||
This repository focuses on Aircraft, and we work towards developing an Aircraft-specific classification model on a multi-class development set by using BERT and its lightweight and heavyweight variants. Besides, introduces a pipeline that comprises data collection, data tagging and model training. | ||
Overall, since data and targets are unique, the presented model in this study is also a groundbreaker. Details of the dataset can be investigated further, and the results are compared by using macro-f1 and accuracy scores between models. | ||
|
||
|
||
## Setup | ||
Install the requirements. I've added torch to requirements.txt but you can prefer to install by yourself according to different cuda version and resources. | ||
```commandline | ||
pip install -r requirements.txt | ||
``` | ||
|
||
## Run the Code | ||
I've concluded hyperparameter tuning by using optuna, and therefore main.py fixed accordingly. Also, you can train standalone model by using *train_loop()* | ||
|
||
## Acknowledment | ||
Currently, I've prepared the paper of this project besides including data collection steps. However, we're doing an additional novel experiments on this topic. | ||
So, paper link/details will be shared as soon as the paper is published. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,265 @@ | ||
# -*- coding: utf-8 -*- | ||
"""hyperoptbert-newer-training.ipynb | ||
Automatically generated by Colaboratory. | ||
Original file is located at | ||
https://colab.research.google.com/drive/1P4RT9W1rjk6qtuwCnV4aH9nWpieyHEeN | ||
""" | ||
|
||
# !pip install transformers optuna nltk scikit-learn pandas | ||
|
||
import gc | ||
import os | ||
import random | ||
import warnings | ||
|
||
import numpy as np | ||
import optuna | ||
import pandas as pd | ||
import torch | ||
import torch.optim as optim | ||
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler | ||
from tqdm import tqdm | ||
from transformers import AutoTokenizer, get_linear_schedule_with_warmup | ||
|
||
from preprocess import preprocessing | ||
from BERT import BertModel | ||
from utils import createBertDataset, evaluate | ||
|
||
warnings.filterwarnings("ignore") | ||
|
||
gc.collect() | ||
torch.cuda.empty_cache() | ||
|
||
rs = 2 | ||
seed_val = rs | ||
random.seed(seed_val) | ||
np.random.seed(seed_val) | ||
os.environ['PYTHONHASHSEED'] = str(seed_val) | ||
torch.manual_seed(seed_val) | ||
torch.cuda.manual_seed_all(seed_val) | ||
torch.backends.cudnn.deterministic = True | ||
torch.backends.cudnn.benchmark = False | ||
|
||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | ||
|
||
df = pd.read_csv("cmp711_5class.csv") | ||
df = df.reset_index(drop=True) | ||
|
||
# Check how many labels are there in the dataset | ||
unique_labels = df.label.unique().tolist() | ||
|
||
# Map each label into its id representation and vice versa | ||
labels_to_ids = {k: v for v, k in enumerate(sorted(unique_labels))} | ||
ids_to_labels = {v: k for v, k in enumerate(sorted(unique_labels))} | ||
|
||
df["text"] = df["text"].apply(preprocessing) | ||
df["label"] = df["label"].apply(lambda x: labels_to_ids[x]) | ||
df.head() | ||
|
||
# Split data into train & test | ||
df_train, df_val, df_test = np.split(df.sample(frac=1, random_state=rs), [int(.70 * len(df)), int(.80 * len(df))]) | ||
|
||
|
||
# Function to calculate the accuracy of our predictions vs labels | ||
def flat_accuracy(preds, labels): | ||
pred_flat = np.argmax(preds, axis=1).flatten() | ||
labels_flat = labels.flatten() | ||
|
||
# accuracy_score(pred_flat, labels_flat) | ||
return np.sum(pred_flat == labels_flat) / len(labels_flat) | ||
|
||
|
||
def train_loop(model, optimizer, BATCH_SIZE, EPOCHS, PATIENCE): | ||
train_dataloader = DataLoader(train_dataset, sampler=RandomSampler(train_dataset), batch_size=BATCH_SIZE) | ||
|
||
# put device into GPU | ||
model = model.to(device) | ||
|
||
scheduler = get_linear_schedule_with_warmup(optimizer, | ||
num_warmup_steps=0, | ||
num_training_steps=len(train_dataloader) * EPOCHS) | ||
|
||
last_loss = 100 | ||
patience = 0 | ||
for epoch in range(EPOCHS): | ||
|
||
# print out active_run | ||
# print("Epoch: %s\n" % (epoch + 1)) | ||
|
||
model.train() | ||
loss_train_total = 0 | ||
acc_train_total = 0 | ||
|
||
loop = tqdm(enumerate(train_dataloader), leave=False, total=len(train_dataloader)) | ||
for step, batch in loop: | ||
# clear previously calculated gradients (Zero the gradients to start fresh next time.) | ||
optimizer.zero_grad() | ||
|
||
batch = [r.to(device) for r in batch] | ||
input_ids, attention_mask, labels = batch | ||
outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels) | ||
|
||
# Calculate loss & backpropagation | ||
loss = outputs[0] | ||
loss_train_total += loss.item() | ||
loss.backward() | ||
|
||
# Calculate accuracy | ||
logits = outputs[1] | ||
# print(logits) | ||
classes = torch.argmax(logits, dim=1) | ||
acc_train_total += torch.mean((classes == labels).float()) | ||
|
||
# Clip the norm of the gradients to 1.0. | ||
# This is to help prevent the "exploding gradients" problem. | ||
# torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) | ||
|
||
optimizer.step() | ||
scheduler.step() | ||
|
||
# Show progress while training | ||
loop.set_description(f'Epoch = {epoch + 1}/{EPOCHS}, training_loss: {(loss.item() / len(batch)):.3f}') | ||
|
||
train_loss = loss_train_total / len(train_dataloader) | ||
|
||
if train_loss >= last_loss: | ||
patience += 1 | ||
|
||
if patience == PATIENCE: | ||
print("Early Stopping!\n") | ||
with open("es-log.txt", "a+") as es: | ||
es.write(f"{MODEL_NAME} - stopped at {epoch} epoch\n") | ||
|
||
return model | ||
|
||
else: | ||
patience = 0 | ||
|
||
last_loss = train_loss | ||
|
||
# # Validation | ||
model.eval() | ||
|
||
val_dataloader = DataLoader(val_dataset, sampler=SequentialSampler(val_dataset), batch_size=32) | ||
|
||
loss_val_total = 0 | ||
total_eval_accuracy = 0 | ||
|
||
# print("\nValidating...") | ||
for batch in tqdm(val_dataloader, total=len(val_dataloader), leave=False): | ||
batch = tuple(b.to(device) for b in batch) | ||
|
||
input_ids, attention_mask, labels = batch | ||
|
||
with torch.no_grad(): | ||
outputs = model(input_ids, attention_mask, labels) | ||
|
||
loss = outputs[0] | ||
logits = outputs[1] | ||
loss_val_total += loss.item() | ||
|
||
logits = logits.detach().cpu().numpy() | ||
label_ids = labels.cpu().numpy() | ||
|
||
total_eval_accuracy += flat_accuracy(logits, label_ids) | ||
|
||
if epoch == EPOCHS - 1: | ||
print(f'Training Loss: {train_loss: .3f}') | ||
print(f'Train Acc.: {acc_train_total / len(train_dataloader)}') | ||
|
||
# Report the final accuracy for this validation run. | ||
val_acc_avg = total_eval_accuracy / len(val_dataloader) | ||
loss_val_avg = loss_val_total / len(val_dataloader) | ||
|
||
print('Val. Average loss: {:.3f}'.format(loss_val_avg)) | ||
print('Val. Average Acc.: {:.3f}\n'.format(val_acc_avg)) | ||
|
||
return model | ||
|
||
|
||
# Hyperparameter Optimization | ||
def objective(trial): | ||
gc.collect() | ||
torch.cuda.empty_cache() | ||
|
||
raw_model = BertModel(MODEL_NAME, unique_labels) | ||
print(f"{MODEL_NAME} - {trial.number} is training...") | ||
|
||
EPOCHS = trial.suggest_int("EPOCHS", low=4, high=7) | ||
BATCH_SIZE = trial.suggest_categorical('BATCH_SIZE', [16, 32, 48, 64]) | ||
LR = trial.suggest_categorical("LR", [1e-3, 3e-3, 3e-4, 1e-5, 3e-5, 5e-5]) | ||
WD = trial.suggest_categorical("WD", [1e-4, 1e-5, 2e-5]) | ||
OPT = trial.suggest_categorical("OPT", ["AdamW", "SGD", "RMSprop"]) | ||
|
||
optimizer = getattr(optim, OPT)(raw_model.parameters(), lr=LR, weight_decay=WD) | ||
|
||
parameters = { | ||
"model": raw_model, | ||
"optimizer": optimizer, | ||
"BATCH_SIZE": BATCH_SIZE, | ||
"EPOCHS": EPOCHS, | ||
"PATIENCE": 2 | ||
} | ||
|
||
trained_model = train_loop(**parameters) | ||
loss, f1, acc = evaluate(trained_model, test_dataset) | ||
|
||
if f1 > 0.63: | ||
trained_model.save(f"./{MODEL_NAME}-{trial.number}") | ||
|
||
return f1 | ||
|
||
|
||
if __name__ == "__main__": | ||
model_names = ["distilbert-base-uncased", "albert-base-v2", "huawei-noah/TinyBERT_General_6L_768D", "roberta-base", | ||
"bert-base-uncased", "google/electra-base-discriminator", "YituTech/conv-bert-base"] | ||
|
||
for model in model_names: | ||
MODEL_NAME = model | ||
|
||
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) | ||
|
||
# Create datasets For BERT | ||
train_dataset = createBertDataset(df_train, tokenizer) | ||
val_dataset = createBertDataset(df_val, tokenizer) | ||
test_dataset = createBertDataset(df_test, tokenizer) | ||
|
||
# We want to maximize the f1 | ||
study = optuna.create_study(study_name='airbert-hyperopt', | ||
direction='maximize', | ||
sampler=optuna.samplers.TPESampler(seed=rs)) | ||
|
||
# Optimize the objective using 50 different trials | ||
study.optimize(objective, n_trials=50) | ||
|
||
with open("param_new.txt", "a+") as file: | ||
trial = study.best_trial | ||
file.write(f"Model Name: {MODEL_NAME}\n") | ||
file.write(f"Best Score: {trial.value}\n") | ||
file.write("Best Params: \n") | ||
for key, value in trial.params.items(): | ||
file.write(" {}: {}\n".format(key, value)) | ||
file.write("*" * 60) | ||
file.write("\n") | ||
|
||
del train_dataset, val_dataset, test_dataset, tokenizer, MODEL_NAME | ||
|
||
"""- I realized in DistilBert training, Model overtfits the training data up to %93 accuracy score however it generalizes badly. | ||
- Also, torch.clip_norm function demolishes the model success rate, it shows that additional alghoritmhs are unnecessar for bert base models. | ||
""" | ||
|
||
# # See the evaluation of any model | ||
# load_model = BertModel("<model-name>", unique_labels) | ||
# load_model = load_model.to(device) | ||
# tokenizer = AutoTokenizer.from_pretrained("<model-name>") | ||
# test_dataset = createBertDataset(df_test, tokenizer) | ||
# evaluate(load_model, test_dataset) | ||
|
||
# test_loss, test_f1, test_acc = evaluate(load_model, test_dataset) | ||
|
||
# print(f'Test loss: {test_loss}') | ||
# print(f'F1 Score (Weighted): {test_f1}') | ||
# print(f'Acc Score: {test_acc}') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
import re | ||
import string | ||
|
||
import nltk | ||
nltk.download('stopwords') | ||
nltk.download('wordnet') | ||
nltk.download('omw-1.4') | ||
|
||
from nltk.corpus import stopwords | ||
from nltk.stem import WordNetLemmatizer | ||
|
||
stop_words = set(stopwords.words('english')) | ||
stop_words.add('subject') | ||
stop_words.add('http') | ||
|
||
def remove_stopwords(text): | ||
return " ".join([word for word in str(text).split() if word not in stop_words]) | ||
|
||
lemmatizer = WordNetLemmatizer() | ||
def lemmatize_words(text): | ||
return " ".join([lemmatizer.lemmatize(word) for word in text.split()]) | ||
|
||
def preprocessing(text): | ||
text = re.sub('[%s]' % re.escape(string.punctuation), '' , text) | ||
text = re.sub('[^A-Za-z0-9]+' ,' ', text) | ||
text = text.replace("\t", " ") | ||
text = text.replace("\n", " ") | ||
text = re.sub(' +', ' ', text) # remove extra whitespaces | ||
text = remove_stopwords(text) | ||
text = text.lower() | ||
text = lemmatize_words(text) | ||
|
||
return text |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
--extra-index-url https://download.pytorch.org/whl/cu116 | ||
torch | ||
torchvision | ||
torchaudio | ||
transformers | ||
optuna | ||
nltk | ||
scikit-learn | ||
pandas | ||
numpy |
Oops, something went wrong.