diff --git a/README.md b/README.md
index 2e5c491..769d8e5 100644
--- a/README.md
+++ b/README.md
@@ -114,6 +114,34 @@ cd head-segmentation
streamlit run ./scripts/apps/web_checking.py
```
+## ⏰ Inference time
+
+If you are strict with time, you can use gpu to acclerate inference. Visualization also consume some time, you can just save the final result as below.
+
+```python
+import torch
+from PIL import Image
+import head_segmentation.segmentation_pipeline as seg_pipeline
+
+device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+
+segmentation_pipeline = seg_pipeline.HumanHeadSegmentationPipeline(device=device)
+
+segmentation_map = segmentation_pipeline.predict(image)
+
+segmented_region = image * cv2.cvtColor(segmentation_map, cv2.COLOR_GRAY2RGB)
+
+pil_image = Image.fromarray(segmented_region)
+pil_image.save(save_path)
+```
+
+The table below presents inference time which is tested on Tesla T4 (just for reference). The first image will take more time.
+
+| | save figure | just save final result|
+|:--------------:|:---------------------:|:---------------------:|
+| cpu | around 2.1s | around 0.8s |
+| gpu | around 1.4s | around 0.15s |
+
### 🤗 Enjoy the model!
diff --git a/head_segmentation/_version.py b/head_segmentation/_version.py
index eb18b4b..e480a94 100644
--- a/head_segmentation/_version.py
+++ b/head_segmentation/_version.py
@@ -1,2 +1,3 @@
# __version__ = "1.0.0" # First version
-__version__ = "1.1.0" # Segmentation model downloaded from GDrive
+# __version__ = "1.1.0" # Segmentation model downloaded from GDrive
+__version__ = "1.3.0"
diff --git a/head_segmentation/segmentation_pipeline.py b/head_segmentation/segmentation_pipeline.py
index e28a297..943cd26 100644
--- a/head_segmentation/segmentation_pipeline.py
+++ b/head_segmentation/segmentation_pipeline.py
@@ -16,6 +16,7 @@ def __init__(
self,
model_path: str = C.HEAD_SEGMENTATION_MODEL_PATH,
model_url: str = C.HEAD_SEGMENTATION_MODEL_URL,
+ device: torch.device = torch.device('cpu')
):
if not os.path.exists(model_path):
model_path = C.HEAD_SEGMENTATION_MODEL_PATH
@@ -26,6 +27,7 @@ def __init__(
gdown.download(model_url, model_path, quiet=False)
+ self.device = device
ckpt = torch.load(model_path, map_location=torch.device("cpu"))
hparams = ckpt["hyper_parameters"]
@@ -35,6 +37,7 @@ def __init__(
self._model = mdl.HeadSegmentationModel.load_from_checkpoint(
ckpt_path=model_path
)
+ self._model.to(self.device)
self._model.eval()
def __call__(self, image: np.ndarray) -> np.ndarray:
@@ -42,7 +45,9 @@ def __call__(self, image: np.ndarray) -> np.ndarray:
def predict(self, image: np.ndarray) -> np.ndarray:
preprocessed_image = self._preprocess_image(image)
+ preprocessed_image = preprocessed_image.to(self.device)
mdl_out = self._model(preprocessed_image)
+ mdl_out = mdl_out.cpu()
pred_segmap = self._postprocess_model_output(mdl_out, original_image=image)
return pred_segmap