-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
605e6c9
commit 8b809c2
Showing
3 changed files
with
150 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,139 @@ | ||
# Verification | ||
|
||
## General Overview | ||
|
||
When comparing our `compiled_model` with the `framework_model` (e.g., `PyTorch`), we aim to verify whether the output from the `compiled_model` is sufficiently similar to the output from the `framework_model` (where required degree of similarity is configurable). | ||
|
||
So generally we want to perform the following steps: | ||
1. Create a framework model | ||
2. Run a forward pass through the framework model | ||
3. Compile the framework model using `Forge` | ||
4. Run a forward pass through the compiled model | ||
5. Compare the outputs. | ||
|
||
Most of the above steps `verify()` function does for us: | ||
- Handles forward passes for both framework and compiled models | ||
- Compares results using a combination of comparison methods | ||
- Supports customization through the `VerifyConfig` class. | ||
|
||
**Example of usage** | ||
```python | ||
def test_add(): | ||
|
||
class Add(nn.Module): | ||
def __init__(self): | ||
super().__init__() | ||
|
||
def forward(self, a, b): | ||
return a + b | ||
|
||
|
||
inputs = [torch.rand(2, 32, 32), torch.rand(2, 32, 32)] | ||
|
||
framework_model = Add() | ||
compiled_model = forge.compile(framework_model, sample_inputs=inputs) | ||
|
||
verify(inputs, framework_model, compiled_model) | ||
``` | ||
|
||
**Notes:** | ||
- If you only want to compile model and perform forward pass without comparing outputs you can just: | ||
```python | ||
def test_add(): | ||
|
||
class Add(nn.Module): | ||
def __init__(self): | ||
super().__init__() | ||
|
||
def forward(self, a, b): | ||
return a + b | ||
|
||
|
||
inputs = [torch.rand(2, 32, 32), torch.rand(2, 32, 32)] | ||
|
||
framework_model = Add() | ||
compiled_model = forge.compile(framework_model, sample_inputs=inputs) | ||
|
||
fw_out = framework_model(*inputs) | ||
co_out = compiled_model(*inputs) | ||
``` | ||
|
||
## Verify Config Overview | ||
|
||
Currently through `VerifyConfig` you can disable/enable: | ||
| Feature | Enabled (default) | | ||
|-----------------------------------|:-----------------:| | ||
| Verification as a method | `True` | | ||
| Output size check | `True` | | ||
| Output type check | `True` | | ||
| Output shape check | `True` | | ||
|
||
For more information about `VerifyConfig` you can check `forge/forge/verify/config.py` | ||
</br> | ||
|
||
Besides that, config also includes value checker. There are 3 types of checker: | ||
- `AutomaticValueChecker` **(default)** | ||
- `AllCloseValueChecker` | ||
- `FullValueChecker` | ||
|
||
<br/> | ||
<br/> | ||
|
||
<img src="imgs/checkers/checkers.svg" alt="Checkers" style="width: 1200%;"> | ||
|
||
|
||
|
||
### AutomaticValueChecker | ||
|
||
This checker performs tensor checks based on the **shape** and **type** of tensor (e.g. for scalars it will perform `torch.allclose` as `pcc` shouldn't be applied to the scalars) | ||
|
||
For this checker you can set: | ||
- `pcc` | ||
- `rtol` | ||
- `atol` | ||
- `dissimilarity_threshold` | ||
|
||
**Example of usage:** | ||
```python | ||
# default behavior | ||
verify(inputs, framework_model, compiled_model) | ||
# this will result same as the default behavior | ||
verify(inputs, framework_model, compiled_model, VerifyConfig(value_checker=AutomaticValueChecker()) | ||
# setting pcc and rtol | ||
verify(inputs, framework_model, compiled_model, VerifyConfig(value_checker=AutomaticValueChecker(pcc=0.95, rtol=1e-03))) | ||
``` | ||
|
||
### AllCloseValueChecker | ||
|
||
This checker checks tensors using `torch.allclose` method. | ||
|
||
For this checker you can set: | ||
- `rtol` | ||
- `atol` | ||
|
||
**Example of usage:** | ||
```python | ||
# setting allclose checker with default values | ||
verify(inputs, framework_model, compiled_model, VerifyConfig(value_checker=AllCloseValueChecker())) | ||
# setting allclose checker with custom values | ||
verify(inputs, framework_model, compiled_model, VerifyConfig(value_checker=AllCloseValueChecker(rtol=1e-03))) | ||
|
||
``` | ||
|
||
### FullValueChecker | ||
|
||
This checker is combination of `AutomaticValueChecker` and `AllCloseValueChecker`. | ||
|
||
For this checker you can set: | ||
- `pcc` | ||
- `rtol` | ||
- `atol` | ||
- `dissimilarity_threshold` | ||
|
||
**Examples of usage:** | ||
```python | ||
# setting full checker with default values | ||
verify(inputs, framework_model, compiled_model, VerifyConfig(value_checker=FullValueChecker()) | ||
# setting full checker with custom values | ||
verify(inputs, framework_model, compiled_model, VerifyConfig(value_checker=FullValueChecker(pcc=0.95, rtol=1e-03))) | ||
``` |