Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Removing logging exception through update run metadata #1292

Merged
merged 12 commits into from
Jul 9, 2024
11 changes: 2 additions & 9 deletions llmfoundry/callbacks/curriculum_learning_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

from composer import DataSpec
from composer.core import State, Time, TimeUnit, ensure_time
from composer.loggers import Logger, MosaicMLLogger
from composer.loggers import Logger
from streaming import StreamingDataset
from streaming.base.util import clean_stale_shared_memory
from torch.utils.data import DataLoader
Expand All @@ -23,7 +23,6 @@
BaseContextualError,
TrainDataLoaderLocation,
)
from llmfoundry.utils.mosaicml_logger_utils import no_override_excepthook

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -238,13 +237,7 @@ def _build_train_loader(
self.device_train_batch_size,
)
except BaseContextualError as e:
for destination in logger.destinations:
if (
isinstance(destination, MosaicMLLogger) and
no_override_excepthook()
):
e.location = TrainDataLoaderLocation
destination.log_exception(e)
e.location = TrainDataLoaderLocation
raise e

def _validate_dataloader(self, train_loader: Any):
Expand Down
19 changes: 6 additions & 13 deletions llmfoundry/callbacks/run_timeout_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,18 @@
from typing import Optional

from composer import Callback, Logger, State
from composer.loggers import MosaicMLLogger

from llmfoundry.utils.exceptions import RunTimeoutError
from llmfoundry.utils.mosaicml_logger_utils import no_override_excepthook

log = logging.getLogger(__name__)


def _timeout(timeout: int, mosaicml_logger: Optional[MosaicMLLogger] = None):
def _timeout(timeout: int):
log.error(f'Timeout after {timeout} seconds of inactivity after fit_end.',)
if mosaicml_logger is not None and no_override_excepthook():
mosaicml_logger.log_exception(RunTimeoutError(timeout=timeout))
os.kill(os.getpid(), signal.SIGINT)
try:
raise RunTimeoutError(timeout=timeout)
except RunTimeoutError:
jjanezhang marked this conversation as resolved.
Show resolved Hide resolved
os.kill(os.getpid(), signal.SIGINT)


class RunTimeoutCallback(Callback):
Expand All @@ -30,14 +29,8 @@ def __init__(
timeout: int = 1800,
):
self.timeout = timeout
self.mosaicml_logger: Optional[MosaicMLLogger] = None
self.timer: Optional[threading.Timer] = None

def init(self, state: State, logger: Logger):
for callback in state.callbacks:
if isinstance(callback, MosaicMLLogger):
self.mosaicml_logger = callback

def _reset(self):
if self.timer is not None:
self.timer.cancel()
Expand All @@ -48,7 +41,7 @@ def _timeout(self):
self.timer = threading.Timer(
self.timeout,
_timeout,
[self.timeout, self.mosaicml_logger],
[self.timeout],
)
self.timer.daemon = True
self.timer.start()
Expand Down
24 changes: 6 additions & 18 deletions scripts/data_prep/convert_delta_to_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,6 @@
from pyspark.sql.dataframe import DataFrame as SparkDataFrame
from pyspark.sql.types import Row

from llmfoundry.utils import (
maybe_create_mosaicml_logger,
no_override_excepthook,
)
from llmfoundry.utils.exceptions import (
ClusterDoesNotExistError,
FailedToConnectToDatabricksError,
Expand Down Expand Up @@ -667,18 +663,10 @@ def fetch_DT(args: Namespace) -> None:
'The name of the combined final jsonl that combines all partitioned jsonl',
)
args = parser.parse_args()
mosaicml_logger = maybe_create_mosaicml_logger()

try:
w = WorkspaceClient()
args.DATABRICKS_HOST = w.config.host
args.DATABRICKS_TOKEN = w.config.token

tik = time.time()
fetch_DT(args)
log.info(f'Elapsed time {time.time() - tik}')
w = WorkspaceClient()
jjanezhang marked this conversation as resolved.
Show resolved Hide resolved
args.DATABRICKS_HOST = w.config.host
args.DATABRICKS_TOKEN = w.config.token

except Exception as e:
if mosaicml_logger is not None and no_override_excepthook():
mosaicml_logger.log_exception(e)
raise e
tik = time.time()
fetch_DT(args)
log.info(f'Elapsed time {time.time() - tik}')
40 changes: 14 additions & 26 deletions scripts/data_prep/convert_text_to_mds.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,6 @@
from transformers import AutoTokenizer, PreTrainedTokenizerBase

from llmfoundry.data.data import AbstractConcatTokensDataset
from llmfoundry.utils import (
maybe_create_mosaicml_logger,
no_override_excepthook,
)
from llmfoundry.utils.data_prep_utils import (
DownloadingIterable,
download_file,
Expand Down Expand Up @@ -608,25 +604,17 @@ def _configure_logging(logging_level: str):
if __name__ == '__main__':
args = parse_args()
_configure_logging(args.logging_level)

mosaicml_logger = maybe_create_mosaicml_logger()

try:
convert_text_to_mds(
tokenizer_name=args.tokenizer,
output_folder=args.output_folder,
input_folder=args.input_folder,
concat_tokens=args.concat_tokens,
eos_text=args.eos_text,
bos_text=args.bos_text,
no_wrap=args.no_wrap,
compression=args.compression,
processes=args.processes,
reprocess=args.reprocess,
trust_remote_code=args.trust_remote_code,
args_str=_args_str(args),
)
except Exception as e:
if mosaicml_logger is not None and no_override_excepthook():
mosaicml_logger.log_exception(e)
raise e
convert_text_to_mds(
tokenizer_name=args.tokenizer,
output_folder=args.output_folder,
input_folder=args.input_folder,
concat_tokens=args.concat_tokens,
eos_text=args.eos_text,
bos_text=args.bos_text,
no_wrap=args.no_wrap,
compression=args.compression,
processes=args.processes,
reprocess=args.reprocess,
trust_remote_code=args.trust_remote_code,
args_str=_args_str(args),
)
7 changes: 0 additions & 7 deletions scripts/train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,6 @@
EvalDataLoaderLocation,
TrainDataLoaderLocation,
)
from llmfoundry.utils.mosaicml_logger_utils import no_override_excepthook
from llmfoundry.utils.registry_utils import import_file

log = logging.getLogger(__name__)
Expand Down Expand Up @@ -398,8 +397,6 @@ def main(cfg: DictConfig) -> Trainer:
)
except BaseContextualError as e:
e.location = TrainDataLoaderLocation
if mosaicml_logger is not None and no_override_excepthook():
mosaicml_logger.log_exception(e)
raise e

if mosaicml_logger is not None:
Expand Down Expand Up @@ -431,8 +428,6 @@ def main(cfg: DictConfig) -> Trainer:
callbacks.append(eval_gauntlet_callback)
except BaseContextualError as e:
e.location = EvalDataLoaderLocation
jjanezhang marked this conversation as resolved.
Show resolved Hide resolved
if mosaicml_logger is not None and no_override_excepthook():
mosaicml_logger.log_exception(e)
raise e

if mosaicml_logger is not None:
Expand Down Expand Up @@ -481,8 +476,6 @@ def main(cfg: DictConfig) -> Trainer:
)
except BaseContextualError as e:
e.location = EvalDataLoaderLocation
if mosaicml_logger is not None and no_override_excepthook():
mosaicml_logger.log_exception(e)
raise e

compile_config = train_cfg.compile_config
Expand Down
Loading