From f255c4c20bc5449fc8aad8a19ccabbf8c4dc40fd Mon Sep 17 00:00:00 2001 From: gkumbhat Date: Fri, 26 Jan 2024 17:57:46 -0600 Subject: [PATCH] :loud_sound: Add model and tokenizer time logging Signed-off-by: gkumbhat --- .pre-commit-config.yaml | 1 + caikit_nlp/resources/pretrained_model/base.py | 32 ++++++++++--------- 2 files changed, 18 insertions(+), 15 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index c26b9621..3987f0dc 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -8,6 +8,7 @@ repos: hooks: - id: black exclude: imports + additional_dependencies: ["platformdirs"] - repo: https://github.com/PyCQA/isort rev: 5.11.5 hooks: diff --git a/caikit_nlp/resources/pretrained_model/base.py b/caikit_nlp/resources/pretrained_model/base.py index eba74744..b68c8d67 100644 --- a/caikit_nlp/resources/pretrained_model/base.py +++ b/caikit_nlp/resources/pretrained_model/base.py @@ -201,25 +201,27 @@ def bootstrap( else "right" ) - # Load the tokenizer and set up the pad token if needed - tokenizer = AutoTokenizer.from_pretrained( - tokenizer_name, - local_files_only=not get_config().allow_downloads, - padding_side=padding_side, - # We can't disable use_fast otherwise unit test fails - # use_fast=False, - ) + with alog.ContextTimer(log.info, "Tokenizer loaded in "): + # Load the tokenizer and set up the pad token if needed + tokenizer = AutoTokenizer.from_pretrained( + tokenizer_name, + local_files_only=not get_config().allow_downloads, + padding_side=padding_side, + # We can't disable use_fast otherwise unit test fails + # use_fast=False, + ) if tokenizer.pad_token_id is None: tokenizer.pad_token_id = tokenizer.eos_token_id - # Load the model - model = cls.MODEL_TYPE.from_pretrained( - model_name, - local_files_only=not get_config().allow_downloads, - torch_dtype=torch_dtype, - **kwargs, - ) + with alog.ContextTimer(log.info, f"Model {model_name} loaded in "): + # Load the model + model = cls.MODEL_TYPE.from_pretrained( + model_name, + local_files_only=not get_config().allow_downloads, + torch_dtype=torch_dtype, + **kwargs, + ) log.debug4("Model Details: %s", model) # Create the class instance