diff --git a/amg_example/amg_example.py b/amg_example/amg_example.py index a4bbcd5..13b7df5 100644 --- a/amg_example/amg_example.py +++ b/amg_example/amg_example.py @@ -2,6 +2,17 @@ import torch import matplotlib.pyplot as plt import cv2 +import torch.utils.benchmark as benchmark + +def profiler_runner(path, fn, *args, **kwargs): + with torch.profiler.profile( + activities=[torch.profiler.ProfilerActivity.CPU, + torch.profiler.ProfilerActivity.CUDA], + record_shapes=True) as prof: + result = fn(*args, **kwargs) + print(f"Saving trace under {path}") + prof.export_chrome_trace(path) + return result def show_anns(anns): if len(anns) == 0: @@ -22,25 +33,46 @@ def show_anns(anns): image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) -from segment_anything_fast import sam_model_registry, SamAutomaticMaskGenerator -from segment_anything_fast.tools import apply_eval_dtype_predictor +from segment_anything_fast import sam_model_registry, sam_model_fast_registry, SamAutomaticMaskGenerator sam_checkpoint = "checkpoints/sam_vit_h_4b8939.pth" model_type = "vit_h" - device = "cuda" -sam = sam_model_registry[model_type](checkpoint=sam_checkpoint) +sam = sam_model_fast_registry[model_type](checkpoint=sam_checkpoint) sam.to(device=device) +mask_generator = SamAutomaticMaskGenerator(sam, process_batch_size=8) -mask_generator = SamAutomaticMaskGenerator(sam) -mask_generator.predictor = apply_eval_dtype_predictor(mask_generator.predictor, torch.bfloat16) - +# Run thrice for warmup +masks = mask_generator.generate(image) +masks = mask_generator.generate(image) masks = mask_generator.generate(image) +# Save an example plt.figure(figsize=(image.shape[1]/100., image.shape[0]/100.), dpi=100) plt.imshow(image) show_anns(masks) plt.axis('off') plt.tight_layout() plt.savefig('dog_mask_fast.png', format='png') + +# Benchmark +torch.cuda.synchronize() +start_event = torch.cuda.Event(enable_timing=True) +end_event = torch.cuda.Event(enable_timing=True) +start_event.record() +for _ in range(10): + masks = mask_generator.generate(image) +end_event.record() +torch.cuda.synchronize() +print(start_event.elapsed_time(end_event) / 10.) + +# Save a GPU trace +profiler_runner(f"amg_example_trace.json.gz", mask_generator.generate, image) + +# Write out memory usage +max_memory_allocated_bytes = torch.cuda.max_memory_allocated() +_, total_memory = torch.cuda.mem_get_info() +max_memory_allocated_percentage = int(100 * (max_memory_allocated_bytes / total_memory)) +max_memory_allocated_bytes = max_memory_allocated_bytes >> 20 +print(f"memory(MiB): {max_memory_allocated_bytes} memory(%): {max_memory_allocated_percentage}") diff --git a/amg_example/amg_example_trace.json.gz b/amg_example/amg_example_trace.json.gz new file mode 100644 index 0000000..6d986f1 Binary files /dev/null and b/amg_example/amg_example_trace.json.gz differ diff --git a/amg_example/dog_mask_fast.png b/amg_example/dog_mask_fast.png index a3ea3e5..ed5682e 100644 Binary files a/amg_example/dog_mask_fast.png and b/amg_example/dog_mask_fast.png differ diff --git a/segment_anything_fast/automatic_mask_generator.py b/segment_anything_fast/automatic_mask_generator.py index d5a8c96..06874d1 100644 --- a/segment_anything_fast/automatic_mask_generator.py +++ b/segment_anything_fast/automatic_mask_generator.py @@ -24,6 +24,7 @@ generate_crop_boxes, is_box_near_crop_edge, mask_to_rle_pytorch, + mask_to_rle_pytorch_2, remove_small_regions, rle_to_mask, uncrop_boxes_xyxy, @@ -49,6 +50,7 @@ def __init__( point_grids: Optional[List[np.ndarray]] = None, min_mask_region_area: int = 0, output_mode: str = "binary_mask", + process_batch_size: Optional[int] = None, ) -> None: """ Using a SAM model, generates masks for the entire image. @@ -93,6 +95,10 @@ def __init__( 'uncompressed_rle', or 'coco_rle'. 'coco_rle' requires pycocotools. For large resolutions, 'binary_mask' may consume large amounts of memory. + process_batch_size (int or None): Set a batch size for the decoding step. + If None, all points will be batched up at once. Set a small number here + to decrease memory footprint. A smaller number will likely decrease + latency, but also decrease memory usage. """ assert (points_per_side is None) != ( @@ -132,6 +138,7 @@ def __init__( self.crop_n_points_downscale_factor = crop_n_points_downscale_factor self.min_mask_region_area = min_mask_region_area self.output_mode = output_mode + self.process_batch_size = process_batch_size @torch.no_grad() def generate(self, image: np.ndarray) -> List[Dict[str, Any]]: @@ -241,10 +248,13 @@ def _process_crop( # Generate masks for this crop in batches data = MaskData() - for (points,) in batch_iterator(self.points_per_batch, points_for_image): - batch_data = self._process_batch(points, cropped_im_size, crop_box, orig_size) + all_points = [points for (points,) in batch_iterator(self.points_per_batch, points_for_image)] + process_batch_size = len(all_points) if self.process_batch_size is None else self.process_batch_size + for i in range(0, len(all_points), process_batch_size): + some_points = all_points[i:i+process_batch_size] + batch_data = self._process_batch(some_points, cropped_im_size, crop_box, orig_size) data.cat(batch_data) - del batch_data + data["rles"] = mask_to_rle_pytorch_2(data["masks"]) self.predictor.reset_image() # Remove duplicates within this crop. @@ -265,24 +275,50 @@ def _process_crop( def _process_batch( self, - points: np.ndarray, + all_points: List[np.ndarray], im_size: Tuple[int, ...], crop_box: List[int], orig_size: Tuple[int, ...], ) -> MaskData: orig_h, orig_w = orig_size - - # Run model on this batch - transformed_points = self.predictor.transform.apply_coords(points, im_size) - in_points = torch.as_tensor(transformed_points, device=self.predictor.device) - in_labels = torch.ones(in_points.shape[0], dtype=torch.int, device=in_points.device) - masks, iou_preds, _ = self.predictor.predict_torch( - in_points[:, None, :], - in_labels[:, None], + nt_in_points = [] + for points in all_points: + # Run model on this batch + transformed_points = self.predictor.transform.apply_coords(points, im_size) + in_points = torch.as_tensor(transformed_points) #, device=self.predictor.device) + nt_in_points.append(in_points) + + nt_in_points = torch.nested.nested_tensor(nt_in_points, layout=torch.jagged, pin_memory=True).to(device=self.predictor.device, non_blocking=True) + # The call to prod is a workaround to share jagged sizes between two NestedTensors. + nt_in_labels = torch.ones_like(nt_in_points, dtype=torch.int).prod(dim=-1, keepdim=True) + nt_in_points = nt_in_points.unsqueeze(2) + + self.predictor.input_sizes = [self.predictor.input_size for _ in range(len(nt_in_points))] + self.predictor.original_sizes = [self.predictor.original_size for _ in range(len(nt_in_points))] + nt_masks, nt_iou_preds, _ = self.predictor.predict_torch( + point_coords=nt_in_points, + point_labels=nt_in_labels, multimask_output=True, return_logits=True, ) + data = MaskData() + for masks, iou_preds, points in zip(nt_masks.unbind(), nt_iou_preds.unbind(), all_points): + batch_data = self._process_batch_2(masks, iou_preds, points, im_size, crop_box, orig_size) + data.cat(batch_data) + return data + + # TODO: Batch this up + def _process_batch_2( + self, + masks: torch.Tensor, + iou_preds: torch.Tensor, + points: torch.Tensor, + im_size: Tuple[int, ...], + crop_box: List[int], + orig_size: Tuple[int, ...], + ) -> MaskData: + orig_h, orig_w = orig_size # Serialize predictions and store in MaskData data = MaskData( masks=masks.flatten(0, 1), @@ -315,8 +351,10 @@ def _process_batch( # Compress to RLE data["masks"] = uncrop_masks(data["masks"], crop_box, orig_h, orig_w) - data["rles"] = mask_to_rle_pytorch(data["masks"]) - del data["masks"] + # Doing this once at the end across all masks. + # data["rles"] = mask_to_rle_pytorch(data["masks"].cpu()) + # Keeping the masks around is faster, even though it uses more memory. + # del data["masks"] return data diff --git a/segment_anything_fast/build_sam.py b/segment_anything_fast/build_sam.py index f684493..a9a4c99 100644 --- a/segment_anything_fast/build_sam.py +++ b/segment_anything_fast/build_sam.py @@ -51,7 +51,7 @@ def build_sam_vit_b(checkpoint=None): "vit_b": build_sam_vit_b, } -def _apply_eval_dtype_sam(model, dtype=None): +def _apply_eval_dtype_sam(model, dtype): def prep_model(model, dtype): if dtype is not None: @@ -64,24 +64,24 @@ def prep_model(model, dtype): return model -def build_sam_fast_vit_h(checkpoint=None): +def build_sam_fast_vit_h(checkpoint=None, compile_mode='max-autotune', dtype=torch.bfloat16): sam = build_sam_vit_h(checkpoint) - sam = _apply_eval_dtype_sam(sam) - sam.image_encoder = torch.compile(sam.image_encoder, mode='max-autotune') + sam = _apply_eval_dtype_sam(sam, dtype) + sam.image_encoder = torch.compile(sam.image_encoder, mode=compile_mode) return sam build_sam_fast = build_sam_fast_vit_h -def build_sam_fast_vit_l(checkpoint=None): +def build_sam_fast_vit_l(checkpoint=None, compile_mode='max-autotune', dtype=torch.bfloat16): sam = build_sam_vit_l(checkpoint) - sam = _apply_eval_dtype_sam(sam) - sam.image_encoder = torch.compile(sam.image_encoder, mode='max-autotune') + sam = _apply_eval_dtype_sam(sam, dtype) + sam.image_encoder = torch.compile(sam.image_encoder, mode=compile_mode) return sam -def build_sam_fast_vit_b(checkpoint=None): +def build_sam_fast_vit_b(checkpoint=None, compile_mode='max-autotune', dtype=torch.bfloat16): sam = build_sam_vit_b(checkpoint) - sam = _apply_eval_dtype_sam(sam) - sam.image_encoder = torch.compile(sam.image_encoder, mode='max-autotune') + sam = _apply_eval_dtype_sam(sam, dtype) + sam.image_encoder = torch.compile(sam.image_encoder, mode=compile_mode) return sam sam_model_fast_registry = { diff --git a/segment_anything_fast/modeling/prompt_encoder.py b/segment_anything_fast/modeling/prompt_encoder.py index 168fa2c..8b7e10d 100644 --- a/segment_anything_fast/modeling/prompt_encoder.py +++ b/segment_anything_fast/modeling/prompt_encoder.py @@ -157,13 +157,10 @@ def forward( torch.Tensor: dense embeddings for the masks, in the shape Bx(embed_dim)x(embed_H)x(embed_W) """ - return_dtype = None bs = self._get_batch_size(points, boxes, masks) if points is not None: coords, labels = points sparse_embeddings = self._embed_points(coords, labels, pad=(boxes is None)) - if sparse_embeddings.dtype != coords.dtype: - return_dtype = coords.dtype if boxes is not None: sparse_embeddings = self._embed_boxes(boxes) @@ -183,10 +180,7 @@ def forward( dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand( bs, -1, self.image_embedding_size[0], self.image_embedding_size[1]) - r0, r1 = sparse_embeddings.to(dense_embeddings.dtype), dense_embeddings - if return_dtype is None: - return r0, r1 - return r0.to(return_dtype), r1.to(return_dtype) + return sparse_embeddings.to(dense_embeddings.dtype), dense_embeddings class PositionEmbeddingRandom(nn.Module): diff --git a/segment_anything_fast/utils/amg.py b/segment_anything_fast/utils/amg.py index be06407..dcae96e 100644 --- a/segment_anything_fast/utils/amg.py +++ b/segment_anything_fast/utils/amg.py @@ -72,7 +72,7 @@ def cat(self, new_stats: "MaskData") -> None: def to_numpy(self) -> None: for k, v in self._stats.items(): if isinstance(v, torch.Tensor): - self._stats[k] = v.detach().cpu().numpy() + self._stats[k] = v.detach().cpu().float().numpy() def is_box_near_crop_edge( @@ -103,6 +103,40 @@ def batch_iterator(batch_size: int, *args) -> Generator[List[Any], None, None]: for b in range(n_batches): yield [arg[b * batch_size : (b + 1) * batch_size] for arg in args] +def mask_to_rle_pytorch_2(tensor: torch.Tensor) -> List[Dict[str, Any]]: + """ + Encodes masks to an uncompressed RLE, in the format expected by + pycoco tools. + """ + # Put in fortran order and flatten h,w + b, h, w = tensor.shape + tensor = tensor.permute(0, 2, 1).flatten(1) + + # Compute change indices + diff = tensor[:, 1:] ^ tensor[:, :-1] + a = torch.tensor([[True]]).pin_memory().cuda().expand_as(diff.narrow(1, 0, 1)) + diff = torch.cat([a, diff, a], dim=1) + change_indices = diff.nonzero() + + alt_lens = diff.sum(dim=1).tolist() + + all_cur_idx = change_indices[:, 1] + all_btw_idx = torch.cat([all_cur_idx[1:], all_cur_idx[:1]]) - all_cur_idx + all_btw_idx = all_btw_idx.detach().cpu().tolist() + + # Encode run length + out = [] + counts_init = (tensor[:, 0] == 0).tolist() + offset = 0 + for i, ci in zip(range(b), counts_init): + btw_idxs = all_btw_idx[offset:offset + alt_lens[i]][:-1] + offset += alt_lens[i] + counts = [] if ci else [0] + counts.extend(btw_idxs) + out.append({"size": [h, w], "counts": counts}) + + return out + def mask_to_rle_pytorch(tensor: torch.Tensor) -> List[Dict[str, Any]]: """ diff --git a/test/test_mask_to_rle.py b/test/test_mask_to_rle.py new file mode 100644 index 0000000..b1e0212 --- /dev/null +++ b/test/test_mask_to_rle.py @@ -0,0 +1,18 @@ +import torch +import itertools +from segment_anything_fast.utils.amg import ( + mask_to_rle_pytorch, + mask_to_rle_pytorch_2, +) + +def test_masks(masks): + rles_0 = mask_to_rle_pytorch(masks) + rles_2 = mask_to_rle_pytorch_2(masks) + + for i in range(len(rles_0)): + torch.testing.assert_close(torch.tensor(rles_0[i]['counts']), torch.tensor(rles_2[i]['counts'])) + +for b, w, h in itertools.product([1, 5], [50, 128], [50, 128]): + test_masks(torch.randn(b, w, h).clamp(min=0).bool().cuda()) + test_masks(torch.randn(b, w, h).mul(0).bool().cuda()) + test_masks(torch.randn(b, w, h).mul(0).add(1).bool().cuda())