-
Notifications
You must be signed in to change notification settings - Fork 105
/
ia_sam_manager.py
185 lines (161 loc) · 7.59 KB
/
ia_sam_manager.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
import os
import platform
from functools import partial
import torch
from modules import devices
from fast_sam import FastSamAutomaticMaskGenerator, fast_sam_model_registry
from ia_check_versions import ia_check_versions
from ia_config import get_webui_setting
from ia_logging import ia_logging
from ia_threading import torch_default_load_cd
from mobile_sam import SamAutomaticMaskGenerator as SamAutomaticMaskGeneratorMobile
from mobile_sam import SamPredictor as SamPredictorMobile
from mobile_sam import sam_model_registry as sam_model_registry_mobile
from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator
from sam2.build_sam import build_sam2
from segment_anything_fb import SamAutomaticMaskGenerator, SamPredictor, sam_model_registry
from segment_anything_hq import SamAutomaticMaskGenerator as SamAutomaticMaskGeneratorHQ
from segment_anything_hq import SamPredictor as SamPredictorHQ
from segment_anything_hq import sam_model_registry as sam_model_registry_hq
def check_bfloat16_support() -> bool:
if torch.cuda.is_available():
compute_capability = torch.cuda.get_device_capability(torch.cuda.current_device())
if compute_capability[0] >= 8:
ia_logging.debug("The CUDA device supports bfloat16")
return True
else:
ia_logging.debug("The CUDA device does not support bfloat16")
return False
else:
ia_logging.debug("CUDA is not available")
return False
def partial_from_end(func, /, *fixed_args, **fixed_kwargs):
def wrapper(*args, **kwargs):
updated_kwargs = {**fixed_kwargs, **kwargs}
return func(*args, *fixed_args, **updated_kwargs)
return wrapper
def rename_args(func, arg_map):
def wrapper(*args, **kwargs):
new_kwargs = {arg_map.get(k, k): v for k, v in kwargs.items()}
return func(*args, **new_kwargs)
return wrapper
arg_map = {"checkpoint": "ckpt_path"}
rename_build_sam2 = rename_args(build_sam2, arg_map)
end_kwargs = dict(device="cpu", mode="eval", hydra_overrides_extra=[], apply_postprocessing=False)
sam2_model_registry = {
"sam2_hiera_large": partial(partial_from_end(rename_build_sam2, **end_kwargs), "sam2_hiera_l.yaml"),
"sam2_hiera_base_plus": partial(partial_from_end(rename_build_sam2, **end_kwargs), "sam2_hiera_b+.yaml"),
"sam2_hiera_small": partial(partial_from_end(rename_build_sam2, **end_kwargs), "sam2_hiera_s.yaml"),
"sam2_hiera_tiny": partial(partial_from_end(rename_build_sam2, **end_kwargs), "sam2_hiera_t.yaml"),
}
@torch_default_load_cd()
def get_sam_mask_generator(sam_checkpoint, anime_style_chk=False):
"""Get SAM mask generator.
Args:
sam_checkpoint (str): SAM checkpoint path
Returns:
SamAutomaticMaskGenerator or None: SAM mask generator
"""
points_per_batch = 64
if "_hq_" in os.path.basename(sam_checkpoint):
model_type = os.path.basename(sam_checkpoint)[7:12]
sam_model_registry_local = sam_model_registry_hq
SamAutomaticMaskGeneratorLocal = SamAutomaticMaskGeneratorHQ
points_per_batch = 32
elif "FastSAM" in os.path.basename(sam_checkpoint):
model_type = os.path.splitext(os.path.basename(sam_checkpoint))[0]
sam_model_registry_local = fast_sam_model_registry
SamAutomaticMaskGeneratorLocal = FastSamAutomaticMaskGenerator
points_per_batch = None
elif "mobile_sam" in os.path.basename(sam_checkpoint):
model_type = "vit_t"
sam_model_registry_local = sam_model_registry_mobile
SamAutomaticMaskGeneratorLocal = SamAutomaticMaskGeneratorMobile
points_per_batch = 64
elif "sam2_" in os.path.basename(sam_checkpoint):
model_type = os.path.splitext(os.path.basename(sam_checkpoint))[0]
sam_model_registry_local = sam2_model_registry
SamAutomaticMaskGeneratorLocal = SAM2AutomaticMaskGenerator
points_per_batch = 128
else:
model_type = os.path.basename(sam_checkpoint)[4:9]
sam_model_registry_local = sam_model_registry
SamAutomaticMaskGeneratorLocal = SamAutomaticMaskGenerator
points_per_batch = 64
pred_iou_thresh = 0.88 if not anime_style_chk else 0.83
stability_score_thresh = 0.95 if not anime_style_chk else 0.9
if "sam2_" in model_type:
pred_iou_thresh = round(pred_iou_thresh - 0.18, 2)
stability_score_thresh = round(stability_score_thresh - 0.03, 2)
sam2_gen_kwargs = dict(
points_per_side=64,
points_per_batch=points_per_batch,
pred_iou_thresh=pred_iou_thresh,
stability_score_thresh=stability_score_thresh,
stability_score_offset=0.7,
crop_n_layers=1,
box_nms_thresh=0.7,
crop_n_points_downscale_factor=2)
if platform.system() == "Darwin":
sam2_gen_kwargs.update(dict(points_per_side=32, points_per_batch=64, crop_n_points_downscale_factor=1))
if os.path.isfile(sam_checkpoint):
sam = sam_model_registry_local[model_type](checkpoint=sam_checkpoint)
if platform.system() == "Darwin":
if "FastSAM" in os.path.basename(sam_checkpoint) or not ia_check_versions.torch_mps_is_available:
sam.to(device=torch.device("cpu"))
else:
sam.to(device=torch.device("mps"))
else:
if get_webui_setting("inpaint_anything_sam_oncpu", False):
ia_logging.info("SAM is running on CPU... (the option has been checked)")
sam.to(device=devices.cpu)
else:
sam.to(device=devices.device)
sam_gen_kwargs = dict(
model=sam, points_per_batch=points_per_batch, pred_iou_thresh=pred_iou_thresh, stability_score_thresh=stability_score_thresh)
if "sam2_" in model_type:
sam_gen_kwargs.update(sam2_gen_kwargs)
sam_mask_generator = SamAutomaticMaskGeneratorLocal(**sam_gen_kwargs)
else:
sam_mask_generator = None
return sam_mask_generator
@ torch_default_load_cd()
def get_sam_predictor(sam_checkpoint):
"""Get SAM predictor.
Args:
sam_checkpoint (str): SAM checkpoint path
Returns:
SamPredictor or None: SAM predictor
"""
# model_type = "vit_h"
if "_hq_" in os.path.basename(sam_checkpoint):
model_type = os.path.basename(sam_checkpoint)[7:12]
sam_model_registry_local = sam_model_registry_hq
SamPredictorLocal = SamPredictorHQ
elif "FastSAM" in os.path.basename(sam_checkpoint):
raise NotImplementedError("FastSAM predictor is not implemented yet.")
elif "mobile_sam" in os.path.basename(sam_checkpoint):
model_type = "vit_t"
sam_model_registry_local = sam_model_registry_mobile
SamPredictorLocal = SamPredictorMobile
else:
model_type = os.path.basename(sam_checkpoint)[4:9]
sam_model_registry_local = sam_model_registry
SamPredictorLocal = SamPredictor
if os.path.isfile(sam_checkpoint):
sam = sam_model_registry_local[model_type](checkpoint=sam_checkpoint)
if platform.system() == "Darwin":
if "FastSAM" in os.path.basename(sam_checkpoint) or not ia_check_versions.torch_mps_is_available:
sam.to(device=torch.device("cpu"))
else:
sam.to(device=torch.device("mps"))
else:
if get_webui_setting("inpaint_anything_sam_oncpu", False):
ia_logging.info("SAM is running on CPU... (the option has been checked)")
sam.to(device=devices.cpu)
else:
sam.to(device=devices.device)
sam_predictor = SamPredictorLocal(sam)
else:
sam_predictor = None
return sam_predictor