Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Faster AMG #69

Merged
merged 29 commits into from
Dec 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
9a7a088
Faster AMG
cpuhrsch Oct 18, 2023
ffc037e
Merge branch 'main' into fasteramg1
cpuhrsch Oct 20, 2023
4be6279
Update
cpuhrsch Oct 20, 2023
e3ca3cb
Update
cpuhrsch Oct 20, 2023
64f39db
Update
cpuhrsch Oct 20, 2023
81310a3
Merge remote-tracking branch 'origin' into fasteramg1
cpuhrsch Nov 17, 2023
f63c7dd
Revert experiments
cpuhrsch Nov 17, 2023
793cc74
Merge branch 'main' of github.com:pytorch-labs/segment-anything-fast …
cpuhrsch Nov 27, 2023
9566500
More exploratory mess
cpuhrsch Nov 28, 2023
7194b86
More exploratory mess
cpuhrsch Nov 28, 2023
d38332a
First wave of cleanup
cpuhrsch Nov 28, 2023
186b788
First wave of cleanup
cpuhrsch Nov 28, 2023
bef9979
First wave of cleanup
cpuhrsch Nov 28, 2023
6433755
First wave of cleanup
cpuhrsch Nov 28, 2023
f834a0c
Second wave of cleanup
cpuhrsch Nov 28, 2023
7055eae
Trying NestedTensor and bfloat16 for AMG
cpuhrsch Nov 29, 2023
6bc0db0
Trying NestedTensor and bfloat16 for AMG
cpuhrsch Nov 29, 2023
79db9ce
First wave of cleanup
cpuhrsch Nov 29, 2023
130d91c
First wave of cleanup
cpuhrsch Nov 29, 2023
a763b06
First wave of cleanup
cpuhrsch Nov 29, 2023
365c11d
Code format
cpuhrsch Nov 29, 2023
57c71fe
More exploratory mess
cpuhrsch Nov 29, 2023
775afb0
More exploratory mess
cpuhrsch Nov 29, 2023
7ed4498
Tests for mask to rle
cpuhrsch Nov 29, 2023
8b6fe40
Second wave of cleanup
cpuhrsch Nov 29, 2023
3c2c362
Third wave of cleanup
cpuhrsch Nov 29, 2023
35db2a9
Last wave of cleanup
cpuhrsch Nov 29, 2023
ec9a2bf
Add process_batch_size argument to control memory
cpuhrsch Nov 30, 2023
308049b
Add process_batch_size argument to control memory
cpuhrsch Nov 30, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 39 additions & 7 deletions amg_example/amg_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,17 @@
import torch
import matplotlib.pyplot as plt
import cv2
import torch.utils.benchmark as benchmark

def profiler_runner(path, fn, *args, **kwargs):
with torch.profiler.profile(
activities=[torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA],
record_shapes=True) as prof:
result = fn(*args, **kwargs)
print(f"Saving trace under {path}")
prof.export_chrome_trace(path)
return result

def show_anns(anns):
if len(anns) == 0:
Expand All @@ -22,25 +33,46 @@ def show_anns(anns):
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)


from segment_anything_fast import sam_model_registry, SamAutomaticMaskGenerator
from segment_anything_fast.tools import apply_eval_dtype_predictor
from segment_anything_fast import sam_model_registry, sam_model_fast_registry, SamAutomaticMaskGenerator

sam_checkpoint = "checkpoints/sam_vit_h_4b8939.pth"
model_type = "vit_h"

device = "cuda"

sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam = sam_model_fast_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device=device)
mask_generator = SamAutomaticMaskGenerator(sam, process_batch_size=8)

mask_generator = SamAutomaticMaskGenerator(sam)
mask_generator.predictor = apply_eval_dtype_predictor(mask_generator.predictor, torch.bfloat16)

# Run thrice for warmup
masks = mask_generator.generate(image)
masks = mask_generator.generate(image)
masks = mask_generator.generate(image)

# Save an example
plt.figure(figsize=(image.shape[1]/100., image.shape[0]/100.), dpi=100)
plt.imshow(image)
show_anns(masks)
plt.axis('off')
plt.tight_layout()
plt.savefig('dog_mask_fast.png', format='png')

