Skip to content

Commit

Permalink
Merge pull request #88 from CDU-data-science-team/0.5.1
Browse files Browse the repository at this point in the history
0.5.1
  • Loading branch information
yiwen-h authored Apr 27, 2023
2 parents d2bc1fd + 1f0bc6b commit f481302
Show file tree
Hide file tree
Showing 30 changed files with 2,225 additions and 1,478 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,4 @@ dist/
.vscode/
datasets/hidden/*
test_multilabel/*
.env
141 changes: 118 additions & 23 deletions api/api.py
Original file line number Diff line number Diff line change
@@ -1,46 +1,141 @@
import os
import pickle
from typing import List

import pandas as pd
from fastapi import FastAPI
from pydantic import BaseModel
from pydantic import BaseModel, validator

from pxtextmining.factories.factory_predict_unlabelled_text import \
predict_multilabel_sklearn
from pxtextmining.factories.factory_predict_unlabelled_text import (
predict_multilabel_sklearn,
)

description = """
This API is for classifying patient experience qualitative data,
utilising the models trained as part of the pxtextmining project.
"""

tags_metadata = [
{"name": "index", "description": "Basic page to test if API is working."},
{
"name": "predict",
"description": "Generate multilabel predictions for given text.",
},
]


class Test(BaseModel):
test: str

class Config:
schema_extra = {
"example": {
"test": "Hello"
}
}

class ItemIn(BaseModel):
comment_id: str
comment_text: str
question_type: str

class Config:
schema_extra = {
"example": {
"comment_id": "01",
"comment_text": "Nurses were friendly. Parking was awful.",
"question_type": "nonspecific",
}
}

app = FastAPI()
@validator("question_type")
def question_type_validation(cls, v):
if v not in ["what_good", "could_improve", "nonspecific"]:
raise ValueError(
"question_type must be one of what_good, could_improve, or nonspecific"
)
return v


class ItemOut(BaseModel):
comment_id: str
comment_text: str
labels: list

@app.get('/')
class Config:
schema_extra = {
"example": {
"comment_id": "01",
"comment_text": "Nurses were friendly. Parking was awful.",
"labels": ["Staff manner & personal attributes", "Parking"],
}
}


app = FastAPI(
title="pxtextmining API",
description=description,
version="0.0.1",
contact={
"name": "Patient Experience Qualitative Data Categorisation",
"url": "https://cdu-data-science-team.github.io/PatientExperience-QDC/",
"email": "[email protected]",
},
license_info={
"name": "MIT License",
"url": "https://github.com/CDU-data-science-team/pxtextmining/blob/main/LICENSE",
},
openapi_tags=tags_metadata
)


@app.get("/", response_model=Test, tags=['index'])
def index():
return {'Test': 'Hello'}
return {"test": "Hello"}

@app.post('/predict_multilabel')

@app.post("/predict_multilabel", response_model=List[ItemOut], tags=['predict'])
def predict(items: List[ItemIn]):
"""Accepts comment ids and comment text as JSON in a POST request. Makes predictions using
"""Accepts comment ids, comment text and question type as JSON in a POST request. Makes predictions using trained SVC model.
Args:
items (List[ItemIn]): JSON list of dictionaries with the following compulsory keys: `comment_id` (str) and `comment_text` (str). For example, `[{'comment_id': '1', 'comment_text': 'Thank you'}, {'comment_id': '2', 'comment_text': 'Food was cold'}]`
items (List[ItemIn]): JSON list of dictionaries with the following compulsory keys:
- `comment_id` (str)
- `comment_text` (str)
- `question_type` (str)
The 'question_type' must be one of three values: 'nonspecific', 'what_good', and 'could_improve'.
For example, `[{'comment_id': '1', 'comment_text': 'Thank you', 'question_type': 'what_good'},
{'comment_id': '2', 'comment_text': 'Food was cold', 'question_type': 'could_improve'}]`
Returns:
(dict): Dict containing two keys. `comments_labelled` is a list of dictionaries containing the received comment ids, comment text, and predicted labels. `comment_ids_failed` is a list of the comment_ids where the text was unable to be labelled, for example due to being an empty string, or a null value.
(dict): Keys are: `comment_id`, `comment_text`, and predicted `labels`.
"""
with open('current_best_multilabel/svc_text_only.sav', 'rb') as model:
loaded_model = pickle.load(model)

# Process received data
df = pd.DataFrame([i.dict() for i in items], dtype=str)
df_newindex = df.set_index('comment_id')
df_newindex.index.rename('Index', inplace = True)
text_to_predict = df_newindex['comment_text']
preds_df = predict_multilabel_sklearn(text_to_predict, loaded_model)
preds_df['comment_id'] = preds_df.index.astype(str)
merged = pd.merge(df, preds_df, how='left', on='comment_id')
merged['labels'] = merged['labels'].fillna('').apply(list)
for i in merged['labels'].index:
if len(merged['labels'].loc[i]) < 1:
merged['labels'].loc[i].append('Labelling not possible')
return_dict = merged[['comment_id', 'comment_text', 'labels']].to_dict(orient='records')
df_newindex = df.set_index("comment_id")
if df_newindex.index.duplicated().sum() != 0:
raise ValueError("comment_id must all be unique values")
df_newindex.index.rename("Comment ID", inplace=True)
text_to_predict = df_newindex[["comment_text", "question_type"]]
text_to_predict = text_to_predict.rename(
columns={"comment_text": "FFT answer", "question_type": "FFT_q_standardised"}
)
# Make predictions
model_path = os.path.join("api", "svc_minorcats_v5.sav")
with open(model_path, "rb") as model:
loaded_model = pickle.load(model)
preds_df = predict_multilabel_sklearn(
text_to_predict, loaded_model, additional_features=True
)
# Join predicted labels with received data
preds_df["comment_id"] = preds_df.index.astype(str)
merged = pd.merge(df, preds_df, how="left", on="comment_id")
merged["labels"] = merged["labels"].fillna("").apply(list)
for i in merged["labels"].index:
if len(merged["labels"].loc[i]) < 1:
merged["labels"].loc[i].append("Labelling not possible")
return_dict = merged[["comment_id", "comment_text", "labels"]].to_dict(
orient="records"
)
return return_dict
104 changes: 104 additions & 0 deletions api/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
absl-py==1.4.0 ; python_version >= "3.8" and python_version < "3.11"
anyio==3.6.2 ; python_version >= "3.8" and python_version < "3.11"
astunparse==1.6.3 ; python_version >= "3.8" and python_version < "3.11"
attrs==22.2.0 ; python_version >= "3.8" and python_version < "3.11"
blis==0.7.9 ; python_version >= "3.8" and python_version < "3.11"
cachetools==5.3.0 ; python_version >= "3.8" and python_version < "3.11"
catalogue==2.0.8 ; python_version >= "3.8" and python_version < "3.11"
certifi==2022.12.7 ; python_version >= "3.8" and python_version < "3.11"
chardet==4.0.0 ; python_version >= "3.8" and python_version < "3.11"
click==8.1.3 ; python_version >= "3.8" and python_version < "3.11"
colorama==0.4.6 ; python_version >= "3.8" and python_version < "3.11" and sys_platform == "win32" or python_version >= "3.8" and python_version < "3.11" and platform_system == "Windows"
confection==0.0.4 ; python_version >= "3.8" and python_version < "3.11"
contourpy==1.0.7 ; python_version >= "3.8" and python_version < "3.11"
cycler==0.11.0 ; python_version >= "3.8" and python_version < "3.11"
cymem==2.0.7 ; python_version >= "3.8" and python_version < "3.11"
exceptiongroup==1.1.1 ; python_version >= "3.8" and python_version < "3.11"
fastapi==0.94.1 ; python_version >= "3.8" and python_version < "3.11"
filelock==3.10.7 ; python_version >= "3.8" and python_version < "3.11"
flatbuffers==23.3.3 ; python_version >= "3.8" and python_version < "3.11"
fonttools==4.39.2 ; python_version >= "3.8" and python_version < "3.11"
gast==0.4.0 ; python_version >= "3.8" and python_version < "3.11"
google-auth-oauthlib==0.4.6 ; python_version >= "3.8" and python_version < "3.11"
google-auth==2.16.3 ; python_version >= "3.8" and python_version < "3.11"
google-pasta==0.2.0 ; python_version >= "3.8" and python_version < "3.11"
grpcio==1.53.0 ; python_version >= "3.8" and python_version < "3.11"
h11==0.14.0 ; python_version >= "3.8" and python_version < "3.11"
h5py==3.8.0 ; python_version >= "3.8" and python_version < "3.11"
httpcore==0.16.3 ; python_version >= "3.8" and python_version < "3.11"
httpx==0.23.3 ; python_version >= "3.8" and python_version < "3.11"
huggingface-hub==0.13.3 ; python_version >= "3.8" and python_version < "3.11"
idna==2.10 ; python_version >= "3.8" and python_version < "3.11"
importlib-metadata==6.1.0 ; python_version >= "3.8" and python_version < "3.10"
importlib-resources==5.12.0 ; python_version >= "3.8" and python_version < "3.10"
iniconfig==2.0.0 ; python_version >= "3.8" and python_version < "3.11"
jax==0.4.7 ; python_version >= "3.8" and python_version < "3.11"
jinja2==3.1.2 ; python_version >= "3.8" and python_version < "3.11"
joblib==1.2.0 ; python_version >= "3.8" and python_version < "3.11"
keras==2.12.0 ; python_version >= "3.8" and python_version < "3.11"
kiwisolver==1.4.4 ; python_version >= "3.8" and python_version < "3.11"
langcodes==3.3.0 ; python_version >= "3.8" and python_version < "3.11"
libclang==16.0.0 ; python_version >= "3.8" and python_version < "3.11"
markdown==3.3.7 ; python_version >= "3.8" and python_version < "3.11"
markupsafe==2.1.2 ; python_version >= "3.8" and python_version < "3.11"
matplotlib==3.7.1 ; python_version >= "3.8" and python_version < "3.11"
ml-dtypes==0.0.4 ; python_version >= "3.8" and python_version < "3.11"
murmurhash==1.0.9 ; python_version >= "3.8" and python_version < "3.11"
numpy==1.23.5 ; python_version < "3.11" and python_version >= "3.8"
oauthlib==3.2.2 ; python_version >= "3.8" and python_version < "3.11"
opt-einsum==3.3.0 ; python_version >= "3.8" and python_version < "3.11"
packaging==23.0 ; python_version >= "3.8" and python_version < "3.11"
pandas==1.5.3 ; python_version >= "3.8" and python_version < "3.11"
pathy==0.10.1 ; python_version >= "3.8" and python_version < "3.11"
pillow==9.4.0 ; python_version >= "3.8" and python_version < "3.11"
pluggy==1.0.0 ; python_version >= "3.8" and python_version < "3.11"
preshed==3.0.8 ; python_version >= "3.8" and python_version < "3.11"
protobuf==4.22.1 ; python_version >= "3.8" and python_version < "3.11"
pyasn1-modules==0.2.8 ; python_version >= "3.8" and python_version < "3.11"
pyasn1==0.4.8 ; python_version >= "3.8" and python_version < "3.11"
pydantic==1.10.7 ; python_version >= "3.8" and python_version < "3.11"
pyparsing==3.0.9 ; python_version >= "3.8" and python_version < "3.11"
pytest==7.2.2 ; python_version >= "3.8" and python_version < "3.11"
python-dateutil==2.8.2 ; python_version >= "3.8" and python_version < "3.11"
pytz==2023.2 ; python_version >= "3.8" and python_version < "3.11"
pyyaml==6.0 ; python_version >= "3.8" and python_version < "3.11"
regex==2023.3.23 ; python_version >= "3.8" and python_version < "3.11"
requests-oauthlib==1.3.1 ; python_version >= "3.8" and python_version < "3.11"
requests==2.25.1 ; python_version >= "3.8" and python_version < "3.11"
rfc3986[idna2008]==1.5.0 ; python_version >= "3.8" and python_version < "3.11"
rsa==4.9 ; python_version >= "3.8" and python_version < "3.11"
scikit-learn==1.0.2 ; python_version >= "3.8" and python_version < "3.11"
scipy==1.10.1 ; python_version >= "3.8" and python_version < "3.11"
setuptools-scm==7.1.0 ; python_version >= "3.8" and python_version < "3.11"
setuptools==67.6.1 ; python_version >= "3.8" and python_version < "3.11"
six==1.16.0 ; python_version >= "3.8" and python_version < "3.11"
smart-open==6.3.0 ; python_version >= "3.8" and python_version < "3.11"
sniffio==1.3.0 ; python_version >= "3.8" and python_version < "3.11"
spacy-legacy==3.0.12 ; python_version >= "3.8" and python_version < "3.11"
spacy-loggers==1.0.4 ; python_version >= "3.8" and python_version < "3.11"
spacy==3.5.1 ; python_version >= "3.8" and python_version < "3.11"
srsly==2.4.6 ; python_version >= "3.8" and python_version < "3.11"
starlette==0.26.1 ; python_version >= "3.8" and python_version < "3.11"
tensorboard-data-server==0.7.0 ; python_version >= "3.8" and python_version < "3.11"
tensorboard-plugin-wit==1.8.1 ; python_version >= "3.8" and python_version < "3.11"
tensorboard==2.12.0 ; python_version >= "3.8" and python_version < "3.11"
tensorflow-estimator==2.12.0 ; python_version >= "3.8" and python_version < "3.11"
tensorflow-io-gcs-filesystem==0.31.0 ; python_version >= "3.8" and python_version < "3.11" and platform_machine != "arm64" or python_version >= "3.8" and python_version < "3.11" and platform_system != "Darwin"
tensorflow==2.12.0 ; python_version >= "3.8" and python_version < "3.11"
termcolor==2.2.0 ; python_version >= "3.8" and python_version < "3.11"
thinc==8.1.9 ; python_version >= "3.8" and python_version < "3.11"
threadpoolctl==3.1.0 ; python_version >= "3.8" and python_version < "3.11"
tokenizers==0.13.2 ; python_version >= "3.8" and python_version < "3.11"
tomli==2.0.1 ; python_version >= "3.8" and python_version < "3.11"
tqdm==4.65.0 ; python_version >= "3.8" and python_version < "3.11"
transformers==4.27.3 ; python_version >= "3.8" and python_version < "3.11"
typer==0.7.0 ; python_version >= "3.8" and python_version < "3.11"
typing-extensions==4.5.0 ; python_version >= "3.8" and python_version < "3.11"
urllib3==1.26.15 ; python_version >= "3.8" and python_version < "3.11"
uvicorn==0.20.0 ; python_version >= "3.8" and python_version < "3.11"
wasabi==1.1.1 ; python_version >= "3.8" and python_version < "3.11"
werkzeug==2.2.3 ; python_version >= "3.8" and python_version < "3.11"
wheel==0.40.0 ; python_version >= "3.8" and python_version < "3.11"
wrapt==1.14.1 ; python_version >= "3.8" and python_version < "3.11"
zipp==3.15.0 ; python_version >= "3.8" and python_version < "3.10"
pxtextmining==0.5.1
Binary file added api/svc_minorcats_v5.sav
Binary file not shown.
19 changes: 14 additions & 5 deletions api/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,24 @@ def test_json_predictions(json):
response = requests.post("http://127.0.0.1:8000/predict_multilabel", json=json)
return response


if __name__ == "__main__":
df = pd.read_csv('datasets/hidden/API_test.csv')
df = df[['row_id', 'comment_txt']].copy().set_index('row_id')[:20]
df = pd.read_csv("datasets/hidden/API_test.csv")
df = df[["row_id", "comment_txt"]].copy().set_index("row_id")[:20]
js = []
for i in df.index:
js.append({'comment_id': str(i), 'comment_text': df.loc[i]['comment_txt']})
print('The JSON that was sent looks like:')
js.append(
{
"comment_id": str(i),
"comment_text": df.loc[i]["comment_txt"],
"question_type": "nonspecific",
}
)
print("The JSON that was sent looks like:")
print(js[:5])
print("The JSON that was sent looks like:")
print(js[:5])
print('The JSON that is returned is:')
print("The JSON that is returned is:")
returned_json = test_json_predictions(js).json()
print(returned_json)
# json_object = json.dumps(returned_json, indent=4)
Expand Down
Loading

0 comments on commit f481302

Please sign in to comment.