diff --git a/nlp/app/forms.py b/nlp/app/forms.py index 5395597..cd45592 100644 --- a/nlp/app/forms.py +++ b/nlp/app/forms.py @@ -6,7 +6,7 @@ class MyForm(FlaskForm): class Meta: # Ignoring CSRF security feature. csrf = False - input_field = StringField(label='input headline:', id='input_field', + input_field = StringField(label='input comment:', id='input_field', validators=[DataRequired()], render_kw={'style': 'width:50%'}) submit = SubmitField('Submit') \ No newline at end of file diff --git a/nlp/app/routes.py b/nlp/app/routes.py index ec4db3f..f62628c 100644 --- a/nlp/app/routes.py +++ b/nlp/app/routes.py @@ -6,10 +6,11 @@ import pickle import sys -clf, vec = pickle.load(open(clf_path, 'rb')) -print('read clf %s' % str(clf)) -print('read vec %s' % str(vec)) -labels = ['liberal', 'conservative'] +bnb, vec_1, process_text = pickle.load(open(clf_path, 'rb')) +print('read bnb %s' % str(bnb)) +print('read vec %s' % str(vec_1)) +print('read process_text %s' % str(process_text)) +labels = ['loss', 'win'] @app.route('/', methods=['GET', 'POST']) @app.route('/index', methods=['GET', 'POST']) @@ -18,9 +19,10 @@ def index(): result = None if form.validate_on_submit(): input_field = form.input_field.data + updated_field = process_text([input_field]) X = vec.transform([input_field]) - pred = clf.predict(X)[0] - proba = clf.predict_proba(X)[0].max() + pred = bnb.predict(X)[0] + proba = bnb.predict_proba(X)[0].max() # flash(input_field) return render_template('myform.html', title='', form=form, prediction=labels[pred], confidence='%.2f' % proba) diff --git a/nlp/cli.py b/nlp/cli.py index 501a27c..825dfe0 100644 --- a/nlp/cli.py +++ b/nlp/cli.py @@ -16,7 +16,7 @@ from sklearn.linear_model import LogisticRegression from nltk.corpus import stopwords from nltk.stem import PorterStemmer -from . import config, config_path +from . import bnb_path, lr_path, config, config_path @click.group() def main(args=None): @@ -113,6 +113,7 @@ def process_text(document): y_pred = bnb.predict(X_val) f1 = f1_score(y_val, y_pred) print("F1 Score:", round(f1, 3)) + pickle.dump((bnb, vec_1, process_text), open(bnb_path, 'wb')) @main.command('train_lr') def train_lr():