diff --git a/nlp/app/routes.py b/nlp/app/routes.py index ec8141c..1254e1a 100644 --- a/nlp/app/routes.py +++ b/nlp/app/routes.py @@ -42,7 +42,7 @@ def index(): # For CNN, assuming preprocessing is handled differently or is built-in text = basic_process(input_field) text = cnn_process(text) - predictions_proba = cnn.predict([text]) + predictions_proba = cnn.predict(text) pred = (predictions_proba > 0.5).astype(int)[0] # Assuming binary classification proba = predictions_proba[0] elif model_choice == 'bert': @@ -51,7 +51,6 @@ def index(): text = tokenizer(input_field, return_tensors="pt") with torch.no_grad(): outputs = model(**text) - predicted_class = torch.argmax(outputs.logits).item() return render_template('myform.html', title='', form=form,