# Benchmark
torch.cuda.synchronize()
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event.record()
for _ in range(10):
masks = mask_generator.generate(image)
end_event.record()
torch.cuda.synchronize()
print(start_event.elapsed_time(end_event) / 10.)

# Save a GPU trace
profiler_runner(f"amg_example_trace.json.gz", mask_generator.generate, image)

# Write out memory usage
max_memory_allocated_bytes = torch.cuda.max_memory_allocated()
_, total_memory = torch.cuda.mem_get_info()
max_memory_allocated_percentage = int(100 * (max_memory_allocated_bytes / total_memory))
max_memory_allocated_bytes = max_memory_allocated_bytes >> 20
print(f"memory(MiB): {max_memory_allocated_bytes} memory(%): {max_memory_allocated_percentage}")
Binary file added amg_example/amg_example_trace.json.gz
Binary file not shown.
Binary file modified amg_example/dog_mask_fast.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
66 changes: 52 additions & 14 deletions segment_anything_fast/automatic_mask_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
generate_crop_boxes,
is_box_near_crop_edge,
mask_to_rle_pytorch,
mask_to_rle_pytorch_2,
remove_small_regions,
rle_to_mask,
uncrop_boxes_xyxy,
Expand All @@ -49,6 +50,7 @@ def __init__(
point_grids: Optional[List[np.ndarray]] = None,
min_mask_region_area: int = 0,
output_mode: str = "binary_mask",
process_batch_size: Optional[int] = None,
) -> None:
"""
Using a SAM model, generates masks for the entire image.
Expand Down Expand Up @@ -93,6 +95,10 @@ def __init__(
'uncompressed_rle', or 'coco_rle'. 'coco_rle' requires pycocotools.
For large resolutions, 'binary_mask' may consume large amounts of
memory.
process_batch_size (int or None): Set a batch size for the decoding step.
If None, all points will be batched up at once. Set a small number here
to decrease memory footprint. A smaller number will likely decrease
latency, but also decrease memory usage.
"""

