From cfee4e4e3466be93a0a10e2c6b7c248bad0928f3 Mon Sep 17 00:00:00 2001 From: Milo Cress Date: Wed, 15 May 2024 18:43:44 -0400 Subject: [PATCH] log eval dataset misconfiguration (#1179) * log eval dataset misconfiguration * use context * literally * BaseException -> Exception * use my archaeological skills to find the right python syntax for 3.9 * refactor names for more general use * oops * oops II * context -> location * use variables instead of strings * Update exceptions.py * delete Mapping --- llmfoundry/utils/exceptions.py | 52 ++++++++++++++++++++-------------- scripts/train/train.py | 43 ++++++++++++++++++---------- 2 files changed, 59 insertions(+), 36 deletions(-) diff --git a/llmfoundry/utils/exceptions.py b/llmfoundry/utils/exceptions.py index d8baac2b49..51da8610e9 100644 --- a/llmfoundry/utils/exceptions.py +++ b/llmfoundry/utils/exceptions.py @@ -2,7 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 """Custom exceptions for the LLMFoundry.""" -from typing import Any, Dict, List +from typing import Any, Dict, List, Literal, Optional, Union __all__ = [ 'ALLOWED_RESPONSE_KEYS', @@ -34,9 +34,19 @@ ALLOWED_PROMPT_KEYS = {'prompt'} ALLOWED_MESSAGES_KEYS = {'messages'} +ErrorLocation = Union[Literal['TrainDataloader'], Literal['EvalDataloader']] +TrainDataLoaderLocation = 'TrainDataloader' +EvalDataLoaderLocation = 'EvalDataloader' + + +class ContextualError(Exception): + """Error thrown when an error occurs in the context of a specific task.""" + + location: Optional[ErrorLocation] = None + # Finetuning dataloader exceptions -class MissingHuggingFaceURLSplitError(ValueError): +class MissingHuggingFaceURLSplitError(ValueError, ContextualError): """Error thrown when there's no split used in HF dataset config.""" def __init__(self) -> None: @@ -45,7 +55,7 @@ def __init__(self) -> None: super().__init__(message) -class NotEnoughDatasetSamplesError(ValueError): +class NotEnoughDatasetSamplesError(ValueError, ContextualError): """Error thrown when there is not enough data to train a model.""" def __init__( @@ -75,7 +85,7 @@ def __init__( ## Tasks exceptions -class UnknownExampleTypeError(KeyError): +class UnknownExampleTypeError(KeyError, ContextualError): """Error thrown when an unknown example type is used in a task.""" def __init__(self, example_keys: str) -> None: @@ -89,7 +99,7 @@ def __init__(self, example_keys: str) -> None: super().__init__(message) -class NotEnoughChatDataError(ValueError): +class NotEnoughChatDataError(ValueError, ContextualError): """Error thrown when there is not enough chat data to train a model.""" def __init__(self) -> None: @@ -97,7 +107,7 @@ def __init__(self) -> None: super().__init__(message) -class ConsecutiveRepeatedChatRolesError(ValueError): +class ConsecutiveRepeatedChatRolesError(ValueError, ContextualError): """Error thrown when there are consecutive repeated chat roles.""" def __init__(self, repeated_role: str) -> None: @@ -106,7 +116,7 @@ def __init__(self, repeated_role: str) -> None: super().__init__(message) -class InvalidLastChatMessageRoleError(ValueError): +class InvalidLastChatMessageRoleError(ValueError, ContextualError): """Error thrown when the last message role in a chat example is invalid.""" def __init__(self, last_role: str, expected_roles: set[str]) -> None: @@ -116,7 +126,7 @@ def __init__(self, last_role: str, expected_roles: set[str]) -> None: super().__init__(message) -class IncorrectMessageKeyQuantityError(ValueError): +class IncorrectMessageKeyQuantityError(ValueError, ContextualError): """Error thrown when a message has an incorrect number of keys.""" def __init__(self, keys: List[str]) -> None: @@ -125,7 +135,7 @@ def __init__(self, keys: List[str]) -> None: super().__init__(message) -class InvalidRoleError(ValueError): +class InvalidRoleError(ValueError, ContextualError): """Error thrown when a role is invalid.""" def __init__(self, role: str, valid_roles: set[str]) -> None: @@ -135,7 +145,7 @@ def __init__(self, role: str, valid_roles: set[str]) -> None: super().__init__(message) -class InvalidContentTypeError(TypeError): +class InvalidContentTypeError(TypeError, ContextualError): """Error thrown when the content type is invalid.""" def __init__(self, content_type: type) -> None: @@ -144,7 +154,7 @@ def __init__(self, content_type: type) -> None: super().__init__(message) -class InvalidPromptTypeError(TypeError): +class InvalidPromptTypeError(TypeError, ContextualError): """Error thrown when the prompt type is invalid.""" def __init__(self, prompt_type: type) -> None: @@ -153,7 +163,7 @@ def __init__(self, prompt_type: type) -> None: super().__init__(message) -class InvalidResponseTypeError(TypeError): +class InvalidResponseTypeError(TypeError, ContextualError): """Error thrown when the response type is invalid.""" def __init__(self, response_type: type) -> None: @@ -162,7 +172,7 @@ def __init__(self, response_type: type) -> None: super().__init__(message) -class InvalidPromptResponseKeysError(ValueError): +class InvalidPromptResponseKeysError(ValueError, ContextualError): """Error thrown when missing expected prompt and response keys.""" def __init__(self, mapping: Dict[str, str], example: Dict[str, Any]): @@ -171,7 +181,7 @@ def __init__(self, mapping: Dict[str, str], example: Dict[str, Any]): super().__init__(message) -class InvalidFileExtensionError(FileNotFoundError): +class InvalidFileExtensionError(FileNotFoundError, ContextualError): """Error thrown when a file extension is not a safe extension.""" def __init__(self, dataset_name: str, valid_extensions: List[str]) -> None: @@ -184,7 +194,7 @@ def __init__(self, dataset_name: str, valid_extensions: List[str]) -> None: super().__init__(message) -class UnableToProcessPromptResponseError(ValueError): +class UnableToProcessPromptResponseError(ValueError, ContextualError): """Error thrown when a prompt and response cannot be processed.""" def __init__(self, input: Dict) -> None: @@ -194,7 +204,7 @@ def __init__(self, input: Dict) -> None: ## Convert Delta to JSON exceptions -class ClusterDoesNotExistError(ValueError): +class ClusterDoesNotExistError(ValueError, ContextualError): """Error thrown when the cluster does not exist.""" def __init__(self, cluster_id: str) -> None: @@ -203,7 +213,7 @@ def __init__(self, cluster_id: str) -> None: super().__init__(message) -class FailedToCreateSQLConnectionError(RuntimeError): +class FailedToCreateSQLConnectionError(RuntimeError, ContextualError): """Error thrown when client can't sql connect to Databricks.""" def __init__(self) -> None: @@ -211,7 +221,7 @@ def __init__(self) -> None: super().__init__(message) -class FailedToConnectToDatabricksError(RuntimeError): +class FailedToConnectToDatabricksError(RuntimeError, ContextualError): """Error thrown when the client fails to connect to Databricks.""" def __init__(self) -> None: @@ -220,7 +230,7 @@ def __init__(self) -> None: ## Convert Text to MDS exceptions -class InputFolderMissingDataError(ValueError): +class InputFolderMissingDataError(ValueError, ContextualError): """Error thrown when the input folder is missing data.""" def __init__(self, input_folder: str) -> None: @@ -229,7 +239,7 @@ def __init__(self, input_folder: str) -> None: super().__init__(message) -class OutputFolderNotEmptyError(FileExistsError): +class OutputFolderNotEmptyError(FileExistsError, ContextualError): """Error thrown when the output folder is not empty.""" def __init__(self, output_folder: str) -> None: @@ -238,7 +248,7 @@ def __init__(self, output_folder: str) -> None: super().__init__(message) -class MisconfiguredHfDatasetError(ValueError): +class MisconfiguredHfDatasetError(ValueError, ContextualError): """Error thrown when a HuggingFace dataset is misconfigured.""" def __init__(self, dataset_name: str, split: str) -> None: diff --git a/scripts/train/train.py b/scripts/train/train.py index 880d4f2350..e0c2b8a94f 100644 --- a/scripts/train/train.py +++ b/scripts/train/train.py @@ -29,6 +29,11 @@ log_train_analytics, maybe_create_mosaicml_logger, ) +from llmfoundry.utils.exceptions import ( + ContextualError, + EvalDataLoaderLocation, + TrainDataLoaderLocation, +) install() @@ -391,8 +396,9 @@ def main(cfg: DictConfig) -> Trainer: tokenizer, train_cfg.device_train_batch_size, ) - except Exception as e: + except ContextualError as e: if mosaicml_logger is not None: + e.location = TrainDataLoaderLocation mosaicml_logger.log_exception(e) raise e @@ -409,19 +415,25 @@ def main(cfg: DictConfig) -> Trainer: train_cfg.eval_first = False else: - log.info('Building eval loader...') - eval_icl_seq_len: int = train_cfg.icl_seq_len if train_cfg.icl_seq_len else train_cfg.max_seq_len - evaluators, _, eval_gauntlet_callback = build_evaluators( - eval_loader_config, - icl_tasks_config, - eval_gauntlet_config, - tokenizer=tokenizer, - device_eval_batch_size=train_cfg.device_eval_batch_size, - icl_seq_len=eval_icl_seq_len, - icl_subset_num_batches=train_cfg.icl_subset_num_batches, - ) - if eval_gauntlet_callback is not None: - callbacks.append(eval_gauntlet_callback) + try: + log.info('Building eval loader...') + eval_icl_seq_len: int = train_cfg.icl_seq_len if train_cfg.icl_seq_len else train_cfg.max_seq_len + evaluators, _, eval_gauntlet_callback = build_evaluators( + eval_loader_config, + icl_tasks_config, + eval_gauntlet_config, + tokenizer=tokenizer, + device_eval_batch_size=train_cfg.device_eval_batch_size, + icl_seq_len=eval_icl_seq_len, + icl_subset_num_batches=train_cfg.icl_subset_num_batches, + ) + if eval_gauntlet_callback is not None: + callbacks.append(eval_gauntlet_callback) + except ContextualError as e: + if mosaicml_logger is not None: + e.location = EvalDataLoaderLocation + mosaicml_logger.log_exception(e) + raise e if mosaicml_logger is not None: log_train_analytics( @@ -467,8 +479,9 @@ def main(cfg: DictConfig) -> Trainer: evaluators, non_icl_metrics, ) - except Exception as e: + except ContextualError as e: if mosaicml_logger is not None: + e.location = EvalDataLoaderLocation mosaicml_logger.log_exception(e) raise e