Skip to content

Commit

Permalink
float8 delayed scaling: remove need to use workaround for AC (#1291)
Browse files Browse the repository at this point in the history
Summary:

The logic to check if the user has called
`sync_float8_amax_and_scale_history`, while nice, hasn't been that
useful in practice.  Removing this logic in order to simplify the
integration with manual activation checkpointing - now it "just works"
without the need to work around with a non-standard config.

Test Plan:

```
pytest test/float8/test_base.py -s -x
```

Reviewers:

Subscribers:

Tasks:

Tags:
  • Loading branch information
vkuzo authored Nov 15, 2024
1 parent 6735461 commit 56bf2e8
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 17 deletions.
14 changes: 12 additions & 2 deletions test/float8/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,7 @@ def _test_linear_impl(
x,
m_ref,
config: Float8LinearConfig,
use_ac: bool = False,
):
m_fp8 = Float8Linear.from_float(
copy.deepcopy(m_ref),
Expand All @@ -269,9 +270,15 @@ def _test_linear_impl(
for _ in range(2):
if linear_requires_sync(config):
sync_float8_amax_and_scale_history(m_fp8)
y_fp8 = m_fp8(x)
if use_ac:
y_fp8 = torch.utils.checkpoint.checkpoint(m_fp8, x, use_reentrant=False)
else:
y_fp8 = m_fp8(x)
y_fp8.sum().backward()
y_ref = m_ref(x)
if use_ac:
y_ref = torch.utils.checkpoint.checkpoint(m_ref, x, use_reentrant=False)
else:
y_ref = m_ref(x)
y_ref.sum().backward()

assert y_ref.shape == y_fp8.shape
Expand Down Expand Up @@ -344,6 +351,7 @@ def _test_linear_impl(
)
@pytest.mark.parametrize("linear_dtype", [torch.bfloat16, torch.float32])
@pytest.mark.parametrize("linear_bias", [False, True])
@pytest.mark.parametrize("use_ac", [False, True])
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
def test_linear_from_config_params(
self,
Expand All @@ -354,6 +362,7 @@ def test_linear_from_config_params(
scaling_type_grad_output: ScalingType,
linear_dtype: torch.dtype,
linear_bias: bool,
use_ac: bool,
):
x = torch.randn(*x_shape, device="cuda", dtype=linear_dtype)
m_ref = nn.Linear(16, 32, bias=linear_bias, device="cuda", dtype=linear_dtype)
Expand All @@ -369,6 +378,7 @@ def test_linear_from_config_params(
x,
m_ref,
config,
use_ac,
)

# Note: there are now too many config combinations to test all of
Expand Down
15 changes: 0 additions & 15 deletions torchao/float8/float8_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,10 +334,6 @@ def __init__(self, *args, **kwargs):
# TODO(future PR): add serialization for this flag
self.is_amax_initialized = not self.config.enable_amax_init

# Syncing of amaxes and scales happens outside of this function. This
# flag is here to enforce that the user does not forget to do this.
self.amax_and_scale_synced = not self.config.enable_amax_init

# This is needed to properly handle autocast in the amax/scale
# update function for torch.float16
self.last_seen_input_dtype = None
Expand Down Expand Up @@ -544,23 +540,12 @@ def cast_output_to_float8_in_bw(self, output: torch.Tensor) -> torch.Tensor:
def float8_pre_forward(self, input):
if not self.enable_pre_and_post_forward:
return
if (
self.is_amax_initialized
and (not self.amax_and_scale_synced)
and torch.is_grad_enabled()
):
raise AssertionError(
"amaxes and scales not synced, please call `sync_float8_amax_and_scale_history` before forward"
)
self.last_seen_input_dtype = input.dtype

def float8_post_forward(self):
if not self.enable_pre_and_post_forward:
return
# Ensure that calling forward again will fail until the user syncs
# amaxes and scales
self.is_amax_initialized = True
self.amax_and_scale_synced = False

def forward_fp8_matmul(self, input: torch.Tensor) -> torch.Tensor:
has_any_axiswise_scaling = (
Expand Down

0 comments on commit 56bf2e8

Please sign in to comment.