Skip to content

Commit

Permalink
commit change
Browse files Browse the repository at this point in the history
  • Loading branch information
Chuck Tang committed Jun 14, 2024
1 parent 5ce7839 commit fbf0967
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions llmfoundry/callbacks/hf_checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import numpy as np
import torch
import torch.nn as nn
from composer.core import Callback, Event, State, Time, TimeUnit
from composer.core import Callback, Event, Precision, State, Time, TimeUnit
from composer.core.state import fsdp_state_dict_type_context
from composer.loggers import Logger, MLFlowLogger
from composer.models import HuggingFaceModel
Expand Down Expand Up @@ -496,7 +496,8 @@ def dtensor_to_tensor_hook(
# Needed for proper hf ckpt saving.
context_manager = te.onnx_export(
True,
) if is_te_imported else contextlib.nullcontext()
) if is_te_imported and state.precision == Precision.AMP_FP8 else contextlib.nullcontext(
)
with context_manager:
new_model_instance.save_pretrained(temp_save_dir)
if original_tokenizer is not None:
Expand Down

0 comments on commit fbf0967

Please sign in to comment.