diff --git a/tests/test_lion8b.py b/tests/test_lion8b.py index 2852d99b8b..dbd6ff6352 100644 --- a/tests/test_lion8b.py +++ b/tests/test_lion8b.py @@ -504,45 +504,59 @@ def _set_state_dict_type(model: nn.Module): def test_fused_as_fast_as_unfused(N: int, D: int, min_elems_traversed: int = 1000000): - W = torch.randn((N, D), device='cuda', requires_grad=True) - W.grad = torch.randn((N, D), device='cuda', requires_grad=False) - - num_iters = int(np.ceil(min_elems_traversed / W.grad.numel())) - num_iters = min(100, num_iters) # don't take all day when overhead-bound - - times = {} - kwargs = {'weight_decay': .01} - combos = [(True, False), (True, True), (False, False), ('NA', False)] - for fused, use_errors in combos: - if fused == 'NA': - opt = Lion8bit([W], quantize=False, - **kwargs) # type:ignore (reportGeneralTypeIssues) - else: - opt = Lion8bit([W], - _fused=fused, - error_correction=use_errors, - **kwargs) # type:ignore (reportGeneralTypeIssues) - for _ in range(3): - opt.step() # warmup iters - torch.cuda.synchronize() - t_start = time.time() - for _ in range(num_iters): - opt.step() - torch.cuda.synchronize() - t_end = time.time() - dur = (t_end - t_start) / num_iters - if use_errors: - times['ecc'] = dur - else: - times[fused] = dur - - atol = 20e-6 # should always be faster, but avoids rare flakiness - assert times[True] < times[False] + atol - assert times[True] < times['NA'] + atol - assert times['ecc'] < times['NA'] + atol - - print('') - print('time fused (ms): ', times[True] * 1e3) - print('time fused+ecc (ms): ', times['ecc'] * 1e3) - print('time unfused (ms): ', times[False] * 1e3) - print('time unquantized (ms): ', times['NA'] * 1e3) + + def _time_kernels(N: int, D: int, min_elems_traversed: int): + W = torch.randn((N, D), device='cuda', requires_grad=True) + W.grad = torch.randn((N, D), device='cuda', requires_grad=False) + + num_iters = int(np.ceil(min_elems_traversed / W.grad.numel())) + num_iters = min(100, + num_iters) # don't take all day when overhead-bound + + times = {} + kwargs = {'weight_decay': .01} + combos = [(True, False), (True, True), (False, False), ('NA', False)] + for fused, use_errors in combos: + if fused == 'NA': + opt = Lion8bit( + [W], quantize=False, + **kwargs) # type:ignore (reportGeneralTypeIssues) + else: + opt = Lion8bit( + [W], _fused=fused, error_correction=use_errors, + **kwargs) # type:ignore (reportGeneralTypeIssues) + for _ in range(3): + opt.step() # warmup iters + torch.cuda.synchronize() + t_start = time.time() + for _ in range(num_iters): + opt.step() + torch.cuda.synchronize() + t_end = time.time() + dur = (t_end - t_start) / num_iters + if use_errors: + times['ecc'] = dur + else: + times[fused] = dur + return times + + times = _time_kernels(N, D, min_elems_traversed) + + atol = 2e-4 # should always be faster, but atol helps avoid flakiness + it = 0 + while True: + try: + assert times[True] < times[False] + atol + assert times[True] < times['NA'] + atol + assert times['ecc'] < times['NA'] + atol + print('') + print('time fused (ms): ', times[True] * 1e3) + print('time fused+ecc (ms): ', times['ecc'] * 1e3) + print('time unfused (ms): ', times[False] * 1e3) + print('time unquantized (ms): ', times['NA'] * 1e3) + break + except AssertionError as e: + if it >= 2: # allow 3 retries to avoid flakiness + raise e + times = _time_kernels(N, D, min_elems_traversed) + it += 1