-
Notifications
You must be signed in to change notification settings - Fork 536
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add EnvironmentLogger Callback (#1350)
- Loading branch information
Showing
3 changed files
with
191 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,188 @@ | ||
# Copyright 2024 MosaicML LLM Foundry authors | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
import os | ||
import platform | ||
import socket | ||
from typing import Any, Optional | ||
|
||
import git | ||
import pkg_resources | ||
import psutil | ||
import torch | ||
from composer.core import Callback, State | ||
from composer.loggers import Logger | ||
from composer.utils import dist | ||
|
||
from mcli import sdk | ||
|
||
__all__ = ['EnvironmentLoggingCallback'] | ||
|
||
_PACKAGES_TO_LOG = [ | ||
'llm-foundry', | ||
'mosaicml', | ||
'megablocks', | ||
'grouped-gemm', | ||
'torch', | ||
'flash_attn', | ||
'transformers', | ||
'datasets', | ||
'peft', | ||
] | ||
|
||
|
||
class EnvironmentLoggingCallback(Callback): | ||
"""A callback for logging environment information during model training. | ||
This callback collects various pieces of information about the training environment, | ||
including git repository details, package versions, system information, GPU details, | ||
distributed training setup, NVIDIA driver information, and Docker container details. | ||
Args: | ||
workspace_dir (str): The directory containing the workspace. Defaults to '/workspace'. | ||
log_git (bool): Whether to log git repository information. Defaults to True. | ||
log_packages (bool): Whether to log package versions. Defaults to True. | ||
log_nvidia (bool): Whether to log NVIDIA driver information. Defaults to True. | ||
log_docker (bool): Whether to log Docker container information. Defaults to True. | ||
log_system (bool): Whether to log system information. Defaults to False. | ||
log_gpu (bool): Whether to log GPU information. Defaults to False. | ||
log_distributed (bool): Whether to log distributed training information. Defaults to False. | ||
packages_to_log (list[str]): A list of package names to log versions for. Defaults to None. | ||
The collected information is logged as hyperparameters at the start of model fitting. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
workspace_dir: str = '/workspace', | ||
log_git: bool = True, | ||
log_nvidia: bool = True, | ||
log_docker: bool = True, | ||
log_packages: bool = True, | ||
log_system: bool = False, | ||
log_gpu: bool = False, | ||
log_distributed: bool = False, | ||
packages_to_log: Optional[list[str]] = None, | ||
): | ||
self.workspace_dir = workspace_dir | ||
self.log_git = log_git | ||
self.log_packages = log_packages | ||
self.log_nvidia = log_nvidia | ||
self.log_docker = log_docker | ||
self.log_system = log_system | ||
self.log_gpu = log_gpu | ||
self.log_distributed = log_distributed | ||
self.env_data: dict[str, Any] = {} | ||
self.packages_to_log = packages_to_log or _PACKAGES_TO_LOG | ||
|
||
def _get_git_info(self, repo_path: str) -> Optional[dict[str, str]]: | ||
if not os.path.isdir(repo_path): | ||
return None | ||
try: | ||
repo = git.Repo(repo_path) | ||
return { | ||
'commit_hash': repo.head.commit.hexsha, | ||
'branch': repo.active_branch.name, | ||
} | ||
except (git.InvalidGitRepositoryError, git.NoSuchPathError): | ||
return None | ||
|
||
def _get_package_version(self, package_name: str) -> Optional[str]: | ||
try: | ||
return pkg_resources.get_distribution(package_name).version | ||
except pkg_resources.DistributionNotFound: | ||
return None | ||
|
||
def _get_system_info(self) -> dict[str, Any]: | ||
return { | ||
'python_version': platform.python_version(), | ||
'os': f'{platform.system()} {platform.release()}', | ||
'hostname': socket.gethostname(), | ||
'cpu_info': { | ||
'model': platform.processor(), | ||
'cores': psutil.cpu_count(logical=False), | ||
'threads': psutil.cpu_count(logical=True), | ||
}, | ||
'memory': { | ||
'total': psutil.virtual_memory().total, | ||
'available': psutil.virtual_memory().available, | ||
}, | ||
} | ||
|
||
def _get_gpu_info(self) -> dict[str, Any]: | ||
if torch.cuda.is_available(): | ||
return { | ||
'model': torch.cuda.get_device_name(0), | ||
'count': torch.cuda.device_count(), | ||
'memory': { | ||
'total': torch.cuda.get_device_properties(0).total_memory, | ||
'allocated': torch.cuda.memory_allocated(0), | ||
}, | ||
} | ||
return {'available': False} | ||
|
||
def _get_nvidia_info(self) -> dict[str, Any]: | ||
if torch.cuda.is_available(): | ||
nccl_version = torch.cuda.nccl.version() # type: ignore | ||
return { | ||
'cuda_version': | ||
torch.version.cuda, # type: ignore[attr-defined] | ||
'cudnn_version': str( | ||
torch.backends.cudnn.version(), | ||
), # type: ignore[attr-defined] | ||
'nccl_version': '.'.join(map(str, nccl_version)), | ||
} | ||
return {'available': False} | ||
|
||
def _get_distributed_info(self) -> dict[str, Any]: | ||
return { | ||
'world_size': dist.get_world_size(), | ||
'local_world_size': dist.get_local_world_size(), | ||
'rank': dist.get_global_rank(), | ||
'local_rank': dist.get_local_rank(), | ||
} | ||
|
||
def _get_docker_info(self) -> Optional[dict[str, Any]]: | ||
if 'RUN_NAME' not in os.environ: | ||
return None | ||
run = sdk.get_run(os.environ['RUN_NAME']) | ||
image, tag = run.image.split(':') | ||
return { | ||
'image': image, | ||
'tag': tag, | ||
} | ||
|
||
def fit_start(self, state: State, logger: Logger) -> None: | ||
# Collect environment data | ||
if self.log_git: | ||
self.env_data['git_info'] = {} | ||
for folder in os.listdir(self.workspace_dir): | ||
path = self._get_git_info( | ||
os.path.join(self.workspace_dir, folder), | ||
) | ||
if path: | ||
self.env_data['git_info'][folder] = path | ||
|
||
if self.log_packages: | ||
self.env_data['package_versions'] = { | ||
pkg: self._get_package_version(pkg) | ||
for pkg in self.packages_to_log | ||
} | ||
if self.log_nvidia: | ||
self.env_data['nvidia'] = self._get_nvidia_info() | ||
|
||
if self.log_docker: | ||
if docker_info := self._get_docker_info(): | ||
self.env_data['docker'] = docker_info | ||
|
||
if self.log_system: | ||
self.env_data['system_info'] = self._get_system_info() | ||
|
||
if self.log_gpu: | ||
self.env_data['gpu_info'] = self._get_gpu_info() | ||
|
||
if self.log_distributed: | ||
self.env_data['distributed_info'] = self._get_distributed_info() | ||
|
||
# Log the collected data | ||
logger.log_hyperparameters({'environment_data': self.env_data}) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -73,6 +73,7 @@ | |
'tenacity>=8.2.3,<9', | ||
'catalogue>=2,<3', | ||
'typer<1', | ||
'GitPython==3.1.43', | ||
] | ||
|
||
extra_deps = {} | ||
|