diff --git a/.github/workflows/code-quality.yml b/.github/workflows/code-quality.yml index 3008e2e18..1f52b1cb6 100644 --- a/.github/workflows/code-quality.yml +++ b/.github/workflows/code-quality.yml @@ -23,10 +23,8 @@ jobs: pip install --upgrade pip pip install -r docs/requirements.txt pip install types-docutils types-setuptools tqdm types-tabulate - if [ -f requirements.txt ]; then pip install -r requirements.txt --index-url https://download.pytorch.org/whl/cpu; fi - pip install torchvision --index-url https://download.pytorch.org/whl/cpu - pip install git+https://github.com/pbelevich/transformers.git@compatible_with_pt_master - pip install "black<23" pylint==v3.0.0a5 mypy==v0.960 flake8==3.8.2 pyre-check==0.9.15 ufmt==2.1.0 + if [ -f requirements.txt ]; then pip install -r requirements.txt --find-links https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html; fi + pip install "black<23" pylint==v3.0.0a5 mypy==v0.981 flake8==3.8.2 pyre-check==0.9.15 ufmt==2.1.0 - name: Static Analysis Checks if: always() - run: ./check.sh --keep-going + run: ./check.sh diff --git a/.github/workflows/pippy_gpu_tests.sh b/.github/workflows/pippy_gpu_tests.sh deleted file mode 100755 index 33f9d487f..000000000 --- a/.github/workflows/pippy_gpu_tests.sh +++ /dev/null @@ -1,44 +0,0 @@ -#!/bin/bash - -set -x - -# Print test options -echo "REPLICATE: ${REPLICATE}" -echo "SCHEDULE: ${SCHEDULE}" - -nvidia-smi -nvcc --version -which python3 -python3 --version -which pip3 -pip3 --version - -# Install git -apt-get update -apt-get install git -y - -# Install dependencies -# Turn off progress bar to save logs -pip3 config set global.progress_bar off -pip3 install flake8 pytest pytest-cov numpy -if [ -f requirements.txt ]; then pip3 install -r requirements.txt --find-links https://download.pytorch.org/whl/nightly/cu102/torch_nightly.html; fi - -# Install pavel's huggingface fork -pip3 install git+https://github.com/huggingface/transformers.git@main sentencepiece - -# Install pippy -python3 setup.py install - -set -ex - -# Run all integration tests -python3 test/local_test_forward.py --replicate ${REPLICATE} -s ${SCHEDULE} -python3 test/local_test_forward_backward.py --replicate ${REPLICATE} -s ${SCHEDULE} -python3 test/local_test_compile.py -s ${SCHEDULE} -python3 examples/hf/gpt2/pippy_gpt2.py --replicate ${REPLICATE} -s ${SCHEDULE} -python3 examples/gspmd/pippy_gspmd.py -s ${SCHEDULE} - -# Run flaky integration tests -python3 test/local_test_ddp.py --replicate ${REPLICATE} -s ${SCHEDULE} -python3 test/local_test_forward_hf_gpt2.py --replicate ${REPLICATE} -s ${SCHEDULE} -python3 test/local_test_forward_hf_bert.py --replicate ${REPLICATE} -s ${SCHEDULE} diff --git a/.github/workflows/pippy_tests.yaml b/.github/workflows/pippy_tests.yaml index a0f9978c2..b79122078 100644 --- a/.github/workflows/pippy_tests.yaml +++ b/.github/workflows/pippy_tests.yaml @@ -21,28 +21,26 @@ concurrency: jobs: - pytest_tests: - runs-on: linux.4xlarge - strategy: - matrix: - python-version: ["3.8", "3.9"] - container: - image: python:${{ matrix.python-version }} + # pytest_tests: + # runs-on: linux.4xlarge + # strategy: + # matrix: + # python-version: ["3.8", "3.9"] + # container: + # image: python:${{ matrix.python-version }} - steps: - - uses: actions/checkout@v2 - - name: Install dependencies - run: | - python -m pip install --upgrade pip - pip install flake8 pytest pytest-cov pytest-xdist numpy - if [ -f requirements.txt ]; then pip install -r requirements.txt --find-links https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html; fi - - name: Install pavel's huggingface fork - run: pip install git+https://github.com/huggingface/transformers.git@main sentencepiece six sacremoses - - name: Install pippy - run: "python setup.py install" - - name: Test with pytest - run: | - pytest --cov=pippy --ignore=test/hf_test.py --ignore=test/test_fx.py --ignore=test/test_fx_experimental.py --ignore=test/fx test/ + # steps: + # - uses: actions/checkout@v2 + # - name: Install dependencies + # run: | + # python -m pip install --upgrade pip + # pip install flake8 pytest pytest-cov pytest-xdist numpy + # if [ -f requirements.txt ]; then pip install -r requirements.txt --find-links https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html; fi + # - name: Install pippy + # run: "python setup.py install" + # - name: Test with pytest + # run: | + # pytest --cov=pippy test/ # hf_model_tests: # runs-on: linux.12xlarge @@ -76,10 +74,8 @@ jobs: runs-on: linux.4xlarge strategy: matrix: - python-version: ["3.8", "3.9"] - replicate: ["0", "1"] - schedule: ["FillDrain", "1F1B"] - checkpoint: [ "0", "1" ] + python-version: ["3.9"] + schedule: ["FillDrain"] env: OMP_NUM_THREADS: "1" container: @@ -92,30 +88,26 @@ jobs: python -m pip install --upgrade pip pip install flake8 pytest pytest-cov numpy datasets evaluate scikit-learn sacrebleu if [ -f requirements.txt ]; then pip install -r requirements.txt --find-links https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html; fi - - name: Install pavel's huggingface fork - run: pip install git+https://github.com/huggingface/transformers.git@main sentencepiece six sacremoses - name: Install pippy run: "python setup.py install" + - name: Test forward pipe generation + run: python test/test_pipe.py + - name: Test backward pipe generation + run: python test/test_pipe_bwd.py - name: Run forward-only integration test - run: python test/local_test_forward.py --replicate ${{ matrix.replicate }} -s ${{ matrix.schedule }} --checkpoint ${{ matrix.checkpoint }} - - name: Run forward-only-auto-parallel integration test - run: python test/local_test_forward_auto_parallel.py --replicate ${{ matrix.replicate }} -s ${{ matrix.schedule }} --checkpoint ${{ matrix.checkpoint }} + run: torchrun --nproc-per-node 4 test/test_fwd.py - name: Run forward-loss-backward integration test - run: python test/local_test_forward_backward.py --replicate ${{ matrix.replicate }} -s ${{ matrix.schedule }} --checkpoint ${{ matrix.checkpoint }} - - name: Run null_coalesce_accumulate integration test - run: python test/local_test_null_coalesce_accumulate.py --replicate ${{ matrix.replicate }} -s ${{ matrix.schedule }} - - name: Run PP + DDP test - run: python test/local_test_ddp.py --replicate ${{ matrix.replicate }} -s ${{ matrix.schedule }} --checkpoint ${{ matrix.checkpoint }} + run: torchrun --nproc-per-node 4 test/test_bwd.py --schedule ${{ matrix.schedule }} + # - name: Run null_coalesce_accumulate integration test + # run: python test/local_test_null_coalesce_accumulate.py --replicate ${{ matrix.replicate }} --schedule ${{ matrix.schedule }} + # - name: Run PP + DDP test + # run: python test/local_test_ddp.py --replicate ${{ matrix.replicate }} --schedule ${{ matrix.schedule }} --checkpoint ${{ matrix.checkpoint }} #- name: Run HF BERT forward-only integration test - # run: python test/local_test_forward_hf_bert.py --replicate ${{ matrix.replicate }} -s ${{ matrix.schedule }} --checkpoint ${{ matrix.checkpoint }} - - name: Run HF GPT2 forward-only integration test - run: python test/local_test_forward_hf_gpt2.py --replicate ${{ matrix.replicate }} -s ${{ matrix.schedule }} --checkpoint ${{ matrix.checkpoint }} - - name: Run visualizer test - run: python test/local_test_visualizer.py --replicate ${{ matrix.replicate }} -s ${{ matrix.schedule }} - - name: Run auto-split test - run: python test/local_test_autosplit.py --replicate ${{ matrix.replicate }} -s ${{ matrix.schedule }} - - name: Run compile test - run: python test/local_test_compile.py -s ${{ matrix.schedule }} --checkpoint ${{ matrix.checkpoint }} + # run: python test/local_test_forward_hf_bert.py --replicate ${{ matrix.replicate }} --schedule ${{ matrix.schedule }} --checkpoint ${{ matrix.checkpoint }} + # - name: Run HF GPT2 forward-only integration test + # run: python test/local_test_forward_hf_gpt2.py --replicate ${{ matrix.replicate }} --schedule ${{ matrix.schedule }} --checkpoint ${{ matrix.checkpoint }} + # - name: Run auto-split test + # run: python test/local_test_autosplit.py --replicate ${{ matrix.replicate }} --schedule ${{ matrix.schedule }} # hf_examples_set1: # runs-on: linux.12xlarge @@ -145,11 +137,11 @@ jobs: # git submodule update --init test/minGPT # python test/min_gpt_tracing.py # - name: Run GPT2 example - # run: python examples/hf/gpt2/pippy_gpt2.py -s ${{ matrix.schedule }} + # run: python examples/hf/gpt2/pippy_gpt2.py --schedule ${{ matrix.schedule }} # - name: Run BERT example - # run: python examples/hf/bert/pippy_bert.py -s ${{ matrix.schedule }} + # run: python examples/hf/bert/pippy_bert.py --schedule ${{ matrix.schedule }} # - name: Run T5 example - # run: python examples/hf/t5/pippy_t5.py -s ${{ matrix.schedule }} + # run: python examples/hf/t5/pippy_t5.py --schedule ${{ matrix.schedule }} # - name: "HF Translation: fine-tune T5 model translation English to Romanian" # run: > # python examples/hf/translation/run_translation.py --model_name_or_path t5-small --do_train --source_lang en --target_lang ro --source_prefix "translate English to Romanian: " --dataset_name wmt16 --dataset_config_name ro-en --output_dir /tmp/tst-translation --per_device_train_batch_size=8 --per_device_eval_batch_size=8 --overwrite_output_dir --predict_with_generate --max_steps=10 --dp_group_size=1 --pp_group_size=8 @@ -186,84 +178,6 @@ jobs: # - name: "HF Text classification: fine-tune BERT on the GLUE benchmark" # run: python examples/hf/text-classification/run_glue.py --dp_group_size=2 --pp_group_size=8 --model_name_or_path bert-base-cased --task_name mrpc --do_train --do_eval --max_seq_length 128 --per_device_train_batch_size 32 --learning_rate 2e-5 --num_train_epochs 3 --output_dir /tmp/mrpc/ --max_steps=3 --overwrite_output_dir - integration_test_gpu: - runs-on: linux.16xlarge.nvidia.gpu - strategy: - matrix: - python-version: ["3.8"] - replicate: ["0", "1"] - schedule: ["FillDrain", "1F1B"] - env: - DOCKER_IMAGE: qts8n/cuda-python:devel - PIPPY_ROOT: /PiPPy - OMP_NUM_THREADS: "1" - REPLICATE: ${{ matrix.replicate }} - SCHEDULE: ${{ matrix.schedule }} - - steps: - - name: Clean working directory - shell: bash - run: | - sudo rm -rf /home/ec2-user/actions-runner/_work/PiPPy/PiPPy/* || true - - uses: actions/checkout@v2 - - name: Install nvidia driver, nvidia-docker runtime, set GPU_FLAG - uses: pytorch/test-infra/.github/actions/setup-nvidia@main - - name: Pull Docker image - run: | - retry () { - "$@" || (sleep 1 && "$@") || (sleep 2 && "$@") - } - retry docker pull "${DOCKER_IMAGE}" - - name: Test docker run - run: | - set -x - # shellcheck disable=SC2086,SC2090 - container_name=$(docker run \ - --gpus all \ - --shm-size=1g --ulimit memlock=-1 \ - -e OMP_NUM_THREADS \ - -e REPLICATE \ - -e SCHEDULE \ - --tty \ - --detach \ - -v "$(pwd):${PIPPY_ROOT}" \ - -w "${PIPPY_ROOT}" \ - "${DOCKER_IMAGE}" - ) - # Run GPU tests and return error signal from docker - docker exec -t -w "${PIPPY_ROOT}" "${container_name}" bash -c "bash .github/workflows/pippy_gpu_tests.sh; exit \$?" - - name: Chown workspace - if: always() - run: | - # Ensure the working directory gets chowned back to the current user - docker run --rm -v "$(pwd):${PIPPY_ROOT}" -w "${PIPPY_ROOT}" "${DOCKER_IMAGE}" chown -R "$(id -u):$(id -g)" . - - name: Kill containers, clean up images - if: always() - run: | - # ignore expansion of "docker ps -q" since it could be empty - # shellcheck disable=SC2046 - docker stop $(docker ps -q) || true - # Prune all of the docker images - docker system prune -af - - programming_model_tests: - runs-on: linux.4xlarge - strategy: - matrix: - python-version: ["3.9"] - container: - image: python:${{ matrix.python-version }} - - steps: - - uses: actions/checkout@v2 - - name: Install dependencies - run: | - python -m pip install --upgrade pip - pip install numpy datasets evaluate scikit-learn sacrebleu - if [ -f requirements.txt ]; then pip install --pre -r requirements.txt --find-links https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html; fi - - name: Install pippy - run: "python setup.py install" - - name: Test PiPPy + Dynamo example - run: python examples/TorchDynamo/pippy_dynamo.py - - name: Run PiPPy in GSPMD style - run: python examples/gspmd/pippy_gspmd.py + # TODO: + # Update GPU test to use template in: + # https://github.com/pytorch/test-infra/wiki/Writing-generic-CI-jobs diff --git a/check.sh b/check.sh index 6f74be023..c20c968bf 100755 --- a/check.sh +++ b/check.sh @@ -4,7 +4,7 @@ function usage() { echo 2>&1 < torch.fx.GraphModule: + logger.info("[PiPPy] Tracing model ...") + try: + torch._dynamo.allow_in_graph(pipe_split) + traced: torch.fx.GraphModule = torch._export._export_to_torch_ir( + mod, + example_args, + example_kwargs, + constraints, + ) + if split_policy is not None: + traced = split_policy(traced) + finally: + torch._dynamo.disallow_in_graph(pipe_split) + return traced + @staticmethod def from_tracing( mod: torch.nn.Module, - multi_use_param_spec: Optional[MultiUseParamSpec] = None, - tracer=None, - output_loss_value_spec=None, - deep_copy_module=False, + num_chunks: int, + example_args: Tuple[Any, ...], + example_kwargs: Optional[Dict[str, Any]] = None, split_policy: Optional[ - Callable[[pippy.fx.GraphModule], pippy.fx.GraphModule] + Callable[[fx.GraphModule], fx.GraphModule] ] = None, - return_to_0: bool = True, - **kwargs, + args_chunk_spec: Optional[Tuple[Any, ...]] = None, + kwargs_chunk_spec: Optional[Dict[str, Any]] = None, + output_chunk_spec=None, + constraints: Optional[List[Constraint]] = None, ): - # TODO: abstract partitioning policy + # If a param will be used in multiple pipeline stages, we default the strategy to REPLICATE'ing the param across + # stages instead of TRANSMIT'ting it + multi_use_param_spec = MultiUseParameterConfig.REPLICATE + + # Figure out which output is loss from output_chunk_spec + output_loss_value_spec: Any = None + if output_chunk_spec is not None: + output_loss_value_spec = fx.node.map_aggregate( + output_chunk_spec, lambda v: isinstance(v, LossReducer) + ) - global _pipeline_tracer - old__pipeline_tracer = _pipeline_tracer - _pipeline_tracer = tracer or pippy.fx.Tracer() - try: - # TODO: tracing policy - if deep_copy_module: - mod = copy.deepcopy( - mod - ) # because further pipe building activities can modify mod - graph = _pipeline_tracer.trace(mod, **kwargs) - if isinstance(graph, torch_fx.Graph): - # HACK to convert torch.fx.Graph to pippy.fx.Graph - g_new = pippy.fx.Graph() - val_map: Dict[pippy.fx.Node, pippy.fx.Node] = {} - out = g_new.graph_copy(graph, val_map, False) - g_new.output(out) - - # `pippy.fx.map_arg` doesn't work on torch.fx.Node instances; - # do it here - def remap_vals(n): - return val_map[n] - - for node in g_new.nodes: - node.args = torch_fx.map_arg(node.args, remap_vals) - node.kwargs = torch_fx.map_arg(node.kwargs, remap_vals) - graph = g_new - - traced = pippy.fx.GraphModule(mod, graph) - finally: - _pipeline_tracer = old__pipeline_tracer + # Get split example inputs + if example_kwargs is None: + # Needed by `split_args_kwargs_into_chunks` + example_kwargs = {} + + args_split, kwargs_split = split_args_kwargs_into_chunks( + example_args, + example_kwargs, + num_chunks, + args_chunk_spec, + kwargs_chunk_spec, # TODO: merge into args_chunk_spec + ) - if split_policy is not None: - traced = split_policy(traced) + # Trace with export + traced = Pipe._trace_with_export( + mod, + example_args=args_split[0], + example_kwargs=kwargs_split[0], + constraints=constraints, + split_policy=split_policy, + ) - return Pipe._from_traced( + pipe = Pipe._from_traced( mod, traced, multi_use_param_spec, output_loss_value_spec=output_loss_value_spec, - return_to_0=return_to_0, ) + logger.info(pipe.split_gm) + if PIPPY_VERBOSITY == "DEBUG": + pipe.split_gm.graph.print_tabular() + + pipe.num_chunks = num_chunks + pipe.args_chunk_spec = args_chunk_spec + pipe.kwargs_chunk_spec = kwargs_chunk_spec + pipe.output_chunk_spec = output_chunk_spec + + # Shape propagation to get shapes of all tensors + PipeFakeTensorProp(pipe.split_gm).run() + for node in pipe.split_gm.graph.nodes: + logger.debug( + f"{node.name}, " + f"{node.meta['example_value'] if 'example_value' in node.meta else 'None'}", + ) + + return pipe + def __str__(self): return self.split_gm.__str__() def __repr__(self): return self.split_gm.__repr__() - # Conditoinal variable to ensure `defer_stage_init` is called before other callers call `materialize_stage` - # TODO: cleaner approach - _stage_init_lock = threading.Lock() - stage_init_cv = threading.Condition(_stage_init_lock) - - def defer_stage_init( - self, - device: torch.device, - index_filename: Union[str, os.PathLike] = None, - dtype: torch.dtype = None, - checkpoint_prefix: str = None, - ): - def materialize_stage(target: str) -> torch.nn.Module: - logging.info(f"Materializing {target} on {device}") - submodule = self.split_gm.get_submodule(target) - if index_filename is not None: - submodule = load_checkpoint( - model=submodule, - index_filename=index_filename, - device=device, - dtype=dtype, - checkpoint_prefix=checkpoint_prefix, - ) - try: - submodule.to(device) - except Exception: - # Usually `to(device)` fails because there is still some meta - # tensor in submodule, potentially because the checkpoint load - # did not cover that parameter. And the reason is often that - # that parameter shares weight with another parameter. - for name, param in submodule.named_parameters(): - if param.device == torch.device("meta"): - logging.warning(f"{name} is a meta tensor") - # Re-throw the original exception - raise - return submodule - - with Pipe.stage_init_cv: - setattr(Pipe, "materialize_stage", materialize_stage) - Pipe.stage_init_cv.notify() - - @staticmethod - def is_stage_init_deferred(): - return hasattr(Pipe, "materialize_stage") - - def export(self, stage_id: int) -> torch.nn.Module: - split_gm_children = list(self.split_gm.children()) - submod = split_gm_children[stage_id] - - # HACK: reusing defer init path in PipelineDriver - def exported_stage(target: str) -> torch.nn.Module: - logging.info(f"Retrieving exported {target}") - assert self.split_gm.get_submodule(target) is submod - return submod - - with Pipe.stage_init_cv: - if not hasattr(Pipe, "materialize_stage"): - setattr(Pipe, "materialize_stage", exported_stage) - Pipe.stage_init_cv.notify() - - return submod - class PipeSplitWrapper(torch.nn.Module): class SplitPoint(Enum): @@ -1198,18 +1157,30 @@ def annotate_split_points( setattr(predecessor_module, atoms[-1], wrapped_mod) -class PiPPyShapeProp(shape_prop.ShapeProp): +class PipeFakeTensorProp(Interpreter): def __init__( - self, module: pippy.fx.GraphModule, garbage_collect_values: bool = True + self, module: fx.GraphModule, garbage_collect_values: bool = True ): super().__init__(module, garbage_collect_values) self.stop_prop = False - def run_node(self, n: pippy.fx.Node) -> Any: - if (n.op, n.target) == ("call_function", stage_backward): + def run(self): + inp = tuple( + node.meta["val"] + for node in self.module.graph.nodes + if node.op == "placeholder" + ) + super().run(*inp) + + def run_node(self, node): + # Do not propagate through the stage backward call because it won't work + if (node.op, node.target) == ("call_function", stage_backward): self.stop_prop = True if self.stop_prop: return None - return super().run_node(n) + res = super().run_node(node) + node.meta["example_value"] = res + node.meta["val"] = res + return res diff --git a/pippy/ModelSplit.py b/pippy/ModelSplit.py index 74490b86c..1fd5e32fe 100644 --- a/pippy/ModelSplit.py +++ b/pippy/ModelSplit.py @@ -3,8 +3,8 @@ from typing import Callable, Dict, List, Tuple import torch +import torch.fx as fx -import pippy.fx from pippy.IR import pipe_split """ @@ -16,13 +16,13 @@ def _analyze_node_size( - gm: pippy.fx.GraphModule, -) -> Dict[pippy.fx.Node, Dict[str, int]]: + gm: fx.GraphModule, +) -> Dict[fx.Node, Dict[str, int]]: # state_dict helps us to get parameter sizes state_dict = gm.state_dict() # Function Parameter Usage - node_param_sizes: Dict[pippy.fx.Node, Dict[str, int]] = {} + node_param_sizes: Dict[fx.Node, Dict[str, int]] = {} for node in gm.graph.nodes: if node.op == "get_attr": # a parameter node param_name = node.target @@ -53,7 +53,7 @@ def _analyze_node_size( """ Split a model based on a maximum number of parameter and buffer elements a pipeline stage can have Input: - gm: `pippy.fx.GraphModule` to split + gm: `fx.GraphModule` to split threshold: maximum number of parameter and buffer elements a stage can have max_stages: maximum number of stages; default = -1, no limit Output: @@ -64,15 +64,15 @@ def _analyze_node_size( def _split_on_size_threshold_with_max_stages( - gm: pippy.fx.GraphModule, + gm: fx.GraphModule, threshold: int, max_stages: int = -1, -) -> Tuple[pippy.fx.GraphModule, int]: +) -> Tuple[fx.GraphModule, int]: # Analyze size of parameters/buffers used by each node in the graph node_param_sizes = _analyze_node_size(gm) # Record split positions - insert_before_nodes: List[pippy.fx.Node] = [] + insert_before_nodes: List[fx.Node] = [] def new_stage_before(node): insert_before_nodes.append(node) @@ -150,10 +150,10 @@ def new_stage_before(node): def split_on_size_threshold( threshold: int, -) -> Callable[[pippy.fx.GraphModule], pippy.fx.GraphModule]: +) -> Callable[[fx.GraphModule], fx.GraphModule]: def _split_on_size_threshold( - gm: pippy.fx.GraphModule, - ) -> pippy.fx.GraphModule: + gm: fx.GraphModule, + ) -> fx.GraphModule: gm, _ = _split_on_size_threshold_with_max_stages(gm, threshold) return gm @@ -172,10 +172,10 @@ def _split_on_size_threshold( def split_into_equal_size( nstages: int = 1, -) -> Callable[[pippy.fx.GraphModule], pippy.fx.GraphModule]: +) -> Callable[[fx.GraphModule], fx.GraphModule]: def _split_into_nstages_equal_size( - gm: pippy.fx.GraphModule, - ) -> pippy.fx.GraphModule: + gm: fx.GraphModule, + ) -> fx.GraphModule: param_size = 0 for param in gm.parameters(): param_size += param.numel() diff --git a/pippy/PipelineDriver.py b/pippy/PipelineDriver.py deleted file mode 100644 index 26dcda518..000000000 --- a/pippy/PipelineDriver.py +++ /dev/null @@ -1,2281 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates -import logging -import operator -import threading -import time -import warnings -from enum import Enum -from inspect import Parameter, Signature -from typing import Any, Callable, Dict, List, Optional, Tuple - -import torch -import torch.distributed.rpc as rpc - -import pippy.fx -from pippy.backward import ( - _null_coalesce_accumulate, - stage_backward, - sync_barrier, -) -from pippy.events import Allocator, Event, EventRecorder, EventsContext -from pippy.fx.passes import shape_prop - -from pippy.IR import Pipe -from pippy.microbatch import ( - gen_output_chunk_spec, - LossReducer, - merge_chunks, - split_args_kwargs_into_chunks, - sum_reducer, -) -from pippy.utils import flatten_args_detach - -# TODO: Define the strategy for replicating the computation. In particular, we will likely make the assumption -# that the operations in the program are batch-wise commutative (my term), i.e. we can guarantee equivalence -# with splitting up the operation along the batch dimension, applying the computation to those sub-batches, -# then merging them back together via concatenation. We should provide a crisp contract surrounding this - -# ===== Questions to Answer ===== -# 1. When does each stage happen? -# micro-batch splitting: per-invocation or with one fixed chunk size? -# physical compilation: this depends on micro-batch splitting (for e.g. scheduling -# so it would have to be ordered after micro-batch splitting -# runtime: obviously needs to happen at runtime -# -# Conceptually: -# -# replicated_programs : List[IR] = replicate(chunks) -# schedule : List[IR] = schedule(replicated_programs) -# for device_schedule in schedule: -# for instruction in device_schedule: -# invoke(rank, instruction) -# -# `chunks` is the only external dependency that could potentially be used per-invocation. -# Do we want to: -# a) Take it as a per-invocation parameter and re-do compilation each time? (-overhead) -# b) Take it as a one-time initialization parameter and consistently split each -# batch into a single `chunks` value (-flexibility) -# c) Allow it to be dynamic but cache compiled policies? -# -# Decision: We can easily convert (a) to (c), so let's go with (a). - -DEBUG = False - - -class Phase(Enum): - FORWARD = 0 - BACKWARD = 1 - ACCUMULATE_GRAD = 2 - SYNC_BARRIER = 3 - - -# TODO: do we need this? -class SchedState(Enum): - WAITING = 0 - READY = 1 - RUNNING = 2 - DONE = 3 - - -def event_name(ph, stage_id, mbid): - phase_to_short_str = { - Phase.FORWARD: "F", - Phase.BACKWARD: "B", - Phase.ACCUMULATE_GRAD: "A", - Phase.SYNC_BARRIER: "S", - } - return f"{phase_to_short_str[ph]}_{stage_id},{mbid}" - - -def event_id(ph, stage_id, mbid, bid): - return f"{event_name(ph, stage_id, mbid)},{bid}" - - -def prev_event_name(ph: Any, all_stages: List[int], stage_id: int, mbid: Any): - i = all_stages.index(stage_id) - if ph == Phase.FORWARD and i > 0: - prev_stage = all_stages[i - 1] - return event_name(ph, prev_stage, mbid) - elif ph == Phase.BACKWARD and i < len(all_stages) - 1: - next_stage = all_stages[i + 1] - return event_name(ph, next_stage, mbid) - else: - return None - - -def next_event_name(ph: Any, all_stages: List[int], stage_id: int, mbid: Any): - i = all_stages.index(stage_id) - if ph == Phase.FORWARD and i < len(all_stages) - 1: - next_stage = all_stages[i + 1] - return event_name(ph, next_stage, mbid) - elif ph == Phase.BACKWARD and i > 0: - prev_stage = all_stages[i - 1] - return event_name(ph, prev_stage, mbid) if stage_id > 0 else None - else: - return None - - -class WorkItem: - def __init__( - self, - stage_id, - phase, - args, - kwargs, - future, - microbatch_id, - blocked_args_count, - ready_args, - batch_id, - num_microbatches, - state=SchedState.WAITING, - debug_str="", - ): - args_to_fwd = [ - "stage_id", - "phase", - "args", - "kwargs", - "future", - "microbatch_id", - "blocked_args_count", - "ready_args", - "batch_id", - "num_microbatches", - "state", - "debug_str", - ] - - for arg in args_to_fwd: - setattr(self, arg, locals()[arg]) - - stage_id: int - phase: Phase - args: Tuple[Any] - kwargs: Dict[str, Any] - future: torch.futures.Future - microbatch_id: int - - blocked_args_count: int - ready_args: Dict[int, Any] - state: SchedState - debug_str: str - - batch_id: int - num_microbatches: int - - def __str__(self): - return f"WorkItem({self.debug_str})" - - -class ValueReference: - def __init__(self, stage_id, unique_key): - self.stage_id = stage_id - self.unique_key = unique_key - self.meta: Dict[str, Any] = {} - - stage_id: int - unique_key: str - - def __repr__(self): - return f"ValueReference({self.stage_id}, {self.unique_key})" - - -class RefcountedFuture: - future: torch.futures.Future - refcount: int - - def __init__(self, future, refcount): - self.future, self.refcount = future, refcount - - def release(self): - """ - Decrement refcount by 1. Return True if this instance should be freed - """ - assert ( - self.refcount != 0 - ), "Detected reference counting inconsistency. Please report a bug to PiPPy" - self.refcount -= 1 - return self.refcount == 0 - - -class RankWorker(EventRecorder): - """ - RankWorker is the underlying WorkItem processing engine for pipeline stages - resident on this rank. WorkItems of multiple stages would share the same - queue in the RankWorker. RankWorker will also maintain states like the - number of outstanding WorkItems. - - * TODO: in-order execution - * Queueing of jobs and execution schedule, e.g. - * Static Schedules - * Fill-drain (GPipe) pipeline by serializing jobs - * TODO: 1F1B scheduling by serializing jobs and stalling for a specific - phase to come through - * TODO: Interleaved 1F1B (TODO: how to set up these data dependencies) - * Dynamic Schedules - * TODO: Varuna dynamic schedule - * TODO: dynamic scheduling via registers and back-pressure (TODO: how to - specify resource limits and how to implement backpressure?) - """ - - def __init__( - self, - rank, - all_stages, - max_outstanding=None, - pp_rank=None, - _record_mem_dumps=False, - checkpoint=False, - ): - logging.info(f"[{rank}] Instantiating RankWorker") - self.rank = rank - self.all_stages = all_stages - self.rank = rank - self.pp_rank = pp_rank - self._record_mem_dumps = _record_mem_dumps - self.checkpoint = checkpoint - - # Maximum outstanding micro-batches of the pipeline schedule - self.max_outstanding = max_outstanding - # Keeps track of the outstanding micro-batches in current rank executor - self.outstanding = 0 - self.stage_executors: Dict[int, PipeStageExecutor] = {} - self.events: List[Event] = [] - - self.waiting_runlist_lock = threading.Lock() - # self.waiting_runlist (*and the contained WorkItems*) are guarded by - # self.waiting_runlist_lock - self.waiting_runlist: Dict[str, WorkItem] = {} - - self.ready_runlist_lock = threading.Lock() - self.ready_runlist_cv = threading.Condition(self.ready_runlist_lock) - self.ready_runlist: Dict[str, WorkItem] = {} - - self.worker_thread = threading.Thread( - target=self.worker_loop, name=f"worker_{self.rank}", daemon=True - ) - self.worker_thread.start() - - def create_stage_executor(self, stage_id, mod, mod_name): - if stage_id in self.stage_executors: - raise AssertionError( - f"Rank {self.rank} already has stage {stage_id}" - ) - - assert ( - mod is not None or mod_name is not None - ), "PipeStageExecutor requires mod or mod_name" - - if mod is None: - with Pipe.stage_init_cv: - defer_called = Pipe.stage_init_cv.wait_for( - Pipe.is_stage_init_deferred, - timeout=100, # stop waiting after 100s - ) - if not defer_called: - raise AssertionError( - f"Rank {self.rank} did not defer stage {stage_id} initialization " - f"though pipeline driver expect it to do so." - ) - - self.stage_executors[stage_id] = PipeStageExecutor( - stage_id=stage_id, - mod=mod or Pipe.materialize_stage(mod_name), # type: ignore[attr-defined] - rank_worker=self, - _record_mem_dumps=self._record_mem_dumps, - ) - return self.stage_executors[stage_id] - - def enqueue_ready_runlist(self, unique_key, work_item): - with self.ready_runlist_cv: - logging.debug( - f"[{self.rank}] Current ready runlist keys: {self.ready_runlist.keys()}" - ) - self.ready_runlist[unique_key] = work_item - self.ready_runlist_cv.notify() - - def enqueue_waiting_runlist(self, unique_key, work_item): - with self.waiting_runlist_lock: - logging.debug( - f"[{self.rank}] Current waiting runlist keys: {self.waiting_runlist.keys()}" - ) - assert ( - unique_key not in self.waiting_runlist - ), f"key {unique_key} already in waiting runlist {self.waiting_runlist}" - self.waiting_runlist[unique_key] = work_item - - def worker_loop(self): - batch_id_to_remaining_backward_microbatches: Dict[int, int] = {} - while True: - work_item = None - with self.ready_runlist_cv: - while len(self.ready_runlist) == 0: - self.ready_runlist_cv.wait() - - logging.debug( - f"[{self.rank}] Dequeueing workitem from set of {len(self.ready_runlist)}" - ) - # TODO: extra priorities - for key in iter(self.ready_runlist.keys()): - # Skip forward work items if we hit the max outstanding limit - # If there are no other READY WorkItems, the runloop wraps around to the beginning and blocks again, - # waiting for another scheduled WorkItem to wake it back up. This works because the only condition - # that can schedule a WAITING Workitem is if another backward WorkItem executes and reduces the number - # of outstanding mciro-batches; - # If there are other READY WorkItems, the runloop executes as normally processing those - if ( - self.ready_runlist[key].phase == Phase.FORWARD - and self.max_outstanding is not None - and self.outstanding >= self.max_outstanding - ): - continue - work_item = self.ready_runlist.pop(key) - break - - # We may not fetch any actionable work item in the above loop, go - # back to the loop in this case - if work_item is None: - continue - logging.debug( - f"[{self.rank}][{work_item.microbatch_id}] Got WorkItem {work_item}" - ) - - work_item.state = SchedState.RUNNING - args_value_refs = work_item.args - kwargs_value_refs = work_item.kwargs - future = work_item.future - microbatch_id = work_item.microbatch_id - ready_args = work_item.ready_args - phase = work_item.phase - try: - stage_executor = self.stage_executors[work_item.stage_id] - except KeyError: - raise RuntimeError( - f"Rank {self.rank} does not have stage {work_item.stage_id}" - f"Current keys {self.stage_executors.keys()}" - ) - - batch_id = work_item.batch_id - num_microbatches = work_item.num_microbatches - - if batch_id not in batch_id_to_remaining_backward_microbatches: - batch_id_to_remaining_backward_microbatches[ - batch_id - ] = num_microbatches - - start_ts = time.time() - name = event_name( - work_item.phase, work_item.stage_id, work_item.microbatch_id - ) - id = event_id( - work_item.phase, - work_item.stage_id, - work_item.microbatch_id, - work_item.batch_id, - ) - if self._record_mem_dumps: - stage_executor._record_dumps_on_all_peer_executors( - f"M{id}_start", start_ts - ) - - value_ref_arg_idx = 0 - - def retrieve_value_ref_args_by_idx(a): - if isinstance(a, ValueReference) and a.unique_key != "noop": - nonlocal value_ref_arg_idx - val = ready_args[value_ref_arg_idx] - value_ref_arg_idx += 1 - return val - else: - return a - - args = pippy.fx.node.map_aggregate( - args_value_refs, retrieve_value_ref_args_by_idx - ) - kwargs = pippy.fx.node.map_aggregate( - kwargs_value_refs, retrieve_value_ref_args_by_idx - ) - - def forward(args, kwargs, no_grad): - args, flat_args = flatten_args_detach(args) - kwargs, flat_kwargs = flatten_args_detach(kwargs) - # Contains all tensors from args and kwargs, in flattened form - flat_args += flat_kwargs - - logging.info( - f"[{self.rank}] Running forward module for microbatch {work_item.microbatch_id}" # type: ignore[union-attr] - ) - - def forward_maybe_with_ddp(args, kwargs): - if isinstance( - stage_executor.mod, - torch.nn.parallel.distributed.DistributedDataParallel, - ): - with stage_executor.mod.no_sync(): # type: ignore[operator] - out_val = stage_executor.mod(*args, **kwargs) - else: - out_val = stage_executor.mod(*args, **kwargs) - return out_val - - def set_requires_grad(a): - if isinstance(a, torch.Tensor) and a.is_floating_point(): - a.requires_grad_(True) - return a - - def dont_traverse_size(a): - return type(a) != torch.Size - - if no_grad: - with torch.no_grad(): - out_val = forward_maybe_with_ddp(args, kwargs) - out_val = pippy.fx.node.map_aggregate( - out_val, set_requires_grad, dont_traverse_size - ) - else: - with torch.enable_grad(): - out_val = forward_maybe_with_ddp(args, kwargs) - - return out_val, flat_args - - if phase == Phase.BACKWARD: - if self.checkpoint: - logging.debug( - f"[{self.rank}][{work_item.microbatch_id}] Running backward phase. " - f"Rerunning forward because of checkpointing" - ) - f_args, f_kwargs = stage_executor.fwd_cache.pop( - microbatch_id - ) - out_val, flat_tensor_args = forward( - f_args, f_kwargs, no_grad=False - ) - kwargs = dict(kwargs) - kwargs["stage_output"], kwargs["input_values"] = ( - out_val if isinstance(out_val, tuple) else (out_val,), - flat_tensor_args, - ) - else: - logging.debug( - f"[{self.rank}][{work_item.microbatch_id}] Running backward phase. " - f"Retrieving stashed values" - ) - # HACK: here we are directly accessing the saved tensor outputs - # for closed-over outputs so that they still have the grad_fn - # from local autograd. Can we solve this more elegantly? - kwargs = dict(kwargs) - ( - kwargs["stage_output"], - kwargs["input_values"], - ) = stage_executor.fwd_cache.pop(microbatch_id) - - if work_item.phase == Phase.FORWARD: - self.outstanding += 1 - out_val, flat_tensor_args = forward( - args, kwargs, no_grad=self.checkpoint - ) - if self.checkpoint: - stage_executor.fwd_cache[microbatch_id] = args, kwargs - else: - stage_executor.fwd_cache[microbatch_id] = ( - out_val if isinstance(out_val, tuple) else (out_val,), - flat_tensor_args, - ) - - elif work_item.phase == Phase.BACKWARD: - logging.info( - f"[{self.rank}] Running backward for microbatch {work_item.microbatch_id}" - ) - - batch_id_to_remaining_backward_microbatches[batch_id] -= 1 - - if ( - isinstance( - stage_executor.mod, - torch.nn.parallel.distributed.DistributedDataParallel, - ) - and batch_id_to_remaining_backward_microbatches[batch_id] - == 0 - ): - # HACK: reaching into DDP implementation details here. Is there a better way? - stage_executor.mod.reducer.prepare_for_backward( # type: ignore[union-attr, operator] - list( - torch.nn.parallel.distributed._find_tensors( # type: ignore[attr-defined] - kwargs["stage_output"] - ) - ) - ) - - out_val = stage_backward(*args, **kwargs) - - # Schedule forward stage of a new micro-batch - self.outstanding -= 1 - elif work_item.phase == Phase.ACCUMULATE_GRAD: - logging.debug( - f"[{self.rank}][{work_item.microbatch_id}] Running accumulate grad" - ) - out_val = _null_coalesce_accumulate(*args, **kwargs) - elif work_item.phase == Phase.SYNC_BARRIER: - logging.debug( - f"[{self.rank}][{work_item.microbatch_id}] Running sync_barrier" - ) - out_val = sync_barrier(*args, **kwargs) - else: - assert ( - False - ), f"Unrecognized phase {work_item.phase} encountered in execution" - - logging.debug( - f"[{self.rank}][{work_item.microbatch_id}] Populating result of type {type(out_val)} " - f"for {key}" - ) - future.set_result(out_val) - work_item.state = SchedState.DONE - - prev_name = prev_event_name( - work_item.phase, - self.all_stages, - work_item.stage_id, - work_item.microbatch_id, - ) - next_name = next_event_name( - work_item.phase, - self.all_stages, - work_item.stage_id, - work_item.microbatch_id, - ) - finish_ts = time.time() - self.record_event( - rank=self.rank, - start_ts=start_ts, - finish_ts=finish_ts, - id=id, - name=name, - type=work_item.phase, - mbid=work_item.microbatch_id, - ) - self.record_event_dependency( - from_id=prev_name, to_id=name, type="transfer" - ) - self.record_event_dependency( - from_id=name, to_id=next_name, type="transfer" - ) - - if self._record_mem_dumps: - stage_executor._record_dumps_on_all_peer_executors( - f"M{id}_finish", finish_ts - ) - - # For work item marked with runlist_key, update its operand list with value - def update_run_list(self, runlist_key, arg_idx, value): - with self.waiting_runlist_lock: - work_item = self.waiting_runlist[runlist_key] - work_item.ready_args[arg_idx] = value - work_item.blocked_args_count -= 1 - if work_item.blocked_args_count == 0: - with self.ready_runlist_cv: - work_item.state = SchedState.READY - self.ready_runlist[runlist_key] = self.waiting_runlist.pop( - runlist_key - ) - self.ready_runlist_cv.notify() - logging.debug( - f"[{self.rank}][{work_item.microbatch_id}] all operands ready: {runlist_key}" - ) - - -class PipeStageExecutor(EventRecorder): - """ - PipeStageExecutor encapsulates the execution semantics of a fragment of - code on a pipeline stage. PipeStageExecutor handles: - - * Ownership of the stage's module and its recursive submodules/parameters - * Serving as an entrypoint for the driver to push jobs into RankWorker's queue - * TODO: gradient checkpointing - """ - - def __init__(self, stage_id, mod, rank_worker, _record_mem_dumps=False): - logging.info(f"Instantiating PipeStageExecutor for stage {stage_id}") - self.stage_id = stage_id - self.mod = mod - self.rank_worker = rank_worker - # map microbatch ID to list of forward tensor args - self.fwd_cache: Dict[int, Tuple[Any, List[torch.Tensor]]] = {} - - self.value_store_lock = threading.Lock() - self.value_store_cv = threading.Condition(self.value_store_lock) - self.value_store: Dict[str, RefcountedFuture] = {} - - self.peer_executors: Dict[int, torch._C._distributed_rpc.PyRRef] = None # type: ignore[assignment] - self._record_mem_dumps = _record_mem_dumps - - self.optimizer = None - # Used to ensure optimizer is created before we create learning rate scheduler - self.optim_init_lock = threading.Lock() - self.optim_init_cv = threading.Condition(self.optim_init_lock) - - self.lr_scheduler = None - self.device = self._find_mod_device() - - # Send/recv order normalization - self.callee_send_tag: Dict[int, int] = {} # callee stage: tag seq num - self.caller_recv_tag: Dict[int, int] = {} # caller stage: tag seq num - self.callee_send_tag_lock = threading.Lock() - self.caller_recv_tag_lock = threading.Lock() - self.caller_recv_tag_cv = threading.Condition(self.caller_recv_tag_lock) - - def _find_mod_device(self): - # We assume that all parameters in the module are on the same device - # HACK: we assume the module has at least one parameter - param = next(self.mod.parameters(), None) - buffer = next(self.mod.buffers(), None) - if param is not None: - device = param.device - elif buffer is not None: - device = buffer.device - else: - logging.warning( - f"Module of stage {self.stage_id} has no parameter or buffer, " - f"cannot figure out device. Setting it to cpu" - ) - device = torch.device("cpu") - return device - - def __getstate__(self): - # Adding an empty __getstate__ function here to work around the DDP pickling issue (#153) that occurs when the - # PipelineDiver asks PipeStageExecutors to install_peer_executor(a list of RRefs) - # More elegant solution is needed in CUDAFuture or RPC to avoid pickling when users do not need to transfer - # tensors - pass - - def install_peer_executors(self, peer_executors): - assert self.peer_executors is None - self.peer_executors = peer_executors - return None - - def init_data_parallel(self, n_stages, dp_group_size, dp_pg_cb=None): - worker_rank = self.rank_worker.rank - if dp_pg_cb is not None: - logging.info( - f"Rank[{worker_rank}] stage[{self.stage_id}] Initializing data parallel: " - f"using DP process groups provided by user" - ) - self.mod = torch.nn.parallel.DistributedDataParallel( - self.mod, process_group=dp_pg_cb(self.stage_id) - ) - return - - logging.debug( - f"Rank[{worker_rank}] stage[{self.stage_id}] Initializing data parallel: " - f"creating DP process groups internally" - ) - # Discover DP peers via Store - # HACK: using the Store coming with the default process group - _store = torch.distributed.distributed_c10d._get_default_store() - # Wrap default store by adding a prefix to each key inserted so as not to step into default store's space - store = torch.distributed.PrefixStore("PiPPy", _store) - # TODO: figure out the unique global "stage rank" for Interleaved 1F1B - my_rank = str(worker_rank) - my_stage = str(self.stage_id) - # Each stage rank checks in with their stage id in respective pipe - store.set(my_rank, my_stage) - - # Create a mapping from stage id to DP ranks - stage_to_dp_ranks: Dict[int, List[int]] = {} - for stage in range(n_stages): - stage_to_dp_ranks.setdefault(stage, []) - - # Wait for all stages to check in - world_size = n_stages * dp_group_size - all_ranks = [str(i) for i in range(world_size)] - store.wait(all_ranks) - logging.debug( - f"Rank[{worker_rank}] stage[{self.stage_id}] Initializing data parallel: all stages have checked in" - ) - - # Fill the mapping - for rank in all_ranks: - stage = store.get(rank) # type: ignore[assignment] - stage_to_dp_ranks[int(stage)].append(int(rank)) - - # Create DP process group for each stage - # Note: even if a rank is not in the DP group of another stage, it must still participate in the new_group call of - # that stage; this is required by c10d - for stage in range(n_stages): - dp_group_ranks = stage_to_dp_ranks[stage] - dp_pg_for_stage = torch.distributed.new_group(dp_group_ranks) - if stage == self.stage_id: - logging.info( - f"Rank[{worker_rank}] stage[{self.stage_id}] " - f"DP group {dp_group_ranks} -- init complete" - ) - - # Wrap stage module with DDP using the DP group corresponding to own stage - if self.stage_id == stage: - self.mod = torch.nn.parallel.DistributedDataParallel( - self.mod, process_group=dp_pg_for_stage - ) - - def create_future(self): - # Future constructor does not accept CPU device, must set to None - return torch.futures.Future( - devices=None if self.device.type == "cpu" else [self.device] - ) - - def invoke( - self, - output_unique_key: str, - phase: Phase, - args, - kwargs, - cur_microbatch: int, - debug_str: str, - output_refcount: int, - batch_id: int, - num_microbatches: int, - ): - start_ts = time.time() - target_name = event_name(phase, self.stage_id, cur_microbatch) - target_id = event_id(phase, self.stage_id, cur_microbatch, batch_id) - name = f"R{target_name}" - id = f"R{target_id}" - if self._record_mem_dumps: - self._record_dumps_on_all_peer_executors(f"M{id}_invoke", start_ts) - # TODO: do we need to serialize calls to invoke() to preserve the order in which WorkItems appear for - # static schedules? - - logging.debug( - f"[{self.stage_id}][{cur_microbatch}] Received invoke call for {debug_str}" - ) - # Extract all ValueRef arguments so we can spawn asynchronous data transfers - # for each of them - value_ref_args: List[ValueReference] = [] - - def extract_value_ref_args(arg): - if isinstance(arg, ValueReference) and arg.unique_key != "noop": - value_ref_args.append(arg) - - pippy.fx.node.map_aggregate(args, extract_value_ref_args) - pippy.fx.node.map_aggregate(kwargs, extract_value_ref_args) - - logging.debug( - f"[{self.stage_id}][{cur_microbatch}] Invoke call found {len(value_ref_args)} ValueReference arguments" - ) - - # Construct WorkItem for this microbatch+phase and record it in the - # waiting runlist - - # We provide device to the Future constructor so that between - # future.set_result() and future.wait() correct dependencies can be - # captured - # We assume the output value is on the same device as the stage's parameters - - # Future constructor does not accept CPU device, must set to None - future: torch.futures.Future = self.create_future() - - # TODO: increase blocked_args_count for extra things like scheduling - work_item = WorkItem( - stage_id=self.stage_id, - phase=phase, - args=args, - kwargs=kwargs, - future=future, - microbatch_id=cur_microbatch, - blocked_args_count=len(value_ref_args), - ready_args={}, - batch_id=batch_id, - num_microbatches=num_microbatches, - debug_str=debug_str, - ) - logging.debug( - f"[{self.stage_id}][{cur_microbatch}] Invoke instantiated WorkItem {work_item} with key {output_unique_key}" - ) - if len(value_ref_args) == 0: - # TODO: convert initial input into ValueRef? - # We always put this work item into the ready queue, though we mark - # it with different state flags depending on whether the schedule - # would hold it based on max outstanding allowed - work_item.state = SchedState.READY - logging.debug( - f"[{self.stage_id}][{cur_microbatch}] No RRef arguments. " - f"Scheduling directly as READY workitem" - ) - self.rank_worker.enqueue_ready_runlist(output_unique_key, work_item) - else: - logging.debug( - f"[{self.stage_id}][{cur_microbatch}] Scheduling WorkItem as WAITING workitem" - ) - work_item.state = SchedState.WAITING - self.rank_worker.enqueue_waiting_runlist( - output_unique_key, work_item - ) - - # Group Value Ref Args based on source stage - # `callee_stage_dict` has the following structure: - # Dict[callee_stage, Dict[my_arg_idx, value_ref]] - callee_stage_dict: Dict[int, Dict[int, ValueReference]] = {} - for arg_idx, value_ref_arg in enumerate(value_ref_args): - # Check if the ValRef corresponds to a tensor - if "tensor_meta" in value_ref_arg.meta: - callee_stage = value_ref_arg.stage_id - batch_refs = callee_stage_dict.setdefault(callee_stage, {}) - batch_refs[arg_idx] = value_ref_arg - else: - # For non-tensor (e.g. a value or a size vector), we use RPC to spawn asynchronous data transfer - logging.debug( - f"[{self.stage_id}][{cur_microbatch}] Launching RPC data transfer for " - f"ValueReference {arg_idx} {value_ref_arg}" - ) - self.async_transfer( - cur_microbatch, value_ref_arg, arg_idx, output_unique_key - ) - - # For tensors, we use c10d two-sided send/recv - # Batch call per source stage to reduce number of RPC threads - with self.callee_send_tag_lock: - for callee_stage, batch_refs in callee_stage_dict.items(): - value_ref_executor_rref = self.peer_executors[callee_stage] - tag = self.callee_send_tag.setdefault(callee_stage, 0) - self.callee_send_tag[callee_stage] += 1 - value_ref_executor_rref.rpc_async().batch_send( - self.stage_id, - output_unique_key, - cur_microbatch, - batch_refs, - tag, - ) - self.batch_recv( - cur_microbatch, - output_unique_key, - callee_stage, - batch_refs, - tag, - ) - - with self.value_store_cv: - assert output_unique_key not in self.value_store, ( - f"[{self.stage_id}] Output key {output_unique_key} " - f"already exists or is not consumed from previous batch" - ) - self.value_store[output_unique_key] = RefcountedFuture( - future, output_refcount - ) - self.value_store_cv.notify_all() - - finish_ts = time.time() - self.record_event( - rank=self.rank_worker.rank, - start_ts=start_ts, - finish_ts=finish_ts, - id=id, - name=name, - type="received", - mbid=cur_microbatch, - ) - self.record_event_dependency( - from_id=name, to_id=target_name, type="waiting" - ) - - return ValueReference(self.stage_id, output_unique_key) - - def coalesced_index_value( - self, indices: List[Tuple[str, int, ValueReference, int]] - ): - for index_tuple in indices: - # `output_unique_key` is the key for the indexed output value (single) - # `value_ref.unique_key` is the key for the overall output of current stage (can have multiple values) - (output_unique_key, output_refcount, value_ref, idx) = index_tuple - logging.debug( - f"[{self.stage_id}] Received getitem call: {(output_unique_key, output_refcount, value_ref, idx)}" - ) - with self.value_store_cv: - # TODO: investigate why value reference in the last batch has not been fully consumed - if output_unique_key in self.value_store: - logging.debug( - f"[{self.stage_id}] Indexed value already in store: {(output_unique_key, output_refcount, value_ref, idx)}" - ) - # raise RuntimeError(f'Repeated index value call detected, potentially due to getitem calls not consumed in previous batch') - - # Wait for the future representing the stage output to be created - while value_ref.unique_key not in self.value_store: - self.value_store_cv.wait() - # Now the stage output future is created - refcounted_future = self.value_store[value_ref.unique_key] - - # For the purposes of refcounting, decrement this use - if refcounted_future.release(): - self.value_store.pop(value_ref.unique_key) - - # Create an indexed future that represents a specific output arg - # Here we use an attach functon so that the index passed to the lambda changes in every loop - def attach_index(fut, index): - indexed_fut = fut.then(lambda f: f.value()[index]) - return indexed_fut - - indexed = attach_index(refcounted_future.future, idx) - - # Enqueue the indexed future - # And notify places that may be waiting for it to be created, such as get_value - self.value_store[output_unique_key] = RefcountedFuture( - indexed, output_refcount - ) - self.value_store_cv.notify_all() - - def get_value( - self, - caller_stage, - runlist_key, - microbatch, - value_ref_arg, - ): - callee_stage = value_ref_arg.stage_id - logging.debug( - f"[{callee_stage}][{microbatch}] Executing transfer of value " - f"{value_ref_arg} initiated by stage {caller_stage} for {runlist_key}" - ) - assert ( - callee_stage == self.stage_id - ), "Mismatch between ValueRef and stage executor" - - with self.value_store_cv: - # Waiting for the indexed future for this arg to be created - while value_ref_arg.unique_key not in self.value_store: - self.value_store_cv.wait() - # Now the indexed future is created - refcounted_future = self.value_store[value_ref_arg.unique_key] - - value = refcounted_future.future.wait() - - with self.value_store_lock: - if refcounted_future.release(): - self.value_store.pop(value_ref_arg.unique_key) - - return value - - def async_transfer(self, microbatch, value_ref_arg, arg_idx, runlist_key): - logging.debug( - f"[{self.stage_id}][{microbatch}] Requesting transfer of value {value_ref_arg} " - f"for runlist item {runlist_key} arg_idx {arg_idx}" - ) - callee_stage = value_ref_arg.stage_id - value_ref_executor_rref = self.peer_executors[callee_stage] - - fut = value_ref_executor_rref.rpc_async().get_value( - self.stage_id, - runlist_key, - microbatch, - value_ref_arg, - ) - - def bottom_half(fut): - logging.debug( - f"[{self.stage_id}][{microbatch}] Completing transfer of value {value_ref_arg} " - f"for runlist item {runlist_key} arg_idx {arg_idx}" - ) - value = fut.value() - self.rank_worker.update_run_list(runlist_key, arg_idx, value) - - return fut.then(bottom_half) - - def batch_send( - self, - caller_stage, - runlist_key, - microbatch, - batch_refs, - tag, - ): - # Wait till this batch's turn to send - with self.caller_recv_tag_cv: - self.caller_recv_tag.setdefault(caller_stage, 0) - while self.caller_recv_tag[caller_stage] < tag: - self.caller_recv_tag_cv.wait() - - logging.debug( - f"[{self.stage_id}][{microbatch}] Sending batch {tag} of " - f"{len(batch_refs)} values initiated by stage {caller_stage} for {runlist_key}" - ) - - for _, value_ref_arg in batch_refs.items(): - with self.value_store_cv: - # Waiting for the indexed future for this arg to be created - while value_ref_arg.unique_key not in self.value_store: - self.value_store_cv.wait() - # Now the indexed future is created - refcounted_future = self.value_store[value_ref_arg.unique_key] - - value = refcounted_future.future.wait() - - with self.value_store_lock: - if refcounted_future.release(): - self.value_store.pop(value_ref_arg.unique_key) - - # Instead of return value let's do a send call - if torch.distributed.get_backend() == "gloo": - # Gloo P2P does not support work.get_future, so we use send instead - torch.distributed.send(value, caller_stage, tag=tag) - else: - torch.distributed.isend(value, caller_stage, tag=tag) - - # Notify next send that's potentially waiting - with self.caller_recv_tag_cv: - self.caller_recv_tag[caller_stage] += 1 - self.caller_recv_tag_cv.notify_all() - - def batch_recv( - self, microbatch, runlist_key, callee_stage, batch_refs, tag - ): - logging.debug( - f"[{self.stage_id}][{microbatch}] Receiving batch {tag} of {len(batch_refs)} values " - f"for runlist item {runlist_key} from stage {callee_stage}" - ) - futures = [] - - for arg_idx, value_ref_arg in batch_refs.items(): - tm = value_ref_arg.meta["tensor_meta"] - recv_buff = torch.empty( - tm.shape, dtype=tm.dtype, device=self.device - ) - - if torch.distributed.get_backend() == "gloo": - # Gloo P2P does not support work.get_future, so we need to: - # - manually create the Future, - # - use recv instead, and - # - manually set_result to the Future - fut: torch.futures.Future = self.create_future() - torch.distributed.recv(recv_buff, callee_stage, tag=tag) - fut.set_result(recv_buff) - else: - work = torch.distributed.irecv(recv_buff, callee_stage, tag=tag) - fut = work.get_future() # type: ignore[attr-defined] - - def bottom_half(fut): - logging.debug( - f"[{self.stage_id}][{microbatch}] Completing transfer of value {value_ref_arg} " - f"for runlist item {runlist_key} arg_idx {arg_idx}" - ) - value = fut.value() - # It is awkward that the Work class in PyTorch fixes the result return to a List: - # def result(self) -> List[Tensor]: ... - # See torch/_C/_distributed_c10d.pyi - # We don't expect P2P operations to actually result in a List, hence unpacking and getting the first and - # only tensor out - if isinstance(value, List): - value = value[0] - self.rank_worker.update_run_list(runlist_key, arg_idx, value) - - futures.append(fut.then(bottom_half)) - - return futures - - def get_grad(self, qualname): - mod = self.mod - if isinstance(mod, torch.nn.parallel.DistributedDataParallel): - mod = mod.module - return mod.get_parameter(qualname).grad - - def set_grad(self, qualname, value): - mod = self.mod - if isinstance(mod, torch.nn.parallel.DistributedDataParallel): - mod = mod.module - param = mod.get_parameter(qualname) - param.grad = value - - def train(self, mode=True): - self.mod.train(mode=mode) - - def _should_instantiate_optim(self): - return len(list(self.mod.parameters())) > 0 - - def instantiate_optimizer(self, optim_class, *args, **kwargs): - assert self._should_instantiate_optim() - with self.optim_init_cv: - self.optimizer = optim_class(self.mod.parameters(), *args, **kwargs) - self.optim_init_cv.notify() - return self.optimizer - - def instantiate_lr_scheduler(self, lr_sched_class, *args, **kwargs): - # Make sure optimizer has been created - with self.optim_init_cv: - while self.optimizer is None: - self.optim_init_cv.wait() - - logging.info(f"[{self.stage_id}] Creating learning rate scheduler") - self.lr_scheduler = lr_sched_class(self.optimizer, *args, **kwargs) - return self.lr_scheduler - - def step_lr_scheduler(self, *args, **kwargs): - self.lr_scheduler.step(*args, **kwargs) # type: ignore[union-attr] - - def _check_cleanup(self) -> bool: - if len(self.value_store): - logging.warning( - f"[{self.stage_id}] Unclean value store: {self.value_store}" - ) - return False - return True - - def _record_dump(self, dump_id, ts): - first_param = next(self.mod.parameters(), None) - device: torch.device = ( - first_param.device - if first_param is not None - else torch.device("cpu") - ) - if device.type == "cuda": - alloc = torch.cuda.memory_allocated() - max_alloc = torch.cuda.max_memory_allocated() - rsrvd = torch.cuda.memory_reserved() - max_rsrvd = torch.cuda.max_memory_reserved() - assert ( - alloc <= max_alloc - ), f"alloc = {alloc} max_alloc = {max_alloc}" - assert ( - rsrvd <= max_rsrvd - ), f"rsrvd = {rsrvd} max_rsrvd = {max_rsrvd}" - assert ( - max_alloc <= max_rsrvd - ), f"max_alloc = {max_alloc} max_rsrvd = {max_rsrvd}" - self.record_dump( - rank=self.rank_worker.rank, - ts=ts, - id=dump_id, - name=dump_id, - type="dump", - allocators={ - "cuda.4.alloc": Allocator( - f"alloc_{self.rank_worker.rank}", - { - "size": alloc, - }, - ), - "cuda.3.max_alloc-alloc": Allocator( - f"max_alloc-alloc_{self.rank_worker.rank}", - { - "size": max_alloc - alloc, - }, - ), - "cuda.2.rsrvd-max_alloc": Allocator( - f"rsrvd-max_alloc_{self.rank_worker.rank}", - { - "size": max(rsrvd - max_alloc, 0), - }, - ), - "cuda.1.max_rsrvd-max_alloc_or_rsrvd": Allocator( - f"max_rsrvd-max_alloc_or_rsrvd_{self.rank_worker.rank}", - { - "size": max_rsrvd - - (max_alloc if max_alloc > rsrvd else rsrvd), - }, - ), - }, - ) - - def _record_dumps_on_all_peer_executors(self, id, ts): - for peer_executor_rref in self.peer_executors.values(): - peer_executor_rref.rpc_sync()._record_dump(f"{id}", ts) - - -def _wait_for_all(rpc_futs): - # Stolen from DistributedOptimizer implementation - # TODO: improve error propagation - exception = None - results = [] - for fut in rpc_futs: - try: - results.append(fut.wait()) - except Exception as e: - results.append(e) - exception = e - if exception is not None: - raise exception - return results - - -class PipelineOptimizer(torch.optim.Optimizer): - def __init__(self, remote_optims): - self.remote_optims = remote_optims - - # TODO: enable this - # self._hook_for_profile() - - # TODO: enable this - # self.state = defaultdict(dict) - - self.param_groups = [] - - # Collect RRefs to remote parameters - param_group = {"params": []} # type: ignore[var-annotated] - - for optim in self.remote_optims: - remote_state = optim.rpc_sync().__getstate__() - assert isinstance(remote_state, dict) - for group in remote_state["param_groups"]: - param_group["params"].extend(group["params"]) - for k in group: - if k != "params": - param_group.setdefault(k, group[k]) - - self.param_groups = [param_group] - - def __getstate__(self): - raise NotImplementedError() - - def __setstate__(self, state): - raise NotImplementedError() - - def _hook_for_profile(self): - raise NotImplementedError() - - def state_dict(self): - raise NotImplementedError() - - def load_state_dict(self, state_dict): - raise NotImplementedError() - - # PyTorch type annotation for this function is wrong. See - # https://github.com/pytorch/pytorch/pull/76998 for proposed fix - def zero_grad(self, set_to_none: bool = False): # type: ignore - futs = [] - for optim in self.remote_optims: - futs.append(optim.rpc_async().zero_grad(set_to_none)) - _wait_for_all(futs) - - def step(self, closure=None): - futs = [] - for optim in self.remote_optims: - futs.append(optim.rpc_async().step(closure)) - _wait_for_all(futs) - - def add_param_group(self, param_group): - raise NotImplementedError() - - -class PipelineLRScheduler(torch.optim.lr_scheduler._LRScheduler): - def __init__(self, stage_to_scheds, stage_to_executor): - # A dict from stage id to LR schedulers - self.stage_to_scheds = stage_to_scheds - self.stage_to_executor = stage_to_executor - self.new_step_called = False - self.last_lr = [] - - def step(self, *args, **kwargs): - futs = [] - # Step all remote LR schedulers - - # We use the executor block below because calling scheduler.step() - # remotely might cause pickling nested functions, where these nested - # functions are usually defined inside user's lr scheduler constructor - # as lambda functions to be used by the lr scheduler - # See https://github.com/pytorch/PiPPy/issues/404 - """ - for scheduler in self.stage_to_scheds.values(): - futs.append(scheduler.rpc_async().step(*args, **kwargs)) - """ - for executor in self.stage_to_executor.values(): - futs.append(executor.rpc_async().step_lr_scheduler(*args, **kwargs)) - - _wait_for_all(futs) - # Mark new step (invalidates last_lr) - self.new_step_called = True - - def get_last_lr(self): - """Return last computed learning rate by remote schedulers.""" - # No need to involve remote schedulers if no new step calls - if not self.new_step_called: - return self.last_lr - - # Ask LR scheduler of stage 0 to return new learning rate as representation of all stages, because: - # (i) we do not support multiple parameter groups yet (neither PipelineOptimizer nor PipelineLRScheduler does), - # so there are not param group specific LR's; and - # (ii) current LRS implementations do not relies on state within the optimizer, so the LR's of different stages - # will not diverge - assert self.stage_to_scheds, "No learning rate scheduler" - self.last_lr = self.stage_to_scheds[0].remote().get_last_lr().to_here() - self.new_step_called = False - return self.last_lr - - def state_dict(self): - """Returns the state of the remote schedulers as a :class:`dict`""" - # Ask LR scheduler of stage 0 to return state_dict as representation of all stages, for the same reason as - # stated in get_last_lr() - rv: Dict = {} - assert self.stage_to_scheds, "No learning rate scheduler" - rv = self.stage_to_scheds[0].remote().state_dict().to_here() - return rv - - def load_state_dict(self, state_dict): - """Loads the scheduler state. - Args: - state_dict (dict): scheduler state. Should be an object returned - from a call to :meth:`state_dict`. - """ - futs = [] - for scheduler in self.stage_to_scheds.values(): - futs.append(scheduler.rpc_async().load_state_dict(state_dict)) - - _wait_for_all(futs) - - def get_lr(self): - # Even in single scheduler setting, get_lr is more of an internal method to be called by step() - # See: pytorch/torch/optim/lr_scheduler.py - warnings.warn( - "To get the last learning rate computed by the scheduler, " - "please use `get_last_lr()`." - ) - raise NotImplementedError - - def print_lr(self, is_verbose, group, lr, epoch=None): - """Display the current learning rate.""" - # This is more of an internal method of native scheduler - # See: pytorch/torch/optim/lr_scheduler.py - raise NotImplementedError - - -class PipelineDriverBase(torch.nn.Module): - def __init__( - self, - pipe: Pipe, - chunks: int, - world_size: int, - all_ranks: List[int] = None, - args_chunk_spec=None, - kwargs_chunk_spec=None, - output_chunk_spec=None, - _debug_mask_minibatches: bool = False, - max_outstanding=None, - interleave_stages=False, - _record_mem_dumps=False, - checkpoint=False, - use_c10d=False, - loss_reducer: LossReducer = sum_reducer, - ): - super().__init__() - self.pipe = pipe - self.chunks = chunks - self.world_size = world_size - self.all_ranks = all_ranks - self.args_chunk_spec = args_chunk_spec - self.kwargs_chunk_spec = kwargs_chunk_spec - self.loss_reducer = loss_reducer - self.output_chunk_spec = ( - output_chunk_spec - if output_chunk_spec - else gen_output_chunk_spec(pipe.loss_spec, loss_reducer) - ) - - # Maximum outstanding micro-batches allowed by the pipeline schedule - # None means no limit - self.max_outstanding: Optional[int] = max_outstanding - self._debug_mask_minibatches = _debug_mask_minibatches - self.interleave_stages = interleave_stages - - self.microbatch_interpreters: List[RemoteInterpreter] = [] - self.batch_id = 0 - self._record_mem_dumps = _record_mem_dumps - self.optimizer_inited = False - self.checkpoint = checkpoint - self.use_c10d = use_c10d - - def _init_remote_executors(self): - self.rank_worker_rrefs: Dict[int, torch.distributed.rpc.RRef] = {} - self.remote_stage_executor_rrefs: Dict[ # type: ignore[syntax] - str, (int, torch.distributed.rpc.RRef) - ] = {} - - if self.all_ranks is not None: - assert ( - len(self.all_ranks) == self.world_size - ), "Explicitly specified ranks must match world_size" - else: - self.all_ranks = list(range(self.world_size)) - logging.info( - f"[root] Creating pipeline driver with {self.world_size} workers: {self.all_ranks}" - ) - - class ExecutorDescriptor: - name: str - mod: Optional[torch.nn.Module] - has_backward: bool = False - - split_gm = self.pipe.split_gm - - executor_descriptors = [] - bw_idx = -1 - for node in split_gm.graph.nodes: - if node.op == "call_module": - descr = ExecutorDescriptor() - descr.name = node.target - if Pipe.is_stage_init_deferred(): - descr.mod = None - else: - descr.mod = split_gm.get_submodule(node.target) - executor_descriptors.append(descr) - elif (node.op, node.target) == ("call_function", stage_backward): - executor_descriptors[bw_idx].has_backward = True - node.meta["fw_stage"] = executor_descriptors[bw_idx].name - bw_idx -= 1 - elif (node.op, node.target) == ( - "call_function", - _null_coalesce_accumulate, - ): - node.meta["fw_stage"] = executor_descriptors[bw_idx].name - - assert all(d.has_backward for d in executor_descriptors) or all( - not d.has_backward for d in executor_descriptors - ) - - if len(executor_descriptors) > self.world_size: - if not self.interleave_stages: - raise RuntimeError( - f"Tried to run pipeline with {len(executor_descriptors)} stages with a world size of " - f"{self.world_size}. Please ensure world_size is large enough to accommodate your pipeline." - ) - - ranks_to_launch = self.world_size - n_stages = len(executor_descriptors) - if n_stages < self.world_size: - ranks_to_launch = n_stages - warnings.warn( - f"Running pipeline with {n_stages} stages on world_size of {self.world_size}. " - f"Remaining ranks will be idle." - ) - - if self.interleave_stages and n_stages <= ranks_to_launch: - self.interleave_stages = False - warnings.warn( - "Falling back from Interleaved 1F1B to 1F1B " - "since there are enough ranks to support one stage per rank" - ) - - # Fire up rank workers - all_stages = list(range(n_stages)) - pp_rank = 0 - for rank in self.all_ranks[:ranks_to_launch]: - kwargs = { - "rank": rank, - "all_stages": all_stages, - "max_outstanding": self.max_outstanding, - "pp_rank": pp_rank, - "_record_mem_dumps": self._record_mem_dumps, - "checkpoint": self.checkpoint, - } - self.rank_worker_rrefs[rank] = rpc.remote( - rank, RankWorker, args=(), kwargs=kwargs - ) - pp_rank += 1 - - self.stage_to_executor: Dict = {} - - # Ask each RankWorker to create stage thereon - # This can involve checkpoint loading in deferred init case - for stage_id, descr in enumerate(executor_descriptors): - # Assign stages to rank workers in a round-robin fashion - rank = self.all_ranks[stage_id % self.world_size] - logging.debug(f"[root] Sending stage_id = {stage_id} mod to worker") - self.remote_stage_executor_rrefs[descr.name] = ( - stage_id, - self.rank_worker_rrefs[rank] - .remote() - .create_stage_executor( - stage_id=stage_id, - mod=descr.mod, - mod_name=descr.name, - ), - ) - - # Check that each RankWorker has completed stage init - for stage_id, descr in enumerate(executor_descriptors): - logging.debug( - f"[root] Waiting stage_id = {stage_id} mod to be confirmed by worker" - ) - while not self.remote_stage_executor_rrefs[descr.name][ - 1 - ].confirmed_by_owner(): - pass - - self.stage_to_executor[stage_id] = self.remote_stage_executor_rrefs[ - descr.name - ][1] - - # Inform executors of their peers - for stage_id, executor in self.stage_to_executor.items(): - executor.rpc_sync().install_peer_executors(self.stage_to_executor) - - """ - Method for creating a data parallel clique for each stage, across multiple pipelines - dp_group_size: size of each data parallel group, equals to the number of pipelines - dp_pg_cb: optional Callable taking pipeline stage as argument and returning corresponding data parallel group; - user can use this Callable to pass in prepared data parallel groups - """ - - def init_data_parallel(self, dp_group_size, dp_pg_cb=None): - if dp_group_size <= 1: - logging.info( - "[root] Data parallel group size <= 1, skipping data parallel initialization" - ) - return - - n_stages = len(self.stage_to_executor) - logging.info( - f"[root] Initializing {n_stages} data parallel groups, each of size {dp_group_size}" - ) - futs = [] - # Asks all stage executors to participate in DP process group init - # These must be async calls because otherwise there will be deadlocks - for executor in self.stage_to_executor.values(): - futs.append( - executor.rpc_async().init_data_parallel( - n_stages, dp_group_size, dp_pg_cb - ) - ) - - # Here we wait for all DP process groups to be initialized before the user can ask the PipeDriver to run - _wait_for_all(futs) - - def forward(self, *args, **kwargs): - raise NotImplementedError( - "PipelineDriverBase is an abstract base class, please use a concrete " - "implementation class." - ) - - def train(self, mode=True): - for executor in self.stage_to_executor.values(): - executor.rpc_sync().train(mode=mode) - - def eval(self): - self.train(mode=False) - - def instantiate_optimizer(self, optim_class, *args, **kwargs): - remote_optims = [] - # Keeps track of stage to optimizer mapping - self.stage_to_optim: Dict = {} - for stage, executor in self.stage_to_executor.items(): - if executor.rpc_sync()._should_instantiate_optim(): - remote_optim = executor.remote().instantiate_optimizer( - optim_class, *args, **kwargs - ) - remote_optims.append(remote_optim) - self.stage_to_optim[stage] = remote_optim - - self.optimizer_inited = True - return PipelineOptimizer( - [optim for optim in remote_optims if optim is not None] - ) - - """ - Create learning rate scheduler for the optimizer of the pipeline. - Note: this API cannot be called before instantiate_optimizer is called. - """ - - def instantiate_lr_scheduler(self, lr_sched_class, *args, **kwargs): - if not self.optimizer_inited: - raise RuntimeError( - "[root] instantiate_optimizer must be called before instantiate_lr_scheduler" - ) - - stage_to_scheds: Dict = {} - for stage, optim in self.stage_to_optim.items(): - if optim is not None: - executor = self.stage_to_executor[stage] - remote_lr_sched = executor.remote().instantiate_lr_scheduler( - lr_sched_class, *args, **kwargs - ) - stage_to_scheds[stage] = remote_lr_sched - - return PipelineLRScheduler(stage_to_scheds, self.stage_to_executor) - - def _sync_replicated_params(self): - logging.debug( - f"[root] Synchronizing gradients for {len(self.pipe.replicated_params)} sets of replicated parameters" - ) - for param_set in self.pipe.replicated_params: - grad_values = [] - for module_name, param_qualname in param_set.items(): - assert module_name in self.remote_stage_executor_rrefs - stage_id, module_rref = self.remote_stage_executor_rrefs[ - module_name - ] - grad_value = module_rref.rpc_sync().get_grad(param_qualname) - grad_values.append(grad_value) - - synced_value = torch.sum(torch.stack(grad_values), dim=0) - - for module_name, param_qualname in param_set.items(): - assert module_name in self.remote_stage_executor_rrefs - stage_id, module_rref = self.remote_stage_executor_rrefs[ - module_name - ] - module_rref.rpc_sync().set_grad(param_qualname, synced_value) - - def _retrieve_output_values(self, microbatch_interpreters, last_nodes): - logging.debug( - f"[root] Retrieving output values from {len(microbatch_interpreters)} chunks" - ) - output_vals = [] - for interp, last_node in zip(microbatch_interpreters, last_nodes): - interp.run_until(lambda n: False) - output_vals.append(interp.env[last_node]) - - # First kick of async transfers to retrieve ValueReference values - def initiate_async_transfer(a): - if isinstance(a, ValueReference): - value_ref_executor_rref = self.stage_to_executor[a.stage_id] - return value_ref_executor_rref.rpc_async().get_value( - "root", "collect", -1, a - ) - else: - return a - - output_vals = pippy.fx.node.map_aggregate( - output_vals, initiate_async_transfer - ) - - # Then wait for futures to be ready - return pippy.fx.node.map_aggregate( - output_vals, - lambda a: a.wait() if isinstance(a, torch._C.Future) else a, - ) - - def retrieve_events(self) -> EventsContext: - events_context = EventsContext() - for rank, worker_rref in self.rank_worker_rrefs.items(): - events_context.update(worker_rref.rpc_sync().retrieve_events()) - for interp in self.microbatch_interpreters: - events_context.update(interp.retrieve_events()) - for _, executor_rref in self.remote_stage_executor_rrefs.values(): - events_context.update(executor_rref.rpc_sync().retrieve_events()) - events_context.events.sort(key=lambda e: e.start_ts) - return events_context - - def _check_stages_cleanup(self) -> bool: - clean = True - for executor in self.stage_to_executor.values(): - clean &= executor.rpc_sync()._check_cleanup() - return clean - - -class RemoteInterpreter(pippy.fx.Interpreter, EventRecorder): - def __init__( - self, - remote_stage_executor_rrefs, - stage_to_executor, - module, - cur_microbatch: int, - args, - kwargs, - batch_id: int, - num_microbatches: int, - garbage_collect_values=True, - ): - super().__init__(module, garbage_collect_values) - self.remote_stage_executor_rrefs = remote_stage_executor_rrefs - self.stage_to_executor = stage_to_executor - self.cur_microbatch = cur_microbatch - self.pc = 0 - self.node_list = list(self.module.graph.nodes) - logging.debug( - f"[root] RemoteInterpreter created with {len(self.node_list)} nodes" - ) - - # Process args/kwargs - - # TODO: replace this with GraphModule.signature() when it lands - parameters = [] - for node in self.module.graph.nodes: - if node.op != "placeholder": - continue - default = next(iter(node.args)) if node.args else Parameter.empty - parameters.append( - Parameter( - node.name, Parameter.POSITIONAL_OR_KEYWORD, default=default - ) - ) - - # We are building a safety net here in case user passes in extra arguments than those defined as variable - # arguments (i.e. non-concrete args) at the tracing phase - # TODO: Remove this safety net - traced_args = [p.name for p in parameters] - filtered_kwargs = {k: v for k, v in kwargs.items() if k in traced_args} - if len(filtered_kwargs) != len(kwargs): - extra_args = kwargs.keys() - filtered_kwargs.keys() - warnings.warn( - f"Received extra arguments: {extra_args}. " - f"They might have already been given a concrete value during pipeline compilation via `concrete_args`. " - f"We will ignore the current inputs and use the values given during compilation." - ) - - sig = Signature(parameters) - bound_args = sig.bind(*args, **filtered_kwargs) - bound_args.apply_defaults() - self.args = bound_args.args - self.args_iter = iter(self.args) - self.batch_id = batch_id - self.num_microbatches = num_microbatches - # Dict from stage id to a list holding the coalesced getitem indices - self.stage_output_indices: Dict[ - int, List[Tuple[str, int, ValueReference, int]] - ] = {} - - def call_module(self, target, args, kwargs): - assert isinstance(target, str) - node = self.node_list[self.pc] - # if PipelineDriver is running inside `torch.no_grad()` context manager then `stage_backward*` nodes - # are excluded from execution, so we need exclude `stage_backward*` from reference count, otherwise - # it will cause memory leak. - users = ( - list( - filter( - lambda user: not user.name.startswith("stage_backward"), - node.users.keys(), - ) - ) - if not torch.is_grad_enabled() - else node.users.keys() - ) - if target in self.remote_stage_executor_rrefs: - stage_id, stage_executor = self.remote_stage_executor_rrefs[target] - logging.debug( - f"[root][{self.cur_microbatch}] Issuing {Phase.FORWARD} " - f"invocation for target {target} on stage {stage_id}" - ) - invocation_key = f"{self.cur_microbatch}_{node.name}" - start_ts = time.time() - forward_name = event_name( - Phase.FORWARD, stage_id, self.cur_microbatch - ) - forward_id = event_id( - Phase.FORWARD, stage_id, self.cur_microbatch, self.batch_id - ) - name = f"I{forward_name}" - id = f"I{forward_id}" - stage_executor.rpc_async().invoke( - invocation_key, - Phase.FORWARD, - args, - kwargs, - self.cur_microbatch, - debug_str=node.format_node(), - output_refcount=len(users), - batch_id=self.batch_id, - num_microbatches=self.num_microbatches, - ) - finish_ts = time.time() - self.record_event( - rank=0, - start_ts=start_ts, - finish_ts=finish_ts, - id=id, - name=name, - type="invoke", - mbid=self.cur_microbatch, - ) - self.record_event_dependency( - from_id=name, to_id=f"R{forward_name}", type="invoke" - ) - return ValueReference(stage_id, invocation_key) - else: - logging.debug( - f"[root][{self.cur_microbatch}] Running local operation {target} from driver" - ) - return super().call_module(target, args, kwargs) - - def call_function(self, target, args, kwargs): - node = self.node_list[self.pc] - invocation_key = f"{self.cur_microbatch}_{node.name}" - # if PipelineDriver is running inside `torch.no_grad()` context manager then `stage_backward*` nodes - # are excluded from execution, so we need exclude `stage_backward*` from reference count, otherwise - # it will cause memory leak. - users = ( - list( - filter( - lambda user: not user.name.startswith("stage_backward"), - node.users.keys(), - ) - ) - if not torch.is_grad_enabled() - else node.users.keys() - ) - if target is operator.getitem and isinstance(args[0], ValueReference): - val_ref = args[0] - stage_id = val_ref.stage_id - num_users = len(users) - if not torch.is_grad_enabled() and val_ref.unique_key == "noop": - return ValueReference(stage_id, "noop") - elif num_users == 0: - # TODO: investigate why there are getitem calls with 0 users - return ValueReference(stage_id, "noop") - else: - indices = self.stage_output_indices.setdefault(stage_id, []) - arg_idx = args[1] - index_tuple = (invocation_key, num_users, val_ref, arg_idx) - logging.debug( - f"[root][{self.cur_microbatch}] Appending getitem tuple to stage {stage_id}: {index_tuple}" - ) - indices.append(index_tuple) - return ValueReference(stage_id, invocation_key) - elif target is stage_backward: - assert "fw_stage" in node.meta - stage_id, stage_executor = self.remote_stage_executor_rrefs[ - node.meta["fw_stage"] - ] - if torch.is_grad_enabled(): - logging.debug( - f"[root][{self.cur_microbatch}] Issuing BW invocation " - f'for target {node.meta["fw_stage"]} on stage {stage_id}' - ) - start_ts = time.time() - backward_name = event_name( - Phase.BACKWARD, stage_id, self.cur_microbatch - ) - backward_id = event_id( - Phase.BACKWARD, stage_id, self.cur_microbatch, self.batch_id - ) - name = f"I{backward_name}" - id = f"I{backward_id}" - stage_executor.rpc_async().invoke( - invocation_key, - Phase.BACKWARD, - args, - kwargs, - self.cur_microbatch, - debug_str=node.format_node(), - output_refcount=len(users), - batch_id=self.batch_id, - num_microbatches=self.num_microbatches, - ) - finish_ts = time.time() - self.record_event( - rank=0, - start_ts=start_ts, - finish_ts=finish_ts, - id=id, - name=name, - type="invoke", - mbid=self.cur_microbatch, - ) - self.record_event_dependency( - from_id=name, to_id=backward_name, type="invoke" - ) - return ValueReference(stage_id, invocation_key) - else: - return ValueReference(stage_id, "noop") - elif target is sync_barrier: - executor_keys = list(self.remote_stage_executor_rrefs.keys()) - stage_id, stage_executor = self.remote_stage_executor_rrefs[ - executor_keys[0] - ] - logging.debug( - f"[root][{self.cur_microbatch}] Issuing sync invocation " - f"on stage {stage_id}" - ) - stage_executor.rpc_async().invoke( - invocation_key, - Phase.SYNC_BARRIER, - args, - kwargs, - self.cur_microbatch, - debug_str=node.format_node(), - output_refcount=len(users), - batch_id=self.batch_id, - num_microbatches=self.num_microbatches, - ) - return ValueReference(stage_id, invocation_key) - elif target is _null_coalesce_accumulate: - assert "fw_stage" in node.meta - stage_id, stage_executor = self.remote_stage_executor_rrefs[ - node.meta["fw_stage"] - ] - if torch.is_grad_enabled(): - logging.debug( - f"[root][{self.cur_microbatch}] Issuing accumulate grad invocation " - f'for target {node.meta["fw_stage"]} on stage {stage_id}' - ) - stage_executor.rpc_async().invoke( - invocation_key, - Phase.ACCUMULATE_GRAD, - args, - kwargs, - self.cur_microbatch, - debug_str=node.format_node(), - output_refcount=len(users), - batch_id=self.batch_id, - num_microbatches=self.num_microbatches, - ) - return ValueReference(stage_id, invocation_key) - else: - return ValueReference(stage_id, "noop") - else: - raise AssertionError(f"Unknown operator {torch.typename(target)}") - - def issue_coalesced_getitem_calls(self): - if len(self.stage_output_indices) == 0: - return - logging.debug( - f"[root][{self.cur_microbatch}] Issuing getitem calls to stage: {self.stage_output_indices.keys()}" - ) - - for stage_id in self.stage_output_indices: - stage_executor = self.stage_to_executor[stage_id] - name = f"G{stage_id},{self.cur_microbatch}" - id = f"{name},{self.batch_id}" - start_ts = time.time() - stage_executor.rpc_async().coalesced_index_value( - self.stage_output_indices[stage_id] - ) - finish_ts = time.time() - self.record_event( - rank=0, - start_ts=start_ts, - finish_ts=finish_ts, - id=id, - name=name, - type="invoke", - mbid=self.cur_microbatch, - ) - - self.stage_output_indices.clear() - - def run_until( - self, predicate: Callable[[pippy.fx.Node], bool] - ) -> Optional[pippy.fx.Node]: - while self.pc < len(self.node_list): - node = self.node_list[self.pc] - - if predicate(node): - # Issue coalesced getitem calls as we pause issuing stage calls - self.issue_coalesced_getitem_calls() - return node - - self.run_one(node) - - # Have run through the entire node_list, using None to mean no node left to run - return None - - def run_one(self, node): - # TODO: hoist run() implementation - logging.debug( - f"[{self.cur_microbatch}] Issue command to run {node.format_node()}" - ) - self.env[node] = super().run_node(node) - - # TODO: we could potentially move this waiting to the use sites for an RRef - # (i.e. during Interpreter.map_nodes_to_values or when we pass args/kwargs - # to the callees) as an optimization - # TODO: is it possible for there to be a blocking version of this API? - def wait_for_confirmation(n): - # The following if will not be true as we are using our own ValueRef - # instead of RPC's RRef - if isinstance(n, torch._C._distributed_rpc.PyRRef): - while not n.confirmed_by_owner(): - pass - - pippy.fx.node.map_aggregate(self.env[node], wait_for_confirmation) - - if DEBUG and isinstance( - self.env[node], torch._C._distributed_rpc.PyRRef - ): - print(node, self.env[node]) - self.env[node].to_here() - - # Insert tensor meta to ValueReference returned by node call - # TODO: there is some problem with "call_function", disabling for now - if node.op == "call_module": - if "tensor_meta" in node.meta and isinstance( - node.meta["tensor_meta"], - shape_prop.TensorMetadata, - ): - val_ref: ValueReference = self.env[node] - val_ref.meta.setdefault("tensor_meta", node.meta["tensor_meta"]) - - self.pc += 1 - return node - - def propagate_shape(self, args, kwargs): - logging.info("Propagating shape across split GraphModule") - sp = shape_prop.ShapeProp(self.module) - # Not sure why FX's propagate API takes only args. Hence we unpack kwargs.values() without keys here - sp.propagate(*args, *kwargs.values()) - for node in self.node_list: - logging.debug(f"Node: {node.name}, outputs: ") - if "tensor_meta" in node.meta: - if isinstance( - node.meta["tensor_meta"], shape_prop.TensorMetadata - ): - logging.debug(f"- {node.meta['tensor_meta']}") - else: - # Multiple output tensors - for t_meta in node.meta["tensor_meta"]: - logging.debug(f"- {t_meta}") - - -class _run_until_criteria: - def __init__(self): - self.seen_stages = 0 - - # Run the node we start with including all nodes that are tuple - # indexing, then stop - def hitting_next_stage(self, node): - if node.op == "output": - return True - - if ( - node.target != operator.getitem - and node.target != _null_coalesce_accumulate - ): - self.seen_stages += 1 - - if self.seen_stages > 1: - return True - elif self.seen_stages == 1 and node.target == _null_coalesce_accumulate: - # We are hitting the accumulate call of the next (backward) stage, stop - return True - else: - return False - - -class PipelineDriverFillDrain(PipelineDriverBase): - def __init__( - self, - pipe: Pipe, - chunks: int, - world_size: int, - all_ranks: List[int] = None, - args_chunk_spec=None, - kwargs_chunk_spec=None, - output_chunk_spec=None, - single_loss: bool = False, - _debug_mask_minibatches: bool = False, - max_outstanding=None, - interleave_stages=False, - _record_mem_dumps=False, - checkpoint=False, - use_c10d=False, - loss_reducer: LossReducer = sum_reducer, - ): - super().__init__( - pipe, - chunks, - world_size, - all_ranks, - args_chunk_spec, - kwargs_chunk_spec, - output_chunk_spec, - _debug_mask_minibatches, - max_outstanding=max_outstanding, - interleave_stages=interleave_stages, - _record_mem_dumps=_record_mem_dumps, - checkpoint=checkpoint, - use_c10d=use_c10d, - loss_reducer=loss_reducer, - ) - self.single_loss = single_loss - - self.last_grads = None - - self._init_remote_executors() - - def forward(self, *args, **kwargs): - if self.single_loss: - raise NotImplementedError("Single minibatch loss not implemented") - - # Roadmap: - # 1) Micro-batch splitting - divide input arguments out into concrete chunk values - # 2) Interpreter tiling - one interpreter per micro-batch - # 3) Scheduling - Use control logic to advance interpreters to issue round-robin - # forward work items, then round-robin losses, then round-robin backwards - - args_split, kwargs_split = split_args_kwargs_into_chunks( - args, - kwargs, - self.chunks, - self.args_chunk_spec, - self.kwargs_chunk_spec, - self._debug_mask_minibatches, - ) - - real_num_chunks = self.chunks - if len(args_split) < self.chunks: - real_num_chunks = len(args_split) - warnings.warn( - f"Reducing micro-batch numbers from {self.chunks} to " - f"{real_num_chunks}." - ) - - logging.info( - f"[root] Running pipeline with {real_num_chunks} micro-batches" - ) - - self.microbatch_interpreters = [] - - batch_id = self.batch_id - self.batch_id += 1 - - for chunk in range(real_num_chunks): - logging.debug( - f"[root] Instantiating microbatch interpreter for chunk {chunk}" - ) - interp = RemoteInterpreter( - remote_stage_executor_rrefs=self.remote_stage_executor_rrefs, - stage_to_executor=self.stage_to_executor, - module=self.pipe.split_gm, - cur_microbatch=chunk, - args=args_split[chunk], - kwargs=kwargs_split[chunk], - batch_id=batch_id, - num_microbatches=real_num_chunks, - ) - # If user wants to use c10d for P2P, we would perform the shape propagation here. The shape prop is - # performed per batch, thus supporting dynamic shape in batch dimension. Dynamic shape in microbatch - # dimension is not yet supported, because all RemoteInterpreters share the same shape info (since they share - # the same split_gm) - if self.use_c10d and chunk == 0: - interp.propagate_shape(args_split[chunk], kwargs_split[chunk]) - - self.microbatch_interpreters.append(interp) - - logging.debug( - f"[root] {len(self.microbatch_interpreters)} instantiated" - ) - - # Deterministic clock cycle - see torchgpipe paper section 3.2.1 for details - - # Advance past placeholders - for interp in self.microbatch_interpreters: - interp.run_until(lambda n: n.op != "placeholder") - - # Ramp-up, admit diagonal wavefront until we get to a full diagonal - # location in the matrix - - for ramp_up_idx in range(len(self.microbatch_interpreters)): - for i in range(ramp_up_idx + 1): - interp = self.microbatch_interpreters[i] - criteria = _run_until_criteria() - interp.run_until(criteria.hitting_next_stage) - - # Steady-state. We have a full diagonal in the matrix; keep dispatching - # across the diagonal - - any_valid = True - while any_valid: - any_valid = False - for interp in self.microbatch_interpreters: - start_node = interp.node_list[ - min(interp.pc, len(interp.node_list) - 1) - ] - criteria = _run_until_criteria() - interp.run_until(criteria.hitting_next_stage) - - any_valid |= interp.node_list[interp.pc] != start_node - - last_nodes = [ - interp.node_list[interp.pc] - for interp in self.microbatch_interpreters - ] - assert all(node.op == "output" for node in last_nodes) - - local_results_and_last_grads = self._retrieve_output_values( - self.microbatch_interpreters, last_nodes - ) - - if self.pipe.has_loss_and_backwards: - # Shared parameter sync - # At this point, all of the gradient jobs should have been run - # (by way of the synchronization dependency earlier) - self._sync_replicated_params() - - if DEBUG: - self._check_stages_cleanup() - - if self.pipe.has_loss_and_backwards: - local_results = [] - last_grads = [] - for local_result in local_results_and_last_grads: - local_results.append(local_result[0]) - last_grads.append(local_result[1]) - - self.last_grads = last_grads # type: ignore[assignment] - else: - local_results = local_results_and_last_grads - - return merge_chunks( - local_results, self.output_chunk_spec, self._debug_mask_minibatches - ) - - -class PipelineDriver1F1B(PipelineDriverFillDrain): - def __init__( - self, - pipe: Pipe, - chunks: int, - world_size: int, - all_ranks: List[int] = None, - args_chunk_spec=None, - kwargs_chunk_spec=None, - output_chunk_spec=None, - single_loss: bool = False, - _debug_mask_minibatches: bool = False, - interleave_stages=False, - _record_mem_dumps=False, - checkpoint=False, - use_c10d=False, - loss_reducer: LossReducer = sum_reducer, - ): - # In 1F1B with backward stages, the maximum number of outstanding - # micro-batches equals the number of pipeline stages - max_outstanding = ( - pipe.num_stages if pipe.has_loss_and_backwards else None - ) - - super().__init__( - pipe, - chunks, - world_size, - all_ranks, - args_chunk_spec, - kwargs_chunk_spec, - output_chunk_spec, - single_loss, - _debug_mask_minibatches, - max_outstanding=max_outstanding, - interleave_stages=interleave_stages, - _record_mem_dumps=_record_mem_dumps, - checkpoint=checkpoint, - use_c10d=use_c10d, - loss_reducer=loss_reducer, - ) - - -class PipelineDriverInterleaved1F1B(PipelineDriver1F1B): - def __init__( - self, - pipe: Pipe, - chunks: int, - world_size: int, - all_ranks: List[int] = None, - args_chunk_spec=None, - kwargs_chunk_spec=None, - output_chunk_spec=None, - single_loss: bool = False, - _debug_mask_minibatches: bool = False, - _record_mem_dumps=False, - checkpoint=False, - use_c10d=False, - loss_reducer: LossReducer = sum_reducer, - ): - super().__init__( - pipe, - chunks, - world_size, - all_ranks, - args_chunk_spec, - kwargs_chunk_spec, - output_chunk_spec, - single_loss, - _debug_mask_minibatches, - interleave_stages=True, - _record_mem_dumps=_record_mem_dumps, - checkpoint=checkpoint, - use_c10d=use_c10d, - loss_reducer=loss_reducer, - ) diff --git a/pippy/PipelineStage.py b/pippy/PipelineStage.py index 5f5293af2..a3b1903e9 100644 --- a/pippy/PipelineStage.py +++ b/pippy/PipelineStage.py @@ -5,25 +5,26 @@ import torch import torch.distributed as dist +import torch.fx as fx +from torch._subclasses.fake_tensor import FakeTensor from torch.nn.parallel import DistributedDataParallel -import pippy -import pippy.fx -from pippy.backward import stage_backward, sync_barrier +from pippy.backward import stage_backward from pippy.debug import map_debug_info - -from pippy.fx.passes import shape_prop from pippy.IR import Pipe from pippy.microbatch import merge_chunks, split_args_kwargs_into_chunks from pippy.utils import flatten_args +logger = logging.getLogger(__name__) + + def _make_tensor_from_meta( - tensor_meta: shape_prop.TensorMetadata, + example_value: FakeTensor, device: torch.device, ) -> torch.Tensor: return torch.empty( - tensor_meta.shape, dtype=tensor_meta.dtype, device=device + example_value.size(), dtype=example_value.dtype, device=device ) @@ -39,7 +40,7 @@ def __init__( self.buffer = buffer def __repr__(self): - return f"RecvInfo(input={self.input_name}, source={self.source}, buffer={self.buffer.size()})" + return f"RecvInfo(input={self.input_name}, source={self.source}, shape={self.buffer.size()})" class StageArgPlaceholder: @@ -51,24 +52,20 @@ def __init__( self, pipe: Pipe, stage_index: int, - nstages: int, - chunks: int, device: torch.device, group: dist.ProcessGroup = None, - args_chunk_spec=None, - kwargs_chunk_spec=None, - output_chunk_spec=None, ): super().__init__() self.pipe = pipe self.stage_index = stage_index - self.nstages = nstages - self.chunks = chunks + self.nstages = pipe.num_stages + self.chunks = pipe.num_chunks self.device = device self.group = group - self.args_chunk_spec = args_chunk_spec - self.kwargs_chunk_spec = kwargs_chunk_spec - self.output_chunk_spec = output_chunk_spec + if dist.get_world_size(self.group) > self.nstages: + raise RuntimeError( + "Number of ranks is larger than number of stages, some ranks are unused" + ) # `group_rank` is rank in process group `group`. self.group_rank = dist.get_rank(group) @@ -90,8 +87,8 @@ def __init__( self.split_gm = self.pipe.split_gm named_children = list(self.split_gm.named_children()) self.name, self.submod = named_children[stage_index] - logging.info( - f"[{self.group_rank}][{self.name}] " + logger.info( + f"[{self.group_rank}] " f"Creating PipelineStage:\n" f"{self.submod}" ) @@ -131,7 +128,7 @@ def __init__( # In interleaved case, `group_rank` is stage index % group size. self.stage_index_to_group_rank: Dict[int, int] = {} pg_world_size = dist.get_world_size(group) - for i in range(nstages): + for i in range(self.nstages): # We only support wrapped-around interleaving peer_rank = i % pg_world_size self.stage_index_to_group_rank.setdefault(i, peer_rank) @@ -209,7 +206,7 @@ def create_recv_tensor( # real source e.g. getitem1 = submod0[1] # Here `submod0` is args[0], 1 is args[1] if input_node.target is operator.getitem: - if "tensor_meta" in input_node.meta: + if "example_value" in input_node.meta: real_input_node = input_node.args[0] out_idx = input_node.args[1] return create_recv_tensor(real_input_node, out_idx) @@ -220,20 +217,20 @@ def create_recv_tensor( ) if output_idx is not None: - # If a node has multiple output values, "tensor_meta" is a list + # If a node has multiple output values, "example_value" is a list # of tensor meta - tensor_meta = input_node.meta["tensor_meta"][output_idx] + example_value = input_node.meta["example_value"][output_idx] else: - tensor_meta = input_node.meta["tensor_meta"] + example_value = input_node.meta["example_value"] - logging.info( - f"[{self.group_rank}][{self.name}] " + logger.info( + f"[{self.group_rank}] " f"Creating recv buffer for input '{input_node.name}' " - f"value index {output_idx}: {tensor_meta.shape}" + f"value index {output_idx}: {example_value.size()}" ) src_rank = self.get_stage_index_of_submod(input_node.name) - buffer = _make_tensor_from_meta(tensor_meta, self.device) + buffer = _make_tensor_from_meta(example_value, self.device) # Enable gradient in training mode if self.pipe.has_loss_and_backwards: buffer.requires_grad_(True) @@ -245,25 +242,20 @@ def create_recv_tensor( # `args` is a Tuple, hence we will have: # Tuple[RecvInfo] - args_recv_info = pippy.fx.node.map_arg( - self.node.args, create_recv_tensor - ) + args_recv_info = fx.node.map_arg(self.node.args, create_recv_tensor) # `kwargs` is a Dict, hence we will have: # Dict[keyword, RecvInfo] - kwargs_recv_info = pippy.fx.node.map_arg( - self.node.kwargs, create_recv_tensor - ) + kwargs_recv_info = fx.node.map_arg(self.node.kwargs, create_recv_tensor) - logging.info( - f"[{self.group_rank}][{self.name}] " - f"Activation recv info: {args_recv_info}" + logger.info( + f"[{self.group_rank}] " f"Activation recv info: {args_recv_info}" ) return args_recv_info, kwargs_recv_info def find_dst_rank( self, - user: pippy.fx.Node, + user: fx.Node, ) -> Optional[int]: """ Find the destination rank of a `user` node. @@ -272,9 +264,6 @@ def find_dst_rank( if user.op == "call_module": # User is a stage (`call_module`) return self.get_stage_index_of_submod(user.name) - elif user.target is sync_barrier: - # Send result back to pp rank 0 - return 0 else: # - If user.op == "output": # No need to send back to rank 0 @@ -305,9 +294,7 @@ def _create_act_send_info(self): if dst_rank is not None: dsts.append(dst_rank) - logging.info( - f"[{self.group_rank}][{self.name}] " f"Send info: {act_send_info}" - ) + logger.info(f"[{self.group_rank}] " f"Send info: {act_send_info}") return act_send_info def _create_grad_recv_info( @@ -316,7 +303,7 @@ def _create_grad_recv_info( ) -> Dict[int, RecvInfo]: # Dict[output_index, RecvInfo] grad_recv_info: Dict = {} - my_tensor_meta = self.node.meta["tensor_meta"] + my_example_value = self.node.meta["example_value"] for out_idx, dst_list in act_send_info.items(): if not dst_list: @@ -325,9 +312,9 @@ def _create_grad_recv_info( # TODO: clean way if len(act_send_info) > 1: - tensor_meta = my_tensor_meta[out_idx] + example_value = my_example_value[out_idx] else: - tensor_meta = my_tensor_meta + example_value = my_example_value # TODO: otherwise needs grad accumulation assert len(dst_list) == 1 @@ -335,13 +322,10 @@ def _create_grad_recv_info( grad_recv_info[out_idx] = RecvInfo( f"{grad_src}", grad_src, - _make_tensor_from_meta(tensor_meta, self.device), + _make_tensor_from_meta(example_value, self.device), ) - logging.info( - f"[{self.group_rank}][{self.name}] " - f"Grad recv info: {grad_recv_info}" - ) + logger.info(f"[{self.group_rank}] " f"Grad recv info: {grad_recv_info}") return grad_recv_info def _create_grad_send_info( @@ -359,19 +343,16 @@ def map_recv_to_send(a): grad_send_info.append(None) return None - pippy.fx.node.map_aggregate(args_recv_info, map_recv_to_send) + fx.node.map_aggregate(args_recv_info, map_recv_to_send) - pippy.fx.node.map_aggregate(kwargs_recv_info, map_recv_to_send) + fx.node.map_aggregate(kwargs_recv_info, map_recv_to_send) - logging.info( - f"[{self.group_rank}][{self.name}] " - f"Grad send info: {grad_send_info}" - ) + logger.info(f"[{self.group_rank}] " f"Grad send info: {grad_send_info}") return grad_send_info def _recv_tensor(self, info, recv_reqs): - logging.debug( - f"[{self.group_rank}][{self.name}] " + logger.debug( + f"[{self.group_rank}] " f"Receiving tensor '{info.input_name}' from Rank {info.source}: " f"{info.buffer.size()}" ) @@ -401,8 +382,8 @@ def split_inputs(self, args, kwargs): args, kwargs, self.chunks, - self.args_chunk_spec, - self.kwargs_chunk_spec, + self.pipe.args_chunk_spec, + self.pipe.kwargs_chunk_spec, ) def _recv_and_fill_inputs( @@ -424,7 +405,7 @@ def recv_args(info): else: return chunk_args_list.pop(0) # type: ignore[has-type] - composite_args = pippy.fx.node.map_aggregate( + composite_args = fx.node.map_aggregate( self.args_recv_info[chunk], recv_args, ) @@ -439,7 +420,7 @@ def recv_kwargs(info): k = next(iter(chunk_kwargs)) # type: ignore[has-type] return chunk_kwargs.pop(k) # type: ignore[has-type] - composite_kwargs = pippy.fx.node.map_aggregate( + composite_kwargs = fx.node.map_aggregate( self.kwargs_recv_info[chunk], recv_kwargs, ) @@ -462,8 +443,8 @@ def _send_activations( for dst in dst_stages: if dst is None: continue - logging.debug( - f"[{self.group_rank}][{self.name}] " + logger.debug( + f"[{self.group_rank}] " f"Sending tensor to Rank {dst}: {out.size()}" ) peer_rank = self.stage_index_to_group_rank[dst] @@ -488,7 +469,7 @@ def _recv_grads( recv_grad = self.recv_tensor_fn(grad_recv_reqs) # Receive gradients - grads = pippy.fx.node.map_aggregate( + grads = fx.node.map_aggregate( self.grad_recv_info[bwd_chunk], recv_grad, ) @@ -496,8 +477,8 @@ def _recv_grads( for work in grad_recv_reqs: work.wait() - logging.debug( - f"[{self.group_rank}][{self.name}] " + logger.debug( + f"[{self.group_rank}] " f"Received output grads of chunk {bwd_chunk}: {map_debug_info(grads)}" ) return grads @@ -511,8 +492,8 @@ def _send_grads( for grad, grad_recv_stage in zip(grads_input, self.grad_send_info): if isinstance(grad, torch.Tensor) and grad_recv_stage is not None: - logging.debug( - f"[{self.group_rank}][{self.name}] " + logger.debug( + f"[{self.group_rank}] " f"Sending gradient to Rank {grad_recv_stage}: {grad.size()}" ) peer_rank = self.stage_index_to_group_rank[grad_recv_stage] @@ -579,6 +560,11 @@ def forward_one_chunk( """ raise RuntimeError(exc_msg) from e + if type(output) is list: + # HACK: this is a hacky workaround for the fact that export creates + # output in list format + output = tuple(output) + logger.debug(map_debug_info(output)) # Unify output form to tuple for easy correspondance with # `act_send_info` output_tuple = output if type(output) is tuple else (output,) @@ -643,7 +629,7 @@ def clear_runtime_states(self): def merge_output_chunks(self): return merge_chunks( self.output_chunks, - self.output_chunk_spec, + self.pipe.output_chunk_spec, ) def forward(self, *args, **kwargs): @@ -656,16 +642,18 @@ def forward(self, *args, **kwargs): # Forward pass of all chunks for chunk in range(self.chunks): self.forward_one_chunk(chunk) - - # Wait for all sends to finish - # TODO: okay to delay the sync till completion of all chunks? - for work in self.all_act_send_reqs: - work.wait() + logger.debug(f"[{self.group_rank}] Forwarded chunk {chunk}") # Backward starts here for bwd_chunk in range(self.chunks): self.backward_one_chunk(bwd_chunk) + logger.debug(f"[{self.group_rank}] Backwarded chunk {bwd_chunk}") + + # Wait for all sends to finish + # TODO: okay to delay the sync till completion of all chunks? + for work in self.all_act_send_reqs: + work.wait() # Wait for all sends to finish # TODO: okay to delay the sync till completion of all chunks? @@ -684,24 +672,14 @@ def __init__( self, pipe: Pipe, rank: int, - nstages: int, - chunks: int, device: torch.device, group: dist.ProcessGroup = None, - args_chunk_spec=None, - kwargs_chunk_spec=None, - output_chunk_spec=None, ): super().__init__( pipe, rank, - nstages, - chunks, device, group=group, - args_chunk_spec=args_chunk_spec, - kwargs_chunk_spec=kwargs_chunk_spec, - output_chunk_spec=output_chunk_spec, ) def forward(self, *args, **kwargs): diff --git a/pippy/__init__.py b/pippy/__init__.py index 49dc0493a..75de5f7bd 100644 --- a/pippy/__init__.py +++ b/pippy/__init__.py @@ -1,10 +1,4 @@ # Copyright (c) Meta Platforms, Inc. and affiliates -from pippy.compile import ( - all_compile, - compile, - compile_stage, - create_default_args, -) from pippy.IR import ( annotate_split_points, LossWrapper, @@ -15,8 +9,6 @@ TrivialLossWrapper, ) from pippy.ModelSplit import split_into_equal_size, split_on_size_threshold -from pippy.PipelineDriver import PipelineDriver1F1B, PipelineDriverFillDrain -from pippy.utils import run_pippy __all__ = [ @@ -25,15 +17,8 @@ "TrivialLossWrapper", "Pipe", "pipe_split", - "run_pippy", "PipeSplitWrapper", "annotate_split_points", - "PipelineDriverFillDrain", - "PipelineDriver1F1B", "split_into_equal_size", "split_on_size_threshold", - "compile", - "all_compile", - "create_default_args", - "compile_stage", ] diff --git a/pippy/auto_parallelization.py b/pippy/auto_parallelization.py index 94e026807..a22cc81f1 100644 --- a/pippy/auto_parallelization.py +++ b/pippy/auto_parallelization.py @@ -18,7 +18,7 @@ import numpy as np -import pippy.fx +from torch import fx from pippy import pipe_split @@ -272,7 +272,7 @@ class AutoParallelConfig: def dp_auto_parallel(config: AutoParallelConfig): - def _dp_auto_parallel(fx_mod: pippy.fx.GraphModule): + def _dp_auto_parallel(fx_mod: fx.GraphModule): n_graph_nodes = len(fx_mod.graph.nodes) submesh_shapes = get_possible_submesh_shapes( n_compute_nodes=config.n_compute_nodes, diff --git a/pippy/compile.py b/pippy/compile.py deleted file mode 100644 index 5ddf074c6..000000000 --- a/pippy/compile.py +++ /dev/null @@ -1,321 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates -import inspect -import logging -from typing import Any, Callable, List, Optional - -import torch -import torch.distributed as dist -from torch._subclasses.fake_tensor import FakeTensorMode - -import pippy.fx as fx -from pippy.debug import PIPPY_VERBOSITY -from pippy.IR import MultiUseParameterConfig, Pipe, PiPPyShapeProp -from pippy.microbatch import ( - gen_output_chunk_spec, - LossReducer, - split_args_kwargs_into_chunks, - sum_reducer, -) -from pippy.PipelineDriver import ( - PipelineDriver1F1B, - PipelineDriverFillDrain, - PipelineDriverInterleaved1F1B, -) -from pippy.PipelineStage import PipelineStage, PipelineStage1F1B -from pippy.utils import get_device, get_pp_rank, get_rank - - -PIPELINE_SCHEDULE_DRIVERS = { - "FillDrain": PipelineDriverFillDrain, - "1F1B": PipelineDriver1F1B, - "Interleaved1F1B": PipelineDriverInterleaved1F1B, -} - - -def create_default_args( - mod: torch.nn.Module, - except_keys: List = None, -): - if except_keys is None: - except_keys = [] - sig = inspect.signature(mod.forward) - default_kwargs = { - p.name: p.default - for p in sig.parameters.values() - if p.name not in except_keys and p.default is not inspect._empty - } - return default_kwargs - - -def _compile( - all_compile: bool, - mod: torch.nn.Module, - num_ranks: int, - num_chunks: int, - schedule: Optional[str] = "FillDrain", - split_policy: Optional[Callable[[fx.GraphModule], fx.GraphModule]] = None, - ranks: List[int] = None, - tracer=None, - loss_reducer: LossReducer = sum_reducer, - args_chunk_spec=None, - kwargs_chunk_spec=None, - output_chunk_spec=None, - checkpoint=False, - _debug_mask_minibatches: bool = False, - index_filename=None, - checkpoint_prefix: str = None, - **kwargs, -): - if ranks is None: - ranks = list(range(num_ranks)) - - if all_compile: - rank = get_rank() - pp_rank = get_pp_rank(rank, ranks) - else: - pp_rank = 0 - - # If a param will be used in multiple pipeline stages, we default the strategy to REPLICATE'ing the param across - # stages instead of TRANSMIT'ting it - multi_use_param_spec = MultiUseParameterConfig.REPLICATE - - # Figure out which output is loss from output_chunk_spec - output_loss_value_spec: Any = None - if output_chunk_spec is not None: - output_loss_value_spec = fx.node.map_aggregate( - output_chunk_spec, lambda v: isinstance(v, LossReducer) - ) - - logging.info("[PiPPy] Tracing model ...") - pipe_model = Pipe.from_tracing( - mod, - multi_use_param_spec=multi_use_param_spec, - tracer=tracer, - output_loss_value_spec=output_loss_value_spec, - split_policy=split_policy, - **kwargs, - ) - - # In all_compile mode, each rank calls pippy.all_compile, hence they will all have the pipe. - # We can hence ask each rank to get its own stage from the pipe, and materialize it locally. - if all_compile: - device = get_device() - - # `None` means self.dtype, i.e. no change - dtype = None - # TODO: generalize this - if hasattr(mod, "config") and hasattr(mod.config, "torch_dtype"): - dtype = mod.config.torch_dtype # type: ignore[union-attr] - - pipe_model.defer_stage_init( - device, - index_filename, - dtype, - checkpoint_prefix, - ) - stage_mod = pipe_model.export(pp_rank) - - if pp_rank == 0: - logging.info(pipe_model.split_gm) - - logging.info("[PiPPy] Creating pipeline driver ...") - if schedule not in PIPELINE_SCHEDULE_DRIVERS: - raise ValueError( - f"Unknown pipeline schedule: {schedule}. " - f"Please select from {PIPELINE_SCHEDULE_DRIVERS.keys()}" - ) - pipeline_driver = PIPELINE_SCHEDULE_DRIVERS[schedule]( - pipe_model, - num_chunks, - num_ranks, - all_ranks=ranks, - args_chunk_spec=args_chunk_spec, - kwargs_chunk_spec=kwargs_chunk_spec, - output_chunk_spec=output_chunk_spec, - checkpoint=checkpoint, - loss_reducer=loss_reducer, - _debug_mask_minibatches=_debug_mask_minibatches, - ) - - if not all_compile: - return pipeline_driver - - if pp_rank == 0: - return pipeline_driver, stage_mod - else: - return None, stage_mod - - -def compile( - mod: torch.nn.Module, - num_ranks: int, - num_chunks: int, - schedule: Optional[str] = "FillDrain", - split_policy: Optional[Callable[[fx.GraphModule], fx.GraphModule]] = None, - ranks: List[int] = None, - tracer=None, - loss_reducer: LossReducer = sum_reducer, - args_chunk_spec=None, - kwargs_chunk_spec=None, - output_chunk_spec=None, - checkpoint=False, - _debug_mask_minibatches: bool = False, - **kwargs, -): - return _compile( - False, - mod, - num_ranks, - num_chunks, - schedule=schedule, - split_policy=split_policy, - ranks=ranks, - tracer=tracer, - loss_reducer=loss_reducer, - args_chunk_spec=args_chunk_spec, - kwargs_chunk_spec=kwargs_chunk_spec, - output_chunk_spec=output_chunk_spec, - checkpoint=checkpoint, - _debug_mask_minibatches=_debug_mask_minibatches, - **kwargs, - ) - - -def all_compile( - mod: torch.nn.Module, - num_ranks: int, - num_chunks: int, - schedule: Optional[str] = "FillDrain", - split_policy: Optional[Callable[[fx.GraphModule], fx.GraphModule]] = None, - ranks: List[int] = None, - tracer=None, - loss_reducer: LossReducer = sum_reducer, - args_chunk_spec=None, - kwargs_chunk_spec=None, - output_chunk_spec=None, - checkpoint=False, - _debug_mask_minibatches: bool = False, - **kwargs, -): - return _compile( - True, - mod, - num_ranks, - num_chunks, - schedule=schedule, - split_policy=split_policy, - ranks=ranks, - tracer=tracer, - loss_reducer=loss_reducer, - args_chunk_spec=args_chunk_spec, - kwargs_chunk_spec=kwargs_chunk_spec, - output_chunk_spec=output_chunk_spec, - checkpoint=checkpoint, - _debug_mask_minibatches=_debug_mask_minibatches, - **kwargs, - ) - - -def compile_stage( - mod: torch.nn.Module, - stage_index: int, - num_stages: int, - num_chunks: int, - device: torch.device, - group: dist.ProcessGroup, - example_inputs: List[torch.Tensor], - split_policy: Optional[Callable[[fx.GraphModule], fx.GraphModule]] = None, - return_to_0: bool = False, - tracer=None, - loss_reducer: LossReducer = sum_reducer, - args_chunk_spec=None, - kwargs_chunk_spec=None, - output_chunk_spec=None, - schedule="FillDrain", - **kwargs, -) -> PipelineStage: - # If a param will be used in multiple pipeline stages, we default the strategy to REPLICATE'ing the param across - # stages instead of TRANSMIT'ting it - multi_use_param_spec = MultiUseParameterConfig.REPLICATE - - # Figure out which output is loss from output_chunk_spec - output_loss_value_spec: Any = None - if output_chunk_spec is not None: - output_loss_value_spec = fx.node.map_aggregate( - output_chunk_spec, lambda v: isinstance(v, LossReducer) - ) - - logging.info("[PiPPy] Tracing model ...") - pipe = Pipe.from_tracing( - mod, - multi_use_param_spec=multi_use_param_spec, - tracer=tracer, - output_loss_value_spec=output_loss_value_spec, - split_policy=split_policy, - return_to_0=return_to_0, - **kwargs, - ) - - gm = pipe.split_gm - if stage_index == 0: - logging.info(gm) - if PIPPY_VERBOSITY == "INFO": - gm.graph.print_tabular() - - # Get shape of chunked arguments - args_split, _ = split_args_kwargs_into_chunks( - example_inputs, - {}, # kwargs included in `example_inputs` - num_chunks, - args_chunk_spec, - kwargs_chunk_spec, # TODO: merge into args_chunk_spec - ) - - # Use fake tensor for shape propagation - # Since model itself may have been materialized, we need to use - # `allow_non_fake_inputs` - fake_mode = FakeTensorMode(allow_non_fake_inputs=True) - # In reality, the fake input should be created from shape info (potentially - # broadcast from Rank 0) - fake_args_split = fx.node.map_aggregate( - args_split, lambda a: fake_mode.from_tensor(a) - ) - - # Use 1st chunk of args for shape propagation - chunk0 = fake_args_split[0] - - sp = PiPPyShapeProp(gm) - sp.propagate(*chunk0) - - # Prepare output chunk/reduce spec for merging/reducing final outputs - output_chunk_spec = ( - output_chunk_spec - if output_chunk_spec - else gen_output_chunk_spec(pipe.loss_spec, loss_reducer) - ) - - # Create pipeline stage based on schedule - if schedule == "1F1B": - return PipelineStage1F1B( - pipe, - stage_index, - num_stages, - num_chunks, - device, - group=group, - args_chunk_spec=args_chunk_spec, - kwargs_chunk_spec=kwargs_chunk_spec, - output_chunk_spec=output_chunk_spec, - ) - else: - return PipelineStage( - pipe, - stage_index, - num_stages, - num_chunks, - device, - group=group, - args_chunk_spec=args_chunk_spec, - kwargs_chunk_spec=kwargs_chunk_spec, - output_chunk_spec=output_chunk_spec, - ) diff --git a/pippy/debug.py b/pippy/debug.py index c393581a9..4e96cf7d1 100644 --- a/pippy/debug.py +++ b/pippy/debug.py @@ -4,19 +4,19 @@ import torch -import pippy.fx - PIPPY_VERBOSITY = os.environ.get("PIPPY_VERBOSITY", "OFF") if PIPPY_VERBOSITY == "DEBUG": - logging.getLogger().setLevel(logging.DEBUG) + logging.getLogger("pippy").setLevel(logging.DEBUG) elif PIPPY_VERBOSITY == "INFO": - logging.getLogger().setLevel(logging.INFO) + logging.getLogger("pippy").setLevel(logging.INFO) elif PIPPY_VERBOSITY == "OFF": pass else: - print(f"Unsupported PIPPY_VERBOSITY level: {PIPPY_VERBOSITY}") + print(f"[PiPPy] Unsupported PIPPY_VERBOSITY level: {PIPPY_VERBOSITY}") + +print(f"[PiPPy] Setting logging level to: {PIPPY_VERBOSITY}") def friendly_debug_info(v): @@ -27,4 +27,4 @@ def friendly_debug_info(v): def map_debug_info(a): - return pippy.fx.node.map_aggregate(a, friendly_debug_info) + return torch.fx.node.map_aggregate(a, friendly_debug_info) diff --git a/pippy/fx/OVERVIEW.md b/pippy/fx/OVERVIEW.md deleted file mode 100644 index f2995eb7a..000000000 --- a/pippy/fx/OVERVIEW.md +++ /dev/null @@ -1,134 +0,0 @@ -# FX Technical Overview (WIP) - -FX is a toolkit for pass writers to facilitate Python-to-Python transformation of `nn.Module` instances. This toolkit aims to support a subset of Python language semantics—rather than the whole Python language—to facilitate ease of implementation of transforms. Currently, this feature is under a Beta release and its API may change. - -## Table of Contents - - - -- [Introduction](#introduction) - - [Motivation](#motivation) - - [Use Cases](#use-cases) - - [Technical Details](#technical-details) -- [Internal Structure](#internal-structure) - - [Graph](#graph) - - [GraphModule](#graphmodule) -- [Symbolic Tracing](#symbolic-tracing) - - [Tracer](#tracer) - - [Proxy](#proxy) -- [The FX IR](#the-fx-ir) -- [Transformation and Codegen](#transformation-and-codegen) - - - -# Introduction - -## Motivation ## - -TODO - -## Use Cases ## - -FX should be used by pass writers to provide functionality for capturing and constructing nn.Module code in a structured way. We do not expect end users to utilize FX directly. A useful property of framing FX in this way is that passes can be seen as functions of the form `pass(in_mod : nn.Module) -> nn.Module`. This means we can create composable pipelines of transformations. - -![An image of a sample nn.Module transformation pipeline that starts with a Quantize transformation, which is then composed with a Split transformation, then a Lower to Accelerator transformation](https://i.imgur.com/TzFIYMi.png "nn.Module transformation pipeline") - -In this example pipeline, we have a Quantize transformation, which is then composed with a Split transformation, then a Lower to Accelerator transformation. Finally, the transformed Modules are compiled with TorchScript for deployment. This last point emphasizes that not only should FX transforms be composable with each other, but their products are composable with other systems like TorchScript compilation or tracing. - -By using `nn.Module` as the interface between passes, FX transforms are interoperable with each other, and the resulting model can be used anywhere an `nn.Module` can be used. - -## Technical Details ## - -The following sections will walk us through the components that transform from original `torch.nn.Module` to FX IR and finally to generated Python code and a GraphModule instance: - -FX’s front-end makes use of the dynamic nature of Python to intercept call-sites for various entities (PyTorch operators, Module invocations, and Tensor method invocations). This functionality is exposed through an API called `torch.fx.symbolic_trace`. We can see how this works by way of an example: - -```python -import torch - -class MyModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.param = torch.nn.Parameter( - torch.rand(3, 4)) - self.linear = torch.nn.Linear(4, 5) - - def forward(self, x): - return self.linear(x + self.param).clamp(min=0.0, max=1.0) - -from torch.fx import symbolic_trace -module = MyModule() -symbolic_traced : torch.fx.GraphModule = symbolic_trace(module) - -input = torch.rand(3, 4) -torch.testing.assert_allclose(symbolic_traced(input), module(input)) -``` - -Here, we set up a simple Module that exercises different language features: fetching a parameter, applying an arithmetic operator, applying a submodule (linear), and applying a Tensor method. `symbolic_trace` returns an instance of GraphModule, which is in itself a subclass of `nn.Module`. We can see that the `symbolic_traced` instance runs and returns the same result as the original module instance module. - -# Internal Structure - -## [Graph](https://pytorch.org/docs/master/fx.html#torch.fx.Graph) ## -TODO - -## [GraphModule](https://pytorch.org/docs/master/fx.html#torch.fx.GraphModule) ## -TODO - -# Symbolic Tracing - -## [Tracer](https://pytorch.org/docs/master/fx.html#torch.fx.Tracer) ## - -`Tracer` is the class that implements the symbolic tracing functionality of `torch.fx.symbolic_trace`. A call to `symbolic_trace(m)` is equivalent to `Tracer().trace(m)`. Tracer can be subclassed to override various behaviors of the tracing process. The different behaviors that can be overridden are described in the docstrings of the methods on the class. - -In the default implementation of `Tracer().trace`, the tracer first creates Proxy objects for all arguments in the `forward` function. (This happens in the call to `create_args_for_root`.) Next, the `forward` function is called with the new Proxy arguments. As the Proxies flow through the program, they record all the operations (`torch` function calls, method calls, and operators) that they touch into the growing FX Graph as Nodes. - -## Proxy ## - -Proxy objects are Node wrappers used by the Tracer to record operations seen during symbolic tracing. The mechanism through which Proxy objects record computation is [`__torch_function__`](https://pytorch.org/docs/stable/notes/extending.html#extending-torch). If any custom Python type defines a method named `__torch_function__`, PyTorch will invoke that `__torch_function__` implementation when an instance of that custom type is passed to a function in the `torch` namespace. In FX, when operations on Proxy are dispatched to the `__torch_function__` handler, the `__torch_function__` handler records the operation in the Graph as a Node. The Node that was recorded in the Graph is then itself wrapped in a Proxy, facilitating further application of ops on that value. - -Consider the following example: - -```python - class M(torch.nn.Module): - def forward(self, x): - return torch.relu(x) - - m = M() - traced = symbolic_trace(m) -``` - -During the call to `symbolic_trace`, the parameter `x` is transformed into a Proxy object and the corresponding Node (a Node with op = “placeholder” and target = “x”) is added to the Graph. Then, the Module is run with Proxies as inputs, and recording happens via the `__torch_function__` dispatch path. - -If you're doing graph transforms, you can wrap your own Proxy method around a raw Node so that you can use the overloaded operators to add additional things to a Graph. - -# The FX IR - -Symbolic tracing captures an intermediate representation (IR), which is represented as a doubly-linked list of Nodes. - -Node is the data structure that represents individual operations within a Graph. For the most part, Nodes represent callsites to various entities, such as operators, methods, and Modules (some exceptions include Nodes that specify function inputs and outputs). Each Node has a function specified by its `op` property. The Node semantics for each value of `op` are as follows: - -- `placeholder` represents a function input. The `name` attribute specifies the name this value will take on. `target` is similarly the name of the argument. `args` holds either: 1) nothing, or 2) a single argument denoting the default parameter of the function input. `kwargs` is don't-care. Placeholders correspond to the function parameters (e.g. `x`) in the graph printout. -- `get_attr` retrieves a parameter from the module hierarchy. `name` is similarly the name the result of the fetch is assigned to. `target` is the fully-qualified name of the parameter's position in the module hierarchy. `args` and `kwargs` are don't-care -- `call_function` applies a free function to some values. `name` is similarly the name of the value to assign to. `target` is the function to be applied. `args` and `kwargs` represent the arguments to the function, following the Python calling convention -- `call_module` applies a module in the module hierarchy's `forward()` method to given arguments. `name` is as previous. `target` is the fully-qualified name of the module in the module hierarchy to call. `args` and `kwargs` represent the arguments to invoke the module on, *including the self argument*. -- `call_method` calls a method on a value. `name` is as similar. `target` is the string name of the method to apply to the `self` argument. `args` and `kwargs` represent the arguments to invoke the module on, *including the self argument* -- `output` contains the output of the traced function in its `args[0]` attribute. This corresponds to the "return" statement in the Graph printout. - -To facilitate easier analysis of data dependencies, Nodes have read-only properties `input_nodes` and `users`, which specify which Nodes in the Graph are used by this Node and which Nodes use this Node, respectively. Although Nodes are represented as a doubly-linked list, the use-def relationships form an acyclic graph and can be traversed as such. - -# Transformation and Codegen - -An invocation of `symbolic_traced` above requires a valid `forward()` method to be defined on the Module instance. How does this work? GraphModule actually generates valid Python source code based on the IR it is instantiated with. This can be seen by accessing the code attribute on the GraphModule: `print(symbolic_traced.code)`. - -After symbolic tracing, the code given under [Technical Details](#technical-details) is represented as follows: - -```python -def forward(self, x): - param = self.param - add_1 = x + param; x = param = None - linear_1 = self.linear(add_1); add_1 = None - clamp_1 = linear_1.clamp(min = 0.0, max = 1.0); linear_1 = None - return clamp_1 -``` - -This is the core of why FX is a Python-to-Python translation toolkit. Outside users can treat the results of FX transformations as they would any other `nn.Module` instance. diff --git a/pippy/fx/__init__.py b/pippy/fx/__init__.py deleted file mode 100644 index e52cc2619..000000000 --- a/pippy/fx/__init__.py +++ /dev/null @@ -1,90 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates -r''' -FX is a toolkit for developers to use to transform ``nn.Module`` -instances. FX consists of three main components: a **symbolic tracer,** -an **intermediate representation**, and **Python code generation**. A -demonstration of these components in action: - -:: - - import torch - # Simple module for demonstration - class MyModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.param = torch.nn.Parameter(torch.rand(3, 4)) - self.linear = torch.nn.Linear(4, 5) - - def forward(self, x): - return self.linear(x + self.param).clamp(min=0.0, max=1.0) - - module = MyModule() - - from pippy.fx import symbolic_trace - # Symbolic tracing frontend - captures the semantics of the module - symbolic_traced : pippy.fx.GraphModule = symbolic_trace(module) - - # High-level intermediate representation (IR) - Graph representation - print(symbolic_traced.graph) - """ - graph(): - %x : [#users=1] = placeholder[target=x] - %param : [#users=1] = get_attr[target=param] - %add : [#users=1] = call_function[target=operator.add](args = (%x, %param), kwargs = {}) - %linear : [#users=1] = call_module[target=linear](args = (%add,), kwargs = {}) - %clamp : [#users=1] = call_method[target=clamp](args = (%linear,), kwargs = {min: 0.0, max: 1.0}) - return clamp - """ - - # Code generation - valid Python code - print(symbolic_traced.code) - """ - def forward(self, x): - param = self.param - add = x + param; x = param = None - linear = self.linear(add); add = None - clamp = linear.clamp(min = 0.0, max = 1.0); linear = None - return clamp - """ - -The **symbolic tracer** performs "symbolic execution" of the Python -code. It feeds fake values, called Proxies, through the code. Operations -on theses Proxies are recorded. More information about symbolic tracing -can be found in the :func:`symbolic_trace` and :class:`Tracer` -documentation. - -The **intermediate representation** is the container for the operations -that were recorded during symbolic tracing. It consists of a list of -Nodes that represent function inputs, callsites (to functions, methods, -or :class:`torch.nn.Module` instances), and return values. More information -about the IR can be found in the documentation for :class:`Graph`. The -IR is the format on which transformations are applied. - -**Python code generation** is what makes FX a Python-to-Python (or -Module-to-Module) transformation toolkit. For each Graph IR, we can -create valid Python code matching the Graph's semantics. This -functionality is wrapped up in :class:`GraphModule`, which is a -:class:`torch.nn.Module` instance that holds a :class:`Graph` as well as a -``forward`` method generated from the Graph. - -Taken together, this pipeline of components (symbolic tracing -> -intermediate representation -> transforms -> Python code generation) -constitutes the Python-to-Python transformation pipeline of FX. In -addition, these components can be used separately. For example, -symbolic tracing can be used in isolation to capture a form of -the code for analysis (and not transformation) purposes. Code -generation can be used for programmatically generating models, for -example from a config file. There are many uses for FX! - -Several example transformations can be found at the -`examples `__ -repository. -''' - -from .graph_module import GraphModule -from ._symbolic_trace import symbolic_trace, Tracer, wrap, PH, ProxyableClassMeta -from .graph import Graph, CodeGen -from .node import Node, map_arg -from .proxy import Proxy -from .interpreter import Interpreter as Interpreter, Transformer as Transformer -from .subgraph_rewriter import replace_pattern diff --git a/pippy/fx/__init__.pyi b/pippy/fx/__init__.pyi deleted file mode 100644 index 2faf3b021..000000000 --- a/pippy/fx/__init__.pyi +++ /dev/null @@ -1,8 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates -from .graph import Graph as Graph -from .graph_module import GraphModule as GraphModule -from .node import Node as Node, map_arg as map_arg -from .proxy import Proxy as Proxy -from ._symbolic_trace import Tracer as Tracer, symbolic_trace as symbolic_trace, wrap as wrap -from .interpreter import Interpreter as Interpreter, Transformer as Transformer -from .subgraph_rewriter import replace_pattern as replace_pattern diff --git a/pippy/fx/_compatibility.py b/pippy/fx/_compatibility.py deleted file mode 100644 index 559232ce2..000000000 --- a/pippy/fx/_compatibility.py +++ /dev/null @@ -1,35 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates -from typing import Any, Dict -import textwrap - -_BACK_COMPAT_OBJECTS : Dict[Any, None] = {} -_MARKED_WITH_COMATIBLITY : Dict[Any, None] = {} - -def compatibility(is_backward_compatible : bool): - if is_backward_compatible: - - def mark_back_compat(fn): - docstring = textwrap.dedent(getattr(fn, '__doc__', None) or '') - docstring += """ -.. note:: - Backwards-compatibility for this API is guaranteed. -""" - fn.__doc__ = docstring - _BACK_COMPAT_OBJECTS.setdefault(fn) - _MARKED_WITH_COMATIBLITY.setdefault(fn) - return fn - - return mark_back_compat - else: - - def mark_not_back_compat(fn): - docstring = textwrap.dedent(getattr(fn, '__doc__', None) or '') - docstring += """ -.. warning:: - This API is experimental and is *NOT* backward-compatible. -""" - fn.__doc__ = docstring - _MARKED_WITH_COMATIBLITY.setdefault(fn) - return fn - - return mark_not_back_compat diff --git a/pippy/fx/_pytree.py b/pippy/fx/_pytree.py deleted file mode 100644 index be8a61af2..000000000 --- a/pippy/fx/_pytree.py +++ /dev/null @@ -1,43 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates -from typing import Callable, Any, Tuple, List, Dict, Type, NamedTuple -from torch.utils._pytree import PyTree, TreeSpec, LeafSpec -from collections import namedtuple - -FlattenFuncSpec = Callable[[PyTree, TreeSpec], List] - -SUPPORTED_NODES: Dict[Type[Any], Any] = {} -def register_pytree_flatten_spec(typ: Any, flatten_fn_spec: FlattenFuncSpec) -> None: - SUPPORTED_NODES[typ] = flatten_fn_spec - -def tree_flatten_spec(pytree: PyTree, spec: TreeSpec) -> List[Any]: - if isinstance(spec, LeafSpec): - return [pytree] - if spec.type not in SUPPORTED_NODES: - raise RuntimeError( - f"{type(pytree)} does not have a flatten_fn_spec associated with it. Please register one with" - "pippy.fx._pytree.register_pytree_flatten_spec. If you have serialized your model, make" - "sure that any custom pytrees have been registered before loading it.") - flatten_fn_spec = SUPPORTED_NODES[spec.type] - child_pytrees = flatten_fn_spec(pytree, spec) - result = [] - for child, child_spec in zip(child_pytrees, spec.children_specs): - flat = tree_flatten_spec(child, child_spec) - result += flat - return result - -def _dict_flatten_spec(d: Dict[Any, Any], spec: TreeSpec) -> List[Any]: - return list([d[k] for k in spec.context]) - -def _list_flatten_spec(d: List[Any], spec: TreeSpec) -> List[Any]: - return [d[i] for i in range(len(spec.children_specs))] - -def _tuple_flatten_spec(d: Tuple[Any], spec: TreeSpec) -> List[Any]: - return [d[i] for i in range(len(spec.children_specs))] - -def _namedtuple_flatten_spec(d: NamedTuple, spec: TreeSpec) -> List[Any]: - return [d[i] for i in range(len(spec.children_specs))] - -register_pytree_flatten_spec(dict, _dict_flatten_spec) -register_pytree_flatten_spec(list, _list_flatten_spec) -register_pytree_flatten_spec(tuple, _tuple_flatten_spec) -register_pytree_flatten_spec(namedtuple, _tuple_flatten_spec) diff --git a/pippy/fx/_symbolic_trace.py b/pippy/fx/_symbolic_trace.py deleted file mode 100644 index 00937803a..000000000 --- a/pippy/fx/_symbolic_trace.py +++ /dev/null @@ -1,1080 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates -import builtins -import copy -import functools -import inspect -import math -import os -import warnings -from itertools import chain -from types import CodeType, FunctionType, ModuleType -from typing import ( - Any, - Callable, - Dict, - List, - NamedTuple, - Optional, - Set, - Tuple, - Type, - Union, -) - -import torch -import torch.utils._pytree as pytree -from torch._C import ScriptObject # type: ignore[attr-defined] - -from ._compatibility import compatibility -from .graph import _PyTreeCodeGen, _PyTreeInfo, Graph -from .graph_module import GraphModule -from .node import Argument, base_types, map_aggregate # pylint: disable=unused-import -from .proxy import ParameterProxy, Proxy, TracerBase - -HAS_VARSTUFF = inspect.CO_VARARGS | inspect.CO_VARKEYWORDS - -# These need to run in global scope to handle nested calls correctly -_orig_module_call: Callable = torch.nn.Module.__call__ -_orig_module_getattr: Callable = torch.nn.Module.__getattr__ - -_proxyable_classes: Dict[Type, None] = {} - -_is_fx_tracing_flag = False - - -def is_fx_tracing(): - return _is_fx_tracing_flag - - -@compatibility(is_backward_compatible=True) -class ProxyableClassMeta(type): - """ - ProxyableClassMeta allows you to make construction of a given Python class - symbolically traceable. For example:: - - import torch - import pippy.fx - - class TensorPair(metaclass=pippy.fx.ProxyableClassMeta): - def __init__(self, left, right): - self.left, self.right = left, right - - def add(self, other): - l = self.left + other.left - r = self.right + other.right - return TensorPair(l, r) - - def mul(self, other): - l = self.left * other.left - r = self.right * other.right - return TensorPair(l, r) - - def use_tensor_pair_ctor(x : TensorPair, y : torch.Tensor): - s = x.add(TensorPair(y, y)) - return s.mul(x) - - x = TensorPair(torch.randn(5, 3), torch.randn(5, 3)) - y = torch.randn(5, 3) - ref_out = use_tensor_pair_ctor(x, y) - - traced = pippy.fx.symbolic_trace(use_tensor_pair_ctor) - print(traced.code) - ''' - def forward(self, x : __main___TensorPair, y : torch.Tensor): - tensor_pair = __main___TensorPair(y, y); y = None - add = x.add(tensor_pair); tensor_pair = None - mul = add.mul(x); add = x = None - return mul - ''' - - From this example, we can see that contruction of a class (``TensorPair``) - defined with ``ProxyableClassMeta`` as metaclass can be recorded in symbolic - tracing. - """ - - def __init__(cls, name, bases, attrs): - _proxyable_classes.setdefault(cls) - super().__init__(name, bases, attrs) - - def __call__(cls, *args, **kwargs): - instance = cls.__new__(cls) # type: ignore[call-overload] - - found_proxies = [] - - def check_proxy(a): - if isinstance(a, Proxy): - found_proxies.append(a) - - map_aggregate(args, check_proxy) - map_aggregate(kwargs, check_proxy) - - if len(found_proxies) != 0: - tracer = found_proxies[0].tracer - return tracer.create_proxy("call_function", cls, args, kwargs) - else: - cls.__init__(instance, *args, **kwargs) # type: ignore[misc] - return instance - - -def _patch_function(fn: FunctionType, nargs: int) -> FunctionType: - co = fn.__code__ - co_flags = co.co_flags & ~HAS_VARSTUFF - co_args: tuple - if hasattr(co, "co_posonlyargcount"): - co_args = ( - nargs, - 0, - 0, - co.co_nlocals, - co.co_stacksize, - co_flags, - co.co_code, - co.co_consts, - co.co_names, - co.co_varnames, - co.co_filename, - co.co_name, - co.co_firstlineno, - co.co_lnotab, - co.co_freevars, - co.co_cellvars, - ) - else: - co_args = ( - nargs, - 0, - co.co_nlocals, - co.co_stacksize, - co_flags, - co.co_code, - co.co_consts, - co.co_names, - co.co_varnames, - co.co_filename, - co.co_name, - co.co_firstlineno, - co.co_lnotab, - co.co_freevars, - co.co_cellvars, - ) - new_code = CodeType(*co_args) # type: ignore[arg-type] - return FunctionType( - new_code, fn.__globals__, fn.__name__, fn.__defaults__, fn.__closure__ - ) - - # we need to insert placeholder nodes for *args and **kwargs - # we can't call this function normally, otherwise it would try to unpack them - # instead, let's make python think that args and kwargs are normal variables - - -@compatibility(is_backward_compatible=False) -class PHBase(object): - """ - Object representing an input placeholder to `concrete_args` - """ - - def __repr__(self): - return "PH" - - -PH = PHBase() - - -@compatibility(is_backward_compatible=True) -class Tracer(TracerBase): - # Reference: https://github.com/pytorch/pytorch/issues/54354 - # The first line of this docstring overrides the one Sphinx generates for the - # documentation. We need it so that Sphinx doesn't leak `math`s path from the - # build environment (e.g. ` None: - # This method's signature is overridden by the first line of this class' - # docstring. If this method's signature is modified, the signature that - # overrides it also should be modified accordingly. - - """ - Construct a Tracer object. - - Args: - - autowrap_modules (Tuple[ModuleType]): defaults to `(math, )`, - Python modules whose functions should be wrapped automatically - without needing to use fx.wrap(). Backward-compatibility for - this parameter is guaranteed. - - autowrap_function (Tuple[Callable, ...]): defaults to `()`, - Python functions that should be wrapped automatically without - needing to use fx.wrap(). Backward compabilibility for this - parameter is guaranteed. - - param_shapes_constant (bool): When this flag is set, calls to shape, - size and a few other shape like attributes of a module's parameter - will be evaluted directly, rather than returning a new Proxy value - for an attribute access. Backward compatibility for this parameter - is guaranteed. - """ - - super().__init__() - - # Functions we will eagerly wrap when we see them while tracing - # this captures both `math.sqrt()` and `from math import sqrt` automatically - self._autowrap_function_ids: Set[int] = { - id(value) - for name, value in chain(*[m.__dict__.items() for m in autowrap_modules]) - if not name.startswith("_") and callable(value) - } - self._autowrap_function_ids.update(set([id(f) for f in autowrap_functions])) - - # Python modules to apply autowrap to at the start, in addition to - # modules we see while tracing - self._autowrap_search: List[ModuleType] = list(autowrap_modules) - self.param_shapes_constant = param_shapes_constant - - self.submodule_paths: Optional[Dict[torch.nn.Module, str]] = None - - @compatibility(is_backward_compatible=True) - def create_arg(self, a: Any) -> "Argument": - """ - A method to specify the behavior of tracing when preparing values to - be used as arguments to nodes in the ``Graph``. - - By default, the behavior includes: - - #. Iterate through collection types (e.g. tuple, list, dict) and recursively - call ``create_args`` on the elements. - #. Given a Proxy object, return a reference to the underlying IR ``Node`` - #. Given a non-Proxy Tensor object, emit IR for various cases: - - * For a Parameter, emit a ``get_attr`` node referring to that Parameter - * For a non-Parameter Tensor, store the Tensor away in a special - attribute referring to that attribute. - - This method can be overridden to support more types. - - Args: - - a (Any): The value to be emitted as an ``Argument`` in the ``Graph``. - - - Returns: - - The value ``a`` converted into the appropriate ``Argument`` - """ - # The base tracer is used to construct Graphs when there is no associated - # module hierarchy, so it can never create parameter references. - # The default tracer adds the ability to refer to parameters when - # tracing modules. - if isinstance(a, torch.nn.Parameter): - for n, p in self.root.named_parameters(): - if a is p: - return self.create_node("get_attr", n, (), {}) - raise NameError("parameter is not a member of this module") - elif isinstance(a, torch.Tensor): - for n_, p_ in self.root.named_buffers(): - if a is p_: - return self.create_node("get_attr", n_, (), {}) - elif isinstance(a, torch.nn.Module): - for n_, p_ in self.root.named_modules(): - if a is p_: - return self.create_node("get_attr", n_, (), {}) - # For NamedTuple instances that appear literally as args, we emit - # a node to construct the NamedTuple and use that Node as the argument. - if isinstance(a, tuple) and hasattr(a, "_fields"): - args = tuple(self.create_arg(elem) for elem in a) - return self.create_node("call_function", a.__class__, args, {}) - - # Tensors do not have a reliable string repr() from which they can be - # constructed (and we probably don't want to rely on that, either), so - # for any constant Tensor values we encounter, first search for if they - # are an attribute of some module in the module hierarchy. If so, emit - # a get_attr to retrieve that tensor. Otherwise, we'll store away the - # tensor value into a special attribute on the Module s.t. we can - # retrieve it with a get_attr. - if isinstance(a, (torch.Tensor, ScriptObject)): - qualname: Optional[str] = self.tensor_attrs.get(a) - - # Tensor was not found in the Module hierarchy, stow it away in a - # special attribute and set the qualname to refer to that - if not qualname: - i = 0 - while True: - qualname = f"_tensor_constant{i}" - if not hasattr(self.root, qualname): - break - i += 1 - self.tensor_attrs[a] = qualname - setattr(self.root, qualname, a) - - return self.create_node("get_attr", qualname, (), {}) - - if type(a) in _proxyable_classes: - # This is an instance of a proxyable class for which we did not - # witness its construction. Intern this as a constant attribute - - # TODO: binary search - i = 0 - while True: - qualname = f"_{a.__class__.__name__}_constant_{i}" - if not hasattr(self.root, qualname): - break - i += 1 - setattr(self.root, qualname, a) - - return self.create_node("get_attr", qualname, (), {}) - - return super().create_arg(a) - - @compatibility(is_backward_compatible=True) - def is_leaf_module(self, m: torch.nn.Module, module_qualified_name: str) -> bool: - """ - A method to specify whether a given ``nn.Module`` is a "leaf" module. - - Leaf modules are the atomic units that appear in - the IR, referenced by ``call_module`` calls. By default, - Modules in the PyTorch standard library namespace (torch.nn) - are leaf modules. All other modules are traced through and - their constituent ops are recorded, unless specified otherwise - via this parameter. - - Args: - - m (Module): The module being queried about - module_qualified_name (str): The path to root of this module. For example, - if you have a module hierarchy where submodule ``foo`` contains - submodule ``bar``, which contains submodule ``baz``, that module will - appear with the qualified name ``foo.bar.baz`` here. - """ - return ( - (m.__module__.startswith("torch.nn") or m.__module__.startswith("torch.ao.nn")) - and not isinstance(m, torch.nn.Sequential) - ) - - @compatibility(is_backward_compatible=True) - def path_of_module(self, mod: torch.nn.Module) -> str: - """ - Helper method to find the qualified name of ``mod`` in the Module hierarchy - of ``root``. For example, if ``root`` has a submodule named ``foo``, which has - a submodule named ``bar``, passing ``bar`` into this function will return - the string "foo.bar". - - Args: - - mod (str): The ``Module`` to retrieve the qualified name for. - """ - # Prefer the O(1) algorithm - if self.submodule_paths: - path = self.submodule_paths.get(mod) - if path is None: - raise NameError("module is not installed as a submodule") - assert isinstance(path, str) - return path - # O(N^2) fallback in the case that we didn't store the submodule - # paths. - else: - for n, p in self.root.named_modules(): - if mod is p: - return n - raise NameError("module is not installed as a submodule") - - @compatibility(is_backward_compatible=True) - def call_module( - self, - m: torch.nn.Module, - forward: Callable[..., Any], - args: Tuple[Any, ...], - kwargs: Dict[str, Any], - ) -> Any: - """ - Method that specifies the behavior of this ``Tracer`` when it encounters - a call to an ``nn.Module`` instance. - - By default, the behavior is to check if the called module is a leaf module - via ``is_leaf_module``. If it is, emit a ``call_module`` node referring to - ``m`` in the ``Graph``. Otherwise, call the ``Module`` normally, tracing through - the operations in its ``forward`` function. - - This method can be overridden to--for example--create nested traced - GraphModules, or any other behavior you would want while tracing across - ``Module`` boundaries. - - Args: - - m (Module): The module for which a call is being emitted - forward (Callable): The forward() method of the ``Module`` to be invoked - args (Tuple): args of the module callsite - kwargs (Dict): kwargs of the module callsite - - Return: - - The return value from the Module call. In the case that a ``call_module`` - node was emitted, this is a ``Proxy`` value. Otherwise, it is whatever - value was returned from the ``Module`` invocation. - """ - module_qualified_name = self.path_of_module(m) - if not self.is_leaf_module(m, module_qualified_name): - return forward(*args, **kwargs) - return self.create_proxy("call_module", module_qualified_name, args, kwargs) - - @compatibility(is_backward_compatible=False) - def getattr(self, attr: str, attr_val: Any, parameter_proxy_cache: Dict[str, Any]): - """ - Method that specifies the behavior of this ``Tracer`` when we call getattr - on a call to an ``nn.Module`` instance. - - By default, the behavior is to return a proxy value for the attribute. It - also stores the proxy value in the ``parameter_proxy_cache``, so that future - calls will reuse the proxy rather than creating a new one. - - This method can be overridden to --for example-- not return proxies when - querying parameters. - - Args: - - attr (str): The name of the attribute being queried - attr_val (Any): The value of the attribute - parametr_proxy_cache (Dict[str, Any]): A cache of attr names to proxies - - Return: - - The return value from the getattr call. - """ - def maybe_get_proxy_for_attr( - attr_val, collection_to_search, parameter_proxy_cache - ): - for n, p in collection_to_search: - if attr_val is p: - if n not in parameter_proxy_cache: - kwargs = {} - if ( - "proxy_factory_fn" - in inspect.signature(self.create_proxy).parameters - ): - kwargs["proxy_factory_fn"] = ( - None - if not self.param_shapes_constant - else lambda node: ParameterProxy( - self, node, n, attr_val - ) - ) - val_proxy = self.create_proxy("get_attr", n, (), {}, **kwargs) # type: ignore[arg-type] - parameter_proxy_cache[n] = val_proxy - return parameter_proxy_cache[n] - return None - - if isinstance(attr_val, torch.nn.Parameter): - maybe_parameter_proxy = maybe_get_proxy_for_attr( - attr_val, self.root.named_parameters(), parameter_proxy_cache - ) - if maybe_parameter_proxy is not None: - return maybe_parameter_proxy - - if self.proxy_buffer_attributes and isinstance(attr_val, torch.Tensor): - maybe_buffer_proxy = maybe_get_proxy_for_attr( - attr_val, self.root.named_buffers(), parameter_proxy_cache - ) - if maybe_buffer_proxy is not None: - return maybe_buffer_proxy - - return attr_val - - # This method will be refactored - @compatibility(is_backward_compatible=False) - def create_args_for_root(self, root_fn, is_module, concrete_args=None): - """ - Create ``placeholder`` nodes corresponding to the signature of the ``root`` - Module. This method introspects root's signature and emits those - nodes accordingly, also supporting ``*args`` and ``**kwargs``. - """ - # In some cases, a function or method has been decorated with a wrapper - # defined via ``functools.wraps``. In this case, the outer code object - # will likely not contain the actual parameters we care about, so unwrap - # the function to get to the innermost callable. - fn_for_analysis = inspect.unwrap(root_fn) - co = fn_for_analysis.__code__ - total_args = co.co_argcount + co.co_kwonlyargcount - orig_args = list(co.co_varnames) - names_iter = iter(co.co_varnames) - args: List[Any] = [] - skip_arg_idx = 0 - if is_module: - if total_args == 0: - raise RuntimeError( - "``self`` argument cannot be part of *args expansion!" - ) - skip_arg_idx = 1 - next(names_iter) # skip self - args.append(self.root) - - sig = inspect.signature(fn_for_analysis) - - def proxy_placeholder(name: str): - if concrete_args is not None and name in concrete_args: - cnt = 0 - - def replace_ph(x): - nonlocal cnt - cnt += 1 - param = sig.parameters[name] - default = ( - () - if param.default is inspect.Parameter.empty - else (param.default,) - ) - out = self.create_proxy( - "placeholder", f"{name}_{str(cnt)}", default, {} - ) - if x == PH: - return out - # Union[int, bool] == bool in Python <= 3.6 - if ( - type(x) == bool - or type(x) in base_types - and type(x) != torch.Tensor - ): - torch._assert( - out == x, - f"{name} has been specialized to have value {x} but got another value", - ) - elif type(x) == type(None): - args = ( - out, - f"{name} has been specialized to have value None but got another value", - ) - self.create_proxy("call_function", _assert_is_none, args, {}) - else: - warnings.warn( - f"Was not able to add assertion to guarantee correct input {name} to " - f"specialized function. It is up to the user to make sure that your inputs match the " - f"inputs you specialized the function with." - ) - - return x - - return pytree.tree_map(replace_ph, concrete_args[name]) - if name[0] == "*": - default = () - else: - param = sig.parameters[name] - default = () if param.default is inspect.Parameter.empty else (param.default,) # type: ignore[assignment] - return self.create_proxy( - "placeholder", - name, - default, - {}, - type_expr=fn_for_analysis.__annotations__.get(name, None), - ) - - arg_names = [next(names_iter) for idx in range(skip_arg_idx, total_args)] - if isinstance(concrete_args, tuple): - if len(arg_names) != len(concrete_args): - raise RuntimeError( - f"Tracing expected {len(arg_names)} arguments but got {len(concrete_args)} concrete arguments" - ) - concrete_args = {name: val for name, val in zip(arg_names, concrete_args)} - args.extend(proxy_placeholder(names) for names in arg_names) - - if co.co_kwonlyargcount > 0 or co.co_flags & HAS_VARSTUFF: - # TODO: type annotations for *args and **kwargs - if co.co_flags & inspect.CO_VARARGS: - args.append(proxy_placeholder("*" + next(names_iter))) - if co.co_flags & inspect.CO_VARKEYWORDS: - args.append(proxy_placeholder("**" + next(names_iter))) - root_fn = _patch_function(root_fn, len(args)) - - flat_args, in_spec = pytree.tree_flatten(tuple(args)) - if any(not isinstance(i, pytree.LeafSpec) for i in in_spec.children_specs): - # In the case that we have pytree-flattened inputs in - # `concrete_args`, generate a flattening wrapper around the - # original root function and return that. - self.graph._codegen = _PyTreeCodeGen( - _PyTreeInfo(orig_args[:total_args], in_spec, None) - ) - - def flatten_fn(*args): - tree_args = pytree.tree_unflatten(list(args), in_spec) - tree_out = root_fn(*tree_args) - out_args, out_spec = pytree.tree_flatten(tree_out) - assert isinstance(self.graph._codegen, _PyTreeCodeGen) - self.graph._codegen.pytree_info = ( - self.graph._codegen.pytree_info._replace(out_spec=out_spec) - ) - return out_args - - return flatten_fn, flat_args - return root_fn, args - - @compatibility(is_backward_compatible=True) - def trace( - self, - root: Union[torch.nn.Module, Callable[..., Any]], - concrete_args: Optional[Dict[str, Any]] = None, - ) -> Graph: - """ - Trace ``root`` and return the corresponding FX ``Graph`` representation. ``root`` - can either be an ``nn.Module`` instance or a Python callable. - - Note that after this call, ``self.root`` may be different from the ``root`` passed - in here. For example, when a free function is passed to ``trace()``, we will - create an ``nn.Module`` instance to use as the root and add embedded constants - to. - - - Args: - - root (Union[Module, Callable]): Either a ``Module`` or a function to be - traced through. Backwards-compatibility for this parameter is - guaranteed. - concrete_args (Optional[Dict[str, any]]): Concrete arguments that should - not be treated as Proxies. This parameter is experimental and - its backwards-compatibility is *NOT* guaranteed. - - Returns: - - A ``Graph`` representing the semantics of the passed-in ``root``. - """ - global _is_fx_tracing_flag - old_is_fx_tracing_flag = _is_fx_tracing_flag - _is_fx_tracing_flag = True - try: - if isinstance(root, torch.nn.Module): - self.root = root - - assert hasattr( - type(root), self.traced_func_name - ), f"traced_func_name={self.traced_func_name} doesn't exist in {type(root).__name__}" - - fn = getattr(type(root), self.traced_func_name) - self.submodule_paths = {mod: name for name, mod in root.named_modules()} - else: - self.root = torch.nn.Module() - fn = root - - tracer_cls: Optional[Type["Tracer"]] = getattr(self, "__class__", None) - self.graph = Graph(tracer_cls=tracer_cls) - - # When we encounter a Tensor value that's not a parameter, we look if it - # is some other attribute on the model. Construct a dict mapping Tensor - # values to the qualified name here for efficiency. This is used downstream - # in create_arg - self.tensor_attrs: Dict[Union[torch.Tensor, ScriptObject], str] = {} - - def collect_tensor_attrs(m: torch.nn.Module, prefix_atoms: List[str]): - for k, v in m.__dict__.items(): - if isinstance(v, (torch.Tensor, ScriptObject)): - self.tensor_attrs[v] = ".".join(prefix_atoms + [k]) - for k, v in m.named_children(): - collect_tensor_attrs(v, prefix_atoms + [k]) - - collect_tensor_attrs(self.root, []) - - assert isinstance(fn, FunctionType) - - fn_globals = fn.__globals__ # run before it gets patched - fn, args = self.create_args_for_root( - fn, isinstance(root, torch.nn.Module), concrete_args - ) - - parameter_proxy_cache: Dict[ - str, Proxy - ] = {} # Reduce number of get_attr calls - - # Method dispatch on parameters is not recorded unless it's directly used. - # Thus, we need to insert a proxy when __getattr__ requests a parameter. - @functools.wraps(_orig_module_getattr) - def module_getattr_wrapper(mod, attr): - attr_val = _orig_module_getattr(mod, attr) - return self.getattr(attr, attr_val, parameter_proxy_cache) - - @functools.wraps(_orig_module_call) - def module_call_wrapper(mod, *args, **kwargs): - def forward(*args, **kwargs): - return _orig_module_call(mod, *args, **kwargs) - - _autowrap_check( - patcher, - getattr(getattr(mod, "forward", mod), "__globals__", {}), - self._autowrap_function_ids, - ) - return self.call_module(mod, forward, args, kwargs) - - with _Patcher() as patcher: - # allow duplicate patches to support the case of nested calls - patcher.patch_method( - torch.nn.Module, - "__getattr__", - module_getattr_wrapper, - deduplicate=False, - ) - patcher.patch_method( - torch.nn.Module, "__call__", module_call_wrapper, deduplicate=False - ) - _patch_wrapped_functions(patcher) - _autowrap_check(patcher, fn_globals, self._autowrap_function_ids) - for module in self._autowrap_search: - _autowrap_check( - patcher, module.__dict__, self._autowrap_function_ids - ) - self.create_node( - "output", - "output", - (self.create_arg(fn(*args)),), - {}, - type_expr=fn.__annotations__.get("return", None), - ) - - self.submodule_paths = None - finally: - _is_fx_tracing_flag = old_is_fx_tracing_flag - return self.graph - - def __deepcopy__(self, memo): - # _autowrap_search contains modules, which cannot be deepcopied. - new_tracer = Tracer.__new__(Tracer) - - for k, v in self.__dict__.items(): - if k in {'_autowrap_search'}: - new_obj = copy.copy(v) - else: - new_obj = copy.deepcopy(v, memo) - - new_tracer.__dict__[k] = new_obj - - return new_tracer - - -# List of pairs of (global dict, function name) functions -# to patch for the purposes of the wrap() API. -_wrapped_fns_to_patch: List[Tuple[dict, str]] = [] - -# List of methods on classes to wrap (class type, function name) -# this currently only works for Tensor.* methods that aren't traced properly -_wrapped_methods_to_patch: List[Tuple[type, str]] = [] - -if os.environ.get("FX_PATCH_GETITEM") == "1": - # This change is needed to trace models like PositionalEmbedding from BERT: - # https://github.com/pytorch/benchmark/blob/master/torchbenchmark/models/BERT_pytorch/bert_pytorch/model/embedding/position.py - # but causes issues in quantization documented here: - # https://github.com/pytorch/pytorch/issues/50710 - # once that is fixed we can make this the default behavior. - _wrapped_methods_to_patch.append((torch.Tensor, "__getitem__")) - - -def _find_proxy(*objects_to_search): - """ - Recursively search a data structure for a Proxy() and return it, - return None if not found. - """ - proxy = None - - def find_proxy(x): - nonlocal proxy - if isinstance(x, Proxy): - proxy = x - - map_aggregate(objects_to_search, find_proxy) - return proxy - - -def _create_wrapped_func(orig_fn): - @functools.wraps(orig_fn) - def wrapped(*args, **kwargs): - """ - Given an closed-over ``orig_function`` to invoke, search the args and kwargs for - a Proxy object. If there is one, emit a ``call_function`` node to preserve the - call to this leaf function directly. Otherwise, just return the results of - this function call, as this function is not being traced. - """ - proxy = _find_proxy(args, kwargs) - if proxy is not None: - return_proxy = proxy.tracer.create_proxy( - "call_function", orig_fn, args, kwargs - ) - return_proxy.node.meta["is_wrapped"] = True - return return_proxy - return orig_fn(*args, **kwargs) - - return wrapped - - -def _create_wrapped_method(cls, name): - orig_fn = getattr(cls, name) - - @functools.wraps(orig_fn) - def wrapped(*args, **kwargs): - """ - Search the args and kwargs for a Proxy object. If there is one, - emit a ``call_method`` node to preserve the call to this method - directly. Otherwise, just return the results of this function - call, as this function is not being traced. - """ - proxy = _find_proxy(args, kwargs) - if proxy is not None: - return proxy.tracer.create_proxy("call_method", name, args, kwargs) - return orig_fn(*args, **kwargs) - - return wrapped - - -class _PatchedFn(NamedTuple): - frame_dict: Any - fn_name: str - orig_fn: Any - - def revert(self): - raise NotImplementedError() - - -class _PatchedFnSetItem(_PatchedFn): - def revert(self): - self.frame_dict[self.fn_name] = self.orig_fn - - -class _PatchedFnDel(_PatchedFn): - def revert(self): - del self.frame_dict[self.fn_name] - - -class _PatchedFnSetAttr(_PatchedFn): - def revert(self): - setattr(self.frame_dict, self.fn_name, self.orig_fn) - - -class _Patcher(object): - def __init__(self): - super(_Patcher, self).__init__() - self.patches_made: List[_PatchedFn] = [] - self.visited: Set[int] = set() - - def patch( - self, - frame_dict: Dict[str, Any], - name: str, - new_fn: Callable, - deduplicate: bool = True, - ): - """ - Replace frame_dict[name] with new_fn until we exit the context manager. - """ - new_fn.__fx_already_patched = deduplicate # type: ignore[attr-defined] - if name not in frame_dict and hasattr(builtins, name): - self.patches_made.append(_PatchedFnDel(frame_dict, name, None)) - elif getattr(frame_dict[name], "__fx_already_patched", False): - return # already patched, no need to do it again - else: - self.patches_made.append( - _PatchedFnSetItem(frame_dict, name, frame_dict[name]) - ) - frame_dict[name] = new_fn - - def patch_method( - self, cls: type, name: str, new_fn: Callable, deduplicate: bool = True - ): - """ - Replace object_or_dict.name with new_fn until we exit the context manager. - """ - new_fn.__fx_already_patched = deduplicate # type: ignore[attr-defined] - orig_fn = getattr(cls, name) - if getattr(orig_fn, "__fx_already_patched", False): - return # already patched, no need to do it again - self.patches_made.append(_PatchedFnSetAttr(cls, name, orig_fn)) - setattr(cls, name, new_fn) - - def visit_once(self, thing: Any): - """Return True on the first call to with thing, otherwise false""" - idx = id(thing) - if idx in self.visited: - return False - self.visited.add(idx) - return True - - def __enter__(self): - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - """ - Undo all the changes made via self.patch() and self.patch_method() - """ - while self.patches_made: - # unpatch in reverse order to handle duplicates correctly - self.patches_made.pop().revert() - self.visited.clear() - - -def _patch_wrapped_functions(patcher: _Patcher): - """ - Go through ``_wrapped_fn_patch_table`` and, for each frame object, wrap - the listed global functions in the `_create_wrapped_func` wrapper. - """ - for frame_dict, name in _wrapped_fns_to_patch: - if name not in frame_dict and hasattr(builtins, name): - orig_fn = getattr(builtins, name) - else: - orig_fn = frame_dict[name] - patcher.patch(frame_dict, name, _create_wrapped_func(orig_fn)) - - for cls, name in _wrapped_methods_to_patch: - patcher.patch_method(cls, name, _create_wrapped_method(cls, name)) - - -def _autowrap_check( - patcher: _Patcher, frame_dict: Dict[str, Any], function_ids: Set[int] -): - """ - Some methods, like `math.sqrt` are common enough we want to automatically wrap them as we see them. - This method searches a scope for them and patches them if found. - """ - if patcher.visit_once(frame_dict): - for name, value in frame_dict.items(): - if ( - not name.startswith("_") - and callable(value) - and id(value) in function_ids - ): - patcher.patch(frame_dict, name, _create_wrapped_func(value)) - - -@compatibility(is_backward_compatible=True) -def wrap(fn_or_name: Union[str, Callable]): - """ - This function can be called at module-level scope to register fn_or_name as a "leaf function". - A "leaf function" will be preserved as a CallFunction node in the FX trace instead of being - traced through:: - - # foo/bar/baz.py - def my_custom_function(x, y): - return x * x + y * y - - pippy.fx.wrap('my_custom_function') - - def fn_to_be_traced(x, y): - # When symbolic tracing, the below call to my_custom_function will be inserted into - # the graph rather than tracing it. - return my_custom_function(x, y) - - This function can also equivalently be used as a decorator:: - - # foo/bar/baz.py - @pippy.fx.wrap - def my_custom_function(x, y): - return x * x + y * y - - A wrapped function can be thought of a "leaf function", analogous to the concept of - "leaf modules", that is, they are functions that are left as calls in the FX trace - rather than traced through. - - Args: - - fn_or_name (Union[str, Callable]): The function or name of the global function to insert into the - graph when it's called - """ - if not callable(fn_or_name) and not isinstance(fn_or_name, str): - raise RuntimeError( - "Unsupported type for global function! Must be either a callable or " - "string name" - ) - - if callable(fn_or_name): - assert not isinstance(fn_or_name, str) # to make mypy happy - fn_name = fn_or_name.__name__ - else: - assert isinstance( - fn_or_name, str - ), "fn_or_name must be a global function or string name" - fn_name = fn_or_name - - currentframe = inspect.currentframe() - assert currentframe is not None - f = currentframe.f_back - assert f is not None - if f.f_code.co_name != "": - raise NotImplementedError("wrap must be called at the top level of a module") - - # consider implementing Callable version of this via _autowrap_function_ids / _autowrap_search - # semantics would be slightly different, but would add support `from x import wrapped_function` - _wrapped_fns_to_patch.append((f.f_globals, fn_name)) - return fn_or_name - - -@compatibility(is_backward_compatible=True) -def symbolic_trace( - root: Union[torch.nn.Module, Callable[..., Any]], - concrete_args: Optional[Dict[str, Any]] = None, -) -> GraphModule: - """ - Symbolic tracing API - - Given an ``nn.Module`` or function instance ``root``, this function will return a ``GraphModule`` - constructed by recording operations seen while tracing through ``root``. - - ``concrete_args`` allows you to partially specialize your function, whether it's to remove control flow or data structures. - - For example:: - - def f(a, b): - if b == True: - return a - else: - return a*2 - - FX can typically not trace through this due to the presence of control - flow. However, we can use `concrete_args` to specialize on the value of - `b` to trace through this. - - f = fx.symbolic_trace(f, concrete_args={'b': False}) - assert f(3, False) == 6 - - Note that although you can still pass in different values of `b`, they will be ignored. - - We can also use `concrete_args` to eliminate data-structure handling from - our function. This will use pytrees to flatten your input. To avoid - overspecializing, pass in `fx.PH` for values that shouldn't be - specialized. For example:: - - def f(x): - out = 0 - for v in x.values(): - out += v - return out - f = fx.symbolic_trace(f, concrete_args={'x': {'a': fx.PH, 'b': fx.PH, 'c': fx.PH}}) - assert f({'a': 1, 'b': 2, 'c': 4}) == 7 - - - Args: - root (Union[torch.nn.Module, Callable]): Module or function to be traced and converted - into a Graph representation. - concrete_args (Optional[Dict[str, any]]): Inputs to be partially specialized - - Returns: - GraphModule: a Module created from the recorded operations from ``root``. - """ - tracer = Tracer() - graph = tracer.trace(root, concrete_args) - name = ( - root.__class__.__name__ if isinstance(root, torch.nn.Module) else root.__name__ - ) - return GraphModule(tracer.root, graph, name) - - -@wrap -def _assert_is_none(value, msg): - assert value is None, msg diff --git a/pippy/fx/annotate.py b/pippy/fx/annotate.py deleted file mode 100644 index 906b9b811..000000000 --- a/pippy/fx/annotate.py +++ /dev/null @@ -1,22 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates -from pippy.fx.proxy import Proxy -from ._compatibility import compatibility - -@compatibility(is_backward_compatible=False) -def annotate(val, type): - # val could be either a regular value (not tracing) - # or fx.Proxy (tracing) - if isinstance(val, Proxy): - if val.node.type: - raise RuntimeError(f"Tried to annotate a value that already had a type on it!" - f" Existing type is {val.node.type} " - f"and new type is {type}. " - f"This could happen if you tried to annotate a function parameter " - f"value (in which case you should use the type slot " - f"on the function signature) or you called " - f"annotate on the same value twice") - else: - val.node.type = type - return val - else: - return val diff --git a/pippy/fx/experimental/__init__.py b/pippy/fx/experimental/__init__.py deleted file mode 100644 index f2661b8c6..000000000 --- a/pippy/fx/experimental/__init__.py +++ /dev/null @@ -1 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates diff --git a/pippy/fx/experimental/accelerator_partitioner.py b/pippy/fx/experimental/accelerator_partitioner.py deleted file mode 100644 index a3254cb45..000000000 --- a/pippy/fx/experimental/accelerator_partitioner.py +++ /dev/null @@ -1,1083 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates -import operator -from typing import Dict, List, Set, NamedTuple, Tuple - -import torch -from pippy.fx.passes.graph_manipulation import get_size_of_all_nodes -from pippy.fx.experimental.partitioner_utils import ( - Partition, - Device, - PartitionerConfig, - get_partition_to_latency_mapping, - get_latency_of_partitioned_graph, - NodeLatency, - get_extra_size_of, - PartitionMode, -) -from pippy.fx.graph_module import GraphModule -from pippy.fx.node import Node, map_arg -from pippy.fx.passes.split_module import split_module - - -class DAGNode: - """DAGNode class maintains useful information for a partition (submodule), - and its input submodules and output submodules. - """ - - def __init__( - self, - submodule_node: Node, - input_nodes: List[Node], - output_nodes: List[Node], - logical_device_ids: List[int], - size_bytes: int, - ) -> None: - self.submodule_node: Node = submodule_node - self.input_nodes: List[Node] = input_nodes - self.output_nodes: List[Node] = output_nodes - self.logical_device_ids: List[int] = logical_device_ids - self.size_bytes = size_bytes - - def __str__(self) -> str: - return str(self.submodule_node) - - -class DAG: - """DAG class contains all the DAG nodes""" - - def __init__(self) -> None: - self.nodes: List[DAGNode] = [] - - def create_node( - self, - submodule_node: Node, - input_nodes: List[Node], - output_nodes: List[Node], - logical_devices: List[int], - size_bytes: int, - ) -> None: - node = DAGNode( - submodule_node, input_nodes, output_nodes, logical_devices, size_bytes - ) - self.nodes.append(node) - - -class PartitionResult(NamedTuple): - """NameTuple used for returning DAG and a new fx module""" - - dag: DAG - module_with_submodules: GraphModule - - -"""Followings are some helper functions for partition manipulation""" - - -def reset_partition_device(partitions): - for partition in partitions: - partition.logical_device_ids = [] - - -def combine_two_partitions( - partition_0: Partition, partition_1: Partition, partitions: List[Partition] -) -> None: - """Given a list of partitions and its two partitions, - combine these two partitions into a new one appending to the partitions - and remove the previous two partitions from the list of partitions - """ - partition = Partition(len(partitions)) - partition.nodes = partition_0.nodes.union(partition_1.nodes) - partition.recalculate_mem_size() - partitions.append(partition) - partitions.remove(partition_0) - partitions.remove(partition_1) - reorganize_partitions(partitions) - return - - -def set_parents_and_children(partitions: List[Partition]) -> None: - """Given a list of partitions, mark parents and children for each partition""" - # Go through all nodes in a partition. - # If a node's user is in other partition, - # then the other partition is this partition's children. - # This partition is the other partition's parent - for partition in partitions: - partition.children = set() - partition.parents = set() - for partition in partitions: - for node in partition.nodes: - # For each node in the current partition, find its users - users = node.users - for n in users: - # Find which the partition the user node belongs to. - # Note that if the node itself is also belongs to that partition, - # that partition is not the child of the current partition - for p in partitions: - if p != partition and n in p.nodes and node not in p.nodes: - partition.children.add(p) - p.parents.add(partition) - return - - -def reorganize_partitions(partitions: List[Partition]) -> None: - """Given a list of partitions, reorganzie partiton id, - its parents and its children for each partition - """ - # Rearrange partition ids - for i, partition in enumerate(partitions): - partition.partition_id = i - set_parents_and_children(partitions) - return - - -def get_bfs_level_partition(partitions: List[Partition]) -> None: - """Given a list of partitions, - mark the bfs level for each partition - """ - current_level: Set[Partition] = set() - visited: Set[Partition] = set() - for partition in partitions: - # If a partition has no parent, it should be in root level - if len(partition.parents) == 0: - current_level.add(partition) - next_level: Set[Partition] = set() - level = 0 - # bfs - while current_level: - partition = current_level.pop() - partition.bfs_level = level - visited.add(partition) - children = partition.children - for child in children: - if child not in next_level: - next_level.add(child) - if not current_level: - current_level = next_level.copy() - next_level = set() - level += 1 - return - - -def get_node_to_partition_mapping(partitions: List[Partition]) -> Dict[Node, int]: - """Given a list of partitions,return node to partition mapping""" - node_to_partition: Dict[Node, int] = {} - for partition in partitions: - for node in partition.nodes: - node_to_partition[node] = partition.partition_id - return node_to_partition - - -def get_logical_id_to_device(devices: List[Device]) -> Dict[int, Device]: - """Get a mapping from device logical ID to Device object.""" - logical_id_to_device: Dict[int, Device] = {} - for d in devices: - logical_id_to_device[d.logical_id] = d - return logical_id_to_device - - -def get_device_partition_stats( - partitions: List[Partition], devices: List[Device] -) -> Tuple[Dict[Device, List[Partition]], Dict[Device, int], List[Partition]]: - """Given a list of partitions and a list of devices, returns: - 1. A mapping from device to partitions on it; - 2. A mapping from device to its remaining memory size; - 3. A list of partitions that do not have a device. - """ - # logical id to device - logical_id_to_device = get_logical_id_to_device(devices) - # Track partitions on device - device_to_partitions: Dict[Device, List[Partition]] = {} - # Track device's left mem size - device_to_left_mem_bytes: Dict[Device, int] = {} - for d in devices: - device_to_partitions[d] = [] - device_to_left_mem_bytes[d] = d.available_mem_bytes - - # Deal with the partitions that already have a device - # and also collect all partitions without a device (no_device_partitions) - no_device_partitions = [] - for partition in partitions: - if partition.logical_device_ids != []: - for logical_id in partition.logical_device_ids: - device = logical_id_to_device[logical_id] - device_to_partitions[device].append(partition) - device_to_left_mem_bytes[device] -= partition.used_mem_bytes - else: - no_device_partitions.append(partition) - - return ( - device_to_partitions, - device_to_left_mem_bytes, - no_device_partitions, - ) - - -def get_device_to_partitions_mapping( - partitions: List[Partition], devices: List[Device] -): - """Given a list of partitions and a list of devices, - map each partition into a device. - """ - - def calculate_extra_mem_bytes_needed_for( - partition: Partition, partitions: List[Partition] - ): - all_nodes: Set[Node] = set() - for p in partitions: - all_nodes = all_nodes.union(p.nodes) - if len(all_nodes) == 0: - return partition.used_mem_bytes - all_nodes = all_nodes.union(partition.nodes) - extra_size_needed = 0 - for node in partition.nodes: - extra_size_needed += get_extra_size_of(node, all_nodes) - return extra_size_needed - - def find_device_for(partition: Partition): - """Given a partition, find a logical device for the partition - The algorithm is to put the partition on the device - that has just enough mem left for that partition. - device_to_left_mem_bytes is a dictionary between device and its left mem size - sorted by its left mem size - """ - for d in device_to_left_mem_bytes: - extra_size_needed = calculate_extra_mem_bytes_needed_for( - partition, device_to_partitions[d] - ) - if extra_size_needed < device_to_left_mem_bytes[d]: - device_to_partitions[d].append(partition) - partition.logical_device_ids.append(d.logical_id) - device_to_left_mem_bytes[d] -= extra_size_needed - return True - return False - - ( - device_to_partitions, - device_to_left_mem_bytes, - no_device_partitions, - ) = get_device_partition_stats(partitions, devices) - - # Find devices for all the partitions without a device - found_device = True - for partition in no_device_partitions: - device_to_left_mem_bytes = { - d: left_mem_bytes - for d, left_mem_bytes in sorted( - device_to_left_mem_bytes.items(), key=lambda item: item[1] - ) - } - found_device = find_device_for(partition) - if not found_device: - break - return found_device - - -def check_dependency(partition): - """Given a partition,check if there is a circular dependency on - this partition using bfs - """ - visited: Set[Partition] = set([partition]) - queue: List[Partition] = [partition] - while queue: - p = queue.pop(0) - for child in p.children: - if child == partition: - return True - else: - if child not in visited: - visited.add(child) - queue.append(child) - return False - - -class Partitioner: - """A fx module may not fit into one device. - Partitioner class helps partition one fx module into submodules (partitions), - so that the submodules can be executed crossing different accelerators. - The main function of this class is self.partition_graph. - It partitions the fx module based on the scheme specified in partition_config - A DAG structure is returned - along with a new fx module with submodule nodes. - """ - - def __init__(self) -> None: - self.partitions: List[Partition] = [] - self.node_to_partition: Dict[Node, int] = {} - self.devices: List[Device] = [] - - def partition_graph( - self, - fx_module: GraphModule, - torch_module: torch.nn.Module, - partitioner_config: PartitionerConfig, - ) -> PartitionResult: - """Given the fx module, torch module and partitioner_config, - find the partitions, do the partitions, - and then return a DAG and a new fx module with submodule nodes (partitions) - """ - self.graph_module = fx_module - self.torch_module = torch_module - self.devices = partitioner_config.devices - if len(self.devices) == 0: - raise RuntimeError("No devices") - # Tag the size in bytes to all nodes in the graph_module. - get_size_of_all_nodes(self.graph_module) - # Check if there are op nodes in the fx module - nodes = self.graph_module.graph.nodes - if all(node.op in {"placeholder", "get_attr", "output"} for node in nodes): - raise RuntimeError("No Partition since no operations in the module") - # Calculate total size of the fx module - total_size_of_graph = 0 - for node in nodes: - if node.op == "output": - break - total_size_of_graph += node.size_bytes.total_size - # Find the device with the max mem size - device_with_max_mem = max(self.devices, key=lambda d: d.available_mem_bytes) - # AOT based partition - if partitioner_config.mode == PartitionMode.aot_based: - self.aot_based_partition( - partitioner_config.node_to_partition_mapping, - partitioner_config.partition_to_logical_device_mapping, - ) - # Single partition if the whole module can be fit into one device - elif total_size_of_graph <= device_with_max_mem.available_mem_bytes: - self.find_single_partition( - total_size_of_graph, logical_device_id=device_with_max_mem.logical_id - ) - elif total_size_of_graph > sum([d.available_mem_bytes for d in self.devices]): - raise RuntimeError("Devices have no enough memory for the module") - else: - # Sparse nn based partition - if partitioner_config.mode == PartitionMode.sparse_nn: - available_mem_bytes = self.devices[0].available_mem_bytes - if not all( - device.available_mem_bytes == available_mem_bytes - for device in self.devices - ): - raise RuntimeError("All devices must have same memory size!") - # sparse_nn_partition only support same memory size - # TODO: add different size support for sparse_nn_partition - self.sparse_nn_partition(available_mem_bytes) - # Cost aware partition - elif partitioner_config.mode == PartitionMode.cost_aware: - self.cost_aware_partition( - partitioner_config.transfer_rate_bytes_per_sec, - partitioner_config.node_to_latency_mapping, - ) - # KL based partition - elif partitioner_config.mode == PartitionMode.kl_based: - self.kl_based_partition( - partitioner_config.transfer_rate_bytes_per_sec, - partitioner_config.node_to_latency_mapping, - ) - else: - self.size_based_partition() - - # Saturate host if possible. - if partitioner_config.saturate_host: - self.saturate_host() - - # Partition the graph module based on the partition assignment. - module_with_submodules = self.do_partition() - - # The DAG contains DAGNodes with info of each partition's input nodes, output nodes - # and how partitions are connected. - dag = self.dump_dag(module_with_submodules) - ret = PartitionResult(dag, module_with_submodules) - return ret - - def find_single_partition( - self, total_size_of_graph, logical_device_id: int = 0 - ) -> None: - """Fit the whole fx module into one device""" - partition_0 = self.create_partition() - for node in self.graph_module.graph.nodes: - if node.op == "output": - # Skip the output node, but there can - # be nodes after the output in certain cases. - continue - partition_0.nodes.add(node) - partition_0.used_mem_bytes = total_size_of_graph - partition_0.logical_device_ids = [logical_device_id] - # Get the node to partition mapping - self.node_to_partition = get_node_to_partition_mapping(self.partitions) - return - - def size_based_partition(self) -> None: - """This method is to partition the fx module based on memory size. - It uses greedy approach. The result may not be the best. - The basic idea is: - Step 1: - Find a device which has enough memory to fit the current node, create a empty partition - with the size of that device. - Then keep adding the following nodes into the partition until the partition is full. - Step 2: - Repeat Step 1 until no device left - Step 3: - If some nodes are left, create a partition for each left node (single node partition). - and then try to map those partitions into logical devices with enough mem left. - """ - - def find_device_based_on_size(node) -> Device: - """Given a node, this function is to find a logical device - that could fit the node. - """ - mem_size_needed = get_extra_size_of(node, set()) - device = Device("", -1, -1) - for d in self.devices: - if ( - d not in occupied_devices - and d.available_mem_bytes >= mem_size_needed - ): - device = d - break - if device.available_mem_bytes < 0: - raise RuntimeError(str(node) + "is too large to fit any device") - occupied_devices.append(device) - return device - - # Track partition and its left mem size - partition_to_left_mem_bytes: Dict[Partition, int] = {} - # Track all the devices that have been used - occupied_devices: List[Device] = [] - partition = self.create_partition() - for node in self.graph_module.graph.nodes: - if node.op in {"call_module", "call_method", "call_function"}: - # Check if there are devices left - if len(self.partitions) <= len(self.devices): - total_size_of_input_nodes = get_extra_size_of(node, partition.nodes) - # Check if the current partition is the very first partition - if partition.used_mem_bytes == 0: - # Find a device to fit the first node, return available mem size - device = find_device_based_on_size(node) - occupied_devices.append(device) - # Update partition and its left mem size - partition_to_left_mem_bytes[ - partition - ] = device.available_mem_bytes - # Update available mem for the current partition - partition.logical_device_ids.append(device.logical_id) - else: - # The current partition is not the first partition - # Check if the current node can fit into current partition - if ( - partition_to_left_mem_bytes[partition] - < total_size_of_input_nodes - ): - # Check if no device is left - if len(self.partitions) == len(self.devices): - # No device is left - # Put the previous partitions into a list (non_single_node_partitions) - non_single_node_partitions = self.partitions[:] - # Create the first single node partition for the current node - self.create_single_node_partition(node) - continue - # Some devices are still left - # Create a new partition with a mem size that is enough for the current node - device = find_device_based_on_size(node) - partition = self.create_partition() - total_size_of_input_nodes = get_extra_size_of( - node, partition.nodes - ) - partition_to_left_mem_bytes[ - partition - ] = device.available_mem_bytes - partition.logical_device_ids.append(device.logical_id) - partition.add_node(node) - partition_to_left_mem_bytes[partition] -= total_size_of_input_nodes - # Create single node partitions if no device is left - else: - self.create_single_node_partition(node) - reorganize_partitions(self.partitions) - # Get the node to partition mapping - self.node_to_partition = get_node_to_partition_mapping(self.partitions) - # Mapping all partitions into device - found_partition_to_device_mapping = get_device_to_partitions_mapping( - self.partitions, self.devices - ) - if not found_partition_to_device_mapping: - raise RuntimeError("Cannot Get a Valid Partition to Logical Device Mapping") - return - - def saturate_host(self) -> None: - """Saturate host by assigning replicates to unused devices with enough memory. - It uses a greedy approach to find a next available set of devices to place all split - partitions: For each used device, it searches for an idle device with minimal memory - size that can hold all the partition located on that device; If the search is successful - for all used devices, it then assigns the new devices' logical ID to the corresponding - partition. - """ - ( - device_to_partitions, - device_to_left_mem_bytes, - no_device_partitions, - ) = get_device_partition_stats(self.partitions, self.devices) - - assert ( - len(no_device_partitions) == 0 - ), f"Expect no_device_partitions has 0 device, but get {len(no_device_partitions)}" - - # Devices that hold partitions - used_devices = [d for d in self.devices if len(device_to_partitions[d]) > 0] - # Track replicates of the assigned devices - replicated_device_to_used_device: Dict[Device, Device] = {} - - while len(used_devices) * 2 + len(replicated_device_to_used_device) <= len( - self.devices - ): - # Success flag for this round - success = True - # Devices that have not been assigned - idle_devices = [ - d - for d in self.devices - if d not in used_devices and d not in replicated_device_to_used_device - ] - # Temporary mapping from replicated device to original device - temp_replicate_mapping = {} - - # Find a new device to replicate all partitions on an used device - for used_device in used_devices: - # Idle devices that have enough memory - available_devices = [ - d - for d in idle_devices - if d.available_mem_bytes - >= used_device.available_mem_bytes - - device_to_left_mem_bytes[used_device] - ] - if len(available_devices) == 0: - success = False - break - new_device = min(available_devices, key=lambda d: d.available_mem_bytes) - idle_devices.remove(new_device) - temp_replicate_mapping[new_device] = used_device - - if not success: - break - replicated_device_to_used_device.update(temp_replicate_mapping) - - # Update logical device IDs assigned to the partitions - for ( - replicate_device, - original_device, - ) in replicated_device_to_used_device.items(): - logical_id = replicate_device.logical_id - for partition in device_to_partitions[original_device]: - partition.logical_device_ids.append(logical_id) - for p in self.partitions: - print(p.logical_device_ids) - - def do_partition(self) -> GraphModule: - """Return a new fx module with submodule nodes (partitions).""" - module_with_submodules = split_module( - self.graph_module, - self.torch_module, - lambda node: self.node_to_partition[node], - ) - return module_with_submodules - - def dump_dag(self, module_with_submodules: GraphModule) -> DAG: - """Return the dag structure and the new fx module with submodules.""" - dag = DAG() - for node in module_with_submodules.graph.nodes: - if node.op == "output": - break - if node.op in {"placeholder", "get_attr"}: - continue - if node.target == operator.__getitem__: - continue - input_nodes: Dict[Node, None] = {} - map_arg(node.args, lambda n: input_nodes.setdefault(n)) - map_arg(node.kwargs, lambda n: input_nodes.setdefault(n)) - # When a node has two or more output nodes, - # it outputs its result to 'getitem' nodes. - # Those 'getitem' nodes are the output node for this node. - # Otherwise, the output node is this node itself. - if len(node.users) > 1: - output_nodes = list(node.users) - else: - output_nodes = [node] - partition_id = int(node.name.rsplit("_", 1)[-1]) - device_ids = self.partitions[partition_id].logical_device_ids - size_bytes = self.partitions[partition_id].used_mem_bytes - dag.create_node( - node, list(input_nodes), output_nodes, device_ids, size_bytes - ) - return dag - - def create_partition(self) -> Partition: - """Create a partition and append it to self.partitions.""" - partition_id = len(self.partitions) - partition = Partition(partition_id) - self.partitions.append(partition) - return partition - - def create_single_node_partition(self, node): - """Create a partition for a single node""" - partition = self.create_partition() - partition.add_node(node) - return - - def sparse_nn_partition(self, available_mem_bytes: int) -> None: - """This method partition a sparse nn module. - It is size based partition but different from size_based_partition, - it only works when all the devices have same memory size (available_mem_bytes). - In the future, devices with different mem sizes will be supported like size_based_partition. - It first traverse all the nodes and do the partitions based on the same memory size. - If the current partition has no enough memory left for a new op node - (call_module, call_method, call_function), a new partition is created. - When crossing the boundary between non-embedding nodes and embedding nodes, - a new partition is created regardlessly. - For example, if the current node is a non-embedding node but the next node is an - embedding node, a new partition is created for the next node. - After the partition, the partitions are combined as much as possible. - The rule is that a non-embedding partition only - combines with another non-embedding one. - So as the embedding partitions. - """ - - def combine_partitions_based_on_size( - partitions: List[Partition], available_mem_bytes: int - ) -> None: - """Combining small partitions together to keep as less partitions as possible. - Here is an example of the algorithm to do this: - Assume some partitions, we first sort them based on partiiton used memory size. - [(partition_4, 1), (partition_3, 1), (partition_2, 2), (partition_1, 7), (partition_0, 9)] - The available memory is 10. - step 1: self.find_partition_to_combine_based_on_size() - First, mark bfs level for each partition - Second, look the smallest partition, partition_4: 10 - 1 = 9 - It means any partition has a used memory equal or less than 9 could combine this partition - We go from the largest and selection partition_0. - Check the bfs level for two partitions, if the level difference is less than 2, - it can be combined. - step 2: repeat step 1 until no partitions can be combined - """ - find_combination = True - while find_combination: - # Sort partitions based on memory size - sorted_partitions = sorted(partitions, key=lambda p: p.used_mem_bytes) - # Mark bfs level - get_bfs_level_partition(self.partitions) - find_combination, partitions = find_partition_to_combine_based_on_size( - sorted_partitions, available_mem_bytes, partitions - ) - return - - def calculate_mem_bytes_needed(p1, p2): - """Given two partitions, calculate how many mem bytes - are needed if two partitions are combined - """ - nodes = p1.nodes.union(p2.nodes) - mem_bytes_needed = 0 - for node in nodes: - mem_bytes_needed += get_extra_size_of(node, nodes) - return mem_bytes_needed - - def find_partition_to_combine_based_on_size( - sorted_partitions: List[Partition], - available_mem_bytes: int, - partitions: List[Partition], - ) -> Tuple[bool, List[Partition]]: - """step 1 in combine_partition_based_on_size()""" - find_combination = False - smallest_partition = sorted_partitions.pop(0) - for p in sorted_partitions[::-1]: - if abs(smallest_partition.bfs_level - p.bfs_level) <= 1: - # Calculate how many bytes needed if combined - mem_bytes_needed = calculate_mem_bytes_needed(p, smallest_partition) - if mem_bytes_needed <= available_mem_bytes: - combine_two_partitions(p, smallest_partition, self.partitions) - partitions.remove(smallest_partition) - partitions.remove(p) - partitions.append(self.partitions[-1]) - find_combination = True - break - return find_combination, partitions - - def reset_partition_in_sparse_nn(partition, new_partition=True): - """If crossing the boudary between non-embedding nodes and - embedding nodes, create a new partition - """ - if in_embedding_region: - embedding_partitions.append(partition) - else: - non_embedding_partitions.append(partition) - if new_partition: - partition = self.create_partition() - partition.left_mem_bytes = available_mem_bytes - return partition - return None - - def is_embedding_node(node: Node) -> bool: - """Check if a node is an embedding node""" - if node.op == "call_module": - submodule = self.graph_module - for atom in str(node.target).split("."): - if not hasattr(submodule, atom): - raise RuntimeError( - f"Module {submodule} has no attribute {atom}" - ) - submodule = getattr(submodule, atom) - if "Embedding" in str(submodule): - return True - return False - - # Track embedding partitons and non-embedding partitions separately - embedding_partitions: List[Partition] = [] - non_embedding_partitions: List[Partition] = [] - # A Flag to check the boundary - in_embedding_region: bool = False - partition = self.create_partition() - for node in self.graph_module.graph.nodes: - if node.op in {"call_module", "call_method", "call_function"}: - # Check if crossing the boundary between embedding nodes and non embedding nodes - if is_embedding_node(node) != in_embedding_region: - # Crossing the boundary - # Check if the current partition is an empty partition - if partition.used_mem_bytes != 0: - # The current partition isn't an empty partition. Create a new one. - partition = reset_partition_in_sparse_nn(partition) - in_embedding_region = not in_embedding_region - total_size_of_input_nodes = get_extra_size_of(node, partition.nodes) - if ( - total_size_of_input_nodes + partition.used_mem_bytes - > available_mem_bytes - ): - partition = reset_partition_in_sparse_nn(partition) - total_size_of_input_nodes = get_extra_size_of(node, partition.nodes) - if total_size_of_input_nodes > available_mem_bytes: - raise RuntimeError( - node.target + "is too large to fit into a device" - ) - partition.add_node(node) - reset_partition_in_sparse_nn(partition, new_partition=False) - # Set parents and children for partitions - set_parents_and_children(self.partitions) - # Combining non-embedding partitions - combine_partitions_based_on_size(non_embedding_partitions, available_mem_bytes) - # Combining embedding partitions - combine_partitions_based_on_size(embedding_partitions, available_mem_bytes) - total_size_of_non_embedding_partitions = 0 - for partition in non_embedding_partitions: - total_size_of_non_embedding_partitions += partition.used_mem_bytes - # Check if devices are enough for all partitions - if len(embedding_partitions) > len(self.devices): - msg = ( - "Need " - + str(len(embedding_partitions)) - + " devices, but only " - + str(len(self.devices)) - + " provided" - ) - raise RuntimeError(msg) - occupied_devices = [] - for i, partition in enumerate(embedding_partitions): - # Check if all non-embedding partitions can fit into embedding partition devices - if ( - total_size_of_non_embedding_partitions + partition.used_mem_bytes - > available_mem_bytes - ): - raise RuntimeError( - "partition_" - + str(partition.partition_id) - + "(embedding partition) and non embedding partitions can not fit into one device" - ) - else: - # Add logical device to the partition - partition.logical_device_ids = [self.devices[i].logical_id] - occupied_devices.append(self.devices[i].logical_id) - # Add logical devices to the non_embedding_partitions - for partition in non_embedding_partitions: - partition.logical_device_ids = occupied_devices - # Get the node to partition mapping - self.node_to_partition = get_node_to_partition_mapping(self.partitions) - return - - def cost_aware_partition( - self, - transfer_rate_bytes_per_sec: float, - node_to_latency_mapping: Dict[Node, NodeLatency], - ) -> None: - """This method is to partition the fx module based on the cost. - The cost is the total latency of running the whole fx module. - In partitioner_utils.py, the cost model is built. - The cost aware partition algorithm is: - #1. At every begining, each node is a partition. - Then we map all the partitions to the devices - and calculate the cost - #2. Then try to pre-combine any two of the partitions if the two - partitions can be combined. - (the bfs level is less than 2 or two partitions are connected and - can find partition to device mapping) - See if any partition pair could reduce the current cost. - Choose the pair that shows the minimum cost and then combine them - #3. Repeat #2 until the cost cannot be reduced. - """ - - def try_combining_partitions(p0_index, p1_index, partitions) -> float: - """Given two partitions and a list of partitions, combine these two partitions - and see what is the cost of the modified partition list - """ - p0 = partitions[p0_index] - p1 = partitions[p1_index] - """If two partitions' bfs level are less than 2 or two partitions are connected to each other, - then they can be combined - """ - if ( - (abs(p0.bfs_level - p1.bfs_level) <= 1) - or (p0 in p1.parents) - or p0 in (p1.children) - ): - combine_two_partitions(p0, p1, partitions) - # Check if a circular dependency exists after combining - if check_dependency(partitions[-1]): - return float("inf") - # Check if the modified partition list can be mapped to devices after combination - reset_partition_device(partitions) - found_deivce = get_device_to_partitions_mapping( - partitions, self.devices - ) - if not found_deivce: - return float("inf") - # Calculate the new cost - partition_to_latency_mapping = get_partition_to_latency_mapping( - partitions, node_to_latency_mapping - ) - cost = get_latency_of_partitioned_graph( - partitions, - partition_to_latency_mapping, - transfer_rate_bytes_per_sec, - ) - return cost - # If two partition can not be combined, the cost is inf - return float("inf") - - def search_combination( - transfer_rate_bytes_per_sec, node_to_latency_mapping - ) -> bool: - """Given transfer rate between partitions and each node's latency, - find two partitions to combine so the cost of the partitions can - be reduced. - The algorithm is : - 1. Go through all the partition pairs and see - if any pair of partitions can be combined. - 2. Calculate the cost after the combination. - 3. Select the minimum cost and combine its cooresponding partition pair. - """ - partition_to_latency_mapping = get_partition_to_latency_mapping( - self.partitions, node_to_latency_mapping - ) - cost = get_latency_of_partitioned_graph( - self.partitions, - partition_to_latency_mapping, - transfer_rate_bytes_per_sec, - ) - if len(self.partitions) == 1: - return False - partition_pair: List[int] = [] - for i in range(len(self.partitions) - 1): - for j in range(i + 1, len(self.partitions)): - # Try to combine the partition pair - # and see the new cost after combination - new_cost = try_combining_partitions(i, j, self.partitions[:]) - if new_cost <= cost: - partition_pair = [i, j] - cost = new_cost - reorganize_partitions(self.partitions) - # If a partition pair is found, combine them - if len(partition_pair) != 0: - p0 = self.partitions[partition_pair[0]] - p1 = self.partitions[partition_pair[1]] - combine_two_partitions(p0, p1, self.partitions) - get_bfs_level_partition(self.partitions) - reset_partition_device(self.partitions) - get_device_to_partitions_mapping(self.partitions, self.devices) - return len(partition_pair) != 0 - - for node in self.graph_module.graph.nodes: - if node.op not in {"placeholder", "get_attr", "output"}: - self.create_single_node_partition(node) - # Set up parent partitions and children partitions for each partition - set_parents_and_children(self.partitions) - # Get bfs level for each partition - get_bfs_level_partition(self.partitions) - find_combination = True - while find_combination: - # Search for a pair partition to generate the minimum new cost, - # then combine them - find_combination = search_combination( - transfer_rate_bytes_per_sec, node_to_latency_mapping - ) - # Make sure all partitions are set up correctly - reorganize_partitions(self.partitions) - # Set up node to partition mapping - self.node_to_partition = get_node_to_partition_mapping(self.partitions) - return - - def kl_based_partition( - self, - transfer_rate_bytes_per_sec: float, - node_to_latency_mapping: Dict[Node, NodeLatency], - ) -> None: - """This function is a cost aware partition based - on Kernighan-Lin algorithm. - First, the graph is partitioned using size_based_partition. - Then, each node is swapped with any other node in a different - partition, and at the same time, the cost is estimated after - the swapping. - For example, we have nodes n0, n1, n2, n3 and n4. - Using size_based_partition, n0 and n1 are in Partition p0. - n2, n3 and n4 in Partition p1. The current cost is esimated. - We first tried using n0 to swap with n2 from the other partiton. - Then we see that swapping n0 and n2 shows a lower cost - than the current cost and it is the minimum among other pairs like - (n0, None)(This means moving n0 to Partition without swapping other nodes), - (n0, n3) and (n0, n4). We swap n0 and n2 and set the new cost - as the current cost. - Then We repeat this process for all the other nodes until all swapping pairs - are tried. - """ - - def swap_nodes(n0, n1, p0, p1): - # Either n0 or n1 could be None - # That means we simply move the node - # to another partition - if n0 is not None: - p0.remove_node(n0) - p1.add_node(n0) - if n1 is not None: - p0.add_node(n1) - p1.remove_node(n1) - - def try_swap_nodes( - n0, n1, p0, p1, node_to_latency_mapping, transfer_rate_per_sec - ): - cost = float("inf") - swap_nodes(n0, n1, p0, p1) - # Reorganize partitions after swapping - reorganize_partitions(self.partitions) - # Check if there is a circular dependency after swapping - if (not check_dependency(p0)) and (not check_dependency(p1)): - reset_partition_device(self.partitions) - partition_to_latency_mapping = get_partition_to_latency_mapping( - self.partitions, node_to_latency_mapping - ) - # Check if all partitions can be mapped to logical devices after swapping - found_device = get_device_to_partitions_mapping( - self.partitions, self.devices - ) - if not found_device: - cost = float("inf") - else: - cost = get_latency_of_partitioned_graph( - self.partitions, - partition_to_latency_mapping, - transfer_rate_bytes_per_sec, - ) - # Swap back and reset all partitions back to original - swap_nodes(n1, n0, p0, p1) - reorganize_partitions(self.partitions) - reset_partition_device(self.partitions) - get_device_to_partitions_mapping(self.partitions, self.devices) - return cost - - def swap_node_to_partition( - node, p0, p1, node_to_latency_mapping, transfer_rate_per_sec - ): - """This function helps to swap one node from partition p0 - with all the nodes in another partition p1 - """ - p1_nodes = list(p1.nodes) + [None] - min_cost = float("inf") - node_pair: List[Node] = [] - for n1 in p1_nodes: - # Ignore the node if it is not a op node - if n1 is not None and n1.op in {"placeholder", "get_attr"}: - continue - # Try swapping node in p0 with n1 in p1 - cost = try_swap_nodes( - node, n1, p0, p1, node_to_latency_mapping, transfer_rate_per_sec - ) - if cost < min_cost: - node_pair = [node, n1] - min_cost = cost - return cost, node_pair - - # First use size_base_partition - self.size_based_partition() - partition_to_latency_mapping = get_partition_to_latency_mapping( - self.partitions, node_to_latency_mapping - ) - # Calculate the cost of the partitions - cost = get_latency_of_partitioned_graph( - self.partitions, partition_to_latency_mapping, transfer_rate_bytes_per_sec - ) - # Keep tracking the node pair that shows the better cost - node_pair: List[Node] = [] - # Keep tracking the partition pair of node pair - partition_pair: List[Partition] = [] - # Collect all the op nodes from the graph - op_nodes = [] - for n in self.graph_module.graph.nodes: - if n.op not in {"placeholder", "get_attr", "output"}: - op_nodes.append(n) - for node in op_nodes: - # Find which partition the current node belongs - p0_index = self.node_to_partition[node] - p0 = self.partitions[p0_index] - # Go through all the other partitions to swap - # with other nodes from those partitions - for p1_index, _ in enumerate(self.partitions): - if p0_index != p1_index: - p1 = self.partitions[p1_index] - new_cost, new_node_pair = swap_node_to_partition( - node, - p0, - p1, - node_to_latency_mapping, - transfer_rate_bytes_per_sec, - ) - # Update the cost - # Track the swapped node pair and their partitions - if new_cost < cost: - cost = new_cost - node_pair = new_node_pair - partition_pair = [p0, p1] - # Do the swapping after trying all the nodes from a partition - if len(node_pair) != 0: - swap_nodes( - node_pair[0], node_pair[1], partition_pair[0], partition_pair[1] - ) - reorganize_partitions(self.partitions) - get_device_to_partitions_mapping(self.partitions, self.devices) - reorganize_partitions(self.partitions) - # Mapping the device to the partition - get_device_to_partitions_mapping(self.partitions, self.devices) - return - - def aot_based_partition( - self, node_to_partition_mapping, partition_to_logical_device_mapping - ): - """This function helps to rebuild the partitions given the nodes and its - corresponding partition id - """ - partition_id_to_partition_mapping: Dict[int, Partition] = {} - self.node_to_partition = node_to_partition_mapping - for node in self.node_to_partition: - partition_id = self.node_to_partition[node] - # If the requested partition has not been created, create the partition - if partition_id not in partition_id_to_partition_mapping: - partition = Partition(partition_id) - self.partitions.append(partition) - partition_id_to_partition_mapping[partition_id] = partition - partition.logical_device_ids = partition_to_logical_device_mapping[ - partition_id - ] - else: - partition = partition_id_to_partition_mapping[ - self.node_to_partition[node] - ] - # Add the current node into the partition - partition.add_node(node) diff --git a/pippy/fx/experimental/const_fold.py b/pippy/fx/experimental/const_fold.py deleted file mode 100644 index f0cf5433e..000000000 --- a/pippy/fx/experimental/const_fold.py +++ /dev/null @@ -1,294 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates -import re -from typing import Callable, Dict, Optional, Set, Union - -import torch -import pippy.fx -from pippy.fx.node import map_arg -from pippy.fx.passes.split_module import split_module - - -class FoldedGraphModule(pippy.fx.GraphModule): - """ - FoldedGraphModule is a GraphModule which also contains another - `const_subgraph_module` representing a subgraph which has all const attr - inputs and which can be run once before running the main standard - `graph`. The `const_output_names` are the ordered list names of attrs which - represent what each respective output from the const_subgraph should be set - on which attrs. - """ - - def __init__( - self, - root: torch.nn.Module, - graph: pippy.fx.Graph, - const_subgraph: Optional[pippy.fx.Graph] = None, - fx_const_folded_attrs_name: str = None, - device_for_folded_attrs: str = "cuda", - ): - # In init, we set graph's owning module to root which will make graph's - # owning module be None because graph already have a owning module. We - # need owning module to run DCE. To work around we set the number of - # graph's owners to 0. - graph._owners = 0 - super().__init__(root, graph) - self.const_subgraph_module = ( - None - if const_subgraph is None - else pippy.fx.GraphModule(root, const_subgraph) - ) - self.has_folding_been_run = False - self.fx_const_folded_attrs_name = fx_const_folded_attrs_name - self.device_for_folded_attrs = device_for_folded_attrs - - def __call__(self, *args, **kwargs): - if not self.has_folding_been_run: - self.run_folding() - return super().__call__(*args) - - def run_folding(self): - # If there's no const subgraph module or attr output names to use, return - # early as there is no const folding to perform. - if ( - self.const_subgraph_module is None - or self.fx_const_folded_attrs_name is None - ): - return - - assert not self.has_folding_been_run - self.has_folding_been_run = True - - # Actually run const folding subgraph. Note that single attr const fold - # subgraphs output a single Tensor while multiple outputs are returned as - # Tuple[Tensor,]. - folded_attrs = self.const_subgraph_module() - - def _create_param(i): - return torch.nn.Parameter( - i - if not isinstance(i, int) - else torch.Tensor([i]).to(device=self.device_for_folded_attrs), - requires_grad=i.requires_grad if isinstance(i, torch.Tensor) else False, - ) - - params = ( - torch.nn.ParameterList([_create_param(i) for i in folded_attrs]) - if isinstance(folded_attrs, tuple) - else _create_param(folded_attrs) - ) - setattr(self, self.fx_const_folded_attrs_name, params) - - -def _inline_module(gm: pippy.fx.GraphModule, inline_mod_name: str): - """ - Given `gm` and some graph module which is called with target name `inline_mod_name`, - this helper will inline all of the nodes from that called graph module into `gm`. - """ - # Fetch the inner graph module that we want to inline inside `gm`. - inline_mod = dict(gm.named_modules())[inline_mod_name] - assert isinstance(inline_mod, pippy.fx.GraphModule) - call_mod_node_to_replace = None - for node in gm.graph.nodes: - if node.op == "call_module" and node.target == inline_mod_name: - call_mod_node_to_replace = node - break - assert call_mod_node_to_replace is not None - - # Now actually do the swap. Note that we have to keep track of new nodes that are - # copied into `gm` -- we do this via replacement_mapping. - call_mod_args = call_mod_node_to_replace.args - replacement_mapping: Dict[pippy.fx.Node, pippy.fx.Node] = {} - ph_count = 0 - - def replacement_fn(node): - new_node = replacement_mapping[node] - new_node.meta = node.meta.copy() - return new_node - - for inline_node in inline_mod.graph.nodes: - if inline_node.op == "placeholder": - replacement_mapping[inline_node] = call_mod_args[ph_count] - ph_count += 1 - continue - - if inline_node.op == "output": - outputs = inline_node.args[0] - output_replacements = map_arg(outputs, replacement_fn) - call_mod_node_to_replace.replace_all_uses_with(output_replacements) - continue - - with gm.graph.inserting_before(call_mod_node_to_replace): - new_node = gm.graph.node_copy(inline_node, replacement_fn) - replacement_mapping[inline_node] = new_node - - gm.graph.eliminate_dead_code() - - -def get_unique_attr_name_in_module(mod_traced: pippy.fx.GraphModule, name: str) -> str: - """ - Make sure the name is unique (in a module) and can represents an attr. - """ - # Delete all characters that are illegal in a Python identifier. - name = re.sub("[^0-9a-zA-Z_]+", "_", name) - if name[0].isdigit(): - name = f"_{name}" - # Now make sure it is in fact unique to the module by incrementing suffix value. - while hasattr(mod_traced, name): - match = re.match(r"(.*)_(\d+)$", name) - if match is None: - name = name + "_1" - else: - base, num = match.group(1, 2) - name = f"{base}_{int(num) + 1}" - - return name - - -def split_const_subgraphs( - module: Union[torch.nn.Module, pippy.fx.GraphModule], - skip_folding_node_fn: Optional[Callable[[pippy.fx.Node], bool]] = None, - device_for_folded_attrs: str = "cpu", -) -> FoldedGraphModule: - """ - Looks through `module` for any nodes that have all constant attribute inputs - and separates them out into their own constant subgraph, and returns a - FoldedGraphModule which runs that constant subgraph on the first run to set - attributes on the module prior to running the non-constant portion of the - graph. - """ - if not isinstance(module, pippy.fx.GraphModule): - mod_traced = pippy.fx.symbolic_trace(module) - else: - mod_traced = module - - # Build up a list of const_nodes, defined as nodes that are themselves - # get_attrs, or have all get_attr or other constant node inputs. - const_nodes: Set[pippy.fx.Node] = set() - found_const_folding = False - for node in mod_traced.graph.nodes: - # Skip over placeholders/outputs because they can't be const folded and - # we don't want to add tags to them. - if node.op in {"placeholder", "output"}: - continue - - # If the node itself is constant, or all of its inputs are constant, - # then tag it as constant. - if node.op != "get_attr" and not set(node.all_input_nodes).issubset( - const_nodes - ): - continue - - # If provided skip folding function says to skip, then skip. - if skip_folding_node_fn and skip_folding_node_fn(node): - continue - - # Skip folding side-effectful functions - if node.is_impure(): - continue - - # Must be a constant foldable node at this point. - const_nodes.add(node) - if node.op != "get_attr": - found_const_folding = True - - # If we did not find any const folding then return early without a const fold subgraph. - if not found_const_folding: - return FoldedGraphModule(mod_traced, mod_traced.graph) - - # Partition the module into two: submod_0 for constant folding subgraph, and - # submod_1 for the rest. - def mod_partition(node: pippy.fx.Node): - return 0 if node in const_nodes else 1 - - split = split_module(mod_traced, module, mod_partition) - - const_gm, non_const_gm = split.submod_0, split.submod_1 - const_mod_name, non_const_mod_name = "submod_0", "submod_1" - - # The module that a call_module node refers to gets copied to submodules during split. - # The path to the module also gets inlined, i.e. mod.a.b -> mod_a_b. Here we need to - # attach inlined modules to `split` as it's the owning module now. - for node in non_const_gm.graph.nodes: - if node.op == "call_module": - setattr(split, node.target, getattr(non_const_gm, node.target)) - for node in const_gm.graph.nodes: - if node.op == "call_module": - setattr(split, node.target, getattr(const_gm, node.target)) - - # split_module currently does not use get_attrs for attrs. Instead it passes - # them in as args from the parent module, which used get_attrs. Here we set - # them as get_attrs inside const_gm, allowing for running folding without - # somehow a priori knowing the attrs that should be passed as args. We can - # unconditionally do this for all placeholders because we know all - # placeholders to const_gm must be constants accessible via get_attr. - call_const_gm_args = None - for node in split.graph.nodes: - if node.op == "call_module": - if node.target == const_mod_name: - call_const_gm_args = node.args - break - assert call_const_gm_args is not None - - # Here we do the actual replacement of placeholders to get_attrs. Note that here we - # set the const_gm.graph into a new root_const_gm with split as the root module, - # because we are fetching attributes directly from the root module, instead of - # fetching them from const_gm. Example: The const_gm must have some format like: - # graph(): - # %inp : [#users=1] = placeholder[target=const_inp] - # %add : [#users=1] = call_function[target=operator.add](args = (%inp, %inp), kwargs = {}) - # return add - # We replace that with the following, which does not have any placeholders: - # graph(): - # %inp_1 : [#users=1] = get_attr[target=const_inp] - # %add : [#users=1] = call_function[target=operator.add](args = (%inp_1, %inp_1), kwargs = {}) - # return add - root_const_gm = pippy.fx.GraphModule(split, const_gm.graph) - for node in root_const_gm.graph.nodes: - if node.op == "output": - multiple_outputs = isinstance(node.args[0], tuple) - continue - if node.op != "placeholder": - continue - in_node = next(n for n in call_const_gm_args if n.name == node.target) - assert in_node.op == "get_attr" - with root_const_gm.graph.inserting_before(node): - new_node = root_const_gm.graph.get_attr(in_node.target) - new_node.meta = node.meta.copy() - node.replace_all_uses_with(new_node) - root_const_gm.graph.erase_node(node) - assert "multiple_outputs" in locals() - - # Now find the call to const_gm inside split, and replace it with a getattr to the - # folded tensor(s) that result from constant folding. Note that we don't need to - # worry about whether this is one or more tensors because the original graph - # correctly uses getitem to extract individual tensors if there are multiple folded. - fx_const_folded_attrs_name = get_unique_attr_name_in_module( - split, "_FX_CONST_FOLDED_ATTRS" - ) - setattr( - split, - fx_const_folded_attrs_name, - torch.nn.ParameterList() if multiple_outputs else torch.nn.Parameter(), - ) - for node in split.graph.nodes: - if node.op == "call_module" and node.target == const_mod_name: - with node.graph.inserting_before(node): - folded_attrs = node.graph.get_attr(fx_const_folded_attrs_name) - folded_attrs.meta = node.meta.copy() - node.replace_all_uses_with(folded_attrs) - break - - split.graph.eliminate_dead_code() - - # Finally, inline the non-constant submod into the split submod. This is so that the - # original caller who may have passed in a graph module will get back out a graph - # module whose graph is traced to the same granularity. - _inline_module(split, non_const_mod_name) - - return FoldedGraphModule( - split, - split.graph, - root_const_gm.graph, - fx_const_folded_attrs_name, - device_for_folded_attrs, - ) diff --git a/pippy/fx/experimental/debug.py b/pippy/fx/experimental/debug.py deleted file mode 100644 index 916c605f8..000000000 --- a/pippy/fx/experimental/debug.py +++ /dev/null @@ -1,32 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates -import pippy.fx as fx - -def set_trace(gm: fx.GraphModule) -> fx.GraphModule: - """ - Sets a breakpoint in `gm`'s generated python code. It drops into pdb when - `gm` gets run. - - Args: - gm: graph module to insert breakpoint. It is then recompiled for it to - take effect. - - Returns: - the `gm` with breakpoint inserted. - """ - def insert_pdb(body): - return ["import pdb; pdb.set_trace()\n", *body] - - with gm.graph.on_generate_code( - make_transformer=lambda cur_transform: ( - # new code transformer to register - lambda body: ( - insert_pdb( - cur_transform(body) if cur_transform - else body - ) - ) - ) - ): - gm.recompile() - - return gm diff --git a/pippy/fx/experimental/graph_gradual_typechecker.py b/pippy/fx/experimental/graph_gradual_typechecker.py deleted file mode 100644 index a3fe8cf37..000000000 --- a/pippy/fx/experimental/graph_gradual_typechecker.py +++ /dev/null @@ -1,927 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates -import itertools -import operator -from functools import reduce -from typing import Callable, Dict - -import torch -from torch.nn.modules.batchnorm import BatchNorm2d -from torch.nn.modules.conv import Conv2d - -import pippy -from pippy.fx.experimental.refinement_types import Equality -from pippy.fx.experimental.unification import Var # type: ignore[attr-defined] -from pippy.fx.node import Target, Node -from pippy.fx.tensor_type import Dyn, is_consistent, TensorType, is_more_precise - -try: - import sympy # type: ignore[import] - HAS_SYMPY = True -except ImportError: - HAS_SYMPY = False - -_INFERENCE_RULES: Dict[Target, Callable] = {} -_REFINEMENT_RULES: Dict[Target, Callable] = {} -_RULES: Dict[Target, Callable] = {} - - -def expand_to_tensor_dim(t, n): - """ - Expand a type to the desired tensor dimension if possible - Raise an error otherwise. - - t is the given type - - n is a number of dimensions to expand to - """ - if t == Dyn: - dims = [Dyn] * n - return TensorType(tuple(dims)) - elif isinstance(t, TensorType): - if len(t.__args__) != n: - raise TypeError(f'Cannot extend tensor. Tensor {t} has rank {len(t.__args__)}. It should have rank {n}') - return t - else: - raise TypeError(f'Cannot match the type {t}') - - -def broadcast_types(t1, t2): - """ - Applies broadcasting to both given types such that they - become consistent with eachother and returns two new - resulting types - """ - - # if either type is Dyn, do nothing since the types are already consistent - if t1 == Dyn or t2 == Dyn or isinstance(t1, Var) or isinstance(t2, Var): - return t1, t2 - - if isinstance(t1, TensorType) and isinstance(t2, TensorType): - s1 = len(t1.__args__) - s2 = len(t2.__args__) - - new_t1 = list(t1.__args__) - new_t2 = list(t2.__args__) - - # We make the types the same length which is the first requirement - # for consistency - if s1 > s2: - for i in range(s1 - s2): - new_t2.insert(0, 1) - - elif s2 > s1: - for i in range(s2 - s1): - new_t1.insert(0, 1) - - # we replace occurrences of "1" with each tensor with - # the corresponding type from the other tensor - for i, (x, y) in enumerate(zip(new_t1, new_t2)): - if x == 1: - new_t1[i] = y - elif y == 1: - new_t2[i] = x - - # at this point our tensors should be consistent - # and we can apply the element-wise operation and find the right dimension - # for the output of the operation - (t1, t2) = TensorType(tuple(new_t1)), TensorType(tuple(new_t2)) - return (t1, t2) - else: - raise TypeError(f'Cannot broadcast types {t1} and {t2}') - -def register_inference_rule(call_target): - def register(fn): - if call_target in _INFERENCE_RULES: - raise RuntimeError(f'Inference rule already registered for {call_target}!') - _INFERENCE_RULES[call_target] = fn - return fn - return register - -def register_refinement_rule(call_target): - def register(fn): - if call_target in _REFINEMENT_RULES: - raise RuntimeError(f'Refinement rule already registered for {call_target}!') - _REFINEMENT_RULES[call_target] = fn - return fn - return register - -def register_algebraic_expressions_inference_rule(call_target): - def register(fn): - if call_target in _RULES: - raise RuntimeError(f'Rule already registered for {call_target}!') - _RULES[call_target] = fn - return fn - return register - -@register_inference_rule(torch.add) -@register_inference_rule(operator.add) -def add_inference_rule(n: Node): - """ - Apply the addition inference rule. This includes: - - scalar addition - - broadcasting semantics - - Note that we always return the least precise type between - the operands (after applying broadcasting) to be the final type of the operation - - Note that we do not modify the operand types themselves after applying broadcasting - to them. We only use them to calculate the final type - """ - assert isinstance(n.args[0], Node) - assert isinstance(n.args[1], Node) - t1 = n.args[0].type - t2 = n.args[1].type - - # handle scalar addition - if t1 == int and isinstance(t2, TensorType): - n.type = t2 - return n.type - - # handle scalar addition - elif t2 == int and isinstance(t1, TensorType): - n.type = t1 - return n.type - - # we bring the new types to the point where - # we can check for consistency - # any inconsistency would not have been caused - # by broadcasting at this point - (new_t1, new_t2) = broadcast_types(t1, t2) - - if new_t1 != t1 or new_t2 != t2: - n.meta['broadcast'] = True - n.meta[str(n.args[0])] = new_t1 - n.meta[str(n.args[1])] = new_t2 - - else: - n.meta['broadcast'] = False - - new_t1 = t1 if not n.meta['broadcast'] else new_t1 - new_t2 = t2 if not n.meta['broadcast'] else new_t2 - - # we check for consistency between the new types - if is_consistent(new_t1, new_t2): - # we return the less precise type because - # broadcasting may have happened - # for operands with shape [1,2,Dyn] and [1,2,1] - # we have to assign the node [1,2,Dyn] - if is_more_precise(new_t1, new_t2): - n.type = new_t2 - else: - n.type = new_t1 - return n.type - else: - raise TypeError(f'Cannot add arguments {n.args[0]} ({ n.args[0].type}) and {n.args[1]} ({ n.args[1].type}) in node {n}.' - f' Types should match ') - -@register_inference_rule(getattr) -def get_attr_inference_rule(n: Node, traced): - """ - The current getattr rule only handles the shape attribute - Can be extended to other attributes - The most representitive type we have is "Dyn" but the system - can be extended with more types, such as a type to represent shapes - """ - attr_node = n.args[0] - attr_name = n.args[1] - - if attr_name == "shape": - n.type = Dyn - else: - raise TypeError("Not yet implelemted") - - # TODO. We leave it like this till we add a type to represent tensor sizes - return n.type - -@register_inference_rule(torch.transpose) -def transpose_inference_rule(n: Node): - """ - We check that dimentions for the transpose operations - are within range of the tensor type of the node - """ - if n.target == torch.transpose: - assert isinstance(n.args[0], Node) - t = n.args[0].type - - assert isinstance(n.args[1], int) - assert isinstance(n.args[2], int) - dim1, dim2 = n.args[1], n.args[2] - - if t == Dyn: - n.type = Dyn - return n.type - - elif isinstance(t, TensorType): - if 0 <= dim1 < len(t.__args__) and 0 <= dim2 < len(t.__args__): - new_type = list(t.__args__) - new_type[dim1], new_type[dim2] = new_type[dim2], new_type[dim1] - final = TensorType(new_type) - n.type = get_greatest_upper_bound(n.type, final) - return n.type - else: - raise TypeError(f'Cannot transpose {dim1} and {dim2} in type {t} for node {n}') - else: - raise TypeError(f'Cannot transpose {dim1} and {dim2} in type {t} for node {n}') - - -@register_inference_rule(torch.reshape) -def reshape_inference_rule(n: Node): - """ - Without dynamism, the rule checks that the - product of the elements of the argument tensor - type is equal to the product of the elements - of the required shape. We gradualize this rule - by adding a case to handle fully dynamic input - as well as input where some of the tensor dimensions - are unknown. In this case we check for divisibility - """ - assert isinstance(n.args[0], Node) - t1 = n.args[0].type - - assert isinstance(n.args[1], list) - t2 = n.args[1] - t2_type = TensorType([Dyn if elem == -1 else elem for elem in t2]) - - # if we do not know the original tensor dimension, - # we return the required dimension - if t1 == Dyn: - n.type = t2_type - return t2_type - - # if any of the dimensions are unknown, - # we check for divisibility - elif isinstance(t1, TensorType): - assert isinstance(t1, TensorType) - a = [e if e != Dyn else 1 for e in t1.__args__] - p1 = reduce(lambda x, y: x * y, a) - p2 = reduce(lambda x, y: x * y, t2) - if p1 % p2 == 0 or p2 % p1 == 0: - n.type = t2_type - return t2_type - else: - raise TypeError(f'Cannot reshape in node {n} from {t1} to {t2_type}') - else: - raise TypeError(f'Cannot reshape in node {n} from {t1} to {t2_type}') - -@register_inference_rule(BatchNorm2d) -def bn2d_inference_rule(n: Node, module_instance): - """ - Given a BatchNorm2D instance and a node check the following conditions: - - the input type can be expanded to a size 4 tensor: t = (x_1, x_2, x_3, x_4) - - the current node type can be expanded to a size 4 tensor: t' = (x_1', x_2', x_3', x_4') - - t is consistent with t' - - x_2 is consistent with the module's num_features - - x_2' is consistent with the module's num_features - output type: the more precise type of t and t' - """ - assert isinstance(n.args[0], Node) - n.args[0].type = expand_to_tensor_dim(n.args[0].type, 4) - arg_type = n.args[0].type - n.type = expand_to_tensor_dim(n.type, 4) - - # we check the conditions on the incoming argument - # and any existing annotation - # we also check for consistency between both annotations - if is_consistent(arg_type.__args__[1], module_instance.num_features) and \ - is_consistent(n.type.__args__[1], module_instance.num_features) and \ - is_consistent(arg_type, n.type): - - # we choose the more precise type - # to be the node type - # so if an incoming argument has more type information - # we set this node's type to be the argument type - n.type = get_greatest_upper_bound(arg_type, n.type) - return n.type - else: - raise TypeError(f'Cannot apply {module_instance} with input type {arg_type} and existing type {n.type} on {n}') - - -def calculate_out_dimension(d_in, module_instance, index): - """ - For calculating h_in and w_out according to the conv2D documentation - """ - padding = (module_instance.padding, module_instance.padding) \ - if isinstance(module_instance.padding, int) else module_instance.padding - kernel_size = (module_instance.kernel_size, module_instance.kernel_size) \ - if isinstance(module_instance.kernel_size, int) else module_instance.kernel_size - stride = (module_instance.stride, module_instance.stride) \ - if isinstance(module_instance.stride, int) else module_instance.stride - dilation = (module_instance.dilation, module_instance.dilation) \ - if isinstance(module_instance.dilation, int) else module_instance.dilation - - DIMENSION_TYPES = (int, sympy.Symbol) if HAS_SYMPY else (int,) - - if d_in == Dyn: - return Dyn - - elif isinstance(d_in, DIMENSION_TYPES): - n = d_in + 2 * padding[index] - \ - dilation[index] * \ - (kernel_size[index] - 1) - 1 - - return (n // stride[0]) + 1 - - else: - raise TypeError(f'{d_in} in {module_instance} must be a number or Dyn. Received {type(d_in)}') - - -def get_greatest_upper_bound(type1, type2): - """ - Get the most precise type that's consistent with the given types - """ - if type1 == Dyn: - return type2 - elif type2 == Dyn: - return type1 - elif isinstance(type1, TensorType) and isinstance(type2, TensorType): - if not is_consistent(type1, type2): - raise TypeError(f'Inconsistent types {type1}, {type2}') - gub = [t1 if is_more_precise(t1, t2) else t2 for (t1, t2) in zip(type1.__args__, type2.__args__)] - return TensorType(tuple(gub)) - - -@register_inference_rule(Conv2d) -def conv2d_inference_rule(n: Node, module_instance): - """ - Given a Conv2D instance and a node check the following conditions: - - the input type can be expanded to a size 4 tensor: t = (x_1, x_2, H, W) - - the current node type can be expanded to a size 4 tensor: t' = (x_1', x_2', x_3', x_4') - - x_2 is consistent with the module's in_channels - - let o = (x_1, out_channels, H_out, W_out) - then the output is the greatest upper bound of o and the existing node type t'. - """ - assert isinstance(n.args[0], Node) - n.args[0].type = expand_to_tensor_dim(n.args[0].type, 4) - arg_type = n.args[0].type - curr_node_type = expand_to_tensor_dim(n.type, 4) - - if is_consistent(arg_type.__args__[1], module_instance.in_channels): - w_in = arg_type.__args__[3] - h_in = arg_type.__args__[2] - h_out = calculate_out_dimension(h_in, module_instance, 0) - w_out = calculate_out_dimension(w_in, module_instance, 1) - new_type = TensorType((arg_type.__args__[0], module_instance.out_channels, h_out, w_out)) - gub = get_greatest_upper_bound(new_type, curr_node_type) - n.type = gub - return n.type - else: - raise TypeError(f'Cannot apply {module_instance} with input type { arg_type} and existing type {n.type} on {n}') - - -@register_inference_rule(torch.nn.ReLU) -def relu_inference_rule(n: Node, module_instance): - """ - Input and output shapes should be equal. - """ - assert isinstance(n.args[0], Node) - - if n.args[0].type == Dyn and isinstance(n.type, TensorType): - n.args[0].type = expand_to_tensor_dim(n.args[0].type, len(n.type.__args__)) - - if isinstance(n.args[0].type, TensorType): - n.type = get_greatest_upper_bound(n.args[0].type, n.type) - return n.type - - -def maxpool2d_check(typ, module_instance): - """ - Applies the maxpool2d shape information to the input - this affects the last two dimensions - """ - new_type_list = list(typ.__args__) - if len(new_type_list) == 4 or len(new_type_list) == 3: - w_in = new_type_list[-1] - h_in = new_type_list[-2] - - h_out = calculate_out_dimension(h_in, module_instance, 0) - w_out = calculate_out_dimension(w_in, module_instance, 1) - - new_type_list[-1] = w_out - new_type_list[-2] = h_out - return TensorType(tuple(new_type_list)) - - else: - raise TypeError(f'Wrong size {typ} for {module_instance}') - - -@register_inference_rule(torch.nn.MaxPool2d) -def maxpool2d_inference_rule(n: Node, module_instance): - """ - Given a MaxPool2D instance and a node check the following conditions: - - Input size matches size 3 or 4 - - Current node type is consistent with the output type we will calculate - - Input size matches output size and the last two dimensions of the output - are w_out and h_out. The remaining dimensions are the same as the input - - Our final result is the greatest upper bound of the output we calculate - and the current node type. - """ - assert isinstance(n.args[0], Node) - - if n.args[0].type == Dyn and isinstance(n.type, TensorType): - n.args[0].type = expand_to_tensor_dim(n.args[0].type, len(n.type.__args__)) - if isinstance(n.args[0].type, TensorType): - output = maxpool2d_check(n.args[0].type, module_instance) - n.type = get_greatest_upper_bound(output, n.type) - return n.type - - - -def linear_check(tensor_type, module_instance): - """ - Checks that an input tensor type satisfies the conditions for linear operation - and returns the output type based on in and out features given by module_instance - """ - if len(tensor_type.__args__) >= 2: - if is_consistent(module_instance.in_features, tensor_type.__args__[-1]): - new_type_args = list(tensor_type.__args__) - new_type_args[-1] = module_instance.out_features - return TensorType(tuple(new_type_args)) - else: - raise TypeError(f'Inconsistent {module_instance.in_features} and {tensor_type.__args__[-1]} in {module_instance}') - else: - raise TypeError(f'Type {tensor_type} must have rank 2 or more.') - - -@register_inference_rule(torch.nn.Linear) -def linear_inference_rule(n: Node, module_instance): - """ - Applies the shape information to the input then gets the greatest upper bound - of the resulting type and the existing type - """ - assert isinstance(n.args[0], Node) - if n.args[0].type == Dyn and isinstance(n.type, TensorType): - n.args[0].type = expand_to_tensor_dim(n.args[0].type, len(n.type.__args__)) - if isinstance(n.args[0].type, TensorType): - output_type = linear_check(n.args[0].type, module_instance) - n.type = get_greatest_upper_bound(output_type, n.type) - return n.type - - -def adaptiveavgpool2d_check(tensor_type, module_instance): - output_size = module_instance.output_size - if isinstance(output_size, int): - output_size = [output_size, output_size] - elif isinstance(output_size, tuple): - output_size = list(output_size) - if output_size[0] is None: - output_size[0] = output_size[1] - if output_size[1] is None: - output_size[1] = output_size[0] - - new_type_list = list(tensor_type.__args__) - - if len(tensor_type.__args__) == 4 or len(tensor_type.__args__) == 3: - new_type_list[-1] = output_size[1] - new_type_list[-2] = output_size[0] - - return TensorType(tuple(new_type_list)) - - else: - raise TypeError(f'Tensor ranks must be 3 or 4. Got {tensor_type}') - -@register_inference_rule(torch.nn.AdaptiveAvgPool2d) -def adaptiveavgpool2d_inference_rule(n: Node, module_instance): - """ - The input and output sizes should be the same except for the last - two dimensions taken from the input, which represent width and height - """ - assert isinstance(n.args[0], Node) - if n.args[0].type == Dyn and isinstance(n.type, TensorType): - n.args[0].type = expand_to_tensor_dim(n.args[0].type, len(n.type.__args__)) - if isinstance(n.args[0].type, TensorType): - output_type = adaptiveavgpool2d_check(n.args[0].type, module_instance) - n.type = get_greatest_upper_bound(n.type, output_type) - return n.type - -def flatten_check(tensor_type, start_dim, end_dim): - l = len(tensor_type.__args__) - - start_dim = l if start_dim == -1 else abs(start_dim) - end_dim = l + end_dim + 1 if end_dim < 0 else end_dim + 1 - - if 0 <= start_dim <= (l - 1) and 0 <= end_dim <= l and start_dim < end_dim: - my_args = list(tensor_type.__args__) - lhs = my_args[0:start_dim] - rhs = my_args[end_dim:] - mid = my_args[start_dim:end_dim] - if Dyn in mid: - mid = [Dyn] - else: - mid = [reduce(lambda x, y: x * y, my_args[start_dim:end_dim])] - new_type_list = lhs + mid + rhs - return TensorType(tuple(new_type_list)) - else: - raise TypeError(f'Incompatable dimentions {start_dim}, {end_dim - 1} in type {tensor_type}') - -@register_inference_rule(torch.flatten) -def flatten_inference_rule(n: Node): - """ - Applies the flatten shape information to the input then gets the - greatest upper bound of the resulting type and the existing type - """ - assert isinstance(n.args[0], Node) - - # set the default start and end dims - start_dim = 1 - end_dim = -1 - - if len(n.args) > 1: - assert isinstance(n.args[1], int) - start_dim = n.args[1] - - if len(n.args) > 2: - assert isinstance(n.args[2], int) - end_dim = n.args[2] - - if n.args[0].type == Dyn and isinstance(n.type, TensorType): - n.args[0].type = expand_to_tensor_dim(n.args[0].type, len(n.type.__args__)) - - if isinstance(n.args[0].type, TensorType): - output_type = flatten_check(n.args[0].type, start_dim, end_dim) - n.type = get_greatest_upper_bound(output_type , n.type) - - return n.type - -class GraphTypeChecker: - def __init__(self, env, traced): - self.env = env - self.traced = traced - - def type_check(self): - """ - A gradual type checker for graphs - Effect: every node's field type will be - populated with a type after type-checking is done - """ - graph = self.traced.graph - - # type check every node with gradual type rules - # if any node does not type check return false - for n in graph.nodes: - self.type_check_node(n) - return True - - def type_check_node(self, n: Node): - """ - Type check a given fx node. - Current operations: - - Reshape - - Transpose - - Add - - Relu - - conv2d - - batchnorm2d - - flatten - - maxpool2d - - adaptiveavgpool2d - - linear - """ - if n.type is None: - n.type = Dyn - - if n.op == 'placeholder': - return n.type - - elif n.op == 'get_attr': - t = get_parameter(self.traced, n.target) # type: ignore[arg-type] - if isinstance(t.data, torch.Tensor): - n.type = TensorType(t.data.shape) - return n.type - - elif n.op == 'call_function': - if n.target == getattr: - assert getattr in _INFERENCE_RULES - return _INFERENCE_RULES[n.target](n, self.traced) - - elif n.target in _INFERENCE_RULES: - return _INFERENCE_RULES[n.target](n) - else: - raise RuntimeError(f'No inference rule registered for target {n.target}!') - - elif n.op == 'call_module': - module_instance = self.traced.get_submodule(n.target) - if type(module_instance) in _INFERENCE_RULES: - return _INFERENCE_RULES[type(module_instance)](n, module_instance) - else: - raise RuntimeError(f'No inference rule registered for class {type(module_instance)}!') - - elif n.op == 'output': - def get_node_type(a): - return a.type - n.type = pippy.fx.node.map_arg(n.args[0], get_node_type) - return n.type - - else: - raise NotImplementedError(f"Method {n.op} not yet implemented") - - -@register_refinement_rule(Conv2d) -def conv_refinement_rule(n: Node): - """ - The equality constraints are between the first dimension of - the input and output - """ - res = [] - assert isinstance(n.args[0], Node) - arg_type = n.args[0].type - if isinstance(arg_type, TensorType) and isinstance(n.type, TensorType): - res = [Equality(arg_type.__args__[0], n.type.__args__[0])] - return res - - -@register_refinement_rule(torch.nn.Linear) -def linear_refinement_rule(n: Node): - """ - The equality constraints are between the first dimension of - the input and output - """ - res = [] - assert isinstance(n.args[0], Node) - arg_type = n.args[0].type - if isinstance(arg_type, TensorType) and isinstance(n.type, TensorType): - res = [Equality(arg_type.__args__[0], n.type.__args__[0])] - return res - -@register_refinement_rule(BatchNorm2d) -@register_refinement_rule(torch.nn.ReLU) -def all_eq(n: Node): - """ - For operations where the input shape is equal to the output shape - """ - res = [] - assert isinstance(n.args[0], Node) - arg_type = n.args[0].type - if isinstance(arg_type, TensorType) and isinstance(n.type, TensorType): - args1 = arg_type.__args__ - args2 = n.type.__args__ - res = [Equality(args1[i], args2[i]) for i in range(len(args1))] - return res - - -@register_refinement_rule(torch.nn.AdaptiveAvgPool2d) -@register_refinement_rule(torch.nn.MaxPool2d) -def first_two_eq(n: Node): - """ - For operations where the first two dimensions of the input and output shape - are equal - """ - res = [] - assert isinstance(n.args[0], Node) - arg_type = n.args[0].type - if isinstance(arg_type, TensorType) and isinstance(n.type, TensorType): - args1 = arg_type.__args__ - args2 = n.type.__args__ - res = [Equality(args1[0], args2[0]), Equality(args1[1], args2[1])] - return res - - -@register_refinement_rule(torch.add) -@register_refinement_rule(operator.add) -def element_wise_eq(n: Node): - """ - For element-wise operations and handles broadcasting. - Note that after applying broadcasting to the arguments - we are able to determine if certain dimensions have not been broadcast - if they are symbolicallu equal. - - in this case, we can establish equality between those dimensions and the - corresponding output dimensions. - - Note that it takes two iterations for this result. One iteration to establish - equality between certain dimensions of the operands (requiring the whole solver - including unification) and another iteration to establish equality between the operands - and the resulting type, requiring another round of constraint generation and unificaiton. - """ - res = [] - if isinstance(n.args[0], Node) and isinstance(n.args[1], Node): - arg_type1 = n.args[0].type - arg_type2 = n.args[1].type - if isinstance(arg_type1, TensorType) and isinstance(arg_type2, TensorType) and isinstance(n.type, TensorType): - args1, args2 = broadcast_types(arg_type1, arg_type2) - # by this point, we know that args1 and args2 are the same size. - a1 = args1.__args__ - a2 = args2.__args__ - a3 = n.type.__args__ - - # we would be here in the second iteration where we establish equality - # between operand type dimensions and the resulting type dimensions - r = [] - for x, y, z in zip(a1, a2, a3): - if x == y: - r.append(Equality(x, z)) - res = r - return res - - -@register_refinement_rule(torch.flatten) -def flatten_refinement_rule(n: Node): - """ - Generates equality constraints between the dimensions of the input and output - that will not be involved in the flatten operation - """ - assert isinstance(n.args[0], Node) - - eq_const = [] - - start_dim = 1 - end_dim = -1 - - if len(n.args) > 1: - assert isinstance(n.args[1], int) - start_dim = n.args[1] - - if len(n.args) > 2: - assert isinstance(n.args[2], int) - end_dim = n.args[2] - - if isinstance(n.type, TensorType) and isinstance(n.args[0].type, TensorType): - l = len(n.type.__args__) - arg_type = n.args[0].type - start_dim = l if start_dim == -1 else start_dim - end_dim = l + end_dim + 1 if end_dim < 0 else end_dim + 1 - - for t1, t2 in zip(n.type.__args__[0:start_dim], arg_type.__args__[0:start_dim]): - eq_const.append(Equality(t1, t2)) - - for t1, t2 in zip(n.type.__args__[end_dim:], arg_type.__args__[end_dim:]): - eq_const.append(Equality(t1, t2)) - return eq_const - - -@register_algebraic_expressions_inference_rule(Conv2d) -def conv_rule(n: Node, module_instance): - """ - Represents the outout in terms of an algrbraic expression w.r.t - the input when possible - """ - assert isinstance(n.args[0], Node) - arg_type = n.args[0].type - if isinstance(arg_type, TensorType) and isinstance(n.type, TensorType): - w_in = arg_type.__args__[3] - h_in = arg_type.__args__[2] - h_out = calculate_out_dimension(h_in, module_instance, 0) - w_out = calculate_out_dimension(w_in, module_instance, 1) - new_type = TensorType((n.type.__args__[0], n.type.__args__[1], h_out, w_out)) - n.type = new_type - return new_type - -class Refine: - """ - Symbolic shape inference. - Generates constraints over type variables. - Currently all constraints are equality constraints. - """ - def __init__(self, traced): - self.constraints = [] - self.traced = traced - self.symbol_iter = itertools.count(start=0, step=1) - - def refine(self): - """ - Generates constraints for - every node in the graph based on - the operation. - """ - graph = self.traced.graph - for n in graph.nodes: - self.refine_node(n) - return True - - def symbolic_relations(self): - """ - Infers algebraic relations - """ - graph = self.traced.graph - for n in graph.nodes: - self.infer_symbolic_relations(n) - return True - - def replace_dyn_with_fresh_var(self, typ): - """ - Replace all unknown types with fresh type variables. - """ - if typ == Dyn: - new_symbol = Var(next(self.symbol_iter)) - return new_symbol - elif isinstance(typ, TensorType): - new_args = [self.replace_dyn_with_fresh_var(a) for a in typ.__args__] - return TensorType(tuple(new_args)) - elif isinstance(typ, list): - return [self.replace_dyn_with_fresh_var(t) for t in typ] - elif isinstance(typ, tuple): - return (self.replace_dyn_with_fresh_var(t) for t in typ) - else: - return typ - - - def convert_to_sympy_symbols(self, typ): - """ - Replace all unknown types with fresh type variables. - """ - if HAS_SYMPY: - if isinstance(typ, Var): - return sympy.symbols(str(typ)) - elif isinstance(typ, TensorType): - new_args = [self.convert_to_sympy_symbols(a) for a in typ.__args__] - return TensorType(tuple(new_args)) - elif isinstance(typ, list): - return [self.convert_to_sympy_symbols(t) for t in typ] - elif isinstance(typ, tuple): - return (self.convert_to_sympy_symbols(t) for t in typ) - else: - return typ - else: - return typ - - def refine_node(self, n: Node): - """ - Returns a list of equality constraints for - call_module and call_function nodes. - Models the relation between input and output dimensions - using constraints in case they are both tensors. - All operations used in resnet50 are defined. - """ - if n.type is None: - n.type = Dyn - - n.type = self.replace_dyn_with_fresh_var(n.type) - - if n.op == 'call_function': - if n.target in _REFINEMENT_RULES: - self.constraints += _REFINEMENT_RULES[n.target](n) - else: - pass - - if n.op == 'call_module': - module_instance = self.traced.get_submodule(n.target) - if type(module_instance) in _REFINEMENT_RULES: - self.constraints += _REFINEMENT_RULES[type(module_instance)](n) - else: - pass - - if n.op == 'output': - def get_node_type(a): - return a.type - n.type = pippy.fx.node.map_arg(n.args[0], get_node_type) - return n.type - - else: - pass - - def infer_symbolic_relations(self, n: Node): - if HAS_SYMPY: - n.type = self.convert_to_sympy_symbols(n.type) - if n.op == 'call_function': - if n.target in _RULES: - return _RULES[n.target](n) - else: - pass - - if n.op == 'call_module': - module_instance = self.traced.get_submodule(n.target) - if type(module_instance) in _RULES: - return _RULES[type(module_instance)](n, module_instance) - else: - pass - - if n.op == 'output': - def get_node_type(a): - return a.type - n.type = pippy.fx.node.map_arg(n.args[0], get_node_type) - return n.type - - else: - pass - else: - pass - -def get_parameter(traced, target: str): - """ - Returns the parameter given by ``target`` if it exists, - otherwise throws an error. - - See the docstring for ``get_submodule`` for a more detailed - explanation of this method's functionality as well as how to - correctly specify ``target``. - - Args: - target: The fully-qualified string name of the Parameter - to look for. (See ``get_submodule`` for how to specify a - fully-qualified string.) - - Returns: - torch.nn.Parameter: The Parameter referenced by ``target`` - - Raises: - AttributeError: If the target string references an invalid - path or resolves to something that is not an - ``nn.Parameter`` - """ - module_path, _, param_name = target.rpartition(".") - - mod: torch.nn.Module = traced.get_submodule(module_path) - - if not hasattr(mod, param_name): - raise AttributeError(mod._get_name() + " has no attribute `" + param_name + "`") - - param: torch.nn.Parameter = getattr(mod, param_name) - - return param diff --git a/pippy/fx/experimental/merge_matmul.py b/pippy/fx/experimental/merge_matmul.py deleted file mode 100644 index f53ea9c9f..000000000 --- a/pippy/fx/experimental/merge_matmul.py +++ /dev/null @@ -1,172 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates -import itertools -import operator -from typing import Dict, List - -import torch - -import pippy -import pippy.fx -from pippy.fx._symbolic_trace import symbolic_trace -from pippy.fx.node import Node -from pippy.fx.passes.tools_common import legalize_graph - - -def split_result_tensors(result: torch.Tensor, inputs: List[torch.Tensor]) -> List[torch.Tensor]: - """ - A free function for use in the merge_matmul graph transformation below that - splits the output from a merged matmul into the individual results for each - input tensor. - - Arguments: - result: The merged matmul result tensor. - inputs: The list of inputs that were merged into one for the matmul. - - Returns: - List of matmul results for each input tensor. - """ - # When fx tracer is running, x.shape[0] will be pippy.fx.Attribute but we - # need an int even when tracing - if isinstance(result, pippy.fx.Proxy): - splits = [0] * len(inputs) - else: - splits = [x.shape[0] for x in inputs] - - return torch.split(result, splits) - - -def may_depend_on(a: Node, b: Node, search_depth: int = 6): - """ - Determine if one node depends on another in a pippy.fx.Graph. - - Arguments: - a: The node that may have a dependency on b. - b: The node that a may have a dependency on. - search_depth: In the case of an indirect dependency, this function - searches upto this many nodes away in search of a - data dependency. If none is found, the function - makes the conservative assumption that there is a - dependency. - - Returns: - True if a may depend on b, False if it definitely does not. - """ - # Equivalence is defined as dependence. - if a == b: - return True - - # If a has no inputs, it cannot depend on b. - if len(a.all_input_nodes) == 0: - return False - - # If the search depth has been exhausted and no conclusion has been - # reached, assume that there is a data dependency. - if search_depth == 0: - return True - - # Recursively check all inputs of a. - for inp in a.all_input_nodes: - if may_depend_on(inp, b, search_depth - 1): - return True - - return False - - -def are_nodes_independent(nodes: List[Node]): - """ - Check if all of the given nodes are pairwise-data independent. - - Arguments: - nodes: The nodes to check for data dependencies. - - Returns: - True if any pair in nodes has a data dependency. - """ - # For each pair in nodes: - for i, j in itertools.combinations(nodes, 2): - if may_depend_on(i, j) or may_depend_on(j, i): - return False - - return True - - -def merge_matmul(in_mod: torch.nn.Module): - """ - A graph transformation that merges matrix multiplication operations that share the same right-hand - side operand into one large matrix multiplication. - ____ _________ _________ - ---- | | | | M| A * C | - M| A | T| B | * K| C | = |---------| - ---- , | | | | T| B * C | - K ---- --------- --------- - K R R - """ - gm = symbolic_trace(in_mod) - - rhs_users: Dict[Node, List[Node]] = {} - lhs_users: Dict[Node, List[Node]] = {} - - # Populate rhs_users and lhs_users - maps from LHS/RHS matrix multiply operands to - # the matmul of which they are the LHS/RHS. - for node in gm.graph.nodes: - if node.op != "call_function" or node.target is not torch.matmul: - continue - - lhs, rhs = node.args - - # TODO: Properly handle aliasing caused by get_attr. For now, - # use the attribute name as the operand if the node is a - # get_attr. - lhs = lhs.target if lhs.op == "get_attr" else lhs - rhs = rhs.target if rhs.op == "get_attr" else rhs - - lhs_users.setdefault(lhs, []).append(node) - rhs_users.setdefault(rhs, []).append(node) - - for rhs, mms in rhs_users.items(): - # There must be at least matmuls for a merge to make sense. - if len(mms) < 2: - continue - - # All matmuls must not depend on each other directly or indirectly - # in order for the merge to be possible. - if not are_nodes_independent(mms): - continue - - lhs_vals = [mm.args[0] for mm in mms] - - # Merge the matmul. - # Collect a list of LHS operands and the single RHS operand. - lhs = [gm.graph.get_attr(l) if isinstance(l, str) else l for l in lhs_vals] - rhs = gm.graph.get_attr(rhs) if isinstance(rhs, str) else rhs - - # Concatenate all the LHS operands. - merge_mm_cat = gm.graph.call_function(torch.cat, (lhs,), {}) - - # Multiply the concatenated LHS operands with the one RHS. This will produce - # the same results as all the individual matmuls involving rhs in the original graph, - # but they will all be concatenated together. - merge_mm = gm.graph.call_function(torch.matmul, (merge_mm_cat, rhs,), {}) - - # Split the result of the merged matmul using the shapes of the LHS operands - # to ascertain how large each chunk should be. - merge_mm_split = gm.graph.call_function( - split_result_tensors, (merge_mm, lhs), {} - ) - merge_mm_res = [ - gm.graph.call_function(operator.getitem, (merge_mm_split, out), {}) - for out in range(len(lhs)) - ] - - # Replace all uses of the original, unmerged matmuls with the equivalent split chunk from the merged matmul. - for old, new in zip(mms, merge_mm_res): - old.replace_all_uses_with(new) - gm.graph.erase_node(old) - - # All of the new nodes created above were inserted at the end, so we need to sort - # the nodes topologically to make sure all definitions precede uses. - legalize_graph(gm) - - gm.recompile() - gm.graph.lint() - return gm diff --git a/pippy/fx/experimental/meta_tracer.py b/pippy/fx/experimental/meta_tracer.py deleted file mode 100644 index 3f5caa599..000000000 --- a/pippy/fx/experimental/meta_tracer.py +++ /dev/null @@ -1,269 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates -import torch -import pippy.fx -import warnings -import functools -import builtins - -from typing import Any, Callable, Dict, Optional, Union - -def embedding_override(self, input): - return torch.empty(*input.shape, self.weight.shape[-1], device='meta') - - -def nn_layernorm_override(self, input): - return input - - -def torch_relu_override(x): - return x - - -def torch_nn_relu_override(self, x): - return x - - -def functional_relu_override(x, inplace=False): - assert not inplace, 'dont support inplace functional.relu for metatensor analysis' - return x - - -def torch_where_override(condition, x, y): - # torch.where returns the broadcasted tensor of condition, x, and y, - # so hack it by using addition - return condition.to(device='meta') + x.to(device='meta') + y.to(device='meta') - - -def torch_abs_override(input, *, out=None): - assert out is None, 'Dont support in-place abs for MetaTensor analysis' - return input - -manual_meta_overrides : Dict[Callable, Callable] = { - torch.nn.Embedding: embedding_override, - torch.nn.LayerNorm: nn_layernorm_override, - torch.relu: torch_relu_override, - torch.nn.functional.relu: functional_relu_override, - torch.nn.ReLU: torch_nn_relu_override, - torch.where: torch_where_override, - torch.abs: torch_abs_override, -} - -def gen_constructor_wrapper(target): - @functools.wraps(target) - def wrapper(*args, **kwargs): - proxy = None - - def check_has_proxy(v): - if isinstance(v, pippy.fx.Proxy): - nonlocal proxy - proxy = v - pippy.fx.node.map_aggregate(args, check_has_proxy) - pippy.fx.node.map_aggregate(kwargs, check_has_proxy) - - if proxy is not None: - return proxy.tracer.create_proxy('call_function', target, args, kwargs) - else: - return target(*args, **kwargs) - return wrapper, target - -class MetaProxy(pippy.fx.Proxy): - def install_tensor_meta(self, tensor_meta): - self._tensor_meta = tensor_meta - - def size(self, dim=None): - if hasattr(self, '_tensor_meta') and self._tensor_meta is not None: - return self._tensor_meta.size(*[dim] if dim else []) - return self.tracer.create_proxy('call_method', 'size', (self, dim) if dim else (self,), {}) - - def dim(self): - if hasattr(self, '_tensor_meta') and self._tensor_meta is not None: - return self._tensor_meta.dim() - return self.tracer.create_proxy('call_method', 'dim', (self,), {}) - - @property - def shape(self): - if hasattr(self, '_tensor_meta') and self._tensor_meta is not None: - return self._tensor_meta.shape - return self.tracer.create_proxy('call_function', builtins.getattr, (self, 'shape'), {}) - - @property - def dtype(self): - if hasattr(self, '_tensor_meta') and self._tensor_meta is not None: - return self._tensor_meta.dtype - return self.tracer.create_proxy('call_function', builtins.getattr, (self, 'dtype'), {}) - - @property - def device(self): - # Hack so we can track when devices are used. During meta-tensor propagation, - # replace these values with a constant 'meta' - return MetaDeviceAttribute(self, 'device') - - def __getattr__(self, k): - if k == '_tensor_meta': - return self.__getattribute__(k) - # note: not added to the graph yet, if this is a method call - # we peephole optimize to the method invocation - return MetaAttribute(self, k) - -class MetaAttribute(MetaProxy): - def __init__(self, root, attr: str): - - self.root = root - self.attr = attr - self.tracer = root.tracer - self._node = None - - @property - def node(self): - # the node for attributes is added lazily, since most will just be method calls - # which do not rely on the getitem call - if self._node is None: - self._node = self.tracer.create_proxy('call_function', getattr, (self.root, self.attr), {}).node - return self._node - - def __call__(self, *args, **kwargs): - return self.tracer.create_proxy('call_method', self.attr, (self.root,) + args, kwargs) - -class MetaDeviceAttribute(MetaAttribute): - pass - -def proxys_to_metas(v): - if isinstance(v, MetaDeviceAttribute): - return 'meta' - if isinstance(v, pippy.fx.Proxy): - assert isinstance(v, MetaProxy), f'Expected MetaProxy but got {type(v)}' - assert hasattr(v, '_tensor_meta'), 'MetaProxy does not have an associated meta' - return v._tensor_meta - return v - -class MetaTracer(pippy.fx.Tracer): - allow_insert_stateless_mods : bool = True - - _TORCH_METHODS_TO_PATCH = ['arange', 'zeros', 'ones', 'full_like', 'eye'] - - def create_proxy(self, kind, target, args, kwargs, name=None, type_expr=None, proxy_factory_fn=None): - rv = super().create_proxy(kind, target, args, kwargs, name, type_expr, proxy_factory_fn) - - if kind == 'placeholder' and target in self.meta_args: - rv.install_tensor_meta(self.meta_args[target]) - return rv - - if target in self.orig_fns: - # NOTE: tensor constructors in PyTorch define the `device` argument as - # *kwargs-only*. That is why this works. If you add methods to - # _TORCH_METHODS_TO_PATCH that do not define `device` as kwarg-only, - # this will break and you will likely see issues where we cannot infer - # the size of the output. - if 'device' in kwargs: - kwargs['device'] = 'meta' - - try: - args_metas = pippy.fx.node.map_aggregate(args, proxys_to_metas) - kwargs_metas = pippy.fx.node.map_aggregate(kwargs, proxys_to_metas) - - if kind == 'call_function': - meta_target = manual_meta_overrides.get(target, target) - meta_out = meta_target(*args_metas, **kwargs_metas) - elif kind == 'call_method': - meta_out = getattr(args_metas[0], target)(*args_metas[1:], **kwargs_metas) - elif kind == 'call_module': - assert hasattr(self, 'orig_forward') - self._disable_module_getattr = True - try: - mod = self.root.get_submodule(target) - mod_type = type(mod) - if mod_type in manual_meta_overrides: - meta_out = manual_meta_overrides[mod_type](mod, *args_metas, **kwargs_metas) - else: - meta_out = self.orig_forward(*args_metas, **kwargs_metas) - finally: - self._disable_module_getattr = False - elif kind == 'get_attr': - self._disable_module_getattr = True - try: - attr_itr = self.root - atoms = target.split('.') - for atom in atoms: - attr_itr = getattr(attr_itr, atom) - assert isinstance(attr_itr, torch.Tensor) - meta_out = attr_itr.to(device='meta') - finally: - self._disable_module_getattr = False - else: - return rv - - # TODO - assert isinstance(rv, pippy.fx.Proxy), 'Dont support composite output yet' - rv.install_tensor_meta(meta_out) - except Exception as e: - warnings.warn(f'Could not compute metadata for {kind} target {target}: {e}') - - return rv - - def getattr(self, attr, attr_val, parameter_proxy_cache): - if getattr(self, '_disable_module_getattr', False): - return attr_val - else: - return super().getattr(attr, attr_val, parameter_proxy_cache) - - def call_module(self, m, forward, args, kwargs): - self.orig_forward = forward - return super().call_module(m, forward, args, kwargs) - - def _insert_module_as_submodule(self, mod: torch.nn.Module) -> str: - """ - Helper method which tries to insert a module that was not declared as submodule. - """ - idx = 0 - mod_name = mod.__class__.__name__.lower() - path = f"{mod_name}_{idx}" - while hasattr(self.root, path): - path = f"{mod_name}_{idx}" - idx += 1 - - self.root.add_module(path, mod) - return path - - def path_of_module(self, mod: torch.nn.Module) -> str: - try: - return super().path_of_module(mod) - except NameError as e: - if self.allow_insert_stateless_mods and len(list(mod.parameters())) == 0 and len(list(mod.buffers())) == 0: - path = self._insert_module_as_submodule(mod) - self.prev_module = path - return path - raise - - def proxy(self, node): - return MetaProxy(node, self) - - def trace(self, root, meta_args : Dict[str, torch.Tensor], concrete_args=None): - assert isinstance(meta_args, dict) - self.meta_args = meta_args - - self.patched_torch_methods = { - target: gen_constructor_wrapper(getattr(torch, target)) for target in self._TORCH_METHODS_TO_PATCH - } - self.orig_fns = set() - - for name, (wrapper, orig) in self.patched_torch_methods.items(): - setattr(torch, name, wrapper) - self.orig_fns.add(orig) - - try: - graph = super().trace(root, concrete_args) - graph._tracer_extras = {'meta_args': meta_args} - return graph - finally: - for name, (_, orig) in self.patched_torch_methods.items(): - setattr(torch, name, orig) - - -def symbolic_trace(root : Union[torch.nn.Module, Callable[..., Any]], - meta_args : Dict[str, torch.Tensor] = None, - concrete_args: Optional[Dict[str, Any]] = None) -> pippy.fx.GraphModule: - tracer = MetaTracer() - graph = tracer.trace(root, meta_args, concrete_args) - name = root.__class__.__name__ if isinstance(root, torch.nn.Module) else root.__name__ - gm = pippy.fx.GraphModule(tracer.root, graph, name) - return gm diff --git a/pippy/fx/experimental/migrate_gradual_types/__init__.py b/pippy/fx/experimental/migrate_gradual_types/__init__.py deleted file mode 100644 index f2661b8c6..000000000 --- a/pippy/fx/experimental/migrate_gradual_types/__init__.py +++ /dev/null @@ -1 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates diff --git a/pippy/fx/experimental/migrate_gradual_types/constraint.py b/pippy/fx/experimental/migrate_gradual_types/constraint.py deleted file mode 100644 index 9188e8346..000000000 --- a/pippy/fx/experimental/migrate_gradual_types/constraint.py +++ /dev/null @@ -1,559 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates -# -*- coding: utf-8 -*- -from pippy.fx.experimental.migrate_gradual_types.operation import op_add, op_sub, op_mul, op_div, \ - op_mod, op_gt, op_lt, op_neq, op_eq -from pippy.fx.tensor_type import TensorType, Dyn - - -class Constraint: - pass - - -class Conj(Constraint): - def __init__(self, conjuncts): - """ - :param conjuncts: Conjuction of constraints - """ - self.conjucts = conjuncts - - def __eq__(self, other): - if isinstance(other, Conj): - return self.conjucts == other.conjucts and self.conjucts == other.conjucts - else: - return False - - def __repr__(self): - return f'And({self.conjucts})' - - -class Disj(Constraint): - def __init__(self, disjuncts): - """ - :param disjuncts: Disjunction of constraints - """ - self.disjuncts = disjuncts - - def __eq__(self, other): - if isinstance(other, Disj): - return self.disjuncts == other.disjuncts and self.disjuncts == other.disjuncts - else: - return False - - def __repr__(self): - return f'Or({self.disjuncts})' - - -class Prod(Constraint): - def __init__(self, products): - """ - :param products: lists of dimensions to multiply - """ - self.products = products - - def __eq__(self, other): - if isinstance(other, Prod): - return self.products == other.products and self.products == other.products - else: - return False - - def __repr__(self): - return f'Product({self.products})' - - -class T(Constraint): - """ - True - """ - def __init__(self): - pass - - def __eq__(self, other): - return isinstance(other, T) - - def __repr__(self): - return 'True' - -class F(Constraint): - """ - False - """ - def __init__(self): - pass - - def __eq__(self, other): - return isinstance(other, F) - - def __repr__(self): - return 'False' - - -class BinaryConstraint(Constraint): - """ - Represents all binary operations - """ - def __init__(self, lhs, rhs, op): - """ - :param lhs: lhs of the constraint - :param rhs: rhs of the constraint - :param op: string reprsenting the operation - """ - self.lhs = lhs - self.rhs = rhs - self.op = op - - def __eq__(self, other): - if isinstance(other, BinaryConstraint): - return self.lhs == other.lhs and self.rhs == other.rhs and self.op == other.op - else: - return False - - def __repr__(self): - return f'({self.lhs} {self.op} {self.rhs})' - - -class BinConstraintT(BinaryConstraint): - """ - Binary constraints about tensors - """ - def __init__(self, lhs, rhs, op): - assert (isinstance(lhs, TVar) or isinstance(lhs, TensorType) or isinstance(lhs, int) or lhs == Dyn) and \ - (isinstance(rhs, TVar) or isinstance(rhs, TensorType) or isinstance(rhs, int) or rhs == Dyn) - super().__init__(lhs, rhs, op) - - def __eq__(self, other): - return super().__eq__(other) - - -class BinConstraintD(BinaryConstraint): - """ - Binary constraints about dimensions - """ - def __init__(self, lhs, rhs, op): - assert is_algebraic_expression(lhs) or is_dim(lhs) or is_bool_expr(lhs) - assert is_algebraic_expression(rhs) or is_dim(rhs) or is_bool_expr(rhs) - - super().__init__(lhs, rhs, op) - - def __eq__(self, other): - return super().__eq__(other) - - - -class TGreatestUpperBound(Constraint): - """ - Greatest Upper bound for tensors with dynamic type - """ - def __init__(self, res, rhs1, rhs2): - """ - :param res: tensor variable that stores the result of the outout - :param rhs1: tensor or tensor variable - :param rhs2: tensor or tensor variabke - """ - self.res = res - self.rhs1 = rhs1 - self.rhs2 = rhs2 - - def __repr__(self): - return f'{self.res} = {self.rhs1}⊔*{self.rhs2}' - - def __eq__(self, other): - if isinstance(other, TGreatestUpperBound): - return self.res == other.res and self.rhs1 == other.rhs1 and self.rhs2 == other.rhs2 - else: - return False - - -class DGreatestUpperBound(Constraint): - """ - Greatest Upper bound for dimensions - """ - def __init__(self, res, rhs1, rhs2): - """ - :param res: Dimension variable to store the result - :param rhs1: dimension variable 1 - :param rhs2: dimension variable 2 - """ - assert is_dim(res) - assert is_dim(rhs1) - assert is_dim(rhs2) - - self.res = res - self.rhs1 = rhs1 - self.rhs2 = rhs2 - - def __repr__(self): - return f'{self.res} = {self.rhs1}⊔{self.rhs2}' - - def __eq__(self, other): - if isinstance(other, DGreatestUpperBound): - return self.res == other.res and self.rhs1 == other.rhs1 and self.rhs2 == other.rhs2 - else: - return False - - -class CanReshape(Constraint): - """ - can_reshape constraint - """ - def __init__(self, src, target): - """ - :param src: tensor variable - :param target: tensor - """ - self.src = src - self.target = target - - def __repr__(self): - return f'can-reshape({self.src}, {self.target})' - - def __eq__(self, other): - if isinstance(other, CanReshape): - return self.src == other.src and self.target == other.target - else: - return False - - -class IndexSelect(Constraint): - - def __init__(self, tensor_size, input_var, dim_replace, index, output): - """ - Args: - input_var: input to index_select - tensor_size: tensor size we are considering - dim_replace: the dimension of the output at "index" - index: location of the dimensions to replace in the input - outut: variable to store the result - """ - assert isinstance(input_var, TVar) - assert isinstance(output, TVar) - assert isinstance(dim_replace, DVar) or dim_replace == Dyn - assert isinstance(index, int) - - self.input_var = input_var - self.tensor_size = tensor_size - self.dim_replace = dim_replace - self.index = index - self.output = output - - def __repr__(self): - - return f' {self.output} = ' \ - f'IndexSelect({self.input_var}, ' \ - f'tensor_size: {self.tensor_size}, ' \ - f'{self.dim_replace}, ' \ - f'{self.index})' - - def __eq__(self, other): - if isinstance(other, IndexSelect): - return self.tensor_size == other.tensor_size and \ - self.dim_replace == other.dim_replace and \ - self.index == other.index and \ - self.output == other.output and \ - self.input_var == other.input_var - else: - return False - - -class Transpose(Constraint): - - def __init__(self, tensor_size, input_var, index1, index2, output): - """ - Args: - tensor_size: current tensor size - input_var: variable to hold input - index1: dimension 1 - index2: dimension 2 - output: output that stores result - """ - assert isinstance(input_var, TVar) - assert isinstance(output, TVar) - assert isinstance(index1, int) - assert isinstance(index2, int) - - self.input_var = input_var - self.tensor_size = tensor_size - self.index1 = index1 - self.index2 = index2 - self.output = output - - def __repr__(self): - - return f' {self.output} = ' \ - f'Transpose({self.input_var}, ' \ - f'tensor_size: {self.tensor_size}, ' \ - f'{self.index1}, ' \ - f'{self.index2})' - - def __eq__(self, other): - if isinstance(other, Transpose): - return self.tensor_size == other.tensor_size and \ - self.index1 == other.index1 and \ - self.index2 == other.index2 and \ - self.output == other.output and \ - self.input_var == other.input_var - else: - return False - - -class GetItem(Constraint): - - def __init__(self, tensor_size, index, res, input_var): - """ - Constraint for getting item given a tensor size - :param tensor_size: actual number - :param index: actual number representing the index - :param res: dimension variable to carry the item we get - :param input_var: a tensor variable from which we will get item - """ - assert isinstance(res, DVar) - - self.res = res - self.tensor_size = tensor_size - self.index = index - self.input_var = input_var - - def __repr__(self): - return f' {self.res} = GetItem({self.input_var}, tensor_size: {self.tensor_size}, {self.index})' - - def __eq__(self, other): - if isinstance(other, GetItem): - return self.res == other.res and \ - self.tensor_size == other.tensor_size and \ - self.index == other.index and \ - self.input_var == other.input_var - else: - return False - -class GetItemTensor(Constraint): - - def __init__(self, tensor_size, index_tuple, res, input_var): - """ - Constraint for getting item given a tensor size - However, when the argument is a tuple, we will - expect a tensor - :param tensor_size: actual number representing the rank - :param index_tuple: tuple for indexing - :param res: tensor variable to carry the item we get - :param input_var: a tensor variable from which we will get item - """ - assert isinstance(res, TVar) - - self.res = res - self.tensor_size = tensor_size - self.index_tuple = index_tuple - self.input_var = input_var - - def __repr__(self): - return f' {self.res} = GetItemT({self.input_var}, tensor_size: {self.tensor_size}, {self.index_tuple})' - - def __eq__(self, other): - if isinstance(other, GetItemTensor): - return self.res == other.res and \ - self.tensor_size == other.tensor_size and \ - self.index_tuple == other.index_tuple and \ - self.input_var == other.input_var - else: - return False - -class CalcConv(Constraint): - - def __init__(self, conv_result, input_var, c_out, kernel, padding, stride, dilation, matching_constraint_vars): - """ - :param conv_result: the convolution result - :param input_var: input to convolution - :param c_out: output chanel type - :param kernel: kernel tuple - """ - self.conv_result = conv_result - self.input_var = input_var - self.c_out = c_out - self.kernel = kernel - self.padding = padding - self.stride = stride - self.dilation = dilation - self.matching_constraint = matching_constraint_vars - - def __repr__(self): - return f'{self.conv_result} =' \ - f' calc-conv({self.input_var},' \ - f' {self.c_out}, {self.kernel}, ' \ - f'{self.padding}, {self.stride},' \ - f' {self.dilation})' - - def __eq__(self, other): - if isinstance(other, CalcConv): - return self.conv_result == other.conv_result and self.input_var == other.input_var and \ - self.c_out == other.c_out and self.kernel == other.kernel and self.padding == other.padding \ - and self.stride == other.stride and self.dilation == other.dilation \ - and self.matching_constraint == other.matching_constraint - else: - return False - - -class CalcMaxPool(Constraint): - - def __init__(self, maxpool_result, input_var, kernel, padding, stride, dilation, matching_constraint_vars): - """ - :param maxpool_result: the result of maxpool - :param input_var: input to convolution - :param kernel: kernel tuple - """ - self.maxpool_result = maxpool_result - self.input_var = input_var - self.kernel = kernel - self.padding = padding - self.stride = stride - self.dilation = dilation - self.matching_constraint = matching_constraint_vars - - def __repr__(self): - return f'{self.maxpool_result} =' \ - f' calc-maxpool({self.input_var},' \ - f' {self.kernel}, ' \ - f'{self.padding}, {self.stride},' \ - f' {self.dilation})' - - def __eq__(self, other): - if isinstance(other, CalcMaxPool): - return self.maxpool_result == other.maxpool_result and self.input_var == other.input_var \ - and self.kernel == other.kernel and self.padding == other.padding \ - and self.stride == other.stride and self.dilation == other.dilation \ - and self.matching_constraint == other.matching_constraint - else: - return False - - -class ApplyBroadcasting(Constraint): - def __init__(self, res1, res2, input1, input2): - """ - :param res1: resulting tensor 1 - :param res2: resulting tensor 2 - :param input1: tensor variable 1 - :param input2: tensor variable 2 - """ - self.res1 = res1 - self.res2 = res2 - self.input1 = input1 - self.input2 = input2 - - def __eq__(self, other): - if isinstance(other, ApplyBroadcasting): - return self.res1 == other.res1 \ - and self.res2 == other.res2 \ - and self.input1 == other.input1 \ - and self.input2 == other.input2 - else: - return False - - def __repr__(self): - return f'{self.res1}, {self.res2} ='f' apply-broadcasting({self.input1},' f' {self.input2})' - - -class CalcProduct(Constraint): - """ - Given correct dimensions, calculate the product for flatten accounting for Dyn - """ - def __init__(self, start, end, flattened, dims_to_flatten): - """ - :param start: start index - :param end: end index - :param theta: variable to store the product - :param dims_to_flatten: the type which we will flatten - """ - assert isinstance(dims_to_flatten, list) - assert isinstance(flattened, TVar) - assert isinstance(start, int) - assert isinstance(end, int) - - self.start = start - self.end = end - self.dims_to_flatten = dims_to_flatten - self.flattened = flattened - - def __eq__(self, other): - if isinstance(other, CalcProduct): - return self.start == other.start and self.end == other.end and \ - self.dims_to_flatten == other.dims_to_flatten and self.flattened == other.flattened - - else: - return False - - def __repr__(self): - return f'{self.flattened} = CalcProduct({self.start}, {self.end}, {self.dims_to_flatten})' - - -class TVar: - """ - Tensor variable with no tensor constructor - """ - def __init__(self, tvar): - """ - :param tvar: tensor variable - """ - self.tvar = tvar - - def __repr__(self): - return f'TV({self.tvar})' - - def __eq__(self, other): - if isinstance(other, TVar): - return self.tvar == other.tvar - else: - return False - - -class DVar: - """ - Dimension variable - """ - def __init__(self, c): - """ - :param c: character or number - """ - self.c = c - - def __repr__(self): - return f'DV({self.c})' - - def __eq__(self, other): - if isinstance(other, DVar): - return self.c == other.c - else: - return False - - -class BVar: - """ - Boolean variable - """ - def __init__(self, c): - """ - :param c: character or number - """ - self.c = c - - def __repr__(self): - return f'BV({self.c})' - - def __eq__(self, other): - if isinstance(other, BVar): - return self.c == other.c - else: - return False - - -def is_algebraic_expression(constraint): - if isinstance(constraint, BinConstraintD): - return constraint.op in [op_add, op_sub, op_div, op_mul, op_mod] - else: - return isinstance(constraint, Prod) - - -def is_bool_expr(constraint): - if isinstance(constraint, BinConstraintD): - return constraint.op in [op_gt, op_lt, op_neq, op_eq] - else: - return isinstance(constraint, BVar) or isinstance(constraint, Conj) or isinstance(constraint, Disj) - -def is_dim(d): - return isinstance(d, DVar) or isinstance(d, int) or d == Dyn diff --git a/pippy/fx/experimental/migrate_gradual_types/constraint_generator.py b/pippy/fx/experimental/migrate_gradual_types/constraint_generator.py deleted file mode 100644 index b47f0160f..000000000 --- a/pippy/fx/experimental/migrate_gradual_types/constraint_generator.py +++ /dev/null @@ -1,1282 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates -import torch -import operator -import warnings -from typing import Callable, Dict, Iterable - -from pippy.fx._symbolic_trace import _assert_is_none -from pippy.fx.experimental.migrate_gradual_types.constraint import ApplyBroadcasting, CalcProduct, \ - Disj, TGreatestUpperBound, CalcMaxPool, CalcConv, Conj, BinConstraintT, CanReshape, BinConstraintD, GetItem, T, F, \ - TVar, DVar, GetItemTensor, IndexSelect, Transpose, DGreatestUpperBound -from pippy.fx.experimental.migrate_gradual_types.operation import \ - op_eq, op_matching, op_consistency, op_leq, op_precision, op_gt, op_div, op_sub, op_neq, op_lt, op_add, op_mul -from pippy.fx.node import Target, Node -from pippy.fx.experimental.migrate_gradual_types.util import gen_tensor_dims, gen_nat_constraints, gen_dvar, gen_tvar, \ - gen_bvar - -from pippy.fx.tensor_type import Dyn, TensorType -from torch.nn.modules.conv import Conv2d -from torch.nn.modules.batchnorm import BatchNorm2d - -_INFERENCE_RULES: Dict[Target, Callable] = {} - -MAX_TENSOR_RANK = 4 - -def register_inference_rule(call_target): - def register(fn): - if call_target in _INFERENCE_RULES: - raise RuntimeError(f'Inference rule already registered for {call_target}!') - _INFERENCE_RULES[call_target] = fn - return fn - return register - - -def generate_flatten_constraints(start_dim, end_dim, input, flattened, n, counter): - d, counter = gen_tensor_dims(n, counter) - c1 = BinConstraintT(input, TensorType(d), op_eq) - start_dim = n if start_dim == -1 else abs(start_dim) - end_dim = n + end_dim + 1 if end_dim < 0 else end_dim + 1 - c2 = CalcProduct(start_dim, end_dim, flattened, d) - nat_constraints = gen_nat_constraints(d) - return Conj([c1, c2, *nat_constraints]), counter - - -@register_inference_rule(getattr) -def get_attr_inference_rule(n: Node, symbols, constraints, counter): - """ - If the attribute is "device" then the tensor shape is preserved - """ - assert isinstance(n.args[0], Node) - assert isinstance(n.args[1], str) - output, counter = gen_tvar(counter) - symbols[n] = output - - input = symbols[n.args[0]] - attr = n.args[1] - - if attr == 'device': - return [BinConstraintT(input, output, op_eq)], counter - else: - raise NotImplementedError('Not yet implemented') - -@register_inference_rule(torch.bmm) -def bmm_inference_rule(n: Node, symbols, constraints, counter): - """ - Constraints that match the input to a size 3 tensor - and switch the dimensions according to the rules - of batch multiplication - """ - assert isinstance(n.args[0], Node) - assert isinstance(n.args[1], Node) - - bmm_output, counter = gen_tvar(counter) - symbols[n] = bmm_output - - bmm_input1 = symbols[n.args[0]] - bmm_input2 = symbols[n.args[1]] - - dims_input1, counter = gen_tensor_dims(3, counter) - dims_input2, counter = gen_tensor_dims(3, counter) - - inputs_dyn = Conj([BinConstraintT(bmm_input1, Dyn, op_eq), - BinConstraintT(bmm_input2, Dyn, op_eq), - BinConstraintT(bmm_output, Dyn, op_eq)]) - - input1_dyn = Conj([BinConstraintT(bmm_input1, Dyn, op_eq), - BinConstraintT(bmm_input2, TensorType(dims_input2), op_eq), - BinConstraintT(bmm_output, TensorType([dims_input2[0], Dyn, dims_input2[2]]), op_eq)]) - - input2_dyn = Conj([BinConstraintT(bmm_input2, Dyn, op_eq), - BinConstraintT(bmm_input1, TensorType(dims_input1), op_eq), - BinConstraintT(bmm_output, TensorType([dims_input1[0], dims_input1[1], Dyn]), op_eq)]) - - consistency_constraints = [BinConstraintD(dims_input1[0], dims_input2[0], op_consistency)] - - batch_size, counter = gen_dvar(counter) - - inputs_are_tensors = Conj([BinConstraintT(bmm_input1, TensorType(dims_input1), op_eq), - BinConstraintT(bmm_input2, TensorType(dims_input2), op_eq), - BinConstraintT(bmm_output, TensorType([batch_size, dims_input1[1], dims_input2[2]]), op_eq), - *consistency_constraints, DGreatestUpperBound(batch_size, dims_input1[0], dims_input2[0])]) - - return [Disj([inputs_dyn, input1_dyn, input2_dyn, inputs_are_tensors])], counter - - -@register_inference_rule("index_select") -def index_select_inference_rule(n: Node, symbols, constraints, counter): - """ - We constrain the second argument to a vector or Dyn. - The output replaces the input with the shape of the vector - at the position given by the index (first argument) - """ - # print(n.args) - assert isinstance(n.args[0], Node) - assert isinstance(n.args[1], int) - assert isinstance(n.args[2], Node) - - - - index_select, counter = gen_tvar(counter) - symbols[n] = index_select - - dims, counter = gen_tensor_dims(1, counter) - - # equality constraint - is_size_1 = BinConstraintT(symbols[n.args[2]], TensorType(dims), op_eq) - is_dyn = BinConstraintT(symbols[n.args[2]], Dyn, op_eq) - - c2 = Conj([is_size_1, Disj([IndexSelect(i + 1, symbols[n.args[0]], dims[0], n.args[1], index_select) - for i in range(MAX_TENSOR_RANK)])]) - c3 = Conj([is_dyn, Disj([IndexSelect(i + 1, symbols[n.args[0]], Dyn, n.args[1], index_select) - for i in range(MAX_TENSOR_RANK)])]) - - return [Disj([c2, c3])], counter - - -@register_inference_rule("expand") -def expand_inference_rule(n: Node, symbols, constraints, counter): - """ - We generate the exact constraints as we do for tensor additions but we constraint - the rank of this expression to be equal to len(n.args[1:]) so that only - those cases get considered for the output - """ - assert isinstance(n.args[0], Node) - - # define the output for expand - expand, counter = gen_tvar(counter) - symbols[n] = expand - - # since we do not have two nodes here, we will construct an argument variable - e1 = symbols[n.args[0]] - e2, counter = gen_tvar(counter) - - e2_nat_constraints = [] - for arg in n.args[1:]: - assert isinstance(arg, Node) or isinstance(arg, int) - if isinstance(arg, Node): - assert isinstance(symbols[arg], DVar) - e2_nat_constraints.append(BinConstraintD(0, symbols[arg], op_leq)) - - e2_constraint = BinConstraintT(e2, TensorType([arg if isinstance(arg, int) else symbols[arg] for arg in n.args[1:]]), op_eq) - - constraints, counter = gen_broadcasting_constraints(e1, e2, symbols, counter, expand) - - # constraint the output size - dims, counter = gen_tensor_dims(len(n.args[1:]), counter) - nat_constraints = gen_nat_constraints(dims) - c = [BinConstraintT(expand, TensorType(dims), op_eq), *nat_constraints, e2_constraint, *e2_nat_constraints] - constraints += c - - return constraints, counter - - -@register_inference_rule(torch.nn.functional.gelu) -@register_inference_rule(torch.nn.functional.dropout) -@register_inference_rule(torch.nn.functional.softmax) -@register_inference_rule("detach") -@register_inference_rule("to") -@register_inference_rule("int") -@register_inference_rule("long") -@register_inference_rule("contiguous") -@register_inference_rule(torch.ones) -@register_inference_rule(torch.zeros) -def equality_inference_rule(n: Node, symbols, constraints, counter): - """ - We generate the constraint: input = output - """ - output, counter = gen_tvar(counter) - symbols[n] = output - - if isinstance(n.args[0], Node): - input = symbols[n.args[0]] - if isinstance(input, TVar): - return [BinConstraintT(input, output, op_eq)], counter - - # then we have dimension variables - else: - for arg in n.args: - assert isinstance(symbols[arg], DVar) - my_size = [symbols[arg] for arg in n.args] - return [BinConstraintT(output, TensorType(my_size), op_eq)], counter - - elif isinstance(n.args[0], tuple): - # then the tuple is the size - assert len(n.args[0]) <= 4 - my_size = [symbols[arg] for arg in n.args[0]] - return [BinConstraintT(output, TensorType(my_size), op_eq)], counter - else: - raise NotImplementedError('Method not yet implemented') - - -@register_inference_rule("transpose") -def transpose_inference_rule(n: Node, symbols, constraints, counter): - """ - Can be considered as a sequence of two index selects, so we generate constraints accordingly - """ - assert isinstance(n.args[0], Node) - assert isinstance(n.args[1], int) - assert isinstance(n.args[2], int) - - output, counter = gen_tvar(counter) - symbols[n] = output - - from_arg = symbols[n.args[0]] - assert isinstance(from_arg, TVar) - - # input and output are dyn - is_dyn = Conj([BinConstraintT(from_arg, Dyn, op_eq), BinConstraintT(output, Dyn, op_eq)]) - - # or input is a tensor and we actually do the replacement - c3 = Disj([Transpose(i + 1, from_arg, n.args[1], n.args[2], output) for i in range(MAX_TENSOR_RANK)]) - - return [Disj([is_dyn, c3])], counter - - -@register_inference_rule("type_as") -def type_inference_rule(n: Node, symbols, constraints, counter): - """ - We generate the constraint: input = output - """ - assert isinstance(n.args[0], Node) - assert isinstance(n.args[1], Node) - - output, counter = gen_tvar(counter) - symbols[n] = output - - from_arg = symbols[n.args[0]] - to_arg = symbols[n.args[1]] - - assert isinstance(from_arg, TVar) - assert isinstance(to_arg, TVar) - - return [BinConstraintT(from_arg, to_arg, op_consistency), - BinConstraintT(output, to_arg, op_eq)], counter - -@register_inference_rule("masked_fill_") -def masked_fill_inference_rule(n: Node, symbols, constraints, counter): - """ - Similar to addition. For now we implemenent the constraints when - the argument is a boolean tensor. There is also a case for when - it is a condition. We will leave this out for now. - """ - - assert isinstance(n.args[0], Node) - assert isinstance(n.args[1], Node) - - # We will retrieve the type variables from the symbol table - # and confirm they are tensor variables - - e1 = symbols[n.args[0]] - e2 = symbols[n.args[1]] - - if isinstance(e1, TVar) and isinstance(e2, TVar): - masked_fill_tensor, counter = gen_tvar(counter) - symbols[n] = masked_fill_tensor - return gen_broadcasting_constraints(e1, e2, symbols, counter, masked_fill_tensor) - else: - raise NotImplementedError('Not yet implemented') - - -@register_inference_rule(torch.nn.functional.embedding) -def embedding_inference_rule_functional(n: Node, symbols, constraints, counter): - assert isinstance(n.args[0], Node) - - embedding_dim_weights = symbols[n.args[1]] - - # will treat this as a static shape. So we will not use matching. - weight_dims, counter = gen_tensor_dims(2, counter) - equality_constraint = BinConstraintT(embedding_dim_weights, TensorType(weight_dims), op_eq) - embedding_dim = weight_dims[1] - constraints, counter = gen_embedding_rules(n, symbols, embedding_dim, counter) - return [equality_constraint] + constraints, counter - - -@register_inference_rule(torch.nn.modules.sparse.Embedding) -def embedding_inference_rule(n: Node, module_instance, symbols, constraints, counter): - """ - The output shape differs from the input shape in the last dimension - """ - assert isinstance(n.args[0], Node) - return gen_embedding_rules(n, symbols, module_instance.embedding_dim, counter) - - -def gen_embedding_rules(n: Node, symbols, embedding_dim, counter): - - embedding_output, counter = gen_tvar(counter) - symbols[n] = embedding_output - embedding_input = symbols[n.args[0]] - - input_dyn = BinConstraintT(embedding_input, Dyn, op_eq) - output_dyn = BinConstraintT(embedding_output, Dyn, op_eq) - - c1 = Conj([input_dyn, output_dyn]) - c2 = [] - - for i in range(1, MAX_TENSOR_RANK): - new_dims, counter = gen_tensor_dims(i, counter) - nat_constraints = gen_nat_constraints(new_dims) - - # we consider all tensor sizes and append embedding_dim to the end of the output dimension in all cases - c_tensor_i = Conj([BinConstraintT(embedding_input, TensorType(new_dims), op_eq), - BinConstraintT(embedding_output, TensorType(new_dims + [embedding_dim]), op_eq)] + - nat_constraints) - c2.append(c_tensor_i) - - return [Disj([c1, Disj(c2)])], counter - - -@register_inference_rule(torch.tensor) -def tensor_inference_rule(n: Node, symbols, constraints, counter): - """ - If the tensor is a scalar, we will skip it since we - do not support scalars yet. We will add support in the future - if it's needed. For our examples so far, scalars are not needed. - """ - return [], counter - - -@register_inference_rule("reshape") -@register_inference_rule("view") -def view_inference_rule(n: Node, symbols, constraints, counter): - """ - Similar to reshape but with an extra condition on the strides - """ - assert isinstance(n.args[0], Node) - - # generate the new variable - my_view, counter = gen_tvar(counter) - symbols[n] = my_view - - - src_var = symbols[n.args[0]] - t2 = [symbols[elem] if isinstance(elem, Node) else elem for elem in n.args[1:]] # target shape - t2_type = [] - num_constraints = [] - - for t in t2: - if t == -1: - var, counter = gen_dvar(counter) - t2_type.append(var) - num_constraints.append(BinConstraintD(var, Dyn, op_neq)) - - else: - num_constraints.append(BinConstraintD(t, Dyn, op_neq)) - t2_type.append(t) - - t2_type = TensorType(t2_type) # type: ignore[assignment] - - c1 = BinConstraintT(my_view, t2_type, op_eq) - c2 = CanReshape(src_var, t2_type) - - # TODO: add the extra check mentioned here: - # https://pytorch.org/docs/stable/generated/torch.Tensor.view.html#torch.Tensor.view - - return [c1, c2] + num_constraints, counter # type: ignore[operator] - - -@register_inference_rule("size") -def size_inference_rule(n: Node, symbols, constraints, counter): - """ - The constraint is just lhs = rhs. - Ex: size = input_ids.size() - """ - - - if len(n.args) == 1: - # generate the new variable - size, counter = gen_tvar(counter) - symbols[n] = size - input = symbols[n.args[0]] - c = BinConstraintT(input, size, op_eq) - return [c], counter - - elif len(n.args) == 2: - # TODO: review this rule; should input = dyn; output = dyn be included here? - if isinstance(n.args[1], int): - # generate the new variable - size_index, counter = gen_dvar(counter) - symbols[n] = size_index - input = symbols[n.args[0]] - c2 = [GetItem(i + 1, n.args[1], size_index, input) for i in range(MAX_TENSOR_RANK)] - c3 = BinConstraintD(0, size_index, op_leq) - - input_dyn = BinConstraintT(input, Dyn, op_eq) - output_dyn = BinConstraintD(size_index, Dyn, op_eq) - c1 = Conj([input_dyn, output_dyn]) - - return [Disj([c1, Conj([Disj(c2), c3])])], counter - - else: - raise NotImplementedError - - else: - raise NotImplementedError - - -def range_check(i, n): - """ - Checks if an index i is within range of a size n list - Args: - i: index - n: list size - - Returns: Boolean - """ - if i >= 0: - return T() if i < n else F() - else: - return T() if i >= n else F() - - -@register_inference_rule(torch.cumsum) -def cumsum_inference_rule(n: Node, symbols, constraints, counter): - """ - Input and output shapes should be equal - We should verify that the index is valid - """ - assert isinstance(n.args[0], Node) - arg_1 = n.args[1] if len(n.args) > 1 else n.kwargs["dim"] - assert isinstance(arg_1, int) - - output, counter = gen_tvar(counter) - symbols[n] = output - input = symbols[n.args[0]] - - input_dyn = BinConstraintT(input, Dyn, op_eq) - output_dyn = BinConstraintT(output, Dyn, op_eq) - c1 = Conj([input_dyn, output_dyn]) - c2 = [] - for i in range(1, MAX_TENSOR_RANK + 1): - new_dims, counter = gen_tensor_dims(i, counter) - - nat_constraints = gen_nat_constraints(new_dims) - - c_tensor_i = Conj([BinConstraintT(input, TensorType(new_dims), op_eq), - BinConstraintT(output, TensorType(new_dims), op_eq)] + - [range_check(arg_1, i)] + nat_constraints) - - c2.append(c_tensor_i) - dyn_or_tensor = Disj([c1, Disj(c2)]) - return [dyn_or_tensor], counter - - -@register_inference_rule(_assert_is_none) -def assert_inference_rule(n: Node, symbols, constraints, counter): - assert len(n.users) == 0 - return [], counter - - -@register_inference_rule(operator.getitem) -def getitem_inference_rule(n: Node, symbols, constraints, counter): - assert isinstance(n.args[0], Node) - - # dimension output case - if isinstance(n.args[1], int): - # create and store the new dimension variable - get_item_output, counter = gen_dvar(counter) - symbols[n] = get_item_output - - # retreive arg variables - get_item_arg = symbols[n.args[0]] - assert isinstance(get_item_arg, TVar) - - - # if the input is dynamic, we accept any index and return - # a dynamic dimension as output - input_dyn = BinConstraintT(get_item_arg, Dyn, op_eq) - output_dyn = BinConstraintD(get_item_output, Dyn, op_eq) - c1 = Conj([input_dyn, output_dyn]) - - # if the input is a tensor, - # generate a getItem constraint which will be expanded based on the - # tensor dimension. - - c2 = [GetItem(i + 1, n.args[1], get_item_output, get_item_arg) for i in range(MAX_TENSOR_RANK)] - - - # since the output is a dimension, we make sure it's a natural number - # added as a conjunction to the disjuction of c2 - c3 = BinConstraintD(0, get_item_output, op_leq) - return [Disj([c1, Conj([Disj(c2), c3])])], counter - - # tensor output case - elif isinstance(n.args[1], tuple): - # create and store the new tensor variable - get_item_output, counter = gen_tvar(counter) - symbols[n] = get_item_output - - # retreive arg variables - if n.args[0] in symbols: - get_item_arg = symbols[n.args[0]] - assert isinstance(get_item_arg, TVar) - - input_dyn = BinConstraintT(get_item_arg, Dyn, op_eq) - output_dyn = BinConstraintT(get_item_output, Dyn, op_eq) # type: ignore[assignment] - c1 = Conj([input_dyn, output_dyn]) - - c2 = [GetItemTensor(i + 1, n.args[1], get_item_output, get_item_arg) # type: ignore[misc] - for i in range(MAX_TENSOR_RANK)] - else: - # TODO: we should figure out why there is a key-error here. - return [], counter - - return [Disj([c1, *c2])], counter - - else: - raise RuntimeError('Method not yet implemented') - - -@register_inference_rule(operator.gt) -def gt_inference_rule(n: Node, symbols, constraints, counter): - assert isinstance(n.args[0], Node) or isinstance(n.args[0], int) - assert isinstance(n.args[1], Node) or isinstance(n.args[1], int) - - # We make sure this node will not be used again. We do not - # generate a constraint about that node. Only about the operands. - - e1 = symbols[n.args[0]] if isinstance(n.args[0], Node) else n.args[0] - e2 = symbols[n.args[1]] if isinstance(n.args[1], Node) else n.args[1] - - if isinstance(n.args[0], Node) and isinstance(n.args[1], Node): - if isinstance(e1, TVar) and isinstance(e2, TVar): - gt_tensor, counter = gen_tvar(counter) - symbols[n] = gt_tensor - return gen_broadcasting_constraints(e1, e2, symbols, counter, gt_tensor) - - elif isinstance(e1, DVar) and isinstance(e2, DVar): - # This is meant to be used for flow analysis only - gt_constraint = BinConstraintD(e1, e2, op_gt) - - my_gt, counter = gen_bvar(counter) - equality_constraint = BinConstraintD(my_gt, gt_constraint, op_eq) - return [equality_constraint], counter - - else: - raise RuntimeError('Sort Mismatch') - - elif isinstance(n.args[0], Node) and not isinstance(n.args[1], Node): - if isinstance(e1, DVar): - # This is meant to be used for flow analysis only - gt_constraint = BinConstraintD(e1, e2, op_gt) - - my_gt, counter = gen_bvar(counter) - equality_constraint = BinConstraintD(my_gt, gt_constraint, op_eq) - return [equality_constraint], counter - - elif isinstance(e1, TVar) and isinstance(e2, int): - # then we made the wrong assumption about the argument being a tensor - # so we should fix the assumption - warnings.warn(f'Made the wrong assumption for node {n}. Correctness not guaranteed.') - - new_e1, counter = gen_dvar(counter) - symbols[n.args[0]] = new_e1 - symbols[n.args[0]] - - gt_constraint = BinConstraintD(new_e1, e2, op_gt) - - my_gt, counter = gen_bvar(counter) - equality_constraint = BinConstraintD(my_gt, gt_constraint, op_eq) - return [equality_constraint], counter - - else: - raise NotImplementedError('Method not yet implemented') - - else: - raise NotImplementedError('Method not yet implemented') - - -@register_inference_rule(operator.eq) -def eq_inference_rule(n: Node, symbols, constraints, counter): - assert isinstance(n.args[0], Node) or isinstance(n.args[0], int) - assert isinstance(n.args[1], Node) or isinstance(n.args[1], int) - - e1 = symbols[n.args[0]] if isinstance(n.args[0], Node) else n.args[0] - e2 = symbols[n.args[1]] if isinstance(n.args[1], Node) else n.args[1] - - if isinstance(n.args[0], Node) and isinstance(n.args[1], Node): - if isinstance(e1, TVar) and isinstance(e2, TVar): - eq_tensor, counter = gen_tvar(counter) - symbols[n] = eq_tensor - return gen_broadcasting_constraints(e1, e2, symbols, counter, eq_tensor) - - elif isinstance(e1, DVar) and isinstance(e2, DVar): - # This is meant to be used for flow analysis only - eq_constraint = BinConstraintD(e1, e2, op_eq) - - my_eq, counter = gen_bvar(counter) - equality_constraint = BinConstraintD(my_eq, eq_constraint, op_eq) - return [equality_constraint], counter - - else: - raise RuntimeError('Sort Mismatch') - - elif isinstance(n.args[0], Node) and not isinstance(n.args[1], Node): - if isinstance(e1, DVar): - # This is meant to be used for flow analysis only - eq_constraint = BinConstraintD(e1, e2, op_eq) - - my_eq, counter = gen_bvar(counter) - equality_constraint = BinConstraintD(my_eq, eq_constraint, op_eq) - return [equality_constraint], counter - else: - raise NotImplementedError('Method not yet implemented') - else: - raise NotImplementedError('Method not yet implemented') - -@register_inference_rule(operator.ne) -def neq_inference_rule(n: Node, symbols, constraints, counter): - """ - Translates to inconsistent in gradual types. - To prove inequality, we should prove that - tensors are either different sizes or - disagree on at least one dimension - - This is a WIP (works when the condition - is false. We are working on making this operation work - when the condition is true as well) - """ - assert isinstance(n.args[0], Node) - assert isinstance(n.args[1], tuple) - - # implementing for size 3 and 4 - if len(n.args[1]) == 3: - - assert isinstance(n.args[1][0], Node) or isinstance(n.args[1][0], int) - assert isinstance(n.args[1][1], Node) or isinstance(n.args[1][1], int) - assert isinstance(n.args[1][2], Node) or isinstance(n.args[1][2], int) - - lhs = symbols[n.args[0]] - - b, counter = gen_tensor_dims(4, counter) - input_is_size3 = BinConstraintT(lhs, TensorType([b[0], b[1], b[2]]), op_eq) - - d1 = n.args[1][0] if isinstance(n.args[1][0], int) else symbols[n.args[1][0]] - d2 = n.args[1][1] if isinstance(n.args[1][1], int) else symbols[n.args[1][1]] - d3 = n.args[1][2] if isinstance(n.args[1][2], int) else symbols[n.args[1][2]] - - # dimensions not equal - my_ne, counter = gen_bvar(counter) - neq_1 = BinConstraintD(d1, b[0], op_neq) - neq_2 = BinConstraintD(d2, b[1], op_neq) - neq_3 = BinConstraintD(d3, b[2], op_neq) - - # dimensions inconsistent - dims_inconsistent1 = Conj([BinConstraintD(d1, Dyn, op_neq), BinConstraintD(b[0], Dyn, op_neq), neq_1]) - dims_inconsistent2 = Conj([BinConstraintD(d2, Dyn, op_neq), BinConstraintD(b[1], Dyn, op_neq), neq_2]) - dims_inconsistent3 = Conj([BinConstraintD(d3, Dyn, op_neq), BinConstraintD(b[2], Dyn, op_neq), neq_3]) - - dims_inconsistent = Disj([dims_inconsistent1, dims_inconsistent2, dims_inconsistent3]) - - # we are covering size 3 and 4 only for now - ne_constraint = Conj([input_is_size3, dims_inconsistent]) - - my_ne, counter = gen_bvar(counter) - equality_constraint = BinConstraintD(my_ne, ne_constraint, op_eq) - - elif len(n.args[1]) == 4: - - assert isinstance(n.args[1][0], Node) or isinstance(n.args[1][0], int) - assert isinstance(n.args[1][1], Node) or isinstance(n.args[1][1], int) - assert isinstance(n.args[1][2], Node) or isinstance(n.args[1][2], int) - assert isinstance(n.args[1][3], Node) or isinstance(n.args[1][3], int) - - lhs = symbols[n.args[0]] - - b1, counter = gen_dvar(counter) - b2, counter = gen_dvar(counter) - b3, counter = gen_dvar(counter) - b4, counter = gen_dvar(counter) - - input_is_size4 = BinConstraintT(lhs, TensorType([b1, b2, b3, b4]), op_eq) - - d1 = n.args[1][0] if isinstance(n.args[1][0], int) else symbols[n.args[1][0]] - d2 = n.args[1][1] if isinstance(n.args[1][1], int) else symbols[n.args[1][1]] - d3 = n.args[1][2] if isinstance(n.args[1][2], int) else symbols[n.args[1][2]] - d4 = n.args[1][3] if isinstance(n.args[1][3], int) else symbols[n.args[1][3]] - - # dimensions not equal - my_ne, counter = gen_bvar(counter) - neq_1 = BinConstraintD(d1, b1, op_neq) - neq_2 = BinConstraintD(d2, b2, op_neq) - neq_3 = BinConstraintD(d3, b3, op_neq) - neq_4 = BinConstraintD(d4, b4, op_neq) - - # dimensions to inconsistent - dims_inconsistent1 = Conj([BinConstraintD(d1, Dyn, op_neq), BinConstraintD(b1, Dyn, op_neq), neq_1]) - dims_inconsistent2 = Conj([BinConstraintD(d2, Dyn, op_neq), BinConstraintD(b2, Dyn, op_neq), neq_2]) - dims_inconsistent3 = Conj([BinConstraintD(d3, Dyn, op_neq), BinConstraintD(b3, Dyn, op_neq), neq_3]) - dims_inconsistent4 = Conj([BinConstraintD(d4, Dyn, op_neq), BinConstraintD(b3, Dyn, op_neq), neq_4]) - - dims_inconsistent = Disj([dims_inconsistent1, dims_inconsistent2, dims_inconsistent3, dims_inconsistent4]) - - ne_constraint = Conj([input_is_size4, dims_inconsistent]) - - my_ne, counter = gen_bvar(counter) - - equality_constraint = BinConstraintD(my_ne, ne_constraint, op_eq) - - else: - raise NotImplementedError('Method not yet implemented') - - return [equality_constraint], counter - - -@register_inference_rule(operator.lt) -def lt_inference_rule(n: Node, symbols, constraints, counter): - assert isinstance(n.args[0], Node) or isinstance(n.args[0], int) - assert isinstance(n.args[1], Node) or isinstance(n.args[1], int) - - # We make sure this node will not be used again. We do not - # generate a constraint about that node. Only about the operands. - - e1 = symbols[n.args[0]] if isinstance(n.args[0], Node) else n.args[0] - e2 = symbols[n.args[1]] if isinstance(n.args[1], Node) else n.args[1] - - if isinstance(n.args[0], Node) and isinstance(n.args[1], Node): - if isinstance(e1, TVar) and isinstance(e2, TVar): - lt_tensor, counter = gen_tvar(counter) - symbols[n] = lt_tensor - return gen_broadcasting_constraints(e1, e2, symbols, counter, lt_tensor) - - elif isinstance(e1, DVar) and isinstance(e2, DVar): - # This is meant to be used for flow analysis only - lt_constraint = BinConstraintD(e1, e2, op_lt) - - my_lt, counter = gen_bvar(counter) - equality_constraint = BinConstraintD(my_lt, lt_constraint, op_eq) - return [equality_constraint], counter - - else: - raise RuntimeError('Sort Mismatch') - - elif isinstance(n.args[0], Node) and not isinstance(n.args[1], Node): - if isinstance(e1, DVar): - # This is meant to be used for flow analysis only - lt_constraint = BinConstraintD(e1, e2, op_lt) - - my_lt, counter = gen_bvar(counter) - equality_constraint = BinConstraintD(my_lt, lt_constraint, op_eq) - return [equality_constraint], counter - else: - raise NotImplementedError('Method not yet implemented') - - else: - raise NotImplementedError('Method not yet implemented') - - -@register_inference_rule(torch.full) -def full_inference_rule(n: Node, symbols, constraints, counter): - full, counter = gen_tvar(counter) - symbols[n] = full - res = [] - - assert isinstance(n.args[0], Iterable) - for arg in n.args[0]: - dim = arg if isinstance(arg, int) else symbols[arg] - res.append(dim) - c = BinConstraintT(full, TensorType(list(res)), op_eq) # type: ignore[arg-type] - return [c], counter - - -# TODO normalize index -@register_inference_rule(torch.arange) -def arange_inference_rule(n: Node, symbols, constraints, counter): - start = 0 - step = 1 - - if len(n.args) == 1: - end = symbols[n.args[0]] - else: - raise NotImplementedError('Not yet implemented') - - # int((end - start) / step) - d1, counter = gen_dvar(counter) - size_constraint = BinConstraintD(d1, BinConstraintD(BinConstraintD(end, start, op_sub), step, op_div), op_eq) - arange, counter = gen_tvar(counter) - symbols[n] = arange - - # either the a parameter is a number or it is Dyn - c1 = Disj([BinConstraintD(end, Dyn, op_eq), - BinConstraintD(start, Dyn, op_eq), - BinConstraintD(step, Dyn, op_eq)]) - c2 = BinConstraintD(d1, Dyn, op_eq) - both_dyn = Conj([c1, c2]) - - c11 = Conj([BinConstraintD(end, Dyn, op_neq), - BinConstraintD(start, Dyn, op_neq), - BinConstraintD(step, Dyn, op_neq)]) - c22 = BinConstraintD(d1, Dyn, op_neq) - both_numbers = Conj([c11, c22, size_constraint]) - - return [BinConstraintT(arange, TensorType([d1]), op_eq), Disj([both_dyn, both_numbers])], counter - -def gen_broadcasting_constraints(e1, e2, symbols, counter, output_var): - # additional vars that don't correspond to expressions - e11, counter = gen_tvar(counter) - e22, counter = gen_tvar(counter) - - # generate constraints - c1 = TGreatestUpperBound(output_var, e11, e22) - c2 = ApplyBroadcasting(e11, e22, e1, e2) - c3 = BinConstraintT(e11, e22, op_consistency) - return [c1, c2, c3], counter - - -@register_inference_rule(operator.mul) -@register_inference_rule(torch.ne) -@register_inference_rule("ne") -@register_inference_rule(torch.add) -@register_inference_rule(operator.add) -def broadcasting_inference_rule(n: Node, symbols, constraints, counter): - - op_code = None - if n.target == operator.add or n.target == torch.add: - op_code = op_add - elif n.target == operator.mul: - op_code = op_mul - - if isinstance(n.args[0], Node) and isinstance(n.args[1], Node): - if isinstance(symbols[n.args[0]], TVar) and isinstance(symbols[n.args[1]], TVar): - my_output, counter = gen_tvar(counter) - symbols[n] = my_output - e1 = symbols[n.args[0]] - e2 = symbols[n.args[1]] - - return gen_broadcasting_constraints(e1, e2, symbols, counter, my_output) - else: - raise NotImplementedError('Method not yet implemented') - - elif isinstance(n.args[0], Node) and (isinstance(n.args[1], int) or isinstance(n.args[1], float)): - if isinstance(symbols[n.args[0]], TVar): - my_output, counter = gen_tvar(counter) - symbols[n] = my_output - e1 = symbols[n.args[0]] - return [BinConstraintT(my_output, e1, op_eq)], counter - elif isinstance(symbols[n.args[0]], DVar): - my_output, counter = gen_dvar(counter) - symbols[n] = my_output - e1 = symbols[n.args[0]] - - # we will propagate the runtime value here since this is regular addition - c = Conj([BinConstraintD(my_output, BinConstraintD(e1, n.args[1], op_code), op_eq), - BinConstraintD(0, my_output, op_leq)]) - return [c], counter - - elif isinstance(n.args[1], Node) and (isinstance(n.args[0], int) or isinstance(n.args[1], float)): - if isinstance(symbols[n.args[1]], TVar): - my_output, counter = gen_tvar(counter) - symbols[n] = my_output - e2 = symbols[n.args[1]] - return [BinConstraintT(my_output, e2, op_eq)], counter - elif isinstance(symbols[n.args[1]], DVar): - my_output, counter = gen_dvar(counter) - symbols[n] = my_output - e2 = symbols[n.args[1]] - - # we will propagate the runtime value here since this is regular addition - c = Conj([BinConstraintD(my_output, BinConstraintD(e2, n.args[0], op_code), op_eq), - BinConstraintD(0, my_output, op_leq)]) - return [c], counter - - else: - raise NotImplementedError('Method not yet implemented') - - else: - # TODO generate add constraints for scalar addition - raise NotImplementedError('Addition not yet implemented') - - -@register_inference_rule(torch.flatten) -def flatten_inference_rule(n: Node, symbols, constraints, counter): - assert isinstance(n.args[0], Node) - - # generate the new variable - flattened, counter = gen_tvar(counter) - symbols[n] = flattened - - input = symbols[n.args[0]] - - # set the default start and end dims - start_dim = 1 - end_dim = -1 - - if len(n.args) > 1: - assert isinstance(n.args[1], int) - start_dim = n.args[1] - - if len(n.args) > 2: - assert isinstance(n.args[2], int) - end_dim = n.args[2] - - c1 = BinConstraintT(input, Dyn, op_eq) - c2 = BinConstraintT(flattened, Dyn, op_eq) - both_dyn = Conj([c1, c2]) - - const = [] - for i in range(1, MAX_TENSOR_RANK + 1): - c, counter = generate_flatten_constraints(start_dim, end_dim, input, flattened, i, counter) - const.append(c) - - return [Disj([both_dyn, *const])], counter - - -@register_inference_rule(torch.nn.functional.layer_norm) -def layer_norm_functional(n: Node, symbols, constraints, counter): - """ - We generate the constraint: input = output - """ - assert isinstance(n.args[0], Node) - return gen_layer_norm_constraints(n, n.args[1], symbols, counter) - - -@register_inference_rule(torch.nn.LayerNorm) -def layer_norm_inference_rule(n: Node, module_instance, symbols, constraints, counter): - """ - Input and output shapes should be equal. - Input should be consistent with the normalized_shape - """ - assert isinstance(n.args[0], Node) - return gen_layer_norm_constraints(n, module_instance.normalized_shape, symbols, counter) - - -def gen_layer_norm_constraints(n: Node, normalized_shape, symbols, counter): - output, counter = gen_tvar(counter) - symbols[n] = output - input = symbols[n.args[0]] - - input_dyn = BinConstraintT(input, Dyn, op_eq) - output_dyn = BinConstraintT(output, Dyn, op_eq) - - c1 = Conj([input_dyn, output_dyn]) - - c2 = [] - for i in range(1, MAX_TENSOR_RANK + 1): - new_dims_rhs, counter = gen_tensor_dims(i, counter) - nat_constraints = gen_nat_constraints(new_dims_rhs) - - c_tensor_i = Conj([BinConstraintT(input, TensorType(new_dims_rhs), op_eq), - BinConstraintT(output, TensorType(new_dims_rhs), op_eq)] + - add_layer_norm_constraints(new_dims_rhs, list(normalized_shape)) + - nat_constraints) - c2.append(c_tensor_i) - return [Disj([c1, Disj(c2)])], counter - -@register_inference_rule(torch.nn.Dropout) -@register_inference_rule(torch.nn.ReLU) -def relu_inference_rule(n: Node, module_instance, symbols, constraints, counter): - """ - Input and output shapes should be equal. - """ - assert isinstance(n.args[0], Node) - output, counter = gen_tvar(counter) - symbols[n] = output - input = symbols[n.args[0]] - assert isinstance(input, TVar) - return [BinConstraintT(input, output, op_eq)], counter - - -@register_inference_rule(torch.nn.Linear) -def linear_inference_rule(n: Node, module_instance, symbols, constraints, counter): - """ - Input and output sizes should be the same except for the last dimension - If the input is Dyn, then so should the output - """ - assert isinstance(n.args[0], Node) - return linear_constraints(n, module_instance.in_features, module_instance.out_features, symbols, counter) - - -@register_inference_rule("dim") # type: ignore[attr-defined] -def torch_dim_inference_rule(n: Node, symbols, constraints, counter): - assert isinstance(n.args[0], Node) - my_dim, counter = gen_dvar(counter) - symbols[n] = my_dim - input = symbols[n.args[0]] - - input_dyn = BinConstraintT(input, Dyn, op_eq) - output_dyn = BinConstraintD(my_dim, Dyn, op_eq) - - c1 = [] - - for i in range(1, MAX_TENSOR_RANK + 1): - new_dims_rhs_1, counter = gen_tensor_dims(i, counter) - - c_tensor_i = Conj([BinConstraintT(input, TensorType(new_dims_rhs_1), op_eq), - BinConstraintD(my_dim, i, op_eq)]) - c1.append(c_tensor_i) - - return [Disj([Conj([input_dyn, output_dyn]), Disj(c1)])], counter - - -@register_inference_rule(torch._C._nn.linear) # type: ignore[attr-defined] -def torch_linear_inference_rule(n: Node, symbols, constraints, counter): - assert isinstance(n.args[0], Node) - weight_dims, counter = gen_tensor_dims(2, counter) - equality_constraint = BinConstraintT(symbols[n.args[1]], TensorType(weight_dims), op_eq) - constraints, counter = linear_constraints(n, weight_dims[1], weight_dims[0], symbols, counter) - return [equality_constraint] + constraints, counter - - -def linear_constraints(n: Node, in_features, out_features, symbols, counter): - linear_output, counter = gen_tvar(counter) - symbols[n] = linear_output - linear_input = symbols[n.args[0]] - - input_dyn = BinConstraintT(linear_input, Dyn, op_eq) - output_dyn = BinConstraintT(linear_output, Dyn, op_eq) - - c1 = Conj([input_dyn, output_dyn]) - - c2 = [] - for i in range(1, MAX_TENSOR_RANK + 1): - new_dims_rhs_1, counter = gen_tensor_dims(i, counter) - new_dims_rhs_2, counter = gen_tensor_dims(i, counter) - - nat_constraints = gen_nat_constraints(new_dims_rhs_1 + new_dims_rhs_2) - - c_tensor_i = Conj([BinConstraintT(linear_input, TensorType(new_dims_rhs_1), op_eq), - BinConstraintT(linear_output, TensorType(new_dims_rhs_2), op_eq)] + - add_linear_constraints(new_dims_rhs_1, new_dims_rhs_2, in_features, out_features) + - nat_constraints) - c2.append(c_tensor_i) - return [Disj([c1, Disj(c2)])], counter - -def add_layer_norm_constraints(input_dim, normalized_dim): - """ - The constraints say that the type has te form: [*, 1024, 1024] - while the normalized_dim have the form [1024, 1024] - Args: - input_dim: Input shape of layer norm - normalized_dim: normalized_dim parameter of the module instance - - """ - - # in this case we return false since there's a pattern mismatch - if len(normalized_dim) > len(input_dim): - return [F()] - - else: - constraints = [] - for i, n in zip(reversed(input_dim), reversed(normalized_dim)): - constraints.append(BinConstraintD(i, n, op_consistency)) - return constraints - - -def add_linear_constraints(dims1, dims2, in_features, out_features): - assert len(dims1) == len(dims2) - constraints = [] - for i in range(len(dims1)): - if i == len(dims1) - 1: - constraints.append(BinConstraintD(dims1[i], in_features, op_consistency)) - constraints.append(BinConstraintD(dims2[i], out_features, op_eq)) - else: - constraints.append(BinConstraintD(dims1[i], dims2[i], op_eq)) - - return constraints - - -@register_inference_rule(torch.reshape) -def reshape_inference_rule(n: Node, symbols, constraints, counter): - assert isinstance(n.args[0], Node) - - # generate the new variable - my_reshape, counter = gen_tvar(counter) - symbols[n] = my_reshape - - src_var = symbols[n.args[0]] - t2 = n.args[1] - t2_type = TensorType([Dyn if elem == -1 else elem for elem in t2]) # type: ignore[union-attr] - c1 = BinConstraintT(my_reshape, t2_type, op_eq) # type: ignore[union-attr] - c2 = CanReshape(src_var, t2_type) - - return [c1, c2], counter - - -@register_inference_rule(BatchNorm2d) -def batchnorm_inference_rule(n: Node, module_instance, symbols, constraints, counter): - assert isinstance(n.args[0], Node) - - # generate the new variable - batchnorm_output, counter = gen_tvar(counter) - symbols[n] = batchnorm_output - batchnorm_input = symbols[n.args[0]] - - # dim vars - d1, counter = gen_dvar(counter) - d2, counter = gen_dvar(counter) - d3, counter = gen_dvar(counter) - d4, counter = gen_dvar(counter) - - nat_constraints = gen_nat_constraints([d1, d2, d3, d4]) - - c1 = BinConstraintT(batchnorm_input, TensorType([d1, d2, d3, d4]), op_matching) - c2 = BinConstraintT(batchnorm_input, batchnorm_output, op_eq) - return [c1, c2, *nat_constraints], counter - - -@register_inference_rule(torch.nn.AdaptiveAvgPool2d) -def adaptive_inference_rule(n: Node, module_instance, symbols, constraints, counter): - assert isinstance(n.args[0], Node) - - avg_pool, counter = gen_tvar(counter) - - symbols[n] = avg_pool - input_var = symbols[n.args[0]] - - # dim vars - d1, counter = gen_dvar(counter) - d2, counter = gen_dvar(counter) - d3, counter = gen_dvar(counter) - d4, counter = gen_dvar(counter) - nat_constraints = gen_nat_constraints([d1, d2, d3, d4]) - c1 = BinConstraintT(input_var, TensorType([d1, d2, d3, d4]), op_matching) - c2 = BinConstraintT(avg_pool, TensorType([d1, d2, module_instance.output_size[0], module_instance.output_size[1]]), op_eq) - - return [c1, c2, *nat_constraints], counter - - -@register_inference_rule(Conv2d) -def conv2d_inference_rule(n: Node, module_instance, symbols, constraints, counter): - assert isinstance(n.args[0], Node) - - my_conv, counter = gen_tvar(counter) - symbols[n] = my_conv - input_var = symbols[n.args[0]] - - # dim vars - [d1, d2, d3, d4], counter = gen_tensor_dims(MAX_TENSOR_RANK, counter) - - # c1 = Matching(input_var, TensorType([d1, d2, d3, d4])) - c1 = BinConstraintT(input_var, TensorType([d1, d2, d3, d4]), op_matching) - - # c2 = DConsistency(module_instance.in_channels, d2) - c2 = BinConstraintD(module_instance.in_channels, d2, op_consistency) - - c3 = CalcConv(my_conv, input_var, - module_instance.out_channels, - module_instance.kernel_size, - module_instance.padding, - module_instance.stride, - module_instance.dilation, [d1, d2, d3, d4]) - - nat_constraints = gen_nat_constraints([d1, d2, d3, d4]) - - return [c1, c2, c3, *nat_constraints], counter - - -@register_inference_rule(torch.nn.MaxPool2d) -def maxpool_inference_rule(n: Node, module_instance, symbols, constraints, counter): - assert isinstance(n.args[0], Node) - maxpool, counter = gen_tvar(counter) - symbols[n] = maxpool - input_var = symbols[n.args[0]] - - # dim vars - [d1, d2, d3, d4], counter = gen_tensor_dims(MAX_TENSOR_RANK, counter) - - c1 = BinConstraintT(input_var, TensorType([d1, d2, d3, d4]), op_matching) - - c2 = CalcMaxPool(maxpool, input_var, module_instance.kernel_size, module_instance.padding, - module_instance.stride, module_instance.dilation, [d1, d2, d3, d4]) - - nat_constraints = gen_nat_constraints([d1, d2, d3, d4]) - - return [c1, c2, *nat_constraints], counter - - -class ConstraintGenerator: - def __init__(self, traced, graph=None): - self.traced = traced # traced or tracer.root - self.traced_params = dict(self.traced.named_parameters()) - self.constraints = [] - self.symbol_dict = {} - self.graph = traced.graph if hasattr(traced, 'graph') else graph - - - def generate_constraints(self, counter=0): - """ - Iterate through every node and generate constraints - Effect: self.constraints will be populated with the final constraints - """ - graph = self.graph - - all_constraints = [] - - for n in graph.nodes: - (constraints, counter) = self.generate_constraints_node(n, counter) - all_constraints += constraints - - return Conj(all_constraints), counter - - def generate_constraints_node(self, n: Node, counter): - """ - Generate constraints the given node: - Currently supported operations: - - Reshape - - Add - - conv2d - """ - - if n.op == 'placeholder': - x, counter = gen_tvar(counter) - self.symbol_dict[n] = x - - my_type = n.type - - if n.type != Dyn and (not isinstance(n.type, TensorType)): - if n.type == torch.nn.parameter.Parameter: - # since we have a parameter, the shape must be static - assert 'example_value' in n.meta - my_type = TensorType(n.meta['example_value'].size()) - else: - my_type = Dyn - - c1 = BinConstraintT(my_type, x, op_precision) - c2 = BinConstraintT(x, MAX_TENSOR_RANK, op_leq) - return [c1, c2], counter - - elif n.op == 'call_function': - if n.target in _INFERENCE_RULES: - return _INFERENCE_RULES[n.target](n, self.symbol_dict, self.constraints, counter) - else: - raise RuntimeError(f'No inference rule registered for target {n.target}!') - - elif n.op == 'call_module': - - module_instance = self.traced.get_submodule(n.target) - if type(module_instance) in _INFERENCE_RULES: - return _INFERENCE_RULES[type(module_instance)](n, - module_instance, - self.symbol_dict, - self.constraints, counter) - else: - raise RuntimeError(f'No inference rule registered for class {type(module_instance)}!') - - elif n.op == 'call_method': - if n.target in _INFERENCE_RULES: - return _INFERENCE_RULES[n.target](n, self.symbol_dict, self.constraints, counter) - else: - raise RuntimeError(f'No inference rule registered for target {n.target}!') - - elif n.op == 'get_attr': - t = self.traced_params.get(n.target, None) - - if isinstance(t, torch.Tensor): - if len(t.shape) > 0: - res = [] - for t in t.shape: - res.append(t) - attr_type = TensorType(res) - output, counter = gen_tvar(counter) - self.symbol_dict[n] = output - return [BinConstraintT(output, attr_type, op_eq)], counter - else: - # scalar? - return [], counter - else: - return [], counter - - elif n.op == 'output': - return [], counter - - else: - raise NotImplementedError(f"Method {n.op} not yet implemented") diff --git a/pippy/fx/experimental/migrate_gradual_types/constraint_transformation.py b/pippy/fx/experimental/migrate_gradual_types/constraint_transformation.py deleted file mode 100644 index ec40a41f6..000000000 --- a/pippy/fx/experimental/migrate_gradual_types/constraint_transformation.py +++ /dev/null @@ -1,1041 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates -# mypy: ignore-errors -import copy -import itertools -from pippy.fx.experimental.migrate_gradual_types.constraint_generator import BinConstraintT, MAX_TENSOR_RANK -from pippy.fx.experimental.migrate_gradual_types.constraint import T, BinConstraintD, Conj, Constraint, DVar, TVar, \ - Transpose -from pippy.fx.experimental.migrate_gradual_types.constraint import Disj, TGreatestUpperBound -from pippy.fx.experimental.migrate_gradual_types.constraint import DGreatestUpperBound -from pippy.fx.experimental.migrate_gradual_types.constraint import CalcConv, CalcMaxPool -from pippy.fx.experimental.migrate_gradual_types.constraint import CalcProduct, CanReshape -from pippy.fx.experimental.migrate_gradual_types.constraint import ApplyBroadcasting, Prod, F, GetItem, GetItemTensor, IndexSelect -from pippy.fx.experimental.migrate_gradual_types.operation import op_eq, op_precision, op_leq, op_matching -from pippy.fx.experimental.migrate_gradual_types.operation import op_consistency, op_neq -from pippy.fx.experimental.migrate_gradual_types.operation import op_mul, op_add, op_sub, op_div, op_mod -from pippy.fx.experimental.migrate_gradual_types.util import gen_tensor_dims, gen_nat_constraints, gen_dvar -from pippy.fx.tensor_type import TensorType, Dyn -from typing import Callable, Dict, List - -_TRANSFORMATION_RULES: Dict[Constraint, Callable] = {} - - -def register_transformation_rule(call_target): - def register(fn): - if call_target in _TRANSFORMATION_RULES: - raise RuntimeError(f'Transformation rule already registered for {call_target}!') - _TRANSFORMATION_RULES[call_target] = fn - return fn - return register - - -def valid_index(index, dims): - """ - Given a list of dimensions, checks if an index is valid in the list - """ - try: - dims[index] - return T() - except IndexError: - return F() - - -@register_transformation_rule(Transpose) -def transform_transpose(constraint, counter): - """ - Similar to a sequence of two index-selects - """ - dims, counter = gen_tensor_dims(constraint.tensor_size, counter) - is_valid_index1 = valid_index(constraint.index1, dims) - is_valid_index2 = valid_index(constraint.index2, dims) - new_dims = copy.deepcopy(dims) - nat_constraints = gen_nat_constraints(dims) - - if is_valid_index1 == T() and is_valid_index2 == T(): - new_dims[constraint.index1] = dims[constraint.index2] - new_dims[constraint.index2] = dims[constraint.index1] - - transformed_constraint = Conj([BinConstraintT(constraint.input_var, TensorType(dims), op_eq), - *nat_constraints, - is_valid_index1, is_valid_index2, - BinConstraintT(constraint.output, TensorType(new_dims), op_eq)]) - return transformed_constraint, counter - - -@register_transformation_rule(IndexSelect) -def transform_index_select(constraint, counter): - """ - The constraints consider the given tensor size, checks if the index is valid - and if so, generates a constraint for replacing the input dimension - with the required dimension - """ - dims, counter = gen_tensor_dims(constraint.tensor_size, counter) - is_valid_index = valid_index(constraint.index, dims) - nat_constraints = gen_nat_constraints(dims) - - # if the index is valid then replace the input dimension with the new dimension - # otherwise the dimension will not be replaced and the clause will contain False - if is_valid_index == T(): - new_dims = copy.deepcopy((dims)) - new_dims[constraint.index] = constraint.dim_replace - - transformed_constraint = Conj([BinConstraintT(constraint.input_var, TensorType(dims), op_eq), - *nat_constraints, - is_valid_index, - BinConstraintT(constraint.output, TensorType(new_dims), op_eq)]) - - # print(constraints) - return transformed_constraint, counter - - -@register_transformation_rule(GetItem) -def transform_get_item(constraint, counter): - """ - generate an equality of the form: - t = [a1, ..., an] - then generate constraints that check if the given index is valid - given this particular tensor size. - If the index is valid, generate a constraint to get the item - Note that we already handled the Dyn input case in the previous - step. - Args: - constraint: GetItem which assumes we are getting an item from a tensor (not Dyn) - counter: variable tracking - Returns: simplified constraints for GetItem - - """ - dims, counter = gen_tensor_dims(constraint.tensor_size, counter) - nat_constraints = gen_nat_constraints(dims) - - - is_valid_index = valid_index(constraint.index, dims) - - all_constraints = [BinConstraintT(constraint.input_var, TensorType(dims), op_eq), - *nat_constraints, - is_valid_index] - - # if the index is valid, we generate a constraint for getting an item - # otherwise this clause will have been UNSAT due to the wrong index - if is_valid_index == T(): - all_constraints.append(BinConstraintD(constraint.res, dims[constraint.index], op_eq)) - - return Conj(all_constraints), counter - -def valid_index_tensor(index, dims): - """ - if the slice instances exceed the length of the dimensions - then this is a type error so we return False - """ - slice_count = 0 - for s in index: - if isinstance(s, slice): - slice_count += 1 - if slice_count > len(dims): - return F() - else: - return T() - -@register_transformation_rule(GetItemTensor) -def transform_get_item_tensor(constraint, counter): - """ - When the index is a tuple, then the output will be a tensor - TODO: we have to check if this is the case for all HF models - - The cases we are covrering here are a tuple with one of: - - slice with default argument - - None - - None appends 1 to the input tensor dimensions - so each occurrence of 'None' increases the rank by 1 - - slice with default arguments does not change the rank - """ - assert isinstance(constraint.index_tuple, tuple) - - - # generate a result tensor of the expected size - dims, counter = gen_tensor_dims(constraint.tensor_size, counter) - nat_constraints = gen_nat_constraints(dims) - - # generate a place-holder list of the right rank - # where "slice" does not contribute to the rank and "None" does - none_c = constraint.index_tuple.count(None) - resulting_tensor_dims = (none_c + len(dims)) * [None] - - dim_index = 0 - for i in range(len(constraint.index_tuple)): - - # append 1 to the right location of the resulting tensor - if constraint.index_tuple[i] is None: - resulting_tensor_dims[i] = 1 - - elif constraint.index_tuple[i] == slice(None, None, None): - pass - - else: - raise NotImplementedError('Method not yet implemented') - - # append the remaining dimensions to the right location - dim_index = 0 - for i in range(len(resulting_tensor_dims)): - if resulting_tensor_dims[i] is None: - resulting_tensor_dims[i] = dims[dim_index] - dim_index += 1 - - # check if the index is valid - is_valid_index = valid_index_tensor(constraint.index_tuple, dims) - - # check if the resulting tensor is within bounds - if len(resulting_tensor_dims) > 4: - return F(), counter - - else: - constraints = [BinConstraintT(constraint.input_var, TensorType(dims), op_eq), - BinConstraintT(constraint.res, TensorType(resulting_tensor_dims), op_eq), - *nat_constraints, - is_valid_index] - return Conj(constraints), counter - - -@register_transformation_rule(BinConstraintT) -def generate_binconstraint_t(constraint, counter): - """ - Transform binary constraints for tensors - """ - - # precision constraints - if constraint.op == op_precision: - if constraint.lhs == Dyn: - return T(), counter - elif isinstance(constraint.lhs, TensorType): - is_fully_static = all([d != Dyn for d in constraint.lhs.__args__]) - if is_fully_static: - return BinConstraintT(constraint.lhs, constraint.rhs, op_eq), counter - else: - new_dims = [] - - for _ in range(len(constraint.lhs.__args__)): - dim, counter = gen_dvar(counter) - new_dims.append(dim) - - new_dim_constraints = [BinConstraintD(old_dim, new_dim, op_precision) for - new_dim, old_dim in zip(new_dims, constraint.lhs.__args__)] + \ - [BinConstraintT(constraint.rhs, TensorType(new_dims), op_eq)] + \ - [BinConstraintD(1, new_dim, op_leq) for - new_dim in new_dims] - return Conj(new_dim_constraints), counter - - # matching - elif constraint.op == op_matching: - assert isinstance(constraint.rhs, TensorType) - d1 = constraint.rhs.__args__[0] - d2 = constraint.rhs.__args__[1] - d3 = constraint.rhs.__args__[2] - d4 = constraint.rhs.__args__[3] - - conj = [BinConstraintT(constraint.lhs, Dyn, op_eq), - BinConstraintD(d1, Dyn, op_eq), - BinConstraintD(d2, Dyn, op_eq), - BinConstraintD(d3, Dyn, op_eq), - BinConstraintD(d4, Dyn, op_eq)] - return Disj([Conj(conj), - BinConstraintT(constraint.lhs, TensorType([d1, d2, d3, d4]), op_eq)]), counter - - elif constraint.op == op_consistency: - c_dyn = Disj([BinConstraintT(constraint.lhs, Dyn, op_eq), BinConstraintT(constraint.rhs, Dyn, op_eq)]) - [c_tensor_1, c_tensor_2, c_tensor_3, c_tensor_4], counter = gen_consistency_constraints(constraint, counter) - - return Disj([c_dyn, c_tensor_1, c_tensor_2, c_tensor_3, c_tensor_4]), counter - - elif constraint.op == op_leq: - assert isinstance(constraint.rhs, int) - disj = [BinConstraintT(constraint.lhs, Dyn, op_eq)] - for i in range(1, constraint.rhs + 1): - dims = [] - for j in range(1, i + 1): - dim_var, counter = gen_dvar(counter) - dims.append(dim_var) - disj.append(BinConstraintT(constraint.lhs, TensorType(dims), op_eq)) - return Disj(disj), counter - else: - return constraint, counter - - -@register_transformation_rule(BinConstraintD) -def generate_binconstraint_d(constraint, counter): - """ - Transform binary constraints for dimensions - """ - if constraint.op == op_precision: - if isinstance(constraint.lhs, int): - return BinConstraintD(constraint.lhs, constraint.rhs, op_eq), counter - elif constraint.lhs == Dyn: - return T(), counter - - elif constraint.op == op_consistency: - return Disj([BinConstraintD(constraint.lhs, constraint.rhs, op_eq), - BinConstraintD(constraint.rhs, Dyn, op_eq), BinConstraintD(constraint.lhs, Dyn, op_eq)]), counter - - else: - return constraint, counter - - -@register_transformation_rule(Conj) -def generate_conj(constraint, counter): - """ - Transform conjunctions - """ - new = [] - for c in constraint.conjucts: - new_c, counter = transform_constraint(c, counter) - new.append(new_c) - return Conj(new), counter - - -@register_transformation_rule(Disj) -def generate_disj(constraint, counter): - """ - Transform disjunctions - """ - new = [] - for c in constraint.disjuncts: - new_c, counter = transform_constraint(c, counter) - new.append(new_c) - return Disj(new), counter - - -@register_transformation_rule(TGreatestUpperBound) -def generate_gub(constraint, counter): - """ - Transform greatest upper bound for tensors. Results in equality and Greatest Upper Bound - on dimensions - """ - c1 = Conj([Disj([BinConstraintT(constraint.rhs1, Dyn, op_eq), - BinConstraintT(constraint.rhs2, Dyn, op_eq)]), BinConstraintT(constraint.res, Dyn, op_eq)]) - - [c2, c3, c4, c5], counter = gen_greatest_upper_bound(constraint, counter) - - return Disj([c1, c2, c3, c4, c5]), counter - - -@register_transformation_rule(DGreatestUpperBound) -def generate_d_gub(constraint, counter): - """ - Transform greatest upper bound for dimensions into equality constraints - """ - c1 = Conj([BinConstraintD(constraint.rhs1, Dyn, op_eq), BinConstraintD(constraint.res, constraint.rhs2, op_eq)]) - c2 = Conj([BinConstraintD(constraint.rhs2, Dyn, op_eq), BinConstraintD(constraint.res, constraint.rhs1, op_eq)]) - c3 = Conj([BinConstraintD(constraint.rhs2, constraint.rhs1, op_eq), BinConstraintD(constraint.res, constraint.rhs1, op_eq)]) - return Disj([c1, c2, c3]), counter - - -@register_transformation_rule(CalcConv) -def generate_calc_conv(constraint, counter): - d, counter = gen_tensor_dims(4, counter) - conv_result = TensorType([d[0], d[1], d[2], d[3]]) - - # the convolution result is a tensor of size 4 - c1 = BinConstraintT(constraint.conv_result, conv_result, op_eq) - - # the second dimension of the output is equal to the output channels - c2 = Conj([BinConstraintD(d[1], constraint.c_out, op_eq), BinConstraintD(d[1], Dyn, op_neq)]) - - # the input corresponds to the output in the first dimension of the convolution - c3 = BinConstraintD(constraint.matching_constraint[0], d[0], op_eq) - - c4, c5 = calc_last_two_dims(constraint, d) - - leq_constraints = Conj([BinConstraintD(0, d[0], op_leq), - BinConstraintD(0, d[1], op_leq), - BinConstraintD(0, d[2], op_leq), - BinConstraintD(0, d[3], op_leq)]) - - return Conj([c1, c2, c3, c4, c5, leq_constraints]), counter - - -@register_transformation_rule(CalcMaxPool) -def generate_calc_maxpool(constraint, counter): - """ - Transform maxpool constraints - """ - d, counter = gen_tensor_dims(4, counter) - maxpool_result = TensorType([d[0], d[1], d[2], d[3]]) - - # the maxpool result is a tensor of size 4 - c1 = BinConstraintT(constraint.maxpool_result, maxpool_result, op_eq) - - # the input corresponds to the output in the first and second dimension of maxpool - c2 = BinConstraintD(constraint.matching_constraint[1], d[1], op_eq) - c3 = BinConstraintD(constraint.matching_constraint[0], d[0], op_eq) - c4, c5 = calc_last_two_dims(constraint, d) - - leq_constraints = Conj([BinConstraintD(0, d[0], op_leq), - BinConstraintD(0, d[1], op_leq), - BinConstraintD(0, d[2], op_leq), - BinConstraintD(0, d[3], op_leq)]) - - return Conj([c1, c2, c3, c4, c5, leq_constraints]), counter - - -@register_transformation_rule(CalcProduct) -def generate_calc_product(constraint, counter): - """ - Transform flatten constraints - """ - start = constraint.start - end = constraint.end - dims = constraint.dims_to_flatten - flattened = constraint.flattened - n = len(constraint.dims_to_flatten) - - # this will be evaluated right here - boundary_check = (0 <= start and start < end and end <= n) - - c_boundary = T() if boundary_check else F() - - lhs = dims[0:start] - rhs = dims[end:] - mid = dims[start:end] - - all_possibilities = generate_all_int_dyn_dim_possibilities(mid) - - all_constraints = [] - - for p in all_possibilities: - p = list(p) - # this tells us there is a dynamic variable - contains_dyn = not(all([constraint.op == op_neq for constraint in p])) - if contains_dyn: - mid_var = [Dyn] - total_constraints = lhs + mid_var + rhs - if len(total_constraints) > 4: - all_constraints.append(F()) - else: - all_constraints.append(Conj([BinConstraintT(flattened, TensorType(lhs + mid_var + rhs), op_eq)] + p)) - else: - new_var, counter = gen_dvar(counter) - mid_eq_prod = Conj([BinConstraintD(new_var, Prod(mid), op_eq), BinConstraintD(new_var, Dyn, op_neq)]) - mid_var = [new_var] - total_constraints = lhs + mid_var + rhs - if len(total_constraints) > 4: - all_constraints.append(F()) - else: - all_constraints.append(Conj([BinConstraintT(flattened, TensorType(lhs + mid_var + rhs), op_eq), mid_eq_prod] + p)) - - return Conj([Disj(all_constraints), c_boundary]), counter - - -@register_transformation_rule(CanReshape) -def generate_reshape(constraint, counter): - """ - Transform reshape constraints - """ - d, counter = gen_tensor_dims(4, counter) - - d1 = d[0] - d2 = d[1] - d3 = d[2] - d4 = d[3] - - target = constraint.target.__args__ - - is_fully_static = all([d != Dyn for d in target]) - - # dynamic tensor - c1_dyn = BinConstraintT(constraint.src, Dyn, op_eq) - c2_tensor1 = BinConstraintT(constraint.src, TensorType([d1]), op_eq) - c2_tensor2 = BinConstraintT(constraint.src, TensorType([d1, d2]), op_eq) - c2_tensor3 = BinConstraintT(constraint.src, TensorType([d1, d2, d3]), op_eq) - c2_tensor4 = BinConstraintT(constraint.src, TensorType([d1, d2, d3, d4]), op_eq) - - d1_eq_dyn = BinConstraintD(d1, Dyn, op_eq) - d1_neq_dyn = BinConstraintD(d1, Dyn, op_neq) - - d2_eq_dyn = BinConstraintD(d2, Dyn, op_eq) - d2_neq_dyn = BinConstraintD(d2, Dyn, op_neq) - - d3_eq_dyn = BinConstraintD(d3, Dyn, op_eq) - d3_neq_dyn = BinConstraintD(d3, Dyn, op_neq) - - d4_eq_dyn = BinConstraintD(d3, Dyn, op_eq) - d4_neq_dyn = BinConstraintD(d3, Dyn, op_neq) - - nat_d1 = BinConstraintD(0, d1, op_leq) - nat_d2 = BinConstraintD(0, d2, op_leq) - nat_d3 = BinConstraintD(0, d3, op_leq) - nat_d4 = BinConstraintD(0, d4, op_leq) - - if is_fully_static: - # size 1 tensor - c3_tensor1 = Disj([d1_eq_dyn, - (Conj([d1_neq_dyn, - BinConstraintD(d1, Prod(target), op_eq)]))]) - all_tensor_1 = Conj([c2_tensor1, c3_tensor1]) - - # size 2 tensor - all_tensor_2 = Conj([c2_tensor2, gen_all_reshape_possibilities([d1, d2], target)]) - - # size 3 tensor - all_tensor_3 = Conj([c2_tensor3, gen_all_reshape_possibilities([d1, d2, d3], target)]) - - # size 4 tensor - all_tensor_4 = Conj([c2_tensor4, gen_all_reshape_possibilities([d1, d2, d3, d4], target)]) - - return Conj([Disj([c1_dyn, all_tensor_1, all_tensor_2, all_tensor_3, all_tensor_4]), - nat_d1, nat_d2, nat_d3, nat_d4]), counter - - # then there must be exactly one occurrence of dyn - else: - new_target = [] - - for n in target: - if n != Dyn: - new_target.append(n) - - # tensor 1 - c3_tensor1 = Disj([d1_eq_dyn, - (Conj([d1_neq_dyn, - is_dim_div_by_target(new_target, d1)]))]) - all_tensor_1 = Conj([c2_tensor1, c3_tensor1]) - - # tensor 2 - c21 = Disj([d1_eq_dyn, d2_eq_dyn]) - c22 = Conj([d1_neq_dyn, d2_neq_dyn, is_dim_div_by_target(new_target, Prod([d1, d2]))]) - all_tensor_2 = Conj([c2_tensor2, Disj([c21, c22])]) - - # tensor 3 - c31 = Disj([d1_eq_dyn, d2_eq_dyn, d3_eq_dyn]) - c32 = Conj([d1_neq_dyn, d2_neq_dyn, d3_neq_dyn, is_dim_div_by_target(new_target, Prod([d1, d2, d3]))]) - all_tensor_3 = Conj([c2_tensor3, Disj([c31, c32])]) - - # tensor 4 - c41 = Disj([d1_eq_dyn, d2_eq_dyn, d3_eq_dyn, d4_eq_dyn]) - c42 = Conj([d1_neq_dyn, d2_neq_dyn, d3_neq_dyn, d4_neq_dyn, is_dim_div_by_target(new_target, Prod([d1, d2, d3, d4]))]) - all_tensor_4 = Conj([c2_tensor4, Disj([c41, c42])]) - - return Conj([Disj([c1_dyn, all_tensor_1, all_tensor_2, all_tensor_3, all_tensor_4]), - nat_d1, nat_d2, nat_d3, nat_d4]), counter - - -@register_transformation_rule(ApplyBroadcasting) -def generate_broadcasting(constraint, counter): - """ - Transform broadcasting constraints - """ - e11, e12 = constraint.res1, constraint.res2 - e1, e2 = constraint.input1, constraint.input2 - - e1_dyn = BinConstraintT(e1, Dyn, op_eq) - e2_dyn = BinConstraintT(e2, Dyn, op_eq) - - # Introduce dimensions - e1_equal_e11 = BinConstraintT(e1, e11, op_eq) - e2_equal_e12 = BinConstraintT(e2, e12, op_eq) - - # dyn possibility - e1_dyn_constraint = Conj([e1_dyn, e1_equal_e11, e2_equal_e12]) - e2_dyn_constraint = Conj([e2_dyn, e1_equal_e11, e2_equal_e12]) - - # tensor possibility - # generate dimensions to create tensors of size 1 - final_tensor_1_constraint, _, _, nat_dims_1, counter = \ - gen_broadcasting_constraints(e1, e2, e11, e12, 1, counter) - - # generate dimensions to create tensors of size 2 - final_tensor_2_constraint_no_padding, final_tensor_2_constraint_padding_arg1, \ - final_tensor_2_constraint_padding_arg2, nat_dims_2, counter = \ - gen_broadcasting_constraints(e1, e2, e11, e12, 2, counter) - - # generate dimensions to create tensors of size 3 - final_tensor_3_constraint_no_padding, final_tensor_3_constraint_padding_arg1, \ - final_tensor_3_constraint_padding_arg2, nat_dims_3, counter = \ - gen_broadcasting_constraints(e1, e2, e11, e12, 3, counter) - - # generate dimensions to create tensors of size 4 - final_tensor_4_constraint_no_padding, final_tensor_4_constraint_padding_arg1, \ - final_tensor_4_constraint_padding_arg2, nat_dims_4, counter = \ - gen_broadcasting_constraints(e1, e2, e11, e12, 4, counter) - - final_result = Disj([ - e1_dyn_constraint, - e2_dyn_constraint, - final_tensor_1_constraint, - final_tensor_2_constraint_no_padding, - final_tensor_2_constraint_padding_arg1, - final_tensor_2_constraint_padding_arg2, - final_tensor_3_constraint_no_padding, - final_tensor_3_constraint_padding_arg1, - final_tensor_3_constraint_padding_arg2, - final_tensor_4_constraint_no_padding, - final_tensor_4_constraint_padding_arg1, - final_tensor_4_constraint_padding_arg2 - ]) - - return Conj([final_result, *nat_dims_1, *nat_dims_2, *nat_dims_3, *nat_dims_4]), counter - - -def transform_constraint(constraint: Constraint, counter: int): - """ - Transforms a constraint into a simpler constraint. - Ex: precision and consistency are transformed to equality - Args: - constraint: constraint to be transformed - counter: for variable tracking - - Returns: Constraint - - """ - if type(constraint) in _TRANSFORMATION_RULES: - return _TRANSFORMATION_RULES[type(constraint)](constraint, counter) - - else: - return constraint, counter - - - - -def calc_last_two_dims(constraint, d: List[DVar]): - """ - Generates constraints for the last two dimensions of a convolution or a maxpool output - Args: - constraint: CalcConv or CalcMaxPool - d: The list of output dimensions - - Returns: Constraints for calculating the last two dimensions of the output - - """ - - assert isinstance(constraint, CalcConv) or isinstance(constraint, CalcMaxPool) - - b3 = constraint.matching_constraint[2] - b4 = constraint.matching_constraint[3] - - b3_dyn = Conj([BinConstraintD(d[2], Dyn, op_eq), BinConstraintD(b3, Dyn, op_eq)]) - b4_dyn = Conj([BinConstraintD(d[3], Dyn, op_eq), BinConstraintD(b4, Dyn, op_eq)]) - - d3_not_dyn = Conj([BinConstraintD(d[2], Dyn, op_neq), BinConstraintD(b3, Dyn, op_neq)]) - d4_not_dyn = Conj([BinConstraintD(d[3], Dyn, op_neq), BinConstraintD(b4, Dyn, op_neq)]) - - # transform parameters into tuples incase they are not already - padding = (constraint.padding, constraint.padding) \ - if isinstance(constraint.padding, int) else constraint.padding - kernel = (constraint.kernel, constraint.kernel) \ - if isinstance(constraint.kernel, int) else constraint.kernel - stride = (constraint.stride, constraint.stride) \ - if isinstance(constraint.stride, int) else constraint.stride - dilation = (constraint.dilation, constraint.dilation) \ - if isinstance(constraint.dilation, int) else constraint.dilation - - f1 = BinConstraintD(b3, BinConstraintD(2, padding[0], op_mul), op_add) - f2 = BinConstraintD(dilation[0], BinConstraintD(kernel[0], 1, op_sub), op_mul) - f3 = BinConstraintD(BinConstraintD(BinConstraintD(f1, f2, op_sub), 1, op_sub), stride[0], op_div) - f4 = BinConstraintD(f3, 1, op_add) - - c4 = Disj([b3_dyn, Conj([d3_not_dyn, BinConstraintD(d[2], f4, op_eq)])]) - - f11 = BinConstraintD(b4, BinConstraintD(2, padding[1], op_mul), op_add) - f22 = BinConstraintD(dilation[1], BinConstraintD(kernel[1], 1, op_sub), op_mul) - f33 = BinConstraintD(BinConstraintD(BinConstraintD(f11, f22, op_sub), 1, op_sub), stride[1], op_div) - f44 = BinConstraintD(f33, 1, op_add) - - c5 = Disj([b4_dyn, Conj([d4_not_dyn, BinConstraintD(d[3], f44, op_eq)])]) - - return c4, c5 - - -def generate_all_int_dyn_dim_possibilities(my_list: List[DVar]): - """ - Generate all possibilities of being equal or not equal to dyn for my_list - Args: - my_list: List of tensor dimensions - - Returns: A list of a list of constraints. Each list of constraints corresponds to - one possibility about the values of the dimension variables - """ - # generate all possibilities of being equal or not equal to dyn for my_list - eq_possibilities = [BinConstraintD(my_list[i], Dyn, op_eq) for i in range(len(my_list))] - neq_possibilities = [BinConstraintD(my_list[i], Dyn, op_neq) for i in range(len(my_list))] - d_possibilities = [] - - for i in zip(eq_possibilities, neq_possibilities): - d_possibilities.append(list(i)) - all_possibilities = list(itertools.product(*d_possibilities)) - return all_possibilities - - -def is_target_div_by_dim(target: List[int], dim: List[DVar]): - """ - Generate constraints to check if the target dimensions are divisible by the input dimensions - Args: - target: Target dimensions - dim: Input dimensions - - Returns: Constraints to check divisibility - - """ - return BinConstraintD(BinConstraintD(Prod(target), dim, op_mod), 0, op_eq) - - -def is_dim_div_by_target(target: List[int], dim: List[DVar]): - """ - Generate constraints to check if the input dimensions is divisible by the target dimensions - Args: - target: Target dimensions - dim: Input dimensions - - Returns: Constraints to check divisibility - - """ - return BinConstraintD(BinConstraintD(dim, Prod(target), op_mod), 0, op_eq) - - -def gen_all_reshape_possibilities(list_of_dims, target): - """ - Consider all possibilities what the input dimensions could be (number or dynamic) - Then generate the appropriate constraints using multiplication or mod depending on the possibility - The possibilities we consider here are the cross product of being equal to dyn or not equal to dyn - for the input. Target is fixed because at most one dimension could be dyn. - We have different cases for this. - - Args: - list_of_dims: The input list of dimensions - target: The tensor we want to reshape to - - Returns: A disjuncition of transformed reshape constraints - - """ - all_possibilities = generate_all_int_dyn_dim_possibilities(list_of_dims) - - all_constraints = [] - - for p in all_possibilities: - to_multiply = [] - - p = list(p) - - for constraint in p: - assert isinstance(constraint, BinConstraintD) - if constraint.op == op_neq: - to_multiply.append(constraint.lhs) - - if not to_multiply: - all_constraints.append(Conj(p)) - - elif len(to_multiply) < len(list_of_dims): - all_constraints.append(Conj(p + [is_target_div_by_dim(target, Prod(to_multiply))])) - else: - all_constraints.append(Conj(p + [BinConstraintD(Prod(list_of_dims), - Prod(target), op_eq)])) - - return Disj(all_constraints) - - -def broadcast_dim(tensor_input1, tensor_input2, res1, res2, index, padding=False): - """ - Apply broadcasting to the 'index' dimension of tensor_input1. - Args: - tensor_input1: should represent [d1, ..., d_index, ...] where d_index = 1 - tensor_input2: represents the second input - res1: broadcasted result 1 - res2: broadcasted result 2 - index: the index to broadcast - padding: If padding was used, then tensor_input1[index] does not exist - - Returns: - - """ - if tensor_input1[index] is None: - assert padding - - - if not padding: - # then the inputs are the same length so they all have dimensions at "index" - return Conj([BinConstraintD(tensor_input1[index], 1, op_eq), - BinConstraintD(res1[index], res2[index], op_eq), - BinConstraintD(res2[index], tensor_input2[index], op_eq)]) - - else: - # we don't set the input dimension to 1, since it doesn't exist. - return Conj([BinConstraintD(res1[index], res2[index], op_eq), - BinConstraintD(res2[index], tensor_input2[index], op_eq)]) - - -def apply_padding(e1_var: TVar, - e11: BinConstraintT, - e2: BinConstraintT, - e12: BinConstraintT, - d2: List[DVar], - d11: List[DVar], - d12: List[DVar], - counter: int): - """ - We are considering the possibility where one input has less dimensions than - another input, so we apply padding to the broadcasted results - - Args: - e1_var: Variable representing the first input where padding will be - e11: constraint of the form e11 = Tensortype[d1, ..., dn] - e2: constraint of the form e2 = Tensortype[d1, ..., dn] - e12: constraint of the form e11 = Tensortype[d1, ..., dn] - d2: Tensor variables for the second input - d11: Tensor variables for the broadcasted first input - d12: Tensor variables for the broadcasted second input - counter: variable tracking - - Returns: A new constraint whose goal is to apply padding to the broadcasted result - - """ - - res = [] - - # pad the shorter input with None so we can pass it to the broadcasting helper function - for i in range(1, len(d2)): - - d1, counter = gen_tensor_dims(i, counter) - - nat_constraints = gen_nat_constraints(d1 + d2 + d11 + d12) - - e1 = BinConstraintT(e1_var, TensorType(d1), op_eq) - - simulate_padding = [None] * (len(d2) - i) - - assert len(simulate_padding + d1) == len(d2) - - broadcast_padding = [] - - # for every padding size, we also consider broadcasting - for j in range((len(d2) - i)): - broadcast_padding.append(broadcast_dim(simulate_padding, d2, d11, d12, j, True)) - - # we consider the possibilities for broadcasting for every dimension. Since we already - # padded d1, we do not consider it while broadcasting - all_broadcasting_possibilities = generate_all_broadcasting_possibilities_no_padding(d1, - d2[(len(d2) - i):], - d11[(len(d2) - i):], - d12[(len(d2) - i):]) - # combine all constraints into a conjunction - c = Conj([e1, e11, e2, e12, - *broadcast_padding, - all_broadcasting_possibilities, - *nat_constraints - ]) - res.append(c) - - return Disj(res), counter - - -def no_broadcast_dim_with_index(d1: List[DVar], - d2: List[DVar], - d3: List[DVar], - d4: List[DVar], - i: int): - """ - Args: - d1: inpput 1 - d2: inpput 2 - d3: simulated broadcasting for input 1 - d4: simulated broadcasting for input 2 - i: the rank of the resulting tensor addition - - Returns: Constraints for when no broadcasting occurs - """ - return Conj([ - Disj([ - Conj([BinConstraintD(d1[i], 1, op_eq), - BinConstraintD(d2[i], 1, op_eq)]), - - Conj([BinConstraintD(d1[i], 1, op_neq), - BinConstraintD(d2[i], 1, op_neq)])]), - - BinConstraintD(d1[i], d3[i], op_eq), - BinConstraintD(d2[i], d4[i], op_eq)]) - - - -def gen_lists_of_dims(num_tensors: int, dim_size: int, counter: int): - """ - Generate lists of DVar to represent tensor dimensions - Args: - num_tensors: the required number of tensors - dim_size: the number of dimensions for each tensor - counter: variable tracking - - Returns: A list of a list of tensor dimensions - - """ - res = [] - - for _ in range(num_tensors): - dims, counter = gen_tensor_dims(dim_size, counter) - res.append(dims) - - return res, counter - - -def create_equality_constraints_for_broadcasting(e1: TVar, - e2: TVar, - e11: TVar, - e12: TVar, - d1: List[DVar], - d2: List[DVar], - d11: List[DVar], - d12: List[DVar]): - """ - Create equality constraints for when no broadcasting occurs - Args: - e1: Input 1 - e2: Input 2 - e11: Broadcasted input 1 - e12: Broadcasted input 2 - d1: Variables that store dimensions for e1 - d2: Variables that store dimensions for e2 - d11: Variables that store dimensions for e11 - d12: Variables that store dimensions for e22 - - Returns: Four equality constraints - - """ - - e1_tensor = BinConstraintT(e1, TensorType(d1), op_eq) - e11_tensor = BinConstraintT(e11, TensorType(d11), op_eq) - e2_tensor = BinConstraintT(e2, TensorType(d2), op_eq) - e12_tensor = BinConstraintT(e12, TensorType(d12), op_eq) - return [e1_tensor, e11_tensor, e2_tensor, e12_tensor] - - -def gen_consistency_constraints(constraint: Constraint, counter: int): - """ - Args: - constraint: Consistency constraint on tensors - counter: for variable tracking - - Returns: Equality and consistency constraints on dimensions - - """ - - all_constraints = [] - - for i in range(1, MAX_TENSOR_RANK + 1): - new_dims_rhs_1, counter = gen_tensor_dims(i, counter) - new_dims_rhs_2, counter = gen_tensor_dims(i, counter) - - nat_constraints = gen_nat_constraints(new_dims_rhs_1 + new_dims_rhs_2) - - c_tensor_i = Conj([BinConstraintT(constraint.lhs, TensorType(new_dims_rhs_1), op_eq), - BinConstraintT(constraint.rhs, TensorType(new_dims_rhs_2), op_eq)] + - [BinConstraintD(d1, d2, op_consistency) for - d1, d2 in zip(new_dims_rhs_1, new_dims_rhs_2)] + nat_constraints) - - all_constraints.append(c_tensor_i) - - return all_constraints, counter - - -def gen_greatest_upper_bound(constraint: TGreatestUpperBound, counter: int): - """ - Args: - constraint: Greatest upper bound on tensors - counter: variable tracking - - Returns: A set of equality constraints and DGreatestUpperBound constraints - - """ - - all_constraints = [] - - for i in range(1, MAX_TENSOR_RANK + 1): - c = [] - dims1, counter = gen_tensor_dims(i, counter) - c1tensor = TensorType(dims1) - - dims2, counter = gen_tensor_dims(i, counter) - c2tensor = TensorType(dims2) - - dims3, counter = gen_tensor_dims(i, counter) - c3tensor = TensorType(dims3) - - c += [BinConstraintT(constraint.rhs1, c1tensor, op_eq), - BinConstraintT(constraint.rhs2, c2tensor, op_eq), - BinConstraintT(constraint.res, c3tensor, op_eq)] + \ - gen_nat_constraints(dims1 + dims2 + dims3) - - assert len(c3tensor.__args__) == len(c1tensor.__args__) == len(c2tensor.__args__) - for i in range(len(c3tensor.__args__)): - c.append(DGreatestUpperBound(c3tensor.__args__[i], - c1tensor.__args__[i], - c2tensor.__args__[i])) - - all_constraints.append(Conj(c)) - return all_constraints, counter - - -def generate_all_broadcasting_possibilities_no_padding(d1: List[DVar], d2: List[DVar], d11: List[DVar], d12: List[DVar]): - """ - Generate broadcasting constraints assuming no padding. Broadcasting can happen at any dimension. - We look at all combinations for all dimendions in d1 and d2 - Args: - d1: input1 dimensions - d2: input2 dimensions - d11: broadcasted input1 dimensions - d12: broadcasted input2 dimensions - - Returns: broadcasting constraints relating the input dimensions to the broadcasted dimensions - - """ - - size = len(d1) - - res2 = [] - - for i in range(size): - t1 = broadcast_dim(d1, d2, d11, d12, i) - t2 = broadcast_dim(d2, d1, d12, d11, i) - t3 = no_broadcast_dim_with_index(d1, d2, d11, d12, i) - - res2.append(Disj([t1, t2, t3])) - - return Conj(res2) - - -def gen_broadcasting_constraints(e1: TVar, e2: TVar, e11: TVar, e12: TVar, i: int, counter: int): - """ - Simulates broadcasting on e1 and e2 and returns the results - respectively in e11 and e12. Because of gradual types, - e1 and e2 may not be equal. Similarly, e11 and e12 may not - be equal. e11 and e12 should be guaranteed to be consistent - as they represent the shapes of the tensors to be added after - broadcasting. - Args: - e1: TVar representing the type of input 1 - e2: TVar representing the type of input 2 - e11: TVar representing the representing broadcasted input 1 - e12: TVar representing the representing broadcasted input 2 - i: The rank of the resulting type of addition - counter: for variable tracking - - Returns: Simplified broadcasting constraints - - """ - dims, counter = gen_lists_of_dims(4, i, counter) - [d1, d2, d3, d4] = dims - nat_dims_i = gen_nat_constraints(list(itertools.chain(*dims))) - - initialize_tensors_constraints = create_equality_constraints_for_broadcasting(e1, e2, e11, e12, - d1, d2, d3, d4) - - [e1_tensor, e11_tensor, e2_tensor, e12_tensor] = initialize_tensors_constraints - - # without padding, broadcast all possibilities for tensors of size i - final_tensor_constraint_no_padding = Conj([*initialize_tensors_constraints, - generate_all_broadcasting_possibilities_no_padding(d1, d2, d3, d4)]) - - # with padding, broadcast all possibilities for tensors of size i - final_tensor_constraint_padding_arg1, counter = \ - apply_padding(e1, e11_tensor, e2_tensor, e12_tensor, d2, d3, d4, counter) - - final_tensor_constraint_padding_arg2, counter = \ - apply_padding(e2, e12_tensor, e1_tensor, e11_tensor, d1, d4, d3, counter) - - return final_tensor_constraint_no_padding, \ - final_tensor_constraint_padding_arg1, \ - final_tensor_constraint_padding_arg2, nat_dims_i, counter diff --git a/pippy/fx/experimental/migrate_gradual_types/operation.py b/pippy/fx/experimental/migrate_gradual_types/operation.py deleted file mode 100644 index ef7c670bf..000000000 --- a/pippy/fx/experimental/migrate_gradual_types/operation.py +++ /dev/null @@ -1,16 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates -# -*- coding: utf-8 -*- -op_add = '+' -op_sub = '-' -op_mul = '*' -op_div = '/' -op_eq = '=' -op_neq = '!=' -op_imp = '=>' -op_matching = '⊳' -op_consistency = '~' -op_precision = '⊑' -op_leq = '≤' -op_lt = '<' -op_gt = '>' -op_mod = '%' diff --git a/pippy/fx/experimental/migrate_gradual_types/transform_to_z3.py b/pippy/fx/experimental/migrate_gradual_types/transform_to_z3.py deleted file mode 100644 index dbf1b4d4c..000000000 --- a/pippy/fx/experimental/migrate_gradual_types/transform_to_z3.py +++ /dev/null @@ -1,349 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates -from pippy.fx.experimental.migrate_gradual_types.constraint import Conj, Disj, T, F, BinConstraintT, BVar, is_bool_expr -from pippy.fx.experimental.migrate_gradual_types.constraint import BinConstraintD, TVar, DVar -from pippy.fx.experimental.migrate_gradual_types.constraint import Prod, is_algebraic_expression, is_dim -from pippy.fx.experimental.migrate_gradual_types.constraint_generator import ConstraintGenerator -from pippy.fx.experimental.migrate_gradual_types.constraint_transformation import transform_constraint -from pippy.fx.experimental.migrate_gradual_types.operation import op_add, op_eq, op_neq, op_gt, op_lt -from pippy.fx.experimental.migrate_gradual_types.operation import op_leq, op_sub, op_div, op_mul, op_mod -from pippy.fx.tensor_type import TensorType, Dyn - -try: - import z3 # type: ignore[import] - from pippy.fx.experimental.migrate_gradual_types.z3_types import tensor_type, z3_dyn, D - HAS_Z3 = True - - def transform_to_z3(constraint, counter, dimension_dict): - if isinstance(constraint, Conj): - conjuncts = [] - for c in constraint.conjucts: - new_c, counter = transform_to_z3(c, counter, dimension_dict) - conjuncts.append(new_c) - return z3.And(conjuncts), counter - - elif isinstance(constraint, Disj): - disjuncts = [] - for c in constraint.disjuncts: - new_c, counter = transform_to_z3(c, counter, dimension_dict) - disjuncts.append(new_c) - return z3.Or(disjuncts), counter - - elif isinstance(constraint, T): - return True, counter - - elif isinstance(constraint, F): - return False, counter - - elif isinstance(constraint, BinConstraintT): - if constraint.op == op_eq: - lhs, counter = transform_var(constraint.lhs, counter, dimension_dict) - rhs, counter = transform_var(constraint.rhs, counter, dimension_dict) - return (lhs == rhs), counter - - else: - raise NotImplementedError('Method not yet implemented') - - elif isinstance(constraint, BinConstraintD): - if constraint.op == op_eq: - - if isinstance(constraint.lhs, BVar) and is_bool_expr(constraint.rhs): - transformed_rhs, counter = transform_to_z3(constraint.rhs, counter, dimension_dict) - transformed_lhs = z3.Bool(constraint.lhs.c) - return transformed_lhs == transformed_rhs, counter - - elif is_dim(constraint.lhs) and is_dim(constraint.rhs): - # with dimension tranformations we consider the encoding - lhs, counter = transform_dimension(constraint.lhs, counter, dimension_dict) - rhs, counter = transform_dimension(constraint.rhs, counter, dimension_dict) - return lhs == rhs, counter - - else: - # then we have an algebraic expression which means that we disregard the - # first element of the encoding - lhs, counter = transform_algebraic_expression(constraint.lhs, counter, dimension_dict) - rhs, counter = transform_algebraic_expression(constraint.rhs, counter, dimension_dict) - return lhs == rhs, counter - - # The assumption here is that the LHS and RHS must be dimensions - elif constraint.op == op_neq: - assert is_dim(constraint.lhs) - assert is_dim(constraint.rhs) - lhs, counter = transform_dimension(constraint.lhs, counter, dimension_dict) - rhs, counter = transform_dimension(constraint.rhs, counter, dimension_dict) - if constraint.rhs == Dyn or constraint.lhs == Dyn: - if constraint.rhs == Dyn: - return lhs.arg(0) == 1, counter - elif constraint.lhs == Dyn: - return rhs.arg(0) == 1, counter - - # if one of the instances is a number - elif isinstance(constraint.lhs, int) or isinstance(constraint.rhs, int): - if isinstance(constraint.lhs, int): - return z3.Or([rhs.arg(0) == 0, z3.And([rhs.arg(0) == 1, lhs.arg(1) != rhs.arg(1)])]), counter - - elif isinstance(constraint.rhs, int): - return z3.Or([lhs.arg(0) == 0, z3.And([lhs.arg(0) == 1, lhs.arg(1) != rhs.arg(1)])]), counter - - else: - return z3.Or([z3.And([lhs.arg(0) == 0, rhs.arg(0) != 0]), - z3.And([lhs.arg(0) != 0, rhs.arg(0) == 0]), - z3.And([lhs.arg(0) != 0, rhs.arg(0) != 0, lhs.arg(1) != rhs.arg(1)])]), counter - - - elif constraint.op == op_leq: - # if the dimensions are not dyn, this will come into effect - # there would have been another constraint specifying if a given dimension - # is dyn or not - assert is_dim(constraint.lhs) and is_dim(constraint.rhs) - lhs, counter = transform_algebraic_expression(constraint.lhs, counter, dimension_dict) - rhs, counter = transform_algebraic_expression(constraint.rhs, counter, dimension_dict) - return lhs <= rhs, counter - - elif constraint.op == op_gt: - assert is_dim(constraint.lhs) and is_dim(constraint.rhs) - lhs, counter = transform_algebraic_expression(constraint.lhs, counter, dimension_dict) - rhs, counter = transform_algebraic_expression(constraint.rhs, counter, dimension_dict) - return lhs > rhs, counter - - elif constraint.op == op_lt: - assert is_dim(constraint.lhs) and is_dim(constraint.rhs) - lhs, counter = transform_algebraic_expression(constraint.lhs, counter, dimension_dict) - rhs, counter = transform_algebraic_expression(constraint.rhs, counter, dimension_dict) - return lhs < rhs, counter - - else: - raise NotImplementedError('operation not yet implemented') - - else: - raise NotImplementedError('Operation not yet implemented') - - - def transform_var(tensor, counter, dimension_dict): - """ - Transforms tensor variables to a format understood by z3 - Args: - tensor: Tensor variable or a tensor type potentially with variable dimensions - Returns: Transformed variable to a z3 format - - """ - if isinstance(tensor, TensorType): - res = [] - for t in tensor.__args__: - transformed, counter = transform_dimension(t, counter, dimension_dict) - res.append(transformed) - - assert len(res) <= 4 - if len(tensor.__args__) == 1: - return tensor_type.tensor1(res[0]), counter - elif len(tensor.__args__) == 2: - return tensor_type.tensor2(res[0], res[1]), counter - elif len(tensor.__args__) == 3: - return tensor_type.tensor3(res[0], res[1], res[2]), counter - elif len(tensor.__args__) == 4: - return tensor_type.tensor4(res[0], res[1], res[2], res[3]), counter - - elif tensor == Dyn: - return z3_dyn, counter - - elif isinstance(tensor, TVar): - return z3.Const(tensor.tvar, tensor_type), counter - - def transform_dimension(dimension, counter, dimension_dict): - """ - Takes a dimension variable or a number and transforms it to a tuple - according to our scheme - Args: - dimension: The dimension to be transformed - counter: variable tracking - - Returns: tuple and the current counter - - """ - if dimension == Dyn: - counter += 1 - return D(0, z3.Int(counter)), counter - elif isinstance(dimension, int): - return D(1, dimension), counter - elif isinstance(dimension, DVar): - if dimension.c in dimension_dict: - return D(z3.Int(dimension_dict[dimension.c]), z3.Int(dimension.c)), counter - else: - counter += 1 - dimension_dict[dimension.c] = counter - return D(z3.Int(counter), z3.Int(dimension.c)), counter - - - def transform_algebraic_expression(expr, counter, dimension_dict): - """ - Transforms an algebraic expression to z3 format - Args: - expr: An expression is either a dimension variable or an algebraic-expression - - - Returns: the transformed expression - - """ - assert is_algebraic_expression(expr) or is_dim(expr) - - if is_dim(expr): - transformed, counter = transform_dimension(expr, counter, dimension_dict) - return transformed.arg(1), counter - - elif isinstance(expr, Prod): - - dims = [] - for dim in expr.products: - assert is_dim(dim) - d, counter = transform_dimension(dim, counter, dimension_dict) - dims.append(d.arg(1)) - return z3.Product(dims), counter - - elif is_algebraic_expression(expr): - - lhs, counter = transform_algebraic_expression(expr.lhs, counter, dimension_dict) - rhs, counter = transform_algebraic_expression(expr.rhs, counter, dimension_dict) - - if expr.op == op_sub: - c = lhs - rhs - - elif expr.op == op_add: - c = lhs + rhs - - elif expr.op == op_div: - c = lhs / rhs - - elif expr.op == op_mul: - c = lhs * rhs - - elif expr.op == op_mod: - c = lhs % rhs - - else: - raise NotImplementedError('operation not yet implemented') - - return c, counter - - else: - raise RuntimeError - - - def transform_all_constraints(traced, counter=0): - """ - Given a trace, generates constraints and transforms them to z3 format - - """ - dimension_dict = {} # type: ignore[var-annotated] - - generator = ConstraintGenerator(traced) - new_constraints, counter = generator.generate_constraints(counter) - - # print(new_constraints.conjucts[0]) - # print(*new_constraints.conjucts, sep='\n') - - # transform precision, matching, consistency till obtaining a fixed point - new_constraints, counter = iterate_till_fixed_point(new_constraints, counter) - # print(new_constraints) - # print(new_constraints.conjucts) - # new_constraints.conjucts = new_constraints.conjucts[:-1] - # print(*new_constraints.conjucts, sep='\n') - - transformed, counter = transform_to_z3(new_constraints, counter, dimension_dict) - # print(transformed) - return transformed - - def iterate_till_fixed_point(constraints, counter): - """ - Transform constraints till reaching a fixed point - """ - old_c = None - while old_c != constraints: - old_c = constraints - constraints, counter = transform_constraint(constraints, counter) - return constraints, counter - - def transform_all_constraints_trace_time(tracer_root, graph, node, counter=0): - """ - Takes a node and a graph and generates two sets of constraints. - One set constraints the node's constraints and another set - constraints the negation of the node's constraints - Args: - tracer_root: the root for getting the module instances - graph: the graph so far in the tracing process - node: node that represents a conditional - counter: variable tracking - - Returns: Two sets of constraints. One with a conjunction with the - the conditional constraint and the other with a conjunction with - its negation. - - """ - dimension_dict = {} # type: ignore[var-annotated] - - generator = ConstraintGenerator(tracer_root, graph) - new_constraints, counter = generator.generate_constraints(counter) - - condition_constraint = new_constraints.conjucts[-1] - - # we know the constraint is a conjunction where the last constraint is about the conditional - # so remove the last constraint - new_constraints.conjucts = new_constraints.conjucts[:-1] - - # transform precision, matching, consistency till obtaining a fixed point - new_constraints, counter = iterate_till_fixed_point(new_constraints, counter) - - - # since the function returns a list of one element, we get the first element - # we are only interested in the RHS in this case because the LHS just stores - # the result - - # we make sure the constraint is of the form: - # c = b where b is a boolean expression - # and we consider b (constraint.rhs) for transformation - assert isinstance(condition_constraint.lhs, BVar) - assert is_bool_expr(condition_constraint.rhs) - condition_constraint_rhs = condition_constraint.rhs - - # transform the condition constraint - condition_constraint_rhs, counter = iterate_till_fixed_point(condition_constraint_rhs, counter) - - transformed, counter = transform_to_z3(new_constraints, counter, dimension_dict) - - transformed_condition_constraint, counter = transform_to_z3(condition_constraint_rhs, counter, dimension_dict) - - negation_transformed_condition_constraint = z3.Not(transformed_condition_constraint) - - return z3.And([transformed, transformed_condition_constraint]),\ - z3.And([transformed, negation_transformed_condition_constraint]) - - - def evaluate_conditional_with_constraints(tracer_root, graph, node, counter=0, user_constraints=None): - """ - Given an IR and a node representing a conditional, evaluate the conditional - and its negation - Args: - tracer_root: Tracer root for module instances - node: The node to be evaluated - - Returns: the results of evaluating the condition and the negation with - the rest of the constraints - - """ - - transformed_positive, transformed_negative = \ - transform_all_constraints_trace_time(tracer_root, graph, node, counter) - - s = z3.Solver() - s.add(transformed_positive) - if user_constraints is not None: - s.add(user_constraints) - condition = s.check() - - s = z3.Solver() - s.add(transformed_negative) - if user_constraints is not None: - s.add(user_constraints) - negation = s.check() - return condition, negation - -except ImportError: - HAS_Z3 = False diff --git a/pippy/fx/experimental/migrate_gradual_types/util.py b/pippy/fx/experimental/migrate_gradual_types/util.py deleted file mode 100644 index 89ab32648..000000000 --- a/pippy/fx/experimental/migrate_gradual_types/util.py +++ /dev/null @@ -1,53 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates -from pippy.fx.experimental.migrate_gradual_types.constraint import TVar, DVar, BinConstraintD, \ - BVar -from pippy.fx.experimental.migrate_gradual_types.operation import op_leq - - -def gen_tvar(curr): - """ - Generate a tensor variable - :param curr: The current counter - :return: a tensor variable and the updated counter - """ - curr += 1 - return TVar(curr), curr - - -def gen_dvar(curr): - """ - Generate a dimension variable - :param curr: the current counter - :return: a dimension variable and an updated counter - """ - curr += 1 - return DVar(curr), curr - -def gen_bvar(curr): - """ - Generate a boolean variable - :param curr: the current counter - :return: a boolean variable and an updated counter - """ - curr += 1 - return BVar(curr), curr - -def gen_tensor_dims(n, curr): - """ - Generate a list of tensor dimensions - :param n: the number of dimensions - :param curr: the current counter - :return: a list of dimension variables and an updated counter - """ - dims = [] - for _ in range(n): - dvar, curr = gen_dvar(curr) - dims.append(dvar) - return dims, curr - - -def gen_nat_constraints(list_of_dims): - """ - Generate natural number constraints for dimensions - """ - return [BinConstraintD(0, d, op_leq) for d in list_of_dims] diff --git a/pippy/fx/experimental/migrate_gradual_types/z3_types.py b/pippy/fx/experimental/migrate_gradual_types/z3_types.py deleted file mode 100644 index 851e4bc89..000000000 --- a/pippy/fx/experimental/migrate_gradual_types/z3_types.py +++ /dev/null @@ -1,30 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates -try: - import z3 # type: ignore[import] - HAS_Z3 = True - # dynamic type - dyn = z3.DeclareSort('Dyn') - dyn_type = z3.Const('dyn', dyn) - - # dimension - dim = z3.Datatype('dim') - dim.declare('dim', ('0', z3.IntSort()), ('1', z3.IntSort())) - dim = dim.create() - - # tensors - tensor_type = z3.Datatype('TensorType') - tensor_type.declare('Dyn', ('dyn', dyn)) - tensor_type.declare('tensor1', ('0', dim)) - tensor_type.declare('tensor2', ('0', dim), ('1', dim)) - tensor_type.declare('tensor3', ('0', dim), ('1', dim), ('2', dim)) - tensor_type.declare('tensor4', ('0', dim), ('1', dim), ('2', dim), ('3', dim)) - tensor_type = tensor_type.create() - - # create dimension - D = dim.dim - - z3_dyn = tensor_type.Dyn(dyn_type) - - -except ImportError: - HAS_Z3 = False diff --git a/pippy/fx/experimental/normalize.py b/pippy/fx/experimental/normalize.py deleted file mode 100644 index c92dbf973..000000000 --- a/pippy/fx/experimental/normalize.py +++ /dev/null @@ -1,163 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates -import operator -from typing import Any, Callable, Dict, Tuple, Optional - -import torch -import pippy.fx -import pippy.fx as fx -from pippy.fx import Transformer, Proxy -from pippy.fx.node import Argument, Target, Node, map_aggregate -from pippy.fx.operator_schemas import ( - normalize_module, - normalize_function, - create_type_hint, -) - -from .schema_type_annotation import AnnotateTypesWithSchema - - -class NormalizeArgs(Transformer): - """ - Normalize arguments to Python targets. This means that - `args/kwargs` will be matched up to the module/functional's - signature and rewritten to exclusively kwargs in positional order - if `normalize_to_only_use_kwargs` is true. Also populates default - values. Does not support positional-only parameters or varargs - parameters (*args, **kwargs). - - If the nodes have 'type' metadata, it will use it to disambiguate - overloads. Otherwise, it will throw an error. - - Example usage: - m = torchvision.models.resnet18() - traced = pippy.fx.symbolic_trace(m) - traced = NormalizeArgs(traced).transform() - """ - - def __init__( - self, module: pippy.fx.GraphModule, normalize_to_only_use_kwargs: bool = True - ): - super().__init__(module) - self.node_map: Dict[Proxy, Node] = {} - self.normalize_to_only_use_kwargs = normalize_to_only_use_kwargs - - def run_node(self, n: Node) -> Any: - args, kwargs = self.fetch_args_kwargs_from_env(n) - - def get_type(arg): - if isinstance(arg, fx.Node): - return n.meta["type"] if "type" in n.meta else None - return type(arg) - - arg_types = map_aggregate(n.args, get_type) - assert isinstance(arg_types, tuple) - arg_types = tuple([create_type_hint(i) for i in arg_types]) - kwarg_types = {k: get_type(v) for k, v in kwargs.items()} - if n.op == "call_function": - out = self.call_function(n.target, args, kwargs, arg_types, kwarg_types) - else: - out = super().run_node(n) - if n.op != "output": - self.node_map[out] = n - out.node.meta = n.meta - out.node.type = n.type - return out - - def call_function( - self, - target: Target, - args: Tuple[Argument, ...], - kwargs: Dict[str, Any], - arg_types: Optional[Tuple[Any, ...]] = None, - kwarg_types: Optional[Dict[str, Any]] = None, - ): - assert callable(target) - new_args_and_kwargs = normalize_function( - target, - args, # type: ignore[arg-type] - kwargs, - arg_types, # type: ignore[arg-type] - kwarg_types, - self.normalize_to_only_use_kwargs, - ) - if new_args_and_kwargs: - new_args, new_kwargs = new_args_and_kwargs - return self.tracer.create_proxy( - "call_function", target, new_args, new_kwargs - ) - else: - return super().call_function(target, args, kwargs) - - def call_module( - self, target: Target, args: Tuple[Argument, ...], kwargs: Dict[str, Any] - ): - assert isinstance(target, str) - new_args_and_kwargs = normalize_module( - self.module, - target, - args, # type: ignore[arg-type] - kwargs, - self.normalize_to_only_use_kwargs, - ) - if new_args_and_kwargs: - new_args, new_kwargs = new_args_and_kwargs - return super().call_module(target, new_args, new_kwargs) - else: - return super().call_module(target, args, kwargs) - - -class NormalizeOperators(AnnotateTypesWithSchema): - """ - Normalize callsites that are different ways of "spelling" the same - invocation into a single, canonical call. Currently supports: - - 1. Normalize operators (e.g. operator.add) to the `torch` ops they - ultimately invoke (e.g. torch.add) when it is possible to statically - reason that - - Example usage: - - m = torchvision.models.resnet18() - - traced = pippy.fx.symbolic_trace(m) - - traced = NormalizeOperators(traced).transform() - """ - - binary_magic_method_remap: Dict[ - Callable[[Any, Any], Any], Callable[[Any, Any], Any] - ] = { - torch.add: operator.add, - torch.mul: operator.mul, - torch.sub: operator.sub, - torch.div: operator.truediv, - torch.floor_divide: operator.floordiv, - torch.remainder: operator.mod, - torch.eq: operator.eq, - torch.ne: operator.ne, - torch.lt: operator.lt, - torch.le: operator.le, - torch.gt: operator.gt, - torch.ge: operator.ge, - } - - def call_function( - self, target: Target, args: Tuple[Argument, ...], kwargs: Dict[str, Any] - ): - # Normalize operators according to the magic methods implemented on tensors here: - # https://github.com/pytorch/pytorch/blob/28c5d90b679c6b38bf4183ec99f16d933c2f1bcd/tools/autograd/templates/python_variable_methods.cpp#L1137 # noqa: B950 - - assert callable(target) - - if target in self.binary_magic_method_remap: - if len(args) != 2: - return super().call_function(target, args, kwargs) - lhs, rhs = args - - return super().call_function( - target=self.binary_magic_method_remap[target], - args=(lhs, rhs), - kwargs={}, - ) - - return super().call_function(target, args, kwargs) diff --git a/pippy/fx/experimental/optimization.py b/pippy/fx/experimental/optimization.py deleted file mode 100644 index 2f9eb07d8..000000000 --- a/pippy/fx/experimental/optimization.py +++ /dev/null @@ -1,406 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates -import pippy.fx as fx -from pippy.fx.node import Argument, Target -from torch.nn.utils.fusion import fuse_conv_bn_eval -from typing import Type, Dict, Any, Tuple, Iterable, Optional, List, cast -import torch -import torch.nn as nn -import torch.nn.functional as F -from pippy.fx.passes.shape_prop import ShapeProp -import copy -from collections import defaultdict -import torch.utils.mkldnn as th_mkldnn -import operator -import time -import logging -from enum import Enum - -def _parent_name(target : str) -> Tuple[str, str]: - """ - Splits a qualname into parent path and last atom. - For example, `foo.bar.baz` -> (`foo.bar`, `baz`) - """ - *parent, name = target.rsplit('.', 1) - return parent[0] if parent else '', name - -# Works for length 2 patterns with 2 modules -def matches_module_pattern(pattern: Iterable[Type], node: fx.Node, modules: Dict[str, Any]): - if len(node.args) == 0: - return False - nodes: Tuple[Any, fx.Node] = (node.args[0], node) - for expected_type, current_node in zip(pattern, nodes): - if not isinstance(current_node, fx.Node): - return False - if current_node.op != 'call_module': - return False - if not isinstance(current_node.target, str): - return False - if current_node.target not in modules: - return False - if type(modules[current_node.target]) is not expected_type: - return False - return True - - -def replace_node_module(node: fx.Node, modules: Dict[str, Any], new_module: torch.nn.Module): - assert(isinstance(node.target, str)) - parent_name, name = _parent_name(node.target) - modules[node.target] = new_module - setattr(modules[parent_name], name, new_module) - -def fuse(model: torch.nn.Module, inplace=False) -> torch.nn.Module: - """ - Fuses convolution/BN layers for inference purposes. Will deepcopy your - model by default, but can modify the model inplace as well. - """ - patterns = [(nn.Conv1d, nn.BatchNorm1d), - (nn.Conv2d, nn.BatchNorm2d), - (nn.Conv3d, nn.BatchNorm3d)] - if not inplace: - model = copy.deepcopy(model) - fx_model = fx.symbolic_trace(model) - modules = dict(fx_model.named_modules()) - new_graph = copy.deepcopy(fx_model.graph) - - for pattern in patterns: - for node in new_graph.nodes: - if matches_module_pattern(pattern, node, modules): - if len(node.args[0].users) > 1: # Output of conv is used by other nodes - continue - conv = modules[node.args[0].target] - bn = modules[node.target] - if not bn.track_running_stats: - continue - fused_conv = fuse_conv_bn_eval(conv, bn) - replace_node_module(node.args[0], modules, fused_conv) - node.replace_all_uses_with(node.args[0]) - new_graph.erase_node(node) - return fx.GraphModule(fx_model, new_graph) - -def remove_dropout(model: nn.Module) -> nn.Module: - """ - Removes all dropout layers from the module. - """ - fx_model = fx.symbolic_trace(model) - - class DropoutRemover(fx.Transformer): - def call_module(self, target : Target, args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any: - if isinstance(self.submodules[target], nn.Dropout): - assert len(args) == 1 - return args[0] - else: - return super().call_module(target, args, kwargs) - return DropoutRemover(fx_model).transform() - -def extract_subgraph(orig_module: nn.Module, nodes: List[fx.Node], inputs: List[fx.Node], outputs: List[fx.Node]): - """ - Given lists of nodes from an existing graph that represent a subgraph, returns a submodule that executes that subgraph. - """ - new_graph = fx.Graph() - env: Dict[fx.Node, fx.Node] = {} - for input in inputs: - new_node = new_graph.placeholder(input.name) - env[input] = new_node - for node in nodes: - new_node = new_graph.node_copy(node, lambda x: env[x]) - env[node] = new_node - new_graph.output([env[output] for output in outputs]) - new_graph.lint() - return fx.GraphModule(orig_module, new_graph) - -mkldnn_supported = [ - nn.Conv2d, nn.Linear, nn.BatchNorm2d, nn.ReLU, nn.MaxPool2d, nn.AvgPool2d, nn.AdaptiveAvgPool2d, - torch.relu, torch.transpose, torch.sigmoid, - F.relu, F.avg_pool2d, F.adaptive_avg_pool2d -] -# These are operators that may not be convertible into MKLDNN ops (e.g. the -# args are scalar values). Thus, we only include them in the subgraph if their -# arguments are already in MKLDNN. -# TODO: Determine whether this can be removed after type inference. -mkldnn_supported_unknown = [operator.add, operator.mul] -mkldnn_map = { - nn.Conv2d: th_mkldnn.MkldnnConv2d, - nn.Linear: th_mkldnn.MkldnnLinear, - nn.BatchNorm2d: lambda a, _: th_mkldnn.MkldnnBatchNorm(a) -} - - -def modules_to_mkldnn(nodes: List[fx.Node], modules: Dict[str, nn.Module]): - """ - For each node, if it's a module that can be preconverted into MKLDNN, - then we do so and create a mapping to allow us to convert from the MKLDNN - version of the module to the original. - """ - old_modules: Dict[nn.Module, nn.Module] = {} - for node in nodes: - if node.op == 'call_module': - assert(isinstance(node.target, str)) - cur_module = modules[node.target] - if type(cur_module) in mkldnn_map: - new_module = mkldnn_map[type(cur_module)](cur_module, torch.float) - assert(isinstance(new_module, nn.Module)) - old_modules[new_module] = copy.deepcopy(cur_module) - replace_node_module(node, modules, new_module) - return old_modules - -def reset_modules(nodes: List[fx.Node], modules: Dict[str, nn.Module], old_modules: Dict[nn.Module, nn.Module]): - """ - Maps each module that's been changed with `modules_to_mkldnn` back to its - original. - """ - for node in nodes: - if node.op == 'call_module': - assert(isinstance(node.target, str)) - cur_module = modules[node.target] - if cur_module in old_modules: - replace_node_module(node, modules, old_modules[cur_module]) - -class MklSubgraph: - def __init__(self, fx_graph: fx.Graph): - self.fx_graph = fx_graph - self.nodes: List[fx.Node] = [] - self.start_nodes: List[fx.Node] = [] - self.end_nodes: List[fx.Node] = [] - -def gen_mkl_autotuner(example_inputs, iters=10, warmup=1): - """ - This generates a heuristic that can be passed into `optimize_for_inference` that - determines whether a subgraph should be run in MKL by running it with the example_inputs. - - Example usage: - heuristic = gen_mkl_autotuner(example_inputs, iters=10) - fast_model = optimization.optimize_for_inference(model, heuristic) - """ - fx_model = None - old_modules = None - - def use_mkl_heuristic(graph: MklSubgraph) -> bool: - nonlocal fx_model, old_modules - input_nodes = graph.start_nodes - if fx_model is None: - fx_model = graph.fx_graph.owning_module - old_modules = graph.fx_graph.old_modules # type: ignore[attr-defined] - ShapeProp(fx_model).propagate(example_inputs) - sample_inputs = [torch.randn(node.shape) for node in input_nodes] # type: ignore[attr-defined] - output_args = cast(List[fx.Node], [node.args[0] for node in graph.end_nodes]) - submodule = extract_subgraph(fx_model, graph.nodes, input_nodes, output_args) - - def benchmark(f): - for _ in range(warmup): - f() - begin = time.time() - for _ in range(iters): - out = f() - return time.time() - begin - - mkl_time = benchmark(lambda: [i.to_dense() for i in submodule(*[i.to_mkldnn() for i in sample_inputs])]) - - reset_modules(submodule.graph.nodes, dict(submodule.named_modules()), old_modules) - no_mkl_time = benchmark(lambda: submodule(*sample_inputs)) - return mkl_time < no_mkl_time - return use_mkl_heuristic - -def use_mkl_length(graph: MklSubgraph) -> bool: - """ - This is a heuristic that can be passed into `optimize_for_inference` that - determines whether a subgraph should be run in MKL by checking if there - are more than 2 nodes in it - """ - return len(graph.nodes) > 2 - -class UnionFind: - def __init__(self, n): - self.parent: List[Optional[int]] = [None] * n - self.size: List[int] = [0] * n - - def make_set(self, v: int): - self.parent[v] = v - self.size[v] = 1 - - def find(self, v: int) -> int: - par = self.parent[v] - if v == par: - return v - assert(par is not None) - self.parent[v] = self.find(par) - return cast(int, self.parent[v]) - - def join(self, a: int, b: int): - a, b = self.find(a), self.find(b) - if a == b: - return a - if self.size[a] < self.size[b]: - a, b = b, a - self.parent[b] = a - self.size[a] += self.size[b] - -def optimize_for_inference( - model: torch.nn.Module, - pass_config: Optional[Dict[str, Any]] = None, - tracer: Type[fx.Tracer] = fx.Tracer -) -> torch.nn.Module: - """ - Performs a set of optimization passes to optimize a model for the - purposes of inference. Specifically, the passes that are run are: - 1. Conv/BN fusion - 2. Dropout removal - 3. MKL layout optimizations - - The third optimization takes a function `use_mkl_heuristic` that's used - to determine whether a subgraph should be explicity run in MKL layout. - - Note: As FX does not currently handle aliasing, this pass currently - assumes nothing aliases. If that isn't true, use at your own risk. - """ - default_pass_config = { - "conv_bn_fuse": True, - "remove_dropout": True, - "mkldnn_layout_optimize": {'heuristic': use_mkl_length}, - } - if pass_config is None: - pass_config = {} - default_pass_config.update(pass_config) - - if default_pass_config["conv_bn_fuse"]: - model = fuse(model) - if default_pass_config["remove_dropout"]: - model = remove_dropout(model) - if default_pass_config["mkldnn_layout_optimize"] is False: - return model - if not isinstance(default_pass_config["mkldnn_layout_optimize"], dict): - raise RuntimeError("mkldnn_layout_optimize config is not a dict") - if "heuristic" not in default_pass_config["mkldnn_layout_optimize"]: - raise RuntimeError("Heuristic not found in mkldnn_layout_optimize config") - use_mkl_heuristic = default_pass_config["mkldnn_layout_optimize"]["heuristic"] - - cur_tracer = tracer() - fx_graph = cur_tracer.trace(copy.deepcopy(model)) - fx_model = fx.GraphModule(cur_tracer.root, fx_graph) - modules: Dict[str, nn.Module] = dict(model.named_modules()) - - class MklSupport(Enum): - NO = 1 - YES = 2 - UNKNOWN = 3 - - # Inserts to_mkldnn and to_dense around every node we want to be a MKLDNN node. - # If the op is in `mkldnn_supported` then we always treat it as a MKLDNN node. - # However, if it's in `mkldnn_supported_unknown`, then we only treat it as - # a MKLDNN node if its inputs are MKLDNN nodes. - for node in list(fx_graph.nodes): - supports_mkldnn = MklSupport.NO - if node.op == 'call_module': - cur_module = modules[node.target] - if type(cur_module) in mkldnn_supported: - supports_mkldnn = MklSupport.YES - sample_parameter = next(cur_module.parameters(), None) - if sample_parameter is not None: - assert(sample_parameter.dtype == torch.float), "this pass is only for torch.float modules" - assert(sample_parameter.device == torch.device('cpu')), "this pass is only for CPU modules" - elif node.op == 'call_function': - if node.target in mkldnn_supported: - supports_mkldnn = MklSupport.YES - elif node.target in mkldnn_supported_unknown: - supports_mkldnn = MklSupport.UNKNOWN - - if supports_mkldnn != MklSupport.NO: - if supports_mkldnn == MklSupport.UNKNOWN: - if not any([arg.target == 'to_dense' for arg in node.args]): - continue - with fx_graph.inserting_before(node): - mkldnn_args = fx.map_arg(node.args, lambda n: fx_graph.call_method('to_mkldnn', (n, ))) - - node.args = cast(Tuple[fx.node.Argument], mkldnn_args) - - with fx_graph.inserting_after(node): - dense_x = fx_graph.create_node('call_method', 'to_dense', (node,)) - node.replace_all_uses_with(dense_x) - dense_x.args = (node,) - - # Does pre-conversion of all modules into MKLDNN (when possible) - old_modules = modules_to_mkldnn(list(fx_graph.nodes), modules) - fx_graph.old_modules = old_modules # type: ignore[attr-defined] - - # optimizes all a -> to_dense -> to_mkldnn -> b patterns into a -> b - for node in fx_graph.nodes: - if node.op == 'call_method' and node.target == 'to_dense': - prv_node = node.args[0] - users = list(node.users) - for user in users: - if user.op == 'call_method' and user.target == 'to_mkldnn': - user.replace_all_uses_with(prv_node) - fx_graph.erase_node(user) - if len(node.users) == 0: - fx_graph.erase_node(node) - - - num_nodes = len(fx_graph.nodes) - uf = UnionFind(num_nodes) - - def get_color(n): - if hasattr(n, 'color'): # Current node is part of a MKL subgraph - return uf.find(n.color) - if hasattr(n, 'start_color'): # Current node is input to MKL subgraph - return uf.find(n.start_color) - return None - - - # This code is to find each MKLDNN subgraph. Each MKLDNN subgraph consists - # of input nodes (which are only `to_mkldnn` calls), output nodes - # (`to_dense` calls), and intermediate nodes, which are run entirely on - # MKLDNN layout tensors. - # - # Specifically, this code does a flood fill on a directed acyclic graph - # (DAG), starting from each possible "start node" (i.e: `to_mkldnn` nodes). - # If every node only had one input, this would be sufficient. However, in - # the case that a node has multiple inputs coming from different start - # nodes (i.e. colors), we need to join these 2 colors into 1. That's done - # using a Disjoint Set Union. - for cur_idx, node in enumerate(fx_graph.nodes): - if node.op == 'call_method' and node.target == 'to_mkldnn': - node.start_color = cur_idx - uf.make_set(cur_idx) - elif node.op == 'call_method' and node.target == 'to_dense': - assert(get_color(node.args[0]) is not None) - node.end_color = get_color(node.args[0]) - else: - cur_colors = [get_color(i) for i in node.all_input_nodes if isinstance(i, fx.Node) if get_color(i) is not None] - - if len(cur_colors) == 0: - continue - assert(not any(i is None for i in cur_colors)) - cur_colors = sorted(cur_colors) - node.color = cur_colors[0] - for other_color in cur_colors[1:]: - uf.join(cur_colors[0], other_color) - - - mkldnn_graphs: Dict[int, MklSubgraph] = defaultdict(lambda: MklSubgraph(fx_graph)) - for node in fx_graph.nodes: - if hasattr(node, 'color'): - mkldnn_graphs[uf.find(node.color)].nodes.append(node) - if hasattr(node, 'start_color'): - mkldnn_graphs[uf.find(node.start_color)].start_nodes.append(node) - if hasattr(node, 'end_color'): - mkldnn_graphs[uf.find(node.end_color)].end_nodes.append(node) - - - # Now that we have all the subgraphs, we need to decide which MKLDNN - # subgraphs we actually want to keep in MKLDNN. - for graph in mkldnn_graphs.values(): - if not use_mkl_heuristic(graph): - for node in graph.start_nodes + graph.end_nodes: - prv = node.args[0] - node.replace_all_uses_with(prv) - fx_graph.erase_node(node) - reset_modules(graph.nodes, modules, old_modules) - - mkldnn_conversions = 0 - for node in fx_graph.nodes: - if node.target == 'to_mkldnn' or node.target == 'to_dense': - mkldnn_conversions += 1 - - logging.getLogger(__name__).info(f"mkldnn conversions: {mkldnn_conversions}") - fx_graph.lint() - result = fx.GraphModule(model, fx_graph) - return result diff --git a/pippy/fx/experimental/partitioner_utils.py b/pippy/fx/experimental/partitioner_utils.py deleted file mode 100644 index 334ef5d94..000000000 --- a/pippy/fx/experimental/partitioner_utils.py +++ /dev/null @@ -1,320 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates -from enum import Enum -from typing import NamedTuple, Dict, List, Set - -from pippy.fx.node import Node, map_arg - - -class Partition: - """Partition class contains all the information about an individual partition. - It also provides necessary methods for manipulation the partition. - """ - - def __init__(self, partition_id: int) -> None: - self.nodes: Set[Node] = set() - self.partition_id = partition_id - self.parents: Set["Partition"] = set() - self.children: Set["Partition"] = set() - self.bfs_level: int = -1 - self.used_mem_bytes: int = 0 - self.logical_device_ids: List[int] = [] - - def __str__(self): - return str(self.partition_id) - - def recalculate_mem_size(self): - self.used_mem_bytes = 0 - for node in self.nodes: - self.used_mem_bytes += get_extra_size_of(node, self.nodes) - - def add_node(self, node): - input_nodes: Dict[Node, None] = {} - map_arg(node.args, lambda n: input_nodes.setdefault(n)) - map_arg(node.kwargs, lambda n: input_nodes.setdefault(n)) - # Add current node's input nodes if they are placeholder or constants - for n in input_nodes: - if n.op in {"placeholder", "get_attr"}: - self.nodes.add(n) - self.nodes.add(node) - self.recalculate_mem_size() - - def remove_node(self, node): - # Remove a node only if the node is in the partition - if node in self.nodes: - self.nodes.remove(node) - # Collect the node's input nodes - input_nodes: Dict[Node, None] = {} - map_arg(node.args, lambda n: input_nodes.setdefault(n)) - map_arg(node.kwargs, lambda n: input_nodes.setdefault(n)) - # Check if an input node is a placeholder or get_attr, - # and this input node is not used by some other nodes in this partition, - # the remove this input node - for input_node in input_nodes: - if all( - [n not in self.nodes for n in input_node.users] - ) and input_node.op in {"placeholder", "get_attr"}: - self.nodes.remove(input_node) - self.recalculate_mem_size() - - -class Device(NamedTuple): - name: str - available_mem_bytes: int - logical_id: int - - -class NodeLatency(NamedTuple): - # Latency due to the memory bandwidth - mem_latency_sec: float - # Latency due to the computation - computer_latency_sec: float - - -class PartitionLatency(NamedTuple): - # Sum of all nodes' memory latency on the critical path - mem_latency_sec: float - # Sum of all nodes' compute latency on the critical path - computer_latency_sec: float - # Latency of the critical path - overall_latency_sec: float - - -class PartitionMode(Enum): - size_based = 0 - sparse_nn = 1 - cost_aware = 2 - kl_based = 3 - aot_based = 4 - - -class PartitionerConfig(NamedTuple): - devices: List[Device] - mode: PartitionMode = PartitionMode.size_based - transfer_rate_bytes_per_sec: float = 0.0 - node_to_latency_mapping: Dict[Node, NodeLatency] = {} - node_to_partition_mapping: Dict[Node, int] = {} - partition_to_logical_device_mapping: Dict[int, List[int]] = {} - # Saturate host by replicating partitions to the remaining idle devices. - saturate_host: bool = False - - -def get_extra_size_of(node: Node, nodes: Set[Node]) -> int: - """Given a node and a set of nodes, - this function return the extra size that needed - if this node is included in this set. - """ - # Find all its input nodes - input_nodes: Dict[Node, None] = {} - map_arg(node.args, lambda n: input_nodes.setdefault(n)) - map_arg(node.kwargs, lambda n: input_nodes.setdefault(n)) - # Calculate total size of related nodes - total_size_of_input_nodes = 0 - for n in input_nodes: - # Make sure this node hasn't been in this set yet - if n not in nodes: - size_bytes = getattr(n, "size_bytes", None) - if size_bytes: - total_size_of_input_nodes += size_bytes.output_size - else: - raise RuntimeError("node has no size_bytes attr") - # Don't forget the op node itself - size_bytes = getattr(node, "size_bytes", None) - if size_bytes: - total_size_of_input_nodes += size_bytes.total_size - else: - raise RuntimeError("node has no size_bytes attr") - return total_size_of_input_nodes - - -def get_latency_of_one_partition( - partition: Partition, node_to_latency_mapping: Dict[Node, NodeLatency] -) -> PartitionLatency: - """Given a partiton and its nodes' latency, return a PartitionLatency for this partition""" - - def get_top_nodes(partition: Partition) -> List[Node]: - """Given a partition, return a list of nodes on the top bfs level""" - top_nodes: List[Node] = [] - for node in partition.nodes: - # Skip placeholder and get_attr nodes - if node.op in {"placeholder", "get_attr"}: - continue - input_nodes: Dict[Node, None] = {} - map_arg(node.args, lambda n: input_nodes.setdefault(n)) - map_arg(node.kwargs, lambda n: input_nodes.setdefault(n)) - # If a node has no input nodes in this partition, - # or its input nodes in this partition are placeholders and get_attrs - # this node is on the top bfs level in this partition - if not any( - [ - n in partition.nodes and n.op not in {"placeholder", "get_attr"} - for n in input_nodes - ] - ): - top_nodes.append(node) - return top_nodes - - def dfs_helper(node: Node, partition_latency) -> PartitionLatency: - """Given a top node of a partition, this function returns - the latency of the critical path in the partition - """ - node_latency = node_to_latency_mapping[node] - # Calculate the current overall latency of the partition - overall_latency_sec = partition_latency.overall_latency_sec + max( - node_latency.computer_latency_sec, node_latency.mem_latency_sec - ) - # Update the mem latency of this path - mem_latency_sec = ( - partition_latency.mem_latency_sec + node_latency.mem_latency_sec - ) - # Update the compute latency of this path - computer_latency_sec = ( - partition_latency.computer_latency_sec + node_latency.computer_latency_sec - ) - # Get all users of this node that are in this partition - users = set(node.users).intersection(partition.nodes) - if users: - max_latency = PartitionLatency( - mem_latency_sec=0.0, computer_latency_sec=0.0, overall_latency_sec=0.0 - ) - for n in users: - # Get new partition latency recursively - new_partition_latency = dfs_helper( - n, - PartitionLatency( - mem_latency_sec, computer_latency_sec, overall_latency_sec - ), - ) - if ( - new_partition_latency.overall_latency_sec - > max_latency.overall_latency_sec - ): - max_latency = new_partition_latency - return max_latency - # If there is no user, the node is at bottom of the partition - return PartitionLatency( - mem_latency_sec, computer_latency_sec, overall_latency_sec - ) - - # Main part starts - # Get all top level nodes of this partition - top_nodes = get_top_nodes(partition) - critical_path_latency = PartitionLatency( - mem_latency_sec=0.0, computer_latency_sec=0.0, overall_latency_sec=0.0 - ) - # Go through all top nodes and find the largest latency (critical pass latency) - for node in top_nodes: - partition_latency = dfs_helper( - node, - PartitionLatency( - mem_latency_sec=0.0, computer_latency_sec=0.0, overall_latency_sec=0.0 - ), - ) - if ( - partition_latency.overall_latency_sec - > critical_path_latency.overall_latency_sec - ): - critical_path_latency = partition_latency - return critical_path_latency - - -def get_partition_to_latency_mapping( - partitions: List[Partition], node_to_latency_mapping: Dict[Node, NodeLatency] -) -> Dict[Partition, PartitionLatency]: - """Given all the partitions and node_to_latency_mapping dictionary, - return a mapping dictionary of each partition to its overall latency - """ - partition_to_latency_mapping: Dict[Partition, PartitionLatency] = {} - # Go through each partition and get its latency - for partition in partitions: - partition_latency = get_latency_of_one_partition( - partition, node_to_latency_mapping - ) - partition_to_latency_mapping[partition] = partition_latency - return partition_to_latency_mapping - - -def get_comm_latency_between( - parent_partition: Partition, - child_partition: Partition, - transfer_rate_bytes_per_sec: float, -): - """Given two partitions (parent and child), - calculate the communication latency between the two. - """ - # If two partitions are on the same device, the comm latency is 0. - if ( - parent_partition.logical_device_ids != [] - and child_partition.logical_device_ids != [] - and parent_partition.logical_device_ids == child_partition.logical_device_ids - ): - return 0.0 - # Keep tracking the communication size between parent and child - comm_size = 0 - # Keep tracking all the counted node - visited_nodes = set() - # Go through all nodes in the child partition - # If a node has input nodes from the parent partition, - # the output size of those input nodes will be counted - # and added to comm_size - for node in child_partition.nodes: - input_nodes: Dict[Node, None] = {} - map_arg(node.args, lambda n: input_nodes.setdefault(n)) - map_arg(node.kwargs, lambda n: input_nodes.setdefault(n)) - for n in input_nodes: - if n in parent_partition.nodes and n not in visited_nodes: - size_bytes = getattr(n, "size_bytes", None) - if size_bytes is not None: - comm_size += size_bytes.output_size - visited_nodes.add(n) - return comm_size / transfer_rate_bytes_per_sec - - -def get_latency_of_partitioned_graph( - partitions: List[Partition], - partition_to_latency_mapping: Dict[Partition, PartitionLatency], - transfer_rate_bytes_per_sec: float, -): - """Given all paritions in a graph, find the critical path among all partitions - and return its latency as the latency of the whole graph - """ - - def dfs_helper(partition: Partition, latency_so_far_sec: float) -> float: - """This function helps to recursively get the latency of a path of partitions""" - # Update latency by adding current partition's latency - latency_so_far_sec += partition_to_latency_mapping[ - partition - ].overall_latency_sec - children = partition.children - if partition.children: - max_latency_sec = 0.0 - for child in partition.children: - # Calculate latency between - comm_latency_sec = get_comm_latency_between( - partition, child, transfer_rate_bytes_per_sec - ) - new_latency_sec = dfs_helper( - child, latency_so_far_sec + comm_latency_sec - ) - if new_latency_sec > max_latency_sec: - max_latency_sec = new_latency_sec - return max_latency_sec - return latency_so_far_sec - - def get_top_partitions(partitions: List[Partition]) -> List[Partition]: - """This function is to return all the partitions without parents - as the starting points of all the paths - """ - top_partitions = [] - for partition in partitions: - # If a partition has no parents, then it is a top partition - if len(partition.parents) == 0: - top_partitions.append(partition) - return top_partitions - - top_partitions = get_top_partitions(partitions) - critical_path_latency_sec = 0.0 - for partition in top_partitions: - latency_sec = dfs_helper(partition, 0.0) - if latency_sec > critical_path_latency_sec: - critical_path_latency_sec = latency_sec - return critical_path_latency_sec diff --git a/pippy/fx/experimental/proxy_tensor.py b/pippy/fx/experimental/proxy_tensor.py deleted file mode 100644 index f223d1290..000000000 --- a/pippy/fx/experimental/proxy_tensor.py +++ /dev/null @@ -1,683 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates -# Copyright (c) Facebook, Inc. and its affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. -import contextlib -import functools -import inspect -import operator -import weakref -from contextlib import contextmanager, nullcontext -from dataclasses import dataclass -from typing import Any, Callable, Dict, List, Optional, Tuple, Union - -import torch -import torch.utils._pytree as pytree -from torch._dispatch.python import enable_python_dispatcher -from torch._subclasses import FakeTensor -from torch._subclasses.fake_tensor import FakeTensorMode -from torch.utils._python_dispatch import TorchDispatchMode, _pop_mode_temporarily, _get_current_dispatch_mode - -import pippy -import pippy.fx as fx -from pippy.fx import Proxy -from pippy.fx import Tracer, GraphModule -from pippy.fx.passes.shape_prop import _extract_tensor_metadata -from .symbolic_shapes import ShapeEnv, SymDispatchMode, PySymInt, PySymFloat - -__all__ = ["PythonKeyTracer", "dispatch_trace", "make_fx", "DecompositionInterpreter", "get_proxy", "has_proxy"] -aten = torch.ops.aten -prim = torch.ops.prim - -CURRENT_DECOMPOSITION_TABLE: Dict[torch._ops.OpOverload, Callable] = {} - -CONSTANT_NUMEL_LIMIT = 1 - - -def fake_signature(fn, nargs): - """FX gets confused by varargs, de-confuse it""" - argnames = ",".join(f"arg{i}" for i in range(nargs)) - return eval(f"lambda {argnames}: fn({argnames})", {"fn": fn}) - -@contextmanager -def decompose(decomposition_table): - global CURRENT_DECOMPOSITION_TABLE - old_decomposition_table = CURRENT_DECOMPOSITION_TABLE - CURRENT_DECOMPOSITION_TABLE = decomposition_table - try: - yield CURRENT_DECOMPOSITION_TABLE - finally: - CURRENT_DECOMPOSITION_TABLE = old_decomposition_table - -# ensure we cannot collide with other properties -proxy_slot = object() -no_default = object() - -def set_proxy_slot(obj, tracer, proxy): - d = obj.__dict__.setdefault(proxy_slot, weakref.WeakKeyDictionary()) - assert isinstance(d, weakref.WeakKeyDictionary) - d[tracer] = proxy - -def has_proxy_slot(obj, tracer): - return get_proxy_slot(obj, tracer, False, lambda _: True) - -# the default argument is what to return if the slot is not set. -# the transform argument is handy if you need to extract a subfield from -# the successfully looked up result (but NOT the default.) -def get_proxy_slot(obj, tracer, default=no_default, transform=lambda x: x): - d = obj.__dict__.get(proxy_slot) - if not d: - if default is no_default: - raise KeyError(f"{obj} is not tracked with proxy for {tracer}") - return default - assert isinstance(d, weakref.WeakKeyDictionary) - if tracer not in d: - if default is no_default: - raise KeyError(f"{obj} is not tracked with proxy for {tracer}") - else: - return default - return transform(d[tracer]) - - -def get_proxy_slots(obj): - return obj.__dict__.get(proxy_slot) - - -# Gets the proxy for a tensor, if it exists. -def get_proxy(obj): - res = get_proxy_slots(obj) - if res is None: - return None - vals = tuple(res.values()) - assert len(vals) == 1 - return vals[0] - -def has_proxy(obj): - return get_proxy(obj) is not None - -def set_meta(proxy, val): - if isinstance(val, FakeTensor): - proxy.node.meta['val'] = val - proxy.node.meta['tensor_meta'] = _extract_tensor_metadata(val) - elif isinstance(val, PySymInt): - proxy.node.meta['val'] = val - elif isinstance(val, torch.Tensor): - if not val.is_sparse: - proxy.node.meta['tensor_meta'] = _extract_tensor_metadata(val) - return proxy - -def thunkify(f, *args, **kwargs): - """ - Delays computation of f until it's called again - Also caches the result - """ - return functools.lru_cache(1)(functools.partial(f, *args, **kwargs)) - -def track_tensor(tensor, proxy, *, constant, tracer): - def try_set_proxy_slot(outer_s, proxy_callable, *args): - assert callable(proxy_callable) - if isinstance(outer_s, SymInt): - inner_s = outer_s.get_pyobj() - assert isinstance(inner_s, PySymInt) - - set_proxy_slot(inner_s, tracer, thunkify(proxy_callable, inner_s, *args)) - - # The basic idea is that we need to associate each tensor/SymInt - # with a Proxy. How do we setup this association? We just store - # the proxy on the proxy slot of the object, keyed on the tracer - # (so that if we have multiple tracers at the same time, they - # don't clobber each other.) - for i, s in enumerate(tensor.shape): - try_set_proxy_slot(s, lambda x, i: set_meta(torch.ops.aten.sym_size(proxy, i), x), i) - - for i, s in enumerate(tensor.stride()): - try_set_proxy_slot(s, lambda x, i: set_meta(torch.ops.aten.sym_stride(proxy, i), x), i) - - try_set_proxy_slot(tensor.numel(), lambda x: set_meta(torch.ops.aten.sym_numel(proxy), x)) - try_set_proxy_slot(tensor.storage_offset(), lambda x: set_meta(torch.ops.aten.sym_storage_offset(proxy), x)) - set_proxy_slot(tensor, tracer, _ProxyTensor(proxy, constant)) - -def track_tensor_tree(inner_res, proxy_res, *, constant, tracer): - def wrap_with_proxy(e, proxy, constant): - if isinstance(e, torch.Tensor): - track_tensor(e, proxy, tracer=tracer, constant=constant) - set_meta(proxy, e) - elif isinstance(e, list): - # example use case: allreduce_ returns ([tensor], work) - for idx, ee in enumerate(e): - wrap_with_proxy(ee, proxy[idx], get_constant(idx)) - - def get_constant(idx): - if constant is None: - return None - else: - return constant[idx] - - # Unfortunately, tree_map cannot directly be used here. As the resulting - # object may be a proxy that represents a tuple, we may need to - # explicitly unwrap the proxy by simulating the flattening operations. - if isinstance(inner_res, tuple) or isinstance(inner_res, list): - for idx, e in enumerate(inner_res): - wrap_with_proxy(e, proxy_res[idx], get_constant(idx)) - elif isinstance(inner_res, torch.Tensor): - wrap_with_proxy(inner_res, proxy_res, constant) - - return inner_res - - -def maybe_disable_fake_tensor_mode(): - # TODO: figure out if this API generally makes sense and bake it into the - # library - mb_fake_mode = _get_current_dispatch_mode() - if isinstance(mb_fake_mode, FakeTensorMode): - return _pop_mode_temporarily() - else: - return nullcontext() - - -@dataclass -class _ProxyTensor: - proxy: Proxy - constant: Optional[torch.Tensor] - - -def fetch_sym_proxy(tracer): - def inner(e): - n = e.get_pyobj() - if n.constant is not None: - return n.constant - else: - # NB: we REQUIRE all symints to be tracked - return get_proxy_slot(n, tracer)() - return inner - - -def fetch_tensor_proxy(tracer): - return lambda t: get_proxy_slot(t, tracer, t) - -HANDLED_TYPES = (torch.Tensor, torch.nn.Parameter) - -def proxy_call(proxy_mode, func, args, kwargs): - def can_handle_tensor(x): - return type(x) in HANDLED_TYPES or has_proxy_slot(x, proxy_mode.tracer) - - # If there are any tensor subclasses, we need to handle those tensor subclasses first - # TODO: we could use types to test this - if not pytree.tree_all_only(torch.Tensor, can_handle_tensor, (args, kwargs)): - return NotImplemented - - if func in CURRENT_DECOMPOSITION_TABLE: - with proxy_mode: - r = CURRENT_DECOMPOSITION_TABLE[func](*args, **kwargs) - if r is not NotImplemented: - return r - - with proxy_mode: - r = func.decompose(*args, **kwargs) - if r is not NotImplemented: - return r - - tracer = proxy_mode.tracer - f_args, f_kwargs = pytree.tree_map_only(torch.Tensor, fetch_tensor_proxy(tracer), (args, kwargs)) - - # If there are SymInts, we also should not consider this constant. - # However, fake tensor handling of SymInts is sufficiently broken that - # I couldn't write a test for this case - all_constant = ( - pytree.tree_all_only(_ProxyTensor, lambda t: t.constant is not None, (f_args, f_kwargs)) - # TODO: maybe constant SymInts should also be allowed? Not sure if - # this can happen - and pytree.tree_all_only((SymInt, SymFloat), lambda _: False, (args, kwargs)) - ) - - if torch.Tag.data_dependent_output in func.tags: # type: ignore[attr-defined] - # Check if all of the Tensor inputs are constants - if all_constant: - const_args, const_kwargs = pytree.tree_map_only( - _ProxyTensor, lambda t: t.constant, (f_args, f_kwargs) - ) - with maybe_disable_fake_tensor_mode(): - return func(*const_args, **const_kwargs) - raise RuntimeError( - "It appears that you're trying to get value out of a tracing tensor - erroring out! " - "It's likely that this is caused by data-dependent control flow or similar." - ) - proxy_args, proxy_kwargs = pytree.tree_map_only( - (SymInt, SymFloat), - fetch_sym_proxy(proxy_mode.tracer), - pytree.tree_map_only(_ProxyTensor, lambda e: e.proxy, (f_args, f_kwargs)) - ) - - # When we trace through a torch.tensor invocation, you never actually - # see a torch.ops.aten.tensor call. Instead, the way this function is - # implemented internally is that we allocate a plain tensor (this is - # *guaranteed* to be a plain tensor, we disable all modes when doing - # so), and then call at::lift_fresh on it (to give modes a chance to do - # their stuff). Furthermore, the tensor argument to lift_fresh is guaranteed - # to be freshly allocated, so we want lift_fresh to be a no-op (directly - # returning the input argument). - # - # Here is the basic problem: when we trace this sequence of executions - # into an FX graph, what happens to this call sequence? Traditionally, - # tensor constants get interned as buffers on the FX GraphModule. But - # this is dangerous. Consider: - # - # x = torch.tensor(1) - # x.add_(2) - # - # Naively, this traces into: - # - # t = self._tensor_constant0 # initialized to torch.tensor(1) - # x = torch.ops.aten.lift_fresh(t) - # x.add_(2) - # - # If lift_fresh returns t directly, the subsequent add_ call will - # modify the tensor constant. Really, the problem is we've violated - # the invariant the the argument to lift is fresh. So what we should - # preserve the invariant by replacing lift_fresh with lift_fresh_copy: - # - # t = self._tensor_constant0 # initialized to torch.tensor(1) - # x = torch.ops.aten.lift_fresh_copy(t) - # x.add_(2) - # - # This is what the overload modification does. - if func is torch.ops.aten.lift_fresh.default: - func = torch.ops.aten.lift_fresh_copy.default - - proxy_out = proxy_mode.tracer.create_proxy('call_function', func, proxy_args, proxy_kwargs, - name=proxy_mode.tracer.graph._target_to_str(func.overloadpacket.__name__)) - - # This makes DCE marginally less likely to DCE inplace operations. - # It is not strictly necessary - # Kind of a hacky way to test if an op is in-place or not - if func.overloadpacket.__name__[-1] == "_" and func.overloadpacket.__name__[0] != "_": - if isinstance(args[0], List): - # e.g., c10d::allreduce_ returns a list of tensors as the first element - # in the output. - for i, a in enumerate(args[0]): - a.proxy = proxy_out[0][i] - else: - args[0].proxy = proxy_out - - out = func(*args, **kwargs) - - # In some circumstances, we will be tracing in a situation where a tensor - # is *statically* known to be a constant (currently, this only happens if - # you run torch.tensor; deterministic factory functions like torch.arange - # don't get this treatment). When the tensor in question is small, it's - # helpful to due constant propagation in case we call item() (in which - # case we can return the constant value that is known, rather than give - # an error.) The logic here tests if constant propagation is possible - # (because all of the inputs are constant). If so, we disable fake tensor - # mode (if it is on) and do true compute on the constant. - # - # It's worth highlighting that we're making a policy decision here. - # There is a potential that the tensor is actually quite large, and we - # don't actually want to run the compute. The tensor being quite large - # is one of the reasons why factory functions don't get this treatment - # (since they can be quite large; if a parameter is initialized to a - # constant value it will be!) Similarly, there is also a potential - # to run an operator that blows up the size of a small tensor; we don't - # protect against this case, but we could force, e.g., only single - # element constant computation by testing the numel of the result before - # propagating const-ness. Similarly, we don't require the constant to - # live on CPU, but we could. - any_constant = pytree.tree_any_only(_ProxyTensor, lambda t: t.constant is not None, (f_args, f_kwargs)) - - constant = None - - # If this is a lift, the input tensor is guaranteed to be a - # constant, so we keep a copy of the original argument along so - # we can query it if we're asked to item() it at some later point - if func is torch.ops.aten.lift_fresh_copy.default and out.numel() <= CONSTANT_NUMEL_LIMIT: - with maybe_disable_fake_tensor_mode(): - constant = args[0].clone() - elif ( - torch.Tag.nondeterministic_seeded not in func.tags # type: ignore[attr-defined] - and all_constant - and any_constant - and pytree.tree_all_only(torch.Tensor, lambda t: t.numel() <= CONSTANT_NUMEL_LIMIT, out) - ): - # NB: do NOT include factories as constants - with maybe_disable_fake_tensor_mode(): - const_args, const_kwargs = pytree.tree_map_only( - _ProxyTensor, lambda t: t.constant, (f_args, f_kwargs) - ) - constant = func(*const_args, **const_kwargs) - else: - constant = None - - track_tensor_tree(out, proxy_out, constant=constant, tracer=tracer) - return out - - -class PythonKeyTracer(Tracer): - def __init__(self): - super().__init__() - - # In general, we don't want to make modules leaves. In principle, users of - # this tracer might want to override this in order to turn a couple specific - # modules into leaves in the traced graph. - def call_module( - self, m: torch.nn.Module, forward: Callable[..., Any], args: Tuple[Any, ...], kwargs: Dict[str, Any] - ) -> Any: - return forward(*args, **kwargs) - - # We don't want to turn getattr calls into proxies. So we just return the actual value. - def getattr(self, attr, attr_val, parameter_proxy_cache): - return attr_val - - def create_arg(self, a: Any): - if isinstance(a, torch.nn.Parameter): - for n, p in self.root.named_parameters(): - if a is p: - return self.create_node('get_attr', n, (), {}) - qualname: Optional[str] = None - - if not qualname: - i = 0 - while True: - qualname = f'_param_constant{i}' - if not hasattr(self.root, qualname): - break - i += 1 - setattr(self.root, qualname, a) - - return self.create_node('get_attr', qualname, (), {}) - elif isinstance(a, (SymInt, SymFloat)): - assert a.get_pyobj().constant is not None - return a.get_pyobj().constant - return super().create_arg(a) - - -def dispatch_trace( - root: Union[torch.nn.Module, Callable], - tracer: Tracer, - concrete_args: Optional[Tuple[Any, ...]] = None, -) -> GraphModule: - graph = tracer.trace(root, concrete_args) - name = root.__class__.__name__ if isinstance(root, torch.nn.Module) else root.__name__ - return GraphModule(tracer.root, graph, name) - - -def wrap_key(f, tensors, tracer): - flat_tensors, tensors_spec = pytree.tree_flatten(tensors) - - @functools.wraps(f) - def wrapped(*proxies): - flat_proxies, proxies_spec = pytree.tree_flatten(proxies) - assert len(flat_proxies) == len(flat_tensors) - track_tensor_tree(flat_tensors, flat_proxies, constant=None, tracer=tracer) - - out = f(*tensors) - return pytree.tree_map_only( - torch.Tensor, - lambda t: get_proxy_slot(t, tracer, t, lambda x: x.proxy), - out - ) - - return wrapped - - -class ProxyTorchDispatchMode(TorchDispatchMode): - def __init__(self, tracer): - self.tracer = tracer - self.enable_tracing = True - self.sym_mode = ProxySymDispatchMode(tracer) - self.trace_state = {} - self._managers = [] - - def __torch_dispatch__(self, func, types, args=(), kwargs=None): - with self.sym_mode.enable(False): - return self.inner_torch_dispatch(func, types, args, kwargs) - - def __enter__(self): - # sym mode first, then us... - m = self.sym_mode.enable(True) - self._managers.append(m) - m.__enter__() - return super().__enter__() - - def __exit__(self, exc_type, exc_value, traceback): - m = self._managers.pop() - # ...exit us first, then sym mode - b = super().__exit__(exc_type, exc_value, traceback) - if not b: - return m.__exit__(exc_type, exc_value, traceback) - else: - return m.__exit__(None, None, None) - - def inner_torch_dispatch(self, func, types, args=(), kwargs=None): - if not self.enable_tracing: - return func(*args, **kwargs) - - if func in [prim.device.default]: - return func(*args, **kwargs) - - out = proxy_call(self, func, args, kwargs) - return out - - -SymInt = torch.SymIntNode -SymFloat = torch.SymFloatNode - - -class ProxySymDispatchMode(SymDispatchMode): - def __init__(self, tracer): - super().__init__() - self.tracer = tracer - # When false, we don't trace operations. If you do this, you MUST - # call track_tensor/track_tensor_tree on all results of the operation - # to ensure we can adeduately track the results - self.enable_tracing = True - - @contextmanager - def enable(self, b): - old = self.enable_tracing - self.enable_tracing = b - try: - yield - finally: - self.enable_tracing = old - - def _compute_proxy(self, func, args, out): - n_args = tuple( - get_proxy_slot(a, self.tracer)().node if a.constant is None else a.constant - if isinstance(a, (PySymInt, PySymFloat)) else a - for a in args - ) - - # func doesn't have a __torch_function__ that Proxy can interpose, so - # we gotta do it manually - n_out = self.tracer.create_node("call_function", func, n_args, {}) - p_out = fx.Proxy(n_out, self.tracer) - set_meta(p_out, out) - return p_out - - def __sym_dispatch__(self, func, types, args, kwargs): - if not self.enable_tracing: - return func(*args, **kwargs) - - # Peephole optimize multiply by one - if func == operator.mul: - if isinstance(args[1], PySymInt) and args[1].constant == 1: - return args[0] - elif isinstance(args[0], PySymInt) and args[0].constant == 1: - return args[1] - - # For speed, we assume there are no nested data structures - # (otherwise we could use tree_map) - # We also assume there are no keyword arguments. - assert not kwargs - out = func(*args, **kwargs) - assert isinstance(out, (PySymInt, PySymFloat)), f"{func}(*{args}, **{kwargs}) = {out}" - - # Delays tracing out the proxies on this op until we actually need it - p_out_thunk = thunkify(self._compute_proxy, func=func, args=args, out=out) - set_proxy_slot(out, self.tracer, p_out_thunk) - return out - - -# TODO: I'm not sure what the point of this class is; you can just -# make_fx through a regular Interpreter -class DecompositionInterpreter(pippy.fx.Interpreter): - def __init__(self, module: pippy.fx.GraphModule, new_graph: pippy.fx.Graph, decomposition_table=None, **kwargs): - super().__init__(module, **kwargs) - self.new_graph = new_graph - self.tracer = pippy.fx.proxy.GraphAppendingTracer(self.new_graph) - self.decomposition_table = decomposition_table - if self.decomposition_table is None: - self.decomposition_table = {} - self.mode = ProxyTorchDispatchMode(self.tracer) - - def placeholder(self, target, args, kwargs): - out = super().placeholder(target, args, kwargs) - proxy = pippy.fx.Proxy(self.new_graph.placeholder(target), self.tracer) - track_tensor_tree(out, proxy, constant=None, tracer=self.tracer) - # TODO handle case where the first character of target is '*' - return out - - def get_attr(self, target, args, kwargs): - out = super().get_attr(target, args, kwargs) - proxy = pippy.fx.Proxy(self.new_graph.get_attr(target), self.tracer) - track_tensor_tree(out, proxy, constant=None, tracer=self.tracer) - return out - - # call_function, call_method, call_module get traced automatically by the outer mode. - - def output(self, target, args, kwargs): - out = super().output(target, args, kwargs) - - def unwrap(e): - return get_proxy_slot(e, self.tracer, e, lambda x: x.proxy.node) - self.new_graph.output(pytree.tree_map(unwrap, out)) - return out - - def run(self, *args, **kwargs): - # Should enter the mode at least once for being able to restore it later - # See: https://github.com/pytorch/pytorch/pull/82549#discussion_r934782025 - with decompose(self.decomposition_table), self.mode: - return super().run(*args, **kwargs) - - -def wrapper_and_args_for_make_fx(func, args, kwargs): - # make_fx doesn't support kwargs, so we need to do this flattening - # and then unflatten the args before calling func - flat_args, spec = pytree.tree_flatten((args, kwargs)) - - def wrapped(flat_args): - fn_args, fn_kwargs = pytree.tree_unflatten(flat_args, spec) - return func(*fn_args, **fn_kwargs) - return wrapped, flat_args - -@contextmanager -def disable_autocast_cache(): - old_value = torch.is_autocast_cache_enabled() - torch.set_autocast_cache_enabled(False) - try: - yield - finally: - torch.set_autocast_cache_enabled(old_value) - - -def make_fx(f, decomposition_table=None, tracing_mode="real"): - assert tracing_mode in ["real", "fake", "symbolic"] - - if decomposition_table is None: - decomposition_table = {} - - @functools.wraps(f) - def wrapped(*args): - phs = pytree.tree_map(lambda _: fx.PH, args) # type: ignore[attr-defined] - fx_tracer = PythonKeyTracer() - fake_tensor_mode: Any = nullcontext() - if tracing_mode == "real": - fake_tensor_mode = nullcontext() - elif tracing_mode == "fake": - fake_tensor_mode = FakeTensorMode(allow_fallback_kernels=True) - elif tracing_mode == "symbolic": - fake_tensor_mode = FakeTensorMode(allow_fallback_kernels=False) - else: - raise AssertionError(f"Unexpected tracing type: {tracing_mode}") - - python_dispatcher_mode: Any = nullcontext() - if tracing_mode == "symbolic": - python_dispatcher_mode = enable_python_dispatcher() - - proxy_mode = ProxyTorchDispatchMode(fx_tracer) - - def wrap_fake_concrete(x): - if isinstance(x, torch.Tensor): - return fake_tensor_mode.from_tensor(x) # type: ignore[attr-defined] - - return x - - shape_env = ShapeEnv() - sym_mode = proxy_mode.sym_mode - - # todo: Figure out a more informative name for symints - def wrap_fake_symbolic(x): - if isinstance(x, torch.Tensor): - return fake_tensor_mode.from_tensor(x, shape_env=shape_env) - return x - - wrap_fn_map = { - "real": lambda x: x, - "fake": wrap_fake_concrete, - "symbolic": wrap_fake_symbolic, - } - args = pytree.tree_map(wrap_fn_map[tracing_mode], args) - - if not hasattr(inspect.unwrap(f), '__code__') or inspect.unwrap(f).__code__.co_flags & inspect.CO_VARARGS: - # FX doesn't support varargs, so we gotta fake up a wrapper - # TODO: Would be nice to fix this at the source... - func = fake_signature(f, len(phs)) - else: - func = f - - # We disable the autocast cache as the autocast cache causes type conversions on parameters to - # check a cache, which introduces untracked tensors into the graph - with decompose(decomposition_table), fake_tensor_mode, python_dispatcher_mode, \ - sym_mode, proxy_mode, disable_autocast_cache(): # type: ignore[attr-defined] - t = dispatch_trace(wrap_key(func, args, fx_tracer), tracer=fx_tracer, concrete_args=tuple(phs)) - - # TODO: kind of a bad way to do it, should maybe figure out a better way - t.shape_env = shape_env # type: ignore[assignment] - return t - - return wrapped - - -def get_torch_dispatch_modes(): - return torch.utils._python_dispatch._get_current_dispatch_mode_stack() - - -@contextlib.contextmanager -def disable_proxy_modes_tracing(): - # TODO: This probably doesn't correctly also disable ProxySymDispatchMode - modes = get_torch_dispatch_modes() - proxy_tensor_modes = [m for m in modes if isinstance(m, ProxyTorchDispatchMode)] - olds = [m.enable_tracing for m in proxy_tensor_modes] - for proxy_mode in proxy_tensor_modes: - proxy_mode.enable_tracing = False - try: - yield - finally: - for proxy_mode, old in zip(proxy_tensor_modes, olds): - proxy_mode.enable_tracing = old - - -def get_isolated_graphmodule(func, args, kwargs, tracing_mode="real"): - """A helper function used to get the GraphModule for the given func. - - It's expected to be used in the ProxyTensor tracing context. - It detaches the args and kwargs from the current tracer so that the trace of - the current graph module can be created without any side-effects. - """ - wrapped, all_args = wrapper_and_args_for_make_fx(func, args, kwargs) - - with disable_proxy_modes_tracing(): - gm = make_fx(wrapped, tracing_mode=tracing_mode)(all_args) - return gm diff --git a/pippy/fx/experimental/refinement_types.py b/pippy/fx/experimental/refinement_types.py deleted file mode 100644 index 665c9d0d6..000000000 --- a/pippy/fx/experimental/refinement_types.py +++ /dev/null @@ -1,17 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates -class Equality: - def __init__(self, lhs, rhs): - self.lhs = lhs - self.rhs = rhs - - def __str__(self): - return f'{self.lhs} = {self.rhs}' - - def __repr__(self): - return f'{self.lhs} = {self.rhs}' - - def __eq__(self, other): - if isinstance(other, Equality): - return self.lhs == other.lhs and self.rhs == other.rhs - else: - return False diff --git a/pippy/fx/experimental/rewriter.py b/pippy/fx/experimental/rewriter.py deleted file mode 100644 index d09eba545..000000000 --- a/pippy/fx/experimental/rewriter.py +++ /dev/null @@ -1,122 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates -import ast -import inspect -import textwrap -import copy -import functools -from types import FunctionType -from typing import cast, Union, Callable, Dict, Optional, Any -from pippy.fx._symbolic_trace import Tracer -from pippy.fx.graph import Graph -from torch._sources import normalize_source_lines -import torch - -class AST_Rewriter(ast.NodeTransformer): - """ - Take a FunctionType object representing a `forward` method, then - perform an AST rewrite to swap out nodes that are not symbolically - traceable with a callsite to the FX alternative. - - To support swapping out an AST node, define a new `visit` method on - that node. For more details, see: - https://docs.python.org/3/library/ast.html#ast.NodeTransformer - """ - - def rewrite(self, fn: FunctionType): - - # Normalize the source lines - sourcelines, _ = inspect.getsourcelines(fn) - sourcelines = normalize_source_lines(sourcelines) - source = ''.join(sourcelines) - normalized_str = textwrap.dedent(source) - - # Rewrite the original AST - source_ast = ast.parse(normalized_str) - dest_ast = ast.fix_missing_locations(self.visit(source_ast)) - - # Pull out the compiled fucntion from the newly-created Module - code = compile(dest_ast, "", "exec") - globals_dict = copy.copy(fn.__globals__) - keys_before = set(globals_dict.keys()) - exec(code, globals_dict) - new_keys = list(set(globals_dict.keys()) - keys_before) - assert len(new_keys) == 1 - fn_compiled = globals_dict[new_keys[0]] - - # return the compiled function with the original globals - def change_func_globals(f, globals): - """Based on https://stackoverflow.com/a/13503277/2988730 (@unutbu)""" - # __globals__ is a private member of the function class - # so we have to copy the function, f, all of its member, except f.__globals__ - g = FunctionType( - f.__code__, - globals, - name=f.__name__, - argdefs=f.__defaults__, - closure=f.__closure__, - ) - g = functools.update_wrapper(g, f) - g.__kwdefaults__ = copy.copy(f.__kwdefaults__) - return g - # Return the correct FunctionType object - return change_func_globals(fn_compiled, globals=fn.__globals__) - - def visit_Assert(self, node): - """ - Swap out the Assert node (Python's `assert`) with a callsite to the - symbolically-traceable torch._assert function - """ - # Create the Call node - n = ast.parse('torch._assert()', mode='eval') - assert isinstance(n, ast.Expression) - call_node = n.body - assert isinstance(call_node, ast.Call) - msg = node.msg if node.msg else ast.Constant(value="", kind=None) - call_node.args = [node.test, msg] - - # Ensure that the new node conforms to the Python AST grammar - expr_wrapper = ast.Expr(value=call_node) - - # Return the new Call node to signify that we want to use it as - # a replacement for the original _assert node - return ast.copy_location(expr_wrapper, node) - - def visit_AnnAssign(self, node): - """ - Swap out Python's AnnAssign with an Assign node where the annotation function is called. - Example: - Original: - y: Tensor_Type(1,2,3, Dyn) = f2(x) - Output: - y = annotate(f2(x),Tensor_Type((1,2,3,Dyn))) - """ - return ast.Assign(targets=[node.target], value=ast.Call( - func=ast.Name(id='annotate', ctx=ast.Load()), - args=[node.value, node.annotation], keywords=[])) - - -class RewritingTracer(Tracer): - def trace(self, root: Union[torch.nn.Module, Callable], concrete_args: Optional[Dict[str, Any]] = None) -> Graph: - return super().trace(_rewrite(root), concrete_args) - - -def _rewrite(fn: Union[torch.nn.Module, Callable]) -> Union[torch.nn.Module, Callable]: - if isinstance(fn, torch.nn.Module): - # Rewrite this module's `forward` as well as the `forward`s of - # all of this module's recursive descendents. Return the new, - # rewritten module hierarchy. - def rewrite_module(m : torch.nn.Module): - class RewrittenModule(torch.nn.Module): - def __init__(self, orig): - super().__init__() - for k, v in orig.__dict__.items(): - if isinstance(v, torch.nn.Module): - self.__dict__[k] = copy.copy(rewrite_module(v)) - else: - self.__dict__[k] = copy.copy(v) - RewrittenModule.forward = AST_Rewriter().rewrite(cast(FunctionType, m.forward)) - return RewrittenModule(m) - return rewrite_module(fn) - else: - # Rewrite this single free function - return AST_Rewriter().rewrite(cast(FunctionType, fn)) diff --git a/pippy/fx/experimental/schema_type_annotation.py b/pippy/fx/experimental/schema_type_annotation.py deleted file mode 100644 index 93102a6b5..000000000 --- a/pippy/fx/experimental/schema_type_annotation.py +++ /dev/null @@ -1,112 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates -import torch -import pippy.fx -import inspect -from typing import Any, Dict, Optional, Tuple -from pippy.fx.node import Argument, Target -from torch._jit_internal import boolean_dispatched -from pippy.fx.operator_schemas import _torchscript_type_to_python_type - -from pippy.fx import Transformer - -class AnnotateTypesWithSchema(Transformer): - """ - Use Python function signatures to annotate types for `Nodes` within an FX graph. - This pulls out Python function signatures for: - - 1. Standard `torch.nn` Module calls - 2. `torch.nn.functional` calls - 3. Attribute fetches via `get_attr` - - Example usage: - - m = torchvision.models.resnet18() - - traced = pippy.fx.symbolic_trace(m) - - traced = AnnotateTypesWithSchema(traced).transform() - - """ - def __init__(self, module : torch.nn.Module, annotate_functionals : bool = True, - annotate_modules : bool = True, annotate_get_attrs : bool = True): - super().__init__(module) - self.annotate_functionals = annotate_functionals - self.annotate_modules = annotate_modules - self.annotate_get_attrs = annotate_get_attrs - - def call_function(self, target : Target, args : Tuple[Argument, ...], kwargs : Dict[str, Any]): - python_ret_type = None - if self.annotate_functionals and target.__module__ == 'torch.nn.functional': - target_for_analysis = target - if target in boolean_dispatched: - # HACK: `boolean_dispatch` as used in `torch.nn.functional` makes it so that we have - # a 2-way dispatch based on a boolean value. Here we check that the `true` and `false` - # branches of the dispatch have exactly the same signature. If they do, use the `true` - # branch signature for analysis. Otherwise, leave this un-normalized - assert not isinstance(target, str) - dispatched = boolean_dispatched[target] - if_true, if_false = dispatched['if_true'], dispatched['if_false'] - # TODO: can we emit the union of these? What are the implications on TorchScript - # compilation? - if inspect.signature(if_true).return_annotation != inspect.signature(if_false).return_annotation: - return super().call_function(target, args, kwargs) - target_for_analysis = if_true - - python_ret_type = self._extract_python_return_type(target_for_analysis) - - return_proxy = super().call_function(target, args, kwargs) - return_proxy.node.type = return_proxy.node.type if return_proxy.node.type else python_ret_type - return return_proxy - - def call_module(self, target : Target, args : Tuple[Argument, ...], kwargs : Dict[str, Any]): - python_ret_type = None - assert isinstance(target, str) - submod = self.fetch_attr(target) - if self.annotate_modules and hasattr(submod.__class__, '__name__'): - classname = submod.__class__.__name__ - if getattr(torch.nn, classname, None) == submod.__class__: - python_ret_type = self._extract_python_return_type(submod.forward) - return_proxy = super().call_module(target, args, kwargs) - return_proxy.node.type = return_proxy.node.type if return_proxy.node.type else python_ret_type - return return_proxy - - def get_attr(self, target : pippy.fx.node.Target, args : Tuple[Argument, ...], kwargs : Dict[str, Any]): - attr_proxy = super().get_attr(target, args, kwargs) - - if self.annotate_get_attrs: - module_itr = self.module - assert isinstance(target, str) - atoms = target.split('.') - for i, atom in enumerate(atoms): - if not hasattr(module_itr, atom): - raise RuntimeError(f'Node referenced nonextent target {".".join(atoms[:i])}!') - module_itr = getattr(module_itr, atom) - - maybe_inferred_ts_type = torch._C._jit_try_infer_type(module_itr) - if maybe_inferred_ts_type.success(): - python_type = _torchscript_type_to_python_type(maybe_inferred_ts_type.type()) - attr_proxy.node.type = python_type if not attr_proxy.node.type else attr_proxy.node.type - - return attr_proxy - - def _extract_python_return_type(self, target : Target) -> Optional[Any]: - """ - Given a Python call target, try to extract the Python return annotation - if it is available, otherwise return None - - Args: - - target (Callable): Python callable to get return annotation for - - Returns: - - Optional[Any]: Return annotation from the `target`, or None if it was - not available. - """ - assert callable(target) - try: - sig = inspect.signature(target) - except (ValueError, TypeError): - return None - - return sig.return_annotation if sig.return_annotation is not inspect.Signature.empty else None diff --git a/pippy/fx/experimental/symbolic_shapes.py b/pippy/fx/experimental/symbolic_shapes.py deleted file mode 100644 index 5817194e5..000000000 --- a/pippy/fx/experimental/symbolic_shapes.py +++ /dev/null @@ -1,472 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates -import torch -import torch.utils._pytree as pytree -from typing import Set, Dict, List, Type, Optional, cast # pylint: disable=unused-import -import operator -import functools -from functools import lru_cache, partial -import traceback -import collections -import textwrap -from torch._subclasses.meta_utils import MetaConverter - -try: - import sympy # type: ignore[import] - HAS_SYMPY = True -except ImportError: - HAS_SYMPY = False - -aten = torch.ops.aten # type: ignore[has-type] - -__all__ = [ - "has_symbolic_sizes_strides", "create_contiguous", "PySymInt", "ShapeEnv", - "SymDispatchMode", "PySymFloat", "sym_float", "FloorDiv" -] - -SYM_FUNCTION_MODE = None - -# We don't bother with the metaclass as all of the dispatching logic happens -# entirely from Python -# -# Didn't bother with ancestors for now, unlikely to have multiple modes for -# symints right now - - -# SymDispatchMode gets invoked whenever an operation is processed on -# a PySymInt. When this occurs, you get called at __sym_dispatch__ -# with the operation in question. This is symmetric to TorchDispatchMode -# but with some caveats: -# -# - In TorchDispatchMode, you get the same arguments as what a user -# invoked your API with; e.g., if you call torch.ops.aten.foo(a, b), -# you get (a, b) as args to your call. In SymDispatchMode, if -# you call a + b (where a and b are SymInts), you will get -# (a.get_pyobj(), b.get_pyobj()) as your args (these are PySymInts) -# -# - SymInt/PySymInt don't have FX proxy support (unlike, e.g., Tensor). -# So you have to manually call Tracer/create_node to write into -# the graph. See ProxySymDispatchMode for an example -# -class SymDispatchMode: - def __sym_dispatch__(self, func, types, args, kwargs): - raise NotImplementedError() - - def __enter__(self): - global SYM_FUNCTION_MODE - old = SYM_FUNCTION_MODE - if hasattr(self, "inner"): - raise RuntimeError(f"{self} has already been used as a mode. Please use a fresh version") - else: - self.inner = old - SYM_FUNCTION_MODE = self - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - global SYM_FUNCTION_MODE - SYM_FUNCTION_MODE = self.inner - -def has_symbolic_sizes_strides(elem): - return elem._has_symbolic_sizes_strides - -def create_contiguous(shape): - strides = [1] - for dim in reversed(shape[:-1]): - strides.append(dim * strides[-1]) - return list(reversed(strides)) - -def _handle_sym_dispatch(func, args, kwargs): - global SYM_FUNCTION_MODE - mode = SYM_FUNCTION_MODE - assert mode - SYM_FUNCTION_MODE = mode.inner - try: - # TODO: properly compute types - types: List[Type] = [] - return mode.__sym_dispatch__(func, types, args, kwargs) - finally: - SYM_FUNCTION_MODE = mode - -def sym_float(a): - if hasattr(a, '__sym_float__'): - return a.__sym_float__() - elif isinstance(a, torch._C.SymFloatNode): - return a - return float(a) - -# TODO: An incomplete list -# 1. Set variables to be equal when we do equality -# 2. Specialize on 0/1 when we do subtraction -class PySymInt(object): - """ - PySymInt objects are the primary "symbolic shape" objects that flow through - our program. They're what sit under FakeTensor, and contains our primary - implementation of symbolic shapes. - """ - def __init__(self, expr, shape_env, constant=None): - self.expr = expr - self.shape_env = shape_env - self.constant = constant - - def wrap(self, num): - return PySymInt(sympy.Integer(num), self.shape_env, constant=num) - - def __str__(self): - return f"{self.expr}" - - def __repr__(self): - return f"{self.expr}" - - # Today we error on calling int on a symbolic shape, as this is a very accessible footgun. - def __int__(self): - raise RuntimeError("Trying to extract a concrete int out of a symbolic int") - - # You can manually trigger a guard with this function - def guard_int(self, file, line): - # TODO: use the file/line for some useful diagnostic on why a - # guard occurred - return int(self.shape_env.evaluate_expr(self.expr)) - - def __sym_float__(self): - if SYM_FUNCTION_MODE: - return _handle_sym_dispatch(sym_float, (self,), {}) - # TODO: consider constant prop here - # TODO: wrapping the expr with sympy.Float doesn't seem to work, why - # not? - return PySymFloat(self.expr, self.shape_env) - - def __bool__(self): - return bool(self.shape_env.evaluate_expr(self.shape_env.replace(self.expr))) - -class PySymFloat: - def __init__(self, expr, shape_env, constant=None): - self.expr = expr - self.shape_env = shape_env - self.constant = constant - - def wrap(self, num): - return PySymFloat(sympy.Float(num), self.shape_env, constant=num) - - def __str__(self): - return f"{self.expr}" - -if HAS_SYMPY: - class FloorDiv(sympy.Function): - """ - We maintain this so that: - 1. We can use divisibility guards to simplify FloorDiv(a, b) to a / b. - 2. Printing out the expression is nicer (compared to say, representing a//b as (a - a % b) / b) - """ - nargs = (2,) - - @classmethod - def eval(cls, base, divisor): - if base == 0: - return sympy.Integer(0) - if divisor == 1: - return base - if isinstance(base, sympy.Integer) and isinstance(divisor, sympy.Integer): - return base // divisor - gcd = sympy.gcd(base, divisor) - if gcd != 1: - return FloorDiv( - sympy.simplify(base / gcd), sympy.simplify(divisor / gcd) - ) - -# Methods that have a `__foo__` as well as `__rfoo__` -reflectable_magic_methods = { - 'add': lambda a, b: a + b, - 'sub': lambda a, b: a - b, - 'mul': lambda a, b: a * b, - 'mod': lambda a, b: a % b, - 'truediv': lambda a, b: a / b, - 'floordiv': lambda a, b: FloorDiv(a, b) -} - -magic_methods = { - **reflectable_magic_methods, - 'eq': lambda a, b: sympy.Eq(a, b), - 'gt': lambda a, b: sympy.Gt(a, b), - 'lt': lambda a, b: sympy.Lt(a, b), - 'le': lambda a, b: sympy.Le(a, b), - 'ge': lambda a, b: sympy.Ge(a, b), -} - -float_magic_methods = {"add", "sub", "mul", "truediv"} - -def _make_magic(method, func, py_type): - func = lru_cache(256)(func) - - def magic_impl(self, other): - if SYM_FUNCTION_MODE: - return _handle_sym_dispatch(getattr(operator, method), (self, other), {}) - if isinstance(other, py_type): - other = other.expr - # TODO: consider constant prop here - expr = self.shape_env.replace(self.expr) - other = self.shape_env.replace(other) - out = func(expr, other) - out = sympy.expand(out) - if method in ["truediv"]: - return PySymFloat(out, self.shape_env) - else: - # TODO: relational operators actually technically return a - # PySymBool, this is a type error - return py_type(out, self.shape_env) - - # this should be wrapped transparently into torch.SymIntNode - setattr(py_type, method, magic_impl) - setattr(py_type, f"__{method}__", magic_impl) - if method in reflectable_magic_methods: - setattr(py_type, f"__r{method}__", magic_impl) - -for method, func in magic_methods.items(): - _make_magic(method, func, PySymInt) - -for method, func in magic_methods.items(): - if method not in float_magic_methods: - continue - _make_magic(method, func, PySymFloat) - -del method -del func - -def _lru_cache(fn, maxsize=None): - """ - Wrapper around lru_cache that clears when new info about shapes has been - updated. - - Use lru_cache if the output is always the same, regardless of the - constraints we know now (i.e. evaluate_expr) - - Use _lru_cache otherwise. - """ - fn_cache = lru_cache(maxsize)(fn) - prior_key = None - - @functools.wraps(fn) - def wrapper(self, *args, **kwargs): - nonlocal prior_key - if prior_key != self._get_key(): - prior_key = self._get_key() - fn_cache.cache_clear() - return fn_cache(self, *args, **kwargs) - - wrapper.cache_info = fn_cache.cache_info # type: ignore[attr-defined] - return wrapper - - - -class ShapeEnv(object): - def __init__(self): - self.guards = [] - # Maps symbolic ints to their original concrete values - # Currently populated from tensors - self.var_to_val: Dict["sympy.Symbol", "sympy.Integer"] = {} - # Maps from sympy ints to expressions representing them - # Populated from equality guards (i.e. a.shape[0] == b.shape[0]) - self.replacements: Dict["sympy.Symbol", "sympy.Expr"] = {} # - # Set holds a % b expressions that evaluate to 0. - self.divisible: Set["sympy.Expr"] = set() - # Duck-shaping says that if two input tensors have the same size, - # they get assigned the same symbolic variable - self.val_to_symint: Dict[int, torch.SymIntNode] = {} - - def _get_key(self): - """ - Defines the current "state" of the guards we've accumulated in this ShapeEnv. - Determines when we need to invalidate our cache - """ - return (len(self.replacements), len(self.divisible)) - - # NB: This is only called for input symbolic sizes; intermediate symbolic - # sizes are allocated via a different mechanism - def create_symint(self, name, val): - assert val >= 0 - if not HAS_SYMPY: - raise RuntimeError("Need sympy installed to create symbolic shapes") - - # TODO: Put 0/1 specialization in guards - if val == 0 or val == 1: - return val - # This implements duck-shaping: input sizes that match are assigned - # the same symint - # TODO: Create a guard whenever this happens - # TODO: But how do I represent the guard in this case? - if val in self.val_to_symint: - return self.val_to_symint[val] - sympy_expr = sympy.Symbol(name, positive=True, integer=True) - py_sym_int = PySymInt(sympy_expr, self) - cpp_sym_int = torch.SymIntNode.new_symint(py_sym_int) # type: ignore[attr-defined] - self.var_to_val[sympy_expr] = sympy.Integer(val) - self.val_to_symint[val] = cpp_sym_int - return cpp_sym_int - - def evaluate_guards_for_args(self, *args): - new_env = ShapeEnv() - # NB: This must be kept in sync with create_aot_dispatcher_function - # and wrap_fake_symbolic - meta_converter = MetaConverter() - pytree.tree_map_only(torch.Tensor, partial(meta_converter, shape_env=new_env), args) - return all(guard.xreplace(new_env.var_to_val) == value for guard, value, _ in self.guards) - - def get_nontrivial_guards(self): - return [(self.simplify(guard), val) for guard, val, _ in self.guards if self._maybe_evaluate_static(guard) is None] - - def format_guards(self, verbose=False): - def format_val(guard, val): - if val is sympy.true: - return str(guard) - elif val is sympy.false: - return f"Not({guard})" - else: - return f"Eq({guard}, {val})" - - def format_tb(tb): - if not verbose: - return "" - return f"\n Guarded at:\n{textwrap.indent(tb, ' ')}" - - return '\n'.join(f" - {format_val(guard, val)}{format_tb(tb)}" for guard, val, tb in self.guards) - - def get_shape_groups(self): - shape_groups = collections.defaultdict(list) - for k, v in self.replacements.items(): - shape_groups[v].append(k) - return shape_groups - - @_lru_cache - def _maybe_evaluate_static(self, expr: "sympy.Expr") -> "Optional[sympy.Expr]": - """ - Tries to evaluate expr without introducing guards - """ - expr = self.simplify(expr) - # Simplifies assuming that shape vars > 1 (since we cache on 0/1 shape values) - symbols = list(expr.free_symbols) - new_shape_env = { - k: sympy.Symbol(f"shape_{idx}", positive=True, integer=True) + 1 - for idx, k in enumerate(symbols) - } - new_expr = expr.xreplace(new_shape_env) - floor_div_replace = {} - for atom in new_expr.atoms(FloorDiv): - floor_div_replace[atom] = sympy.floor(atom.args[0] / atom.args[1]) - new_expr = sympy.expand(new_expr.xreplace(floor_div_replace)) - if len(list(new_expr.free_symbols)) == 0: - return new_expr - return None - - @_lru_cache - def replace(self, expr: "sympy.Expr") -> "sympy.Expr": - replacements = {s: self._find(cast(sympy.Symbol, s)) for s in expr.free_symbols} - return sympy.expand(expr.xreplace(replacements)) - - @_lru_cache - def _update_divisible(self): - new_divisible = set() - for k in self.divisible: - res = self.replace(k) - if len(res.free_symbols) > 0: - new_divisible.add(k) - - self.divisible = new_divisible - - @_lru_cache - def simplify(self, expr: "sympy.Expr") -> "sympy.Expr": - expr = self.replace(expr) - if expr.has(FloorDiv): - self._update_divisible() - div_replacements = {} - for atom in expr.atoms(FloorDiv): - base, divisor = atom.args - if self.replace(base % divisor) in self.divisible: - div_replacements[atom] = base / divisor - expr = expr.xreplace(div_replacements) - expr = sympy.expand(expr) - return expr - - @lru_cache(256) - def size_hint(self, expr: "sympy.Expr"): - """ - Gets a size hint for a given expression from the underlying shapes we had. - Does not introduce a guard, so only use this when you can guarantee that - your code is still valid for arbitrary shapes (such as optimization decisions) - """ - result_expr = sympy.expand(expr).xreplace(self.var_to_val) - assert len(result_expr.free_symbols) == 0, "Size hint has variables we don't have underlying values for" - return result_expr - - @_lru_cache - def _find(self, a: "sympy.Symbol") -> "sympy.Expr": - """ - Implements a DSU-like algorithm to find the variable that represents a - Also handles transitive non-identity replacements. - - a: b + c - c: d - """ - if a not in self.replacements: - return a - res = self.replacements[a] - cur_replace = {s: self._find(s) for s in res.free_symbols} - self.replacements[a] = self.replacements[a].xreplace(cur_replace) - return self.replacements[a] - - @lru_cache(256) - def _maybe_guard_eq(self, expr: "sympy.Eq") -> None: - """ - Evaluates the result of an eq call. If true, uses information to - simplify shapes (i.e. a == b or a % 5 == 0) - """ - concrete_bool = bool(self.size_hint(expr)) - if not concrete_bool: - return - free = list(expr.free_symbols) - - assert len(free) > 0, "The expression should not be static by this point" - # In case of really gnarly expression, we don't blow up - if len(free) > 5: - return - free = sorted(free, key=lambda x: (self.size_hint(x), x.name), reverse=True) # type: ignore[attr-defined] - lhs = expr.lhs - rhs = expr.rhs - try: - solutions = sympy.solve(lhs - rhs, free[0], dict=True) - if len(solutions) != 1: - return - solution = solutions[0][free[0]] - if all(t.is_integer for t in sympy.preorder_traversal(solution)): - new_var = self._find(solution) - self.replacements[cast(sympy.Symbol, free[0])] = new_var - except NotImplementedError: - if expr.has(sympy.Mod): - mod_expr = tuple(expr.atoms(sympy.Mod))[0] - try: - solutions = sympy.solve(lhs - rhs, mod_expr, dict=True) - if len(solutions) == 1 and solutions[0][mod_expr] == 0: - self.divisible.add(mod_expr) - except NotImplementedError: - pass - return - - @lru_cache(256) - def evaluate_expr(self, expr: "sympy.Expr"): - """ - Given an expression, evaluates it, adding guards if necessary - """ - if len(expr.free_symbols) == 0: - return expr - expr = self.simplify(expr) - static_expr = self._maybe_evaluate_static(expr) - if static_expr is not None: - return static_expr - - if isinstance(expr, sympy.Eq): - self._maybe_guard_eq(expr) - concrete_val = self.size_hint(expr) - - # TODO: optimize this; avoid formatting traces until we need them - # NB: drop two frames; evaluate_expr and the Sym* function that - # actually called us - stack = ''.join(traceback.format_list(traceback.extract_stack()[:-2])) - self.guards.append((expr, concrete_val, stack)) - return concrete_val diff --git a/pippy/fx/experimental/unification/LICENSE.txt b/pippy/fx/experimental/unification/LICENSE.txt deleted file mode 100644 index 775eca52c..000000000 --- a/pippy/fx/experimental/unification/LICENSE.txt +++ /dev/null @@ -1,28 +0,0 @@ -Copyright (c) 2014 Matthew Rocklin - -All rights reserved. - -Redistribution and use in source and binary forms, with or without -modification, are permitted provided that the following conditions are met: - - a. Redistributions of source code must retain the above copyright notice, - this list of conditions and the following disclaimer. - b. Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - c. Neither the name of Unification nor the names of its contributors - may be used to endorse or promote products derived from this software - without specific prior written permission. - - -THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE -ARE DISCLAIMED. IN NO EVENT SHALL THE REGENTS OR CONTRIBUTORS BE LIABLE FOR -ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT -LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY -OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH -DAMAGE. diff --git a/pippy/fx/experimental/unification/__init__.py b/pippy/fx/experimental/unification/__init__.py deleted file mode 100644 index 5e1477089..000000000 --- a/pippy/fx/experimental/unification/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates -# type: ignore[attr-defined] -from .core import unify, reify # noqa: F403 -from .more import unifiable # noqa: F403 -from .variable import var, isvar, vars, variables, Var # noqa: F403 diff --git a/pippy/fx/experimental/unification/core.py b/pippy/fx/experimental/unification/core.py deleted file mode 100644 index c1eb2b3cf..000000000 --- a/pippy/fx/experimental/unification/core.py +++ /dev/null @@ -1,119 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates -from collections.abc import Iterator # type: ignore[import] -from functools import partial - -from .unification_tools import assoc # type: ignore[import] -from .utils import transitive_get as walk -from .variable import isvar -from .dispatch import dispatch - -__all__ = ["reify", "unify"] - -################ -# Reificiation # -################ - -@dispatch(Iterator, dict) -def _reify(t, s): - return map(partial(reify, s=s), t) - # return (reify(arg, s) for arg in t) -_reify - -@dispatch(tuple, dict) # type: ignore[no-redef] -def _reify(t, s): - return tuple(reify(iter(t), s)) -_reify - -@dispatch(list, dict) # type: ignore[no-redef] -def _reify(t, s): - return list(reify(iter(t), s)) -_reify - -@dispatch(dict, dict) # type: ignore[no-redef] -def _reify(d, s): - return dict((k, reify(v, s)) for k, v in d.items()) -_reify - -@dispatch(object, dict) # type: ignore[no-redef] -def _reify(o, s): - return o # catch all, just return the object - -def reify(e, s): - """ Replace variables of expression with substitution - >>> # xdoctest: +SKIP - >>> x, y = var(), var() - >>> e = (1, x, (3, y)) - >>> s = {x: 2, y: 4} - >>> reify(e, s) - (1, 2, (3, 4)) - >>> e = {1: x, 3: (y, 5)} - >>> reify(e, s) - {1: 2, 3: (4, 5)} - """ - if isvar(e): - return reify(s[e], s) if e in s else e - return _reify(e, s) - -############### -# Unification # -############### - -seq = tuple, list, Iterator - -@dispatch(seq, seq, dict) -def _unify(u, v, s): - if len(u) != len(v): - return False - for uu, vv in zip(u, v): # avoiding recursion - s = unify(uu, vv, s) - if s is False: - return False - return s -# -# @dispatch((set, frozenset), (set, frozenset), dict) -# def _unify(u, v, s): -# i = u & v -# u = u - i -# v = v - i -# return _unify(sorted(u), sorted(v), s) -# -# -# @dispatch(dict, dict, dict) -# def _unify(u, v, s): -# if len(u) != len(v): -# return False -# for key, uval in iteritems(u): -# if key not in v: -# return False -# s = unify(uval, v[key], s) -# if s is False: -# return False -# return s -# -# -# @dispatch(object, object, dict) -# def _unify(u, v, s): -# return False # catch all - - -@dispatch(object, object, dict) -def unify(u, v, s): # no check at the moment - """ Find substitution so that u == v while satisfying s - >>> x = var('x') - >>> unify((1, x), (1, 2), {}) - {~x: 2} - """ - u = walk(u, s) - v = walk(v, s) - if u == v: - return s - if isvar(u): - return assoc(s, u, v) - if isvar(v): - return assoc(s, v, u) - return _unify(u, v, s) -unify - -@dispatch(object, object) # type: ignore[no-redef] -def unify(u, v): - return unify(u, v, {}) diff --git a/pippy/fx/experimental/unification/dispatch.py b/pippy/fx/experimental/unification/dispatch.py deleted file mode 100644 index fd9b37188..000000000 --- a/pippy/fx/experimental/unification/dispatch.py +++ /dev/null @@ -1,7 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates -from functools import partial -from .multipledispatch import dispatch # type: ignore[import] - -namespace = {} # type: ignore[var-annotated] - -dispatch = partial(dispatch, namespace=namespace) diff --git a/pippy/fx/experimental/unification/match.py b/pippy/fx/experimental/unification/match.py deleted file mode 100644 index 09ca3ad5d..000000000 --- a/pippy/fx/experimental/unification/match.py +++ /dev/null @@ -1,123 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates -from .core import unify, reify # type: ignore[attr-defined] -from .variable import isvar -from .utils import _toposort, freeze -from .unification_tools import groupby, first # type: ignore[import] - - -class Dispatcher(object): - def __init__(self, name): - self.name = name - self.funcs = {} - self.ordering = [] - - def add(self, signature, func): - self.funcs[freeze(signature)] = func - self.ordering = ordering(self.funcs) - - def __call__(self, *args, **kwargs): - func, s = self.resolve(args) - return func(*args, **kwargs) - - def resolve(self, args): - n = len(args) - for signature in self.ordering: - if len(signature) != n: - continue - s = unify(freeze(args), signature) - if s is not False: - result = self.funcs[signature] - return result, s - raise NotImplementedError("No match found. \nKnown matches: " - + str(self.ordering) + "\nInput: " + str(args)) - - def register(self, *signature): - def _(func): - self.add(signature, func) - return self - return _ - -class VarDispatcher(Dispatcher): - """ A dispatcher that calls functions with variable names - >>> d = VarDispatcher('d') - >>> # xdoctest: +SKIP - >>> x = var('x') - >>> @d.register('inc', x) - ... def f(x): - ... return x + 1 - >>> @d.register('double', x) - ... def f(x): - ... return x * 2 - >>> d('inc', 10) - 11 - >>> d('double', 10) - 20 - """ - def __call__(self, *args, **kwargs): - func, s = self.resolve(args) - d = dict((k.token, v) for k, v in s.items()) - return func(**d) - - - - -global_namespace = {} # type: ignore[var-annotated] - - -def match(*signature, **kwargs): - namespace = kwargs.get('namespace', global_namespace) - dispatcher = kwargs.get('Dispatcher', Dispatcher) - - def _(func): - name = func.__name__ - - if name not in namespace: - namespace[name] = dispatcher(name) - d = namespace[name] - - d.add(signature, func) - - return d - return _ - - -def supercedes(a, b): - """ ``a`` is a more specific match than ``b`` """ - if isvar(b) and not isvar(a): - return True - s = unify(a, b) - if s is False: - return False - s = dict((k, v) for k, v in s.items() if not isvar(k) or not isvar(v)) - if reify(a, s) == a: - return True - if reify(b, s) == b: - return False - - -# Taken from multipledispatch -def edge(a, b, tie_breaker=hash): - """ A should be checked before B - Tie broken by tie_breaker, defaults to ``hash`` - """ - if supercedes(a, b): - if supercedes(b, a): - return tie_breaker(a) > tie_breaker(b) - else: - return True - return False - - -# Taken from multipledispatch -def ordering(signatures): - """ A sane ordering of signatures to check, first to last - Topoological sort of edges as given by ``edge`` and ``supercedes`` - """ - signatures = list(map(tuple, signatures)) - edges = [(a, b) for a in signatures for b in signatures if edge(a, b)] - edges = groupby(first, edges) - for s in signatures: - if s not in edges: - edges[s] = [] - edges = dict((k, [b for a, b in v]) for k, v in edges.items()) # type: ignore[attr-defined, assignment] - return _toposort(edges) diff --git a/pippy/fx/experimental/unification/more.py b/pippy/fx/experimental/unification/more.py deleted file mode 100644 index 81e72821f..000000000 --- a/pippy/fx/experimental/unification/more.py +++ /dev/null @@ -1,117 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates -from .core import unify, reify # type: ignore[attr-defined] -from .dispatch import dispatch - - -def unifiable(cls): - """ Register standard unify and reify operations on class - This uses the type and __dict__ or __slots__ attributes to define the - nature of the term - See Also: - >>> class A(object): - ... def __init__(self, a, b): - ... self.a = a - ... self.b = b - >>> # xdoctest: +SKIP - >>> unifiable(A) - - >>> x = var('x') - >>> a = A(1, 2) - >>> b = A(1, x) - >>> unify(a, b, {}) - {~x: 2} - """ - _unify.add((cls, cls, dict), unify_object) - _reify.add((cls, dict), reify_object) - - return cls - - -######### -# Reify # -######### - - -def reify_object(o, s): - """ Reify a Python object with a substitution - >>> class Foo(object): - ... def __init__(self, a, b): - ... self.a = a - ... self.b = b - ... def __str__(self): - ... return "Foo(%s, %s)"%(str(self.a), str(self.b)) - >>> # xdoctest: +SKIP - >>> x = var('x') - >>> f = Foo(1, x) - >>> print(f) - Foo(1, ~x) - >>> print(reify_object(f, {x: 2})) - Foo(1, 2) - """ - if hasattr(o, '__slots__'): - return _reify_object_slots(o, s) - else: - return _reify_object_dict(o, s) - - -def _reify_object_dict(o, s): - obj = object.__new__(type(o)) - d = reify(o.__dict__, s) - if d == o.__dict__: - return o - obj.__dict__.update(d) - return obj - - -def _reify_object_slots(o, s): - attrs = [getattr(o, attr) for attr in o.__slots__] - new_attrs = reify(attrs, s) - if attrs == new_attrs: - return o - else: - newobj = object.__new__(type(o)) - for slot, attr in zip(o.__slots__, new_attrs): - setattr(newobj, slot, attr) - return newobj - - -@dispatch(slice, dict) -def _reify(o, s): - """ Reify a Python ``slice`` object """ - return slice(*reify((o.start, o.stop, o.step), s)) - - -######### -# Unify # -######### - - -def unify_object(u, v, s): - """ Unify two Python objects - Unifies their type and ``__dict__`` attributes - >>> class Foo(object): - ... def __init__(self, a, b): - ... self.a = a - ... self.b = b - ... def __str__(self): - ... return "Foo(%s, %s)"%(str(self.a), str(self.b)) - >>> # xdoctest: +SKIP - >>> x = var('x') - >>> f = Foo(1, x) - >>> g = Foo(1, 2) - >>> unify_object(f, g, {}) - {~x: 2} - """ - if type(u) != type(v): - return False - if hasattr(u, '__slots__'): - return unify([getattr(u, slot) for slot in u.__slots__], - [getattr(v, slot) for slot in v.__slots__], - s) - else: - return unify(u.__dict__, v.__dict__, s) - -@dispatch(slice, slice, dict) -def _unify(u, v, s): - """ Unify a Python ``slice`` object """ - return unify((u.start, u.stop, u.step), (v.start, v.stop, v.step), s) diff --git a/pippy/fx/experimental/unification/multipledispatch/__init__.py b/pippy/fx/experimental/unification/multipledispatch/__init__.py deleted file mode 100644 index 26039e4ce..000000000 --- a/pippy/fx/experimental/unification/multipledispatch/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates -from .core import dispatch -from .dispatcher import (Dispatcher, halt_ordering, restart_ordering, - MDNotImplementedError) diff --git a/pippy/fx/experimental/unification/multipledispatch/conflict.py b/pippy/fx/experimental/unification/multipledispatch/conflict.py deleted file mode 100644 index 4ed1da4b0..000000000 --- a/pippy/fx/experimental/unification/multipledispatch/conflict.py +++ /dev/null @@ -1,120 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates -from .utils import _toposort, groupby -from .variadic import isvariadic - -__all__ = ["AmbiguityWarning", "supercedes", "consistent", "ambiguous", "ambiguities", "super_signature", - "edge", "ordering"] - -class AmbiguityWarning(Warning): - pass - - -def supercedes(a, b): - """ A is consistent and strictly more specific than B """ - if len(a) < len(b): - # only case is if a is empty and b is variadic - return not a and len(b) == 1 and isvariadic(b[-1]) - elif len(a) == len(b): - return all(map(issubclass, a, b)) - else: - # len(a) > len(b) - p1 = 0 - p2 = 0 - while p1 < len(a) and p2 < len(b): - cur_a = a[p1] - cur_b = b[p2] - if not (isvariadic(cur_a) or isvariadic(cur_b)): - if not issubclass(cur_a, cur_b): - return False - p1 += 1 - p2 += 1 - elif isvariadic(cur_a): - assert p1 == len(a) - 1 - return p2 == len(b) - 1 and issubclass(cur_a, cur_b) - elif isvariadic(cur_b): - assert p2 == len(b) - 1 - if not issubclass(cur_a, cur_b): - return False - p1 += 1 - return p2 == len(b) - 1 and p1 == len(a) - - -def consistent(a, b): - """ It is possible for an argument list to satisfy both A and B """ - - # Need to check for empty args - if not a: - return not b or isvariadic(b[0]) - if not b: - return not a or isvariadic(a[0]) - - # Non-empty args check for mutual subclasses - if len(a) == len(b): - return all(issubclass(aa, bb) or issubclass(bb, aa) - for aa, bb in zip(a, b)) - else: - p1 = 0 - p2 = 0 - while p1 < len(a) and p2 < len(b): - cur_a = a[p1] - cur_b = b[p2] - if not issubclass(cur_b, cur_a) and not issubclass(cur_a, cur_b): - return False - if not (isvariadic(cur_a) or isvariadic(cur_b)): - p1 += 1 - p2 += 1 - elif isvariadic(cur_a): - p2 += 1 - elif isvariadic(cur_b): - p1 += 1 - # We only need to check for variadic ends - # Variadic types are guaranteed to be the last element - return (isvariadic(cur_a) and p2 == len(b) or - isvariadic(cur_b) and p1 == len(a)) - - -def ambiguous(a, b): - """ A is consistent with B but neither is strictly more specific """ - return consistent(a, b) and not (supercedes(a, b) or supercedes(b, a)) - - -def ambiguities(signatures): - """ All signature pairs such that A is ambiguous with B """ - signatures = list(map(tuple, signatures)) - return set((a, b) for a in signatures for b in signatures - if hash(a) < hash(b) - and ambiguous(a, b) - and not any(supercedes(c, a) and supercedes(c, b) - for c in signatures)) - - -def super_signature(signatures): - """ A signature that would break ambiguities """ - n = len(signatures[0]) - assert all(len(s) == n for s in signatures) - - return [max([type.mro(sig[i]) for sig in signatures], key=len)[0] - for i in range(n)] - - -def edge(a, b, tie_breaker=hash): - """ A should be checked before B - Tie broken by tie_breaker, defaults to ``hash`` - """ - # A either supercedes B and B does not supercede A or if B does then call - # tie_breaker - return supercedes(a, b) and (not supercedes(b, a) or tie_breaker(a) > tie_breaker(b)) - - -def ordering(signatures): - """ A sane ordering of signatures to check, first to last - Topoological sort of edges as given by ``edge`` and ``supercedes`` - """ - signatures = list(map(tuple, signatures)) - edges = [(a, b) for a in signatures for b in signatures if edge(a, b)] - edges = groupby(lambda x: x[0], edges) - for s in signatures: - if s not in edges: - edges[s] = [] - edges = dict((k, [b for a, b in v]) for k, v in edges.items()) # type: ignore[assignment, attr-defined] - return _toposort(edges) diff --git a/pippy/fx/experimental/unification/multipledispatch/core.py b/pippy/fx/experimental/unification/multipledispatch/core.py deleted file mode 100644 index ca79fcadb..000000000 --- a/pippy/fx/experimental/unification/multipledispatch/core.py +++ /dev/null @@ -1,82 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates -import inspect -import sys - -from .dispatcher import Dispatcher, MethodDispatcher - -global_namespace = {} # type: ignore[var-annotated] - -__all__ = ["dispatch", "ismethod"] - -def dispatch(*types, **kwargs): - """ Dispatch function on the types of the inputs - Supports dispatch on all non-keyword arguments. - Collects implementations based on the function name. Ignores namespaces. - If ambiguous type signatures occur a warning is raised when the function is - defined suggesting the additional method to break the ambiguity. - Examples - -------- - >>> @dispatch(int) - ... def f(x): - ... return x + 1 - >>> @dispatch(float) - ... def f(x): - ... return x - 1 - >>> f(3) - 4 - >>> f(3.0) - 2.0 - >>> # Specify an isolated namespace with the namespace keyword argument - >>> my_namespace = {} - >>> @dispatch(int, namespace=my_namespace) - ... def foo(x): - ... return x + 1 - >>> # Dispatch on instance methods within classes - >>> class MyClass(object): - ... @dispatch(list) - ... def __init__(self, data): - ... self.data = data - ... @dispatch(int) - ... def __init__(self, datum): - ... self.data = [datum] - >>> MyClass([1, 2, 3]).data - [1, 2, 3] - >>> MyClass(3).data - [3] - """ - namespace = kwargs.get('namespace', global_namespace) - - types = tuple(types) - - def _df(func): - name = func.__name__ - - if ismethod(func): - dispatcher = inspect.currentframe().f_back.f_locals.get( # type: ignore[union-attr] - name, # type: ignore[union-attr] - MethodDispatcher(name), - ) - else: - if name not in namespace: - namespace[name] = Dispatcher(name) - dispatcher = namespace[name] - - dispatcher.add(types, func) - return dispatcher - return _df - - -def ismethod(func): - """ Is func a method? - Note that this has to work as the method is defined but before the class is - defined. At this stage methods look like functions. - """ - if hasattr(inspect, "signature"): - signature = inspect.signature(func) - return signature.parameters.get('self', None) is not None - else: - if sys.version_info.major < 3: - spec = inspect.getargspec(func) - else: - spec = inspect.getfullargspec(func) # type: ignore[union-attr, assignment] - return spec and spec.args and spec.args[0] == 'self' diff --git a/pippy/fx/experimental/unification/multipledispatch/dispatcher.py b/pippy/fx/experimental/unification/multipledispatch/dispatcher.py deleted file mode 100644 index 7427aebe5..000000000 --- a/pippy/fx/experimental/unification/multipledispatch/dispatcher.py +++ /dev/null @@ -1,433 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates -from warnings import warn -import inspect -from .conflict import ordering, ambiguities, super_signature, AmbiguityWarning -from .utils import expand_tuples -from .variadic import Variadic, isvariadic -import itertools as itl - -__all__ = ["MDNotImplementedError", "ambiguity_warn", "halt_ordering", "restart_ordering", "variadic_signature_matches_iter", - "variadic_signature_matches", "Dispatcher", "source", "MethodDispatcher", "str_signature", "warning_text"] - -class MDNotImplementedError(NotImplementedError): - """ A NotImplementedError for multiple dispatch """ - - -def ambiguity_warn(dispatcher, ambiguities): - """ Raise warning when ambiguity is detected - Parameters - ---------- - dispatcher : Dispatcher - The dispatcher on which the ambiguity was detected - ambiguities : set - Set of type signature pairs that are ambiguous within this dispatcher - See Also: - Dispatcher.add - warning_text - """ - warn(warning_text(dispatcher.name, ambiguities), AmbiguityWarning) - - -def halt_ordering(): - """Deprecated interface to temporarily disable ordering. - """ - warn( - 'halt_ordering is deprecated, you can safely remove this call.', - DeprecationWarning, - ) - - -def restart_ordering(on_ambiguity=ambiguity_warn): - """Deprecated interface to temporarily resume ordering. - """ - warn( - 'restart_ordering is deprecated, if you would like to eagerly order' - 'the dispatchers, you should call the ``reorder()`` method on each' - ' dispatcher.', - DeprecationWarning, - ) - - -def variadic_signature_matches_iter(types, full_signature): - """Check if a set of input types matches a variadic signature. - Notes - ----- - The algorithm is as follows: - Initialize the current signature to the first in the sequence - For each type in `types`: - If the current signature is variadic - If the type matches the signature - yield True - Else - Try to get the next signature - If no signatures are left we can't possibly have a match - so yield False - Else - yield True if the type matches the current signature - Get the next signature - """ - sigiter = iter(full_signature) - sig = next(sigiter) - for typ in types: - matches = issubclass(typ, sig) - yield matches - if not isvariadic(sig): - # we're not matching a variadic argument, so move to the next - # element in the signature - sig = next(sigiter) - else: - try: - sig = next(sigiter) - except StopIteration: - assert isvariadic(sig) - yield True - else: - # We have signature items left over, so all of our arguments - # haven't matched - yield False - - -def variadic_signature_matches(types, full_signature): - # No arguments always matches a variadic signature - assert full_signature - return all(variadic_signature_matches_iter(types, full_signature)) - - -class Dispatcher(object): - """ Dispatch methods based on type signature - Use ``dispatch`` to add implementations - Examples - -------- - >>> # xdoctest: +SKIP("bad import name") - >>> from multipledispatch import dispatch - >>> @dispatch(int) - ... def f(x): - ... return x + 1 - >>> @dispatch(float) - ... def f(x): - ... return x - 1 - >>> f(3) - 4 - >>> f(3.0) - 2.0 - """ - __slots__ = '__name__', 'name', 'funcs', '_ordering', '_cache', 'doc' - - def __init__(self, name, doc=None): - self.name = self.__name__ = name - self.funcs = {} - self.doc = doc - - self._cache = {} - - def register(self, *types, **kwargs): - """ register dispatcher with new implementation - >>> f = Dispatcher('f') - >>> @f.register(int) - ... def inc(x): - ... return x + 1 - >>> @f.register(float) - ... def dec(x): - ... return x - 1 - >>> @f.register(list) - ... @f.register(tuple) - ... def reverse(x): - ... return x[::-1] - >>> f(1) - 2 - >>> f(1.0) - 0.0 - >>> f([1, 2, 3]) - [3, 2, 1] - """ - def _df(func): - self.add(types, func, **kwargs) # type: ignore[call-arg] - return func - return _df - - @classmethod - def get_func_params(cls, func): - if hasattr(inspect, "signature"): - sig = inspect.signature(func) - return sig.parameters.values() - - @classmethod - def get_func_annotations(cls, func): - """ get annotations of function positional parameters - """ - params = cls.get_func_params(func) - if params: - Parameter = inspect.Parameter - - params = (param for param in params - if param.kind in - (Parameter.POSITIONAL_ONLY, - Parameter.POSITIONAL_OR_KEYWORD)) - - annotations = tuple( - param.annotation - for param in params) - - if all(ann is not Parameter.empty for ann in annotations): - return annotations - - def add(self, signature, func): - """ Add new types/method pair to dispatcher - >>> D = Dispatcher('add') - >>> D.add((int, int), lambda x, y: x + y) - >>> D.add((float, float), lambda x, y: x + y) - >>> D(1, 2) - 3 - >>> D(1, 2.0) - Traceback (most recent call last): - ... - NotImplementedError: Could not find signature for add: - >>> # When ``add`` detects a warning it calls the ``on_ambiguity`` callback - >>> # with a dispatcher/itself, and a set of ambiguous type signature pairs - >>> # as inputs. See ``ambiguity_warn`` for an example. - """ - # Handle annotations - if not signature: - annotations = self.get_func_annotations(func) - if annotations: - signature = annotations - - # Handle union types - if any(isinstance(typ, tuple) for typ in signature): - for typs in expand_tuples(signature): - self.add(typs, func) - return - - new_signature = [] - - for index, typ in enumerate(signature, start=1): - if not isinstance(typ, (type, list)): - str_sig = ', '.join(c.__name__ if isinstance(c, type) - else str(c) for c in signature) - raise TypeError("Tried to dispatch on non-type: %s\n" - "In signature: <%s>\n" - "In function: %s" % - (typ, str_sig, self.name)) - - # handle variadic signatures - if isinstance(typ, list): - if index != len(signature): - raise TypeError( - 'Variadic signature must be the last element' - ) - - if len(typ) != 1: - raise TypeError( - 'Variadic signature must contain exactly one element. ' - 'To use a variadic union type place the desired types ' - 'inside of a tuple, e.g., [(int, str)]' - ) - new_signature.append(Variadic[typ[0]]) - else: - new_signature.append(typ) - - self.funcs[tuple(new_signature)] = func - self._cache.clear() - - try: - del self._ordering - except AttributeError: - pass - - @property - def ordering(self): - try: - return self._ordering - except AttributeError: - return self.reorder() - - def reorder(self, on_ambiguity=ambiguity_warn): - self._ordering = od = ordering(self.funcs) - amb = ambiguities(self.funcs) - if amb: - on_ambiguity(self, amb) - return od - - def __call__(self, *args, **kwargs): - types = tuple([type(arg) for arg in args]) - try: - func = self._cache[types] - except KeyError: - func = self.dispatch(*types) - if not func: - raise NotImplementedError( - 'Could not find signature for %s: <%s>' % - (self.name, str_signature(types))) - self._cache[types] = func - try: - return func(*args, **kwargs) - - except MDNotImplementedError: - funcs = self.dispatch_iter(*types) - next(funcs) # burn first - for func in funcs: - try: - return func(*args, **kwargs) - except MDNotImplementedError: - pass - - raise NotImplementedError( - "Matching functions for " - "%s: <%s> found, but none completed successfully" % ( - self.name, str_signature(types),),) - - def __str__(self): - return "" % self.name - __repr__ = __str__ - - def dispatch(self, *types): - """Deterimine appropriate implementation for this type signature - This method is internal. Users should call this object as a function. - Implementation resolution occurs within the ``__call__`` method. - >>> # xdoctest: +SKIP - >>> from multipledispatch import dispatch - >>> @dispatch(int) - ... def inc(x): - ... return x + 1 - >>> implementation = inc.dispatch(int) - >>> implementation(3) - 4 - >>> print(inc.dispatch(float)) - None - See Also: - ``multipledispatch.conflict`` - module to determine resolution order - """ - - if types in self.funcs: - return self.funcs[types] - - try: - return next(self.dispatch_iter(*types)) - except StopIteration: - return None - - def dispatch_iter(self, *types): - - n = len(types) - for signature in self.ordering: - if len(signature) == n and all(map(issubclass, types, signature)): - result = self.funcs[signature] - yield result - elif len(signature) and isvariadic(signature[-1]): - if variadic_signature_matches(types, signature): - result = self.funcs[signature] - yield result - - def resolve(self, types): - """ Deterimine appropriate implementation for this type signature - .. deprecated:: 0.4.4 - Use ``dispatch(*types)`` instead - """ - warn("resolve() is deprecated, use dispatch(*types)", - DeprecationWarning) - - return self.dispatch(*types) - - def __getstate__(self): - return {'name': self.name, - 'funcs': self.funcs} - - def __setstate__(self, d): - self.name = d['name'] - self.funcs = d['funcs'] - self._ordering = ordering(self.funcs) - self._cache = {} - - @property - def __doc__(self): - docs = ["Multiply dispatched method: %s" % self.name] - - if self.doc: - docs.append(self.doc) - - other = [] - for sig in self.ordering[::-1]: - func = self.funcs[sig] - if func.__doc__: - s = 'Inputs: <%s>\n' % str_signature(sig) - s += '-' * len(s) + '\n' - s += func.__doc__.strip() - docs.append(s) - else: - other.append(str_signature(sig)) - - if other: - docs.append('Other signatures:\n ' + '\n '.join(other)) - - return '\n\n'.join(docs) - - def _help(self, *args): - return self.dispatch(*map(type, args)).__doc__ - - def help(self, *args, **kwargs): - """ Print docstring for the function corresponding to inputs """ - print(self._help(*args)) - - def _source(self, *args): - func = self.dispatch(*map(type, args)) - if not func: - raise TypeError("No function found") - return source(func) - - def source(self, *args, **kwargs): - """ Print source code for the function corresponding to inputs """ - print(self._source(*args)) - - -def source(func): - s = 'File: %s\n\n' % inspect.getsourcefile(func) - s = s + inspect.getsource(func) - return s - - -class MethodDispatcher(Dispatcher): - """ Dispatch methods based on type signature - See Also: - Dispatcher - """ - __slots__ = ('obj', 'cls') - - @classmethod - def get_func_params(cls, func): - if hasattr(inspect, "signature"): - sig = inspect.signature(func) - return itl.islice(sig.parameters.values(), 1, None) - - def __get__(self, instance, owner): - self.obj = instance - self.cls = owner - return self - - def __call__(self, *args, **kwargs): - types = tuple([type(arg) for arg in args]) - func = self.dispatch(*types) - if not func: - raise NotImplementedError('Could not find signature for %s: <%s>' % - (self.name, str_signature(types))) - return func(self.obj, *args, **kwargs) - - -def str_signature(sig): - """ String representation of type signature - >>> str_signature((int, float)) - 'int, float' - """ - return ', '.join(cls.__name__ for cls in sig) - - -def warning_text(name, amb): - """ The text for ambiguity warnings """ - text = "\nAmbiguities exist in dispatched function %s\n\n" % (name) - text += "The following signatures may result in ambiguous behavior:\n" - for pair in amb: - text += "\t" + \ - ', '.join('[' + str_signature(s) + ']' for s in pair) + "\n" - text += "\n\nConsider making the following additions:\n\n" - text += '\n\n'.join(['@dispatch(' + str_signature(super_signature(s)) - + ')\ndef %s(...)' % name for s in amb]) - return text diff --git a/pippy/fx/experimental/unification/multipledispatch/utils.py b/pippy/fx/experimental/unification/multipledispatch/utils.py deleted file mode 100644 index 3e427d2f4..000000000 --- a/pippy/fx/experimental/unification/multipledispatch/utils.py +++ /dev/null @@ -1,126 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates -from collections import OrderedDict - -__all__ = ["raises", "expand_tuples", "reverse_dict", "groupby", "typename"] - -def raises(err, lamda): - try: - lamda() - return False - except err: - return True - - -def expand_tuples(L): - """ - >>> expand_tuples([1, (2, 3)]) - [(1, 2), (1, 3)] - >>> expand_tuples([1, 2]) - [(1, 2)] - """ - if not L: - return [()] - elif not isinstance(L[0], tuple): - rest = expand_tuples(L[1:]) - return [(L[0],) + t for t in rest] - else: - rest = expand_tuples(L[1:]) - return [(item,) + t for t in rest for item in L[0]] - - -# Taken from theano/theano/gof/sched.py -# Avoids licensing issues because this was written by Matthew Rocklin -def _toposort(edges): - """ Topological sort algorithm by Kahn [1] - O(nodes + vertices) - inputs: - edges - a dict of the form {a: {b, c}} where b and c depend on a - outputs: - L - an ordered list of nodes that satisfy the dependencies of edges - >>> _toposort({1: (2, 3), 2: (3, )}) - [1, 2, 3] - >>> # Closely follows the wikipedia page [2] - >>> # [1] Kahn, Arthur B. (1962), "Topological sorting of large networks", - >>> # Communications of the ACM - >>> # [2] http://en.wikipedia.org/wiki/Toposort#Algorithms - """ - incoming_edges = reverse_dict(edges) - incoming_edges = OrderedDict((k, set(val)) - for k, val in incoming_edges.items()) - S = OrderedDict.fromkeys(v for v in edges if v not in incoming_edges) - L = [] - - while S: - n, _ = S.popitem() - L.append(n) - for m in edges.get(n, ()): - assert n in incoming_edges[m] - incoming_edges[m].remove(n) - if not incoming_edges[m]: - S[m] = None - if any(incoming_edges.get(v, None) for v in edges): - raise ValueError("Input has cycles") - return L - - -def reverse_dict(d): - """Reverses direction of dependence dict - >>> d = {'a': (1, 2), 'b': (2, 3), 'c':()} - >>> reverse_dict(d) # doctest: +SKIP - {1: ('a',), 2: ('a', 'b'), 3: ('b',)} - :note: dict order are not deterministic. As we iterate on the - input dict, it make the output of this function depend on the - dict order. So this function output order should be considered - as undeterministic. - """ - result = OrderedDict() # type: ignore[var-annotated] - for key in d: - for val in d[key]: - result[val] = result.get(val, tuple()) + (key, ) - return result - - -# Taken from toolz -# Avoids licensing issues because this version was authored by Matthew Rocklin -def groupby(func, seq): - """ Group a collection by a key function - >>> names = ['Alice', 'Bob', 'Charlie', 'Dan', 'Edith', 'Frank'] - >>> groupby(len, names) # doctest: +SKIP - {3: ['Bob', 'Dan'], 5: ['Alice', 'Edith', 'Frank'], 7: ['Charlie']} - >>> iseven = lambda x: x % 2 == 0 - >>> groupby(iseven, [1, 2, 3, 4, 5, 6, 7, 8]) # doctest: +SKIP - {False: [1, 3, 5, 7], True: [2, 4, 6, 8]} - See Also: - ``countby`` - """ - - d = OrderedDict() # type: ignore[var-annotated] - for item in seq: - key = func(item) - if key not in d: - d[key] = list() - d[key].append(item) - return d - - -def typename(type): - """Get the name of `type`. - Parameters - ---------- - type : Union[Type, Tuple[Type]] - Returns - ------- - str - The name of `type` or a tuple of the names of the types in `type`. - Examples - -------- - >>> typename(int) - 'int' - >>> typename((int, float)) - '(int, float)' - """ - try: - return type.__name__ - except AttributeError: - if len(type) == 1: - return typename(*type) - return '(%s)' % ', '.join(map(typename, type)) diff --git a/pippy/fx/experimental/unification/multipledispatch/variadic.py b/pippy/fx/experimental/unification/multipledispatch/variadic.py deleted file mode 100644 index 5802302ee..000000000 --- a/pippy/fx/experimental/unification/multipledispatch/variadic.py +++ /dev/null @@ -1,93 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates -import six - -from .utils import typename - -__all__ = ["VariadicSignatureType", "isvariadic", "VariadicSignatureMeta", "Variadic"] - -class VariadicSignatureType(type): - # checking if subclass is a subclass of self - def __subclasscheck__(cls, subclass): - other_type = (subclass.variadic_type if isvariadic(subclass) - else (subclass,)) - return subclass is cls or all( - issubclass(other, cls.variadic_type) for other in other_type # type: ignore[attr-defined] - ) - - def __eq__(cls, other): - """ - Return True if other has the same variadic type - Parameters - ---------- - other : object (type) - The object (type) to check - Returns - ------- - bool - Whether or not `other` is equal to `self` - """ - return (isvariadic(other) and - set(cls.variadic_type) == set(other.variadic_type)) # type: ignore[attr-defined] - - def __hash__(cls): - return hash((type(cls), frozenset(cls.variadic_type))) # type: ignore[attr-defined] - - -def isvariadic(obj): - """Check whether the type `obj` is variadic. - Parameters - ---------- - obj : type - The type to check - Returns - ------- - bool - Whether or not `obj` is variadic - Examples - -------- - >>> isvariadic(int) - False - >>> isvariadic(Variadic[int]) - True - """ - return isinstance(obj, VariadicSignatureType) - - -class VariadicSignatureMeta(type): - """A metaclass that overrides ``__getitem__`` on the class. This is used to - generate a new type for Variadic signatures. See the Variadic class for - examples of how this behaves. - """ - def __getitem__(cls, variadic_type): - if not (isinstance(variadic_type, (type, tuple)) or type(variadic_type)): - raise ValueError("Variadic types must be type or tuple of types" - " (Variadic[int] or Variadic[(int, float)]") - - if not isinstance(variadic_type, tuple): - variadic_type = variadic_type, - return VariadicSignatureType( - 'Variadic[%s]' % typename(variadic_type), - (), - dict(variadic_type=variadic_type, __slots__=()) - ) - - -class Variadic(six.with_metaclass(VariadicSignatureMeta)): - """A class whose getitem method can be used to generate a new type - representing a specific variadic signature. - Examples - -------- - >>> Variadic[int] # any number of int arguments - >>> # xdoctest: +SKIP - - >>> Variadic[(int, str)] # any number of one of int or str arguments - - >>> issubclass(int, Variadic[int]) - True - >>> issubclass(int, Variadic[(int, str)]) - True - >>> issubclass(str, Variadic[(int, str)]) - True - >>> issubclass(float, Variadic[(int, str)]) - False - """ diff --git a/pippy/fx/experimental/unification/unification_tools.py b/pippy/fx/experimental/unification/unification_tools.py deleted file mode 100644 index d2ddc1df3..000000000 --- a/pippy/fx/experimental/unification/unification_tools.py +++ /dev/null @@ -1,393 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates -import collections -import operator -from functools import reduce -from collections.abc import Mapping - -__all__ = ('merge', 'merge_with', 'valmap', 'keymap', 'itemmap', - 'valfilter', 'keyfilter', 'itemfilter', - 'assoc', 'dissoc', 'assoc_in', 'update_in', 'get_in') - -def _get_factory(f, kwargs): - factory = kwargs.pop('factory', dict) - if kwargs: - raise TypeError("{}() got an unexpected keyword argument " - "'{}'".format(f.__name__, kwargs.popitem()[0])) - return factory - - -def merge(*dicts, **kwargs): - """ Merge a collection of dictionaries - - >>> merge({1: 'one'}, {2: 'two'}) - {1: 'one', 2: 'two'} - - Later dictionaries have precedence - - >>> merge({1: 2, 3: 4}, {3: 3, 4: 4}) - {1: 2, 3: 3, 4: 4} - - See Also: - merge_with - """ - if len(dicts) == 1 and not isinstance(dicts[0], Mapping): - dicts = dicts[0] - factory = _get_factory(merge, kwargs) - - rv = factory() - for d in dicts: - rv.update(d) - return rv - - -def merge_with(func, *dicts, **kwargs): - """ Merge dictionaries and apply function to combined values - - A key may occur in more than one dict, and all values mapped from the key - will be passed to the function as a list, such as func([val1, val2, ...]). - - >>> merge_with(sum, {1: 1, 2: 2}, {1: 10, 2: 20}) - {1: 11, 2: 22} - - >>> merge_with(first, {1: 1, 2: 2}, {2: 20, 3: 30}) # doctest: +SKIP - {1: 1, 2: 2, 3: 30} - - See Also: - merge - """ - if len(dicts) == 1 and not isinstance(dicts[0], Mapping): - dicts = dicts[0] - factory = _get_factory(merge_with, kwargs) - - result = factory() - for d in dicts: - for k, v in d.items(): - if k not in result: - result[k] = [v] - else: - result[k].append(v) - return valmap(func, result, factory) - - -def valmap(func, d, factory=dict): - """ Apply function to values of dictionary - - >>> bills = {"Alice": [20, 15, 30], "Bob": [10, 35]} - >>> valmap(sum, bills) # doctest: +SKIP - {'Alice': 65, 'Bob': 45} - - See Also: - keymap - itemmap - """ - rv = factory() - rv.update(zip(d.keys(), map(func, d.values()))) - return rv - - -def keymap(func, d, factory=dict): - """ Apply function to keys of dictionary - - >>> bills = {"Alice": [20, 15, 30], "Bob": [10, 35]} - >>> keymap(str.lower, bills) # doctest: +SKIP - {'alice': [20, 15, 30], 'bob': [10, 35]} - - See Also: - valmap - itemmap - """ - rv = factory() - rv.update(zip(map(func, d.keys()), d.values())) - return rv - - -def itemmap(func, d, factory=dict): - """ Apply function to items of dictionary - - >>> accountids = {"Alice": 10, "Bob": 20} - >>> itemmap(reversed, accountids) # doctest: +SKIP - {10: "Alice", 20: "Bob"} - - See Also: - keymap - valmap - """ - rv = factory() - rv.update(map(func, d.items())) - return rv - - -def valfilter(predicate, d, factory=dict): - """ Filter items in dictionary by value - - >>> iseven = lambda x: x % 2 == 0 - >>> d = {1: 2, 2: 3, 3: 4, 4: 5} - >>> valfilter(iseven, d) - {1: 2, 3: 4} - - See Also: - keyfilter - itemfilter - valmap - """ - rv = factory() - for k, v in d.items(): - if predicate(v): - rv[k] = v - return rv - - -def keyfilter(predicate, d, factory=dict): - """ Filter items in dictionary by key - - >>> iseven = lambda x: x % 2 == 0 - >>> d = {1: 2, 2: 3, 3: 4, 4: 5} - >>> keyfilter(iseven, d) - {2: 3, 4: 5} - - See Also: - valfilter - itemfilter - keymap - """ - rv = factory() - for k, v in d.items(): - if predicate(k): - rv[k] = v - return rv - - -def itemfilter(predicate, d, factory=dict): - """ Filter items in dictionary by item - - >>> def isvalid(item): - ... k, v = item - ... return k % 2 == 0 and v < 4 - - >>> d = {1: 2, 2: 3, 3: 4, 4: 5} - >>> itemfilter(isvalid, d) - {2: 3} - - See Also: - keyfilter - valfilter - itemmap - """ - rv = factory() - for item in d.items(): - if predicate(item): - k, v = item - rv[k] = v - return rv - - -def assoc(d, key, value, factory=dict): - """ Return a new dict with new key value pair - - New dict has d[key] set to value. Does not modify the initial dictionary. - - >>> assoc({'x': 1}, 'x', 2) - {'x': 2} - >>> assoc({'x': 1}, 'y', 3) # doctest: +SKIP - {'x': 1, 'y': 3} - """ - d2 = factory() - d2.update(d) - d2[key] = value - return d2 - - -def dissoc(d, *keys, **kwargs): - """ Return a new dict with the given key(s) removed. - - New dict has d[key] deleted for each supplied key. - Does not modify the initial dictionary. - - >>> dissoc({'x': 1, 'y': 2}, 'y') - {'x': 1} - >>> dissoc({'x': 1, 'y': 2}, 'y', 'x') - {} - >>> dissoc({'x': 1}, 'y') # Ignores missing keys - {'x': 1} - """ - factory = _get_factory(dissoc, kwargs) - d2 = factory() - - if len(keys) < len(d) * .6: - d2.update(d) - for key in keys: - if key in d2: - del d2[key] - else: - remaining = set(d) - remaining.difference_update(keys) - for k in remaining: - d2[k] = d[k] - return d2 - - -def assoc_in(d, keys, value, factory=dict): - """ Return a new dict with new, potentially nested, key value pair - - >>> purchase = {'name': 'Alice', - ... 'order': {'items': ['Apple', 'Orange'], - ... 'costs': [0.50, 1.25]}, - ... 'credit card': '5555-1234-1234-1234'} - >>> assoc_in(purchase, ['order', 'costs'], [0.25, 1.00]) # doctest: +SKIP - {'credit card': '5555-1234-1234-1234', - 'name': 'Alice', - 'order': {'costs': [0.25, 1.00], 'items': ['Apple', 'Orange']}} - """ - return update_in(d, keys, lambda x: value, value, factory) - - -def update_in(d, keys, func, default=None, factory=dict): - """ Update value in a (potentially) nested dictionary - - inputs: - d - dictionary on which to operate - keys - list or tuple giving the location of the value to be changed in d - func - function to operate on that value - - If keys == [k0,..,kX] and d[k0]..[kX] == v, update_in returns a copy of the - original dictionary with v replaced by func(v), but does not mutate the - original dictionary. - - If k0 is not a key in d, update_in creates nested dictionaries to the depth - specified by the keys, with the innermost value set to func(default). - - >>> inc = lambda x: x + 1 - >>> update_in({'a': 0}, ['a'], inc) - {'a': 1} - - >>> transaction = {'name': 'Alice', - ... 'purchase': {'items': ['Apple', 'Orange'], - ... 'costs': [0.50, 1.25]}, - ... 'credit card': '5555-1234-1234-1234'} - >>> update_in(transaction, ['purchase', 'costs'], sum) # doctest: +SKIP - {'credit card': '5555-1234-1234-1234', - 'name': 'Alice', - 'purchase': {'costs': 1.75, 'items': ['Apple', 'Orange']}} - - >>> # updating a value when k0 is not in d - >>> update_in({}, [1, 2, 3], str, default="bar") - {1: {2: {3: 'bar'}}} - >>> update_in({1: 'foo'}, [2, 3, 4], inc, 0) - {1: 'foo', 2: {3: {4: 1}}} - """ - ks = iter(keys) - k = next(ks) - - rv = inner = factory() - rv.update(d) - - for key in ks: - if k in d: - d = d[k] - dtemp = factory() - dtemp.update(d) - else: - d = dtemp = factory() - - inner[k] = inner = dtemp - k = key - - if k in d: - inner[k] = func(d[k]) - else: - inner[k] = func(default) - return rv - - -def get_in(keys, coll, default=None, no_default=False): - """ Returns coll[i0][i1]...[iX] where [i0, i1, ..., iX]==keys. - - If coll[i0][i1]...[iX] cannot be found, returns ``default``, unless - ``no_default`` is specified, then it raises KeyError or IndexError. - - ``get_in`` is a generalization of ``operator.getitem`` for nested data - structures such as dictionaries and lists. - - >>> transaction = {'name': 'Alice', - ... 'purchase': {'items': ['Apple', 'Orange'], - ... 'costs': [0.50, 1.25]}, - ... 'credit card': '5555-1234-1234-1234'} - >>> get_in(['purchase', 'items', 0], transaction) - 'Apple' - >>> get_in(['name'], transaction) - 'Alice' - >>> get_in(['purchase', 'total'], transaction) - >>> get_in(['purchase', 'items', 'apple'], transaction) - >>> get_in(['purchase', 'items', 10], transaction) - >>> get_in(['purchase', 'total'], transaction, 0) - 0 - >>> get_in(['y'], {}, no_default=True) - Traceback (most recent call last): - ... - KeyError: 'y' - - See Also: - itertoolz.get - operator.getitem - """ - try: - return reduce(operator.getitem, keys, coll) - except (KeyError, IndexError, TypeError): - if no_default: - raise - return default - -def getter(index): - if isinstance(index, list): - if len(index) == 1: - index = index[0] - return lambda x: (x[index],) - elif index: - return operator.itemgetter(*index) - else: - return lambda x: () - else: - return operator.itemgetter(index) - -def groupby(key, seq): - """ Group a collection by a key function - - >>> names = ['Alice', 'Bob', 'Charlie', 'Dan', 'Edith', 'Frank'] - >>> groupby(len, names) # doctest: +SKIP - {3: ['Bob', 'Dan'], 5: ['Alice', 'Edith', 'Frank'], 7: ['Charlie']} - - >>> iseven = lambda x: x % 2 == 0 - >>> groupby(iseven, [1, 2, 3, 4, 5, 6, 7, 8]) # doctest: +SKIP - {False: [1, 3, 5, 7], True: [2, 4, 6, 8]} - - Non-callable keys imply grouping on a member. - - >>> groupby('gender', [{'name': 'Alice', 'gender': 'F'}, - ... {'name': 'Bob', 'gender': 'M'}, - ... {'name': 'Charlie', 'gender': 'M'}]) # doctest:+SKIP - {'F': [{'gender': 'F', 'name': 'Alice'}], - 'M': [{'gender': 'M', 'name': 'Bob'}, - {'gender': 'M', 'name': 'Charlie'}]} - - Not to be confused with ``itertools.groupby`` - - See Also: - countby - """ - if not callable(key): - key = getter(key) - d = collections.defaultdict(lambda: [].append) # type: ignore[var-annotated] - for item in seq: - d[key(item)](item) - rv = {} - for k, v in d.items(): - rv[k] = v.__self__ # type: ignore[var-annotated, attr-defined] - return rv - -def first(seq): - """ The first element in a sequence - - >>> first('ABC') - 'A' - """ - return next(iter(seq)) diff --git a/pippy/fx/experimental/unification/utils.py b/pippy/fx/experimental/unification/utils.py deleted file mode 100644 index a54ad565d..000000000 --- a/pippy/fx/experimental/unification/utils.py +++ /dev/null @@ -1,106 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates -__all__ = ["hashable", "transitive_get", "raises", "reverse_dict", "xfail", "freeze"] -def hashable(x): - try: - hash(x) - return True - except TypeError: - return False - - -def transitive_get(key, d): - """ Transitive dict.get - >>> d = {1: 2, 2: 3, 3: 4} - >>> d.get(1) - 2 - >>> transitive_get(1, d) - 4 - """ - while hashable(key) and key in d: - key = d[key] - return key - - -def raises(err, lamda): - try: - lamda() - return False - except err: - return True - - -# Taken from theano/theano/gof/sched.py -# Avoids licensing issues because this was written by Matthew Rocklin -def _toposort(edges): - """ Topological sort algorithm by Kahn [1] - O(nodes + vertices) - inputs: - edges - a dict of the form {a: {b, c}} where b and c depend on a - outputs: - L - an ordered list of nodes that satisfy the dependencies of edges - >>> _toposort({1: (2, 3), 2: (3, )}) - >>> # xdoctest: +SKIP - [1, 2, 3] - Closely follows the wikipedia page [2] - [1] Kahn, Arthur B. (1962), "Topological sorting of large networks", - Communications of the ACM - [2] http://en.wikipedia.org/wiki/Toposort#Algorithms - """ - incoming_edges = reverse_dict(edges) - incoming_edges = dict((k, set(val)) for k, val in incoming_edges.items()) - S = set((v for v in edges if v not in incoming_edges)) - L = [] - - while S: - n = S.pop() - L.append(n) - for m in edges.get(n, ()): - assert n in incoming_edges[m] - incoming_edges[m].remove(n) - if not incoming_edges[m]: - S.add(m) - if any(incoming_edges.get(v, None) for v in edges): - raise ValueError("Input has cycles") - return L - - -def reverse_dict(d): - """Reverses direction of dependence dict - >>> d = {'a': (1, 2), 'b': (2, 3), 'c':()} - >>> reverse_dict(d) # doctest: +SKIP - {1: ('a',), 2: ('a', 'b'), 3: ('b',)} - :note: dict order are not deterministic. As we iterate on the - input dict, it make the output of this function depend on the - dict order. So this function output order should be considered - as undeterministic. - """ - result = {} # type: ignore[var-annotated] - for key in d: - for val in d[key]: - result[val] = result.get(val, tuple()) + (key, ) - return result - - -def xfail(func): - try: - func() - raise Exception("XFailed test passed") # pragma:nocover - except Exception: - pass - - -def freeze(d): - """ Freeze container to hashable form - >>> freeze(1) - 1 - >>> freeze([1, 2]) - (1, 2) - >>> freeze({1: 2}) # doctest: +SKIP - frozenset([(1, 2)]) - """ - if isinstance(d, dict): - return frozenset(map(freeze, d.items())) - if isinstance(d, set): - return frozenset(map(freeze, d)) - if isinstance(d, (tuple, list)): - return tuple(map(freeze, d)) - return d diff --git a/pippy/fx/experimental/unification/variable.py b/pippy/fx/experimental/unification/variable.py deleted file mode 100644 index e836d7653..000000000 --- a/pippy/fx/experimental/unification/variable.py +++ /dev/null @@ -1,81 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates -from contextlib import contextmanager -from .utils import hashable -from .dispatch import dispatch - -_global_logic_variables = set() # type: ignore[var-annotated] -_glv = _global_logic_variables - - -class Var(object): - """ Logic Variable """ - - _id = 1 - - def __new__(cls, *token): - if len(token) == 0: - token = "_%s" % Var._id # type: ignore[assignment] - Var._id += 1 - elif len(token) == 1: - token = token[0] - - obj = object.__new__(cls) - obj.token = token # type: ignore[attr-defined] - return obj - - def __str__(self): - return "~" + str(self.token) # type: ignore[attr-defined] - __repr__ = __str__ - - def __eq__(self, other): - return type(self) == type(other) and self.token == other.token # type: ignore[attr-defined] - - def __hash__(self): - return hash((type(self), self.token)) # type: ignore[attr-defined] - - -def var(): - return lambda *args: Var(*args) - -def vars(): - return lambda n: [var() for i in range(n)] - - -@dispatch(Var) -def isvar(v): - return True - -isvar - -@dispatch(object) # type: ignore[no-redef] -def isvar(o): - return not not _glv and hashable(o) and o in _glv - - -@contextmanager -def variables(*variables): - """ Context manager for logic variables - >>> from __future__ import with_statement - >>> with variables(1): - ... print(isvar(1)) - True - >>> print(isvar(1)) - False - >>> # xdoctest: +SKIP("undefined vars") - >>> # Normal approach - >>> from unification import unify - >>> x = var('x') - >>> unify(x, 1) - {~x: 1} - >>> # Context Manager approach - >>> with variables('x'): - ... print(unify('x', 1)) - {'x': 1} - """ - old_global_logic_variables = _global_logic_variables.copy() - _global_logic_variables.update(set(variables)) - try: - yield - finally: - _global_logic_variables.clear() - _global_logic_variables.update(old_global_logic_variables) diff --git a/pippy/fx/experimental/unify_refinements.py b/pippy/fx/experimental/unify_refinements.py deleted file mode 100644 index 07f9f4aca..000000000 --- a/pippy/fx/experimental/unify_refinements.py +++ /dev/null @@ -1,121 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates -from pippy.fx.experimental.graph_gradual_typechecker import Refine -from pippy.fx.tensor_type import TensorType -from pippy.fx.experimental.unification import Var, unify # type: ignore[attr-defined] - - -def infer_symbolic_types_single_pass(traced): - """ - Calls our symbolic inferencer once. - """ - r = Refine(traced) - r.refine() - mgu = unify_eq(r.constraints) - substitute_all_types(traced.graph, mgu) - -def infer_symbolic_types(traced): - """ - Calls our symbolic inferencer twice. - This is useful when one pass is not enough - to infer all the information such as the case - for braodcasting. - """ - r = Refine(traced) - r.refine() - mgu = unify_eq(r.constraints) - substitute_all_types(traced.graph, mgu) - - r = Refine(traced) - r.refine() - mgu = unify_eq(r.constraints) - substitute_all_types(traced.graph, mgu) - - r.symbolic_relations() - -def convert_eq(list_of_eq): - """ - Convert equality constraints in the right format - to be used by unification library. - """ - lhs = [] - rhs = [] - for eq in list_of_eq: - lhs.append(eq.lhs) - rhs.append(eq.rhs) - return tuple(lhs), tuple(rhs) - - -def unify_eq(list_of_eq): - """ - Apply unification to a set of - equality constraints - """ - lhs, rhs = convert_eq(list_of_eq) - return unify(lhs, rhs) - - -def substitute_solution_one_type(mapping, t): - """ - Apply the most general unifier to a type - """ - if isinstance(t, Var): - if t in mapping.keys(): - return mapping[t] - else: - return t - - elif isinstance(t, TensorType): - new_type = [] - for typ in t.__args__: - if typ in mapping.keys(): - new_type.append(mapping[typ]) - else: - new_type.append(typ) - return TensorType(tuple(new_type)) - - elif isinstance(t, list): - new_type = [] - for typ in t: - new_type.append(substitute_solution_one_type(mapping, typ)) - return new_type - - elif isinstance(t, tuple): - new_type = [] - for typ in t: - new_type.append(substitute_solution_one_type(mapping, typ)) - return tuple(new_type) - - else: - return t - - -def substitute_all_types(graph, mapping): - """ - Apply the most general unifier to all types in a graph - till reaching a fixed point. If the input and output graph - are the same, we converge. - """ - flag = True - while flag: - flag = False - for k in mapping: - old_mapping_val = mapping[k] - if mapping[k] in mapping.keys(): - new_key = mapping[k] - mapping[k] = mapping[new_key] - if old_mapping_val != mapping[k]: - flag = True - - for n in graph.nodes: - n.type = substitute_solution_one_type(mapping, n.type) - -def check_for_type_equality(g1, g2): - """ - A check equality to be used in fixed points. - We do not use graph equality but instead type - equality. - """ - for n, m in zip(g1.nodes, g2.nodes): - if n.type != m.type: - return False - return True diff --git a/pippy/fx/graph.py b/pippy/fx/graph.py deleted file mode 100644 index 2f30a750b..000000000 --- a/pippy/fx/graph.py +++ /dev/null @@ -1,1507 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates -import builtins -import contextlib -import copy -import inspect -import keyword -import math -import re -import warnings -from contextlib import contextmanager -from dataclasses import dataclass -from typing import TYPE_CHECKING, Callable, Any, List, Dict, NamedTuple, Optional, Tuple, Set, FrozenSet, Type - -import torch -import torch.utils._pytree as pytree - -import pippy -import pippy.fx -from . import _pytree as fx_pytree -from ._compatibility import compatibility -from .node import Node, Argument, Target, map_arg, _type_repr, _get_qualified_name - -__all__ = ["PythonCode", "CodeGen", "Graph"] - -if TYPE_CHECKING: - from .graph_module import GraphModule # noqa: F401 - from ._symbolic_trace import Tracer # noqa: F401 - - -# Mapping of builtins to their `typing` equivalent. -_origin_type_map = { - list: List, - dict: Dict, - set: Set, - frozenset: FrozenSet, - tuple: Tuple, -} - - -# Signature for functions thattransforms the body (`list[str]`) of the -# generated code -TransformCodeFunc = Callable[[List[str]], List[str]] - - -class _CustomBuiltin(NamedTuple): - """Additional objs that we add to every graph's globals. - - The repr() for some standard library objects is not valid Python code without - an import. For common objects of this sort, we bundle them in the globals of - every FX graph. - """ - # How to import this object from the standard library. - import_str: str - # The actual object, produced from that import string. - obj: Any - -_custom_builtins: Dict[str, _CustomBuiltin] = {} - - -def _register_custom_builtin(name: str, import_str: str, obj: Any): - _custom_builtins[name] = _CustomBuiltin(import_str, obj) - - -_register_custom_builtin('inf', 'from math import inf', math.inf) -_register_custom_builtin('nan', 'from math import nan', math.nan) -_register_custom_builtin('NoneType', 'NoneType = type(None)', type(None)) -_register_custom_builtin('torch', 'import torch', torch) -_register_custom_builtin('pippy', 'import pippy', pippy) -_register_custom_builtin('device', 'from torch import device', torch.device) -_register_custom_builtin('fx_pytree', 'import pippy.fx._pytree as fx_pytree', fx_pytree) -_register_custom_builtin('pytree', 'import torch.utils._pytree as pytree', pytree) - - -def _is_magic(x: str) -> bool: - return x.startswith('__') and x.endswith('__') - - -def _snake_case(s: str) -> str: - """ - Transforms the given string ``s`` to a Python-style variable name - - Examples: - ``mod.snake_case`` -> ``mod.snake_case`` - ``mod.pascalCase``-> ``mod.pascal_case`` - ``mod.ALL_CAPS`` -> ``mod.all_caps`` - """ - chars = [] - prev_lower = False - for c in s: - if prev_lower and c.isupper(): - chars.append('_') - chars.append(c.lower()) - prev_lower = c.islower() - return ''.join(chars) - - -def _is_from_torch(obj: Any) -> bool: - module_name = getattr(obj, '__module__', None) - if module_name is not None: - if module_name.startswith('pippy.fx'): - return True - - base_module = module_name.partition('.')[0] - return base_module == 'torch' - - name = getattr(obj, '__name__', None) - # exclude torch because torch.torch.torch.torch works. idk mang - if name is not None and name != 'torch': - for guess in [torch, torch.nn.functional]: - if getattr(guess, name, None) is obj: - return True - - return False - - -class _Namespace: - """A context for associating names uniquely with objects. - - The following invariants are enforced: - - Each object gets a single name. - - Each name is unique within a given namespace. - - Names generated do not shadow builtins, unless the object is indeed that builtin. - """ - def __init__(self): - self._obj_to_name: Dict[Any, str] = {} - self._unassociated_names = set() - self._used_names: Dict[str, int] = {} - - self._illegal_char_regex = re.compile('[^0-9a-zA-Z_]+') - self._name_suffix_regex = re.compile(r"(.*)_(\d+)$") - - def create_name(self, candidate: str, obj: Optional[Any]) -> str: - """Create a unique name. - - Arguments: - candidate: used as the basis for the unique name, relevant to the user. - obj: If not None, an object that will be associated with the unique name. - """ - if obj is not None and obj in self._obj_to_name: - return self._obj_to_name[obj] - - # delete all characters that are illegal in a Python identifier - candidate = self._illegal_char_regex.sub('_', candidate) - - if candidate[0].isdigit(): - candidate = f'_{candidate}' - - match = self._name_suffix_regex.match(candidate) - if match is None: - base = candidate - num = None - else: - base, num_str = match.group(1, 2) - num = int(num_str) - - candidate = base if num is None else f'{base}_{num}' - num = num if num else 0 - - while candidate in self._used_names or self._is_illegal_name(candidate, obj): - num += 1 - candidate = f'{base}_{num}' - - self._used_names.setdefault(candidate, 0) - if obj is None: - self._unassociated_names.add(candidate) - else: - self._obj_to_name[obj] = candidate - return candidate - - def associate_name_with_obj(self, name: str, obj: Any): - """Associate a unique name with an object. - - Neither `name` nor `obj` should be associated already. - """ - assert obj not in self._obj_to_name - assert name in self._unassociated_names - self._obj_to_name[obj] = name - self._unassociated_names.remove(name) - - def _is_illegal_name(self, name: str, obj: Any) -> bool: - # 1. keywords are never allowed as names. - if name in keyword.kwlist: - return True - - # 2. Can't shadow a builtin name, unless you *are* that builtin. - if name in builtins.__dict__: - return obj is not builtins.__dict__[name] - - # 3. Can't shadow our custom builtins either - if name in _custom_builtins: - return obj is not _custom_builtins[name].obj - - return False - - -@compatibility(is_backward_compatible=True) -@dataclass -class PythonCode: - """ - Represents all the information necessary to exec or save a graph as Python code. - """ - # Python source code for the forward function definition. - src: str - # Values in global scope during exection of `src_def`. - globals: Dict[str, Any] - - -def _format_target(base: str, target: str) -> str: - elems = target.split('.') - r = base - for e in elems: - if not e.isidentifier(): - r = f'getattr({r}, "{e}")' - else: - r = f'{r}.{e}' - return r - -class _InsertPoint: - def __init__(self, graph, new_insert): - self.graph = graph - self.orig_insert, graph._insert = graph._insert, new_insert - - def __enter__(self): - pass - - def __exit__(self, type, value, tb): - self.graph._insert = self.orig_insert - -class _node_list: - def __init__(self, graph: 'Graph', direction: str = '_next'): - assert direction in ['_next', '_prev'] - self.graph = graph - self.direction = direction - - def __len__(self): - return self.graph._len - - def __iter__(self): - root, direction = self.graph._root, self.direction - cur = getattr(root, direction) - while cur is not root: - if not cur._erased: - yield cur - cur = getattr(cur, direction) - - def __reversed__(self): - return _node_list(self.graph, '_next' if self.direction == '_prev' else '_prev') - -class _PyTreeInfo(NamedTuple): - """ - Contains extra info stored when we're using Pytrees - """ - orig_args: List[str] - in_spec: pytree.TreeSpec - out_spec: Optional[pytree.TreeSpec] - -@compatibility(is_backward_compatible=False) -class CodeGen(object): - def __init__(self): - self._body_transformer: Optional[TransformCodeFunc] = None - - def gen_fn_def(self, free_vars: List[str], maybe_return_annotation: str) -> str: - """ - Given the free variables and a return annotation, generates the beginning of the FX function. - By default, `gen_fn_def(['a', 'b'], '') == 'def forward(a, b):'` - """ - # If the original function didn't have self as its first argument, we - # would have added it. - if len(free_vars) == 0 or free_vars[0] != 'self': - free_vars.insert(0, 'self') - return f"def forward({', '.join(free_vars)}){maybe_return_annotation}:" - - def generate_output(self, output_args: Argument) -> str: - """ - Given the output arguments, generates the return statement of the FX function. - Note: The returned statement should not be indented. - """ - return f'return {repr(output_args)}' - - def process_inputs(self, *args: Any) -> Any: - """ - Transforms the inputs so that the graph can take them as arguments, as - non-default codegen may result in the inputs to the function being - different from the inputs to the graph. - - If the graph was directly runnable, this invariant should hold true - `f.graph.process_outputs(f.graph(*f.graph.process_inputs(*inputs))) == f(*inputs)` - """ - return args - - def process_outputs(self, outputs: Any) -> Any: - """ - Transforms the outputs of the graph to be identical to the codegen. - - See ``process_inputs`` for more details. - """ - return outputs - - def additional_globals(self) -> List[Tuple[str, Any]]: - """ - If your codegen uses extra global values, add tuples of (identifier,reference to the value) here. - For example, return ['List', typing.List] if you need ``List`` in the global context. - """ - return [] - - def _gen_python_code(self, nodes, root_module: str, namespace: _Namespace, *, verbose: bool = False) -> PythonCode: - free_vars: List[str] = [] - body: List[str] = [] - globals_: Dict[str, Any] = {} - wrapped_fns: Dict[str, None] = {} - - # Wrap string in list to pass by reference - maybe_return_annotation : List[str] = [''] - - def add_global(name_hint: str, obj: Any): - """Add an obj to be tracked as a global. - - We call this for names that reference objects external to the - Graph, like functions or types. - - Returns: the global name that should be used to reference 'obj' in generated source. - """ - if _is_from_torch(obj) and obj != torch.device: # to support registering torch.device - # HACK: workaround for how torch custom ops are registered. We - # can't import them like normal modules so they must retain their - # fully qualified name. - return _get_qualified_name(obj) - - # normalize the name hint to get a proper identifier - global_name = namespace.create_name(name_hint, obj) - - if global_name in globals_: - assert globals_[global_name] is obj - return global_name - globals_[global_name] = obj - return global_name - - # Pre-fill the globals table with registered builtins. - for name, (_, obj) in _custom_builtins.items(): - add_global(name, obj) - - def type_repr(o : Any): - if o == (): - # Empty tuple is used for empty tuple type annotation Tuple[()] - return '()' - - typename = _type_repr(o) - - if hasattr(o, '__origin__'): - # This is a generic type, e.g. typing.List[torch.Tensor] - origin_type = _origin_type_map.get(o.__origin__, o.__origin__) - origin_typename = add_global(_type_repr(origin_type), origin_type) - - if hasattr(o, '__args__'): - # Assign global names for each of the inner type variables. - args = [type_repr(arg) for arg in o.__args__] - - if len(args) == 0: - # Bare type, such as `typing.Tuple` with no subscript - # This code-path used in Python < 3.9 - return origin_typename - - return f'{origin_typename}[{",".join(args)}]' - else: - # Bare type, such as `typing.Tuple` with no subscript - # This code-path used in Python 3.9+ - return origin_typename - - # Common case: this is a regular module name like 'foo.bar.baz' - return add_global(typename, o) - - def _format_args(args: Tuple[Argument, ...], kwargs: Dict[str, Argument]) -> str: - def _get_repr(arg): - # Handle NamedTuples (if it has `_fields`) via add_global. - if isinstance(arg, tuple) and hasattr(arg, '_fields'): - qualified_name = _get_qualified_name(type(arg)) - global_name = add_global(qualified_name, type(arg)) - return f"{global_name}{repr(tuple(arg))}" - return repr(arg) - args_s = ', '.join(_get_repr(a) for a in args) - kwargs_s = ', '.join(f'{k} = {_get_repr(v)}' for k, v in kwargs.items()) - if args_s and kwargs_s: - return f'{args_s}, {kwargs_s}' - return args_s or kwargs_s - - # Run through reverse nodes and record the first instance of a use - # of a given node. This represents the *last* use of the node in the - # execution order of the program, which we will use to free unused - # values - node_to_last_use : Dict[Node, Node] = {} - user_to_last_uses : Dict[Node, List[Node]] = {} - - def register_last_uses(n : Node, user : Node): - if n not in node_to_last_use: - node_to_last_use[n] = user - user_to_last_uses.setdefault(user, []).append(n) - - for node in reversed(nodes): - map_arg(node.args, lambda n: register_last_uses(n, node)) - map_arg(node.kwargs, lambda n: register_last_uses(n, node)) - - def delete_unused_values(user : Node): - """ - Delete values after their last use. This ensures that values that are - not used in the remainder of the code are freed and the memory usage - of the code is optimal. - """ - if user.op == 'placeholder': - return - if user.op == 'output': - body.append('\n') - return - nodes_to_delete = user_to_last_uses.get(user, []) - if len(nodes_to_delete): - to_delete_str = ' = '.join([repr(n) for n in nodes_to_delete] + ['None']) - body.append(f'; {to_delete_str}\n') - else: - body.append('\n') - - prev_stacktrace = None - - def append_stacktrace_summary(node : Node): - """ - Append a summary of the stacktrace to the generated code. This is - useful for debugging. - """ - nonlocal prev_stacktrace - pattern = re.compile(r"^File \"(.+)\", line (\d+), in (.+)$") - - if node.op not in {'placeholder', 'output'}: - if node.stack_trace: - if node.stack_trace != prev_stacktrace: - prev_stacktrace = node.stack_trace - - lines = node.stack_trace.strip().split('\n') - idx = 0 - context_lines = [] - while idx < len(lines): - line = lines[idx].strip() - if line.startswith('File '): - break - context_lines.append(line) - idx += 1 - - summary_lines = [] - if context_lines: - summary_lines.append(', '.join(context_lines)) - - if idx + 1 < len(lines): - matches = pattern.match(lines[idx].strip()) - if matches: - file = matches.group(1) - lineno = matches.group(2) - lineage = f'File: {file}:{lineno}' - summary_lines.append(lineage) - - code = f"code: {lines[idx + 1].strip()}" - summary_lines.append(code) - - summary_str = ', '.join(summary_lines) - body.append(f'\n# {summary_str}\n') - elif prev_stacktrace != "": - prev_stacktrace = "" - body.append('\n# No stacktrace found for following nodes \n') - - def emit_node(node : Node): - maybe_type_annotation = '' if node.type is None else f' : {type_repr(node.type)}' - if node.op == 'placeholder': - assert isinstance(node.target, str) - maybe_default_arg = '' if not node.args else f' = {repr(node.args[0])}' - free_vars.append(f'{node.target}{maybe_type_annotation}{maybe_default_arg}') - raw_name = node.target.replace('*', '') - if raw_name != repr(node): - body.append(f'{repr(node)} = {raw_name}\n') - return - elif node.op == 'call_method': - assert isinstance(node.target, str) - body.append( - f'{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.target)}' - f'({_format_args(node.args[1:], node.kwargs)})') - return - elif node.op == 'call_function': - assert callable(node.target) - # pretty print operators - if node.target.__module__ == '_operator' and node.target.__name__ in magic_methods: - assert isinstance(node.args, tuple) - body.append(f'{repr(node)}{maybe_type_annotation} = ' - f'{magic_methods[node.target.__name__].format(*(repr(a) for a in node.args))}') - return - - # pretty print inplace operators; required for jit.script to work properly - # not currently supported in normal FX graphs, but generated by torchdynamo - if node.target.__module__ == '_operator' and node.target.__name__ in inplace_methods: - body.append(f'{inplace_methods[node.target.__name__].format(*(repr(a) for a in node.args))}; ' - f'{repr(node)}{maybe_type_annotation} = {repr(node.args[0])}') - return - - qualified_name = _get_qualified_name(node.target) - global_name = add_global(qualified_name, node.target) - # special case for getattr: node.args could be 2-argument or 3-argument - # 2-argument: attribute access; 3-argument: fall through to attrib function call with default value - if global_name == 'getattr' and \ - isinstance(node.args, tuple) and \ - isinstance(node.args[1], str) and \ - node.args[1].isidentifier() and \ - len(node.args) == 2: - body.append(f'{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.args[1])}') - return - body.append(f'{repr(node)}{maybe_type_annotation} = {global_name}({_format_args(node.args, node.kwargs)})') - if node.meta.get('is_wrapped', False): - wrapped_fns.setdefault(global_name) - return - elif node.op == 'call_module': - assert isinstance(node.target, str) - body.append(f'{repr(node)}{maybe_type_annotation} = ' - f'{_format_target(root_module, node.target)}({_format_args(node.args, node.kwargs)})') - return - elif node.op == 'get_attr': - assert isinstance(node.target, str) - body.append(f'{repr(node)}{maybe_type_annotation} = {_format_target(root_module, node.target)}') - return - elif node.op == 'output': - if node.type is not None: - maybe_return_annotation[0] = f" -> {type_repr(node.type)}" - body.append(self.generate_output(node.args[0])) - return - raise NotImplementedError(f'node: {node.op} {node.target}') - - for node in nodes: - # NOTE: emit_node does not emit a string with newline. It depends - # on delete_unused_values to append one - if verbose: - append_stacktrace_summary(node) - emit_node(node) - delete_unused_values(node) - - if len(body) == 0: - # If the Graph has no non-placeholder nodes, no lines for the body - # have been emitted. To continue to have valid Python code, emit a - # single pass statement - body.append('pass\n') - - - - if len(wrapped_fns) > 0: - wrap_name = add_global('wrap', pippy.fx.wrap) - wrap_stmts = '\n'.join([f'{wrap_name}("{name}")' for name in wrapped_fns]) - else: - wrap_stmts = '' - - if self._body_transformer: - body = self._body_transformer(body) - - for name, value in self.additional_globals(): - add_global(name, value) - - prologue = self.gen_fn_def(free_vars, maybe_return_annotation[0]) - - code = ''.join(body) - code = '\n'.join(' ' + line for line in code.split('\n')) - fn_code = f""" -{wrap_stmts} - -{prologue} -{code}""" - return PythonCode(fn_code, globals_) - - -# Ideally, we'd like to refactor all of the pytree logic into this codegen -# class. Unfortunately, there are 3 areas we currently need extra logic in FX. -# 1. In the initial symbolic trace, the pytree logic is tied up with `concrete_args`. -# 2. In the FX graph, we need to access 2 attributes - in_spec and out_spec. -# Since we can't access .graph within the FX forward, we need to copy the attribute to the module. -# 3. We currently can't register the pytree imports with `add_global` - not sure why. -class _PyTreeCodeGen(CodeGen): - def __init__(self, pytree_info: _PyTreeInfo): - super().__init__() - self.pytree_info: _PyTreeInfo = pytree_info - - def process_inputs(self, *inputs: Any) -> Any: - flat_args, _ = pytree.tree_flatten(inputs) - return flat_args - - def process_outputs(self, out: Any) -> Any: - if self.pytree_info is None: - return out - if not isinstance(out, list): - out = [out] - assert(self.pytree_info.out_spec is not None) - return pytree.tree_unflatten(out, self.pytree_info.out_spec) - - def gen_fn_def(self, free_vars, maybe_return_annotation): - if self.pytree_info is None: - return super().gen_fn_def(free_vars, maybe_return_annotation) - function_args = self.pytree_info.orig_args - has_orig_self = (function_args[0] == 'self') - if has_orig_self: - free_vars.insert(0, 'self') - function_definition = super().gen_fn_def(function_args[:], maybe_return_annotation) - if len(free_vars) > 0: # pytree has placeholders in it - function_definition += f""" - {', '.join(free_vars)}, = fx_pytree.tree_flatten_spec([{', '.join(function_args)}], self._in_spec)""" - return function_definition - - def generate_output(self, output_args): - if self.pytree_info: - return f'return pytree.tree_unflatten({repr(output_args)}, self._out_spec)' - else: - return super().generate_output(output_args) - -@compatibility(is_backward_compatible=True) -class Graph: - """ - ``Graph`` is the main data structure used in the FX Intermediate Representation. - It consists of a series of ``Node`` s, each representing callsites (or other - syntactic constructs). The list of ``Node`` s, taken together, constitute a - valid Python function. - - For example, the following code - - .. code-block:: python - - import torch - import pippy.fx - - class MyModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.param = torch.nn.Parameter(torch.rand(3, 4)) - self.linear = torch.nn.Linear(4, 5) - - def forward(self, x): - return torch.topk(torch.sum(self.linear(x + self.linear.weight).relu(), dim=-1), 3) - - m = MyModule() - gm = pippy.fx.symbolic_trace(m) - - Will produce the following Graph:: - - print(gm.graph) - - .. code-block:: text - - graph(x): - %linear_weight : [#users=1] = self.linear.weight - %add_1 : [#users=1] = call_function[target=operator.add](args = (%x, %linear_weight), kwargs = {}) - %linear_1 : [#users=1] = call_module[target=linear](args = (%add_1,), kwargs = {}) - %relu_1 : [#users=1] = call_method[target=relu](args = (%linear_1,), kwargs = {}) - %sum_1 : [#users=1] = call_function[target=torch.sum](args = (%relu_1,), kwargs = {dim: -1}) - %topk_1 : [#users=1] = call_function[target=torch.topk](args = (%sum_1, 3), kwargs = {}) - return topk_1 - - For the semantics of operations represented in the ``Graph``, please see :class:`Node`. - """ - - @compatibility(is_backward_compatible=True) - def __init__(self, owning_module: Optional["GraphModule"] = None, tracer_cls: Optional[Type["Tracer"]] = None, - tracer_extras: Optional[Dict[str, Any]] = None): - """ - Construct an empty Graph. - """ - self._root : Node = Node(self, '', 'root', '', (), {}) - self._used_names : Dict[str, int] = {} # base name -> number - self._insert = self._root.prepend - self._len = 0 - self._graph_namespace = _Namespace() - self._owners = 0 - self._owning_module = owning_module - self._tracer_cls = tracer_cls - self._tracer_extras = tracer_extras - self._codegen = CodeGen() - - @property - def owning_module(self): - """ - Return the module that owns this ``GraphModule``, if there is one, - ``None`` if there is no owning module or if there are multiple owning - modules. - """ - return self._owning_module - - @owning_module.setter - def owning_module(self, mod: Optional["GraphModule"]): - if mod: - self._owning_module = mod if not self._owners else None - self._owners += 1 - - @property - def nodes(self) -> _node_list: - """ - Get the list of Nodes that constitute this Graph. - - Note that this ``Node`` list representation is a doubly-linked list. Mutations - during iteration (e.g. delete a Node, add a Node) are safe. - - Returns: - - A doubly-linked list of Nodes. Note that ``reversed`` can be called on - this list to switch iteration order. - """ - return _node_list(self) - - @compatibility(is_backward_compatible=True) - def graph_copy(self, g : 'Graph', val_map : Dict[Node, Node], return_output_node=False) -> 'Optional[Argument]': - """ - Copy all nodes from a given graph into ``self``. - - Args: - - g (Graph): The source graph from which to copy Nodes. - - val_map (Dict[Node, Node]): a dictionary that will be populated with a mapping - from nodes in ``g`` to nodes in ``self``. Note that ``val_map`` can be passed - in with values in it already to override copying of certain values. - - Returns: - - The value in ``self`` that is now equivalent to the output value in ``g``, - if ``g`` had an ``output`` node. ``None`` otherwise. - """ - for node in g.nodes: - if node in val_map: - continue - if node.op == 'output': - rv = map_arg(node.args[0], lambda n: val_map[n]) - return rv if not return_output_node else (rv, node) - val_map[node] = self.node_copy(node, lambda n : val_map[n]) - return None - - def __deepcopy__(self, memo=None) -> 'Graph': - """ - Explicitly implement __deepcopy__ to prevent excessive recursion depth - from the default implementation. This uses graph_copy to copy the nodes - in an iterative way, rather than recursive. It also populates the - memoization table to prevent unnecessary copies (e.g. references to - nodes or other parts of the Graph from a custom GraphModule implementation. - """ - memo = memo if memo else {} - g = Graph(tracer_cls=self._tracer_cls) - output_vals = g.graph_copy(self, val_map=memo, return_output_node=True) - g._codegen = copy.deepcopy(self._codegen) - assert isinstance(output_vals, tuple) - output_val, old_output_val = output_vals - g.output(output_val, type_expr=getattr(old_output_val, 'type', None)) - return g - - @compatibility(is_backward_compatible=True) - def create_node(self, op: str, target: 'Target', - args: Optional[Tuple['Argument', ...]] = None, - kwargs: Optional[Dict[str, 'Argument']] = None, - name: Optional[str] = None, - type_expr: Optional[Any] = None) -> Node: - """ - Create a ``Node`` and add it to the ``Graph`` at the current insert-point. - Note that the current insert-point can be set via :meth:`Graph.inserting_before` - and :meth:`Graph.inserting_after`. - - Args: - op (str): the opcode for this Node. One of 'call_function', 'call_method', 'get_attr', - 'call_module', 'placeholder', or 'output'. The semantics of these opcodes are - described in the ``Graph`` docstring. - - args (Optional[Tuple[Argument, ...]]): is a tuple of arguments to this node. - - kwargs (Optional[Dict[str, Argument]]): the kwargs of this Node - - name (Optional[str]): an optional string name for the ``Node``. - This will influence the name of the value assigned to in the - Python generated code. - - type_expr (Optional[Any]): an optional type annotation representing the - Python type the output of this node will have. - - Returns: - - The newly-created and inserted node. - """ - assert op in ('call_function', 'call_method', 'get_attr', 'call_module', 'placeholder', 'output') - args = () if args is None else args - kwargs = {} if kwargs is None else kwargs - assert isinstance(args, tuple), "args must be a tuple" - assert isinstance(kwargs, dict), "kwargs must be a dict" - - candidate = name if name is not None else self._target_to_str(target) - name = self._graph_namespace.create_name(candidate, None) - n = Node(self, name, op, target, args, kwargs, type_expr) - - self._graph_namespace.associate_name_with_obj(name, n) - - self._insert(n) - self._len += 1 - return n - - @compatibility(is_backward_compatible=False) - def process_inputs(self, *args): - """ - Processes args so that they can be passed to the FX graph. - """ - return self._codegen.process_inputs(*args) - - @compatibility(is_backward_compatible=False) - def process_outputs(self, out): - return self._codegen.process_outputs(out) - - - @compatibility(is_backward_compatible=True) - def erase_node(self, to_erase : Node) -> None: - """ - Erases a ``Node`` from the ``Graph``. Throws an exception if - there are still users of that node in the ``Graph``. - - Args: - - to_erase (Node): The ``Node`` to erase from the ``Graph``. - """ - if len(to_erase.users) > 0: - raise RuntimeError(f'Tried to erase Node {to_erase} but it still had {len(to_erase.users)} ' - f'users in the graph: {to_erase.users}!') - - to_erase._remove_from_list() - to_erase._erased = True # iterators may retain handles to erased nodes - self._len -= 1 - - # Null out this Node's argument nodes so that the Nodes referred to - # can update their ``users`` accordingly - new_args = map_arg(to_erase.args, lambda n: None) - assert isinstance(new_args, tuple) - to_erase.args = new_args - new_kwargs = map_arg(to_erase.kwargs, lambda n: None) - assert isinstance(new_kwargs, dict) - to_erase.kwargs = new_kwargs - - @compatibility(is_backward_compatible=True) - def inserting_before(self, n: Optional[Node] = None): - """Set the point at which create_node and companion methods will insert into the graph. - When used within a 'with' statement, this will temporary set the insert point and - then restore it when the with statement exits:: - - with g.inserting_before(n): - ... # inserting before node n - ... # insert point restored to what it was previously - g.inserting_before(n) # set the insert point permanently - - Args: - - n (Optional[Node]): The node before which to insert. If None this will insert before - the beginning of the entire graph. - - Returns: - A resource manager that will restore the insert point on ``__exit__``. - """ - if n is None: - return self.inserting_after(self._root) - assert n.graph == self, "Node to insert before is not in graph." - return _InsertPoint(self, n.prepend) - - @compatibility(is_backward_compatible=True) - def inserting_after(self, n: Optional[Node] = None): - """Set the point at which create_node and companion methods will insert into the graph. - When used within a 'with' statement, this will temporary set the insert point and - then restore it when the with statement exits:: - - with g.inserting_after(n): - ... # inserting after node n - ... # insert point restored to what it was previously - g.inserting_after(n) # set the insert point permanently - - Args: - - n (Optional[Node]): The node before which to insert. If None this will insert after - the beginning of the entire graph. - - Returns: - A resource manager that will restore the insert point on ``__exit__``. - """ - if n is None: - return self.inserting_before(self._root) - assert n.graph == self, "Node to insert after is not in graph." - return _InsertPoint(self, n.append) - - @compatibility(is_backward_compatible=True) - def placeholder(self, name: str, type_expr: Optional[Any] = None, - default_value : Any = inspect.Signature.empty) -> Node: - """ - Insert a ``placeholder`` node into the Graph. A ``placeholder`` represents - a function input. - - Args: - - name (str): A name for the input value. This corresponds to the name - of the positional argument to the function this ``Graph`` represents. - - type_expr (Optional[Any]): an optional type annotation representing the - Python type the output of this node will have. This is needed in some - cases for proper code generation (e.g. when the function is used - subsequently in TorchScript compilation). - - default_value (Any): The default value this function argument should take - on. NOTE: to allow for `None` as a default value, `inspect.Signature.empty` - should be passed as this argument to specify that the parameter does _not_ - have a default value. - - .. note:: - The same insertion point and type expression rules apply for this method - as ``Graph.create_node``. - """ - args = () if default_value is inspect.Signature.empty else (default_value,) - return self.create_node('placeholder', name, args=args, type_expr=type_expr) - - @compatibility(is_backward_compatible=True) - def get_attr(self, qualified_name: str, type_expr: Optional[Any] = None) -> Node: - """ - Insert a ``get_attr`` node into the Graph. A ``get_attr`` ``Node`` represents the - fetch of an attribute from the ``Module`` hierarchy. - - Args: - - qualified_name (str): the fully-qualified name of the attribute to be retrieved. - For example, if the traced Module has a submodule named ``foo``, which has a - submodule named ``bar``, which has an attribute named ``baz``, the qualified - name ``foo.bar.baz`` should be passed as ``qualified_name``. - - type_expr (Optional[Any]): an optional type annotation representing the - Python type the output of this node will have. - - - Returns: - - The newly-created and inserted ``get_attr`` node. - - .. note:: - The same insertion point and type expression rules apply for this method - as ``Graph.create_node``. - """ - def _get_attr_reference_exists(mod: torch.nn.Module, qualified_name: str) -> bool: - module_path, _, name = qualified_name.rpartition(".") - - try: - submod: torch.nn.Module = mod.get_submodule(module_path) - except AttributeError: - warnings.warn(f"Failed to fetch module {module_path}!") - return False - - if not hasattr(submod, name): - return False - - res = getattr(submod, name) - - if (not isinstance(res, torch.nn.Module) - and not isinstance(res, torch.nn.Parameter) - and name not in submod._buffers): - return False - - return True - - if (self.owning_module and - not _get_attr_reference_exists(self.owning_module, qualified_name)): - warnings.warn("Attempted to insert a get_attr Node with no " - "underlying reference in the owning " - "GraphModule! Call " - "GraphModule.add_submodule to add the " - "necessary submodule, " - "GraphModule.add_parameter to add the " - "necessary Parameter, or " - "nn.Module.register_buffer to add the " - "necessary buffer", stacklevel=2) - return self.create_node('get_attr', qualified_name, type_expr=type_expr) - - @compatibility(is_backward_compatible=True) - def call_module(self, - module_name: str, - args: Optional[Tuple['Argument', ...]] = None, - kwargs: Optional[Dict[str, 'Argument']] = None, - type_expr: Optional[Any] = None) -> Node: - """ - Insert a ``call_module`` ``Node`` into the ``Graph``. A ``call_module`` node - represents a call to the forward() function of a ``Module`` in the ``Module`` - hierarchy. - - Args: - - module_name (str): The qualified name of the ``Module`` in the ``Module`` - hierarchy to be called. For example, if the traced ``Module`` has a - submodule named ``foo``, which has a submodule named ``bar``, the - qualified name ``foo.bar`` should be passed as ``module_name`` to - call that module. - - args (Optional[Tuple[Argument, ...]]): The positional arguments to be passed - to the called method. Note that this should *not* include a ``self`` argument. - - kwargs (Optional[Dict[str, Argument]]): The keyword arguments to be passed - to the called method - - type_expr (Optional[Any]): an optional type annotation representing the - Python type the output of this node will have. - - Returns: - - The newly-created and inserted ``call_module`` node. - - .. note:: - The same insertion point and type expression rules apply for this method - as :meth:`Graph.create_node`. - """ - if (self.owning_module and - self.owning_module.get_submodule(module_name) is None): - warnings.warn("Attempted to insert a call_module Node with " - "no underlying reference in the owning " - "GraphModule! Call " - "GraphModule.add_submodule to add the " - "necessary submodule") - return self.create_node('call_module', module_name, args, kwargs, type_expr=type_expr) - - @compatibility(is_backward_compatible=True) - def call_method(self, - method_name: str, - args: Optional[Tuple['Argument', ...]] = None, - kwargs: Optional[Dict[str, 'Argument']] = None, - type_expr: Optional[Any] = None) -> Node: - """ - Insert a ``call_method`` ``Node`` into the ``Graph``. A ``call_method`` node - represents a call to a given method on the 0th element of ``args``. - - Args: - - method_name (str): The name of the method to apply to the self argument. - For example, if args[0] is a ``Node`` representing a ``Tensor``, - then to call ``relu()`` on that ``Tensor``, pass ``relu`` to ``method_name``. - - args (Optional[Tuple[Argument, ...]]): The positional arguments to be passed - to the called method. Note that this *should* include a ``self`` argument. - - kwargs (Optional[Dict[str, Argument]]): The keyword arguments to be passed - to the called method - - type_expr (Optional[Any]): an optional type annotation representing the - Python type the output of this node will have. - - Returns: - - The newly created and inserted ``call_method`` node. - - .. note:: - The same insertion point and type expression rules apply for this method - as :meth:`Graph.create_node`. - """ - return self.create_node('call_method', method_name, args, kwargs, type_expr=type_expr) - - @compatibility(is_backward_compatible=True) - def call_function(self, - the_function: Callable[..., Any], - args: Optional[Tuple['Argument', ...]] = None, - kwargs: Optional[Dict[str, 'Argument']] = None, - type_expr: Optional[Any] = None) -> Node: - """ - Insert a ``call_function`` ``Node`` into the ``Graph``. A ``call_function`` node - represents a call to a Python callable, specified by ``the_function``. - - Args: - - the_function (Callable[..., Any]): The function to be called. Can be any PyTorch - operator, Python function, or member of the ``builtins`` or ``operator`` - namespaces. - - args (Optional[Tuple[Argument, ...]]): The positional arguments to be passed - to the called function. - - kwargs (Optional[Dict[str, Argument]]): The keyword arguments to be passed - to the called function - - type_expr (Optional[Any]): an optional type annotation representing the - Python type the output of this node will have. - - Returns: - - The newly created and inserted ``call_function`` node. - - .. note:: - The same insertion point and type expression rules apply for this method - as :meth:`Graph.create_node`. - """ - return self.create_node('call_function', the_function, args, kwargs, type_expr=type_expr) - - @compatibility(is_backward_compatible=True) - def node_copy(self, node: Node, arg_transform: Callable[[Node], 'Argument'] = lambda x: x) -> Node: - """ - Copy a node from one graph into another. ``arg_transform`` needs to transform arguments from - the graph of node to the graph of self. Example:: - - # Copying all the nodes in `g` into `new_graph` - g : pippy.fx.Graph = ... - new_graph = pippy.fx.graph() - value_remap = {} - for node in g.nodes: - value_remap[node] = new_graph.node_copy(node, lambda n : value_remap[n]) - - Args: - - node (Node): The node to copy into ``self``. - - arg_transform (Callable[[Node], Argument]): A function that transforms - ``Node`` arguments in node's ``args`` and ``kwargs`` into the - equivalent argument in ``self``. In the simplest case, this should - retrieve a value out of a table mapping Nodes in the original - graph to ``self``. - """ - args = map_arg(node.args, arg_transform) - kwargs = map_arg(node.kwargs, arg_transform) - assert isinstance(args, tuple) - assert isinstance(kwargs, dict) - result_node = self.create_node(node.op, node.target, args, kwargs, node.name, node.type) - result_node.meta = copy.copy(node.meta) - return result_node - - @compatibility(is_backward_compatible=True) - def output(self, result: 'Argument', type_expr: Optional[Any] = None): - """ - Insert an ``output`` ``Node`` into the ``Graph``. An ``output`` node represents - a ``return`` statement in Python code. ``result`` is the value that should - be returned. - - Args: - - result (Argument): The value to be returned. - - type_expr (Optional[Any]): an optional type annotation representing the - Python type the output of this node will have. - - .. note:: - - The same insertion point and type expression rules apply for this method - as ``Graph.create_node``. - """ - return self.create_node(op='output', target='output', args=(result,), type_expr=type_expr) - - def _target_to_str(self, target : Target) -> str: - if callable(target): - op = target.__name__ - else: - assert isinstance(target, str) - op = target - if _is_magic(op): - op = op[2:-2] - op = _snake_case(op) - return op - - @compatibility(is_backward_compatible=True) - def python_code(self, root_module: str, *, verbose: bool = False) -> PythonCode: - """ - Turn this ``Graph`` into valid Python code. - - Args: - - root_module (str): The name of the root module on which to look-up - qualified name targets. This is usually 'self'. - - Returns: - - A PythonCode object, consisting of two fields: - src: the Python source code representing the object - globals: a dictionary of global names in `src` -> the objects that they reference. - """ - # NOTE: [Graph Namespaces] - # - # There are two types of symbols in generated Python source code: - # locals and globals. - # Locals are locally defined by the output of a node in the Graph. - # Globals are references to external objects, like functions or types. - # - # When generating Python code, we need to make sure to name things - # appropriately. In particular: - # - All names should be unique, to avoid weird shadowing bugs. - # - These names need to be consistent, e.g. a object should always be - # referenced by the same name. - # - # To do this, we create a new namespace just for this source. All names - # that get printed must come from this namespace. - # - # Why can't we re-use node.name? Because it was generated within the - # namespace `self._graph_namespace`. In order to provide uniqueness - # over both locals (node.name) *and* globals, we create a completely - # new namespace to put all identifiers in. - namespace = _Namespace() - - # Override Node's repr to generate a valid name within our namespace. - # Since repr() is designed to produce a valid Python expression, it - # makes sense to re-use it. This way, it's easy to print something like - # Tuple[Node, Node] by simply calling repr() on it. Node's __repr__ is - # implemented cooperatively to allow this. - def node_repr(n: Node): - return namespace.create_name(n.name, n) - - @contextmanager - def override_node_repr(graph: Graph): - orig_repr_fns = {} - for node in graph.nodes: - orig_repr_fns[node] = node._repr_fn - node._repr_fn = node_repr - try: - yield None - finally: - # restore the original repr functions - for node in graph.nodes: - node._repr_fn = orig_repr_fns[node] - - with override_node_repr(self): - return self._python_code(root_module, namespace, verbose=verbose) - - def _python_code(self, root_module: str, namespace: _Namespace, *, verbose: bool = False) -> PythonCode: - return self._codegen._gen_python_code(self.nodes, root_module, namespace, verbose=verbose) - - - def __str__(self) -> str: - """ - Return a human-readable (not machine-readable) string representation - of this Graph - """ - placeholder_names : List[str] = [] - # This is a one-element array just so ``format_node`` can modify the closed - # over value - maybe_return_typename : List[str] = [''] - - node_strs = [node.format_node(placeholder_names) for node in self.nodes] - param_str = ', '.join(placeholder_names) - s = f'graph({param_str}){maybe_return_typename[0]}:' - for node_str in node_strs: - if node_str: - s += '\n ' + node_str - return s - - @compatibility(is_backward_compatible=True) - def print_tabular(self): - """ - Prints the intermediate representation of the graph in tabular - format. Note that this API requires the ``tabulate`` module to be - installed. - """ - try: - from tabulate import tabulate - except ImportError: - print("`print_tabular` relies on the library `tabulate`, " - "which could not be found on this machine. Run `pip " - "install tabulate` to install the library.") - node_specs = [[n.op, n.name, n.target, n.args, n.kwargs] - for n in self.nodes] - print(tabulate(node_specs, - headers=['opcode', 'name', 'target', 'args', 'kwargs'])) - - @compatibility(is_backward_compatible=True) - def lint(self): - """ - Runs various checks on this Graph to make sure it is well-formed. In - particular: - - Checks Nodes have correct ownership (owned by this graph) - - Checks Nodes appear in topological order - - If this Graph has an owning GraphModule, checks that targets - exist in that GraphModule - """ - - # Check topo order - def check_arg(arg : Node, n : Optional[Node] = None) -> None: - context_str = f' of Node \'{n}\' ' if n else ' ' - if arg.graph is not self: - raise RuntimeError(f'Argument \'{arg}\'{context_str}does not belong to this Graph, ' - f'but was used as an argument! If you are copying nodes from another graph, make ' - f'sure to use ``arg_transform`` on node_copy() to remap values\n{self}') - if arg not in seen_values: - raise RuntimeError(f'Argument \'{arg}\'{context_str}was used before it has been ' - f'defined! Please check that Nodes in the graph are topologically ordered\n{self}') - - seen_names : Set[str] = set() - seen_values : Set[Node] = set() - for node in self.nodes: - if node.op not in ['placeholder', 'call_method', 'call_module', 'call_function', 'get_attr', 'output']: - raise RuntimeError(f'Node {node} had unknown opcode {node.op}!') - if node.graph is not self: - raise RuntimeError(f'Node \'{node}\' does not belong to this Graph!') - map_arg(node.args, lambda arg: check_arg(arg, node)) - map_arg(node.kwargs, lambda arg: check_arg(arg, node)) - seen_values.add(node) - - if node.name in seen_names: - raise RuntimeError(f'Node redefined name {node.name}!') - seen_names.add(node.name) - - # Check targets are legit - if self.owning_module: - for node in self.nodes: - if node.op == 'call_function': - if not callable(node.target): - raise ValueError(f'Node {node} target {node.target} has type {torch.typename(node.target)} but ' - 'a Callable is expected') - else: - if not isinstance(node.target, str): - raise ValueError(f'Node {node} target {node.target} has type {torch.typename(node.target)} but ' - 'a str is expected') - if node.op in ['get_attr', 'call_module']: - target_atoms = node.target.split('.') - m_itr = self.owning_module - for i, atom in enumerate(target_atoms): - new_m_itr = getattr(m_itr, atom, None) - seen_qualname = '.'.join(target_atoms[:i]) - if new_m_itr is None: - raise RuntimeError(f'Node {node} target {node.target} references nonexistent attribute ' - f'{atom} of {seen_qualname}') - if (node.op == "call_module" - and not isinstance(new_m_itr, torch.nn.Module)): - raise RuntimeError(f'Node {node} target {node.target} {atom} of {seen_qualname} does ' - 'not reference an nn.Module') - elif (node.op == "get_attr" - and not isinstance(new_m_itr, torch.nn.Module) - and not isinstance(new_m_itr, torch.nn.Parameter) - and atom not in m_itr._buffers): - warnings.warn(f'Node {node} target {node.target} {atom} of {seen_qualname} does ' - 'not reference an nn.Module, nn.Parameter, or buffer, which is ' - 'what \'get_attr\' Nodes typically target') - else: - m_itr = new_m_itr - - @compatibility(is_backward_compatible=True) - def eliminate_dead_code(self): - """ - Remove all dead code from the graph, based on each node's number of - users, and whether the nodes have any side effects. The graph must be - topologically sorted before calling. - - Returns: - bool: Whether the graph was changed as a result of the pass. - - Example: - - Before dead code is eliminated, `a` from `a = x + 1` below has no users - and thus can be eliminated from the graph without having an effect. - - .. code-block:: python - - def forward(self, x): - a = x + 1 - return x + self.attr_1 - - After dead code is eliminated, `a = x + 1` has been removed, and the rest - of `forward` remains. - - .. code-block:: python - - def forward(self, x): - return x + self.attr_1 - - .. warning:: - - Dead code elimination has some heuristics to avoid removing - side-effectful nodes (see Node.is_impure) but in general coverage - is very bad, so you should assume that this method is not sound - to call unless you know that your FX graph consists entirely - of functional operations. - """ - # Lint the graph first to make sure its topologically sorted, otherwise - # DCE below will not behave as expected. - self.lint() - - # Reverse iterate so that when we remove a node, any nodes used as an - # input to that node have an updated user count that no longer reflects - # the removed node. - changed = False - for node in reversed(self.nodes): - if not node.is_impure() and len(node.users) == 0: - self.erase_node(node) - changed = True - - return changed - - @compatibility(is_backward_compatible=False) - def set_codegen(self, codegen: CodeGen): - self._codegen = codegen - - @compatibility(is_backward_compatible=False) - def on_generate_code( - self, - make_transformer: Callable[[Optional[TransformCodeFunc]], TransformCodeFunc] - ): - """Register a transformer function when python code is generated - - Args: - make_transformer (Callable[[Optional[TransformCodeFunc]], TransformCodeFunc]): - a function that returns a code transformer to be registered. - This function is called by `on_generate_code` to obtain the - code transformer. - - This function is also given as its input the currently - registered code transformer (or None if nothing is registered), - in case it is not desirable to overwrite it. This is useful to - chain code transformers together. - - Returns: - a context manager that when used in a `with` statement, to automatically - restore the previously registered code transformer. - - Example: - - .. code-block:: python - - - gm: fx.GraphModule = ... - - # This is a code transformer we want to register. This code - # transformer prepends a pdb import and trace statement at the very - # beginning of the generated pippy.fx code to allow for manual - # debugging with the PDB library. - def insert_pdb(body): - return ["import pdb; pdb.set_trace()\\n", *body] - - # Registers `insert_pdb`, and overwrites the current registered - # code transformer (given by `_` to the lambda): - gm.graph.on_generate_code( - lambda _: insert_pdb - ) - - # Or alternatively, registers a code transformer which first - # runs `body` through existing registered transformer, then - # through `insert_pdb`: - gm.graph.on_generate_code( - lambda current_trans: ( - lambda body: insert_pdb( - current_trans(body) if current_trans - else body - ) - ) - ) - - gm.recompile() - gm(*inputs) # drops into pdb - - - This function can also be used as a context manager, with the benefit to - automatically restores the previously registered code transformer: - - .. code-block:: python - - # ... continue from previous example - - with gm.graph.on_generate_code(lambda _: insert_pdb): - # do more stuff with `gm`... - gm.recompile() - gm(*inputs) # drops into pdb - - # now previous code transformer is restored (but `gm`'s code with pdb - # remains - that means you can run `gm` with pdb here too, until you - # run next `recompile()`). - """ - on_gen_code_old = self._codegen._body_transformer - self._codegen._body_transformer = make_transformer(on_gen_code_old) - - @contextlib.contextmanager - def on_generate_code_context_manager(): - try: - yield - finally: - self._codegen._body_transformer = on_gen_code_old - - return on_generate_code_context_manager() - - -reflectable_magic_methods = { - 'add': '{} + {}', - 'sub': '{} - {}', - 'mul': '{} * {}', - 'floordiv': '{} // {}', - 'truediv': '{} / {}', - 'div': '{} / {}', - 'mod': '{} % {}', - 'pow': '{} ** {}', - 'lshift': '{} << {}', - 'rshift': '{} >> {}', - 'and_': '{} & {}', - 'or_': '{} | {}', - 'xor': '{} ^ {}', - 'getitem': '{}[{}]', - 'matmul': '{} @ {}', -} - -magic_methods = dict({ - 'eq': '{} == {}', - 'ne': '{} != {}', - 'lt': '{} < {}', - 'gt': '{} > {}', - 'le': '{} <= {}', - 'ge': '{} >= {}', - 'pos': '+{}', - 'neg': '-{}', - 'invert': '~{}'}, **reflectable_magic_methods) - -inplace_methods = { - 'iadd': '{} += {}', - 'iand': '{} &= {}', - 'ifloordiv': '{} //= {}', - 'ilshift': '{} <<= {}', - 'imod': '{} %= {}', - 'imul': '{} *= {}', - 'imatmul': '{} @= {}', - 'ior': '{} |= {}', - 'ipow': '{} **= {}', - 'irshift': '{} >>= {}', - 'isub': '{} -= {}', - 'itruediv': '{} /= {}', - 'ixor': '{} ^= {}', - 'setitem': '{}[{}] = {}', -} diff --git a/pippy/fx/graph_module.py b/pippy/fx/graph_module.py deleted file mode 100644 index f8f46d917..000000000 --- a/pippy/fx/graph_module.py +++ /dev/null @@ -1,759 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates -import copy -import itertools -import linecache -import os -import sys -import traceback -import warnings -from pathlib import Path -from typing import Type, Dict, List, Any, Union, Optional, Set # pylint: disable=unused-import - -import torch -import torch.nn as nn -import torch.overrides -from torch.nn.modules.module import _addindent -from torch.package import Importer, sys_importer -from torch.package import PackageImporter, PackageExporter - -from ._compatibility import compatibility -from .graph import Graph, _PyTreeCodeGen, _is_from_torch, _custom_builtins, PythonCode - - -# Normal exec loses the source code, however we can work with -# the linecache module to recover it. -# Using _exec_with_source will add it to our local cache -# and then tools like TorchScript will be able to get source info. -class _EvalCacheLoader(object): - def __init__(self): - self.eval_cache = {} - self.next_id = 0 - - def cache(self, src: str, globals: Dict[str, Any]): - """Store the source in a private cache, and add a lazy entry in linecache - that allows the source to be retrieved by 'filename'. - - Args: - src (str): The module source to cache - globals (dict): The module globals - - Returns: - str: The cache key (and dummy filename) generated for src. - """ - - key = self._get_key() - self.eval_cache[key] = src - - # Don't mutate globals so that this loader is only used - # to populate linecache, and doesn't interact with other modules - # that might check `__loader__` - globals_copy = globals.copy() - globals_copy['__file__'] = key - globals_copy['__name__'] = key - globals_copy['__loader__'] = self - linecache.lazycache(key, globals_copy) - - return key - - # Part of the loader protocol (PEP 302) - # linecache will use this method when trying to find source code - def get_source(self, module_name) -> Optional[str]: - if module_name in self.eval_cache: - return self.eval_cache[module_name] - return None - - def _get_key(self): - key = f'.{self.next_id}' - self.next_id += 1 - return key - -_loader = _EvalCacheLoader() - - -def _exec_with_source(src: str, globals: Dict[str, Any]): - key = _loader.cache(src, globals) - exec(compile(src, key, 'exec'), globals) - - -def _forward_from_src(src: str, globals: Dict[str, Any]): - # avoid mutating the passed in dict - globals_copy = globals.copy() - _exec_with_source(src, globals_copy) - forward_fn = globals_copy['forward'] - del globals_copy['forward'] - return forward_fn - - -def _format_import_statement(name: str, obj: Any, importer: Importer) -> str: - if name in _custom_builtins: - return _custom_builtins[name].import_str - if _is_from_torch(name): - return 'import torch' - module_name, attr_name = importer.get_name(obj) - return f'from {module_name} import {attr_name} as {name}' - - -def _format_import_block(globals: Dict[str, Any], importer: Importer): - import_strs: Set[str] = set() - for name, obj in globals.items(): - import_strs.add(_format_import_statement(name, obj, importer)) - return '\n'.join(import_strs) - - -@compatibility(is_backward_compatible=True) -def reduce_graph_module(body: Dict[Any, Any], import_block: str) -> torch.nn.Module: - # BC: attribute name was changed from `code` to `_code` to facilitate - # making `code` into a property and adding a docstring to it - fn_src = body.get('_code') or body['code'] - forward = _forward_from_src(import_block + fn_src, {}) - return _deserialize_graph_module(forward, body) - - -@compatibility(is_backward_compatible=True) -def reduce_package_graph_module( - importer: PackageImporter, body: Dict[Any, Any], generated_module_name: str -) -> torch.nn.Module: - forward = importer.import_module(generated_module_name).forward - return _deserialize_graph_module(forward, body) - -@compatibility(is_backward_compatible=True) -def reduce_deploy_graph_module( - importer: PackageImporter, body: Dict[Any, Any], import_block: str -) -> torch.nn.Module: - ns = {} - ns["__builtins__"] = importer.patched_builtins - fn_src = body.get('_code') - assert fn_src is not None - forward = _forward_from_src(import_block + fn_src, ns) - return _deserialize_graph_module(forward, body) - - -def _deserialize_graph_module(forward, body: Dict[Any, Any]) -> torch.nn.Module: - """ - Deserialize a GraphModule given the dictionary of the original module, - using the code to reconstruct the graph. We delete the actual graph before - saving the dictionary so that changes to the in-memory graph format do not - get serialized. - """ - # We create a dummy class here because symbolic_trace pulls the forward() - # function off of the class, rather than the instance - class CodeOnlyModule(torch.nn.Module): - def __init__(self, body): - super().__init__() - self.__dict__ = body - - # Try to retrieve the forward source in a backward-compatible way - CodeOnlyModule.forward = forward - - tracer_cls = body.get('_tracer_cls') - if tracer_cls is None: - from ._symbolic_trace import Tracer - tracer_cls = Tracer - - graphmodule_cls_name = body.get('_graphmodule_cls_name', 'GraphModule') - - # This is a workaround for a mypy linter issue related to - # passing base class as an argument - https://github.com/python/mypy/issues/5865. - cls_tracer : Any = tracer_cls - - class KeepModules(cls_tracer): - # we shouldn't trace into any of the submodules, - # because they were not traced in the original GraphModule - def is_leaf_module(self, _: torch.nn.Module, __: str) -> bool: - return True - - com = CodeOnlyModule(body) - - tracer_extras = body.get('_tracer_extras', {}) - graph = KeepModules().trace(com, **tracer_extras) - - # Manually set Tracer class on the reconstructed Graph, to avoid - # referencing the private local subclass KeepModules. - graph._tracer_cls = tracer_cls - gm = GraphModule(com, graph, class_name=graphmodule_cls_name) - - # The GraphModule constructor only retains attributes referenced by the graph. - # In this case, our goal is return a GraphModule as close to identical as the one - # put into the package. If any additional attributes were present in body, - # we should keep them. - for k, v in body.items(): - if not hasattr(gm, k): - setattr(gm, k, v) - return gm - -# copy an attribute value with qualified name 'target' from 'from_module' to 'to_module' -# This installs empty Modules where none exist yet if they are subpaths of target -def _copy_attr(from_module: torch.nn.Module, to_module: torch.nn.Module, target: str): - *prefix, field = target.split('.') - for item in prefix: - f = getattr(from_module, item) - t = getattr(to_module, item, None) - if f is t: - # we have already installed one of its parents - # (e.g. target = root.linear.weight, but we have already installed root.linear) - # once we install a parent, we no longer need to copy the children - # since all the needed properties will already be present - return - - if t is None: - t = torch.nn.Module() - setattr(to_module, item, t) - from_module, to_module = f, t - - orig = getattr(from_module, field) - # If it is a tensor and not a parameter attribute of a module, it should be a named buffer. - # So, we register it as a named buffer in the target module. - if isinstance(orig, torch.Tensor) and not isinstance(orig, torch.nn.Parameter): - to_module.register_buffer(field, orig) - else: - setattr(to_module, field, orig) - -# Assign attribute 'from_obj' to the qualified name 'target' on 'to_module -# This installs empty Modules where none exist yet if they are subpaths of target -def _assign_attr(from_obj: Any, to_module: torch.nn.Module, target: str): - *prefix, field = target.split('.') - for item in prefix: - t = getattr(to_module, item, None) - - if t is None: - t = torch.nn.Module() - setattr(to_module, item, t) - to_module = t - - # If it is a tensor and not a parameter attribute of a module, it should be a named buffer. - # So, we register it as a named buffer in the target module. - if isinstance(from_obj, torch.Tensor) and not isinstance(from_obj, torch.nn.Parameter): - to_module.register_buffer(field, from_obj) - else: - setattr(to_module, field, from_obj) - -class _WrappedCall: - def __init__(self, cls, cls_call): - self.cls = cls - self.cls_call = cls_call - - # Previously, if an error occurred when valid - # symbolically-traced code was run with an invalid input, the - # user would see the source of the error as coming from - # `File "`, where N is some number. We use - # this function to generate a more informative error message. We - # return the traceback itself, a message explaining that the - # error occurred in a traced Module's generated forward - # function, and five lines of context surrounding the faulty - # line - @staticmethod - def _generate_error_message(frame_summary: traceback.FrameSummary) -> str: - # auxiliary variables (for readability) - err_lineno = frame_summary.lineno - assert err_lineno is not None - line = frame_summary.line - assert line is not None - err_line_len = len(line) - all_src_lines = linecache.getlines(frame_summary.filename) - - # constituent substrings of the error message - tb_repr = traceback.format_exc() - custom_msg = ("Call using an FX-traced Module, " - f"line {err_lineno} of the traced Module's " - "generated forward function:") - before_err = "".join(all_src_lines[err_lineno - 2 : err_lineno]) - marker = "~" * err_line_len + "~~~ <--- HERE" - err_and_after_err = "\n".join(all_src_lines[err_lineno : err_lineno + 2]) - - # joined message - return "\n".join([tb_repr, custom_msg, before_err, marker, err_and_after_err]) - - def __call__(self, obj, *args, **kwargs): - try: - if self.cls_call is not None: - return self.cls_call(obj, *args, **kwargs) - else: - return super(self.cls, obj).__call__(*args, **kwargs) # type: ignore[misc] - except Exception as e: - assert e.__traceback__ - topmost_framesummary: traceback.FrameSummary = \ - traceback.StackSummary.extract(traceback.walk_tb(e.__traceback__))[-1] # type: ignore[arg-type] - if "eval_with_key" in topmost_framesummary.filename: - print(_WrappedCall._generate_error_message(topmost_framesummary), - file=sys.stderr) - raise e.with_traceback(None) - else: - raise e - -@compatibility(is_backward_compatible=True) -class GraphModule(torch.nn.Module): - """ - GraphModule is an nn.Module generated from an fx.Graph. Graphmodule has a - ``graph`` attribute, as well as ``code`` and ``forward`` attributes generated - from that ``graph``. - - .. warning:: - - When ``graph`` is reassigned, ``code`` and ``forward`` will be automatically - regenerated. However, if you edit the contents of the ``graph`` without reassigning - the ``graph`` attribute itself, you must call ``recompile()`` to update the generated - code. - """ - def __new__(cls: 'Type[GraphModule]', *args, **kwargs): - # each instance of a graph module needs its own forward method - # so create a new singleton class for each instance. - # it is a subclass of the user-defined class, the only difference - # is an extra layer to install the forward method - - # address issue described at https://github.com/pytorch/pytorch/issues/63883 - # in other words, traverse class hierarchy to fix the redundant class definition problem - for t in cls.__mro__: - c = t.__qualname__.split('.')[-1] - if c != 'GraphModuleImpl': - cls = t - break - - class GraphModuleImpl(cls): # type: ignore[misc, valid-type] - pass - return super().__new__(GraphModuleImpl) - - @compatibility(is_backward_compatible=True) - def __init__(self, - root: Union[torch.nn.Module, Dict[str, Any]], - graph: Graph, - class_name: str = 'GraphModule'): - """ - Construct a GraphModule. - - Args: - - root (Union[torch.nn.Module, Dict[str, Any]): - ``root`` can either be an nn.Module instance or a Dict mapping strings to any attribute type. - In the case that ``root`` is a Module, any references to Module-based objects (via qualified - name) in the Graph's Nodes' ``target`` field will be copied over from the respective place - within ``root``'s Module hierarchy into the GraphModule's module hierarchy. - In the case that ``root`` is a dict, the qualified name found in a Node's ``target`` will be - looked up directly in the dict's keys. The object mapped to by the Dict will be copied - over into the appropriate place within the GraphModule's module hierarchy. - - graph (Graph): ``graph`` contains the nodes this GraphModule should use for code generation - - class_name (str): ``name`` denotes the name of this GraphModule for debugging purposes. If it's unset, all - error messages will report as originating from ``GraphModule``. It may be helpful to set this - to ``root``'s original name or a name that makes sense within the context of your transform. - """ - super().__init__() - self.__class__.__name__ = class_name - if isinstance(root, torch.nn.Module): - if hasattr(root, 'training'): - self.training = root.training - for node in graph.nodes: - if node.op in ['get_attr', 'call_module']: - assert isinstance(node.target, str) - _copy_attr(root, self, node.target) - elif isinstance(root, dict): - targets_to_copy = [] - for node in graph.nodes: - if node.op in ['get_attr', 'call_module']: - assert isinstance(node.target, str) - if node.target not in root: - raise RuntimeError('Node ' + str(node) + ' referenced target ' + node.target + - ' but that target was not provided in ``root``!') - targets_to_copy.append(node.target) - # Sort targets in ascending order of the # of atoms. - # This will ensure that less deeply nested attributes are assigned - # before more deeply nested attributes. For example, foo.bar - # will be assigned before foo.bar.baz. Otherwise, we might assign - # the user-provided ``foo.bar`` and wipe out the previously-assigned - # ``foo.bar.baz`` - targets_to_copy.sort(key=lambda t: t.count('.')) - for target_to_copy in targets_to_copy: - _assign_attr(root[target_to_copy], self, target_to_copy) - else: - raise RuntimeError('Unsupported type ' + str(root) + ' passed for root!') - - self.graph = graph - - # Store the Tracer class responsible for creating a Graph separately as part of the - # GraphModule state, except when the Tracer is defined in a local namespace. - # Locally defined Tracers are not pickleable. This is needed because torch.package will - # serialize a GraphModule without retaining the Graph, and needs to use the correct Tracer - # to re-create the Graph during deserialization. - self._tracer_cls = None - if self.graph._tracer_cls and '' not in self.graph._tracer_cls.__qualname__: - self._tracer_cls = self.graph._tracer_cls - - self._tracer_extras = {} - if self.graph._tracer_extras: - self._tracer_extras = self.graph._tracer_extras - - # Dictionary to store metadata - self.meta : Dict[str, Any] = {} - - # TorchScript breaks trying to compile the graph setter because of the - # continued string literal. Issue here: https://github.com/pytorch/pytorch/issues/44842 - # - # Shouldn't be an issue since these methods shouldn't be used in TorchScript anyway - __jit_unused_properties__ = ['graph'] - - @property - def graph(self) -> Graph: - """ - Return the ``Graph`` underlying this ``GraphModule`` - """ - return self._graph - - @graph.setter - def graph(self, g : Graph) -> None: - """ - Set the underlying ``Graph`` for this ``GraphModule``. This will internally - recompile the ``GraphModule`` so that the generated ``forward()`` function - corresponds to ``g`` - """ - assert isinstance(g, Graph), f'Expected a Graph instance, but got {type(g)}' - self._graph = g - g.owning_module = self - self.recompile() - - @compatibility(is_backward_compatible=False) - def to_folder(self, folder: Union[str, os.PathLike], module_name : str = "FxModule"): - """Dumps out module to ``folder`` with ``module_name`` so that it can be - imported with ``from import `` - - Args: - - folder (Union[str, os.PathLike]): The folder to write the code out to - - module_name (str): Top-level name to use for the ``Module`` while - writing out the code - """ - folder = Path(folder) - Path(folder).mkdir(exist_ok=True) - torch.save(self.state_dict(), folder / 'state_dict.pt') - tab = " " * 4 - custom_builtins = '\n'.join([v.import_str for v in _custom_builtins.values()]) - model_str = f""" -import torch -{custom_builtins} - -from torch.nn import * -class {module_name}(torch.nn.Module): - def __init__(self): - super().__init__() -""" - - def _gen_model_repr(module_name: str, module: torch.nn.Module) -> Optional[str]: - safe_reprs = [nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d, nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d] - if type(module) in safe_reprs: - return f"{module.__repr__()}" - else: - return None - - blobified_modules = [] - for module_name, module in self.named_children(): - module_str = _gen_model_repr(module_name, module) - if module_str is None: - module_file = folder / f'{module_name}.pt' - torch.save(module, module_file) - blobified_modules.append(module_name) - module_repr = module.__repr__().replace('\r', ' ').replace('\n', ' ') - module_str = f"torch.load(r'{module_file}') # {module_repr}" - model_str += f"{tab*2}self.{module_name} = {module_str}\n" - - for buffer_name, buffer in self._buffers.items(): - if buffer is None: - continue - model_str += f"{tab*2}self.register_buffer('{buffer_name}', torch.empty({list(buffer.shape)}, dtype={buffer.dtype}))\n" - - for param_name, param in self._parameters.items(): - if param is None: - continue - model_str += f"{tab*2}self.{param_name} = torch.nn.Parameter(torch.empty({list(param.shape)}, dtype={param.dtype}))\n" - - model_str += f"{tab*2}self.load_state_dict(torch.load(r'{folder}/state_dict.pt'))\n" - model_str += f"{_addindent(self.code, 4)}\n" - - module_file = folder / 'module.py' - module_file.write_text(model_str) - - init_file = folder / '__init__.py' - init_file.write_text('from .module import *') - - if len(blobified_modules) > 0: - warnings.warn("Was not able to save the following children modules as reprs -" - f"saved as pickled files instead: {blobified_modules}") - - @compatibility(is_backward_compatible=True) - def add_submodule(self, target: str, m: torch.nn.Module) -> bool: - """ - Adds the given submodule to ``self``. - - This installs empty Modules where none exist yet if they are - subpaths of ``target``. - - Args: - target: The fully-qualified string name of the new submodule - (See example in ``nn.Module.get_submodule`` for how to - specify a fully-qualified string.) - m: The submodule itself; the actual object we want to - install in the current Module - - Return: - bool: Whether or not the submodule could be inserted. For - this method to return True, each object in the chain - denoted by ``target`` must either a) not exist yet, - or b) reference an ``nn.Module`` (not a parameter or - other attribute) - """ - *prefix, field = target.split('.') - mod: torch.nn.Module = self - - for item in prefix: - - submod = getattr(mod, item, None) - - if submod is None: - submod = torch.nn.Module() - setattr(mod, item, submod) - - if not isinstance(submod, torch.nn.Module): - return False - - mod = submod - - mod.add_module(field, m) - return True - - @compatibility(is_backward_compatible=True) - def delete_submodule(self, target: str) -> bool: - """ - Deletes the given submodule from ``self``. - - The module will not be deleted if ``target`` is not a valid - target. - - Args: - target: The fully-qualified string name of the new submodule - (See example in ``nn.Module.get_submodule`` for how to - specify a fully-qualified string.) - - Returns: - bool: Whether or not the target string referenced a - submodule we want to delete. A return value of ``False`` - means that the ``target`` was not a valid reference to - a submodule. - """ - atoms = target.split(".") - path, target_submod = atoms[:-1], atoms[-1] - mod: torch.nn.Module = self - - # Get the parent module - for item in path: - - if not hasattr(mod, item): - return False - - mod = getattr(mod, item) - - if not isinstance(mod, torch.nn.Module): - return False - - if not hasattr(mod, target_submod): - return False - - if not isinstance(getattr(mod, target_submod), torch.nn.Module): - return False - - delattr(mod, target_submod) - return True - - @compatibility(is_backward_compatible=True) - def delete_all_unused_submodules(self) -> None: - """ - Deletes all unused submodules from ``self``. - - A Module is considered "used" if any one of the following is - true: - 1. It has children that are used - 2. Its forward is called directly via a ``call_module`` node - 3. It has a non-Module attribute that is used from a - ``get_attr`` node - - This method can be called to clean up an ``nn.Module`` without - manually calling ``delete_submodule`` on each unused submodule. - """ - used: List[str] = [] - - for node in self.graph.nodes: - - if node.op == "call_module" or node.op == "get_attr": - - # A list of strings representing the different parts - # of the path. For exmaple, `foo.bar.baz` gives us - # ["foo", "bar", "baz"] - fullpath = node.target.split(".") - - # If we're looking at multiple parts of a path, join - # join them with a dot. Otherwise, return that single - # element without doing anything to it. - def join_fn(x: str, y: str) -> str: - return '.'.join([x, y] if y else [x]) - - # Progressively collect all the names of intermediate - # modules. For example, if we have the target - # `foo.bar.baz`, we'll add `foo`, `foo.bar`, and - # `foo.bar.baz` to the list. - for path in itertools.accumulate(fullpath, join_fn): - used.append(path) - - # For a `call_module` node, also register all recursive submodules - # as used - if node.op == "call_module": - try: - submod = self.get_submodule(node.target) - - for submod_name, _ in submod.named_modules(): - if submod_name != '': - used.append('.'.join([node.target, submod_name])) - except AttributeError: - # Node referenced nonexistent submodule, don't need to - # worry about GCing anything - pass - - to_delete = [name for name, _ in self.named_modules() - if name not in used] - - for name in to_delete: - self.delete_submodule(name) - - @property - def code(self) -> str: - """ - Return the Python code generated from the ``Graph`` underlying this - ``GraphModule``. - """ - if not hasattr(self, '_code'): - raise RuntimeError('Code has not been generated! Please report a bug to PyTorch') - return self._code - - @compatibility(is_backward_compatible=True) - def recompile(self) -> PythonCode: - """ - Recompile this GraphModule from its ``graph`` attribute. This should be - called after editing the contained ``graph``, otherwise the generated - code of this ``GraphModule`` will be out of date. - """ - if isinstance(self._graph._codegen, _PyTreeCodeGen): - self._in_spec = self._graph._codegen.pytree_info.in_spec - self._out_spec = self._graph._codegen.pytree_info.out_spec - python_code = self._graph.python_code(root_module='self') - self._code = python_code.src - - cls = type(self) - cls.forward = _forward_from_src(self._code, python_code.globals) - - # Determine whether this class explicitly defines a __call__ implementation - # to wrap. If it does, save it in order to have wrapped_call invoke it. - # If it does not, wrapped_call can use a dynamic call to super() instead. - # In most cases, super().__call__ should be torch.nn.Module.__call__. - # We do not want to hold a reference to Module.__call__ here; doing so will - # bypass patching of torch.nn.Module.__call__ done while symbolic tracing. - cls_call = cls.__call__ if "__call__" in vars(cls) else None - - if '_wrapped_call' not in vars(cls): - cls._wrapped_call = _WrappedCall(cls, cls_call) # type: ignore[attr-defined] - - def call_wrapped(self, *args, **kwargs): - return self._wrapped_call(self, *args, **kwargs) - - cls.__call__ = call_wrapped - - return python_code - - # Passing Tracer as argument allows subclasses extending fx.GraphModule - # define their own Tracer (extending fx.Tracer). - def __reduce_deploy__(self, importer: Importer): - dict_without_graph = self.__dict__.copy() - dict_without_graph['_graphmodule_cls_name'] = self.__class__.__name__ - del dict_without_graph['_graph'] - - python_code = self.recompile() - import_block = _format_import_block(python_code.globals, importer) - return (reduce_deploy_graph_module, (dict_without_graph, import_block)) - - def __reduce_package__(self, exporter: PackageExporter): - dict_without_graph = self.__dict__.copy() - dict_without_graph['_graphmodule_cls_name'] = self.__class__.__name__ - del dict_without_graph['_graph'] - - generated_module_name = f'fx-generated._{exporter.get_unique_id()}' - python_code = self.recompile() - import_block = _format_import_block(python_code.globals, exporter.importer) - module_code = import_block + self.code - exporter.save_source_string(generated_module_name, module_code) - return (reduce_package_graph_module, (dict_without_graph, generated_module_name)) - - def __reduce__(self): - """ - Serialization of GraphModule. We serialize only the generated code, not - the underlying ``Graph``. This is because ``Graph`` does not have on-disk - backward-compatibility guarantees, whereas Python source code does. - On the deserialization side, we symbolically trace through the generated - code to regenerate the underlying ``Graph`` - """ - dict_without_graph = self.__dict__.copy() - python_code = self.recompile() - import_block = _format_import_block(python_code.globals, sys_importer) - del dict_without_graph['_graph'] - return (reduce_graph_module, (dict_without_graph, import_block)) - - # because __reduce__ is defined for serialization, - # we need to define deepcopy otherwise it will call __reduce__ - # and cause symbolic tracing to occur every time we try to copy the object - def __deepcopy__(self, memo): - fake_mod = torch.nn.Module() - fake_mod.__dict__ = copy.deepcopy(self.__dict__) - return GraphModule(fake_mod, fake_mod.__dict__['_graph']) - - def __copy__(self): - return GraphModule(self, self.graph) - - @compatibility(is_backward_compatible=False) - def print_readable(self): - """ - Return the Python code generated for current GraphModule and its children GraphModules - """ - verbose_python_code = self._graph.python_code(root_module='self', verbose=True) - module_code = verbose_python_code.src - module_code = module_code.lstrip('\n') - module_code = f"class {self._get_name()}(torch.nn.Module):\n" + module_code - module_code = _addindent(module_code, 4) - - submodule_code_list = [""] - for submodule in self.children(): - if isinstance(submodule, GraphModule): - submodule_code_list.append(submodule.__nested_code()) - submodule_code = "\n".join(submodule_code_list) - submodule_code = _addindent(submodule_code, 4) - - print(module_code + submodule_code) - - def __str__(self) -> str: - orig_str = super().__str__() - print_readable_reminder = "# To see more debug info, please use `graph_module.print_readable()`" - return '\n'.join([orig_str, self._code, print_readable_reminder]) - - def _replicate_for_data_parallel(self): - new_gm = self.__copy__() - new_gm._is_replica = True - return new_gm - -# workarounds for issues in __torch_function__ - -# WAR for __torch_function__ not handling tensor lists, -# fix is in https://github.com/pytorch/pytorch/pull/34725 -# orig_cat = torch.cat -# def patched_cat(*args, **kwargs): -# tensors = args[0] -# for t in tensors: -# if isinstance(t, Proxy): -# return t.__torch_function__(patched_cat, (), args, kwargs) -# return orig_cat(*args, **kwargs) -# patched_cat.__module__ = 'torch' -# patched_cat.__name__ = 'cat' -# torch.cat = patched_cat diff --git a/pippy/fx/immutable_collections.py b/pippy/fx/immutable_collections.py deleted file mode 100644 index de884c205..000000000 --- a/pippy/fx/immutable_collections.py +++ /dev/null @@ -1,53 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates -from typing import Any, Dict, Tuple, List - -from ._compatibility import compatibility -from torch.utils._pytree import Context, _register_pytree_node - -__all__ = ["immutable_list", "immutable_dict"] - -_help_mutation = """\ -If you are attempting to modify the kwargs or args of a pippy.fx.Node object, -instead create a new copy of it and assign the copy to the node: - new_args = ... # copy and mutate args - node.args = new_args -""" - -def _no_mutation(self, *args, **kwargs): - raise NotImplementedError(f"'{type(self).__name__}' object does not support mutation. {_help_mutation}") - -def _create_immutable_container(base, mutable_functions): - container = type('immutable_' + base.__name__, (base,), {}) - for attr in mutable_functions: - setattr(container, attr, _no_mutation) - return container - -immutable_list = _create_immutable_container(list, - ['__delitem__', '__iadd__', '__imul__', '__setitem__', 'append', - 'clear', 'extend', 'insert', 'pop', 'remove']) -immutable_list.__reduce__ = lambda self: (immutable_list, (tuple(iter(self)),)) - -compatibility(is_backward_compatible=True)(immutable_list) - -immutable_dict = _create_immutable_container(dict, ['__delitem__', '__setitem__', 'clear', 'pop', 'popitem', 'update']) -immutable_dict.__reduce__ = lambda self: (immutable_dict, (iter(self.items()),)) -compatibility(is_backward_compatible=True)(immutable_dict) - - -# Register immutable collections for PyTree operations - -def _immutable_dict_flatten(d: Dict[Any, Any]) -> Tuple[List[Any], Context]: - return list(d.values()), list(d.keys()) - -def _immutable_dict_unflatten(values: List[Any], context: Context) -> Dict[Any, Any]: - return immutable_dict({key: value for key, value in zip(context, values)}) - -def _immutable_list_flatten(d: List[Any]) -> Tuple[List[Any], Context]: - return d, None - -def _immutable_list_unflatten(values: List[Any], context: Context) -> List[Any]: - return immutable_list(values) - - -_register_pytree_node(immutable_dict, _immutable_dict_flatten, _immutable_dict_unflatten) -_register_pytree_node(immutable_list, _immutable_list_flatten, _immutable_list_unflatten) diff --git a/pippy/fx/interpreter.py b/pippy/fx/interpreter.py deleted file mode 100644 index 7a7cf7fc0..000000000 --- a/pippy/fx/interpreter.py +++ /dev/null @@ -1,481 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates -from .graph_module import GraphModule -from .graph import Graph -from .node import Argument, Node, Target, map_arg, map_aggregate # pylint: disable=unused-import -from .proxy import Proxy -from ._symbolic_trace import Tracer -from ._compatibility import compatibility -import pippy.fx.traceback as fx_traceback -from typing import Any, Dict, Iterator, List, Optional, Tuple, Union -import inspect -from contextlib import contextmanager - -__all__ = ['Interpreter', 'Transformer'] - -@compatibility(is_backward_compatible=True) -class Interpreter: - """ - An Interpreter executes an FX graph Node-by-Node. This pattern - can be useful for many things, including writing code - transformations as well as analysis passes. - - Methods in the Interpreter class can be overridden to customize - the behavior of execution. The map of overrideable methods - in terms of call hierarchy:: - - run() - +-- run_node - +-- placeholder() - +-- get_attr() - +-- call_function() - +-- call_method() - +-- call_module() - +-- output() - - Example: - - Suppose we want to swap all instances of ``torch.neg`` with - ``torch.sigmoid`` and vice versa (including their ``Tensor`` - method equivalents). We could subclass Interpreter like so:: - - class NegSigmSwapInterpreter(Interpreter): - def call_function(self, target : Target, - args : Tuple, kwargs : Dict) -> Any: - if target == torch.sigmoid: - return torch.neg(*args, **kwargs) - return super().call_function(n) - - def call_method(self, target : Target, - args : Tuple, kwargs : Dict) -> Any: - if target == 'neg': - call_self, *args_tail = args - return call_self.sigmoid(*args_tail, **kwargs) - return super().call_method(n) - - def fn(x): - return torch.sigmoid(x).neg() - - gm = pippy.fx.symbolic_trace(fn) - input = torch.randn(3, 4) - result = NegSigmSwapInterpreter(gm).run(input) - torch.testing.assert_allclose(result, torch.neg(input).sigmoid()) - - Args: - module (GraphModule): The module to be executed - garbage_collect_values (bool): Whether to delete values after their last - use within the Module's execution. This ensures optimal memory usage during - execution. This can be disabled to, for example, examine all of the intermediate - values in the execution by looking at the ``Interpreter.env`` attribute. - """ - @compatibility(is_backward_compatible=True) - def __init__(self, module : GraphModule, garbage_collect_values : bool = True): - assert isinstance(module, GraphModule) - self.module = module - self.submodules = dict(self.module.named_modules()) - self.env : Dict[Node, Any] = {} - - self.garbage_collect_values = garbage_collect_values - - if self.garbage_collect_values: - # Run through reverse nodes and record the first instance of a use - # of a given node. This represents the *last* use of the node in the - # execution order of the program, which we will use to free unused - # values - node_to_last_use : Dict[Node, Node] = {} - self.user_to_last_uses : Dict[Node, List[Node]] = {} - - def register_last_uses(n : Node, user : Node): - if n not in node_to_last_use: - node_to_last_use[n] = user - self.user_to_last_uses.setdefault(user, []).append(n) - - for node in reversed(self.module.graph.nodes): - map_arg(node.args, lambda n: register_last_uses(n, node)) - map_arg(node.kwargs, lambda n: register_last_uses(n, node)) - - @compatibility(is_backward_compatible=True) - def run(self, *args, initial_env : Optional[Dict[Node, Any]] = None, enable_io_processing : bool = True) -> Any: - """ - Run `module` via interpretation and return the result. - - Args: - *args: The arguments to the Module to run, in positional order - initial_env (Optional[Dict[Node, Any]]): An optional starting environment for execution. - This is a dict mapping `Node` to any value. This can be used, for example, to - pre-populate results for certain `Nodes` so as to do only partial evaluation within - the interpreter. - enable_io_processing (bool): If true, we process the inputs and outputs with graph's process_inputs and - process_outputs function first before using them. - - Returns: - Any: The value returned from executing the Module - """ - self.env = initial_env if initial_env else {} - - # Positional function args are consumed left-to-right by - # `placeholder` nodes. Use an iterator to keep track of - # position and extract those values. - if enable_io_processing: - args = self.module.graph.process_inputs(*args) - self.args_iter : Iterator[Any] = iter(args) - - for node in self.module.graph.nodes: - if node in self.env: - # Short circuit if we have this value. This could - # be used, for example, for partial evaluation - # where the caller has pre-populated `env` with - # values for a subset of the program. - continue - - try: - self.env[node] = self.run_node(node) - except Exception as e: - msg = f"While executing {node.format_node()}" - msg = '{}\n\n{}'.format(e.args[0], msg) if e.args else str(msg) - msg += f"\nOriginal traceback:\n{node.stack_trace}" - e.args = (msg,) + e.args[1:] - if isinstance(e, KeyError): - raise RuntimeError(*e.args) - raise - - if self.garbage_collect_values: - for to_delete in self.user_to_last_uses.get(node, []): - del self.env[to_delete] - - if node.op == 'output': - output_val = self.env[node] - return self.module.graph.process_outputs(output_val) if enable_io_processing else output_val - - @contextmanager - def _set_current_node(self, node): - with fx_traceback.append_stack_trace(node.stack_trace): - yield - - @compatibility(is_backward_compatible=True) - def run_node(self, n : Node) -> Any: - """ - Run a specific node ``n`` and return the result. - Calls into placeholder, get_attr, call_function, - call_method, call_module, or output depending - on ``node.op`` - - Args: - n (Node): The Node to execute - - Returns: - Any: The result of executing ``n`` - """ - with fx_traceback.append_stack_trace(n.stack_trace): - args, kwargs = self.fetch_args_kwargs_from_env(n) - assert isinstance(args, tuple) - assert isinstance(kwargs, dict) - return getattr(self, n.op)(n.target, args, kwargs) - - # Main Node running APIs - @compatibility(is_backward_compatible=True) - def placeholder(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any: - """ - Execute a ``placeholder`` node. Note that this is stateful: - ``Interpreter`` maintains an internal iterator over - arguments passed to ``run`` and this method returns - next() on that iterator. - - Args: - target (Target): The call target for this node. See - `Node `__ for - details on semantics - args (Tuple): Tuple of positional args for this invocation - kwargs (Dict): Dict of keyword arguments for this invocation - - Returns: - Any: The argument value that was retrieved. - """ - assert isinstance(target, str) - if target.startswith('*'): - # For a starred parameter e.g. `*args`, retrieve all - # remaining values from the args list. - return list(self.args_iter) - else: - try: - return next(self.args_iter) - except StopIteration as si: - if len(args) > 0: - return args[0] - else: - raise RuntimeError(f'Expected positional argument for parameter {target}, but one was not passed in!') - - @compatibility(is_backward_compatible=True) - def get_attr(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any: - """ - Execute a ``get_attr`` node. Will retrieve an attribute - value from the ``Module`` hierarchy of ``self.module``. - - Args: - target (Target): The call target for this node. See - `Node `__ for - details on semantics - args (Tuple): Tuple of positional args for this invocation - kwargs (Dict): Dict of keyword arguments for this invocation - - Return: - Any: The value of the attribute that was retrieved - """ - assert isinstance(target, str) - return self.fetch_attr(target) - - @compatibility(is_backward_compatible=True) - def call_function(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any: - """ - Execute a ``call_function`` node and return the result. - - Args: - target (Target): The call target for this node. See - `Node `__ for - details on semantics - args (Tuple): Tuple of positional args for this invocation - kwargs (Dict): Dict of keyword arguments for this invocation - - Return - Any: The value returned by the function invocation - """ - assert not isinstance(target, str) - - # Execute the function and return the result - return target(*args, **kwargs) - - @compatibility(is_backward_compatible=True) - def call_method(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any: - """ - Execute a ``call_method`` node and return the result. - - Args: - target (Target): The call target for this node. See - `Node `__ for - details on semantics - args (Tuple): Tuple of positional args for this invocation - kwargs (Dict): Dict of keyword arguments for this invocation - - Return - Any: The value returned by the method invocation - """ - # args[0] is the `self` object for this method call - self_obj, *args_tail = args - - # Execute the method and return the result - assert isinstance(target, str) - return getattr(self_obj, target)(*args_tail, **kwargs) - - @compatibility(is_backward_compatible=True) - def call_module(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any: - """ - Execute a ``call_module`` node and return the result. - - Args: - target (Target): The call target for this node. See - `Node `__ for - details on semantics - args (Tuple): Tuple of positional args for this invocation - kwargs (Dict): Dict of keyword arguments for this invocation - - Return - Any: The value returned by the module invocation - """ - # Retrieve executed args and kwargs values from the environment - - # Execute the method and return the result - assert isinstance(target, str) - submod = self.fetch_attr(target) - - return submod(*args, **kwargs) - - @compatibility(is_backward_compatible=True) - def output(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any: - """ - Execute an ``output`` node. This really just retrieves - the value referenced by the ``output`` node and returns it. - - Args: - target (Target): The call target for this node. See - `Node `__ for - details on semantics - args (Tuple): Tuple of positional args for this invocation - kwargs (Dict): Dict of keyword arguments for this invocation - - Return: - Any: The return value referenced by the output node - """ - return args[0] - - # Helper methods - @compatibility(is_backward_compatible=True) - def fetch_attr(self, target : str): - """ - Fetch an attribute from the ``Module`` hierarchy of ``self.module``. - - Args: - target (str): The fully-qualfiied name of the attribute to fetch - - Return: - Any: The value of the attribute. - """ - target_atoms = target.split('.') - attr_itr = self.module - for i, atom in enumerate(target_atoms): - if not hasattr(attr_itr, atom): - raise RuntimeError(f"Node referenced nonexistent target {'.'.join(target_atoms[:i])}") - attr_itr = getattr(attr_itr, atom) - return attr_itr - - @compatibility(is_backward_compatible=True) - def fetch_args_kwargs_from_env(self, n : Node) -> Tuple[Tuple, Dict]: - """ - Fetch the concrete values of ``args`` and ``kwargs`` of node ``n`` - from the current execution environment. - - Args: - n (Node): The node for which ``args`` and ``kwargs`` should be fetched. - - Return: - Tuple[Tuple, Dict]: ``args`` and ``kwargs`` with concrete values for ``n``. - """ - args = self.map_nodes_to_values(n.args, n) - assert isinstance(args, tuple) - kwargs = self.map_nodes_to_values(n.kwargs, n) - assert isinstance(kwargs, dict) - return args, kwargs - - @compatibility(is_backward_compatible=True) - def map_nodes_to_values(self, args : Argument, n : Node) -> Argument: - """ - Recursively descend through ``args`` and look up the concrete value - for each ``Node`` in the current execution environment. - - Args: - args (Argument): Data structure within which to look up concrete values - - n (Node): Node to which ``args`` belongs. This is only used for error reporting. - """ - def load_arg(n_arg : Node) -> Any: - if n_arg not in self.env: - raise RuntimeError(f'Node {n} referenced nonexistent value {n_arg}! Run Graph.lint() ' - f'to diagnose such issues') - return self.env[n_arg] - return map_arg(args, load_arg) - -@compatibility(is_backward_compatible=True) -class Transformer(Interpreter): - """ - ``Transformer`` is a special type of interpreter that produces a - new ``Module``. It exposes a ``transform()`` method that returns - the transformed ``Module``. ``Transformer`` does not require - arguments to run, as ``Interpreter`` does. ``Transformer`` works - entirely symbolically. - - Example: - - Suppose we want to swap all instances of ``torch.neg`` with - ``torch.sigmoid`` and vice versa (including their ``Tensor`` - method equivalents). We could subclass ``Transformer`` like so:: - - class NegSigmSwapXformer(Transformer): - def call_function(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any: - if target == torch.sigmoid: - return torch.neg(*args, **kwargs) - return super().call_function(n) - - def call_method(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any: - if target == 'neg': - call_self, *args_tail = args - return call_self.sigmoid(*args_tail, **kwargs) - return super().call_method(n) - - def fn(x): - return torch.sigmoid(x).neg() - - gm = pippy.fx.symbolic_trace(fn) - - transformed : torch.nn.Module = NegSigmSwapXformer(gm).transform() - input = torch.randn(3, 4) - torch.testing.assert_allclose(transformed(input), torch.neg(input).sigmoid()) - - Args: - module (GraphModule): The ``Module`` to be transformed. - """ - - @compatibility(is_backward_compatible=True) - def __init__(self, module): - super().__init__(module) - self.new_graph = Graph() - self.new_graph.set_codegen(module.graph._codegen) - - class TransformerTracer(Tracer): - def __init__(self, graph: Graph): - super().__init__() - self.graph = graph - - def is_leaf_module(self, _, __) -> bool: - return True - - self.tracer = TransformerTracer(self.new_graph) - self.tracer.root = module - - @compatibility(is_backward_compatible=True) - def placeholder(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Proxy: - """ - Execute a ``placeholder`` node. In ``Transformer``, this is - overridden to insert a new ``placeholder`` into the output - graph. - - Args: - target (Target): The call target for this node. See - `Node `__ for - details on semantics - args (Tuple): Tuple of positional args for this invocation - kwargs (Dict): Dict of keyword arguments for this invocation - """ - assert isinstance(target, str) - default_value = next(iter(args)) if args else inspect.Signature.empty - return Proxy(self.new_graph.placeholder(target, default_value=default_value), self.tracer) - - @compatibility(is_backward_compatible=True) - def get_attr(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Proxy: - """ - Execute a ``get_attr`` node. In ``Transformer``, this is - overridden to insert a new ``get_attr`` node into the output - graph. - - Args: - target (Target): The call target for this node. See - `Node `__ for - details on semantics - args (Tuple): Tuple of positional args for this invocation - kwargs (Dict): Dict of keyword arguments for this invocation - """ - assert isinstance(target, str) - return Proxy(self.new_graph.get_attr(target), self.tracer) - - @compatibility(is_backward_compatible=True) - def call_module(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any: - # Override so that the leaf module policy from `self.tracer` is respected. - assert isinstance(target, str) - submod = self.fetch_attr(target) - return self.tracer.call_module(submod, submod.forward, args, kwargs) - - @compatibility(is_backward_compatible=True) - def call_function(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any: - # Override so that functions that were wrapped are still wrapped. - return self.tracer.create_proxy('call_function', target, args, kwargs) - - @compatibility(is_backward_compatible=True) - def transform(self) -> GraphModule: - """ - Transform ``self.module`` and return the transformed - ``GraphModule``. - """ - with fx_traceback.override_stack_trace(): - result = super().run(enable_io_processing=False) - if result is not None: - def strip_proxy(a : Union[Argument, Proxy]) -> Any: - return a.node if isinstance(a, Proxy) else a - self.new_graph.output(map_aggregate(result, strip_proxy)) - return GraphModule(self.module, self.new_graph) diff --git a/pippy/fx/node.py b/pippy/fx/node.py deleted file mode 100644 index 5e4600319..000000000 --- a/pippy/fx/node.py +++ /dev/null @@ -1,627 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates -# Nodes represent a definition of a value in our graph of operators. -from typing import TYPE_CHECKING, Union, Callable, Any, Tuple, List, Optional, Dict, Set -from ._compatibility import compatibility -from .immutable_collections import immutable_dict, immutable_list -import torch -import builtins -import types -import warnings -from pippy.fx.operator_schemas import normalize_function, normalize_module, ArgsKwargsPair - -if TYPE_CHECKING: - from .graph import Graph - -__all__ = ['Node', 'map_arg', 'map_aggregate'] - -BaseArgumentTypes = Union[str, int, float, bool, complex, torch.dtype, - torch.Tensor, torch.device, torch.memory_format, torch.layout] -base_types = BaseArgumentTypes.__args__ # type: ignore[attr-defined] - -Target = Union[Callable[..., Any], str] - -Argument = Optional[Union[ - Tuple[Any, ...], # actually Argument, but mypy can't represent recursive types - List[Any], # actually Argument - Dict[str, Any], # actually Argument - slice, # Slice[Argument, Argument, Argument], but slice is not a templated type in typing - 'Node', - BaseArgumentTypes -]] - -_side_effectful_functions: Set[Callable] = { - torch._assert, - torch.ops.profiler._record_function_enter, - torch.ops.profiler._record_function_enter_new, - torch.ops.profiler._record_function_exit} - -# this is fixed on master, WAR for 1.5 -def _find_module_of_method(orig_method: Callable[..., Any]) -> str: - name = orig_method.__name__ - module = orig_method.__module__ - if module is not None: - return module - for guess in [torch, torch.nn.functional]: - if getattr(guess, name, None) is orig_method: - return guess.__name__ - raise RuntimeError(f'cannot find module for {orig_method}') - -# Borrowed from CPython typing module -# https://github.com/python/cpython/blob/f90dc36c15d7fee0efaf6d39e97be0bdf2683e93/Lib/typing.py#L156 -def _type_repr(obj): - """Return the repr() of an object, special-casing types (internal helper). - If obj is a type, we return a shorter version than the default - type.__repr__, based on the module and qualified name, which is - typically enough to uniquely identify a type. For everything - else, we fall back on repr(obj). - """ - if isinstance(obj, type): - if obj.__module__ == 'builtins': - return obj.__qualname__ - return f'{obj.__module__}.{obj.__qualname__}' - if obj is ...: - return('...') - if isinstance(obj, types.FunctionType): - return obj.__name__ - return repr(obj) - -def _get_qualified_name(func: Callable[..., Any]) -> str: - # things like getattr just appear in builtins - if getattr(builtins, func.__name__, None) is func: - return func.__name__ - name = func.__name__ - module = _find_module_of_method(func) - module = module.replace('torch._ops', 'torch.ops') # WAR for bug in how torch.ops assigns module - return f'{module}.{name}' - -def _format_arg(arg, max_list_len=float('inf')) -> str: - if hasattr(arg, '_custom_fx_repr_fn'): - return arg._custom_fx_repr_fn() - elif isinstance(arg, list): - items = ', '.join(_format_arg(a) for idx, a in enumerate(arg) if idx < max_list_len) - maybe_len = '' if len(arg) < max_list_len + 1 else f', ...[total_len={len(arg)}]' - return f'[{items}{maybe_len}]' - elif isinstance(arg, tuple): - items = ', '.join(_format_arg(a) for idx, a in enumerate(arg) if idx < max_list_len) - maybe_len = '' if len(arg) < max_list_len + 1 else f', ...[total_len={len(arg)}]' - maybe_comma = ',' if len(arg) == 1 else '' - return f'({items}{maybe_comma}{maybe_len})' - elif isinstance(arg, dict): - items_str = ', '.join(f'{k}: {_format_arg(v)}' for k, v in arg.items()) - return f'{{{items_str}}}' - - if isinstance(arg, Node): - return '%' + str(arg) - else: - return str(arg) - -@compatibility(is_backward_compatible=True) -class Node: - """ - ``Node`` is the data structure that represents individual operations within - a ``Graph``. For the most part, Nodes represent callsites to various entities, - such as operators, methods, and Modules (some exceptions include nodes that - specify function inputs and outputs). Each ``Node`` has a function specified - by its ``op`` property. The ``Node`` semantics for each value of ``op`` are as follows: - - - ``placeholder`` represents a function input. The ``name`` attribute specifies the name this value will take on. - ``target`` is similarly the name of the argument. ``args`` holds either: 1) nothing, or 2) a single argument - denoting the default parameter of the function input. ``kwargs`` is don't-care. Placeholders correspond to - the function parameters (e.g. ``x``) in the graph printout. - - ``get_attr`` retrieves a parameter from the module hierarchy. ``name`` is similarly the name the result of the - fetch is assigned to. ``target`` is the fully-qualified name of the parameter's position in the module hierarchy. - ``args`` and ``kwargs`` are don't-care - - ``call_function`` applies a free function to some values. ``name`` is similarly the name of the value to assign - to. ``target`` is the function to be applied. ``args`` and ``kwargs`` represent the arguments to the function, - following the Python calling convention - - ``call_module`` applies a module in the module hierarchy's ``forward()`` method to given arguments. ``name`` is - as previous. ``target`` is the fully-qualified name of the module in the module hierarchy to call. - ``args`` and ``kwargs`` represent the arguments to invoke the module on, *including the self argument*. - - ``call_method`` calls a method on a value. ``name`` is as similar. ``target`` is the string name of the method - to apply to the ``self`` argument. ``args`` and ``kwargs`` represent the arguments to invoke the module on, - *including the self argument* - - ``output`` contains the output of the traced function in its ``args[0]`` attribute. This corresponds to the "return" statement - in the Graph printout. - """ - - @compatibility(is_backward_compatible=True) - def __init__(self, graph: 'Graph', name: str, op: str, target: 'Target', - args: Tuple['Argument', ...], kwargs: Dict[str, 'Argument'], - return_type : Optional[Any] = None) -> None: - """ - Instantiate an instance of ``Node``. Note: most often, you want to use the - Graph APIs, i.e. ``Graph.call_module``, ``Graph.call_method``, etc. rather - than instantiating a ``Node`` directly. - - Args: - graph (Graph): The ``Graph`` to which this ``Node`` should belong. - - name (str): The name to which the output of this ``Node`` should be assigned - - op (str): The opcode for this ``Node``. Can be one of 'placeholder', - 'call_method', 'call_module', 'call_function', 'get_attr', - 'output' - - target ('Target'): The target this op should call. See the broader - ``Node`` docstring for more details. - - args (Tuple['Argument']): The args to be passed to ``target`` - - kwargs (Dict[str, 'Argument']): The kwargs to be passed to ``target`` - - return_type (Optional[Any]): The python type expression representing the - type of the output of this node. This field can be used for - annotation of values in the generated code or for other types - of analyses. - """ - self.graph = graph - self.name = name # unique name of value being created - assert op in ['placeholder', 'call_method', 'call_module', 'call_function', 'get_attr', 'output', 'root'] - self.op = op # the kind of operation = placeholder|call_method|call_module|call_function|get_attr - if op == 'call_function': - if not callable(target): - raise ValueError(f'Node [graph = {graph}, name = \'{name}\'] target {target} has type {torch.typename(target)} ' - 'but a Callable is expected') - else: - if not isinstance(target, str): - raise ValueError(f'Node [graph = {graph}, name = \'{name}\'] target {target} has type {torch.typename(target)} ' - 'but a str is expected') - self.target = target # for method/module/function, the name of the method/module/function/attr - # being invoked, e.g add, layer1, or torch.add - - # All `Node`-valued inputs. Key is the Node, value is don't-care. - # The public API for this is `all_input_nodes`, this private attribute - # should not be accessed directly. - self._input_nodes : Dict[Node, None] = {} - self.__update_args_kwargs(map_arg(args, lambda x: x), map_arg(kwargs, lambda x: x)) # type: ignore[arg-type] - - # All of the nodes that use the value produced by this Node - # Note one user may correspond to several uses, e.g. the node fo ``x + x`` - # would appear once here, but represents two uses. - # - # Is a dict to act as an "ordered set". Keys are significant, value dont-care - self.users : Dict['Node', None] = {} - # Type expression representing the output value of this node. - # This should contain the same class of Type objects that would appear - # as type annotations for function inputs/outputs. - # - # For placeholder nodes, this value will be used to type-annotate the - # generated function parameters. - # For the return node, this value will be used to type-annotate the - # generated function return type. (Note this is a special case. ``return`` - # does not produce a value, it's more of a notation. Thus, this value - # describes the type of args[0] in the ``return`` node. - self.type : Optional[Any] = return_type - self._prev = self - self._next = self - self._erased = False - - # If set, use this fn to print this node - self._repr_fn : Optional[Callable[[Node], str]] = None - - # Dictionary to store metadata passes need to do their - # transformations. This metadata is preserved across node copies - self.meta : Dict[str, Any] = {} - - @property - def next(self) -> 'Node': - """ - Returns the next ``Node`` in the linked list of Nodes. - - Returns: - - The next ``Node`` in the linked list of Nodes. - """ - return self._next - - @property - def prev(self) -> 'Node': - """ - Returns the previous ``Node`` in the linked list of Nodes. - - Returns: - - The previous ``Node`` in the linked list of Nodes. - """ - return self._prev - - @compatibility(is_backward_compatible=True) - def prepend(self, x: 'Node') -> None: - """ - Insert x before this node in the list of nodes in the graph. Example:: - - Before: p -> self - bx -> x -> ax - After: p -> x -> self - bx -> ax - - Args: - x (Node): The node to put before this node. Must be a member of the same graph. - """ - assert self.graph == x.graph, "Attempting to move a Node into a different Graph" - if self == x: - warnings.warn("Trying to prepend a node to itself. This behavior has no effect on the graph.") - return - x._remove_from_list() - p = self._prev - p._next, x._prev = x, p - x._next, self._prev = self, x - - @compatibility(is_backward_compatible=True) - def append(self, x: 'Node') -> None: - """ - Insert ``x`` after this node in the list of nodes in the graph. - Equivalent to ``self.next.prepend(x)`` - - Args: - x (Node): The node to put after this node. Must be a member of the same graph. - """ - self._next.prepend(x) - - def _remove_from_list(self): - p, n = self._prev, self._next - p._next, n._prev = n, p - - @property - def args(self) -> Tuple[Argument, ...]: - """ - The tuple of arguments to this ``Node``. The interpretation of arguments - depends on the node's opcode. See the :class:`Node` docstring for more - information. - - Assignment to this property is allowed. All accounting of uses and users - is updated automatically on assignment. - """ - return self._args - - @args.setter - def args(self, a : Tuple[Argument, ...]): - """ - Set the tuple of arguments to this Node. The interpretation of arguments - depends on the node's opcode. See the ``fx.Graph`` docstring for more - information. - """ - # DO NOT CALL `__update_args_kwargs` directly. The correct way to - # set `args` is via direct assignment, i.e. `node.args = new_args` - self.__update_args_kwargs(map_arg(a, lambda x: x), self._kwargs) # type: ignore[arg-type] - - @property - def kwargs(self) -> Dict[str, Argument]: - """ - The dict of keyword arguments to this ``Node``. The interpretation of arguments - depends on the node's opcode. See the :class:`Node` docstring for more - information. - - Assignment to this property is allowed. All accounting of uses and users - is updated automatically on assignment. - """ - return self._kwargs - - @kwargs.setter - def kwargs(self, k : Dict[str, Argument]): - """ - Set the dict of kwargs to this Node. The interpretation of arguments - depends on the node's opcode. See the ``fx.Graph`` docstring for more - information. - """ - # DO NOT CALL `__update_args_kwargs` directly. The correct way to - # set `args` is via direct assignment, i.e. `node.kwargs = new_kwargs` - self.__update_args_kwargs(self._args, map_arg(k, lambda x: x)) # type: ignore[arg-type] - - @property - def all_input_nodes(self) -> List['Node']: - """ - Return all Nodes that are inputs to this Node. This is equivalent to - iterating over ``args`` and ``kwargs`` and only collecting the values that - are Nodes. - - Returns: - - List of ``Nodes`` that appear in the ``args`` and ``kwargs`` of this - ``Node``, in that order. - """ - return list(self._input_nodes.keys()) - - @compatibility(is_backward_compatible=True) - def update_arg(self, idx : int, arg : Argument) -> None: - """ - Update an existing positional argument to contain the new value - ``arg``. After calling, ``self.args[idx] == arg``. - - Args: - - idx (int): The index into ``self.args`` of the element to update - arg (Argument): The new argument value to write into ``args`` - """ - args = list(self.args) - args[idx] = arg - self.args = tuple(args) - - @compatibility(is_backward_compatible=True) - def update_kwarg(self, key : str, arg : Argument) -> None: - """ - Update an existing keyword argument to contain the new value - ``arg``. After calling, ``self.kwargs[key] == arg``. - - Args: - - key (str): The key in ``self.kwargs`` of the element to update - arg (Argument): The new argument value to write into ``kwargs`` - """ - kwargs = dict(self.kwargs) - kwargs[key] = arg - self.kwargs = kwargs - - @property - def stack_trace(self) -> Optional[str]: - """ - Return the Python stack trace that was recorded during tracing, if any. - This property is usually populated by `Tracer.create_proxy`. To record - stack traces during tracing for debug purposes, set - `record_stack_traces = True` on the `Tracer` instance. - """ - return self.meta.get("stack_trace", None) - - @stack_trace.setter - def stack_trace(self, trace : Optional[str]): - self.meta["stack_trace"] = trace - - def __update_args_kwargs(self, new_args : Tuple['Argument', ...], new_kwargs : Dict[str, 'Argument']): - """ - This API is internal. Do *not* call it directly. - """ - self._args = new_args - self._kwargs = new_kwargs - - for old_use in self._input_nodes.keys(): - old_use.users.pop(self) - - self._input_nodes = {} - map_arg(self._args, lambda n: self._input_nodes.setdefault(n)) - map_arg(self._kwargs, lambda n: self._input_nodes.setdefault(n)) - - for new_use in self._input_nodes.keys(): - new_use.users.setdefault(self) - - def __repr__(self) -> str: - if self._repr_fn: - return self._repr_fn(self) - return self.name - - def _pretty_print_target(self, target): - """ - Make target printouts more user-friendly. - 1) builtins will be printed as `builtins.xyz` - 2) operators will be printed as `operator.xyz` - 3) other callables will be printed with qualfied name, e.g. torch.add - """ - if isinstance(target, str): - return target - if hasattr(target, '__module__'): - if not hasattr(target, '__name__'): - # Just to be defensive, if we don't have `__name__`, get the - # qualname. Not sure if this happens for any members of `operator` - # or `builtins`. This fallback path is not as good, since e.g. - # things in `operator` have `_operator` as their __module__. - return _get_qualified_name(target) - if target.__module__ == 'builtins': - return f'builtins.{target.__name__}' - elif target.__module__ == '_operator': - return f'operator.{target.__name__}' - return _get_qualified_name(target) - - @compatibility(is_backward_compatible=True) - def format_node(self, - placeholder_names: Optional[List[str]] = None, - maybe_return_typename: Optional[List[str]] = None) -> Optional[str]: - """ - Return a descriptive string representation of ``self``. - - This method can be used with no arguments as a debugging - utility. - - This function is also used internally in the ``__str__`` method - of ``Graph``. Together, the strings in ``placeholder_names`` - and ``maybe_return_typename`` make up the signature of the - autogenerated ``forward`` function in this Graph's surrounding - GraphModule. ``placeholder_names`` and ``maybe_return_typename`` - should not be used otherwise. - - Args: - placeholder_names: A list that will store formatted strings - representing the placeholders in the generated - ``forward`` function. Internal use only. - maybe_return_typename: A single-element list that will store - a formatted string representing the output of the - generated ``forward`` function. Internal use only. - - Returns: - str: If 1) we're using ``format_node`` as an internal helper - in the ``__str__`` method of ``Graph``, and 2) ``self`` - is a placeholder Node, return ``None``. Otherwise, - return a descriptive string representation of the - current Node. - """ - if self.op == 'placeholder': - assert isinstance(self.target, str) - arg_str = self.target - arg_str += arg_str + f': {_type_repr(self.type)}' if self.type else '' - if placeholder_names: - placeholder_names.append(arg_str) - return None - maybe_typename = f'{_type_repr(self.type)} ' if self.type else '' - default_val = '(default=' + str(self.args[0]) + ')' if self.args else '' - return f'%{self.name} : {maybe_typename}[#users={len(self.users)}] = {self.op}[target={self.target}]{default_val}' - elif self.op == 'get_attr': - maybe_typename = f'{_type_repr(self.type)} ' if self.type is not None else '' - return f'%{self.name} : {maybe_typename}[#users={len(self.users)}] = ' \ - f'{self.op}[target={self._pretty_print_target(self.target)}]' - elif self.op == 'output': - if self.type and maybe_return_typename: - maybe_return_typename[0] = f' -> {_type_repr(self.type)}' - return f'return {self.args[0]}' - else: - maybe_typename = f'{_type_repr(self.type)} ' if self.type is not None else '' - return f'%{self.name} : {maybe_typename}[#users={len(self.users)}] = ' \ - f'{self.op}[target={self._pretty_print_target(self.target)}](' \ - f'args = {_format_arg(self.args)}, kwargs = {_format_arg(self.kwargs)})' - - @compatibility(is_backward_compatible=True) - def replace_all_uses_with(self, - replace_with : 'Node', - delete_user_cb: Callable[['Node'], bool] = lambda user: True - ) -> List['Node']: - """ - Replace all uses of ``self`` in the Graph with the Node ``replace_with``. - - Args: - - replace_with (Node): The node to replace all uses of ``self`` with. - delete_user_cb (Callable): Callback that is called to determine - whether a given user of the self node should be removed. - - Returns: - - The list of Nodes on which this change was made. - """ - to_process = list(self.users) - skipped = [] - for use_node in to_process: - if not delete_user_cb(use_node): - skipped.append(use_node) - continue - - def maybe_replace_node(n : Node) -> Node: - if n == self: - return replace_with - else: - return n - - new_args = map_arg(use_node.args, maybe_replace_node) - new_kwargs = map_arg(use_node.kwargs, maybe_replace_node) - assert isinstance(new_args, tuple) - assert isinstance(new_kwargs, dict) - use_node.__update_args_kwargs(new_args, new_kwargs) - - assert len(self.users) - len(skipped) == 0 - return [n for n in to_process if n not in skipped] - - @compatibility(is_backward_compatible=False) - def is_impure(self): - """ - Returns whether this op is impure, i.e. if its op is a placeholder or - output, or if a call_function or call_module which is impure. - - Returns: - - bool: If the op is impure or not. - """ - if self.op in {"placeholder", "output"}: - return True - - # Check if an impure function. - if self.op == "call_function": - return self.target in _side_effectful_functions - - # Check if an impure module. - if self.op == "call_module": - assert ( - self.graph.owning_module is not None - ), "self.graph.owning_module not set for purity check" - target_mod = self.graph.owning_module.get_submodule(self.target) - assert ( - target_mod is not None - ), f"Did not find expected submodule target {self.target}" - return getattr(target_mod, "_is_impure", False) - - return False - - @compatibility(is_backward_compatible=False) - def normalized_arguments( - self, root : torch.nn.Module, arg_types : Optional[Tuple[Any]] = None, - kwarg_types : Optional[Dict[str, Any]] = None, - normalize_to_only_use_kwargs : bool = False) -> Optional[ArgsKwargsPair]: - """ - Returns normalized arguments to Python targets. This means that - `args/kwargs` will be matched up to the module/functional's - signature and return exclusively kwargs in positional order - if `normalize_to_only_use_kwargs` is true. - Also populates default values. Does not support positional-only - parameters or varargs parameters. - - Supports module calls. - - May require `arg_types` and `kwarg_types` in order to disambiguate overloads. - - Args: - root (torch.nn.Module): Module upon which to resolve module targets. - arg_types (Optional[Tuple[Any]]): Tuple of arg types for the args - kwarg_types (Optional[Dict[str, Any]]): Dict of arg types for the kwargs - normalize_to_only_use_kwargs (bool): Whether to normalize to only use kwargs. - - Returns: - - Returns NamedTuple ArgsKwargsPair, or `None` if not successful. - """ - if self.op == 'call_function': - assert callable(self.target) - return normalize_function(self.target, self.args, self.kwargs, arg_types, kwarg_types) # type: ignore[arg-type] - elif self.op == 'call_module': - assert isinstance(self.target, str) - return normalize_module(root, self.target, self.args, self.kwargs) # type: ignore[arg-type] - - return None - - @compatibility(is_backward_compatible=True) - def replace_input_with(self, old_input: 'Node', new_input: 'Node'): - """ - Loop through input nodes of ``self``, and replace all instances of - ``old_input`` with ``new_input``. - - Args: - - old_input (Node): The old input node to be replaced. - new_input (Node): The new input node to replace ``old_input``. - """ - def maybe_replace_node(n : Node) -> Node: - return new_input if n == old_input else n - - new_args = map_arg(self.args, maybe_replace_node) - new_kwargs = map_arg(self.kwargs, maybe_replace_node) - assert isinstance(new_args, tuple) - assert isinstance(new_kwargs, dict) - self.__update_args_kwargs(new_args, new_kwargs) - - -@compatibility(is_backward_compatible=True) -def map_arg(a: Argument, fn: Callable[[Node], Argument]) -> Argument: - """ - Apply fn to each Node appearing arg. arg may be a list, tuple, slice, or dict with string keys. - """ - assert callable(fn), "pippy.fx.map_arg(a, fn): fn must be a callable" - return map_aggregate(a, lambda x: fn(x) if isinstance(x, Node) else x) - - -@compatibility(is_backward_compatible=True) -def map_aggregate(a: Argument, fn: Callable[[Argument], Argument], - should_traverse_fn: Optional[Callable[[Argument], bool]] = None) -> Argument: - """ - Apply fn to each Node appearing arg. arg may be a list, tuple, slice, or dict with string keys. - Traverses list, tuple, slice, or dict if ``should_traverse_fn`` is either None or returns True for supplied argument - """ - if should_traverse_fn and not should_traverse_fn(a): - return fn(a) - - if isinstance(a, tuple): - t = tuple(map_aggregate(elem, fn, should_traverse_fn) for elem in a) - # Support NamedTuple (if it has `_fields`) by repacking into original type. - return t if not hasattr(a, '_fields') else type(a)(*t) - elif isinstance(a, list): - return immutable_list(map_aggregate(elem, fn, should_traverse_fn) for elem in a) - elif isinstance(a, dict): - return immutable_dict((k, map_aggregate(v, fn, should_traverse_fn)) for k, v in a.items()) - elif isinstance(a, slice): - return slice(map_aggregate(a.start, fn, should_traverse_fn), map_aggregate(a.stop, fn, should_traverse_fn), - map_aggregate(a.step, fn, should_traverse_fn)) - else: - return fn(a) diff --git a/pippy/fx/operator_schemas.py b/pippy/fx/operator_schemas.py deleted file mode 100644 index eccabf917..000000000 --- a/pippy/fx/operator_schemas.py +++ /dev/null @@ -1,409 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates -import torch -import inspect -import numbers -import types -import typing -import enum -import warnings -from typing import Any, Callable, Dict, List, Optional, Tuple, NamedTuple, cast, TYPE_CHECKING -from torch._jit_internal import boolean_dispatched -from ._compatibility import compatibility -from torch._ops import OpOverloadPacket, OpOverload - -if TYPE_CHECKING: - from .node import Argument - -@compatibility(is_backward_compatible=False) -class ArgsKwargsPair(NamedTuple): - """ - Simple named tuple for wrapping args/kwargs pairs. - """ - args: Tuple[Any, ...] - kwargs: Dict[str, Any] - -_manual_overrides : Dict[Callable, List[inspect.Signature]] = {} - -def _nonzero_schemas(): - signatures = [] - - def nonzero(self): - pass - signatures.append(inspect.signature(nonzero)) - - def nonzero(self, *, as_tuple : bool): # type: ignore[no-redef] - pass - signatures.append(inspect.signature(nonzero)) - - return signatures - -_manual_overrides[torch.nonzero] = _nonzero_schemas() - -class _FakeGlobalNamespace: - def __getattr__(self, name): - if name == 'torch': - return torch - raise RuntimeError('Expected a torch namespace lookup') - -_type_eval_globals = {'Tensor' : torch.Tensor, 'Device' : torch.device, 'Layout' : torch.layout, - 'number' : numbers.Number, 'Future' : torch.jit.Future, - 'AnyEnumType' : enum.Enum, 'QScheme' : torch.qscheme, - '__torch__': _FakeGlobalNamespace(), 'NoneType': type(None), - 't': typing.TypeVar('t')} -for k in dir(typing): - _type_eval_globals[k] = getattr(typing, k) - -def _torchscript_type_to_python_type(ts_type : 'torch._C.JitType') -> Any: - """ - Convert a TorchScript type to a Python type (including subtypes) via - eval'ing the annotation_str. _type_eval_globals sets up expressions - like "List" and "Future" to map to actual types (typing.List and jit.Future) - """ - return eval(ts_type.annotation_str, _type_eval_globals) - -def _torchscript_schema_to_signature(ts_schema : torch._C.FunctionSchema) -> inspect.Signature: - parameters : List[inspect.Parameter] = [] - for arg in ts_schema.arguments: - arg_type = _torchscript_type_to_python_type(arg.type) - default = arg.default_value if arg.has_default_value() else inspect.Parameter.empty - # TODO: Figure out if this is safe. It seems like when generating the type signatures for - # PythonArgParser, we emit signatures with `input` instead of `self` as the first tensor - # argument name. Downstream, if someone converts that positional argument to a keyword - # argument, the name mismatch will break things, so here we're going to normalize the - # name to "input" - name = arg.name if arg.name != 'self' else 'input' - kind = inspect.Parameter.KEYWORD_ONLY if arg.kwarg_only else inspect.Parameter.POSITIONAL_OR_KEYWORD - parameters.append(inspect.Parameter(name=name, kind=kind, default=default, annotation=arg_type)) - return_types = [_torchscript_type_to_python_type(ret.type) for ret in ts_schema.returns] - if len(return_types) == 0: - return_type = None - elif len(return_types) == 1: - return_type = return_types[0] - else: - return_type = tuple(return_types) - - return inspect.Signature(parameters, return_annotation=return_type) - -@compatibility(is_backward_compatible=False) -def check_for_mutable_operation(target : Callable, args : Tuple['Argument', ...], kwargs : Dict[str, 'Argument']): - signatures, schemas = get_signature_for_torch_op(target, return_schemas=True) - - if signatures and schemas: - matched_schemas = [] - - # Iterate through all of the schema until we find one that matches - # If one matches, populate `new_args_and_kwargs` with the new args/kwargs - # values. If none matches, `new_args_and_kwargs` will be None - for candidate_signature, schema in zip(signatures, schemas): - try: - candidate_signature.bind(*args, **kwargs) - matched_schemas.append((candidate_signature, schema)) - except TypeError as e: - continue - - def throw_if_mutable(schema): - if schema.is_mutable: - raise RuntimeError(f'Tried to trace mutable operation {schema}. FX only supports functional ' - f'code, so operations that mutate operands in-place (e.g. via `out` arguments) ' - f'are not supported') - - if len(matched_schemas) == 0: - # Did not match any schema. Cannot check for mutation - pass - elif len(matched_schemas) == 1: - # Matched exactly one schema, unambiguous - _, schema_to_check = matched_schemas[0] - throw_if_mutable(schema_to_check) - pass - else: - # Ambiguous schema match. Since mutability checking is best effort, - # do nothing. - pass - -@compatibility(is_backward_compatible=False) -def get_signature_for_torch_op(op : Callable, return_schemas : bool = False): - """ - Given an operator on the `torch` namespace, return a list of `inspect.Signature` - objects corresponding to the overloads of that op.. May return `None` if a signature - could not be retrieved. - - Args: - op (Callable): An operator on the `torch` namespace to look up a signature for - - Returns: - Optional[List[inspect.Signature]]: A list of signatures for the overloads of this - operator, or None if the operator signatures could not be retrieved. If - return_schemas=True, returns a tuple containing the optional Python signatures - and the optional TorchScript Function signature - """ - if isinstance(op, OpOverload): - schemas = [op._schema] - elif isinstance(op, OpOverloadPacket): - schemas = [getattr(op, overload)._schema for overload in op.overloads()] - else: - override = _manual_overrides.get(op) - if override: - return (override, None) if return_schemas else None - - aten_fn = torch.jit._builtins._find_builtin(op) - - if aten_fn is None: - return (None, None) if return_schemas else None - schemas = torch._C._jit_get_schemas_for_operator(aten_fn) - - signatures = [_torchscript_schema_to_signature(schema) for schema in schemas] - return (signatures, schemas) if return_schemas else signatures - -@compatibility(is_backward_compatible=False) -def create_type_hint(x): - try: - if isinstance(x, list) or isinstance(x, tuple): - # todo(chilli): Figure out the right way for mypy to handle this - if isinstance(x, list): - def ret_type(x): - return List[x] # type: ignore[valid-type] - else: - def ret_type(x): - return Tuple[x, ...] - if len(x) == 0: - return ret_type(Any) - base_type = x[0] - for t in x: - if issubclass(t, base_type): - continue - elif issubclass(base_type, t): - base_type = t - else: - return ret_type(Any) - return ret_type(base_type) - except Exception as e: - # We tried to create a type hint for list but failed. - warnings.warn(f"We were not able to successfully create type hint from the type {x}") - pass - return x - -@compatibility(is_backward_compatible=False) -def type_matches(signature_type : Any, argument_type : Any): - sig_origin_type = getattr(signature_type, '__origin__', signature_type) - - if signature_type is argument_type: - return True - - # Union types in signature. Given type needs to match one of the - # contained types in the Union - if sig_origin_type is typing.Union and signature_type != argument_type: - sig_contained = signature_type.__args__ - return any(type_matches(c, argument_type) for c in sig_contained) - - if signature_type is List[int] and argument_type is int: - # int can be promoted to List[int] - return True - - if getattr(signature_type, '__origin__', None) in {list, List}: - sig_el_type = signature_type.__args__[0] - if not inspect.isclass(sig_el_type): - warnings.warn( - f"Does not support nested parametric types, got {signature_type}. Please file a bug.") - return False - if getattr(argument_type, '__origin__', None) in {list, List}: - return issubclass(argument_type.__args__[0], sig_el_type) - - def is_homogeneous_tuple(t): - if not getattr(t, '__origin__', None) in {tuple, Tuple}: - return False - contained = t.__args__ - if t.__args__ == ((),): # Tuple[()].__args__ == ((),) for some reason - return True - return all((c is Ellipsis) or issubclass(c, sig_el_type) for c in contained) - - # Tuple[T] is accepted for List[T] parameters - return is_homogeneous_tuple(argument_type) - - # Dtype is an int in schemas - if signature_type is int and argument_type is torch.dtype: - return True - - if signature_type is numbers.Number and argument_type in {int, float}: - return True - if inspect.isclass(argument_type) and inspect.isclass(signature_type): - return issubclass(argument_type, signature_type) - - return False - -@compatibility(is_backward_compatible=False) -def normalize_function( - target: Callable, args: Tuple[Any], kwargs : Optional[Dict[str, Any]] = None, arg_types : Optional[Tuple[Any]] = None, - kwarg_types : Optional[Dict[str, Any]] = None, - normalize_to_only_use_kwargs : bool = False) -> Optional[ArgsKwargsPair]: - """ - Returns normalized arguments to PyTorch functions. This means that - `args/kwargs` will be matched up to the functional's - signature and return exclusively kwargs in positional order if - `normalize_to_only_use_kwargs` is True. - Also populates default values. Does not support positional-only - parameters or varargs parameters (*args, **kwargs). Does not support modules. - - May require `arg_types` and `kwarg_types` in order to disambiguate overloads. - - Args: - target (Callable): Function that we are normalizing - args (Tuple[Any]): Tuple of args to the function - kwargs (Optional[Dict[str, Any]]): Dict of kwargs to the function - arg_types (Optional[Tuple[Any]]): Tuple of arg types for the args - kwarg_types (Optional[Dict[str, Any]]): Dict of arg types for the kwargs - normalize_to_only_use_kwargs (bool): Whether to normalize to only use kwargs. - - Returns: - - Returns normalized_args_and_kwargs, or `None` if not successful. - """ - if kwargs is None: - kwargs = {} - new_args_and_kwargs = None - if not isinstance(target, types.BuiltinFunctionType) and not ( - isinstance(target, OpOverloadPacket) or isinstance(target, OpOverload) - ): - target_for_analysis = target - if target in boolean_dispatched: - # HACK: `boolean_dispatch` as used in `torch.nn.functional` makes it so that we have - # a 2-way dispatch based on a boolean value. Here we check that the `true` and `false` - # branches of the dispatch have exactly the same signature. If they do, use the `true` - # branch signature for analysis. Otherwise, leave this un-normalized - assert not isinstance(target, str) - dispatched = boolean_dispatched[target] - if_true, if_false = dispatched['if_true'], dispatched['if_false'] - if inspect.signature(if_true).parameters != inspect.signature(if_false).parameters: - return None - target_for_analysis = if_true - - assert callable(target_for_analysis) - sig = inspect.signature(inspect.unwrap(target_for_analysis)) - new_args_and_kwargs = _args_kwargs_to_normalized_args_kwargs(sig, args, kwargs, normalize_to_only_use_kwargs) - else: - assert callable(target) - torch_op_schemas = get_signature_for_torch_op(target) - matched_schemas = [] - if torch_op_schemas: - # Iterate through all of the schema until we find one that matches - # If one matches, populate `new_args_and_kwargs` with the new args/kwargs - # values. If none matches, `new_args_and_kwargs` will be None - for candidate_signature in torch_op_schemas: - try: - candidate_signature.bind(*args, **kwargs) - matched_schemas.append(candidate_signature) - except TypeError as e: - continue - - if len(matched_schemas) == 0: - # Did not match any schema. Cannot normalize - pass - elif len(matched_schemas) == 1: - # Matched exactly one schema, unambiguous - new_args_and_kwargs = _args_kwargs_to_normalized_args_kwargs(matched_schemas[0], args, kwargs, - normalize_to_only_use_kwargs) - else: - if arg_types is not None or kwarg_types is not None: - arg_types = arg_types if arg_types else cast(Tuple[Any], ()) - kwarg_types = kwarg_types if kwarg_types else {} - for candidate_signature in torch_op_schemas: - sig_matches = True - try: - bound_types = candidate_signature.bind(*arg_types, **kwarg_types) - for arg_name, arg_type in bound_types.arguments.items(): - param = candidate_signature.parameters[arg_name] - sig_matches = sig_matches and type_matches(param.annotation, arg_type) - except TypeError as e: - sig_matches = False - if sig_matches: - new_args_and_kwargs = _args_kwargs_to_normalized_args_kwargs(candidate_signature, args, kwargs, - normalize_to_only_use_kwargs) - break - else: - # Matched more than one schema. In this situation, the caller must provide the types of - # the arguments of the overload they expect. - schema_printouts = '\n'.join(str(schema) for schema in matched_schemas) - raise RuntimeError(f'Tried to normalize arguments to {torch.typename(target)} but ' - f'the schema match was ambiguous! Please provide argument types to ' - f'the normalize_arguments() call. Available schemas:\n{schema_printouts}') - - return new_args_and_kwargs - -@compatibility(is_backward_compatible=False) -def normalize_module( - root: torch.nn.Module, target: str, args: Tuple[Any], kwargs : Optional[Dict[str, Any]] = None, - normalize_to_only_use_kwargs : bool = False) -> Optional[ArgsKwargsPair]: - """ - Returns normalized arguments to PyTorch modules. This means that - `args/kwargs` will be matched up to the functional's - signature and return exclusively kwargs in positional order if - `normalize_to_only_use_kwargs` is True. - Also populates default values. Does not support positional-only - parameters or varargs parameters (*args, **kwargs). - - Args: - root (nn.Module): root module upon which we query modules - target (Callable): Function that we are normalizing - args (Tuple[Any]): Tuple of args to the function - kwargs (Optional[Dict[str, Any]]): Dict of kwargs to the function - normalize_to_only_use_kwargs (bool): Whether to normalize to only use kwargs. - - Returns: - - Returns normalized_args_and_kwargs, or `None` if not successful. - """ - try: - submod = root.get_submodule(target) - except AttributeError: - raise RuntimeError(f"Tried to normalize node with target {target} but root did not " - f"have that target!") - if hasattr(submod.__class__, '__name__'): - classname = submod.__class__.__name__ - if getattr(torch.nn, classname, None) == submod.__class__: - sig = inspect.signature(inspect.unwrap(submod.forward)) - if kwargs is None: - kwargs = {} - new_args_and_kwargs = _args_kwargs_to_normalized_args_kwargs(sig, args, kwargs, - normalize_to_only_use_kwargs) - return new_args_and_kwargs - return None - -def _args_kwargs_to_normalized_args_kwargs(sig : inspect.Signature, args : Tuple[Any, ...], - kwargs : Dict[str, Any], - normalize_to_only_use_kwargs : bool) -> Optional[ArgsKwargsPair]: - """ - Given a call target, args, and kwargs, return the arguments normalized into - an ArgsKwargsPair, or None if the type signature is not supported by - this normalization. - - Args: - - target (inspect.Signature): Signature object for the target - args (Tuple): Arguments that appear at the callsite for `target` - kwargs (Dict): Keyword arguments that appear at the callsite for `target` - normalize_to_only_use_kwargs (bool): Whether to normalize to only use kwargs. - - Returns: - - Optional[ArgsKwargsPair]: Normalized args and kwargs for `target`, or `None` if - this target is not supported. - """ - - # Don't currently support positional-only - # or varargs (*args, **kwargs) signatures - supported_parameter_types = { - inspect.Parameter.POSITIONAL_OR_KEYWORD, inspect.Parameter.KEYWORD_ONLY} - if any(p.kind not in supported_parameter_types for p in sig.parameters.values()): - return None - - bound_args = sig.bind(*args, **kwargs) - bound_args.apply_defaults() - - new_kwargs : Dict[str, Any] = {} - new_args : List[Any] = [] - for i, param in enumerate(sig.parameters): - if not normalize_to_only_use_kwargs and i < len(args): - new_args.append(bound_args.arguments[param]) - else: - new_kwargs[param] = bound_args.arguments[param] - - return ArgsKwargsPair(tuple(new_args), new_kwargs) diff --git a/pippy/fx/passes/README.md b/pippy/fx/passes/README.md deleted file mode 100644 index a29968487..000000000 --- a/pippy/fx/passes/README.md +++ /dev/null @@ -1,20 +0,0 @@ -## FX Pass Infrastructure -This folder contains the pass infarstructure and passes for transforming fx.Graph. - - -## Code Structure - -* [infra](infra) - Common infrastructure, such as PassManager, PassBase - * [partitioner.py](infra/partitioner.py) - backend agnostic FX graph partitioner -* [utils](utils) - Utility classes and functions - * [common.py](utils/common.py) - common utility functions - * [fuser_utis.py](utils/fuser_utils.py) - utility functions for fusing list of nodes into a single node -* [dialect](dialect) - dialect specific passes - * [common](dialect/common) - common passes that can be shared by all dialects - * [cse_pass.py](dialect/common/cse_pass.py) - a CSE pass - * [aten](dialect/aten) - aten dialect specific passes - * [prims](dialect/prims) - prim dialect specific passes -* [backends](backends) - Backend specific passes - * [nvfuser](backends/nvfuser) - passes for nvfuser - * [operator_support.py](backends/nvfuser/operator_support.py) - nvFuser supported ops -* [conversion](conversion) - Conversion passes between dialects diff --git a/pippy/fx/passes/__init__.py b/pippy/fx/passes/__init__.py deleted file mode 100644 index d20580680..000000000 --- a/pippy/fx/passes/__init__.py +++ /dev/null @@ -1,12 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates -from . import graph_drawer -from . import graph_manipulation -from . import net_min_base -from . import operator_support -from . import param_fetch -from . import reinplace -from . import shape_prop -from . import split_module -from . import split_utils -from . import splitter_base -from . import tools_common diff --git a/pippy/fx/passes/backends/__init__.py b/pippy/fx/passes/backends/__init__.py deleted file mode 100644 index f2661b8c6..000000000 --- a/pippy/fx/passes/backends/__init__.py +++ /dev/null @@ -1 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates diff --git a/pippy/fx/passes/backends/cudagraphs.py b/pippy/fx/passes/backends/cudagraphs.py deleted file mode 100644 index 3898a2e74..000000000 --- a/pippy/fx/passes/backends/cudagraphs.py +++ /dev/null @@ -1,56 +0,0 @@ -import torch -from pippy.fx.passes.infra.partitioner import CapabilityBasedPartitioner -from pippy.fx.passes.operator_support import OperatorSupport -from pippy.fx.passes.tools_common import CALLABLE_NODE_OPS -from pippy.fx.passes.fake_tensor_prop import FakeTensorProp -from torch.utils._pytree import tree_map - -import operator - -class CudaGraphsSupport(OperatorSupport): - # TODO: why is submodules passed here - def is_node_supported(self, submodules, node: pippy.fx.Node) -> bool: - if node.op not in CALLABLE_NODE_OPS: - return False - - if node.target in [torch.ops.aten.embedding_dense_backward.default]: - return False - - if node.target in [operator.getitem]: - return True - - found_not_cuda = False - - def meta_fk(meta): - return meta["val"] if "val" in meta else meta["fake_result"] - - def find_not_cuda(t): - nonlocal found_not_cuda - if isinstance(t, torch.Tensor) and t.device.type != 'cuda': - found_not_cuda = True - - for n in node.all_input_nodes: - tree_map(find_not_cuda, meta_fk(n.meta)) - - tree_map(find_not_cuda, meta_fk(node.meta)) - - # NB: factory function is accounted for because the result would be - # cpu or cuda - - return not found_not_cuda - -def partition_cudagraphs(gm, inputs): - """ - Partition an FX graph into sub-GraphModules that can be validly run under - CUDA graphs. For a subgraph to be runnable under CUDA, all of the operations - must involve CUDA tensors only/ - """ - - FakeTensorProp(gm).propagate(*inputs) - supported_ops = CudaGraphsSupport() - # TODO: single node partition may be wrong due to the pessimization - # from copying in and out the data. Check in benchmarks, perhaps - partitioner = CapabilityBasedPartitioner(gm, supported_ops, allows_single_node_partition=True) - partitions = partitioner.propose_partitions() - fused_graph = partitioner.fuse_partitions(partitions) - return fused_graph diff --git a/pippy/fx/passes/backends/nvfuser.py b/pippy/fx/passes/backends/nvfuser.py deleted file mode 100644 index 689ab8432..000000000 --- a/pippy/fx/passes/backends/nvfuser.py +++ /dev/null @@ -1,287 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates -from typing import Dict - -import torch -from torch.nn import Module -from torch._ops import OpOverload - -from pippy.fx import GraphModule -from pippy.fx.node import Node, _get_qualified_name -from pippy.fx.passes.operator_support import OperatorSupport -from pippy.fx.passes.tools_common import CALLABLE_NODE_OPS -from pippy.fx.passes.infra.partitioner import CapabilityBasedPartitioner -from torch._prims.executor import execute -from pippy.fx.experimental.proxy_tensor import DecompositionInterpreter -from torch._decomp import decomposition_table - -import typing as t - -import logging - -logging.basicConfig(level=logging.WARNING) -logger = logging.getLogger(__name__) - -def aten_to_dtype(self, dtype: torch.dtype, **kwargs): - if len(kwargs) > 0 or not dtype: - raise RuntimeError("No support for other to.dtype() formats other than to.dtype(self, dtype)") - return torch._prims.convert_element_type(self, dtype) - -# decomposition_table currently contains both aten2aten and aten2prim decomposition -# this is a hack to separate them, as we only need aten2prim decomposition for nvfuser-supported aten graph lowering -aten2aten_decomp = {} -aten2prim_decomp = {} - -for op, decomp_fn in decomposition_table.items(): - if "torch._refs" in decomp_fn.__module__: - aten2prim_decomp[op] = decomp_fn - else: - aten2aten_decomp[op] = decomp_fn - -aten2aten_decomp_skips = { - "aten.native_layer_norm_backward.default", - "aten.embedding_dense_backward.default", # This is hurting nvfuser's perf - "aten.addmm.default" -} - -for op, decomp_fn in decomposition_table.items(): - if "torch._refs" in decomp_fn.__module__: - aten2prim_decomp[op] = decomp_fn - else: - if str(op) not in aten2aten_decomp_skips: - aten2aten_decomp[op] = decomp_fn - - -aten2prim_decomp[torch.ops.aten.to.dtype] = aten_to_dtype - - -class NvFuserOperatorSupport(OperatorSupport): - """ - Operator support for nvFuser backend. - - Currently, partitioning is based on FX ATen graph. The fused subgraph will latter be decomposed into prims. - To determine if an ATen ops is supported by nvFuser, we shall check the prim ops used in its ref decomposition. - Only if all the prim ops in the ref has a nvfuser_impl, we say this Aten op is suppported by nvFuser. - - Note: When adding a rule, please add it to the corresponding section and follow the - alphabetical order. - """ - - def __init__(self): - - # TODO: current list copied from torch/csrc/jit/codegen/cuda/parser.cpp is incorrect, - # as that file is solely for TorchScript and doesn't represent the actual status - # whether operation would be runnable by primTorch+nvFuser. - # We will iterate on this list to reflect the the reality. - support_dict = { - # =============================================================== - # call_function aten - # =============================================================== - # Following supported aten ops is copied from torch/csrc/jit/codegen/cuda/parser.cpp - # TODO: might need to update according to supported input types - "torch.ops.aten.add": None, - "torch.ops.aten.sub": None, - # "torch.ops.aten.rsub": None, # rsub decomp is supported at aten2aten level - "torch.ops.aten.div": None, - "torch.ops.aten.atan2": None, - "torch.ops.aten.mul": None, - "torch.ops.aten.max": None, - "torch.ops.aten.min": None, - "torch.ops.aten.pow": None, - "torch.ops.aten.remainder": None, - "torch.ops.aten.fmod": None, - "torch.ops.aten.bitwise_and": None, - "torch.ops.aten.__and__": None, - "torch.ops.aten.bitwise_or": None, - "torch.ops.aten.__or__": None, - "torch.ops.aten.bitwise_xor": None, - "torch.ops.aten.__xor__": None, - "torch.ops.aten.bitwise_left_shift": None, - "torch.ops.aten.__lshift__": None, - "torch.ops.aten.bitwise_right_shift": None, - "torch.ops.aten.__rshift__": None, - "torch.ops.aten.eq": None, - "torch.ops.aten.ne": None, - "torch.ops.aten.ge": None, - "torch.ops.aten.gt": None, - "torch.ops.aten.le": None, - "torch.ops.aten.lt": None, - "torch.ops.aten.abs": None, - "torch.ops.aten.bitwise_not": None, - "torch.ops.aten.ceil": None, - "torch.ops.aten.floor": None, - "torch.ops.aten.frac": None, - "torch.ops.aten.neg": None, - "torch.ops.aten.relu": None, - "torch.ops.aten.round": None, - "torch.ops.aten.silu": None, - "torch.ops.aten.trunc": None, - "torch.ops.aten.log": None, - "torch.ops.aten.log10": None, - "torch.ops.aten.log1p": None, - "torch.ops.aten.log2": None, - "torch.ops.aten.lgamma": None, - "torch.ops.aten.exp": None, - "torch.ops.aten.expm1": None, - "torch.ops.aten.erf": None, - "torch.ops.aten.erfc": None, - "torch.ops.aten.cos": None, - "torch.ops.aten.acos": None, - "torch.ops.aten.cosh": None, - "torch.ops.aten.sin": None, - "torch.ops.aten.asin": None, - "torch.ops.aten.sinh": None, - "torch.ops.aten.tan": None, - "torch.ops.aten.atan": None, - "torch.ops.aten.tanh": None, - "torch.ops.aten.atanh": None, - "torch.ops.aten.sqrt": None, - "torch.ops.aten.rsqrt": None, - "torch.ops.aten.reciprocal": None, - "torch.ops.aten.sigmoid": None, - "torch.ops.aten.isfinite": None, - "torch.ops.aten.isinf": None, - "torch.ops.aten.isnan": None, - "torch.ops.aten.isneginf": None, - "torch.ops.aten.isposinf": None, - "torch.ops.aten.isreal": None, - # "torch.ops.aten.rand_like": None, # causing Node empty_like_default does not support nvfuser - "torch.ops.aten.softplus": None, - "torch.ops.aten.threshold": None, - # relying on aten->aten->prim decomp, aten2aten is using unsupported aten.new_zero op - # "torch.ops.aten.threshold_backward": None, - "torch.ops.aten.clamp": None, - # "torch.ops.aten.clone": None, - # Failing with where(): incompatible function arguments: \ - # [aten->prim decomp, aten2aten is using unsupported aten.div - # "torch.ops.aten.native_layer_norm_backward": None, - "torch.ops.aten.softmax.int": None, - "torch.ops.aten.log_softmax.int": None, - # relying on aten->aten->prim decomp, aten2aten is using unsupported aten.amax - # "torch.ops.aten._softmax": None, - "torch.ops.aten._log_softmax_backward_data": None, - # "torch.ops.aten._softmax_backward_data": None, # Node _softmax_backward_data_default does not support nvfuser - # "torch.ops.aten.var.dim": None, # missing refs - "torch.ops.aten.std.dim": None, - "torch.ops.aten.sum": None, - # "torch.ops.aten.mean.dim": None, # missing refs - "torch.ops.aten._grad_sum_to_size": None, - "torch.ops.aten.sum_to_size": None, - "torch.ops.aten._autocast_to_reduced_precision": None, - "torch.ops.aten._autocast_to_full_precision": None, - # "torch.ops.aten.to.dtype": None, # causing segfault - # "torch.ops.aten.type_as": None, # missing refs - "torch.ops.aten.linear": None, - "torch.ops.aten.gelu": None, - # "torch.ops.aten.gelu_backward": None, # gelu_backward is handled at aten2aten decomp - # "torch.ops.aten.hardtanh": None, # has functional ref, using unsupported aten.clamp - "torch.ops.aten.leaky_relu": None, - "torch.ops.aten.square": None, - # relying on aten->aten->prim decomp, aten2aten is using unsupported aten.conj_physical - "torch.ops.aten.tanh_backward": None, - # "torch.ops.aten.amax": None, # missing prim decomp - # "torch.ops.aten.amin": None, # missing prim decomp - # "torch.ops.aten.reshape": None, - # "torch.ops.aten.view": None, # missing prim decomp - "torch.ops.aten.flatten.using_ints": None, - - # =============================================================== - # call_function builtins and operator - # =============================================================== - "getattr": None, - "_operator.getitem": None, - } - - super().__init__(support_dict) - - def is_node_supported( - self, submodules: t.Mapping[str, Module], node: Node - ) -> bool: - - # nvFuser FX subgraph should be purely functional - if node.op not in CALLABLE_NODE_OPS: - return False - - # ops in supported_dict doesn't have overload name - # use overloadpacket's qualified_name for OpOverload - if isinstance(node.target, OpOverload): - target = _get_qualified_name(node.target.overloadpacket) - if target in self._support_dict: - return True - - return super().is_node_supported(submodules, node) - - -class NvFuserBackend: - def __init__(self): - self.supported_ops = NvFuserOperatorSupport() - - # TODO: this is a naive implementation of cache without proper guard - self.partitioner_cache: Dict[GraphModule, GraphModule] = {} - - # TODO: this is a naive implementation of cache without proper guard, this will only work for identical inputs - self.prim_decomp_cache: Dict[GraphModule, GraphModule] = {} - - def lower_to_prims_and_execute(self, graph_module: GraphModule, *args, **kwargs): - # `graph_module` is an Aten-Fx graph - # "lowering to prims" and "trace execution" are grouped into this function, as they are both input dependent - - if graph_module in self.prim_decomp_cache: - logging.debug("prim_decomp_cache hit!") - prim_module = self.prim_decomp_cache[graph_module] - else: - prim_graph = pippy.fx.Graph() - DecompositionInterpreter(graph_module, prim_graph, decomposition_table=aten2prim_decomp).run(*args, **kwargs) - prim_module = pippy.fx.GraphModule(graph_module, prim_graph) - self.prim_decomp_cache[graph_module] = prim_module - - logging.debug("Lower to prims graph: ", prim_module.code) - - # invokes trace executor for running the prim graph - return execute(prim_module, *args, executor="nvfuser") - - def compile(self, graph_module: GraphModule) -> GraphModule: - # entry function for nvFuser backend - logging.debug("Compiling graph_module: ", graph_module.code) - - # FX graph based partitioning based on nvfuser supported ops - if graph_module in self.partitioner_cache: - logging.debug("partitioner_cache hit!") - fused_graph_module = self.partitioner_cache[graph_module] - else: - partitioner = CapabilityBasedPartitioner( - graph_module, self.supported_ops, allows_single_node_partition=False) - fused_graph_module = partitioner.partition_and_fuse() - - self.partitioner_cache[graph_module] = fused_graph_module - - # Overriding fused_module's __call__() function with lower_to_prims_and_execute() - for node in fused_graph_module.graph.nodes: - # TODO: use a better way to identify fused submodule - if node.op == "call_module" and "fused_" in node.name: - fused_module = getattr(fused_graph_module, node.name) - fused_module._wrapped_call = self.lower_to_prims_and_execute - - return fused_graph_module - - def __call__(self, graph_module: GraphModule, _) -> GraphModule: - # wrap self.compile as __call__ function to fit the interface for AOTAutograd's fw_compiler - return self.compile(graph_module) diff --git a/pippy/fx/passes/dialect/__init__.py b/pippy/fx/passes/dialect/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/pippy/fx/passes/dialect/common/__init__.py b/pippy/fx/passes/dialect/common/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/pippy/fx/passes/dialect/common/cse_pass.py b/pippy/fx/passes/dialect/common/cse_pass.py deleted file mode 100644 index 365781794..000000000 --- a/pippy/fx/passes/dialect/common/cse_pass.py +++ /dev/null @@ -1,114 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates -from typing import Dict, Tuple, Any - -import torch -from torch.utils._pytree import tree_flatten - -import pippy -from pippy.fx import GraphModule, Graph -from pippy.fx import Node -from pippy.fx.passes.infra.pass_base import PassBase, PassResult - -aten = torch.ops.aten - - -# stateful ops are banned from CSE -rand_ops = set([aten.dropout, aten._fused_dropout, aten._standard_gamma, aten.bernoulli, aten.multinomial, aten.native_dropout, aten.normal, aten.poisson, aten.binomial, aten.rrelu, aten.rand_like, aten.rand, aten.randint, aten.randn, aten.randperm]) # noqa: E501 - -inplace_ops = set([aten.add_, aten.sub_, aten.mul_, aten.div_, aten.pow_, aten.lerp_, aten.relu_, aten.sigmoid_, aten.tanh_]) # noqa: E501 - - -@pippy.fx._compatibility.compatibility(is_backward_compatible=False) -def get_CSE_banned_ops(): - return rand_ops.union(inplace_ops) - - -@pippy.fx._compatibility.compatibility(is_backward_compatible=False) -class CSEPass(PassBase): - - def __init__(self, banned_ops=None): - """ - This version of CSE Pass aims to be dialect agnostic, and it's implemented purely based on the connectivity between fx.Node. - - For functional dialects, user would only need to specify the random ops in ban list. - - Warning: CSE Pass cannot be safely applied on a FX graph in non-functional dialects. - If your dialect contains stateful operators, please customized the banned_ops. - - """ - if banned_ops is None: - banned_ops = set() - self.banned_ops = banned_ops - super().__init__() - - def call(self, graph_module: GraphModule) -> PassResult: - """ - Return a new copy of pippy.fx.GraphModule with CSE applied to the input graph - - Example usage: - - from pippy.fx.experimental.proxy_tensor import make_fx - def f(a): - b = a * a - c = a * a - return b+c - - p = CSEPass() - traced_graph = make_fx(f)(torch.tensor(1)) - print(traced_graph) - result = p(traced_graph) - print(result.graph_module) - """ - def get_aten_target(node): - if hasattr(node.target, 'overloadpacket'): - return node.target.overloadpacket - return node.target - - modified = False - new_graph = Graph() - env: Dict[Node, Node] = {} # map from node in the old graph to node in the new graph - hash_env: Dict[Tuple[torch._ops.OpOverload, int], Node] = {} # map from hash to a node in the new graph - token_map: Dict[Tuple[torch._ops.OpOverload, int], Dict[str, Any]] = {} # map from hash to token - for n in graph_module.graph.nodes: - # The placeholder, output, and get_attr nodes are copied to the new grpah without change - # do not CSE away random operations - if n.op == 'placeholder' or n.op == 'output' or n.op == 'get_attr' or get_aten_target(n) in self.banned_ops: - new_node = new_graph.node_copy(n, lambda x: env[x]) - env[n] = new_node - else: # n.op == 'call_function', should never see n.op == 'call_module' or 'call_method' - # substitute args and kwargs memebrs to their mapping in env if exists - # specs can be used to reconstruct nested list/dictionaries - def substitute(arg_list): - arg_list, spec = tree_flatten(arg_list) - for i in range(len(arg_list)): - v = arg_list[i] - if isinstance(v, Node) and v in env: - arg_list[i] = env[v] - return tuple(arg_list), spec - args, args_spec = substitute(n.args) - kwargs, kwargs_spec = substitute(n.kwargs) - - # each token corresponds to a unique node - # nodes with the same token can be substituted - token = {"target": n.target, "args": args, "args_spec": args_spec, - "kwargs": kwargs, "kwargs_spec": kwargs_spec} - - # hash substituted args to a number, do not hash specs because specs are not hashable - hash_arg = hash((args, kwargs)) - hash_val = (n.target, hash_arg) - - # check if a node has a substitute and can be eliminated - hash_val_in_hash_env = hash_val in hash_env - if hash_val_in_hash_env and token_map[hash_val] == token: - modified = True # substition happens and the graph is modified - env[n] = hash_env[hash_val] - continue - - new_node = new_graph.node_copy(n, lambda x: env[x]) - env[n] = new_node - if not hash_val_in_hash_env: - hash_env[hash_val] = new_node - token_map[hash_val] = token - - csed_gm = GraphModule(graph_module, new_graph) - return PassResult(csed_gm, modified) diff --git a/pippy/fx/passes/fake_tensor_prop.py b/pippy/fx/passes/fake_tensor_prop.py deleted file mode 100644 index bf30bd3f6..000000000 --- a/pippy/fx/passes/fake_tensor_prop.py +++ /dev/null @@ -1,30 +0,0 @@ -import pippy.fx -from pippy.fx import Node -from pippy.fx._compatibility import compatibility -from torch._subclasses.fake_tensor import FakeTensorMode - -__all__ = ['FakeTensorProp'] - -@compatibility(is_backward_compatible=False) -class FakeTensorProp(pippy.fx.Interpreter): - """ - Execute an FX graph Node-by-Node and record a fake tensor representing - the metadata for the node. Unlike ShapeProp, (1) this propagation - is cheap--it does the propagation with meta tensors which do not actually - store data, and (2) the fake tensors have much more fine grained information, - e.g., they have accurate alias information that can be consulted by looking - at the storages. - - Args: - module (GraphModule): The module to be executed - """ - - def run_node(self, n: Node): - result = super().run_node(n) - n.meta['val'] = result - return result - - def propagate(self, *args): - with FakeTensorMode.push() as mode: - fake_args = [mode.from_tensor(a) for a in args] - return super().run(*fake_args) diff --git a/pippy/fx/passes/graph_drawer.py b/pippy/fx/passes/graph_drawer.py deleted file mode 100644 index cddd6d99f..000000000 --- a/pippy/fx/passes/graph_drawer.py +++ /dev/null @@ -1,328 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates -from __future__ import absolute_import, division, print_function, unicode_literals - -import hashlib -import torch -import pippy.fx -from typing import Dict, Any, TYPE_CHECKING -from pippy.fx.node import _get_qualified_name, _format_arg -from pippy.fx.passes.shape_prop import TensorMetadata -from pippy.fx._compatibility import compatibility -from itertools import chain - -__all__ = ['FxGraphDrawer'] -try: - import pydot - HAS_PYDOT = True -except ImportError: - HAS_PYDOT = False - -_COLOR_MAP = { - "placeholder": '"AliceBlue"', - "call_module": "LemonChiffon1", - "get_param": "Yellow2", - "get_attr": "LightGrey", - "output": "PowderBlue", -} - -_HASH_COLOR_MAP = [ - "CadetBlue1", - "Coral", - "DarkOliveGreen1", - "DarkSeaGreen1", - "GhostWhite", - "Khaki1", - "LavenderBlush1", - "LightSkyBlue", - "MistyRose1", - "MistyRose2", - "PaleTurquoise2", - "PeachPuff1", - "Salmon", - "Thistle1", - "Thistle3", - "Wheat1", -] - -_WEIGHT_TEMPLATE = { - "shape": "record", - "fillcolor": "Salmon", - "style": '"filled,rounded"', - "fontcolor": "#000000", -} - -if HAS_PYDOT: - @compatibility(is_backward_compatible=False) - class FxGraphDrawer: - """ - Visualize a pippy.fx.Graph with graphviz - Basic usage: - g = FxGraphDrawer(symbolic_traced, "resnet18") - with open("a.svg", "w") as f: - f.write(g.get_dot_graph().create_svg()) - """ - - def __init__( - self, - graph_module: pippy.fx.GraphModule, - name: str, - ignore_getattr: bool = False, - ignore_parameters_and_buffers: bool = False, - skip_node_names_in_args: bool = True, - ): - self._name = name - self._dot_graphs = { - name: self._to_dot( - graph_module, name, ignore_getattr, ignore_parameters_and_buffers, skip_node_names_in_args - ) - } - - for node in graph_module.graph.nodes: - if node.op != "call_module": - continue - - leaf_node = self._get_leaf_node(graph_module, node) - - if not isinstance(leaf_node, pippy.fx.GraphModule): - continue - - self._dot_graphs[f"{name}_{node.target}"] = self._to_dot( - leaf_node, - f"{name}_{node.target}", - ignore_getattr, - ignore_parameters_and_buffers, - skip_node_names_in_args, - ) - - def get_dot_graph(self, submod_name=None) -> pydot.Dot: - if submod_name is None: - return self.get_main_dot_graph() - else: - return self.get_submod_dot_graph(submod_name) - - def get_main_dot_graph(self) -> pydot.Dot: - return self._dot_graphs[self._name] - - def get_submod_dot_graph(self, submod_name) -> pydot.Dot: - return self._dot_graphs[f"{self._name}_{submod_name}"] - - def get_all_dot_graphs(self) -> Dict[str, pydot.Dot]: - return self._dot_graphs - - def _get_node_style(self, node: pippy.fx.Node) -> Dict[str, str]: - template = { - "shape": "record", - "fillcolor": "#CAFFE3", - "style": '"filled,rounded"', - "fontcolor": "#000000", - } - if node.op in _COLOR_MAP: - template["fillcolor"] = _COLOR_MAP[node.op] - else: - # Use a random color for each node; based on its name so it's stable. - target_name = node._pretty_print_target(node.target) - target_hash = int(hashlib.md5(target_name.encode()).hexdigest()[:8], 16) - template["fillcolor"] = _HASH_COLOR_MAP[target_hash % len(_HASH_COLOR_MAP)] - return template - - def _get_leaf_node( - self, module: torch.nn.Module, node: pippy.fx.Node - ) -> torch.nn.Module: - py_obj = module - assert isinstance(node.target, str) - atoms = node.target.split(".") - for atom in atoms: - if not hasattr(py_obj, atom): - raise RuntimeError( - str(py_obj) + " does not have attribute " + atom + "!" - ) - py_obj = getattr(py_obj, atom) - return py_obj - - def _typename(self, target: Any) -> str: - if isinstance(target, torch.nn.Module): - ret = torch.typename(target) - elif isinstance(target, str): - ret = target - else: - ret = _get_qualified_name(target) - - # Escape "{" and "}" to prevent dot files like: - # https://gist.github.com/SungMinCho/1a017aab662c75d805c5954d62c5aabc - # which triggers `Error: bad label format (...)` from dot - return ret.replace("{", r"\{").replace("}", r"\}") - - def _get_node_label( - self, - module: pippy.fx.GraphModule, - node: pippy.fx.Node, - skip_node_names_in_args: bool, - ) -> str: - def _get_str_for_args_kwargs(arg): - if isinstance(arg, tuple): - prefix, suffix = r"|args=(\l", r",\n)\l" - arg_strs_list = [_format_arg(a, max_list_len=8) for a in arg] - elif isinstance(arg, dict): - prefix, suffix = r"|kwargs={\l", r",\n}\l" - arg_strs_list = [ - f"{k}: {_format_arg(v, max_list_len=8)}" - for k, v in arg.items() - ] - else: # Fall back to nothing in unexpected case. - return "" - - # Strip out node names if requested. - if skip_node_names_in_args: - arg_strs_list = [a for a in arg_strs_list if "%" not in a] - if len(arg_strs_list) == 0: - return "" - arg_strs = prefix + r",\n".join(arg_strs_list) + suffix - return arg_strs.replace("{", r"\{").replace("}", r"\}") - - - label = "{" + f"name=%{node.name}|op_code={node.op}\n" - - if node.op == "call_module": - leaf_module = self._get_leaf_node(module, node) - label += r"\n" + self._typename(leaf_module) + r"\n|" - extra = "" - if hasattr(leaf_module, "__constants__"): - extra = r"\n".join( - [f"{c}: {getattr(leaf_module, c)}" for c in leaf_module.__constants__] # type: ignore[union-attr] - ) - label += extra + r"\n" - else: - label += f"|target={self._typename(node.target)}" + r"\n" - if len(node.args) > 0: - label += _get_str_for_args_kwargs(node.args) - if len(node.kwargs) > 0: - label += _get_str_for_args_kwargs(node.kwargs) - label += f"|num_users={len(node.users)}" + r"\n" - - tensor_meta = node.meta.get('tensor_meta') - label += self._tensor_meta_to_label(tensor_meta) - - return label + "}" - - def _tensor_meta_to_label(self, tm) -> str: - if tm is None: - return "" - elif isinstance(tm, TensorMetadata): - return self._stringify_tensor_meta(tm) - elif isinstance(tm, list): - result = "" - for item in tm: - result += self._tensor_meta_to_label(item) - return result - elif isinstance(tm, dict): - result = "" - for k, v in tm.items(): - result += self._tensor_meta_to_label(v) - return result - elif isinstance(tm, tuple): - result = "" - for item in tm: - result += self._tensor_meta_to_label(item) - return result - else: - raise RuntimeError(f"Unsupported tensor meta type {type(tm)}") - - def _stringify_tensor_meta(self, tm: TensorMetadata) -> str: - result = "" - if not hasattr(tm, "dtype"): - print("tm", tm) - result += "|" + "dtype" + "=" + str(tm.dtype) + r"\n" - result += "|" + "shape" + "=" + str(tuple(tm.shape)) + r"\n" - result += "|" + "requires_grad" + "=" + str(tm.requires_grad) + r"\n" - result += "|" + "stride" + "=" + str(tm.stride) + r"\n" - if tm.is_quantized: - assert tm.qparams is not None - assert "qscheme" in tm.qparams - qscheme = tm.qparams["qscheme"] - if qscheme in { - torch.per_tensor_affine, - torch.per_tensor_symmetric, - }: - result += "|" + "q_scale" + "=" + str(tm.qparams["scale"]) + r"\n" - result += "|" + "q_zero_point" + "=" + str(tm.qparams["zero_point"]) + r"\n" - elif qscheme in { - torch.per_channel_affine, - torch.per_channel_symmetric, - torch.per_channel_affine_float_qparams, - }: - result += "|" + "q_per_channel_scale" + "=" + str(tm.qparams["scale"]) + r"\n" - result += "|" + "q_per_channel_zero_point" + "=" + str(tm.qparams["zero_point"]) + r"\n" - result += "|" + "q_per_channel_axis" + "=" + str(tm.qparams["axis"]) + r"\n" - else: - raise RuntimeError(f"Unsupported qscheme: {qscheme}") - result += "|" + "qscheme" + "=" + str(tm.qparams["qscheme"]) + r"\n" - return result - - def _get_tensor_label(self, t: torch.Tensor) -> str: - return str(t.dtype) + str(list(t.shape)) + r"\n" - - def _to_dot( - self, - graph_module: pippy.fx.GraphModule, - name: str, - ignore_getattr: bool, - ignore_parameters_and_buffers: bool, - skip_node_names_in_args: bool, - ) -> pydot.Dot: - """ - Actual interface to visualize a fx.Graph. Note that it takes in the GraphModule instead of the Graph. - If ignore_parameters_and_buffers is True, the parameters and buffers - created with the module will not be added as nodes and edges. - """ - dot_graph = pydot.Dot(name, rankdir="TB") - - for node in graph_module.graph.nodes: - if ignore_getattr and node.op == "get_attr": - continue - - style = self._get_node_style(node) - dot_node = pydot.Node( - node.name, label=self._get_node_label(graph_module, node, skip_node_names_in_args), **style - ) - dot_graph.add_node(dot_node) - - def get_module_params_or_buffers(): - for pname, ptensor in chain( - leaf_module.named_parameters(), leaf_module.named_buffers() - ): - pname1 = node.name + "." + pname - label1 = ( - pname1 + "|op_code=get_" + "parameter" - if isinstance(ptensor, torch.nn.Parameter) - else "buffer" + r"\l" - ) - dot_w_node = pydot.Node( - pname1, - label="{" + label1 + self._get_tensor_label(ptensor) + "}", - **_WEIGHT_TEMPLATE, - ) - dot_graph.add_node(dot_w_node) - dot_graph.add_edge(pydot.Edge(pname1, node.name)) - - if node.op == "call_module": - leaf_module = self._get_leaf_node(graph_module, node) - - if not ignore_parameters_and_buffers and not isinstance(leaf_module, pippy.fx.GraphModule): - get_module_params_or_buffers() - - for node in graph_module.graph.nodes: - if ignore_getattr and node.op == "get_attr": - continue - - for user in node.users: - dot_graph.add_edge(pydot.Edge(node.name, user.name)) - - return dot_graph - -else: - if not TYPE_CHECKING: - @compatibility(is_backward_compatible=False) - class FxGraphDrawer: - def __init__(self, graph_module: pippy.fx.GraphModule, name: str, ignore_getattr: bool = False): - raise RuntimeError('FXGraphDrawer requires the pydot package to be installed. Please install ' - 'pydot through your favorite Python package manager.') diff --git a/pippy/fx/passes/graph_manipulation.py b/pippy/fx/passes/graph_manipulation.py deleted file mode 100644 index c4c6716e6..000000000 --- a/pippy/fx/passes/graph_manipulation.py +++ /dev/null @@ -1,111 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates -from typing import Any, Dict, List, NamedTuple, Optional - -import torch -from pippy.fx._compatibility import compatibility -from pippy.fx.graph import Graph -from pippy.fx.graph_module import GraphModule -from pippy.fx.node import ( - map_arg, - Node, - Target, -) -from pippy.fx.passes.shape_prop import ShapeProp - -__all__ = ['replace_target_nodes_with', 'size_bytes', 'get_size_of_all_nodes', 'get_tensor_meta', - 'get_size_of_node'] - -@compatibility(is_backward_compatible=False) -def replace_target_nodes_with( - fx_module: GraphModule, - old_op: str, - old_target: Target, - new_op: str, - new_target: Target, -): - """Modifies all nodes in fx_module.graph.nodes which match the specified op code and target, - and updates them to match the new op code and target""" - new_graph = Graph() - val_map: Dict[Node, Node] = {} - for node in fx_module.graph.nodes: - if node.op == old_op and node.target == old_target: - args = map_arg(node.args, lambda n: val_map[n]) - kwargs = map_arg(node.kwargs, lambda n: val_map[n]) - assert isinstance(args, tuple) - assert isinstance(kwargs, dict) - val_map[node] = new_graph.create_node( - new_op, new_target, args, kwargs, node.name - ) - else: - val_map[node] = new_graph.node_copy(node, lambda n: val_map[n]) - fx_module.graph = new_graph - - -@compatibility(is_backward_compatible=False) -class size_bytes(NamedTuple): - output_size: int - total_size: int - - -@compatibility(is_backward_compatible=False) -def get_size_of_all_nodes( - fx_module: GraphModule, args: Optional[List[torch.Tensor]] = None -) -> None: - """Given a fx graph module, update each node with its total size (weights + bias + output) - and its output_size(output). For a non-module node, the total size is the output size. - return total size""" - if args is not None: - # Mark shape and dtype for each node (node.shape and node.dtype) - ShapeProp(fx_module).propagate(*args) - # Calculate the total size of the whole fx graph - total_size_of_graph = 0.0 - for node in fx_module.graph.nodes: - if node.op == "output": - break - node.size_bytes = get_size_of_node(fx_module, node) - return - - -@compatibility(is_backward_compatible=False) -def get_tensor_meta(node: Node) -> Any: - tensor_meta = node.meta.get("tensor_meta") - - if not tensor_meta: - raise RuntimeError( - f"Node {node} has no tensor metadata associated with it! " - f"Check that shape propagation has run." - ) - - return tensor_meta - - -@compatibility(is_backward_compatible=False) -def get_size_of_node(fx_module: GraphModule, node: Node) -> size_bytes: - """Given a node with node.dtype and node.shape, return its total size and its output size. - total_size = weights + bias + output_size - """ - # Total num of elements - total_num_of_elems = 0 - # For a module, conside all parameters - if node.op == "call_module": - submodule_dict = dict(fx_module.named_modules()) - submodule = submodule_dict[node.target] - parameters = submodule.named_parameters() - # Parameters are named tuples - for name, p in parameters: - total_num_of_elems += p.numel() - # Don't forget the output size - # node.shape is the shape of this node's output - tensor_meta = get_tensor_meta(node) - output_elem = tensor_meta.shape.numel() - total_num_of_elems += output_elem - # Assume for now if it's quantized then it's qint8 or quint8 - if tensor_meta.is_quantized: - size_per_elem_bytes = torch._empty_affine_quantized( - [], dtype=tensor_meta.dtype - ).element_size() - else: - size_per_elem_bytes = torch.tensor([], dtype=tensor_meta.dtype).element_size() - total_size = size_per_elem_bytes * total_num_of_elems - output_size = size_per_elem_bytes * output_elem - return size_bytes(output_size, total_size) diff --git a/pippy/fx/passes/infra/__init__.py b/pippy/fx/passes/infra/__init__.py deleted file mode 100644 index c53c3b3f7..000000000 --- a/pippy/fx/passes/infra/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates -from . import pass_manager diff --git a/pippy/fx/passes/infra/partitioner.py b/pippy/fx/passes/infra/partitioner.py deleted file mode 100644 index b5a39fb2b..000000000 --- a/pippy/fx/passes/infra/partitioner.py +++ /dev/null @@ -1,228 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates -from typing import Dict, List, Set, Iterable, Optional - -from pippy.fx.passes.utils.fuser_utils import fuse_by_partitions -from pippy.fx.passes.tools_common import NodeList - -from pippy.fx.graph_module import GraphModule -from pippy.fx.node import Node, _get_qualified_name -from pippy.fx.passes.operator_support import OperatorSupportBase - -from collections import defaultdict -import logging -import itertools - -logging.basicConfig(level=logging.WARNING) -logger = logging.getLogger(__name__) - -class Partition: - def __init__(self, id: int = None, nodes: Iterable[Node] = None): - self.id = id - self.nodes: Set[Node] = set(nodes) if nodes is not None else set() - - def __repr__(self) -> str: - return str(self.nodes) - - def add_node(self, node: Node): - self.nodes.add(node) - - def remove_node(self, node: Node): - self.nodes.remove(node) - - def size(self): - return len(self.nodes) - -class CapabilityBasedPartitioner: - - def __init__(self, - graph_module: GraphModule, - operator_support: OperatorSupportBase, - allows_single_node_partition: bool = False - ) -> None: - self.graph_module = graph_module - self.operator_support = operator_support - self.allows_single_node_partition = allows_single_node_partition - - # map of node to it's upstream dependency nodes - # if A is found in dependency_map[B], then B depends on A (or a is an upstream depedency of b) - self.dependency_map = self.__build_dependency_map() - - def __build_dependency_map(self) -> Dict[Node, Set[Node]]: - dependency_map = defaultdict(set) - - # assumptions: nodes in graph are sorted in topological order - for node in self.graph_module.graph.nodes: - for input_node in node.all_input_nodes: - # add input_node and input_node's upstream dependency - dependency_map[node].add(input_node) - dependency_map[node].update(dependency_map[input_node]) - - return dependency_map - - def __node_depends_on(self, a: Node, b: Node) -> int: - # Returns - # 1 if b depends on a (,or equivalently a is an upstream depedency of b) - # -1 if a depends on b (,or equivalently b is an upstream depedency of a) - # 0 if a and b doesn't have dependency between each other - - if a in self.dependency_map[b]: - return 1 - elif b in self.dependency_map[a]: - return -1 - else: - return 0 - - def __partition_depends_on(self, partition_a: Partition, partition_b: Partition) -> int: - # Returns - # 1 if b depends on a (,or equivalently a is an upstream depedency of b) - # -1 if a depends on b (,or equivalently b is an upstream depedency of a) - # 0 if a and b doesn't have dependency between each other - - # TODO: build a cache here to speedup the query - - for node_a in partition_a.nodes: - for node_b in partition_b.nodes: - dependency = self.__node_depends_on(node_a, node_b) - if dependency != 0: - return dependency - return 0 - - def __get_supported_nodes(self) -> NodeList: - logging.debug("Collecting supported nodes...") - supported_nodes = [] - for node in self.graph_module.graph.nodes: - if self.operator_support.is_node_supported(dict(self.graph_module.named_modules()), node): - supported_nodes.append(node) - return supported_nodes - - def propose_partitions(self) -> List[Partition]: - candidates: NodeList = self.__get_supported_nodes() - - # assumptions: nodes in candidate list is sorted in topological order - assignment: Dict[Node, int] = {} # maping from node to partition_id - partitions_by_id: Dict[int, Partition] = {} # mapping from partition_id to partition - new_partition_id = itertools.count() - - def assign(node: Node, id: Optional[int] = None): - # If id is None, remove the node from original assigment - - # node has been assigned before, clean up and re-assign - if node in assignment: - original_id = assignment[node] - del assignment[node] - partitions_by_id[original_id].remove_node(node) - if partitions_by_id[original_id].size() == 0: - del partitions_by_id[original_id] - - if id is not None: - assignment[node] = id - if id not in partitions_by_id: - partitions_by_id[id] = Partition(id=id, nodes=[node]) - else: - partitions_by_id[id].add_node(node) - - logging.debug("Proposing partitions...") - - # visit candidates in reversed topological order - for node in reversed(candidates): - # use Dict as an ordered set to ensure deterministic partitioning result, don't care value - user_partitions: Dict[Partition, None] = {} - for user_node in node.users: - if user_node in assignment: - id = assignment[user_node] - user_partitions[partitions_by_id[id]] = None - else: - user_partitions[Partition(nodes=[user_node])] = None - - # Filter out all the partitions that has dependency on other users - # TODO: find a better way to do this, rather than pair-wise comparision - user_partitions_list = list(user_partitions.keys()) - for i in range(len(user_partitions_list)): - for j in range(i + 1, len(user_partitions_list)): - pi = user_partitions_list[i] - pj = user_partitions_list[j] - dependency = self.__partition_depends_on(pi, pj) - if dependency == 1 and pj in user_partitions: - del user_partitions[pj] - elif dependency == -1 and pi in user_partitions: - del user_partitions[pi] - - # We use the following rules for partition assignment: - # 1. If none of the candidates has been assigned to a partition, create a new partition - # 2. If there is one partition candidate, assign to the partition - # 3. If there are more than one partition candidates, assign current node to the first partition and - # merge the other partitions with first partition, since user_partitions doesn't have depedency between - # each other. - - assigned_candidate_partition_ids = [partition.id for partition in user_partitions if partition.id is not None] - - if len(assigned_candidate_partition_ids) == 0: - # create a new partition - assign(node, next(new_partition_id)) - elif len(assigned_candidate_partition_ids) == 1: - id = assigned_candidate_partition_ids[0] - assign(node, id) - else: - # users are assigned to more than one partition, since user_partitions doesn't have - # dependency on each other, they can be fused into a single partition - id = assigned_candidate_partition_ids[0] - assign(node, id) - - reassignment: Dict[Node, int] = {} - for other_id in assigned_candidate_partition_ids[1:]: - for other_node in partitions_by_id[other_id].nodes: - reassignment[other_node] = id - for other_node in reassignment: - assign(other_node, id) - - # post processing to re-assign "getitem" nodes into upstream partition - logger.debug("Reassigning getitem nodes to its producer node's partition...") - nodes_reassignment: Dict[Node, int] = {} - for node in self.graph_module.graph.nodes: - is_tuple_output = True - for user in node.users: - if user.op != "call_function" or \ - _get_qualified_name(user.target) != "_operator.getitem": # type: ignore[arg-type] - is_tuple_output = False - break - - # node has tuple outputs, re-assign all following getitem node into node's partition - if is_tuple_output: - id = assignment.get(node, None) # type: ignore[arg-type] - for user in node.users: - if assignment.get(user, None) != id: # type: ignore[arg-type] - nodes_reassignment[user] = id - for node, id in nodes_reassignment.items(): - assign(node, id) - - # filter out single node partitions - if not self.allows_single_node_partition: - logger.debug("Filtering out single node partitions...") - non_compute_ops = {"torch.ops.aten.view", "_operator.getitem"} - partitions_to_remove: List[int] = [] - for id, partition in partitions_by_id.items(): - compute_node_count = 0 - for node in partition.nodes: - if node.op == "call_function" and \ - _get_qualified_name(node.target) not in non_compute_ops: # type: ignore[arg-type] - compute_node_count += 1 - if compute_node_count <= 1: - partitions_to_remove.append(id) - for id in partitions_to_remove: - del partitions_by_id[id] - - logging.debug("Partitions proposed:") - for id, partition in partitions_by_id.items(): - logging.debug(f"partition #{id}", [node.name for node in partition.nodes]) - - return list(partitions_by_id.values()) - - def fuse_partitions(self, partitions: List[Partition]) -> GraphModule: - logging.debug("Fusing partitions...") - # fuse_by_partitions expects partitions in List[List[Node]]: [ [node0, node1], [node2, node3] ] - return fuse_by_partitions(self.graph_module, [list(partition.nodes) for partition in partitions]) - - def partition_and_fuse(self) -> GraphModule: - partitions = self.propose_partitions() - fused_gm = self.fuse_partitions(partitions) - return fused_gm diff --git a/pippy/fx/passes/infra/pass_base.py b/pippy/fx/passes/infra/pass_base.py deleted file mode 100644 index 711a1a1ca..000000000 --- a/pippy/fx/passes/infra/pass_base.py +++ /dev/null @@ -1,79 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates -import abc -from collections import namedtuple -from typing import Optional - -from pippy.fx.graph_module import GraphModule -from pippy.fx._compatibility import compatibility - - -__all__ = ['PassResult', 'PassBase'] - -@compatibility(is_backward_compatible=False) -class PassResult(namedtuple("PassResult", ["graph_module", "modified"])): - """ - Result of a pass: - graph_module: The modified graph module - modified: A flag for if the pass has modified the graph module - """ - def __new__(cls, graph_module, modified): - return super().__new__(cls, graph_module, modified) - -@compatibility(is_backward_compatible=False) -class PassBase(abc.ABC): - """ - Base interface for implementing passes. - - It is required to implement the `call` function so that we can directly - pass instances of the Pass directly to the PassManager and call them as a - function. - - We can directly pass an instance of a class implementing this interface into - the PassManager's `passes` attribute. - """ - - def __init__(self) -> None: - pass - - def __call__(self, graph_module: GraphModule) -> Optional[PassResult]: - """ - Runs the precondition check, the pass itself, and the postcondition check. - """ - - self.requires(graph_module) - res = self.call(graph_module) - self.ensures(graph_module) - return res - - @abc.abstractmethod - def call(self, graph_module: GraphModule) -> Optional[PassResult]: - """ - The pass that is run through the given graph module. To implement a - pass, it is required to implement this function. - - Args: - graph_module: The graph module we will run a pass on - """ - pass - - def requires(self, graph_module: GraphModule) -> None: - """ - This function will be called before the pass is run and will check that - the given graph module contains the preconditions needed to run the - pass. It is not required to implement this function. - - Args: - graph_module: The graph module we will run checks on - """ - pass - - def ensures(self, graph_module: GraphModule) -> None: - """ - This function will be called after the pass is run and will check that - the given graph module contains the postconditions needed to run the - pass. It is not required to implement this function. - - Args: - graph_module: The graph module we will run checks on - """ - pass diff --git a/pippy/fx/passes/infra/pass_manager.py b/pippy/fx/passes/infra/pass_manager.py deleted file mode 100644 index 27fc6618a..000000000 --- a/pippy/fx/passes/infra/pass_manager.py +++ /dev/null @@ -1,308 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates -import inspect -from queue import Queue -from functools import wraps -from typing import Callable, Dict, List - -import torch.nn as nn -from pippy.fx.graph_module import GraphModule -from pippy.fx._compatibility import compatibility -from pippy.fx.passes.infra.pass_base import PassResult - -__all__ = ['inplace_wrapper', 'pass_result_wrapper', 'this_before_that_pass_constraint', 'PassManager'] - -@compatibility(is_backward_compatible=False) -def inplace_wrapper(fn: Callable) -> Callable: - """ - Convenience wrapper for passes which modify an object inplace. This - wrapper makes them return a PassResult containing the modified object and - True for the "modified" flag. - - Args: - fn (Callable[Module, Any]) - - Returns: - wrapped_fn (Callable[Module, PassResult]) - """ - if fn is None: - return None - - @wraps(fn) - def wrapped_fn(gm): - return fn(gm) or PassResult(gm, True) - - if wrapped_fn.__name__ == 'wrapped_fn': - wrapped_fn.__name__ = str(fn) - return wrapped_fn - -@compatibility(is_backward_compatible=False) -def pass_result_wrapper(fn: Callable) -> Callable: - """ - Wrapper for passes which currently do not return a PassResult. - This wrapper makes them return a PassResult containing the modified object - and True for the "modified" flag. - - Args: - fn (Callable[Module, Any]) - - Returns: - wrapped_fn (Callable[Module, PassResult]) - """ - if fn is None: - return None - - @wraps(fn) - def wrapped_fn(gm): - gm = fn(gm) - return PassResult(gm, True) - - return wrapped_fn - -def _validate_pass_schedule_constraint( - constraint: Callable[[Callable, Callable], bool], passes: List[Callable] -) -> None: - for i, a in enumerate(passes): - for j, b in enumerate(passes[i + 1 :]): - if constraint(a, b): - continue - raise RuntimeError( - f"pass schedule constraint violated. Expected {a} before {b}" - f" but found {a} at index {i} and {b} at index{j} in pass" - f" list." - ) - -def _topological_sort_passes( - passes: List[Callable], constraints: List[Callable] -) -> List[Callable]: - """ - Args - passes: Passes that we are ordering - constraints: Constraints applied on these passes - - Returns - A sorted list of callables and a boolean of if a circular dependency - existed - """ - if len(constraints) == 0: - return passes - - # Contruct a graph mapping nodes to a list of their users - graph: Dict[Callable, List[Callable]] = {p : [] for p in passes} - indegree_map: Dict[Callable, int] = {p : 0 for p in passes} - candidates: Queue = Queue() - for a in passes: - for b in passes: - if a == b: - continue - - for constraint in constraints: - if not constraint(a, b): - graph[b].append(a) - indegree_map[a] += 1 - - if indegree_map[a] == 0: - candidates.put(a) - - visited: Dict[Callable, bool] = {p : False for p in passes} - sorted_passes: List[Callable] = [] - - while not candidates.empty(): - p = candidates.get() - sorted_passes.append(p) - visited[p] = True - - for n in graph[p]: - if not visited[n]: - indegree_map[n] -= 1 - if indegree_map[n] == 0: - candidates.put(n) - - # Check if there are unvisited nodes (aka cycles in the graph) - cycle_passes = list(filter(lambda p: indegree_map[p] != 0, indegree_map.keys())) - if len(cycle_passes) != 0: - error = f"Circular dependency detected within the following passes: {cycle_passes}" - raise RuntimeError(error) - - return sorted_passes - -@compatibility(is_backward_compatible=False) -def this_before_that_pass_constraint(this: Callable, that: Callable) -> Callable: - """ - Defines a partial order ('depends on' function) where `this` must occur - before `that`. - - For example, the following pass list and constraint list would be invalid. - ``` - passes = [pass_b, pass_a] - - constraints = [ - this_before_that_pass_constraint(pass_a, pass_b) - ] - ``` - - Args: - this (Callable): pass which should occur first - that (Callable): pass which should occur later - - Returns: - depends_on (Callable[[Object, Object], bool] - """ - - def depends_on(a: Callable, b: Callable): - if a == that and b == this: - return False - return True - - return depends_on - - -@compatibility(is_backward_compatible=False) -class PassManager: - """ - Construct a PassManager. - - Collects passes and constraints. This defines the pass schedule, manages - pass constraints and pass execution. - - Args: - passes (Optional[List[Callable]]): List of passes. A pass is a - callable which modifies an object and returns a PassResult - constraint (Optional[List[Callable]]): List of constraints. A - constraint is a callable which takes two passes (A, B) and returns - True if A depends on B and False otherwise. See implementation of - `this_before_that_pass_constraint` for example. - steps (int): Max number of times we run the passes (default = 1). - run_checks_after_each_pass (bool): Whether to run checks and linting - after each pass - suppress_check_failures (bool): Whether to raise errors when running - checks - """ - - passes: List[Callable[[nn.Module], PassResult]] = [] - constraints: List[Callable[[Callable, Callable], bool]] = [] - _validated: bool = False - steps: int = 1 - - def __init__( - self, - passes=None, - constraints=None, - steps=None, - run_checks_after_each_pass: bool = False, - suppress_check_failures: bool = False, - debug: bool = False, - ): - if passes: - self.passes = passes - if constraints: - self.constraints = constraints - if steps: - self.steps = steps - - self.run_checks_after_each_pass = run_checks_after_each_pass - self.suppress_check_failures = suppress_check_failures - self.debug = debug - - def add_pass(self, _pass: Callable): - """ - Adds a pass into the current list of passes. - """ - self.passes.append(_pass) - self._validated = False - - def add_constraint(self, constraint: Callable): - """ - Adds a constraint into the current list of constraints. - """ - self.constraints.append(constraint) - self._validated = False - - def validate_constraints(self): - """ - Validates that current pass schedule defined by `self.passes` is valid - according to all constraints in `self.constraints` - """ - if self._validated: - return - for constraint in self.constraints: - _validate_pass_schedule_constraint(constraint, self.passes) - self._validated = True - - def solve_constraints(self): - """ - Finds a valid traversal order based on the given constraints and orders - the passes based on this order. - - If a circular dependency exists between the constraints and steps = 1, - then we will raise an error because if steps != 1 this means that we - will re-run the passes, allowing for circular dependencies. - """ - self.passes = _topological_sort_passes(self.passes, self.constraints) - self._validated = True - - def add_checks(self, check: Callable) -> None: - """ - Adds a function which takes runs various checks on a given graph module. - This function is run before and after each pass if the - `run_checks_after_each_pass` flag is enabled. - """ - sig = inspect.signature(check) - - if len(list(sig.parameters.values())) != 1: - raise TypeError("PassManager check function should only take in one variable, a module") - - setattr(self, "check", check) # noqa: B010 - - def check(self, module: nn.Module) -> None: - pass - - def __call__(self, module: nn.Module) -> PassResult: - """ - Runs a list of passes in the order based on `self.passes` on the given - graph module. Each time a pass is run, checks and linting will be run on - the graph module if `run_checks_after_each_pass` is set. - - If the module is a graph module, we will run the list of passes until - the graph stops changing, or until `steps` number of times. - """ - # Order the passes based on the constraints - if not self._validated: - self.solve_constraints() - - # Check graph invariants - self.check(module) - - # Run the set of passes `steps` number of times or until the graph stops - # changing - overall_modified = False - for _ in range(self.steps): - modified = False - - # Run the set of passes on the graph module - for i, fn in enumerate(self.passes): - if self.debug: - print(f"Running pass \'{fn.__name__}\'") - - try: - res = fn(module) - except Exception as e: - prev_pass_names = [p.__name__ for p in self.passes[:i]] - msg = f"An error occurred when running the \'{fn.__name__}\' pass after the following passes: {prev_pass_names}" - raise type(e)(msg) from e - - module = res.graph_module - modified = modified or res.modified - - if isinstance(module, GraphModule): - module.recompile() - - # Check graph invariants - if self.run_checks_after_each_pass: - self.check(module) - - # If the graph no longer changes, then we can stop running these passes - overall_modified = overall_modified or modified - if not modified: - break - - return PassResult(module, overall_modified) diff --git a/pippy/fx/passes/net_min_base.py b/pippy/fx/passes/net_min_base.py deleted file mode 100644 index da18d980e..000000000 --- a/pippy/fx/passes/net_min_base.py +++ /dev/null @@ -1,619 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates -import logging -from dataclasses import dataclass -from typing import Any, Callable, Dict, List, Optional, Tuple - -import torch -import pippy.fx -from pippy.fx._compatibility import compatibility -from pippy.fx.node import map_arg - -from .shape_prop import ShapeProp -from .split_utils import split_by_tags -from .tools_common import ( - CALLABLE_NODE_OPS, - FxNetAccFusionsFinder, - Names, - NodeList, - NodeSet, - TensorOrTensors, - Tensors, -) - -__all__ = [ - "FxNetMinimizerBadModuleError", - "FxNetMinimizerRunFuncError", - "FxNetMinimizerResultMismatchError", -] - -_LOGGER = logging.getLogger(__name__) - - -@compatibility(is_backward_compatible=False) -class FxNetMinimizerBadModuleError(Exception): - """ - Raised if failed to split out a minimize module - """ - - pass - - -@compatibility(is_backward_compatible=False) -class FxNetMinimizerRunFuncError(Exception): - """ - Raised if error occurs during run_a or run_b functions - """ - - pass - - -@compatibility(is_backward_compatible=False) -class FxNetMinimizerResultMismatchError(Exception): - """ - Raised if comparing function thinks the results are mismatching. - """ - - pass - - -@dataclass -class _MinimizerSettingBase: - """ - Args: - `accumulate_error`: Instead of using a's input for both converted module to verify - , use the previous outputs of each converted module as input to accumulate the - errors. - - `traverse_method`: "sequential" or "binary" or "accumulate" - Determine the way of traverse the nodes in FX module. - - `find_all`: Minimizer will go through the entire model and return all problematic nodes. - - `return_intermediate`: If true, when using `run_nodes()` function to run the - model, intermediate results of all the ops will be returned as output. - """ - - accumulate_error: bool = False - traverse_method: str = "sequential" - find_all: bool = False - return_intermediate: bool = False - - def __str__(self): - settings_str = "FX Minimizer Settings:\n" - - for k, v in vars(self).items(): - settings_str += f"\t{k}: {v}\n" - - return settings_str - - -class _MinimizerBase: - """ - This class is used to automatically find problematic nodes in a model. It takes a FX - graphmodule and generate some submodules while traverse the graph. Then two functions - `run_a` and `run_b` will be used to run the same submodule and a function `compare_fn` - will be used to compare the results. - - Currently we provides two ways to traverse the graph and generate submodules. - 1. Sequential traversal: this will traverse the graph node by node and generate - one submodule with one sigle node. - 2. Binary searching: this will do a binary search style traversal on the graph. - - For internal Users, a guide can be found here https://fb.quip.com/HDtuAgiKGfkP. - """ - - def __init__( - self, - module: pippy.fx.GraphModule, - sample_input: Tensors, - compare_fn: Callable[ - [TensorOrTensors, TensorOrTensors, Names], Tuple[float, bool] - ], - settings: _MinimizerSettingBase, - ): - assert isinstance(module, pippy.fx.GraphModule) - - self.module = module - self.sample_input = sample_input - self.compare_fn = compare_fn - self.settings = settings - - # Stores outputs of run_a function - self.a_outputs: Dict[str, Any] = {} - - # Stores outputs of run_b function - self.b_outputs: Dict[str, Any] = {} - - # Stores the results of compare_fn - self.results: Dict[Any, Any] = {} - - # Stores the report for the runs - self.reports: List[List[str]] = [] - - # Current iteration - self.iteration: int = 0 - - callable_nodes = { - node for node in self.module.graph.nodes if node.op in CALLABLE_NODE_OPS - } - ShapeProp(self.module).propagate(*self.sample_input) - self.fusions = FxNetAccFusionsFinder(self.module, callable_nodes)() - - # Check if number of input in sample_input matches the number of placeholders - placeholders = [ - node.name for node in self.module.graph.nodes if node.op == "placeholder" - ] - assert len(placeholders) == len(self.sample_input) - - # Store sample_input - for i, name in enumerate(placeholders): - self.a_outputs[name] = sample_input[i] - self.b_outputs[name] = sample_input[i] - - def run_a(self, mod: pippy.fx.GraphModule, inputs: Tensors) -> TensorOrTensors: - """ - Run `mod` with `inputs` and generate output. The output will be compared with - output of run_b(). - """ - raise RuntimeError("run_a() is not implemented.") - - def run_b(self, mod: pippy.fx.GraphModule, inputs: Tensors) -> TensorOrTensors: - """ - Run `mod` with `inputs` and generate output. The output will be compared with - output of run_a(). - """ - raise RuntimeError("run_b() is not implemented.") - - def _store_outputs( - self, - a_result: TensorOrTensors, - b_result: TensorOrTensors, - submodule: pippy.fx.GraphModule, - ): - """ - Store the outputs of self.run_a() and self.run_b() into self.a_outputs and - self.b_outputs, so that we can use them when execute preceding nodes that - use those outputs as inputs. - - Args: - a_result: Output of self.run_a(). Could be a tensor or tensors. - b_result: Output of self.run_b(). Could be a tensor or tensors. - submodule: The module that generates a_result and b_result. - """ - output_node = next( - node for node in submodule.graph.nodes if node.op == "output" - ) - - # Only one output - if isinstance(output_node.args[0], pippy.fx.Node): - self.a_outputs[output_node.args[0].name] = a_result - self.b_outputs[output_node.args[0].name] = b_result - # Multiple outputs - else: - for i, arg in enumerate(output_node.args[0]): - self.a_outputs[arg.name] = a_result[i] - self.b_outputs[arg.name] = b_result[i] - - def _get_submod_inputs( - self, main_module: pippy.fx.GraphModule, submod_path: str - ) -> Tuple[Tensors, Tensors]: - """ - Try get submodule inputs from stored outputs. If not found then use - torch_glow.get_submod_inputs to get the inputs. - - If accumulate_error is False, use a_input for run_a() and run_b() - otherwise use a_input for run_a and b_input for run_b. - - Args: - main_module: Top-levlel fx module. - submod_path: Path to the submodule we want to run and compare results. - - Returns: - a_input: List of tensor(s) that will be used by run_a() as submodule inputs. - b_input: List of tensor(s) that will be used by run_b() as submodule inputs. - """ - a_input = [] - b_input = [] - submodule = getattr(main_module, submod_path) - placeholders = [ - node.name for node in submodule.graph.nodes if node.op == "placeholder" - ] - - # If all placeholder can be found in stored outputs, use stored - # outputs as inputs. Otherwise, use `torch_glow.get_submod_inputs` - # to get the inputs. - if set(placeholders) <= self.a_outputs.keys(): - for name in placeholders: - a_input.append(self.a_outputs[name]) - b_input.append(self.b_outputs[name]) - else: - if self.settings.accumulate_error: - print(f"Can't find previous stored outputs named {placeholders}!") - - def get_inputs(self: torch.nn.Module, inputs: Any): - nonlocal a_input - a_input = inputs - - # Use forward hook to get the inputs to the submodule - handle = submodule.register_forward_pre_hook(get_inputs) - main_module(*self.sample_input) - handle.remove() - - b_input = a_input - - if not self.settings.accumulate_error: - return a_input, a_input - - return a_input, b_input - - def _tag_nodes(self, selected_nodes: NodeSet): - """ - Tag selected nodes with tag "minimize". Nodes with the same tags will - be split to the same submodule afterwards. - - Args: - selected_nodes: Nodes that we want to minimize. We will tag those nodes - with "minimize", all preceding nodes with "main_0" and all following - nodes with "main_1". - """ - for node in self.module.graph.nodes: - if node.op not in CALLABLE_NODE_OPS: - continue - - if node in selected_nodes: - node.tag = "minimize" - elif any( - n.tag in {"minimize", "main_1"} - for n in node.all_input_nodes - if n.op in CALLABLE_NODE_OPS - ): - node.tag = "main_1" - else: - node.tag = "main_0" - - def _build_submodule(self, nodes: NodeSet) -> Tuple[pippy.fx.GraphModule, str]: - """ - Split self.module so that one submodule consists of `nodes` and only `nodes`. - - Args: - nodes: Nodes that we want to include in the minimize submodule. - - Returns: - split_module (pippy.fx.GraphModule): the module after split. - submodule_name (str): the name of the submodule that consists of `nodes`. - """ - # Color provided nodes - self._tag_nodes(nodes) - - # Split module based on coloring - split_module = split_by_tags(self.module, ["main_0", "minimize", "main_1"]) - - # Find submodule containing colored nodes - submodule_name: str = "" - for child_name, _ in split_module.named_children(): - # Skip submodules we're not interested in at the moment - if "minimize" not in child_name: - continue - - if submodule_name == "": - submodule_name = child_name - else: - raise FxNetMinimizerBadModuleError( - f"Expected only one minimize submodule with nodes {nodes}" - ) - - if submodule_name == "": - raise FxNetMinimizerBadModuleError( - f"Minimize submodule was not found with nodes {nodes}" - ) - - return split_module, submodule_name - - def _run_and_compare( - self, split_module: pippy.fx.GraphModule, submod_name: str, output_names: Names - ): - """ - Run the submodule in `split_module` that has name `submod_name` - using `self.run_a` and `self.run_b` and compare their results. - - Args: - split_module: Main module that contains the minimize submodule. - submod_name: Name of the minimize submodule. - output_names: Names of the node we want to output. If None, we - will use the original output. - """ - submodule = getattr(split_module, submod_name) - a_input, b_input = self._get_submod_inputs(split_module, submod_name) - - if len(self.reports) == 0: - self.reports.append([]) - self.iteration = 1 - - report = self.reports[self.iteration - 1] - report.append("Run and compare ...") - - if output_names: - output_nodes: NodeList = [] - for node in submodule.graph.nodes: - if node.op == "output": - submodule.graph.erase_node(node) - - if node.name in output_names: - output_nodes.append(node) - - submodule.graph.output( - output_nodes[0] if len(output_nodes) == 1 else tuple(output_nodes) - ) - submodule.graph.lint() - submodule.recompile() - - # Use name of args in output node as key to store comparison result - for node in submodule.graph.nodes: - if node.op == "output": - result_key = map_arg(node.args, lambda x: x.name) - - a_result = self.run_a(submodule, a_input) - b_result = self.run_b(submodule, b_input) - self._store_outputs(a_result, b_result, submodule) - - # Compare results - names: Names = output_names - if output_names is None: - names = [str(v) for v in result_key] - - numeric_result, bool_result = self.compare_fn(a_result, b_result, names) - - self.results[result_key] = numeric_result - report.append(f"Numerical accuracy = {numeric_result}") - if not bool_result: - report.append(f"Result mismatch for {result_key}") - raise FxNetMinimizerResultMismatchError(f"Result mismatch for {result_key}") - - def _binary_search_impl( - self, all_nodes: NodeList, start_idx: int, end_idx: int - ) -> NodeSet: - """ - Recursive binary search implementation. - """ - nodes: NodeList = all_nodes[start_idx:end_idx] - - report: List[str] = [] - self.reports.append(report) - self.iteration += 1 - report.append(f"Binary search iteration {self.iteration}.") - report.append( - f"From node index {start_idx} to {end_idx-1}. " - f"Size of the interested node list is {len(nodes)}" - ) - - cur_nodes: NodeSet = set(nodes) - - for node in nodes: - if node in self.fusions: - cur_nodes.update(self.fusions[node]) - - try: - split_module, submod_name = self._build_submodule(cur_nodes) - self._run_and_compare(split_module, submod_name, []) - except (FxNetMinimizerRunFuncError, FxNetMinimizerResultMismatchError): - - if len(nodes) == 1: - report.append( - f"This is the last node in the sub-module. " - f"Search in the current branch is successful with culprit = {cur_nodes}." - ) - self.print_report(report) - return cur_nodes - - report.append( - "Proceed to split and lower the halves of the current " - "sub-module individually." - ) - self.print_report(report) - - mid = len(nodes) // 2 - culprits = self._binary_search_impl(all_nodes, start_idx, start_idx + mid) - - if len(culprits) != 0 and not self.settings.find_all: - return culprits - - culprits = self._binary_search_impl(all_nodes, start_idx + mid, end_idx) - - if len(culprits) == 0: - report.append( - f"Further split and lowering found no errors. " - f"Unable to minimize the submodule with list of nodes: {nodes}" - ) - self.print_report(report) - - return culprits - else: - report.append("No discrepancy found.") - self.print_report(report) - return set() - - def _binary_traverse(self, nodes: NodeList) -> NodeSet: - """ - Binary search on `nodes` for culprit. - """ - return self._binary_search_impl(nodes, 0, len(nodes)) - - def _sequential_traverse(self, nodes: NodeList) -> NodeSet: - """ - Traverse `nodes` one by one and determine if any of them is a culprit. - """ - culprits: NodeSet = set() - - for node in nodes: - report: List[str] = [] - self.reports.append(report) - self.iteration += 1 - report.append(f"Sequential traverse iteration {self.iteration}.") - report.append(f"Visit node: {node.name}") - - _LOGGER.info(f"Visit node: {node.name}") - cur_nodes: NodeSet = {node} - - if node in self.fusions: - cur_nodes = self.fusions[node] - - try: - split_module, submod_name = self._build_submodule(cur_nodes) - self._run_and_compare(split_module, submod_name, [node.name]) - self.print_report(report) - except (FxNetMinimizerResultMismatchError): - culprits.add(node) - report.append(f"Found culprit from numeric error: {node}") - self.print_report(report) - if not self.settings.find_all: - return culprits - except (FxNetMinimizerRunFuncError): - culprits.update(cur_nodes) - report.append(f"Found culprit from run error: {node}") - self.print_report(report) - if not self.settings.find_all: - return culprits - - return culprits - - def _accumulate_traverse(self, nodes: NodeList) -> NodeSet: - culprits: NodeSet = set() - nodes_to_run: NodeSet = set() - - # find_all is not supported for accumulate traversal because all the - # ops run on NNPI. So we return after the first op that raises error. - if self.settings.find_all: - print("'Find All' mode is not supported in accumulate traversal.") - return culprits - - for node in nodes: - report: List[str] = [] - self.reports.append(report) - self.iteration += 1 - report.append(f"Accumulate traverse iteration {self.iteration}.") - - nodes_to_run.add(node) - - node_name = node.name - if node_name is not None and isinstance(node_name, tuple): - node_name = node_name[0] - assert node_name is not None and isinstance( - node_name, str - ), f"minimize: node_name: {node_name}" - - report.append(f"Add node: {node_name}") - - try: - split_module, submod_name = self._build_submodule(nodes_to_run) - self._run_and_compare(split_module, submod_name, [node_name]) - self.print_report(report) - except (FxNetMinimizerResultMismatchError, FxNetMinimizerRunFuncError): - culprits.add(node) - report.append(f"Found culprit {node}") - self.print_report(report) - return culprits - - return culprits - - def _collect_nodes(self, start: Optional[str], end: Optional[str]) -> NodeList: - """ - Collect nodes in the model that between nodes with name of `start` and `end`. - These two nodes are also included. - """ - nodes: NodeList = [] - add_node = start is None - - for node in self.module.graph.nodes: - if node.op not in CALLABLE_NODE_OPS: - continue - - if node.name == start: - add_node = True - - if add_node: - nodes.append(node) - - if node.name == end: - break - - return nodes - - def run_nodes(self, start: Optional[str] = None, end: Optional[str] = None): - """ - Run part of the model from `start` node to `end` node. If `start` is None - then we start from the beginning of the model. If `end` is None then we - stop at the end of the model. - - Args: - start: The name of the node which is the first node of the submodule - we want to run. If set to None, then we'll start with the first - node of the model. - end: The name of the node which is the last node of the submodule we - want to run. If set to None, we'll end with the last node of the - model. - """ - nodes = self._collect_nodes(start, end) - cur_nodes = set(nodes) - - for node in nodes: - if node in self.fusions: - cur_nodes.update(self.fusions[node]) - - output_names = [] - if self.settings.return_intermediate: - output_names = [node.name for node in nodes] - - try: - split_module, submod_name = self._build_submodule(cur_nodes) - self._run_and_compare(split_module, submod_name, output_names) - except ( - FxNetMinimizerRunFuncError, - FxNetMinimizerResultMismatchError, - ) as e: - print(e) - - def print_report(self, report: List[str]): - for i in range(len(report)): - if i > 0: - print(" . " + report[i]) - else: - print(report[i]) - - def print_reports(self): - for report in self.reports: - self.print_report(report) - - def minimize( - self, start: Optional[str] = None, end: Optional[str] = None - ) -> NodeSet: - """ - Minimizing the model from node with name `start` to node with name `end` base - on self.settings. Find culprits that causes FxNetMinimizerRunFuncError or - FxNetMinimizerResultMismatchError errors. - - Args: - start: The name of the node where we want to start minimizing. If set - to None, then we'll start with the first node of the model. - end: The name of the node where we want to terminate minimizing. If - set to None, we'll end with the last node of the model. - - Returns: - nodes: A list of nodes that causes FxNetMinimizerRunFuncError or - FxNetMinimizerResultMismatchError errors during minimizing. - """ - - print(self.settings) - print(self.module.graph) - - nodes = self._collect_nodes(start, end) - - if self.settings.traverse_method == "sequential": - return self._sequential_traverse(nodes) - - if self.settings.traverse_method == "binary": - return self._binary_traverse(nodes) - - if self.settings.traverse_method == "accumulate": - return self._accumulate_traverse(nodes) - - raise RuntimeError(f"Unknow traverse method {self.settings.traverse_method}!") diff --git a/pippy/fx/passes/operator_support.py b/pippy/fx/passes/operator_support.py deleted file mode 100644 index 62aa708a7..000000000 --- a/pippy/fx/passes/operator_support.py +++ /dev/null @@ -1,207 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates -import abc -import typing as t - -import torch -import pippy.fx -from pippy.fx._compatibility import compatibility -from .shape_prop import TensorMetadata -from .tools_common import get_node_target, CALLABLE_NODE_OPS - - -__all__ = ['OperatorSupportBase', 'OperatorSupport', 'create_op_support', 'chain', 'OpSupports'] - -# fx.Node.target typename, as returned by `get_node_target()` -TargetTypeName = str - -# Arguments' dtypes for a given node, see `OperatorSupport` -SupportedArgumentDTypes = t.Optional[ - t.Tuple[ - t.Sequence[t.Sequence[torch.dtype]], - t.Dict[str, t.Sequence[torch.dtype]], - ] -] - -SupportDict = t.Mapping[TargetTypeName, SupportedArgumentDTypes] - - -@compatibility(is_backward_compatible=False) -class OperatorSupportBase(abc.ABC): - """Interface for determining if a fx.Node is supported by a backend""" - @abc.abstractmethod - def is_node_supported( - self, submodules: t.Mapping[str, torch.nn.Module], node: pippy.fx.Node - ) -> bool: - raise NotImplementedError() - - -@compatibility(is_backward_compatible=False) -class OperatorSupport(OperatorSupportBase): - """ - `_support_dict` maps node.target typename to supported inputs dtypes. - - node.target typename is retrieved using helper function `get_node_target()` - - If supported inputs dtypes is None, it means any dtype is supported, else - we should see a tuple like (([dtypes], ...), {"name":[dtypes], ...}). - - The first tuple ([dtypes], ...) indicates what dtypes are supported for - inputs in node.args and the second dict {"name": [dtypes], ...} indicates - what dtypes are supported for inputs in node.kwargs. - - For inputs in args, if we don't want to check it, we can put None there, - e.g. (None, [torch.float]) indicates that we don't care about the type of - the first input in args. And for inputs in kwargs, if not listed, will not - be checked. - """ - - _support_dict: SupportDict - - def __init__( - self, - support_dict: t.Optional[SupportDict] = None - ): - self._support_dict = support_dict or {} - - def is_node_supported( - self, submodules: t.Mapping[str, torch.nn.Module], node: pippy.fx.Node - ) -> bool: - """ - Args: - `sumodules`: mapping from module name to the module. This can be - retrieved by calling model.named_modules(). - - `node`: a Fx node that we want to determine whether it's supported. - - Returns: - `is_supported`: whether the arg `node` is supported. - """ - if node.op not in CALLABLE_NODE_OPS: - return True - - target = get_node_target(submodules, node) - - # Target not found in _support_dict meaning that we don't support this op at all - if target not in self._support_dict: - return False - - # The rule for target is None meaning that we accept any dtype - if self._support_dict[target] is None: - return True - - args_dtypes, kwargs_dtypes = self._support_dict[target] # type: ignore[misc] - - # Check args dtypes - for i, dtypes in enumerate(args_dtypes): - if len(node.args) <= i: - break - - # None indicates we don't care about the dtype of args[i] - if dtypes is None: - continue - - # If arg is not a node then we don't check it - if not isinstance(node.args[i], pippy.fx.Node): - continue - - arg_dtype = _get_arg_dtype(node.args[i]) # type: ignore[arg-type] - if arg_dtype not in dtypes: - return False - - # Check kwargs dtypes - for k, dtypes in kwargs_dtypes.items(): - if k not in node.kwargs: - continue - - # If arg is not a node then we don't check it - if not isinstance(node.kwargs[k], pippy.fx.Node): - continue - - kwarg_dtype = _get_arg_dtype(node.kwargs[k]) # type: ignore[arg-type] - if kwarg_dtype not in dtypes: - return False - - return True - - -# ====================================================================== -# Functional interfaces and utils for defining basic operator support logic -# and composing them into more complex ones -# ====================================================================== - -IsNodeSupported = t.Callable[[t.Mapping[str, torch.nn.Module], pippy.fx.Node], bool] - - -@compatibility(is_backward_compatible=False) -def create_op_support(is_node_supported: IsNodeSupported) -> OperatorSupportBase: - """Wraps a `IsNodeSupported` function into an `OperatorSupportBase` instance - - `IsNodeSupported` has the same call signature as - `OperatorSupportBase.is_node_supported` - """ - class FunctionalOperatorSupport(OperatorSupportBase): - def is_node_supported( - self, submodules: t.Mapping[str, torch.nn.Module], node: pippy.fx.Node - ) -> bool: - return is_node_supported(submodules, node) - return FunctionalOperatorSupport() - - -@compatibility(is_backward_compatible=False) -def chain(*op_support: OperatorSupportBase) -> OperatorSupportBase: - """Combines a sequence of `OperatorSupportBase` instances to form a single `OperatorSupportBase` - instance by evaluating each input `OperatorSupportBase` instance, and returns False if - any of it reports False. - """ - def _chain(submods, node) -> bool: - return all( - x.is_node_supported(submods, node) - for x in op_support - ) - return create_op_support(_chain) - - -@compatibility(is_backward_compatible=False) -class OpSupports: - """A set of atomic `OperatorSupportBase` instances that can be combined together - to form more complex operator support logic. - """ - @classmethod - def decline_if_input_dtype(cls, dtype: torch.dtype) -> OperatorSupportBase: - """Report a node as non-supported, if any of its arguments is of dtype""" - - def _decline_if_input_dtype( - submodules: t.Mapping[str, torch.nn.Module], - node: pippy.fx.Node, - ) -> bool: - for arg in node.all_input_nodes: - # escape dtype check for get_attr node - if arg.op == "get_attr": - continue - arg_dtype = _get_arg_dtype(arg) - if arg_dtype == dtype: - return False - return True - return create_op_support(_decline_if_input_dtype) - - @classmethod - def decline_if_node_in_names(cls, disallow_set: t.Set[str]) -> OperatorSupportBase: - """ - If a node has a name that is in the disallow set, reported it as non-supported. - """ - def _decline_if_node_in_names( - submodules: t.Mapping[str, torch.nn.Module], - node: pippy.fx.Node, - ) -> bool: - if node.name in disallow_set: - return False - else: - return True - return create_op_support(_decline_if_node_in_names) - - -def _get_arg_dtype(arg: pippy.fx.Node) -> t.Any: - assert isinstance(arg, pippy.fx.Node) - tensor_meta = arg.meta.get("tensor_meta") # type: ignore[union-attr] - dtype = tensor_meta.dtype if isinstance(tensor_meta, TensorMetadata) else arg.meta["type"] - return dtype diff --git a/pippy/fx/passes/param_fetch.py b/pippy/fx/passes/param_fetch.py deleted file mode 100644 index 411134f7b..000000000 --- a/pippy/fx/passes/param_fetch.py +++ /dev/null @@ -1,67 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates -from pippy.fx.graph_module import GraphModule -from typing import Any, Callable, Dict, List, Tuple, Type -import torch -import torch.nn as nn - -from pippy.fx._compatibility import compatibility - -__all__ = ['default_matching', 'extract_attrs_for_lowering', 'lift_lowering_attrs_to_nodes'] - -# Matching method matches the attribute name of current version to the attribute name of `target_version` -@compatibility(is_backward_compatible=False) -def default_matching(name: str, target_version: int) -> str: - """Default matching method - """ - return name - -# This dict maps the nn.Module class name to the attribute name list that we want to fetch for lowering. -# The first integer in the tuple is the version number of the nn.Module class when we create the parameter list. -# If there's a version mismatch then it means the parameter names in the book might be mismatched with nn.Module. -module_fetch_book: Dict[Type, Tuple[int, List[str], Callable[[str, int], str]]] = { - torch.nn.modules.linear.Linear: (1, ["weight", "bias"], default_matching), - torch.nn.modules.conv.Conv2d: ( - 1, ["weight", "bias", "kernel_size", "stride", "padding", "dilation", "groups", "padding_mode"], default_matching - ), - torch.nn.modules.batchnorm.BatchNorm2d: (2, ["weight", "bias", "running_mean", "running_var", "eps"], default_matching), - torch.nn.modules.pooling.AdaptiveAvgPool2d: (1, [], default_matching), - torch.nn.modules.pooling.MaxPool2d: ( - 1, ["kernel_size", "stride", "padding", "dilation", "return_indices", "ceil_mode"], default_matching - ), - torch.nn.modules.activation.ReLU: (1, ["inplace"], default_matching), -} - -@compatibility(is_backward_compatible=False) -def extract_attrs_for_lowering(mod: nn.Module) -> Dict[str, Any]: - """If `mod` is in `module_fetch_book`, fetch the mod's attributes that in the `module_fetch_book` - after checking module's version is compatible with the `module_fetch_book`. - """ - attrs_for_lowering: Dict[str, Any] = {} - attrs_for_lowering["name"] = torch.typename(mod) - - if type(mod) in module_fetch_book: - version, param_to_fetch, matching_method = module_fetch_book[type(mod)] - if version < mod._version: - raise RuntimeError(f"Fetcher version {version} try to fetch {torch.typename(mod)} version {mod._version}, " - "please upgrade the module_fetch_book, open an issue and @842974287 " - "or report a bug to AIACC team directly.") - for attr in param_to_fetch: - attrs_for_lowering[attr] = getattr(mod, matching_method(attr, mod._version)) - else: - raise RuntimeError(f"{torch.typename(mod)} is not in the module_fetch_book yet, " - "please add it to the module_fetch_book, open an issue and @842974287 " - "or report a bug to AIACC team directly.") - return attrs_for_lowering - -@compatibility(is_backward_compatible=False) -def lift_lowering_attrs_to_nodes(fx_module: GraphModule) -> None: - """Recursively traverse all `fx_module` nodes and fetch the module's attributes if the node is a leaf module. - """ - submodules = dict(fx_module.named_modules()) - - for node in fx_module.graph.nodes: - if node.op == "call_module": - if isinstance(submodules[node.target], GraphModule): - lift_lowering_attrs_to_nodes(submodules[node.target]) - else: - node.attrs_for_lowering = extract_attrs_for_lowering(submodules[node.target]) diff --git a/pippy/fx/passes/pass_manager.py b/pippy/fx/passes/pass_manager.py deleted file mode 100644 index 3bdde31b5..000000000 --- a/pippy/fx/passes/pass_manager.py +++ /dev/null @@ -1,242 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates -from functools import wraps -from inspect import unwrap -from typing import Callable, List -import logging - -logger = logging.getLogger(__name__) - - -# for callables which modify object inplace and return something other than -# the object on which they act -def inplace_wrapper(fn: Callable) -> Callable: - """ - Convenience wrapper for passes which modify an object inplace. This - wrapper makes them return the modified object instead. - - Args: - fn (Callable[Object, Any]) - - Returns: - wrapped_fn (Callable[Object, Object]) - """ - - @wraps(fn) - def wrapped_fn(gm): - val = fn(gm) - return gm - - return wrapped_fn - -def log_hook(fn: Callable, level=logging.INFO) -> Callable: - """ - Logs callable output. - - This is useful for logging output of passes. Note inplace_wrapper replaces - the pass output with the modified object. If we want to log the original - output, apply this wrapper before inplace_wrapper. - - - ``` - def my_pass(d: Dict) -> bool: - changed = False - if 'foo' in d: - d['foo'] = 'bar' - changed = True - return changed - - pm = PassManager( - passes=[ - inplace_wrapper(log_hook(my_pass)) - ] - ) - ``` - - Args: - fn (Callable[Type1, Type2]) - level: logging level (e.g. logging.INFO) - - Returns: - wrapped_fn (Callable[Type1, Type2]) - """ - @wraps(fn) - def wrapped_fn(gm): - val = fn(gm) - logger.log(level, f"Ran pass {fn}\t Return value: {val}",) - return val - - return wrapped_fn - - - -def loop_pass(base_pass: Callable, n_iter: int = None, predicate: Callable = None): - """ - Convenience wrapper for passes which need to be applied multiple times. - - Exactly one of `n_iter`or `predicate` must be specified. - - Args: - base_pass (Callable[Object, Object]): pass to be applied in loop - n_iter (int, optional): number of times to loop pass - predicate (Callable[Object, bool], optional): - - """ - assert (n_iter is not None) ^ ( - predicate is not None - ), "Exactly one of `n_iter`or `predicate` must be specified." - - @wraps(base_pass) - def new_pass(source): - output = source - if n_iter is not None and n_iter > 0: - for _ in range(n_iter): - output = base_pass(output) - elif predicate is not None: - while predicate(output): - output = base_pass(output) - else: - raise RuntimeError( - f"loop_pass must be given positive int n_iter (given " - f"{n_iter}) xor predicate (given {predicate})" - ) - return output - - return new_pass - - -# Pass Schedule Constraints: -# -# Implemented as 'depends on' operators. A constraint is satisfied iff a list -# has a valid partial ordering according to this comparison operator. -def _validate_pass_schedule_constraint( - constraint: Callable[[Callable, Callable], bool], passes: List[Callable] -): - for i, a in enumerate(passes): - for j, b in enumerate(passes[i + 1 :]): - if constraint(a, b): - continue - raise RuntimeError( - f"pass schedule constraint violated. Expected {a} before {b}" - f" but found {a} at index {i} and {b} at index{j} in pass" - f" list." - ) - - -def this_before_that_pass_constraint(this: Callable, that: Callable): - """ - Defines a partial order ('depends on' function) where `this` must occur - before `that`. - """ - - def depends_on(a: Callable, b: Callable): - if a == that and b == this: - return False - return True - - return depends_on - - -def these_before_those_pass_constraint(these: Callable, those: Callable): - """ - Defines a partial order ('depends on' function) where `these` must occur - before `those`. Where the inputs are 'unwrapped' before comparison. - - For example, the following pass list and constraint list would be invalid. - ``` - passes = [ - loop_pass(pass_b, 3), - loop_pass(pass_a, 5), - ] - - constraints = [ - these_before_those_pass_constraint(pass_a, pass_b) - ] - ``` - - Args: - these (Callable): pass which should occur first - those (Callable): pass which should occur later - - Returns: - depends_on (Callable[[Object, Object], bool] - """ - - def depends_on(a: Callable, b: Callable): - if unwrap(a) == those and unwrap(b) == these: - return False - return True - - return depends_on - - -class PassManager: - """ - Construct a PassManager. - - Collects passes and constraints. This defines the pass schedule, manages - pass constraints and pass execution. - - Args: - passes (Optional[List[Callable]]): list of passes. A pass is a - callable which modifies an object and returns modified object - constraint (Optional[List[Callable]]): list of constraints. A - constraint is a callable which takes two passes (A, B) and returns - True if A depends on B and False otherwise. See implementation of - `this_before_that_pass_constraint` for example. - """ - - passes: List[Callable] = [] - constraints: List[Callable] = [] - _validated: bool = False - - def __init__( - self, - passes=None, - constraints=None, - ): - if passes: - self.passes = passes - if constraints: - self.constraints = constraints - - @classmethod - def build_from_passlist(cls, passes): - pm = PassManager(passes) - # TODO(alexbeloi): add constraint management/validation - return pm - - def add_pass(self, _pass: Callable): - self.passes.append(_pass) - self._validated = False - - def add_constraint(self, constraint): - self.constraints.append(constraint) - self._validated = False - - def remove_pass(self, _passes: List[Callable]): - if _passes is None: - return - passes_left = [] - for ps in self.passes: - if ps.__name__ not in _passes: - passes_left.append(ps) - self.passes = passes_left - self._validated = False - - def validate(self): - """ - Validates that current pass schedule defined by `self.passes` is valid - according to all constraints in `self.constraints` - """ - if self._validated: - return - for constraint in self.constraints: - _validate_pass_schedule_constraint(constraint, self.passes) - self._validated = True - - def __call__(self, source): - self.validate() - out = source - for _pass in self.passes: - out = _pass(out) - return out diff --git a/pippy/fx/passes/reinplace.py b/pippy/fx/passes/reinplace.py deleted file mode 100644 index 94419bbc9..000000000 --- a/pippy/fx/passes/reinplace.py +++ /dev/null @@ -1,663 +0,0 @@ -import torch -import pippy -from pippy.fx import Node -from pippy.fx._compatibility import compatibility -from torch._subclasses.fake_tensor import FakeTensorMode, FakeTensor -from torch.utils._pytree import tree_map, tree_flatten, tree_map_only -from torch.multiprocessing.reductions import StorageWeakRef - -import _operator -from enum import Enum -import itertools -from typing import Set, Dict -from collections import defaultdict - -__all__ = ['reinplace'] - -class _ViewType(Enum): - NonView = 0 - SingleOutputView = 1 - MultiOutputView = 2 - -def _is_view_op(tgt): - if tgt is not None and isinstance(tgt, torch._ops.OpOverload): - schema = tgt._schema - if len(schema.arguments) > 0: - first_arg = schema.arguments[0] - # check if op is a view - return first_arg.alias_info is not None and not first_arg.alias_info.is_write - -def _get_view_type(tgt) -> _ViewType: - if tgt is not None and isinstance(tgt, torch._ops.OpOverload): - schema = tgt._schema - if len(schema.arguments) > 0: - first_arg = schema.arguments[0] - # check if op is a view - if first_arg.alias_info is not None and not first_arg.alias_info.is_write: - # check if op is a multi-output view - if '*' in first_arg.alias_info.after_set: - return _ViewType.MultiOutputView - else: - return _ViewType.SingleOutputView - return _ViewType.NonView - - -# Stores a bunch of metadata related to functionalization each node. -# Relevant metadata: -# n.meta['fake_result']: FakeTensor (same type as the output of the node, but with FakeTenors instead of Tensors) -# The fake tensor output from running the current node -# n.meta['view_of']: Node -# If the current node n is a view of some base tensor, the 'view_of' field tells us which -# view node was used to generate the current node (a view tensor). -# This information actually makes `fake_result` redundant, but we can use `fake_result` -# to sanity check that our aliasing information is correct. -@compatibility(is_backward_compatible=False) -class _FunctionalizationMetadataProp(pippy.fx.Interpreter): - - def run_node(self, node: Node): - self.node_counter += 1 - result = super().run_node(node) - node.meta['fake_result'] = result - node.meta['node_idx'] = self.node_counter - - # (1) Update metadata with the list of nodes that are used by this node - # copy_() doesn't read from its first argument; it writes to it, overwriting previous data. - # We don't want to treat it as "being used as an input". - node_args = node.args - if node.target is torch.ops.aten.copy_.default: - node_args = node_args[1:] - - # (2) Update metadata to track aliasing information about view tensor nodes. - if node.op == 'call_function': - view_type = _get_view_type(node.target) - if view_type == _ViewType.SingleOutputView: - assert isinstance(node.args[0], Node) - node.meta['view_of'] = node.args[0] - elif view_type == _ViewType.MultiOutputView: - self.multi_output_view_nodes[node] = node.args[0] - - # Check if we returned a multi-output view, - # and we're now grabbing the individual views from the output. - # - # For multi-output views, we want to map each output view to the base, - # but this mapping involves two separate nodes in FX IR. - # e.g. "a, b = x_1.split(...)" becomes: - # %split_tensor : [#users=2] = call_function[target=torch.ops.aten.split.Tensor](args = (%x_1, 2), kwargs = {}) - # %getitem : [#users=1] = call_function[target=operator.getitem](args = (%split_tensor, 0), kwargs = {}) - # %getitem_1 : [#users=1] = call_function[target=operator.getitem](args = (%split_tensor, 1), kwargs = {}) - # And we'd like to set: - # getitem1.meta['view_of'] = x_1 - elif node.target is _operator.getitem: - list_arg = node.args[0] - maybe_base_of_view = self.multi_output_view_nodes.get(list_arg, None) - if maybe_base_of_view is not None: - # Note: we could also track indexing info here for multi-output views. - # I don't think this metadata is strictly needed for de-functionalization. - assert isinstance(maybe_base_of_view, Node) - node.meta['view_of'] = maybe_base_of_view - - if 'view_of' in node.meta: - # We're linking the current node with its first argument as views. - # Assert here that this is actually the case, and their storages are the same. - assert isinstance(node.meta['fake_result'], FakeTensor) - assert isinstance(node.meta['view_of'].meta['fake_result'], FakeTensor) - view_storage = StorageWeakRef(node.meta['fake_result'].storage()) - base_storage = StorageWeakRef(node.meta['view_of'].meta['fake_result'].storage()) - assert view_storage == base_storage - return result - - - - def propagate(self, *args): - self.multi_output_view_nodes = {} - self.node_counter = -1 - - with FakeTensorMode(allow_meta=True) as mode: - fake_args = [mode.from_tensor(a) for a in args] - return super().run(*fake_args) - -def _schemas_match(functional_schema, inplace_schema): - names_match = inplace_schema.name.endswith("_") and inplace_schema.name[:-1] == functional_schema.name - arg_types_match = len(functional_schema.arguments) == len(inplace_schema.arguments) and all( - a1.type == a2.type for a1, a2 in zip(functional_schema.arguments, inplace_schema.arguments)) - # for the inplace op, its first argument should be mutable - assert inplace_schema.arguments[0].alias_info is not None and inplace_schema.arguments[0].alias_info.is_write - # and its remaining arguments shouldn't be. - assert all(a.alias_info is None for a in inplace_schema.arguments[1:]) - return names_match and arg_types_match - -# TODO: this should be beefed up to be able to properly re-inplace with: -# - mutating ops (e.g. _fused_moving_avg_obs_fq_helper) -# - out= ops (e.g. angle -> angle.out) -# TODO: we should also figure this info out using torchgen. -def _maybe_get_inplace_op(op): - # __module__ seems broken; it returns torch._ops.aten which doesn't exist - if not isinstance(op, torch._ops.OpOverload): - return None - # Some view ops have inplace variants (as_strided_, etc), - # but we do NOT want the reinplacing pass to directly add these into the program. - # (they'll require extra special handling, aren't aren't really useful for perf anyway) - if _is_view_op(op): - return None - op_namespace = op.__module__.split(".")[-1] - op_base_name = op.overloadpacket.__name__ - maybe_namespace_module = getattr(torch.ops, op_namespace) - maybe_inplace_op = None if maybe_namespace_module is None else getattr(maybe_namespace_module, f'{op_base_name}_', None) - if maybe_inplace_op is None: - return None - - inplace_overloads = [ - getattr(maybe_inplace_op, overload_name) for overload_name in maybe_inplace_op.overloads() - ] - inplace_overloads_with_matching_schemas = [ - f - for f in inplace_overloads - if _schemas_match(op._schema, f._schema) - ] - # Just becuase foo() and foo_() are both existing operators, - # They aren't guaranteed to have compatible schemas. - # For example, pow.Scalar(Scalar self, Tensor exponent) has no valid inplace variant, - # Even though several overloads of pow_ exist. - if len(inplace_overloads_with_matching_schemas) == 0: - return None - assert len(inplace_overloads_with_matching_schemas) == 1 - inplace_op = inplace_overloads_with_matching_schemas[0] - return inplace_op - -_VIEW_INVERSE_MAP = { - torch.ops.aten.diagonal_scatter.default: torch.ops.aten.diagonal.default, - torch.ops.aten.select_scatter.default: torch.ops.aten.select.int, - torch.ops.aten.slice_scatter.default: torch.ops.aten.slice.Tensor, - torch.ops.aten.as_strided_scatter.default: torch.ops.aten.as_strided.default, -} - -# This function, given a set of set of (aliased) tensor nodes, -# Returns any nodes in the graph that *use* any of the aliases, that occur *after* op_index -# in the node ordering. -def _get_all_later_node_usages(tensor_aliases: Set[Node], op_index: int): - def _add_if_tensor(x, set_): - if isinstance(x, FakeTensor): - set_.add(StorageWeakRef(x.storage())) - - nodes_used_after = set() - for t in tensor_aliases: - # get all nodes that use the current alias - usage_nodes = t.users - for n in usage_nodes: - # We only care about usages after the current node - if 'node_idx' not in n.meta or n.meta['node_idx'] <= op_index: - continue - # We also don't care about intermediate view ops. - # They only matter if their output is then used elsewhere - # (either in an out-of-place op, or as an output to the function). - if n in tensor_aliases: - if isinstance(n.target, torch._ops.OpOverload) or n.target == _operator.getitem: - continue - nodes_used_after.add(n) - return nodes_used_after - -# Given an op that we're trying to re-inplace, "b = foo(a)", -# And given a {view}_scatter op that shows up later in the graph, "y = {view}_scatter(base, x, args...)" -# Then re-inplacing `foo()` would allow us to remove the `{view}_scatter` op entirely, IF: -# If there are any aliases in the alias_set(a) that satisfy: -# (1) The base of "alias", "alias_base", has the same size/stride/offset metadata as "base" -# (2) The output of running {view}(alias, args...) gives you the same size/stride/offset metadata -# as "alias" -def _get_view_inverse_node_usages(later_node_usages: Set[Node], self_aliases: Set[Node]) -> Set[Node]: - def matching_view_metadata(a, b): - return a.size() == b.size() and \ - a.stride() == b.stride() and \ - a.storage_offset() == b.storage_offset() - - view_inverse_nodes = set() - # Go through them in node order, so we can see chains of view_scatter ops. - for n in sorted(later_node_usages, key=lambda x: x.meta['node_idx']): - if n.target not in _VIEW_INVERSE_MAP: - continue - base = n.args[0] - mutated_view = n.args[1] - assert isinstance(base, Node) - assert isinstance(base.meta['fake_result'], FakeTensor) - assert isinstance(mutated_view, Node) - assert isinstance(mutated_view.meta['fake_result'], FakeTensor) - # Check that this view_inverse op actually corresponds to taking doing the inverse - # of one of our existing self_alias nodes. - original_view = _VIEW_INVERSE_MAP[n.target] - for self_alias in self_aliases: - # We're looking for some alias of the self arg, "alias", - # that was created from some op `alias = foo(base, args...)` - # such that the current _scatter op "inverts" that foo call. - # We can check that by running the original op again, and checking that the strides match. - if 'view_of' not in self_alias.meta: - continue - self_alias_base = self_alias.meta['view_of'] - try: - # The we're trying to re-use the args from the view_scatter call inside of the corresponding - # view op, which might throw. This just indicates that view_scatter op isn't a valid inverse - # of the current alias we're looking at. - view_replay_metadata = original_view(self_alias_base.meta['fake_result'], *n.args[2:], **n.kwargs) - expected_metadata = self_alias.meta['fake_result'] - # If the alias and its base both have matching metadata, then this view_scatter op is valid to re-inplace. - if matching_view_metadata(self_alias_base.meta['fake_result'], base.meta['fake_result']) and \ - matching_view_metadata(view_replay_metadata, expected_metadata): - view_inverse_nodes.add(n) - except Exception: - continue - - return view_inverse_nodes - - -@compatibility(is_backward_compatible=True) -def reinplace(gm, *sample_args): - """ - Given an fx.GraphModule, modifies it to perform "reinplacing", - mutating the nodes of the graph. - We look for out-of-place op call sites like `b = a.add(...)`, - and convert them to be inplace (`b = a.add_(...)`), - as long as the input to the current operator ("a") isn't re-used - anywhere later in the graph. - - This pass currently expects to operate on a **functional, ATen** graph. - This can be obtained by running `make_fx(functionalize(f))`. - - Sample inputs are needed to determine aliasing relationships of the inputs. - In general, we can't reinplace node `b = a.add(...)` if "a" aliases any of the - inputs to the program. - - Given a node "b = foo(a, args...) the algorithm for re-inplacing is as follows: - - (1) Perform some initial checks on the metadata of "a" and "args..." - that can disqualify them from being reinplaced. - - (1a) Check that the self argument we're attempting to reinplace - has acceptable dtype/size metadata to reinplace with. - - For example, if we have: - a = torch.ones(1) - b = torch.ones(10) - out = torch.add(a, b) - We can't turn that into - a.add_(b) - Because that would require resizing "a". - - Similarly, we can't convert torch.ge(a, b) into a.ge_(b), - beause that would require changing a's dtype (from e.g. float32 to bool). - Note that in this specific example, we could technically do better.. - - If we see the pattern: - a_1 = a.ge(b) - a_2 = aten._to_copy(a_1, a.dtype) - Then we this should be valid to completely re-inplace - (this is exactly what functionalization will emit when it sees a.ge_(b)). - - This optimization is only really important for user programs - that directly use inplace comparison ops though. - - We also cannot re-inplace on tensors that have overlapping memory, - e.g. torch.ones(1).expand(4, 4).add_(1) - - (1b) Check if "a" is an alias of any of the program inputs. - - If it is, skip and move to the next node. - Inplace'ing an op that would cause it to mutate a program is not sound, - because that would be a side effect visible to the user. - - NOTE: there's a future optimization that we should make: - if "a" is a (alias of a) program input, but later in the program - there is a node that looks like "a.copy_(...)", - Then re-inplacing is ok to do - we are temporarily re-using a's buffer, - which will later be overwritten by the copy_() call. - - This will be an important optimization to have for programs that mutate - their inputs. It currently isn't implemented though. - - (1c) Check if "a" and "args..." alias - - For example, re-inplacing to create code like the below - isn't guaranteed to be sound: - - aten.mul_(a, a) - - (2) Check that "a" and all of its outstanding aliases are not used anywhere - later in the graph. If this is the case, then it's safe to re-inplace - to "b = foo_(a)". - - There are a few caveats to this, explained in more detail below: - (a) If "a" is used later as an argument to a view op, that is okay. - It's only a problem if "a" (or that view) is later passed - into a normal operator, or if it is returned as the program output. - (b) If "a" is a repeat argument in `foo()`, then don't reinplace. - Most ATen kernels don't make any guarantees that this is sound, - e.g. if you do aten.mul_(a, a). - So we'll just ban re-inplacing in this case. - It's only a problem if "a" (or that view) is later passed - (c) If "a" is used as an input into a view "inverse" / "scatter" - operator, it is potentially fine to re-inplace - (and remove that scatter operator from the graph). - See below for a more detailed example. - - NOTE: there is an optimization in this step that is crucial - to fully recovering performance from functionalization. - - Given this program: - def f(x): - a = torch.ops.aten.add(x, x) - b = torch.ops.aten.diagonal(a) - torch.ops.aten.fill_(b, 0) - return d - - Functionalization will emit the following: - def f(x): - a = torch.ops.aten.add(x, x) - b = torch.ops.aten.diagonal(a, 0, 1) - b_updated = torch.ops.aten.fill(b, 0) - a_updated = torch.ops.aten.diagonal_scatter(a, b_updated, 0, 1) - return a_updated - - Ordinarily, we would not be able to reinplace the fill, - because "b" aliases with "a" which is used by the diagonal_scatter call. - - "re-inplacing" is on the hook for figuring out that it is ok to - completely, the expensive diagonal_scatter call, if we re-inplace the add(). - - So, for every `alias in alias_set(a)`, instead of checking - that "alias" is not used anywhere later in the graph, - we check that - EITHER: - (a) alias is not used anywhere later in the graph - OR: - (b) alias is used exactly once later on in the graph, - in the following op: - - out = foo_scatter(alias, x, args...) - - where the following must hold: - (i) "foo_scatter" is the "inverse" operator for foo. - This only applies to "foo" ops that are view operators, - which view into a subset of the original tensor's memory. - In practice, there are ~4 operators where this applies: - diagonal -> diagonal_scatter - slice -> slice_scatter - select -> select_scatter - as_strided -> as_strided_scatter - (ii) "args..." are the same between the foo() and foo_scatter() calls. - - (3) Perform the actual re-inplacing on foo! - - (3b) is the common case, but special care is needed for {view}_scatter (3a) - - (3a) {view}_scatter ops. - - Consider this program: - a = torch.zeros(2, 2) - b = torch.ones(2) - a[0] = b - - Post functionalization, that will look like: - a = torch.zeros(2) - b = torch.ones(1) - a_updated = torch.select_scatter(a, b, 0, 0) - - In this case though, there is no "functional" op to re-inplace! - Instead, we'd like to directly remove toe select_scatter call. - We already know from (3) that this is valid, - because "a" has no later usages in the graph. - - We perform the re-inplacing on the {view}_scatter op like so - Before: - a_updated = torch.select_scatter(a, b, args...) - After: - a_slice = a.select(a, args...) - a_slice.copy_(b) - - (3b) Otherwise, replace the functional op with its inplace variant. - Before: - b = foo(a, args...) - After: - a.foo_(args...) - - (4) Finally, after converting either: - Before: - b = foo(a) - After: - foo_(a) - or - Before: - b = {slice}_scatter(a, mutated_slice, args...) - After: - slice = {slice}(a, args...) - slice.copy_(mutated_slice) - - We now need to find all later nodes that use "b" as an argument - and update them to take in "a" instead. - - Note that for the majority of inplace ops, this isn't actually necessary - (because most inplace ops return "self" as their output). - This isn't generally true for all mutable ops though, which is why - we need to actually replace all of the arguments. - - We also need to update our metadata of Dict[StorageWeakRef, Set[Node]], - That maps a given tensor storage to the set of all nodes that take in that storage - as an input. - Specifically, re-inplacing `b = foo(a)` causes "a" and "b"'s sets to get fused - together. - - (5) Any "view_inverse/scatter" nodes that were identified as "it's ok to ignore them" - during step (3) get manually deleted from the graph. - Their outputs are no longer used, so technically standard DCE would be able - to do this, but we can no longer run FX's DCE pass now that we have mutable - ops in the graph. - """ - _FunctionalizationMetadataProp(gm).propagate(*sample_args) - - # Useful debug printing - # def _print(x): - # if isinstance(x, FakeTensor): - # print(f'fake_result: {StorageWeakRef(x.storage()).cdata}') - - # for n in gm.graph.nodes: - # print(n.format_node()) - # if hasattr(n, 'meta'): - # print(f'node_idx: {n.meta["node_idx"]}') - # if 'fake_result' in n.meta: - # tree_map(_print, n.meta['fake_result']) - # if 'view_of' in n.meta: - # print(f'view_of: {str(n.meta["view_of"])}') - # print() - - # We need to know which nodes correspond to inputs (or their aliases) - # so we know not to re-inplace them. - # NOTE: later, we'll need to add an optimization for fully recovering performance - # on programs that mutate inputs. - input_storages = set(StorageWeakRef(node.meta['fake_result'].storage()) for node in gm.graph.nodes if node.op == 'placeholder') - - - # We also need to know for a given node, what are all of its aliasing nodes. - storage_to_nodes: Dict[StorageWeakRef, Set[Node]] = defaultdict(set) - for n in gm.graph.nodes: - if 'fake_result' in n.meta: - # Tree-mapping because some ops can return lists of tensors. - def _add_to_map(x): - if isinstance(x, FakeTensor): - storage_to_nodes[StorageWeakRef(x.storage())].add(n) - tree_map(_add_to_map, n.meta['fake_result']) - - # inplace-ify functional ops, subject to the constraints written below. - all_later_view_inverse_nodes_to_delete = set() - for idx, node in enumerate(gm.graph.nodes): - if node.op == 'call_function': - - # Today, the re-inplace pass on directly acts on: - # - functional ops with an inplace variant - # - {view}_scatter ops that can be potentially removed from the graph. - # Both of these ops take in tensor first args, so filtering on this condition - # makes the later code simpler. - # We should revisit this at some point though, particularly when we also want - # the reinplacer to be able to handle out= and mutable operators - # and tensorlist first args (like `_foreach_` ops). - if not isinstance(node.target, torch._ops.OpOverload): - continue - if len(node.target._schema.arguments) < 1: - continue - if type(node.target._schema.arguments[0].type) != torch.TensorType: - continue - - # Step 1a: Check that the self argument we're attempting to reinplace - # has the same size/stride as the output. - # For example, we shouldn't try to reinplace torch.add(scalar_tensor, larger_tensor) - # As it would require resizing scalar_tensor. - # (We could potentially swizzle this into larger_tensor.add_(scalar_tensor), - # this is probably an optimization to revisit later). - self_arg = node.args[0] - self_flattened, _ = tree_flatten(self_arg.meta['fake_result']) - node_flattened, _ = tree_flatten(node.meta['fake_result']) - self_has_wrong_metadata = False - if len(self_flattened) == len(node_flattened): - for self_meta, node_meta in zip(self_flattened, node_flattened): - if self_meta.numel() != node_meta.numel(): - self_has_wrong_metadata = True - if self_meta.dtype != node_meta.dtype: - self_has_wrong_metadata = True - # We also cannot re-inplace on tensors that have internal memory overlap. - # e.g. torch.ones(1).expand(4, 4).add_(1) - if torch._debug_has_internal_overlap(self_meta) == 1: - self_has_wrong_metadata = True - # Here, we (optimistically) assume that a.resize(b) is valid to re-inplace, - # Since users should never really be calling the functional "torch.ops.aten.resize" - # op directly in their programs. - if self_has_wrong_metadata and node.target != torch.ops.aten.resize.default: - continue - - # Step 1b: ensure that the op we're trying to re-inplace isn't a program input - self_arg_name = self_arg.name - self_arg_storage = StorageWeakRef(self_arg.meta['fake_result'].storage()) - if self_arg_storage in input_storages: - # TODO: later, add the optimization for handling `copy_()` calls in the graph. - continue - if len([x for x in node.args if x is self_arg]) > 1: - # Step 1c: - # Calling stuff like aten.mul_(a, a) isn't guaranteed to be sound, - # so we prevent re-inplacing in this case. - continue - - self_arg_storage = StorageWeakRef(self_arg.meta['fake_result'].storage()) - self_aliases = storage_to_nodes[self_arg_storage] - - # First, we find all later usages of any of the aliases of self_arg. - later_node_usages = _get_all_later_node_usages(self_aliases, node.meta['node_idx']) - # Then, we check if any of those later usages are actually view_scatter ops - # that are safe to fully remove. - later_view_inverse_node_usages = _get_view_inverse_node_usages(later_node_usages, self_aliases) - - # Step 2: Check to see if the input to the op is re-used later in the graph. - # If not (same goes for its aliases), then this op is safe to re-in place. - # This is a slightly roundabout way to check that there are no later usages of the current self argument. - # (later_view_inverse_node_usages corresponds to "view_scatter" nodes that we are allowed to delete) - can_reinplace = len(later_node_usages - later_view_inverse_node_usages) == 0 - if not can_reinplace: - continue - - # Step 3a: Special handling for when we see *_scatter operators. - # When we see an operator like `b = torch.slice_scatter(a, ...)`, - # instead of trying to "inplace" it into a.slice_scatter_(..._), - # we would prefer to remove it from the graph entirely, - # and instead copy_() the slice directly into the larger tensor. - # See the description of the algorithm for a full example. - if node.target in _VIEW_INVERSE_MAP and node not in all_later_view_inverse_nodes_to_delete: - view_op = _VIEW_INVERSE_MAP[node.target] - # Before: - # base_updated = torch.ops.aten.slice_scatter.default(base, mutated_slice, args...) - # After: - # slice = torch.ops.aten.slice.default(base, args...) - # slice.copy_(mutated_slice) - with gm.graph.inserting_before(node): - mutated_slice_node = node.args[1] - remaining_slice_args = node.args[2:] - slice_node = gm.graph.create_node( - 'call_function', view_op, (self_arg,) + tuple(remaining_slice_args), node.kwargs) - copy_node = gm.graph.create_node( - 'call_function', torch.ops.aten.copy_.default, (slice_node, mutated_slice_node,), {}) - # Add the slice_scatter node to our "nodes to delete" list. - all_later_view_inverse_nodes_to_delete.add(node) - - - else: - # Step 3b: Check to see if this operator has an inplace variant. - maybe_inplace_op = _maybe_get_inplace_op(node.target) - if maybe_inplace_op is None: - continue - # And if so, replace it with its inplace variant. - node.target = maybe_inplace_op - - # At this point, 'storage_to_nodes' will be stale. - # Now that we're inplacing `b = foo(a)`, we need to effectively - # union together the dict values for b and a's storage. - # Hmm... morally I think we also want to keep the `fake_result` metadata - # up to date here, but I'm not sure how easy it is to do. - # Maybe it's fine to wait until the end of the pass to update it. - curr_node_storage = StorageWeakRef(node.meta['fake_result'].storage()) - storage_to_nodes[self_arg_storage].update(storage_to_nodes[curr_node_storage]) - storage_to_nodes[curr_node_storage].update(storage_to_nodes[self_arg_storage]) - - # Need to remember the view_scatter view nodes we found so we can remove them alter. - all_later_view_inverse_nodes_to_delete.update(later_view_inverse_node_usages) - - # Step 4: - # Now that we've replaced b = a.foo() with a.foo_(), - # We need to replace any later usages of "b" with "a" - for old in itertools.chain([node], later_view_inverse_node_usages): - new = old.args[0] - nodes_to_update = [n for n in old.users if n.meta['node_idx'] > node.meta['node_idx']] - for node_to_update in nodes_to_update: - new_args = [] - args = node_to_update.args - - def replace_arg(a): - if a == old: - return new - return a - - # First, replace usages of "b" with "a" - node_to_update.args = tree_map_only(Node, replace_arg, node_to_update.args) - node_to_update.kwargs = tree_map_only(Node, replace_arg, node_to_update.kwargs) - - # Second, update our storage_to_nodes data structure. - old_flattened_res, _ = tree_flatten(old.meta['fake_result']) - node_flattened_res, _ = tree_flatten(node_to_update.meta['fake_result']) - - old_res_storage = set(StorageWeakRef(x.storage()) for x in old_flattened_res if isinstance(x, FakeTensor)) - node_res_storage = set(StorageWeakRef(x.storage()) for x in node_flattened_res if isinstance(x, FakeTensor)) - - # This will happen if we're updating a view op, e.g. - # e.g. replacing - # x = view(old) - # x = view(new) - # When that happens, we need to make sure to keep our - # storage mapping up to date. - # - # We're checking for len(...) == 1 here because all view ops are guaranteed to return either a single tensor, - # or multiple tensors that all share the same storage. - # We can't just check equality because we might encounter FX nodes that return zero tensor outputs. - if len(old_res_storage) == 1 and len(node_res_storage) == 1 and old_res_storage == node_res_storage: - new_flattened_res, _ = tree_flatten(new.meta['fake_result']) - new_res_storage = set(StorageWeakRef(x.storage()) for x in new_flattened_res if isinstance(x, FakeTensor)) - assert len(new_res_storage) == 1 - (old_ref,) = old_res_storage - (new_ref,) = new_res_storage - (node_ref,) = node_res_storage - # Technically, "old_ref" and all its aliases will remain - # in our mapping. - # That should be fine though, since we deleted "old" - # from the graph at this point. - storage_to_nodes[node_ref].update(storage_to_nodes[new_ref]) - storage_to_nodes[new_ref].update(storage_to_nodes[node_ref]) - - # Step 4: delete any _scatter nodes that we de-functionalized - # Need to take care not to delete any of these nodes until after *all* modifications - # to the graph are finished. - for to_delete in all_later_view_inverse_nodes_to_delete: - gm.graph.erase_node(to_delete) - - - gm.recompile() - return gm diff --git a/pippy/fx/passes/shape_prop.py b/pippy/fx/passes/shape_prop.py deleted file mode 100644 index 9745136a2..000000000 --- a/pippy/fx/passes/shape_prop.py +++ /dev/null @@ -1,153 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates -import torch -import pippy.fx -import traceback - -from pippy.fx.node import Node, map_aggregate -from typing import Any, Tuple, NamedTuple, Optional, Dict -from pippy.fx._compatibility import compatibility - -__all__ = ['TensorMetadata', 'ShapeProp'] - -@compatibility(is_backward_compatible=True) -class TensorMetadata(NamedTuple): - # TensorMetadata is a structure containing pertinent information - # about a tensor within a PyTorch program. - - # General Tensor metadata - shape : torch.Size - dtype : torch.dtype - requires_grad : bool - stride : Tuple[int] - memory_format : Optional[torch.memory_format] - - # Quantization metadata - is_quantized : bool - qparams: Dict[str, Any] - -def _extract_tensor_metadata(result : torch.Tensor) -> TensorMetadata: - """ - Extract a TensorMetadata NamedTuple describing `result`. - """ - shape = result.shape - dtype = result.dtype - requires_grad = result.requires_grad - stride = result.stride() - - memory_formats = { - torch.contiguous_format, - torch.channels_last, - torch.channels_last_3d, - } - - memory_format = None - - for query_format in memory_formats: - if result.is_contiguous(memory_format=query_format): - memory_format = query_format - break - - is_quantized = result.is_quantized - qparams: Dict[str, Any] = {} - if is_quantized: - qscheme = result.qscheme() - qparams["qscheme"] = qscheme - if qscheme in {torch.per_tensor_affine, torch.per_tensor_symmetric}: - qparams["scale"] = result.q_scale() # type: ignore[assignment] - qparams["zero_point"] = result.q_zero_point() # type: ignore[assignment] - elif qscheme in {torch.per_channel_affine, torch.per_channel_affine_float_qparams, torch.per_channel_symmetric}: - # In this branch, scale and zero_point are expected to be tensors, - # we store the values as immutable_list in TensorMetadata for - # easier serialization downstream - qparams["scale"] = result.q_per_channel_scales().tolist() # type: ignore[assignment] - qparams["zero_point"] = result.q_per_channel_zero_points().tolist() # type: ignore[assignment] - qparams["axis"] = result.q_per_channel_axis() # type: ignore[assignment] - - return TensorMetadata( - shape, dtype, requires_grad, stride, memory_format, is_quantized, qparams) - -@compatibility(is_backward_compatible=True) -class ShapeProp(pippy.fx.Interpreter): - """ - Execute an FX graph Node-by-Node and - record the shape and type of the result - into the corresponding node. - - Example: - In this example, we record the shape - and data type of a module given - an example input ``torch.randn(50, D_in)``. - We print the name, shape and dtype of each node. - - class TwoLayerNet(torch.nn.Module): - def __init__(self, D_in, H, D_out): - super(TwoLayerNet, self).__init__() - self.linear1 = torch.nn.Linear(D_in, H) - self.linear2 = torch.nn.Linear(H, D_out) - def forward(self, x): - h_relu = self.linear1(x).clamp(min=0) - y_pred = self.linear2(h_relu) - return y_pred - N, D_in, H, D_out = 64, 1000, 100, 10 - x = torch.randn(N, D_in) - y = torch.randn(N, D_out) - model = TwoLayerNet(D_in, H, D_out) - gm = pippy.fx.symbolic_trace(model) - sample_input = torch.randn(50, D_in) - ShapeProp(gm).propagate(sample_input) - - for node in gm.graph.nodes: - print(node.name, node.meta['tensor_meta'].dtype, - node.meta['tensor_meta'].shape) - - The output of this code is: - - x torch.float32 torch.Size([50, 1000]) - linear1 torch.float32 torch.Size([50, 100]) - clamp_1 torch.float32 torch.Size([50, 100]) - linear2 torch.float32 torch.Size([50, 10]) - output torch.float32 torch.Size([50, 10]) - - Args: - module (GraphModule): The module to be executed - - """ - def run_node(self, n : Node) -> Any: - try: - result = super().run_node(n) - except Exception: - traceback.print_exc() - raise RuntimeError( - f"ShapeProp error for: node={n.format_node()} with " - f"meta={n.meta}" - ) - - found_tensor = False - - def extract_tensor_meta(obj): - if isinstance(obj, torch.Tensor): - nonlocal found_tensor - found_tensor = True - return _extract_tensor_metadata(obj) - else: - return obj - - meta = map_aggregate(result, extract_tensor_meta) - if found_tensor: - n.meta['tensor_meta'] = meta - - n.meta['type'] = type(result) - return result - - def propagate(self, *args): - """ - Run `module` via interpretation and return the result and - record the shape and type of each node. - - Args: - *args (Tensor): the sample input. - - Returns: - Any: The value returned from executing the Module - """ - return super().run(*args) diff --git a/pippy/fx/passes/split_module.py b/pippy/fx/passes/split_module.py deleted file mode 100644 index 2ccc28108..000000000 --- a/pippy/fx/passes/split_module.py +++ /dev/null @@ -1,327 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates -import inspect -from typing import Any, Callable, Dict, List, Optional - -import torch - -import pippy -import pippy.fx -from pippy.fx._compatibility import compatibility -from pippy.fx.graph_module import GraphModule - -__all__ = ["Partition", "split_module"] - - -@compatibility(is_backward_compatible=True) -class Partition: - def __init__(self, name: str): - self.name: str = name - self.submod_name = f"submod_{name}" - self.node_names: List[str] = [] - self.inputs: Dict[str, None] = {} - self.outputs: Dict[str, None] = {} - self.partitions_dependent_on: Dict[str, None] = {} - self.partition_dependents: Dict[str, None] = {} - self.graph: pippy.fx.graph.Graph = pippy.fx.graph.Graph() - self.environment: Dict[pippy.fx.node.Node, pippy.fx.node.Node] = {} - self.targets: Dict[str, Any] = {} - - def __repr__(self) -> str: - return ( - f"name: {self.name},\n" - f" nodes: {self.node_names},\n" - f" inputs: {self.inputs},\n" - f" outputs: {self.outputs},\n" - f" partitions depenent on: {self.partitions_dependent_on},\n" - f" parition dependents: {self.partition_dependents}" - ) - - -# Creates subgraphs out of main graph -@compatibility(is_backward_compatible=True) -def split_module( - m: GraphModule, - root_m: torch.nn.Module, - split_callback: Callable[[pippy.fx.node.Node], int], - qualname_map: Optional[Dict[str, str]] = None, - keep_original_order: Optional[bool] = False, -): - """ - Creates subgraphs out of main graph - - Args: - m (GraphModule): Graph module to split - root_m (torch.nn.Module): root nn module. Not currently used. Included - because the root nn module is usually transformed via - pippy.fx._symbolic_trace.symbolic_trace (see example below) - split_callback (Callable[[pippy.fx.node.Node], int]): Callable function - that maps a given Node instance to a numeric partition identifier. - split_module will use this function as the policy for which operations - appear in which partitions in the output Module. - qualname_map: Optional[Dict[str, str]]: optional output parameter that returns a - mapping from new target names in the module after split to old target - names in the original module. - keep_original_order: Optional[bool]: keep the original order of the GraphModule - or use the Topological order of the new constructed GraphModule - - - Returns: - GraphModule: the module after split. - - Example: - - This is a sample setup: - - import torch - from pippy.fx.symbolic_trace import symbolic_trace - from pippy.fx.graph_module import GraphModule - from pippy.fx.node import Node - from pippy.fx.passes.split_module import split_module - - class MyModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.param = torch.nn.Parameter(torch.rand(3, 4)) - self.linear = torch.nn.Linear(4, 5) - - def forward(self, x, y): - z = self.linear(x + self.param).clamp(min=0.0, max=1.0) - w = self.linear(y).clamp(min=0.0, max=1.0) - return z + w - - # symbolically trace model - my_module = MyModule() - my_module_traced = symbolic_trace(my_module) - - # random mod partitioning - partition_counter = 0 - NPARTITIONS = 3 - - def mod_partition(node: Node): - global partition_counter - partition = partition_counter % NPARTITIONS - partition_counter = (partition_counter + 1) % NPARTITIONS - return partition - - # split module in module with submodules - module_with_submodules = split_module( - my_module_traced, my_module, mod_partition - ) - - Output looks like this. Original graph is broken into partitions - - > print(module_with_submodules) - GraphModule( - (submod_0): GraphModule( - (linear): Linear(in_features=4, out_features=5, bias=True) - ) - (submod_1): GraphModule( - (linear): Linear(in_features=4, out_features=5, bias=True) - ) - (submod_2): GraphModule() - ) - - def forward(self, x, y): - param = self.param - submod_0 = self.submod_0(x, param, y); x = param = y = None - getitem = submod_0[0] - getitem_1 = submod_0[1]; submod_0 = None - submod_1 = self.submod_1(getitem, getitem_1); getitem = getitem_1 = None - getitem_2 = submod_1[0] - getitem_3 = submod_1[1]; submod_1 = None - submod_2 = self.submod_2(getitem_2, getitem_3); getitem_2 = getitem_3 = None - return submod_2 - - Output of split module is the same as output of input traced module. - This is an example within a test setting: - - > orig_out = my_module_traced(x, y) - > submodules_out = module_with_submodules(x, y) - > self.assertEqual(orig_out, submodules_out) - True - """ - partitions: Dict[str, Partition] = {} - orig_nodes: Dict[str, pippy.fx.node.Node] = {} - - def record_cross_partition_use( - def_node: pippy.fx.node.Node, use_node: Optional[pippy.fx.node.Node] - ): # noqa: B950 - def_partition_name = getattr(def_node, "_fx_partition", None) - use_partition_name = getattr(use_node, "_fx_partition", None) - if def_partition_name != use_partition_name: - if def_partition_name is not None: - def_partition = partitions[def_partition_name] - def_partition.outputs.setdefault(def_node.name) - if use_partition_name is not None: - def_partition.partition_dependents.setdefault(use_partition_name) - - if use_partition_name is not None: - use_partition = partitions[use_partition_name] - use_partition.inputs.setdefault(def_node.name) - if def_partition_name is not None: - use_partition.partitions_dependent_on.setdefault(def_partition_name) - - # split nodes into parititons - for node in m.graph.nodes: - orig_nodes[node.name] = node - - # TODO currently placeholders/parameters aren't put into random partitions, - # rather they're added to the graphs where they are used down below - if node.op in ["placeholder", "get_attr"]: - continue - if node.op == "output": - pippy.fx.graph.map_arg( - node.args[0], lambda n: record_cross_partition_use(n, None) - ) - continue - partition_name = str(split_callback(node)) - - # add node to partitions - partition = partitions.get(partition_name) - if partition is None: - partitions[partition_name] = partition = Partition(partition_name) - - partition.node_names.append(node.name) - node._fx_partition = partition_name - - pippy.fx.graph.map_arg( - node.args, lambda def_node: record_cross_partition_use(def_node, node) - ) - pippy.fx.graph.map_arg( - node.kwargs, lambda def_node: record_cross_partition_use(def_node, node) - ) # noqa: B950 - - original_partition_order = list(partitions.keys()) - # find partitions with no dependencies - root_partitions: List[str] = [] - for partition_name, partition in partitions.items(): - if not len(partition.partitions_dependent_on): - root_partitions.append(partition_name) - - # check partitions for circular dependencies and create topological partition ordering - sorted_partitions: List[str] = [] - while root_partitions: - root_partition = root_partitions.pop() - sorted_partitions.append(root_partition) - for dependent in partitions[root_partition].partition_dependents: - partitions[dependent].partitions_dependent_on.pop(root_partition) - if not partitions[dependent].partitions_dependent_on: - root_partitions.append(dependent) - if len(sorted_partitions) != len(partitions): - raise RuntimeError("cycle exists between partitions!") - - # add placeholders to parititons - for partition_name in sorted_partitions: - partition = partitions[partition_name] - for input in partition.inputs: - placeholder = partition.graph.placeholder(input) - placeholder.meta = orig_nodes[input].meta.copy() - partition.environment[orig_nodes[input]] = placeholder - - # Transform nodes and collect targets for partition's submodule - for node in m.graph.nodes: - if hasattr(node, "_fx_partition"): - partition = partitions[node._fx_partition] - - # swap out old graph nodes in kw/args with references to new nodes in this submodule - environment = partition.environment - gathered_args = pippy.fx.graph.map_arg(node.args, lambda n: environment[n]) - gathered_kwargs = pippy.fx.graph.map_arg( - node.kwargs, lambda n: environment[n] - ) - - if node.op not in ["call_module", "get_attr"]: - target = node.target - else: - target_atoms = node.target.split(".") - target_attr = m - for atom in target_atoms: - if not hasattr(target_attr, atom): - raise RuntimeError(f"Operator target {node.target} not found!") - target_attr = getattr(target_attr, atom) - # target = target_atoms[-1] - target = "_".join(target_atoms) - partition.targets[target] = target_attr - # Fill in the passed-in mapping from new qualname to old qualname - if qualname_map is not None: - # When creating the split module later, the submodules will have - # path prefix matching the corresponding partition's submod_name - qualname = f"{partition.submod_name}.{target}" - qualname_map[qualname] = node.target - - assert isinstance(gathered_args, tuple) - assert isinstance(gathered_kwargs, dict) - new_node = partition.graph.create_node( - op=node.op, target=target, args=gathered_args, kwargs=gathered_kwargs - ) - new_node.meta = node.meta.copy() - partition.environment[node] = new_node - - # Set up values to construct base module - base_mod_env: Dict[str, pippy.fx.node.Node] = {} - base_mod_graph: pippy.fx.graph.Graph = pippy.fx.graph.Graph() - base_mod_attrs: Dict[str, pippy.fx.graph_module.GraphModule] = {} - for node in m.graph.nodes: - if node.op == "placeholder": - default_value = ( - node.args[0] if len(node.args) > 0 else inspect.Signature.empty - ) - base_mod_env[node.name] = base_mod_graph.placeholder( - node.target, type_expr=node.type, default_value=default_value - ) - base_mod_env[node.name].meta = node.meta.copy() - elif node.op == "get_attr": - base_mod_env[node.name] = base_mod_graph.get_attr(node.target) - base_mod_env[node.name].meta = node.meta.copy() - attr_val = m - for atom in node.target.split("."): - if not hasattr(attr_val, atom): - raise RuntimeError(f"Node target {node.target} not found!") - attr_val = getattr(attr_val, atom) - base_mod_attrs[node.target] = attr_val - - # Do some things iterating over the partitions in topological order again: - # 1) Finish off submodule Graphs by setting corresponding outputs - # 2) Construct GraphModules for each submodule - # 3) Construct the base graph by emitting calls to those submodules in - # topological order - - construct_order_partitions = ( - sorted_partitions if not keep_original_order else original_partition_order - ) - - for partition_name in construct_order_partitions: - partition = partitions[partition_name] - - # Set correct output values - output_vals = tuple( - partition.environment[orig_nodes[name]] for name in partition.outputs - ) - output_vals = output_vals[0] if len(output_vals) == 1 else output_vals # type: ignore[assignment] - partition.graph.output(output_vals) - - # Construct GraphModule for this partition - base_mod_attrs[partition.submod_name] = pippy.fx.graph_module.GraphModule( - partition.targets, partition.graph - ) # noqa: B950 - - # Emit call in base graph to this submodule - output_val = base_mod_graph.call_module( - partition.submod_name, - tuple(base_mod_env[name] for name in partition.inputs), - ) - if len(partition.outputs) > 1: - # Unpack multiple return values from submodule - output_val_proxy = pippy.fx.proxy.Proxy(output_val) - for i, output_name in enumerate(partition.outputs): - base_mod_env[output_name] = output_val_proxy[i].node # type: ignore[index] - else: - base_mod_env[list(partition.outputs)[0]] = output_val - - for node in m.graph.nodes: - if node.op == "output": - base_mod_graph.output( - pippy.fx.graph.map_arg(node.args[0], lambda n: base_mod_env[n.name]) - ) # noqa: B950 - - return pippy.fx.graph_module.GraphModule(base_mod_attrs, base_mod_graph) diff --git a/pippy/fx/passes/split_utils.py b/pippy/fx/passes/split_utils.py deleted file mode 100644 index f6f8b90be..000000000 --- a/pippy/fx/passes/split_utils.py +++ /dev/null @@ -1,278 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates -from dataclasses import dataclass, field -from typing import List, Optional, Dict - -import pippy.fx -from pippy.fx.graph import map_arg -from .tools_common import NodeList -from pippy.fx._compatibility import compatibility -from pippy.fx.passes.utils import lift_subgraph_as_module, HolderModule - -__all__ = ['getattr_recursive', 'setattr_recursive', 'Component', 'split_by_tags'] - -@compatibility(is_backward_compatible=False) -def getattr_recursive(obj, name): - for layer in name.split("."): - if hasattr(obj, layer): - obj = getattr(obj, layer) - else: - return None - return obj - - -@compatibility(is_backward_compatible=False) -def setattr_recursive(obj, attr, value): - if "." not in attr: - setattr(obj, attr, value) - else: - layer = attr.split(".") - setattr_recursive(getattr(obj, layer[0]), ".".join(layer[1:]), value) - - -@compatibility(is_backward_compatible=False) -@dataclass -class Component: - """ - A component serves as a container for a subgraph we want to create afterwards. - """ - - graph: pippy.fx.Graph - order: int - name: str - - # Stores the placeholder nodes in `graph`. - input_placeholders: List = field(default_factory=list) - - # Store the nodes in original graph that are placeholder in `graph`. - orig_inputs: List = field(default_factory=list) - - # Store the nodes in original graph that are outputs in `graph`. - orig_outputs: List = field(default_factory=list) - - # Mapping from get_attr node in original graph to get_attr node in `graph`. - getattr_maps: Dict[pippy.fx.Node, pippy.fx.Node] = field(default_factory=dict) - constructor_args: List[str] = field(default_factory=list) - gm: Optional[pippy.fx.GraphModule] = None - - -@compatibility(is_backward_compatible=False) -def split_by_tags(gm: pippy.fx.GraphModule, tags: List[str]) -> pippy.fx.GraphModule: - """ - Splits a GraphModule using tags on its graph nodes. We honor the order of - tags. For example, we have tags = ["a", "b", "c"], the function will create - the initial submodules in the order of "a_0", "b_1", "c_2". - - To set a tag: - gm.graph.nodes[idx].tag = "mytag" - - This will result in all nodes with the same tag being extracted and placed in their - own submodule. For placeholder, output and get_attr node, the tag is ignored. placeholder - and output nodes are created when needed while get_attr nodes get copied to submodules - where they are used. - - Given the following module def: - - class SimpleModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.linear1 = torch.nn.Linear(...) - self.linear2 = torch.nn.Linear(...) - self.linear3 = torch.nn.Linear(...) - - def forward(self, in1, in2): - r1 = self.linear1(in1) - r2 = self.linear2(in2) - r3 = torch.cat([r1, r2]) - return self.linear3(r3) - - Marking the node corresponding to in1 with the tag sc.REQUEST_ONLY.lower() results in the following split: - - ro_0: - def forward(self, in1): - self = self.root - linear1 = self.linear1(in1) - return linear1 - - main_1: - def forward(self, in2, linear1): - self = self.root - linear2 = self.linear2(in2) - cat_1 = torch.cat([linear1, linear2]) - linear3 = self.linear3(cat_1) - return linear3 - - main_0: - def forward(self, in1, in2): - self = self.root - ro_0 = self.ro_0(in1) - main_1 = self.main_1(in2, ro_0) - return main_1 - """ - - def flatten(x: pippy.fx.node.Argument) -> NodeList: - """ - Stores nodes in x to a list and returns the list. - """ - r: NodeList = [] - map_arg(x, r.append) - return r - - # Mapping from node in original module to node in created submodule. - node_remapping: Dict[pippy.fx.Node, pippy.fx.Node] = {} - - # Mapping from node in original module or created submodules to - # corresponding component. - node_to_component: Dict[pippy.fx.Node, Component] = {} - - # Mapping from tag to the corresponding component. - tag_to_component: Dict[str, Component] = {} - - # Stores all components. - all_components: List[Component] = [] - - # Stores nodes that will be used in main graph. - used_in_main: Dict[pippy.fx.Node, None] = {} - - # Main graph after split. - main_g = pippy.fx.Graph() - - # Mapping from node in original module to node in main graph after split. - main_remapping: Dict[pippy.fx.Node, pippy.fx.Node] = {} - - # Output node of original module. - output_node: Optional[pippy.fx.Node] = None - - # Create a component for each tag, we don't expect to create other components afterwards. - for tag in tags: - comp = Component(pippy.fx.Graph(), len(all_components), f"{tag}") - all_components.append(comp) - tag_to_component[tag] = comp - - # Traverse the nodes in original graph and take care of them. - for node in gm.graph.nodes: - if node.op == "output": - if output_node is not None: - raise RuntimeError("Multiple output nodes in graph!") - output_node = node - continue - - # Placeholders in the original graph get copied to main graph. - if node.op == "placeholder": - main_remapping[node] = main_g.placeholder(node.name, type_expr=node.type) - continue - - # Get_attr nodes are ignored because we are not tagging them. - # Instead, we copy them directly to the submodules use them afterwards. - if node.op == "get_attr": - continue - - # Now we process callable nodes which are nodes with op of call_module, - # call_function or call_method. Every callable nodes should be tagged. - assert hasattr(node, "tag") - - upstream_components = [ - node_to_component[x] - for x in flatten(node.args) + flatten(node.kwargs) - if x.op not in {"placeholder", "get_attr"} - ] - - comp = tag_to_component[node.tag] - node_to_component[node] = comp - - # Max order of upperstream components. - mx = max((c.order for c in upstream_components), default=0) - - # Expect the componet for `node` has higher order then its upstream components. - assert comp.order >= mx - - # Map a input of `node` to nodes in the component's graph. - def remap_func(x): - # If input is a get_attr node, copy it to current component's graph. - # Returns the get_attr node in current component's graph. - if x.op == "get_attr": - if x not in comp.getattr_maps: - comp.getattr_maps[x] = comp.graph.get_attr( - x.target, type_expr=x.type - ) - return comp.getattr_maps[x] - - # If input is not a placeholder, it should have been put into a component - # already. If it's the current component then we return the corresponding - # node in the component. - if x.op != "placeholder" and node_to_component[x] == comp: - return node_remapping[x] - - # If input is a placeholder or it's in other components, we want to make it - # as a placeholder in current component's graph. - if x not in comp.orig_inputs: - comp.orig_inputs.append(x) - comp.input_placeholders.append( - comp.graph.placeholder(x.name, type_expr=x.type) - ) - used_in_main[x] = None - - return comp.input_placeholders[ - next(i for i, y in enumerate(comp.orig_inputs) if x is y) - ] - - n = comp.graph.node_copy(node, remap_func) - n.tag = node.tag # type: ignore[attr-defined] - node_remapping[node] = n - node_to_component[n] = comp - - if output_node is None: - raise RuntimeError("Graph had no output node!") - - for x in flatten(output_node.args[0]): - if x.op == "get_attr": - # We don't need components mapping for nodes of type "get_attr" - # that are consumed by the output. Only need to make sure we create - # corresponding counterparts in the resulting graph. - main_remapping[x] = main_g.get_attr(x.name, type_expr=x.type) - else: - # All component results consumed by the output node should be - # marked as "used in main". - used_in_main[x] = None - - # If a node is used in main graph then we mark it as an output in the component - # it belongs to. - for n in used_in_main: - if n.op != "placeholder": - node_to_component[n].orig_outputs.append(n) - - # Now we create a graphmodule for each component. - for comp in all_components: - outs = tuple(map(node_remapping.__getitem__, comp.orig_outputs)) - - # Take care of the args of FX output node. If there's a single - # output then the output node args is like (output_single), else - # if there're multiple outputs then the output node args is like - # ((output_0, output_1, ...)). - comp.graph.output(outs[0] if len(outs) == 1 else outs) - - comp.gm = lift_subgraph_as_module(gm, comp.graph) - - # Create a call_module node in main graph. - main_node = main_g.call_module( - comp.name, - args=tuple(map(main_remapping.__getitem__, comp.orig_inputs)), - kwargs=None, - ) - - if len(outs) == 1: - main_remapping[comp.orig_outputs[0]] = main_node - else: - for i, o in enumerate(comp.orig_outputs): - # Use Proxy to record getitem access. - main_remapping[o] = pippy.fx.Proxy(main_node)[i].node # type: ignore[index] - - main_g.output(map_arg(output_node.args[0], main_remapping.__getitem__)) - main_root = HolderModule({comp.name: comp.gm for comp in all_components}) - - # If the output nodes consumes get_attr directly in the original graph, - # then we need to make sure get_attr is copied to the new graph. - for x in flatten(output_node.args[0]): - if x.op == "get_attr": - setattr(main_root, x.name, getattr_recursive(gm, x.target)) # type: ignore[arg-type] - - return pippy.fx.GraphModule(main_root, main_g) diff --git a/pippy/fx/passes/splitter_base.py b/pippy/fx/passes/splitter_base.py deleted file mode 100644 index 99ed92bd0..000000000 --- a/pippy/fx/passes/splitter_base.py +++ /dev/null @@ -1,854 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates -import argparse -import logging -import warnings -from collections import defaultdict -from dataclasses import dataclass -from typing import NamedTuple, Sequence, Iterable, Any, List, Dict, Optional, Tuple - -import torch - -import pippy -import pippy.fx -from pippy.fx._compatibility import compatibility -from pippy.fx.node import map_arg -from pippy.fx.passes.graph_manipulation import get_size_of_node -from .graph_drawer import FxGraphDrawer -from .operator_support import ( - get_node_target, - OperatorSupportBase, -) -from .shape_prop import ShapeProp -from .split_utils import split_by_tags -from .tools_common import ( - FxNetAccFusionsFinder, - CALLABLE_NODE_OPS, - Tensors, - NodeList, - NodeSet, - is_node_output_tensor, -) - -__all__ = ['FxNetAccNodesFinder', 'FxNetSplitterInternalError', 'Subgraph', 'SplitResult', 'generate_inputs_for_submodules'] -_LOGGER = logging.getLogger(__name__) - - -class _SplitterSettingBase: - def __init__(self): - parser = argparse.ArgumentParser() - parser.add_argument( - "--min_acc_module_size", - default=1, - type=int, - help="Minimum size limit of an accelerator subgraph.", - ) - parser.add_argument( - "--skip_fusion", - default=False, - action="store_true", - help="If true then no fusion groups. Fusion group is used to " - "enforce no non-tensor data flow between submodules. If we don't " - "have this constrain, setting this to false is recommended as it " - "can reduce overhead.", - ) - parser.add_argument( - "--allow_non_tensor", - default=False, - action="store_true", - help="For some backends non-tensor data flow between cpu and them " - "are not allowed. Therefore, if a node supported by accelerator but " - "it has non-tensor inputs or outputs to a cpu node we would want to " - "consider it as a cpu node during splitting. However, for some backends " - "we might not care about non-tensor data flow and we can set this option " - "to true to disable the functionality that prevent non-tensor data flow.", - ) - args, unknown = parser.parse_known_args() - - self.min_acc_module_size: int = args.min_acc_module_size - self.skip_fusion: bool = args.skip_fusion - self.allow_non_tensor: bool = args.allow_non_tensor - - -@compatibility(is_backward_compatible=False) -class FxNetAccNodesFinder: - """ - Finds a set of nodes that can be supported on ACC, excluding nodes that have non-tensor - input/output to cpu nodes to prevent non-tensor data flow between backends and cpu. - - I.e. if we have a chain: - - ACC_NODE_1 -> ACC_NODE_2 -> ACC_NODE_3 -> CPU_NODE_1 - - where every ACC node produces non-tensor output, then they all should be treated as CPU nodes. - - This behavior can be turned off by passing allow_non_tensor=True. - """ - - def __init__( - self, - module: pippy.fx.GraphModule, - operator_support: OperatorSupportBase, - allow_non_tensor: bool, - ): - self.module = module - self.operator_support = operator_support - self.allow_non_tensor = allow_non_tensor - - def reduce_acc_nodes_non_tensor_input_helper( - self, cpu_worklist: NodeList - ): - """ - Transitively excludes nodes from ACC supported set. - For every node in the worklist: - - removes its downstream ACC nodes from ACC supported set, - - if any downstream ACC node produces non-tensor output, - then it gets added into the worklist. - """ - while cpu_worklist: - node = cpu_worklist.pop(0) - - for user in node.users: - if user in self.acc_nodes: - self.acc_nodes.remove(user) - if not is_node_output_tensor(user): - cpu_worklist.append(user) - - def reduce_acc_nodes_non_tensor_input(self): - """ - Excludes nodes from ACC supported set that have direct - upstream CPU nodes that produce non-tensor outputs. - """ - non_tensor_cpu_nodes: NodeList = [] - - for node in self.module.graph.nodes: - if node.op not in CALLABLE_NODE_OPS: - continue - if node in self.acc_nodes: - continue - if is_node_output_tensor(node): - continue - non_tensor_cpu_nodes.append(node) - - self.reduce_acc_nodes_non_tensor_input_helper(non_tensor_cpu_nodes) - - def reduce_acc_nodes_non_tensor_output(self): - """ - Excludes nodes from ACC supported set that produce non-tensor - outputs and have downstream CPU nodes. - """ - while True: - new_cpu_nodes: NodeList = [] - - for acc_node in self.acc_nodes: - if is_node_output_tensor(acc_node): - continue - for user in acc_node.users: - if user not in self.acc_nodes: - new_cpu_nodes.append(acc_node) - break - - if not new_cpu_nodes: - break - - for new_cpu_node in new_cpu_nodes: - self.acc_nodes.remove(new_cpu_node) - - self.reduce_acc_nodes_non_tensor_input_helper(new_cpu_nodes) - - def __call__(self) -> NodeSet: - submodules = dict(self.module.named_modules()) - self.acc_nodes = { - n - for n in self.module.graph.nodes - if n.op in CALLABLE_NODE_OPS - and self.operator_support.is_node_supported(submodules, n) - } - - if not self.allow_non_tensor: - self.reduce_acc_nodes_non_tensor_input() - self.reduce_acc_nodes_non_tensor_output() - - return self.acc_nodes - -@compatibility(is_backward_compatible=False) -class FxNetSplitterInternalError(Exception): - pass - -@compatibility(is_backward_compatible=False) -@dataclass -class Subgraph: - is_acc: bool - nodes: NodeList - - -@compatibility(is_backward_compatible=False) -class SplitResult(NamedTuple): - """ - Stores the results of the splitter. - - Attributes: - split_module: root module after splitting. - submodule_inputs: a dict that maps submodule name to its inputs. - non_acc_submodule_prefix: the prefix for non acc submodules. For - acc submodule the prefix is alwasy "_run_on_acc_". - """ - - split_module: pippy.fx.GraphModule - submodule_inputs: Dict[str, Any] - non_acc_submodule_prefix: str - - -@compatibility(is_backward_compatible=False) -def generate_inputs_for_submodules( - model: torch.nn.Module, - inputs: Sequence[Any], - target_submodules: Iterable[str] -) -> Dict[str, Any]: - """ - Generate inputs for targeting submdoules in the given model. Note that if two submodules refer to the same obj, this - function doesn't work. - - Args: - model: root model. - inputs: inputs to the root model. - target_submodules: submodules that we want to generate inputs for. - - Returns: - A dict that maps from submodule name to its inputs. - """ - - handles = [] - results = {} - submodule_to_names = dict((mod, name) for name, mod in model.named_modules()) - - def pre_forward(module, module_inputs): - results[submodule_to_names[module]] = module_inputs - try: - for name, mod in model.named_modules(): - if name in target_submodules: - handles.append(mod.register_forward_pre_hook(pre_forward)) - model(*inputs) - except Exception as e: - warnings.warn(f"Failed to generate submodule inputs because of the following error:\n{e}") - finally: - for h in handles: - h.remove() - return results - - -class _SplitterBase: - """ - Splits a GraphModule into sub-GraphModules for execution on CPU or the accelerator. - Output is a GraphModule with supported and unsupported operators grouped into as few sub-GraphModules as possible. - Assumes that only "call_module", "call_function" and "call_method" from FX IR can potentially be executed on the accelerator. - - Given the following graph: - ==> b ==> - // \\ - a d - \\ // - ==> c ==> - - class SimpleModule(torch.nn.Module): - def forward(self, a): - b = torch.sin(a) - c = torch.cos(a) - d = b + c - return d - - and providing "operator_support" that indicates that 'b' and 'c' can be executed on the accelerator, - we will get the following split result: - - main: - def forward(self, a): - run_on_acc_0_0 = self._run_on_acc_0_0(a) - getitem = run_on_acc_0_0[0] - getitem_1 = run_on_acc_0_0[1] - run_on_cpu_1_1 = self._run_on_cpu_1_1(getitem, getitem_1) - return run_on_cpu_1_1 - - _run_on_acc_0_0: - def forward(self, a): - sin_1 = torch.sin(a) - cos_1 = torch.cos(a) - return (sin_1, cos_1) - - _run_on_cpu_1_1: - def forward(self, sin_1, cos_1): - add_1 = sin_1 + cos_1 - return add_1 - """ - - # PCIe bandwidth for the backend, default to 100 GB/s - PCIe_BW = 100 * 2 ** 30 - - def __init__( - self, - module: pippy.fx.GraphModule, - sample_input: Sequence[Any], - operator_support: OperatorSupportBase, - settings: _SplitterSettingBase, - non_acc_submodule_name: str = "_run_on_cpu_", - ): - """ - Preprocesses graph before splitting: - - finds nodes supported by ACC, - - finds fusion groups for ACC nodes having non-tensor IO, - - builds a graph of direct dependencies, - - builds a map of fused nodes to their fusions. - As a result we get self.acc_nodes, self.deps and self.fusions. - """ - assert isinstance(module, pippy.fx.GraphModule) - - self.module = module - ShapeProp(self.module).propagate(*sample_input) - - self.settings = settings - self.operator_support = operator_support - self.sample_input = sample_input - self.acc_nodes = FxNetAccNodesFinder(self.module, self.operator_support, self.settings.allow_non_tensor)() - - if self.settings.skip_fusion: - self.fusions = {} - else: - self.fusions = FxNetAccFusionsFinder(module, self.acc_nodes)() - - # Modify deps to add more deps for fused nodes - self.deps = self.find_deps() - self.update_deps_for_fusions() - - self.non_acc_submodule_name = non_acc_submodule_name - self._node_submodule_map: Dict[str, str] = {} - - # =============================================================== - # Helpers for ctor and initial state - # =============================================================== - - def get_node_submodule_map(self) -> Dict[str, str]: - """ Returns a map from node name to submodule name, e.g. - node: main_module_impl_impl_over_arch_unary_multiple_embedding - _pooling_embedding_pooling_sparse_entity_equivalence_key - _proxy_embedding_bag - maps to submodule name of: _run_on_acc_1 - """ - return self._node_submodule_map - - def find_deps(self) -> Dict[pippy.fx.Node, NodeSet]: - """ - Builds a graph of node dependencies. Leaf nodes don't have any - dependencies and the "output" node doesn't have nodes depending on it. - - Resulting graph has only direct dependencies, i.e. there are no - transitive dependencies. - """ - deps: Dict[pippy.fx.Node, NodeSet] = defaultdict(set) - for node in self.module.graph.nodes: - if node.op not in CALLABLE_NODE_OPS: - continue - - for user in node.users: - if user.op != "output": - deps[user].add(node) - return deps - - def update_deps_for_fusions(self): - """ - Updates graph of dependencies so that: - - nodes from the same fusion depend on the same set of outer nodes, - - outer nodes depending on a fusion depend on all nodes in that fusion. - """ - for node in self.fusions: - fusion = self.fusions[node] - for fused_neighbor in fusion: - self.deps[node].update(self.deps[fused_neighbor] - fusion) - - for user in fused_neighbor.users: - if user not in fusion: - self.deps[user].add(node) - - # =============================================================== - # Helpers for preview - # =============================================================== - - def _lower_model_to_backend( - self, mod: pippy.fx.GraphModule, inputs: Tensors - ) -> torch.nn.Module: - """ - Lower the model to a backend. - """ - - return mod - - def _find_culprit( - self, mod: pippy.fx.GraphModule, inputs: Tensors - ) -> str: - """ - When an error occurs during lowering or running the lowered mod, we use this - function to find culprits in the `mod` that causes the error. - """ - - return "Unable to find a culprit because _find_culprit() function is not implemented." - - def _draw_graph_based_on_node_support( - self, mod: pippy.fx.GraphModule, supported_nodes: NodeList - ): - color_map = { - "default": "AliceBlue", - "supported": "chartreuse1", - "unsupported": "crimson", - } - - class CustomDrawer(FxGraphDrawer): - def _get_node_style(self, node): - template = super()._get_node_style(node) - if node in supported_nodes: - template["fillcolor"] = color_map["supported"] - elif node.op in CALLABLE_NODE_OPS: - template["fillcolor"] = color_map["unsupported"] - else: - template["fillcolor"] = color_map["default"] - - return template - - drawer = CustomDrawer(mod, "node_support", ignore_getattr=True) - dot_graph = drawer.get_main_dot_graph() - dot_graph.write_raw("node_support.dot") - - def node_support_preview(self, dump_graph: bool = False): - submodules = dict(self.module.named_modules()) - - supported_nodes: NodeList = [] - supported_node_types = defaultdict(set) - unsupported_node_types = defaultdict(set) - - def get_dtype(arg): - tensor_meta = arg.meta.get("tensor_meta") - return getattr(tensor_meta, "dtype", None) - - for node in self.module.graph.nodes: - if node.op not in CALLABLE_NODE_OPS: - continue - - target = get_node_target(submodules, node) - - # Store dtype of arg in node.args. If arg doesn't have dtype, i.e. not a tensor, we'll store None. - arg_dtypes = [ - get_dtype(arg) if isinstance(arg, pippy.fx.Node) else None - for arg in node.args - ] - - # Find last non-None element. If all elements are None, return max_len. - last_index = len(arg_dtypes) - next( - ( - i - for i, dtype in enumerate(reversed(arg_dtypes)) - if dtype is not None - ), - len(arg_dtypes), - ) - - # Strip None elements at the end. - arg_dtypes_tuple = tuple(arg_dtypes[:last_index]) - kwarg_dtypes_tuple = tuple( - (k, get_dtype(arg)) - for k, arg in node.kwargs.items() - if isinstance(arg, pippy.fx.Node) - ) - - if self.operator_support.is_node_supported(submodules, node): - supported_nodes.append(node) - supported_node_types[target].add((arg_dtypes_tuple, kwarg_dtypes_tuple)) - else: - unsupported_node_types[target].add((arg_dtypes_tuple, kwarg_dtypes_tuple)) - - if dump_graph: - self._draw_graph_based_on_node_support(self.module, supported_nodes) - - reports = "\nSupported node types in the model:\n" - for t, dtypes in supported_node_types.items(): - for arg_dtypes_tuple, kwarg_dtypes_tuple in dtypes: - reports += f"{t}: ({arg_dtypes_tuple}, {dict(kwarg_dtypes_tuple)})\n" - - reports += "\nUnsupported node types in the model:\n" - for t, dtypes in unsupported_node_types.items(): - for arg_dtypes_tuple, kwarg_dtypes_tuple in dtypes: - reports += f"{t}: ({arg_dtypes_tuple}, {dict(kwarg_dtypes_tuple)})\n" - - print(reports) - - # Return reports for testing purpose - return reports - - def split_preview(self, dump_graph: bool = False): - reports = "" - subgraphs = self.put_nodes_into_subgraphs() - acc_subgraphs_num = len([g for g in subgraphs if g.is_acc]) - cpu_subgraphs_num = len(subgraphs) - acc_subgraphs_num - reports += f"Before removing small acc subgraphs, total {len(subgraphs)} subgraphs are created:" - reports += f" {acc_subgraphs_num} acc subgraphs and {cpu_subgraphs_num} cpu subgraphs.\n" - - subgraphs = self.remove_small_acc_subgraphs(subgraphs) - acc_subgraphs_num = len([g for g in subgraphs if g.is_acc]) - cpu_subgraphs_num = len(subgraphs) - acc_subgraphs_num - reports += f"After removing small acc subgraphs, total {len(subgraphs)} subgraphs are created:" - reports += f" {acc_subgraphs_num} acc subgraphs and {cpu_subgraphs_num} cpu subgraphs.\n" - - for i, subgraph in enumerate(subgraphs): - reports += f"_run_on_acc_{i}: " if subgraph.is_acc else f"{self.non_acc_submodule_name}{i}: " - reports += f"{len(subgraph.nodes)} node(s)\n" - - self.tag(subgraphs) - split_mod = self.split(remove_tag=True) - split_mod.eval() - - if dump_graph: - drawer = FxGraphDrawer( - split_mod, "preview", ignore_getattr=True - ) - dot_graphs = drawer.get_all_dot_graphs() - for name, dot_graph in dot_graphs.items(): - dot_graph.write_raw(f"{name}.dot") - - max_qps: float = self.PCIe_BW - bottleneck_module = "" - - for node in split_mod.graph.nodes: - if node.op == "call_module" and "acc" in node.target: - reports += f"\nProcessing acc submodule {node.target}\n" - - submod = getattr(split_mod, node.target) - - def get_submod_inputs(main_mod, submod, example_inputs): - sub_inputs = None - - def get_inputs(self, inputs): - nonlocal sub_inputs - sub_inputs = inputs - - handle = submod.register_forward_pre_hook(get_inputs) - main_mod(*example_inputs) - handle.remove() - return sub_inputs - - submod_inputs = get_submod_inputs( - split_mod, submod, self.sample_input - ) - ShapeProp(submod).propagate(*submod_inputs) - - total_input_bytes = 0 - total_output_bytes = 0 - - reports += "Checking inputs...\n" - for n in submod.graph.nodes: - if n.op == "placeholder": - if not is_node_output_tensor(n): - reports += f"Input {n.name} is not a tensor, this might cause problems during lowering!\n" - else: - total_input_bytes += get_size_of_node(submod, n)[0] - if n.op == "output": - output_node = n - - reports += "Checking outputs...\n" - - def get_bytes(node: pippy.fx.Node): - nonlocal total_output_bytes - nonlocal reports - if not is_node_output_tensor(node): - reports += f"Output {node.name} is not a tensor, this might cause problems during lowering!\n" - else: - total_output_bytes += get_size_of_node(submod, node)[0] - - map_arg(output_node.args, get_bytes) - qps = self.PCIe_BW / max(total_input_bytes, total_output_bytes) - reports += f"Total input size in bytes is {total_input_bytes}, total output size in bytes is {total_output_bytes}," - reports += f" theoretical max qps (bounds by PCIe bandwidth) for this submodule is {qps}.\n" - - if qps < max_qps: - max_qps = qps - bottleneck_module = node.target - - try: - lowered_submod = self._lower_model_to_backend(submod, submod_inputs) - except RuntimeError: - reports += "Run into an error during lowering!\n" - reports += self._find_culprit(submod, submod_inputs) - continue - - try: - lowered_submod(*submod_inputs) - except RuntimeError: - reports += "Run into an error during inference!\n" - reports += self._find_culprit(submod, submod_inputs) - else: - reports += "Lowering and running succeed!\n" - - reports += f"\nTheoretical max qps (bounds by PCIe bandwidth) for this model is {max_qps}," - reports += f" bottleneck is submodule {bottleneck_module}." - print(reports) - - # return the reports for testing purposes - return reports - - # =============================================================== - # Helpers for extend_acc_subgraph() method - # =============================================================== - - def find_reverse_deps( - self, tag_id: Optional[int] = None - ) -> Dict[pippy.fx.Node, NodeSet]: - """ - Builds reversed topological node dependencies, if tag_id is specified, - we ignore nodes that are in later subgraph i.e. nodes have greater tag_id. - """ - result: Dict[pippy.fx.Node, NodeSet] = defaultdict(set) - - for node in self.module.graph.nodes: - if node.op not in CALLABLE_NODE_OPS: - continue - - for user in node.users: - if user.op not in CALLABLE_NODE_OPS: - continue - - if tag_id is None or (int(user.tag.split("_")[-1]) < tag_id): - result[node].add(user) - - return result - - def update_reverse_deps_for_fusions( - self, deps: Dict[pippy.fx.Node, NodeSet] - ): - processed_node = set() - - for node, fusion in self.fusions.items(): - if node in processed_node: - continue - - new_dep = set() - - # Create a new dependency set which include all the - # dependencies of the nodes in the fusion group - for n in fusion: - new_dep.update(deps[n]) - - # Exclude nodes in the fusion - new_dep.difference_update(fusion) - - # Update dependency - for n in fusion: - deps[n] = new_dep - - for arg in n.all_input_nodes: - if arg not in fusion: - deps[arg].update(fusion) - - processed_node.add(n) - - def find_parent_nodes_of_subgraph(self, tag: str) -> NodeSet: - """ - Finds parent nodes of the `tag` subgraph. - - Traverse the inputs of nodes in the subgraph, if input doesn't belong to the subgraph - and is not a placeholder, we consider it as the parent node of the subgraph. - """ - parent_nodes = set() - - for node in self.module.graph.nodes: - if node.op in CALLABLE_NODE_OPS and node.tag == tag: - for arg in node.all_input_nodes: - if arg.op in CALLABLE_NODE_OPS and arg.tag != tag: - parent_nodes.add(arg) - - return parent_nodes - - def extend_acc_subgraph(self, tag: str): - """ - Extend the acc subgraph with `tag` going the reversed topological direction. - """ - # Dict that maps node to its users and ignore users that - # are in the subgraph that has greater tag - deps = self.find_reverse_deps(tag_id=int(tag.split("_")[-1])) - self.update_reverse_deps_for_fusions(deps) - - # Parent nodes of the subgraph - parent_nodes = self.find_parent_nodes_of_subgraph(tag) - - visited_nodes: NodeSet = set() - - while parent_nodes: - node = None - - # Find a acc node that depends on visited nodes only - for n in parent_nodes: - if deps[n] <= visited_nodes and n in self.acc_nodes: - node = n - break - - if node is None: - break - - # Put the node into `tag` subgraph - node.tag = tag # type: ignore[attr-defined] - parent_nodes.remove(node) - visited_nodes.add(node) - - # If node is in a fusion group, add all fusion buddies to parent nodes - if node in self.fusions: - for fusion_node in self.fusions[node]: - if fusion_node not in visited_nodes: - parent_nodes.add(fusion_node) - - # Add inputs of the node to parent nodes - for arg in node.all_input_nodes: - if arg.op in CALLABLE_NODE_OPS and arg not in visited_nodes: - parent_nodes.add(arg) - - # =============================================================== - # Helpers for split() method - # =============================================================== - - def starter_nodes(self) -> Tuple[NodeSet, NodeSet]: - """ - Finds nodes that consume module inputs or get_attr nodes. - """ - starter_cpu_nodes: NodeSet = set() - starter_acc_nodes: NodeSet = set() - for node in self.module.graph.nodes: - if node.op not in {"placeholder", "get_attr"}: - continue - for user in node.users: - if user in self.acc_nodes: - starter_acc_nodes.add(user) - else: - starter_cpu_nodes.add(user) - return starter_cpu_nodes, starter_acc_nodes - - def put_nodes_into_subgraphs(self) -> List[Subgraph]: - # We start graph traversal from leaf nodes - current_cpu_nodes, current_acc_nodes = self.starter_nodes() - visited_nodes: NodeSet = set() - - # Determine which subgraph to start from based on which subgraph has - # 0-dep node - acc_subgraph: bool = not any([len(self.deps[n]) == 0 for n in current_cpu_nodes]) - - current_subgraph_nodes: NodeList = [] - - # Result accumulator - subgraphs: List[Subgraph] = [] - while current_cpu_nodes or current_acc_nodes: - # Find the first node that should belong to the current subgraph and has all dependencies resolved - current_nodes = current_acc_nodes if acc_subgraph else current_cpu_nodes - node = next( - (n for n in current_nodes if self.deps[n] <= visited_nodes), - None, - ) - - # If nothing was found, then it's time to flip the mode and start a new subgraph - if node is None: - if not current_subgraph_nodes: - raise FxNetSplitterInternalError("Subgraph can't be empty") - - subgraphs.append( - Subgraph(is_acc=acc_subgraph, nodes=current_subgraph_nodes) - ) - acc_subgraph = not acc_subgraph - current_subgraph_nodes = [] - continue - - current_nodes.remove(node) - visited_nodes.add(node) - current_subgraph_nodes.append(node) - - # Add fusion buddies - if node in self.fusions: - if node in self.acc_nodes: - current_acc_nodes.update(self.fusions[node] - visited_nodes) - else: - current_cpu_nodes.update(self.fusions[node] - visited_nodes) - - # Put depending nodes into the queue - for user in node.users: - if user.op not in CALLABLE_NODE_OPS: - continue - - # Add downstream nodes - if user in self.acc_nodes: - current_acc_nodes.add(user) - else: - current_cpu_nodes.add(user) - - # Check if the last subgraph was not created - if current_subgraph_nodes: - subgraphs.append( - Subgraph(is_acc=acc_subgraph, nodes=current_subgraph_nodes) - ) - - if not subgraphs: - raise FxNetSplitterInternalError("Couldn't create subgraphs") - - return subgraphs - - def remove_small_acc_subgraphs(self, subgraphs: List[Subgraph]) -> List[Subgraph]: - """ - This pass finds ACC submodules with less than specified size and merges - them with adjacent CPU submodules. - """ - result: List[Subgraph] = [] - for subgraph in subgraphs: - if subgraph.is_acc: - if len(subgraph.nodes) >= self.settings.min_acc_module_size: - result.append(subgraph) - else: - print( - "Eliminating acc subgraph because it's smaller than the threshold: " - f"{len(subgraph.nodes)} < {self.settings.min_acc_module_size}" - ) - if result: - result[-1].nodes.extend(subgraph.nodes) - else: - subgraph.is_acc = False - result.append(subgraph) - else: - if result and not result[-1].is_acc: - result[-1].nodes.extend(subgraph.nodes) - else: - result.append(subgraph) - return result - - def tag(self, subgraphs: List[Subgraph]): - self.tags: List[str] = [] - for subgraph in subgraphs: - tag = f"_run_on_acc_{len(self.tags)}" if subgraph.is_acc else f"{self.non_acc_submodule_name}{len(self.tags)}" - self.tags.append(tag) - for node in subgraph.nodes: - if hasattr(node, "tag"): - raise FxNetSplitterInternalError(f"Node {node} was already tagged") - - node.tag = tag # type: ignore[attr-defined] - self._node_submodule_map[node.name] = tag - - def split(self, remove_tag: bool = False) -> pippy.fx.GraphModule: - split_module = split_by_tags(self.module, self.tags) - if remove_tag: - for node in self.module.graph.nodes: - if hasattr(node, "tag"): - del node.tag - return split_module - - def __call__(self) -> pippy.fx.GraphModule: - subgraphs = self.put_nodes_into_subgraphs() - subgraphs = self.remove_small_acc_subgraphs(subgraphs) - acc_subgraphs_count = len([s for s in subgraphs if s.is_acc]) - non_acc_subgraphs_count = len(subgraphs) - acc_subgraphs_count - print(f"Got {acc_subgraphs_count} acc subgraphs and {non_acc_subgraphs_count} non-acc subgraphs") - self.tag(subgraphs) - return self.split() - - def generate_split_results(self) -> SplitResult: - split_module = self() - submodule_names = [] - for name, mod in split_module.named_children(): - submodule_names.append(name) - submodule_inputs = generate_inputs_for_submodules(split_module, self.sample_input, submodule_names) - return SplitResult(split_module, submodule_inputs, self.non_acc_submodule_name) diff --git a/pippy/fx/passes/tests/__init__.py b/pippy/fx/passes/tests/__init__.py deleted file mode 100644 index f2661b8c6..000000000 --- a/pippy/fx/passes/tests/__init__.py +++ /dev/null @@ -1 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates diff --git a/pippy/fx/passes/tests/test_pass_manager.py b/pippy/fx/passes/tests/test_pass_manager.py deleted file mode 100644 index 34b325355..000000000 --- a/pippy/fx/passes/tests/test_pass_manager.py +++ /dev/null @@ -1,37 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates -import unittest - -from ..pass_manager import ( - inplace_wrapper, - PassManager, - these_before_those_pass_constraint, - this_before_that_pass_constraint, -) - - -class TestPassManager(unittest.TestCase): - def test_pass_manager_builder(self) -> None: - passes = [lambda x: 2 * x for _ in range(10)] - pm = PassManager(passes) - pm.validate() - - def test_this_before_that_pass_constraint(self) -> None: - passes = [lambda x: 2 * x for _ in range(10)] - pm = PassManager(passes) - - # add unfulfillable constraint - pm.add_constraint(this_before_that_pass_constraint(passes[-1], passes[0])) - - self.assertRaises(RuntimeError, pm.validate) - - def test_these_before_those_pass_constraint(self) -> None: - passes = [lambda x: 2 * x for _ in range(10)] - constraint = these_before_those_pass_constraint(passes[-1], passes[0]) - pm = PassManager( - [inplace_wrapper(p) for p in passes] - ) - - # add unfulfillable constraint - pm.add_constraint(constraint) - - self.assertRaises(RuntimeError, pm.validate) diff --git a/pippy/fx/passes/tools_common.py b/pippy/fx/passes/tools_common.py deleted file mode 100644 index 50a242c88..000000000 --- a/pippy/fx/passes/tools_common.py +++ /dev/null @@ -1,254 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates -from typing import List, Tuple, Union, Dict, Any, Set, Mapping -import collections -from dataclasses import dataclass - -import torch -import pippy.fx -from pippy.fx.node import _get_qualified_name -from pippy.fx._compatibility import compatibility - -__all__ = ['get_acc_ops_name', 'get_node_target', 'is_node_output_tensor', 'FxNetAccFusionsFinder', 'legalize_graph'] - -Tensors = Union[Tuple[torch.Tensor], List[torch.Tensor]] -TensorOrTensors = Union[torch.Tensor, Tensors] -NodeList = List[pippy.fx.Node] -NodeSet = Set[pippy.fx.Node] -Names = List[str] -CALLABLE_NODE_OPS = {"call_module", "call_function", "call_method"} - - -@compatibility(is_backward_compatible=False) -def get_acc_ops_name(k): - if isinstance(k, str): - return k - elif k.__module__ and "acc_ops" in k.__module__: - return f"acc_ops.{k.__name__}" - else: - module = k.__module__.replace('torch._ops', 'torch.ops') # WAR for bug in how torch.ops assigns module - return f"{module if module else ''}.{k.__name__}" - - -@compatibility(is_backward_compatible=False) -def get_node_target(submodules: Mapping[str, torch.nn.Module], node: pippy.fx.Node) -> str: - """ - Given a `node` returns its target typename. - - For "call_method" node, return node.target which is the name of that method being called. - This could potential lead to conflict but should be okay because normally it's on a tensor. - - For "call_function" node, return typename of node.target. - - For "call_module" node, return typename of the module that node.target point to. - - If seeing "_VariableFunctionsClass" in the target name string, it will be replaced by - "torch". e.g. _VariableFunctionsClass.relu would become torch.relu. - """ - - assert node.op in CALLABLE_NODE_OPS, ( - "Expect op types of " + ", ".join(CALLABLE_NODE_OPS) + f", but found {node.op}" - ) - - if node.op == "call_module": - assert isinstance(node.target, str) - submod = submodules[node.target] - submod_type = getattr(submod, "_base_class_origin", type(submod)) - return get_acc_ops_name(submod_type) - elif node.op == "call_function": - target: Any = node.target - return ( - f"acc_ops.{target.__name__}" - if target.__module__ is not None and "acc_ops" in target.__module__ - else _get_qualified_name(target) - ) - else: - assert isinstance(node.target, str) - return node.target - -@compatibility(is_backward_compatible=False) -def is_node_output_tensor(node: pippy.fx.Node) -> bool: - """Checks if the node output produces a Tensor or not. - - NOTE: This requires to run `ShapeProp` on the containing fx graph before - calling this function. This is because it works by checking the `type` - metadata on the node. This metadata is produced by the `ShapeProp`. - """ - type_ = node.meta.get("type", None) - return type_ is not None and issubclass(type_, torch.Tensor) - -@compatibility(is_backward_compatible=False) -class FxNetAccFusionsFinder: - """ - Finds groups of connected ACC nodes that pass non-tensor data between each other. - Such groups are called fusion groups. - """ - - def __init__(self, module: pippy.fx.GraphModule, acc_nodes: NodeSet): - self.module = module - self.nodes = list(module.graph.nodes) - self.acc_nodes = acc_nodes - - @dataclass - class FusionGroup: - # The smallest idx of nodes in the fusion group after topological sorting all the nodes in the model. - top_node_idx: int - - # Nodes in this fusion group. - nodes: NodeSet - - # Inputs to this fusion group. - inputs: NodeSet - - # Nodes that in the fusion group that haven't been processed yet. - nodes_need_process: NodeSet - - def add_node(self, node): - """ - Add a node to fusion group. - """ - if node in self.nodes: - return - - self.nodes_need_process.add(node) - self.nodes.add(node) - self.inputs.discard(node) - self.inputs.update( - { - n - for n in node.all_input_nodes - if n.op in CALLABLE_NODE_OPS and n not in self.nodes - } - ) - - def recursive_add_node( - self, - fusion_group: "FxNetAccFusionsFinder.FusionGroup", - inputs: Union[NodeSet, NodeList], - ): - """ - Start from inputs and going reverse topological order. If any upstream node - is in the fusion group, add all the nodes in this path to fusion group. - """ - for arg in inputs: - # Skip placeholder and get_attr because they won't be in the fusion group. - if arg.op not in CALLABLE_NODE_OPS: - continue - - # If the node has smaller idx, it's already an upstream node of the fusion - # group. We don't need to check it anymore. - if self.nodes.index(arg) < fusion_group.top_node_idx: - continue - - # If the node is in the fusion group, return True. - if arg in fusion_group.nodes: - return True - - # Check the upstream nodes of the node, if any of them is in the fusion group - # we'll add this node to fusion group and return True. - if self.recursive_add_node(fusion_group, arg.all_input_nodes): - fusion_group.add_node(arg) - return True - - return False - - def __call__(self) -> Dict[pippy.fx.Node, NodeSet]: - result: Dict[pippy.fx.Node, NodeSet] = {} - acc_nodes = list(self.acc_nodes) - - for node in acc_nodes: - if node in result: - continue - if node.op not in CALLABLE_NODE_OPS: - continue - if "tensor_meta" in node.meta: - continue - if node not in self.acc_nodes: - continue - - fusion_group: "FxNetAccFusionsFinder.FusionGroup" = self.FusionGroup( - top_node_idx=self.nodes.index(node), - nodes={node}, - inputs=set(node.all_input_nodes), - nodes_need_process={node}, - ) - while fusion_group.nodes_need_process: - node = fusion_group.nodes_need_process.pop() - self.recursive_add_node(fusion_group, fusion_group.inputs) - - # Optionally add downstream nodes - if "tensor_meta" not in node.meta: - for user in node.users: - if user.op not in CALLABLE_NODE_OPS: - continue - if user in fusion_group.nodes: - continue - - fusion_group.add_node(user) - self.recursive_add_node(fusion_group, fusion_group.inputs) - - # Add some upstream nodes - for arg in node.all_input_nodes: - if arg.op not in CALLABLE_NODE_OPS: - continue - if "tensor_meta" in arg.meta: - continue - if arg in fusion_group.nodes: - continue - - fusion_group.add_node(arg) - fusion_group.top_node_idx = min( - fusion_group.top_node_idx, self.nodes.index(arg) - ) - self.recursive_add_node(fusion_group, fusion_group.inputs) - - if not (set(fusion_group.nodes) <= self.acc_nodes): - self.acc_nodes -= fusion_group.nodes - else: - for n in fusion_group.nodes: - result[n] = fusion_group.nodes - - return result - - -@compatibility(is_backward_compatible=False) -def legalize_graph(gm: pippy.fx.GraphModule) -> pippy.fx.GraphModule: - """ - Replace the graph of the given GraphModule with one that contains the same nodes as the - original, but in topologically sorted order. - - This is used by the merge_matmul transformation below, which disturbs the topologically sorted - order of its input GraphModule, so that this order is restored before further transformation. - - Arguments: - gm: The graph module to topologically sort. It is modified in-place. - - Returns: - The graph module in-place sorted - """ - indeg = {node: 0 for node in gm.graph.nodes} - new_graph = pippy.fx.Graph() - # Track how many unfulfilled dependencies each node has - for node in gm.graph.nodes: - for user in node.users: - indeg[user] += 1 - queue: collections.deque = collections.deque() - # Add all nodes with no dependencies to the queue - for node in gm.graph.nodes: - if indeg[node] == 0: - queue.append(node) - env: Dict[pippy.fx.Node, pippy.fx.Node] = {} - # Pop nodes from the queue, and add nodes that have had all their - # dependencies fulfilled - while len(queue) > 0: - cur = queue.popleft() - env[cur] = new_graph.node_copy(cur, lambda x: env[x]) - for user in cur.users: - indeg[user] -= 1 - if indeg[user] == 0: - queue.append(user) - # If the new graph's size is not as large as the old one, then there must be - # a cycle (i.e. some node's dependencies were not satisfied.) - if len(new_graph.nodes) < len(gm.graph.nodes): - raise RuntimeError(f"Input graph has cycles, unable to add {[node for node in indeg if indeg[node] != 0]}") - gm.graph = new_graph - return gm diff --git a/pippy/fx/passes/utils/__init__.py b/pippy/fx/passes/utils/__init__.py deleted file mode 100644 index e4af20899..000000000 --- a/pippy/fx/passes/utils/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates -from .common import lift_subgraph_as_module, HolderModule diff --git a/pippy/fx/passes/utils/common.py b/pippy/fx/passes/utils/common.py deleted file mode 100644 index 972848347..000000000 --- a/pippy/fx/passes/utils/common.py +++ /dev/null @@ -1,84 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates -from torch.nn import Module - -from pippy.fx.graph_module import GraphModule -from pippy.fx.graph import Graph -from pippy.fx.passes.utils.matcher_utils import SubgraphMatcher -from pippy.fx._compatibility import compatibility - - -__all__ = ['HolderModule', 'lift_subgraph_as_module', 'compare_graphs'] - -@compatibility(is_backward_compatible=False) -class HolderModule(Module): - """ - HolderModule is used to copy all the attributes from original module to submodules - that uses the attributes - """ - - def __init__(self, d): - super().__init__() - for k, v in d.items(): - self.add_module(k, v) - - -@compatibility(is_backward_compatible=False) -def lift_subgraph_as_module(gm: GraphModule, subgraph: Graph, class_name: str = 'GraphModule') -> GraphModule: - """ - Create a GraphModule for subgraph, which copies the necessory attributes from the original parent graph_module. - - Args: - gm (GraphModule): parent graph module - - subgraph (Graph): a valid subgraph that contains copied nodes from the parent graph - - class_name (str): name for the submodule - - """ - - # Loop through all module calls (call_module) and param fetches (get_attr) - # in this component, creating HolderModules as necessary to match the path. - # e.g. if in the original module there's a get_attr node fetches "conv.weight". - # We create a HolderModule as root -> add a HolderModule named "conv" -> - # make "weight" a attribute of "conv" HolderModule and point to conv.weight in - # the original module. - submodule = HolderModule({}) - for n in subgraph.nodes: - if n.op not in ("call_module", "get_attr"): - continue - - target = n.target - assert isinstance(target, str) - target_name_parts = target.split(".") - curr = submodule - orig_gm = gm - - for name in target_name_parts[:-1]: - if not hasattr(curr, name): - curr.add_module(name, HolderModule({})) - - curr = getattr(curr, name) - orig_gm = getattr(orig_gm, name) - - leaf_node_name = target_name_parts[-1] - leaf_node = getattr(orig_gm, leaf_node_name) - - # Relies on custom __setattr__ magic. - setattr(curr, leaf_node_name, leaf_node) - - return GraphModule(submodule, subgraph, class_name) - - -@compatibility(is_backward_compatible=False) -def compare_graphs(left: Graph, right: Graph) -> bool: - """ - Return True if two graphs are identical, i.e they - - have the same number of outputs in the same order - - have the same number of inputs in the same order - - have the same set of nodes, and identical connectivity - """ - - matcher = SubgraphMatcher(left, match_output=True, match_placeholder=True) - matches = matcher.match(right) - - return len(matches) > 0 diff --git a/pippy/fx/passes/utils/fuser_utils.py b/pippy/fx/passes/utils/fuser_utils.py deleted file mode 100644 index 270739078..000000000 --- a/pippy/fx/passes/utils/fuser_utils.py +++ /dev/null @@ -1,214 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates -import copy -from queue import SimpleQueue -from typing import List, Dict, Tuple - -import pippy.fx -from pippy.fx.graph_module import GraphModule -from pippy.fx.graph import Graph -from pippy.fx.node import Node -from pippy.fx.passes.tools_common import NodeList, NodeSet, legalize_graph -from pippy.fx.passes.utils import lift_subgraph_as_module - -def topo_sort(nodes: NodeList) -> NodeList: - # sort nodes according to the topological order - indegree_map = {node : 0 for node in nodes} - candidates: SimpleQueue = SimpleQueue() - - for node in nodes: - for n in node.all_input_nodes: - if n in indegree_map: - indegree_map[node] += 1 - if indegree_map[node] == 0: - candidates.put(node) - - sorted_nodes: NodeList = list() - while not candidates.empty(): - node = candidates.get() - sorted_nodes.append(node) - - for n in node.users: - if n in indegree_map: - indegree_map[n] -= 1 - if indegree_map[n] == 0: - candidates.put(n) - - assert len(nodes) == len(sorted_nodes), "topological sorted nodes doesn't have same length as input nodes" - - return sorted_nodes - - -def validate_partition(partition: NodeList) -> bool: - # verify the partition does't form a dependency cycle in the original graph - # returns True for valid partition, False for invalid - - partition_set = set(partition) - - outputs: NodeList = list() - for node in partition_set: - for user_node in node.users: - if user_node not in partition_set: - # external user node, need to expose as an output - outputs.append(user_node) - - # perform DFS on the parition outputs - # if it reaches a node within the partition, then it found a cycle - visited: NodeSet = set() - - def dfs_find_cycle(node): - if node in partition_set: - return True # found cycle, return - - visited.add(node) - for user_node in node.users: - if user_node not in visited: - if dfs_find_cycle(user_node): - return True - return False - - for output_node in outputs: - if dfs_find_cycle(output_node): - return False - - return True - - -def fuse_as_graphmodule(gm: GraphModule, - nodes: NodeList, - module_name: str) -> Tuple[GraphModule, Tuple[Node, ...], Tuple[Node, ...]]: - - """ - Fuse nodes in graph_module into a GraphModule. - - Args: - gm (GraphModule): target graph_module - - nodes (List[Node]): list of nodes in `gm` to fuse, where the node must be topologically sorted - - module_name: class name for the fused GraphModule - - Returns: - fused_gm (GraphModule): fused graph module, where its node is a copy of `nodes` in `gm` - - original_inputs (Tuple[Node, ...]): input nodes to `nodes` in original `gm` - - original_outputs (Tuple[Node, ...]): consumer nodes of `nodes` in original `gm` - - """ - - # assumption: nodes are already sorted in topo order - - for node in nodes: - assert node.graph.owning_module is gm, f"{node} doesn't belong to passed in graph module {gm._get_name()}" - assert not node._erased, f"{node} has been removed from owning graph" - assert node in gm.graph.nodes, f"{node} is not found in graph module {gm._get_name()}" - - # validates partition doesn't introduce dependency circles in the graph - assert validate_partition(nodes), "Invalid partition, found dependency cycles" - - subgraph = Graph() - - node_to_placeholder: Dict[Node, Node] = {} # mapping of nodes from old graph to placeholder in new graph - node_map: Dict[Node, Node] = {} # mapping of nodes from old graph to new graph - - # handles inputs throught graph.node_copy's arg_transform functions - def remap_inputs(x): - if x.op == "get_attr": - # TODO: do we really need copy the get_attr node into the graph? - # do something here - pass - - if x in nodes: - # x is inside subgraph, return the copied node - # the node should have been copied aleady, as we are copying graph in the topological order - return node_map[x] - - if x not in node_to_placeholder: - # x is not in subgraph, create a new placeholder for subgraph - placeholder_node = subgraph.placeholder(x.name, type_expr=x.type) - # copy all meta fields, even if some fields might be irrelvant for the placeholder node - placeholder_node.meta = copy.copy(x.meta) - node_to_placeholder[x] = placeholder_node - - return node_to_placeholder[x] - - # copy nodes in topological order - for node in nodes: - new_node = subgraph.node_copy(node, remap_inputs) - node_map[node] = new_node - - # handles outputs - output_mapping: Dict[Node, Node] = {} # mapping from old output to new outputs - - for node in nodes: - for user_node in node.users: - if user_node not in nodes: - # external user node, need to expose as an output - output_mapping[node] = node_map[node] - - # outs contain nodes in the new subgraph - outs = tuple(output_mapping.values()) - - # Take care of the args of FX output node. If there's a single - # output then the output node args is like (output_single), else - # if there're multiple outputs then the output node args is like - # ((output_0, output_1, ...)). - subgraph.output(outs[0] if len(outs) == 1 else outs) - - # lint to ensure correctness - subgraph.lint() - - fused_gm: GraphModule = lift_subgraph_as_module(gm, subgraph, class_name=module_name) - - # sub_gm's input nodes in the original module - original_inputs: Tuple[Node, ...] = tuple(node_to_placeholder.keys()) - - # sub_gm's outputs node in the original module - original_outputs: Tuple[Node, ...] = tuple(output_mapping.keys()) - - return fused_gm, original_inputs, original_outputs - - -def insert_subgm(gm: GraphModule, sub_gm: GraphModule, orig_inputs: Tuple[Node, ...], orig_outputs: Tuple[Node, ...]): - # add sub_gm into gm - submodule_name = sub_gm.__class__.__name__ - gm.add_submodule(submodule_name, sub_gm) - - # Create a call_module node in main graph. - module_node = gm.graph.call_module( - submodule_name, - args=orig_inputs, - kwargs=None) - - if len(orig_outputs) == 1: - # main_remapping[comp.orig_outputs[0]] = module_node - orig_outputs[0].replace_all_uses_with(module_node) - else: - for i, orig_output in enumerate(orig_outputs): - # Use Proxy to record getitem access. - proxy_out = pippy.fx.Proxy(module_node)[i].node # type: ignore[index] - orig_output.replace_all_uses_with(proxy_out) - return gm - -def erase_nodes(gm: GraphModule, nodes: NodeList): - - # erase original nodes in inversed topological order - for node in reversed(nodes): - gm.graph.erase_node(node) - - -def fuse_by_partitions(gm: GraphModule, partitions: List[NodeList]) -> GraphModule: - for partition_id, nodes in enumerate(partitions): - sorted_nodes = topo_sort(nodes) - - submodule_name = "fused_" + str(partition_id) - sub_gm, orig_inputs, orig_outputs = fuse_as_graphmodule(gm, sorted_nodes, submodule_name) - - insert_subgm(gm, sub_gm, orig_inputs, orig_outputs) - - erase_nodes(gm, sorted_nodes) - - # topological sort original gm with newly created sub_gm - legalize_graph(gm) - - return gm diff --git a/pippy/fx/passes/utils/matcher_utils.py b/pippy/fx/passes/utils/matcher_utils.py deleted file mode 100644 index 27bb9240a..000000000 --- a/pippy/fx/passes/utils/matcher_utils.py +++ /dev/null @@ -1,309 +0,0 @@ -from dataclasses import dataclass, field -from collections import defaultdict -import copy -from pippy.fx.graph import Graph -from pippy.fx.node import Node -from pippy.fx._compatibility import compatibility -import torch.utils._pytree as pytree -from typing import Dict, List, Set, Any -import os -import logging - -__all__ = ['SubgraphMatcher', 'InternalMatch'] - -format_str = "%(levelname)s > %(message)s" -LOGLEVEL = os.environ.get('LOGLEVEL', 'WARNING').upper() -logging.basicConfig(level=LOGLEVEL, format=format_str) -logger = logging.getLogger(__name__) - -@compatibility(is_backward_compatible=False) -@dataclass -class InternalMatch(): - # Nodes from which the match was found - anchors: List[Node] - # Maps nodes in the pattern subgraph to nodes in the larger graph - nodes_map: Dict[Node, Node] = field(default_factory=dict) - - # nodes in target graph that are matched placeholder in pattern - placeholder_nodes: List[Node] = field(default_factory=list) - - # nodes in matched subgraph returned by output - returning_nodes: List[Node] = field(default_factory=list) - - def __copy__(self): - return InternalMatch(anchors=self.anchors, nodes_map=self.nodes_map.copy(), - placeholder_nodes=self.placeholder_nodes.copy(), - returning_nodes=self.returning_nodes.copy()) - -@compatibility(is_backward_compatible=False) -class SubgraphMatcher: - def __init__(self, pattern: Graph, - match_output: bool = False, - match_placeholder: bool = False, - remove_overlapping_matches: bool = True) -> None: - """ - Args: - pattern: the targeted matching pattern, represented in fx.Graph. - match_output: If True, output node in the pattern graph will be treated as a part of the targeted pattern. - If False, output node is ignored during match. - match_placeholder: If True, placeholder node in the pattern graph will be treated as a part of - the targeted pattern. If False, placeholder nodes will be used a wildcard. - remove_overlapping_matches: If True, in the case of overlapping matches, only the first match - will be returned. - """ - - self.pattern = pattern - self.match_output = match_output - self.match_placeholder = match_placeholder - self.remove_overlapping_matches = remove_overlapping_matches - - if len(pattern.nodes) == 0: - raise ValueError("SubgraphMatcher cannot be initialized with an empty pattern") - - for node in pattern.nodes: - if node.op != "output": - assert len(node.users) > 0, \ - "SubgraphMatcher cannot be initialized with an pattern with dead code" - - # TODO: assert pattern is a connected graph - - self.pattern_placeholder_nodes = [n for n in pattern.nodes if n.op == "placeholder"] - output_node = next(iter(reversed(pattern.nodes))) - # nodes returned by outputs - self.pattern_returning_nodes: List[Node] = output_node.all_input_nodes - - self.pattern_anchors: List[Node] = [] - if match_output: - self.pattern_anchors = [output_node] - else: - # If a node has output_node as the ONLY user, then this node is a graph sink, - # and should be matched against as an anchor - self.pattern_anchors = [n for n in output_node.all_input_nodes if len(n.users) == 1] - - def _nodes_are_equal(self, pn: Node, gn: Node) -> bool: - # if exact match for placeholder is not required, then use placeholder as a wildcard - if not self.match_placeholder and pn.op == "placeholder": - return True - - if pn.op == gn.op: - if pn.op == "placeholder" or pn.op == "output": - return True - return pn.target == gn.target - return False - - def _is_contained(self, nodes_map: Dict[Node, Node]) -> bool: - # `lookup` represents all the nodes in `original_graph` - # that are part of `pattern` - lookup: Dict[Node, Node] = {gn : pn for pn, gn in nodes_map.items()} - for gn, pn in lookup.items(): - # Placeholders can be used by other nodes in the graphs - if pn.op == "placeholder": - continue - - # nodes returned by output are allowed to be used in other areas of the graph - if pn in self.pattern_returning_nodes: - continue - - for user in gn.users: - # If this node has users that were not in `lookup`, then it must leak out of the - # pattern subgraph - if user not in lookup: - return False - return True - - def _remove_overlapping_matches(self, matches: List[InternalMatch]) -> List[InternalMatch]: - non_overlapping_matches: List[InternalMatch] = list() - nodes_matched: Set[Node] = set() - - for match in matches: - found_overlap = False - for pn, gn in match.nodes_map.items(): - if pn.op not in {"placeholder", "output"} and gn in nodes_matched: - found_overlap = True - break - - if not found_overlap: - non_overlapping_matches.append(match) - for pn, gn in match.nodes_map.items(): - if pn.op not in {"placeholder", "output"}: - nodes_matched.add(gn) - return non_overlapping_matches - - def _match_args(self, pn: Any, gn: Any, match: InternalMatch) -> bool: - assert not(isinstance(pn, Node) and isinstance(gn, Node)), "pn and gn cannot both be Node" - - if isinstance(pn, Node) and not isinstance(gn, Node): - if pn.op == "placeholder": - # Check if we've already matched these nodes in the current - # traversal - if pn in match.nodes_map: - return match.nodes_map[pn] == gn - - match.nodes_map[pn] = gn - return True - else: - return False - elif not isinstance(pn, Node) and isinstance(gn, Node): - return False - else: - return type(gn) == type(pn) and gn == pn - - def _match_nodes(self, pn: Node, gn: Node, match: InternalMatch) -> bool: - logger.info(f" matching {pn} to {gn}") - - assert isinstance(pn, Node) and isinstance(gn, Node), str(f"pn and gn must be Node, pn: {pn}, gn: {gn}") - - # Check if we've already matched these nodes in the current - # traversal - if pn in match.nodes_map: - return match.nodes_map[pn] == gn - - # TODO: use a more efficienty way to check if gn is matched before: two-way dict - if gn in match.nodes_map.values(): - return False - - if not self._nodes_are_equal(pn, gn): - return False - - # Optimistically mark `pn` as a match for `gn`, and save a local copy of match - saved_match = copy.copy(match) - match.nodes_map[pn] = gn - - if pn.op == "placeholder": - return True - - # Recursively traverse upwards to check if `pn` is a true - # match for `gn` - match_found = True - - pn_flatten_args, _ = pytree.tree_flatten(pn.args) - gn_flatten_args, _ = pytree.tree_flatten(gn.args) - - if pn.kwargs.keys() == gn.kwargs.keys(): - for key in pn.kwargs.keys(): - pn_flatten_args.append(pn.kwargs[key]) - gn_flatten_args.append(gn.kwargs[key]) - else: - match_found = False - - if match_found and len(pn_flatten_args) == len(gn_flatten_args): - for pn_, gn_ in zip(pn_flatten_args, gn_flatten_args): - if isinstance(gn_, Node) and isinstance(pn_, Node): - matched = self._match_nodes(pn_, gn_, match) - else: - matched = self._match_args(pn_, gn_, match) - - if not matched: - match_found = False - break - else: - match_found = False - - if not match_found: - # revert to saved_match before matching with current node - match = copy.copy(saved_match) - return False - - return True - - def match(self, graph: Graph) -> List[InternalMatch]: - """ - Returns: - The matched subgraphs. - Thre returned subgraph would be fully self-contained, meaning the nodes (except placeholder - and nodes returned by output) can only be consumed by nodes within the matched subgraph. - - Subgraph pattern matcher is implemented with the backtracking style in the following steps: - - 1. We first identify all the anchor nodes in the pattern graph. The anchor nodes - are the "sinks" (nodes with no user other than the output node) of the pattern graph. - One pattern graph could have multiple anchors if it has multiple return values. - - 2. In the target graph, we identify the potential candidate nodes that can be matched - with each anchor. These anchor-candidate pairs are the starting points for - pairwise per-node matching. - - 3. For each anchor-candidate pair, we simultaneously traverse backwards (DFS) in both - pattern and target graphs. For every pattern nodes along traversal path, we compare it - against the target nodes. In case any comparison failed, the match for this anchor-candidate - pair fails. A match is found when DFS completes traversing the graph. See `self._match_nodes` - for more details. - - 4. In the case of multiple anchors, every anchor will need to find a match using step 3. - In addition, the matches found between anchors need to have a common intersection node - in order for the match to be valid. This is implemented with backtracking. See `backtracking` - for more details. - - Notice: graph traversal must be done in the reverser order because a tensor can have multiple - consumers, but can only have a single producer. Only with reverser order, we can we jointly - traverse the pattern and target graph in a deterministic path. - - Warning: In theory, this backtracking algorithm have an **exponential** time complexity. However, - in practice, it's unlikely to blow up. - - """ - from pippy.fx.passes.utils.fuser_utils import validate_partition - - # find candidate nodes to match with pattern anchors - match_candidates: Dict[Node, List[Node]] = defaultdict(list) - for pattern_anchor in self.pattern_anchors: - for node in graph.nodes: - if self._nodes_are_equal(pattern_anchor, node): - match_candidates[pattern_anchor].append(node) - match_candidates_list = list(match_candidates.items()) - matches: List[InternalMatch] = [] - - def backtracking(anchor_index, match): - if anchor_index == len(match_candidates_list): - match.placeholder_nodes = [match.nodes_map[pn] for pn in self.pattern_placeholder_nodes] - match.returning_nodes = [match.nodes_map[pn] for pn in self.pattern_returning_nodes] - matches.append(match) - - logger.info(f"Found a match: {match}\n") - return - - pattern_anchor, candidate_nodes = match_candidates_list[anchor_index] - saved_match = copy.copy(match) - - for node in candidate_nodes: - logger.info(f"Trying to match anchor {pattern_anchor} to {node}") - - match_found = self._match_nodes(pattern_anchor, node, match) - if match_found: - # match next anchor - backtracking(anchor_index + 1, match) - else: - logger.info(f"Failed to match anchor {pattern_anchor} to {node}\n") - - # revert to saved_match before matching with current anchor - match = copy.copy(saved_match) - - match = InternalMatch(anchors=self.pattern_anchors) - backtracking(0, match) - - # filter out the matches where the subgraph is not fully_contained - before = len(matches) - matches = [match for match in matches if self._is_contained(match.nodes_map)] - after = len(matches) - if before != after: - logger.info(f"Filtered out {before - after} matches because they are not fully contained") - - # filter out the matches that that forms a cycle if the subgraph is fused - valid_matches = [] - for match in matches: - matched_compute_nodes = \ - [gn for pn, gn in match.nodes_map.items() if pn.op not in {"placeholder", "output"}] - if validate_partition(matched_compute_nodes): - valid_matches.append(match) - if len(valid_matches) != len(matches): - logger.info(f"Filtered out {len(matches) - len(valid_matches)} matches because \ - matched subgraph would form a cycle if fused") - - if self.remove_overlapping_matches: - before = len(valid_matches) - matches = self._remove_overlapping_matches(valid_matches) - after = len(matches) - if before != after: - logger.info(f"Filtered out {before - after} matches because matched subgraphs are overlapping") - - return matches diff --git a/pippy/fx/proxy.py b/pippy/fx/proxy.py deleted file mode 100644 index 73ec01089..000000000 --- a/pippy/fx/proxy.py +++ /dev/null @@ -1,419 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates -import dis -import inspect -import operator -import traceback -from typing import Tuple, Dict, Optional, Iterable, Any, Iterator, Callable - -import torch - -import pippy.fx.traceback as fx_traceback -from ._compatibility import compatibility -from .graph import magic_methods, reflectable_magic_methods, Graph -from .node import Target, Node, Argument, base_types, map_aggregate -from .operator_schemas import check_for_mutable_operation - -__all__ = ['TracerBase', 'GraphAppendingTracer', 'TraceError', 'Proxy', 'Attribute', 'ParameterProxy'] - -@compatibility(is_backward_compatible=True) -class TracerBase: - graph: Graph - record_stack_traces : bool = False - # Feature flag for mutable schema checking - # Enableby default in 1.12 - check_mutable_operations : bool = False - # Feature flag for assert tracing - trace_asserts : bool = False - # Feature flag for proxying accesses to buffer values - proxy_buffer_attributes : bool = False - - # Name of the function to be traced. It will only be used when - # ``root`` is an instance of ``nn.Module`` - traced_func_name: str = "forward" - - @compatibility(is_backward_compatible=True) - def create_node(self, kind : str, target : Target, - args : Tuple[Argument, ...], kwargs : Dict[str, Argument], name : Optional[str] = None, - type_expr : Optional[Any] = None) -> Node: - """ - Inserts a graph node given target, args, kwargs, and name. - - This method can be overridden to do extra checking, validation, or - modification of values used in node creation. For example, one might - want to disallow in-place operations from being recorded. - """ - if kind == 'call_function' and self.check_mutable_operations: - check_for_mutable_operation(target, args, kwargs) - - return self.graph.create_node(kind, target, args, kwargs, name, type_expr) - - @compatibility(is_backward_compatible=True) - def proxy(self, node: Node) -> 'Proxy': - return Proxy(node, self) - - @compatibility(is_backward_compatible=True) - def create_proxy(self, kind: str, target: Target, args: Tuple[Any, ...], kwargs: Dict[str, Any], - name: Optional[str] = None, type_expr : Optional[Any] = None, - proxy_factory_fn: Callable[[Node], 'Proxy'] = None): - ''' - Create a Node from the given arguments, then return the Node - wrapped in a Proxy object. - - If kind = 'placeholder', then we're creating a Node that - represents the parameter of a function. If we need to encode - a default parameter, we use the ``args`` tuple. ``args`` is - otherwise empty for ``placeholder`` Nodes. - ''' - - args_ = self.create_arg(args) - kwargs_ = self.create_arg(kwargs) - assert isinstance(args_, tuple) - assert isinstance(kwargs_, dict) - - node = self.create_node(kind, target, args_, kwargs_, name, type_expr) - - if not proxy_factory_fn: - proxy = self.proxy(node) - else: - proxy = proxy_factory_fn(node) - - # Optionally set stack trace on the created Node for debugging purposes - if fx_traceback.is_stack_trace_overridden(): - stacks = fx_traceback.format_stack() - proxy.node.stack_trace = '\n'.join(reversed(stacks)) - elif self.record_stack_traces: - user_frame = self._find_user_frame() - if user_frame: - walk_stack_gen = traceback.walk_stack(user_frame) - summary = traceback.StackSummary.extract(walk_stack_gen) # type: ignore[arg-type] - tb_lines = summary.format() - proxy.node.stack_trace = ''.join(tb_lines) - - return proxy - - def _find_user_frame(self): - """ - Find the Python stack frame executing the user code during - symbolic tracing. - """ - # We have to do a little dance here. Basically, walk up the callstack and - # record the first frame not in the pytorch source. This is the frame executing - # the user code during tracing. - frame = inspect.currentframe() - - pt_files = ['torch/fx/proxy.py', - 'torch/fx/_symbolic_trace.py', - 'torch/fx/experimental/proxy_tensor.py', - 'torch/_ops.py', - 'torch/_tensor.py', - 'torch/utils/_python_dispatch.py', - 'torch/_prims_common/wrappers.py', - 'torch/_refs/__init__.py', - 'torch/_refs/nn/functional/__init__.py' - ] - while frame: - frame = frame.f_back - if frame and all(not frame.f_code.co_filename.endswith(file) for file in pt_files): - break - - if not frame: - return None - - return frame - - @compatibility(is_backward_compatible=True) - def create_arg(self, a: Any) -> Argument: - """ - A method that lowers the objects seen as arguments during symbolic evaluation - into Argument types that can be stored in IR. - - Can be override to support more trace-specific types. - """ - if not isinstance(a, Proxy) and hasattr(a, '__fx_create_arg__'): - return a.__fx_create_arg__(self) - # aggregates - elif isinstance(a, tuple) and hasattr(a, '_fields'): - # NamedTuple constructors don't seem to like getting a generator - # expression as an argument to their constructor, so build this - # intermediate tuple and unpack it into the NamedTuple constructor - args = tuple(self.create_arg(elem) for elem in a) - return type(a)(*args) # type: ignore[arg-type] - elif isinstance(a, (tuple, list)): - return type(a)(self.create_arg(elem) for elem in a) - elif isinstance(a, dict): - r = {} - for k, v in a.items(): - # Check for invalid dict keys. We do not want a Proxy to appear - # anywhere within the key. Since keys can be collection types, - # we iterate through the key with map_aggregate - k = self.create_arg(k) - - def no_node(arg): - if isinstance(arg, Node): - raise RuntimeError("Keys for dictionaries used as an argument cannot contain a " - "Node. Got key: {k}") - map_aggregate(k, no_node) - - r[k] = self.create_arg(v) - return r - elif isinstance(a, slice): - return slice(self.create_arg(a.start), self.create_arg(a.stop), self.create_arg(a.step)) - - if isinstance(a, Proxy): - # base case: we unwrap the Proxy object - return a.node - elif isinstance(a, base_types) or a is None or a is ...: - return a - raise NotImplementedError(f"argument of type: {type(a)}") - - @compatibility(is_backward_compatible=True) - def to_bool(self, obj: 'Proxy') -> bool: - """Called when a proxy object is being converted to a boolean, such as - when used in control flow. Normally we don't know what to do because - we don't know the value of the proxy, but a custom tracer can attach more - information to the graph node using create_node and can choose to return a value. - """ - raise TraceError('symbolically traced variables cannot be used as inputs to control flow') - - @compatibility(is_backward_compatible=True) - def iter(self, obj: 'Proxy') -> Iterator: - """Called when a proxy object is being iterated over, such as - when used in control flow. Normally we don't know what to do because - we don't know the value of the proxy, but a custom tracer can attach more - information to the graph node using create_node and can choose to return an iterator. - """ - raise TraceError('Proxy object cannot be iterated. This can be ' - 'attempted when the Proxy is used in a loop or' - ' as a *args or **kwargs function argument. ' - 'See the pippy.fx docs on pytorch.org for a ' - 'more detailed explanation of what types of ' - 'control flow can be traced, and check out the' - ' Proxy docstring for help troubleshooting ' - 'Proxy iteration errors') - - @compatibility(is_backward_compatible=True) - def keys(self, obj: 'Proxy') -> Any: - """Called when a proxy object is has the keys() method called. - This is what happens when ** is called on a proxy. This should return an - iterator it ** is suppose to work in your custom tracer. - """ - return Attribute(obj, 'keys')() - - -# used in Proxy object when just appending to the graph while not tracing. -@compatibility(is_backward_compatible=True) -class GraphAppendingTracer(TracerBase): - def __init__(self, graph: Graph): - super().__init__() - self.graph = graph - -@compatibility(is_backward_compatible=False) -def assert_fn(x): - assert x - -@compatibility(is_backward_compatible=True) -class TraceError(ValueError): - pass - -@compatibility(is_backward_compatible=True) -class Proxy: - """ - ``Proxy`` objects are ``Node`` wrappers that flow through the - program during symbolic tracing and record all the operations - (``torch`` function calls, method calls, operators) that they touch - into the growing FX Graph. - - If you're doing graph transforms, you can wrap your own ``Proxy`` - method around a raw ``Node`` so that you can use the overloaded - operators to add additional things to a ``Graph``. - - ``Proxy`` objects cannot be iterated. In other words, the symbolic - tracer will throw an error if a ``Proxy`` is used in a loop or as - an ``*args``/``**kwargs`` function argument. - - There are two main ways around this: - 1. Factor out the untraceable logic into a top-level function and - use ``fx.wrap`` on it. - 2. If the control flow is static (i.e. the loop trip count is - based on some hyperparameter), the code can be kept in its original - position and refactored into something like:: - - for i in range(self.some_hyperparameter): - indexed_item = proxied_value[i] - - For a more detailed description into the Proxy internals, check out - the "Proxy" section in `torch/fx/OVERVIEW.md` - """ - - @compatibility(is_backward_compatible=True) - def __init__(self, node: Node, tracer: 'Optional[TracerBase]' = None): - if tracer is None: - # This allows you to create a Proxy object around a raw Node - tracer = GraphAppendingTracer(node.graph) - self.tracer = tracer - self.node = node - - def __repr__(self) -> str: - return f'Proxy({self.node.name})' - - def __getattr__(self, k) -> 'Attribute': - # note: not added to the graph yet, if this is a method call - # we peephole optimize to the method invocation - return Attribute(self, k) - - def __call__(self, *args, **kwargs) -> 'Proxy': - return self.tracer.create_proxy('call_method', '__call__', (self,) + args, kwargs) - - def __iter__(self) -> Iterable['Proxy']: - frame = inspect.currentframe() - assert frame is not None - calling_frame = frame.f_back - assert calling_frame is not None - inst = list(dis.get_instructions(calling_frame.f_code))[calling_frame.f_lasti // 2] - if inst.opname == 'UNPACK_SEQUENCE': - return (self[i] for i in range(inst.argval)) # type: ignore[index] - - return self.tracer.iter(self) - - def __bool__(self) -> bool: - if self.tracer.trace_asserts: - # check if this boolean is used in an assertion, bytecode pattern for assertions - # is pretty stable for Python 3.7--3.9 - frame = inspect.currentframe() - assert frame is not None - calling_frame = frame.f_back - assert calling_frame is not None - insts = list(dis.get_instructions(calling_frame.f_code)) - cur = calling_frame.f_lasti // 2 - inst = insts[cur] - - if inst.opname == 'POP_JUMP_IF_TRUE': - first = insts[cur + 1] - assert inst.arg is not None - last = insts[inst.arg // 2 - 1] - starts_with_assert = (first.opname == 'LOAD_GLOBAL' and first.argval == 'AssertionError' - or first.opname == 'LOAD_ASSERTION_ERROR') - if starts_with_assert and last.opname == 'RAISE_VARARGS': - self.tracer.create_proxy('call_function', assert_fn, (self,), {}) - return True - - return self.tracer.to_bool(self) - - @compatibility(is_backward_compatible=True) - def keys(self): - return self.tracer.keys(self) - - def __len__(self): - raise RuntimeError("'len' is not supported in symbolic tracing by default. If you want " - "this call to be recorded, please call pippy.fx.wrap('len') at " - "module scope") - - @classmethod - def __torch_function__(cls, orig_method, types, args=None, kwargs=None): - args = args if args else () - kwargs = kwargs if kwargs else {} - - tracers : Dict[Any, None] = {} - - def find_tracer(a): - if isinstance(a, cls): - tracers[a.tracer] = None - map_aggregate(args, find_tracer) - map_aggregate(kwargs, find_tracer) - - if len(tracers) > 1: - raise RuntimeError(f'Found multiple different tracers {list(tracers.keys())} while ' - f'trying to trace operations {orig_method}') - tracer = next(iter(tracers.keys())) - - if isinstance(orig_method, torch._C.ScriptMethod): - args = (orig_method.owner,) + args - return tracer.create_proxy('call_method', orig_method.name, args, kwargs) - if torch.overrides.is_tensor_method_or_property(orig_method): - return tracer.create_proxy('call_method', orig_method.__name__, args, kwargs) - else: - return tracer.create_proxy('call_function', orig_method, args, kwargs, - name=tracer.graph._target_to_str(orig_method.__name__)) - - -@compatibility(is_backward_compatible=True) -class Attribute(Proxy): - @compatibility(is_backward_compatible=True) - def __init__(self, root: Proxy, attr: str): - self.root = root - self.attr = attr - self.tracer = root.tracer - self._node: Optional[Node] = None - - @property - def node(self): - # the node for attributes is added lazily, since most will just be method calls - # which do not rely on the getitem call - if self._node is None: - self._node = self.tracer.create_proxy('call_function', getattr, (self.root, self.attr), {}).node - return self._node - - def __call__(self, *args, **kwargs): - return self.tracer.create_proxy('call_method', self.attr, (self.root,) + args, kwargs) - - -@compatibility(is_backward_compatible=False) -class ParameterProxy(Proxy): - """ - A special proxy which lets "shape", "size", "dim", and a few other - attribute accesses pass through to the underlying module parameter object, - so that conditional tests on these attributes will not throw exception during tracing - """ - def __init__(self, tracer: TracerBase, node: Node, name, param): - super().__init__(node, tracer) - assert(isinstance(param, torch.nn.Parameter)) - self.param = param - self.name = name - - def __repr__(self) -> str: - return f'ParameterProxy({self.name})' - - @property - def shape(self): - return self.param.shape - - def size(self): - return self.param.size() - - def dim(self): - return self.param.dim() - - @property - def ndim(self): - return self.param.ndim - - def numel(self): - return self.param.numel() - - def nelement(self): - return self.param.nelement() - - -for method in magic_methods: - def _scope(method): - def impl(*args, **kwargs): - tracer = args[0].tracer - target = getattr(operator, method) - return tracer.create_proxy('call_function', target, args, kwargs) - impl.__name__ = method - as_magic = f'__{method.strip("_")}__' - setattr(Proxy, as_magic, impl) - _scope(method) - -def _define_reflectable(orig_method_name): - method_name = f'__r{orig_method_name.strip("_")}__' - - def impl(self, rhs): - target = getattr(operator, orig_method_name) - return self.tracer.create_proxy('call_function', target, (rhs, self), {}) - impl.__name__ = method_name - impl.__qualname__ = method_name - setattr(Proxy, method_name, impl) - -for orig_method_name in reflectable_magic_methods: - _define_reflectable(orig_method_name) diff --git a/pippy/fx/subgraph_rewriter.py b/pippy/fx/subgraph_rewriter.py deleted file mode 100644 index 42620e05c..000000000 --- a/pippy/fx/subgraph_rewriter.py +++ /dev/null @@ -1,255 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates -from .graph_module import GraphModule -from .graph import Graph -from .node import Node -from ._symbolic_trace import symbolic_trace -from ._compatibility import compatibility - -import copy -from typing import Callable, Dict, List, NamedTuple, Optional, Set -import torch - -__all__ = ['Match', 'replace_pattern'] - -@compatibility(is_backward_compatible=True) -class Match(NamedTuple): - # Node from which the match was found - anchor: Node - # Maps nodes in the pattern subgraph to nodes in the larger graph - nodes_map: Dict[Node, Node] - - -def _replace_submodules(gm: GraphModule, replacement: torch.nn.Module) -> None: - gm.delete_all_unused_submodules() - - if isinstance(replacement, GraphModule): - replacement.graph.lint() - - def try_get_submodule(mod: torch.nn.Module, target: str) -> Optional[torch.nn.Module]: - try: - mod_match = mod.get_submodule(target) - return mod_match - except AttributeError: - return None - - for node in gm.graph.nodes: - if node.op == "call_module" or node.op == "get_attr": - - gm_submod = try_get_submodule(gm, node.target) - - replacement_submod = try_get_submodule(replacement, node.target) - - # CASE 1: This target already exists as a submodule in our - # result GraphModule. Whether or not it exists in - # `replacement`, the existing submodule takes precedence. - if gm_submod is not None: - continue - - # CASE 2: The target exists as a submodule in `replacement` - # only, so we need to copy it over. - elif replacement_submod is not None: - new_submod = copy.deepcopy(getattr(replacement, node.target)) - gm.add_submodule(node.target, new_submod) - - # CASE 3: The target doesn't exist as a submodule in `gm` - # or `replacement` - else: - raise RuntimeError("Attempted to create a \"", node.op, - "\" node during subgraph rewriting " - f"with target {node.target}, but " - "the referenced submodule does not " - "exist in either the original " - "GraphModule `gm` or the replacement" - " GraphModule `replacement`") - - gm.graph.lint() - -@compatibility(is_backward_compatible=True) -def replace_pattern(gm: GraphModule, pattern: Callable, replacement: Callable) -> List[Match]: - """ - Matches all possible non-overlapping sets of operators and their - data dependencies (``pattern``) in the Graph of a GraphModule - (``gm``), then replaces each of these matched subgraphs with another - subgraph (``replacement``). - - Args: - ``gm``: The GraphModule that wraps the Graph to operate on - ``pattern``: The subgraph to match in ``gm`` for replacement - ``replacement``: The subgraph to replace ``pattern`` with - - Returns: - List[Match]: A list of ``Match`` objects representing the places - in the original graph that ``pattern`` was matched to. The list - is empty if there are no matches. ``Match`` is defined as: - - .. code-block:: python - - class Match(NamedTuple): - # Node from which the match was found - anchor: Node - # Maps nodes in the pattern subgraph to nodes in the larger graph - nodes_map: Dict[Node, Node] - - Examples: - - .. code-block:: python - - import torch - from pippy.fx import symbolic_trace, subgraph_rewriter - - class M(torch.nn.Module): - def __init__(self): - super().__init__() - - def forward(self, x, w1, w2): - m1 = torch.cat([w1, w2]).sum() - m2 = torch.cat([w1, w2]).sum() - return x + torch.max(m1) + torch.max(m2) - - def pattern(w1, w2): - return torch.cat([w1, w2]).sum() - - def replacement(w1, w2): - return torch.stack([w1, w2]) - - traced_module = symbolic_trace(M()) - - subgraph_rewriter.replace_pattern(traced_module, pattern, replacement) - - The above code will first match ``pattern`` in the ``forward`` - method of ``traced_module``. Pattern-matching is done based on - use-def relationships, not node names. For example, if you had - ``p = torch.cat([a, b])`` in ``pattern``, you could match - ``m = torch.cat([a, b])`` in the original ``forward`` function, - despite the variable names being different (``p`` vs ``m``). - - The ``return`` statement in ``pattern`` is matched based on its - value only; it may or may not match to the ``return`` statement in - the larger graph. In other words, the pattern doesn't have to extend - to the end of the larger graph. - - When the pattern is matched, it will be removed from the larger - function and replaced by ``replacement``. If there are multiple - matches for ``pattern`` in the larger function, each non-overlapping - match will be replaced. In the case of a match overlap, the first - found match in the set of overlapping matches will be replaced. - ("First" here being defined as the first in a topological ordering - of the Nodes' use-def relationships. In most cases, the first Node - is the parameter that appears directly after ``self``, while the - last Node is whatever the function returns.) - - One important thing to note is that the parameters of the - ``pattern`` Callable must be used in the Callable itself, - and the parameters of the ``replacement`` Callable must match - the pattern. The first rule is why, in the above code block, the - ``forward`` function has parameters ``x, w1, w2``, but the - ``pattern`` function only has parameters ``w1, w2``. ``pattern`` - doesn't use ``x``, so it shouldn't specify ``x`` as a parameter. - As an example of the second rule, consider replacing - - .. code-block:: python - - def pattern(x, y): - return torch.neg(x) + torch.relu(y) - - with - - .. code-block:: python - - def replacement(x, y): - return torch.relu(x) - - In this case, ``replacement`` needs the same number of parameters - as ``pattern`` (both ``x`` and ``y``), even though the parameter - ``y`` isn't used in ``replacement``. - - After calling ``subgraph_rewriter.replace_pattern``, the generated - Python code looks like this: - - .. code-block:: python - - def forward(self, x, w1, w2): - stack_1 = torch.stack([w1, w2]) - sum_1 = stack_1.sum() - stack_2 = torch.stack([w1, w2]) - sum_2 = stack_2.sum() - max_1 = torch.max(sum_1) - add_1 = x + max_1 - max_2 = torch.max(sum_2) - add_2 = add_1 + max_2 - return add_2 - """ - from pippy.fx.passes.utils.matcher_utils import SubgraphMatcher, InternalMatch - - # Get the graphs for `gm`, `pattern`, `replacement` - original_graph: Graph = gm.graph - pattern_graph: Graph = symbolic_trace(pattern).graph - replacement_graph: Graph = symbolic_trace(replacement).graph - - matcher = SubgraphMatcher(pattern_graph, match_output=False, match_placeholder=False, - remove_overlapping_matches=True) - _matches: List[InternalMatch] = matcher.match(original_graph) - - replacement_placeholders = [n for n in replacement_graph.nodes if n.op == "placeholder"] - - # As we progressively replace nodes, we'll need to keep track of how the match results should change - match_changed_node: Dict[Node, Node] = {} - - for match in _matches: - - # Build connecting between replacement graph's input and original graph input producer node - - # Initialize `val_map` with mappings from placeholder nodes in - # `replacement` to their corresponding node in `original_graph` - assert len(match.placeholder_nodes) == len(replacement_placeholders) - val_map: Dict[Node, Node] = {} - for rn, gn in zip(replacement_placeholders, match.placeholder_nodes): - val_map[rn] = match_changed_node.get(gn, gn) - - # Copy the replacement graph over - user_nodes: Set[Node] = set() - for n in match.returning_nodes: - for user in n.users: - user_nodes.add(user) - assert user_nodes, "The returning_nodes should have at least one user node" - - if len(user_nodes) == 1: - first_user_node = list(user_nodes)[0] - else: - # If there are multiple user nodes, we need to find the first user node - # in the current execution order of the `original_graph` - for n in original_graph.nodes: - if n in user_nodes: - first_user_node = n - break - - with original_graph.inserting_before(first_user_node): - copied_returning_nodes = original_graph.graph_copy(replacement_graph, val_map) - - if isinstance(copied_returning_nodes, Node): - copied_returning_nodes = (copied_returning_nodes, ) - - # Hook the output Node of the replacement subgraph in to the - # original Graph at the correct location - assert len(match.returning_nodes) == len(copied_returning_nodes) - for gn, copied_node in zip(match.returning_nodes, copied_returning_nodes): - gn.replace_all_uses_with(copied_node) - match_changed_node[gn] = copied_node - # Remove the original nodes - for node in reversed(pattern_graph.nodes): - if node.op != "placeholder" and node.op != "output": - gn = match.nodes_map[node] - gm.graph.erase_node(gn) - - # Update the passed-in GraphModule to reflect the new state of - # `original_graph` - gm.recompile() - - # If `replacement` was an nn.Module, we'll need to make sure that - # all the submodules have been copied over correctly - if isinstance(replacement, torch.nn.Module): - _replace_submodules(gm, replacement) - - # Convert _matches: InternalMatch to Match to comply with backward compatibility of this function - matches: List[Match] = [Match(anchor=match.anchors[0], nodes_map=match.nodes_map) for match in _matches] - return matches diff --git a/pippy/fx/tensor_type.py b/pippy/fx/tensor_type.py deleted file mode 100644 index a85292ea3..000000000 --- a/pippy/fx/tensor_type.py +++ /dev/null @@ -1,105 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates -from pippy.fx.experimental.unification import Var # type: ignore[attr-defined] - -from ._compatibility import compatibility - - -@compatibility(is_backward_compatible=False) -class TensorType: - """ - TensorType defines a type for tensors, which consists of a list of dimensions. - Example: - class M(torch.nn.Module): - def forward(self, x:TensorType((1,2,3, Dyn)), y:TensorType((1,2,3, Dyn))): - return torch.add(x, y) - """ - - def __init__(self, dim): - self.__origin__ = TensorType - self.__args__ = dim - - def __repr__(self): - return f'TensorType[{self.__args__}]' - - def __eq__(self, other): - if isinstance(other, self.__class__): - return list(self.__args__) == list(other.__args__) - else: - return False - - @staticmethod - def __class_getitem__(*args): - if len(args) == 1 and isinstance(args[0], tuple): - args = args[0] - return TensorType(tuple(args)) - - -class _DynType: - """ - _DynType defines a type which stands for the absence of type information. - """ - def __init__(self): - self.__name__ = '_DynType' - - def __eq__(self, other): - return isinstance(other, self.__class__) - - def __str__(self): - return "Dyn" - - def __repr__(self): - return "Dyn" - - -Dyn = _DynType() - -@compatibility(is_backward_compatible=False) -def is_consistent(t1, t2): - """ - A binary relation denoted by ~ that determines if t1 is consistent with t2. - The relation is reflexive, semmetric but not transitive. - returns True if t1 and t2 are consistent and False otherwise. - Example: - Dyn ~ TensorType((1,2,3)) - int ~ Dyn - int ~ int - TensorType((1,Dyn,3)) ~ TensorType((1,2,3)) - """ - - if t1 == t2: - return True - - if t1 == Dyn or t2 == Dyn or isinstance(t1, Var) or isinstance(t2, Var): - return True - - if isinstance(t1, TensorType) and isinstance(t2, TensorType): - return len(t1.__args__) == len(t2.__args__) and \ - all([is_consistent(elem1, elem2) for elem1, elem2 in zip(t1.__args__, t2.__args__)]) - else: - return False - - -@compatibility(is_backward_compatible=False) -def is_more_precise(t1, t2): - """ - A binary relation denoted by <= that determines if t1 is more precise than t2. - The relation is reflexive and transitive. - returns True if t1 is more precise than t2 and False otherwise. - Example: - Dyn >= TensorType((1,2,3)) - int >= Dyn - int >= int - TensorType((1,Dyn,3)) <= TensorType((1,2,3)) - """ - if t1 == t2: - return True - - if isinstance(t2, _DynType): - return True - - if isinstance(t1, TensorType) and isinstance(t2, TensorType): - return len(t1.__args__) == len(t2.__args__) and \ - all([is_more_precise(elem1, elem2) for elem1, elem2 in zip(t1.__args__, t2.__args__)]) - - else: - return False diff --git a/pippy/fx/traceback.py b/pippy/fx/traceback.py deleted file mode 100644 index a07b36b99..000000000 --- a/pippy/fx/traceback.py +++ /dev/null @@ -1,62 +0,0 @@ -import traceback -from contextlib import contextmanager -from typing import Optional, List -from ._compatibility import compatibility - -__all__ = ['override_stack_trace', 'set_stack_trace', 'append_stack_trace', 'format_stack', 'is_stack_trace_overridden'] - - -current_stack: List[str] = [] -is_overridden = False - - -@compatibility(is_backward_compatible=False) -@contextmanager -def override_stack_trace(): - global is_overridden - - saved_is_overridden = is_overridden - try: - is_overridden = True - yield - finally: - is_overridden = saved_is_overridden - - -@compatibility(is_backward_compatible=False) -def set_stack_trace(stack : List[str]): - global current_stack - - if is_overridden and stack: - current_stack = stack - -@compatibility(is_backward_compatible=False) -@contextmanager -def append_stack_trace(stack : Optional[str]): - """ - The content of stack here is an entire stacktraces as a string - """ - global current_stack - - if is_overridden and stack: - try: - current_stack.append(stack) - yield - finally: - current_stack.pop() - else: - yield - - -@compatibility(is_backward_compatible=False) -def format_stack() -> List[str]: - if is_overridden: - return current_stack.copy() - else: - # fallback to traceback.format_stack() - return traceback.format_stack() - - -@compatibility(is_backward_compatible=False) -def is_stack_trace_overridden() -> bool: - return is_overridden diff --git a/pippy/microbatch.py b/pippy/microbatch.py index eb2cace9a..a84c81bbf 100644 --- a/pippy/microbatch.py +++ b/pippy/microbatch.py @@ -1,13 +1,18 @@ # Copyright (c) Meta Platforms, Inc. and affiliates import logging -import warnings -from typing import Any import torch - from torch.utils._pytree import tree_flatten, tree_unflatten -from pippy.IR import TrivialLossWrapper + +logger = logging.getLogger(__name__) + +""" +_debug_mask_minibatches specifies to send masked versions of the mini-batch +through instead of micro-batch slices--this can be used for more stable +numerical testing (see [A Note About Correctness Testing]) +""" +_debug_mask_minibatches = False class CustomReducer: @@ -48,7 +53,6 @@ def shard_dict_of_args( args_dict, args_chunk_spec, num_chunks, - _debug_mask_minibatches: bool = False, ): # Stage 1+2: flatten and shard/replicate @@ -95,7 +99,7 @@ def shard_dict_of_args( if first_tensor: # We can only adjust number of chunks when we hit this # issue at the first tensor encountered - warnings.warn( + logger.warning( f"Tensor size on chunking dimension is {v_split_dim_size}, " f"downsizing the number of chunks from {num_chunks} to {v_split_dim_size}." ) @@ -173,7 +177,6 @@ def split_args_kwargs_into_chunks( chunks, args_chunk_spec=None, kwargs_chunk_spec=None, - _debug_mask_minibatches: bool = False, ): # Given `args` and `kwargs`, we want to yield a set of `chunks` args and kwargs such that # the constituent Tensor values have been sharded/replicated according to the `args_chunk_spec` @@ -221,12 +224,13 @@ def split_args_kwargs_into_chunks( dict(enumerate(args)), dict(enumerate(args_chunk_spec)), chunks, - _debug_mask_minibatches, ) real_num_chunks = len(args_split_dict) kwargs_split = shard_dict_of_args( - kwargs, kwargs_chunk_spec, real_num_chunks, _debug_mask_minibatches + kwargs, + kwargs_chunk_spec, + real_num_chunks, ) if len(kwargs_split) < real_num_chunks: @@ -238,7 +242,6 @@ def split_args_kwargs_into_chunks( dict(enumerate(args)), dict(enumerate(args_chunk_spec)), real_num_chunks, - _debug_mask_minibatches, ) if len(args_split_dict) != len(kwargs_split): @@ -254,7 +257,7 @@ def split_args_kwargs_into_chunks( return args_split, kwargs_split -def merge_chunks(chunks, chunk_spec, _debug_mask_minibatches: bool = False): +def merge_chunks(chunks, chunk_spec): # Given a list of chunks and a chunk specification, merge the chunks # into a single value according to that chunk spec. This is essentially # the inverse of `split_args_kwargs_into_chunks`, so the steps are @@ -374,6 +377,8 @@ def merge_chunks(chunks, chunk_spec, _debug_mask_minibatches: bool = False): return tree_unflatten(args_flattened, flatten_spec) +# TODO: determine if we still need this helper +""" def gen_output_chunk_spec(loss_spec, loss_reducer): output_chunk_spec: Any = None if loss_spec is None: @@ -390,8 +395,9 @@ def gen_output_chunk_spec(loss_spec, loss_reducer): else: raise ValueError(f"Cannot generate output chunk spec for {loss_spec}") - logging.info( + logger.info( f"Generated output_chunk_spec for loss_spec {loss_spec}: " f"{output_chunk_spec}" ) return output_chunk_spec +""" diff --git a/pippy/utils.py b/pippy/utils.py index 7e3d4d7d4..69a9b7ae1 100644 --- a/pippy/utils.py +++ b/pippy/utils.py @@ -1,277 +1,7 @@ # Copyright (c) Meta Platforms, Inc. and affiliates -import logging -import os -import socket -from typing import List - -import torch.distributed as dist - - -# Pinning process to a separate GPU if not yet done by launch script -# Notes: -# 1. Previously this env was added to work around an issue that each RPC process creates an extra CUDA context on device -# 0. This issue may have been caused by RPC not automatically pinning spawned worker threads to same CUDA device as the -# main thread. So pinning each RPC process to one device would avoid the issue. -# 2. This pinning must be done before `import torch` at which point CUDA context may have been created. Thus, if user -# code has `import torch` before importing PiPPy, this may not work. -# (Update): the issue in #1 seems to be gone as of March 2023. Hence, we are setting the default value of -# `PIPPY_PIN_DEVICE` to 0 now. -if os.getenv("PIPPY_PIN_DEVICE", "0") == "1": - cuda_devices_str = os.getenv("CUDA_VISIBLE_DEVICES") - if ( - cuda_devices_str is None # not set - or len(cuda_devices_str.split(",")) > 1 - ): # or set to all devices - # If launchers like Torchrun sets `LOCAL_RANK`, we would use this information - local_rank_str = os.getenv("LOCAL_RANK") - if local_rank_str is not None: - os.environ["CUDA_VISIBLE_DEVICES"] = local_rank_str - print( - f"Pinning local process {local_rank_str} to gpu {os.getenv('CUDA_VISIBLE_DEVICES')}" - ) - - import torch -import torch.distributed.rpc as rpc -import torch.multiprocessing as mp - -import pippy.fx - - -def get_rank() -> int: - worker_info = rpc.get_worker_info() - logging.debug(worker_info) - return worker_info.id - - -def get_device() -> torch.device: - worker_info = rpc.get_worker_info() - agent = rpc._get_current_rpc_agent() - dev_map = agent._get_device_map(worker_info) - logging.debug(dev_map) - num_devs = len(dev_map) - - if num_devs == 0: - logging.debug("Empty device mapping, assuming device type to be cpu") - device = torch.device("cpu") - elif num_devs != 1: - raise AssertionError( - f"Expecting at most one device for RPC worker {worker_info}, " - f"but got device map of length {num_devs}: {dev_map}" - ) - else: - src_dev = next(iter(dev_map)) - dst_dev = dev_map[src_dev] - if src_dev != dst_dev: - raise AssertionError( - f"Expecting at most one device for RPC worker {worker_info}, " - f"but got {dev_map}" - ) - device = src_dev - - logging.info(f"Found device {device} for rank {worker_info.id}") - return device - - -def get_pp_rank(rank: int, ranks: List[int]) -> int: - for index, r in enumerate(ranks): - if rank == r: - return index - raise ValueError(f"Rank {rank} not in ranks {ranks}") - - -def has_efa() -> bool: - try: - import subprocess - - return ( - subprocess.run( - ["fi_info", "-p", "efa", "-t", "FI_EP_RDM"], - stdout=subprocess.DEVNULL, - stderr=subprocess.DEVNULL, - ).returncode - == 0 - ) - except FileNotFoundError: - return False - except PermissionError: - return False - - -def tp_transports(): - return ["shm", "uv"] if has_efa() else None - - -global _pp_group_barrier -# Defined later in `run_worker` (triggered via `run_pippy`) - - -# A barrier util for pipeline dimension -def pp_group_barrier(): - _pp_group_barrier() # type: ignore[name-defined] - - -def run_pippy(run_func, args, *extra_args): - if not hasattr(args, "world_size"): - assert hasattr(args, "pp_group_size") - args.dp_group_size = ( - args.dp_group_size if hasattr(args, "dp_group_size") else 1 - ) - else: - if not hasattr(args, "dp_group_size"): - args.pp_group_size = ( - args.pp_group_size - if hasattr(args, "pp_group_size") - else args.world_size - ) - assert args.world_size % args.pp_group_size == 0 - args.dp_group_size = args.world_size // args.pp_group_size - elif not hasattr(args, "pp_group_size"): - args.dp_group_size = ( - args.dp_group_size if hasattr(args, "dp_group_size") else 1 - ) - assert args.world_size % args.dp_group_size == 0 - args.pp_group_size = args.world_size // args.dp_group_size - else: - pass - # TODO: doesn't work for PiPPyTrainingArguments - # assert args.world_size == args.dp_group_size * args.pp_group_size - - actual_world_size = args.dp_group_size * args.pp_group_size - print( - f"[PiPPy] World size: {actual_world_size}, " - f"DP group size: {args.dp_group_size}, " - f"PP group size: {args.pp_group_size}" - ) - - if args.rank == -1: - mp.spawn( - run_worker, - args=(run_func, args, *extra_args), - nprocs=actual_world_size, - join=True, - ) - elif args.rank < actual_world_size: - run_worker(args.rank, run_func, args, *extra_args) - else: - print("I'm unused, exiting") - - -def run_worker(rank, run_func, args, *extra_args): - args.rank = rank - - os.environ["MASTER_ADDR"] = args.master_addr - os.environ["MASTER_PORT"] = args.master_port - - actual_world_size = args.dp_group_size * args.pp_group_size - - # TODO: Move to training args, blocked by: cannot pickle 'TensorPipeRpcBackendOptions' object - # Exclude IB for metadata transport due to lack of EFA support on AWS - if hasattr(args, "num_worker_threads"): - num_worker_threads = args.num_worker_threads - else: - num_worker_threads = 512 - - if hasattr(args, "rpc_timeout"): - rpc_timeout = args.rpc_timeout - else: - rpc_timeout = 1800 - - options = rpc.TensorPipeRpcBackendOptions( - num_worker_threads=num_worker_threads, - rpc_timeout=rpc_timeout, - _transports=tp_transports(), - ) - if args.cuda: - n_devs = torch.cuda.device_count() - if n_devs > 0: - dev_id = rank % n_devs - for i in range(actual_world_size): - options.set_device_map(f"worker{i}", {dev_id: i % n_devs}) - # Does not seem effective for RPC device pinning. TODO - # options.set_devices([f'cuda:{dev_id}']) - else: - args.cuda = 0 - print("Warning: no CUDA device found. Running on CPU instead.") - - args.device = f"cuda:{dev_id}" if args.cuda else "cpu" - print( - f"rank = {rank} host/pid/device = " - f"{socket.gethostname()}/{os.getpid()}/{args.device}" - ) - - # Init DDP process group - backend = "nccl" if args.cuda else "gloo" - torch.distributed.init_process_group( - backend=backend, rank=rank, world_size=actual_world_size - ) - - rpc.init_rpc( - f"worker{rank}", - rank=rank, - world_size=actual_world_size, - rpc_backend_options=options, - ) - - global dp_pg_per_pp_rank - dp_ranks_per_pp_rank = ( - torch.arange(actual_world_size) - .reshape(args.pp_group_size, args.dp_group_size) - .tolist() - ) - dp_pg_per_pp_rank = [ # type: ignore[name-defined] - torch.distributed.new_group(ranks) for ranks in dp_ranks_per_pp_rank - ] - - pp_ranks_per_dp_group = [ - [i * args.dp_group_size + rank for i in range(args.pp_group_size)] - for rank in range(args.dp_group_size) - ] - - my_pp_ranks = pp_ranks_per_dp_group[rank % args.dp_group_size] - - args.driver_group = torch.distributed.new_group( - list(range(args.dp_group_size)) - ) - - global exclude_master - exclude_master = ( # type: ignore[name-defined] - args.exclude_master if hasattr(args, "exclude_master") else 0 - ) - gspmd = ( # type: ignore[name-defined] - args.gspmd if hasattr(args, "gspmd") else 0 - ) - - # A barrier util for pipeline dimension - global _pp_group_barrier - - # ProcessGroupGloo cannot create group with strided ranks, e.g. [0, 2, 4, 6, ...] - # Skipping the `pp_group` and `pp_group_barrier` creation here - # TODO: unskip - if torch.distributed.get_backend() == "gloo" and args.dp_group_size > 1: - - def _pp_group_barrier(): - logging.warning( - f"pp_group_barrier() does not support ProcessGroupGloo with strided ranks {my_pp_ranks}. This will be a no-op." - ) - - else: - pp_group = torch.distributed.new_group(my_pp_ranks) - - def _pp_group_barrier(): - logging.debug( - f"Running pipeline group barrier on ranks {my_pp_ranks}" - ) - torch.distributed.barrier(pp_group) - - if rank >= 0 and rank // args.dp_group_size == 0: - args.driver_index = rank - args.local_driver_index = os.getenv("LOCAL_RANK", rank) - run_func(my_pp_ranks, args, *extra_args) - elif gspmd == 1: - run_func(my_pp_ranks, args, *extra_args) - - rpc.shutdown() +import torch.distributed as dist +from torch import fx def flatten_args_detach(args): @@ -287,11 +17,14 @@ def extract_tensor_args(a): flat_detached_args.append(a) return a + """ def dont_traverse_size(a): return type(a) != torch.Size + """ - new_args = pippy.fx.node.map_aggregate( - args, extract_tensor_args, dont_traverse_size + new_args = fx.node.map_aggregate( + args, + extract_tensor_args, # dont_traverse_size ) return new_args, flat_detached_args @@ -305,10 +38,15 @@ def extract_tensor_args(a): flat_args.append(a) return a + """ def dont_traverse_size(a): return type(a) != torch.Size + """ - pippy.fx.node.map_aggregate(args, extract_tensor_args, dont_traverse_size) + fx.node.map_aggregate( + args, + extract_tensor_args, # dont_traverse_size + ) return flat_args diff --git a/requirements.txt b/requirements.txt index ac954cf5d..4d73a1d52 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,2 +1,2 @@ -torch >= 1.13.0 +torch >= 2.2.0.dev packaging >= 21.3 diff --git a/setup.py b/setup.py index f5229da0c..b50de9597 100644 --- a/setup.py +++ b/setup.py @@ -45,7 +45,7 @@ def write_version_file(): requirements = [ # If the torch version has a ".dev" suffix, it would represent a nightly version of PyTorch. # It can be installed as a binary or from source. - "torch>=1.13.0", + "torch>=2.2.0.dev", ] extras: Dict = {} diff --git a/test/expect/TestFXAPIBackwardCompatibility.test_class_member_back_compat-fx_backcompat_class_members.expect b/test/expect/TestFXAPIBackwardCompatibility.test_class_member_back_compat-fx_backcompat_class_members.expect deleted file mode 100644 index 1b732fd1f..000000000 --- a/test/expect/TestFXAPIBackwardCompatibility.test_class_member_back_compat-fx_backcompat_class_members.expect +++ /dev/null @@ -1,19 +0,0 @@ -pippy.fx._symbolic_trace.ProxyableClassMeta [] -pippy.fx._symbolic_trace.Tracer ['call_module', 'create_arg', 'create_args_for_root', 'getattr', 'is_leaf_module', 'path_of_module', 'trace'] -pippy.fx.graph.Graph ['call_function', 'call_method', 'call_module', 'create_node', 'eliminate_dead_code', 'erase_node', 'get_attr', 'graph_copy', 'inserting_after', 'inserting_before', 'lint', 'node_copy', 'nodes', 'on_generate_code', 'output', 'owning_module', 'placeholder', 'print_tabular', 'process_inputs', 'process_outputs', 'python_code', 'set_codegen'] -pippy.fx.graph.PythonCode [] -pippy.fx.graph_module.GraphModule ['add_submodule', 'code', 'delete_all_unused_submodules', 'delete_submodule', 'graph', 'print_readable', 'recompile', 'to_folder'] -pippy.fx.immutable_collections.immutable_dict ['clear', 'pop', 'popitem', 'update'] -pippy.fx.immutable_collections.immutable_list ['append', 'clear', 'extend', 'insert', 'pop', 'remove'] -pippy.fx.interpreter.Interpreter ['call_function', 'call_method', 'call_module', 'fetch_args_kwargs_from_env', 'fetch_attr', 'get_attr', 'map_nodes_to_values', 'output', 'placeholder', 'run', 'run_node'] -pippy.fx.interpreter.Transformer ['call_function', 'call_module', 'get_attr', 'placeholder', 'transform'] -pippy.fx.node.Node ['all_input_nodes', 'append', 'args', 'format_node', 'is_impure', 'kwargs', 'next', 'normalized_arguments', 'prepend', 'prev', 'replace_all_uses_with', 'replace_input_with', 'stack_trace', 'update_arg', 'update_kwarg'] -pippy.fx.passes.shape_prop.ShapeProp ['propagate', 'run_node'] -pippy.fx.passes.shape_prop.TensorMetadata ['dtype', 'is_quantized', 'memory_format', 'qparams', 'requires_grad', 'shape', 'stride'] -pippy.fx.passes.split_module.Partition [] -pippy.fx.proxy.Attribute ['node'] -pippy.fx.proxy.GraphAppendingTracer [] -pippy.fx.proxy.Proxy ['keys'] -pippy.fx.proxy.TraceError [] -pippy.fx.proxy.TracerBase ['check_mutable_operations', 'create_arg', 'create_node', 'create_proxy', 'iter', 'keys', 'proxy', 'proxy_buffer_attributes', 'record_stack_traces', 'to_bool', 'trace_asserts', 'traced_func_name'] -pippy.fx.subgraph_rewriter.Match ['anchor', 'nodes_map'] \ No newline at end of file diff --git a/test/expect/TestFXAPIBackwardCompatibility.test_function_back_compat-fx_backcompat_function_signatures.expect b/test/expect/TestFXAPIBackwardCompatibility.test_function_back_compat-fx_backcompat_function_signatures.expect deleted file mode 100644 index 25e1e641c..000000000 --- a/test/expect/TestFXAPIBackwardCompatibility.test_function_back_compat-fx_backcompat_function_signatures.expect +++ /dev/null @@ -1,74 +0,0 @@ -pippy.fx._symbolic_trace.Tracer.__init__(self, autowrap_modules: Tuple[Callable] = (,), autowrap_functions: Tuple[Callable, ...] = (,), param_shapes_constant: bool = False) -> None -pippy.fx._symbolic_trace.Tracer.call_module(self, m: torch.nn.modules.module.Module, forward: Callable[..., Any], args: Tuple[Any, ...], kwargs: Dict[str, Any]) -> Any -pippy.fx._symbolic_trace.Tracer.create_arg(self, a: Any) -> 'Argument' -pippy.fx._symbolic_trace.Tracer.is_leaf_module(self, m: torch.nn.modules.module.Module, module_qualified_name: str) -> bool -pippy.fx._symbolic_trace.Tracer.path_of_module(self, mod: torch.nn.modules.module.Module) -> str -pippy.fx._symbolic_trace.Tracer.trace(self, root: Union[torch.nn.modules.module.Module, Callable[..., Any]], concrete_args: Optional[Dict[str, Any]] = None) -> pippy.fx.graph.Graph -pippy.fx._symbolic_trace.symbolic_trace(root: Union[torch.nn.modules.module.Module, Callable[..., Any]], concrete_args: Optional[Dict[str, Any]] = None) -> pippy.fx.graph_module.GraphModule -pippy.fx._symbolic_trace.wrap(fn_or_name: Union[str, Callable]) -pippy.fx.graph.Graph.__init__(self, owning_module: Optional[GraphModule] = None, tracer_cls: Optional[Type[Tracer]] = None, tracer_extras: Optional[Dict[str, Any]] = None) -pippy.fx.graph.Graph.call_function(self, the_function: Callable[..., Any], args: Optional[Tuple[Argument, ...]] = None, kwargs: Optional[Dict[str, Argument]] = None, type_expr: Optional[Any] = None) -> pippy.fx.node.Node -pippy.fx.graph.Graph.call_method(self, method_name: str, args: Optional[Tuple[Argument, ...]] = None, kwargs: Optional[Dict[str, Argument]] = None, type_expr: Optional[Any] = None) -> pippy.fx.node.Node -pippy.fx.graph.Graph.call_module(self, module_name: str, args: Optional[Tuple[Argument, ...]] = None, kwargs: Optional[Dict[str, Argument]] = None, type_expr: Optional[Any] = None) -> pippy.fx.node.Node -pippy.fx.graph.Graph.create_node(self, op: str, target: 'Target', args: Optional[Tuple[Argument, ...]] = None, kwargs: Optional[Dict[str, Argument]] = None, name: Optional[str] = None, type_expr: Optional[Any] = None) -> pippy.fx.node.Node -pippy.fx.graph.Graph.eliminate_dead_code(self) -pippy.fx.graph.Graph.erase_node(self, to_erase: pippy.fx.node.Node) -> None -pippy.fx.graph.Graph.get_attr(self, qualified_name: str, type_expr: Optional[Any] = None) -> pippy.fx.node.Node -pippy.fx.graph.Graph.graph_copy(self, g: 'Graph', val_map: Dict[pippy.fx.node.Node, pippy.fx.node.Node], return_output_node = False) -> 'Optional[Argument]' -pippy.fx.graph.Graph.inserting_after(self, n: Optional[pippy.fx.node.Node] = None) -pippy.fx.graph.Graph.inserting_before(self, n: Optional[pippy.fx.node.Node] = None) -pippy.fx.graph.Graph.lint(self) -pippy.fx.graph.Graph.node_copy(self, node: pippy.fx.node.Node, arg_transform: Callable[[pippy.fx.node.Node], Argument] = >) -> pippy.fx.node.Node -pippy.fx.graph.Graph.output(self, result: 'Argument', type_expr: Optional[Any] = None) -pippy.fx.graph.Graph.placeholder(self, name: str, type_expr: Optional[Any] = None, default_value: Any) -> pippy.fx.node.Node -pippy.fx.graph.Graph.print_tabular(self) -pippy.fx.graph.Graph.python_code(self, root_module: str, verbose: bool = False) -> pippy.fx.graph.PythonCode -pippy.fx.graph_module.GraphModule.__init__(self, root: Union[torch.nn.modules.module.Module, Dict[str, Any]], graph: pippy.fx.graph.Graph, class_name: str = 'GraphModule') -pippy.fx.graph_module.GraphModule.add_submodule(self, target: str, m: torch.nn.modules.module.Module) -> bool -pippy.fx.graph_module.GraphModule.delete_all_unused_submodules(self) -> None -pippy.fx.graph_module.GraphModule.delete_submodule(self, target: str) -> bool -pippy.fx.graph_module.GraphModule.recompile(self) -> pippy.fx.graph.PythonCode -pippy.fx.graph_module.reduce_deploy_graph_module(importer: Callable, body: Dict[Any, Any], import_block: str) -> torch.nn.modules.module.Module -pippy.fx.graph_module.reduce_graph_module(body: Dict[Any, Any], import_block: str) -> torch.nn.modules.module.Module -pippy.fx.graph_module.reduce_package_graph_module(importer: Callable, body: Dict[Any, Any], generated_module_name: str) -> torch.nn.modules.module.Module -pippy.fx.interpreter.Interpreter.__init__(self, module: pippy.fx.graph_module.GraphModule, garbage_collect_values: bool = True) -pippy.fx.interpreter.Interpreter.call_function(self, target: 'Target', args: Tuple[pippy.fx.node.Argument, ...], kwargs: Dict[str, Any]) -> Any -pippy.fx.interpreter.Interpreter.call_method(self, target: 'Target', args: Tuple[pippy.fx.node.Argument, ...], kwargs: Dict[str, Any]) -> Any -pippy.fx.interpreter.Interpreter.call_module(self, target: 'Target', args: Tuple[pippy.fx.node.Argument, ...], kwargs: Dict[str, Any]) -> Any -pippy.fx.interpreter.Interpreter.fetch_args_kwargs_from_env(self, n: pippy.fx.node.Node) -> Tuple[Tuple, Dict] -pippy.fx.interpreter.Interpreter.fetch_attr(self, target: str) -pippy.fx.interpreter.Interpreter.get_attr(self, target: 'Target', args: Tuple[pippy.fx.node.Argument, ...], kwargs: Dict[str, Any]) -> Any -pippy.fx.interpreter.Interpreter.map_nodes_to_values(self, args: pippy.fx.node.Argument, n: pippy.fx.node.Node) -> pippy.fx.node.Argument -pippy.fx.interpreter.Interpreter.output(self, target: 'Target', args: Tuple[pippy.fx.node.Argument, ...], kwargs: Dict[str, Any]) -> Any -pippy.fx.interpreter.Interpreter.placeholder(self, target: 'Target', args: Tuple[pippy.fx.node.Argument, ...], kwargs: Dict[str, Any]) -> Any -pippy.fx.interpreter.Interpreter.run(self, *args, initial_env: Optional[Dict[pippy.fx.node.Node, Any]] = None, enable_io_processing: bool = True) -> Any -pippy.fx.interpreter.Interpreter.run_node(self, n: pippy.fx.node.Node) -> Any -pippy.fx.interpreter.Transformer.__init__(self, module) -pippy.fx.interpreter.Transformer.call_function(self, target: 'Target', args: Tuple[pippy.fx.node.Argument, ...], kwargs: Dict[str, Any]) -> Any -pippy.fx.interpreter.Transformer.call_module(self, target: 'Target', args: Tuple[pippy.fx.node.Argument, ...], kwargs: Dict[str, Any]) -> Any -pippy.fx.interpreter.Transformer.get_attr(self, target: 'Target', args: Tuple[pippy.fx.node.Argument, ...], kwargs: Dict[str, Any]) -> pippy.fx.proxy.Proxy -pippy.fx.interpreter.Transformer.placeholder(self, target: 'Target', args: Tuple[pippy.fx.node.Argument, ...], kwargs: Dict[str, Any]) -> pippy.fx.proxy.Proxy -pippy.fx.interpreter.Transformer.transform(self) -> pippy.fx.graph_module.GraphModule -pippy.fx.node.Node.__init__(self, graph: 'Graph', name: str, op: str, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Argument], return_type: Optional[Any] = None) -> None -pippy.fx.node.Node.append(self, x: 'Node') -> None -pippy.fx.node.Node.format_node(self, placeholder_names: Optional[List[str]] = None, maybe_return_typename: Optional[List[str]] = None) -> Optional[str] -pippy.fx.node.Node.prepend(self, x: 'Node') -> None -pippy.fx.node.Node.replace_all_uses_with(self, replace_with: 'Node', delete_user_cb: Callable[[Node], bool] = >) -> List[Node] -pippy.fx.node.Node.replace_input_with(self, old_input: 'Node', new_input: 'Node') -pippy.fx.node.Node.update_arg(self, idx: int, arg: pippy.fx.node.Argument) -> None -pippy.fx.node.Node.update_kwarg(self, key: str, arg: pippy.fx.node.Argument) -> None -pippy.fx.node.map_aggregate(a: pippy.fx.node.Argument, fn: Callable[[pippy.fx.node.Argument], pippy.fx.node.Argument], should_traverse_fn: Optional[Callable[[pippy.fx.node.Argument], bool]] = None) -> pippy.fx.node.Argument -pippy.fx.node.map_arg(a: pippy.fx.node.Argument, fn: Callable[[pippy.fx.node.Node], pippy.fx.node.Argument]) -> pippy.fx.node.Argument -pippy.fx.passes.reinplace.reinplace(gm, *sample_args) -pippy.fx.passes.split_module.split_module(m: pippy.fx.graph_module.GraphModule, root_m: torch.nn.modules.module.Module, split_callback: Callable[[pippy.fx.node.Node], int], qualname_map: Optional[Dict[str, str]] = None, keep_original_order: Optional[bool] = False) -pippy.fx.proxy.Attribute.__init__(self, root: pippy.fx.proxy.Proxy, attr: str) -pippy.fx.proxy.Proxy.__init__(self, node: pippy.fx.node.Node, tracer: 'Optional[TracerBase]' = None) -pippy.fx.proxy.Proxy.keys(self) -pippy.fx.proxy.TracerBase.create_arg(self, a: Any) -> pippy.fx.node.Argument -pippy.fx.proxy.TracerBase.create_node(self, kind: str, target: pippy.fx.node.Target, args: Tuple[pippy.fx.node.Argument, ...], kwargs: Dict[str, pippy.fx.node.Argument], name: Optional[str] = None, type_expr: Optional[Any] = None) -> pippy.fx.node.Node -pippy.fx.proxy.TracerBase.create_proxy(self, kind: str, target: pippy.fx.node.Target, args: Tuple[Any, ...], kwargs: Dict[str, Any], name: Optional[str] = None, type_expr: Optional[Any] = None, proxy_factory_fn: Callable[[pippy.fx.node.Node], Proxy] = None) -pippy.fx.proxy.TracerBase.iter(self, obj: 'Proxy') -> Iterator -pippy.fx.proxy.TracerBase.keys(self, obj: 'Proxy') -> Any -pippy.fx.proxy.TracerBase.proxy(self, node: pippy.fx.node.Node) -> 'Proxy' -pippy.fx.proxy.TracerBase.to_bool(self, obj: 'Proxy') -> bool -pippy.fx.subgraph_rewriter.replace_pattern(gm: pippy.fx.graph_module.GraphModule, pattern: Callable, replacement: Callable) -> List[pippy.fx.subgraph_rewriter.Match] diff --git a/test/fx/named_tup.py b/test/fx/named_tup.py deleted file mode 100644 index 2d4f63113..000000000 --- a/test/fx/named_tup.py +++ /dev/null @@ -1,8 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates -from typing import NamedTuple - -import torch - -class MyNamedTup(NamedTuple): - i : torch.Tensor - f : torch.Tensor diff --git a/test/fx/quantization.py b/test/fx/quantization.py deleted file mode 100644 index 75589ddbc..000000000 --- a/test/fx/quantization.py +++ /dev/null @@ -1,325 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates -r''' -**This file is EXPERIMENTAL and is mostly used for testing purposes! Do not -rely on it for anything!** -''' -from pippy.fx import Graph, GraphModule -from pippy.fx.graph import map_arg -from pippy.fx.proxy import Proxy -import sys -import torch -from torch.nn.utils import fuse_conv_bn_weights -import operator - -# can be a -# module type, a builtin function, or a string to match target - -def _minmax_scale_zeropoint(min_val, max_val, qmin=-127, qmax=128, eps=torch.finfo(torch.float32).eps): - min_val = min(0.0, min_val) - max_val = max(0.0, max_val) - if max_val == min_val: - return 1.0, 0 - else: - scale = (max_val - min_val) / float(qmax - qmin) - scale = max(scale, eps) - zero_point = qmin - round(min_val / scale) - zero_point = max(qmin, zero_point) - zero_point = min(qmax, zero_point) - zero_point = int(zero_point) - return scale, zero_point - -class MinMaxObserver: - def __init__(self, quantizer, node): - self.min, self.max = float('inf'), float('-inf') - self.all_tensors = True - - def observe(self, node, env): - v = env[node.name] - if not isinstance(v, torch.Tensor): - self.all_tensors = False - return - self.max = max(self.max, float(v.max())) - self.min = min(self.min, float(v.min())) - - def scale_zeropoint(self): - return _minmax_scale_zeropoint(self.min, self.max, qmin=0, qmax=255) - -class NoObserver: - def __init__(self, quantizer, node): - pass - - def observe(self, node, env): - pass - -DEFAULT_QUANTIZATION_PATTERNS = {} -def register_pattern(pattern): - def insert(fn): - DEFAULT_QUANTIZATION_PATTERNS[pattern] = fn - return fn - return insert - - -@register_pattern(operator.add) -class Add(MinMaxObserver): - def quantize(self, quantizer, node, load_arg): - if not self.all_tensors: - return NotImplemented - scale, zeropoint = self.scale_zeropoint() - return quantizer.quantized_graph.create_node( - 'call_function', torch.ops.quantized.add, load_arg(node.args), {'scale': scale, 'zero_point': zeropoint}) - - -class Relu(NoObserver): - def quantize(self, quantizer, node, load_arg): - return torch.relu(load_arg(node.args[0])) # torch.relu works directly on quantized tensors? - -# these ops have quantized equivalents that do not need any extra information -@register_pattern(torch.nn.ReLU) -@register_pattern(torch.nn.AvgPool2d) -@register_pattern(torch.nn.MaxPool2d) -@register_pattern(torch.nn.AdaptiveAvgPool2d) -class CopyNode(NoObserver): - def quantize(self, quantizer, node, load_arg): - return quantizer.quantized_graph.node_copy(node, load_arg) - -class IdentityModule(torch.nn.Module): - def forward(self, x): - return x - -# handle conv, maybe followed by bn, maybe followed by relu -@register_pattern(torch.nn.modules.conv.Conv2d) -@register_pattern((torch.nn.ReLU, torch.nn.modules.conv.Conv2d)) -@register_pattern((torch.nn.modules.batchnorm.BatchNorm2d, torch.nn.modules.conv.Conv2d)) -@register_pattern((torch.nn.ReLU, (torch.nn.modules.batchnorm.BatchNorm2d, torch.nn.modules.conv.Conv2d))) -class ConvNormRelu(MinMaxObserver): - def __init__(self, quantizer, node): - super().__init__(quantizer, node) - self.relu_node, self.bn_node = None, None - if isinstance(quantizer.modules[node.target], torch.nn.ReLU): - self.relu_node = node - node = node.args[0] - if isinstance(quantizer.modules[node.target], torch.nn.BatchNorm2d): - self.bn_node = node - self.bn = quantizer.modules[self.bn_node.target] - node = node.args[0] - assert isinstance(quantizer.modules[node.target], torch.nn.modules.Conv2d) - self.conv_node = node - self.conv = quantizer.modules[self.conv_node.target] - - def quantize(self, quantizer, node, load_arg): - mod = self.conv - weight, bias = mod.weight, mod.bias - - if self.bn_node is not None: - weight, bias = fuse_conv_bn_weights( - weight, bias, self.bn.running_mean, self.bn.running_var, - self.bn.eps, self.bn.weight, self.bn.bias) - - min_val, max_val = float(weight.min()), float(weight.max()) - - act_scale, act_zp = self.scale_zeropoint() - - weight_scale, weight_zp = _minmax_scale_zeropoint(min_val, max_val) - qweight = torch.quantize_per_tensor(weight, weight_scale, weight_zp, torch.qint8) - - ctor = torch.ao.nn.intrinsic.quantized.ConvReLU2d if self.relu_node is not None else torch.ao.nn.quantized.Conv2d - - qconv = ctor(mod.in_channels, mod.out_channels, mod.kernel_size, - mod.stride, mod.padding, mod.dilation, mod.groups, - mod.bias is not None, mod.padding_mode) - - qconv.set_weight_bias(qweight, bias) - qconv.scale = float(act_scale) - qconv.zero_point = int(act_zp) - parent_name, name = _parent_name(self.conv_node.target) - setattr(quantizer.modules[parent_name], name, qconv) - if self.bn_node is not None: - parent_bn, bn_name = _parent_name(self.bn_node.target) - # we can't just delete this because submodules's forwards (which are not longer use) - # try to call it, so replace with something that does nothing. - setattr(quantizer.modules[parent_name], bn_name, IdentityModule()) - - return quantizer.quantized_graph.create_node('call_module', self.conv_node.target, (load_arg(self.conv_node.args[0]),), {}) - - -# turn foo.bar -> ['foo', 'bar'] -def _parent_name(target): - r = target.rsplit('.', 1) - if len(r) == 1: - return '', r[0] - else: - return r[0], r[1] - - - -class DefaultQuant(MinMaxObserver): - def quantize(self, input): - assert self.all_tensors - scale, zeropoint = self.scale_zeropoint() - return torch.quantize_per_tensor(Proxy(input), scale, zeropoint, torch.quint8).node - -def matches(modules, node, pattern, max_uses=sys.maxsize): - if isinstance(pattern, tuple): - self_match, *arg_matches = pattern - else: - self_match = pattern - arg_matches = None - - if len(node.users) > max_uses: - return False - - if isinstance(self_match, type) and issubclass(self_match, torch.nn.Module): - if node.op != 'call_module': - return False - if not isinstance(modules[node.target], self_match): - return False - elif callable(self_match): - if node.op != 'call_function' or node.target is not self_match: - return False - elif node.target != self_match: - return False - - if not arg_matches: - return True - - if len(arg_matches) != len(node.args): - return False - - return all(matches(modules, node, arg_match, max_uses=1) for node, arg_match in zip(node.args, arg_matches)) - - -class Quantizer: - def __init__(self, mod, patterns=DEFAULT_QUANTIZATION_PATTERNS, quant_ctor=DefaultQuant): - self.root = mod - self.graph = mod.graph - self.quant_ctor = quant_ctor - - # cached information for observe - self.state_dict = self.root.state_dict() - self.modules = dict(self.root.named_modules()) - - # match the patterns that will get quantized - self.matches = self._find_matches(patterns) - # find _inputs_ to matched nodes that are not quantized, these - # have to be quantized, which requires measuring stats, - # initialize an quant_ctor object for each - self.quants = self._find_quants(quant_ctor) - - - - def observe(self, args): - # most of this function is just an interpreter for the graph - # it would be possible to put this in some abstraction, but - # it is pretty nice to just be able to see exactly what is happening here - # and hack on it. - # maybe we should just provide an example interpreter that people copy/paste - # then edit. - args_iter = iter(args) - env = {} - - def load_arg(a): - return map_arg(a, lambda node: env[node.name]) - - output_node : Optional[Node] = None - for node in self.graph.nodes: - if node.op == 'placeholder': - result = next(args_iter) - elif node.op == 'get_attr': - result = self.state_dict[node.target] - elif node.op == 'call_function': - result = node.target(*load_arg(node.args), **load_arg(node.kwargs)) - elif node.op == 'call_method': - self_obj, *args = load_arg(node.args) - kwargs = load_arg(node.kwargs) - result = getattr(self_obj, node.target)(*args, **kwargs) - elif node.op == 'call_module': - result = self.modules[node.target](*load_arg(node.args), **load_arg(node.kwargs)) - elif node.op == 'output': - return load_arg(node.args[0]) - - env[node.name] = result - root_node, obj = self.matches.get(node.name, (None, None)) - if root_node is node: - obj.observe(node, env) - if node.name in self.quants: - self.quants[node.name].observe(node, env) - - raise RuntimeError('Graph had no output node!') - - def quantize(self): - self.quantized_graph = Graph() - - env = {} - quant_env = {} - - def load_arg(n, quantized): - if not quantized: - if n.name not in env and n.name in quant_env: - env[n.name] = Proxy(quant_env[n.name]).dequantize().node - return env[n.name] - else: - if n.name not in quant_env and n.name in env: - quant_env[n.name] = self.quants[n.name].quantize(env[n.name]) - return quant_env[n.name] - - def copy_recursive(node): - def load_or_emit(n): - if n.name in env or e.name in quant_env: - return load_arg(n, quantized=False) - else: - return copy_recusive(n) - r = env[node.name] = self.quantized_graph.node_copy(node, lambda n: load_arg(n, quantized=False)) - return r - - for node in self.graph.nodes: - root_node, obj = self.matches.get(node.name, (None, None)) - if root_node is None: - # not quantized just copy it - env[node.name] = self.quantized_graph.node_copy(node, lambda n: load_arg(n, quantized=False)) - - elif root_node is node: - r = obj.quantize(self, node, lambda a: map_arg(a, lambda n: load_arg(n, quantized=True))) - if r is NotImplemented: - # quantizer choose to to quantize the node take the entire match, and just copy it over - env[node.name] = copy_recursive(node) - else: - quant_env[node.name] = r - - return GraphModule(self.root, self.quantized_graph) - - def _find_matches(self, patterns): - modules = dict(self.root.named_modules()) - match_map = {} # node name -> (root_node, match_value?) - - def apply_match(pattern, node, match): - if isinstance(pattern, tuple): - s, *args = pattern - apply_match(s, node, match) - for subpattern, arg in zip(args, node.args): - apply_match(subpattern, arg, match) - else: - match_map[node.name] = match - - for node in reversed(self.graph.nodes): - if node.name not in match_map: - for pattern, value in patterns.items(): - if matches(modules, node, pattern): - apply_match(pattern, node, (node, value(self, node))) - - return match_map - - def _find_quants(self, quant_ctor): - quants = {} - - def visit_arg(n): - # note: we have to measure quantization information - # even for nodes where we might not use it because it is already - # quantized. This is because each match has the option to - # say NotImplemented (if for instance, it is an __add__ and the data type is not appropriate) - if n.name not in quants: - quants[n.name] = quant_ctor(self, n) - for node in self.graph.nodes: - if node.name in self.matches: - map_arg(node.args, visit_arg) - map_arg(node.kwargs, visit_arg) - return quants diff --git a/test/fx/test_common_passes.py b/test/fx/test_common_passes.py deleted file mode 100644 index d16020e69..000000000 --- a/test/fx/test_common_passes.py +++ /dev/null @@ -1,116 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates -# Owner(s): ["oncall: fx"] - -import torch - -from torch.testing._internal.common_utils import ( - TestCase, parametrize, instantiate_parametrized_tests, run_tests) -from pippy.fx.experimental.proxy_tensor import make_fx -from pippy.fx.passes.dialect.common.cse_pass import CSEPass -from pippy.fx.graph_module import GraphModule - -import itertools - -def FactoryFunctionCall(x, device): - y = torch.full(x.shape, 3, device=device) - z = torch.add(y, x) - return z - - -def TorchTensorCall(x): - y = torch.tensor(3) - return x + y - - -def TakeList(x): - z = torch.cat([x, x]) - return z - - -def ReturnList(x): - a = torch.arange(10).reshape(5, 2) - z = torch.split(a, [1, 4]) - return z - - -def Mutation(x): - y = x + 2 - y.add_(1) - return x + y - - -def MutationInput(x): - x.add_(1) - y = x + 2 - return x + y - - -def MutationFactory(x, device): - y = torch.full(x.shape, 3, device=device) - y.add_(1) - return x + y - - -def MutationTorchTensorCall(x): - y = torch.tensor(3) - y.add_(1) - return x + y - - -def MutationMetadata(x): - x.resize_(2) - return x - - -Passes = [CSEPass] -Test_Cases = [TakeList, - ReturnList, - Mutation, - MutationInput, - MutationMetadata, - MutationTorchTensorCall] -Factory_Test_Cases = [FactoryFunctionCall, MutationFactory] -Devices = ["cpu"] -if torch.cuda.is_available(): - Devices.append("cuda") - -@instantiate_parametrized_tests -class TestCommonPass(TestCase): - - @parametrize("common_pass,f,device", itertools.product(Passes, Test_Cases, Devices)) - def test_correctness(self, common_pass, f, device): - inp = torch.randn(10, device=device) - - traced_m = make_fx(f)(inp) - P = common_pass() - - res = P(traced_m) - modified_m = res.graph_module - assert isinstance(modified_m, GraphModule) - - inp_copy = inp.clone() - expected = f(inp) - result = modified_m(inp_copy) - - self.assertEqual(result, expected) - - - @parametrize("common_pass,f,device", itertools.product(Passes, Factory_Test_Cases, Devices)) - def test_correctness_factory(self, common_pass, f, device): - inp = torch.randn(10, device=device) - traced_m = make_fx(f)(inp, device) - P = common_pass() - - res = P(traced_m) - modified_m = res.graph_module - assert isinstance(modified_m, GraphModule) - - inp_copy = inp.clone() - expected = f(inp, device) - result = modified_m(inp_copy, device) - - self.assertEqual(result, expected) - - -if __name__ == '__main__': - run_tests() diff --git a/test/fx/test_cse_pass.py b/test/fx/test_cse_pass.py deleted file mode 100644 index 21a83e66d..000000000 --- a/test/fx/test_cse_pass.py +++ /dev/null @@ -1,234 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates -# Owner(s): ["oncall: fx"] - -import torch - -from torch.testing._internal.common_utils import ( - TestCase, run_tests) -from pippy.fx.experimental.proxy_tensor import make_fx -from pippy.fx.passes.dialect.common.cse_pass import CSEPass, get_CSE_banned_ops -from pippy.fx import symbolic_trace - -import random - - -banned_ops = get_CSE_banned_ops() -P_default = CSEPass(banned_ops=banned_ops) - -def check(self, f, t, delta, check_val=True, graph_input=False, P=None): - """ - check if the CSE modified graph of ``f`` - 1) has delta less nodes, and - 2) do not reduce the number of nodes further on a second pass, and - 3) modified returned is true only if the number of nodes decreases. - - Args: - f: function to be checked - t: tensor to be passed to f - delta: an integer >= -1. - If delta = -1, it only checks if the new graph has less or equal number of nodes - check_val: if True, check if the output of f is correct - graph_input: True is f is type GraphModule - P: the pass to use. If None, use P_default - """ - if graph_input: - fx_g = f - else: - fx_g = make_fx(f)(t) - - if P is None: - P = P_default - - res = P(fx_g) - new_g = res.graph_module - new_graph = new_g.graph - modified = res.modified - - # the number of nodes decrease/ or stay the same - old_num_nodes = len(fx_g.graph.nodes) - new_num_nodes = len(new_graph.nodes) - - assert (new_num_nodes < old_num_nodes) == modified, "modified should be True if the number of nodes decrease" - - if delta == -1: - self.assertTrue(old_num_nodes >= new_num_nodes, ( - f"number of nodes increased {old_num_nodes}, {new_num_nodes}")) - else: - self.assertTrue(old_num_nodes == new_num_nodes + delta, ( - f"number of nodes not the same {old_num_nodes - delta}, {new_num_nodes}\n {fx_g.graph} \n {new_graph}")) - - # a second pass should not reduce more nodes - res = P(new_g) - pass_2_graph = res.graph_module.graph - pass_2_num_nodes = len(pass_2_graph.nodes) - self.assertTrue(pass_2_num_nodes == new_num_nodes, ( - f"second pass graph has less node {pass_2_num_nodes}, {new_num_nodes}\n {new_graph} \n {pass_2_graph}")) - - # check correctness - if check_val: - true_result = fx_g(t) - our_result = new_g(t) - if true_result is None: # both return None - self.assertTrue(our_result is None, f"true result is None, CSE result is {our_result}") - else: # results returned are the same - self.assertTrue(torch.all(true_result == our_result), ( - f"results are different {true_result}, {our_result}")) # check results are the same - -class TestCSEPass(TestCase): - - def test_nochange(self): - def f(x): - a = x + 1 - b = x + a - a = x - d = x + a - return b + d - t = torch.randn(2, 2) - check(self, f, t, 0) - - def test_empty(self): - def f(x): - pass - t = torch.randn(2, 2) - check(self, f, t, 0) - - - def test_immutable_list_type(self): - def f(x): - a = x.sum(dim=1) - b = x.sum(dim=1) - c = x.sum() - d = x.sum() - return a + b + c + d - t = torch.randn(2, 2) - check(self, f, t, 2) - - def test_immutable_list_multiple_entries(self): - def f(x): - a = x.sum(dim=[0, 1]) - b = x.sum(dim=[0, 1]) - c = x.sum(dim=1) - d = x.sum(dim=1) - return a + b + c + d - t = torch.randn(2, 2) - check(self, f, t, 2) - - def test_simple(self): - def f(x): - a = x.cos() - b = x.cos() - c = a + a - d = b + b - return c + d - t = torch.randn(2, 2) - check(self, f, t, 2) - - def test_simple_2(self): - def f(x): - a = x.cos().sin() - b = x.cos().sin() - c = a + a - d = b + b - return c + d - t = torch.randn(1) - check(self, f, t, 3) - - def test_two_args_default(self): - def f(x): - a = x.sum(dim=1) - b = x.sum(dim=1, keepdim=False) - c = x.sum(dim=1, keepdim=False) - d = x.sum(dim=1) - return a + b + c + d - t = torch.randn(2, 2) - check(self, f, t, 3) - - def test_two_args(self): - def f(x): - a = x.sum(dim=1) - b = x.sum(dim=1, keepdim=True) - c = x.sum(dim=1, keepdim=True) - d = x.sum(dim=1) - return a + b + c + d - t = torch.randn(2, 2) - check(self, f, t, 2) - - def test_simple_multiple_same_ops(self): - def f(x): - a = x.sum() - b = x.sum() - c = x.sum() - d = x.sum() - return a + b + c + d - t = torch.randn(2, 2) - check(self, f, t, 3) - - def test_nested_immutable_list_type(self): - def f(x): - a = torch.cat((x, x)) - b = torch.cat((x, x)) - return a + b - t = torch.randn(2, 2) - check(self, f, t, 1) - - def test_kwarg(self): - def f(x): - a = torch.ones_like(x) - b = torch.ones_like(x) - return a + b - t = torch.randn(2, 2) - check(self, f, t, 1) - - """ - Generate function with random ops and check if the result is the same - """ - def test_random(self): - def f(x): - vals = [x] - ops = [torch.clone, torch.cos, torch.tanh, torch.nn.functional.gelu] - for _ in range(100): - new_val = random.choice(ops)(random.choice(vals)) - vals.append(new_val) - return vals[-1] - - fx_g = symbolic_trace(f) - fx_g.graph.eliminate_dead_code() - fx_g.recompile() - t = torch.randn(2, 2) - - for _ in range(30): - check(self, fx_g, t, -1, graph_input=True) - - """ - Test that banned list ban ops as expected. - """ - def test_banned_list(self): - def f(x): - a = x + 1 - b = x + 1 - return a + b - - t = torch.randn(2, 2) - P_ban_add = P = CSEPass(banned_ops=[torch.ops.aten.add]) - check(self, f, t, 0, P=P_ban_add) # check that add is banned - check(self, f, t, 1) # check that add is not banned by default - - def test_rand_like(self): - def f(x): - a = torch.rand_like(x) - b = torch.rand_like(x) - return a + b - t = torch.randn(2, 2) - check(self, f, t, 0, check_val=False) - - def test_rand_n(self): - def f(x): - a = torch.randn(4) - b = torch.randn(4) - return a + b - t = torch.randn(2, 2) - check(self, f, t, 0, check_val=False) - - -if __name__ == '__main__': - run_tests() diff --git a/test/fx/test_dce_pass.py b/test/fx/test_dce_pass.py deleted file mode 100644 index bb29df4c9..000000000 --- a/test/fx/test_dce_pass.py +++ /dev/null @@ -1,185 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates -# Owner(s): ["module: fx"] - -from typing import Set, Type -import torch -import pippy.fx - -from torch.testing._internal.common_utils import TestCase - - -class TestDCE(TestCase): - def _has_nodes_without_users(self, m: pippy.fx.GraphModule): - for node in m.graph.nodes: - if node.is_impure(): - continue - if len(node.users) == 0: - return True - return False - - def _get_num_placeholders(self, m: pippy.fx.GraphModule) -> int: - count = 0 - for node in m.graph.nodes: - if node.op == "placeholder": - count += 1 - return count - - def _run_dce_and_test( - self, - m: torch.nn.Module, - expect_dce_changes: bool, - modules_to_be_leafs: Set[Type] = None, - ): - class TestTracer(pippy.fx.Tracer): - def is_leaf_module(self, m, qualname): - if modules_to_be_leafs and type(m) in modules_to_be_leafs: - return True - return super().trace(m, qualname) - - traced: pippy.fx.GraphModule = pippy.fx.GraphModule(m, TestTracer().trace(m)) - print(str(traced.graph)) - - # Verify there are nodes without users (if expected). - has_nodes_without_users = self._has_nodes_without_users(traced) - if expect_dce_changes: - self.assertTrue(has_nodes_without_users) - else: - self.assertFalse(has_nodes_without_users) - - # Get the original number of placeholders to verify it doesn't change - # during DCE. - orig_num_phs = self._get_num_placeholders(traced) - changed = traced.graph.eliminate_dead_code() - - self.assertTrue(changed if expect_dce_changes else not changed) - - # Verify there are no nodes without users after DCE is run. - self.assertFalse(self._has_nodes_without_users(traced)) - new_num_phs = self._get_num_placeholders(traced) - self.assertEqual(orig_num_phs, new_num_phs) - - traced.recompile() - # Make sure we run and get the same results before/after DCE. - inputs = [torch.tensor([1.5])] * new_num_phs - self.assertTrue(torch.equal(m(*inputs), traced(*inputs))) - - def test_simple(self): - """ - Tests that a single node in the graph is DCE'd correctly. - """ - - class TestModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.attr_1 = torch.nn.Parameter(torch.tensor([-0.9])) - - def forward(self, x): - a = x + 1 - return x + self.attr_1 - - self._run_dce_and_test(TestModule(), expect_dce_changes=True) - - def test_dead_chain(self): - """ - Tests that a chain of two nodes in the graph are DCE'd correctly. - """ - - class TestModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.attr_1 = torch.nn.Parameter(torch.tensor([-0.9])) - - def forward(self, x): - a = x + 1 - b = a * 7 - return x + self.attr_1 - - self._run_dce_and_test(TestModule(), expect_dce_changes=True) - - def test_dead_getattr(self): - """ - Tests that a getatrr in the graph is DCE'd correctly. - """ - - class TestModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.attr_1 = torch.nn.Parameter(torch.tensor([-0.9])) - - def forward(self, x): - a = x + 1 - b = a * self.attr_1 - return x + 11 - - self._run_dce_and_test(TestModule(), expect_dce_changes=True) - - def test_dead_placeholder(self): - """ - Tests that a placeholder in the graph is not DCE'd, as that would change - the function signature. - """ - - class TestModule(torch.nn.Module): - def __init__(self): - super().__init__() - - def forward(self, x, y): - return x + 7 - - self._run_dce_and_test(TestModule(), expect_dce_changes=False) - - def test_dead_placeholder_with_user(self): - """ - Tests that a placeholder in the graph is not DCE'd, as that would change - the function signature. Also verifies that a dead node that uses the - placeholder is DCE'd. - - """ - - class TestModule(torch.nn.Module): - def __init__(self): - super().__init__() - - def forward(self, x, y): - a = y + 2 - return x + 7 - - self._run_dce_and_test(TestModule(), expect_dce_changes=True) - - def test_keep_module_with_side_effects(self): - """ - Test that DCE doesn't remove a module if it's specified as having side effects. - """ - - class ReLUImpure(torch.nn.ReLU): - _is_impure = True - - class TestModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.relu = ReLUImpure() - - def forward(self, a: torch.Tensor) -> torch.Tensor: - r = self.relu(a) - return a * 2 - - self._run_dce_and_test( - TestModule(), expect_dce_changes=False, modules_to_be_leafs={ReLUImpure} - ) - - def test_keep_torch_assert(self): - """ - Test that DCE doesn't remove torch._assert since it has side effects. - """ - - class TestModule(torch.nn.Module): - def __init__(self): - super().__init__() - - def forward(self, a: torch.Tensor) -> torch.Tensor: - torch._assert(torch.equal(a, a), "a must equal a") - return a * 2 - - # Note: Don't need to specify torch._assert as having side effects - # because it's known to. - self._run_dce_and_test(TestModule(), expect_dce_changes=False) diff --git a/test/fx/test_future.py b/test/fx/test_future.py deleted file mode 100644 index de9af1487..000000000 --- a/test/fx/test_future.py +++ /dev/null @@ -1,51 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates -# Owner(s): ["module: fx"] - -from __future__ import annotations # type: ignore[attr-defined] -import torch -import typing -from pippy.fx import symbolic_trace - -class A: - def __call__(self, x: torch.Tensor): - return torch.add(x, x) - -# No forward references -class M1(torch.nn.Module): - def forward(self, x: torch.Tensor, a: A) -> torch.Tensor: - return a(x) - -# Forward references -class M2(torch.nn.Module): - def forward(self, x: 'torch.Tensor', a: 'A') -> 'torch.Tensor': - return a(x) - -# Non-torch annotation with no internal forward references -class M3(torch.nn.Module): - def forward(self, x: typing.List[torch.Tensor], a: A) -> torch.Tensor: - return a(x[0]) - -# Non-torch annotation with internal forward references -class M4(torch.nn.Module): - def forward(self, x: typing.List['torch.Tensor'], a: A) -> 'torch.Tensor': - return a(x[0]) - -x = torch.rand(2, 3) - -ref = torch.add(x, x) - -traced1 = symbolic_trace(M1()) -res1 = traced1(x, A()) -assert torch.all(torch.eq(ref, res1)) - -traced2 = symbolic_trace(M2()) -res2 = traced2(x, A()) -assert torch.all(torch.eq(ref, res2)) - -traced3 = symbolic_trace(M3()) -res3 = traced3([x], A()) -assert torch.all(torch.eq(ref, res3)) - -traced4 = symbolic_trace(M4()) -res4 = traced4([x], A()) -assert torch.all(torch.eq(ref, res4)) diff --git a/test/fx/test_fx_const_fold.py b/test/fx/test_fx_const_fold.py deleted file mode 100644 index b8207c2b5..000000000 --- a/test/fx/test_fx_const_fold.py +++ /dev/null @@ -1,712 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates -# Owner(s): ["module: fx"] - -import operator - -import torch -import pippy.fx -from pippy.fx.experimental import const_fold -from pippy.fx.passes.shape_prop import _extract_tensor_metadata, ShapeProp -from torch.testing._internal.common_utils import TestCase - - -class TestConstFold(TestCase): - def _get_attr(self, node): - mod = node.graph.owning_module - target = str(node.target) - target_atoms = target.split(".") - curr_obj = mod - for i, atom in enumerate(target_atoms): - if not hasattr(curr_obj, atom): - raise RuntimeError( - f"Node referenced nonexistent target '{'.'.join(target_atoms[:i])}'; " - f" original whole target: '{target}'" - ) - curr_obj = getattr(curr_obj, atom) - return curr_obj - - def _verify_const_fold_mod(self, mod_folded: const_fold.FoldedGraphModule): - self.assertTrue(mod_folded.const_subgraph_module is not None) - - # Check that we don't have the const or non-const fold graphs in the gm, and - # that we do have the const folded get_attr. - found_folded_attrs = False - for n in mod_folded.graph.nodes: - if n.op == "get_attr" and n.target.startswith("_FX_CONST_FOLDED_ATTRS"): - found_folded_attrs = True - elif n.op == "call_module": - self.assertTrue(n.target not in {"submod_0", "submod_1"}) - self.assertTrue(found_folded_attrs) - - def test_const_fold_basic_one_attr_no_name_collision(self): - r""" - Perform constant folding conversion, from original mod to split constant folding - module with two split subgraphs, where there's a single attr to fold and - a single output attr result to replace. - - attr1 attr1 - | | | | - x add add - \ / | - sub y output (becomes attr add_1) - \ / ==> -------+------- (const/base subgraph split) - mul attr2 x / (input from previous subgraph - \ / \ / is attr) - add sub y - | \ / - output mul attr2 - \ / - add - | - output - """ - - class ConstFoldTestModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.attr_1 = torch.nn.Parameter(torch.tensor([[-0.9]])) - self.attr_2 = torch.nn.Parameter(torch.tensor([[17.1]])) - - def forward(self, x, y): - a = self.attr_1 + self.attr_1 - x = x - a - return x * y + self.attr_2 - - mod = ConstFoldTestModule() - mod_folded: const_fold.FoldedGraphModule = const_fold.split_const_subgraphs(mod) - self._verify_const_fold_mod(mod_folded) - - # Now run both folded and non-folded to check results equal. - in_x, in_y = torch.tensor([[-0.45]]), torch.tensor([0.9]) - base_result = mod(in_x, in_y) - fold_result = mod_folded(in_x, in_y) - self.assertTrue(torch.equal(fold_result, base_result)) - - def test_const_fold_basic_one_attr_name_collision(self): - r""" - Perform constant folding conversion, from original mod to split constant folding - module with two split subgraphs, where there's a single attr to fold and - a single output attr result to replace. Name the attrs such that they will - collide by name with folded attrs. - - add_1 add_1 - | | | | - x add add - \ / | - sub y output (becomes attr add_1) - \ / ==> -------+------- (const/base subgraph split) - mul add_2 x / (input from previous subgraph - \ / \ / is attr) - add sub y - | \ / - output mul add_2 - \ / - add - | - output - """ - - class ConstFoldTestModule(torch.nn.Module): - def __init__(self): - super().__init__() - # Note: Named as such to result in name collision. - self.add_1__CF = torch.nn.Parameter(torch.tensor([[1.0]])) - self.add_2__CF = torch.nn.Parameter(torch.tensor([[17.1]])) - - def forward(self, x, y): - a = self.add_1__CF + self.add_1__CF - x = x - a - return x * y + self.add_2__CF - - mod = ConstFoldTestModule() - mod_folded: const_fold.FoldedGraphModule = const_fold.split_const_subgraphs(mod) - self._verify_const_fold_mod(mod_folded) - - # Now run both folded and non-folded to check results equal. - in_x, in_y = torch.tensor([[5.0]]), torch.tensor([4.0]) - base_result = mod(in_x, in_y) - fold_result = mod_folded(in_x, in_y) - self.assertTrue(torch.equal(fold_result, base_result)) - - def test_const_fold_basic_placeholder_reordered(self): - """ - Test code path where placeholder comes after normal op node in FX - """ - - class ConstFoldTestModule(torch.nn.Module): - def __init__(self): - super().__init__() - - def forward(self, x, y): - return x * 2 + y - - mod = ConstFoldTestModule() - mod = pippy.fx.symbolic_trace(mod) - yy = None - for n in mod.graph.nodes: - if n.op == "placeholder" and n.target == "y": - yy = n - elif yy is not None and n.op == "call_function": - yy.prepend(n) - break - - mod_folded: const_fold.FoldedGraphModule = const_fold.split_const_subgraphs(mod) - - self.assertTrue(mod_folded.const_subgraph_module is None) - # Now run both folded and non-folded to check results equal. - in_x = torch.tensor([[-0.45]]) - in_y = torch.tensor([[0.45]]) - base_result = mod(in_x, in_y) - fold_result = mod_folded(in_x, in_y) - self.assertTrue(torch.equal(fold_result, base_result)) - - def test_const_fold_noop(self): - r""" - Check that a graph with no constant folding is handled correctly. - - x attr1 - \ / - sub - | - output - """ - - class ConstFoldTestModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.attr1 = torch.nn.Parameter(torch.tensor([[-0.9]])) - - def forward(self, x): - return x - self.attr1 - - mod = ConstFoldTestModule() - mod_folded: const_fold.FoldedGraphModule = const_fold.split_const_subgraphs(mod) - - # Check that the folded graph module is None, since there was no folding to do. - self.assertTrue(mod_folded.const_subgraph_module is None) - - # Now run both folded and non-folded to check results equal. - in_x = torch.tensor([[-0.45]]) - base_result = mod(in_x) - fold_result = mod_folded(in_x) - self.assertTrue(torch.equal(fold_result, base_result)) - - def test_const_fold_basic_two_attr_three_input(self): - r""" - Perform constant folding conversion, from original mod to split constant - folding module with two split subgraphs, where there are two attrs to - fold into a single output, and there are three placeholder inputs. - - attr1 attr2 attr1 attr2 - \ / \ / - x add add - \ / | - sub y output (becomes attr add_1) - \ / ==> -------+------- (const/base subgraph split) - mul z x / (input from previous subgraph - \ / \ / is attr) - div sub y - | \ / - output mul z - \ / - div - | - output - """ - - class ConstFoldTestModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.attr1 = torch.nn.Parameter(torch.tensor([[-0.9]])) - self.attr1 = torch.nn.Parameter(torch.tensor([[1.32]])) - - def forward(self, x, y, z): - a = self.attr1 + self.attr1 - sub = x - a - mul = sub * y - return mul / z - - mod = ConstFoldTestModule() - mod_folded: const_fold.FoldedGraphModule = const_fold.split_const_subgraphs(mod) - self._verify_const_fold_mod(mod_folded) - - # Now run both folded and non-folded to check results equal. - in_x, in_y, in_z = ( - torch.tensor([[-0.45]]), - torch.tensor([0.9]), - torch.tensor([1.1]), - ) - base_result = mod(in_x, in_y, in_z) - fold_result = mod_folded(in_x, in_y, in_z) - self.assertTrue(torch.equal(fold_result, base_result)) - - def test_const_fold_basic_two_attr(self): - r""" - Perform constant folding conversion, from original mod to split constant - folding module with two split subgraphs, where there are two attrs to - fold into a single output. - - attr1 attr2 attr1 attr2 - \ / \ / - x add add (becomes attr add_1) - \ / ==> -------+------- (const/base subgraph split) - sub x | (input from previous subgraph is attr) - | \ / - output sub - | - output - """ - - class ConstFoldTestModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.attr1 = torch.nn.Parameter(torch.randn(2, 3)) - self.attr2 = torch.nn.Parameter(torch.randn(2, 3)) - - def forward(self, x): - y = self.attr1 + self.attr2 - return x + y - - mod = ConstFoldTestModule() - mod_folded: const_fold.FoldedGraphModule = const_fold.split_const_subgraphs(mod) - self._verify_const_fold_mod(mod_folded) - - # Now run both folded and non-folded to check results equal. - in_x = torch.randn(2, 3) - fold_result = mod_folded(in_x) - base_result = mod(in_x) - self.assertTrue(torch.equal(fold_result, base_result)) - - def test_const_fold_multi_const_folded_attrs(self): - r""" - Perform constant folding conversion, from original mod to split constant - folding module with two split subgraphs, where there are two attrs to - fold into two new attrs. - - attr1 attr2 attr1 attr2 - / \ | / \ | - permute | sum permute | sum - \ / / \ / | - x add y / add | - \ / \ / | | - sub add output output (become attrs add_1 and mul_1) - \ / ==> --------+-------+------ (const/base subgraph split) - \ / x | y | (inputs from previous subgraph - add \ / \ / are attrs) - | sub add - linear \ / - | add - sigmoid | - | linear - output | - sigmoid - | - output - """ - - class ConstFoldTestModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.attr1 = torch.nn.Parameter(torch.randn(4, 4)) - self.attr2 = torch.nn.Parameter(torch.randn(4, 4)) - self.lin = torch.nn.Linear(4, 4) - - def forward(self, x, y): - a = self.attr1 + self.attr1.permute(1, 0) - x = x - a - amax = torch.sum(self.attr2, dim=1) - y = y + amax - return torch.sigmoid(self.lin(x + y)) - - mod = ConstFoldTestModule() - mod_folded: const_fold.FoldedGraphModule = const_fold.split_const_subgraphs(mod) - self._verify_const_fold_mod(mod_folded) - - # Now run both folded and non-folded to check results equal. - in_x, in_y = torch.randn(4, 4), torch.randn(4) - fold_result = mod_folded(in_x, in_y) - base_result = mod(in_x, in_y) - self.assertTrue(torch.equal(fold_result, base_result)) - - def test_const_fold_submod_hierarchy(self): - r""" - Perform constant folding conversion, from original mod to split constant folding - module where one of the folded attrs comes from a submod deeper in the hierarchy - of the base module. - """ - - class TracedThroughModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.internal_attr = torch.nn.Parameter(torch.randn(2, 3)) - - def forward(self): - return self.internal_attr - - class ConstFoldTestModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.my_mod = TracedThroughModule() - self.attr = torch.nn.Parameter(torch.randn(2, 3)) - - def forward(self, x): - return self.attr + self.my_mod() + x - - mod = ConstFoldTestModule() - mod_folded: const_fold.FoldedGraphModule = const_fold.split_const_subgraphs(mod) - self._verify_const_fold_mod(mod_folded) - - # Now run both folded and non-folded to check results equal. - in_x = torch.randn(2, 3) - fold_result = mod_folded(in_x) - base_result = mod(in_x) - self.assertTrue(torch.equal(fold_result, base_result)) - - def test_retain_node_meta(self): - r""" - Perform constant folding conversion, and validate that node meta is retained. - """ - - class ConstFoldTestModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.attr = torch.nn.Parameter(torch.randn(2, 3)) - - def forward(self, x): - a = self.attr + self.attr - return x - a - - mod = ConstFoldTestModule() - gm = pippy.fx.symbolic_trace(mod) - - # Add a count for each node to check after we const fold. - for idx, node in enumerate(gm.graph.nodes): - if node.op != "output": - node.meta["meta_idx"] = idx - - # Pre-folding: - # idx 0: placeholder - # idx 1: get_attr (will no longer be used, hence removed) - # idx 2: add (will be folded into a get_attr) - # idx 3: sub - - gm_folded: const_fold.FoldedGraphModule = const_fold.split_const_subgraphs(gm) - self._verify_const_fold_mod(gm_folded) - - # Post-folding: - # idx 0: placeholder - # idx 2: get_attr (replaced original add; original get_attr was removed) - # idx 3: sub - - # Check the expected indices are still here. - for node in gm_folded.graph.nodes: - if node.op == "placeholder": - self.assertEqual(node.meta["meta_idx"], 0) - elif node.op == "get_attr": - self.assertEqual(node.meta["meta_idx"], 2) - elif node.op == "call_function" and node.target == operator.sub: - self.assertEqual(node.meta["meta_idx"], 3) - else: - self.assertEqual(node.op, "output") - - # Now run both folded and non-folded to check results equal. - in_x = torch.randn(2, 3) - fold_result = gm_folded(in_x) - base_result = mod(in_x) - self.assertTrue(torch.equal(fold_result, base_result)) - - def test_const_fold_has_inlined_call_module_node(self): - class ConstFoldTestModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.attr = torch.nn.Parameter(torch.randn(2, 3)) - self.mod = torch.nn.Identity() - self.mod.relu = torch.nn.ReLU() - - def forward(self, x): - a = self.attr + self.attr - return self.mod.relu(x - a) - - mod = ConstFoldTestModule() - gm_folded = const_fold.split_const_subgraphs(mod) - - # Now run both folded and non-folded to check results equal. - in_x = torch.randn(2, 3) - fold_result = gm_folded(in_x) - base_result = mod(in_x) - self.assertTrue(torch.equal(fold_result, base_result)) - - def test_const_fold_module_attr(self): - class ConstFoldTestModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.const = torch.nn.Parameter(torch.randn(2, 3)) - self.mod = torch.nn.Identity() - self.mod.attr = torch.nn.Parameter(torch.randn(2, 3)) - - def forward(self, x): - a = self.const + self.mod.attr - x = x + a - return x + self.mod.attr - - mod = ConstFoldTestModule() - gm_folded = const_fold.split_const_subgraphs(mod) - - # Now run both folded and non-folded to check results equal. - in_x = torch.randn(2, 3) - fold_result = gm_folded(in_x) - base_result = mod(in_x) - self.assertTrue(torch.equal(fold_result, base_result)) - - def test_const_fold_unused_placeholder(self): - class ConstFoldTestModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.const = torch.nn.Parameter(torch.randn(2, 3)) - - def forward(self, x, y, z): - a = self.const + self.const - return y + a - - mod = ConstFoldTestModule() - gm_folded = const_fold.split_const_subgraphs(mod) - - # Now run both folded and non-folded to check results equal. - in_x = torch.randn(2, 3) - fold_result = gm_folded(in_x, in_x, in_x) - base_result = mod(in_x, in_x, in_x) - self.assertTrue(torch.equal(fold_result, base_result)) - - def test_dict_output(self): - class ConstFoldTestModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.const = torch.nn.Parameter(torch.randn(2, 3)) - - def forward(self, x): - a = self.const + self.const - return {"result": x + a} - - mod = ConstFoldTestModule() - gm_folded = const_fold.split_const_subgraphs(mod) - - # Now run both folded and non-folded to check results equal. - in_x = torch.randn(2, 3) - fold_result = gm_folded(in_x) - base_result = mod(in_x) - self.assertTrue(torch.equal(fold_result["result"], base_result["result"])) - - def test_two_outputs(self): - class ConstFoldTestModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.const = torch.nn.Parameter(torch.randn(2, 3)) - - def forward(self, x): - a = self.const + self.const - return x, x + a - - mod = ConstFoldTestModule() - gm_folded = const_fold.split_const_subgraphs(mod) - - # Now run both folded and non-folded to check results equal. - in_x = torch.randn(2, 3) - fold_result = gm_folded(in_x) - base_result = mod(in_x) - self.assertTrue(torch.equal(fold_result[0], base_result[0])) - self.assertTrue(torch.equal(fold_result[1], base_result[1])) - - def test_three_outputs(self): - class ConstFoldTestModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.const = torch.nn.Parameter(torch.randn(2, 3)) - - def forward(self, x): - a = self.const + self.const - return x, x + a, x + a - - mod = ConstFoldTestModule() - gm_folded = const_fold.split_const_subgraphs(mod) - - # Now run both folded and non-folded to check results equal. - in_x = torch.randn(2, 3) - fold_result = gm_folded(in_x) - base_result = mod(in_x) - self.assertTrue(torch.equal(fold_result[0], base_result[0])) - self.assertTrue(torch.equal(fold_result[1], base_result[1])) - self.assertTrue(torch.equal(fold_result[2], base_result[2])) - - def test_check_inline_non_const(self): - r""" - Perform constant folding conversion and check that the non-const module is inlined - correctly. - """ - - class ConstFoldTestModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.attr = torch.nn.Parameter(torch.randn(2, 3)) - - def forward(self, x): - a = self.attr + self.attr - return (x - a * x) / 2 - - mod = ConstFoldTestModule() - gm = pippy.fx.symbolic_trace(mod) - - gm_folded: const_fold.FoldedGraphModule = const_fold.split_const_subgraphs(gm) - self._verify_const_fold_mod(gm_folded) - - # Check there are no call modules, because they've been inlined or extracted for - # const folding. - for node in gm_folded.graph.nodes: - self.assertNotEqual(node.op, "call_module") - - # Now run both folded and non-folded to check results equal. - in_x = torch.randn(2, 3) - fold_result = gm_folded(in_x) - base_result = mod(in_x) - self.assertTrue(torch.equal(fold_result, base_result)) - - def test_check_inline_non_const_mult_return(self): - r""" - Perform constant folding conversion and check that the non-const module is inlined - correctly. - """ - - class ConstFoldTestModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.attr = torch.nn.Parameter(torch.randn(2, 3)) - - def forward(self, x): - a = self.attr + self.attr - return x - a, x / 2 - - mod = ConstFoldTestModule() - gm = pippy.fx.symbolic_trace(mod) - - gm_folded: const_fold.FoldedGraphModule = const_fold.split_const_subgraphs(gm) - self._verify_const_fold_mod(gm_folded) - - # Check there are no call modules, because they've been inlined or extracted for - # const folding. - for node in gm_folded.graph.nodes: - self.assertNotEqual(node.op, "call_module") - - # Now run both folded and non-folded to check results equal. - in_x = torch.randn(2, 3) - fold_result = gm_folded(in_x) - base_result = mod(in_x) - self.assertTrue(torch.equal(fold_result[0], base_result[0])) - self.assertTrue(torch.equal(fold_result[1], base_result[1])) - - def test_check_skip_folding_quant_dequant_pattern(self): - r""" - Set up skip_folding_quant_dequant function to skip quant/dequant pattern. - This example shows how to use skip_folding_node_fn. - """ - - class ConstFoldTestModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.weight = torch.nn.Parameter(torch.randn(4, 4)) - self.bias = torch.nn.Parameter(torch.randn(4)) - self.relu = torch.nn.ReLU() - - def forward(self, x): - quant_weight = torch.quantize_per_tensor( - self.weight, 0.5, 3, torch.quint8 - ) - dequant_weight = torch.dequantize(quant_weight) - output = torch.nn.functional.linear(x, dequant_weight, self.bias) - return self.relu(output) - - mod = ConstFoldTestModule() - in_x = torch.randn(2, 4) - gm = pippy.fx.symbolic_trace(mod) - - def skip_folding_quant_dequant(node: pippy.fx.Node): - if node.target != torch.quantize_per_tensor: - return False - # If quantize_per_node -> dequantize, then skip folding. - for user in node.users: - if user.target == torch.dequantize: - return True - return False - - gm_folded: const_fold.FoldedGraphModule = const_fold.split_const_subgraphs( - gm, skip_folding_node_fn=skip_folding_quant_dequant - ) - - # Check that the folded graph module is None, since there was no folding to do. - self.assertTrue(gm_folded.const_subgraph_module is None) - - # Now run both folded and non-folded to check results equal. - fold_result = gm_folded(in_x) - base_result = mod(in_x) - self.assertTrue(torch.equal(fold_result, base_result)) - - def test_fold_module(self): - r""" - Perform constant folding with a call_module node. - """ - - class ConstFoldTestModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.lin_input = torch.nn.Parameter(torch.randn(4, 4)) - self.lin = torch.nn.Linear(4, 4) - - def forward(self, x): - return self.lin(self.lin_input) + x - - mod = ConstFoldTestModule() - mod_folded: const_fold.FoldedGraphModule = const_fold.split_const_subgraphs(mod) - self._verify_const_fold_mod(mod_folded) - - # Now run both folded and non-folded to check results equal. - inp = torch.randn(4, 4) - self.assertTrue(torch.equal(mod_folded(inp), mod(inp))) - - def test_const_fold_tensor_meta(self): - self._test_const_fold_tensor_meta(True) - self._test_const_fold_tensor_meta(False) - - def _test_const_fold_tensor_meta(self, requires_grad): - """ - Verify tensor_meta is handled correctly. - """ - - class ConstFoldTestModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.attr_1 = torch.nn.Parameter(torch.tensor([[-0.9]]), requires_grad) - self.attr_2 = torch.nn.Parameter(torch.tensor([[17.1]]), requires_grad) - - def forward(self, x, y): - a = self.attr_1 + self.attr_1 - x = x - a - return x * y + self.attr_2 - - mod = ConstFoldTestModule() - gm = pippy.fx.symbolic_trace(mod) - in_x, in_y = torch.tensor([[-0.45]]), torch.tensor([0.9]) - ShapeProp(gm).propagate(in_x, in_y) - mod_folded: const_fold.FoldedGraphModule = const_fold.split_const_subgraphs( - gm, device_for_folded_attrs="cpu" - ) - self._verify_const_fold_mod(mod_folded) - - mod_folded.run_folding() - - for n in mod_folded.graph.nodes: - if n.op == "get_attr": - attr = self._get_attr(n) - self.assertEquals(_extract_tensor_metadata(attr), n.meta["tensor_meta"]) - - # Now run both folded and non-folded to check results equal. - base_result = mod(in_x, in_y) - fold_result = mod_folded(in_x, in_y) - self.assertTrue(torch.equal(fold_result, base_result)) diff --git a/test/fx/test_fx_param_shape_control_flow.py b/test/fx/test_fx_param_shape_control_flow.py deleted file mode 100644 index 88f19642c..000000000 --- a/test/fx/test_fx_param_shape_control_flow.py +++ /dev/null @@ -1,156 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates -# Owner(s): ["module: fx"] - -import unittest -import torch -import pippy.fx - -from torch.testing._internal.common_utils import TestCase - - -class MyModuleBase(torch.nn.Module): - def forward(self, x): - matrx = self.get_mul_matrix() - if self.no_relu(): - return torch.mm(x, matrx) - else: - return torch.relu(torch.mm(x, matrx)) - - def get_mul_matrix(self): - return self.param - - def no_relu(self): - raise Exception("not implemented") - -class MyModuleParamShape(MyModuleBase): - def __init__(self, in_channels): - super().__init__() - self.param = torch.nn.Parameter(torch.randn(in_channels, 3)) - - def no_relu(self): - return self.param.shape[0] < 10 - - -class MyModuleParamSize(MyModuleBase): - def __init__(self, in_channels): - super().__init__() - self.param = torch.nn.Parameter(torch.randn(in_channels, 3)) - - def no_relu(self): - return self.param.size()[0] < 10 - - -class MyModuleParamDim(MyModuleBase): - def __init__(self, param): - super().__init__() - self.param = param - - def get_mul_matrix(self): - return self.param[0] if (self.param.dim() == 3) else self.param - - def no_relu(self): - return self.param.dim() == 3 - - -class MyModuleParamNDim(MyModuleBase): - def __init__(self, param): - super().__init__() - self.param = param - - def get_mul_matrix(self): - return self.param[0] if (self.param.ndim == 3) else self.param - - def no_relu(self): - return self.param.ndim == 3 - - -class MyModuleParamNumEl(MyModuleBase): - def __init__(self, in_channels): - super().__init__() - self.param = torch.nn.Parameter(torch.randn(in_channels, 3)) - - def no_relu(self): - return self.param.numel() < 10 * 3 - - - -class MyModuleParamNElement(MyModuleBase): - def __init__(self, in_channels): - super().__init__() - self.param = torch.nn.Parameter(torch.randn(in_channels, 3)) - - def no_relu(self): - return self.param.nelement() < 10 * 3 - - - -class TestConstParamShapeInControlFlow(TestCase): - - def verify_mm_relu_mods(self, mm_only_mod, relu_mod): - """ - Verify one module only does a mm op while the other - performs both mm and relu ops in cascade - """ - x = torch.randn(10, 5) - torch.testing.assert_allclose(mm_only_mod(x), torch.mm(x, mm_only_mod.get_mul_matrix())) - tracer = pippy.fx.Tracer(param_shapes_constant=True) - traced_graph = tracer.trace(mm_only_mod) - - # verify the graph module calculates the same result - graph_mod_mm = pippy.fx.GraphModule(mm_only_mod, traced_graph) - torch.testing.assert_allclose(graph_mod_mm(x), torch.mm(x, mm_only_mod.get_mul_matrix())) - - - # Make a new module with different parameter shape to go down the different - # code path - x = torch.randn(10, 15) - torch.testing.assert_allclose(relu_mod(x), torch.relu(torch.mm(x, relu_mod.get_mul_matrix()))) - - tracer2 = pippy.fx.Tracer(param_shapes_constant=True) - traced_graph2 = tracer2.trace(relu_mod) - - # verify the graph module calculates the same result - graph_mod_relu = pippy.fx.GraphModule(relu_mod, traced_graph2) - torch.testing.assert_allclose(graph_mod_relu(x), torch.relu(torch.mm(x, relu_mod.get_mul_matrix()))) - - - graph1_node_targets = [n.target for n in traced_graph.nodes] - graph2_node_targets = [n.target for n in traced_graph2.nodes] - - # the second graph has an exta relu function call node - assert torch.mm in graph1_node_targets and torch.mm in graph2_node_targets - assert torch.relu not in graph1_node_targets and torch.relu in graph2_node_targets - - def test_param_shape_const(self): - mymod = MyModuleParamShape(in_channels=5) - mymod2 = MyModuleParamShape(in_channels=15) - self.verify_mm_relu_mods(mymod, mymod2) - - def test_param_size_const(self): - mymod = MyModuleParamSize(in_channels=5) - mymod2 = MyModuleParamSize(in_channels=15) - self.verify_mm_relu_mods(mymod, mymod2) - - def test_param_dim_const(self): - mymod = MyModuleParamDim(torch.nn.Parameter(torch.randn(2, 5, 3))) - mymod2 = MyModuleParamDim(torch.nn.Parameter(torch.randn(15, 3))) - self.verify_mm_relu_mods(mymod, mymod2) - - def test_param_ndim_const(self): - mymod = MyModuleParamNDim(torch.nn.Parameter(torch.randn(2, 5, 3))) - mymod2 = MyModuleParamNDim(torch.nn.Parameter(torch.randn(15, 3))) - self.verify_mm_relu_mods(mymod, mymod2) - - def test_param_numel_const(self): - mymod = MyModuleParamNumEl(in_channels=5) - mymod2 = MyModuleParamNumEl(in_channels=15) - self.verify_mm_relu_mods(mymod, mymod2) - - def test_param_nelement_const(self): - mymod = MyModuleParamNElement(in_channels=5) - mymod2 = MyModuleParamNElement(in_channels=15) - self.verify_mm_relu_mods(mymod, mymod2) - - -if __name__ == '__main__': - unittest.main() diff --git a/test/fx/test_gradual_type.py b/test/fx/test_gradual_type.py deleted file mode 100644 index 9f82f0810..000000000 --- a/test/fx/test_gradual_type.py +++ /dev/null @@ -1,1017 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates -# Owner(s): ["module: fx"] - -import unittest -import torch -import pippy -import pippy.fx -from pippy.fx import symbolic_trace -from pippy.fx.experimental.unify_refinements import infer_symbolic_types -from pippy.fx.experimental.refinement_types import Equality -from pippy.fx.tensor_type import TensorType, Dyn, is_consistent, is_more_precise -from pippy.fx.annotate import annotate -from pippy.fx.experimental.graph_gradual_typechecker import GraphTypeChecker, broadcast_types, Refine -from pippy.fx.experimental.rewriter import RewritingTracer -from pippy.fx import GraphModule -from pippy.fx.passes.shape_prop import ShapeProp -from torch.testing._internal.common_utils import TestCase - - -try: - import sympy - HAS_SYMPY = True -except ImportError: - HAS_SYMPY = False -skipIfNoSympy = unittest.skipIf(not HAS_SYMPY, "no sympy") - - -try: - from torchvision.models import resnet50 - - HAS_TORCHVISION = True -except ImportError: - HAS_TORCHVISION = False -skipIfNoTorchVision = unittest.skipIf(not HAS_TORCHVISION, "no torchvision") - -def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): - """3x3 convolution with padding""" - return torch.nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, - padding=dilation, groups=groups, bias=False, dilation=dilation) - -class AnnotationsTest(TestCase): - - def test_annotations(self): - """ - Test type annotations in the forward function. - The annoation should appear in the n.graph - where n is the corresoinding node in the resulting graph. - """ - class M(torch.nn.Module): - def forward(self, - x: TensorType((1, 2, 3, Dyn)), - y: Dyn, - z: TensorType[Dyn, 3, Dyn]): - return torch.add(x, y) + z - - module = M() - symbolic_traced: pippy.fx.GraphModule = symbolic_trace(module) - - expected_ph_types = [TensorType((1, 2, 3, Dyn)), Dyn, TensorType((Dyn, 3, Dyn))] - expected_iter = iter(expected_ph_types) - - for n in symbolic_traced.graph.nodes: - if n.op == 'placeholder': - assert n.type == next(expected_iter) - - def test_annotate(self): - class M(torch.nn.Module): - - def forward(self, x): - y = annotate(x, TensorType((1, 2, 3, Dyn))) - return torch.add(x, y) - - module = M() - symbolic_traced : pippy.fx.GraphModule = symbolic_trace(module) - for n in symbolic_traced.graph.nodes: - if n.op == 'placeholder': - assert n.type == TensorType((1, 2, 3, Dyn)) - - def test_consistency(self): - """ - Test the consistency relation. - """ - self.assertTrue(is_consistent(TensorType((1, 2, 3)), TensorType((1, Dyn, 3)))) - self.assertTrue(is_consistent(int, Dyn)) - self.assertTrue(is_consistent(int, int)) - self.assertFalse(is_consistent(TensorType((1, 2, 3)), TensorType((1, 2, 3, 5)))) - self.assertFalse(is_consistent(TensorType((1, 2, 3)), int)) - - def test_precision(self): - """ - Test the consistency relation. - """ - self.assertTrue(is_more_precise(TensorType((1, 2, 3)), TensorType((1, Dyn, 3)))) - self.assertTrue(is_more_precise(int, Dyn)) - self.assertTrue(is_more_precise(int, int)) - self.assertFalse(is_more_precise(TensorType((1, 2, 3)), TensorType((1, 2, 3, 5)))) - self.assertFalse(is_more_precise(TensorType((1, 2, 3)), int)) - - def test_broadcasting1(self): - t1 = TensorType((1, 2, 3, 4)) - t2 = TensorType((1, 2, 1, 4)) - t3 = TensorType(()) - t4 = TensorType((4, 1)) - t5 = TensorType((4, 4, 4)) - # todo switch all code to use list instead of tuple - t6 = TensorType([1]) - assert broadcast_types(t1, t2) == (TensorType((1, 2, 3, 4)), TensorType((1, 2, 3, 4))) - assert broadcast_types(t3, t4) == (t4, t4) - assert broadcast_types(t5, t6) == (t5, t5) - - def test_broadcasting2(self): - t1 = TensorType((2, 3, 4)) - t2 = TensorType((1, 2, 1, 4)) - - assert broadcast_types(t1, t2) == (TensorType((1, 2, 3, 4)), TensorType((1, 2, 3, 4))) - - def test_broadcasting3(self): - t1 = TensorType((1, 2, 3, Dyn)) - t2 = TensorType((2, 3, 4)) - assert broadcast_types(t1, t2) == (TensorType((1, 2, 3, Dyn)), TensorType((1, 2, 3, 4))) - -class TypeCheckerTest(TestCase): - - def test_type_check_add_with_broadcast(self): - class M(torch.nn.Module): - def forward(self, x: TensorType((1, 2, 3, Dyn)), y: TensorType((2, 3, 4))): - return torch.add(x, y) - module = M() - symbolic_traced: pippy.fx.GraphModule = symbolic_trace(module) - tc = GraphTypeChecker({}, symbolic_traced) - tc.type_check() - expected_ph_types = [TensorType((1, 2, 3, Dyn)), - TensorType((2, 3, 4)), - TensorType((1, 2, 3, Dyn)), - TensorType((1, 2, 3, Dyn))] - expected_iter = iter(expected_ph_types) - - for n in symbolic_traced.graph.nodes: - if n.op == 'call_function': - assert n.meta['broadcast'] - assert n.type == next(expected_iter) - - def test_type_check_add_with_scalar(self): - class M(torch.nn.Module): - def forward(self, x: int, y: TensorType((2, 3, 4))): - return torch.add(x, y) - module = M() - symbolic_traced: pippy.fx.GraphModule = symbolic_trace(module) - tc = GraphTypeChecker({}, symbolic_traced) - tc.type_check() - expected_ph_types = [int, - TensorType((2, 3, 4)), - TensorType((2, 3, 4)), - TensorType((2, 3, 4))] - expected_iter = iter(expected_ph_types) - - for n in symbolic_traced.graph.nodes: - assert n.type == next(expected_iter) - - def test_type_check_add_false(self): - class M(torch.nn.Module): - def forward(self, x: TensorType((1, 2, 3, Dyn)), y: TensorType((1, 2, 3))): - return torch.add(x, y) - module = M() - symbolic_traced: pippy.fx.GraphModule = symbolic_trace(module) - tc = GraphTypeChecker({}, symbolic_traced) - with self.assertRaises(TypeError): - tc.type_check() - - def test_type_check_add_true(self): - class M(torch.nn.Module): - def forward(self, x: TensorType((1, 2, Dyn)), y: TensorType((1, 2, 3))): - return torch.add(x, y) - module = M() - symbolic_traced: pippy.fx.GraphModule = symbolic_trace(module) - tc = GraphTypeChecker({}, symbolic_traced) - self.assertTrue(tc.type_check()) - - expected_ph_types = [TensorType((1, 2, Dyn)), TensorType((1, 2, 3))] - expected_iter = iter(expected_ph_types) - - for n in symbolic_traced.graph.nodes: - if n.op == 'placeholder': - assert n.type == next(expected_iter) - if n.op == 'output': - assert n.type == TensorType((1, 2, Dyn)) - - def test_type_check_reshape_true(self): - class M(torch.nn.Module): - def forward(self, x: TensorType((1, 6))): - return torch.reshape(x, [1, 2, 3]) - - module = M() - symbolic_traced: pippy.fx.GraphModule = symbolic_trace(module) - tc = GraphTypeChecker({}, symbolic_traced) - self.assertTrue(tc.type_check()) - - for n in symbolic_traced.graph.nodes: - if n.op == 'placeholder': - assert n.type == TensorType((1, 6)) - - if n.op == 'call_function': - assert n.type == TensorType((1, 2, 3)) - - if n.op == 'output': - assert n.type == TensorType((1, 2, 3)) - - def test_type_check_reshape_false(self): - class M(torch.nn.Module): - def forward(self, x: TensorType((1, 5))): - return torch.reshape(x, [1, 2, 3]) - - module = M() - symbolic_traced: pippy.fx.GraphModule = symbolic_trace(module) - tc = GraphTypeChecker({}, symbolic_traced) - with self.assertRaises(TypeError): - tc.type_check() - - def test_type_check_reshape_dyn_false(self): - class M(torch.nn.Module): - def forward(self, x: TensorType((1, 5))): - return torch.reshape(x, [1, 2, -1]) - - module = M() - symbolic_traced: pippy.fx.GraphModule = symbolic_trace(module) - tc = GraphTypeChecker({}, symbolic_traced) - with self.assertRaises(TypeError): - tc.type_check() - - def test_type_check_reshape_dyn_true(self): - class M(torch.nn.Module): - def forward(self, x: TensorType((1, 15))): - return torch.reshape(x, [1, 5, -1]) - - module = M() - symbolic_traced: pippy.fx.GraphModule = symbolic_trace(module) - tc = GraphTypeChecker({}, symbolic_traced) - self.assertTrue(tc.type_check()) - - def test_type_check_reshape_dyn_true_param_false(self): - class M(torch.nn.Module): - def forward(self, x: TensorType((Dyn, 5))): - return torch.reshape(x, [1, 2, -1]) - - module = M() - symbolic_traced: pippy.fx.GraphModule = symbolic_trace(module) - tc = GraphTypeChecker({}, symbolic_traced) - with self.assertRaises(TypeError): - tc.type_check() - - def test_type_check_transpose_true(self): - class M(torch.nn.Module): - def forward(self, x: TensorType((1, 2, 3, 5))): - return torch.transpose(x, 0, 1) - - module = M() - symbolic_traced : pippy.fx.GraphModule = symbolic_trace(module) - tc = GraphTypeChecker({}, symbolic_traced) - self.assertTrue(tc.type_check()) - - for n in symbolic_traced.graph.nodes: - if n.op == 'call_function': - assert n.type == TensorType([2, 1, 3, 5]) - if n.op == 'output': - assert n.type == TensorType([2, 1, 3, 5]) - if n.op == 'x': - assert n.placeholder == TensorType([1, 2, 3, 5]) - - def test_type_check_transpose_False(self): - class M(torch.nn.Module): - def forward(self, x: TensorType((1, 2, 3, 5))): - return torch.transpose(x, 0, 10) - - module = M() - symbolic_traced: pippy.fx.GraphModule = symbolic_trace(module) - tc = GraphTypeChecker({}, symbolic_traced) - with self.assertRaises(TypeError): - tc.type_check() - - def test_type_check_batch_norm_2D(self): - class BasicBlock(torch.nn.Module): - - def __init__(self, inplanes, planes): - super(BasicBlock, self).__init__() - norm_layer = torch.nn.BatchNorm2d - self.bn1 = norm_layer(planes) - - def forward(self, x: TensorType((2, 2, 5, 4))): - identity = x - out: TensorType((2, 2, Dyn, 4)) = self.bn1(x) - out += identity - return out - - B = BasicBlock(2, 2) - ast_rewriter = RewritingTracer() - graph = ast_rewriter.trace(B) - traced = GraphModule(ast_rewriter.root, graph, "gm") - tc = GraphTypeChecker({}, traced) - tc.type_check() - - for n in graph.nodes: - if n.op == 'placeholder': - assert n.type == TensorType((2, 2, 5, 4)) - if n.op == 'output': - assert n.type == TensorType((2, 2, 5, 4)) - if n.op == 'call_module': - assert n.type == TensorType((2, 2, 5, 4)) - if n.op == 'call_function': - assert n.type == TensorType((2, 2, 5, 4)) - - def test_type_check_batch_norm_2D_false(self): - class BasicBlock(torch.nn.Module): - - def __init__(self, inplanes, planes): - super(BasicBlock, self).__init__() - norm_layer = torch.nn.BatchNorm2d - self.bn1 = norm_layer(planes) - - def forward(self, x: TensorType((2, 2, 5))): - identity = x - out: TensorType((2, 2, Dyn, 4)) = self.bn1(x) - out += identity - return out - - B = BasicBlock(2, 2) - ast_rewriter = RewritingTracer() - graph = ast_rewriter.trace(B) - traced = GraphModule(ast_rewriter.root, graph, "gm") - tc = GraphTypeChecker({}, traced) - with self.assertRaises(TypeError): - tc.type_check() - - def test_type_check_batch_norm_2D_broadcast(self): - class BasicBlock(torch.nn.Module): - - def __init__(self, inplanes, planes): - super(BasicBlock, self).__init__() - norm_layer = torch.nn.BatchNorm2d - self.bn1 = norm_layer(planes) - - def forward(self, x: Dyn): - identity = x - out: TensorType((2, 2, Dyn, 4)) = self.bn1(x) - out += identity - return out - - B = BasicBlock(2, 2) - ast_rewriter = RewritingTracer() - graph = ast_rewriter.trace(B) - traced = GraphModule(ast_rewriter.root, graph, "gm") - tc = GraphTypeChecker({}, traced) - tc.type_check() - for n in graph.nodes: - if n.op == 'placeholder': - assert n.type == TensorType((Dyn, Dyn, Dyn, Dyn)) - if n.op == 'call_function': - assert n.type == TensorType((Dyn, Dyn, Dyn, Dyn)) - if n.op == 'output': - assert n.type == TensorType((Dyn, Dyn, Dyn, Dyn)) - if n.op == 'call_module': - assert n.type == TensorType((2, 2, Dyn, 4)) - - B = BasicBlock(1, 1) - ast_rewriter = RewritingTracer() - graph = ast_rewriter.trace(B) - traced = GraphModule(ast_rewriter.root, graph, "gm") - tc = GraphTypeChecker({}, traced) - with self.assertRaises(TypeError): - tc.type_check() - - def test_type_check_conv2D(self): - class BasicBlock(torch.nn.Module): - def __init__(self, inplanes, planes, stride=1): - super(BasicBlock, self).__init__() - norm_layer = torch.nn.BatchNorm2d - self.conv1 = conv3x3(inplanes, planes, stride) - self.bn1 = norm_layer(planes) - - def forward(self, x: Dyn): - identity = x - out: TensorType((2, 2, Dyn, 4)) = self.conv1(x) - out += identity - return out - - B = BasicBlock(2, 2) - ast_rewriter = RewritingTracer() - graph = ast_rewriter.trace(B) - traced = GraphModule(ast_rewriter.root, graph, "gm") - tc = GraphTypeChecker({}, traced) - tc.type_check() - for n in graph.nodes: - if n.op == 'placeholder': - assert n.type == TensorType((Dyn, Dyn, Dyn, Dyn)) - if n.op == 'call_function': - assert n.type == TensorType((Dyn, Dyn, Dyn, Dyn)) - if n.op == 'output': - assert n.type == TensorType((Dyn, Dyn, Dyn, Dyn)) - if n.op == 'call_module': - assert n.type == TensorType((2, 2, Dyn, 4)) - - def test_type_check_conv2D_2(self): - class BasicBlock(torch.nn.Module): - def __init__(self, inplanes, planes, stride=1): - super(BasicBlock, self).__init__() - norm_layer = torch.nn.BatchNorm2d - self.conv1 = conv3x3(inplanes, planes, stride) - self.bn1 = norm_layer(planes) - - def forward(self, x: TensorType((5, 2, 3, 4))): - identity = x - out = self.conv1(x) - out += identity - return out - - B = BasicBlock(2, 2) - b = B.forward(torch.rand(5, 2, 3, 4)) - - ast_rewriter = RewritingTracer() - graph = ast_rewriter.trace(B) - traced = GraphModule(ast_rewriter.root, graph, "gm") - tc = GraphTypeChecker({}, traced) - tc.type_check() - t = TensorType((5, 2, 3, 4)) - for n in graph.nodes: - if n.op == 'placeholder': - assert n.type == t - if n.op == 'call_function': - assert n.type == t - if n.op == 'output': - assert torch.Size(n.type.__args__) == b.shape - if n.op == 'call_module': - assert n.type == t - - B = BasicBlock(1, 2) - ast_rewriter = RewritingTracer() - graph = ast_rewriter.trace(B) - traced = GraphModule(ast_rewriter.root, graph, "gm") - tc = GraphTypeChecker({}, traced) - with self.assertRaises(TypeError): - tc.type_check() - - def test_type_check_conv2D_2_fully_static(self): - annotation_list = [(1, 2, 3, 5), (2, 5, 6, 9), (10, 15, 13, 14), - (10, Dyn, 13, 14), (Dyn, Dyn, Dyn, 3)] - input_list = [(1, 2, 3, 5), (2, 5, 6, 9), (10, 15, 13, 14), - (10, 15, 13, 14), (1, 2, 2, 3)] - intermediate_types = [(1, Dyn, Dyn, 7), (2, Dyn, 4, 6), (10, 15, Dyn, 5), - (10, 15, 7, 7), (1, Dyn, Dyn, Dyn)] - in_planes_list = [2, 5, 15, 15, 2] - stride_list = [1, 2, 3, 2, 2] - out_planes_list = [2, 5, 15, 15, 2] - groups_list = [1, 5, 5, 5, 2] - dilation_list = [1, 2, 3, 3, 3] - padding_list = [1, 2, 3, 3, 3] - kernel_size_list = [1, 2, 3, 3, 3] - output_types = [(1, 2, Dyn, 7), (2, 5, 4, 6), (10, 15, Dyn, 5), (10, 15, 7, 7), (1, 2, Dyn, Dyn)] - - for i in range(5): - annotation = annotation_list[i] - input = input_list[i] - in_planes = in_planes_list[i] - stride = stride_list[i] - out_planes = out_planes_list[i] - groups = groups_list[i] - dilation = dilation_list[i] - padding = padding_list[i] - kernel_size = kernel_size_list[i] - intermediate_type = intermediate_types[i] - - class BasicBlock(torch.nn.Module): - def __init__(self, in_planes, out_planes, kernel_size, stride, padding, groups, dilation): - super(BasicBlock, self).__init__() - self.conv1 = torch.nn.Conv2d(in_channels=in_planes, out_channels=out_planes, - kernel_size=kernel_size, stride=stride, - padding=padding, groups=groups, bias=False, dilation=dilation) - - def forward(self, x): - out = self.conv1(x) - return out - - B = BasicBlock(in_planes, out_planes, kernel_size, stride, padding, groups, dilation) - ast_rewriter = RewritingTracer() - graph = ast_rewriter.trace(B) - traced = GraphModule(ast_rewriter.root, graph, "gm") - - # annotate our argument - for n in graph.nodes: - if n.op == 'placeholder': - n.type = TensorType(annotation) - - b = B.forward(torch.rand(input)) - tc = GraphTypeChecker({}, traced) - tc.type_check() - - for n in graph.nodes: - if n.op == 'output': - assert is_consistent(n.type, TensorType(b.size())) - - # test with intermediate annotations - class BasicBlock(torch.nn.Module): - def __init__(self, in_planes, out_planes, kernel_size, stride, padding, groups, dilation): - super(BasicBlock, self).__init__() - self.conv1 = torch.nn.Conv2d(in_channels=in_planes, out_channels=out_planes, - kernel_size=kernel_size, stride=stride, - padding=padding, groups=groups, bias=False, dilation=dilation) - - def forward(self, x): - out = self.conv1(x) - return out - - B = BasicBlock(in_planes, out_planes, kernel_size, stride, padding, groups, dilation) - ast_rewriter = RewritingTracer() - graph = ast_rewriter.trace(B) - traced = GraphModule(ast_rewriter.root, graph, "gm") - - # populate our intermediate notes - for n in traced.graph.nodes: - if n.op == 'call_module': - n.type = TensorType(intermediate_type) - - tc = GraphTypeChecker({}, traced) - tc.type_check() - - for n in traced.graph.nodes: - if n.op == 'output': - assert n.type == TensorType(output_types[i]) - assert is_consistent(n.type, TensorType(b.size())) - - def test_typecheck_basicblock(self): - class BasicBlock(torch.nn.Module): - expansion = 1 - - def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, - base_width=64, dilation=1): - super(BasicBlock, self).__init__() - norm_layer = torch.nn.BatchNorm2d - if groups != 1 or base_width != 64: - raise ValueError('BasicBlock only supports groups=1 and base_width=64') - if dilation > 1: - raise NotImplementedError("Dilation > 1 not supported in BasicBlock") - # Both self.conv1 and self.downsample layers downsample the input when stride != 1 - self.conv1 = conv3x3(inplanes, planes, stride) - self.bn1 = norm_layer(planes) - self.relu = torch.nn.ReLU(inplace=True) - self.conv2 = conv3x3(planes, planes) - self.bn2 = norm_layer(planes) - self.downsample = downsample - self.stride = stride - - def forward(self, x: TensorType((2, 2, 4, 5))): - identity = x - - out = self.conv1(x) - out = self.bn1(out) - out = self.relu(out) - - out = self.conv2(out) - out = self.bn2(out) - - if self.downsample is not None: - identity = self.downsample(x) - - out += identity - out = self.relu(out) - - return out - - B = BasicBlock(2, 2) - - ast_rewriter = RewritingTracer() - graph = ast_rewriter.trace(B) - traced = GraphModule(ast_rewriter.root, graph, "gm") - - tc = GraphTypeChecker({}, traced) - tc.type_check() - - for n in traced.graph.nodes: - if n.target == 'output': - assert isinstance(n.type, TensorType) - assert torch.Size(n.type.__args__) == B.forward(torch.rand(2, 2, 4, 5)).size() - - def test_type_check_conv2D_maxpool2d_flatten(self): - - class BasicBlock(torch.nn.Module): - def __init__(self): - super(BasicBlock, self).__init__() - - self.conv1 = torch.nn.Conv2d(3, 6, 5) - self.pool = torch.nn.MaxPool2d(2, 2) - self.conv2 = torch.nn.Conv2d(6, 16, 5) - self.fc1 = torch.nn.Linear(5, 120) - self.pool2 = torch.nn.AdaptiveAvgPool2d((6, 7)) - - def forward(self, x : TensorType((4, 3, 32, 32))): - out = self.conv1(x) - out = self.pool(out) - out = self.conv2(out) - out = self.pool(out) - out = self.fc1(out) - out = self.pool2(out) - out = torch.flatten(out, 1) - return out - - B = BasicBlock() - ast_rewriter = RewritingTracer() - graph = ast_rewriter.trace(B) - traced = GraphModule(ast_rewriter.root, graph, "gm") - tc = GraphTypeChecker({}, traced) - tc.type_check() - - expected_ph_types = [TensorType((4, 3, 32, 32)), TensorType((4, 6, 28, 28)), - TensorType((4, 6, 14, 14)), TensorType((4, 16, 10, 10)), - TensorType((4, 16, 5, 5)), TensorType((4, 16, 5, 120)), - TensorType((4, 16, 6, 7)), TensorType((4, 672)), TensorType((4, 672))] - - expected_iter = iter(expected_ph_types) - traced.graph.eliminate_dead_code() - - for n in traced.graph.nodes: - assert n.type == next(expected_iter) - - def test_type_check_flatten(self): - class M(torch.nn.Module): - def forward(self, x: TensorType((1, 2, 3, 5, Dyn))): - return torch.flatten(x, 1, 2) - - module = M() - symbolic_traced: pippy.fx.GraphModule = symbolic_trace(module) - tc = GraphTypeChecker({}, symbolic_traced) - tc.type_check() - for n in symbolic_traced.graph.nodes: - if n.op == 'output': - assert n.type == TensorType((1, 6, 5, Dyn)) - - - def test_type_check_flatten_2(self): - class M(torch.nn.Module): - def forward(self, x: TensorType((1, Dyn, 3, 5, Dyn))): - return torch.flatten(x, 1, 2) - - module = M() - symbolic_traced: pippy.fx.GraphModule = symbolic_trace(module) - tc = GraphTypeChecker({}, symbolic_traced) - tc.type_check() - for n in symbolic_traced.graph.nodes: - if n.op == 'output': - assert n.type == TensorType((1, Dyn, 5, Dyn)) - - def test_type_check_flatten3(self): - class M(torch.nn.Module): - def forward(self, x: TensorType((2, 3, 4, 5))): - return torch.flatten(x, start_dim=1, end_dim=3) - - module = M() - symbolic_traced: pippy.fx.GraphModule = symbolic_trace(module) - tc = GraphTypeChecker({}, symbolic_traced) - tc.type_check() - for n in symbolic_traced.graph.nodes: - if n.op == 'output': - assert n.type == TensorType((2, 60)) - r = Refine(symbolic_traced) - r.refine() - c = r.constraints - assert c == [Equality(2, 2)] - - def test_type_typechecl_maxpool2d_3dinput(self): - - class BasicBlock(torch.nn.Module): - def __init__(self): - super(BasicBlock, self).__init__() - self.pool = torch.nn.MaxPool2d(5, 8) - - def forward(self, x : TensorType((64, 8, 8))): - out = self.pool(x) - return out - - B = BasicBlock() - ast_rewriter = RewritingTracer() - graph = ast_rewriter.trace(B) - traced = GraphModule(ast_rewriter.root, graph, "gm") - tc = GraphTypeChecker({}, traced) - tc.type_check() - - for n in traced.graph.nodes: - if n.target == 'output': - assert n.type == TensorType((64, 1, 1)) - - def test_type_maxpool2d_fully_static(self): - annotation_list = [(Dyn, Dyn, 3, 5), (2, 5, 6, 9), (10, 15, 13, 14), - (10, Dyn, 13, 14), (Dyn, Dyn, Dyn, 10)] - input_list = [(1, 2, 3, 5), (2, 5, 6, 9), (10, 15, 13, 14), - (10, 15, 13, 14), (2, 2, 10, 10)] - intermediate_types = [(1, 2, Dyn, Dyn), (2, Dyn, 2, 4), (10, 15, Dyn, 2), - (10, 15, 2, 3), (2, Dyn, Dyn, Dyn)] - stride_list = [1, 2, 3, 2, 1] - dilation_list = [1, 2, 3, 3, 2] - padding_list = [1, 2, 3, 3, 1] - kernel_size_list = [2, 4, 6, 6, 3] - output_types = [(1, 2, 4, 6), (2, 5, 2, 4), (10, 15, 2, 2), (10, 15, 2, 3), (2, Dyn, Dyn, 8)] - - for i in range(5): - annotation = annotation_list[i] - input = input_list[i] - stride = stride_list[i] - dilation = dilation_list[i] - padding = padding_list[i] - kernel_size = kernel_size_list[i] - intermediate_type = intermediate_types[i] - - class BasicBlock(torch.nn.Module): - def __init__(self, kernel_size, stride, padding, dilation): - super(BasicBlock, self).__init__() - self.pool = torch.nn.MaxPool2d(kernel_size, stride=stride, - padding=padding, dilation=dilation, - return_indices=False, ceil_mode=False) - - def forward(self, x): - out = self.pool(x) - return out - - B = BasicBlock(kernel_size, stride, padding, dilation) - ast_rewriter = RewritingTracer() - graph = ast_rewriter.trace(B) - traced = GraphModule(ast_rewriter.root, graph, "gm") - - # annotate our argument - for n in graph.nodes: - if n.op == 'placeholder': - n.type = TensorType(annotation) - - b = B.forward(torch.rand(input)) - tc = GraphTypeChecker({}, traced) - tc.type_check() - - for n in graph.nodes: - if n.op == 'output': - assert is_consistent(n.type, TensorType(b.size())) - - # test with intermediate annotations - class BasicBlock(torch.nn.Module): - def __init__(self, kernel_size, stride, padding, dilation): - super(BasicBlock, self).__init__() - self.pool = torch.nn.MaxPool2d(kernel_size, stride=stride, - padding=padding, dilation=dilation, - return_indices=False, ceil_mode=False) - - def forward(self, x): - out = self.pool(x) - return out - - B = BasicBlock(kernel_size, stride, padding, dilation) - ast_rewriter = RewritingTracer() - graph = ast_rewriter.trace(B) - traced = GraphModule(ast_rewriter.root, graph, "gm") - - # annotate our argument - for n in graph.nodes: - if n.op == 'placeholder': - n.type = TensorType(annotation) - - # populate our intermediate notes - for n in traced.graph.nodes: - if n.op == 'call_module': - n.type = TensorType(intermediate_type) - - tc = GraphTypeChecker({}, traced) - tc.type_check() - - for n in traced.graph.nodes: - if n.op == 'output': - assert n.type == TensorType(output_types[i]) - assert is_consistent(n.type, TensorType(b.size())) - - def test_flatten_fully_static(self): - annotation_list = [Dyn, TensorType((2, 5, 6, 9)), TensorType((10, 15, 13, 14)), - TensorType((10, Dyn, 13, 14)), TensorType((Dyn, Dyn, Dyn, 10))] - input_list = [(1, 2, 3, 5), (2, 5, 6, 9), (10, 15, 13, 14), - (10, 15, 13, 14), (2, 2, 10, 10)] - - intermediate_list = [Dyn, (2, 5, 6, 9), (10, 15, 13, 14), - (10, 15, 13, 14), (2, 2, 10, 10)] - - start_dim = [1, 2, 1, 2, 0] - end_dim = [1, 3, 3, 3, -2] - - for i in range(5): - annotation = annotation_list[i] - input = input_list[i] - # intermediate_type = intermediate_list[i] - - class BasicBlock(torch.nn.Module): - def __init__(self, start, end): - super(BasicBlock, self).__init__() - self.start = start - self.end = end - - def forward(self, x): - out = torch.flatten(x, self.start, self.end) - return out - - B = BasicBlock(start_dim[i], end_dim[i]) - ast_rewriter = RewritingTracer() - graph = ast_rewriter.trace(B) - traced = GraphModule(ast_rewriter.root, graph, "gm") - - # annotate our argument - for n in graph.nodes: - if n.op == 'placeholder': - n.type = annotation - - b = B.forward(torch.rand(input)) - tc = GraphTypeChecker({}, traced) - tc.type_check() - - for n in graph.nodes: - if n.op == 'output': - assert is_consistent(n.type, TensorType(b.size())) - - @skipIfNoSympy - @skipIfNoTorchVision - def test_resnet50(self): - gm_run = symbolic_trace(resnet50()) - sample_input = torch.randn(1, 3, 224, 224) - - # run our nodes - ShapeProp(gm_run).propagate(sample_input) - - gm_static = symbolic_trace(resnet50()) - - for n in gm_static.graph.nodes: - n.type = None - - g = GraphTypeChecker({}, gm_static) - g.type_check() - gm_static.graph.eliminate_dead_code() - gm_run.graph.eliminate_dead_code() - # here we are checking for consistency with fully dynamic nodes - for n1, n2 in zip(gm_static.graph.nodes, gm_run.graph.nodes): - assert is_consistent(n1.type, TensorType(n2.meta['tensor_meta'].shape)) - - # here we give the same input as to runtume - gm_static_with_types = symbolic_trace(resnet50()) - - # we initialize our placeholder - for n in gm_static_with_types.graph.nodes: - if n.op == 'placeholder': - n.type = TensorType((1, 3, 224, 224)) - - g = GraphTypeChecker({}, gm_static_with_types) - g.type_check() - for n1, n2 in zip(gm_static_with_types.graph.nodes, gm_run.graph.nodes): - assert n1.type == TensorType(n2.meta['tensor_meta'].shape) - - # apply shape inference to graph and check - # that the batch size is equal across all layers - infer_symbolic_types(gm_static) - - - batch_sizes = set() - gm_static.graph.eliminate_dead_code() - for n in gm_static.graph.nodes: - assert isinstance(n.type, TensorType) - batch_sizes.add(n.type.__args__[0]) - assert (len(batch_sizes) == 1) - - @skipIfNoSympy - def test_type_check_batch_norm_symbolic(self): - class BasicBlock(torch.nn.Module): - - def __init__(self, inplanes, planes): - super(BasicBlock, self).__init__() - norm_layer = torch.nn.BatchNorm2d - self.bn1 = norm_layer(planes) - - def forward(self, x: Dyn): - identity = x - out: TensorType((2, 2, Dyn, 4)) = self.bn1(x) - out += identity - return out - - B = BasicBlock(2, 2) - ast_rewriter = RewritingTracer() - graph = ast_rewriter.trace(B) - traced = GraphModule(ast_rewriter.root, graph, "gm") - tc = GraphTypeChecker({}, traced) - tc.type_check() - - infer_symbolic_types(traced) - - my_types = iter([TensorType[(2, 2, sympy.symbols('~7'), 4)], - TensorType[(2, 2, sympy.symbols('~7'), 4)], - TensorType[(2, 2, sympy.symbols('~7'), 4)], - TensorType[(2, 2, sympy.symbols('~7'), 4)]]) - - for n in graph.nodes: - assert n.type == next(my_types) - - @skipIfNoSympy - def test_symbolic_add_with_broadcast(self): - class M(torch.nn.Module): - def forward(self, x: TensorType((1, 2, 3, Dyn)), y: TensorType((2, 3, 4))): - return torch.add(x, y) - module = M() - symbolic_traced: pippy.fx.GraphModule = symbolic_trace(module) - tc = GraphTypeChecker({}, symbolic_traced) - tc.type_check() - infer_symbolic_types(symbolic_traced) - r = Refine(symbolic_traced) - r.refine() - - assert r.constraints == [Equality(1, 1), Equality(2, 2), Equality(3, 3)] - # note that there is no equality constraint between dyn and 4 because - # dyn could be 4 or 1 - - infer_symbolic_types(symbolic_traced) - - expected_ph_types = [TensorType((1, 2, 3, sympy.symbols('~0'))), - TensorType((2, 3, 4)), - TensorType((1, 2, 3, sympy.symbols('~1'))), - TensorType((1, 2, 3, sympy.symbols('~1')))] - expected_iter = iter(expected_ph_types) - - - for n in symbolic_traced.graph.nodes: - assert n.type == next(expected_iter) - - @skipIfNoSympy - def test_symbolic_add_with_broadcast_2(self): - class M(torch.nn.Module): - def forward(self, x: TensorType((1, 2)), y: TensorType((Dyn, 2))): - return torch.add(x, y) - module = M() - symbolic_traced: pippy.fx.GraphModule = symbolic_trace(module) - tc = GraphTypeChecker({}, symbolic_traced) - tc.type_check() - infer_symbolic_types(symbolic_traced) - r = Refine(symbolic_traced) - r.refine() - - expected_ph_types = [TensorType((1, 2)), - TensorType((sympy.symbols('~1'), 2)), - TensorType((sympy.symbols('~1'), 2)), - TensorType((sympy.symbols('~1'), 2))] - expected_iter = iter(expected_ph_types) - - for n in symbolic_traced.graph.nodes: - assert n.type == next(expected_iter) - - @skipIfNoSympy - def test_type_check_conv2D_types(self): - class BasicBlock(torch.nn.Module): - def __init__(self, inplanes, planes, stride=1): - super(BasicBlock, self).__init__() - norm_layer = torch.nn.BatchNorm2d - self.conv1 = conv3x3(inplanes, planes, stride) - self.bn1 = norm_layer(planes) - - def forward(self, x: Dyn): - identity = x - out: TensorType((2, 2, Dyn, 4)) = self.conv1(x) - out += identity - return out - - B = BasicBlock(2, 2) - ast_rewriter = RewritingTracer() - graph = ast_rewriter.trace(B) - traced = GraphModule(ast_rewriter.root, graph, "gm") - tc = GraphTypeChecker({}, traced) - tc.type_check() - infer_symbolic_types(traced) - - for n in traced.graph.nodes: - if n.op == 'call_module': - assert isinstance(n.type.__args__[2], sympy.floor) - assert isinstance(n.type.__args__[3], sympy.floor) - - @skipIfNoSympy - def test_type_check_symbolic_inferenceconv2D_maxpool2d_flatten(self): - - class BasicBlock(torch.nn.Module): - def __init__(self): - super(BasicBlock, self).__init__() - - self.conv1 = torch.nn.Conv2d(3, 6, 5) - self.pool = torch.nn.MaxPool2d(2, 2) - self.conv2 = torch.nn.Conv2d(6, 16, 5) - self.fc1 = torch.nn.Linear(5, 120) - self.pool2 = torch.nn.AdaptiveAvgPool2d((6, 7)) - - def forward(self, x : TensorType((4, 3, Dyn, Dyn))): - out = self.conv1(x) - out = self.pool(out) - out = self.conv2(out) - out = self.pool(out) - out = self.fc1(out) - out = self.pool2(out) - out = torch.flatten(out, 1) - return out - - B = BasicBlock() - ast_rewriter = RewritingTracer() - traced = symbolic_trace(B) - tc = GraphTypeChecker({}, traced) - tc.type_check() - infer_symbolic_types(traced) - - for n in traced.graph.nodes: - if n.target == 'conv1': - assert n.type == TensorType((4, 6, sympy.floor((sympy.symbols('~0') - 4)), - sympy.floor((sympy.symbols('~1') - 4)))) - - elif n.target == 'conv2': - assert n.type == TensorType((4, 16, sympy.floor((sympy.symbols('~4') - 4)), - sympy.floor((sympy.symbols('~5') - 4)))) - -if __name__ == '__main__': - unittest.main() diff --git a/test/fx/test_pass_infra.py b/test/fx/test_pass_infra.py deleted file mode 100644 index e41aaf0a6..000000000 --- a/test/fx/test_pass_infra.py +++ /dev/null @@ -1,176 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates -# Owner(s): ["module: fx"] - -import torch -from torch.testing._internal.common_utils import TestCase - -import pippy -import pippy.fx as fx -from pippy.fx.passes.infra.pass_base import PassResult -from pippy.fx.passes.infra.pass_manager import ( - PassManager, - this_before_that_pass_constraint, - _topological_sort_passes, -) - - -def replace_add_with_mul_pass(gm): - modified = False - for node in gm.graph.nodes: - if node.op == "call_function" and node.target == torch.add: - node.target = torch.mul - modified = True - return PassResult(gm, modified) - -def replace_mul_with_div_pass(gm): - modified = False - for node in gm.graph.nodes: - if node.op == "call_function" and node.target == torch.mul: - node.target = torch.div - modified = True - return PassResult(gm, modified) - -class AddModule(torch.nn.Module): - def __init__(self): - super().__init__() - - def forward(self, x): - y = torch.add(x, x) - z = torch.add(y, x) - return z - - -class TestPassManager(TestCase): - def test_pass_manager(self): - """ - Tests that the pass manager runs the passes correctly. - """ - - m = AddModule() - traced_m = pippy.fx.symbolic_trace(m) - pm = PassManager(passes=[replace_add_with_mul_pass, replace_mul_with_div_pass], steps=5) - - pm.validate_constraints() - self.assertEqual(len(pm.passes), 2) - - res = pm(traced_m) - modified_m = res.graph_module - assert isinstance(modified_m, fx.GraphModule) - - # Check that all call_function nodes are divs - for node in modified_m.graph.nodes: - if node.op == "call_function": - self.assertEqual(node.target, torch.div) - - def test_this_before_that_pass_constraint(self): - """ - Tests the construction of constraints - """ - passes = [lambda x: 2 * x for _ in range(10)] - pm = PassManager(passes) - - # add unfulfillable constraint - pm.add_constraint(this_before_that_pass_constraint(passes[-1], passes[0])) - - with self.assertRaises(RuntimeError): - pm.validate_constraints() - - - def test_pass_manager_checks(self): - """ - Tests that users can add in check functions correctly - """ - m = AddModule() - traced_m = fx.symbolic_trace(m) - pm = PassManager(passes=[replace_add_with_mul_pass, replace_mul_with_div_pass]) - - def check_div_target(graph_module): - for node in graph_module.graph.nodes: - if node.op == "call_function" and node.target != torch.div: - raise ValueError("Target should be div!") - pm.add_checks(check_div_target) - - with self.assertRaises(ValueError): - pm(traced_m) - - def test_pass_manager_bad_checks(self): - """ - Checks that we error if we pass in a check function with the wrong parameters - """ - def check_bad_args(graph_module, i): - pass - - pm = PassManager() - self.assertRaises(TypeError, pm.add_checks, check_bad_args) - - def test_topological_sort(self): - """ - Tests that passes are correctly ordered based on contraints. - """ - - def pass0(x): - return x - - def pass1(x): - return x + 1 - - def pass2(x): - return x + 2 - - def pass3(x): - return x + 3 - - def pass4(x): - return x + 4 - - def pass5(x): - return x + 5 - - # Not passing any constraints should keep the original order - passes = [pass0, pass1, pass2, pass3, pass4, pass5] - sorted = _topological_sort_passes(passes, []) - self.assertEqual(sorted, passes) - - # Graph that we are constructing: - # 5 ----> 0 <---- 4 - # | | - # +-> 2 -> 3 -> 1 <-+ - # Which has a possible topological order of: [4, 5, 0, 2, 3, 1] - passes = [pass0, pass1, pass2, pass3, pass4, pass5] - constraints = [ - this_before_that_pass_constraint(pass5, pass0), - this_before_that_pass_constraint(pass5, pass2), - this_before_that_pass_constraint(pass4, pass0), - this_before_that_pass_constraint(pass4, pass1), - this_before_that_pass_constraint(pass2, pass3), - this_before_that_pass_constraint(pass3, pass1), - ] - sorted = _topological_sort_passes(passes, constraints) - self.assertEqual(sorted, [pass4, pass5, pass0, pass2, pass3, pass1]) - - # Circular dependency should result in the circular_dep flag being set - passes = [pass0, pass1, pass2] - constraints = [ - this_before_that_pass_constraint(passes[0], passes[1]), - this_before_that_pass_constraint(passes[1], passes[2]), - this_before_that_pass_constraint(passes[2], passes[0]), - ] - with self.assertRaises(RuntimeError) as e: - _topological_sort_passes(passes, constraints) - expected_error_msg = f"Circular dependency detected within the following passes: {passes}" - self.assertEqual(e.exception.args[0], expected_error_msg) - - def test_pass_manager_error(self): - """ - Tests error catching + debug - """ - def pass_fail(graph_module): - raise RuntimeError("bad") - - m = AddModule() - traced_m = pippy.fx.symbolic_trace(m) - pm = PassManager(passes=[replace_add_with_mul_pass, replace_mul_with_div_pass, pass_fail]) - - # Comment out this line to see the actual error message - with self.assertRaises(RuntimeError): - pm(traced_m) diff --git a/test/fx/test_subgraph_rewriter.py b/test/fx/test_subgraph_rewriter.py deleted file mode 100644 index d8a1bc77d..000000000 --- a/test/fx/test_subgraph_rewriter.py +++ /dev/null @@ -1,777 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates -# Owner(s): ["module: fx"] - -import os -import sys - -import torch - -import pippy -from pippy.fx import symbolic_trace, subgraph_rewriter -from pippy.fx.annotate import annotate -# Make the helper files in test/ importable -from pippy.fx.experimental.rewriter import RewritingTracer - -pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) -sys.path.append(pytorch_test_dir) -from torch.testing._internal.jit_utils import JitTestCase - -if __name__ == '__main__': - raise RuntimeError("This test file is not meant to be run directly, use:\n\n" - "\tpython test/test_fx.py TESTNAME\n\n" - "instead.") - -@pippy.fx.wrap -def wrapped_gemm_bias_mul(a, b, bias): - lin_res = torch.nn.functional.linear(a, b, bias=bias) - mul_res = lin_res * a - return lin_res, mul_res - -@pippy.fx.wrap -def wrapped_gemm_bias_mul_with_c(a, b, bias, c): - lin_res = torch.nn.functional.linear(a, b, bias=bias) - mul_res = lin_res * c - return lin_res, mul_res - -class TestSubgraphRewriter(JitTestCase): - - def test_subgraph_rewriter_preserves_logic(self): - class M(torch.nn.Module): - def forward(self, x): - val = torch.neg(x) + torch.relu(x) - return torch.add(val, val) - - def pattern(x): - return torch.neg(x) + torch.relu(x) - - def comparison(x): - val = torch.neg(x) + torch.relu(x) - return torch.add(val, val) - - traced = symbolic_trace(M()) - comparison_fn = symbolic_trace(comparison) - - x = torch.rand(1, 3) - - # Replace `pattern` with the same pattern (shouldn't change - # the underlying logic) - subgraph_rewriter.replace_pattern(traced, pattern, pattern) - - traced.graph.lint() - - ref_output = comparison_fn(x) - test_output = traced.forward(x) - self.assertEqual(ref_output, test_output) - - def test_subgraph_rewriter_with_oneliner_pattern(self): - class M(torch.nn.Module): - def forward(self, x): - val = torch.neg(x) - return torch.add(val, val) - - def pattern(x): - return torch.neg(x) - - def replacement(x): - return torch.relu(x) - - def comparison(x): - val = torch.relu(x) - return torch.add(val, val) - - traced = symbolic_trace(M()) - comparison_fn = symbolic_trace(comparison) - - x = torch.rand(1, 3) - - subgraph_rewriter.replace_pattern(traced, pattern, replacement) - - traced.graph.lint() - - ref_output = comparison_fn(x) - test_output = traced.forward(x) - self.assertEqual(ref_output, test_output) - - def test_subgraph_rewriter_single_pattern_match(self): - class M(torch.nn.Module): - def forward(self, x): - val = torch.neg(x) + torch.relu(x) - return torch.add(val, val) - - def pattern(x): - return torch.neg(x) + torch.relu(x) - - def replacement(x): - return torch.relu(x) - - def comparison(x): - val = torch.relu(x) - return torch.add(val, val) - - traced = symbolic_trace(M()) - comparison_fn = symbolic_trace(comparison) - - x = torch.rand(1, 3) - - subgraph_rewriter.replace_pattern(traced, pattern, replacement) - - traced.graph.lint() - - ref_output = comparison_fn(x) - test_output = traced.forward(x) - self.assertEqual(ref_output, test_output) - - def test_subgraph_rewriter_multiple_pattern_match(self): - class M(torch.nn.Module): - def forward(self, x, w1, w2): - m1 = torch.cat([w1, w2]).sum() - m2 = torch.cat([w1, w2]).sum() - return x + torch.max(m1) + torch.max(m2) - - def pattern(w1, w2): - return torch.cat([w1, w2]).sum() - - def replacement(w1, w2): - return torch.stack([w1, w2]) - - def comparison(x, w1, w2): - m1 = torch.stack([w1, w2]) - m2 = torch.stack([w1, w2]) - return x + torch.max(m1) + torch.max(m2) - - traced = symbolic_trace(M()) - comparison_fn = symbolic_trace(comparison) - - x = torch.rand(1, 3) - w1 = torch.rand(1, 3) - w2 = torch.rand(1, 3) - - subgraph_rewriter.replace_pattern(traced, pattern, replacement) - - traced.graph.lint() - - ref_outs = comparison_fn(x, w1, w2) - test_outs = traced.forward(x, w1, w2) - self.assertEqual(ref_outs, test_outs) - - def test_subgraph_rewriter_graph_argument_order(self): - class M(torch.nn.Module): - def forward(self, x, y): - return torch.mm(x, y) - - def pattern(x, y): - return torch.mm(x, y) - - def comparison(x, y): - return torch.mm(x, y) - - traced = symbolic_trace(M()) - comparison_fn = symbolic_trace(comparison) - - x = torch.randn(3, 4) - y = torch.randn(4, 5) - - subgraph_rewriter.replace_pattern(traced, pattern, pattern) - - traced.graph.lint() - - ref_outs = comparison_fn(x, y) - test_outs = traced.forward(x, y) - self.assertEqual(ref_outs, test_outs) - - def test_subgraph_rewriter_correct_output_replacement(self): - class M(torch.nn.Module): - def forward(self, x, y): - val = torch.neg(y) + torch.relu(x) - return torch.add(val, val) - - def pattern(x): - return torch.relu(x) - - def replacement(x): - return torch.neg(x) - - def comparison(x, y): - val = torch.neg(y) + torch.neg(x) - return torch.add(val, val) - - traced = symbolic_trace(M()) - comparison_fn = symbolic_trace(comparison) - - x = torch.randn(4, 4) - y = torch.randn(4, 4) - - subgraph_rewriter.replace_pattern(traced, pattern, replacement) - - traced.graph.lint() - - ref_outs = comparison_fn(x, y) - test_outs = traced.forward(x, y) - self.assertEqual(ref_outs, test_outs) - - def test_subgraph_rewriter_traced_as_callable(self): - class M(torch.nn.Module): - def forward(self, x): - val = torch.neg(x) + torch.relu(x) - return torch.add(val, val) - - class Pattern(torch.nn.Module): - def forward(self, x): - return torch.neg(x) + torch.relu(x) - - class Replacement(torch.nn.Module): - def forward(self, x): - return torch.sigmoid(x) - - def comparison(x): - val = torch.sigmoid(x) - return torch.add(val, val) - - traced = symbolic_trace(M()) - traced_pattern = symbolic_trace(Pattern()) - traced_replacement = symbolic_trace(Replacement()) - comparison_fn = symbolic_trace(comparison) - - x = torch.randn(3, 4) - - subgraph_rewriter.replace_pattern(traced, traced_pattern, traced_replacement) - - traced.graph.lint() - - ref_outs = comparison_fn(x) - test_outs = traced.forward(x) - self.assertEqual(ref_outs, test_outs) - - def test_subgraph_rewriter_pattern_is_entire_graph(self): - class M(torch.nn.Module): - def forward(self, x): - a = torch.neg(x) - return torch.add(a, a) - - def pattern(x): - a = torch.neg(x) - return torch.add(a, a) - - def replacement(x): - a = torch.sigmoid(x) - return torch.cat([a, a]) - - traced = symbolic_trace(M()) - comparison_fn = symbolic_trace(replacement) - - x = torch.randn(3, 4) - - subgraph_rewriter.replace_pattern(traced, pattern, replacement) - - traced.graph.lint() - - ref_outs = comparison_fn(x) - test_outs = traced.forward(x) - self.assertEqual(ref_outs, test_outs) - - def test_subgraph_rewriter_pattern_output_pattern_node_can_have_users_that_are_not_matched(self): - class M(torch.nn.Module): - def forward(self, x): - y = torch.relu(x) - return torch.neg(y) - y - - def pattern(x): - return torch.relu(x) - - def replacement(x): - return torch.sigmoid(x) - - def comparison(x): - y = torch.sigmoid(x) - return torch.neg(y) - y - - traced = symbolic_trace(M()) - comparison_fn = symbolic_trace(comparison) - - x = torch.randn(3, 4) - - subgraph_rewriter.replace_pattern(traced, pattern, replacement) - - traced.graph.lint() - - ref_outs = comparison_fn(x) - test_outs = traced.forward(x) - self.assertEqual(ref_outs, test_outs) - - def test_subgraph_rewriter_internal_pattern_nodes_cannot_have_users_that_are_not_matched(self): - class M(torch.nn.Module): - def forward(self, x, w1, w2, b1, b2): - m0 = torch.cat([w1, w2]) - m1 = torch.cat([w1, w2]) - m2 = torch.cat([x, b2]) - t0 = torch.addmm(b1, m1, m2.t()) - t1 = torch.sum(w1, 1) - t2 = torch.addmm(b1, m1, m2.t()) - return torch.sum(t1), torch.sum(t2) - - def pattern(x, w1, w2, b1, b2): - m1 = torch.cat([w1, w2]) - m2 = torch.cat([x, b2]) - return torch.addmm(b1, m1, m2.t()) - - def replacement(x, w1, w2, b1, b2): - return torch.cat([x, w1, w2]) - - traced = symbolic_trace(M()) - - # Result should be [] since no matches can be found - res = subgraph_rewriter.replace_pattern(traced, pattern, replacement) - - traced.graph.lint() - - self.assertEqual(res, []) - - def test_subgraph_rewriter_placeholder_matching(self): - """ - This tests that a placeholder Node can be matched to a Node with - a different number of input Nodes. In the example below, the - original traced Module looks like this: - - opcode target args kwargs - ------------- ---------------------------------------------------------- ------------------------ -------- - placeholder x () {} - call_function (x, 3) {} - call_method dequantize (add,) {} - call_function (dequantize,) {} - call_method to (sigmoid, torch.float16) {} - output output (to,) {} - - while the pattern we want to match looks like this: - - opcode target args kwargs - ------------- ---------------------------------------------------------- ------------------------ -------- - placeholder x () {} - call_method dequantize (x,) {} - call_function (dequantize,) {} - call_method to (sigmoid, torch.float16) {} - output output (to,) {} - - Here, we want to be able to match the original graph's - `call_function.add` Node with the pattern graph's - `plaeholder.x` Node. - - Credit to Jerry Zhang (GitHub: jerryzh168) for this test case - """ - class M(torch.nn.Module): - def __init__(self): - super().__init__() - self.dtype = torch.float16 - - def forward(self, x): - x += 3 - x = x.dequantize() - x = torch.sigmoid(x) - dtype = self.dtype - x = x.to(dtype) - return x - - def pattern(x): - x = x.dequantize() - x = torch.sigmoid(x) - x = x.to(torch.float16) - return x - - def replacement(x): - return x - - def comparison(x): - return x + 3 - - traced = symbolic_trace(M()) - comparison_fn = symbolic_trace(comparison) - - x = torch.randn(3, 4) - - subgraph_rewriter.replace_pattern(traced, pattern, replacement) - - traced.graph.lint() - - ref_outs = comparison_fn(x) - test_outs = traced.forward(x) - self.assertEqual(ref_outs, test_outs) - - def test_subgraph_rewriter_replaces_referenced_submodules(self): - class M(torch.nn.Module): - def __init__(self): - super().__init__() - self.sigmoid = torch.nn.Sigmoid() - self.submod = torch.nn.ReLU() - - def forward(self, x): - x = x + 1 - return self.submod(self.sigmoid(x)) - - class Pattern(torch.nn.Module): - def __init__(self): - super().__init__() - self.sigmoid = torch.nn.Sigmoid() - self.submod = torch.nn.ReLU() - - def forward(self, x): - return self.submod(self.sigmoid(x)) - - class Replacement(torch.nn.Module): - def __init__(self): - super().__init__() - self.tanh = torch.nn.Tanh() - self.submod = torch.nn.ReLU() - - def forward(self, x): - return self.submod(self.tanh(x)) - - class Comparison(torch.nn.Module): - def __init__(self): - super().__init__() - self.tanh = torch.nn.Tanh() - self.submod = torch.nn.ReLU() - - def forward(self, x): - x = x + 1 - return self.submod(self.tanh(x)) - - traced = symbolic_trace(M()) - comparison = Comparison() - - x = torch.randn(3, 4) - - subgraph_rewriter.replace_pattern(traced, Pattern(), Replacement()) - - traced.graph.lint() - - ref_outs = comparison(x) - test_outs = traced.forward(x) - self.assertEqual(ref_outs, test_outs) - - traced.get_submodule("tanh") - with self.assertRaisesRegex(AttributeError, "has no attribute"): - traced.get_submodule("sigmoid") - - submod = traced.get_submodule("submod") - self.assertEqual(type(submod), torch.nn.ReLU) - - def test_subgraph_rewriter_annotations_int(self): - - class M1(torch.nn.Module): - def forward(self, x): - y: int = x - return torch.add(x, y) - - class M2(torch.nn.Module): - def forward(self, x): - y = annotate(x, int) - return torch.add(x, y) - - ast_rewriter = RewritingTracer() - graph = ast_rewriter.trace(M1()) - - module = M2() - symbolic_traced: pippy.fx.GraphModule = symbolic_trace(module) - for n, m in zip(symbolic_traced.graph.nodes, graph.nodes): - if n.op == 'placeholder': - assert n.type == int - assert m.type == int - - def test_subgraph_rewriter_replace_consecutive_submodules(self): - - def f(x): - x = torch.sigmoid(x) - x = torch.sigmoid(x) - return torch.sigmoid(x) - - def pattern(x): - return torch.sigmoid(x) - - def replacement(x): - return torch.exp(x) - - def comparison(x): - x = torch.exp(x) - x = torch.exp(x) - return torch.exp(x) - - traced = symbolic_trace(f) - comparison_fn = symbolic_trace(comparison) - - x = torch.randn(3, 4) - - subgraph_rewriter.replace_pattern(traced, pattern, replacement) - - traced.graph.lint() - - ref_outs = comparison_fn(x) - test_outs = traced.forward(x) - self.assertEqual(ref_outs, test_outs) - - def test_subgraph_rewriter_with_overlapping_matches(self): - - def f(x): - x = torch.sigmoid(x) - x = torch.sigmoid(x) - x = torch.sigmoid(x) - return torch.sigmoid(x) - - def pattern(x): - x = torch.sigmoid(x) - x = torch.sigmoid(x) - return x - - def replacement(x): - return torch.neg(x) - - def comparison(x): - x = torch.neg(x) - return torch.neg(x) - - traced = symbolic_trace(f) - comparison_fn = symbolic_trace(comparison) - - x = torch.randn(3, 4) - - subgraph_rewriter.replace_pattern(traced, pattern, replacement) - - traced.graph.lint() - - ref_outs = comparison_fn(x) - test_outs = traced.forward(x) - self.assertEqual(ref_outs, test_outs) - - def test_subgraph_rewriter_replace_with_multiple_outputs(self): - - def f(x): - y = torch.sigmoid(x) - z = torch.relu(x) - return y + z - - def pattern(a): - b = torch.sigmoid(a) - c = torch.relu(a) - return b, c - - def replacement(x): - return torch.exp(x), torch.abs(x) - - def comparison(x): - y = torch.exp(x) - z = torch.abs(x) - return y + z - - traced = symbolic_trace(f) - comparison_fn = symbolic_trace(comparison) - - x = torch.randn(3, 4) - - subgraph_rewriter.replace_pattern(traced, pattern, replacement) - - traced.graph.lint() - - ref_outs = comparison_fn(x) - test_outs = traced.forward(x) - self.assertEqual(ref_outs, test_outs) - - def test_subgraph_rewriter_replace_with_duplicated_outputs(self): - - def f(x1, x2): - x = x1 - x2 - y = torch.sigmoid(x) - z = torch.relu(x) - return y + z - - def pattern(a1, a2): - a = a1 - a2 - b = torch.sigmoid(a) - c = torch.relu(a) - return b, c, a - - def replacement(x1, x2): - y1 = torch.exp(x1) - y2 = torch.abs(x2) - return y2, y2, y1 - - def comparison(x1, x2): - y2 = torch.abs(x2) - return y2 + y2 - - traced = symbolic_trace(f) - comparison_fn = symbolic_trace(comparison) - - x1 = torch.randn(3, 4) - x2 = torch.randn(3, 4) - - subgraph_rewriter.replace_pattern(traced, pattern, replacement) - - traced.graph.lint() - - ref_outs = comparison_fn(x1, x2) - test_outs = traced.forward(x1, x2) - self.assertEqual(ref_outs, test_outs) - - def test_subgraph_rewriter_with_unused_args(self): - class M(torch.nn.Module): - def forward(self, x, y, z): - return x + y - - def pattern(x, y): - return x + y - - def replacement(x, y): - return x - y - - def comparison(x1, x2, x3): - return x1 - x2 - - traced = symbolic_trace(M()) - comparison_fn = symbolic_trace(comparison) - - x1 = torch.randn(3, 4) - x2 = torch.randn(3, 4) - x3 = torch.randn(3, 4) - - subgraph_rewriter.replace_pattern(traced, pattern, replacement) - - traced.graph.lint() - placeholder_nodes = [n for n in traced.graph.nodes if n.op == "placeholder"] - assert len(placeholder_nodes) == 3 - - ref_outs = comparison_fn(x1, x2, x3) - test_outs = traced.forward(x1, x2, x3) - self.assertEqual(ref_outs, test_outs) - - def test_subgraph_rewriter_call_method(self): - - class M(torch.nn.Module): - def forward(self, x): - x = x.dequantize() - x = x.sigmoid() - x = x.to(torch.float16) - return x - - def pattern(x): - x = x.dequantize() - x = x.sigmoid() - x = x.to(torch.float16) - return x - - def replacement(x): - return x - - traced = symbolic_trace(M()) - comparison_fn = symbolic_trace(replacement) - - x1 = torch.randn(3, 4) - - subgraph_rewriter.replace_pattern(traced, pattern, replacement) - - traced.graph.lint() - - ref_outs = comparison_fn(x1) - test_outs = traced.forward(x1) - self.assertEqual(ref_outs, test_outs) - - def test_subgraph_rewriter_nodes_with_kwargs(self): - - class M(torch.nn.Module): - def __init__(self) -> None: - super().__init__() - self.w0 = torch.nn.Parameter(torch.empty([128, 128])) - self.b0 = torch.nn.Parameter(torch.empty([128])) - - def forward(self, in0): - lin_res = torch.nn.functional.linear(in0, self.w0, bias=self.b0) - mul_res = in0 * lin_res - sum_res = mul_res + in0 - return sum_res - - def pattern(a, b, bias): - lin_res = torch.nn.functional.linear(a, b, bias=bias) - mul_res = a * lin_res - return lin_res, mul_res - - def replacement(a, b, bias): - lin_res, mul_res = wrapped_gemm_bias_mul(a, b, bias) - return lin_res, mul_res - - traced = symbolic_trace(M()) - matches = subgraph_rewriter.replace_pattern(traced, pattern, replacement) - - self.assertEqual(len(matches), 1) - - found_repalcement_node = False - for node in traced.graph.nodes: - if node.target == wrapped_gemm_bias_mul: - found_repalcement_node = True - break - - self.assertTrue(found_repalcement_node) - - def test_subgraph_rewriter_local_revert(self): - - # Following model will have 3 anchors as the matching candidate with the given pattern - # Anchor 1 and 3 is a real match, but anchor 2 is not. - # The subgraph rewriter should be able to revert the changes made while matching anchor 2. - # Final match with anchor 3 should be successful. - - class M(torch.nn.Module): - def __init__(self) -> None: - super().__init__() - self.w0 = torch.nn.Parameter(torch.empty([128, 128])) - self.b0 = torch.nn.Parameter(torch.empty([128])) - self.w1 = torch.nn.Parameter(torch.empty([128, 128])) - self.b1 = torch.nn.Parameter(torch.empty([128])) - self.w2 = torch.nn.Parameter(torch.empty([128, 128])) - self.b2 = torch.nn.Parameter(torch.empty([128])) - self.w3 = torch.nn.Parameter(torch.empty([128, 128])) - self.b3 = torch.nn.Parameter(torch.empty([128])) - self.w4 = torch.nn.Parameter(torch.empty([128, 128])) - self.b4 = torch.nn.Parameter(torch.empty([128])) - - def forward(self, in0, in1): - lin_res_1 = torch.nn.functional.linear(in1, self.w0, bias=self.b0) - lin_res_2 = torch.nn.functional.linear(lin_res_1, self.w1, bias=self.b1) - # potential match at anchor 1 - mul_res_1 = in1 * lin_res_2 - sum_res_1 = mul_res_1 + in1 - lin_res_3 = torch.nn.functional.linear( - sum_res_1, self.w2, bias=self.b2 - ) - sigmoid_res_1 = torch.sigmoid(lin_res_3) - # potential match at anchor 2 - mul_res_2 = lin_res_3 * sigmoid_res_1 - lin_res_4 = torch.nn.functional.linear(in0, self.w3, bias=self.b3) - lin_res_5 = torch.nn.functional.linear(lin_res_4, self.w4, bias=self.b4) - # potential match at anchor 3 - mul_res_3 = in0 * lin_res_5 - sum_res_2 = mul_res_3 + in0 - cat_res = torch.cat( - [mul_res_2, sum_res_2], - dim=1, - ) - return cat_res - - def gemm_bias_mul_pattern_with_c(a, b, bias, c): - lin_res = torch.nn.functional.linear(a, b, bias=bias) - mul_res = c * lin_res - return lin_res, mul_res - - def gemm_bias_mul_replacement_with_c(a, b, bias, c): - lin_res, mul_res = wrapped_gemm_bias_mul_with_c(a, b, bias, c) - return lin_res, mul_res - - traced = symbolic_trace(M()) - matches = subgraph_rewriter.replace_pattern( - traced, - gemm_bias_mul_pattern_with_c, - gemm_bias_mul_replacement_with_c) - - self.assertEqual(len(matches), 2) - - repalcement_node_found = 0 - for node in traced.graph.nodes: - if node.target == wrapped_gemm_bias_mul_with_c: - repalcement_node_found += 1 - - self.assertEqual(repalcement_node_found, 2) diff --git a/test/fx/test_z3_gradual_types.py b/test/fx/test_z3_gradual_types.py deleted file mode 100644 index 567a31dfe..000000000 --- a/test/fx/test_z3_gradual_types.py +++ /dev/null @@ -1,2481 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates -# Owner(s): ["module: fx"] -import operator -import unittest -from pippy.fx import GraphModule, symbolic_trace -from pippy.fx.experimental.meta_tracer import symbolic_trace as meta_symbolic_trace -from pippy.fx.experimental.migrate_gradual_types.constraint import BinConstraintT, DVar, TVar, T -from pippy.fx.experimental.migrate_gradual_types.constraint_generator import ConstraintGenerator -from pippy.fx.experimental.migrate_gradual_types.constraint_transformation import transform_constraint -from pippy.fx.experimental.migrate_gradual_types.operation import op_precision, op_matching, op_consistency -from pippy.fx.experimental.migrate_gradual_types.transform_to_z3 import transform_all_constraints,\ - evaluate_conditional_with_constraints -from pippy.fx.experimental.migrate_gradual_types.z3_types import tensor_type, D, z3_dyn -from pippy.fx.experimental.rewriter import RewritingTracer -from pippy.fx.tensor_type import Dyn, TensorType -import torch - - -try: - import z3 # type: ignore[import] - HAS_Z3 = True -except ImportError: - HAS_Z3 = False - - -try: - from torchvision import models - HAS_TORCHVISION = True -except ImportError: - HAS_TORCHVISION = False -skipIfNoTorchVision = unittest.skipIf(not HAS_TORCHVISION, "no torchvision") - -class TorchDynamoUseCases(unittest.TestCase): - - def test_dim(self): - class BasicBlock(torch.nn.Module): - def __init__(self): - super().__init__() - - def forward(self, x: TensorType([1, 2])): - y = x.dim() - return y - - symbolic_traced: pippy.fx.GraphModule = symbolic_trace(BasicBlock()) - transformed = transform_all_constraints(symbolic_traced, counter=0) - s = z3.Solver() - s.add(transformed) - self.assertEqual(s.check(), z3.sat) - y_res = z3.z3.Int(2) - self.assertEqual(s.model()[y_res], 2) - - - def test_reshape(self): - """ - In this example, we prove that some nodes must - always have a fixed shape regardless of the input - """ - - class BasicBlock(torch.nn.Module): - def __init__(self): - super().__init__() - - def forward(self, x: Dyn): - y = x.view(100) - tmp = y.size()[0] - return tmp - - symbolic_traced: pippy.fx.GraphModule = symbolic_trace(BasicBlock()) - transformed = transform_all_constraints(symbolic_traced, counter=0) - - s = z3.Solver() - s.add(transformed) - self.assertEqual(s.check(), z3.sat) - dim = z3.Int(4) - self.assertEqual(s.model()[dim], 100) - # print(s.model()[dim]) - - -class HFOperations(unittest.TestCase): - - def test_eq_dim(self): - """ - test dimensions and equalities - """ - class BasicBlock(torch.nn.Module): - def __init__(self): - super(BasicBlock, self).__init__() - - def forward(self, x: TensorType([32, 4, 4])): - eq = x.dim() == 3 - return eq - - ast_rewriter = RewritingTracer() - graph = ast_rewriter.trace(BasicBlock()) - - # The node we are considering is the gt node - for n in graph.nodes: - if n.target == operator.eq: - node = n - - positive, negative = evaluate_conditional_with_constraints(ast_rewriter.root, graph, node) - self.assertEqual(positive, z3.sat) - self.assertEqual(negative, z3.unsat) - - def test_conditional_ne_1(self): - """ - This test case is for the HFmodels interface. - A function takes a node and a graph and considers - the conditional the node represents and its negation - and solves each formula with the remaining sets of constraints - Returns: - - """ - class BasicBlock(torch.nn.Module): - def __init__(self): - super(BasicBlock, self).__init__() - - def forward(self, x: TensorType([32, 4, 4]), y: TensorType([32, 4, 4])): - size_5 = x.size() - getitem_7 = size_5[0] - getitem_8 = size_5[1] - getitem_9 = size_5[2] - ne_1 = y != (getitem_7, getitem_8, getitem_9) - return ne_1 - - ast_rewriter = RewritingTracer() - graph = ast_rewriter.trace(BasicBlock()) - - # The node we are considering is the gt node - for n in graph.nodes: - if n.target == operator.ne: - node = n - - # since x and y are equal, the requirement that x != y cannot be true, so we should get unsat - # for the positive condition and sat for the negative condition - positive, negative = evaluate_conditional_with_constraints(ast_rewriter.root, graph, node) - self.assertEqual(positive, z3.unsat) - self.assertEqual(negative, z3.sat) - - def test_bmm(self): - class BasicBlock(torch.nn.Module): - def __init__(self): - super(BasicBlock, self).__init__() - - def forward(self, x: TensorType([Dyn, 2, 3]), y: TensorType([1, 3, 2])): - bmm = torch.bmm(x, y) - return bmm - - symbolic_traced: pippy.fx.GraphModule = symbolic_trace(BasicBlock()) - b = BasicBlock().forward(torch.rand(1, 2, 3), torch.rand(1, 3, 2)) - transformed = transform_all_constraints(symbolic_traced, counter=0) - - s = z3.Solver() - s.add(transformed) - - output = z3.Const(3, tensor_type) - self.assertEqual(s.check(), z3.sat) - self.assertEqual(s.model()[output].arg(0).arg(1), b.shape[0]) - self.assertEqual(s.model()[output].arg(1).arg(1), b.shape[1]) - self.assertEqual(s.model()[output].arg(2).arg(1), b.shape[2]) - - - def test_bmm2(self): - class BasicBlock(torch.nn.Module): - def __init__(self): - super(BasicBlock, self).__init__() - - def forward(self, x: Dyn, y: TensorType([1, 3, 2])): - bmm = torch.bmm(x, y) - return bmm - - symbolic_traced: pippy.fx.GraphModule = symbolic_trace(BasicBlock()) - b = BasicBlock().forward(torch.rand(1, 2, 3), torch.rand(1, 3, 2)) - transformed = transform_all_constraints(symbolic_traced, counter=0) - - s = z3.Solver() - s.add(transformed) - - output = z3.Const(3, tensor_type) - self.assertEqual(s.check(), z3.sat) - self.assertEqual(s.model()[output].arg(0).arg(1), b.shape[0]) - self.assertEqual(s.model()[output].arg(1).arg(0), 0) - self.assertEqual(s.model()[output].arg(2).arg(1), b.shape[2]) - - def test_bmm3(self): - class BasicBlock(torch.nn.Module): - def __init__(self): - super(BasicBlock, self).__init__() - - def forward(self, x: TensorType([2, 3, 3]), y: TensorType([1, 3, 2])): - bmm = torch.bmm(x, y) - return bmm - - symbolic_traced: pippy.fx.GraphModule = symbolic_trace(BasicBlock()) - transformed = transform_all_constraints(symbolic_traced, counter=0) - - s = z3.Solver() - s.add(transformed) - self.assertEqual(s.check(), z3.unsat) - - - def test_transpose(self): - class BasicBlock(torch.nn.Module): - def __init__(self): - super(BasicBlock, self).__init__() - - def forward(self, x: TensorType([1, 2, 3, 4])): - transpose = x.transpose(0, 1) - return transpose - - symbolic_traced: pippy.fx.GraphModule = symbolic_trace(BasicBlock()) - b = BasicBlock().forward(torch.rand(1, 2, 3, 4)) - - transformed = transform_all_constraints(symbolic_traced, counter=0) - - s = z3.Solver() - s.add(transformed) - output = z3.Const(2, tensor_type) - self.assertEqual(s.check(), z3.sat) - self.assertEqual(s.model()[output].arg(0).arg(1), b.shape[0]) - self.assertEqual(s.model()[output].arg(1).arg(1), b.shape[1]) - self.assertEqual(s.model()[output].arg(2).arg(1), b.shape[2]) - self.assertEqual(s.model()[output].arg(3).arg(1), b.shape[3]) - - # change the annotation to Dyn - for n in symbolic_traced.graph.nodes: - if n.op == 'placeholder': - n.type = Dyn - - transformed = transform_all_constraints(symbolic_traced, counter=0) - - s = z3.Solver() - s.add(transformed) - self.assertEqual(s.check(), z3.sat) - - - def test_index_select(self): - class BasicBlock(torch.nn.Module): - def __init__(self): - super(BasicBlock, self).__init__() - - def forward(self, x: TensorType([2050, 1024]), y: Dyn): - index_select = x.index_select(0, y) - return index_select - symbolic_traced: pippy.fx.GraphModule = symbolic_trace(BasicBlock()) - # print(symbolic_traced) - b = BasicBlock().forward(torch.rand(2050, 1024), torch.ones(8).int()) - transformed = transform_all_constraints(symbolic_traced, counter=0) - s = z3.Solver() - s.add(transformed) - self.assertEqual(s.check(), z3.sat) - index_select = z3.Const(3, tensor_type) - - # the second dimension of the result should not be affected since - # the index is 0 - self.assertEqual(s.model()[index_select].arg(1).arg(1), b.shape[1]) - - replacement_vector = z3.Const(2, tensor_type) - - # we set the vector to Dyn - s = z3.Solver() - s.add(transformed) - self.assertEqual(s.check(), z3.sat) - index_select = z3.Const(3, tensor_type) - s.add(replacement_vector == z3_dyn) - self.assertEqual(s.check(), z3.sat) - - # this implies that the index at 0 should be dyn - self.assertEqual(s.model()[index_select].arg(0).arg(0), 0) - - def test_get_attr(self): - class BasicBlock(torch.nn.Module): - def __init__(self): - super(BasicBlock, self).__init__() - - def forward(self, x: TensorType([1, 2, 3])): - getattr = x.device - to = x.to(getattr) - return to - - symbolic_traced: pippy.fx.GraphModule = symbolic_trace(BasicBlock()) - b = BasicBlock().forward(torch.rand(1, 2, 3)) - transformed = transform_all_constraints(symbolic_traced, counter=0) - s = z3.Solver() - s.add(transformed) - self.assertEqual(s.check(), z3.sat) - attr_res = z3.Const(3, tensor_type) - assert s.model()[attr_res].arg(0).arg(1) == b.shape[0] - assert s.model()[attr_res].arg(1).arg(1) == b.shape[1] - assert s.model()[attr_res].arg(2).arg(1) == b.shape[2] - - - def test_expand(self): - class BasicBlock(torch.nn.Module): - def __init__(self): - super(BasicBlock, self).__init__() - - def forward(self, x: TensorType([1, 4])): - size = x.size() - getitem = size[-1] - expand = x.expand(getitem, 4) - return expand - - b = BasicBlock().forward(torch.rand(1, 4)) - - symbolic_traced: pippy.fx.GraphModule = symbolic_trace(BasicBlock()) - transformed = transform_all_constraints(symbolic_traced, counter=0) - - s = z3.Solver() - s.add(transformed) - self.assertEquals(s.check(), z3.sat) - expand_res = z3.Const(4, tensor_type) - assert s.model()[expand_res].arg(0).arg(1) == b.shape[0] - assert s.model()[expand_res].arg(1).arg(1) == b.shape[1] - - # change the annotation on the input to Dyn. - # the last dimension should still be 4 - for n in symbolic_traced.graph.nodes: - if n.op == 'placeholder': - n.type = Dyn - - transformed = transform_all_constraints(symbolic_traced, counter=0) - - s = z3.Solver() - s.add(transformed) - self.assertEquals(s.check(), z3.sat) - - assert s.model()[expand_res].arg(1).arg(1) == b.shape[1] - - def test_getitem_tensor(self): - class BasicBlock(torch.nn.Module): - def __init__(self): - super(BasicBlock, self).__init__() - - def forward(self, x: TensorType([4, 4])): - getitem = x[(None, None, slice(None, None, None), slice(None, None, None))] - return getitem - - B = BasicBlock() - b = B.forward(torch.rand(4, 4)) - - symbolic_traced: pippy.fx.GraphModule = symbolic_trace(B) - transformed = transform_all_constraints(symbolic_traced, counter=0) - - s = z3.Solver() - s.add(transformed) - self.assertEquals(s.check(), z3.sat) - get_item_res = z3.Const(2, tensor_type) - assert s.model()[get_item_res].arg(0).arg(1) == b.shape[0] - assert s.model()[get_item_res].arg(1).arg(1) == b.shape[1] - assert s.model()[get_item_res].arg(2).arg(1) == b.shape[2] - assert s.model()[get_item_res].arg(3).arg(1) == b.shape[3] - - # change the annotation on the input to make sure it propagates - # to the output - for n in symbolic_traced.graph.nodes: - if n.op == 'placeholder': - n.type = TensorType([Dyn, 4]) - - transformed = transform_all_constraints(symbolic_traced, counter=0) - s = z3.Solver() - s.add(transformed) - self.assertEqual(s.check(), z3.sat) - # dyn check - assert s.model()[get_item_res].arg(2).arg(0) == 0 - - - def test_getitem_tensor2(self): - class BasicBlock(torch.nn.Module): - def __init__(self): - super(BasicBlock, self).__init__() - - def forward(self, x: TensorType([4, 4])): - getitem = x[(None, None)] - return getitem - - B = BasicBlock() - b = B.forward(torch.rand(4, 4)) - - symbolic_traced: pippy.fx.GraphModule = symbolic_trace(B) - transformed = transform_all_constraints(symbolic_traced, counter=0) - s = z3.Solver() - s.add(transformed) - self.assertEquals(s.check(), z3.sat) - get_item_res = z3.Const(2, tensor_type) - assert s.model()[get_item_res].arg(0).arg(1) == b.shape[0] - assert s.model()[get_item_res].arg(1).arg(1) == b.shape[1] - assert s.model()[get_item_res].arg(2).arg(1) == b.shape[2] - assert s.model()[get_item_res].arg(3).arg(1) == b.shape[3] - - - def test_getitem_tensor_3(self): - class BasicBlock(torch.nn.Module): - def __init__(self): - super(BasicBlock, self).__init__() - - def forward(self, x: TensorType([4, 4])): - getitem = x[(None, slice(None, None, None), None, slice(None, None, None))] - return getitem - - B = BasicBlock() - b = B.forward(torch.rand(4, 4)) - symbolic_traced: pippy.fx.GraphModule = symbolic_trace(B) - transformed = transform_all_constraints(symbolic_traced, counter=0) - s = z3.Solver() - s.add(transformed) - self.assertEquals(s.check(), z3.sat) - get_item_res = z3.Const(2, tensor_type) - assert s.model()[get_item_res].arg(0).arg(1) == b.shape[0] - assert s.model()[get_item_res].arg(1).arg(1) == b.shape[1] - assert s.model()[get_item_res].arg(2).arg(1) == b.shape[2] - assert s.model()[get_item_res].arg(3).arg(1) == b.shape[3] - - - - def test_layer_norm(self): - - class BasicBlock(torch.nn.Module): - def __init__(self): - super(BasicBlock, self).__init__() - self.l = torch.nn.LayerNorm((1024,)) - - def forward(self, x: Dyn): - return self.l(x) - - ast_rewriter = RewritingTracer() - graph = ast_rewriter.trace(BasicBlock()) - traced = GraphModule(ast_rewriter.root, graph, "gm") - transformed = transform_all_constraints(traced, counter=0) - - s = z3.Solver() - s.add(transformed) - self.assertEquals(s.check(), z3.sat) - - # make the output a size 1 tensor which should result - # in the migration of the input - - b = BasicBlock().forward(torch.rand(1024)) - input = z3.Const(1, tensor_type) - output = z3.Const(2, tensor_type) - s.add(output == tensor_type.tensor1(D(1, 1024))) - s.check() - self.assertEqual(s.model()[input], s.model()[output]) - # input shape = output shape - self.assertEqual(b.shape[0], s.model()[input].arg(0).arg(1)) - - # change annotation to the wrong shape - for n in graph.nodes: - if n.op == 'placeholder': - n.type = TensorType([10, 10]) - - traced = GraphModule(ast_rewriter.root, graph, "gm") - transformed = transform_all_constraints(traced, counter=0) - s = z3.Solver() - s.add(transformed) - self.assertEqual(s.check(), z3.unsat) - - # fix the annotation - for n in graph.nodes: - if n.op == 'placeholder': - n.type = TensorType([10, 1024]) - - traced = GraphModule(ast_rewriter.root, graph, "gm") - transformed = transform_all_constraints(traced, counter=0) - s = z3.Solver() - s.add(transformed) - s.check() - b = BasicBlock().forward(torch.rand(10, 1024)).shape - self.assertEqual(s.model()[output].arg(0).arg(1), b[0]) - self.assertEqual(s.model()[output].arg(1).arg(1), b[1]) - - - def test_layer_norm_functional(self): - - class BasicBlock(torch.nn.Module): - def __init__(self): - super(BasicBlock, self).__init__() - - def forward(self, x: Dyn): - return torch.nn.functional.layer_norm(x, (1024,)) - - ast_rewriter = RewritingTracer() - graph = ast_rewriter.trace(BasicBlock()) - traced = GraphModule(ast_rewriter.root, graph, "gm") - transformed = transform_all_constraints(traced, counter=0) - - s = z3.Solver() - s.add(transformed) - self.assertEquals(s.check(), z3.sat) - - # make the output a size 1 tensor which should result - # in the migration of the input - - b = BasicBlock().forward(torch.rand(1024)) - input = z3.Const(1, tensor_type) - output = z3.Const(2, tensor_type) - s.add(output == tensor_type.tensor1(D(1, 1024))) - s.check() - self.assertEqual(s.model()[input], s.model()[output]) - # input shape = output shape - self.assertEqual(b.shape[0], s.model()[input].arg(0).arg(1)) - - def test_ne_int_long_type_as(self): - - class BasicBlock(torch.nn.Module): - def __init__(self): - super(BasicBlock, self).__init__() - - def forward(self, x: TensorType([Dyn, Dyn]), y: TensorType([Dyn, Dyn])): - ne_int = torch.ne(x, y).int() - type_as = ne_int.type_as(y) - long = type_as.long() - return long - - symbolic_traced: pippy.fx.GraphModule = symbolic_trace(BasicBlock()) - transformed = transform_all_constraints(symbolic_traced, counter=0) - s = z3.Solver() - s.add(transformed) - self.assertEquals(s.check(), z3.sat) - - # migrate one of the parameters to a fully static shape so we can compare - - input = z3.Const(1, tensor_type) - input_2 = z3.Const(2, tensor_type) - s1, s2 = z3.Ints('s1 s2') - - output_long = z3.Const(8, tensor_type) - s.add(input == tensor_type.tensor2(D(1, 2), D(1, 4))) - s.add(input_2 == tensor_type.tensor2(D(1, s1), D(1, s2))) - - self.assertEquals(s.check(), z3.sat) - actual_shape = BasicBlock().forward(torch.rand(2, 4), torch.rand(2, 4)).shape - self.assertEqual(s.model()[output_long].arg(0).arg(1), actual_shape[0]) - self.assertEqual(s.model()[output_long].arg(1).arg(1), actual_shape[1]) - - - def test_ne(self): - s1, s2 = z3.Ints('s1 s2') - s11, s22 = z3.Ints('s11 s22') - d1, d2 = D(s11, s1), D(0, s2) - - class BasicBlock(torch.nn.Module): - def __init__(self): - super(BasicBlock, self).__init__() - - def forward(self, x: Dyn, y: Dyn): - return torch.ne(x, y) - - ast_rewriter = RewritingTracer() - graph = ast_rewriter.trace(BasicBlock()) - traced = GraphModule(ast_rewriter.root, graph, "gm") - - transformed = transform_all_constraints(traced, counter=0) - s = z3.Solver() - s.add(transformed) - self.assertEquals(s.check(), z3.sat) - - # change the annotations - for n in graph.nodes: - if n.name == 'x': - n.type = TensorType([1, 2]) - if n.name == 'y': - n.type = TensorType([2, Dyn]) - - # resulting type should be TensorType([2, 2]) - transformed = transform_all_constraints(traced, counter=0) - s = z3.Solver() - s.add(transformed) - self.assertEquals(s.check(), z3.sat) - - # force the second dimension to be Dyn - # output should still be TensorType([2, 2]) - input = z3.Const(2, tensor_type) - s.add(input == tensor_type.tensor2(d1, d2)) - self.assertEqual(s.check(), z3.sat) - B = BasicBlock().forward(torch.rand(1, 2), torch.rand(2, 1)) - output = z3.Const(3, tensor_type) - self.assertEqual(s.model()[output].arg(0).arg(1), B.shape[0]) - self.assertEqual(s.model()[output].arg(1).arg(1), B.shape[0]) - - - def test_cumsum(self): - class BasicBlock(torch.nn.Module): - def __init__(self): - super(BasicBlock, self).__init__() - - def forward(self, x: TensorType([Dyn, 4, 3])): - t = torch.cumsum(x, 3) - return t - - symbolic_traced: pippy.fx.GraphModule = meta_symbolic_trace(BasicBlock(), meta_args={}) - transformed = transform_all_constraints(symbolic_traced, counter=0) - s = z3.Solver() - s.add(transformed) - - # should be unsat since the index is not valid for this annotation - self.assertEqual(s.check(), z3.unsat) - - # modify the annotation to Dyn which should give sat - for n in symbolic_traced.graph.nodes: - if n.op == 'placeholder': - n.type = Dyn - - transformed = transform_all_constraints(symbolic_traced, counter=0) - s = z3.Solver() - s.add(transformed) - self.assertEqual(s.check(), z3.sat) - - # # modify the annotation to the right tensor size - for n in symbolic_traced.graph.nodes: - if n.op == 'placeholder': - n.type = TensorType([1, 2, 3, 4]) - - # verify that the input is equal to the output - B = BasicBlock().forward(torch.rand(1, 2, 3, 4)) - res_shape = B.shape - transformed = transform_all_constraints(symbolic_traced, counter=0) - s = z3.Solver() - s.add(transformed) - self.assertEqual(s.check(), z3.sat) - - # confirm the output matches the expected tensor - result = z3.Const(2, tensor_type) - self.assertEqual(s.model()[result].arg(0).arg(1), res_shape[0]) - self.assertEqual(s.model()[result].arg(1).arg(1), res_shape[1]) - self.assertEqual(s.model()[result].arg(2).arg(1), res_shape[2]) - self.assertEqual(s.model()[result].arg(3).arg(1), res_shape[3]) - - # confirm the output is not dyn - self.assertNotEqual(s.model()[result].arg(0).arg(0).as_long(), 0) - self.assertNotEqual(s.model()[result].arg(1).arg(0).as_long(), 0) - self.assertNotEqual(s.model()[result].arg(2).arg(0).as_long(), 0) - self.assertNotEqual(s.model()[result].arg(3).arg(0).as_long(), 0) - - - def test_cumsum_kwargs(self): - class BasicBlock(torch.nn.Module): - def __init__(self): - super(BasicBlock, self).__init__() - - def forward(self, x: TensorType([Dyn, 4, 3])): - t = torch.cumsum(x, dim=3) - return t - - symbolic_traced: pippy.fx.GraphModule = meta_symbolic_trace(BasicBlock(), meta_args={}) - transformed = transform_all_constraints(symbolic_traced, counter=0) - s = z3.Solver() - s.add(transformed) - - # should be unsat since the index is not valid for this annotation - self.assertEqual(s.check(), z3.unsat) - - # modify the annotation to Dyn which should give sat - for n in symbolic_traced.graph.nodes: - if n.op == 'placeholder': - n.type = Dyn - - transformed = transform_all_constraints(symbolic_traced, counter=0) - s = z3.Solver() - s.add(transformed) - self.assertEqual(s.check(), z3.sat) - - - def test_arange(self): - class BasicBlock(torch.nn.Module): - def __init__(self): - super(BasicBlock, self).__init__() - - def forward(self, x: TensorType([2, 4])): - size = x.size() - getitem = size[-1] - arange = torch.arange(getitem) - return arange - - B = BasicBlock().forward(torch.rand(2, 4)) - - symbolic_traced: pippy.fx.GraphModule = meta_symbolic_trace(BasicBlock(), meta_args={}) - transformed = transform_all_constraints(symbolic_traced, counter=0) - s = z3.Solver() - s.add(transformed) - self.assertEqual(s.check(), z3.sat) - arange_result = z3.Const(5, tensor_type) - self.assertNotEqual(s.model()[arange_result].arg(0).arg(0).as_long(), 0) - self.assertEqual(s.model()[arange_result].arg(0).arg(1).as_long(), B.size()[0]) - - # change the annotation to Dyn. This will migrate to an arbitirary type - for n in symbolic_traced.graph.nodes: - if n.op == 'placeholder': - n.type = Dyn - - transformed = transform_all_constraints(symbolic_traced, counter=0) - s = z3.Solver() - s.add(transformed) - self.assertEqual(s.check(), z3.sat) - - for n in symbolic_traced.graph.nodes: - if n.op == 'placeholder': - n.type = TensorType([Dyn, Dyn, Dyn, Dyn]) - - transformed = transform_all_constraints(symbolic_traced, counter=0) - s = z3.Solver() - s.add(transformed) - self.assertEqual(s.check(), z3.sat) - - def test_scalar_add(self): - class BasicBlock(torch.nn.Module): - def __init__(self): - super(BasicBlock, self).__init__() - - def forward(self, x: TensorType([2, 4])): - size = x.size() - getitem = size[-1] - arange = torch.arange(getitem) - add = arange + 1 - return add - - symbolic_traced: pippy.fx.GraphModule = meta_symbolic_trace(BasicBlock(), meta_args={}) - transformed = transform_all_constraints(symbolic_traced, counter=0) - s = z3.Solver() - s.add(transformed) - self.assertEqual(s.check(), z3.sat) - - arange_result = z3.Const(5, tensor_type) - add_result = z3.Const(6, tensor_type) - self.assertEqual(s.model()[arange_result], s.model()[add_result]) - - - def test_regular_add_2(self): - class BasicBlock(torch.nn.Module): - def __init__(self): - super(BasicBlock, self).__init__() - - def forward(self, x: TensorType([2, 4])): - to = x.to() - size = to.size() - getitem = size[-1] - add = getitem + 1 - return add - - b = BasicBlock().forward(torch.rand(2, 4)) - - symbolic_traced: pippy.fx.GraphModule = meta_symbolic_trace(BasicBlock(), meta_args={}) - transformed = transform_all_constraints(symbolic_traced, counter=0) - s = z3.Solver() - s.add(transformed) - self.assertEqual(s.check(), z3.sat) - res = z3.Int(5) - self.assertEqual(s.model()[res], b) - - - def test_regular_add_3(self): - class BasicBlock(torch.nn.Module): - def __init__(self): - super(BasicBlock, self).__init__() - - def forward(self, x: TensorType([2, 4])): - to = x.to() - size = to.size() - getitem = size[-1] - add = 1 + getitem - return add - - b = BasicBlock().forward(torch.rand(2, 4)) - - symbolic_traced: pippy.fx.GraphModule = meta_symbolic_trace(BasicBlock(), meta_args={}) - transformed = transform_all_constraints(symbolic_traced, counter=0) - s = z3.Solver() - s.add(transformed) - self.assertEqual(s.check(), z3.sat) - res = z3.Int(5) - self.assertEqual(s.model()[res], b) - - def test_embedding(self): - class BasicBlock(torch.nn.Module): - def __init__(self): - super(BasicBlock, self).__init__() - self.embedding = torch.nn.Embedding(256008, 1024, padding_idx=1) - - def forward(self, x: TensorType([2, 4])): - return self.embedding(x) - - B = BasicBlock().forward(torch.ones([2, 4], dtype=torch.long)).size() - ast_rewriter = RewritingTracer() - graph = ast_rewriter.trace(BasicBlock()) - traced = GraphModule(ast_rewriter.root, graph, "gm") - - transformed = transform_all_constraints(traced, counter=0) - s = z3.Solver() - s.add(transformed) - self.assertEquals(s.check(), z3.sat) - embedding_result = z3.Const(2, tensor_type) - - assert s.model()[embedding_result].arg(0).arg(1) == B[0] - assert s.model()[embedding_result].arg(1).arg(1) == B[1] - assert s.model()[embedding_result].arg(2).arg(1) == B[2] - - # change the type. This should still be satisfiable - for n in traced.graph.nodes: - if n.op == 'placeholder': - n.type = TensorType([Dyn, Dyn]) - - transformed = transform_all_constraints(traced, counter=0) - s = z3.Solver() - s.add(transformed) - self.assertEquals(s.check(), z3.sat) - assert s.model()[embedding_result].arg(0).arg(0) == 0 - assert s.model()[embedding_result].arg(1).arg(0) == 0 - assert s.model()[embedding_result].arg(2).arg(1) == B[2] - - # change the type to Dyn. Here, we will get an arbitirary migration - for n in traced.graph.nodes: - if n.op == 'placeholder': - n.type = Dyn - - transformed = transform_all_constraints(traced, counter=0) - s = z3.Solver() - s.add(transformed) - - self.assertEquals(s.check(), z3.sat) - - - def test_embedding_2(self): - class BasicBlock(torch.nn.Module): - def __init__(self): - super(BasicBlock, self).__init__() - - def forward(self, x: TensorType([2, 4]), y: TensorType([Dyn, 1024])): - return torch.nn.functional.embedding(x, y) - - B = BasicBlock().forward(torch.ones([2, 4], dtype=torch.long), torch.rand(256008, 1024)).size() - ast_rewriter = RewritingTracer() - graph = ast_rewriter.trace(BasicBlock()) - traced = GraphModule(ast_rewriter.root, graph, "gm") - transformed = transform_all_constraints(traced, counter=0) - s = z3.Solver() - s.add(transformed) - self.assertEquals(s.check(), z3.sat) - embedding_result = z3.Const(5, tensor_type) - - assert s.model()[embedding_result].arg(0).arg(1) == B[0] - assert s.model()[embedding_result].arg(1).arg(1) == B[1] - assert s.model()[embedding_result].arg(2).arg(1) == B[2] - - def test_size_two_args(self): - class BasicBlock(torch.nn.Module): - def __init__(self): - super(BasicBlock, self).__init__() - - def forward(self, x: TensorType([Dyn, 2, Dyn])): - size = x.size(-1) - return size - - ast_rewriter = RewritingTracer() - graph = ast_rewriter.trace(BasicBlock()) - traced = GraphModule(ast_rewriter.root, graph, "gm") - transformed = transform_all_constraints(traced, counter=0) - s = z3.Solver() - s.add(transformed) - self.assertEqual(s.check(), z3.sat) - - d1, d2 = z3.Int(39), z3.Int(2) - d4, d5 = z3.Int('input_d1'), z3.Int('input_d2') - - # migrate the third dimension - s.add(d1 != 0) - - self.assertEqual(s.check(), z3.sat) - input = z3.Const(1, tensor_type) - s.add(input == tensor_type.tensor3(D(3, 39), D(1, 2), D(d4, d5))) - - # check if the item we got is the right one - self.assertEqual(s.check(), z3.sat) - self.assertEqual(s.model()[d5], s.model()[d2]) - self.assertEqual(s.model()[d1], s.model()[d4]) - - def test_size_getitem(self): - class BasicBlock(torch.nn.Module): - def __init__(self): - super(BasicBlock, self).__init__() - - def forward(self, x: Dyn): - size = x.size() - getitem = size[-1] - return getitem - - - ast_rewriter = RewritingTracer() - graph = ast_rewriter.trace(BasicBlock()) - traced = GraphModule(ast_rewriter.root, graph, "gm") - - transformed = transform_all_constraints(traced, counter=0) - s = z3.Solver() - s.add(transformed) - - self.assertEquals(s.check(), z3.sat) - - # force the input to be of size 4 - - s1, s2, s3, s4 = z3.Ints('x1 x2 x3 x4') - s11, s22, s33, s44 = z3.Ints('x11 x22 x33 x44') - d1, d2, d3, d4 = D(s11, s1), D(s22, s2), D(s33, s3), D(s44, s4), - - input = z3.Const(1, tensor_type) - s.add(input == tensor_type.tensor4(d1, d2, d3, d4)) - - # check if the model is still SAT - self.assertEquals(s.check(), z3.sat) - - s1, s2 = z3.Int(23), z3.Int(3) - - # check that the item is correct - self.assertEquals(s.model()[s1], s.model()[s2]) - - # invalid index but should still be SAT because input will be Dyn - class BasicBlock(torch.nn.Module): - def __init__(self): - super(BasicBlock, self).__init__() - - def forward(self, x: Dyn): - size = x.size() - getitem = size[-10] - return getitem - - ast_rewriter = RewritingTracer() - graph = ast_rewriter.trace(BasicBlock()) - traced = GraphModule(ast_rewriter.root, graph, "gm") - - transformed = transform_all_constraints(traced, counter=0) - s = z3.Solver() - s.add(transformed) - - self.assertEquals(s.check(), z3.sat) - s.add(input != z3_dyn) - self.assertEqual(s.check(), z3.unsat) - - def test_view_mul(self): - class BasicBlock(torch.nn.Module): - def __init__(self): - super(BasicBlock, self).__init__() - self.embed_tokens = torch.nn.Embedding(256008, 1024, padding_idx=1) - - def forward(self, x: TensorType([2, 4])): - size = x.size() - getitem = size[-1] - view = x.view(-1, getitem) - embed_tokens = self.embed_tokens(view) - mul = embed_tokens * 32.0 - return mul - - - # print(B) - - ast_rewriter = RewritingTracer() - graph = ast_rewriter.trace(BasicBlock()) - traced = GraphModule(ast_rewriter.root, graph, "gm") - - # print(traced) - - transformed = transform_all_constraints(traced, counter=0) - s = z3.Solver() - s.add(transformed) - self.assertEquals(s.check(), z3.sat) - # print(s.model()) - - embedding_result = z3.Const(6, tensor_type) - - # note that the view output will be: tensor3(dim(0, 0), dim(1, 4), dim(1, 1024)) - # this is due to the reshape constraints. This can be lifted - # but would require revising the type rules accordingly so we leave it for now - assert (s.model()[embedding_result].arg(1).arg(1)) == 4 - assert (s.model()[embedding_result].arg(2).arg(1)) == 1024 - - mul_result = z3.Const(13, tensor_type) - assert s.model()[mul_result] == s.model()[embedding_result] - - def test_gt(self): - class BasicBlock(torch.nn.Module): - def __init__(self): - super(BasicBlock, self).__init__() - - def forward(self, x: TensorType([Dyn, 4])): - size = x.size() - getitem_1 = size[-1] - gt = getitem_1 > 1 - return gt - - ast_rewriter = RewritingTracer() - graph = ast_rewriter.trace(BasicBlock()) - traced = GraphModule(ast_rewriter.root, graph, "gm") - - transformed = transform_all_constraints(traced, counter=0) - s = z3.Solver() - s.add(transformed) - self.assertEquals(s.check(), z3.sat) - res = z3.Bool(4) - self.assertEqual(s.model()[res], True) - - def test_view(self): - class BasicBlock(torch.nn.Module): - def __init__(self): - super(BasicBlock, self).__init__() - - def forward(self, x: TensorType([2, 4])): - view = x.view(-1, 8) - return view - - ast_rewriter = RewritingTracer() - graph = ast_rewriter.trace(BasicBlock()) - traced = GraphModule(ast_rewriter.root, graph, "gm") - - transformed = transform_all_constraints(traced, counter=0) - s = z3.Solver() - s.add(transformed) - self.assertEquals(s.check(), z3.sat) - - def test_lt_tensor(self): - class BasicBlock(torch.nn.Module): - def __init__(self): - super(BasicBlock, self).__init__() - - def forward(self, x: TensorType([2, 4]), y: Dyn): - lt = x > y - return lt - - ast_rewriter = RewritingTracer() - graph = ast_rewriter.trace(BasicBlock()) - traced = GraphModule(ast_rewriter.root, graph, "gm") - - transformed = transform_all_constraints(traced, counter=0) - s = z3.Solver() - s.add(transformed) - self.assertEquals(s.check(), z3.sat) - - - def test_conditional_wrong_assumption(self): - """ - Test condition after making the wrong assumption about the input - """ - class BasicBlock(torch.nn.Module): - def __init__(self): - super(BasicBlock, self).__init__() - - def forward(self, x: Dyn): - gt = x > 1 - return gt - - ast_rewriter = RewritingTracer() - graph = ast_rewriter.trace(BasicBlock()) - - # The node we are considering is the gt node - for n in graph.nodes: - if n.target == operator.gt: - node = n - - positive, negative = evaluate_conditional_with_constraints(ast_rewriter.root, graph, node) - - self.assertEqual(positive, z3.sat) - self.assertEqual(negative, z3.sat) - - def test_conditional(self): - """ - This test case is for the HFmodels interface. - A function takes a node and a graph and considers - the conditional the node represents and its negation - and solves each formula with the remaining sets of constraints - Returns: - - """ - class BasicBlock(torch.nn.Module): - def __init__(self): - super(BasicBlock, self).__init__() - self.embed_tokens = torch.nn.Embedding(256008, 1024, padding_idx=1) - - def forward(self, x: TensorType([Dyn, 4])): - size = x.size() - getitem = size[-1] - view = x.view(-1, getitem) - embed_tokens = self.embed_tokens(view) - mul = embed_tokens * 32.0 - getitem_1 = size[-1] - gt = getitem_1 > 1 - return gt - - ast_rewriter = RewritingTracer() - graph = ast_rewriter.trace(BasicBlock()) - - # The node we are considering is the gt node - for n in graph.nodes: - if n.target == operator.gt: - node = n - - positive, negative = evaluate_conditional_with_constraints(ast_rewriter.root, graph, node) - self.assertEqual(positive, z3.sat) - self.assertEqual(negative, z3.unsat) - - # change the annotation to Dyn - for n in graph.nodes: - if n.op == 'placeholder': - n.type = Dyn - - # here, both should be SAT since the input is Dyn - positive, negative = evaluate_conditional_with_constraints(ast_rewriter.root, graph, node) - - self.assertEqual(positive, z3.sat) - self.assertEqual(negative, z3.sat) - - - # change the annotation to TensorType[Dyn, Dyn] - for n in graph.nodes: - if n.op == 'placeholder': - n.type = TensorType([Dyn, Dyn]) - - # here, both should be SAT as well - positive, negative = evaluate_conditional_with_constraints(ast_rewriter.root, graph, node) - - self.assertEqual(positive, z3.sat) - self.assertEqual(negative, z3.sat) - - - def test_conditional_2(self): - """ - This test case is for the HFmodels interface. - A function takes a node and a graph and considers - the conditional the node represents and its negation - and solves each formula with the remaining sets of constraints - Returns the opposite result of the above testcase - - """ - class BasicBlock(torch.nn.Module): - def __init__(self): - super(BasicBlock, self).__init__() - self.embed_tokens = torch.nn.Embedding(256008, 1024, padding_idx=1) - - def forward(self, x: TensorType([Dyn, 4])): - size = x.size() - getitem = size[-1] - view = x.view(-1, getitem) - embed_tokens = self.embed_tokens(view) - mul = embed_tokens * 32.0 - getitem_1 = size[-1] - lt = getitem_1 < 1 - return lt - - ast_rewriter = RewritingTracer() - graph = ast_rewriter.trace(BasicBlock()) - - # The node we are considering is the gt node - for n in graph.nodes: - if n.target == operator.lt: - node = n - - positive, negative = evaluate_conditional_with_constraints(ast_rewriter.root, graph, node) - self.assertEqual(positive, z3.unsat) - self.assertEqual(negative, z3.sat) - - -class ComposeOperationsGradualTypes(unittest.TestCase): - - def test_masked_fill(self): - class BasicBlock(torch.nn.Module): - def __init__(self): - super(BasicBlock, self).__init__() - - def forward(self, x: TensorType([2, 4])): - size = x.size() - getitem = size[-1] - arange = torch.arange(getitem) - view = x.view(-1, getitem) - lt = arange > view - masked_fill = x.masked_fill_(lt, 0) - return masked_fill - - B = BasicBlock().forward(torch.rand(2, 4)) - # print(B.shape) - - symbolic_traced: pippy.fx.GraphModule = meta_symbolic_trace(BasicBlock(), meta_args={}) - # print(symbolic_traced) - transformed = transform_all_constraints(symbolic_traced, counter=0) - s = z3.Solver() - s.add(transformed) - self.assertEqual(s.check(), z3.sat) - masked_fill_res = z3.Const(10, tensor_type) - self.assertEqual(s.model()[masked_fill_res].arg(0).arg(1).as_long(), B.size()[0]) - self.assertEqual(s.model()[masked_fill_res].arg(1).arg(1).as_long(), B.size()[1]) - - # change the annotation to Dyn. This will migrate to an arbitirary type - for n in symbolic_traced.graph.nodes: - if n.op == 'placeholder': - n.type = Dyn - - transformed = transform_all_constraints(symbolic_traced, counter=0) - s = z3.Solver() - s.add(transformed) - self.assertEqual(s.check(), z3.sat) - - for n in symbolic_traced.graph.nodes: - if n.op == 'placeholder': - n.type = TensorType([Dyn, Dyn, Dyn, Dyn]) - - transformed = transform_all_constraints(symbolic_traced, counter=0) - s = z3.Solver() - s.add(transformed) - self.assertEqual(s.check(), z3.sat) - - def test_add_reshape_1(self): - class BasicBlock(torch.nn.Module): - def __init__(self): - super(BasicBlock, self).__init__() - - def forward(self, x: Dyn, y: Dyn): - return torch.add(torch.reshape(x, (1, 2)), torch.reshape(y, (2, 2))) - - ast_rewriter = RewritingTracer() - graph = ast_rewriter.trace(BasicBlock()) - traced = GraphModule(ast_rewriter.root, graph, "gm") - - transformed = transform_all_constraints(traced, counter=0) - - s = z3.Solver() - s.add(transformed) - self.assertEquals(s.check(), z3.sat) - - def test_add_reshape_2(self): - class BasicBlock(torch.nn.Module): - def __init__(self): - super(BasicBlock, self).__init__() - - def forward(self, x: Dyn, y: Dyn): - return torch.add(torch.reshape(x, (-1, 2)), torch.reshape(y, (2, 2, 2))) - - ast_rewriter = RewritingTracer() - graph = ast_rewriter.trace(BasicBlock()) - traced = GraphModule(ast_rewriter.root, graph, "gm") - - transformed = transform_all_constraints(traced, counter=0) - s = z3.Solver() - s.add(transformed) - self.assertEquals(s.check(), z3.sat) - - def test_conv_reshape_add_0(self): - class BasicBlock(torch.nn.Module): - def __init__(self, in_planes, out_planes, kernel_size, stride, padding, groups, dilation): - super(BasicBlock, self).__init__() - self.conv1 = torch.nn.Conv2d(in_channels=in_planes, out_channels=out_planes, - kernel_size=kernel_size, stride=stride, - padding=padding, groups=groups, bias=False, dilation=dilation) - - def forward(self, x: Dyn, y: Dyn): - return torch.add(self.conv1(torch.reshape(x, (1, 2, 10, 20))), y) - - B = BasicBlock(2, 2, 2, 3, 2, 2, 2) - ast_rewriter = RewritingTracer() - graph = ast_rewriter.trace(B) - traced = GraphModule(ast_rewriter.root, graph, "gm") - new_transformed_c = transform_all_constraints(traced) - solver = z3.Solver() - solver.add(new_transformed_c) - self.assertEquals(solver.check(), z3.sat) - - - def test_conv_reshape_add_0_2(self): - class BasicBlock(torch.nn.Module): - def __init__(self, in_planes, out_planes, kernel_size, stride, padding, groups, dilation): - super(BasicBlock, self).__init__() - self.conv1 = torch.nn.Conv2d(in_channels=in_planes, out_channels=out_planes, - kernel_size=kernel_size, stride=stride, - padding=padding, groups=groups, bias=False, dilation=dilation) - - def forward(self, x: Dyn, y: TensorType([4, 1])): - return torch.add(self.conv1(torch.reshape(x, (1, 2, 10, 20))), y) - - B = BasicBlock(2, 2, 2, 3, 2, 2, 2) - - # 4,1 - # 1, 2, 4, 8 - res = B.forward(torch.rand(20, 20), torch.rand(1, 2, 4, 8)).size() - ast_rewriter = RewritingTracer() - graph = ast_rewriter.trace(B) - traced = GraphModule(ast_rewriter.root, graph, "gm") - new_transformed_c = transform_all_constraints(traced) - solver = z3.Solver() - solver.add(new_transformed_c) - self.assertEquals(solver.check(), z3.sat) - - - conv_result = z3.Const(4, tensor_type) - add_result = z3.Const(9, tensor_type) - input_2 = z3.Const(2, tensor_type) - - s1, s2, s3, s4 = z3.Ints('x1 x2 x3 x4') - s11, s22, s33, s44 = z3.Ints('x11 x22 x33 x44') - d1, d2, d3, d4 = D(s11, s1), D(s22, s2), D(s33, s3), D(s44, s4), - - - solver.add(conv_result == tensor_type.tensor4(d1, d2, d3, d4)) - solver.check() - assert solver.model()[s1].as_long() == res[0] - assert solver.model()[s2].as_long() == res[1] - assert solver.model()[s3].as_long() == res[2] - assert solver.model()[s4].as_long() == res[3] - - solver.add(input_2 == tensor_type.tensor2(D(1, 4), D(1, 1))) - self.assertEquals(solver.check(), z3.sat) - solver.add(add_result == tensor_type.tensor4(d1, d2, d3, d4)) - self.assertEquals(solver.check(), z3.sat) - - # first dimension could be anything because we have broadcasting - assert solver.model()[s1] == res[0] - assert solver.model()[s2] == res[1] - assert solver.model()[s3] == res[2] - assert solver.model()[s4] == res[3] - - def test_conv_reshape_add_0_3(self): - class BasicBlock(torch.nn.Module): - def __init__(self, in_planes, out_planes, kernel_size, stride, padding, groups, dilation): - super(BasicBlock, self).__init__() - self.conv1 = torch.nn.Conv2d(in_channels=in_planes, out_channels=out_planes, - kernel_size=kernel_size, stride=stride, - padding=padding, groups=groups, bias=False, dilation=dilation) - - def forward(self, x: Dyn, y: TensorType([11, 1])): - return torch.add(self.conv1(torch.reshape(x, (1, 2, 10, 20))), y) - - B = BasicBlock(2, 2, 2, 3, 2, 2, 2) - ast_rewriter = RewritingTracer() - graph = ast_rewriter.trace(B) - traced = GraphModule(ast_rewriter.root, graph, "gm") - new_transformed_c = transform_all_constraints(traced) - solver = z3.Solver() - solver.add(new_transformed_c) - self.assertEquals(solver.check(), z3.unsat) - - - def test_conv_reshape_add_1(self): - class BasicBlock(torch.nn.Module): - def __init__(self, in_planes, out_planes, kernel_size, stride, padding, groups, dilation): - super(BasicBlock, self).__init__() - self.conv1 = torch.nn.Conv2d(in_channels=in_planes, out_channels=out_planes, - kernel_size=kernel_size, stride=stride, - padding=padding, groups=groups, bias=False, dilation=dilation) - - def forward(self, x: Dyn, y: TensorType([1, 2, 10, 20])): - return torch.add(self.conv1(torch.reshape(x, (1, 2, 10, 20))), y) - - B = BasicBlock(2, 2, 2, 3, 2, 2, 2) - ast_rewriter = RewritingTracer() - graph = ast_rewriter.trace(B) - traced = GraphModule(ast_rewriter.root, graph, "gm") - new_transformed_c = transform_all_constraints(traced) - solver = z3.Solver() - solver.add(new_transformed_c) - self.assertEquals(solver.check(), z3.unsat) - - -class GradualTypes(unittest.TestCase): - def test_conv_reshape_unsat(self): - - class BasicBlock(torch.nn.Module): - def __init__(self, in_planes, out_planes, kernel_size, stride, padding, groups, dilation): - super(BasicBlock, self).__init__() - self.conv1 = torch.nn.Conv2d(in_channels=in_planes, out_channels=out_planes, - kernel_size=kernel_size, stride=stride, - padding=padding, groups=groups, bias=False, dilation=dilation) - - def forward(self, x: Dyn): - return self.conv1(torch.reshape(x, (1, 2, 10))) - - B = BasicBlock(2, 2, 2, 3, 2, 2, 2) - ast_rewriter = RewritingTracer() - graph = ast_rewriter.trace(B) - traced = GraphModule(ast_rewriter.root, graph, "gm") - new_transformed_c = transform_all_constraints(traced) - solver = z3.Solver() - solver.add(new_transformed_c) - self.assertEquals(solver.check(), z3.unsat) - - def test_conv_reshape0(self): - class BasicBlock(torch.nn.Module): - def __init__(self, in_planes, out_planes, kernel_size, stride, padding, groups, dilation): - super(BasicBlock, self).__init__() - self.conv1 = torch.nn.Conv2d(in_channels=in_planes, out_channels=out_planes, - kernel_size=kernel_size, stride=stride, - padding=padding, groups=groups, bias=False, dilation=dilation) - - def forward(self, x: Dyn): - return self.conv1(torch.reshape(x, (1, 2, 10, 20))) - - B = BasicBlock(2, 2, 2, 3, 2, 2, 2) - res = B.forward(torch.rand(20, 20)).size() - ast_rewriter = RewritingTracer() - graph = ast_rewriter.trace(B) - traced = GraphModule(ast_rewriter.root, graph, "gm") - new_transformed_c = transform_all_constraints(traced) - - solver = z3.Solver() - solver.add(new_transformed_c) - self.assertEquals(solver.check(), z3.sat) - conv_result = z3.Const(3, tensor_type) - - s1, s2, s3, s4 = z3.Ints('x1 x2 x3 x4') - s11, s22, s33, s44 = z3.Ints('x11 x22 x33 x44') - d1, d2, d3, d4 = D(s11, s1), D(s22, s2), D(s33, s3), D(s44, s4), - - solver.add(conv_result == tensor_type.tensor4(d1, d2, d3, d4)) - solver.check() - # print(solver.model()) - # print(type(solver.model()[s1])) - assert solver.model()[s1].as_long() == res[0] - assert solver.model()[s2].as_long() == res[1] - assert solver.model()[s3].as_long() == res[2] - assert solver.model()[s4].as_long() == res[3] - - s1, s2, s3, s4 = z3.Ints('y1 y2 y3 y4') - s11, s22, s33, s44 = z3.Ints('y11 y22 y33 y44') - d1, d2, d3, d4 = D(s11, s1), D(s22, s2), D(s33, s3), D(s44, s4), - - input = z3.Const(1, tensor_type) - solver.add(input == tensor_type.tensor4(d1, d2, d3, d4)) - - # assert solver.check() == sat - # solver.add(s11 == 1) - # solver.add(s22 == 1) - # solver.add(s33 == 1) - # solver.add(s44 == 1) - # - # print(solver.check()) - # print(solver.model()) - - - def test_conv_reshape1(self): - class BasicBlock(torch.nn.Module): - def __init__(self, in_planes, out_planes, kernel_size, stride, padding, groups, dilation): - super(BasicBlock, self).__init__() - self.conv1 = torch.nn.Conv2d(in_channels=in_planes, out_channels=out_planes, - kernel_size=kernel_size, stride=stride, - padding=padding, groups=groups, bias=False, dilation=dilation) - - def forward(self, x: TensorType([20, 20])): - return self.conv1(torch.reshape(x, (1, -1, 10, 20))) - - B = BasicBlock(2, 2, 2, 3, 2, 2, 2) - res = B.forward(torch.rand(20, 20)).size() - ast_rewriter = RewritingTracer() - graph = ast_rewriter.trace(B) - traced = GraphModule(ast_rewriter.root, graph, "gm") - new_transformed_c = transform_all_constraints(traced) - - solver = z3.Solver() - solver.add(new_transformed_c) - self.assertEquals(solver.check(), z3.sat) - conv_result = z3.Const(3, tensor_type) - - s1, s2, s3, s4 = z3.Ints('x1 x2 x3 x4') - s11, s22, s33, s44 = z3.Ints('x11 x22 x33 x44') - d1, d2, d3, d4 = D(s11, s1), D(s22, s2), D(s33, s3), D(s44, s4), - - solver.add(conv_result == tensor_type.tensor4(d1, d2, d3, d4)) - solver.check() - # print(solver.model()) - assert solver.model()[s1].as_long() == res[0] - assert solver.model()[s2].as_long() == res[1] - assert solver.model()[s3].as_long() == res[2] - assert solver.model()[s4].as_long() == res[3] - - -class TestSingleOperation(unittest.TestCase): - - def test_conv_wrong_example(self): - class BasicBlock(torch.nn.Module): - def __init__(self): - super(BasicBlock, self).__init__() - self.conv1 = torch.nn.Conv2d(in_channels=2, out_channels=2, - kernel_size=2, stride=2, - padding=2, groups=2, bias=False, dilation=2) - - self.conv2 = torch.nn.Conv2d(in_channels=4, out_channels=2, - kernel_size=2, stride=2, - padding=2, groups=2, bias=False, dilation=2) - - self.relu = torch.nn.ReLU(inplace=True) - - def forward(self, x: Dyn): - y = self.relu(self.conv1(x)) - z = self.relu(self.conv2(x)) - return z - - ast_rewriter = RewritingTracer() - graph = ast_rewriter.trace(BasicBlock()) - traced = GraphModule(ast_rewriter.root, graph, "gm") - - transformed = transform_all_constraints(traced) - - solver3 = z3.Solver() - solver3.add(transformed) - print(solver3.check()) - assert solver3.check() == z3.sat - - s1, s2, s3, s4 = z3.Ints('s1 s2 s3 s4') - s11, s22, s33, s44 = z3.Ints('s11 s22 s33 s44') - d1, d2, d3, d4 = D(s11, s1), D(s22, s2), D(s33, s3), D(s44, s4), - x = z3.Const(1, tensor_type) - solver3.add(x == tensor_type.tensor4(d1, d2, d3, d4)) - assert solver3.check() == z3.sat - - solver3.add(s22 != 0) - assert solver3.check() == z3.unsat - - def test_conv_dyn(self): - - s1, s2, s3, s4 = z3.Ints('s1 s2 s3 s4') - e1, e2, e3, e4 = z3.Ints('e1 e2 e3 e4') - s11, s22, s33, s44 = z3.Ints('s11 s22 s33 s44') - e11, e22, e33, e44 = z3.Ints('e11 e22 e33 e44') - d1, d2, d3, d4 = D(s11, s1), D(s22, s2), D(s33, s3), D(s44, s4), - b1, b2, b3, b4 = D(e11, e1), D(e22, e2), D(e33, e3), D(e44, e4) - - class BasicBlock(torch.nn.Module): - def __init__(self, in_planes, out_planes, kernel_size, stride, padding, groups, dilation): - super(BasicBlock, self).__init__() - self.conv1 = torch.nn.Conv2d(in_channels=in_planes, out_channels=out_planes, - kernel_size=kernel_size, stride=stride, - padding=padding, groups=groups, bias=False, dilation=dilation) - - def forward(self, x: Dyn): - return self.conv1(x) - - BasicBlock(2, 2, 2, 2, 2, 2, 2).forward(torch.rand(4, 2, 3, 4)) - - ast_rewriter = RewritingTracer() - graph = ast_rewriter.trace(BasicBlock(2, 2, 2, 2, 2, 2, 2)) - traced = GraphModule(ast_rewriter.root, graph, "gm") - - transformed = transform_all_constraints(traced) - - solver3 = z3.Solver() - solver3.add(transformed) - assert solver3.check() == z3.sat - - x = z3.Const(1, tensor_type) - y = z3.Const(2, tensor_type) - - solver3.add(x == tensor_type.tensor4(d1, d2, d3, d4), - y == tensor_type.tensor4(b1, b2, b3, b4)) - - assert solver3.check() == z3.sat - assert solver3.model()[s1].as_long() == solver3.model()[e1].as_long() - assert solver3.model()[s11].as_long() == solver3.model()[e11].as_long() - - solver3.add(s2 != 2) - assert solver3.check() == z3.sat - assert solver3.model()[s22].as_long() == 0 - - solver3.add(s22 != 0) - self.assertEquals(solver3.check(), z3.unsat) - - solver2 = z3.Solver() - solver2.add(transformed) - assert solver2.check() == z3.sat - solver2.add(x == tensor_type.tensor3(d1, d2, d3)) - self.assertEquals(solver2.check(), z3.unsat) - - - def test_add(self): - s1, s2, s3, s4 = z3.Ints('s1 s2 s3 s4') - s11, s22, s33, s44 = z3.Ints('s11 s22 s33 s44') - d1, d2, d3, d4 = D(s11, s1), D(s22, s2), D(s33, s3), D(s44, s4), - - class BasicBlock(torch.nn.Module): - def __init__(self): - super(BasicBlock, self).__init__() - - def forward(self, x: Dyn, y: Dyn): - return torch.add(x, y) - - ast_rewriter = RewritingTracer() - graph = ast_rewriter.trace(BasicBlock()) - traced = GraphModule(ast_rewriter.root, graph, "gm") - - transformed = transform_all_constraints(traced, counter=0) - - s = z3.Solver() - s.add(transformed) - self.assertEquals(s.check(), z3.sat) - - # make the tensor be of size 1 - x = z3.Const(1, tensor_type) - s.add(x == tensor_type.tensor1(D(1, s11))) - self.assertEquals(s.check(), z3.sat) - - y = z3.Const(2, tensor_type) - s.add(y == tensor_type.tensor1(D(1, s22))) - self.assertEquals(s.check(), z3.sat) - - s.add(s11 == 1) # tensor[1] - s.add(s22 == 2) # tensor[2] - self.assertEquals(s.check(), z3.sat) - - class BasicBlock2(torch.nn.Module): - def __init__(self): - super(BasicBlock2, self).__init__() - - def forward(self, x: TensorType((Dyn,)), y: Dyn): - return torch.add(x, y) - - ast_rewriter = RewritingTracer() - graph = ast_rewriter.trace(BasicBlock2()) - traced = GraphModule(ast_rewriter.root, graph, "gm") - - transformed = transform_all_constraints(traced) - s = z3.Solver() - s.add(transformed) - self.assertEquals(s.check(), z3.sat) - # make the tensor be of size 1 - x = z3.Const(1, tensor_type) - s.add(x == tensor_type.tensor1(D(1, s11))) - self.assertEquals(s.check(), z3.sat) - y = z3.Const(2, tensor_type) - s.add(y == tensor_type.tensor1(D(1, s22))) - self.assertEquals(s.check(), z3.sat) - s.add(s11 == 4) # tensor[4] - s.add(s22 == 5) # tensor[5] - self.assertEquals(s.check(), z3.unsat) - - class BasicBlock3(torch.nn.Module): - def __init__(self): - super(BasicBlock3, self).__init__() - - def forward(self, x: TensorType((Dyn,)), y: Dyn): - return torch.add(x, y) - - ast_rewriter = RewritingTracer() - graph = ast_rewriter.trace(BasicBlock3()) - traced = GraphModule(ast_rewriter.root, graph, "gm") - - transformed = transform_all_constraints(traced) - s = z3.Solver() - s.add(transformed) - x = z3.Const(1, tensor_type) - s.add(x == tensor_type.tensor2(d1, d2)) - self.assertEquals(s.check(), z3.unsat) - - def test_add_padding(self): - s1, s2, s3, s4 = z3.Ints('s1 s2 s3 s4') - - class BasicBlock(torch.nn.Module): - def __init__(self): - super(BasicBlock, self).__init__() - - def forward(self, x: TensorType((Dyn,)), y: TensorType((Dyn, Dyn))): - return torch.add(x, y) - - ast_rewriter = RewritingTracer() - graph = ast_rewriter.trace(BasicBlock()) - traced = GraphModule(ast_rewriter.root, graph, "gm") - - transformed = transform_all_constraints(traced, counter=0) - - s = z3.Solver() - s.add(transformed) - self.assertEquals(s.check(), z3.sat) - - x = z3.Const(1, tensor_type) - s.add(x == tensor_type.tensor1(D(1, s1))) - - self.assertEquals(s.check(), z3.sat) - - # print(s.model()) - - def test_add_padding_2(self): - s1, s2, s3, s4 = z3.Ints('s1 s2 s3 s4') - - class BasicBlock(torch.nn.Module): - def __init__(self): - super(BasicBlock, self).__init__() - - def forward(self, x: TensorType([Dyn, Dyn]), y: TensorType([Dyn])): - return torch.add(x, y) - - ast_rewriter = RewritingTracer() - graph = ast_rewriter.trace(BasicBlock()) - traced = GraphModule(ast_rewriter.root, graph, "gm") - - transformed = transform_all_constraints(traced, counter=0) - - s = z3.Solver() - s.add(transformed) - self.assertEquals(s.check(), z3.sat) - # print(s.model()) - - x = z3.Const(1, tensor_type) - s.add(x == tensor_type.tensor2(D(1, s1), D(1, s2))) - self.assertEquals(s.check(), z3.sat) - - y = z3.Const(2, tensor_type) - s.add(y == tensor_type.tensor1(D(0, s3))) - self.assertEquals(s.check(), z3.sat) - - add_result = z3.Const(3, tensor_type) - broadcast_res1, broadcast_res2 = z3.Const(4, tensor_type), z3.Const(5, tensor_type) - - # print(s.model()) - - assert s.model()[broadcast_res1].decl() == tensor_type.tensor2 - assert s.model()[broadcast_res2].decl() == tensor_type.tensor2 - assert s.model()[add_result].decl() == tensor_type.tensor2 - assert s.model()[y].decl() == tensor_type.tensor1 - - # print(s.model()) - - # prevent broadcasting for that dimension - s.add(s2 > 1) - - assert s.check() - - # the second dimension of the result is a number, not Dyn. - # however if the first input dimension had been 1, we would - # have had dyn in the result, as seen in the next test case - assert s.model()[add_result].arg(1).arg(0).as_long() != 0 - - def test_add_padding_3(self): - s1, s2, s3, s4 = z3.Ints('s1 s2 s3 s4') - - class BasicBlock(torch.nn.Module): - def __init__(self): - super(BasicBlock, self).__init__() - - def forward(self, x: TensorType([Dyn, 1]), y: TensorType([Dyn])): - return torch.add(x, y) - - ast_rewriter = RewritingTracer() - graph = ast_rewriter.trace(BasicBlock()) - traced = GraphModule(ast_rewriter.root, graph, "gm") - - transformed = transform_all_constraints(traced, counter=0) - - s = z3.Solver() - s.add(transformed) - # print(transformed) - self.assertEquals(s.check(), z3.sat) - - x = z3.Const(1, tensor_type) - y = z3.Const(2, tensor_type) - - s.add(s2 != 0) - s.add(x == tensor_type.tensor2(D(0, s1), D(s2, 1))) - s.add(y == tensor_type.tensor1(D(0, s3))) - - self.assertEquals(s.check(), z3.sat) - - # print(s.model()) - - add_result = z3.Const(3, tensor_type) - assert s.model()[add_result].arg(0).arg(0).as_long() == 0 - assert s.model()[add_result].arg(1).arg(0).as_long() == 0 - - - def test_add_padding_4(self): - class BasicBlock(torch.nn.Module): - def __init__(self): - super(BasicBlock, self).__init__() - - def forward(self, x: TensorType([2, 1]), y: TensorType([3])): - return torch.add(x, y) - - ast_rewriter = RewritingTracer() - graph = ast_rewriter.trace(BasicBlock()) - traced = GraphModule(ast_rewriter.root, graph, "gm") - - transformed = transform_all_constraints(traced, counter=0) - - s = z3.Solver() - s.add(transformed) - - self.assertEquals(s.check(), z3.sat) - - add_result = z3.Const(3, tensor_type) - assert s.model()[add_result] == tensor_type.tensor2(D(1, 2), D(1, 3)) - - def test_add_padding_5(self): - class BasicBlock(torch.nn.Module): - def __init__(self): - super(BasicBlock, self).__init__() - - def forward(self, x: TensorType([2, 2]), y: TensorType([3])): - return torch.add(x, y) - - ast_rewriter = RewritingTracer() - graph = ast_rewriter.trace(BasicBlock()) - traced = GraphModule(ast_rewriter.root, graph, "gm") - - transformed = transform_all_constraints(traced, counter=0) - - s = z3.Solver() - s.add(transformed) - self.assertEquals(s.check(), z3.unsat) - - def test_add_size_3(self): - - class BasicBlock(torch.nn.Module): - def __init__(self): - super(BasicBlock, self).__init__() - - def forward(self, x: TensorType([Dyn, Dyn, Dyn]), y: TensorType([Dyn, Dyn, Dyn])): - return torch.add(x, y) - - ast_rewriter = RewritingTracer() - graph = ast_rewriter.trace(BasicBlock()) - traced = GraphModule(ast_rewriter.root, graph, "gm") - - transformed = transform_all_constraints(traced, counter=0) - - s = z3.Solver() - s.add(transformed) - self.assertEquals(s.check(), z3.sat) - - x = z3.Const(1, tensor_type) - y = z3.Const(2, tensor_type) - - s1, s2, s3, s4, s5 = z3.Ints('s1 s2 s3 s4 s5') - - s.add(x == tensor_type.tensor3(D(1, s1), D(1, 1), D(1, s2))) - s.add(y == tensor_type.tensor3(D(1, s3), D(1, s4), D(1, s5))) - - self.assertEquals(s.check(), z3.sat) - s.add(s2 == 5) - self.assertEquals(s.check(), z3.sat) - s.add(s5 == 6) - self.assertEquals(s.check(), z3.unsat) - - def test_add_padding_6(self): - - class BasicBlock(torch.nn.Module): - def __init__(self): - super(BasicBlock, self).__init__() - - def forward(self, x: TensorType([Dyn]), y: TensorType([Dyn, Dyn, Dyn])): - return torch.add(x, y) - - ast_rewriter = RewritingTracer() - graph = ast_rewriter.trace(BasicBlock()) - traced = GraphModule(ast_rewriter.root, graph, "gm") - - transformed = transform_all_constraints(traced, counter=0) - s = z3.Solver() - s.add(transformed) - self.assertEquals(s.check(), z3.sat) - - x = z3.Const(1, tensor_type) - y = z3.Const(2, tensor_type) - - s1, s2, s3, s4, s5 = z3.Ints('s1 s2 s3 s4 s5') - - s.add(x == tensor_type.tensor1(D(1, s1))) - s.add(y == tensor_type.tensor3(D(1, s2), D(1, s3), D(1, s4))) - - self.assertEquals(s.check(), z3.sat) - - s.add(s1 == 4) - s.add(s4 == 5) - - self.assertEquals(s.check(), z3.unsat) - - def test_add_padding_7(self): - - class BasicBlock(torch.nn.Module): - def __init__(self): - super(BasicBlock, self).__init__() - - def forward(self, x: TensorType([Dyn]), y: TensorType([Dyn, Dyn, Dyn, Dyn])): - return torch.add(x, y) - - ast_rewriter = RewritingTracer() - graph = ast_rewriter.trace(BasicBlock()) - traced = GraphModule(ast_rewriter.root, graph, "gm") - - transformed = transform_all_constraints(traced, counter=0) - s = z3.Solver() - s.add(transformed) - self.assertEquals(s.check(), z3.sat) - x = z3.Const(1, tensor_type) - s1, s2, s3, s4, s5 = z3.Ints('s1 s2 s3 s4 s5') - s.add(x == tensor_type.tensor2(D(s1, s2), D(s2, s3))) - self.assertEquals(s.check(), z3.unsat) - - - def test_add_padding_8(self): - - class BasicBlock(torch.nn.Module): - def __init__(self): - super(BasicBlock, self).__init__() - - def forward(self, x: TensorType([Dyn]), y: TensorType([Dyn, Dyn, Dyn, Dyn])): - return torch.add(x, y) - - ast_rewriter = RewritingTracer() - graph = ast_rewriter.trace(BasicBlock()) - traced = GraphModule(ast_rewriter.root, graph, "gm") - - transformed = transform_all_constraints(traced, counter=0) - s = z3.Solver() - s.add(transformed) - self.assertEquals(s.check(), z3.sat) - x = z3.Const(1, tensor_type) - y = z3.Const(2, tensor_type) - - s1, s2, s3, s4, s5 = z3.Ints('s1 s2 s3 s4 s5') - s.add(x == tensor_type.tensor1(D(s1, 1))) - s.add(s1 >= 0) - - self.assertEquals(s.check(), z3.sat) - - s.add(y == tensor_type.tensor4(D(0, s2), D(0, s3), D(0, s4), D(0, s5))) - self.assertEquals(s.check(), z3.sat) - - def test_add_padding_9(self): - - class BasicBlock(torch.nn.Module): - def __init__(self): - super(BasicBlock, self).__init__() - - def forward(self, x: Dyn, y: TensorType([Dyn, Dyn, Dyn, Dyn])): - return torch.add(x, y) - - ast_rewriter = RewritingTracer() - graph = ast_rewriter.trace(BasicBlock()) - traced = GraphModule(ast_rewriter.root, graph, "gm") - - transformed = transform_all_constraints(traced, counter=0) - s = z3.Solver() - s.add(transformed) - - self.assertEquals(s.check(), z3.sat) - x = z3.Const(1, tensor_type) - y = z3.Const(2, tensor_type) - - s1, s2, s3, s4, s5, s6, s7 = z3.Ints('s1 s2 s3 s4 s5 s6 s7') - s.add(x == tensor_type.tensor1(D(s1, s7))) - s.add(s1 == 1) - self.assertEquals(s.check(), z3.sat) - - s.add(y == tensor_type.tensor4(D(0, s2), D(0, s3), D(0, s4), D(s6, s5))) - self.assertEquals(s.check(), z3.sat) - - s.add(s6 == 1) - - self.assertEquals(s.check(), z3.sat) - s.add(s5 != 1, s7 != 1) - assert s.check() - - assert s.model()[s5].as_long() == s.model()[s7].as_long() - - def test_conv_static(self): - s1, s2, s3, s4 = z3.Ints('s1 s2 s3 s4') - e1, e2, e3, e4 = z3.Ints('e1 e2 e3 e4') - s11, s22, s33, s44 = z3.Ints('s11 s22 s33 s44') - e11, e22, e33, e44 = z3.Ints('e11 e22 e33 e44') - d1, d2, d3, d4 = D(s11, s1), D(s22, s2), D(s33, s3), D(s44, s4), - b1, b2, b3, b4 = D(e11, e1), D(e22, e2), D(e33, e3), D(e44, e4) - - class BasicBlock(torch.nn.Module): - def __init__(self, in_planes, out_planes, kernel_size, stride, padding, groups, dilation): - super(BasicBlock, self).__init__() - self.conv1 = torch.nn.Conv2d(in_channels=in_planes, out_channels=out_planes, - kernel_size=kernel_size, stride=stride, - padding=padding, dilation=dilation) - - def forward(self, x: TensorType((1, 2, 10, 20))): - return self.conv1(x) - - ast_rewriter = RewritingTracer() - - B = BasicBlock(2, 2, 2, 3, 2, 2, 2) - res = B.forward(torch.rand(1, 2, 10, 20)).size() - - graph = ast_rewriter.trace(B) - traced = GraphModule(ast_rewriter.root, graph, "gm") - new_transformed_c = transform_all_constraints(traced) - solver = z3.Solver() - solver.add(new_transformed_c) - self.assertEquals(solver.check(), z3.sat) - - x = z3.Const(1, tensor_type) - y = z3.Const(2, tensor_type) - - solver.add(x == tensor_type.tensor4(d1, d2, d3, d4)) - solver.add(y == tensor_type.tensor4(b1, b2, b3, b4)) - self.assertEquals(solver.check(), z3.sat) - # print(solver.model()) - assert solver.model()[e3].as_long() == res[2] - assert solver.model()[e4].as_long() == res[3] - - B2 = BasicBlock(2, 4, 5, 2, 9, 2, 2) - res2 = B2.forward(torch.rand(1, 2, 10, 20)).size() - - graph2 = ast_rewriter.trace(B2) - traced2 = GraphModule(ast_rewriter.root, graph2, "gm") - new_transformed_c = transform_all_constraints(traced2) - solver = z3.Solver() - solver.add(new_transformed_c) - - solver.add(x == tensor_type.tensor4(d1, d2, d3, d4)) - solver.add(y == tensor_type.tensor4(b1, b2, b3, b4)) - - self.assertEquals(solver.check(), z3.sat) - assert solver.model()[e3].as_long() == res2[2] - assert solver.model()[e4].as_long() == res2[3] - - def test_reshape_dyn(self): - s11, s22, s33, s44 = z3.Ints('s11 s22 s33 s44') - - class BasicBlock(torch.nn.Module): - def __init__(self): - super(BasicBlock, self).__init__() - - def forward(self, x: Dyn): - return torch.reshape(x, (2, -1)) - - ast_rewriter = RewritingTracer() - graph = ast_rewriter.trace(BasicBlock()) - traced = GraphModule(ast_rewriter.root, graph, "gm") - - transformed = transform_all_constraints(traced) - s = z3.Solver() - s.add(transformed) - self.assertEquals(s.check(), z3.sat) - x = z3.Const(1, tensor_type) - s.add(x == tensor_type.tensor1(D(1, s11))) - self.assertEquals(s.check(), z3.sat) - s.add(z3.Or([s11 == 2, s11 == 4, s11 == 9])) - self.assertEquals(s.check(), z3.sat) - s.add(s11 == 9) - self.assertEquals(s.check(), z3.unsat) - - - def test_reshape_annotated(self): - s1, s2, s3, s4 = z3.Ints('s1 s2 s3 s4') - s11, s22, s33, s44 = z3.Ints('s11 s22 s33 s44') - d1, d2, d3, d4 = D(s11, s1), D(s22, s2), D(s33, s3), D(s44, s4), - - class BasicBlock(torch.nn.Module): - def __init__(self): - super(BasicBlock, self).__init__() - - def forward(self, x: TensorType([Dyn])): - return torch.reshape(x, (2, -1)) - - ast_rewriter = RewritingTracer() - graph = ast_rewriter.trace(BasicBlock()) - traced = GraphModule(ast_rewriter.root, graph, "gm") - transformed = transform_all_constraints(traced) - s = z3.Solver() - s.add(transformed) - self.assertEquals(s.check(), z3.sat) - x = z3.Const(1, tensor_type) - s.add(x == tensor_type.tensor2(d1, d2)) - self.assertEquals(s.check(), z3.unsat) - - def test_reshape_static_target(self): - s11, s22, s33, s44 = z3.Ints('s11 s22 s33 s44') - - class BasicBlock(torch.nn.Module): - def __init__(self): - super(BasicBlock, self).__init__() - - def forward(self, x: TensorType([Dyn])): - return torch.reshape(x, (2, 3)) - - ast_rewriter = RewritingTracer() - graph = ast_rewriter.trace(BasicBlock()) - traced = GraphModule(ast_rewriter.root, graph, "gm") - transformed = transform_all_constraints(traced) - # print(transformed) - s = z3.Solver() - s.add(transformed) - self.assertEquals(s.check(), z3.sat) - x = z3.Const(1, tensor_type) - s.add(x == tensor_type.tensor1(D(1, s11))) - s.check() - assert s.model()[s11].as_long() == 6 - s.add(s11 != 6) - self.assertEquals(s.check(), z3.unsat) - - def test_reshape_static_target2(self): - s11, s22, s33, s44 = z3.Ints('s11 s22 s33 s44') - - class BasicBlock(torch.nn.Module): - def __init__(self): - super(BasicBlock, self).__init__() - - def forward(self, x: Dyn): - return torch.reshape(x, (2, 3, 1, 1)) - - ast_rewriter = RewritingTracer() - graph = ast_rewriter.trace(BasicBlock()) - traced = GraphModule(ast_rewriter.root, graph, "gm") - transformed = transform_all_constraints(traced) - s = z3.Solver() - s.add(transformed) - self.assertEquals(s.check(), z3.sat) - x = z3.Const(1, tensor_type) - s.add(x == tensor_type.tensor1(D(1, s11))) - s.check() - assert s.model()[s11].as_long() == 6 - s.add(s11 != 6) - self.assertEquals(s.check(), z3.unsat) - - - def test_conv2D_maxpool2d_flatten(self): - class BasicBlock(torch.nn.Module): - def __init__(self): - super(BasicBlock, self).__init__() - - self.conv1 = torch.nn.Conv2d(3, 6, 5) - self.pool = torch.nn.MaxPool2d(2, 2) - self.conv2 = torch.nn.Conv2d(6, 16, 5) - self.fc1 = torch.nn.Linear(5, 120) - self.pool2 = torch.nn.AdaptiveAvgPool2d((6, 7)) - - def forward(self, x : TensorType((4, 3, 32, 32))): - out = self.conv1(x) - out = self.pool(out) - out = self.conv2(out) - out = self.pool(out) - out = self.fc1(out) - out = self.pool2(out) - out = torch.flatten(out, 1) - return out - - B = BasicBlock() - res = B.forward(torch.rand(4, 3, 32, 32)).shape - ast_rewriter = RewritingTracer() - graph = ast_rewriter.trace(B) - traced = GraphModule(ast_rewriter.root, graph, "gm") - constraints = transform_all_constraints(traced, counter=0) - solver = z3.Solver() - solver.add(constraints) - solver.check() - input = z3.Const(1, tensor_type) - solver.add(input == tensor_type.tensor4(D(1, 4), D(1, 3), D(1, 32), D(1, 32))) - solver.check() - output = z3.Const(48, tensor_type) - assert solver.model()[output].arg(0).arg(1) == res[0] - assert solver.model()[output].arg(1).arg(1) == res[1] - - def test_conv2D_maxpool2d_flatten_unsat(self): - class BasicBlock(torch.nn.Module): - def __init__(self): - super(BasicBlock, self).__init__() - - self.conv1 = torch.nn.Conv2d(3, 6, 5) - self.pool = torch.nn.MaxPool2d(2, 2) - self.conv2 = torch.nn.Conv2d(6, 16, 5) - self.fc1 = torch.nn.Linear(5, 120) - self.pool2 = torch.nn.AdaptiveAvgPool2d((6, 7)) - - def forward(self, x : TensorType((4, 3, 32, 32))): - out = self.conv1(x) - out = self.pool(out) - out = self.conv2(out) - out = self.pool(out) - out = self.fc1(out) - out = self.pool2(out) - out = torch.flatten(out, 1) - return out - - B = BasicBlock() - ast_rewriter = RewritingTracer() - graph = ast_rewriter.trace(B) - traced = GraphModule(ast_rewriter.root, graph, "gm") - constraints = transform_all_constraints(traced, counter=0) - solver = z3.Solver() - solver.add(constraints) - solver.check() - input = z3.Const(1, tensor_type) - solver.add(input == tensor_type.tensor4(D(1, 4), D(1, 3), D(1, 32), D(1, 45))) - self.assertEquals(solver.check(), z3.unsat) - - def test_conv2D_maxpool2d_flatten_dyn(self): - class BasicBlock(torch.nn.Module): - def __init__(self): - super(BasicBlock, self).__init__() - - self.conv1 = torch.nn.Conv2d(3, 6, 5) - self.pool = torch.nn.MaxPool2d(2, 2) - self.conv2 = torch.nn.Conv2d(6, 16, 5) - self.fc1 = torch.nn.Linear(5, 120) - self.pool2 = torch.nn.AdaptiveAvgPool2d((6, 7)) - - def forward(self, x : TensorType((Dyn, 3, 32, 32))): - out = self.conv1(x) - out = self.pool(out) - out = self.conv2(out) - out = self.pool(out) - out = self.fc1(out) - out = self.pool2(out) - out = torch.flatten(out, 1) - return out - - B = BasicBlock() - ast_rewriter = RewritingTracer() - graph = ast_rewriter.trace(B) - traced = GraphModule(ast_rewriter.root, graph, "gm") - constraints = transform_all_constraints(traced, counter=0) - solver = z3.Solver() - solver.add(constraints) - self.assertEquals(solver.check(), z3.sat) - - def test_type_check_flatten(self): - s1, s2, s3, s4 = z3.Ints('s1 s2 s3 s4') - - class M(torch.nn.Module): - def forward(self, x: TensorType([2, 3, 4, 5])): - return torch.flatten(x, start_dim=1, end_dim=3) - - module = M() - symbolic_traced: pippy.fx.GraphModule = symbolic_trace(module) - constraints = transform_all_constraints(symbolic_traced, counter=0) - solver = z3.Solver() - solver.add(constraints) - self.assertEquals(solver.check(), z3.sat) - flatten = z3.Const(2, tensor_type) - - res = M().forward(torch.rand(2, 3, 4, 5)).size() - assert solver.model()[flatten].arg(0).arg(1) == res[0] - assert solver.model()[flatten].arg(1).arg(1) == res[1] - - class M(torch.nn.Module): - def forward(self, x: TensorType([2, 3, Dyn, 5])): - return torch.flatten(x, start_dim=1, end_dim=3) - - module = M() - symbolic_traced: pippy.fx.GraphModule = symbolic_trace(module) - constraints = transform_all_constraints(symbolic_traced, counter=0) - solver = z3.Solver() - solver.add(constraints) - self.assertEquals(solver.check(), z3.sat) - x = z3.Const(1, tensor_type) - y = z3.Const(2, tensor_type) - - solver.add(x == tensor_type.tensor4(D(1, 2), D(1, 3), D(0, s1), D(1, 5))) - self.assertEquals(solver.check(), z3.sat) - assert solver.model()[y].arg(1).arg(0) == 0 - - - class M(torch.nn.Module): - def forward(self, x: TensorType([2, 3, Dyn])): - return torch.flatten(x, 10, 0) - - module = M() - # print(module.forward(torch.rand(2,3,5)).shape) - symbolic_traced: pippy.fx.GraphModule = symbolic_trace(module) - constraints = transform_all_constraints(symbolic_traced, counter=0) - solver = z3.Solver() - solver.add(constraints) - self.assertEquals(solver.check(), z3.unsat) - -class ConstraintGeneration(unittest.TestCase): - - def test_add_reshape(self): - class BasicBlock(torch.nn.Module): - def __init__(self): - super(BasicBlock, self).__init__() - - def forward(self, x: Dyn, y: Dyn): - return torch.add(torch.reshape(x, (1, 2)), torch.reshape(y, (2, 2))) - - ast_rewriter = RewritingTracer() - graph = ast_rewriter.trace(BasicBlock()) - traced = GraphModule(ast_rewriter.root, graph, "gm") - - generator = ConstraintGenerator(traced) - new_constraints, counter = generator.generate_constraints(0) - assert len(new_constraints.conjucts) == 11 - - - def test_conv_reshape_add(self): - class BasicBlock(torch.nn.Module): - def __init__(self, in_planes, out_planes, kernel_size, stride, padding, groups, dilation): - super(BasicBlock, self).__init__() - self.conv1 = torch.nn.Conv2d(in_channels=in_planes, out_channels=out_planes, - kernel_size=kernel_size, stride=stride, - padding=padding, groups=groups, bias=False, dilation=dilation) - - def forward(self, x: Dyn, y: Dyn): - return torch.add(self.conv1(torch.reshape(x, (1, 2, 10, 20))), y) - - B = BasicBlock(2, 2, 2, 3, 2, 2, 2) - ast_rewriter = RewritingTracer() - graph = ast_rewriter.trace(B) - traced = GraphModule(ast_rewriter.root, graph, "gm") - - generator = ConstraintGenerator(traced) - new_constraints, counter = generator.generate_constraints(0) - assert len(new_constraints.conjucts) == 16 - - -class TestInternalConstraints(unittest.TestCase): - def test_precision(self): - - c1 = BinConstraintT(Dyn, TVar('x'), op_precision) - transformed, _ = transform_constraint(c1, 0) - assert transformed == T() - - c2 = BinConstraintT(TensorType([1, Dyn, 3]), TVar('x'), op_precision) - transformed, counter = transform_constraint(c2, 0) - assert len(transformed.conjucts) == 7 - - def test_matching(self): - c1 = BinConstraintT(TVar('x'), - TensorType([DVar('a'), DVar('b'), DVar('c'), DVar('d')]), op_matching) - transformed, _ = transform_constraint(c1, 0) - assert len(transformed.disjuncts) == 2 - - def test_consistency(self): - c1 = BinConstraintT(TVar('x'), - TensorType([DVar('a'), DVar('b')]), op_consistency) - transformed, count = transform_constraint(c1, 0) - - assert len(transformed.disjuncts) == 5 - transformed, count = transform_constraint(transformed, count) - assert len(transformed.disjuncts) == 5 - - # def test_apply_broadcasting(self): - # c1 = ApplyBroadcasting(TVar(1), TVar(2), TVar(3), TVar(4)) - # transformed, count = transform_apply_broadcasting(c1, 5) - # assert len(transformed.conjucts) == 41 - -@skipIfNoTorchVision -class TestResNet(unittest.TestCase): - - def test_resnet50_unsat(self): - traced = symbolic_trace(models.resnet50()) - for n in traced.graph.nodes: - n.type = Dyn - - constraints = transform_all_constraints(traced, counter=0) - solver = z3.Solver() - solver.add(constraints) - input = z3.Const(1, tensor_type) - # input with 3 dimensions - solver.add(input == tensor_type.tensor3(D(1, 1), D(1, 3), D(1, 224))) - self.assertEquals(solver.check(), z3.unsat) - - - - def test_resnet50(self): - traced = symbolic_trace(models.resnet50()) - for n in traced.graph.nodes: - n.type = Dyn - - sample_input = torch.randn(1, 3, 224, 224) - res = models.resnet50().forward(sample_input).size() - constraints = transform_all_constraints(traced, counter=0) - solver = z3.Solver() - solver.add(constraints) - self.assertEquals(solver.check(), z3.sat) - linear = z3.Const(650, tensor_type) - - input = z3.Const(1, tensor_type) - solver.add(input == tensor_type.tensor4(D(1, 1), D(1, 3), D(1, 224), D(1, 224))) - self.assertEquals(solver.check(), z3.sat) - assert solver.model()[linear] == tensor_type.tensor2(D(1, res[0]), D(1, res[1])) - - def test_resnet502(self): - traced = symbolic_trace(models.resnet50()) - for n in traced.graph.nodes: - n.type = Dyn - - constraints = transform_all_constraints(traced, counter=0) - solver = z3.Solver() - solver.add(constraints) - linear = z3.Const(650, tensor_type) - input = z3.Const(1, tensor_type) - batch = z3.Int('b') - solver.add(input == tensor_type.tensor4(D(1, batch), D(1, 3), D(1, 224), D(1, 224))) - solver.add(batch > 4) - solver.check() - assert solver.model()[batch] == solver.model()[linear].arg(0).arg(1) - - def test_resnet503(self): - traced = symbolic_trace(models.resnet50()) - for n in traced.graph.nodes: - n.type = Dyn - - constraints = transform_all_constraints(traced, counter=0) - solver = z3.Solver() - solver.add(constraints) - linear = z3.Const(650, tensor_type) - input = z3.Const(1, tensor_type) - batch, d1, d2 = z3.Ints('b d1 d2') - solver.add(input == tensor_type.tensor4(D(1, batch), D(1, 3), D(1, 224), D(1, 224))) - solver.add(linear == tensor_type.tensor2(D(1, d1), D(1, d2))) - self.assertEquals(solver.check(), z3.sat) - solver.add(batch != d1) - self.assertEquals(solver.check(), z3.unsat) - -@skipIfNoTorchVision -class TestAlexNet(unittest.TestCase): - def test_alexnet1(self): - - alexnet = models.alexnet() - symbolic_traced : pippy.fx.GraphModule = symbolic_trace(alexnet) - - for n in symbolic_traced.graph.nodes: - n.type = Dyn - - # print(symbolic_traced) - - res = alexnet.forward(torch.rand(10, 3, 227, 227)).size() - constraints = transform_all_constraints(symbolic_traced, counter=0) - solver = z3.Solver() - solver.add(constraints) - self.assertEquals(solver.check(), z3.sat) - input = z3.Const(1, tensor_type) - conv = z3.Const(2, tensor_type) - solver.add(input == tensor_type.tensor4(D(1, 10), D(1, 3), D(1, 227), D(1, 227))) - self.assertEquals(solver.check(), z3.sat) - assert solver.model()[conv] == tensor_type.tensor4(D(1, 10), D(1, 64), D(1, 56), D(1, 56)) - - relu = z3.Const(7, tensor_type) - assert solver.model()[relu] == tensor_type.tensor4(D(1, 10), D(1, 64), D(1, 56), D(1, 56)) - - maxpool = z3.Const(8, tensor_type) - assert solver.model()[maxpool] == tensor_type.tensor4(D(1, 10), D(1, 64), D(1, 27), D(1, 27)) - - maxpool2 = z3.Const(42, tensor_type) - assert solver.model()[maxpool2] == tensor_type.tensor4(D(1, 10), D(1, 256), D(1, 6), D(1, 6)) - - flatten = z3.Const(52, tensor_type) - assert solver.model()[flatten] == tensor_type.tensor2(D(1, 10), D(1, 9216)) - - linear = z3.Const(64, tensor_type) - assert solver.model()[linear] == tensor_type.tensor2(D(1, 10), D(1, 4096)) - - linear2 = z3.Const(109, tensor_type) - assert solver.model()[linear2] == tensor_type.tensor2(D(1, res[0]), D(1, res[1])) - - - def test_alexnet2(self): - alexnet = models.alexnet() - symbolic_traced : pippy.fx.GraphModule = symbolic_trace(alexnet) - - for n in symbolic_traced.graph.nodes: - if n.op == 'placeholder': - n.type = TensorType([Dyn, 4, 227, 227]) - - constraints = transform_all_constraints(symbolic_traced, counter=0) - solver = z3.Solver() - solver.add(constraints) - self.assertEquals(solver.check(), z3.unsat) - - def test_alexnet3(self): - alexnet = models.alexnet() - symbolic_traced : pippy.fx.GraphModule = symbolic_trace(alexnet) - - for n in symbolic_traced.graph.nodes: - if n.op == 'placeholder': - n.type = TensorType([Dyn, Dyn, 227, 227]) - - constraints = transform_all_constraints(symbolic_traced, counter=0) - solver = z3.Solver() - solver.add(constraints) - self.assertEquals(solver.check(), z3.sat) - - def test_alexnet4(self): - alexnet = models.alexnet() - symbolic_traced : pippy.fx.GraphModule = symbolic_trace(alexnet) - - for n in symbolic_traced.graph.nodes: - if n.op == 'placeholder': - n.type = TensorType([Dyn, Dyn, 227]) - - constraints = transform_all_constraints(symbolic_traced, counter=0) - solver = z3.Solver() - solver.add(constraints) - self.assertEquals(solver.check(), z3.unsat) - - - -if __name__ == '__main__': - unittest.main() diff --git a/test/local_test_compile.py b/test/local_test_compile.py deleted file mode 100644 index 6ee079912..000000000 --- a/test/local_test_compile.py +++ /dev/null @@ -1,110 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates -import argparse -import os -import unittest - -import pippy - -import torch -from pippy import run_pippy -from pippy.IR import pipe_split - -d_hid = 512 -bs = 256 - - -class ExampleCode(torch.nn.Module): - def __init__(self): - super().__init__() - self.mm_param = torch.nn.Parameter(torch.randn(d_hid, d_hid)) - self.mm_param2 = torch.nn.Parameter(torch.randn(d_hid, d_hid)) - self.lin = torch.nn.Linear(d_hid, d_hid) - - def forward(self, x): - x = torch.mm(x, self.mm_param) - skip_connection = x - x = torch.relu(x) - pipe_split() - x = torch.mm(x, self.mm_param) - x = self.lin(x) - pipe_split() - x = torch.relu(x) - x = x + skip_connection - x = torch.mm(x, self.mm_param2) - pipe_split() - x = self.lin(x) - x = torch.relu(x) - return {"out": x} - - -def run_master(_, args): - ec = ExampleCode() - ec.to(args.device) - ec_input = torch.randn(bs, d_hid, device=args.device) - - # Create pipeline model - pipe_ec = pippy.compile( - ec, - num_ranks=args.world_size, - num_chunks=4, - schedule=args.schedule, - checkpoint=bool(args.checkpoint), - _debug_mask_minibatches=True, # for numerical equivalence test only - ) - - # Warm up and correctness runs - out = pipe_ec(ec_input) - ref_out = ec(ec_input) - - # run with different chunk size to exercise microbatch and scheduling components - torch.testing.assert_close(out["out"], ref_out["out"]) - print( - f'equivalence test passed {torch.sum(out["out"])} ref {torch.sum(ref_out["out"])}' - ) - - -def main(args=None): - parser = argparse.ArgumentParser() - parser.add_argument( - "--world_size", type=int, default=int(os.getenv("WORLD_SIZE", 4)) - ) - parser.add_argument("--rank", type=int, default=int(os.getenv("RANK", -1))) - parser.add_argument( - "--master_addr", type=str, default=os.getenv("MASTER_ADDR", "localhost") - ) - parser.add_argument( - "--master_port", type=str, default=os.getenv("MASTER_PORT", "29500") - ) - parser.add_argument( - "-s", - "--schedule", - type=str, - default="FillDrain", - ) - parser.add_argument( - "--cuda", type=int, default=int(torch.cuda.is_available()) - ) - parser.add_argument("--checkpoint", type=int, default=0, choices=[0, 1]) - args = parser.parse_args(args) - - # Interleaved 1F1B uses less ranks than number of stages - if args.schedule == "Interleaved1F1B": - args.world_size = 2 - - run_pippy(run_master, args) - - -if __name__ == "__main__": - main() - - -class LocalTestCompileTest(unittest.TestCase): - def test_compile(self): - import random - - port = random.randint(29500, 30000) - args = [ - "--master_port", - str(port), - ] - main(args) diff --git a/test/local_test_forward.py b/test/local_test_forward.py deleted file mode 100644 index 924759cba..000000000 --- a/test/local_test_forward.py +++ /dev/null @@ -1,168 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates -import argparse -import os -import unittest - -import pippy.fx - -import torch -import torch.autograd.profiler_legacy -from pippy import run_pippy -from pippy.IR import MultiUseParameterConfig, Pipe, pipe_split -from pippy.PipelineDriver import ( - PipelineDriver1F1B, - PipelineDriverBase, - PipelineDriverFillDrain, - PipelineDriverInterleaved1F1B, -) - -PROFILING_ENABLED = True -CHECK_NUMERIC_EQUIVALENCE = True - -schedules = { - "FillDrain": PipelineDriverFillDrain, - "1F1B": PipelineDriver1F1B, - "Interleaved1F1B": PipelineDriverInterleaved1F1B, -} - -pippy.fx.Tracer.proxy_buffer_attributes = True - - -def run_master(_, args): - d_hid = 512 - bs = 503 - - MULTI_USE_PARAM_CONFIG = ( - MultiUseParameterConfig.REPLICATE - if args.replicate - else MultiUseParameterConfig.TRANSMIT - ) - print(f"REPLICATE config: {args.replicate} -> {MULTI_USE_PARAM_CONFIG}") - - print("Using schedule:", args.schedule) - - class ExampleCode(torch.nn.Module): - def __init__(self): - super().__init__() - self.mm_param = torch.nn.Parameter(torch.randn(d_hid, d_hid)) - self.mm_param2 = torch.nn.Parameter(torch.randn(d_hid, d_hid)) - self.lin = torch.nn.Linear(d_hid, d_hid) - self.register_buffer("buffer", torch.randn(bs + 100, d_hid)) - - def forward(self, x): - x = torch.mm(x, self.mm_param) - skip_connection = x - x = torch.relu(x) - pipe_split() - x = torch.mm(x, self.mm_param) + self.buffer[: x.shape[0]] - x = self.lin(x) - pipe_split() - x = torch.relu(x) - x = x + skip_connection - x = torch.mm(x, self.mm_param2) - x = self.lin(x) - pipe_split() - x = torch.relu(x) - return {"out": x} - - ec = ExampleCode() - ec.to(args.device) - ec_input = torch.randn(bs, d_hid, device=args.device) - ec(ec_input) - - ec_pipe = Pipe.from_tracing(ec, MULTI_USE_PARAM_CONFIG) - print(ec_pipe.split_gm) - - pipe_driver: PipelineDriverBase = schedules[args.schedule]( - ec_pipe, - 5, - args.world_size, - _debug_mask_minibatches=True, - _record_mem_dumps=bool(args.record_mem_dumps), - checkpoint=bool(args.checkpoint), - ) - - # # Warm up and correctness runs - out = pipe_driver(ec_input) - ref_out = ec_pipe(ec_input) - - # run with different chunk size to exercise microbatch and scheduling components - pipe_driver.chunks = 1 - pipe_driver(ec_input) - pipe_driver.chunks = 100 - pipe_driver(ec_input) - - if CHECK_NUMERIC_EQUIVALENCE: - torch.testing.assert_close(out["out"], ref_out["out"]) - print( - f'equivalence test passed {torch.sum(out["out"])} ref {torch.sum(ref_out["out"])}' - ) - - # # Profiling runs - with torch.autograd.profiler_legacy.profile( - enabled=PROFILING_ENABLED - ) as prof: - pipe_driver.chunks = 5 - out = pipe_driver(ec_input) - ref_out = ec_pipe(ec_input) - print( - f'profiling run completed {torch.sum(out["out"])} ref {torch.sum(ref_out["out"])}' - ) - if PROFILING_ENABLED: - prof.export_chrome_trace( - f"{os.path.splitext(os.path.basename(__file__))[0]}.json" - ) - - -def main(args=None): - parser = argparse.ArgumentParser() - parser.add_argument( - "--world_size", type=int, default=int(os.getenv("WORLD_SIZE", 4)) - ) - parser.add_argument("--rank", type=int, default=int(os.getenv("RANK", -1))) - parser.add_argument( - "--master_addr", type=str, default=os.getenv("MASTER_ADDR", "localhost") - ) - parser.add_argument( - "--master_port", type=str, default=os.getenv("MASTER_PORT", "29500") - ) - parser.add_argument( - "-s", - "--schedule", - type=str, - default=list(schedules.keys())[0], - choices=schedules.keys(), - ) - parser.add_argument( - "--replicate", type=int, default=int(os.getenv("REPLICATE", "0")) - ) - parser.add_argument( - "--cuda", type=int, default=int(torch.cuda.is_available()) - ) - parser.add_argument( - "--record_mem_dumps", type=int, default=0, choices=[0, 1] - ) - parser.add_argument("--checkpoint", type=int, default=0, choices=[0, 1]) - args = parser.parse_args(args) - - # Interleaved 1F1B uses less ranks than number of stages - if args.schedule == "Interleaved1F1B": - args.world_size = 2 - - run_pippy(run_master, args) - - -if __name__ == "__main__": - main() - - -class LocalTestForwardTest(unittest.TestCase): - def test_forward(self): - import random - - port = random.randint(29500, 30000) - args = [ - "--master_port", - str(port), - ] - main(args) diff --git a/test/local_test_visualizer.py b/test/local_test_visualizer.py deleted file mode 100644 index 95eacbe20..000000000 --- a/test/local_test_visualizer.py +++ /dev/null @@ -1,365 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates -import argparse -import os -import time -import unittest -from collections import defaultdict -from functools import reduce -from typing import Any, Dict, List - -import pippy.fx - -import torch -import torch.nn as nn -from pippy import run_pippy -from pippy.events import Event -from pippy.IR import ( - MultiUseParameterConfig, - Pipe, - pipe_split, - TrivialLossWrapper, -) -from pippy.PipelineDriver import ( - EventsContext, - Phase, - PipelineDriver1F1B, - PipelineDriverBase, - PipelineDriverFillDrain, - PipelineDriverInterleaved1F1B, -) -from pippy.visualizer import events_to_json -from torch.autograd import Function - -PROFILING_ENABLED = True -CHECK_NUMERIC_EQUIVALENCE = True - -schedules = { - "FillDrain": PipelineDriverFillDrain, - "1F1B": PipelineDriver1F1B, - "Interleaved1F1B": PipelineDriverInterleaved1F1B, -} - -pippy.fx.Tracer.proxy_buffer_attributes = True - - -@pippy.fx.wrap -def sleep(x, t=1.0): - time.sleep(t) - return x - - -class SlowMSELoss(nn.MSELoss): - def forward(self, input, target): - return super().forward(sleep(input, t=0.01), target) - - -# Inherit from Function -class MyLinearFunction(Function): - # Note that both forward and backward are @staticmethods - @staticmethod - # bias is an optional argument - def forward(ctx, input, weight, bias=None): - # print("my forward") - input = sleep(input, t=0.1) - ctx.save_for_backward(input, weight, bias) - output = input.mm(weight.t()) - if bias is not None: - output += bias.unsqueeze(0).expand_as(output) - return output - - # This function has only a single output, so it gets only one gradient - @staticmethod - def backward(ctx, grad_output): - # print("my backward") - grad_output = sleep(grad_output, t=0.3) - # This is a pattern that is very convenient - at the top of backward - # unpack saved_tensors and initialize all gradients w.r.t. inputs to - # None. Thanks to the fact that additional trailing Nones are - # ignored, the return statement is simple even when the function has - # optional inputs. - input, weight, bias = ctx.saved_tensors - grad_input = grad_weight = grad_bias = None - - # These needs_input_grad checks are optional and there only to - # improve efficiency. If you want to make your code simpler, you can - # skip them. Returning gradients for inputs that don't require it is - # not an error. - if ctx.needs_input_grad[0]: - grad_input = grad_output.mm(weight) - if ctx.needs_input_grad[1]: - grad_weight = grad_output.t().mm(input) - if bias is not None and ctx.needs_input_grad[2]: - grad_bias = grad_output.sum(0) - - return grad_input, grad_weight, grad_bias - - -@pippy.fx.wrap -def linear(input, weight, bias): - return MyLinearFunction.apply(input, weight, bias) - - -class MyLinear(nn.Module): - def __init__(self, input_features, output_features, bias=True): - super(MyLinear, self).__init__() - self.input_features = input_features - self.output_features = output_features - - # nn.Parameter is a special kind of Tensor, that will get - # automatically registered as Module's parameter once it's assigned - # as an attribute. Parameters and buffers need to be registered, or - # they won't appear in .parameters() (doesn't apply to buffers), and - # won't be converted when e.g. .cuda() is called. You can use - # .register_buffer() to register buffers. - # nn.Parameters require gradients by default. - self.weight = nn.Parameter(torch.empty(output_features, input_features)) - if bias: - self.bias = nn.Parameter(torch.empty(output_features)) - else: - # You should always register all possible parameters, but the - # optional ones can be None if you want. - self.register_parameter("bias", None) - - # Not a very smart way to initialize weights - nn.init.uniform_(self.weight, -0.1, 0.1) - if self.bias is not None: - nn.init.uniform_(self.bias, -0.1, 0.1) - - def forward(self, input): - # See the autograd section for explanation of what happens here. - return linear(input, self.weight, self.bias) - - def extra_repr(self): - # (Optional)Set the extra information about this module. You can test - # it by printing an object of this class. - return "input_features={}, output_features={}, bias={}".format( - self.input_features, self.output_features, self.bias is not None - ) - - -def run_master(_, args): - d_hid = 100 - bs = 400 - chunks = 4 - batches = 1 - - MULTI_USE_PARAM_CONFIG = ( - MultiUseParameterConfig.REPLICATE - if args.replicate - else MultiUseParameterConfig.TRANSMIT - ) - print(f"REPLICATE config: {args.replicate} -> {MULTI_USE_PARAM_CONFIG}") - print("Using schedule:", args.schedule) - - class ExampleCode(torch.nn.Module): - def __init__(self): - super().__init__() - self.l1 = MyLinear(d_hid, d_hid) - self.l2 = MyLinear(d_hid, d_hid) - self.l3 = MyLinear(d_hid, d_hid) - self.l4 = MyLinear(d_hid, d_hid) - - def forward(self, x): - x = self.l1(x) - pipe_split() - x = self.l2(x) - pipe_split() - x = self.l3(x) - pipe_split() - x = self.l4(x) - return x - - ec = ExampleCode() - ec.to(args.device) - - mse_loss = SlowMSELoss(reduction="sum") - wrapper = TrivialLossWrapper(ec, mse_loss) - ec_pipe = Pipe.from_tracing(wrapper, MULTI_USE_PARAM_CONFIG) - - all_ranks = list(range(1, args.world_size)) # exclude master rank = 0 - pipe_driver: PipelineDriverBase = schedules[args.schedule]( - ec_pipe, - chunks, - args.world_size - 1, - all_ranks=all_ranks, - _debug_mask_minibatches=True, - _record_mem_dumps=bool(args.record_mem_dumps), - checkpoint=bool(args.checkpoint), - ) - - ec_input = torch.randn(bs, d_hid, device=args.device) - target = torch.randn(bs, d_hid, device=args.device) - - pipe_visualized_filename = "pipe_visualized.json" - batches_events_contexts = [] - for i in range(batches): - pipe_driver(ec_input, target) - batches_events_contexts.append(pipe_driver.retrieve_events()) - - # first: save file - all_events_contexts: EventsContext = reduce( - lambda c1, c2: EventsContext().update(c1).update(c2), - batches_events_contexts, - EventsContext(), - ) - with open(pipe_visualized_filename, "w") as f: - f.write(events_to_json(all_events_contexts)) - - # TODO: Investigate flakiness! TODO(https://github.com/pytorch/PiPPy/issues/136) - # # second: perform checks - # for events_context in batches_events_contexts: - # check_events_for_single_batch(events_context.events, all_ranks, chunks, pipe_visualized_filename) - - -def check_events_for_single_batch( - events: List[Event], - all_stages: List[int], - chunks: int, - pipe_visualized_filename: str, -): - events_by_type_by_rank_by_mbid: Dict[ - Any, Dict[Any, Dict[Any, Event]] - ] = defaultdict(lambda: defaultdict(lambda: dict())) - for event in events: - events_by_type_by_rank_by_mbid[event.type][event.rank][ - event.mbid - ] = event - - def start_ts(e: Event, eps=0.1): - return e.start_ts + (e.finish_ts - e.start_ts) * eps - - def finish_ts(e: Event, eps=0.1): - return e.finish_ts - (e.finish_ts - e.start_ts) * eps - - # Basic happens-before cross rank checks - for i in range(len(all_stages) - 1): - rank = all_stages[i] - next_rank = all_stages[i + 1] - for mbid in range(chunks): - rank_forward = events_by_type_by_rank_by_mbid[Phase.FORWARD][rank][ - mbid - ] - next_rank_forward = events_by_type_by_rank_by_mbid[Phase.FORWARD][ - next_rank - ][mbid] - # happens-before cross-rank forward check - assert start_ts(next_rank_forward) >= finish_ts(rank_forward), ( - f"{rank_forward.name}({rank_forward.finish_ts}) must happen before " - f"{next_rank_forward.name}({next_rank_forward.start_ts}), see {pipe_visualized_filename}" - ) - - rank_backward = events_by_type_by_rank_by_mbid[Phase.BACKWARD][ - next_rank - ][mbid] - next_rank_backward = events_by_type_by_rank_by_mbid[Phase.BACKWARD][ - rank - ][mbid] - # happens-before cross-rank backward check - assert start_ts(next_rank_backward) >= finish_ts(rank_backward), ( - f"{rank_backward.name}({rank_backward.finish_ts}) must happen before " - f"{next_rank_backward.name}({next_rank_backward.start_ts}), see {pipe_visualized_filename}" - ) - - # Basic happens-before cross-microbatch checks - for mbid in range(chunks - 1): - next_mbid = mbid + 1 - for i in range(len(all_stages) - 1): - rank = all_stages[i] - rank_forward = events_by_type_by_rank_by_mbid[Phase.FORWARD][rank][ - mbid - ] - next_mbid_forward = events_by_type_by_rank_by_mbid[Phase.FORWARD][ - rank - ][next_mbid] - # happens-before cross-microbatch forward check - assert start_ts(next_mbid_forward) >= finish_ts(rank_forward), ( - f"{rank_forward.name}({rank_forward.finish_ts}) must happen before " - f"{next_mbid_forward.name}({next_mbid_forward.start_ts}), see {pipe_visualized_filename}" - ) - - rank_backward = events_by_type_by_rank_by_mbid[Phase.BACKWARD][ - rank - ][mbid] - next_mbid_backward = events_by_type_by_rank_by_mbid[Phase.BACKWARD][ - rank - ][next_mbid] - # happens-before cross-microbatch backward check - assert start_ts(next_mbid_backward) >= finish_ts(rank_backward), ( - f"{rank_backward.name}({rank_backward.finish_ts}) must happen before " - f"{next_mbid_backward.name}({next_mbid_backward.start_ts}), see {pipe_visualized_filename}" - ) - - # Overlap checks - for mbid in range(chunks - 1): - next_mbid = mbid + 1 - last_forward = events_by_type_by_rank_by_mbid[Phase.FORWARD][ - all_stages[-1] - ][mbid] - first_next_forward = events_by_type_by_rank_by_mbid[Phase.FORWARD][ - all_stages[0] - ][next_mbid] - # cross-microbatch forward overlap check - assert ( - last_forward.finish_ts >= first_next_forward.start_ts - ), f"Forward microbatch {mbid} doesn't overlap with next microbatch {next_mbid}" - - last_backward = events_by_type_by_rank_by_mbid[Phase.BACKWARD][ - all_stages[0] - ][mbid] - first_next_backward = events_by_type_by_rank_by_mbid[Phase.BACKWARD][ - all_stages[-1] - ][next_mbid] - # cross-microbatch forward overlap check - assert ( - last_backward.finish_ts >= first_next_backward.start_ts - ), f"Backward microbatch {mbid} doesn't overlap with next microbatch {next_mbid}" - - -def main(args=None): - parser = argparse.ArgumentParser() - parser.add_argument( - "--world_size", type=int, default=int(os.getenv("WORLD_SIZE", 5)) - ) - parser.add_argument("--rank", type=int, default=int(os.getenv("RANK", -1))) - parser.add_argument( - "--master_addr", type=str, default=os.getenv("MASTER_ADDR", "localhost") - ) - parser.add_argument( - "--master_port", type=str, default=os.getenv("MASTER_PORT", "29500") - ) - parser.add_argument( - "-s", - "--schedule", - type=str, - default=list(schedules.keys())[0], - choices=schedules.keys(), - ) - parser.add_argument( - "--replicate", type=int, default=int(os.getenv("REPLICATE", "0")) - ) - parser.add_argument( - "--cuda", type=int, default=int(torch.cuda.is_available()) - ) - parser.add_argument( - "--record_mem_dumps", type=int, default=0, choices=[0, 1] - ) - parser.add_argument("--checkpoint", type=int, default=0, choices=[0, 1]) - args = parser.parse_args(args) - - run_pippy(run_master, args) - - -if __name__ == "__main__": - main() - - -class LocalTestVisualizer(unittest.TestCase): - def test_visualizer(self): - import random - - port = random.randint(29500, 30000) - args = [ - "--master_port", - str(port), - ] - main(args) diff --git a/test/local_test_c10d_bwd.py b/test/test_bwd.py similarity index 72% rename from test/local_test_c10d_bwd.py rename to test/test_bwd.py index 21db17240..fb5f45aa0 100644 --- a/test/local_test_c10d_bwd.py +++ b/test/test_bwd.py @@ -3,12 +3,16 @@ import os import unittest +import pippy + import torch import torch.distributed as dist +from pippy.IR import Pipe, pipe_split +from pippy.microbatch import sum_reducer, TensorChunkSpec +from pippy.PipelineStage import PipelineStage -from pippy.compile import compile_stage -from pippy.IR import pipe_split +pippy.microbatch._debug_mask_minibatches = True schedules = [ "FillDrain", @@ -16,11 +20,12 @@ ] d_hid = 512 -chunk_size = 256 +batch_size = 256 torch.manual_seed(0) +# Basic example class ExampleCode(torch.nn.Module): def __init__(self): super().__init__() @@ -29,7 +34,7 @@ def __init__(self): self.lin = torch.nn.Linear(d_hid, d_hid) self.mse_loss = torch.nn.MSELoss(reduction="sum") - def forward(self, x, target): + def forward(self, x, y): x = torch.mm(x, self.mm_param) skip_connection = x x = torch.relu(x) @@ -42,35 +47,41 @@ def forward(self, x, target): x = torch.mm(x, self.mm_param2) pipe_split() x = self.lin(x) - x = torch.relu(x) - loss = self.mse_loss(x, target) - return {"logits": x, "loss": loss} + logits = torch.relu(x) + loss = self.mse_loss(x, y) + return logits, loss def run_worker(args): - ec = ExampleCode() - ec.to(args.device) - ec.train() + mod = ExampleCode() + mod.to(args.device) - ec_x = torch.randn(args.chunks * chunk_size, d_hid, device=args.device) - target = torch.randn(args.chunks * chunk_size, d_hid, device=args.device) + x = torch.randn(batch_size, d_hid, device=args.device) + y = torch.randn(batch_size, d_hid, device=args.device) - stage = compile_stage( - ec, - args.rank, - args.world_size, + output_chunk_spec = ( + TensorChunkSpec(0), # logits + sum_reducer, # loss + ) + + pipe = Pipe.from_tracing( + mod, args.chunks, - args.device, - None, - [ec_x, target], - schedule=args.schedule, + example_args=(x, y), + output_chunk_spec=output_chunk_spec, + ) + + stage = PipelineStage( + pipe, + args.rank, + device=args.device, ) # Run if args.rank == 0: - out = stage(ec_x) + out = stage(x) elif args.rank == args.world_size - 1: - out = stage(target) + out = stage(y) else: stage() @@ -79,11 +90,9 @@ def run_worker(args): # Last rank checks result if args.rank == args.world_size - 1: - ref_out = ec(ec_x, target) + ref_out = mod(x, y) torch.testing.assert_close(out, ref_out) - print( - f"equivalence test passed, loss = {out['loss']}, ref loss = {ref_out['loss']}" - ) + print(f"equivalence test passed loss={out[1]} ref_loss={ref_out[1]}") def main(args=None): @@ -135,8 +144,8 @@ def main(args=None): main() -class LocalTestC10DBwdTest(unittest.TestCase): - def test_c10d_bwd(self): +class TestBwd(unittest.TestCase): + def test_bwd(self): import random port = random.randint(29500, 30000) diff --git a/test/local_test_c10d.py b/test/test_fwd.py similarity index 82% rename from test/local_test_c10d.py rename to test/test_fwd.py index 5044c50d3..675844170 100644 --- a/test/local_test_c10d.py +++ b/test/test_fwd.py @@ -3,15 +3,18 @@ import os import unittest +import pippy + import torch import torch.distributed as dist +from pippy.IR import Pipe, pipe_split +from pippy.PipelineStage import PipelineStage -from pippy.compile import compile_stage -from pippy.IR import pipe_split +pippy.microbatch._debug_mask_minibatches = True d_hid = 512 -chunk_size = 256 +batch_size = 256 torch.manual_seed(0) @@ -42,25 +45,27 @@ def forward(self, x, y): def run_worker(args): - ec = ExampleCode() - ec.to(args.device) + mod = ExampleCode() + mod.to(args.device) - ec_x = torch.randn(args.chunks * chunk_size, d_hid, device=args.device) - ec_y = torch.randn(args.chunks * chunk_size, d_hid, device=args.device) + x = torch.randn(batch_size, d_hid, device=args.device) + y = torch.randn(batch_size, d_hid, device=args.device) - stage = compile_stage( - ec, - args.rank, - args.world_size, + pipe = Pipe.from_tracing( + mod, args.chunks, - args.device, - None, - [ec_x, ec_y], + example_args=(x, y), + ) + + stage = PipelineStage( + pipe, + args.rank, + device=args.device, ) # Run if args.rank == 0: - out = stage(ec_x, ec_y) + stage(x, y) elif args.rank == args.world_size - 1: out = stage() else: @@ -71,7 +76,7 @@ def run_worker(args): # Last rank checks result if args.rank == args.world_size - 1: - ref_out = ec(ec_x, ec_y) + ref_out = mod(x, y) torch.testing.assert_close(out, ref_out) print( f"equivalence test passed {torch.sum(out)} ref {torch.sum(ref_out)}" @@ -121,8 +126,8 @@ def main(args=None): main() -class LocalTestC10DTest(unittest.TestCase): - def test_c10d(self): +class TestFwd(unittest.TestCase): + def test_fwd(self): import random port = random.randint(29500, 30000) diff --git a/test/test_fx.py b/test/test_fx.py deleted file mode 100644 index b366395ec..000000000 --- a/test/test_fx.py +++ /dev/null @@ -1,4658 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates -# Owner(s): ["module: fx"] - -import builtins -import contextlib -import copy -import functools -import inspect -import io -import math -import numbers -import operator -import os -import pickle -import sys -import traceback -import types -import typing -import unittest -import warnings -from collections import namedtuple -from copy import deepcopy -from math import sqrt - -from typing import Any, Callable, Dict, List, NamedTuple, Optional, Tuple, Union - -import pippy -import pippy.fx._pytree as fx_pytree - -import torch -import torch.nn.utils._stateless as _stateless -import torch.utils._pytree as pytree - -from fx.named_tup import MyNamedTup -from pippy.fx import ( - CodeGen, - Graph, - GraphModule, - Interpreter, - Node, - PH, - Proxy, - symbolic_trace, - Tracer, - Transformer, - wrap, -) -from pippy.fx._compatibility import ( - _BACK_COMPAT_OBJECTS, - _MARKED_WITH_COMATIBLITY, -) -from pippy.fx.experimental.rewriter import RewritingTracer -from pippy.fx.immutable_collections import immutable_dict, immutable_list -from pippy.fx.node import _format_arg, Argument, Target -from pippy.fx.operator_schemas import get_signature_for_torch_op -from pippy.fx.passes import shape_prop -from pippy.fx.proxy import TraceError -from torch.multiprocessing import Process -from torch.testing import FileCheck -from torch.testing._internal.common_device_type import ( - instantiate_device_type_tests, - onlyCPU, - ops, -) -from torch.testing._internal.common_methods_invocations import op_db -from torch.testing._internal.common_utils import ( - find_library_location, - IS_FBCODE, - IS_MACOS, - IS_WINDOWS, - run_tests, - skipIfSlowGradcheckEnv, -) -from torch.testing._internal.jit_utils import JitTestCase - -try: - from torchvision import models as torchvision_models - - HAS_TORCHVISION = True -except ImportError: - HAS_TORCHVISION = False -skipIfNoTorchVision = unittest.skipIf(not HAS_TORCHVISION, "no torchvision") - - -class SimpleTest(torch.nn.Module): - def forward(self, x): - return torch.relu(x + 3.0) - - -def a_non_torch_leaf(a, b): - return a + b - - -# Used for test_autowrap_function. Autowrapped functions need to be global -def fx_int(x: float) -> int: - return int(x) - - -def fx_int_x2(x: float) -> int: - return int(x) * 2 - - -# used in test_pytree. It's all the way out here because pickling a GraphModule -# that uses Point errors out if Point is local to the function -Point = namedtuple("Point", ["x", "y"]) - - -# Test wrap() passing both a function name as well as a function -# directly -def a_lifted_leaf(a, b): - return a[0] + a[1] + b - - -wrap("a_lifted_leaf") -# Test wrapping twice doesn't break anything -wrap("a_lifted_leaf") - - -def a_lifted_leaf2(a, b): - return a[0] + a[1] + b - - -wrap(a_lifted_leaf2) - -wrap("len") - -wrap("getattr") - - -def wrapped_named_tup(p1, *, p2): - return p1.x + p2.y - - -wrap(wrapped_named_tup) - - -@wrap -def wrapped_via_decorator(a): - return a + 1 - - -wrap("wrapped_with_submodule") - - -def wrapped_with_submodule(x: torch.Tensor, batchnorm1d: torch.nn.BatchNorm1d): - return batchnorm1d(x) - - -def my_decorator(f): - @functools.wraps(f) - def wrapper_inside_decorator(*args, **kwargs): - return f(*args, **kwargs) - - return wrapper_inside_decorator - - -@wrap -@my_decorator -def wrapped_decorated_fn(x): - return x - - -real_wrapped_via_decorator = wrapped_via_decorator -real_a_lifed_leaf = a_lifted_leaf -real_a_lifed_leaf2 = a_lifted_leaf2 -_sqrt = sqrt - -wrap("wrapper_fn") - - -def wrapper_fn(x): - return torch.foo(x) - - -class Pair(NamedTuple): - x: torch.Tensor - y: torch.Tensor - - def _custom_fx_repr_fn(self) -> str: - return f"Pair(x={_format_arg(self.x)}, y={_format_arg(self.y)})" - - -# for testing pytrees -class Foo(object): # noqa: B209 - def __init__(self, a, b): - self.a = a - self.b = b - - -class TestFX(JitTestCase): - def setUp(self): - super().setUp() - # Checking for mutable operations whil tracing is feature flagged - # Enable it in testing but not by default - self.orig_tracer_mutable_flag = ( - pippy.fx.proxy.TracerBase.check_mutable_operations - ) - pippy.fx.proxy.TracerBase.check_mutable_operations = True - - if not (IS_FBCODE or IS_WINDOWS or IS_MACOS): - lib_file_path = find_library_location("libtorchbind_test.so") - torch.ops.load_library(str(lib_file_path)) - - def tearDown(self): - super().tearDown() - pippy.fx.proxy.TracerBase.check_mutable_operations = ( - self.orig_tracer_mutable_flag - ) - - def checkGraphModule(self, m: torch.nn.Module, args, kwargs=None): - """Check that an nn.Module's results match the GraphModule version - for a given set of args/kwargs. - """ - kwargs = kwargs if kwargs else {} - ref_outs = m(*args, **kwargs) - gm = symbolic_trace(m) - gm.graph.lint() - test_outs = gm(*args, **kwargs) - self.assertEqual(ref_outs, test_outs) - - def test_graph_module(self): - class MySub(torch.nn.Module): - def __init__(self): - super().__init__() - self.w = torch.nn.Parameter(torch.rand(4, 3)) - - def forward(self, x): - return self.w + x - - class MyModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.lin = torch.nn.Linear(4, 3) - self.sub_mod = MySub() - self.w = torch.nn.Parameter(torch.rand(3)) - - def forward(self, A, B, c): - t = torch.sigmoid(A) + self.lin(c) - return self.sub_mod( - t.data - + self.w - + t - + 1 - - A - + B // A - + -A - + A.add(B, alpha=3) - ) - - m = MyModule() - gm = symbolic_trace(m) - - ms = torch.jit.script(gm) - - class M2(torch.nn.Module): - def forward(self, A): - m, idx = torch.max(A, 0) - return m + 1, idx + 1 - - m2 = M2() - gm2 = symbolic_trace(m2) - - class T(torch.nn.Module): - def forward(self, A, b=4, *args, c=5, **kwargs): - x = A + 1 + args[0] + kwargs["3"] - return x - - t = T() - symbolic_trace(t) - - # test for issue described at https://github.com/pytorch/pytorch/issues/63883 - class M3(torch.nn.Module): - def forward(self, x): - return torch.relu(x) - - m3 = M3() - gm3 = symbolic_trace(m3) - new_instance = gm3.__new__(type(gm3)) - new_instance.__init__(gm3, gm3.graph) - - x = torch.randn(5, 3) - torch.testing.assert_allclose(new_instance(x), torch.relu(x)) - - def test_custom_import(self): - graph = pippy.fx.Graph() - a = graph.placeholder("x") - b = graph.placeholder("y") - c = graph.call_function(a_non_torch_leaf, (a, b)) - d = graph.call_function(torch.sin, (c,)) - graph.output(d) - gm = GraphModule(torch.nn.Module(), graph) - x, y = torch.rand(1), torch.rand(1) - self.assertEqual(torch.sin(x + y), gm(x, y)) - - def test_args_kwargs(self): - class T(torch.nn.Module): - def forward(self, *args, **kwargs): - x = args[0] + kwargs["foo"] - return x - - t = T() - self.checkGraphModule( - t, (torch.rand(1), torch.rand(1)), {"foo": torch.rand(1)} - ) - - def test_args_kwargs_no_self(self): - class T(torch.nn.Module): - def forward(*args, **kwargs): # noqa: B902 - self = args[0] - return torch.relu(args[1]) - - t = T() - with self.assertRaisesRegex( - RuntimeError, r"cannot be part of \*args expansion" - ): - self.checkGraphModule( - t, (torch.rand(1), torch.rand(1)), {"foo": torch.rand(1)} - ) - - def test_fx_shifts(self): - class MyModule(torch.nn.Module): - def forward(self, x): - return x << 3, x >> 3 - - input = torch.LongTensor(10).random_(0, 1024) - - m = MyModule() - self.checkGraphModule(m, (input,)) - - def test_fx_and_or(self): - class MyModule(torch.nn.Module): - def forward(self, x): - return x & x, x | x - - input = torch.LongTensor(10).random_(0, 1024) - - m = MyModule() - self.checkGraphModule(m, (input,)) - - def test_dict(self): - class MyDictMod(torch.nn.Module): - def forward(self, d): - return d["3"].relu(), {"4": d["3"].neg()} - - input_dict = {"3": torch.rand(3, 4)} - m = MyDictMod() - - self.checkGraphModule(m, (input_dict,)) - - def test_matmul_tracing(self): - const = torch.randn(3) - - def matmul_f(x): - return x @ const - - mod = symbolic_trace(matmul_f) - inp = torch.randn(3) - self.assertEqual(mod(inp), matmul_f(inp)) - - def rmatmul_f(x): - return const @ x - - mod = symbolic_trace(rmatmul_f) - inp = torch.randn(3) - self.assertEqual(mod(inp), rmatmul_f(inp)) - - def test_disallow_override(self): - # Custom delegate to disallow in-place tensor operations - class NoMutableCallTracer(Tracer): - def create_node( - self, - kind: str, - target: Union[str, Callable], - args: Tuple[Argument, ...], - kwargs: Dict[str, Any], - name: Optional[str] = None, - type_expr: Optional[Any] = None, - ) -> Node: - name = ( - target - if isinstance(target, str) - else torch.typename(target) - ) - if name[-1] == "_": - raise RuntimeError("In-place operations are not supported") - return super().create_node(kind, target, args, kwargs, name) - - # Test method - class MyInplaceMod(torch.nn.Module): - def forward(self, x): - x.add_(3.0) - return x - - m = MyInplaceMod() - - with self.assertRaisesRegex(RuntimeError, "In-place operations"): - NoMutableCallTracer().trace(m) - - # Test free function - class MyInplaceMod2(torch.nn.Module): - def forward(self, x): - torch.log_(x) - return x - - m2 = MyInplaceMod2() - with self.assertRaisesRegex(RuntimeError, "In-place operations"): - NoMutableCallTracer().trace(m2) - - # Test symbolic node as an arg - class MyInplaceMod3(torch.nn.Module): - def forward(self, x): - y = torch.ones(3, 4) - y.add_(x) - return x - - m3 = MyInplaceMod3() - with self.assertRaisesRegex(RuntimeError, "In-place operations"): - NoMutableCallTracer().trace(m3) - - def test_leaf_module(self): - # Custom delegate to make it so that there are no leaf modules, everything - # should get traced through - class NoLeafModulesTracer(Tracer): - def is_leaf_module(self, m, qualname): - return False - - class MyReluMod(torch.nn.Module): - def __init__(self): - super().__init__() - self.relu = torch.nn.ReLU() - - def forward(self, x): - return self.relu(x) - - mrm = MyReluMod() - sym = NoLeafModulesTracer().trace(mrm) - for node in sym.nodes: - self.assertNotEqual(node.op, "call_module") - sym.lint() - - def test_wrap(self): - self.assertEqual(3 + 4 + 5, a_lifted_leaf((3, 4), 5)) - - def to_trace(y): - return ( - a_lifted_leaf((4, y), 3) - + a_lifted_leaf((3, 4), 5) - + a_lifted_leaf((y, y), y) - ) - - m = symbolic_trace(to_trace) - self.assertIn("a_lifted_leaf", m.code) - self.assertEqual(27, m(2)) - self.assertIs(a_lifted_leaf, real_a_lifed_leaf) - - def test_wrap_fn_directly(self): - self.assertEqual(3 + 4 + 5, a_lifted_leaf2((3, 4), 5)) - - def to_trace(y): - return ( - a_lifted_leaf2((4, y), 3) - + a_lifted_leaf2((3, 4), 5) - + a_lifted_leaf2((y, y), y) - ) - - m = symbolic_trace(to_trace) - self.assertIn("a_lifted_leaf2", m.code) - self.assertEqual(27, m(2)) - self.assertIs(a_lifted_leaf2, real_a_lifed_leaf2) - - def test_wrapped_via_decorator(self): - self.assertEqual(wrapped_via_decorator(0), 1) - - def to_trace(y): - return wrapped_via_decorator(y) - - m = symbolic_trace(to_trace) - self.assertIn("wrapped_via_decorator", m.code) - self.assertEqual(m(0), 1) - self.assertIs(wrapped_via_decorator, real_wrapped_via_decorator) - self.assertFalse(hasattr(wrapped_via_decorator, "__fx_already_patched")) - - def test_wrapped_via_decorator_and_transformed(self): - self.assertEqual(wrapped_via_decorator(0), 1) - - def to_trace(y): - return wrapped_via_decorator(y) - - m = symbolic_trace(to_trace) - self.assertIn("wrapped_via_decorator", m.code) - self.assertEqual(m(0), 1) - self.assertIs(wrapped_via_decorator, real_wrapped_via_decorator) - self.assertFalse(hasattr(wrapped_via_decorator, "__fx_already_patched")) - - transformed = pippy.fx.Transformer(m).transform() - self.assertIn("wrapped_via_decorator", transformed.code) - self.assertEqual(transformed(0), 1) - self.assertIs(wrapped_via_decorator, real_wrapped_via_decorator) - self.assertFalse(hasattr(wrapped_via_decorator, "__fx_already_patched")) - - def test_wrap_with_submodule(self): - class M(torch.nn.Module): - def __init__(self): - super(M, self).__init__() - self.batchnorm1d = torch.nn.BatchNorm1d(2, affine=False) - - def forward(self, x: torch.Tensor): - return wrapped_with_submodule(x, self.batchnorm1d) - - m = symbolic_trace(M()) - - self.assertIn("wrapped_with_submodule", m.code) - - input = torch.rand(3, 2) - ref_batchnorm1d = torch.nn.BatchNorm1d(2, affine=False) - self.assertEqual(ref_batchnorm1d(input), m(input)) - - def test_wrapped_retrace(self): - def to_trace(y): - return wrapped_via_decorator(y) - - m = symbolic_trace(to_trace) - self.assertIn("wrapped_via_decorator", m.code) - self.assertEqual(m(0), 1) - - retraced = symbolic_trace(m) - self.assertIn("wrapped_via_decorator", retraced.code) - self.assertEqual(retraced(0), 1) - - def test_wrap_decorated_function(self): - def to_trace(y): - return wrapped_decorated_fn(y) - - m = symbolic_trace(to_trace) - self.assertIn("wrapped_decorated_fn", m.code) - self.assertEqual(m(1), 1) - - def test_graph_edit_with_proxy(self): - class M(torch.nn.Module): - def forward(self, a, b): - return a + b - - m = M() - g = symbolic_trace(m).graph - new_g = pippy.fx.Graph() - val_map: Dict[Node, Node] = {} - output_val = new_g.graph_copy(g, val_map) - t = Proxy(output_val) - # test that we can use proxy objects to generate more graph code later for things that do not need to work with modules. - new_g.output((t + t).node) - gm = GraphModule(m, new_g) - gm.graph.lint() - self.assertEqual(gm(3, 4), 14) - - def test_concrete_arg_none_assert(self): - class Foo(torch.nn.Module): - def forward(self, x, val=None): - return x if val is None else x + val - - f = Foo() - traced = pippy.fx.symbolic_trace(f, concrete_args={"val": None}) - with self.assertRaisesRegex( - AssertionError, "val has been specialized to have value None" - ): - traced(torch.randn(5), torch.randn(5)) - - x = torch.randn(5) - torch.testing.assert_close(traced(x), f(x)) - - def test_trace_multiple_funcs(self): - class Foo(torch.nn.Module): - def forward(self, x, y): - return x + y - - def minus_forward(self, x, y): - return x - y - - def multiply_forward(self, x, y): - return x * y - - f = Foo() - x, y = torch.randn(5), torch.randn(5) - - print(torch.__version__) - - tracer = Tracer() - torch.testing.assert_close( - GraphModule(f, tracer.trace(f))(x, y), f(x, y) - ) - - tracer.traced_func_name = "minus_forward" - torch.testing.assert_close( - GraphModule(f, tracer.trace(f))(x, y), - f.minus_forward(x, y), - ) - - tracer.traced_func_name = "multiply_forward" - torch.testing.assert_close( - GraphModule(f, tracer.trace(f))(x, y), - f.multiply_forward(x, y), - ) - - tracer.traced_func_name = "add_forward" - with self.assertRaisesRegex(AssertionError, "doesn't exist in"): - tracer.trace(f) - - def test_graph_unique_names(self): - class M(torch.nn.Module): - def forward(self, a, b): - return a + b - - m = M() - g = symbolic_trace(m).graph - new_g = pippy.fx.Graph() - val_map: Dict[Node, Node] = {} - output_val = new_g.graph_copy(g, val_map) - t = Proxy(output_val) - # test that we can use proxy objects to generate more graph code later for things that do not need to work with modules. - new_g.output((t + t).node) - gm = GraphModule(m, new_g) - seen_names: Set[str] = set() - for node in gm.graph.nodes: - assert node.name not in seen_names - seen_names.add(node.name) - - def test_stack_traces(self): - class M(torch.nn.Module): - def forward(self, a, b): - return a + b - - tracer = pippy.fx.Tracer() - tracer.record_stack_traces = True - - graph = tracer.trace(M()) - # saving the original list because we will insert new nodes as a part of a test - orig_graph_nodes = list(graph.nodes) - for node in orig_graph_nodes: - if node.op == "output": - continue - self.assertTrue(node.stack_trace is not None) - assert "test_fx.py" in node.stack_trace - - # verify that copying the node does not lose the stack trace - new_node = graph.node_copy(node) - self.assertTrue(new_node.stack_trace is not None) - assert "test_fx.py" in new_node.stack_trace - - def test_stack_traces_with_transformer(self): - class M(torch.nn.Module): - def forward(self, a, b): - return a + b - - tracer = pippy.fx.Tracer() - tracer.record_stack_traces = True - - graph = tracer.trace(M()) - gm = GraphModule(tracer.root, graph) - new_gm = Transformer(gm).transform() - - # nodes after Transformer should still preserve the original node's stack trace - for node in new_gm.graph.nodes: - if node.op in {"placeholder", "output"}: - continue - self.assertTrue(node.stack_trace is not None) - assert "test_fx.py" in node.stack_trace - - def test_graph_unique_names_manual(self): - graph: pippy.fx.Graph = pippy.fx.Graph() - a: pippy.fx.Node = graph.create_node("placeholder", "x") - b: pippy.fx.Node = graph.create_node( - "call_module", "linear_mod", args=(a,), name="foo_1_1" - ) - c: pippy.fx.Node = graph.create_node("get_attr", "y_attr", name="foo_1") - d: pippy.fx.Node = graph.create_node( - "call_function", operator.add, args=(b, c) - ) - graph.output(d) - graph2 = pippy.fx.Graph() - val_map: Dict[Node, Node] = {} - graph2.graph_copy(graph, val_map) - seen_names: Set[str] = set() - for node in graph2.nodes: - assert node.name not in seen_names - seen_names.add(node.name) - - def test_unpack(self): - class M(torch.nn.Module): - def forward(self, a, b): - c, d = a - return c + d + b - - a = (torch.rand(1), torch.rand(1)) - b = torch.rand(1) - m = M() - self.checkGraphModule(m, (a, b)) - - def test_native_callable(self): - if IS_FBCODE or IS_WINDOWS or IS_MACOS: - raise unittest.SkipTest( - "non-portable load_library call used in test" - ) - # This test exercises the case where we use FX to translate from Python - # code to some native callable object - # - # For the purposes of testing, we use ElementwiseInterpreter defined - # in test_custom_class.cpp. - # - # We test that we can - # 1) Construct a native callable from FX IR - # 2) Construct a drop-in replacement module that delegates to the - # native callable rather than the original code - # 3) Run both the original code and native callable wrapper with - # equivalent results - # 4) TorchScript compile the native callable wrapper and confirm - # equivalent results with the reference - # 5) TorchScript serialize and deserialize the native callable - # and confirm equivalent results with the reference - - # We use this simple Module as a reference computation - class MySimpleMod(torch.nn.Module): - def forward(self, x): - return 3.0 * x + x - - msm = MySimpleMod() - - # This is what a lowering pass might look like: a function that takes - # a valid nn.Module, symbolically traces it, lowers the Module to some - # representation, and wraps that representation up into another - # nn.Module instance that handles dispatch to the compiled/lowered code. - def lower_to_elementwise_interpreter( - orig_mod: torch.nn.Module, - ) -> torch.nn.Module: - # ===== Stage 1: Symbolic trace the module ===== - mod = symbolic_trace(orig_mod) - - # ===== Stage 2: Lower GraphModule representation to the C++ - # interpreter's instruction format ====== - instructions = [] - constant_idx = 0 - constants = {} - fn_input_names = [] - - target_to_name = {operator.add: "add", operator.mul: "mul"} - - output_node: Optional[Node] = None - # For each instruction, create a triple - # (instruction_name : str, inputs : List[str], output : str) - # to feed into the C++ interpreter - for n in mod.graph.nodes: - target, args, out_name = n.target, n.args, n.name - assert len(n.kwargs) == 0, "kwargs currently not supported" - - if n.op == "placeholder": - # Placeholders specify function argument names. Save these - # for later when we generate the wrapper GraphModule - fn_input_names.append(target) - elif n.op == "call_function": - assert target in target_to_name, ( - "Unsupported call target " + target - ) - arg_names = [] - for arg in args: - if not isinstance(arg, Node): - # Pull out constants. These constants will later be - # fed to the interpreter C++ object via add_constant() - arg_name = f"constant_{constant_idx}" - constants[arg_name] = torch.tensor( - [arg] - if isinstance(arg, numbers.Number) - else arg - ) - arg_names.append(arg_name) - constant_idx += 1 - else: - arg_names.append(arg.name) - instructions.append( - (target_to_name[target], arg_names, out_name) - ) - elif n.op == "output": - if output_node is not None: - raise RuntimeError("Multiple output nodes!") - output_node = n - else: - raise RuntimeError("Unsupported opcode " + n.op) - - interpreter = ( - torch.classes._TorchScriptTesting._ElementwiseInterpreter() - ) - # Load constants - for k, v in constants.items(): - interpreter.add_constant(k, v) - # Specify names for positional input arguments - interpreter.set_input_names(fn_input_names) - # Load instructions - interpreter.set_instructions(instructions) - # Specify name for single output - assert isinstance(output_node.args[0], pippy.fx.Node) - interpreter.set_output_name(output_node.args[0].name) - - # ===== Stage 3: Create a wrapper GraphModule around the interpreter ===== - class WrapperModule(torch.nn.Module): - def __init__(self, interpreter): - super().__init__() - self.interpreter = interpreter - - wrapper = WrapperModule(interpreter) - - # Create a graph that: 1) Takes function arguments 2) Invokes the interpreter - # 3) Returns the speficied return value - - # FIXME: The following code could be greatly simplified by symbolic_trace'ing - # the wrapper with a Tracer that considers the Wrapper instance a root - # module, however, I can't get `__call__` exposed on TorchBind classes - # without it messing up Python `hasattr` for some reason. More digging - # into CPython's implementation of hasattr is probably in order... - - graph = pippy.fx.Graph() - # Add placeholders for fn inputs - placeholder_nodes = [] - for name in fn_input_names: - placeholder_nodes.append(graph.create_node("placeholder", name)) - - # Get the interpreter object - interpreter_node = graph.create_node("get_attr", "interpreter") - - # Add a node to call the interpreter instance - output_node = graph.create_node( - op="call_method", - target="__call__", - args=(interpreter_node, placeholder_nodes), - ) - - # Register output - graph.output(output_node) - - graph.lint() - - # Return final GraphModule!!! - return GraphModule(wrapper, graph) - - # Lower GraphModule to C++ interpreter - lowered = lower_to_elementwise_interpreter(msm) - - # Compare correctness with original module - x = torch.rand(3, 4) - ref_out = msm(x) - test_out = lowered(x) - torch.testing.assert_close(test_out, ref_out) - - # Test TorchScript compilation - scripted_lowered = torch.jit.script(lowered) - script_out = scripted_lowered(x) - torch.testing.assert_close(script_out, ref_out) - - # Test TorchScript ser/de - import_copy = self.getExportImportCopy(scripted_lowered) - imported_out = import_copy(x) - torch.testing.assert_close(imported_out, ref_out) - - def test_reserved_getattr(self): - """Ensure that we do not name any nodes with a reserved builtin like `getattr`""" - - class M(torch.nn.Module): - def forward(self, a): - return a.foo.bar.baz - - m = M() - m_g = symbolic_trace(m) - m_g.graph.lint() - for node in m_g.graph.nodes: - self.assertTrue(node.name != "getattr") - - @unittest.skip("Hotfix for SEV remediation") - def test_trace_buffer_slice(self): - bs, d_hid = 10, 23 - - class ExampleCode(torch.nn.Module): - def __init__(self): - super().__init__() - self.mm_param = torch.nn.Parameter(torch.randn(d_hid, d_hid)) - self.mm_param2 = torch.nn.Parameter(torch.randn(d_hid, d_hid)) - self.lin = torch.nn.Linear(d_hid, d_hid) - self.register_buffer("buffer", torch.randn(bs + 100, d_hid)) - - def forward(self, x): - x = torch.mm(x, self.mm_param) - skip_connection = x - x = torch.relu(x) - x = torch.mm(x, self.mm_param) + self.buffer[: x.shape[0]] - x = self.lin(x) - x = torch.relu(x) - x = x + skip_connection - x = torch.mm(x, self.mm_param2) - x = self.lin(x) - return x - - ec = ExampleCode() - - traced = pippy.fx.symbolic_trace(ec) - - x = torch.randn(bs, d_hid) - torch.testing.assert_allclose(ec(x), traced(x)) - - def test_node_tagging(self): - class TaggingTracer(Tracer): - def create_node( - self, - kind: str, - target: Union[str, Callable], - args: Tuple[Argument, ...], - kwargs: Dict[str, Any], - name: Optional[str] = None, - type_expr: Optional[Any] = None, - ) -> Node: - n = super().create_node(kind, target, args, kwargs, name) - n.tag = "foo" - return n - - class M(torch.nn.Module): - def forward(self, a, b): - return a + b - - m = M() - g = TaggingTracer().trace(m) - g.lint() - for n in g.nodes: - self.assertTrue(hasattr(n, "tag")) - self.assertEqual(n.tag, "foo") - - def test_tensor_attribute(self): - class TensorAttribute(torch.nn.Module): - def __init__(self): - super().__init__() - self.tensor = torch.rand(3, 4) - - def forward(self, x): - return torch.nn.functional.linear(x, self.tensor) - - ta = TensorAttribute() - traced = symbolic_trace(ta) - traced(torch.rand(4, 4)) - - class WrapperForQualname(torch.nn.Module): - def __init__(self): - super().__init__() - self.ta = TensorAttribute() - - def forward(self, x): - return torch.nn.functional.linear(x, self.ta.tensor) - - wfq = WrapperForQualname() - traced2 = symbolic_trace(wfq) - traced2.graph.lint() - traced2(torch.rand(4, 4)) - - def test_tensor_attribute_coalseced(self): - def count_attrs(fx_module): - targets = set() - for node in traced.graph.nodes: - if node.op == "get_attr": - targets.add(node.target) - return len(targets) - - val = torch.tensor(5) - - def f(x): - return x + val + val - - traced = symbolic_trace(f) - traced.graph.lint() - self.assertEqual(count_attrs(traced), 1) - - val2 = torch.tensor(5) - - def f(x): - val = torch.tensor(5) - return x + val + val2 - - traced = symbolic_trace(f) - traced.graph.lint() - self.assertEqual(count_attrs(traced), 2) - - def test_symbolic_trace_sequential(self): - class Simple(torch.nn.Module): - def forward(self, x): - return torch.neg(x) - - seq = torch.nn.Sequential(Simple(), Simple(), Simple()) - traced = symbolic_trace(seq) - traced.graph.lint() - x = torch.rand(3, 4) - self.assertEqual(traced(x), seq(x)) - - def test_tensor_constant(self): - class ConstTensor(torch.nn.Module): - def forward(self, x): - return torch.nn.functional.linear(x, torch.zeros(3, 4)) - - ct = ConstTensor() - traced = symbolic_trace(ct) - traced.graph.lint() - traced(torch.rand(4, 4)) - - def test_pickle_graphmodule(self): - class Nested(torch.nn.Module): - def __init__(self): - super().__init__() - self.st = torch.nn.Linear(4, 4) - - def forward(self, x): - return self.st(x) - - n = Nested() - traced = symbolic_trace(n) - traced.graph.lint() - pickled = pickle.dumps(traced) - loaded = pickle.loads(pickled) - loaded.graph.lint() - x = torch.rand(3, 4) - self.assertEqual(loaded(x), traced(x)) - - def test_pickle_custom_import(self): - graph = pippy.fx.Graph() - a = graph.placeholder("x") - b = graph.placeholder("y") - c = graph.call_function(a_non_torch_leaf, (a, b)) - d = graph.call_function(torch.sin, (c,)) - graph.output(d) - gm = GraphModule(torch.nn.Module(), graph) - pickled = pickle.dumps(gm) - loaded = pickle.loads(pickled) - loaded.graph.lint() - x, y = torch.rand(1), torch.rand(1) - self.assertEqual(loaded(x, y), gm(x, y)) - - def test_all_input_nodes(self): - graph: pippy.fx.Graph = pippy.fx.Graph() - a: pippy.fx.Node = graph.placeholder("x") - b: pippy.fx.Node = graph.call_module("linear_mod", args=(a,)) - c: pippy.fx.Node = graph.get_attr("y_attr") - d: pippy.fx.Node = graph.call_function(operator.add, args=(b, c)) - e: pippy.fx.Node = graph.call_function(torch.unsqueeze, args=(d, 0)) - graph.output(e) - graph.lint() - - self.assertEqual(b.all_input_nodes, [a]) - self.assertEqual(c.all_input_nodes, []) - self.assertEqual(d.all_input_nodes, [b, c]) - self.assertEqual(e.all_input_nodes, [d]) - - def test_deepcopy_graphmodule_with_transform(self): - st = SimpleTest() - traced = symbolic_trace(st) - traced.graph.lint() - - def transform(traced): - new_graph = pippy.fx.Graph() - val_map: Dict[Node, Node] = {} - output_value = new_graph.graph_copy(traced.graph, val_map) - relu_out = new_graph.create_node( - op="call_method", target="neg", args=(output_value,), kwargs={} - ) - new_graph.output(relu_out) - return GraphModule(traced, new_graph) - - transformed = transform(traced) - transformed.graph.lint() - copied = copy.deepcopy(transformed) - self.assertNotEqual(id(type(transformed)), id(type(copied))) - x = torch.randn(3, 4) - self.assertEqual(copied(x), transformed(x)) - - def test_deepcopy_with_submods_params(self): - class Bar(torch.nn.Module): - def __init__(self): - super().__init__() - self.param = torch.nn.Parameter(torch.rand(3, 4)) - - def forward(self, x): - return torch.relu(x) + self.param - - class Baz(torch.nn.Module): - def __init__(self): - super().__init__() - self.param = torch.nn.Parameter(torch.rand(3, 4)) - self.bar = Bar() - - def forward(self, x): - return self.bar(x) - self.param - - baz = Baz() - traced = symbolic_trace(baz) - traced.graph.lint() - copied = copy.deepcopy(traced) - copied.graph.lint() - - def test_deepcopy_graph_with_tracer_cls(self): - class TestTracer(Tracer): - def is_leaf_module(self, module, name): - return True - - g = Graph(tracer_cls=TestTracer) - x = g.placeholder("x") - g.output(x) - - h = copy.deepcopy(g) - self.assertIsNotNone(h._tracer_cls) - self.assertTrue(g._tracer_cls == h._tracer_cls) - - def test_unpack_list_better_error(self): - class SomeArgs(torch.nn.Module): - def forward(self, a, b): - return torch.rand(3, 4) - - class UnpacksList(torch.nn.Module): - def __init__(self): - super().__init__() - self.sa = SomeArgs() - - def forward(self, x: list): - return self.sa(*x) - - ul = UnpacksList() - with self.assertRaisesRegex( - TraceError, "Proxy object cannot be iterated." - ): - symbolic_trace(ul) - - def test_unpack_dict_better_error(self): - class SomeKwargs(torch.nn.Module): - def forward(self, x=3, y=4): - return torch.rand(3, 4) - - class UnpacksDict(torch.nn.Module): - def __init__(self): - super().__init__() - self.sk = SomeKwargs() - - def forward(self, x: dict): - return self.sk(**x) - - ud = UnpacksDict() - with self.assertRaisesRegex( - TraceError, "Proxy object cannot be iterated." - ): - symbolic_trace(ud) - - def test_pretty_print_targets(self): - # Test that Graph pretty-print prints friendly name for targets - # in `operator` and `builtins` - - class SomeMod(torch.nn.Module): - def forward(self, x): - return torch.add(x.foo + x.bar, 3.0) - - traced = symbolic_trace(SomeMod()) - graph_str = str(traced.graph) - self.assertIn("builtins.getattr", graph_str) - self.assertIn("operator.add", graph_str) - self.assertIn("torch.add", graph_str) - - def test_pretty_print_node(self): - class M(torch.nn.Module): - def __init__(self): - super().__init__() - self.param: torch.nn.Parameter = torch.nn.Parameter( - torch.rand(3, 4) - ) - self.linear = torch.nn.Linear(4, 5) - - def forward(self, x: torch.Tensor, y: int = 2): - return self.linear(x[y] + self.param).clamp(min=0.0, max=1.0) - - traced = symbolic_trace(M()) - - all_formatted = "\n".join([n.format_node() for n in traced.graph.nodes]) - - FileCheck().check("x").check("placeholder").check("y").check( - "placeholder" - ).check("getitem").check("call_function").check("param").check( - "get_attr" - ).check( - "add" - ).check( - "call_function" - ).check( - "linear" - ).check( - "call_module" - ).check( - "clamp" - ).check( - "call_method" - ).run( - all_formatted - ) - - def test_script_tensor_constant(self): - # TorchScript seems to ignore attributes that start with `__`. - # We used to call anonymous Tensor values `__tensor_constant*`, but - # they were getting ignored by script. Now they're called - # `_tensor_constant*` - class IHaveATensorConstant(torch.nn.Module): - def forward(self, x): - return x + torch.rand(3, 4) - - traced = pippy.fx.symbolic_trace(IHaveATensorConstant()) - torch.jit.script(traced) - - def test_autowrap_functions(self): - class AutowrapFnTest(torch.nn.Module): - def forward(self, x): - return fx_int(x.shape[0] / 2) - - class AutowrapFnTest2(torch.nn.Module): - def forward(self, x): - return fx_int(x.shape[0] / 2) + fx_int_x2(x.shape[0] / 2) - - # Check function(s) are wrapped - # `int` would normally throw a TypeError as argument can't be `Proxy` - tracer = Tracer(autowrap_functions=(fx_int,)) - graph = tracer.trace(AutowrapFnTest()) - traced = GraphModule(tracer.root, graph, "test") - tracer_2 = Tracer(autowrap_functions=(fx_int, fx_int_x2)) - tracer_2.trace(AutowrapFnTest2()) - - # Test scriptability - traced_scripted = torch.jit.script(traced) - self.assertEqual(traced_scripted(torch.rand(4)), 2) - - def test_tuple_no_subscript(self): - def foo(x: Tuple): - return x[0] - - traced = pippy.fx.symbolic_trace(foo) - x = (torch.randn(5, 3),) - torch.testing.assert_allclose(traced(x), x[0]) - - bio = io.BytesIO() - - torch.save(traced, bio) - - bio.seek(0) - - loaded = torch.load(bio) - - torch.testing.assert_allclose(loaded(x), x[0]) - - def test_torch_fx_len(self): - class FXLenTest(torch.nn.Module): - def forward(self, x): - return len(x) - - traced = symbolic_trace(FXLenTest()) - self.assertEqual(traced(torch.rand(3, 4)), 3) - - # Test scriptability - scripted = torch.jit.script(FXLenTest()) - self.assertEqual(scripted(torch.rand(3)), 3) - - traced_scripted = torch.jit.script(traced) - self.assertEqual(traced_scripted(torch.rand(3)), 3) - - # Test non-proxy len - class FXLenTest2(torch.nn.Module): - def __init__(self): - super().__init__() - self.l = [3, 4, 5] - - def forward(self, x): - return x + len(self.l) - - traced2 = symbolic_trace(FXLenTest2()) - inp = torch.rand(3, 4) - self.assertEqual(traced2(inp), inp + 3.0) - self.assertIs(len, builtins.len) - - def test_torch_fx_getattr(self): - class FXGetattrTest(torch.nn.Module): - def forward(self, x): - return getattr(x, "nonexistent_attr", torch.Tensor([2, 3])) - - traced = symbolic_trace(FXGetattrTest()) - self.assertEqual(traced(torch.rand(3, 4)), torch.Tensor([2, 3])) - - def test_sqrt(self): - class Sqrt1(torch.nn.Module): - def forward(self, x): - return sqrt(x.size(0)) - - class Sqrt2(torch.nn.Module): - def forward(self, x): - return math.sqrt(x.size(0)) - - class Sqrt3(torch.nn.Module): - def forward(self, x): - return x + math.sqrt(2) + sqrt(2) - - self.checkGraphModule(Sqrt1(), [torch.zeros(8)]) - self.checkGraphModule(Sqrt2(), [torch.zeros(8)]) - self.checkGraphModule(Sqrt3(), [torch.zeros(8)]) - self.assertIs(sqrt, _sqrt) - self.assertIs(math.sqrt, _sqrt) - - def test_torch_custom_ops(self): - class M(torch.nn.Module): - def forward(self, a): - b = torch.ops.aten.sigmoid(a) - c = torch.ops.aten.cat([a, b]) - return torch.ops.aten.cat((c, c)) - - m = M() - input = torch.randn(3) - ref_out = m(input) - gm = symbolic_trace(m) - gm.graph.lint() - out = gm(input) - self.assertEqual(out, ref_out) - - def test_torch_op_overloads(self): - class M(torch.nn.Module): - def forward(self, a): - b = torch.ops.aten.add.Tensor(a, a) - return b - - m = M() - input = torch.randn(3) - ref_out = m(input) - gm = symbolic_trace(m) - gm.graph.lint() - out = gm(input) - self.assertEqual(out, ref_out) - - for node in gm.graph.nodes: - if node.op == "call_function": - assert isinstance(node.target, torch._ops.OpOverload) - assert node.target.__name__ == "add.Tensor" - - def test_pickle_torch_custom_ops(self): - class M(torch.nn.Module): - def forward(self, a): - b = torch.ops.aten.sigmoid(a) - c = torch.ops.aten.cat([a, b]) - return torch.ops.aten.cat((c, c)) - - m = M() - input = torch.randn(3) - ref_out = m(input) - gm = symbolic_trace(m) - gm.graph.lint() - pickled = pickle.dumps(gm) - loaded = pickle.loads(pickled) - self.assertEqual(loaded(input), gm(input)) - - def test_pretty_print(self): - st = SimpleTest() - traced = symbolic_trace(st) - traced.graph.lint() - printed = str(traced) - assert "SimpleTest()" in printed - assert "torch.relu" in printed - - def test_pretty_print_graph(self): - class KwargPrintTest(torch.nn.Module): - def forward(self, x): - return torch.squeeze(x + 3.0, dim=2) - - st = KwargPrintTest() - traced = symbolic_trace(st) - traced.graph.lint() - stringed = str(traced.graph) - for s in ["args", "kwargs", "#users"]: - assert s in stringed - - def test_custom_proxy_type(self): - class TensorPair: - def __init__(self, left, right): - self.left, self.right = left, right - - def add(self, other): - l = self.left + other.left - r = self.right + other.right - return TensorPair(l, r) - - def mul(self, other): - l = self.left * other.left - r = self.right * other.right - return TensorPair(l, r) - - def use_tensor_pair(x: TensorPair, y: TensorPair): - s = x.add(y) - return s.mul(x) - - x = TensorPair(torch.randn(5, 3), torch.randn(5, 3)) - y = TensorPair(torch.randn(5, 3), torch.randn(5, 3)) - - ref_out = use_tensor_pair(x, y) - - traced = symbolic_trace(use_tensor_pair) - - traced_out = traced(x, y) - self.assertEqual(traced_out.left, ref_out.left) - self.assertEqual(traced_out.right, ref_out.right) - - def test_custom_proxy_type_literal(self): - class TensorPair(metaclass=pippy.fx.ProxyableClassMeta): - def __init__(self, left, right): - self.left, self.right = left, right - - def add(self, other): - l = self.left + other.left - r = self.right + other.right - return TensorPair(l, r) - - def mul(self, other): - l = self.left * other.left - r = self.right * other.right - return TensorPair(l, r) - - def use_tensor_pair_literal(x: TensorPair): - s = x.add(TensorPair(torch.zeros(5, 3), torch.zeros(5, 3))) - return s.mul(x) - - x = TensorPair(torch.randn(5, 3), torch.randn(5, 3)) - - ref_out = use_tensor_pair_literal(x) - - traced = symbolic_trace(use_tensor_pair_literal) - - traced_out = traced(x) - self.assertEqual(traced_out.left, ref_out.left) - self.assertEqual(traced_out.right, ref_out.right) - - def test_custom_proxy_dynamic_value(self): - class TensorPair(metaclass=pippy.fx.ProxyableClassMeta): - def __init__(self, left, right): - self.left, self.right = left, right - - def add(self, other): - l = self.left + other.left - r = self.right + other.right - return TensorPair(l, r) - - def mul(self, other): - l = self.left * other.left - r = self.right * other.right - return TensorPair(l, r) - - def use_tensor_pair_ctor(x: TensorPair, y: torch.Tensor): - s = x.add(TensorPair(y, y)) - return s.mul(x) - - x = TensorPair(torch.randn(5, 3), torch.randn(5, 3)) - y = torch.randn(5, 3) - ref_out = use_tensor_pair_ctor(x, y) - - traced = symbolic_trace(use_tensor_pair_ctor) - - traced_out = traced(x, y) - self.assertEqual(traced_out.left, ref_out.left) - self.assertEqual(traced_out.right, ref_out.right) - - def test_custom_proxy_input_dependent_control_flow(self): - class ZeroTensor(metaclass=pippy.fx.ProxyableClassMeta): - def __init__(self, inp): - if inp.sum() == 0: - self.is_zero = True - self.tensor = torch.tensor([]) - else: - self.is_zero = False - self.tensor = inp - - def add(self, other): - if self.is_zero: - return ZeroTensor(other.tensor) - elif other.is_zero: - return self - - def use_zero_tensor(x: torch.Tensor, y: torch.Tensor): - return ZeroTensor(x + y) - - x, y = torch.randn(5, 3), torch.randn(5, 3) - - ref_out = use_zero_tensor(x, y) - - traced = symbolic_trace(use_zero_tensor) - - traced_out = traced(x, y) - - self.assertEqual(traced_out.is_zero, ref_out.is_zero) - self.assertEqual(traced_out.tensor, ref_out.tensor) - - def test_graph_fns(self): - g = Graph() - a = g.placeholder("a") - b = g.call_module("linear", (a,)) - c = g.get_attr("bias") - d = g.call_method("add", (b, c)) - e = g.call_function(torch.sin, (d,)) - g.output(e) - mod = torch.nn.Module() - mod.linear = torch.nn.Linear(3, 4) - mod.bias = torch.rand(4) - gm = GraphModule(mod, g) - gm.graph.lint() - input = torch.rand(3) - r = gm(input) - ref = torch.sin(mod.linear(input) + mod.bias) - self.assertEqual(r, ref) - - def test_remove_uses(self): - g: pippy.fx.Graph = Graph() - x: pippy.fx.Node = g.placeholder("x") - relu: pippy.fx.Node = g.call_function(torch.relu, (x,)) - neg: pippy.fx.Node = g.call_function(torch.neg, (relu,)) - g.output(neg) - - neg.replace_all_uses_with(relu) - g.erase_node(neg) - - self.assertTrue(neg not in relu.users) - - def test_remove_uses_with_custom_filter(self): - g: pippy.fx.Graph = Graph() - x: pippy.fx.Node = g.placeholder("x") - relu: pippy.fx.Node = g.call_function(torch.relu, (x,)) - neg: pippy.fx.Node = g.call_function(torch.neg, (relu,)) - g.output(neg) - - neg.replace_all_uses_with(relu, lambda x: x != neg) - - self.assertTrue(neg in relu.users) - - def test_nonetype_annotation(self): - eb = torch.nn.EmbeddingBag(3, 4) - symbolic_trace(eb) - - def test_pickle_nonetype_annotation(self): - eb = torch.nn.EmbeddingBag(10, 3, mode="sum") - traced = symbolic_trace(eb) - pickled = pickle.dumps(traced) - loaded = pickle.loads(pickled) - loaded.graph.lint() - input = torch.LongTensor([1, 2, 4, 5, 4, 3, 2, 9]) - offsets = torch.LongTensor([0, 4]) - self.assertEqual(loaded(input, offsets), traced(input, offsets)) - - def test_return_tuple(self): - class M(torch.nn.Module): - def forward( - self, x: torch.Tensor - ) -> Tuple[torch.Tensor, torch.Tensor]: - return (x, x + x) - - original = M() - traced = symbolic_trace(original) - self.assertEqual(traced(torch.ones(1)), original.forward(torch.ones(1))) - - def test_construct_root_dict(self): - graph: pippy.fx.Graph = pippy.fx.Graph() - a: pippy.fx.Node = graph.create_node("placeholder", "x") - b: pippy.fx.Node = graph.create_node( - "call_module", "foo.bar.baz", args=(a,) - ) - c: pippy.fx.Node = graph.create_node("get_attr", "zip.zap.zam") - d: pippy.fx.Node = graph.create_node( - "call_function", operator.add, args=(b, c) - ) - graph.output(d) - - linear_mod: torch.nn.Module = torch.nn.Linear(3, 4) - add_param: torch.Tensor = torch.rand(3, 4) - gm: pippy.fx.GraphModule = pippy.fx.GraphModule( - {"foo.bar.baz": linear_mod, "zip.zap.zam": add_param}, graph - ) - gm.graph.lint() - - assert "self.foo.bar.baz" in gm.code - - x: torch.Tensor = torch.rand(3, 3) - out: torch.Tensor = gm(x) - ref_out: torch.Tensor = linear_mod(x) + add_param - self.assertEqual(out, ref_out) - - def test_symbolic_trace_assert(self): - class AssertsTensorShape(torch.nn.Module): - def forward(self, x): - torch._assert(x.shape[1] > 4, "assert_foobar") - return x - - m = AssertsTensorShape() - # verify traceability - traced = symbolic_trace(m) - # verify assertion on traced model works correctly at runtime - traced(torch.rand(4, 5)) - with self.assertRaisesRegex(AssertionError, "assert_foobar"): - traced(torch.rand(4, 3)) - # verify the symbolically traced module is scriptable - ms = torch.jit.script(m) - with self.assertRaisesRegex(torch.jit.Error, "assert_foobar"): - ms(torch.rand(4, 3)) - - def test_fx_create_arg(self): - class CustomArgObject: - def __init__(self, x, y): - self.x = x - self.y = y - - def __fx_create_arg__(self, tracer: pippy.fx.Tracer): - return tracer.create_node( - "call_function", - CustomArgObject, - args=( - tracer.create_arg(self.x), - tracer.create_arg(self.y), - ), - kwargs={}, - ) - - class HasCustomArgObjectWhenLeaf(torch.nn.Module): - def forward(self, o: CustomArgObject): - # Not normally traceable; good reason to make - # this module a leaf. - for x in o.x: - o.y += x - return o.y - - class Root(torch.nn.Module): - def __init__(self): - super().__init__() - self.inner = HasCustomArgObjectWhenLeaf() - - def forward(self, x, y): - o = CustomArgObject(x, y) - return self.inner(o) - - class CreateArgTracer(pippy.fx.Tracer): - def is_leaf_module(self, m, module_qualified_name): - return type(m) is HasCustomArgObjectWhenLeaf - - m = Root() - graph = CreateArgTracer().trace(m) - gm = pippy.fx.GraphModule(m, graph) - assert "CustomArgObject(" in gm.code - - def test_trace_fn_constant(self): - some_constant = torch.rand(3, 4) - - def add_const(x): - return some_constant + x - - traced = symbolic_trace(add_const) - - input = torch.rand(3, 4) - self.assertEqual(traced(input), add_const(input)) - - def test_copy_no_remap(self): - traced = symbolic_trace(SimpleTest()) - g = traced.graph - copied = pippy.fx.Graph() - for node in g.nodes: - copied.node_copy(node) - with self.assertRaisesRegex( - RuntimeError, "does not belong to this Graph" - ): - copied.lint() - - def test_wrong_topo(self): - graph: pippy.fx.Graph = pippy.fx.Graph() - a: pippy.fx.Node = graph.create_node("placeholder", "x") - b: pippy.fx.Node = graph.create_node( - "call_module", "foo.bar.baz", args=(a,) - ) - c: pippy.fx.Node = graph.create_node("get_attr", "zip.zap.zam") - d: pippy.fx.Node = graph.create_node( - "call_function", operator.add, args=(b, c) - ) - graph.output(d) - nodes = list(graph.nodes) - nodes[3].append(nodes[2]) - with self.assertRaisesRegex( - RuntimeError, "was used before it has been defined" - ): - graph.lint() - - def test_wrong_target_type(self): - graph: pippy.fx.Graph = pippy.fx.Graph() - with self.assertRaises(ValueError): - n = pippy.fx.Node( - graph=graph, - name="foo", - op="call_function", - target="foo", - args=(), - kwargs={}, - ) - - def test_example_shape_prop(self): - class TestCase(torch.nn.Module): - def __init__(self): - super().__init__() - self.attr = torch.randn(3, 4) - self.submod = torch.nn.Linear(4, 4) - - def forward(self, x): - return torch.neg(self.submod(x.relu() + self.attr)) - - tc = TestCase() - tc_traced = symbolic_trace(tc) - ref_out = tc_traced(torch.rand(3, 4)) - shape_prop.ShapeProp(tc_traced).propagate(torch.rand(3, 4)) - - # Make sure we're testing all opcodes - opcodes = set() - output_shape: Optional[torch.Shape] = None - output_stride: Optional[Tuple[int]] = None - for node in tc_traced.graph.nodes: - opcodes.add(node.op) - if node.op == "output": - output_shape = node.args[0].meta["tensor_meta"].shape - output_stride = node.args[0].meta["tensor_meta"].stride - self.assertEqual( - opcodes, - set( - [ - "placeholder", - "get_attr", - "call_function", - "call_method", - "call_module", - "output", - ] - ), - ) - - # Test shape propagation and make sure results match actual - self.assertEqual(output_shape, ref_out.shape) - self.assertEqual(output_stride, ref_out.stride()) - - def test_shape_prop_layout(self): - class ConvTest(torch.nn.Module): - def __init__(self): - super().__init__() - self.conv_mod = torch.nn.Conv2d(5, 5, 3) - - def forward(self, x): - return self.conv_mod(x) - - # contiguous layout - test_mod = ConvTest() - traced = symbolic_trace(test_mod) - x = torch.randn(5, 5, 224, 224) - shape_prop.ShapeProp(traced).propagate(x) - - assert all( - node.meta["tensor_meta"].memory_format is torch.contiguous_format - for node in traced.graph.nodes - ) - - x_channels_last = x.contiguous(memory_format=torch.channels_last) - traced.to(memory_format=torch.channels_last) - shape_prop.ShapeProp(traced).propagate(x_channels_last) - for node in traced.graph.nodes: - # NB: the implementation of conv may not preserve the memory format, - # unfortunately. The best we can do is just check that the placeholder - # node is channels-last - if node.op in {"placeholder"}: - self.assertEqual( - node.meta["tensor_meta"].memory_format, torch.channels_last - ) - - def test_shape_prop_aggregate(self): - class ReturnTwo(torch.nn.Module): - def forward(self, x): - return (3, torch.sum(x)) - - class UnderTest(torch.nn.Module): - def __init__(self): - super().__init__() - self.rt = ReturnTwo() - - def forward(self, x): - return self.rt(x) - - ut = UnderTest() - - class RTTracer(pippy.fx.Tracer): - def is_leaf_module(self, m, module_qualified_name): - return type(m) is ReturnTwo - - graph = RTTracer().trace(ut) - mod = pippy.fx.GraphModule(ut, graph) - - shape_prop.ShapeProp(mod).propagate(torch.rand(3, 4)) - - for node in mod.graph.nodes: - if node.op == "call_module": - assert "tensor_meta" in node.meta - tensor_meta = node.meta["tensor_meta"] - assert tensor_meta[0] == 3 - assert tensor_meta[1].shape == torch.Size([]) - - def test_shape_prop_layout_3d(self): - class ConvTest3d(torch.nn.Module): - def __init__(self): - super().__init__() - self.conv_mod = torch.nn.Conv3d(5, 5, 3) - - def forward(self, x): - return self.conv_mod(x) - - test_mod_3d = ConvTest3d() - traced_3d = symbolic_trace(test_mod_3d) - x_3d = torch.randn(5, 5, 224, 224, 15) - shape_prop.ShapeProp(traced_3d).propagate(x_3d) - assert all( - node.meta["tensor_meta"].memory_format is torch.contiguous_format - for node in traced_3d.graph.nodes - ) - - x_channels_last_3d = x_3d.contiguous( - memory_format=torch.channels_last_3d - ) - traced_3d.to(memory_format=torch.channels_last_3d) - shape_prop.ShapeProp(traced_3d).propagate(x_channels_last_3d) - for node in traced_3d.graph.nodes: - # NB: the implementation of conv may not preserve the memory format, - # unfortunately. The best we can do is just check that the placeholder - # node is channels-last - if node.op in {"placeholder"}: - self.assertEqual( - node.meta["tensor_meta"].memory_format, - torch.channels_last_3d, - ) - - def test_interpreter(self): - class MyModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.param = torch.nn.Parameter(torch.rand(3, 4)) - self.linear = torch.nn.Linear(4, 5) - - def forward(self, x): - return self.linear(x + self.param).clamp(min=0.0, max=1.0) - - m = MyModule() - gm = pippy.fx.symbolic_trace(m) - - interpreter = Interpreter(gm) - input = torch.randn(3, 4) - self.assertEqual(interpreter.run(input), gm(input)) - self.assertEqual(interpreter.run(input), m(input)) - - def test_interpreter_run_node_override(self): - class MyModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.param = torch.nn.Parameter(torch.rand(3, 4)) - self.linear = torch.nn.Linear(4, 5) - - def forward(self, x): - return self.linear(x + self.param).clamp(min=0.0, max=1.0) - - m = MyModule() - gm = pippy.fx.symbolic_trace(m) - - class RunNodeInterpreter(Interpreter): - def __init__(self, module): - super().__init__(module) - - def run_node(self, n: Node) -> Any: - result = super().run_node(n) - n.cached_value = result - return result - - input = torch.randn(3, 4) - RunNodeInterpreter(gm).run(input) - for node in gm.graph.nodes: - assert hasattr(node, "cached_value") - - def test_interpreter_onthefly_swap(self): - def fn(x): - return torch.sigmoid(x).neg() - - gm = pippy.fx.symbolic_trace(fn) - - class NegSigmSwapInterpreter(Interpreter): - def call_function( - self, target: Target, args: Tuple, kwargs: Dict - ) -> Any: - if target == torch.sigmoid: - return torch.neg(*args, **kwargs) - return super().call_function(n) - - def call_method( - self, target: Target, args: Tuple, kwargs: Dict - ) -> Any: - if target == "neg": - call_self, *args_tail = args - return call_self.sigmoid(*args_tail, **kwargs) - return super().call_method(n) - - input = torch.randn(3, 4) - result = NegSigmSwapInterpreter(gm).run(input) - self.assertEqual(result, torch.neg(input).sigmoid()) - - def test_interpreter_partial_eval(self): - class MyModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.param = torch.nn.Parameter(torch.rand(3, 4)) - self.linear = torch.nn.Linear(4, 5) - - def forward(self, x): - return self.linear(x + self.param).clamp(min=0.0, max=1.0) - - gm = pippy.fx.symbolic_trace(MyModule()) - interp = Interpreter(gm) - env = {} - for node in gm.graph.nodes: - if node.op == "call_module" and node.target == "linear": - env[node] = torch.arange(0, 12, 1).reshape(3, 4) - 6.0 - break - assert len(env) == 1 - x = torch.randn(3, 4) - result = interp.run(x, initial_env=env) - self.assertEqual( - result, (torch.arange(0, 12, 1).reshape(3, 4) - 6.0).clamp(0.0, 1.0) - ) - - def test_interpreter_star_args(self): - def with_star_args(x, *args): - return x + args[0] - - gm = pippy.fx.symbolic_trace(with_star_args) - interp = Interpreter(gm) - result = interp.run( - torch.ones(3, 4), torch.ones(3, 4), torch.rand(3, 4) - ) - self.assertEqual(result, torch.ones(3, 4) * 2.0) - - @skipIfNoTorchVision - def test_interpreter_noop_resnet18(self): - rn18 = torchvision_models.resnet18() - transformed = pippy.fx.Transformer(symbolic_trace(rn18)).transform() - inp = torch.randn(5, 3, 224, 224) - self.assertEqual(transformed(inp), rn18(inp)) - - @skipIfNoTorchVision - def test_interpreter_gc_values(self): - rn18 = torchvision_models.resnet18() - interp = Interpreter(symbolic_trace(rn18)) - inp = torch.rand(5, 3, 224, 224) - out = interp.run(inp) - env_key_names = set(n.name for n in interp.env.keys()) - self.assertEqual(env_key_names, set(["output"])) - - def test_interpreter_default_args(self): - class Model(torch.nn.Module): - def forward(self, x, y=3.14159): - return x + y - - model = Model() - gm = pippy.fx.symbolic_trace(model) - - interp = Interpreter(gm) - x = torch.randn(5, 3) - out = interp.run(x) - torch.testing.assert_allclose(out, x + 3.14159) - - def test_interpreter_not_enough_args(self): - class Model(torch.nn.Module): - def forward(self, x, y): - return x + y - - model = Model() - gm = pippy.fx.symbolic_trace(model) - - interp = Interpreter(gm) - x = torch.randn(5, 3) - with self.assertRaisesRegex( - RuntimeError, - "Expected positional argument for parameter y, but one was not passed in", - ): - out = interp.run(x) - - def test_transformer_noop(self): - class MyModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.param = torch.nn.Parameter(torch.rand(3, 4)) - self.linear = torch.nn.Linear(4, 5) - - def forward(self, x): - return self.linear(x + self.param).clamp(min=0.0, max=1.0) - - m = MyModule() - gm = pippy.fx.symbolic_trace(m) - - new_gm = Transformer(gm).transform() - - input = torch.randn(3, 4) - self.assertEqual(new_gm(input), gm(input)) - - def test_transformer_op_swap(self): - def fn(x): - return torch.sigmoid(x).neg() - - gm = pippy.fx.symbolic_trace(fn) - - class NegSigmSwapXformer(Transformer): - def call_function( - self, target: Target, args: Tuple, kwargs: Dict - ) -> Any: - if target == torch.sigmoid: - return torch.neg(*args, **kwargs) - return super().call_function(n) - - def call_method( - self, target: Target, args: Tuple, kwargs: Dict - ) -> Any: - if target == "neg": - call_self, *args_tail = args - return call_self.sigmoid(*args_tail, **kwargs) - return super().call_method(n) - - transformed = NegSigmSwapXformer(gm).transform() - input = torch.randn(3, 4) - self.assertEqual(transformed(input), torch.neg(input).sigmoid()) - - def test_transformer_multi_outputs(self): - class MyModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.param = torch.nn.Parameter(torch.rand(3, 4)) - self.linear = torch.nn.Linear(4, 5) - - def forward(self, x): - x = x + self.param - out = self.linear(x) - return x, out - - m = MyModule() - gm = pippy.fx.symbolic_trace(m) - - new_gm = Transformer(gm).transform() - - input = torch.randn(3, 4) - self.assertEqual(new_gm(input), gm(input)) - - def test_fn_type_annotations(self): - class Foo(torch.nn.Module): - def forward( - self, p: Pair, z: torch.Tensor, i: int - ) -> Dict[str, torch.Tensor]: - return {"a": p.x + p.y + z + i} - - foo_scripted = torch.jit.script(Foo()) - foo_scripted(Pair(torch.rand(5), torch.rand(5)), torch.rand(5), 3) - - fxed = symbolic_trace(Foo()) - fxed_scripted = torch.jit.script(fxed) - fxed_scripted(Pair(torch.rand(5), torch.rand(5)), torch.rand(5), 3) - - def test_fn_type_annotation_empty(self): - def forward(a: List[torch.Tensor]): - return a[0] - - torch.jit.script(symbolic_trace(forward)) - - def test_wrapped_method(self): - def wrap_with_relu(fn): - @functools.wraps(fn) - def wrapper(*args, **kwargs): - return torch.relu(fn(*args, **kwargs)) - - return wrapper - - class Foo(torch.nn.Module): - @wrap_with_relu - def forward(self, x, w): - return torch.matmul(x, w) - - f = Foo() - traced = symbolic_trace(f) - x, w = torch.rand(3, 4), torch.rand(4, 4) - self.assertTrue(any(n.target == torch.relu for n in traced.graph.nodes)) - - def test_empty_graph_codegen(self): - graph = pippy.fx.Graph() - gm = pippy.fx.GraphModule(torch.nn.Module(), graph) - self.assertEqual(gm(), None) - - def test_sequential(self): - m = torch.nn.Sequential(torch.nn.Conv2d(1, 1, 1)) - gm = pippy.fx.symbolic_trace(m) - gm_copy = copy.deepcopy(gm) - - def test_ctx_mgr(self): - @contextlib.contextmanager - def do_nothing(): - yield - - class M(torch.nn.Module): - def __init__(self): - super().__init__() - - @do_nothing() - def forward(self, x): - return torch.relu(x) - - m = M() - self.checkGraphModule(m, (torch.rand(3, 4),)) - - def test_typename_print(self): - graph: pippy.fx.Graph = pippy.fx.Graph() - x: pippy.fx.Node = graph.create_node("placeholder", "x") - b: pippy.fx.Node = graph.create_node( - "call_function", target=torch.relu, args=(x,), type_expr=List[float] - ) - output: pippy.fx.Node = graph.output(b) - - self.assertTrue("typing.List[float]" in str(graph)) - - def test_layout(self): - class M(torch.nn.Module): - def __init__(self): - super().__init__() - - def forward(self, x): - return torch.empty_like( - x, layout=torch.strided, pin_memory=False - ).fill_(0) - - traced = symbolic_trace(M()) - x = torch.rand(5, 9, 3, 4) - self.assertEqual(traced(x), torch.zeros_like(x)) - - def test_ellipsis(self): - class M(torch.nn.Module): - def __init__(self): - super().__init__() - - def forward(self, x, y): - return x + y[:, 1:10, ...] - - traced = symbolic_trace(M()) - x, y = torch.rand(5, 9, 3, 4), torch.rand(5, 15, 3, 4) - self.assertEqual(traced(x, y), x + y[:, 1:10, ...]) - - def test_inf_nan(self): - class FooMod(torch.nn.Module): - def forward(self, x): - return x + float("inf"), x + float("-inf"), x + float("nan") - - fm = FooMod() - self.checkGraphModule(fm, (torch.rand(3, 4),)) - - def test_inf_nan_kwds(self): - graph: pippy.fx.Graph = pippy.fx.Graph() - x: pippy.fx.Node = graph.create_node("placeholder", "x") - b: pippy.fx.Node = graph.create_node( - "call_function", operator.add, (x, float("inf")), {}, name="inf" - ) - c: pippy.fx.Node = graph.create_node( - "call_function", operator.add, (x, float("nan")), {}, name="nan" - ) - graph.output((b, c)) - - gm = pippy.fx.GraphModule(torch.nn.Module(), graph) - x = torch.rand(3, 4) - self.assertEqual(gm(x), (x + float("inf"), x + float("nan"))) - - def test_deepcopy_recursion_depth(self): - depth = sys.getrecursionlimit() + 20 - - g = pippy.fx.Graph() - x = g.placeholder("x") - for i in range(depth): - x = g.call_function(torch.relu, (x,)) - g.output(x) - - copied_graph = copy.deepcopy(g) - - val_map = {} - for orig_node, new_node in zip(g.nodes, copied_graph.nodes): - val_map[orig_node] = new_node - - for orig_node, new_node in zip(g.nodes, copied_graph.nodes): - orig_users = set(orig_node.users.keys()) - orig_users_equiv = set(val_map[u] for u in orig_users) - new_users = set(new_node.users.keys()) - self.assertEqual(orig_users_equiv, new_users) - - @skipIfNoTorchVision - def test_replace_uses(self): - rn18 = torchvision_models.resnet18() - - class LowerReluTracer(pippy.fx.Tracer): - def is_leaf_module(self, m: torch.nn.Module, qualname: str): - if isinstance(m, torch.nn.ReLU): - return False - return super().is_leaf_module(m, qualname) - - rn18_traced = GraphModule(rn18, LowerReluTracer().trace(rn18)) - - to_erase = [] - for node in rn18_traced.graph.nodes: - if node.op == "call_function" and node.target in [ - torch.relu, - torch.nn.functional.relu, - ]: - kwargs = node.kwargs.copy() - # Neg doesn't have in-place - kwargs.pop("inplace") - with rn18_traced.graph.inserting_before(node): - new_node = rn18_traced.graph.call_function( - the_function=torch.neg, - args=node.args, - kwargs=node.kwargs, - ) - node.replace_all_uses_with(replace_with=new_node) - to_erase.append(node) - - for node in to_erase: - rn18_traced.graph.erase_node(node) - - def test_replace_input(self): - graph: pippy.fx.Graph = pippy.fx.Graph() - x: pippy.fx.Node = graph.create_node("placeholder", "x") - y: pippy.fx.Node = graph.create_node("placeholder", "y") - b: pippy.fx.Node = graph.create_node( - "call_function", target=torch.relu, args=(x,) - ) - output: pippy.fx.Node = graph.output(b) - - b.replace_input_with(x, y) - - gm = pippy.fx.GraphModule(torch.nn.Module(), graph) - - input_x = torch.randn(33, 44) - input_y = torch.randn(11, 22) - self.assertEqual(gm(input_x, input_y), torch.relu(input_y)) - - def test_insertion_point(self): - graph: pippy.fx.Graph = pippy.fx.Graph() - x: pippy.fx.Node = graph.create_node("placeholder", "x") - b: pippy.fx.Node = graph.create_node( - "call_function", target=torch.relu, args=(x,) - ) - output: pippy.fx.Node = graph.output(b) - - with graph.inserting_before(b): - neg: pippy.fx.Node = graph.call_function( - the_function=torch.neg, args=(x,) - ) - _, *relu_args = b.args - b.args = (neg, *relu_args) - - gm = pippy.fx.GraphModule(torch.nn.Module(), graph) - - input = torch.randn(33, 44) - self.assertEqual(gm(input), torch.relu(torch.neg(input))) - - def test_update_args_api(self): - graph: pippy.fx.Graph = pippy.fx.Graph() - x: pippy.fx.Node = graph.create_node("placeholder", "x") - y: pippy.fx.Node = graph.create_node("placeholder", "y") - b: pippy.fx.Node = graph.create_node( - "call_function", target=torch.relu, args=(x,) - ) - output: pippy.fx.Node = graph.output(b) - - orig_gm = pippy.fx.GraphModule(torch.nn.Module(), graph) - inp_x, inp_y = torch.randn(5, 3), torch.randn(3, 5) - self.assertEqual(orig_gm(inp_x, inp_y), torch.relu(inp_x)) - - b.update_arg(0, y) - new_gm = pippy.fx.GraphModule(torch.nn.Module(), graph) - self.assertEqual(new_gm(inp_x, inp_y), torch.relu(inp_y)) - - def test_update_kwargs_api(self): - graph: pippy.fx.Graph = pippy.fx.Graph() - x: pippy.fx.Node = graph.create_node("placeholder", "x") - y: pippy.fx.Node = graph.create_node("placeholder", "y") - b: pippy.fx.Node = graph.create_node( - "call_function", target=torch.relu, kwargs={"input": x} - ) - output: pippy.fx.Node = graph.output(b) - - orig_gm = pippy.fx.GraphModule(torch.nn.Module(), graph) - inp_x, inp_y = torch.randn(5, 3), torch.randn(3, 5) - self.assertEqual(orig_gm(inp_x, inp_y), torch.relu(inp_x)) - - b.update_kwarg("input", y) - new_gm = pippy.fx.GraphModule(torch.nn.Module(), graph) - self.assertEqual(new_gm(inp_x, inp_y), torch.relu(inp_y)) - - def test_immutable_list_pytree_ops(self): - rand_tensor = torch.randn(5, 3) - l = immutable_list([3, [rand_tensor, 42]]) - - flattened, spec = pytree.tree_flatten(l) - assert flattened == [3, rand_tensor, 42] - - unflattened = pytree.tree_unflatten(flattened, spec) - assert unflattened == l - assert isinstance(unflattened, immutable_list) - - def test_immutable_dict_pytree_ops(self): - rand_tensor = torch.randn(5, 3) - d = immutable_dict({"a": 3, "b": [rand_tensor, 42]}) - - flattened, spec = pytree.tree_flatten(d) - assert flattened == [3, rand_tensor, 42] - - unflattened = pytree.tree_unflatten(flattened, spec) - assert unflattened == d - assert isinstance(unflattened, immutable_dict) - - def test_move_before(self): - graph: pippy.fx.Graph = pippy.fx.Graph() - x: pippy.fx.Node = graph.create_node("placeholder", "x") - b: pippy.fx.Node = graph.create_node( - "call_function", target=torch.relu, args=(x,) - ) - output: pippy.fx.Node = graph.output(b) - - neg: pippy.fx.Node = graph.call_function( - the_function=torch.neg, args=(x,) - ) - _, *relu_args = b.args - b.args = (neg, *relu_args) - b.prepend(neg) - - gm = pippy.fx.GraphModule(torch.nn.Module(), graph) - - input = torch.randn(33, 44) - self.assertEqual(gm(input), torch.relu(torch.neg(input))) - - def test_prepend_self(self): - graph: pippy.fx.Graph = pippy.fx.Graph() - x: pippy.fx.Node = graph.create_node("placeholder", "x") - b: pippy.fx.Node = graph.create_node( - "call_function", target=torch.relu, args=(x,) - ) - output: pippy.fx.Node = graph.output(b) - - b.prepend(b) - x.append(b) - self.assertEqual(len(graph.nodes), 3) - - def test_erase_node_error(self): - st = SimpleTest() - traced = symbolic_trace(st) - - for node in traced.graph.nodes: - # Test deleting with uses both in another Node and at the output - if node.target in [operator.add, torch.relu]: - with self.assertRaisesRegex( - RuntimeError, "but it still had .* users in the graph" - ): - traced.graph.erase_node(node) - - def test_copy_it(self): - d = immutable_dict([(3, 4), (5, 6)]) - l = immutable_list([(3, 4), (5, 6)]) - - self.assertEqual(d, deepcopy(d)) - self.assertEqual(l, deepcopy(l)) - - def test_get_torch_func_signature(self): - for key in dir(torch): - obj = getattr(torch, key) - if callable(obj): - schemas = get_signature_for_torch_op(obj) - - def test_find_uses(self): - graph = pippy.fx.Graph() - x = pippy.fx.Proxy(graph.placeholder("x")) - - y = torch.relu(x) - z = x + x - u = torch.neg(x) - graph.output((y + z + u).node) - graph.lint() - - users_of_x = x.node.users - self.assertEqual(len(users_of_x), 3) - expected_ops = set(["relu", "add", "neg"]) - for use in users_of_x: - assert any(use.name.startswith(prefix) for prefix in expected_ops) - - def test_inline_graph(self): - class InlineInto(torch.nn.Module): - def forward(self, x): - return torch.relu(x) - - class ToInline(torch.nn.Module): - def forward(self, x): - return torch.neg(x) - - inline_into = symbolic_trace(InlineInto()) - to_inline = symbolic_trace(ToInline()) - - combined_graph = pippy.fx.Graph() - output_node = combined_graph.graph_copy(inline_into.graph, {}) - - input_node = list(to_inline.graph.nodes)[0] - assert input_node and input_node.op == "placeholder" - - val_map = {input_node: output_node} - output = combined_graph.graph_copy(to_inline.graph, val_map) - combined_graph.output(output) - - combined_module = pippy.fx.GraphModule( - torch.nn.Module(), combined_graph - ) - - input = torch.rand(3, 4) - self.assertEqual(combined_module(input), input.relu().neg()) - - def test_multi_insert_point(self): - graph = pippy.fx.Graph() - x = pippy.fx.Proxy(graph.placeholder("x")) - relu = torch.relu(x) - - with graph.inserting_before(relu.node): - y = torch.neg(x) - z = torch.tanh(y) - - graph.output((relu.node, z.node)) - graph.lint() - - expected_ops = ["x", "neg", "tanh", "relu"] - for node, expected in zip(graph.nodes, expected_ops): - assert expected in node.name - - def test_reassign_args_kwargs_uses(self): - graph = pippy.fx.Graph() - x, y = Proxy(graph.placeholder("x")), Proxy(graph.placeholder("y")) - z = x + y - zed = z + z + z - graph.output(zed.node) - graph.lint() - - # zed = z + z + z -> zed = z + z + x - zed.node.args = (zed.node.args[0], x.node) - self.assertEqual(list(x.node.users.keys()), [z.node, zed.node]) - - # z = x + y -> z = y + y - z.node.args = (y.node, y.node) - self.assertEqual(list(x.node.users.keys()), [zed.node]) - - def test_trace_function(self): - def foo(x, y): - return torch.relu(x) + y - - x, y = torch.randn(3, 4), torch.randn(3, 4) - self.checkGraphModule(foo, (x, y)) - - def test_trace_dict_int_keys(self): - class ModWithDictArg(torch.nn.Module): - def forward(self, d: Dict[int, torch.Tensor]): - return d[42] - - class CallsModWithDict(torch.nn.Module): - def __init__(self): - super().__init__() - self.m = ModWithDictArg() - - def forward(self, x): - return self.m({42: x}) - - class MyTracer(pippy.fx.Tracer): - def is_leaf_module( - self, m: torch.nn.Module, module_qualified_name: str - ) -> bool: - return isinstance(m, ModWithDictArg) - - traced_graph = MyTracer().trace(CallsModWithDict()) - - def test_trace_dict_proxy_keys(self): - class ModWithDictArg(torch.nn.Module): - def forward(self, d: Dict[torch.Tensor, torch.Tensor]): - return d[42] - - class CallsModWithDict(torch.nn.Module): - def __init__(self): - super().__init__() - self.m = ModWithDictArg() - - def forward(self, x): - return self.m({x: x}) - - class MyTracer(pippy.fx.Tracer): - def is_leaf_module( - self, m: torch.nn.Module, module_qualified_name: str - ) -> bool: - return isinstance(m, ModWithDictArg) - - with self.assertRaisesRegex(RuntimeError, "cannot contain a Node"): - traced_graph = MyTracer().trace(CallsModWithDict()) - - def test_module_deepcopy_edit_nodes(self): - class Foo(torch.nn.Module): - def forward(self, x): - return torch.relu(x) - - traced1 = symbolic_trace(Foo()) - copied = copy.deepcopy(traced1) - - for node in copied.graph.nodes: - if node.target == torch.relu: - node.target = torch.neg - - copied.recompile() - traced1.recompile() - - x = torch.randn(15, 15) - torch.testing.assert_allclose(traced1(x), torch.relu(x)) - torch.testing.assert_allclose(copied(x), torch.neg(x)) - - def test_direct_param_use(self): - class TransposeTest(torch.nn.Module): - def __init__(self): - super().__init__() - self.b = torch.nn.Parameter(torch.rand(4, 3)) - - def forward(self, x): - return self.b - - class Foo(torch.nn.Module): - def __init__(self): - super().__init__() - self.a = TransposeTest() - - def forward(self, x): - return self.a.b, self.a.b.t(), self.a.b.view(12) - - traced = pippy.fx.symbolic_trace(Foo()) - assert all("constant" not in node.target for node in traced.graph.nodes) - - def test_single_default_arg(self): - class M(torch.nn.Module): - def __init__(self): - super().__init__() - - def forward(self, y=1): - return y - - m = M() - self.checkGraphModule(m, ()) - self.checkGraphModule(m, (3,)) - - def test_multiple_default_args(self): - class M(torch.nn.Module): - def __init__(self): - super().__init__() - - def forward(self, y=1, z=2): - return y + z - - m = M() - self.checkGraphModule(m, ()) - self.checkGraphModule(m, (3,)) - self.checkGraphModule(m, (3, 4)) - - def test_regular_and_default_args(self): - class M(torch.nn.Module): - def __init__(self): - super().__init__() - - def forward(self, x, y=1): - return x + y - - m = M() - self.checkGraphModule(m, (2,)) - self.checkGraphModule(m, (2, 3)) - - def test_string_literal_return(self): - class M(torch.nn.Module): - def __init__(self): - super().__init__() - - def forward(self): - return "foo" - - m = M() - self.checkGraphModule(m, ()) - - def test_namedtuple_return_qualname(self): - class NamedTupReturn(torch.nn.Module): - def forward(self, x): - return MyNamedTup(x, x) - - traced = symbolic_trace(NamedTupReturn()) - input = torch.rand(3, 4) - self.assertEqual(traced(input), MyNamedTup(input, input)) - - def test_update_args_kwargs_yells_at_you(self): - symtraced = symbolic_trace(SimpleTest()) - node = next(iter(symtraced.graph.nodes)) - with self.assertRaisesRegex(AttributeError, "__update_args_kwargs"): - node.__update_args_kwargs((), {}) - - def test_torchbind_class_attribute_in_fx(self): - if IS_FBCODE or IS_WINDOWS or IS_MACOS: - self.skipTest( - "torch.classes._TorchScriptTesting._StackString is registered, skipping" - ) - - class FooBar1234(torch.nn.Module): - def __init__(self): - super(FooBar1234, self).__init__() - self.f = torch.classes._TorchScriptTesting._StackString( - ["3", "4"] - ) - - def forward(self): - return self.f.top() - - m = FooBar1234() - self.checkGraphModule(m, ()) - - def test_torchbind_class_attribute_in_fx_tensor_arg(self): - if IS_FBCODE or IS_WINDOWS or IS_MACOS: - self.skipTest( - "torch.classes._TorchScriptTesting._ReLUClass is registered, skipping" - ) - - class FooBar2341(torch.nn.Module): - def __init__(self): - super(FooBar2341, self).__init__() - self.f = torch.classes._TorchScriptTesting._ReLUClass() - - def forward(self, x): - return self.f.run(x) - - m = FooBar2341() - - traced = symbolic_trace(m) - input = torch.randn(3, 4) - self.assertEqual(traced(input), m(input)) - - self.assertTrue(any(n.op == "call_method" for n in traced.graph.nodes)) - - def test_script_method_trace(self): - class Scripted(torch.nn.Module): - def forward(self, x): - return torch.relu(x) - - class Holder(torch.nn.Module): - def __init__(self): - super().__init__() - self.s = torch.jit.script(Scripted()) - - def forward(self, x): - return self.s(x) - - h = Holder() - traced = symbolic_trace(h) - input = torch.randn(3, 4) - self.assertEqual(traced(input), h(input)) - - self.assertTrue(any(n.op == "call_method" for n in traced.graph.nodes)) - - def test_namedtuple_return_trace(self): - class NamedTupReturn(torch.nn.Module): - def forward(self, x): - return Pair(x, x) - - traced = symbolic_trace(NamedTupReturn()) - input = torch.rand(3, 4) - self.assertEqual(traced(input), Pair(input, input)) - - def test_named_tuple_inlined(self): - class NamedTupMod(torch.nn.Module): - def forward(self, inp): - return wrapped_named_tup(Pair(inp, 1.2), p2=Pair(3.4, inp)) - - m = NamedTupMod() - input = torch.rand(3, 4) - ref = m(input) - traced = symbolic_trace(m) - - res = traced(input) - self.assertEqual(ref, res) - - # Check Pair NamedTuple works when inlined into the function call. - ph = call_func = None - for node in traced.graph.nodes: - if node.op == "placeholder": - ph = node - elif ( - node.op == "call_function" and node.target == wrapped_named_tup - ): - node.update_arg(0, Pair(ph, 1.2)) - node.update_kwarg("p2", Pair(3.4, ph)) - call_func = node - break - self.assertTrue(call_func is not None) - self.assertTrue(isinstance(call_func.args[0], Pair)) - self.assertTrue(isinstance(call_func.kwargs["p2"], Pair)) - self.assertEqual(_format_arg(call_func.args[0]), "Pair(x=%inp, y=1.2)") - self.assertEqual( - _format_arg(call_func.kwargs["p2"]), "Pair(x=3.4, y=%inp)" - ) - - traced.graph.eliminate_dead_code() - traced.recompile() - res = traced(input) - self.assertEqual(ref, res) - - def test_return_type_exists(self): - class ReturnTypeModule(torch.nn.Module): - def other(self, x: List[str]) -> List[str]: - return x - - def forward(self, x: List[str]) -> List[str]: - return self.other(x) - - traced = symbolic_trace(ReturnTypeModule()) - self.assertIn("-> typing_List[str]", traced._code) - scripted = torch.jit.script(traced) - self.assertIn("-> List[str]", scripted.code) - - def getitem_inner(self): - class GetItemBase(torch.nn.Module): - def __init__(self): - super().__init__() - self.register_buffer("pe", torch.randn(8, 8)) - - class GetItem1(GetItemBase): - def forward(self, x): - return self.pe[:, : x.size(0)] - - class GetItem2(GetItemBase): - def forward(self, x): - return self.pe[x.size(0)] - - class GetItem3(GetItemBase): - def forward(self, x): - return self.pe[4] # fx creates `self._tensor_constant0` here - - self.checkGraphModule(GetItem1(), [torch.zeros(4)]) - self.checkGraphModule(GetItem2(), [torch.zeros(4)]) - self.checkGraphModule(GetItem3(), [torch.zeros(4)]) - - @unittest.skipUnless( - os.environ.get("FX_PATCH_GETITEM") == "1", - "Will be checked in test_getitem_subproc", - ) - def test_getitem(self): - self.getitem_inner() - - def test_getitem_subproc(self): - # need to run this test in a subproc to work around: - # https://github.com/pytorch/pytorch/issues/50710 - proc = Process(target=run_getitem_target) - proc.start() - proc.join() - self.assertEqual(proc.exitcode, 0) - - def test_user_friendly_call_provenance_with_function(self): - def fn(x): - return wrapper_fn(x) - - traced = pippy.fx.symbolic_trace(fn) - - with self.assertRaisesRegex( - RuntimeError, - "'wrapper_fn' is " - "being compiled since it was called" - " from 'fn.forward'", - ): - scripted = torch.jit.script(traced) - - def test_user_friendly_call_provenance_with_module(self): - class M(torch.nn.Module): - def forward(self, x): - return wrapper_fn(x) - - traced = pippy.fx.symbolic_trace(M()) - - with self.assertRaisesRegex( - RuntimeError, - "'wrapper_fn' is " - "being compiled since it was called" - " from 'M.forward'", - ): - scripted = torch.jit.script(traced) - - def test_snake_case(self): - class M(torch.nn.Module): - def __init__(self): - super(M, self).__init__() - self.activations = torch.nn.ModuleDict( - [ - ["snake_case", torch.nn.ReLU()], - ["PascalCase", torch.nn.LeakyReLU()], - ["ALL_CAPS", torch.nn.PReLU()], - ] - ) - - def forward(self, x): - a = self.activations["snake_case"](x) - b = self.activations["PascalCase"](x) - c = self.activations["ALL_CAPS"](x) - return a, b, c - - traced = symbolic_trace(M()) - - check = [ - ("activations_snake_case", "activations.snake_case"), - ("activations_pascal_case", "activations.PascalCase"), - ("activations_all_caps", "activations.ALL_CAPS"), - ] - - i = 0 - for node in traced.graph.nodes: - if node.op == "placeholder" or node.op == "output": - continue - name = check[i][0] - target = check[i][1] - self.assertEqual(name, node.name) - self.assertEqual(target, node.target) - i += 1 - self.assertEqual(i, 3) - - def test_no_mutation(self): - from pippy.fx.immutable_collections import immutable_list - - x = immutable_list([3, 4]) - with self.assertRaisesRegex(NotImplementedError, "new_args"): - x[0] = 4 - - def test_partial_trace(self): - class Foo(torch.nn.Module): - def forward(self, x, y): - if y: - return 2 * x - else: - return x - - mod = Foo() - mod_true = symbolic_trace(mod, concrete_args={"y": True}) - mod_false = symbolic_trace(mod, concrete_args={"y": False}) - self.assertEqual(mod_true(3, True), 6) - print(mod_true.code) - assert any([i.target == torch._assert for i in mod_true.graph.nodes]) - with self.assertRaises(AssertionError): - mod_true(3, False) - self.assertEqual(mod_false(3, False), 3) - with self.assertRaises(AssertionError): - mod_false(3, True) - - def f_higher(a, f): - return f(a) - - nf = symbolic_trace(f_higher, concrete_args={"f": lambda x: x * 2}) - self.assertEqual(nf(3, lambda x: x * 2), 6) - - def test_custom_traceback_raised_when_exception_source_is_graphmodule(self): - class M(torch.nn.Module): - def __init__(self): - super(M, self).__init__() - self.W = torch.nn.Parameter(torch.randn(5)) - - def forward(self, x): - return torch.dot(self.W, x) - - traced = pippy.fx.symbolic_trace(M()) - - out = [n for n in traced.graph.nodes if n.op == "output"][-1] - with traced.graph.inserting_before(out): - relu_out = traced.graph.call_method( - method_name="relu", args=(out.args[0],) - ) - out.args = (relu_out,) - - traced.recompile() - - with self.capture_stderr() as captured: - with self.assertRaises(TypeError): - traced(5) - - self.assertRegex( - captured[0], - r"Call using an FX-traced Module, line .* of the " - r"traced Module's generated forward function:", - ) - - def test_custom_traceback_not_raised_when_exception_source_is_submodule( - self, - ): - class M(torch.nn.Module): - def __init__(self): - super().__init__() - self.linear = torch.nn.Linear(3, 4) - - def forward(self, x): - return self.linear(x) - - traced = pippy.fx.symbolic_trace(M()) - - # Do not change this to `capture_stderr` or another context - # manager without ensuring that the output is as expected - try: - traced(torch.rand(5, 5)) - except RuntimeError: - captured = traceback.format_exc() - - self.assertNotRegex( - captured, - r"Call using an FX-traced Module, line .* of the " - r"traced Module's generated forward function:", - ) - - def test_graph_module_replicate_for_dp(self): - class Foo(torch.nn.Module): - def forward(self, x): - return torch.relu(x) - - gm = pippy.fx.symbolic_trace(Foo()) - - x = torch.randn(5, 3) - out = gm(x) - - replica = gm._replicate_for_data_parallel() - out_replica = replica(x) - - torch.testing.assert_allclose(out_replica, out) - - def test_ast_rewriter_rewrites_assert(self): - class M(torch.nn.Module): - def forward(self, x: torch.Tensor, y: int, z: int): - assert y == z - return torch.add(x, x) - - ast_rewriter = RewritingTracer() - graph = ast_rewriter.trace(M()) - traced = GraphModule(ast_rewriter.root, graph, "gm") - - traced.graph.lint() - - def test_ast_rewriter_rewrites_assert_with_message(self): - class M(torch.nn.Module): - def forward(self, x: torch.Tensor, y: int, z: int): - assert y == z, "msg" - return torch.add(x, x) - - ast_rewriter = RewritingTracer() - graph = ast_rewriter.trace(M()) - traced = GraphModule(ast_rewriter.root, graph, "gm") - - traced.graph.lint() - - def test_throw_out_variant(self): - def foo(x): - y = torch.rand_like(x) - torch.sigmoid(x, out=y) - return y - - class MyTracer(pippy.fx.Tracer): - check_mutable_operations = True - - tracer = MyTracer() - with self.assertRaisesRegex( - RuntimeError, "mutable operation aten::sigmoid.out" - ): - traced_graph = tracer.trace(foo) - - def test_ast_rewriter_reassigns_submodules(self): - class M(torch.nn.Module): - def __init__(self): - super().__init__() - self.bn = torch.nn.BatchNorm2d(100) - - def forward(self, x: torch.Tensor): - return torch.add(x, x) - - ast_rewriter = RewritingTracer() - graph = ast_rewriter.trace(M()) - traced = GraphModule(ast_rewriter.root, graph, "gm") - - traced.graph.lint() - - def test_ast_rewriter_wrap(self): - self.assertEqual(3 + 4 + 5, a_lifted_leaf((3, 4), 5)) - - def to_trace(y): - return ( - a_lifted_leaf((4, y), 3) - + a_lifted_leaf((3, 4), 5) - + a_lifted_leaf((y, y), y) - ) - - ast_rewriter = RewritingTracer() - graph = ast_rewriter.trace(to_trace) - traced = GraphModule(ast_rewriter.root, graph, "gm") - - self.assertIn("a_lifted_leaf", traced.code) - self.assertEqual(27, traced(2)) - self.assertIs(a_lifted_leaf, real_a_lifed_leaf) - - def test_ast_rewriter_wrap_fn_directly(self): - self.assertEqual(3 + 4 + 5, a_lifted_leaf2((3, 4), 5)) - - def to_trace(y): - return ( - a_lifted_leaf2((4, y), 3) - + a_lifted_leaf2((3, 4), 5) - + a_lifted_leaf2((y, y), y) - ) - - ast_rewriter = RewritingTracer() - graph = ast_rewriter.trace(to_trace) - traced = GraphModule(ast_rewriter.root, graph, "gm") - - self.assertIn("a_lifted_leaf2", traced.code) - self.assertEqual(27, traced(2)) - self.assertIs(a_lifted_leaf2, real_a_lifed_leaf2) - - def test_profiler_ranges_side_effect(self): - g = pippy.fx.Graph() - handle = g.call_function( - torch.ops.profiler._record_function_enter, ("test_range",) - ) - g.call_function(torch.ops.profiler._record_function_exit, (handle,)) - g.output(None) - - found_targets = {} - for node in g.nodes: - if node.op == "call_function": - found_targets.setdefault(node.target) - self.assertEqual( - list(found_targets.keys()), - [ - torch.ops.profiler._record_function_enter, - torch.ops.profiler._record_function_exit, - ], - ) - - g.eliminate_dead_code() - found_targets = {} - for node in g.nodes: - if node.op == "call_function": - found_targets.setdefault(node.target) - self.assertEqual( - list(found_targets.keys()), - [ - torch.ops.profiler._record_function_enter, - torch.ops.profiler._record_function_exit, - ], - ) - - def test_ast_rewriter_wrapped_via_decorator(self): - class F(torch.nn.Module): - def forward(self, x): - return wrapped_via_decorator(x) - - ast_rewriter = RewritingTracer() - graph = ast_rewriter.trace(F()) - traced = GraphModule(ast_rewriter.root, graph, "gm") - - self.assertIn("wrapped_via_decorator", traced.code) - self.assertEqual(traced(0), 1) - self.assertIs(wrapped_via_decorator, real_wrapped_via_decorator) - self.assertFalse(hasattr(wrapped_via_decorator, "__fx_already_patched")) - - def test_ast_rewriter_wrapped_via_decorator_and_transformed(self): - self.assertEqual(wrapped_via_decorator(0), 1) - - def to_trace(y): - return wrapped_via_decorator(y) - - ast_rewriter = RewritingTracer() - graph = ast_rewriter.trace(to_trace) - traced = GraphModule(ast_rewriter.root, graph, "gm") - - self.assertIn("wrapped_via_decorator", traced.code) - self.assertEqual(traced(0), 1) - self.assertIs(wrapped_via_decorator, real_wrapped_via_decorator) - self.assertFalse(hasattr(wrapped_via_decorator, "__fx_already_patched")) - - transformed = pippy.fx.Transformer(traced).transform() - self.assertIn("wrapped_via_decorator", transformed.code) - self.assertEqual(transformed(0), 1) - self.assertIs(wrapped_via_decorator, real_wrapped_via_decorator) - self.assertFalse(hasattr(wrapped_via_decorator, "__fx_already_patched")) - - def test_ast_rewriter_wrap_with_submodule(self): - class M(torch.nn.Module): - def __init__(self): - super(M, self).__init__() - self.batchnorm1d = torch.nn.BatchNorm1d(2, affine=False) - - def forward(self, x: torch.Tensor): - return wrapped_with_submodule(x, self.batchnorm1d) - - ast_rewriter = RewritingTracer() - graph = ast_rewriter.trace(M()) - traced = GraphModule(ast_rewriter.root, graph, "gm") - - self.assertIn("wrapped_with_submodule", traced.code) - - input = torch.rand(3, 2) - ref_batchnorm1d = torch.nn.BatchNorm1d(2, affine=False) - self.assertEqual(ref_batchnorm1d(input), traced(input)) - - def test_submodule_manipulation_API(self): - class C(torch.nn.Module): - def __init__(self): - super(C, self).__init__() - self.conv = torch.nn.Conv2d(16, 33, 3, stride=2) - self.param = torch.nn.Parameter(torch.rand(2, 3)) - - def forward(self, x): - return self.conv(torch.cat([self.param, x])) - - class B(torch.nn.Module): - def __init__(self): - super(B, self).__init__() - self.linear = torch.nn.Linear(100, 200) - self.register_buffer("buf", torch.randn(2, 3)) - self.net_c = C() - - def forward(self, x): - return self.linear(torch.cat([self.buf, self.net_c(x)])) - - class A(torch.nn.Module): - def __init__(self): - super(A, self).__init__() - self.net_b = B() - self.param = torch.nn.Parameter(torch.rand(2, 3)) - - def forward(self, x): - return self.net_b(x) + self.param - - a = symbolic_trace(A()) - - a.add_submodule("net_b.net_c.dropout", torch.nn.Dropout(p=0.2)) - - conv = [n for n in a.graph.nodes if n.target == "net_b.net_c.conv"][-1] - with a.graph.inserting_before(conv): - with warnings.catch_warnings(record=True) as w: - dropout = a.graph.call_module( - module_name="net_b.net_c.dropout", args=conv.args - ) - self.assertEqual(len(w), 0) - - conv.replace_all_uses_with(dropout) - a.graph.erase_node(conv) - a.recompile() - - def module_exists(gm: GraphModule, path: str) -> bool: - return any(path == name for name, _ in gm.named_modules()) - - def parameter_exists(gm: GraphModule, path: str) -> bool: - return any( - path == name for name, _ in gm.named_parameters() - ) and any(path == name for name in gm.state_dict().keys()) - - def buffer_exists(gm: GraphModule, path: str) -> bool: - return any(path == name for name, _ in gm.named_buffers()) and any( - path == name for name in gm.state_dict().keys() - ) - - # Test that we added the "dropout" submodule - self.assertTrue(module_exists(a, "net_b.net_c.dropout")) - - # Test `get_submodule` with an added submodule - self.assertIsNotNone(a.get_submodule("net_b.net_c.dropout")) - - # Test that the "conv" submodule is still there - self.assertTrue(module_exists(a, "net_b.net_c.conv")) - - # Test `get_submodule` with an original module - self.assertIsNotNone(a.get_submodule("net_b.net_c.conv")) - - # Test that the "conv" node is NOT still there - conv = [n for n in a.graph.nodes if n.target == "net_b.net_c.conv"] - self.assertEqual(conv, []) - - a.delete_submodule("net_b.net_c.conv") - - # Test that the "conv" submodule is now gone - self.assertFalse(module_exists(a, "net_b.net_c.conv")) - - # Test `get_submodule` with a deleted submodule - with self.assertRaisesRegex( - AttributeError, "has no attribute " "`conv`" - ): - self.assertIsNone(a.get_submodule("net_b.net_c.conv")) - - # Test `get_attr` warnings - cat = [n for n in a.graph.nodes if n.target == torch.cat][-1] - - with a.graph.inserting_before(cat): - with warnings.catch_warnings(record=True) as w: - param = a.graph.get_attr(qualified_name="net_b.net_c.param") - self.assertEqual(len(w), 0) - - with self.assertWarnsRegex( - UserWarning, - "Attempted to " - "insert a get_attr Node with no " - "underlying reference in the " - "owning GraphModule", - ): - bad_param = a.graph.get_attr(qualified_name="net_b.param") - a.graph.erase_node(bad_param) - - cat.args = (*cat.args, param) - - a.recompile() - - a.graph.lint() - - # Test `get_parameter` - a.get_parameter("net_b.net_c.param") - with self.assertRaisesRegex( - AttributeError, "is not an " "nn.Parameter" - ): - a.get_parameter("net_b.buf") - with self.assertRaisesRegex( - AttributeError, "has no attribute " "`param`" - ): - a.get_parameter("net_b.param") - - # Test `get_buffer` - a.get_buffer("net_b.buf") - with self.assertRaisesRegex(AttributeError, "is not a " "buffer"): - a.get_buffer("net_b.net_c.param") - with self.assertRaisesRegex( - AttributeError, "has no attribute " "`buf`" - ): - a.get_buffer("net_b.net_c.buf") - - # Test non-nested attributes - a.get_submodule("") - a.get_parameter("param") - - # Insert some unused submodules - a.add_submodule("net_b.embedding", torch.nn.Embedding(10, 3)) - a.add_submodule("net_b.net_c.embedding", torch.nn.Embedding(10, 3)) - a.add_submodule("net_b.net_c.rnn", torch.nn.RNN(10, 20, 2)) - a.add_submodule("batch_norm_2d", torch.nn.BatchNorm2d(100)) - - # Garbage collection - a.delete_all_unused_submodules() - - # Test that all the unused submodules are gone - self.assertFalse(module_exists(a, "net_b.embedding")) - self.assertFalse(module_exists(a, "net_b.net_c.embedding")) - self.assertFalse(module_exists(a, "net_b.net_c.rnn")) - self.assertFalse(module_exists(a, "batch_norm_2d")) - - # Test that we didn't delete any unused Parameters or buffers - self.assertTrue(parameter_exists(a, "net_b.net_c.param")) - self.assertTrue(buffer_exists(a, "net_b.buf")) - - a.graph.lint() - - def test_delete_unused_submodules_leaf(self): - class SubModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.linear = torch.nn.Linear(10, 10) - self.relu = torch.nn.ReLU() - - def forward(self, x): - x = self.linear(x) - x = self.relu(x) - return x - - class Model(torch.nn.Module): - def __init__(self): - super().__init__() - self.submod = SubModule() - - def forward(self, x): - x = self.submod(x) - return x - - model = Model() - - class MyCustomTracer(pippy.fx.Tracer): - def is_leaf_module( - self, m: torch.nn.Module, module_qualified_name: str - ) -> bool: - return module_qualified_name == "submod" - - inputs = torch.randn(1, 10) - traced_graph = MyCustomTracer().trace(model) - gm2 = pippy.fx.GraphModule(model, traced_graph) - gm2.delete_all_unused_submodules() - torch.testing.assert_allclose(gm2(inputs), model(inputs)) - - def test_fx_stateless(self): - class MockModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.l1 = torch.nn.Linear(1, 1) - self.register_buffer("buffer", torch.ones(1)) - - def forward(self, x): - return self.l1(x) + self.buffer - - module = MockModule() - x = torch.rand((1, 1)) - weight = torch.tensor([[1.0]], requires_grad=True) - bias = torch.tensor([0.0], requires_grad=True) - buffer = torch.tensor([0.0]) - parameters = {"l1.weight": weight, "l1.bias": bias, "buffer": buffer} - fx_module = pippy.fx.symbolic_trace(module) - res = _stateless.functional_call(fx_module, parameters, x) - res.backward() - self.assertIsNotNone(weight.grad) - self.assertIsNotNone(bias.grad) - self.assertIsNone(buffer.grad) - # Gradient was not calculated for the module stated and buffers - self.assertIsNone(module.l1.weight.grad) - self.assertIsNone(module.l1.bias.grad) - self.assertIsNone(module.buffer.grad) - - def test_tracing_graphmodules_as_leaf_submodules(self): - class A(torch.nn.Module): - def forward(self, t): - return t + t - - class B(torch.nn.Module): - def __init__(self): - super(type(self), self).__init__() - self.calling = False - self.called = False - - def forward(self, t): - if self.calling: - return t - t - else: - return t + t - - def __call__(self, *args): - self.called = True - self.calling = True - return super(type(self), self).__call__(*args) - self.calling = False - - class M(torch.nn.Module): - def __init__(self, a, b): - super().__init__() - self.a = a - self.b = b - - def forward(self, t): - x = self.a(t) - y = self.b(t) - return x + y - - class LeafTracer(Tracer): - def is_leaf_module(self, module, name): - return True - - class LeafTracerNotB(Tracer): - def is_leaf_module(self, module, name): - return False if "b" in name else True - - # Recompile calls added "for fun", since they - # chain __call__ wrappers. - - # - # Test: B as a regular, non-leaf module - # - a = symbolic_trace(A()) - a.recompile() - m = M(a, B()) - graph = LeafTracerNotB().trace(m) - gm = GraphModule(m, graph) - gm.recompile() - - # Test graphmodule/submodule a is not inlined. - self.assertTrue(isinstance(gm.get_submodule("a"), GraphModule)) - match = [ - n - for n in gm.graph.nodes - if n.op == "call_module" and n.target == "a" - ] - self.assertTrue(len(match) == 1) - - # Test submodule b is not treated as leaf. - self.assertFalse(hasattr(gm, "b")) - - # Test assert custom __call__ on submodule b was honored. - match = [ - n - for n in gm.graph.nodes - if n.op == "call_function" and n.target == operator.sub - ] - self.assertTrue(len(match) == 1) - - # - # Test: B as a regular, leaf module - # symbolic_trace should only patch torch.nn.Module.__call__, - # which means B.__call__ should still execute - # - a = symbolic_trace(A()) - a.recompile() - b = B() - m = M(a, b) - graph = LeafTracer().trace(m) - gm = GraphModule(m, graph) - gm.recompile() - - # Test graphmodule/submodule a is not inlined. - self.assertTrue(isinstance(gm.get_submodule("a"), GraphModule)) - match = [ - n - for n in gm.graph.nodes - if n.op == "call_module" and n.target == "a" - ] - self.assertTrue(len(match) == 1) - - # Test submodule b is leaf: - self.assertTrue(isinstance(gm.get_submodule("b"), torch.nn.Module)) - match = [ - n - for n in gm.graph.nodes - if n.op == "call_module" and n.target == "b" - ] - self.assertTrue(len(match) == 1) - - # Test b.__call__ was run - self.assertTrue(b.called) - self.assertTrue(gm.get_submodule("b").called) - - # - # Test: B as GraphModule leaf - # __call__ not honored since symbolic_trace directly invokes forward() - # - a = symbolic_trace(A()) - a.recompile() - b = symbolic_trace(B()) - b.recompile() - m = M(a, b) - graph = LeafTracer().trace(m) - gm = GraphModule(m, graph) - gm.recompile() - - self.assertTrue(isinstance(gm.get_submodule("a"), GraphModule)) - match = [ - n - for n in gm.graph.nodes - if n.op == "call_module" and n.target == "a" - ] - self.assertTrue(len(match) == 1) - - self.assertTrue(isinstance(gm.get_submodule("b"), torch.nn.Module)) - match = [ - n - for n in gm.graph.nodes - if n.op == "call_module" and n.target == "b" - ] - self.assertTrue(len(match) == 1) - - def _test_graph_module_init_buffer_param_copied(self, use_dict_init: bool): - class MyModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.register_buffer("my_buff", torch.rand(3, 4)) - self.register_parameter( - "my_param", torch.nn.Parameter(torch.rand(3, 4)) - ) - - def forward(self, x): - return x + self.my_buff + self.my_param - - mod = MyModule() - mod_traced = symbolic_trace(mod) - - # Create new GraphModule based on original, either w/ dict or root module. - orig_buff = mod_traced.get_buffer("my_buff") - orig_param = mod_traced.get_parameter("my_param") - mod_traced_new = GraphModule( - {"my_buff": orig_buff, "my_param": orig_param} - if use_dict_init - else mod, - mod_traced.graph, - ) - - # Check that both my_buff and my_param are found and the same. - try: - new_buff = mod_traced_new.get_buffer("my_buff") - except Exception: - self.fail("Did not find my_buff") - self.assertEqual(orig_buff, new_buff) - - try: - new_param = mod_traced_new.get_parameter("my_param") - except Exception: - self.fail("Did not find my_param") - self.assertEqual(orig_param, new_param) - - x = torch.rand(3, 4) - orig_out = mod_traced(x) - submodules_out = mod_traced_new(x) - - self.assertEqual(orig_out, submodules_out) - - def test_graph_module_init_buffer_param_copied_dict_init(self): - self._test_graph_module_init_buffer_param_copied(use_dict_init=True) - - def test_graph_module_init_buffer_param_copied_mod_init(self): - self._test_graph_module_init_buffer_param_copied(use_dict_init=False) - - def test_annotations_with_no_forward_references(self): - class A: - def __call__(self, x: torch.Tensor): - return torch.add(x, x) - - class M(torch.nn.Module): - def forward(self, x: torch.Tensor, a: A) -> torch.Tensor: - return a(x) - - self.checkGraphModule(M(), (torch.rand(2, 3), A()), kwargs=None) - - def test_annotations_with_forward_references(self): - class A: - def __call__(self, x: torch.Tensor): - return torch.add(x, x) - - class M(torch.nn.Module): - def forward(self, x: "torch.Tensor", a: "A") -> "torch.Tensor": - return a(x) - - self.checkGraphModule(M(), (torch.rand(2, 3), A()), kwargs=None) - - def test_annotations_with_non_torch_reference_and_no_internal_forward_references( - self, - ): - class A: - def __call__(self, x: torch.Tensor): - return torch.add(x, x) - - class M(torch.nn.Module): - def forward(self, x: List[torch.Tensor], a: A) -> torch.Tensor: - return a(x[0]) - - self.checkGraphModule(M(), (torch.rand(2, 3), A()), kwargs=None) - - def test_annotations_with_non_torch_reference_and_internal_forward_references( - self, - ): - class A: - def __call__(self, x: torch.Tensor): - return torch.add(x, x) - - class M(torch.nn.Module): - def forward(self, x: List["torch.Tensor"], a: A) -> "torch.Tensor": - return a(x)[0] - - self.checkGraphModule(M(), (torch.rand(2, 3), A()), kwargs=None) - - @unittest.skipIf( - sys.version_info < (3, 7), - "`__future__` feature " "`annotations` is not defined in Python <3.7", - ) - def test_annotation_with_future(self): - try: - import fx.test_future # noqa: F401 - finally: - del sys.modules["__future__"] - - def test_annotations_empty_tuple(self): - class Foo(torch.nn.Module): - def forward(self, x: Tuple[()], y: Tuple[str, Tuple[()]]): - return "foo" - - traced = pippy.fx.symbolic_trace(Foo()) - - x = () - y = ("bar", ()) - - traced(x, y) - - FileCheck().check("_Tuple[()]").check( - "typing_Tuple[str,typing_Tuple[()]]" - ).run(traced.code) - - scripted = torch.jit.script(traced) - - scripted(x, y) - - FileCheck().check("Tuple[()]").check("Tuple[str, Tuple[()]]").run( - scripted.code - ) - - @unittest.skipIf( - IS_WINDOWS, "Python Windows bug? https://bugs.python.org/issue45108" - ) - @unittest.skipIf( - sys.version_info >= (3, 10), "Does not work on Python-3.10" - ) - def test_assert(self): - def f(x): - assert x > 1 - return x + 1 - - try: - pippy.fx.proxy.TracerBase.trace_asserts = True - traced = symbolic_trace(f) - finally: - pippy.fx.proxy.TracerBase.trace_asserts = False - - self.assertEqual(f(2), traced(2)) - with self.assertRaises(AssertionError): - traced(0) - - def test_pytree(self): - def f_sum(x): - return sum(x) - - def f_sum_dict(x): - out = 0 - for k, v in x.items(): - out += v - return out - - def f_dict_list_map(x): - new_dict = {} - for k, v in x.items(): - new_dict[k] = [i + 1 for i in v] - return new_dict - - def f_dict_add(x): - return x["a"] + sum(x["z"]) - - def f_namedtuple_add(x): - return x.x + x.y - - pytree._register_pytree_node( - Foo, - lambda x: ([x.a, x.b], None), - lambda x, _: Foo(x[0], x[1]), - ) - fx_pytree.register_pytree_flatten_spec(Foo, lambda x, _: [x.a, x.b]) - - def f_custom(x): - return x.a + x.b - - def f_custom_dict(x): - return f_sum_dict(x.a) + x.b - - def f_return_custom(x): - return Foo(x.b, x.a) - - tests = [ - (f_sum, [PH, PH, PH]), - (f_sum, []), - (f_sum_dict, {"a": PH, "b": PH, "c": PH}), - (f_dict_list_map, {"a": (PH, PH), "b": [PH], "c": []}), - (f_dict_list_map, {5: (PH, PH, PH)}), - (f_dict_add, {"a": PH, "z": (PH, PH, PH)}), - (f_dict_add, {"a": PH, "z": []}), - (f_custom, Foo(PH, PH)), - (f_custom, Foo(PH, 3)), - (f_custom_dict, Foo({"a": PH, "b": PH}, PH)), - # (f_return_custom, Foo(PH, PH)), # Don't currently support output pytrees - (f_namedtuple_add, Point(PH, PH)), - ] - - def verify_pytree(f, inp): - val = pytree.tree_map( - lambda x: torch.randn(3) if x == PH else x, inp - ) - num_flat_args = len([i == PH for i in pytree.tree_flatten(inp)[0]]) - orig_out = f(val) - nf = symbolic_trace(f, concrete_args={"x": inp}) - self.assertEqual(nf(val), orig_out) - - bare_fx = GraphModule({}, copy.deepcopy(nf.graph)) - bare_fx.graph.set_codegen(CodeGen()) - bare_fx.recompile() - self.assertEqual( - nf.graph.process_outputs( - bare_fx(*nf.graph.process_inputs(val)) - ), - orig_out, - ) - - assert num_flat_args == 0 or "tree_flatten_spec" in nf.code - assert ( - sum([i.op == "placeholder" for i in nf.graph.nodes]) - == num_flat_args - ) - - nf = symbolic_trace(nf) - self.assertEqual(nf(val), orig_out) - assert "tree_flatten_spec" not in nf.code - assert sum([i.op == "placeholder" for i in nf.graph.nodes]) == 1 - - nf = symbolic_trace(nf, concrete_args={"x": inp}) - self.assertEqual(nf(val), orig_out) - assert num_flat_args == 0 or "tree_flatten_spec" in nf.code - assert ( - sum([i.op == "placeholder" for i in nf.graph.nodes]) - == num_flat_args - ) - - pickled = pickle.dumps(nf) - nf = pickle.loads(pickled) - self.assertEqual(nf(val), orig_out) - - for f, inp in tests: - verify_pytree(f, inp) - - def test_pytree_concrete(self): - def f(b, a): - if b: - return a["a"] - else: - return a["z"] - - inp = {"a": {"a": PH, "z": PH}, "b": True} - nf = symbolic_trace(f, concrete_args=inp) - val = pytree.tree_map(lambda x: torch.randn(3) if x == PH else x, inp) - self.assertEqual(nf(**val), f(**val)) - - nf = symbolic_trace(nf) - self.assertEqual(nf(**val), f(**val)) - - def test_custom_codegen(self): - class ListCodeGen(CodeGen): - def gen_fn_def(self, free_vars, maybe_return_annotation): - lst_unpack = f""" -def forward(self, args_list: List[torch.Tensor]){maybe_return_annotation}: - {', '.join(free_vars)} = args_list""" - return lst_unpack - - def additional_globals(self): - return [("List", typing.List)] - - def process_inputs(self, *inputs): - assert len(inputs) == 1 - return inputs[0] - - def f(a, b): - return a + b - - nf = symbolic_trace(f) - vals = [torch.randn(3), torch.randn(3)] - self.assertEqual(nf(*vals), f(*vals)) - - nf.graph.set_codegen(ListCodeGen()) - nf.recompile() - - bare_fx = GraphModule({}, copy.deepcopy(nf.graph)) - bare_fx.graph.set_codegen(CodeGen()) - bare_fx.recompile() - - self.assertEqual(nf(vals), f(*vals)) - self.assertEqual( - nf.graph.process_outputs(bare_fx(*nf.graph.process_inputs(vals))), - f(*vals), - ) - - ts_f = torch.jit.script(nf) - self.assertEqual(nf(vals), ts_f(vals)) - - def test_custom_codegen_with_transformer(self): - class ListCodeGen(CodeGen): - def gen_fn_def(self, free_vars, maybe_return_annotation): - lst_unpack = f""" -def forward(self, args_list: List[torch.Tensor]){maybe_return_annotation}: - {', '.join(free_vars)} = args_list""" - return lst_unpack - - def additional_globals(self): - return [("List", typing.List)] - - def process_inputs(self, *inputs): - assert len(inputs) == 1 - return inputs[0] - - def f(a, b): - return a + b - - nf = symbolic_trace(f) - vals = [torch.randn(3), torch.randn(3)] - self.assertEqual(nf(*vals), f(*vals)) - - nf.graph.set_codegen(ListCodeGen()) - nf.recompile() - self.assertEqual(nf(vals), f(*vals)) - - transformed_gm = Transformer(nf).transform() - self.assertEqual(nf(vals), transformed_gm(vals)) - - def test_interpreter_with_codegen(self): - class ListCodeGen(CodeGen): - def gen_fn_def(self, free_vars, maybe_return_annotation): - lst_unpack = f""" -def forward(self, args_list: List[torch.Tensor]){maybe_return_annotation}: - {', '.join(free_vars)} = args_list""" - return lst_unpack - - def additional_globals(self): - return [("List", typing.List)] - - def process_inputs(self, *inputs): - assert len(inputs) == 1 - return inputs[0] - - def generate_output(self, output_args): - return f"return list({repr(output_args)})" - - def process_outputs(self, outputs): - return list(outputs) - - def f(a, b): - a = a + b - b = a + b - return a, b - - nf = symbolic_trace(f) - vals = [torch.randn(3), torch.randn(3)] - nf.graph.set_codegen(ListCodeGen()) - nf.recompile() - self.assertEqual(Interpreter(nf).run(vals), nf(vals)) - - def test_imul_code_print(self): - graph = pippy.fx.Graph() - a = graph.placeholder("a") - b = graph.placeholder("b") - graph.call_function(operator.imul, (a, b), {}) - graph.output(a) - gm = pippy.fx.GraphModule({}, graph) - gm.recompile() - self.assertEqual(gm(2, 3), 6) - self.assertIn("a *= b", gm.code) - - def test_deepcopy_tracer(self): - def fn(x, y): - return (x + y).relu().sin() - - tracer = Tracer() - tracer_before = copy.deepcopy(tracer) - tracer.trace(fn) - tracer_after = copy.deepcopy(tracer) - - self.assertEqual(str(tracer.graph), str(tracer_after.graph)) - self.assertTrue( - not hasattr(tracer_before, "graph") - or str(tracer.graph) != str(tracer_before.graph) - ) - - -def run_getitem_target(): - from pippy.fx._symbolic_trace import _wrapped_methods_to_patch - - _wrapped_methods_to_patch.append((torch.Tensor, "__getitem__")) - try: - TestFX().getitem_inner() - finally: - _wrapped_methods_to_patch.pop() - - -class TestOperatorSignatures(JitTestCase): - def setUp(self): - # Checking for mutable operations whil tracing is feature flagged - # Enable it in testing but not by default - self.orig_tracer_mutable_flag = ( - pippy.fx.proxy.TracerBase.check_mutable_operations - ) - pippy.fx.proxy.TracerBase.check_mutable_operations = True - - def tearDown(self): - pippy.fx.proxy.TracerBase.check_mutable_operations = ( - self.orig_tracer_mutable_flag - ) - - @onlyCPU - @ops(op_db, allowed_dtypes=(torch.float,)) - def test_get_torch_func_signature_exhaustive(self, device, dtype, op): - if not isinstance(op.op, types.BuiltinFunctionType): - raise unittest.SkipTest( - "This path doesn't work on Python functions" - ) - sample_inputs_itr = op.sample_inputs(device, dtype, requires_grad=False) - schemas = get_signature_for_torch_op(op.op) - if not schemas: - raise RuntimeError("No Schemas Returned") - for sample_input in sample_inputs_itr: - # Iterate through overloads until we hit a match. If we exit this - # loop via `else`, we haven't found a match - for schema in schemas: - try: - bound_args = schema.bind( - sample_input.input, - *sample_input.args, - **sample_input.kwargs, - ) - bound_args.apply_defaults() - op(*bound_args.args, **bound_args.kwargs) - break - except TypeError as e: - pass - else: - raise RuntimeError( - f"Did not match any schemas for op {op.name}!" - ) - - -class TestFXAPIBackwardCompatibility(JitTestCase): - def setUp(self): - super().setUp() - self.maxDiff = None - - # Checking for mutable operations whil tracing is feature flagged - # Enable it in testing but not by default - self.orig_tracer_mutable_flag = ( - pippy.fx.proxy.TracerBase.check_mutable_operations - ) - pippy.fx.proxy.TracerBase.check_mutable_operations = True - - def tearDown(self): - super().tearDown() - pippy.fx.proxy.TracerBase.check_mutable_operations = ( - self.orig_tracer_mutable_flag - ) - - def _fn_to_stable_annotation_str(self, obj): - """ - Unfortunately we have to serialize function signatures manually since - serialization for `inspect.Signature` objects is not stable across - python versions - """ - fn_name = torch.typename(obj) - - signature = inspect.signature(obj) - - sig_str = f"{fn_name}{signature}" - - arg_strs = [] - for k, v in signature.parameters.items(): - maybe_type_annotation = ( - f": {self._annotation_type_to_stable_str(v.annotation, sig_str)}" - if v.annotation is not inspect.Signature.empty - else "" - ) - - def default_val_str(val): - if isinstance(val, (tuple, list)): - str_pieces = ["(" if isinstance(val, tuple) else "["] - str_pieces.append( - ", ".join(default_val_str(v) for v in val) - ) - if isinstance(val, tuple) and len(str_pieces) == 2: - str_pieces.append(",") - str_pieces.append(")" if isinstance(val, tuple) else "]") - return "".join(str_pieces) - - # Need to fix up some default value strings. - # First case: modules. Default module `repr` contains the FS path of the module. - # Don't leak that - if isinstance(val, types.ModuleType): - return f"" - - # Second case: callables. Callables (such as lambdas) encode their address in - # their string repr. Don't do that - if callable(val): - return f"" - - return str(val) - - if v.default is not inspect.Signature.empty: - default_val_str = ( - default_val_str(v.default) - if not isinstance(v.default, str) - else f"'{v.default}'" - ) - maybe_default = f" = {default_val_str}" - else: - maybe_default = "" - maybe_stars = "" - if v.kind == inspect.Parameter.VAR_POSITIONAL: - maybe_stars = "*" - elif v.kind == inspect.Parameter.VAR_KEYWORD: - maybe_stars = "**" - arg_strs.append( - f"{maybe_stars}{k}{maybe_type_annotation}{maybe_default}" - ) - - return_annot = ( - f" -> {self._annotation_type_to_stable_str(signature.return_annotation, sig_str)}" - if signature.return_annotation is not inspect.Signature.empty - else "" - ) - - return f'{fn_name}({", ".join(arg_strs)}){return_annot}' - - def _annotation_type_to_stable_str(self, t, sig_str): - if t is inspect.Signature.empty: - return "" - - # Forward ref - if isinstance(t, str): - return f"'{t}'" - if hasattr(typing, "ForwardRef") and isinstance(t, typing.ForwardRef): - return t.__forward_arg__ - if hasattr(typing, "_ForwardRef") and isinstance(t, typing._ForwardRef): - return t.__forward_arg__ - - trivial_mappings = { - str: "str", - int: "int", - float: "float", - bool: "bool", - torch.dtype: "torch.dtype", - torch.Tensor: "torch.Tensor", - torch.device: "torch.device", - torch.memory_format: "torch.memory_format", - slice: "slice", - torch.nn.Module: "torch.nn.modules.module.Module", - pippy.fx.Graph: "pippy.fx.graph.Graph", - pippy.fx.Node: "pippy.fx.node.Node", - pippy.fx.Proxy: "pippy.fx.proxy.Proxy", - pippy.fx.node.Target: "pippy.fx.node.Target", - pippy.fx.node.Argument: "pippy.fx.node.Argument", - pippy.fx.graph.PythonCode: "pippy.fx.graph.PythonCode", - pippy.fx.graph_module.GraphModule: "pippy.fx.graph_module.GraphModule", - pippy.fx.subgraph_rewriter.Match: "pippy.fx.subgraph_rewriter.Match", - Ellipsis: "...", - typing.Any: "Any", - type(None): "NoneType", - None: "None", - typing.Iterator: "Iterator", - } - - mapping = trivial_mappings.get(t, None) - if mapping: - return mapping - - # Handle types with contained types - contained = getattr(t, "__args__", None) or [] - - # Callables contain a bare List for arguments - contained = t if isinstance(t, list) else contained - - # Python 3.8 puts type vars into __args__ for unbound types such as Dict - if all(isinstance(ct, typing.TypeVar) for ct in contained): - contained = [] - - contained_type_annots = [ - self._annotation_type_to_stable_str(ct, sig_str) for ct in contained - ] - contained_type_str = ( - f'[{", ".join(contained_type_annots)}]' - if len(contained_type_annots) > 0 - else "" - ) - - origin = getattr(t, "__origin__", None) - if origin is None: - # Unbound types don't have `__origin__` in some Python versions, so fix that up here. - origin = ( - t - if t - in { - typing.Tuple, - typing.Union, - typing.Dict, - typing.List, - typing.Type, - typing.Callable, - } - else origin - ) - - if origin in {tuple, typing.Tuple}: - return f"Tuple{contained_type_str}" - if origin in {typing.Union}: - # Annoying hack to detect Optional - if len(contained) == 2 and (contained[0] is type(None)) ^ ( - contained[1] is type(None) - ): - not_none_param = ( - contained[0] - if contained[0] is not type(None) - else contained[1] - ) - return f"Optional[{self._annotation_type_to_stable_str(not_none_param, sig_str)}]" - return f"Union{contained_type_str}" - if origin in {dict, typing.Dict}: - return f"Dict{contained_type_str}" - if origin in {list, typing.List}: - return f"List{contained_type_str}" - if origin in {type, typing.Type}: - return f"Type{contained_type_str}" - if isinstance(t, typing.Callable): - if len(contained) > 0 and contained[0] is not Ellipsis: - return f'Callable[[{", ".join(contained_type_annots[:-1])}], {contained_type_annots[-1]}]' - else: - return f"Callable{contained_type_str}" - - raise RuntimeError( - f"Unrecognized type {t} used in BC-compatible type signature {sig_str}." - f"Please add support for this type and confirm with the " - f"FX team that your signature change is valid." - ) - - def test_function_back_compat(self): - """ - Test backward compatibility for function signatures with - @compatibility(is_backward_compatible=True). Currently this checks for - exact signature matches, which may lead to false positives. If this - becomes too annoying, we can refine this check to actually parse out - the saved schema strings and check if the change is truly backward- - incompatible. - """ - signature_strs = [] - - for obj in _BACK_COMPAT_OBJECTS: - if not isinstance(obj, type): - signature_strs.append(self._fn_to_stable_annotation_str(obj)) - - signature_strs.sort() - - try: - self.assertExpected( - "\n".join(signature_strs) + "\n", - "fx_backcompat_function_signatures", - ) - except AssertionError as e: - msg = ( - f"{e}\n****** ERROR ******\nAn FX function that has been marked " - f"as backwards-compatible has experienced a signature change. See the " - f"above exception context for more information. If this change was " - f"unintended, please revert it. If it was intended, check with the FX " - f"team to ensure that the proper deprecation protocols have been followed " - f"and subsequently --accept the change." - ) - raise AssertionError(msg) - - def test_class_member_back_compat(self): - """ - Test backward compatibility for members of classes with - @compatibility(is_backward_compatible=True). Currently this checks for - exact matches on the publicly visible members of the class. - """ - class_method_strs = [] - - for obj in _BACK_COMPAT_OBJECTS: - if isinstance(obj, type): - public_members = [ - name for name in obj.__dict__ if not name.startswith("_") - ] - class_method_strs.append( - f"{torch.typename(obj)} {sorted(public_members)}" - ) - - class_method_strs.sort() - - try: - self.assertExpected( - "\n".join(class_method_strs), "fx_backcompat_class_members" - ) - except AssertionError as e: - msg = ( - f"{e}\n****** ERROR ******\nAn FX class that has been marked " - f"as backwards-compatible has experienced change in its public members. See the " - f"above exception context for more information. If this change was " - f"unintended, please revert it. If it was intended, check with the FX " - f"team to ensure that the proper deprecation protocols have been followed " - f"and subsequently --accept the change." - ) - raise AssertionError(msg) - - def test_public_api_surface(self): - non_back_compat_objects = {} - - def check_symbols_have_bc_designation(m, prefix): - if not m.__name__.startswith("pippy.fx"): - return - if m.__name__.startswith("pippy.fx.experimental"): - return - for k, v in m.__dict__.items(): - if v is m: - continue - if k.startswith("_"): - continue - if isinstance(v, types.ModuleType): - check_symbols_have_bc_designation(v, prefix + [k]) - elif isinstance(v, type) or isinstance(v, types.FunctionType): - if v not in _MARKED_WITH_COMATIBLITY: - non_back_compat_objects.setdefault(v) - - check_symbols_have_bc_designation(pippy.fx, ["torch", "fx"]) - check_symbols_have_bc_designation( - pippy.fx.passes, ["torch", "fx", "passes"] - ) - - non_back_compat_strs = [ - torch.typename(obj) for obj in non_back_compat_objects.keys() - ] - # Only want objects in pippy.fx - non_back_compat_strs = [ - s - for s in non_back_compat_strs - if s.startswith("pippy.fx") - and not s.startswith("pippy.fx.experimental") - ] - # Only want objects in public namespaces - non_back_compat_strs = [ - s - for s in non_back_compat_strs - if all(not atom.startswith("_") for atom in s.split(".")) - ] - non_back_compat_strs.sort() - - if len(non_back_compat_strs) != 0: - raise AssertionError( - f"Public FX API(s) {non_back_compat_strs} introduced but not given a " - f"backwards-compatibility classification! Please decorate these " - f"API(s) with `@pippy.fx._compatibility.compatibility` to specify " - f"BC guarantees." - ) - - -class TestFunctionalTracing(JitTestCase): - def setUp(self): - super().setUp() - # Checking for mutable operations whil tracing is feature flagged - # Enable it in testing but not by default - self.orig_tracer_mutable_flag = ( - pippy.fx.proxy.TracerBase.check_mutable_operations - ) - pippy.fx.proxy.TracerBase.check_mutable_operations = True - - def tearDown(self): - super().tearDown() - pippy.fx.proxy.TracerBase.check_mutable_operations = ( - self.orig_tracer_mutable_flag - ) - - IGNORE_FUNCS = ( - "has_torch_function", - "has_torch_function_unary", - "has_torch_function_variadic", - "handle_torch_function", - "boolean_dispatch", - ) - TO_PATCH = { - "has_torch_function": None, - "has_torch_function_unary": None, - "has_torch_function_variadic": None, - } - - BUILT_IN_FUNC = (AssertionError, "") - PROXY_ITERABLE = (TypeError, r"argument of type 'Proxy' is not iterable") - PROXY_ITERATED = (TraceError, r"Proxy object cannot be iterated") - LEN_ERROR = ( - RuntimeError, - r"'len' is not supported in symbolic tracing by default", - ) - ARG_TYPE_MISMATCH = (TypeError, r", not Proxy$") - CONTROL_FLOW = ( - TraceError, - r"symbolically traced variables cannot be used as inputs to control flow", - ) - INTERPOLATE_ARGS_CONFLICT = ( - ValueError, - r"only one of size or scale_factor should be defined", - ) - MUTABLE = (RuntimeError, r"Tried to trace mutable operation") - - UNTRACEABLE_FUNCTIONALS = { - "adaptive_avg_pool1d": BUILT_IN_FUNC, - "avg_pool1d": BUILT_IN_FUNC, - "avg_pool2d": BUILT_IN_FUNC, - "avg_pool3d": BUILT_IN_FUNC, - "bilinear": BUILT_IN_FUNC, - "celu_": BUILT_IN_FUNC, - "channel_shuffle": BUILT_IN_FUNC, - "native_channel_shuffle": BUILT_IN_FUNC, - "conv1d": BUILT_IN_FUNC, - "conv2d": BUILT_IN_FUNC, - "conv3d": BUILT_IN_FUNC, - "conv_tbc": BUILT_IN_FUNC, - "conv_transpose1d": BUILT_IN_FUNC, - "conv_transpose2d": BUILT_IN_FUNC, - "conv_transpose3d": BUILT_IN_FUNC, - "cosine_similarity": BUILT_IN_FUNC, - "elu_": BUILT_IN_FUNC, - "gelu": BUILT_IN_FUNC, - "hardshrink": BUILT_IN_FUNC, - "hardtanh_": BUILT_IN_FUNC, - "leaky_relu_": BUILT_IN_FUNC, - "linear": BUILT_IN_FUNC, - "logsigmoid": BUILT_IN_FUNC, - "one_hot": BUILT_IN_FUNC, - "pad": BUILT_IN_FUNC, - "pairwise_distance": BUILT_IN_FUNC, - "pdist": BUILT_IN_FUNC, - "pixel_shuffle": BUILT_IN_FUNC, - "pixel_unshuffle": BUILT_IN_FUNC, - "prelu": BUILT_IN_FUNC, - "relu_": BUILT_IN_FUNC, - "rrelu_": BUILT_IN_FUNC, - "selu_": BUILT_IN_FUNC, - "softplus": BUILT_IN_FUNC, - "softshrink": BUILT_IN_FUNC, - "threshold_": BUILT_IN_FUNC, - "adaptive_avg_pool2d": LEN_ERROR, - "adaptive_avg_pool3d": LEN_ERROR, - "adaptive_max_pool2d_with_indices": LEN_ERROR, - "adaptive_max_pool3d_with_indices": LEN_ERROR, - "instance_norm": CONTROL_FLOW, - "adaptive_max_pool1d": PROXY_ITERABLE, - "adaptive_max_pool2d": PROXY_ITERABLE, - "adaptive_max_pool3d": PROXY_ITERABLE, - "fractional_max_pool2d": PROXY_ITERABLE, - "fractional_max_pool3d": PROXY_ITERABLE, - "max_pool1d": PROXY_ITERABLE, - "max_pool2d": PROXY_ITERABLE, - "max_pool3d": PROXY_ITERABLE, - "group_norm": PROXY_ITERATED, - "lp_pool2d": PROXY_ITERATED, - "max_unpool1d": PROXY_ITERATED, - "max_unpool2d": PROXY_ITERATED, - "max_unpool3d": PROXY_ITERATED, - "adaptive_max_pool1d_with_indices": ARG_TYPE_MISMATCH, - "fractional_max_pool2d_with_indices": ARG_TYPE_MISMATCH, - "fractional_max_pool3d_with_indices": ARG_TYPE_MISMATCH, - "layer_norm": ARG_TYPE_MISMATCH, - "lp_pool1d": ARG_TYPE_MISMATCH, - "affine_grid": CONTROL_FLOW, - "alpha_dropout": CONTROL_FLOW, - "batch_norm": CONTROL_FLOW, - "binary_cross_entropy": CONTROL_FLOW, - "binary_cross_entropy_with_logits": CONTROL_FLOW, - "celu": CONTROL_FLOW, - "cosine_embedding_loss": CONTROL_FLOW, - "cross_entropy": CONTROL_FLOW, - "ctc_loss": CONTROL_FLOW, - "dropout": CONTROL_FLOW, - "dropout1d": CONTROL_FLOW, - "dropout2d": CONTROL_FLOW, - "dropout3d": CONTROL_FLOW, - "elu": CONTROL_FLOW, - "embedding": CONTROL_FLOW, - "embedding_bag": CONTROL_FLOW, - "feature_alpha_dropout": CONTROL_FLOW, - "fold": CONTROL_FLOW, - "gaussian_nll_loss": CONTROL_FLOW, - "glu": CONTROL_FLOW, - "grid_sample": CONTROL_FLOW, - "gumbel_softmax": CONTROL_FLOW, - "hardsigmoid": CONTROL_FLOW, - "hardswish": CONTROL_FLOW, - "hardtanh": CONTROL_FLOW, - "hinge_embedding_loss": CONTROL_FLOW, - "huber_loss": CONTROL_FLOW, - "interpolate": CONTROL_FLOW, - "kl_div": CONTROL_FLOW, - "l1_loss": CONTROL_FLOW, - "leaky_relu": CONTROL_FLOW, - "local_response_norm": CONTROL_FLOW, - "margin_ranking_loss": CONTROL_FLOW, - "max_pool1d_with_indices": ARG_TYPE_MISMATCH, - "max_pool2d_with_indices": ARG_TYPE_MISMATCH, - "max_pool3d_with_indices": ARG_TYPE_MISMATCH, - "mse_loss": CONTROL_FLOW, - "multi_head_attention_forward": CONTROL_FLOW, - "multi_margin_loss": CONTROL_FLOW, - "multilabel_margin_loss": CONTROL_FLOW, - "multilabel_soft_margin_loss": CONTROL_FLOW, - "nll_loss": CONTROL_FLOW, - "poisson_nll_loss": CONTROL_FLOW, - "relu": CONTROL_FLOW, - "relu6": CONTROL_FLOW, - "rrelu": CONTROL_FLOW, - "selu": CONTROL_FLOW, - "silu": CONTROL_FLOW, - "mish": CONTROL_FLOW, - "smooth_l1_loss": CONTROL_FLOW, - "soft_margin_loss": CONTROL_FLOW, - "threshold": CONTROL_FLOW, - "triplet_margin_loss": CONTROL_FLOW, - "triplet_margin_with_distance_loss": CONTROL_FLOW, - "unfold": CONTROL_FLOW, - "upsample": CONTROL_FLOW, - "upsample_bilinear": INTERPOLATE_ARGS_CONFLICT, - "upsample_nearest": INTERPOLATE_ARGS_CONFLICT, - } - - # List of nn.functionals with Tensor inputs but not with type annotation - FUNCTIONALS_WITHOUT_ANNOTATION = ( - "adaptive_max_pool1d", - "adaptive_max_pool2d", - "adaptive_max_pool3d", - "fractional_max_pool2d", - "fractional_max_pool3d", - "max_pool1d", - "max_pool2d", - "max_pool3d", - "gaussian_nll_loss", - "upsample", - "upsample_bilinear", - "upsample_nearest", - ) - - # Inconsistent behavior between Python 3.8 and other Python versions: - # - Python 3.8+: Re-raise internal exception like `PROXY_ITERATED` - # - Other Python: Raise `argument of type 'Proxy' is not iterable` due to the same - # internal exception above - # Use the following map to override the expected exception for Python 3.8 - UNTRACEABLE_FUNCTIONALS_PY38 = { - "adaptive_max_pool1d": PROXY_ITERATED, - "adaptive_max_pool2d": PROXY_ITERATED, - "adaptive_max_pool3d": PROXY_ITERATED, - "fractional_max_pool2d": PROXY_ITERATED, - "fractional_max_pool3d": PROXY_ITERATED, - "max_pool1d": PROXY_ITERATED, - "max_pool2d": PROXY_ITERATED, - "max_pool3d": PROXY_ITERATED, - "group_norm": LEN_ERROR, - } - - @classmethod - def _get_functional(cls): - functional_list = [] - for f in dir(torch.nn.functional): - if not f.islower(): - continue - # Ignore internal functions - if f.startswith("_"): - continue - # Ignore supporting functions - if f in cls.IGNORE_FUNCS: - continue - fn = getattr(torch.nn.functional, f) - # Ignore non-callable object like modules - if not isinstance(fn, Callable): - continue - if f not in cls.FUNCTIONALS_WITHOUT_ANNOTATION: - try: - sig = inspect.signature(fn) - has_tensor_arg = False - for arg, param in sig.parameters.items(): - if isinstance(param.annotation, type) and issubclass( - param.annotation, torch.Tensor - ): - has_tensor_arg = True - if not has_tensor_arg: - continue - # No signature or Object is not supported - except ValueError: - pass - functional_list.append((f, fn)) - return functional_list - - @classmethod - def generate_test_func(cls, func_name, fn): - def functional_test(self): - if ( - func_name in self.UNTRACEABLE_FUNCTIONALS_PY38 - and sys.version_info >= (3, 8) - and sys.version_info < (3, 11) - ): - exc, err = self.UNTRACEABLE_FUNCTIONALS_PY38[func_name] - with self.assertRaisesRegex(exc, err): - symbolic_trace(fn) - elif func_name in self.UNTRACEABLE_FUNCTIONALS: - exc, err = self.UNTRACEABLE_FUNCTIONALS[func_name] - with self.assertRaisesRegex(exc, err): - symbolic_trace(fn) - else: - symbolic_trace(fn) - - return functional_test - - @classmethod - def generate_tests(cls): - functional_list = cls._get_functional() - for func_name, fn in functional_list: - test_name = "test_nn_functional_" + func_name - functional_test = cls.generate_test_func(func_name, fn) - setattr(cls, test_name, functional_test) - - @classmethod - def setUpClass(cls): - def no(*args, **kwargs): - return False - - for name in cls.TO_PATCH.keys(): - cls.TO_PATCH[name] = getattr(torch.nn.functional, name) - setattr(torch.nn.functional, name, no) - - @classmethod - def tearDownClass(cls): - for name in cls.TO_PATCH.keys(): - setattr(torch.nn.functional, name, cls.TO_PATCH[name]) - - -TestFunctionalTracing.generate_tests() - - -instantiate_device_type_tests(TestOperatorSignatures, globals()) - - -@skipIfNoTorchVision -@skipIfSlowGradcheckEnv -class TestVisionTracing(JitTestCase): - def setUp(self): - # Checking for mutable operations while tracing is feature flagged - # Enable it in testing but not by default - self.orig_tracer_mutable_flag = ( - pippy.fx.proxy.TracerBase.check_mutable_operations - ) - pippy.fx.proxy.TracerBase.check_mutable_operations = True - - def tearDown(self): - pippy.fx.proxy.TracerBase.check_mutable_operations = ( - self.orig_tracer_mutable_flag - ) - - PROXY_ITERATED = (TraceError, r"Proxy object cannot be iterated") - INCONSISTENT_TYPE = ( - RuntimeError, - r"Return value was annotated as having type __torch__.torchvision.models[.\w]+ but is actually of type Tensor", - ) - - UNTRACEABLE_MODELS = { - "fasterrcnn_resnet50_fpn": PROXY_ITERATED, - "fasterrcnn_resnet50_fpn_v2": PROXY_ITERATED, - "fasterrcnn_mobilenet_v3_large_320_fpn": PROXY_ITERATED, - "fasterrcnn_mobilenet_v3_large_fpn": PROXY_ITERATED, - "maskrcnn_resnet50_fpn": PROXY_ITERATED, - "maskrcnn_resnet50_fpn_v2": PROXY_ITERATED, - "keypointrcnn_resnet50_fpn": PROXY_ITERATED, - "retinanet_resnet50_fpn": PROXY_ITERATED, - "retinanet_resnet50_fpn_v2": PROXY_ITERATED, - "ssd300_vgg16": PROXY_ITERATED, - "fcos_resnet50_fpn": PROXY_ITERATED, - "ssdlite320_mobilenet_v3_large": PROXY_ITERATED, - } - UNSCRIPTABLE_MODELS = { - "googlenet": INCONSISTENT_TYPE, - "inception_v3": INCONSISTENT_TYPE, - } - - output_transform = { - "fcn_resnet50": lambda x: x["out"], - "fcn_resnet101": lambda x: x["out"], - "deeplabv3_resnet50": lambda x: x["out"], - "deeplabv3_resnet101": lambda x: x["out"], - "deeplabv3_mobilenet_v3_large": lambda x: x["out"], - "lraspp_mobilenet_v3_large": lambda x: x["out"], - "fasterrcnn_resnet50_fpn": lambda x: x[1], - "fasterrcnn_mobilenet_v3_large_fpn": lambda x: x[1], - "fasterrcnn_mobilenet_v3_large_320_fpn": lambda x: x[1], - "maskrcnn_resnet50_fpn": lambda x: x[1], - "keypointrcnn_resnet50_fpn": lambda x: x[1], - "retinanet_resnet50_fpn": lambda x: x[1], - } - - @classmethod - def generate_test_fn(cls, name, x, kwargs): - def run_test(self): - model = torchvision_models.get_model(name, **kwargs) - model = model.eval() - if name in self.UNTRACEABLE_MODELS: - err, exc = self.UNTRACEABLE_MODELS[name] - with self.assertRaisesRegex(err, exc): - graph = symbolic_trace(model) - else: - out_transform = self.output_transform.get(name, lambda x: x) - graph: pippy.fx.GraphModule = symbolic_trace(model) - a = out_transform(model(x)) - b = out_transform(graph(x)) - self.assertEqual(a, b) - - if name in self.UNSCRIPTABLE_MODELS: - err, exc = self.UNSCRIPTABLE_MODELS[name] - with self.assertRaisesRegex(err, exc): - script = torch.jit.script(graph) - else: - script = torch.jit.script(graph) - c = out_transform(script(x)) - self.assertEqual(a, c) - - return run_test - - @classmethod - def generate_classification_tests(cls): - for k in torchvision_models.list_models(module=torchvision_models): - test_name = "test_torchvision_models_" + k - x = ( - torch.rand(1, 3, 299, 299) - if k in ["inception_v3"] - else torch.rand(1, 3, 224, 224) - ) - kwargs = dict(num_classes=50) - model_test = cls.generate_test_fn(k, x, kwargs) - setattr(cls, test_name, model_test) - - @classmethod - def generate_segmentation_tests(cls): - for k in torchvision_models.list_models( - module=torchvision_models.segmentation - ): - test_name = "test_torchvision_models_segmentation_" + k - x = torch.rand(1, 3, 32, 32) - kwargs = dict(num_classes=10, pretrained_backbone=False) - model_test = cls.generate_test_fn(k, x, kwargs) - setattr(cls, test_name, model_test) - - @classmethod - def generate_detection_tests(cls): - for k in torchvision_models.list_models( - module=torchvision_models.detection - ): - test_name = "test_torchvision_models_detection_" + k - x = [torch.rand(3, 300, 300)] - kwargs = dict(num_classes=10, pretrained_backbone=False) - model_test = cls.generate_test_fn(k, x, kwargs) - setattr(cls, test_name, model_test) - - @classmethod - def generate_video_tests(cls): - for k in torchvision_models.list_models( - module=torchvision_models.video - ): - test_name = "test_torchvision_models_video_" + k - x = ( - torch.rand(1, 3, 4, 112, 112) - if k not in {"mvit_v1_b", "mvit_v2_s", "s3d"} - else torch.rand(1, 3, 16, 224, 224) - ) - kwargs = dict(num_classes=50) - model_test = cls.generate_test_fn(k, x, kwargs) - setattr(cls, test_name, model_test) - - @classmethod - def generate_tests(cls): - cls.generate_classification_tests() - cls.generate_detection_tests() - cls.generate_segmentation_tests() - cls.generate_video_tests() - - -if HAS_TORCHVISION: - TestVisionTracing.generate_tests() - -if __name__ == "__main__": - run_tests() diff --git a/test/test_fx_experimental.py b/test/test_fx_experimental.py deleted file mode 100644 index 0ba3f2b8e..000000000 --- a/test/test_fx_experimental.py +++ /dev/null @@ -1,1717 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates -# Owner(s): ["module: fx"] - -import math -import numbers -import operator -import pickle -import sys -import tempfile -import unittest -from types import BuiltinFunctionType -from typing import Callable, Dict, List, Optional, Union - -import pippy.fx.experimental.meta_tracer -import pippy.fx.experimental.optimization as optimization - -import torch -from pippy.fx._symbolic_trace import symbolic_trace -from pippy.fx.experimental import merge_matmul -from pippy.fx.experimental.accelerator_partitioner import Partitioner -from pippy.fx.experimental.normalize import NormalizeArgs, NormalizeOperators -from pippy.fx.experimental.partitioner_utils import ( - Device, - get_latency_of_partitioned_graph, - get_partition_to_latency_mapping, - NodeLatency, - PartitionerConfig, - PartitionMode, -) -from pippy.fx.experimental.rewriter import RewritingTracer -from pippy.fx.experimental.schema_type_annotation import AnnotateTypesWithSchema -from pippy.fx.graph_module import GraphModule -from pippy.fx.node import Node -from pippy.fx.operator_schemas import ( - _torchscript_type_to_python_type, - create_type_hint, - normalize_function, - normalize_module, - type_matches, -) -from pippy.fx.passes import graph_manipulation -from pippy.fx.passes.param_fetch import lift_lowering_attrs_to_nodes -from pippy.fx.passes.shape_prop import ShapeProp -from pippy.fx.passes.split_module import split_module -from torch.testing._internal.common_device_type import ( - instantiate_device_type_tests, - onlyCPU, - ops, -) -from torch.testing._internal.common_methods_invocations import op_db -from torch.testing._internal.common_nn import module_tests, new_module_tests -from torch.testing._internal.common_utils import run_tests -from torch.testing._internal.jit_utils import JitTestCase - -try: - import torchvision.models - from torchvision.models import resnet18 - - HAS_TORCHVISION = True -except ImportError: - HAS_TORCHVISION = False -skipIfNoTorchVision = unittest.skipIf(not HAS_TORCHVISION, "no torchvision") -skipIfNoMkldnn = unittest.skipIf( - not ( - torch.backends.mkldnn.enabled and torch.backends.mkldnn.is_available() - ), - "no MKLDNN", -) - - -def symbolic_trace_with_rewrite( - root: Union[torch.nn.Module, Callable] -) -> GraphModule: - return GraphModule( - root if isinstance(root, torch.nn.Module) else torch.nn.Module(), - RewritingTracer().trace(root), - ) - - -class TestFXExperimental(JitTestCase): - def test_find_single_partition(self): - class TestModule(torch.nn.Module): - def forward(self, a, b): - return a + b - - m = TestModule() - traced = symbolic_trace(m) - a = torch.rand(1) - b = torch.rand(1) - graph_manipulation.get_size_of_all_nodes(traced, [a, b]) - partitioner = Partitioner() - devices = [ - Device("dev_0", 125, 0), - Device("dev_1", 150, 1), - Device("dev_2", 125, 2), - ] - partitioner_config = PartitionerConfig(devices) - ret = partitioner.partition_graph(traced, m, partitioner_config) - module_with_submodules = ret.module_with_submodules - dag = ret.dag - self.assertEqual(traced(a, b), module_with_submodules(a, b)) - assert dag.nodes[0].logical_device_ids == [1] - - def test_lack_of_devices(self): - class TestModule(torch.nn.Module): - def forward(self, a, b): - return a + b - - m = TestModule() - traced = symbolic_trace(m) - a = torch.rand(4) - b = torch.rand(4) - graph_manipulation.get_size_of_all_nodes(traced, [a, b]) - partitioner = Partitioner() - devices = [Device("dev_0", 4, 0), Device("dev_1", 4, 1)] - partitioner_config = PartitionerConfig( - devices, PartitionMode.size_based - ) - catch_runtime_error = False - try: - ret = partitioner.partition_graph(traced, m, partitioner_config) - except RuntimeError: - catch_runtime_error = True - assert catch_runtime_error - - def test_large_node_error(self): - class TestModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.linear = torch.nn.Linear(4, 4) - - def forward(self, a): - linear = self.linear(a) - add = linear + a - return add - - m = TestModule() - traced = symbolic_trace(m) - a = torch.rand(4) - graph_manipulation.get_size_of_all_nodes(traced, [a]) - partitioner = Partitioner() - devices = [ - Device("dev_0", 40, 0), - Device("dev_1", 40, 0), - Device("dev_2", 40, 0), - Device("dev_3", 40, 0), - Device("dev_4", 40, 0), - ] - partitioner_config = PartitionerConfig( - devices, PartitionMode.size_based - ) - catch_runtime_error = False - try: - ret = partitioner.partition_graph(traced, m, partitioner_config) - except RuntimeError: - catch_runtime_error = True - assert catch_runtime_error - - def test_partition_node_manipulation(self): - class TestModule(torch.nn.Module): - def forward(self, a, b): - add_1 = a + b - add_2 = add_1 + torch.rand(4) - add_3 = add_2 + torch.rand(4) - return add_3 - - m = TestModule() - traced = symbolic_trace(m) - a, b = torch.rand(4), torch.rand(4) - graph_manipulation.get_size_of_all_nodes(traced, [a, b]) - partitioner = Partitioner() - devices = [Device("dev_0", 1000, 0)] - partitioner_config = PartitionerConfig(devices) - ret = partitioner.partition_graph(traced, m, partitioner_config) - partition = partitioner.partitions[0] - assert partition.used_mem_bytes == 112 - # Select add_2 node to remove - selected_node = None - for node in partition.nodes: - if node.name == "add_2": - selected_node = node - partition.remove_node(selected_node) - assert partition.used_mem_bytes == 80 - - def test_size_based_partition(self): - class TestModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.linear = torch.nn.Linear(4, 4) - self.c = torch.rand(4) - - def forward(self, a, b): - add_1 = a + b - linear = self.linear(add_1) - add_2 = linear + self.c - return add_2 - - m = TestModule() - traced = symbolic_trace(m) - a = torch.rand(4) - b = torch.rand(4) - graph_manipulation.get_size_of_all_nodes(traced, [a, b]) - partitioner = Partitioner() - devices = [ - Device("dev_0", 125, 0), - Device("dev_1", 125, 1), - Device("dev_2", 125, 2), - ] - partitioner_config = PartitionerConfig( - devices, PartitionMode.size_based - ) - ret = partitioner.partition_graph(traced, m, partitioner_config) - module_with_submodules = ret.module_with_submodules - dag = ret.dag - self.assertEqual(traced(a, b), module_with_submodules(a, b)) - for i, node in enumerate(dag.nodes): - assert node.logical_device_ids == [i] - - def test_partition_device_mapping(self): - class TestModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.linear = torch.nn.Linear(4, 4) - - def forward(self, a): - b = torch.rand(4) - add_1 = a + b - linear_1 = self.linear(add_1) - add_2 = torch.rand(4) + a - add_3 = add_2 + linear_1 - return add_3 - - m = TestModule() - traced = symbolic_trace(m) - a = torch.rand(4) - graph_manipulation.get_size_of_all_nodes(traced, [a]) - partitioner = Partitioner() - devices = [Device("dev_0", 120, 0), Device("dev_1", 160, 1)] - partitioner_config = PartitionerConfig( - devices, PartitionMode.size_based - ) - ret = partitioner.partition_graph(traced, m, partitioner_config) - module_with_submodules = ret.module_with_submodules - dag = ret.dag - self.assertEqual(traced(a), module_with_submodules(a)) - for i, node in enumerate(dag.nodes): - if i == 1: - assert node.logical_device_ids == [1] - else: - assert node.logical_device_ids == [0] - - def test_sparse_nn_partition(self): - class MyRecommendationModule(torch.nn.Module): - def create_mlp( - self, num_of_layers: int, input_size: int, output_size: int - ): - layers = torch.nn.ModuleList() - for _ in range(num_of_layers): - ll = torch.nn.Linear(input_size, output_size) - layers.append(ll) - layers.append(torch.nn.ReLU()) - return layers - - def __init__(self): - super(MyRecommendationModule, self).__init__() - layers = self.create_mlp(4, 4, 4) - self.bottom_layers = torch.nn.Sequential(*layers) - layers = self.create_mlp(3, 24, 24) - self.top_layers = torch.nn.Sequential(*layers) - self.embedding_layers = torch.nn.ModuleList() - el = torch.nn.EmbeddingBag(500000, 4, mode="sum", sparse=True) - self.embedding_layers.append(el) - for i in range(3): - el = torch.nn.EmbeddingBag( - 1000000, 4, mode="sum", sparse=True - ) - self.embedding_layers.append(el) - el = torch.nn.EmbeddingBag(500000, 4, mode="sum", sparse=True) - self.embedding_layers.append(el) - - def forward(self, a, b, offset): - x = self.bottom_layers(a) - y = [] - c = [] - for i in range(len(self.embedding_layers)): - temp = torch.randint(10, (8,)) - c.append(temp + b) - for i in range(len(self.embedding_layers)): - if i % 2 == 0: - y.append(self.embedding_layers[i](c[i], offset)) - else: - y.append( - self.embedding_layers[i]( - torch.randint(10, (8,)), offset - ) - ) - z = torch.cat([x] + y, dim=1) - p = self.top_layers(z) - return p - - m = MyRecommendationModule() - a = torch.rand(2, 4) - b = torch.randint(10, (8,)) - offset = torch.randint(1, (2,)) - traced = symbolic_trace(m) - graph_manipulation.get_size_of_all_nodes(traced, [a, b, offset]) - devices = [ - Device("dev_0", 33000000, 0), - Device("dev_1", 33000000, 1), - Device("dev_2", 33000000, 2), - ] - partitioner_config = PartitionerConfig(devices, PartitionMode.sparse_nn) - partitioner = Partitioner() - ret = partitioner.partition_graph(traced, m, partitioner_config) - module_with_submodules = ret.module_with_submodules - dag = ret.dag - self.assertEqual( - traced(a, b, offset), module_with_submodules(a, b, offset) - ) - assert len(module_with_submodules.graph.nodes) == 24 - - def test_partition_latency(self): - class TestModule(torch.nn.Module): - def __init__(self): - super(TestModule, self).__init__() - self.linear = torch.nn.Linear(4, 4) - - def forward(self, a): - add_1 = a + torch.rand(4) - add_2 = add_1 + torch.rand(4) - linear_1 = self.linear(add_1) - add_3 = add_2 + linear_1 - add_4 = add_2 + add_3 - return add_4 - - def get_node_to_latency_mapping(fx_module: GraphModule): - """Given a fx module, generate node latency for each node - based on the size of each node - """ - node_to_latency_mapping: Dict[Node, NodeLatency] = {} - for node in fx_module.graph.nodes: - if node.op not in {"output", "placeholder", "get_attr"}: - if ( - node.size_bytes.total_size - == node.size_bytes.output_size - ): - node_to_latency_mapping[node] = NodeLatency( - node.size_bytes.total_size, - 2.0 * node.size_bytes.total_size, - ) - else: - node_to_latency_mapping[node] = NodeLatency( - node.size_bytes.total_size, - node.size_bytes.output_size, - ) - return node_to_latency_mapping - - m = TestModule() - traced = symbolic_trace(m) - a = torch.rand(4) - graph_manipulation.get_size_of_all_nodes(traced, [a]) - node_to_latency_mapping = get_node_to_latency_mapping(traced) - devices = [Device("dev_0", 200, 0), Device("dev_1", 200, 1)] - partitioner = Partitioner() - partitioner_config = PartitionerConfig(devices) - ret = partitioner.partition_graph(traced, m, partitioner_config) - module_with_submodules = ret.module_with_submodules - self.assertEqual(traced(a), module_with_submodules(a)) - partitions = partitioner.partitions - partition_to_latency_mapping = get_partition_to_latency_mapping( - partitions, node_to_latency_mapping - ) - for p in partition_to_latency_mapping: - if p.partition_id == 0: - assert partition_to_latency_mapping[p] == (128.0, 80.0, 160.0) - else: - assert partition_to_latency_mapping[p] == (16.0, 32.0, 32.0) - transfer_rate_bytes_per_sec = 2 - critical_path_latency_sec = get_latency_of_partitioned_graph( - partitions, - partition_to_latency_mapping, - transfer_rate_bytes_per_sec, - ) - assert critical_path_latency_sec == 208.0 - - def test_cost_aware_partition(self): - class MyModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.linear = torch.nn.Linear(4, 4) - - def forward(self, a): - add_1 = a + torch.rand(4) - add_2 = add_1 + torch.rand(4) - linear_1 = self.linear(add_1) - add_3 = add_2 + torch.rand(4) - add_4 = add_2 + linear_1 - add_5 = add_3 + add_4 - return add_5 - - def get_node_to_latency_mapping(fx_module: GraphModule): - node_to_latency_mapping: Dict[Node, NodeLatency] = {} - for node in fx_module.graph.nodes: - if node.op not in {"output", "placeholder", "get_attr"}: - if ( - node.size_bytes.total_size - == node.size_bytes.output_size - ): - node_to_latency_mapping[node] = NodeLatency( - node.size_bytes.total_size, 1 - ) - else: - node_to_latency_mapping[node] = NodeLatency( - node.size_bytes.total_size, - node.size_bytes.output_size, - ) - return node_to_latency_mapping - - m = MyModule() - traced = symbolic_trace(m) - a = torch.rand(4) - graph_manipulation.get_size_of_all_nodes(traced, [a]) - devices = [ - Device("dev_0", 125, 0), - Device("dev_1", 125, 1), - Device("dev_2", 125, 2), - Device("dev_3", 125, 3), - ] - node_to_latency_mapping = get_node_to_latency_mapping(traced) - partitioner_config = PartitionerConfig( - devices, - mode=PartitionMode.cost_aware, - transfer_rate_bytes_per_sec=2, - node_to_latency_mapping=node_to_latency_mapping, - ) - partitioner = Partitioner() - ret = partitioner.partition_graph(traced, m, partitioner_config) - module_with_submodules = ret.module_with_submodules - dag = ret.dag - self.assertEqual(traced(a), module_with_submodules(a)) - partitions = partitioner.partitions - partition_to_latency_mapping = get_partition_to_latency_mapping( - partitions, node_to_latency_mapping - ) - critical_path_latency_sec = get_latency_of_partitioned_graph( - partitions, - partition_to_latency_mapping, - partitioner_config.transfer_rate_bytes_per_sec, - ) - assert critical_path_latency_sec == 160.0 - - def test_aot_based_partition(self): - class TestModule(torch.nn.Module): - def __init__(self): - super(TestModule, self).__init__() - self.b = torch.rand(4) - self.c = torch.rand(4) - - def forward(self, a): - add_1 = a + self.b - add_2 = self.c + add_1 - return add_2 - - m = TestModule() - traced = symbolic_trace(m) - a = torch.rand(4) - node_to_partition_id = {} - partition_to_logical_devices = {} - count = 0 - graph_manipulation.get_size_of_all_nodes(traced, [a]) - for node in traced.graph.nodes: - if node.op not in {"placeholder", "get_attr", "output"}: - node_to_partition_id[node] = count - partition_to_logical_devices[count] = [0] - count += 1 - devices = [Device("dev_0", 200, 0)] - partitioner_config = PartitionerConfig( - devices=devices, - mode=PartitionMode.aot_based, - node_to_partition_mapping=node_to_partition_id, - partition_to_logical_device_mapping=partition_to_logical_devices, - ) - partitioner = Partitioner() - ret = partitioner.partition_graph(traced, m, partitioner_config) - module_with_submodules = ret.module_with_submodules - dag = ret.dag - self.assertEqual(module_with_submodules(a), traced(a)) - for node in dag.nodes: - assert node.size_bytes == 48 - assert node.logical_device_ids == [0] - - def test_replace_target_nodes_with(self): - class testModule(torch.nn.Module): - def forward(self, a, b): - return a + b - - m = testModule() - traced = symbolic_trace(m) - input1 = torch.randn(1) - input2 = torch.randn(1) - assert (input1 + input2) == traced(input1, input2) - graph_manipulation.replace_target_nodes_with( - fx_module=traced, - old_op="call_function", - old_target=operator.add, - new_op="call_function", - new_target=operator.mul, - ) - assert (input1 * input2) == traced(input1, input2) - - def test_saturate_host(self): - class TestModule(torch.nn.Module): - def __init__(self): - super(TestModule, self).__init__() - self.linear = torch.nn.Linear(4, 4) - - def forward(self, a): - add_1 = a + torch.rand(4) - add_2 = add_1 + torch.rand(4) - linear_1 = self.linear(add_1) - add_3 = add_2 + linear_1 - add_4 = add_2 + add_3 - return add_4 - - m = TestModule() - traced = symbolic_trace(m) - a = torch.rand(4) - graph_manipulation.get_size_of_all_nodes(traced, [a]) - devices = [ - Device("dev_0", 200, 0), - Device("dev_1", 200, 1), - Device("dev_2", 100, 2), - Device("dev_3", 100, 3), - Device("dev_4", 200, 4), - Device("dev_5", 100, 5), - ] - partitioner = Partitioner() - # Without host saturation, the model will be split into two partitions. - # dev_0 holds partition 0 of 192 bytes and dev_1 holds partition 1 of 48 bytes. - partitioner_config = PartitionerConfig(devices, saturate_host=True) - ret = partitioner.partition_graph(traced, m, partitioner_config) - module_with_submodules = ret.module_with_submodules - self.assertEqual(traced(a), module_with_submodules(a)) - - partitions = partitioner.partitions - self.assertEqual(len(partitions), 2) - # With host saturation, partition 1 will be replicated to dev_4, and partition 2 - # will be replicated to dev_2. - self.assertEqual(partitions[0].logical_device_ids, [0, 4]) - self.assertEqual(partitions[1].logical_device_ids, [1, 2]) - - @skipIfNoTorchVision - def test_conv_bn_fusion(self): - rn18 = resnet18().eval() - traced = symbolic_trace(rn18) - fused = optimization.fuse(traced) - - self.assertTrue( - all( - not isinstance(m, torch.nn.BatchNorm2d) for m in fused.modules() - ) - ) - - N, C, H, W = 20, 3, 224, 224 - inp = torch.randn(N, C, H, W) - - self.assertEqual(fused(inp), rn18(inp)) - - def test_conv_bn_fusion_not_running_state(self): - class M(torch.nn.Module): - def __init__(self): - super(M, self).__init__() - self.conv = torch.nn.Conv2d(32, 64, 3, stride=2) - self.bn = torch.nn.BatchNorm2d( - 64, - eps=1e-05, - momentum=0.1, - affine=True, - track_running_stats=False, - ) - - def forward(self, x): - x = self.conv(x) - x = self.bn(x) - return x - - model = M().eval() - - traced = symbolic_trace(model) - fused = optimization.fuse(traced) - inp = torch.randn([1, 32, 50, 50]) - - # bn need not be folded in conv - self.assertTrue( - any(isinstance(m, torch.nn.BatchNorm2d) for m in fused.modules()) - ) - self.assertEqual(fused(inp), model(inp)) - - def test_call_to_assert_no_msg(self): - class M(torch.nn.Module): - def forward(self, a, b): - assert a == b - return a + b - - m = M() - traced = symbolic_trace_with_rewrite(m) - - # Make sure the graph is well-formed - traced.graph.lint() - - # Check the IR to make sure there's a call_function node with target == "Assert" - self.assertTrue( - any( - node.op == "call_function" and node.target == torch._assert - for node in traced.graph.nodes - ) - ) - - # Ensure that the assert throws when it's supposed to and doesn't throw when it's not supposed to - traced(3, 3) - with self.assertRaisesRegex(AssertionError, ""): - traced(3, 5) - - # Confirm that the output is correct - self.assertEqual(traced(3, 3), m(3, 3)) - - def test_meta_tracer(self): - class MetaTracerTestModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.emb = torch.nn.Embedding( - num_embeddings=42, embedding_dim=16 - ) - self.layernorm = torch.nn.LayerNorm(16) - - def forward(self, x): - emb = self.emb(x) - emb = emb + torch.arange( - emb.shape[-1], dtype=torch.float, device=emb.device - ) - lol = self.layernorm(emb) - return ( - torch.relu(lol) if lol.shape[0] < 30 else torch.sigmoid(lol) - ) - - mttm = MetaTracerTestModule() - for BS in [15, 35]: - x = torch.zeros(BS, dtype=torch.long).random_(42) - meta_args = {"x": x.to(device="meta")} - gm = pippy.fx.experimental.meta_tracer.symbolic_trace( - mttm, meta_args=meta_args - ) - torch.testing.assert_close(gm(x), mttm(x)) - - # Test serialization/deserialization - with tempfile.TemporaryDirectory() as tmp_dir: - with open(f"{tmp_dir}/meta_module.pkl", "wb") as f: - pickle.dump(gm, f) - - with open(f"{tmp_dir}/meta_module.pkl", "rb") as f: - loaded = pickle.load(f) - - torch.testing.assert_close(loaded(x), mttm(x)) - - def test_call_to_assert_with_msg(self): - class M(torch.nn.Module): - def forward(self, a, b): - assert a == b, "test message" - return a + b - - m = M() - traced = symbolic_trace_with_rewrite(m) - - # Make sure the graph is well-formed - traced.graph.lint() - - # Check the IR to make sure there's a call_function node with target == "Assert" - self.assertTrue( - any( - node.op == "call_function" and node.target == torch._assert - for node in traced.graph.nodes - ) - ) - - # Ensure that the assert throws when it's supposed to and doesn't throw when it's not supposed to - traced(3, 3) - with self.assertRaisesRegex(AssertionError, "test message"): - traced(3, 5) - - # Confirm that the output is correct - self.assertEqual(traced(3, 3), m(3, 3)) - - def test_call_to_assert_with_empty_msg(self): - class M(torch.nn.Module): - def forward(self, a, b): - assert a == b, "" - return a + b - - m = M() - traced = symbolic_trace_with_rewrite(m) - - # Make sure the graph is well-formed - traced.graph.lint() - - # Check the IR to make sure there's a call_function node with target == "Assert" - self.assertTrue( - any( - node.op == "call_function" and node.target == torch._assert - for node in traced.graph.nodes - ) - ) - - # Ensure that the assert throws when it's supposed to and doesn't throw when it's not supposed to - traced(3, 3) - with self.assertRaisesRegex(AssertionError, ""): - traced(3, 5) - - # Confirm that the output is correct - self.assertEqual(traced(3, 3), m(3, 3)) - - def test_call_to_assert_with_multiline_message(self): - class M(torch.nn.Module): - def forward(self, a, b): - error_msg = """ -An error message with -terrible spacing - """ - assert a == b, error_msg - return a + b - - m = M() - traced = symbolic_trace_with_rewrite(m) - - # Make sure the graph is well-formed - traced.graph.lint() - - # Check the IR to make sure there's a call_function node with target == "Assert" - self.assertTrue( - any( - node.op == "call_function" and node.target == torch._assert - for node in traced.graph.nodes - ) - ) - - # Ensure that the assert throws when it's supposed to and doesn't throw when it's not supposed to - error_msg = """ -An error message with -terrible spacing - """ - traced(3, 3) - with self.assertRaisesRegex(AssertionError, error_msg): - traced(3, 5) - - # Confirm that the output is correct - self.assertEqual(traced(3, 3), m(3, 3)) - - def test_subgraph_creation(self): - class MyModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.param = torch.nn.Parameter(torch.rand(3, 4)) - self.linear = torch.nn.Linear(4, 5) - - def forward(self, x, y): - z = self.linear(x + self.param).clamp(min=0.0, max=1.0) - w = self.linear(y).clamp(min=0.0, max=1.0) - return z + w - - # symbolically trace model - my_module = MyModule() - my_module_traced = symbolic_trace(my_module) - - # random mod partitioning - partition_counter = 0 - NPARTITIONS = 3 - - # Add some random meta info to make sure it is kept around. - for node in my_module_traced.graph.nodes: - if node.op != "output": - node.meta["test_meta_info"] = True - - def mod_partition(node: Node): - nonlocal partition_counter - partition = partition_counter % NPARTITIONS - partition_counter = (partition_counter + 1) % NPARTITIONS - return partition - - # split module in module with submodules - module_with_submodules = split_module( - my_module_traced, my_module, mod_partition - ) - - # Check that test_meta_info was still on all nodes. - submodules = dict(module_with_submodules.named_modules()) - for node in module_with_submodules.graph.nodes: - if node.op == "call_module": - submod = submodules[node.target] - self.assertTrue(isinstance(submod, pippy.fx.GraphModule)) - for submod_node in submod.graph.nodes: - if submod_node.op != "output": - stored_op = submod_node.meta.get("test_meta_info") - self.assertTrue(stored_op is not None and stored_op) - - x = torch.rand(3, 4) - y = torch.rand(3, 4) - - orig_out = my_module_traced(x, y) - submodules_out = module_with_submodules(x, y) - - self.assertEqual(orig_out, submodules_out) - - def test_split_module_kwargs_expansion(self): - class ModuleWithKwargsExpansion(torch.nn.Module): - def forward(self, x, **kwargs): - return x + kwargs["foo"] - - mod = ModuleWithKwargsExpansion() - traced = pippy.fx.symbolic_trace(mod) - - seen_getitem = False - - def split_callback(n): - nonlocal seen_getitem - split_idx = int(seen_getitem) - if n.target == operator.getitem: - seen_getitem = True - return split_idx - - split = split_module(traced, mod, split_callback) - - x = torch.randn(5, 3) - foo = torch.randn(5, 3) - torch.testing.assert_allclose(split(x, foo=foo), traced(x, foo=foo)) - - @skipIfNoTorchVision - def test_subgraph_trivial_resnet(self): - # Smoke test trivially splitting resnet into 1 partition works - # There was an issue before causing submodule names to be aliased - m = resnet18() - traced = symbolic_trace(m) - a = torch.rand(64, 3, 7, 7) - module_with_submodules = split_module(traced, m, lambda node: 0) - module_with_submodules(a) - - def test_split_module_default_arg(self): - class ModelToTrace(torch.nn.Module): - def __init__(self): - super().__init__() - self.lin = torch.nn.Linear(512, 512) - - def forward(self, x, targets=None): - x = self.lin(x) - - if targets is not None: - x = x + targets - - return x - - mtt = ModelToTrace() - traced = pippy.fx.symbolic_trace(mtt, concrete_args={"targets": None}) - - split = split_module(traced, mtt, lambda node: 0) - - x = torch.randn(50, 512) - torch.testing.assert_allclose(split(x), traced(x)) - - def test_normalize_binary_operators(self): - ops_to_test = { - torch.add, - torch.mul, - torch.sub, - torch.div, - torch.floor_divide, - torch.remainder, - torch.eq, - torch.ne, - torch.lt, - torch.le, - torch.gt, - torch.ge, - } - - # Test Tensor/Tensor callsite - for op in ops_to_test: - - class WrapperMod(torch.nn.Module): - def forward(self, x, y): - return op(x, y) - - traced = symbolic_trace(WrapperMod()) - normalized = NormalizeOperators(traced).transform() - x, y = torch.randn(3, 4), torch.randn(3, 4) - torch.testing.assert_close(traced(x, y), normalized(x, y)) - self.assertFalse( - any(n.target in ops_to_test for n in normalized.graph.nodes) - ) - - # Test Tensor/scalar callsite - for op in ops_to_test: - - class WrapperMod(torch.nn.Module): - def forward(self, x): - return op(x, 42) - - traced = symbolic_trace(WrapperMod()) - normalized = NormalizeOperators(traced).transform() - x = torch.randn(3, 4) - torch.testing.assert_close(traced(x), normalized(x)) - self.assertFalse( - any(n.target in ops_to_test for n in normalized.graph.nodes) - ) - - @skipIfNoTorchVision - def test_normalize_args(self): - m = resnet18() - - class FunctionalTracer(pippy.fx.Tracer): - def is_leaf_module( - self, m: torch.nn.Module, module_qualified_name: str - ) -> bool: - # `leaves` contains the set of standard `nn.Modules` that are not - # currently symbolically traceable. Ideally this set would be empty - leaves = set([torch.nn.BatchNorm2d]) - return type(m) in leaves - - traced = pippy.fx.GraphModule(m, FunctionalTracer().trace(m)) - - input = torch.randn(5, 3, 224, 224) - ref_outs = traced(input) - - ShapeProp(traced).propagate(input) - traced = NormalizeArgs(traced).transform() - - modules = dict(traced.named_modules()) - - for node in traced.graph.nodes: - if node.op == "call_function" and node.target != operator.add: - self.assertEqual(len(node.args), 0) - elif node.op == "call_module": - submod_class = modules[node.target].__class__ - nn_class = getattr(torch.nn, submod_class.__name__) - if submod_class == nn_class: - self.assertEqual(len(node.args), 0) - traced(input) - self.assertEqual(traced(input), ref_outs) - - def test_normalize_modules_exhaustive(self): - """ - Exhaustively test `Node.normalized_arguments` on all standard - torch.nn Module classes - """ - for test_params in module_tests + new_module_tests: - if "constructor" not in test_params: - constructor = getattr(torch.nn, test_params["module_name"]) - else: - constructor = test_params["constructor"] - - if "constructor_args" not in test_params: - args = () - else: - args = test_params["constructor_args"] - - mod = constructor(*args) - # Skip modules that are not standard `torch.nn` - # instances, including functionals. (functionals - # are tested in test_normalize_args) - if mod.__class__.__name__ not in dir(torch.nn): - continue - - if "input_fn" not in test_params: - inputs = torch.randn(test_params["input_size"]) - else: - inputs = test_params["input_fn"]() - - if not isinstance(inputs, (tuple, list)): - inputs = (inputs,) - - params = ", ".join(f"v{i}" for i in range(len(inputs))) - - # Generate a class to wrap this standard `nn.Module` instance - test_classname = f"Test{mod.__class__.__name__}" - test_mod_code = f""" -class {test_classname}(torch.nn.Module): - def __init__(self, mod): - super().__init__() - self.mod = mod - - def forward(self, {params}): - return self.mod({params}) - """ - - gbls = {"torch": torch} - exec(test_mod_code, gbls) - - test_instance = gbls[test_classname](mod) - traced = symbolic_trace(test_instance) - - # Use `Node.normalized_arguments` to get a new set of arguments - # to feed to the Module. Then, rewrite the node to only take - # in those arguments as kwargs - modules = dict(traced.named_modules()) - for node in traced.graph.nodes: - if node.op == "call_module": - submod_class = modules[node.target].__class__ - nn_class = getattr(torch.nn, submod_class.__name__) - if submod_class == nn_class: - normalized_args = node.normalized_arguments(traced) - normalized_args2 = normalize_module( - traced, node.target, node.args, node.kwargs - ) - assert normalized_args == normalized_args2 - assert normalized_args - node.args = normalized_args.args - node.kwargs = normalized_args.kwargs - - traced.recompile() - - # These Modules have an RNG in their forward, so testing - # correctness by comparing outputs is not correct. Skip that - # check for these - stochastic_modules = { - "FractionalMaxPool2d", - "FractionalMaxPool3d", - "RReLU", - } - - if mod.__class__.__name__ not in stochastic_modules: - self.assertEqual(traced(*inputs), mod(*inputs)) - - traced = NormalizeArgs(symbolic_trace(test_instance)).transform() - modules = dict(traced.named_modules()) - for node in traced.graph.nodes: - if node.op == "call_module": - submod_class = modules[node.target].__class__ - nn_class = getattr(torch.nn, submod_class.__name__) - if submod_class == nn_class: - self.assertEqual(len(node.args), 0) - - def test_normalize_args_preserve_meta(self): - class MyModule(torch.nn.Module): - def __init__(self): - super().__init__() - - def forward(self, a): - return torch.add(a, 3) - - m = MyModule() - traced = symbolic_trace(m) - - for node in traced.graph.nodes: - if node.op == "call_function" and node.target == torch.add: - node.meta["my_key"] = 7 - break - else: - self.fail("Didn't find call_function torch.add") - - input = torch.randn(2, 3) - ShapeProp(traced).propagate(input) - traced = NormalizeArgs(traced).transform() - - for node in traced.graph.nodes: - if node.op == "call_function" and node.target == torch.add: - self.assertTrue("my_key" in node.meta) - self.assertEqual(node.meta["my_key"], 7) - break - else: - self.fail("Didn't find call_function torch.add") - - def test_normalize_args_perserve_type(self): - class MyModule(torch.nn.Module): - def forward(self, a: List[torch.Tensor]): - return torch.add(a[0], a[1]) - - m = MyModule() - traced = symbolic_trace(m) - traced = NormalizeArgs(traced).transform() - - for node in traced.graph.nodes: - if node.op == "placeholder": - self.assertEqual(node.type, List[torch.Tensor]) - - @skipIfNoTorchVision - def test_annotate_returns_with_schema(self): - m = resnet18() - - traced_modules = symbolic_trace(m) - traced_modules_annotated = AnnotateTypesWithSchema( - traced_modules - ).transform() - for node in traced_modules_annotated.graph.nodes: - if node.type is None: - check = (node.op, node.target) - self.assertIn( - check, - { - ("placeholder", "x"), - ("call_module", "maxpool"), - ("call_function", operator.add), - ("call_function", torch.flatten), - ("output", "output"), - }, - ) - - # Smoke test torchscript compilation since now we're emitting type annotations - torch.jit.script(traced_modules_annotated) - - class FunctionalTracer(pippy.fx.Tracer): - def is_leaf_module( - self, m: torch.nn.Module, module_qualified_name: str - ) -> bool: - # `leaves` contains the set of standard `nn.Modules` that are not - # currently symbolically traceable. Ideally this set would be empty - leaves = set([torch.nn.BatchNorm2d]) - return type(m) in leaves - - traced_functionals = pippy.fx.GraphModule( - m, FunctionalTracer().trace(m) - ) - - traced_functionals_annotated = AnnotateTypesWithSchema( - traced_functionals - ).transform() - for node in traced_functionals_annotated.graph.nodes: - if node.type is None: - check = (node.op, node.target) - excluded_nodes = { - ("placeholder", "x"), - # Return type differs based on boolean dispatch :( - ("call_function", torch.nn.functional.max_pool2d), - ("output", "output"), - } - # AnnotateTypesWithSchema doesn't work with bound C++ functions - if not isinstance(node.target, BuiltinFunctionType): - self.assertIn(check, excluded_nodes) - - # Smoke test torchscript compilation since now we're emitting type annotations - torch.jit.script(traced_functionals_annotated) - - def test_subgraph_uniquename(self): - class MyModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.linear = torch.nn.Linear(4, 4) - - def forward(self, a, b, c, d): - add_1 = a + b - add_2 = add_1 + c - linear_1 = self.linear(add_1) - add_3 = add_2 + d - add_4 = add_2 + linear_1 - add_5 = add_3 + add_4 - return add_5 - - a, b, c, d = torch.ones(4), torch.ones(4), torch.ones(4), torch.ones(4) - mm = MyModule() - traced = symbolic_trace(mm) - - def split_cb(node: pippy.fx.Node): - if node.name == "a" or node.name == "b" or node.name == "add": - return 0 - else: - return 1 - - module_with_submodule = split_module(traced, mm, split_cb) - self.assertEqual(module_with_submodule(a, b, c, d), traced(a, b, c, d)) - - def test_split_qualname_mapping(self): - d_hid = 4 - - class ExampleCode(torch.nn.Module): - def __init__(self): - super().__init__() - self.mm_param = torch.nn.Parameter(torch.randn(d_hid, d_hid)) - self.mm_param2 = torch.nn.Parameter(torch.randn(d_hid, d_hid)) - self.lin = torch.nn.Linear(d_hid, d_hid) - - def forward(self, x): - x = torch.mm(x, self.mm_param) - x = torch.relu(x) - x = torch.mm(x, self.mm_param) - x = self.lin(x) - x = torch.relu(x) - x = torch.mm(x, self.mm_param2) - x = self.lin(x) - return x - - my_module = ExampleCode() - my_module_traced = symbolic_trace(my_module) - - part_idx = 0 - - def split_callback(n: pippy.fx.Node): - nonlocal part_idx - if (n.op, n.target) == ("call_module", "lin"): - part_idx += 1 - return part_idx - - # split module in module with submodules - qualname_map: Dict[str, str] = {} - module_with_submodules = split_module( - my_module_traced, my_module, split_callback, qualname_map - ) - expected_qualname_map = {"submod_1.lin": "lin", "submod_2.lin": "lin"} - self.assertEqual(qualname_map, expected_qualname_map) - - def test_traceable_function_with_nonstandard_name(self): - def foo(x): - return torch.relu(x) - - traced = symbolic_trace_with_rewrite(foo) - - def test_to_folder(self): - class Test(torch.nn.Module): - def __init__(self): - super(Test, self).__init__() - self.W = torch.nn.Parameter(torch.randn(2)) - self.seq = torch.nn.Sequential(torch.nn.BatchNorm1d(2, 2)) - self.linear = torch.nn.Linear(2, 2) - self.attr = torch.randn(2) - self.register_buffer("attr2", torch.randn(2)) - self.register_buffer("attr3", torch.ones(2, dtype=torch.int32)) - - def forward(self, x): - return self.linear( - self.seq(self.W + self.attr + self.attr2 + self.attr3 + x) - ) - - mod = symbolic_trace(Test()) - module_name = "Foo" - import tempfile - from pathlib import Path - - with tempfile.TemporaryDirectory() as tmp_dir: - tmp_dir = Path(tmp_dir) - mod.to_folder(tmp_dir, module_name) - # Recipe taken from here: - # https://docs.python.org/3/library/importlib.html#importing-a-source-file-directly - import importlib.util - - spec = importlib.util.spec_from_file_location( - module_name, tmp_dir / "__init__.py" - ) - module = importlib.util.module_from_spec(spec) - sys.modules[module_name] = module - spec.loader.exec_module(module) - t = torch.randn(2, 2) - self.assertEqual(module.Foo()(t), mod(t)) - - def test_fetch(self): - attrs_for_lowering: Dict[str, List[str]] = { - "torch.nn.modules.conv.Conv2d": [ - "weight", - "bias", - "kernel_size", - "stride", - "padding", - "dilation", - "groups", - "padding_mode", - ], - "torch.nn.modules.batchnorm.BatchNorm2d": [ - "weight", - "bias", - "running_mean", - "running_var", - "eps", - ], - } - - class TestModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.conv = torch.nn.Conv2d(3, 3, 2) - self.bn = torch.nn.BatchNorm2d(3) - - def forward(self, a): - a = self.conv(a) - a += a - return self.bn(a) - - mod = TestModule() - traced = symbolic_trace(mod) - lift_lowering_attrs_to_nodes(traced) - - for node in traced.graph.nodes: - if node.op == "call_module": - assert hasattr(node, "attrs_for_lowering") - para_list = attrs_for_lowering[node.attrs_for_lowering["name"]] - - # node.attrs_for_lowering has an addition field of class name - assert len(para_list) + 1 == len(node.attrs_for_lowering) - for p_name in para_list: - assert p_name in node.attrs_for_lowering - - def test_merge_matmuls(self): - """ - A collection of test cases for pippy.fx.experimental.merge_matmul, - a graph transformation that merges matrix multiplication operations. - """ - - # Utility function for counting matmuls for test assertions. - def _count_matmuls(mod): - gm = pippy.fx.symbolic_trace(mod) - - num_matmuls = 0 - for node in gm.graph.nodes: - if node.target == torch.matmul: - num_matmuls += 1 - - return num_matmuls - - # Simple test case in which there are two matmuls of the same size to merge. - class SimpleMergeMatmulModule(torch.nn.Module): - def __init__(self, rhs): - super().__init__() - self.rhs = rhs - - def forward(self, x, y): - a = torch.matmul(x, self.rhs) - b = torch.matmul(y, self.rhs) - return a + b - - # Initialize inputs. - a = torch.randn(3, 3) - b = torch.randn(3, 3) - - # Initialize RHS for matmuls. - rhs = torch.randn(3, 4) - - # Construct SimpleMergeMatmulModule and call merge_matmul on it. - module = SimpleMergeMatmulModule(rhs) - opt_module = merge_matmul.merge_matmul(module) - - # Numerical correctness check. - before = module(a, b) - after = opt_module(a, b) - before.allclose(after) - - # Basic graph structure check; original module should have 2 matmuls - # and optimized module should have 1. - self.assertEqual(_count_matmuls(module), 2) - self.assertEqual(_count_matmuls(opt_module), 1) - - # Test case in which there are multiple matmuls of different sizes to merge. - class FiveMergeMatmulModule(torch.nn.Module): - def __init__(self, rhs): - super().__init__() - self.rhs = rhs - - def forward(self, a, b, c, d, e): - s = torch.tensor([]) - matmuls = [] - - # For some reason using a list comprehension or for-loop for this - # doesn't work. - matmuls.append(torch.matmul(a, self.rhs)) - matmuls.append(torch.matmul(b, self.rhs)) - matmuls.append(torch.matmul(c, self.rhs)) - matmuls.append(torch.matmul(d, self.rhs)) - matmuls.append(torch.matmul(e, self.rhs)) - - for m in matmuls: - s += torch.sum(m) - - return s - - # Initialize inputs. - inputs = [torch.randn(2 * i + 1, 5) for i in range(5)] - - # Initialize RHS. - rhs = torch.randn(5, 4) - - # Construct FiveMergeMatmulModule and call merge_matmul on it. - module = FiveMergeMatmulModule(rhs) - opt_module = merge_matmul.merge_matmul(module) - - # Numerical correctness check. - before = module(*inputs) - after = opt_module(*inputs) - before.allclose(after) - - # Basic graph structure check; original module should have len(inputs) matmuls - # and optimized module should have 1. - self.assertEqual(_count_matmuls(module), len(inputs)) - self.assertEqual(_count_matmuls(opt_module), 1) - - # Simple test case in which two matmuls cannot be merged due to a data dependency between - # the LHS operands. - class UnmergeableMatmulModule(torch.nn.Module): - def __init__(self, rhs): - super().__init__() - self.rhs = rhs - - def forward(self, x): - a = torch.matmul(x, self.rhs) - a_abs = torch.abs(a) - b = torch.matmul(a_abs.transpose(1, 0), self.rhs) - return b - - # Initialize inputs. - a = torch.randn(3, 3) - - # Initialize RHS for matmuls. - rhs = torch.randn(3, 4) - - # Construct UnmergeableMatmulModule and call merge_matmul on it. - module = UnmergeableMatmulModule(rhs) - opt_module = merge_matmul.merge_matmul(module) - - # Numerical correctness check. - before = module(a) - after = opt_module(a) - before.allclose(after) - - # Basic graph structure check; the number of matrix multiplcations should not have changed. - self.assertEqual(_count_matmuls(module), 2) - self.assertEqual(_count_matmuls(opt_module), 2) - - def test_type_matches(self): - should_be_equal = [ - (int, type(5)), - (numbers.Number, type(5)), - (numbers.Number, type(5.0)), - (int, type(torch.float)), - (Union[int, float], type(5)), - (Union[int, float], type(5.0)), - (List[int], type(5)), - (List[int], create_type_hint([int, int])), - (List[int], create_type_hint((int, int))), - ( - List[torch.Tensor], - create_type_hint([torch.Tensor, torch.Tensor]), - ), - ( - List[torch.Tensor], - create_type_hint([torch.nn.Parameter, torch.nn.Parameter]), - ), - (torch.Tensor, torch.nn.Parameter), - ( - List[torch.Tensor], - create_type_hint([torch.nn.Parameter, torch.Tensor]), - ), - ( - List[torch.Tensor], - create_type_hint([torch.Tensor, torch.nn.Parameter]), - ), - ( - List[torch.Tensor], - create_type_hint((torch.Tensor, torch.Tensor)), - ), - ( - List[torch.Tensor], - create_type_hint((torch.nn.Parameter, torch.nn.Parameter)), - ), - (torch.Tensor, torch.nn.Parameter), - ( - List[torch.Tensor], - create_type_hint((torch.nn.Parameter, torch.Tensor)), - ), - ( - List[torch.Tensor], - create_type_hint((torch.Tensor, torch.nn.Parameter)), - ), - (Optional[List[torch.Tensor]], List[torch.Tensor]), - (Optional[List[int]], List[int]), - ] - for sig_type, arg_type in should_be_equal: - self.assertTrue(type_matches(sig_type, arg_type)) - - should_fail = [ - (int, float), - (Union[int, float], str), - (List[torch.Tensor], List[int]), - ] - - for sig_type, arg_type in should_fail: - self.assertFalse(type_matches(sig_type, arg_type)) - - @skipIfNoMkldnn - def test_optimize_for_inference_cpu(self): - import torch.nn as nn - - class Foo(nn.Module): - def __init__(self): - super().__init__() - layers = [] - layers2 = [] - for _ in range(10): - layers.append(nn.Conv2d(3, 3, 1)) - layers.append(nn.BatchNorm2d(3)) - layers.append(nn.ReLU()) - - layers2.append(nn.Conv2d(3, 3, 1)) - layers2.append(nn.BatchNorm2d(3)) - layers2.append(nn.ReLU()) - self.model = nn.Sequential(*layers) - self.model2 = nn.Sequential(*layers2) - - def forward(self, x): - return self.model(x) + self.model2(x) - - (N, C, H, W) = ( - 1, - 3, - 224, - 224, - ) - inp = torch.randn(N, C, H, W) - with torch.no_grad(): - model = Foo().eval() - optimized_model = optimization.optimize_for_inference(model) - torch.testing.assert_close(model(inp), optimized_model(inp)) - - optimized_model2 = optimization.optimize_for_inference( - model, pass_config={"remove_dropout": False} - ) - torch.testing.assert_close(model(inp), optimized_model2(inp)) - - @skipIfNoTorchVision - @skipIfNoMkldnn - def test_optimize_for_inference_cpu_torchvision(self): - models = [ - torchvision.models.resnet18, - torchvision.models.resnet50, - torchvision.models.densenet121, - torchvision.models.shufflenet_v2_x1_0, - torchvision.models.vgg16, - torchvision.models.mobilenet_v2, - torchvision.models.mnasnet1_0, - torchvision.models.resnext50_32x4d, - ] - with torch.no_grad(): - for model_type in models: - model = model_type() - (C, H, W) = ( - 3, - 224, - 224, - ) - inp = torch.randn(3, C, H, W) - model(inp) - model.eval() - inp = torch.randn(1, C, H, W) - heuristic = optimization.gen_mkl_autotuner( - inp, iters=0, warmup=0 - ) - optimized_model = optimization.optimize_for_inference(model) - - orig_out = model(inp) - new_out = optimized_model(inp) - torch.testing.assert_close(orig_out, new_out) - - -class TestNormalizeOperators(JitTestCase): - @onlyCPU - @ops(op_db, allowed_dtypes=(torch.float,)) - def test_normalize_operator_exhaustive(self, device, dtype, op): - # These ops currently don't trace in FX for various reasons (i.e. they take a list of tensors) - fx_fail = { - "cat", - "stack", - "hstack", - "vstack", - "dstack", - "linalg.multi_dot", - } - sample_inputs_itr = op.sample_inputs(device, dtype, requires_grad=False) - if isinstance(op.op, torch._ops.OpOverload): - self.skipTest("normalize operator doesn't work on torch.ops") - for sample_input in sample_inputs_itr: - unsupported_arg_type = False - arg_values = [sample_input.input] + list(sample_input.args) - kwarg_values = sample_input.kwargs - arg_types = [] - kwarg_types = {} - - def jit_infer_type(v): - inferred_arg_type = torch._C._jit_try_infer_type(v) - assert inferred_arg_type.success() - t = _torchscript_type_to_python_type(inferred_arg_type.type()) - return t - - for v in arg_values: - if isinstance(v, torch.Tensor): - arg_types.append(type(v)) - else: - if isinstance(v, complex): - # Complex type not supported in FX - unsupported_arg_type = True - arg_types.append(jit_infer_type(v)) - - for k, v in kwarg_values.items(): - if isinstance(v, torch.Tensor): - kwarg_types[k] = type(v) - else: - if isinstance(v, complex): - # Complex type not supported in FX - unsupported_arg_type = True - kwarg_types[k] = jit_infer_type(v) - - if unsupported_arg_type: - continue - # Test normalize_function by itself - ref_out = op.op(*arg_values, **kwarg_values) - norm_args_and_kwargs = normalize_function( - op.op, arg_values, kwarg_values, arg_types, kwarg_types - ) - if norm_args_and_kwargs is None: - raise RuntimeError( - """ - FX failed to normalize op - add the op to the op_skip list. - A common reason is if your OpInfo was implemented with a lambda - - otherwise, file an issue - """ - ) - test_out = op.op( - *norm_args_and_kwargs.args, **norm_args_and_kwargs.kwargs - ) - self.assertEqual(test_out, ref_out) - - # Test normalized_arguments as part of FX - if op.name in fx_fail: - continue - param_names = [] - param_values = [] - fx_args = [] - for idx, v in enumerate(arg_values): - if isinstance(v, torch.Tensor): - param_names.append(f"arg_{idx}") - param_values.append(v) - fx_args.append(param_names[-1]) - else: - fx_args.append(f"{repr(v)}") - - for k, v in kwarg_values.items(): - if isinstance(v, torch.Tensor): - param_names.append(k) - param_values.append(v) - fx_args.append(f"{k} = {k}") - else: - fx_args.append(f"{k} = {repr(v)}") - - code = f""" -class TestModule(torch.nn.Module): - def forward(self, {', '.join(param_names)}): - return torch.{op.name}({', '.join(fx_args)}) - """ - - g = {"torch": torch, "inf": math.inf} - exec(code, g) - TestModule = g["TestModule"] - - m = TestModule() - traced = pippy.fx.symbolic_trace(m) - ref_out = traced(*param_values) - - for node in traced.graph.nodes: - if node.op == "call_function": - normalized_args = node.normalized_arguments( - traced, arg_types, kwarg_types - ) - assert normalized_args - node.args = normalized_args.args - node.kwargs = normalized_args.kwargs - traced.recompile() - - test_out = traced(*param_values) - self.assertEqual(test_out, ref_out) - - def test_normalize_quantized_eb(self): - target = torch.ops.quantized.embedding_bag_byte_rowwise_offsets - args = ( - torch.empty((2, 3), dtype=torch.uint8), - torch.empty((2,), dtype=torch.int64), - torch.empty((2,), dtype=torch.int64), - ) - norm_args_and_kwargs = normalize_function( - target, args, normalize_to_only_use_kwargs=True - ) - self.assertTrue(norm_args_and_kwargs is not None) - self.assertEqual( - set(norm_args_and_kwargs.kwargs.keys()), - { - "weight", - "indices", - "offsets", - "scale_grad_by_freq", - "mode", - "pruned_weights", - "per_sample_weights", - "compressed_indices_mapping", - "include_last_offset", - }, - ) - self.assertEqual(norm_args_and_kwargs.args, tuple()) - - def test_normalize_args_op_overload(self): - for target in [ - torch.ops.aten.resize_as_.default, - torch.ops.aten.resize_as_, - ]: - inp1 = torch.rand([1]) - inp2 = torch.rand([4]) - args, kwargs = normalize_function( - target, - (inp1,), - {"the_template": inp2}, - normalize_to_only_use_kwargs=True, - ) - self.assertIs(kwargs["input"], inp1) - self.assertIs(kwargs["the_template"], inp2) - - -instantiate_device_type_tests(TestNormalizeOperators, globals()) - -if __name__ == "__main__": - run_tests() diff --git a/test/test_pipe.py b/test/test_pipe.py new file mode 100644 index 000000000..6259953e7 --- /dev/null +++ b/test/test_pipe.py @@ -0,0 +1,113 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates +import argparse +import unittest + +import torch + +from pippy.IR import Pipe, pipe_split + + +d_hid = 512 +batch_size = 256 + +torch.manual_seed(0) + + +# Basic example +class ExampleCode(torch.nn.Module): + def __init__(self): + super().__init__() + self.mm_param = torch.nn.Parameter(torch.randn(d_hid, d_hid)) + self.mm_param2 = torch.nn.Parameter(torch.randn(d_hid, d_hid)) + self.lin = torch.nn.Linear(d_hid, d_hid) + + def forward(self, x, y): + x = torch.mm(x, self.mm_param) + skip_connection = x + x = x + y + x = torch.relu(x) + pipe_split() + x = torch.mm(x, self.mm_param) + x = self.lin(x) + pipe_split() + x = torch.relu(x) + x = x + skip_connection + x = torch.mm(x, self.mm_param2) + pipe_split() + x = self.lin(x) + x = torch.relu(x) + return x + + +# MLP example +class MLPModule(torch.nn.Module): + def __init__(self, d_hid): + super(MLPModule, self).__init__() + self.net1 = torch.nn.Linear(d_hid, d_hid) + self.relu = torch.nn.ReLU() + self.net2 = torch.nn.Linear(d_hid, d_hid) + + def forward(self, x): + x = self.net1(x) + x = self.relu(x) + x = self.net2(x) + return x + + +class MultiMLP(torch.nn.Module): + def __init__(self): + super().__init__() + self.mlp0 = MLPModule(d_hid) + self.mlp1 = MLPModule(d_hid) + self.mlp2 = MLPModule(d_hid) + self.mlp3 = MLPModule(d_hid) + + def forward(self, x, y): + x = self.mlp0(x) + pipe_split() + x = self.mlp1(x) + pipe_split() + x = self.mlp2(x) + pipe_split() + x = self.mlp3(x) + return x - y + + +def run_worker(args, model_class): + mod = model_class() + x = torch.randn(batch_size, d_hid) + y = torch.randn(batch_size, d_hid) + + pipe = Pipe.from_tracing( + mod, + args.chunks, + example_args=(x, y), + ) + + ref_out = mod(x, y) + out = pipe(x, y)[0] + torch.testing.assert_close(out, ref_out) + print(f"equivalence test passed {torch.sum(out)} ref {torch.sum(ref_out)}") + + +def main(args=None): + parser = argparse.ArgumentParser() + parser.add_argument( + "--chunks", + type=int, + default=4, + ) + args = parser.parse_args(args) + + for model_class in [ExampleCode, MultiMLP]: + print("Testing ", model_class.__name__) + run_worker(args, model_class) + + +if __name__ == "__main__": + main() + + +class TestPipe(unittest.TestCase): + def test_pipe(self): + main(args) diff --git a/test/test_pipe_bwd.py b/test/test_pipe_bwd.py new file mode 100644 index 000000000..e9657621d --- /dev/null +++ b/test/test_pipe_bwd.py @@ -0,0 +1,123 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates +import argparse +import unittest + +import torch +from pippy.IR import Pipe, pipe_split + +from pippy.microbatch import sum_reducer, TensorChunkSpec + + +d_hid = 512 +batch_size = 256 + +torch.manual_seed(0) + + +# Basic example +class ExampleCode(torch.nn.Module): + def __init__(self): + super().__init__() + self.mm_param = torch.nn.Parameter(torch.randn(d_hid, d_hid)) + self.mm_param2 = torch.nn.Parameter(torch.randn(d_hid, d_hid)) + self.lin = torch.nn.Linear(d_hid, d_hid) + self.mse_loss = torch.nn.MSELoss(reduction="sum") + + def forward(self, x, y): + x = torch.mm(x, self.mm_param) + skip_connection = x + x = torch.relu(x) + pipe_split() + x = torch.mm(x, self.mm_param) + x = self.lin(x) + pipe_split() + x = torch.relu(x) + x = x + skip_connection + x = torch.mm(x, self.mm_param2) + pipe_split() + x = self.lin(x) + logits = torch.relu(x) + loss = self.mse_loss(x, y) + return logits, loss + + +# MLP example +class MLPModule(torch.nn.Module): + def __init__(self, d_hid): + super(MLPModule, self).__init__() + self.net1 = torch.nn.Linear(d_hid, d_hid) + self.relu = torch.nn.ReLU() + self.net2 = torch.nn.Linear(d_hid, d_hid) + + def forward(self, x): + x = self.net1(x) + x = self.relu(x) + x = self.net2(x) + return x + + +class MultiMLP(torch.nn.Module): + def __init__(self): + super().__init__() + self.mlp0 = MLPModule(d_hid) + self.mlp1 = MLPModule(d_hid) + self.mlp2 = MLPModule(d_hid) + self.mlp3 = MLPModule(d_hid) + self.mse_loss = torch.nn.MSELoss(reduction="sum") + + def forward(self, x, y): + x = self.mlp0(x) + pipe_split() + x = self.mlp1(x) + pipe_split() + x = self.mlp2(x) + pipe_split() + x = self.mlp3(x) + loss = self.mse_loss(x, y) + return x, loss + + +def run_worker(args, model_class): + mod = model_class() + x = torch.randn(batch_size, d_hid) + y = torch.randn(batch_size, d_hid) + + output_chunk_spec = ( + TensorChunkSpec(0), # logits + sum_reducer, # loss + ) + + pipe = Pipe.from_tracing( + mod, + args.chunks, + example_args=(x, y), + output_chunk_spec=output_chunk_spec, + ) + + ref_out = mod(x, y) + out = pipe(x, y) + torch.testing.assert_close(out, ref_out) + print(f"equivalence test passed loss={out[1]} ref_loss={ref_out[1]}") + + +def main(args=None): + parser = argparse.ArgumentParser() + parser.add_argument( + "--chunks", + type=int, + default=4, + ) + args = parser.parse_args(args) + + for model_class in [ExampleCode, MultiMLP]: + print("Testing ", model_class.__name__) + run_worker(args, model_class) + + +if __name__ == "__main__": + main() + + +class TestPipeBwd(unittest.TestCase): + def test_pipe_bwd(self): + main(args)