diff --git a/src/osc_transformer_based_extractor/relevance_detector/fine_tune.py b/src/osc_transformer_based_extractor/relevance_detector/fine_tune.py index 5b85193..456e0d8 100644 --- a/src/osc_transformer_based_extractor/relevance_detector/fine_tune.py +++ b/src/osc_transformer_based_extractor/relevance_detector/fine_tune.py @@ -13,6 +13,8 @@ import shutil import pandas as pd import torch +import json +import time from transformers import ( TrainingArguments, Trainer, @@ -20,7 +22,7 @@ AutoTokenizer, ) from sklearn.model_selection import train_test_split -from sklearn.metrics import accuracy_score +from sklearn.metrics import accuracy_score, precision_recall_fscore_support from torch.utils.data import Dataset @@ -219,8 +221,23 @@ def fine_tune_model( checkpoint_dir = os.path.join(saved_model_path, "checkpoints") os.makedirs(saved_model_path, exist_ok=True) + def compute_metrics(pred): + predictions, labels = pred + predictions = predictions.argmax(axis=-1) + precision, recall, f1, _ = precision_recall_fscore_support( + labels, predictions, average="weighted" + ) + acc = accuracy_score(labels, predictions) + return { + "accuracy": acc, + "f1": f1, + "precision": precision, + "recall": recall, + } + training_args = TrainingArguments( output_dir=checkpoint_dir, + dataloader_pin_memory=False, evaluation_strategy="epoch", # Evaluate at the end of each epoch logging_dir="./logs", # Directory for logs logging_steps=10, # Log every 10 steps @@ -236,6 +253,7 @@ def fine_tune_model( metric_for_best_model="eval_loss", greater_is_better=False, save_total_limit=1, + seed=42, ) # Initialize Trainer @@ -244,11 +262,14 @@ def fine_tune_model( args=training_args, train_dataset=train_dataset, eval_dataset=eval_dataset, + compute_metrics=compute_metrics, ) + start = time.time() # Start Training trainer.train() - + end = time.time() + print("Time needed to train model =", end - start) # Save the final trained model and config trainer.save_model(saved_model_path) @@ -261,6 +282,12 @@ def fine_tune_model( for key, value in eval_result.items(): print(f"{key}: {value}") + with open( + os.path.join(output_dir, f"evaluation_results_{export_model_name}.json"), "w" + ) as json_file: + json.dump(eval_result, json_file, indent=4) + print("Evaluation results saved to 'evaluation_results.json'") + # Predict labels for the evaluation dataset predictions = trainer.predict(eval_dataset) predicted_labels = predictions.predictions.argmax(axis=1) diff --git a/src/osc_transformer_based_extractor/relevance_detector/inference.py b/src/osc_transformer_based_extractor/relevance_detector/inference.py index 6e007dc..095880f 100644 --- a/src/osc_transformer_based_extractor/relevance_detector/inference.py +++ b/src/osc_transformer_based_extractor/relevance_detector/inference.py @@ -7,6 +7,43 @@ from transformers import AutoModelForSequenceClassification, AutoTokenizer, AutoConfig +def combine_and_filter_xlsx_files(folder_path, output_file): + """ + Combine all .xlsx files in a folder, filter rows where paragraph_relevance_flag is 1, + and save the result as a new .xlsx file. + + Parameters: + folder_path (str): Path to the folder containing .xlsx files. + output_file (str): Path to save the filtered combined file. + """ + all_dataframes = [] + + # Iterate through all files in the folder + for file_name in os.listdir(folder_path): + if file_name.endswith(".xlsx"): + file_path = os.path.join(folder_path, file_name) + try: + # Read the Excel file into a DataFrame + df = pd.read_excel(file_path) + all_dataframes.append(df) + except Exception as e: + print(f"Error reading {file_name}: {e}") + + # Combine all DataFrames into one + if all_dataframes: + combined_df = pd.concat(all_dataframes, axis=0, ignore_index=True) + + # Filter rows where paragraph_relevance_flag is 1 + filtered_df = combined_df[combined_df["paragraph_relevance_flag"] == 1] + + # Save the filtered DataFrame to an Excel file + file_name = "combined_inference.xlsx" + filtered_df.to_excel(os.path.join(output_file, file_name), index=False) + print(f"Filtered data saved to {output_file}") + else: + print("No valid .xlsx files found in the folder.") + + def validate_path_exists(path: str, which_path: str): """ Validate if the given path exists. @@ -170,3 +207,11 @@ def run_full_inference( print(f"Successfully SAVED resulting file at {output_file_path}") except Exception as e: print(f"Error saving file {excel_name}: {e}") + + combine_file_path = os.path.join(output_path, "combined_inference") + os.makedirs(combine_file_path, exist_ok=True) + combine_and_filter_xlsx_files(output_path, combine_file_path) + print( + "Successfully SAVED combined inference file for KPI DETECTION in ", + combine_file_path, + )