-
Notifications
You must be signed in to change notification settings - Fork 8
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #53 from CDU-data-science-team/0.4.1
0.4.1
- Loading branch information
Showing
45 changed files
with
42,805 additions
and
39,079 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,3 +6,4 @@ test_results_label/* | |
test_results_criticality/* | ||
site/ | ||
dist/ | ||
.vscode/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
from fastapi import FastAPI | ||
from pxtextmining.factories.factory_predict_unlabelled_text import factory_predict_unlabelled_text | ||
import pandas as pd | ||
import mysql.connector | ||
|
||
|
||
app = FastAPI() | ||
|
||
# Define a root `/` endpoint | ||
@app.get('/') | ||
def index(): | ||
return {'Test': 'Hello'} | ||
|
||
@app.get('/predict_from_sql') | ||
def predict(ids: str, | ||
target: str): | ||
""" | ||
This function creates an SQL query string based on the 'ids' param. It then obtains the text data from the SQL | ||
database using an SQL connector. This will need to be configured in a my.conf file. The text data is then converted | ||
to a pandas dataframe and this is used to formulate the prediction using the model of choice. The model used to generate | ||
predictions is dependent on the 'target' parameter. | ||
Args: | ||
ids (str): ids of the text data to be used for predictions, in the format '1,2,3,4,5,6'. Can take up to 5000 ids | ||
target (str): type of prediction to be chosen. Can either be 'label' or 'criticality'. | ||
Returns: | ||
list: List of the predictions, each in dictionary format, containing 'id' and 'predictions'. e.g. | ||
[{'id': 1, 'predictions': '3'}, {'id': 2, 'predictions': '1'}]. This is converted to JSON by FastAPI. | ||
""" | ||
q = ids.split(',') | ||
placeholders= ', '.join(['%s']*len(q)) # "%s, %s, %s, ... %s" | ||
if target == 'label': | ||
model = 'results_label/pipeline_label.sav' | ||
query = "SELECT id , feedback FROM text_data WHERE id IN ({})".format(placeholders) | ||
elif target == 'criticality': | ||
model = 'results_criticality_with_theme/pipeline_criticality_with_theme.sav' | ||
query = "SELECT id , label , feedback FROM text_data WHERE id IN ({})".format(placeholders) | ||
else: | ||
return {'error': 'invalid target'} | ||
|
||
db = mysql.connector.connect(option_files="my.conf", use_pure=True) | ||
with db.cursor() as cursor: | ||
cursor.execute(query, tuple(q)) | ||
text_data = cursor.fetchall() | ||
text_data = pd.DataFrame(text_data) | ||
text_data.columns = cursor.column_names | ||
if target == 'label': | ||
predictions = factory_predict_unlabelled_text(dataset=text_data, predictor="feedback", | ||
pipe_path_or_object=model, columns_to_return=['id', 'predictions']) | ||
elif target == 'criticality': | ||
text_data = text_data.rename(columns = {'feedback': 'predictor'}) | ||
predictions = factory_predict_unlabelled_text(dataset=text_data, predictor="predictor", | ||
theme = 'label', pipe_path_or_object=model, | ||
columns_to_return=['id', 'predictions']) | ||
return predictions.to_dict(orient='records') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
import requests | ||
|
||
ids = ','.join([str(x) for x in range(1,500)]) | ||
params={"ids": ids, "target": "criticality"} | ||
response = requests.get("http://127.0.0.1:8000/predict_from_sql", params=params) | ||
print(response.json()) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
from pxtextmining.pipelines.text_classification_pipeline import text_classification_pipeline | ||
|
||
""" | ||
This is an example of how to train a model that predicts 'criticality' levels using the labelled dataset | ||
'datasets/text_data.csv'. | ||
The trained model is saved in the folder results_criticality. | ||
""" | ||
|
||
pipe, tuning_results, pred, accuracy_per_class, p_compare_models_bar, index_train, index_test = \ | ||
text_classification_pipeline(filename='datasets/text_data.csv', target="criticality", predictor="feedback", | ||
test_size=0.33, | ||
ordinal=True, | ||
tknz="spacy", | ||
metric="class_balance_accuracy", | ||
cv=5, n_iter=100, n_jobs=5, verbose=3, | ||
learners=[ | ||
"SGDClassifier", | ||
"RidgeClassifier", | ||
"Perceptron", | ||
"PassiveAggressiveClassifier", | ||
"BernoulliNB", | ||
"ComplementNB", | ||
"MultinomialNB", | ||
"KNeighborsClassifier", | ||
"NearestCentroid", | ||
"RandomForestClassifier" | ||
], | ||
objects_to_save=[ | ||
"pipeline", | ||
"tuning results", | ||
"predictions", | ||
"accuracy per class", | ||
"index - training data", | ||
"index - test data", | ||
"bar plot" | ||
], | ||
save_objects_to_server=False, | ||
save_objects_to_disk=True, | ||
save_pipeline_as="pipeline_criticality_no_theme", | ||
results_folder_name="results_criticality_no_theme", | ||
reduce_criticality=True, | ||
theme=None) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.