diff --git a/image_recognition_face_recognition/scripts/get_face_recognition b/image_recognition_face_recognition/scripts/get_face_recognition index 360fa8f1..eb459ef6 100755 --- a/image_recognition_face_recognition/scripts/get_face_recognition +++ b/image_recognition_face_recognition/scripts/get_face_recognition @@ -50,7 +50,7 @@ if __name__ == "__main__": # Add arguments parser.add_argument("image", type=str, help="Image") - parser.add_argument("-d", "--db", type=argparse.FileType("r"), help="Load already trained faces db from file") + parser.add_argument("-d", "--db", type=argparse.FileType("rb"), help="Load already trained faces db from file") args = parser.parse_args() sys.exit(main(**vars(args))) diff --git a/image_recognition_face_recognition/scripts/train_from_images b/image_recognition_face_recognition/scripts/train_from_images index 94cbd4b1..55144a2b 100755 --- a/image_recognition_face_recognition/scripts/train_from_images +++ b/image_recognition_face_recognition/scripts/train_from_images @@ -57,7 +57,7 @@ if __name__ == '__main__': parser = ArgumentParser(description='Train openface from a database of images') parser.add_argument('modeldir', action=ReadableDir, help='Directory with folders for each category') - parser.add_argument('outfile', type=FileType('w'), help='Where to output the trained faces database') + parser.add_argument('outfile', type=FileType('wb'), help='Where to output the trained faces database') parser.add_argument('-v', '--verbose', action='store_true') diff --git a/image_recognition_face_recognition/src/image_recognition_face_recognition/face_recognizer.py b/image_recognition_face_recognition/src/image_recognition_face_recognition/face_recognizer.py index ec9be971..6684b3ff 100644 --- a/image_recognition_face_recognition/src/image_recognition_face_recognition/face_recognizer.py +++ b/image_recognition_face_recognition/src/image_recognition_face_recognition/face_recognizer.py @@ -3,7 +3,7 @@ import torch from facenet_pytorch import MTCNN, InceptionResnetV1 import numpy as np - +import pickle class TrainedFace: """ @@ -179,3 +179,9 @@ def train(self, face_representation: np.ndarray, name: str) -> None: rospy.loginfo( f"Label: {trained_face.get_label()}, Representations: {len(trained_face.get_representations())}" ) + + def save_trained_faces(self, file_name): + pickle.dump(self._trained_faces, file_name) + + def restore_trained_faces(self, file_name): + self._trained_faces = pickle.load(file_name)