From bc3cd9d5b1ae520acaf1e6c00df03c0d70160cc7 Mon Sep 17 00:00:00 2001 From: Jose Javier <26491792+josejg@users.noreply.github.com> Date: Fri, 16 Aug 2024 11:37:29 -0700 Subject: [PATCH] fix --- llmfoundry/callbacks/env_logging_callback.py | 38 ++++++++++++-------- 1 file changed, 24 insertions(+), 14 deletions(-) diff --git a/llmfoundry/callbacks/env_logging_callback.py b/llmfoundry/callbacks/env_logging_callback.py index 1b5514823c..7210f53ba5 100644 --- a/llmfoundry/callbacks/env_logging_callback.py +++ b/llmfoundry/callbacks/env_logging_callback.py @@ -75,12 +75,17 @@ def __init__( self.env_data: dict[str, Any] = {} self.packages_to_log = packages_to_log or _PACKAGES_TO_LOG - def _get_git_info(self, repo_path: str) -> dict[str, str]: - repo = git.Repo(repo_path) - return { - 'commit_hash': repo.head.commit.hexsha, - 'branch': repo.active_branch.name, - } + def _get_git_info(self, repo_path: str) -> Optional[dict[str, str]]: + if not os.path.isdir(os.path.join(self.workspace_dir, folder)): + return None + try: + repo = git.Repo(repo_path) + return { + 'commit_hash': repo.head.commit.hexsha, + 'branch': repo.active_branch.name, + } + except (git.InvalidGitRepositoryError, git.NoSuchPathError): + return None def _get_package_version(self, package_name: str) -> Optional[str]: try: @@ -137,7 +142,9 @@ def _get_distributed_info(self) -> dict[str, Any]: 'local_rank': dist.get_local_rank(), } - def _get_docker_info(self) -> dict[str, Any]: + def _get_docker_info(self) -> Optional[dict[str, Any]]: + if 'RUN_NAME' not in os.environ: + return None run = sdk.get_run(os.environ['RUN_NAME']) image, tag = run.image.split(':') return { @@ -148,12 +155,13 @@ def _get_docker_info(self) -> dict[str, Any]: def fit_start(self, state: State, logger: Logger) -> None: # Collect environment data if self.log_git: - self.env_data['git_info'] = { - folder: - self._get_git_info(os.path.join(self.workspace_dir, folder)) - for folder in os.listdir(self.workspace_dir) - if os.path.isdir(os.path.join(self.workspace_dir, folder)) - } + self.env_data['git_info'] = {} + for folder in os.listdir(self.workspace_dir): + path = self._get_git_info( + os.path.join(self.workspace_dir, folder), + ) + if path: + self.env_data['git_info'][folder] = path if self.log_packages: self.env_data['package_versions'] = { @@ -164,7 +172,9 @@ def fit_start(self, state: State, logger: Logger) -> None: self.env_data['nvidia'] = self._get_nvidia_info() if self.log_docker: - self.env_data['docker'] = self._get_docker_info() + if docker_info := self._get_docker_info(): + self.env_data['docker'] = docker_info + if self.log_system: self.env_data['system_info'] = self._get_system_info()