Skip to content

Commit

Permalink
updated embedding function
Browse files Browse the repository at this point in the history
  • Loading branch information
devansh-shah-11 committed Jul 28, 2024
1 parent 8afe690 commit 19a57cd
Showing 1 changed file with 79 additions and 26 deletions.
105 changes: 79 additions & 26 deletions API/route.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,80 @@ def load_and_preprocess_image(img_path, target_size=(160, 160)):
img_array /= 255.0
return img_array

def calculate_embeddings(image_filename):

"""
Calculate embeddings for the provided image.
Args:
image_filename (str): The path to the image file.
Returns:
list: A list of embeddings for the image.
"""

face_image_data = DeepFace.extract_faces(
image_filename, detector_backend='mtcnn', enforce_detection=False,
)
new_image_path = f'Images/Faces/tmp.jpg'

if face_image_data[0]['face'] is not None:
plt.imsave(new_image_path, face_image_data[0]['face'])

img_array = load_and_preprocess_image(new_image_path)
model=load_model('Model/embedding_trial3.h5')
embedding = model.predict(img_array)[0]
embedding_list = embedding.tolist()
logging.info(f'Embedding created')

return embedding_list

@router.post('/recalculate_embeddings')
async def recalculate_embeddings():
"""
Recalculate embeddings for all the images in the database.
Returns:
dict: A dictionary with a success message.
Raises:
None
"""
logging.info('Recalculating embeddings')
employees_mongo = client2.find(collection2)
for employee in employees_mongo:
print(employee, type(employee))
embeddings = []

# In the initial version, the images were stored in the 'Image' field
if 'Images' in employee:
images = employee['Images']
else:
images = employee['Image']

for encoded_image in images:
encoded_image = encoded_image.replace('data:image/png;base64,', '')
encoded_image = encoded_image.strip()
encoded_image += '=' * (-len(encoded_image) % 4)
img_recovered = base64.b64decode(encoded_image)
pil_image = Image.open(BytesIO(img_recovered))
image_filename = f'{employee["Name"]}.png'
pil_image.save(image_filename)
logging.debug(f'Image saved {employee["Name"]}')
embeddings.append(calculate_embeddings(image_filename))
# os.remove(image_filename)

logging.debug(f'About to update Embeddings: {embeddings}')
# Store the data in the database
client2.update_one(
collection2,
{'EmployeeCode': employee['EmployeeCode']},
{'$set': {'embeddings': embeddings, 'Images': images}},
)

return {'message': 'Embeddings Recalculated successfully'}


# To create new entries of employee
@router.post('/create_new_faceEntry')
async def create_new_faceEntry(Employee: Employee):
Expand Down Expand Up @@ -102,19 +176,8 @@ async def create_new_faceEntry(Employee: Employee):
# image_filename, model_name='Facenet512', detector_backend='mtcnn',
# )

face_image_data = DeepFace.extract_faces(
image_filename, detector_backend='mtcnn', enforce_detection=False,
)
if face_image_data[0]['face'] is not None:
plt.imsave(f'Images/Faces/{Name}.jpg', face_image_data[0]['face'])

img_array = load_and_preprocess_image(f'Images/Faces/{Name}.jpg')
model=load_model('Model/embedding_trial3.h5')
embedding = model.predict(img_array)[0]
embedding_list = embedding.tolist()
embeddings.append(embedding_list)
logging.info(f'Embedding created Embeddings for {Name}')
os.remove(image_filename)
embeddings.append(calculate_embeddings(image_filename))
# os.remove(image_filename)

logging.debug(f'About to insert Embeddings: {embeddings}')
# Store the data in the database
Expand Down Expand Up @@ -250,19 +313,9 @@ async def update_employees(EmployeeCode: int, Employee: UpdateEmployee):
# image_filename, model_name='Facenet', detector_backend='mtcnn',
# )

face_image_data = DeepFace.extract_faces(
image_filename, detector_backend='mtcnn', enforce_detection=False,
)
if face_image_data[0]['face'] is not None:
plt.imsave(f'Images/Faces/{Employee.Name}.jpg', face_image_data[0]['face'])

img_array = load_and_preprocess_image(f'Images/Faces/{Employee.Name}.jpg')
model=load_model('Model/embedding_trial3.h5')
embedding = model.predict(img_array)[0]
embedding_list = embedding.tolist()
embeddings.append(embedding_list)
logging.info(f'Embedding created Embeddings for {Employee.Name}')
os.remove(image_filename)
embeddings.append(calculate_embeddings(image_filename))
# os.remove(image_filename)

Employee_data['embeddings'] = embeddings

try:
Expand Down

0 comments on commit 19a57cd

Please sign in to comment.