-
Notifications
You must be signed in to change notification settings - Fork 30
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Feat: Add functionality to calculate intra-cluster distances and compare them between original and fine-tuned models #49
Changes from all commits
8631b6f
b5c55c0
1e4ac9c
bcffd5d
20096e8
1010427
51225f2
0b452e9
856476e
70f67b6
0971d16
bc89ae8
8afe690
19a57cd
1b7c685
02bd51e
0a207e1
d3a2916
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -143,3 +143,4 @@ __pyc | |
FaceRec/static/Images/uploads/* | ||
Images/dbImages/* | ||
Images/Faces/* | ||
Images/ | ||
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,7 +8,7 @@ | |
from datetime import datetime | ||
from io import BytesIO | ||
from typing import List | ||
|
||
from tensorflow.keras.models import load_model | ||
from bson import ObjectId | ||
from deepface import DeepFace | ||
from dotenv import load_dotenv | ||
Comment on lines
8
to
14
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. CODE REVIEWImport statements should be grouped in the following order:
from datetime import datetime
from io import BytesIO
from typing import List
from bson import ObjectId
from dotenv import load_dotenv
from tensorflow.keras.models import load_model
from deepface import DeepFace |
||
|
@@ -20,6 +20,8 @@ | |
from matplotlib import pyplot as plt | ||
from PIL import Image | ||
from pydantic import BaseModel | ||
import numpy as np | ||
from keras.preprocessing import image | ||
|
||
from API.database import Database | ||
from API.utils import init_logging_config | ||
Comment on lines
20
to
27
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. CODE REVIEWConsider organizing imports alphabetically for easier readability. It's also recommended to group standard library imports, third-party imports, and local imports separately. from PIL import Image
import numpy as np
from keras.preprocessing import image
from matplotlib import pyplot as plt
from pydantic import BaseModel
from API.database import Database
from API.utils import init_logging_config |
||
|
@@ -36,7 +38,7 @@ | |
|
||
collection = 'faceEntries' | ||
collection2 = 'ImageDB' | ||
|
||
collection3 = 'VectorDB' | ||
|
||
# Models for the data to be sent and received by the server | ||
class Employee(BaseModel): | ||
|
@@ -53,6 +55,84 @@ | |
Department: str | ||
Images: list[str] | ||
|
||
def load_and_preprocess_image(img_path, target_size=(160, 160)): | ||
|
||
img = image.load_img(img_path, target_size=target_size) | ||
img_array = image.img_to_array(img) | ||
img_array = np.expand_dims(img_array, axis=0) | ||
img_array /= 255.0 | ||
return img_array | ||
|
||
def calculate_embeddings(image_filename): | ||
|
||
""" | ||
Calculate embeddings for the provided image. | ||
|
||
Args: | ||
image_filename (str): The path to the image file. | ||
|
||
Returns: | ||
list: A list of embeddings for the image. | ||
""" | ||
|
||
face_image_data = DeepFace.extract_faces( | ||
image_filename, enforce_detection=False, | ||
) | ||
new_image_path = f'Images/Faces/tmp.jpg' | ||
|
||
if face_image_data[0]['face'] is not None: | ||
plt.imsave(new_image_path, face_image_data[0]['face']) | ||
|
||
img_array = load_and_preprocess_image(new_image_path) | ||
model=load_model('Model/embedding_trial3.h5') | ||
embedding = model.predict(img_array)[0] | ||
embedding_list = embedding.tolist() | ||
logging.info(f'Embedding created') | ||
|
||
return embedding_list | ||
|
||
@router.post('/recalculate_embeddings') | ||
async def recalculate_embeddings(): | ||
""" | ||
Recalculate embeddings for all the images in the database. | ||
|
||
Returns: | ||
dict: A dictionary with a success message. | ||
|
||
Raises: | ||
None | ||
""" | ||
logging.info('Recalculating embeddings') | ||
employees_mongo = client2.find(collection2) | ||
for employee in employees_mongo: | ||
print(employee, type(employee)) | ||
Check failure Code scanning / CodeQL Clear-text logging of sensitive information High
This expression logs
sensitive data (private) Error loading related location Loading This expression logs sensitive data (private) Error loading related location Loading |
||
embeddings = [] | ||
|
||
# In the initial version, the images were stored in the 'Image' field | ||
if 'Images' in employee: | ||
images = employee['Images'] | ||
else: | ||
images = [employee['Image']] | ||
|
||
for encoded_image in images: | ||
|
||
pil_image = Image.open(BytesIO(base64.b64decode(encoded_image))) | ||
image_filename = f'{employee["Name"]}.png' | ||
pil_image.save(image_filename) | ||
logging.debug(f'Image saved {employee["Name"]}') | ||
Check failure Code scanning / CodeQL Clear-text logging of sensitive information High
This expression logs
sensitive data (private) Error loading related location Loading This expression logs sensitive data (private) Error loading related location Loading |
||
embeddings.append(calculate_embeddings(image_filename)) | ||
# os.remove(image_filename) | ||
|
||
logging.debug(f'About to update Embeddings: {embeddings}') | ||
# Store the data in the database | ||
client2.update_one( | ||
collection2, | ||
{'EmployeeCode': employee['EmployeeCode']}, | ||
{'$set': {'embeddings': embeddings, 'Images': images}}, | ||
) | ||
|
||
return {'message': 'Embeddings Recalculated successfully'} | ||
|
||
|
||
# To create new entries of employee | ||
@router.post('/create_new_faceEntry') | ||
|
@@ -74,7 +154,7 @@ | |
'\r\n', | ||
'', | ||
).replace('\n', '') | ||
EmployeeCode = Employee.EmployeeCode.replace('\r\n', '').replace('\n', '') | ||
EmployeeCode = Employee.EmployeeCode | ||
gender = Employee.gender.replace('\r\n', '').replace('\n', '') | ||
Department = Employee.Department.replace('\r\n', '').replace('\n', '') | ||
encoded_images = Employee.Images | ||
Comment on lines
154
to
160
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. CODE REVIEWConsider simplifying and enhancing code readability by removing unnecessary EmployeeCode = Employee.EmployeeCode
gender = Employee.gender
Department = Employee.Department
encoded_images = Employee.Images |
||
|
@@ -88,17 +168,13 @@ | |
image_filename = f'{Name}.png' | ||
pil_image.save(image_filename) | ||
pil_image.save(fr'Images\dbImages\{Name}.jpg') | ||
face_image_data = DeepFace.extract_faces( | ||
image_filename, detector_backend='mtcnn', enforce_detection=False, | ||
) | ||
plt.imsave(f'Images/Faces/{Name}.jpg', face_image_data[0]['face']) | ||
logging.info(f'Face saved {Name}') | ||
embedding = DeepFace.represent( | ||
image_filename, model_name='Facenet512', detector_backend='mtcnn', | ||
) | ||
embeddings.append(embedding) | ||
logging.info(f'Embedding created Embeddings for {Name}') | ||
os.remove(image_filename) | ||
# embedding = DeepFace.represent( | ||
# image_filename, model_name='Facenet512', detector_backend='mtcnn', | ||
# ) | ||
|
||
embeddings.append(calculate_embeddings(image_filename)) | ||
# os.remove(image_filename) | ||
|
||
logging.debug(f'About to insert Embeddings: {embeddings}') | ||
# Store the data in the database | ||
Comment on lines
168
to
180
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. CODE REVIEWConsider creating a separate function for calculating embeddings to improve code readability and maintainability. def calculate_embeddings(image_filename):
return DeepFace.represent(
image_filename, model_name='Facenet512', detector_backend='mtcnn',
) This allows for easier testing and potential reuse in the future. |
||
|
@@ -128,8 +204,8 @@ | |
list[Employee]: A list of Employee objects containing employee information. | ||
""" | ||
logging.info('Displaying all employees') | ||
employees_mongo = client.find(collection) | ||
employees_mongo = client2.find(collection2) | ||
logging.info(f'Employees found {employees_mongo}') | ||
Check failure Code scanning / CodeQL Clear-text logging of sensitive information High
This expression logs
sensitive data (private) Error loading related location Loading |
||
employees = [ | ||
Employee( | ||
Comment on lines
204
to
210
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. CODE REVIEWConsider adding more descriptive variable names for clarity. Utilize type hints for better code readability. employees_mongo = client.find(collection) # consider renaming to employees_mongo = client.find_employee_data(collection) |
||
EmployeeCode=int(employee.get('EmployeeCode', 0)), | ||
|
@@ -162,8 +238,8 @@ | |
logging.debug(f'Display information for {EmployeeCode}') | ||
try: | ||
logging.debug(f'Start {EmployeeCode}') | ||
items = client.find_one( | ||
collection, | ||
items = client2.find_one( | ||
collection2, | ||
filter={'EmployeeCode': EmployeeCode}, | ||
projection={ | ||
'Name': True, | ||
Comment on lines
238
to
245
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. CODE REVIEW
logger.debug(f'Display information for {EmployeeCode}')
try:
logger.info(f'Start {EmployeeCode}')
items = client2.find_one(
collection2,
filter={'EmployeeCode': EmployeeCode},
projection={'Name': True}
) |
||
|
@@ -210,8 +286,8 @@ | |
""" | ||
logging.debug(f'Updating for EmployeeCode: {EmployeeCode}') | ||
try: | ||
user_id = client.find_one( | ||
collection, {'EmployeeCode': EmployeeCode}, projection={'_id': True}, | ||
user_id = client2.find_one( | ||
collection2, {'EmployeeCode': EmployeeCode}, projection={'_id': True}, | ||
) | ||
print(user_id) | ||
if not user_id: | ||
Comment on lines
286
to
293
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. CODE REVIEWIt's good practice to use meaningful variable names. Consider renaming previous_client = client
new_client = client2
user_id = new_client.find_one(
new_collection, {'EmployeeCode': EmployeeCode}, projection={'_id': True},
) |
||
|
@@ -229,20 +305,19 @@ | |
image_filename = f'{Employee.Name}.png' | ||
pil_image.save(image_filename) | ||
logging.debug(f'Image saved {Employee.Name}') | ||
face_image_data = DeepFace.extract_faces( | ||
image_filename, detector_backend='mtcnn', enforce_detection=False, | ||
) | ||
embedding = DeepFace.represent( | ||
image_filename, model_name='Facenet', detector_backend='mtcnn', | ||
) | ||
logging.debug(f'Embedding created {Employee.Name}') | ||
embeddings.append(embedding) | ||
os.remove(image_filename) | ||
|
||
# embedding = DeepFace.represent( | ||
# image_filename, model_name='Facenet', detector_backend='mtcnn', | ||
# ) | ||
|
||
embeddings.append(calculate_embeddings(image_filename)) | ||
# os.remove(image_filename) | ||
|
||
Employee_data['embeddings'] = embeddings | ||
|
||
try: | ||
update_result = client.update_one( | ||
collection, | ||
update_result = client2.update_one( | ||
collection2, | ||
{'_id': ObjectId(user_id['_id'])}, | ||
update={'$set': Employee_data}, | ||
) | ||
Comment on lines
305
to
323
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. CODE REVIEWConsider removing commented-out code for clarity. Simplify by extracting the logic for computing embeddings into a separate function for better separation of concerns. def calculate_embeddings(image_filename):
embedding = DeepFace.represent(
image_filename, model_name='Facenet', detector_backend='mtcnn',
)
return embedding |
||
|
@@ -285,7 +360,7 @@ | |
""" | ||
logging.info('Deleting Employee') | ||
logging.debug(f'Deleting for EmployeeCode: {EmployeeCode}') | ||
client.find_one_and_delete(collection, {'EmployeeCode': EmployeeCode}) | ||
client2.find_one_and_delete(collection2, {'EmployeeCode': EmployeeCode}) | ||
|
||
return {'Message': 'Successfully Deleted'} | ||
|
||
Comment on lines
360
to
366
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. CODE REVIEWConsider abstracting the database client logic to improve modularity and maintainability. def delete_employee(client, collection, EmployeeCode):
client.find_one_and_delete(collection, {'EmployeeCode': EmployeeCode}) |
||
|
@@ -306,20 +381,38 @@ | |
""" | ||
logging.info('Recognizing Face') | ||
try: | ||
# Code to calculate embeddings via Original Facenet model | ||
|
||
img_data = await Face.read() | ||
with open('temp.png', 'wb') as f: | ||
image_filename = 'temp.png' | ||
with open(image_filename, 'wb') as f: | ||
f.write(img_data) | ||
|
||
embedding = DeepFace.represent( | ||
img_path='temp.png', model_name='Facenet512', detector_backend='mtcnn', | ||
# embedding = DeepFace.represent( | ||
# img_path='temp.png', model_name='Facenet512', detector_backend='mtcnn', | ||
# ) | ||
|
||
# Code to calculate embeddings via Finetuned Facenet model | ||
face_image_data = DeepFace.extract_faces( | ||
image_filename, detector_backend='mtcnn', enforce_detection=False, | ||
) | ||
result = client2.vector_search(collection2, embedding[0]['embedding']) | ||
logging.info(f"Result: {result[0]['Name']}, {result[0]['score']}") | ||
os.remove('temp.png') | ||
if result[0]['score'] < 0.5: | ||
return Response( | ||
status_code=404, content=json.dumps({'message': 'No match found'}), | ||
) | ||
|
||
if face_image_data and face_image_data[0]['face'] is not None: | ||
|
||
plt.imsave(f'Images/Faces/tmp.jpg', face_image_data[0]['face']) | ||
face_image_path = f'Images/Faces/tmp.jpg' | ||
img_array = load_and_preprocess_image(face_image_path) | ||
|
||
model = load_model('Model/embedding_trial3.h5') | ||
embedding_list = model.predict(img_array)[0] # Get the first prediction | ||
print(embedding_list, type(embedding_list)) | ||
embedding = embedding_list.tolist() | ||
result = client2.vector_search(collection3, embedding) | ||
logging.info(f"Result: {result[0]['Name']}, {result[0]['score']}") | ||
os.remove('temp.png') | ||
if result[0]['score'] < 0.5: | ||
return Response( | ||
status_code=404, content=json.dumps({'message': 'No match found'}), | ||
) | ||
except Exception as e: | ||
logging.error(f'Error: {e}') | ||
os.remove('temp.png') | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
CODE REVIEW
The changes seem to involve moving directories and adding an entire directory. It would be beneficial to provide more context and explanation behind these changes to ensure they are necessary. Consider breaking up these changes into smaller, more meaningful commits for better clarity and version control.