From 9469a08d9b7ab8e62f8278eabf4f0ecf14703183 Mon Sep 17 00:00:00 2001 From: "Jiang, Yanbing" Date: Tue, 14 May 2024 00:08:28 -0700 Subject: [PATCH] Add CPU support and update README --- README.md | 4 ++++ experiments/README.md | 21 ++++++++++++---- experiments/eval_combo.py | 44 ++++++++++++++++++++++------------ experiments/run_experiments.py | 10 +++++--- 4 files changed, 57 insertions(+), 22 deletions(-) diff --git a/README.md b/README.md index aa1bbba..fe7076a 100644 --- a/README.md +++ b/README.md @@ -17,6 +17,10 @@ For example: ``` pip3 install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu121 ``` +or +``` +pip3 install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cpu +``` Installation instructions vary by platform. Please see the website https://pytorch.org/ diff --git a/experiments/README.md b/experiments/README.md index dadbd87..8c8fd6d 100644 --- a/experiments/README.md +++ b/experiments/README.md @@ -37,6 +37,7 @@ These experiments were run on an Amazon p4d.24xlarge instance. See the Product - 1152 GiB of RAM - Software +Meanwhile, these experiments (fp32, bf16, compile, SDPA, Triton, NT) can run on CPU platform as well. Experiment results will be shown in the near future. ### Versions @@ -47,11 +48,17 @@ These experiments were run on an Amazon p4d.24xlarge instance. See the Product ### Installation instructions ``` -$ conda create -n nightly20231117py310 -$ conda activate nightly20231117py310 +$ conda create -n nightlypy310 +$ conda activate nightlypy310 $ conda install python=3.10 -$ pip install https://download.pytorch.org/whl/nightly/cu121/torch-2.2.0.dev20231117%2Bcu121-cp310-cp310-linux_x86_64.whl -$ pip install https://download.pytorch.org/whl/nightly/cu121/torchvision-0.17.0.dev20231117%2Bcu121-cp310-cp310-linux_x86_64.whl +For GPU, +- $ pip install https://download.pytorch.org/whl/nightly/cu121/torch-2.2.0.dev20231117%2Bcu121-cp310-cp310-linux_x86_64.whl +- $ pip install https://download.pytorch.org/whl/nightly/cu121/torchvision-0.17.0.dev20231117%2Bcu121-cp310-cp310-linux_x86_64.whl +For CPU, +- $ pip install https://download.pytorch.org/whl/nightly/cpu/torch-2.4.0.dev20240509%2Bcpu-cp310-cp310-linux_x86_64.whl +- $ pip install https://download.pytorch.org/whl/nightly/cpu/torchvision-0.19.0.dev20240509%2Bcpu-cp310-cp310-linux_x86_64.whl +- $ pip install triton + $ git clone https://github.com/cpuhrsch/segment-anything.git $ cd segment-anything $ pip install -e . @@ -66,10 +73,16 @@ If you plan to run the scripts that run the experiments from segment-anything-fa ### How to run experiments +For GPU platform, ``` $ python run_experiments.py 16 vit_b --run-experiments --num-workers 32 ``` +For CPU platform, set SEGMENT_ANYTHING_FAST_USE_FLASH_4 as 0, since Custom flash attention kernels were written specifically for A100. +``` +$ SEGMENT_ANYTHING_FAST_USE_FLASH_4=0 python run_experiments.py 16 vit_b --run-experiments --num-workers 32 --device cpu +``` + If at any point you run into issue, please note that you can increase verbosity by adding `--capture_output False` to above command. Also, please don't hesitate to open an issue. diff --git a/experiments/eval_combo.py b/experiments/eval_combo.py index 30f6edc..d0ca231 100644 --- a/experiments/eval_combo.py +++ b/experiments/eval_combo.py @@ -5,6 +5,7 @@ from data import build_data, setup_coco_img_ids import math import segment_anything_fast +import time torch._dynamo.config.cache_size_limit = 50000 @@ -64,10 +65,13 @@ def build_results_batch_nested(predictor, batch, batch_size, pad_input_image_bat # We explicitly exclude data transfers from the timing to focus # only on the kernel performance. # Next we synchronize and set two events to start timing. - torch.cuda.synchronize() - start_event = torch.cuda.Event(enable_timing=True) - end_event = torch.cuda.Event(enable_timing=True) - start_event.record() + if torch.cuda.is_available(): + torch.cuda.synchronize() + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + start_event.record() + else: + t0 = time.time() with torch.autograd.profiler.record_function("timed region"): with torch.autograd.profiler.record_function("image encoder"): @@ -93,9 +97,12 @@ def build_results_batch_nested(predictor, batch, batch_size, pad_input_image_bat # the amount of time spent on the GPU. This is a fairly tight measurement # around the launched GPU kernels and excludes data movement from host # to device. - end_event.record() - torch.cuda.synchronize() - elapsed_time = start_event.elapsed_time(end_event) + if torch.cuda.is_available(): + end_event.record() + torch.cuda.synchronize() + elapsed_time = start_event.elapsed_time(end_event) + else: + elapsed_time = time.time() - t0 return sum(result_batch, []), orig_input_image_batch_size, elapsed_time def build_results_batch(predictor, batch, batch_size, pad_input_image_batch): @@ -123,10 +130,13 @@ def build_results_batch(predictor, batch, batch_size, pad_input_image_batch): # We explicitly exclude data transfers from the timing to focus # only on the kernel performance. # Next we synchronize and set two events to start timing. - torch.cuda.synchronize() - start_event = torch.cuda.Event(enable_timing=True) - end_event = torch.cuda.Event(enable_timing=True) - start_event.record() + if torch.cuda.is_available(): + torch.cuda.synchronize() + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + start_event.record() + else: + t0 = time.time() with torch.autograd.profiler.record_function("timed region"): with torch.autograd.profiler.record_function("image encoder"): @@ -157,9 +167,12 @@ def build_results_batch(predictor, batch, batch_size, pad_input_image_batch): # the amount of time spent on the GPU. This is a fairly tight measurement # around the launched GPU kernels and excludes data movement from host # to device. - end_event.record() - torch.cuda.synchronize() - elapsed_time = start_event.elapsed_time(end_event) + if torch.cuda.is_available(): + end_event.record() + torch.cuda.synchronize() + elapsed_time = start_event.elapsed_time(end_event) + else: + elapsed_time = time.time() - t0 return result_batch, orig_input_image_batch_size, elapsed_time @@ -290,6 +303,7 @@ def run( memory_path=None, use_local_sam_fork=False, use_compiler_settings=False, + device="cuda" ): from torch._inductor import config as inductorconfig inductorconfig.triton.unique_kernel_names = True @@ -327,7 +341,7 @@ def run( else: from segment_anything import sam_model_registry, SamPredictor checkpoint_path = model_type_to_checkpoint[sam_model_type] - sam = sam_model_registry[sam_model_type](checkpoint=checkpoint_path).cuda() + sam = sam_model_registry[sam_model_type](checkpoint=checkpoint_path).to(torch.device(device)) predictor = SamPredictor(sam) from segment_anything_fast import tools diff --git a/experiments/run_experiments.py b/experiments/run_experiments.py index ece8f1b..b37e3fa 100755 --- a/experiments/run_experiments.py +++ b/experiments/run_experiments.py @@ -45,7 +45,8 @@ def run_experiment(experiments_data, limit=None, profile_path=None, profile_top=False, - memory_path=None): + memory_path=None, + device="cuda"): root_cmd = ["python", "eval_combo.py", "--coco_root_dir", f"{experiments_data}/datasets/coco2017", @@ -84,6 +85,7 @@ def run_experiment(experiments_data, args = args + ["--memory-path", memory_path] if extra_args is None: extra_args = [] + args = args + ["--device", device] args = args + extra_args if print_header: args = args + ["--print_header", "True"] @@ -145,7 +147,8 @@ def run(batch_size, num_workers=32, print_header=True, capture_output=True, - local_fork_only=False): + local_fork_only=False, + device="cuda"): assert model == "vit_b" or model == "vit_h" @@ -155,7 +158,8 @@ def run(batch_size, model, batch_size=batch_size, num_workers=num_workers, - capture_output=capture_output) + capture_output=capture_output, + device=device) print_header = True if run_traces: