From aca97664cc8f44ddbdab0b8ce8ebb5962c5d6a7f Mon Sep 17 00:00:00 2001 From: jcollopy-tulane Date: Mon, 29 Apr 2024 16:41:50 -0500 Subject: [PATCH] J --- nlp/cli.py | 17 ++--------------- 1 file changed, 2 insertions(+), 15 deletions(-) diff --git a/nlp/cli.py b/nlp/cli.py index 3e5e7f5..7072248 100644 --- a/nlp/cli.py +++ b/nlp/cli.py @@ -183,22 +183,9 @@ def train_bert(): ''' Get BERT ''' - - - gauth = GoogleAuth() - gauth.LocalWebserverAuth() # Follow the authentication steps - drive = GoogleDrive(gauth) - - # Replace 'file_id_here' with the file ID from the Google Drive link - file_id = 'https://drive.google.com/file/d/1K26N14tCziLm97ie7YeBBK8d_fz9oIg3/view?usp=sharing' - - # Download the file - downloaded_file = drive.CreateFile({'id': file_id}) - downloaded_file.GetContentFile('bert_model.pth') - - # Load the model from the downloaded file - model_state_dict = torch.load('bert_model.pth', map_location=torch.device('cpu')) + model = AutoModelForSequenceClassification.from_pretrained('prajjwal1/bert-mini', num_labels=2) + model.load_state_dict(torch.load('bert_model.pth')) tokenizer = BertTokenizerFast.from_pretrained('prajjwal1/bert-mini') def tokenize(data, max_length=87):