diff --git a/scripts/conf/asr_vicuna_lora.yaml b/scripts/conf/asr_vicuna_lora.yaml index 3e90ad98..ccc983df 100644 --- a/scripts/conf/asr_vicuna_lora.yaml +++ b/scripts/conf/asr_vicuna_lora.yaml @@ -47,7 +47,7 @@ train_config: weight_decay: 0.0 gamma: 0.85 seed: 42 - use_fp16: False + use_fp16: false mixed_precision: true val_batch_size: 1 @@ -91,7 +91,7 @@ dataset_config: fsdp_config: mixed_precision: true - use_fp16: II"${train_config.use_fp16}" + use_fp16: false # sharding_strategy: "FULL_SHARD" #ShardingStrategy = ShardingStrategy.FULL_SHARD sharding_strategy: "NO_SHARD" #ShardingStrategy.NO_SHARD #MZY: set NO_SHARD to use DDP mode in FSDP checkpoint_type: "StateDictType.SHARDED_STATE_DICT" # alternatively can use SHARDED_STATE_DICT save one file per rank, and can resize the world-size. diff --git a/src/llama_recipes/pipeline/finetune.py b/src/llama_recipes/pipeline/finetune.py index acedb1f6..de7170aa 100644 --- a/src/llama_recipes/pipeline/finetune.py +++ b/src/llama_recipes/pipeline/finetune.py @@ -81,6 +81,7 @@ def main(kwargs: DictConfig): kwargs.model_config, \ kwargs.log_config, \ kwargs.dataset_config + fsdp_config.use_fp16 = train_config.use_fp16 del kwargs.train_config del kwargs.fsdp_config del kwargs.model_config