From 72e13560ad20a459f04982372e7a4ed231a6bfab Mon Sep 17 00:00:00 2001 From: Chuck Tang Date: Fri, 14 Jun 2024 13:27:01 -0700 Subject: [PATCH] add precision --- llmfoundry/callbacks/hf_checkpointer.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/llmfoundry/callbacks/hf_checkpointer.py b/llmfoundry/callbacks/hf_checkpointer.py index 0ae59a21b4..017ab1e565 100644 --- a/llmfoundry/callbacks/hf_checkpointer.py +++ b/llmfoundry/callbacks/hf_checkpointer.py @@ -17,7 +17,7 @@ import numpy as np import torch import torch.nn as nn -from composer.core import Callback, Event, State, Time, TimeUnit +from composer.core import Callback, Event, Precision, State, Time, TimeUnit from composer.core.state import fsdp_state_dict_type_context from composer.loggers import Logger, MLFlowLogger from composer.models import HuggingFaceModel @@ -496,7 +496,8 @@ def dtensor_to_tensor_hook( # Needed for proper hf ckpt saving. context_manager = te.onnx_export( True, - ) if is_te_imported else contextlib.nullcontext() + ) if is_te_imported and state.precision == Precision.AMP_FP8 else contextlib.nullcontext( + ) with context_manager: new_model_instance.save_pretrained(temp_save_dir) if original_tokenizer is not None: