-
Notifications
You must be signed in to change notification settings - Fork 1
/
flask_server.py
executable file
·103 lines (83 loc) · 4.04 KB
/
flask_server.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
# File: flas_server.py
# Author: @MichaelHannalla
# Project: Trurapid COVID-19 Strips Detection Server with Python
# Description: Main Python file for the flask server that loads all deep learning models, recieves the input through the browser, then performs classification,
# and then reports back to the browser the result of the submitted sample
# Please reference my GitHub account if you intend to use this project or part thereof for another purposes.
import imghdr
import os
import cv2
from utils import classify_crop, get_strip_crop, get_image_data, classical_classify_crop
from utils import positive, negative, labels
import torch
from datetime import datetime
from PIL import Image
from flask import Flask, render_template, request, redirect, url_for, abort, flash, session
from flask_socketio import SocketIO
from werkzeug.utils import secure_filename
from detecto.core import Dataset, Model
# Specifying flask objects and configs
app = Flask(__name__)
app.secret_key = b'_5#y2L"F4Q8z\n\xec]/'
app.config['MAX_CONTENT_LENGTH'] = 4096 * 4096
app.config['UPLOAD_EXTENSIONS'] = ['.jpg', '.png', '.gif', 'jpeg']
app.config['UPLOAD_PATH'] = 'uploads'
os.environ['FLASK_APP'] = 'covid_detection_server'
os.environ['FLASK_ENV'] = 'development'
socketio = SocketIO(app)
# Load the deep learning models
print("SERVER LOADING, PLEASE WAIT....")
global strip_detection_model, strip_classifier_model
strip_detection_model = Model.load('models/strip_detector_weights_pass2.pth', labels)
#strip_classifier_model = torch.load('models/strip_classifier_mini.pth') # not used now, just being loaded in early versions of code
#strip_classifier_model.eval()
print("SERVER READY")
def send_string_output(outgoing_string):
session.pop('_flashes', None)
flash(outgoing_string)
def send_output(result):
session.pop('_flashes', None)
now = datetime.now()
current_time = now.strftime("%H:%M:%S")
if result == positive:
flash("Test Result at {}: POSITIVE".format(current_time))
if result == negative:
flash("Test Result at {}: NEGATIVE".format(current_time))
# Function for input validation
def validate_image(stream):
header = stream.read(512)
stream.seek(0)
format = imghdr.what(None, header)
if not format:
return None
return '.' + (format if format != 'jpeg' else 'jpg')
# Web-page rendering
@app.route('/')
def index():
return render_template('index.html')
# HTTP routing function
@app.route('/', methods=['POST'])
def upload_files():
global strip_detection_model, strip_classifier_model
global i
uploaded_file = request.files['file'] # Request file
filename = secure_filename(uploaded_file.filename) # Security encryption
if filename != '':
file_ext = os.path.splitext(filename)[1]
# # Checking if invalid image/file has been uploaded
# if file_ext not in app.config['UPLOAD_EXTENSIONS'] or \
# file_ext != validate_image(uploaded_file.stream):
# abort(400)
try:
# Perform the detection on the incoming stream of image
send_string_output("Recieved an input, proceeding to processing")
img_cv = get_image_data(uploaded_file) # Get image from flask server
strip_crop = get_strip_crop(img_cv, strip_detection_model) # Get area of interest (strip area)
result = classical_classify_crop(strip_crop) # Classify the sample
send_output(result) # Send the output to flask server
except:
flash("Exception caught during runtime, check for invalid inputs")
return redirect(url_for('index'))
if __name__ == "__main__":
#app.run('0.0.0.0', debug=True) # Run the flask web-server
socketio.run(app, host='0.0.0.0', port=int(5000))