Skip to content

Commit

Permalink
Fix: copy copy_feature_extractor for whisper model
Browse files Browse the repository at this point in the history
Signed-off-by: sagewe <[email protected]>
  • Loading branch information
sagewe committed Nov 18, 2024
1 parent 6040dfb commit 0abd1e3
Showing 1 changed file with 43 additions and 0 deletions.
43 changes: 43 additions & 0 deletions mergekit/merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,19 @@ def run_merge(
"Chat template specified but no tokenizer found. Chat template will not be saved."
)

# Copy feature_extractor if it is a whisper model
if options.copy_feature_extractor and arch_info.definition.expected_model_type == "whisper":
try:
_copy_feature_extractor(
merge_config, out_path, trust_remote_code=options.trust_remote_code
)
except Exception as e:
logging.error(
"Failed to copy feature_extractor. The merge was still successful, just copy it from somewhere else.",
exc_info=e,
)


if tokenizer:
logging.info("Saving tokenizer")
_set_chat_template(tokenizer, merge_config)
Expand Down Expand Up @@ -229,6 +242,36 @@ def _copy_tokenizer(
tokenizer.save_pretrained(out_path, safe_serialization=True)


def _copy_feature_extractor(
merge_config: MergeConfiguration, out_path: str, trust_remote_code: bool = False
):
donor_model = merge_config.base_model or (merge_config.referenced_models()[0])

if (os.path.exists(
os.path.join(donor_model.model.path, "preprocessor_config.json")
)
):
logging.info(f"Copying feature_extractor from {donor_model}")

for file_name in [
"preprocessor_config.json",
]:
if os.path.exists(os.path.join(donor_model.model.path, file_name)):
shutil.copy(
os.path.join(donor_model.model.path, file_name),
os.path.join(out_path, file_name),
)
return

# fallback: try actually loading the feature_extractor and saving it
logging.info(f"Reserializing feature_extractor from {donor_model}")
feature_extractor = transformers.AutoFeatureExtractor.from_pretrained(
donor_model.model.path,
revision=donor_model.model.revision,
trust_remote_code=trust_remote_code,
)
_set_chat_template(feature_extractor, merge_config)
feature_extractor.save_pretrained(out_path, safe_serialization=True)
def _model_out_config(
config: MergeConfiguration,
arch_info: ArchitectureInfo,
Expand Down

0 comments on commit 0abd1e3

Please sign in to comment.