Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
danielvegamyhre committed Jan 8, 2025
2 parents 693159f + 96ee5ee commit ec6aa9e
Show file tree
Hide file tree
Showing 47 changed files with 5,654 additions and 2,127 deletions.
4 changes: 1 addition & 3 deletions .github/scripts/github_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,12 @@
import json
import os
import warnings

from dataclasses import dataclass
from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Tuple, Union, cast
from urllib.error import HTTPError
from urllib.parse import quote
from urllib.request import Request, urlopen


GITHUB_API_URL = "https://api.github.com"


Expand Down
4 changes: 2 additions & 2 deletions .github/scripts/gitutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,14 @@
from typing import (
Any,
Callable,
cast,
Dict,
Iterator,
List,
Optional,
Tuple,
TypeVar,
Union,
cast,
)

T = TypeVar("T")
Expand Down Expand Up @@ -45,7 +45,7 @@ def fuzzy_list_to_dict(items: List[Tuple[str, str]]) -> Dict[str, List[str]]:


def _check_output(items: List[str], encoding: str = "utf-8") -> str:
from subprocess import CalledProcessError, check_output, STDOUT
from subprocess import STDOUT, CalledProcessError, check_output

try:
return check_output(items, stderr=STDOUT).decode(encoding)
Expand Down
5 changes: 2 additions & 3 deletions .github/scripts/label_utils.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
"""GitHub Label Utilities."""

import json

from functools import lru_cache
from typing import Any, List, Tuple, TYPE_CHECKING, Union
from typing import TYPE_CHECKING, Any, List, Tuple, Union

from github_utils import gh_fetch_url_and_headers, GitHubComment
from github_utils import GitHubComment, gh_fetch_url_and_headers

# TODO: this is a temp workaround to avoid circular dependencies,
# and should be removed once GitHubPR is refactored out of trymerge script.
Expand Down
16 changes: 6 additions & 10 deletions .github/scripts/trymerge.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,44 +23,41 @@
from typing import (
Any,
Callable,
cast,
Dict,
Iterable,
List,
NamedTuple,
Optional,
Pattern,
Tuple,
cast,
)
from warnings import warn

import yaml
from github_utils import (
GitHubComment,
gh_fetch_json_list,
gh_fetch_merge_base,
gh_fetch_url,
gh_graphql,
gh_post_commit_comment,
gh_post_pr_comment,
gh_update_pr_state,
GitHubComment,
)

from gitutils import (
GitRepo,
are_ghstack_branches_in_sync,
get_git_remote_name,
get_git_repo_dir,
GitRepo,
patterns_to_regex,
retries_decorator,
)
from label_utils import (
gh_add_labels,
gh_remove_label,
has_required_labels,
LABEL_ERR_MSG,
)
from trymerge_explainer import get_revert_message, TryMergeExplainer
from trymerge_explainer import TryMergeExplainer, get_revert_message

# labels
MERGE_IN_PROGRESS_LABEL = "merging"
Expand Down Expand Up @@ -1477,7 +1474,7 @@ def checks_to_str(checks: List[Tuple[str, Optional[str]]]) -> str:


def checks_to_markdown_bullets(
checks: List[Tuple[str, Optional[str], Optional[int]]]
checks: List[Tuple[str, Optional[str], Optional[int]]],
) -> List[str]:
return [
f"- [{c[0]}]({c[1]})" if c[1] is not None else f"- {c[0]}" for c in checks[:5]
Expand Down Expand Up @@ -1716,7 +1713,7 @@ def get_readable_drci_results(drci_classifications: Any) -> str:
try:
print(f"From Dr.CI checkrun summary: {drci_summary}")
drci_classifications = json.loads(str(drci_summary))
except json.JSONDecodeError as error:
except json.JSONDecodeError:
warn("Invalid Dr.CI checkrun summary")
drci_classifications = {}

Expand Down Expand Up @@ -1887,7 +1884,6 @@ def do_revert_prs(
dry_run: bool = False,
) -> None:
# Prepare and push revert commits
commit_shas: List[str] = []
for commit_sha, pr in shas_and_prs:
revert_msg = f"\nReverted {pr.get_pr_url()} on behalf of {prefix_with_github_url(author_login)}"
revert_msg += extra_msg
Expand Down
1 change: 0 additions & 1 deletion .github/scripts/trymerge_explainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import re
from typing import List, Optional, Pattern, Tuple


BOT_COMMANDS_WIKI = "https://github.com/pytorch/pytorch/wiki/Bot-commands"

CIFLOW_LABEL = re.compile(r"^ciflow/.+")
Expand Down
9 changes: 3 additions & 6 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,9 @@

import os
import sys
from docutils.parsers import rst

import pytorch_sphinx_theme
from docutils.parsers import rst

sys.path.append(os.path.abspath("."))

Expand Down Expand Up @@ -60,7 +61,7 @@

### TODO: Delete this when we have content
suppress_warnings = [
'toc.unlisted',
"toc.unlisted",
]
###

Expand Down Expand Up @@ -169,12 +170,8 @@
# -- A patch that prevents Sphinx from cross-referencing ivar tags -------
# See http://stackoverflow.com/a/41184353/3343043

from docutils import nodes
from sphinx import addnodes
from sphinx.util.docfields import TypedField

from custom_directives import CustomCardEnd, CustomCardItem, CustomCardStart
from docutils.parsers import rst

rst.directives.register_directive("customcardstart", CustomCardStart)
rst.directives.register_directive("customcarditem", CustomCardItem)
Expand Down
5 changes: 3 additions & 2 deletions docs/source/tutorials_source/template_tutorial.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,9 @@
# -----
#
# Example code (the output below is generated automatically):
#
#
import torch

x = torch.rand(5, 3)
print(x)

Expand All @@ -48,7 +49,7 @@
######################################################################
# Conclusion
# ----------
#
#
# Summarize the steps and concepts covered. Highlight key takeaways.
#
# Further Reading
Expand Down
73 changes: 45 additions & 28 deletions examples/sam2_amg_server/amg_example.py
Original file line number Diff line number Diff line change
@@ -1,57 +1,68 @@
import cv2
import matplotlib.pyplot as plt
import numpy as np
import torch
import matplotlib.pyplot as plt
import cv2
import torch.utils.benchmark as benchmark

from torch._inductor import config as inductorconfig

inductorconfig.triton.unique_kernel_names = True
inductorconfig.coordinate_descent_tuning = True
inductorconfig.coordinate_descent_check_all_directions = True


def profiler_runner(path, fn, *args, **kwargs):
with torch.profiler.profile(
activities=[torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA],
record_shapes=True) as prof:
activities=[
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA,
],
record_shapes=True,
) as prof:
result = fn(*args, **kwargs)
print(f"Saving trace under {path}")
prof.export_chrome_trace(path)
return result


