From 11eab11d2605f776b5b407fdb40be9a6dd042224 Mon Sep 17 00:00:00 2001 From: Jose Javier <26491792+josejg@users.noreply.github.com> Date: Thu, 15 Aug 2024 09:45:05 -0700 Subject: [PATCH] isort --- llmfoundry/callbacks/env_logging_callback.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/llmfoundry/callbacks/env_logging_callback.py b/llmfoundry/callbacks/env_logging_callback.py index 5c91e9be5a..1c55fb15e8 100644 --- a/llmfoundry/callbacks/env_logging_callback.py +++ b/llmfoundry/callbacks/env_logging_callback.py @@ -13,6 +13,7 @@ from composer.core import Callback, State from composer.loggers import Logger from composer.utils import dist + from mcli import sdk __all__ = ['EnvironmentLoggingCallback'] @@ -29,6 +30,7 @@ 'peft', ] + class EnvironmentLoggingCallback(Callback): """A callback for logging environment information during model training. @@ -118,8 +120,10 @@ def _get_nvidia_info(self) -> dict[str, Any]: if torch.cuda.is_available(): nccl_version = torch.cuda.nccl.version() # type: ignore return { - 'cuda_version': torch.version.cuda, # type: ignore[attr-defined] - 'cudnn_version': str(torch.backends.cudnn.version()), # type: ignore[attr-defined] + 'cuda_version': + torch.version.cuda, # type: ignore[attr-defined] + 'cudnn_version': str(torch.backends.cudnn.version() + ), # type: ignore[attr-defined] 'nccl_version': '.'.join(map(str, nccl_version)), } return {'available': False} @@ -144,7 +148,8 @@ def fit_start(self, state: State, logger: Logger) -> None: # Collect environment data if self.log_git: self.env_data['git_info'] = { - folder: self._get_git_info(os.path.join(self.workspace_dir, folder)) + folder: + self._get_git_info(os.path.join(self.workspace_dir, folder)) for folder in os.listdir(self.workspace_dir) if os.path.isdir(os.path.join(self.workspace_dir, folder)) }