From 873ab8bfe5d48f6c6bb0088f76f2268c71c019ff Mon Sep 17 00:00:00 2001 From: petersapountzis Date: Mon, 22 Apr 2024 09:51:40 -0500 Subject: [PATCH] cli updates --- nlp/__init__.py | 10 ++++----- nlp/cli.py | 55 +++++++++++++++++++++++++++++++++++++++---------- 2 files changed, 49 insertions(+), 16 deletions(-) diff --git a/nlp/__init__.py b/nlp/__init__.py index 3c59c3c..3653ff3 100644 --- a/nlp/__init__.py +++ b/nlp/__init__.py @@ -14,11 +14,11 @@ # such as where data will be downloaded from. # here is an example. def write_default_config(path): - w = open(path, 'wt') - w.write('[data]\n') - w.write('url = https://drive.google.com/drive/folders/1gF0E9E8w1x-yz5FvxS8zFZlSNIivYfhT/train.csv\n') - w.write('file = %s%s%s\n' % (nlp_path, os.path.sep, 'train.csv')) - w.close() + with open(path, 'wt') as w: + w.write('[data]\n') + w.write('url = https://www.dropbox.com/scl/fi/8afm3cbr1ui1j3qrtv1u9/train.csv?rlkey=d0y73zduv1ira37d5xyd0sg2m&st=tfkqctcq&dl=1\n') # Corrected URL + w.write('file = %s%s%s\n' % (os.path.dirname(path), os.path.sep, 'nli.csv')) # Corrected 'file' option + # Find NLP_HOME path if 'NLP_HOME' in os.environ: diff --git a/nlp/cli.py b/nlp/cli.py index a273e40..df2982d 100644 --- a/nlp/cli.py +++ b/nlp/cli.py @@ -5,6 +5,7 @@ import glob import pickle import sys +import os import numpy as np import pandas as pd @@ -18,7 +19,7 @@ from sklearn.model_selection import train_test_split from sklearn.metrics import classification_report -from . import clf_path, config, config_path +from . import clf_path, config, config_path, write_default_config model_name = "MoritzLaurer/mDeBERTa-v3-base-xnli-multilingual-nli-2mil7" tokenizer = AutoTokenizer.from_pretrained(model_name) @@ -41,29 +42,52 @@ def web(port): from .app import app app.run(host='0.0.0.0', debug=True, port=port) + @main.command('dl-data') def dl_data(): """ Download training/testing data. """ - print("Config file path:", config_path) - config.read(config_path) # Reload the configuration - # data_url = config.get('data', 'url') - # data_file = config.get('data', 'file') - data_url = 'https://www.dropbox.com/scl/fi/8afm3cbr1ui1j3qrtv1u9/train.csv?rlkey=d0y73zduv1ira37d5xyd0sg2m&dl=0' - data_file = '/Users/petersapountzis/.nlp/nli_train.csv' - print('downloading from %s to %s' % (data_url, data_file)) + # Rewrite the default configuration to make sure it's updated + write_default_config(config_path) + + # Now reload the configuration to check the new values + config.read(config_path) + data_url = config.get('data', 'url') + data_file = config.get('data', 'file') + + print("configuration content:") + print("URL:", data_url) + print("File:", data_file) + + # Proceed with the data download + print('Downloading from %s to %s' % (data_url, data_file)) r = requests.get(data_url) + r.raise_for_status() # Ensure successful request with open(data_file, 'wt') as f: f.write(r.text) - + + def load_and_tokenize_data(file_path): df = pd.read_csv(file_path) + print("Columns in CSV:", df.columns) # Display column names + print("Number of rows:", len(df)) # Display number of rows + + + # Check if required columns are present + required_columns = ['premise', 'hypothesis', 'label'] + missing_columns = [col for col in required_columns if col not in df.columns] + + if missing_columns: + raise KeyError(f"Missing required columns: {', '.join(missing_columns)}") + df['premise'] = df['premise'].astype(str) df['hypothesis'] = df['hypothesis'].astype(str) + tokenized_data = tokenizer(df['premise'].tolist(), df['hypothesis'].tolist(), padding=True, truncation=True, return_tensors="pt") labels = torch.tensor(df['label'].values) + return tokenized_data, labels def train_model(data_file): @@ -90,6 +114,7 @@ def train_model(data_file): @main.command('stats') +#TODO: update stats function for my df def stats(): """ Read the data files and print interesting statistics. @@ -100,9 +125,17 @@ def stats(): print(df.partisan.value_counts()) @main.command('train') -@click.argument('data_file', type=click.Path(exists=True)) -def train(data_file): +def train(): """Train the NLI classifier.""" + config.read(config_path) # Reload the configuration to get the correct file path + + # Get the file path for training data from the configuration + data_file = config.get('data', 'file') + + if not data_file or not os.path.exists(data_file): + raise FileNotFoundError("Training data file not found. Please run ' nlp dl-data' first.") + + # Proceed with training using the correct data file train_model(data_file) print("Training complete.")