def show_anns(anns):
if len(anns) == 0:
return
sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)
sorted_anns = sorted(anns, key=(lambda x: x["area"]), reverse=True)
ax = plt.gca()
ax.set_autoscale_on(False)

img = np.ones((sorted_anns[0]['segmentation'].shape[0], sorted_anns[0]['segmentation'].shape[1], 4))
img[:,:,3] = 0
img = np.ones(
(
sorted_anns[0]["segmentation"].shape[0],
sorted_anns[0]["segmentation"].shape[1],
4,
)
)
img[:, :, 3] = 0
ms = []
for ann in sorted_anns:
m = ann['segmentation']
m = ann["segmentation"]
ms.append(torch.as_tensor(m))
color_mask = np.concatenate([np.random.random(3), [0.35]])
img[m] = color_mask
ax.imshow(img)
return torch.stack(ms)

image = cv2.imread('dog.jpg')

image = cv2.imread("dog.jpg")
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)


# from segment_anything_fast import sam_model_registry, sam_model_fast_registry, SamAutomaticMaskGenerator
#
#
# sam_checkpoint = "checkpoints/sam_vit_h_4b8939.pth"
# model_type = "vit_h"
device = "cuda"
#
#
# sam = sam_model_fast_registry[model_type](checkpoint=sam_checkpoint)
# sam.to(device=device)

from sam2.build_sam import build_sam2
from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator
from sam2.build_sam import build_sam2

sam2_checkpoint = "checkpoints/sam2_hiera_large.pt"
model_cfg = "sam2_hiera_l.yaml"
Expand All @@ -66,7 +77,7 @@ def show_anns(anns):
## TODO: Implement mIoU to allow approximations.
# torch.set_float32_matmul_precision('high')
# torch.autocast("cuda", dtype=torch.bfloat16).__enter__()
##
##

## TODO: Using CUDA graphs can cause numerical differences?
mask_generator.predictor.model.image_encoder = torch.compile(
Expand All @@ -93,24 +104,26 @@ def show_anns(anns):
)

# with torch.backends.cuda.sdp_kernel(enable_cudnn=False): #, enable_math=False, enable_mem_efficient=False):
with torch.backends.cuda.sdp_kernel(enable_cudnn=True): #, enable_math=False, enable_mem_efficient=False):
with torch.backends.cuda.sdp_kernel(
enable_cudnn=True
): # , enable_math=False, enable_mem_efficient=False):
# Run thrice for warmup
masks = mask_generator.generate(image)
masks = mask_generator.generate(image)
masks = mask_generator.generate(image)

# Save an example
plt.figure(figsize=(image.shape[1]/100., image.shape[0]/100.), dpi=100)
plt.figure(figsize=(image.shape[1] / 100.0, image.shape[0] / 100.0), dpi=100)
plt.imshow(image)
ms = show_anns(masks)
ms_ref = torch.load("dog_mask_fast.pt")
torch.testing.assert_allclose(ms, ms_ref)
print("Masks match reference")
# # torch.save(ms, "dog_mask_fast.pt")
plt.axis('off')
plt.axis("off")
plt.tight_layout()
plt.savefig('dog_mask_fast.png', format='png')
plt.savefig("dog_mask_fast.png", format="png")

# Benchmark
torch.cuda.synchronize()
start_event = torch.cuda.Event(enable_timing=True)
Expand All @@ -120,14 +133,18 @@ def show_anns(anns):
masks = mask_generator.generate(image)
end_event.record()
torch.cuda.synchronize()
print(start_event.elapsed_time(end_event) / 10.)
print(start_event.elapsed_time(end_event) / 10.0)

# Save a GPU trace
profiler_runner(f"amg_example_trace.json.gz", mask_generator.generate, image)
profiler_runner("amg_example_trace.json.gz", mask_generator.generate, image)

# Write out memory usage
max_memory_allocated_bytes = torch.cuda.max_memory_allocated()
_, total_memory = torch.cuda.mem_get_info()
max_memory_allocated_percentage = int(100 * (max_memory_allocated_bytes / total_memory))
max_memory_allocated_percentage = int(
100 * (max_memory_allocated_bytes / total_memory)
)
max_memory_allocated_bytes = max_memory_allocated_bytes >> 20
print(f"memory(MiB): {max_memory_allocated_bytes} memory(%): {max_memory_allocated_percentage}")
print(
f"memory(MiB): {max_memory_allocated_bytes} memory(%): {max_memory_allocated_percentage}"
)
Loading

0 comments on commit ec6aa9e

Please sign in to comment.