Skip to content

Commit

Permalink
cli updates
Browse files Browse the repository at this point in the history
  • Loading branch information
petersapountzis committed Apr 22, 2024
1 parent e1f147f commit 873ab8b
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 16 deletions.
10 changes: 5 additions & 5 deletions nlp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@
# such as where data will be downloaded from.
# here is an example.
def write_default_config(path):
w = open(path, 'wt')
w.write('[data]\n')
w.write('url = https://drive.google.com/drive/folders/1gF0E9E8w1x-yz5FvxS8zFZlSNIivYfhT/train.csv\n')
w.write('file = %s%s%s\n' % (nlp_path, os.path.sep, 'train.csv'))
w.close()
with open(path, 'wt') as w:
w.write('[data]\n')
w.write('url = https://www.dropbox.com/scl/fi/8afm3cbr1ui1j3qrtv1u9/train.csv?rlkey=d0y73zduv1ira37d5xyd0sg2m&st=tfkqctcq&dl=1\n') # Corrected URL
w.write('file = %s%s%s\n' % (os.path.dirname(path), os.path.sep, 'nli.csv')) # Corrected 'file' option


# Find NLP_HOME path
if 'NLP_HOME' in os.environ:
Expand Down
55 changes: 44 additions & 11 deletions nlp/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import glob
import pickle
import sys
import os

import numpy as np
import pandas as pd
Expand All @@ -18,7 +19,7 @@
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report

from . import clf_path, config, config_path
from . import clf_path, config, config_path, write_default_config

model_name = "MoritzLaurer/mDeBERTa-v3-base-xnli-multilingual-nli-2mil7"
tokenizer = AutoTokenizer.from_pretrained(model_name)
Expand All @@ -41,29 +42,52 @@ def web(port):
from .app import app
app.run(host='0.0.0.0', debug=True, port=port)


@main.command('dl-data')
def dl_data():
"""
Download training/testing data.
"""
print("Config file path:", config_path)
config.read(config_path) # Reload the configuration
# data_url = config.get('data', 'url')
# data_file = config.get('data', 'file')
data_url = 'https://www.dropbox.com/scl/fi/8afm3cbr1ui1j3qrtv1u9/train.csv?rlkey=d0y73zduv1ira37d5xyd0sg2m&dl=0'
data_file = '/Users/petersapountzis/.nlp/nli_train.csv'
print('downloading from %s to %s' % (data_url, data_file))
# Rewrite the default configuration to make sure it's updated
write_default_config(config_path)

# Now reload the configuration to check the new values
config.read(config_path)
data_url = config.get('data', 'url')
data_file = config.get('data', 'file')

print("configuration content:")
print("URL:", data_url)
print("File:", data_file)

# Proceed with the data download
print('Downloading from %s to %s' % (data_url, data_file))
r = requests.get(data_url)
r.raise_for_status() # Ensure successful request
with open(data_file, 'wt') as f:
f.write(r.text)




def load_and_tokenize_data(file_path):
df = pd.read_csv(file_path)
print("Columns in CSV:", df.columns) # Display column names
print("Number of rows:", len(df)) # Display number of rows


# Check if required columns are present
required_columns = ['premise', 'hypothesis', 'label']
missing_columns = [col for col in required_columns if col not in df.columns]

if missing_columns:
raise KeyError(f"Missing required columns: {', '.join(missing_columns)}")

df['premise'] = df['premise'].astype(str)
df['hypothesis'] = df['hypothesis'].astype(str)

tokenized_data = tokenizer(df['premise'].tolist(), df['hypothesis'].tolist(), padding=True, truncation=True, return_tensors="pt")
labels = torch.tensor(df['label'].values)

return tokenized_data, labels

def train_model(data_file):
Expand All @@ -90,6 +114,7 @@ def train_model(data_file):


@main.command('stats')
#TODO: update stats function for my df
def stats():
"""
Read the data files and print interesting statistics.
Expand All @@ -100,9 +125,17 @@ def stats():
print(df.partisan.value_counts())

@main.command('train')
@click.argument('data_file', type=click.Path(exists=True))
def train(data_file):
def train():
"""Train the NLI classifier."""
config.read(config_path) # Reload the configuration to get the correct file path

# Get the file path for training data from the configuration
data_file = config.get('data', 'file')

if not data_file or not os.path.exists(data_file):
raise FileNotFoundError("Training data file not found. Please run ' nlp dl-data' first.")

# Proceed with training using the correct data file
train_model(data_file)
print("Training complete.")

Expand Down

0 comments on commit 873ab8b

Please sign in to comment.