Skip to content

Commit

Permalink
log eval dataset misconfiguration (#1179)
Browse files Browse the repository at this point in the history
* log eval dataset misconfiguration

* use context

* literally

* BaseException -> Exception

* use my archaeological skills to find the right python syntax for 3.9

* refactor names for more general use

* oops

* oops II

* context -> location

* use variables instead of strings

* Update exceptions.py

* delete Mapping
  • Loading branch information
milocress authored May 15, 2024
1 parent b414626 commit cfee4e4
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 36 deletions.
52 changes: 31 additions & 21 deletions llmfoundry/utils/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# SPDX-License-Identifier: Apache-2.0

"""Custom exceptions for the LLMFoundry."""
from typing import Any, Dict, List
from typing import Any, Dict, List, Literal, Optional, Union

__all__ = [
'ALLOWED_RESPONSE_KEYS',
Expand Down Expand Up @@ -34,9 +34,19 @@
ALLOWED_PROMPT_KEYS = {'prompt'}
ALLOWED_MESSAGES_KEYS = {'messages'}

ErrorLocation = Union[Literal['TrainDataloader'], Literal['EvalDataloader']]
TrainDataLoaderLocation = 'TrainDataloader'
EvalDataLoaderLocation = 'EvalDataloader'


class ContextualError(Exception):
"""Error thrown when an error occurs in the context of a specific task."""

location: Optional[ErrorLocation] = None


# Finetuning dataloader exceptions
class MissingHuggingFaceURLSplitError(ValueError):
class MissingHuggingFaceURLSplitError(ValueError, ContextualError):
"""Error thrown when there's no split used in HF dataset config."""

def __init__(self) -> None:
Expand All @@ -45,7 +55,7 @@ def __init__(self) -> None:
super().__init__(message)


class NotEnoughDatasetSamplesError(ValueError):
class NotEnoughDatasetSamplesError(ValueError, ContextualError):
"""Error thrown when there is not enough data to train a model."""

def __init__(
Expand Down Expand Up @@ -75,7 +85,7 @@ def __init__(


## Tasks exceptions
class UnknownExampleTypeError(KeyError):
class UnknownExampleTypeError(KeyError, ContextualError):
"""Error thrown when an unknown example type is used in a task."""

def __init__(self, example_keys: str) -> None:
Expand All @@ -89,15 +99,15 @@ def __init__(self, example_keys: str) -> None:
super().__init__(message)


class NotEnoughChatDataError(ValueError):
class NotEnoughChatDataError(ValueError, ContextualError):
"""Error thrown when there is not enough chat data to train a model."""

def __init__(self) -> None:
message = 'Chat example must have at least two messages'
super().__init__(message)


class ConsecutiveRepeatedChatRolesError(ValueError):
class ConsecutiveRepeatedChatRolesError(ValueError, ContextualError):
"""Error thrown when there are consecutive repeated chat roles."""

def __init__(self, repeated_role: str) -> None:
Expand All @@ -106,7 +116,7 @@ def __init__(self, repeated_role: str) -> None:
super().__init__(message)


class InvalidLastChatMessageRoleError(ValueError):
class InvalidLastChatMessageRoleError(ValueError, ContextualError):
"""Error thrown when the last message role in a chat example is invalid."""

def __init__(self, last_role: str, expected_roles: set[str]) -> None:
Expand All @@ -116,7 +126,7 @@ def __init__(self, last_role: str, expected_roles: set[str]) -> None:
super().__init__(message)


class IncorrectMessageKeyQuantityError(ValueError):
class IncorrectMessageKeyQuantityError(ValueError, ContextualError):
"""Error thrown when a message has an incorrect number of keys."""

def __init__(self, keys: List[str]) -> None:
Expand All @@ -125,7 +135,7 @@ def __init__(self, keys: List[str]) -> None:
super().__init__(message)


class InvalidRoleError(ValueError):
class InvalidRoleError(ValueError, ContextualError):
"""Error thrown when a role is invalid."""

def __init__(self, role: str, valid_roles: set[str]) -> None:
Expand All @@ -135,7 +145,7 @@ def __init__(self, role: str, valid_roles: set[str]) -> None:
super().__init__(message)


class InvalidContentTypeError(TypeError):
class InvalidContentTypeError(TypeError, ContextualError):
"""Error thrown when the content type is invalid."""

def __init__(self, content_type: type) -> None:
Expand All @@ -144,7 +154,7 @@ def __init__(self, content_type: type) -> None:
super().__init__(message)


class InvalidPromptTypeError(TypeError):
class InvalidPromptTypeError(TypeError, ContextualError):
"""Error thrown when the prompt type is invalid."""

def __init__(self, prompt_type: type) -> None:
Expand All @@ -153,7 +163,7 @@ def __init__(self, prompt_type: type) -> None:
super().__init__(message)


class InvalidResponseTypeError(TypeError):
class InvalidResponseTypeError(TypeError, ContextualError):
"""Error thrown when the response type is invalid."""

def __init__(self, response_type: type) -> None:
Expand All @@ -162,7 +172,7 @@ def __init__(self, response_type: type) -> None:
super().__init__(message)


class InvalidPromptResponseKeysError(ValueError):
class InvalidPromptResponseKeysError(ValueError, ContextualError):
"""Error thrown when missing expected prompt and response keys."""

def __init__(self, mapping: Dict[str, str], example: Dict[str, Any]):
Expand All @@ -171,7 +181,7 @@ def __init__(self, mapping: Dict[str, str], example: Dict[str, Any]):
super().__init__(message)


class InvalidFileExtensionError(FileNotFoundError):
class InvalidFileExtensionError(FileNotFoundError, ContextualError):
"""Error thrown when a file extension is not a safe extension."""

def __init__(self, dataset_name: str, valid_extensions: List[str]) -> None:
Expand All @@ -184,7 +194,7 @@ def __init__(self, dataset_name: str, valid_extensions: List[str]) -> None:
super().__init__(message)


class UnableToProcessPromptResponseError(ValueError):
class UnableToProcessPromptResponseError(ValueError, ContextualError):
"""Error thrown when a prompt and response cannot be processed."""

def __init__(self, input: Dict) -> None:
Expand All @@ -194,7 +204,7 @@ def __init__(self, input: Dict) -> None:


## Convert Delta to JSON exceptions
class ClusterDoesNotExistError(ValueError):
class ClusterDoesNotExistError(ValueError, ContextualError):
"""Error thrown when the cluster does not exist."""

def __init__(self, cluster_id: str) -> None:
Expand All @@ -203,15 +213,15 @@ def __init__(self, cluster_id: str) -> None:
super().__init__(message)


class FailedToCreateSQLConnectionError(RuntimeError):
class FailedToCreateSQLConnectionError(RuntimeError, ContextualError):
"""Error thrown when client can't sql connect to Databricks."""

def __init__(self) -> None:
message = 'Failed to create sql connection to db workspace. To use sql connect, you need to provide http_path and cluster_id!'
super().__init__(message)


class FailedToConnectToDatabricksError(RuntimeError):
class FailedToConnectToDatabricksError(RuntimeError, ContextualError):
"""Error thrown when the client fails to connect to Databricks."""

def __init__(self) -> None:
Expand All @@ -220,7 +230,7 @@ def __init__(self) -> None:


## Convert Text to MDS exceptions
class InputFolderMissingDataError(ValueError):
class InputFolderMissingDataError(ValueError, ContextualError):
"""Error thrown when the input folder is missing data."""

def __init__(self, input_folder: str) -> None:
Expand All @@ -229,7 +239,7 @@ def __init__(self, input_folder: str) -> None:
super().__init__(message)


class OutputFolderNotEmptyError(FileExistsError):
class OutputFolderNotEmptyError(FileExistsError, ContextualError):
"""Error thrown when the output folder is not empty."""

def __init__(self, output_folder: str) -> None:
Expand All @@ -238,7 +248,7 @@ def __init__(self, output_folder: str) -> None:
super().__init__(message)


class MisconfiguredHfDatasetError(ValueError):
class MisconfiguredHfDatasetError(ValueError, ContextualError):
"""Error thrown when a HuggingFace dataset is misconfigured."""

def __init__(self, dataset_name: str, split: str) -> None:
Expand Down
43 changes: 28 additions & 15 deletions scripts/train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,11 @@
log_train_analytics,
maybe_create_mosaicml_logger,
)
from llmfoundry.utils.exceptions import (
ContextualError,
EvalDataLoaderLocation,
TrainDataLoaderLocation,
)

install()

Expand Down Expand Up @@ -391,8 +396,9 @@ def main(cfg: DictConfig) -> Trainer:
tokenizer,
train_cfg.device_train_batch_size,
)
except Exception as e:
except ContextualError as e:
if mosaicml_logger is not None:
e.location = TrainDataLoaderLocation
mosaicml_logger.log_exception(e)
raise e

Expand All @@ -409,19 +415,25 @@ def main(cfg: DictConfig) -> Trainer:
train_cfg.eval_first = False

else:
log.info('Building eval loader...')
eval_icl_seq_len: int = train_cfg.icl_seq_len if train_cfg.icl_seq_len else train_cfg.max_seq_len
evaluators, _, eval_gauntlet_callback = build_evaluators(
eval_loader_config,
icl_tasks_config,
eval_gauntlet_config,
tokenizer=tokenizer,
device_eval_batch_size=train_cfg.device_eval_batch_size,
icl_seq_len=eval_icl_seq_len,
icl_subset_num_batches=train_cfg.icl_subset_num_batches,
)
if eval_gauntlet_callback is not None:
callbacks.append(eval_gauntlet_callback)
try:
log.info('Building eval loader...')
eval_icl_seq_len: int = train_cfg.icl_seq_len if train_cfg.icl_seq_len else train_cfg.max_seq_len
evaluators, _, eval_gauntlet_callback = build_evaluators(
eval_loader_config,
icl_tasks_config,
eval_gauntlet_config,
tokenizer=tokenizer,
device_eval_batch_size=train_cfg.device_eval_batch_size,
icl_seq_len=eval_icl_seq_len,
icl_subset_num_batches=train_cfg.icl_subset_num_batches,
)
if eval_gauntlet_callback is not None:
callbacks.append(eval_gauntlet_callback)
except ContextualError as e:
if mosaicml_logger is not None:
e.location = EvalDataLoaderLocation
mosaicml_logger.log_exception(e)
raise e

if mosaicml_logger is not None:
log_train_analytics(
Expand Down Expand Up @@ -467,8 +479,9 @@ def main(cfg: DictConfig) -> Trainer:
evaluators,
non_icl_metrics,
)
except Exception as e:
except ContextualError as e:
if mosaicml_logger is not None:
e.location = EvalDataLoaderLocation
mosaicml_logger.log_exception(e)
raise e

Expand Down

0 comments on commit cfee4e4

Please sign in to comment.