Skip to content

Commit

Permalink
basic demo finished
Browse files Browse the repository at this point in the history
  • Loading branch information
JonathanSears1 committed Apr 26, 2024
1 parent 2decb34 commit ac4a12d
Showing 1 changed file with 7 additions and 8 deletions.
15 changes: 7 additions & 8 deletions app/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,13 +53,12 @@ def model_predict(text, model):
# Placeholder for model prediction
raw_preds = []
for sentence in text:
print(sentence)
input_ids = torch.tensor(sentence['input_ids']).to(device)
attention_mask = torch.ones(sentence['attention_mask']).to(device)
token_type_ids = torch.zeros(sentence['token_type_ids']).to(device)
input_ids = torch.tensor(sentence['input_ids']).unsqueeze(0).to(device)
attention_mask = torch.tensor(sentence['attention_mask']).unsqueeze(0).to(device)
token_type_ids = torch.tensor(sentence['token_type_ids']).unsqueeze(0).to(device)
with torch.no_grad():
preds = model(attention_mask = attention_mask, token_type_ids = token_type_ids, input_ids = input_ids)
raw_preds.append(preds)
raw_preds.append(preds.detach().cpu().numpy().tolist())
return raw_preds

def tokenize_text(text,tokenizer):
Expand All @@ -78,13 +77,13 @@ def upload():
#read the file and save it to the uploads folder
data = request.get_json(force=True)
text = data['text']
print(text)
# print(text)
tokenized_text = tokenize_text(text,tokenizer)
#make a prediction
print(tokenized_text[0]['input_ids'])
# print(tokenized_text[0]['input_ids'])
preds = model_predict(tokenized_text, model)
#process the prediction to determine the output
print(preds)
# print(preds)
return jsonify(prediction=preds)
return None

Expand Down

0 comments on commit ac4a12d

Please sign in to comment.