diff --git a/deepspeed/runtime/zero/compile/stage3_backend.py b/deepspeed/runtime/zero/compile/stage3_backend.py index 52a9bec396a5..808784696ec5 100644 --- a/deepspeed/runtime/zero/compile/stage3_backend.py +++ b/deepspeed/runtime/zero/compile/stage3_backend.py @@ -237,10 +237,8 @@ def fw(gm, sample_inputs): gm.recompile() if debug_log: - nz3.start_forward() mem_prof = MemoryProfilingInterpreter(gm) mem_prof.run(*real_inputs) - nz3.end_forward() if rank == 0: mem_prof.dump(f"mem_prof_fwd_{graph_id}.csv") @@ -310,10 +308,8 @@ def bw(gm, sample_inputs): gm.recompile() if debug_log: - nz3.start_backward(True) mem_prof = MemoryProfilingInterpreter(gm) mem_prof.run(*validated_inputs) - nz3.end_backward() if rank == 0: mem_prof.dump(f"mem_prof_bwd_{graph_id}.csv")