Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
josejg committed Aug 16, 2024
1 parent c9d1884 commit bc3cd9d
Showing 1 changed file with 24 additions and 14 deletions.
38 changes: 24 additions & 14 deletions llmfoundry/callbacks/env_logging_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,12 +75,17 @@ def __init__(
self.env_data: dict[str, Any] = {}
self.packages_to_log = packages_to_log or _PACKAGES_TO_LOG

def _get_git_info(self, repo_path: str) -> dict[str, str]:
repo = git.Repo(repo_path)
return {
'commit_hash': repo.head.commit.hexsha,
'branch': repo.active_branch.name,
}
def _get_git_info(self, repo_path: str) -> Optional[dict[str, str]]:
if not os.path.isdir(os.path.join(self.workspace_dir, folder)):
return None
try:
repo = git.Repo(repo_path)
return {
'commit_hash': repo.head.commit.hexsha,
'branch': repo.active_branch.name,
}
except (git.InvalidGitRepositoryError, git.NoSuchPathError):
return None

def _get_package_version(self, package_name: str) -> Optional[str]:
try:
Expand Down Expand Up @@ -137,7 +142,9 @@ def _get_distributed_info(self) -> dict[str, Any]:
'local_rank': dist.get_local_rank(),
}

def _get_docker_info(self) -> dict[str, Any]:
def _get_docker_info(self) -> Optional[dict[str, Any]]:
if 'RUN_NAME' not in os.environ:
return None
run = sdk.get_run(os.environ['RUN_NAME'])
image, tag = run.image.split(':')
return {
Expand All @@ -148,12 +155,13 @@ def _get_docker_info(self) -> dict[str, Any]:
def fit_start(self, state: State, logger: Logger) -> None:
# Collect environment data
if self.log_git:
self.env_data['git_info'] = {
folder:
self._get_git_info(os.path.join(self.workspace_dir, folder))
for folder in os.listdir(self.workspace_dir)
if os.path.isdir(os.path.join(self.workspace_dir, folder))
}
self.env_data['git_info'] = {}
for folder in os.listdir(self.workspace_dir):
path = self._get_git_info(
os.path.join(self.workspace_dir, folder),
)
if path:
self.env_data['git_info'][folder] = path

if self.log_packages:
self.env_data['package_versions'] = {
Expand All @@ -164,7 +172,9 @@ def fit_start(self, state: State, logger: Logger) -> None:
self.env_data['nvidia'] = self._get_nvidia_info()

if self.log_docker:
self.env_data['docker'] = self._get_docker_info()
if docker_info := self._get_docker_info():
self.env_data['docker'] = docker_info

if self.log_system:
self.env_data['system_info'] = self._get_system_info()

Expand Down

0 comments on commit bc3cd9d

Please sign in to comment.