Skip to content

Commit

Permalink
Merge branch 'main' into add_coat_optimizer
Browse files Browse the repository at this point in the history
  • Loading branch information
MirMustafaAli committed Dec 15, 2024
2 parents 1f8f153 + 46b8796 commit f9d0aa1
Show file tree
Hide file tree
Showing 116 changed files with 5,421 additions and 1,376 deletions.
2 changes: 2 additions & 0 deletions .github/pytorch-probot.yml
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
mergebot: True
ciflow_push_tags:
- ciflow/benchmark
49 changes: 49 additions & 0 deletions .github/workflows/dashboard_perf_test.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
name: A100-perf-nightly

on:
push:
tags:
- ciflow/benchmark/*
workflow_dispatch:
schedule:
- cron: 0 7 * * 0-6

jobs:
benchmark:
runs-on: linux.aws.a100
strategy:
matrix:
torch-spec:
- '--pre torch --index-url https://download.pytorch.org/whl/nightly/cu124'
steps:
- uses: actions/checkout@v3

- name: Setup miniconda
uses: pytorch/test-infra/.github/actions/setup-miniconda@main
with:
python-version: "3.9"

- name: Run benchmark
shell: bash
run: |
set -eux
${CONDA_RUN} python -m pip install --upgrade pip
${CONDA_RUN} pip install ${{ matrix.torch-spec }}
${CONDA_RUN} pip install -r dev-requirements.txt
${CONDA_RUN} pip install .
export CHECKPOINT_PATH=checkpoints
export MODEL_REPO=meta-llama/Llama-2-7b-chat-hf
${CONDA_RUN} python scripts/download.py --repo_id meta-llama/Llama-2-7b-chat-hf --hf_token ${{ secrets.HUGGING_FACE_HUB_TOKEN }}
${CONDA_RUN} python scripts/convert_hf_checkpoint.py --checkpoint_dir "${CHECKPOINT_PATH}/${MODEL_REPO}"
mkdir -p ${{ runner.temp }}/benchmark-results
${CONDA_RUN} python torchao/_models/llama/generate.py --checkpoint_path "${CHECKPOINT_PATH}/${MODEL_REPO}/model.pth" --compile --output_json_path ${{ runner.temp }}/benchmark-results/benchmark-results.json
- name: Upload the benchmark results to OSS benchmark database for the dashboard
uses: pytorch/test-infra/.github/actions/upload-benchmark-results@main
with:
benchmark-results-dir: ${{ runner.temp }}/benchmark-results
dry-run: false
schema-version: v3
github-token: ${{ secrets.GITHUB_TOKEN }}
1 change: 1 addition & 0 deletions .github/workflows/regression_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ jobs:
torch-spec: 'torch==2.5.1 --index-url https://download.pytorch.org/whl/cu121'
gpu-arch-type: "cuda"
gpu-arch-version: "12.1"

- name: CPU 2.3
runs-on: linux.4xlarge
torch-spec: 'torch==2.3.0 --index-url https://download.pytorch.org/whl/cpu'
Expand Down
5 changes: 4 additions & 1 deletion .github/workflows/ruff_linter.yml
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,10 @@ jobs:
# please be careful when using this large changes means everyone needs to rebase
ruff check --isolated --select F821,F823,W191
ruff check --select F,I
ruff format --check
ruff format --check || {
echo "Ruff check failed, please try again after running 'ruff format'."
exit 1
}
- name: Apply fixes to PR
if: github.event_name == 'workflow_dispatch'
Expand Down
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ We're also fortunate to be integrated into some of the leading open-source libra
3. Mobius HQQ backend leveraged our int4 kernels to get [195 tok/s on a 4090](https://github.com/mobiusml/hqq#faster-inference)
4. [TorchTune](https://github.com/pytorch/torchtune) for our QLoRA and QAT recipes
5. [torchchat](https://github.com/pytorch/torchchat) for post training quantization
6. [SGLang](https://github.com/sgl-project/sglang/pull/1341) for LLM inference quantization
6. SGLang for LLM serving: [usage](https://github.com/sgl-project/sglang/blob/4f2ee48ed1c66ee0e189daa4120581de324ee814/docs/backend/backend.md?plain=1#L83) and the major [PR](https://github.com/sgl-project/sglang/pull/1341).

## Videos
* [Keynote talk at GPU MODE IRL](https://youtu.be/FH5wiwOyPX4?si=VZK22hHz25GRzBG1&t=1009)
Expand All @@ -205,4 +205,5 @@ If you find the torchao library useful, please cite it in your work as below.
license = {BSD-3-Clause},
month = oct,
year = {2024}
}
```
34 changes: 32 additions & 2 deletions benchmarks/float8/profile_linear_float8.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import copy
import io
import functools
import os
import random
from contextlib import nullcontext, redirect_stdout
Expand All @@ -22,6 +23,11 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.checkpoint import (
checkpoint,
create_selective_checkpoint_contexts,
CheckpointPolicy,
)
from torchao.float8.config import (
CastConfig,
Float8LinearConfig,
Expand Down Expand Up @@ -254,6 +260,22 @@ def profile_function(
return prof


# set up AC for max(abs(tensor))
# context: https://pytorch.org/docs/stable/checkpoint.html#torch.utils.checkpoint.create_selective_checkpoint_contexts
ops_to_save = [
torch.ops.aten.abs.default,
torch.ops.aten.max.default,
]

def policy_fn(ctx, op, *args, **kwargs):
if op in ops_to_save:
return CheckpointPolicy.MUST_SAVE
else:
return CheckpointPolicy.PREFER_RECOMPUTE

context_fn = functools.partial(create_selective_checkpoint_contexts, policy_fn)


def main(
profile_path_prefix: pathlib.Path,
compile: bool = True,
Expand All @@ -265,6 +287,7 @@ def main(
dtype_filter: str = "both",
add_inductor_metadata_to_trace: bool = True,
enable_sync_amax_history: bool = True,
enable_activation_checkpointing: bool = False,
):
assert model_type in ("linear", "ln_linear", "norm_ffn_norm", "norm_ffn_norm_small"), "unsupported"
assert dtype_filter in ("both", "float8", "bfloat16")
Expand Down Expand Up @@ -294,6 +317,7 @@ def main(
print(f"Compile is set to | {compile}")
print(f"model_type is set to | {model_type}")
print(f"scaling_repr is set to | {scaling_repr}")
print(f"enable_activation_checkpointing is set to {enable_activation_checkpointing}")

device = "cuda"
ref_dtype = torch.bfloat16
Expand Down Expand Up @@ -338,11 +362,17 @@ def main(
convert_to_float8_training(m_float8, config=config)

def ref_forw_backward(x):
out = m_ref(x)
if enable_activation_checkpointing:
out = checkpoint(m_ref, x, use_reentrant=False, context_fn=context_fn)
else:
out = m_ref(x)
out.sum().backward()

def float8_forw(x):
out = m_float8(x)
if enable_activation_checkpointing:
out = checkpoint(m_float8, x, use_reentrant=False, context_fn=context_fn)
else:
out = m_float8(x)
return out

sync_amax_history = sync_float8_amax_and_scale_history
Expand Down
9 changes: 2 additions & 7 deletions docs/source/api_ref_quantization.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@ torchao.quantization
.. autosummary::
:toctree: generated/
:nosignatures:
autoquant


autoquant
quantize_
int8_dynamic_activation_int4_weight
int8_dynamic_activation_int8_weight
Expand All @@ -21,12 +21,9 @@ torchao.quantization
float8_static_activation_float8_weight
uintx_weight_only
fpx_weight_only

to_linear_activation_quantized

swap_linear_with_smooth_fq_linear
smooth_fq_linear_to_inference

choose_qparams_affine
choose_qparams_affine_with_min_max
choose_qparams_affine_floatx
Expand All @@ -37,10 +34,8 @@ torchao.quantization
choose_qparams_and_quantize_affine_hqq
fake_quantize_affine
fake_quantize_affine_cachemask

safe_int_mm
int_scaled_matmul

MappingType
ZeroPointDomain
TorchAODType
Expand Down
7 changes: 5 additions & 2 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
Welcome to the torchao Documentation
=======================================

`**torchao** <https://github.com/pytorch/ao>`__ is a library for custom data types & optimizations. Quantize and sparsify weights, gradients, optimizers & activations for inference and training using native PyTorch. Please checkout torchao `README <https://github.com/pytorch/ao#torchao-pytorch-architecture-optimization>`__ for an overall introduction to the library and recent highlight and updates. The documentation here will focus on 1. API Reference 2. Developer / Researcher Contribution Guide 3. Tutorials.
`torchao <https://github.com/pytorch/ao>`__ is a library for custom data types & optimizations. Quantize and sparsify weights, gradients, optimizers & activations for inference and training using native PyTorch. Please checkout torchao `README <https://github.com/pytorch/ao#torchao-pytorch-architecture-optimization>`__ for an overall introduction to the library and recent highlight and updates. The documentation here will focus on:

1. API Reference
2. Developer Contribution Guide
3. Tutorials

..
.. grid:: 3
Expand Down Expand Up @@ -96,7 +100,6 @@ Welcome to the torchao Documentation
:glob:
:maxdepth: 1
:caption: Tutorials
:hidden:

serialization

22 changes: 21 additions & 1 deletion examples/sam2_amg_server/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,32 @@ The 'ao' mode is a copy of the baseline with modifications to make the code more
### 0. Download checkpoints and install requirements

```
pip install -r requirements.txt
# From the top-level "ao" directory
# If necessary, create and activate a virtual environment
# Ex:
python -m venv venv && source venv/bin/activate
# Install requirements for this example
pip install -r examples/sam2_amg_server/requirements.txt
# If you have an older version of torch in your current environment, uninstall it first
pip uninstall torch
# Install torch nightly
pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu124
# Build ao from source for now
python setup.py develop
# On your mark, get set...
cd examples/sam2_amg_server/
```

Download `sam2.1_hiera_large.pt` from https://github.com/facebookresearch/sam2?tab=readme-ov-file#download-checkpoints and put it into `~/checkpoints/sam2`

### 1. Create a random subset of 1000 images
Using images with corresponding mask annotations, like from the Segment Anything Video (SA-V) [Dataset](https://github.com/facebookresearch/sam2/tree/main/sav_dataset#download-the-dataset) is suggested, to later compare any drop in accuracy using `--furious` (using `torch.float16`).
```
find sav_val -type f > sav_val_image_paths
shuf -n 1000 sav_val_image_paths > sav_val_image_paths_shuf_1000
Expand Down
16 changes: 11 additions & 5 deletions examples/sam2_amg_server/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
from server import model_type_to_paths
from server import MODEL_TYPES_TO_MODEL
from server import set_fast
from server import set_aot_fast
from server import load_aot_fast
from server import set_furious
from torchao._models.sam2.build_sam import build_sam2
from torchao._models.sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator
Expand All @@ -22,17 +24,20 @@ def main_docstring():
"""


def main_headless(checkpoint_path, model_type, input_bytes, points_per_batch=1024, output_format='png', verbose=False, fast=False, furious=False):
def main_headless(checkpoint_path, model_type, input_bytes, points_per_batch=1024, output_format='png', verbose=False, fast=False, furious=False, load_fast=""):
device = "cuda"
sam2_checkpoint, model_cfg = model_type_to_paths(checkpoint_path, model_type)
if verbose:
print(f"Loading model {sam2_checkpoint} with config {model_cfg}")
sam2 = build_sam2(model_cfg, sam2_checkpoint, device=device, apply_postprocessing=False)
mask_generator = SAM2AutomaticMaskGenerator(sam2, points_per_batch=points_per_batch, output_mode="uncompressed_rle")
if fast:
set_fast(mask_generator)
if furious:
set_furious(mask_generator)
if load_fast:
load_aot_fast(mask_generator, load_fast)
if fast:
set_fast(mask_generator, load_fast)

image_tensor = file_bytes_to_image_tensor(input_bytes)
if verbose:
print(f"Loaded image of size {tuple(image_tensor.shape)} and generating mask.")
Expand All @@ -50,7 +55,7 @@ def main_headless(checkpoint_path, model_type, input_bytes, points_per_batch=102
buf.seek(0)
return buf.getvalue()

def main(checkpoint_path, model_type, input_path, output_path, points_per_batch=1024, output_format='png', verbose=False, fast=False, furious=False):
def main(checkpoint_path, model_type, input_path, output_path, points_per_batch=1024, output_format='png', verbose=False, fast=False, furious=False, load_fast=""):
input_bytes = bytearray(open(input_path, 'rb').read())
output_bytes = main_headless(checkpoint_path,
model_type,
Expand All @@ -59,7 +64,8 @@ def main(checkpoint_path, model_type, input_path, output_path, points_per_batch=
output_format=output_format,
verbose=verbose,
fast=fast,
furious=furious)
furious=furious,
load_fast=load_fast)
with open(output_path, "wb") as file:
file.write(output_bytes)

Expand Down
Loading

0 comments on commit f9d0aa1

Please sign in to comment.