Skip to content

Commit

Permalink
Merge pull request #85 from tanishq-ids/changes
Browse files Browse the repository at this point in the history
Added code for saving metrics and a relevance to kpi-curation bridge
  • Loading branch information
tanishq-ids authored Dec 2, 2024
2 parents 212826f + ce23dab commit 446fa2e
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,16 @@
import shutil
import pandas as pd
import torch
import json
import time
from transformers import (
TrainingArguments,
Trainer,
AutoModelForSequenceClassification,
AutoTokenizer,
)
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
from torch.utils.data import Dataset


Expand Down Expand Up @@ -219,8 +221,23 @@ def fine_tune_model(
checkpoint_dir = os.path.join(saved_model_path, "checkpoints")
os.makedirs(saved_model_path, exist_ok=True)

def compute_metrics(pred):
predictions, labels = pred
predictions = predictions.argmax(axis=-1)
precision, recall, f1, _ = precision_recall_fscore_support(
labels, predictions, average="weighted"
)
acc = accuracy_score(labels, predictions)
return {
"accuracy": acc,
"f1": f1,
"precision": precision,
"recall": recall,
}

training_args = TrainingArguments(
output_dir=checkpoint_dir,
dataloader_pin_memory=False,
evaluation_strategy="epoch", # Evaluate at the end of each epoch
logging_dir="./logs", # Directory for logs
logging_steps=10, # Log every 10 steps
Expand All @@ -236,6 +253,7 @@ def fine_tune_model(
metric_for_best_model="eval_loss",
greater_is_better=False,
save_total_limit=1,
seed=42,
)

# Initialize Trainer
Expand All @@ -244,11 +262,14 @@ def fine_tune_model(
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
compute_metrics=compute_metrics,
)

start = time.time()
# Start Training
trainer.train()

end = time.time()
print("Time needed to train model =", end - start)
# Save the final trained model and config
trainer.save_model(saved_model_path)

Expand All @@ -261,6 +282,12 @@ def fine_tune_model(
for key, value in eval_result.items():
print(f"{key}: {value}")

with open(
os.path.join(output_dir, f"evaluation_results_{export_model_name}.json"), "w"
) as json_file:
json.dump(eval_result, json_file, indent=4)
print("Evaluation results saved to 'evaluation_results.json'")

# Predict labels for the evaluation dataset
predictions = trainer.predict(eval_dataset)
predicted_labels = predictions.predictions.argmax(axis=1)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,43 @@
from transformers import AutoModelForSequenceClassification, AutoTokenizer, AutoConfig


def combine_and_filter_xlsx_files(folder_path, output_file):
"""
Combine all .xlsx files in a folder, filter rows where paragraph_relevance_flag is 1,
and save the result as a new .xlsx file.
Parameters:
folder_path (str): Path to the folder containing .xlsx files.
output_file (str): Path to save the filtered combined file.
"""
all_dataframes = []

# Iterate through all files in the folder
for file_name in os.listdir(folder_path):
if file_name.endswith(".xlsx"):
file_path = os.path.join(folder_path, file_name)
try:
# Read the Excel file into a DataFrame
df = pd.read_excel(file_path)
all_dataframes.append(df)
except Exception as e:
print(f"Error reading {file_name}: {e}")

# Combine all DataFrames into one
if all_dataframes:
combined_df = pd.concat(all_dataframes, axis=0, ignore_index=True)

# Filter rows where paragraph_relevance_flag is 1
filtered_df = combined_df[combined_df["paragraph_relevance_flag"] == 1]

# Save the filtered DataFrame to an Excel file
file_name = "combined_inference.xlsx"
filtered_df.to_excel(os.path.join(output_file, file_name), index=False)
print(f"Filtered data saved to {output_file}")
else:
print("No valid .xlsx files found in the folder.")


def validate_path_exists(path: str, which_path: str):
"""
Validate if the given path exists.
Expand Down Expand Up @@ -170,3 +207,11 @@ def run_full_inference(
print(f"Successfully SAVED resulting file at {output_file_path}")
except Exception as e:
print(f"Error saving file {excel_name}: {e}")

combine_file_path = os.path.join(output_path, "combined_inference")
os.makedirs(combine_file_path, exist_ok=True)
combine_and_filter_xlsx_files(output_path, combine_file_path)
print(
"Successfully SAVED combined inference file for KPI DETECTION in ",
combine_file_path,
)

0 comments on commit 446fa2e

Please sign in to comment.