Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Distributed Data Parallelism #402

Merged
merged 128 commits into from
Sep 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
128 commits
Select commit Hold shift + click to select a range
2f0260a
new amp flag
ncassereau Mar 20, 2023
1ada463
use amp during training
ncassereau Mar 20, 2023
86afa02
zero_grad set to none
ncassereau Mar 20, 2023
c4177a6
linter is bullying me
ncassereau Mar 20, 2023
0e4ef7c
sort imports
ncassereau Mar 20, 2023
025bc94
add ddp flag
ncassereau Mar 20, 2023
59baef3
new logger filter in order not to pollute stdout
ncassereau Mar 20, 2023
a78a275
Cluster resolver and DDP manager
ncassereau Mar 20, 2023
193c090
New sampler and test code
ncassereau Mar 20, 2023
0c9852b
Merge branch 'aramis-lab:dev' into amp
ncassereau Mar 23, 2023
7e01a74
Merge branch 'aramis-lab:dev' into ddp
ncassereau Mar 24, 2023
9b559b1
add resolver flag
ncassereau Mar 24, 2023
064b90c
update forward function such that DDP will call forward instead of co…
ncassereau Mar 24, 2023
ca07222
change doc such that it matches previous commit changes'
ncassereau Mar 24, 2023
13faf08
First batch of DDP-converted methods
ncassereau Mar 24, 2023
fef8c41
Zero Redundancy Optimizer
ncassereau Mar 24, 2023
43053d1
debugged port for mono task resolver in slurmless environments
ncassereau Mar 24, 2023
9c6ef6a
linter
ncassereau Mar 24, 2023
7d01e35
black linter
ncassereau Mar 24, 2023
44e959b
Merge branch 'aramis-lab:dev' into amp
ncassereau Mar 24, 2023
eb5bd6c
added forgotten ddp in test_loader method
ncassereau Mar 24, 2023
e0ad3f4
updated predict and interpret to work with AMP
ncassereau May 5, 2023
85fd15c
update cli for predict & interpret
ncassereau May 5, 2023
4a522b7
solve conflict with profiler
ncassereau May 5, 2023
1e2897f
satisfy linter god
ncassereau May 5, 2023
024f992
Cb issues tsvtools (#422)
camillebrianceau May 15, 2023
3c72f0e
fix json option bug (#423)
camillebrianceau May 15, 2023
e66d794
change output directory for `tsvtools get-labels` (#415)
camillebrianceau May 15, 2023
e2c2345
add caps_directory option in get-labels (#416)
camillebrianceau May 15, 2023
0134a33
Fix missing mods parsing (#424)
14thibea May 15, 2023
0b083e8
Update CHANGELOG and pyproject.toml for release (#425)
camillebrianceau May 15, 2023
e229385
Cb rh thesis2 (#420)
camillebrianceau May 24, 2023
3a3a2de
ensure image_path is a Path (#428)
14thibea May 24, 2023
3b5f154
Clean code (#427)
camillebrianceau May 24, 2023
1c8088e
Filter diagnoses in split manager (#429)
14thibea May 24, 2023
ed742fc
Bump pymdown-extensions from 9.10 to 10.0 (#426)
dependabot[bot] May 24, 2023
bb65bb3
Bump requests from 2.28.2 to 2.31.0 (#430)
dependabot[bot] May 24, 2023
2298867
Cb ssim 3 d (#433)
camillebrianceau May 26, 2023
a118721
Add joblib to parallelize commands (#399)
camillebrianceau May 31, 2023
08dec2c
add __init__.py when missing (#434)
camillebrianceau Jun 8, 2023
f0e77ca
Data augmentation with torchio (#417)
sophieloiz Jun 8, 2023
5534254
Fix unexpected keyword argument 'split' in issue #438 (#439)
camillebrianceau Jun 8, 2023
05b524c
Add generate motion using torchio (#419)
sophieloiz Jun 9, 2023
746ba03
Fix merged_tsv option in get_labels command (#437)
camillebrianceau Jun 9, 2023
b977d49
change is None to not (#444)
camillebrianceau Jun 9, 2023
24e795e
prepare release (#443)
camillebrianceau Jun 9, 2023
c979c49
Update main.html
camillebrianceau Jun 9, 2023
f6d65bb
[INFRA TEST] Add unit tests to ClinicaDL (#446)
NicolasGensollen Jun 15, 2023
a5a2ca9
Create CITATION.cff (#451)
camillebrianceau Jun 15, 2023
ce82ce4
Merge branch 'dev' into amp
camillebrianceau Jun 15, 2023
125f341
map to amp
camillebrianceau Jun 15, 2023
22de9b7
solve AMP conflict with recent commit
ncassereau Jul 20, 2023
0c5f741
linter
ncassereau Jul 20, 2023
e27eabe
more consistent use of AMP
ncassereau Jul 20, 2023
5ce6abf
a little doc
ncassereau Jul 20, 2023
eccb345
Merge branch 'aramis-lab:dev' into amp
ncassereau Jul 27, 2023
710e237
Merge branch dev into branch ddp and solve conflicts
ncassereau Jul 27, 2023
222b45b
Merge branch 'dev' into ddp
ncassereau Jul 27, 2023
0a57a21
linter god
ncassereau Jul 27, 2023
785f822
solve bug with conflict solve
ncassereau Jul 27, 2023
2d9edc7
new cluster resolver
ncassereau Jul 27, 2023
9cb3bd6
remove useless flags
ncassereau Jul 27, 2023
d401142
typing and cpus per task for default api
ncassereau Jul 27, 2023
e5a8511
linter god
ncassereau Jul 27, 2023
92d8e70
pls work
ncassereau Jul 27, 2023
90a3477
Merge branch 'amp' into ddp
ncassereau Jul 27, 2023
4374e31
pls linter god, be nice 2 me
ncassereau Jul 27, 2023
bce605e
pls linter god, i'm a nice person
ncassereau Jul 27, 2023
0383c09
why r u so mean linter god ?
ncassereau Jul 28, 2023
2fe7987
solve no-gpu dist bug
ncassereau Jul 28, 2023
eb87f1b
solve gpu and no nvidiadriver bug
ncassereau Jul 28, 2023
5518926
remove device_ids in DDP class in case there is no gpu
ncassereau Jul 28, 2023
c080d80
solve reduce for predict
ncassereau Jul 28, 2023
dbe3ad2
linter god
ncassereau Jul 28, 2023
03ed55d
add no_sync step to use gradient accumulation with DDP
ncassereau Jul 28, 2023
b517e4d
typing
ncassereau Jul 31, 2023
dad2ab2
typing
ncassereau Jul 31, 2023
5123bfe
Monkeypatch classes so that users can still use the forward method to…
ncassereau Jul 31, 2023
1b70257
solve compile bug in case a default argument is given
ncassereau Jul 31, 2023
27d0e13
solve namespace bug for monkeypatching
ncassereau Jul 31, 2023
23e86e9
Abort patching if _forward method already present. Abort patching if …
ncassereau Jul 31, 2023
fb456f2
Docstrings and comments
ncassereau Jul 31, 2023
b2155bb
raise error if AMP is enabled but GPU usage is not
ncassereau Aug 4, 2023
3bbf14d
add amp to QC
ncassereau Aug 4, 2023
5db02d4
remove default option so that it correctly fetches the default from t…
ncassereau Aug 4, 2023
6b7fe82
Docs for AMP
ncassereau Aug 4, 2023
6bb226e
Add Docs for profiler
ncassereau Aug 4, 2023
7a26ab2
Merge branch 'amp' into ddp
ncassereau Aug 4, 2023
676b50b
remove ZERO flag default so that train_config.toml is taken into account
ncassereau Aug 4, 2023
e368d9d
Docs for ZeRO
ncassereau Aug 4, 2023
fd6c85b
Docs for Distributed Data Parallelism
ncassereau Aug 4, 2023
dea287e
better free port search (shamelessly stolen from lightning fabric :p
ncassereau Aug 4, 2023
9b5643f
remove unused attribute for default cluster api
ncassereau Aug 4, 2023
4d9c4c3
Merge branch 'dev' into ddp
camillebrianceau Aug 7, 2023
eadfa2e
resolve conflict pb
camillebrianceau Aug 7, 2023
0d07319
remove async op
ncassereau Aug 7, 2023
93387fe
solve failed merge between branch AMP and branch DDP
ncassereau Aug 7, 2023
0063070
solve Sampler bug in case of multi-network
ncassereau Aug 7, 2023
73e398f
correct number of samples in case of DDP with weighted sampler
ncassereau Aug 7, 2023
fda196c
linter
ncassereau Aug 7, 2023
c83e10a
Merge branch 'dev' into ddp
ncassereau Aug 9, 2023
b00dff2
Improved Cluster module interface.
ncassereau Aug 10, 2023
1c42694
linter
ncassereau Aug 10, 2023
d353c97
remove useless stuff
ncassereau Aug 10, 2023
bf4a365
remove no longer needed function
ncassereau Aug 10, 2023
b08057c
corrected stack level warning
ncassereau Aug 16, 2023
df8ff79
Add a small patch for kineto for pytorch > 1.11
ncassereau Aug 16, 2023
6a103a8
Merge branch 'dev' into ddp
ncassereau Sep 7, 2023
b5b429f
solved issue where older versions of pytorch could not compare their …
ncassereau Sep 8, 2023
2c4f3b7
remove needlessly duplicated line
ncassereau Sep 8, 2023
029daa2
corrected some typing mistakes
ncassereau Sep 8, 2023
bbfd835
Remove useless imports
ncassereau Sep 8, 2023
8cab7a9
ensemble prediction is now only performed on master
ncassereau Sep 8, 2023
eba8ba9
Merge branch 'dev' into ddp
ncassereau Sep 8, 2023
f74c898
Rename world_size as dp_degree in task manager
ncassereau Sep 8, 2023
d14c9c0
Rename ZeRO flag to FSDP
ncassereau Sep 8, 2023
e6c558d
typo
ncassereau Sep 8, 2023
5d73bf7
move alias in correct place with other aliases
ncassereau Sep 8, 2023
a6fe6df
Add docstrings and comments to explicitly explain the way the cluster…
ncassereau Sep 8, 2023
a58ee46
add docstrings to kineto patcher
ncassereau Sep 8, 2023
2b47ca6
set regex as raw string
ncassereau Sep 8, 2023
ad4a3b2
Support typing for python 3.8 or older
ncassereau Sep 8, 2023
0112de1
fix taxk_manager's method generate_sampler's argument name change in …
ncassereau Sep 11, 2023
79fb0b6
Merge branch 'dev' into ddp
camillebrianceau Sep 12, 2023
e77efe7
Update maps_manager.py
camillebrianceau Sep 12, 2023
89fdb71
linter
ncassereau Sep 12, 2023
d2cfd80
Rename fsdp to fullyshardeddataparallel
ncassereau Sep 12, 2023
cad088d
rename flag with underscore to separate words
ncassereau Sep 20, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions clinicadl/resources/config/train_config.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ gpu = true
n_proc = 2
batch_size = 8
evaluation_steps = 0
fully_sharded_data_parallel = false
amp = false

[Reproducibility]
Expand Down
1 change: 1 addition & 0 deletions clinicadl/train/tasks/classification_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
@train_option.n_proc
@train_option.batch_size
@train_option.evaluation_steps
@train_option.fully_sharded_data_parallel
@train_option.amp
# Reproducibility
@train_option.seed
Expand Down
1 change: 1 addition & 0 deletions clinicadl/train/tasks/reconstruction_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
@train_option.n_proc
@train_option.batch_size
@train_option.evaluation_steps
@train_option.fully_sharded_data_parallel
@train_option.amp
# Reproducibility
@train_option.seed
Expand Down
1 change: 1 addition & 0 deletions clinicadl/train/tasks/regression_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
@train_option.n_proc
@train_option.batch_size
@train_option.evaluation_steps
@train_option.fully_sharded_data_parallel
@train_option.amp
# Reproducibility
@train_option.seed
Expand Down
1 change: 1 addition & 0 deletions clinicadl/train/tasks/task_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def task_launcher(network_task: str, task_options_list: List[str], **kwargs):
"dropout",
"epochs",
"evaluation_steps",
"fully_sharded_data_parallel",
"gpu",
"learning_rate",
"multi_cohort",
Expand Down
11 changes: 11 additions & 0 deletions clinicadl/utils/cli_param/train_option.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,17 @@
help="Fix the number of iterations to perform before computing an evaluation. Default will only "
"perform one evaluation at the end of each epoch.",
)
fully_sharded_data_parallel = cli_param.option_group.computational_group.option(
"--fully_sharded_data_parallel",
"-fsdp",
type=bool,
is_flag=True,
help="Enables Fully Sharded Data Parallel with Pytorch to save memory at the cost of communications. "
"Currently this only enables ZeRO Stage 1 but will be entirely replaced by FSDP in a later patch, "
"this flag is already set to FSDP to that the zero flag is never actually removed.",
default=False,
)

amp = cli_param.option_group.computational_group.option(
"--amp/--no-amp",
type=bool,
Expand Down
11 changes: 11 additions & 0 deletions clinicadl/utils/maps_manager/cluster/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
import sys

# These imports won't be available at runtime, but will help VSCode completion.
from .api import API as API
from .api import AutoMasterAddressPort as AutoMasterAddressPort
from .config import *
from .interface import Interface
from .utils import ClinicaClusterResolverWarning as ClinicaClusterResolverWarning
from .utils import Rank0Filter as Rank0Filter

sys.modules[__name__] = Interface()
13 changes: 13 additions & 0 deletions clinicadl/utils/maps_manager/cluster/api/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from .auto_master_addr_port import AutoMasterAddressPort
from .base import API
from .default import DefaultAPI
from .slurm import SlurmAPI
from .torchelastic import TorchElasticAPI

__all__ = [
"API",
"AutoMasterAddressPort",
"DefaultAPI",
"SlurmAPI",
"TorchElasticAPI",
]
46 changes: 46 additions & 0 deletions clinicadl/utils/maps_manager/cluster/api/auto_master_addr_port.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
#! /usr/bin/env python
# -*- coding: utf-8 -*-

import os
from functools import wraps
from typing import Callable, Type

from ..config import __all__ as all_API_methods
from .base import API

# Defines a class decorator to make wraps any API methods so that the Master Address
# and the Master Port are set in order to allow the process group to initialize
# correctly.

env_variables_set: bool = False


def set_master_addr_port_env_variables(func):
# The parameter should be a method of a subclass of the API abstract class.
@wraps(func)
def wrapper(self):
global env_variables_set
if not env_variables_set:
env_variables_set = True # must be done before actually setting the variable to prevent stackoverflow
os.environ["MASTER_ADDR"] = self.master_address()
os.environ["MASTER_PORT"] = str(self.port())
return func(self)

return wrapper


def decorate_methods(cls: Type[API], func_to_apply: Callable) -> Type[API]:
# Decorate all API methods defined in the config file with the given function.
for obj_name in dir(cls):
if obj_name in all_API_methods:
decorated = func_to_apply(getattr(cls, obj_name))
setattr(cls, obj_name, decorated)

return cls


def AutoMasterAddressPort(cls: Type[API]) -> Type[API]:
# When we call a cluster API function for the first time, we set the MASTER_ADDR
# and MASTER_PORT environment variables, so that the Pytorch wrapper
# DistributedDataParallel can set up communication correctly.
return decorate_methods(cls, func_to_apply=set_master_addr_port_env_variables)
93 changes: 93 additions & 0 deletions clinicadl/utils/maps_manager/cluster/api/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
#! /usr/bin/env python
# -*- coding: utf-8 -*-

from abc import ABC, abstractmethod
from typing import List, Union


class API(ABC):
priority: int = 5000
name: str = "AbstractAPI"

@abstractmethod
def is_launcher(self) -> bool:
"""
Detects if the given API is the one used to launch the current job.
"""
raise NotImplementedError()

@abstractmethod
def rank(self) -> int:
"""
Property containing the rank of the process.
"""
raise NotImplementedError()

@abstractmethod
def local_rank(self) -> int:
"""
Property containing the local rank of the process.
"""
raise NotImplementedError()

@abstractmethod
def world_size(self) -> int:
"""
Property containing the number of processes launched.
"""
raise NotImplementedError()

@abstractmethod
def local_world_size(self) -> int:
"""
Property containing the number of processes launched of each node.
"""
raise NotImplementedError()

@abstractmethod
def num_nodes(self) -> int:
"""
Property containing the number of nodes.
"""
raise NotImplementedError()

@abstractmethod
def cpus(self) -> int:
"""
Property containing the number of CPUs allocated to each process.
"""
raise NotImplementedError()

@abstractmethod
def gpus(self) -> List[str]:
"""
Property containing all GPUs ids.
"""
raise NotImplementedError()

@abstractmethod
def nodelist(self) -> Union[str, List[str]]:
"""
Property containing the list of nodes.
"""
raise NotImplementedError()

@abstractmethod
def master_address(self) -> str:
"""
Property containing the master node.
"""
raise NotImplementedError()

@abstractmethod
def port(self) -> int:
"""
Property containing the port to communicate with the master process.
"""
raise NotImplementedError()

def is_master(self) -> bool:
"""
Detects whether or not the given process is the master (i.e. rank 0)
"""
return self.rank() == 0
65 changes: 65 additions & 0 deletions clinicadl/utils/maps_manager/cluster/api/default.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
#! /usr/bin/env python
# -*- coding: utf-8 -*-

import os
import socket
from contextlib import closing
from typing import List, Optional

from .auto_master_addr_port import AutoMasterAddressPort
from .base import API


@AutoMasterAddressPort
class DefaultAPI(API):
priority: int = 0
name: str = "Sequential"

def __init__(self):
self.current_port: Optional[int] = None

@staticmethod
def find_available_port() -> int:
"""
Tries to bind to local port until it finds one which is available.
This is used to set the master port environment variable.
"""
with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as sock:
sock.bind(("localhost", 0))
port = sock.getsockname()[1]
return port

def is_launcher(self) -> bool:
return True

def rank(self) -> int:
return 0

def local_rank(self) -> int:
return 0

def world_size(self) -> int:
return 1

def local_world_size(self) -> int:
return 1

def num_nodes(self) -> int:
return 1

def cpus(self) -> int:
return len(os.sched_getaffinity(0))

def gpus(self) -> List[str]:
return []

def nodelist(self) -> List[str]:
return ["localhost"]

def master_address(self) -> str:
return "localhost"

def port(self) -> int:
if self.current_port is None:
self.current_port = self.find_available_port()
return self.current_port
55 changes: 55 additions & 0 deletions clinicadl/utils/maps_manager/cluster/api/slurm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
#! /usr/bin/env python
# -*- coding: utf-8 -*-

import os
from typing import List

from ..utils import get_first_host
from .auto_master_addr_port import AutoMasterAddressPort
from .base import API


@AutoMasterAddressPort
class SlurmAPI(API):
priority: int = 10000
name: str = "Slurm"

def is_launcher(self) -> bool:
return "SLURM_STEP_ID" in os.environ

def rank(self) -> int:
return int(os.environ["SLURM_PROCID"])

def local_rank(self) -> int:
return int(os.environ["SLURM_LOCALID"])

def world_size(self) -> int:
return int(os.environ["SLURM_STEP_NUM_TASKS"])

def local_world_size(self) -> int:
return int(os.environ["SLURM_STEP_TASKS_PER_NODE"])

def num_nodes(self) -> int:
return int(os.environ["SLURM_STEP_NUM_NODES"])

def cpus(self) -> int:
cpu = int(os.environ.get("SLURM_CPUS_PER_TASK", 0))
return cpu or len(os.sched_getaffinity(0))

def gpus(self) -> List[str]:
step_gpus = os.environ.get("SLURM_STEP_GPUS", None)
if step_gpus is not None:
return step_gpus.split(",")
return []

def nodelist(self) -> str:
return os.environ["SLURM_STEP_NODELIST"]

def master_address(self) -> str:
return get_first_host(self.nodelist())

def jobid(self) -> int:
return int(os.environ["SLURM_JOB_ID"])

def port(self) -> int:
return 10000 + self.jobid() % 20000
Loading