-
Notifications
You must be signed in to change notification settings - Fork 8
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #88 from CDU-data-science-team/0.5.1
0.5.1
- Loading branch information
Showing
30 changed files
with
2,225 additions
and
1,478 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -9,3 +9,4 @@ dist/ | |
.vscode/ | ||
datasets/hidden/* | ||
test_multilabel/* | ||
.env |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,46 +1,141 @@ | ||
import os | ||
import pickle | ||
from typing import List | ||
|
||
import pandas as pd | ||
from fastapi import FastAPI | ||
from pydantic import BaseModel | ||
from pydantic import BaseModel, validator | ||
|
||
from pxtextmining.factories.factory_predict_unlabelled_text import \ | ||
predict_multilabel_sklearn | ||
from pxtextmining.factories.factory_predict_unlabelled_text import ( | ||
predict_multilabel_sklearn, | ||
) | ||
|
||
description = """ | ||
This API is for classifying patient experience qualitative data, | ||
utilising the models trained as part of the pxtextmining project. | ||
""" | ||
|
||
tags_metadata = [ | ||
{"name": "index", "description": "Basic page to test if API is working."}, | ||
{ | ||
"name": "predict", | ||
"description": "Generate multilabel predictions for given text.", | ||
}, | ||
] | ||
|
||
|
||
class Test(BaseModel): | ||
test: str | ||
|
||
class Config: | ||
schema_extra = { | ||
"example": { | ||
"test": "Hello" | ||
} | ||
} | ||
|
||
class ItemIn(BaseModel): | ||
comment_id: str | ||
comment_text: str | ||
question_type: str | ||
|
||
class Config: | ||
schema_extra = { | ||
"example": { | ||
"comment_id": "01", | ||
"comment_text": "Nurses were friendly. Parking was awful.", | ||
"question_type": "nonspecific", | ||
} | ||
} | ||
|
||
app = FastAPI() | ||
@validator("question_type") | ||
def question_type_validation(cls, v): | ||
if v not in ["what_good", "could_improve", "nonspecific"]: | ||
raise ValueError( | ||
"question_type must be one of what_good, could_improve, or nonspecific" | ||
) | ||
return v | ||
|
||
|
||
class ItemOut(BaseModel): | ||
comment_id: str | ||
comment_text: str | ||
labels: list | ||
|
||
@app.get('/') | ||
class Config: | ||
schema_extra = { | ||
"example": { | ||
"comment_id": "01", | ||
"comment_text": "Nurses were friendly. Parking was awful.", | ||
"labels": ["Staff manner & personal attributes", "Parking"], | ||
} | ||
} | ||
|
||
|
||
app = FastAPI( | ||
title="pxtextmining API", | ||
description=description, | ||
version="0.0.1", | ||
contact={ | ||
"name": "Patient Experience Qualitative Data Categorisation", | ||
"url": "https://cdu-data-science-team.github.io/PatientExperience-QDC/", | ||
"email": "[email protected]", | ||
}, | ||
license_info={ | ||
"name": "MIT License", | ||
"url": "https://github.com/CDU-data-science-team/pxtextmining/blob/main/LICENSE", | ||
}, | ||
openapi_tags=tags_metadata | ||
) | ||
|
||
|
||
@app.get("/", response_model=Test, tags=['index']) | ||
def index(): | ||
return {'Test': 'Hello'} | ||
return {"test": "Hello"} | ||
|
||
@app.post('/predict_multilabel') | ||
|
||
@app.post("/predict_multilabel", response_model=List[ItemOut], tags=['predict']) | ||
def predict(items: List[ItemIn]): | ||
"""Accepts comment ids and comment text as JSON in a POST request. Makes predictions using | ||
"""Accepts comment ids, comment text and question type as JSON in a POST request. Makes predictions using trained SVC model. | ||
Args: | ||
items (List[ItemIn]): JSON list of dictionaries with the following compulsory keys: `comment_id` (str) and `comment_text` (str). For example, `[{'comment_id': '1', 'comment_text': 'Thank you'}, {'comment_id': '2', 'comment_text': 'Food was cold'}]` | ||
items (List[ItemIn]): JSON list of dictionaries with the following compulsory keys: | ||
- `comment_id` (str) | ||
- `comment_text` (str) | ||
- `question_type` (str) | ||
The 'question_type' must be one of three values: 'nonspecific', 'what_good', and 'could_improve'. | ||
For example, `[{'comment_id': '1', 'comment_text': 'Thank you', 'question_type': 'what_good'}, | ||
{'comment_id': '2', 'comment_text': 'Food was cold', 'question_type': 'could_improve'}]` | ||
Returns: | ||
(dict): Dict containing two keys. `comments_labelled` is a list of dictionaries containing the received comment ids, comment text, and predicted labels. `comment_ids_failed` is a list of the comment_ids where the text was unable to be labelled, for example due to being an empty string, or a null value. | ||
(dict): Keys are: `comment_id`, `comment_text`, and predicted `labels`. | ||
""" | ||
with open('current_best_multilabel/svc_text_only.sav', 'rb') as model: | ||
loaded_model = pickle.load(model) | ||
|
||
# Process received data | ||
df = pd.DataFrame([i.dict() for i in items], dtype=str) | ||
df_newindex = df.set_index('comment_id') | ||
df_newindex.index.rename('Index', inplace = True) | ||
text_to_predict = df_newindex['comment_text'] | ||
preds_df = predict_multilabel_sklearn(text_to_predict, loaded_model) | ||
preds_df['comment_id'] = preds_df.index.astype(str) | ||
merged = pd.merge(df, preds_df, how='left', on='comment_id') | ||
merged['labels'] = merged['labels'].fillna('').apply(list) | ||
for i in merged['labels'].index: | ||
if len(merged['labels'].loc[i]) < 1: | ||
merged['labels'].loc[i].append('Labelling not possible') | ||
return_dict = merged[['comment_id', 'comment_text', 'labels']].to_dict(orient='records') | ||
df_newindex = df.set_index("comment_id") | ||
if df_newindex.index.duplicated().sum() != 0: | ||
raise ValueError("comment_id must all be unique values") | ||
df_newindex.index.rename("Comment ID", inplace=True) | ||
text_to_predict = df_newindex[["comment_text", "question_type"]] | ||
text_to_predict = text_to_predict.rename( | ||
columns={"comment_text": "FFT answer", "question_type": "FFT_q_standardised"} | ||
) | ||
# Make predictions | ||
model_path = os.path.join("api", "svc_minorcats_v5.sav") | ||
with open(model_path, "rb") as model: | ||
loaded_model = pickle.load(model) | ||
preds_df = predict_multilabel_sklearn( | ||
text_to_predict, loaded_model, additional_features=True | ||
) | ||
# Join predicted labels with received data | ||
preds_df["comment_id"] = preds_df.index.astype(str) | ||
merged = pd.merge(df, preds_df, how="left", on="comment_id") | ||
merged["labels"] = merged["labels"].fillna("").apply(list) | ||
for i in merged["labels"].index: | ||
if len(merged["labels"].loc[i]) < 1: | ||
merged["labels"].loc[i].append("Labelling not possible") | ||
return_dict = merged[["comment_id", "comment_text", "labels"]].to_dict( | ||
orient="records" | ||
) | ||
return return_dict |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,104 @@ | ||
absl-py==1.4.0 ; python_version >= "3.8" and python_version < "3.11" | ||
anyio==3.6.2 ; python_version >= "3.8" and python_version < "3.11" | ||
astunparse==1.6.3 ; python_version >= "3.8" and python_version < "3.11" | ||
attrs==22.2.0 ; python_version >= "3.8" and python_version < "3.11" | ||
blis==0.7.9 ; python_version >= "3.8" and python_version < "3.11" | ||
cachetools==5.3.0 ; python_version >= "3.8" and python_version < "3.11" | ||
catalogue==2.0.8 ; python_version >= "3.8" and python_version < "3.11" | ||
certifi==2022.12.7 ; python_version >= "3.8" and python_version < "3.11" | ||
chardet==4.0.0 ; python_version >= "3.8" and python_version < "3.11" | ||
click==8.1.3 ; python_version >= "3.8" and python_version < "3.11" | ||
colorama==0.4.6 ; python_version >= "3.8" and python_version < "3.11" and sys_platform == "win32" or python_version >= "3.8" and python_version < "3.11" and platform_system == "Windows" | ||
confection==0.0.4 ; python_version >= "3.8" and python_version < "3.11" | ||
contourpy==1.0.7 ; python_version >= "3.8" and python_version < "3.11" | ||
cycler==0.11.0 ; python_version >= "3.8" and python_version < "3.11" | ||
cymem==2.0.7 ; python_version >= "3.8" and python_version < "3.11" | ||
exceptiongroup==1.1.1 ; python_version >= "3.8" and python_version < "3.11" | ||
fastapi==0.94.1 ; python_version >= "3.8" and python_version < "3.11" | ||
filelock==3.10.7 ; python_version >= "3.8" and python_version < "3.11" | ||
flatbuffers==23.3.3 ; python_version >= "3.8" and python_version < "3.11" | ||
fonttools==4.39.2 ; python_version >= "3.8" and python_version < "3.11" | ||
gast==0.4.0 ; python_version >= "3.8" and python_version < "3.11" | ||
google-auth-oauthlib==0.4.6 ; python_version >= "3.8" and python_version < "3.11" | ||
google-auth==2.16.3 ; python_version >= "3.8" and python_version < "3.11" | ||
google-pasta==0.2.0 ; python_version >= "3.8" and python_version < "3.11" | ||
grpcio==1.53.0 ; python_version >= "3.8" and python_version < "3.11" | ||
h11==0.14.0 ; python_version >= "3.8" and python_version < "3.11" | ||
h5py==3.8.0 ; python_version >= "3.8" and python_version < "3.11" | ||
httpcore==0.16.3 ; python_version >= "3.8" and python_version < "3.11" | ||
httpx==0.23.3 ; python_version >= "3.8" and python_version < "3.11" | ||
huggingface-hub==0.13.3 ; python_version >= "3.8" and python_version < "3.11" | ||
idna==2.10 ; python_version >= "3.8" and python_version < "3.11" | ||
importlib-metadata==6.1.0 ; python_version >= "3.8" and python_version < "3.10" | ||
importlib-resources==5.12.0 ; python_version >= "3.8" and python_version < "3.10" | ||
iniconfig==2.0.0 ; python_version >= "3.8" and python_version < "3.11" | ||
jax==0.4.7 ; python_version >= "3.8" and python_version < "3.11" | ||
jinja2==3.1.2 ; python_version >= "3.8" and python_version < "3.11" | ||
joblib==1.2.0 ; python_version >= "3.8" and python_version < "3.11" | ||
keras==2.12.0 ; python_version >= "3.8" and python_version < "3.11" | ||
kiwisolver==1.4.4 ; python_version >= "3.8" and python_version < "3.11" | ||
langcodes==3.3.0 ; python_version >= "3.8" and python_version < "3.11" | ||
libclang==16.0.0 ; python_version >= "3.8" and python_version < "3.11" | ||
markdown==3.3.7 ; python_version >= "3.8" and python_version < "3.11" | ||
markupsafe==2.1.2 ; python_version >= "3.8" and python_version < "3.11" | ||
matplotlib==3.7.1 ; python_version >= "3.8" and python_version < "3.11" | ||
ml-dtypes==0.0.4 ; python_version >= "3.8" and python_version < "3.11" | ||
murmurhash==1.0.9 ; python_version >= "3.8" and python_version < "3.11" | ||
numpy==1.23.5 ; python_version < "3.11" and python_version >= "3.8" | ||
oauthlib==3.2.2 ; python_version >= "3.8" and python_version < "3.11" | ||
opt-einsum==3.3.0 ; python_version >= "3.8" and python_version < "3.11" | ||
packaging==23.0 ; python_version >= "3.8" and python_version < "3.11" | ||
pandas==1.5.3 ; python_version >= "3.8" and python_version < "3.11" | ||
pathy==0.10.1 ; python_version >= "3.8" and python_version < "3.11" | ||
pillow==9.4.0 ; python_version >= "3.8" and python_version < "3.11" | ||
pluggy==1.0.0 ; python_version >= "3.8" and python_version < "3.11" | ||
preshed==3.0.8 ; python_version >= "3.8" and python_version < "3.11" | ||
protobuf==4.22.1 ; python_version >= "3.8" and python_version < "3.11" | ||
pyasn1-modules==0.2.8 ; python_version >= "3.8" and python_version < "3.11" | ||
pyasn1==0.4.8 ; python_version >= "3.8" and python_version < "3.11" | ||
pydantic==1.10.7 ; python_version >= "3.8" and python_version < "3.11" | ||
pyparsing==3.0.9 ; python_version >= "3.8" and python_version < "3.11" | ||
pytest==7.2.2 ; python_version >= "3.8" and python_version < "3.11" | ||
python-dateutil==2.8.2 ; python_version >= "3.8" and python_version < "3.11" | ||
pytz==2023.2 ; python_version >= "3.8" and python_version < "3.11" | ||
pyyaml==6.0 ; python_version >= "3.8" and python_version < "3.11" | ||
regex==2023.3.23 ; python_version >= "3.8" and python_version < "3.11" | ||
requests-oauthlib==1.3.1 ; python_version >= "3.8" and python_version < "3.11" | ||
requests==2.25.1 ; python_version >= "3.8" and python_version < "3.11" | ||
rfc3986[idna2008]==1.5.0 ; python_version >= "3.8" and python_version < "3.11" | ||
rsa==4.9 ; python_version >= "3.8" and python_version < "3.11" | ||
scikit-learn==1.0.2 ; python_version >= "3.8" and python_version < "3.11" | ||
scipy==1.10.1 ; python_version >= "3.8" and python_version < "3.11" | ||
setuptools-scm==7.1.0 ; python_version >= "3.8" and python_version < "3.11" | ||
setuptools==67.6.1 ; python_version >= "3.8" and python_version < "3.11" | ||
six==1.16.0 ; python_version >= "3.8" and python_version < "3.11" | ||
smart-open==6.3.0 ; python_version >= "3.8" and python_version < "3.11" | ||
sniffio==1.3.0 ; python_version >= "3.8" and python_version < "3.11" | ||
spacy-legacy==3.0.12 ; python_version >= "3.8" and python_version < "3.11" | ||
spacy-loggers==1.0.4 ; python_version >= "3.8" and python_version < "3.11" | ||
spacy==3.5.1 ; python_version >= "3.8" and python_version < "3.11" | ||
srsly==2.4.6 ; python_version >= "3.8" and python_version < "3.11" | ||
starlette==0.26.1 ; python_version >= "3.8" and python_version < "3.11" | ||
tensorboard-data-server==0.7.0 ; python_version >= "3.8" and python_version < "3.11" | ||
tensorboard-plugin-wit==1.8.1 ; python_version >= "3.8" and python_version < "3.11" | ||
tensorboard==2.12.0 ; python_version >= "3.8" and python_version < "3.11" | ||
tensorflow-estimator==2.12.0 ; python_version >= "3.8" and python_version < "3.11" | ||
tensorflow-io-gcs-filesystem==0.31.0 ; python_version >= "3.8" and python_version < "3.11" and platform_machine != "arm64" or python_version >= "3.8" and python_version < "3.11" and platform_system != "Darwin" | ||
tensorflow==2.12.0 ; python_version >= "3.8" and python_version < "3.11" | ||
termcolor==2.2.0 ; python_version >= "3.8" and python_version < "3.11" | ||
thinc==8.1.9 ; python_version >= "3.8" and python_version < "3.11" | ||
threadpoolctl==3.1.0 ; python_version >= "3.8" and python_version < "3.11" | ||
tokenizers==0.13.2 ; python_version >= "3.8" and python_version < "3.11" | ||
tomli==2.0.1 ; python_version >= "3.8" and python_version < "3.11" | ||
tqdm==4.65.0 ; python_version >= "3.8" and python_version < "3.11" | ||
transformers==4.27.3 ; python_version >= "3.8" and python_version < "3.11" | ||
typer==0.7.0 ; python_version >= "3.8" and python_version < "3.11" | ||
typing-extensions==4.5.0 ; python_version >= "3.8" and python_version < "3.11" | ||
urllib3==1.26.15 ; python_version >= "3.8" and python_version < "3.11" | ||
uvicorn==0.20.0 ; python_version >= "3.8" and python_version < "3.11" | ||
wasabi==1.1.1 ; python_version >= "3.8" and python_version < "3.11" | ||
werkzeug==2.2.3 ; python_version >= "3.8" and python_version < "3.11" | ||
wheel==0.40.0 ; python_version >= "3.8" and python_version < "3.11" | ||
wrapt==1.14.1 ; python_version >= "3.8" and python_version < "3.11" | ||
zipp==3.15.0 ; python_version >= "3.8" and python_version < "3.10" | ||
pxtextmining==0.5.1 |
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.