diff --git a/README.md b/README.md index 2e5c491..769d8e5 100644 --- a/README.md +++ b/README.md @@ -114,6 +114,34 @@ cd head-segmentation streamlit run ./scripts/apps/web_checking.py ``` +## ⏰ Inference time + +If you are strict with time, you can use gpu to acclerate inference. Visualization also consume some time, you can just save the final result as below. + +```python +import torch +from PIL import Image +import head_segmentation.segmentation_pipeline as seg_pipeline + +device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + +segmentation_pipeline = seg_pipeline.HumanHeadSegmentationPipeline(device=device) + +segmentation_map = segmentation_pipeline.predict(image) + +segmented_region = image * cv2.cvtColor(segmentation_map, cv2.COLOR_GRAY2RGB) + +pil_image = Image.fromarray(segmented_region) +pil_image.save(save_path) +``` + +The table below presents inference time which is tested on Tesla T4 (just for reference). The first image will take more time. + +| | save figure | just save final result| +|:--------------:|:---------------------:|:---------------------:| +| cpu | around 2.1s | around 0.8s | +| gpu | around 1.4s | around 0.15s | +
### 🤗 Enjoy the model! diff --git a/head_segmentation/_version.py b/head_segmentation/_version.py index eb18b4b..e480a94 100644 --- a/head_segmentation/_version.py +++ b/head_segmentation/_version.py @@ -1,2 +1,3 @@ # __version__ = "1.0.0" # First version -__version__ = "1.1.0" # Segmentation model downloaded from GDrive +# __version__ = "1.1.0" # Segmentation model downloaded from GDrive +__version__ = "1.3.0" diff --git a/head_segmentation/segmentation_pipeline.py b/head_segmentation/segmentation_pipeline.py index e28a297..943cd26 100644 --- a/head_segmentation/segmentation_pipeline.py +++ b/head_segmentation/segmentation_pipeline.py @@ -16,6 +16,7 @@ def __init__( self, model_path: str = C.HEAD_SEGMENTATION_MODEL_PATH, model_url: str = C.HEAD_SEGMENTATION_MODEL_URL, + device: torch.device = torch.device('cpu') ): if not os.path.exists(model_path): model_path = C.HEAD_SEGMENTATION_MODEL_PATH @@ -26,6 +27,7 @@ def __init__( gdown.download(model_url, model_path, quiet=False) + self.device = device ckpt = torch.load(model_path, map_location=torch.device("cpu")) hparams = ckpt["hyper_parameters"] @@ -35,6 +37,7 @@ def __init__( self._model = mdl.HeadSegmentationModel.load_from_checkpoint( ckpt_path=model_path ) + self._model.to(self.device) self._model.eval() def __call__(self, image: np.ndarray) -> np.ndarray: @@ -42,7 +45,9 @@ def __call__(self, image: np.ndarray) -> np.ndarray: def predict(self, image: np.ndarray) -> np.ndarray: preprocessed_image = self._preprocess_image(image) + preprocessed_image = preprocessed_image.to(self.device) mdl_out = self._model(preprocessed_image) + mdl_out = mdl_out.cpu() pred_segmap = self._postprocess_model_output(mdl_out, original_image=image) return pred_segmap