Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fully Sharded Data Parallel #489

Merged
merged 33 commits into from
Apr 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
c33019c
fsdp first draft
ncassereau Oct 6, 2023
12908f8
change how scaler works
ncassereau Oct 6, 2023
871d04f
Add forgotten change
ncassereau Oct 6, 2023
21e0ab8
Merge branch 'dev' into fsdp
ncassereau Oct 6, 2023
03ada6d
pep8
ncassereau Oct 6, 2023
0f9d096
Provide fsdp argument to all DDP instanciation
ncassereau Oct 6, 2023
b0f232a
working even if fsdp is unavailable
ncassereau Oct 6, 2023
91ecc13
Linter
ncassereau Oct 6, 2023
6317e17
typing for archaic python version
ncassereau Oct 6, 2023
565127e
stupid mistake
ncassereau Oct 6, 2023
422bbd2
Improved monkeypatching. Now allows source inspection
ncassereau Oct 6, 2023
208b7e1
Temporarily always activate FSDP to test on github actions (should be…
ncassereau Oct 6, 2023
4e412db
Temporarily always activate FSDP to test on github actions (should be…
ncassereau Oct 6, 2023
972a988
Temporarily always activate FSDP to test on github actions (should be…
ncassereau Oct 6, 2023
292a5ff
Restore standard code and add warning for FSDP before pytorch 2.0
ncassereau Oct 6, 2023
0aa529f
Now using FSDP MixedPrecision instead of autocast
ncassereau Oct 10, 2023
ef460ad
Merge branch 'dev' into fsdp
ncassereau Nov 30, 2023
f97627e
Merge branch 'dev' into fsdp
ncassereau Dec 12, 2023
381f636
Solved an issue where the handlers would get duplicated
ncassereau Dec 13, 2023
eea7368
Correct wrong git merge
ncassereau Dec 13, 2023
1fc0a16
Compute correct validation ?
ncassereau Dec 14, 2023
8ac8260
Solve an issue where not all GPU would try a collective communication…
ncassereau Dec 14, 2023
497afe0
Add eval mode to inference functions
ncassereau Dec 14, 2023
0a09928
Better warning
ncassereau Dec 14, 2023
365585f
Useless import removed
ncassereau Dec 14, 2023
9ab6741
FSDP auto wrap
ncassereau Dec 14, 2023
cbf0d2a
linter
ncassereau Dec 15, 2023
a1b5207
fake commit
ncassereau Jan 9, 2024
9edd42e
Merge branch 'dev' into fsdp
camillebrianceau Feb 1, 2024
552cfd6
Last changes
ncassereau Apr 1, 2024
35bd1a7
Merge branch 'dev' into fsdp
ncassereau Apr 1, 2024
57615d3
linter
ncassereau Apr 2, 2024
793f6f5
Merge branch 'dev' into cb_fsdp
camillebrianceau Apr 5, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions clinicadl/utils/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ def setup_logging(verbose: bool = False) -> None:
err_handler.addFilter(StdLevelFilter(err=True))
err_handler.setFormatter(err_formatter)

logger.handlers = []
logger.addHandler(console_handler)
logger.addHandler(err_handler)

Expand Down
138 changes: 132 additions & 6 deletions clinicadl/utils/maps_manager/ddp.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,40 @@
import inspect
import linecache
import logging
from dataclasses import dataclass
from functools import partial
from logging import Logger
from textwrap import dedent
from types import CodeType, FunctionType, MethodType
from typing import Any, Optional, Set
from typing import Any, Optional, Set, Union
from uuid import uuid4

import torch
import torch.distributed as dist
from packaging.version import Version
from torch.nn import Module
from torch.nn.parallel import DistributedDataParallel
from torch.optim import Optimizer

try:
from torch.distributed.fsdp import (
FullOptimStateDictConfig,
FullStateDictConfig,
FullyShardedDataParallel,
MixedPrecision,
ShardingStrategy,
StateDictType,
)
from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy
except ImportError:
fsdp_available = False
else:
fsdp_available = True

from . import cluster

logger = logging.getLogger("DDP")


@dataclass
class Methods:
Expand Down Expand Up @@ -104,7 +127,8 @@ def monkeypatch(model: Module) -> None:
monkeypatched_code = dedent(
source_code.replace("self.forward", "self._forward")
)
compiled_code = compile(monkeypatched_code, "<string>", "exec")
filename = f"<dynamic-{int(uuid4())}>"
compiled_code = compile(monkeypatched_code, filename, "exec")

# If the function has default arguments, then the code of the function
# will not be the first constant in the defined code but will be after
Expand All @@ -115,6 +139,16 @@ def monkeypatch(model: Module) -> None:
else:
raise ValueError("Expected to find code object, did not find any.")

# Store the patched code source in the cache so that it can be retrieved
# later on by inspect.getsource. Otherwise, inspect would not be
# able to get source code from dynamically generated functions.
linecache.cache[filename] = (
len(monkeypatched_code),
None,
[line + "\n" for line in monkeypatched_code.splitlines()],
filename,
)

# Convert code to a method bound to the given model.
function = FunctionType(
code=const,
Expand All @@ -130,11 +164,59 @@ def monkeypatch(model: Module) -> None:
model.forward = MethodType(forward, model)


class DDP(DistributedDataParallel):
def __init__(self, model: Module, *args, **kwargs):
monkeypatch(model)
super().__init__(model, *args, **kwargs)
if fsdp_available:

class FSDP(FullyShardedDataParallel):
def __init__(self, model: Module, amp: bool = False):
sharding_strategy = ShardingStrategy.FULL_SHARD
if amp:
mixed_precision = MixedPrecision(
param_dtype=torch.float16,
reduce_dtype=torch.float32,
buffer_dtype=torch.float32,
keep_low_precision_grads=False,
)
else:
mixed_precision = None

super().__init__(
model,
sharding_strategy=sharding_strategy,
mixed_precision=mixed_precision,
cpu_offload=None,
auto_wrap_policy=partial(
size_based_auto_wrap_policy,
min_num_params=1,
),
)
self.set_state_dict_type(
self,
StateDictType.FULL_STATE_DICT,
FullStateDictConfig(offload_to_cpu=True, rank0_only=True),
FullOptimStateDictConfig(offload_to_cpu=True, rank0_only=True),
)

def transfer_weights(self, *args, **kwargs):
raise RuntimeError("Please transfer weights before converting to FSDP.")

def optim_state_dict(self, optimizer: Optimizer):
return super().optim_state_dict(self, optimizer)

def load_optim_state_dict(self, optimizer: Optimizer, state_dict: dict):
optim_state_dict = self.optim_state_dict_to_load(
optim_state_dict=state_dict,
model=self,
optim=optimizer,
)
optimizer.load_state_dict(optim_state_dict)

else:

class FSDP(object):
pass


class ClinicaDDP(DistributedDataParallel):
def _forward(self, *args, **kwargs):
return self.module._forward(*args, **kwargs)

Expand All @@ -147,6 +229,50 @@ def transfer_weights(self, *args, **kwargs):
def state_dict(self):
return self.module.state_dict()

def optim_state_dict(self, optimizer: Optimizer):
return optimizer.state_dict()

def load_optim_state_dict(self, optimizer: Optimizer, state_dict: dict):
optimizer.load_state_dict(state_dict)


class DDP:
def __new__(
cls, model: Module, fsdp: bool = False, amp: bool = False
) -> Union[ClinicaDDP, FSDP]:
monkeypatch(model)

if fsdp:
if Version(torch.__version__) < Version("2.0.0"):
logger.warning(
"We do not support FullyShardedDataParallel before Pytorch 2."
" Falling back to standard distributed data parallelism."
)
return ClinicaDDP(model)

if fsdp_available:
return FSDP(model, amp=amp)
else:
logger.warning(
"FSDP is not available on your system, falling back "
"to standard distributed data parallelism."
)
return ClinicaDDP(model)
else:
return ClinicaDDP(model)

def optim_state_dict(self, optimizer: Optimizer):
...

def state_dict(self):
...

def load_state_dict(self, state_dict: dict):
...

def load_optim_state_dict(self, optimizer: Optimizer, state_dict: dict):
...


def get_backend(gpu: bool = False) -> str:
if gpu and dist.is_nccl_available():
Expand Down
Loading