Skip to content

Commit

Permalink
isort
Browse files Browse the repository at this point in the history
  • Loading branch information
josejg committed Aug 15, 2024
1 parent 72d100b commit 11eab11
Showing 1 changed file with 8 additions and 3 deletions.
11 changes: 8 additions & 3 deletions llmfoundry/callbacks/env_logging_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from composer.core import Callback, State
from composer.loggers import Logger
from composer.utils import dist

from mcli import sdk

__all__ = ['EnvironmentLoggingCallback']
Expand All @@ -29,6 +30,7 @@
'peft',
]


class EnvironmentLoggingCallback(Callback):
"""A callback for logging environment information during model training.
Expand Down Expand Up @@ -118,8 +120,10 @@ def _get_nvidia_info(self) -> dict[str, Any]:
if torch.cuda.is_available():
nccl_version = torch.cuda.nccl.version() # type: ignore
return {
'cuda_version': torch.version.cuda, # type: ignore[attr-defined]
'cudnn_version': str(torch.backends.cudnn.version()), # type: ignore[attr-defined]
'cuda_version':
torch.version.cuda, # type: ignore[attr-defined]
'cudnn_version': str(torch.backends.cudnn.version()
), # type: ignore[attr-defined]
'nccl_version': '.'.join(map(str, nccl_version)),
}
return {'available': False}
Expand All @@ -144,7 +148,8 @@ def fit_start(self, state: State, logger: Logger) -> None:
# Collect environment data
if self.log_git:
self.env_data['git_info'] = {
folder: self._get_git_info(os.path.join(self.workspace_dir, folder))
folder:
self._get_git_info(os.path.join(self.workspace_dir, folder))
for folder in os.listdir(self.workspace_dir)
if os.path.isdir(os.path.join(self.workspace_dir, folder))
}
Expand Down

0 comments on commit 11eab11

Please sign in to comment.