Skip to content

Commit

Permalink
Generate op tests based on run model (#599)
Browse files Browse the repository at this point in the history
- Introduce new compiler configuration: tvm_generate_op_tests
  (true/false)
- Introduce logic that will generate op tests for a model we run

_Note: This change is based on few dependant issues that are cherry-picked. Feel free to review only last commit_

Fix #589
  • Loading branch information
nvukobratTT authored Nov 7, 2024
1 parent fcbee46 commit 8957339
Show file tree
Hide file tree
Showing 2 changed files with 89 additions and 0 deletions.
3 changes: 3 additions & 0 deletions forge/forge/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,9 @@ class CompilerConfig:
# Number of patterns to match for each module
tvm_module_to_num_patterns: Dict[str, int] = field(default_factory=lambda: dict())

# If enabled, for given test, it generates Forge Modules in form of PyTest for each op that exists in given module
tvm_generate_op_tests: bool = False

# Enables a transform for conv that directly reads input, such that it goes from stride > 1 to stride = 1
# This usually translates to lower DRAM BW and less math as the input better populates tiles
enable_conv_prestride: bool = True
Expand Down
86 changes: 86 additions & 0 deletions forge/forge/tvm_to_python.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import torch
import numpy as np
import pytest

# import forge._C.pattern_matcher as pypattern_matcher
from forge.module import OnnxModule, ForgeModule, TFLiteModule
Expand Down Expand Up @@ -2682,7 +2683,92 @@ def delete_unneeded_outputs(ops, returns):

modules.append(writer)

# Generate op tests based on requested model. Currently only supported
# for PyTorch framework.
if compiler_cfg.tvm_generate_op_tests:
generate_op_tests(
ops,
current_module_name,
framework,
contains_incompatible_np_floats,
delete_inputs,
params,
constants,
param_names,
param_file_name,
names_params_file_name,
named_buffers_file_name,
)

# Exit python progrems without error
# - Two different exit methods depending on whether compile is run using
# pytest, or as a standalone python script
if "pytest" in sys.modules:
pytest.exit("Exiting test without error", returncode=0)
else:
sys.exit(0)

if compiler_cfg.retain_tvm_python_files:
save_writers_metadata(modules, flattened_pytorch_inputs, forge_inputs, graph_name)

return modules, forge_inputs


def generate_op_tests(
ops,
current_module_name,
framework,
contains_incompatible_np_floats,
delete_inputs,
params,
constants,
param_names,
param_file_name,
names_params_file_name,
named_buffers_file_name,
):
"""
Generates test modules for a list of operations.
This function creates unique test modules for each operation in the provided list.
It initializes a ForgeWriter to generate the necessary code for testing each operation,
including headers, class definitions, forward functions, parameter parsers, and pytest functions.
The generated tests are designed to run the operations as standalone tests.
"""
for op_idx, key in enumerate(sorted(ops)):
# Create unique module name
module_name = "test_" + current_module_name.lower() + str(op_idx)

# Initialize Forge writer and generate header and class definition
writer = ForgeWriter(
module_name,
framework,
contains_incompatible_np_floats=contains_incompatible_np_floats,
delete_inputs=delete_inputs,
)
writer.write_header()
writer.write_class_definition(params, constants)

# Focus on generating test for a single op
single_op = {key: ops[key]}

# Create new inputs for the single op
new_inputs = {}
for i, input_name in enumerate(single_op[key].input_names):
# Detected parameter as input, insert dummy input
# TODO: Need to handle this case better. Probably just ignoring
# model parameters, and using new generated inputs.
if "." in input_name:
input_name = "dummy_input_" + str(i)
new_inputs[input_name] = input_name

# Force output to be same as the op we're running
single_return = {key: single_op[key].output_name}

# Generate forward function and parameter parser (loading params and constants)
writer.write_forward(single_op, new_inputs, single_return)
writer.write_param_parser(param_names, param_file_name, names_params_file_name, named_buffers_file_name)

# Generate pytest function that enables runing Forge Module as standalone test
writer.write_pytest_function(module_name, single_op[key].input_shapes)
writer.close_file()

0 comments on commit 8957339

Please sign in to comment.