diff --git a/.github/workflows/ruff_linter.yml b/.github/workflows/ruff_linter.yml index 027279721..2e2f45bad 100644 --- a/.github/workflows/ruff_linter.yml +++ b/.github/workflows/ruff_linter.yml @@ -22,7 +22,7 @@ jobs: permissions: contents: write pull-requests: write - + strategy: matrix: python-version: ["3.9"] @@ -33,35 +33,35 @@ jobs: PR_URL=${{ github.event.inputs.pr_url }} PR_NUMBER=$(echo $PR_URL | grep -oE '[0-9]+$') echo "PR_NUMBER=$PR_NUMBER" >> $GITHUB_ENV - + - uses: actions/checkout@v3 if: github.event_name == 'workflow_dispatch' with: fetch-depth: 0 token: ${{ secrets.GITHUB_TOKEN }} - + - name: Checkout PR branch if: github.event_name == 'workflow_dispatch' run: | gh pr checkout ${{ env.PR_NUMBER }} env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - + - uses: actions/checkout@v3 if: github.event_name != 'workflow_dispatch' with: fetch-depth: 0 - + - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@v3 with: python-version: ${{ matrix.python-version }} - + - name: Install dependencies run: | python -m pip install --upgrade pip pip install ruff==0.6.8 - + - name: Regular lint check if: github.event_name != 'workflow_dispatch' run: | @@ -69,7 +69,7 @@ jobs: # --isolated is used to skip the allowlist at all so this applies to all files # please be careful when using this large changes means everyone needs to rebase ruff check --isolated --select F821,F823,W191 - ruff check --select F,I + ruff check ruff format --check || { echo "Ruff check failed, please try again after running 'ruff format'." exit 1 @@ -80,11 +80,11 @@ jobs: run: | git config --global user.name 'github-actions[bot]' git config --global user.email 'github-actions[bot]@users.noreply.github.com' - + # Apply fixes - ruff check --select F,I --fix + ruff check --fix ruff format . - + # Commit and push if there are changes if [[ -n "$(git status --porcelain)" ]]; then git add . diff --git a/ruff.toml b/ruff.toml index b20cab030..763a9161e 100644 --- a/ruff.toml +++ b/ruff.toml @@ -20,4 +20,5 @@ include = [ "test/prototype/low_bit_optim/**.py", ] +lint.select = ["F", "I"] lint.ignore = ["E731"] diff --git a/test/quantization/test_quant_api.py b/test/quantization/test_quant_api.py index 205ba9129..177c35704 100644 --- a/test/quantization/test_quant_api.py +++ b/test/quantization/test_quant_api.py @@ -15,25 +15,25 @@ import torch from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e from torch.ao.quantization.quantizer.xnnpack_quantizer import ( - get_symmetric_quantization_config, XNNPACKQuantizer, + get_symmetric_quantization_config, ) from torch.testing._internal import common_utils from torch.testing._internal.common_utils import TestCase from torchao import quantize_ -from torchao._models.llama.model import prepare_inputs_for_model, Transformer +from torchao._models.llama.model import Transformer, prepare_inputs_for_model from torchao._models.llama.tokenizer import get_tokenizer from torchao.dtypes import AffineQuantizedTensor from torchao.quantization import LinearActivationQuantizedTensor from torchao.quantization.quant_api import ( + Quantizer, + TwoStepQuantizer, _replace_with_custom_fn_if_matches_filter, int4_weight_only, int8_dynamic_activation_int4_weight, int8_dynamic_activation_int8_weight, int8_weight_only, - Quantizer, - TwoStepQuantizer, ) from torchao.quantization.quant_primitives import MappingType from torchao.quantization.subclass import (