diff --git a/utils/deprecate_models.py b/utils/deprecate_models.py index d5160e93842095..2307f997202ce0 100644 --- a/utils/deprecate_models.py +++ b/utils/deprecate_models.py @@ -79,17 +79,13 @@ def insert_tip_to_model_doc(model_doc_path, tip_message): def get_model_doc_path(model: str) -> Tuple[Optional[str], Optional[str]]: # Possible variants of the model name in the model doc path - model_doc_paths = [ - REPO_PATH / f"docs/source/en/model_doc/{model}.md", - # Try replacing _ with - in the model name - REPO_PATH / f"docs/source/en/model_doc/{model.replace('_', '-')}.md", - # Try replacing _ with "" in the model name - REPO_PATH / f"docs/source/en/model_doc/{model.replace('_', '')}.md", - ] + model_names = [model, model.replace("_", "-"), model.replace("_", "")] + + model_doc_paths = [REPO_PATH / f"docs/source/en/model_doc/{model_name}.md" for model_name in model_names] - for model_doc_path in model_doc_paths: + for model_doc_path, model_name in zip(model_doc_paths, model_names): if os.path.exists(model_doc_path): - return model_doc_path, model + return model_doc_path, model_name return None, None @@ -186,6 +182,7 @@ def remove_model_references_from_file(filename, models, condition): models (List[str]): The models to remove condition (Callable): A function that takes the line and model and returns True if the line should be removed """ + filename = REPO_PATH / filename with open(filename, "r") as f: init_file = f.read()