Skip to content

Commit

Permalink
catch UC not found
Browse files Browse the repository at this point in the history
  • Loading branch information
Vincent Chen committed Dec 12, 2024
1 parent a27c720 commit 5ce2048
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 2 deletions.
17 changes: 15 additions & 2 deletions llmfoundry/utils/config_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import logging
import math
import os
import re
import warnings
from dataclasses import dataclass, fields
from typing import (
Expand All @@ -21,12 +22,14 @@
import mlflow
from composer.loggers import Logger
from composer.utils import dist, parse_uri
from exceptions import UCNotFoundError
from mlflow.data import (
delta_dataset_source,
http_dataset_source,
huggingface_dataset_source,
uc_volume_dataset_source,
)
from mlflow.exceptions import MlflowException
from omegaconf import MISSING, DictConfig, ListConfig, MissingMandatoryValue
from omegaconf import OmegaConf as om
from transformers import PretrainedConfig
Expand Down Expand Up @@ -788,13 +791,23 @@ def log_dataset_uri(cfg: dict[str, Any]) -> None:

# Map data source types to their respective MLFlow DataSource.
for dataset_type, path, split in data_paths:

if dataset_type in dataset_source_mapping:
source_class = dataset_source_mapping[dataset_type]
if dataset_type == 'delta_table':
source = source_class(delta_table_name=path)
elif dataset_type == 'hf' or dataset_type == 'uc_volume':
source = source_class(path=path)
try:
source = source_class(path=path)
except MlflowException as e:
error_str = str(e)
match = re.search(
r'MlflowException:\s+(.*?)\s+does not exist in Databricks Unified Catalog\.',
error_str,
)
if match:
uc_path = match.group(1)
raise UCNotFoundError(uc_path)
raise
else:
source = source_class(url=path)
else:
Expand Down
15 changes: 15 additions & 0 deletions llmfoundry/utils/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
'StoragePermissionError',
'UCNotEnabledError',
'DeltaTableNotFoundError',
'UCNotFoundError',
]

ALLOWED_RESPONSE_KEYS = {'response', 'completion'}
Expand Down Expand Up @@ -585,3 +586,17 @@ def __init__(
volume_name=volume_name,
table_name=table_name,
)


class UCNotFoundError(UserError):
"""Error thrown when the UC passed in training doesn't exist."""

def __init__(
self,
path: str,
) -> None:
message = f'Your data path {path} does not exist. Please double check your UC path'
super().__init__(
message=message,
path=path,
)

0 comments on commit 5ce2048

Please sign in to comment.