assert (points_per_side is None) != (
Expand Down Expand Up @@ -132,6 +138,7 @@ def __init__(
self.crop_n_points_downscale_factor = crop_n_points_downscale_factor
self.min_mask_region_area = min_mask_region_area
self.output_mode = output_mode
self.process_batch_size = process_batch_size

@torch.no_grad()
def generate(self, image: np.ndarray) -> List[Dict[str, Any]]:
Expand Down Expand Up @@ -241,10 +248,13 @@ def _process_crop(

# Generate masks for this crop in batches
data = MaskData()
for (points,) in batch_iterator(self.points_per_batch, points_for_image):
batch_data = self._process_batch(points, cropped_im_size, crop_box, orig_size)
all_points = [points for (points,) in batch_iterator(self.points_per_batch, points_for_image)]
process_batch_size = len(all_points) if self.process_batch_size is None else self.process_batch_size
for i in range(0, len(all_points), process_batch_size):
some_points = all_points[i:i+process_batch_size]
batch_data = self._process_batch(some_points, cropped_im_size, crop_box, orig_size)
data.cat(batch_data)
del batch_data
data["rles"] = mask_to_rle_pytorch_2(data["masks"])
self.predictor.reset_image()

# Remove duplicates within this crop.
Expand All @@ -265,24 +275,50 @@ def _process_crop(

def _process_batch(
self,
points: np.ndarray,
all_points: List[np.ndarray],
im_size: Tuple[int, ...],
crop_box: List[int],
orig_size: Tuple[int, ...],
) -> MaskData:
orig_h, orig_w = orig_size

# Run model on this batch
transformed_points = self.predictor.transform.apply_coords(points, im_size)
in_points = torch.as_tensor(transformed_points, device=self.predictor.device)
in_labels = torch.ones(in_points.shape[0], dtype=torch.int, device=in_points.device)
masks, iou_preds, _ = self.predictor.predict_torch(
in_points[:, None, :],
in_labels[:, None],
nt_in_points = []
for points in all_points:
# Run model on this batch
transformed_points = self.predictor.transform.apply_coords(points, im_size)
in_points = torch.as_tensor(transformed_points) #, device=self.predictor.device)
nt_in_points.append(in_points)

nt_in_points = torch.nested.nested_tensor(nt_in_points, layout=torch.jagged, pin_memory=True).to(device=self.predictor.device, non_blocking=True)
# The call to prod is a workaround to share jagged sizes between two NestedTensors.
nt_in_labels = torch.ones_like(nt_in_points, dtype=torch.int).prod(dim=-1, keepdim=True)
nt_in_points = nt_in_points.unsqueeze(2)

self.predictor.input_sizes = [self.predictor.input_size for _ in range(len(nt_in_points))]
self.predictor.original_sizes = [self.predictor.original_size for _ in range(len(nt_in_points))]
nt_masks, nt_iou_preds, _ = self.predictor.predict_torch(
point_coords=nt_in_points,
point_labels=nt_in_labels,
multimask_output=True,
return_logits=True,
)

data = MaskData()
for masks, iou_preds, points in zip(nt_masks.unbind(), nt_iou_preds.unbind(), all_points):
batch_data = self._process_batch_2(masks, iou_preds, points, im_size, crop_box, orig_size)
data.cat(batch_data)
return data

# TODO: Batch this up
def _process_batch_2(
self,
masks: torch.Tensor,
iou_preds: torch.Tensor,
points: torch.Tensor,
im_size: Tuple[int, ...],
crop_box: List[int],
orig_size: Tuple[int, ...],
) -> MaskData:
orig_h, orig_w = orig_size
# Serialize predictions and store in MaskData
data = MaskData(
masks=masks.flatten(0, 1),
Expand Down Expand Up @@ -315,8 +351,10 @@ def _process_batch(

# Compress to RLE
data["masks"] = uncrop_masks(data["masks"], crop_box, orig_h, orig_w)
data["rles"] = mask_to_rle_pytorch(data["masks"])
del data["masks"]
# Doing this once at the end across all masks.
# data["rles"] = mask_to_rle_pytorch(data["masks"].cpu())
# Keeping the masks around is faster, even though it uses more memory.
# del data["masks"]

return data

Expand Down
20 changes: 10 additions & 10 deletions segment_anything_fast/build_sam.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def build_sam_vit_b(checkpoint=None):
"vit_b": build_sam_vit_b,
}

def _apply_eval_dtype_sam(model, dtype=None):
def _apply_eval_dtype_sam(model, dtype):

def prep_model(model, dtype):
if dtype is not None:
Expand All @@ -64,24 +64,24 @@ def prep_model(model, dtype):

return model

def build_sam_fast_vit_h(checkpoint=None):
def build_sam_fast_vit_h(checkpoint=None, compile_mode='max-autotune', dtype=torch.bfloat16):
sam = build_sam_vit_h(checkpoint)
sam = _apply_eval_dtype_sam(sam)
sam.image_encoder = torch.compile(sam.image_encoder, mode='max-autotune')
sam = _apply_eval_dtype_sam(sam, dtype)
sam.image_encoder = torch.compile(sam.image_encoder, mode=compile_mode)
return sam

build_sam_fast = build_sam_fast_vit_h

def build_sam_fast_vit_l(checkpoint=None):
def build_sam_fast_vit_l(checkpoint=None, compile_mode='max-autotune', dtype=torch.bfloat16):
sam = build_sam_vit_l(checkpoint)
sam = _apply_eval_dtype_sam(sam)
sam.image_encoder = torch.compile(sam.image_encoder, mode='max-autotune')
sam = _apply_eval_dtype_sam(sam, dtype)
sam.image_encoder = torch.compile(sam.image_encoder, mode=compile_mode)
return sam

def build_sam_fast_vit_b(checkpoint=None):
def build_sam_fast_vit_b(checkpoint=None, compile_mode='max-autotune', dtype=torch.bfloat16):
sam = build_sam_vit_b(checkpoint)
sam = _apply_eval_dtype_sam(sam)
sam.image_encoder = torch.compile(sam.image_encoder, mode='max-autotune')
sam = _apply_eval_dtype_sam(sam, dtype)
sam.image_encoder = torch.compile(sam.image_encoder, mode=compile_mode)
return sam

sam_model_fast_registry = {
Expand Down
8 changes: 1 addition & 7 deletions segment_anything_fast/modeling/prompt_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,13 +157,10 @@ def forward(
torch.Tensor: dense embeddings for the masks, in the shape
Bx(embed_dim)x(embed_H)x(embed_W)
"""
return_dtype = None
bs = self._get_batch_size(points, boxes, masks)
if points is not None:
coords, labels = points
sparse_embeddings = self._embed_points(coords, labels, pad=(boxes is None))
if sparse_embeddings.dtype != coords.dtype:
return_dtype = coords.dtype
if boxes is not None:
sparse_embeddings = self._embed_boxes(boxes)

Expand All @@ -183,10 +180,7 @@ def forward(
dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand(
bs, -1, self.image_embedding_size[0], self.image_embedding_size[1])

r0, r1 = sparse_embeddings.to(dense_embeddings.dtype), dense_embeddings
if return_dtype is None:
return r0, r1
return r0.to(return_dtype), r1.to(return_dtype)
return sparse_embeddings.to(dense_embeddings.dtype), dense_embeddings


class PositionEmbeddingRandom(nn.Module):
Expand Down
36 changes: 35 additions & 1 deletion segment_anything_fast/utils/amg.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def cat(self, new_stats: "MaskData") -> None:
def to_numpy(self) -> None:
for k, v in self._stats.items():
if isinstance(v, torch.Tensor):
self._stats[k] = v.detach().cpu().numpy()
self._stats[k] = v.detach().cpu().float().numpy()


def is_box_near_crop_edge(
Expand Down Expand Up @@ -103,6 +103,40 @@ def batch_iterator(batch_size: int, *args) -> Generator[List[Any], None, None]:
for b in range(n_batches):
yield [arg[b * batch_size : (b + 1) * batch_size] for arg in args]

def mask_to_rle_pytorch_2(tensor: torch.Tensor) -> List[Dict[str, Any]]:
"""
Encodes masks to an uncompressed RLE, in the format expected by
pycoco tools.
"""
# Put in fortran order and flatten h,w
b, h, w = tensor.shape
tensor = tensor.permute(0, 2, 1).flatten(1)

# Compute change indices
diff = tensor[:, 1:] ^ tensor[:, :-1]
a = torch.tensor([[True]]).pin_memory().cuda().expand_as(diff.narrow(1, 0, 1))
diff = torch.cat([a, diff, a], dim=1)
change_indices = diff.nonzero()

alt_lens = diff.sum(dim=1).tolist()

all_cur_idx = change_indices[:, 1]
all_btw_idx = torch.cat([all_cur_idx[1:], all_cur_idx[:1]]) - all_cur_idx
all_btw_idx = all_btw_idx.detach().cpu().tolist()

# Encode run length
out = []
counts_init = (tensor[:, 0] == 0).tolist()
offset = 0
for i, ci in zip(range(b), counts_init):
btw_idxs = all_btw_idx[offset:offset + alt_lens[i]][:-1]
offset += alt_lens[i]
counts = [] if ci else [0]
counts.extend(btw_idxs)
out.append({"size": [h, w], "counts": counts})

return out


def mask_to_rle_pytorch(tensor: torch.Tensor) -> List[Dict[str, Any]]:
"""
Expand Down
18 changes: 18 additions & 0 deletions test/test_mask_to_rle.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import torch
import itertools
from segment_anything_fast.utils.amg import (
mask_to_rle_pytorch,
mask_to_rle_pytorch_2,
)

def test_masks(masks):
rles_0 = mask_to_rle_pytorch(masks)
rles_2 = mask_to_rle_pytorch_2(masks)

for i in range(len(rles_0)):
torch.testing.assert_close(torch.tensor(rles_0[i]['counts']), torch.tensor(rles_2[i]['counts']))

for b, w, h in itertools.product([1, 5], [50, 128], [50, 128]):
test_masks(torch.randn(b, w, h).clamp(min=0).bool().cuda())
test_masks(torch.randn(b, w, h).mul(0).bool().cuda())
test_masks(torch.randn(b, w, h).mul(0).add(1).bool().cuda())
Loading