From ac4a12defe511f79b84d1af888a9aa338302191f Mon Sep 17 00:00:00 2001 From: Jonathan Sears Date: Fri, 26 Apr 2024 14:31:34 -0500 Subject: [PATCH] basic demo finished --- app/app.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/app/app.py b/app/app.py index 46cc2db..3767bfa 100644 --- a/app/app.py +++ b/app/app.py @@ -53,13 +53,12 @@ def model_predict(text, model): # Placeholder for model prediction raw_preds = [] for sentence in text: - print(sentence) - input_ids = torch.tensor(sentence['input_ids']).to(device) - attention_mask = torch.ones(sentence['attention_mask']).to(device) - token_type_ids = torch.zeros(sentence['token_type_ids']).to(device) + input_ids = torch.tensor(sentence['input_ids']).unsqueeze(0).to(device) + attention_mask = torch.tensor(sentence['attention_mask']).unsqueeze(0).to(device) + token_type_ids = torch.tensor(sentence['token_type_ids']).unsqueeze(0).to(device) with torch.no_grad(): preds = model(attention_mask = attention_mask, token_type_ids = token_type_ids, input_ids = input_ids) - raw_preds.append(preds) + raw_preds.append(preds.detach().cpu().numpy().tolist()) return raw_preds def tokenize_text(text,tokenizer): @@ -78,13 +77,13 @@ def upload(): #read the file and save it to the uploads folder data = request.get_json(force=True) text = data['text'] - print(text) + # print(text) tokenized_text = tokenize_text(text,tokenizer) #make a prediction - print(tokenized_text[0]['input_ids']) + # print(tokenized_text[0]['input_ids']) preds = model_predict(tokenized_text, model) #process the prediction to determine the output - print(preds) + # print(preds) return jsonify(prediction=preds) return None