Skip to content

Commit

Permalink
Merge branch 'main' into convert_mds_script
Browse files Browse the repository at this point in the history
  • Loading branch information
irenedea authored Sep 15, 2023
2 parents 6342902 + 7ec2fe0 commit 1fa9797
Showing 1 changed file with 56 additions and 42 deletions.
98 changes: 56 additions & 42 deletions tests/test_lion8b.py
Original file line number Diff line number Diff line change
Expand Up @@ -504,45 +504,59 @@ def _set_state_dict_type(model: nn.Module):
def test_fused_as_fast_as_unfused(N: int,
D: int,
min_elems_traversed: int = 1000000):
W = torch.randn((N, D), device='cuda', requires_grad=True)
W.grad = torch.randn((N, D), device='cuda', requires_grad=False)

num_iters = int(np.ceil(min_elems_traversed / W.grad.numel()))
num_iters = min(100, num_iters) # don't take all day when overhead-bound

times = {}
kwargs = {'weight_decay': .01}
combos = [(True, False), (True, True), (False, False), ('NA', False)]
for fused, use_errors in combos:
if fused == 'NA':
opt = Lion8bit([W], quantize=False,
**kwargs) # type:ignore (reportGeneralTypeIssues)
else:
opt = Lion8bit([W],
_fused=fused,
error_correction=use_errors,
**kwargs) # type:ignore (reportGeneralTypeIssues)
for _ in range(3):
opt.step() # warmup iters
torch.cuda.synchronize()
t_start = time.time()
for _ in range(num_iters):
opt.step()
torch.cuda.synchronize()
t_end = time.time()
dur = (t_end - t_start) / num_iters
if use_errors:
times['ecc'] = dur
else:
times[fused] = dur

atol = 20e-6 # should always be faster, but avoids rare flakiness
assert times[True] < times[False] + atol
assert times[True] < times['NA'] + atol
assert times['ecc'] < times['NA'] + atol

print('')
print('time fused (ms): ', times[True] * 1e3)
print('time fused+ecc (ms): ', times['ecc'] * 1e3)
print('time unfused (ms): ', times[False] * 1e3)
print('time unquantized (ms): ', times['NA'] * 1e3)

def _time_kernels(N: int, D: int, min_elems_traversed: int):
W = torch.randn((N, D), device='cuda', requires_grad=True)
W.grad = torch.randn((N, D), device='cuda', requires_grad=False)

num_iters = int(np.ceil(min_elems_traversed / W.grad.numel()))
num_iters = min(100,
num_iters) # don't take all day when overhead-bound

times = {}
kwargs = {'weight_decay': .01}
combos = [(True, False), (True, True), (False, False), ('NA', False)]
for fused, use_errors in combos:
if fused == 'NA':
opt = Lion8bit(
[W], quantize=False,
**kwargs) # type:ignore (reportGeneralTypeIssues)
else:
opt = Lion8bit(
[W], _fused=fused, error_correction=use_errors,
**kwargs) # type:ignore (reportGeneralTypeIssues)
for _ in range(3):
opt.step() # warmup iters
torch.cuda.synchronize()
t_start = time.time()
for _ in range(num_iters):
opt.step()
torch.cuda.synchronize()
t_end = time.time()
dur = (t_end - t_start) / num_iters
if use_errors:
times['ecc'] = dur
else:
times[fused] = dur
return times

times = _time_kernels(N, D, min_elems_traversed)

atol = 2e-4 # should always be faster, but atol helps avoid flakiness
it = 0
while True:
try:
assert times[True] < times[False] + atol
assert times[True] < times['NA'] + atol
assert times['ecc'] < times['NA'] + atol
print('')
print('time fused (ms): ', times[True] * 1e3)
print('time fused+ecc (ms): ', times['ecc'] * 1e3)
print('time unfused (ms): ', times[False] * 1e3)
print('time unquantized (ms): ', times['NA'] * 1e3)
break
except AssertionError as e:
if it >= 2: # allow 3 retries to avoid flakiness
raise e
times = _time_kernels(N, D, min_elems_traversed)
it += 1

0 comments on commit 1fa9797

Please sign in to comment.