Skip to content

Commit

Permalink
Merge pull request #53 from CDU-data-science-team/0.4.1
Browse files Browse the repository at this point in the history
0.4.1
  • Loading branch information
yiwen-h authored Jan 5, 2023
2 parents 2429d8b + 24f9729 commit 436fff2
Show file tree
Hide file tree
Showing 45 changed files with 42,805 additions and 39,079 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@ test_results_label/*
test_results_criticality/*
site/
dist/
.vscode/
56 changes: 56 additions & 0 deletions api/api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
from fastapi import FastAPI
from pxtextmining.factories.factory_predict_unlabelled_text import factory_predict_unlabelled_text
import pandas as pd
import mysql.connector


app = FastAPI()

# Define a root `/` endpoint
@app.get('/')
def index():
return {'Test': 'Hello'}

@app.get('/predict_from_sql')
def predict(ids: str,
target: str):
"""
This function creates an SQL query string based on the 'ids' param. It then obtains the text data from the SQL
database using an SQL connector. This will need to be configured in a my.conf file. The text data is then converted
to a pandas dataframe and this is used to formulate the prediction using the model of choice. The model used to generate
predictions is dependent on the 'target' parameter.
Args:
ids (str): ids of the text data to be used for predictions, in the format '1,2,3,4,5,6'. Can take up to 5000 ids
target (str): type of prediction to be chosen. Can either be 'label' or 'criticality'.
Returns:
list: List of the predictions, each in dictionary format, containing 'id' and 'predictions'. e.g.
[{'id': 1, 'predictions': '3'}, {'id': 2, 'predictions': '1'}]. This is converted to JSON by FastAPI.
"""
q = ids.split(',')
placeholders= ', '.join(['%s']*len(q)) # "%s, %s, %s, ... %s"
if target == 'label':
model = 'results_label/pipeline_label.sav'
query = "SELECT id , feedback FROM text_data WHERE id IN ({})".format(placeholders)
elif target == 'criticality':
model = 'results_criticality_with_theme/pipeline_criticality_with_theme.sav'
query = "SELECT id , label , feedback FROM text_data WHERE id IN ({})".format(placeholders)
else:
return {'error': 'invalid target'}

db = mysql.connector.connect(option_files="my.conf", use_pure=True)
with db.cursor() as cursor:
cursor.execute(query, tuple(q))
text_data = cursor.fetchall()
text_data = pd.DataFrame(text_data)
text_data.columns = cursor.column_names
if target == 'label':
predictions = factory_predict_unlabelled_text(dataset=text_data, predictor="feedback",
pipe_path_or_object=model, columns_to_return=['id', 'predictions'])
elif target == 'criticality':
text_data = text_data.rename(columns = {'feedback': 'predictor'})
predictions = factory_predict_unlabelled_text(dataset=text_data, predictor="predictor",
theme = 'label', pipe_path_or_object=model,
columns_to_return=['id', 'predictions'])
return predictions.to_dict(orient='records')
6 changes: 6 additions & 0 deletions api/test_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
import requests

ids = ','.join([str(x) for x in range(1,500)])
params={"ids": ids, "target": "criticality"}
response = requests.get("http://127.0.0.1:8000/predict_from_sql", params=params)
print(response.json())
42 changes: 42 additions & 0 deletions execution/execution_criticality_no_theme.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
from pxtextmining.pipelines.text_classification_pipeline import text_classification_pipeline

"""
This is an example of how to train a model that predicts 'criticality' levels using the labelled dataset
'datasets/text_data.csv'.
The trained model is saved in the folder results_criticality.
"""

pipe, tuning_results, pred, accuracy_per_class, p_compare_models_bar, index_train, index_test = \
text_classification_pipeline(filename='datasets/text_data.csv', target="criticality", predictor="feedback",
test_size=0.33,
ordinal=True,
tknz="spacy",
metric="class_balance_accuracy",
cv=5, n_iter=100, n_jobs=5, verbose=3,
learners=[
"SGDClassifier",
"RidgeClassifier",
"Perceptron",
"PassiveAggressiveClassifier",
"BernoulliNB",
"ComplementNB",
"MultinomialNB",
"KNeighborsClassifier",
"NearestCentroid",
"RandomForestClassifier"
],
objects_to_save=[
"pipeline",
"tuning results",
"predictions",
"accuracy per class",
"index - training data",
"index - test data",
"bar plot"
],
save_objects_to_server=False,
save_objects_to_disk=True,
save_pipeline_as="pipeline_criticality_no_theme",
results_folder_name="results_criticality_no_theme",
reduce_criticality=True,
theme=None)
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@
"BernoulliNB",
"ComplementNB",
"MultinomialNB",
# "KNeighborsClassifier",
# "NearestCentroid",
"KNeighborsClassifier",
"NearestCentroid",
"RandomForestClassifier"
],
objects_to_save=[
Expand All @@ -36,7 +36,7 @@
],
save_objects_to_server=False,
save_objects_to_disk=True,
save_pipeline_as="test_pipeline_criticality",
results_folder_name="test_results_criticality",
save_pipeline_as="pipeline_criticality_with_theme",
results_folder_name="results_criticality_with_theme",
reduce_criticality=True,
theme="label")
4 changes: 2 additions & 2 deletions execution/execution_label.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@
"BernoulliNB",
"ComplementNB",
"MultinomialNB",
# "KNeighborsClassifier",
# "NearestCentroid",
"KNeighborsClassifier",
"NearestCentroid",
"RandomForestClassifier"
],
objects_to_save=[
Expand Down
2 changes: 1 addition & 1 deletion execution/execution_predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,5 @@
dataset = pd.read_csv('datasets/text_data.csv')
predictions = factory_predict_unlabelled_text(dataset=dataset, predictor="feedback",
pipe_path_or_object="results_label/pipeline_label.sav",
columns_to_return=['feedback', 'organization', 'question'])
columns_to_return='all_cols')
print(predictions.head())
Loading

0 comments on commit 436fff2

Please sign in to comment.