Skip to content

Commit

Permalink
hugging face (#468)
Browse files Browse the repository at this point in the history
* adapt to clinicadl hugging face organization

* finish skip leak check to use clinicadl with downloaded masp
  • Loading branch information
camillebrianceau authored Feb 15, 2024
1 parent 51d3175 commit 972e94d
Show file tree
Hide file tree
Showing 8 changed files with 391 additions and 4 deletions.
2 changes: 2 additions & 0 deletions clinicadl/cmdline.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import click

from clinicadl.generate.generate_cli import cli as generate_cli
from clinicadl.hugging_face.hugging_face_cli import cli as hf_cli
from clinicadl.interpret.interpret_cli import cli as interpret_cli
from clinicadl.predict.predict_cli import cli as predict_cli
from clinicadl.prepare_data.prepare_data_cli import cli as prepare_data_cli
Expand Down Expand Up @@ -50,6 +51,7 @@ def cli(verbose):
cli.add_command(interpret_cli)
cli.add_command(qc_cli)
cli.add_command(random_search_cli)
cli.add_command(hf_cli)

if __name__ == "__main__":
cli()
232 changes: 232 additions & 0 deletions clinicadl/hugging_face/hugging_face.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,232 @@
import importlib
import os
from logging import getLogger
from pathlib import Path

import toml

from clinicadl.utils.exceptions import ClinicaDLArgumentError
from clinicadl.utils.maps_manager.maps_manager_utils import (
change_str_to_path,
read_json,
remove_unused_tasks,
)

logger = getLogger("clinicadl")


def hf_hub_is_available():
return importlib.util.find_spec("huggingface_hub") is not None


def push_to_hf_hub(
hf_hub_path: str,
maps_dir: Path,
model_name: str,
):
if hf_hub_is_available():
from huggingface_hub import CommitOperationAdd, HfApi
else:
raise ModuleNotFoundError(
"`huggingface_hub` package must be installed to push your model to the HF hub. "
"Run `python -m pip install huggingface_hub` and log in to your account with "
"`huggingface-cli login`."
)

model_card_ = """
---
language: en
library_name: clinicadl
tags:
- clinicadl
license: mit
---
"""
hf_hub_path = "ClinicaDL" if hf_hub_path.lower() == "clinicadl" else hf_hub_path

config_file = maps_dir / "maps.json"
n_splits = create_readme(
config_file=config_file, model_name=model_name, model_card=model_card_
)
logger.info(f"Uploading {model_name} model to {hf_hub_path} repo in HF hub...")
api = HfApi()
hf_operations = []
id_ = os.path.join(hf_hub_path, model_name)
user = api.whoami()
list_orgs = [x["name"] for x in user["orgs"]]

if hf_hub_path == "ClinicaDL":
if "ClinicaDL" not in list_orgs:
raise ClinicaDLArgumentError(
"You're not in the ClinicaDL organization on Hugging Face. Please follow the link to request to join the organization: https://huggingface.co/clinicadl-test"
)
elif hf_hub_path != user["name"]:
raise ClinicaDLArgumentError(
f"You're logged as {user['name']} in Hugging Face and you are trying to push a model under {hf_hub_path} logging."
)

tmp_file = "tmp_README.md"
hf_operations = [
CommitOperationAdd(path_in_repo="README.md", path_or_fileobj=tmp_file),
CommitOperationAdd(
path_in_repo="maps.json", path_or_fileobj=maps_dir / "maps.json"
),
]

for split in range(n_splits):
hf_operations.append(
CommitOperationAdd(
path_in_repo=str(("split-" + str(split)) + "/best-loss/model.pth.tar"),
path_or_fileobj=str(
maps_dir / ("split-" + str(split)) / "best-loss" / "model.pth.tar"
),
)
)

for root, dirs, files in os.walk(maps_dir, topdown=False):
for name in files:
hf_operations.append(
CommitOperationAdd(
path_in_repo=str(
("split-" + str(split)) + "/best-loss/model.pth.tar"
),
path_or_fileobj=str(
maps_dir
/ ("split-" + str(split))
/ "best-loss"
/ "model.pth.tar"
),
)
)

try:
api.create_commit(
commit_message=f"Uploading {model_name} in {maps_dir}",
repo_id=id_,
operations=hf_operations,
private=True,
)
logger.info(f"Successfully uploaded {model_name} to {maps_dir} repo in HF hub!")

except:
from huggingface_hub import create_repo

repo_name = maps_dir.name
logger.info(f"Creating {repo_name} in the HF hub since it does not exist...")
create_repo(repo_id=id_)
logger.info(f"Successfully created {repo_name} in the HF hub!")

api.create_commit(
commit_message=f"Uploading {model_name} in {maps_dir}",
repo_id=id_,
operations=hf_operations,
)

if Path(tmp_file).exists():
Path(tmp_file).unlink()


def create_readme(
config_file: Path = None, model_name: str = "test", model_card: str = None
):
if not config_file.is_file():
raise ClinicaDLArgumentError("There is no maps.json file in your repository.")

clinicadl_root_dir = (Path(__file__) / "../..").resolve()
config_path = (
Path(clinicadl_root_dir) / "resources" / "config" / "train_config.toml"
)
config_dict = toml.load(config_path)

train_dict = read_json(config_file)
train_dict = change_str_to_path(train_dict)

task = train_dict["network_task"]

config_dict = remove_unused_tasks(config_dict, task)
config_dict = change_str_to_path(config_dict)

file = open("tmp_README.md", "w")
list_lines = []
list_lines.append(model_card)
list_lines.append(f"# Model Card for {model_name} \n")
list_lines.append(
f"This model was trained with ClinicaDL. You can find here all the information.\n"
)

list_lines.append(f"## General information \n")

if train_dict["multi_cohort"]:
list_lines.append(
f"This model was trained on several datasets at the same time. \n"
)
list_lines.append(
f"This model was trained for **{task}** and the architecture chosen is **{train_dict['architecture']}**. \n"
)

for config_section in config_dict:
list_lines.append(f"### {config_section} \n")
for key in config_dict[config_section]:
if key == "preprocessing_dict":
list_lines.append(f"### Preprocessing \n")
for key_bis in config_dict[config_section][key]:
list_lines.append(
f"**{key_bis}**: {config_dict[config_section][key][key_bis]} \n"
)
else:
if key in train_dict:
config_dict[config_section][key] = train_dict[key]
train_dict.pop(key)
list_lines.append(f"**{key}**: {config_dict[config_section][key]} \n")
list_lines.append(f"### Other information \n")
for key in train_dict:
list_lines.append(f"**{key}**: {train_dict[key]} \n")

file.writelines(list_lines)
file.close()
return config_dict["Cross_validation"]["n_splits"]


def load_from_hf_hub(
output_maps: Path, hf_hub_path: str, maps_name: str
): # pragma: no cover
"""Class method to be used to load a pretrained model from the Hugging Face hub
Parameters
----------
output_path: str,
hf_hub_path: (str)
The path where the model should have been be saved on thehugginface hub.
maps_name: str
"""

if hf_hub_is_available():
from huggingface_hub import HfApi, snapshot_download
else:
raise ModuleNotFoundError(
"`huggingface_hub` package must be installed to push your model to the HF hub. "
"Run `python -m pip install huggingface_hub` and log in to your account with "
"`huggingface-cli login`."
)

hf_hub_path = "ClinicaDL" if hf_hub_path.lower() == "clinicadl" else hf_hub_path

api = HfApi()
id_ = os.path.join(hf_hub_path, maps_name)
user = api.whoami()
list_orgs = [x["name"] for x in user["orgs"]]

if hf_hub_path == "ClinicaDL":
if "ClinicaDL" not in list_orgs:
raise ClinicaDLArgumentError(
"You're not in the ClinicaDL organization on Hugging Face. Please follow the link to request to join the organization: https://huggingface.co/clinicadl-test"
)
elif hf_hub_path != user["name"]:
logger.warning(
f"You're logged as {user['name']} in Hugging Face and you are trying to pull a model from {hf_hub_path}."
)
else:
logger.info(f"Downloading {hf_hub_path} files for rebuilding...")

environment_json = snapshot_download(repo_id=id_, local_dir=output_maps)
18 changes: 18 additions & 0 deletions clinicadl/hugging_face/hugging_face_cli.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import click

from .pull_cli import cli as pull_cli
from .push_cli import cli as push_cli


@click.group(name="hugging-face", no_args_is_help=True)
def cli():
"""Train a deep learning model for a specific task."""
pass


cli.add_command(push_cli)
cli.add_command(pull_cli)


if __name__ == "__main__":
cli()
32 changes: 32 additions & 0 deletions clinicadl/hugging_face/pull_cli.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from pathlib import Path

import click

from clinicadl.utils import cli_param
from clinicadl.utils.maps_manager import MapsManager


@click.command(name="pull", no_args_is_help=True)
@click.argument(
"hf_hub_path",
type=str,
default=None,
)
@click.argument(
"maps_name",
type=str,
default="maps",
)
@cli_param.argument.output_maps
def cli(hf_hub_path, maps_name, output_maps_directory):
from .hugging_face import load_from_hf_hub

load_from_hf_hub(
output_maps=output_maps_directory,
hf_hub_path=hf_hub_path,
maps_name=maps_name,
)


if __name__ == "__main__":
cli()
33 changes: 33 additions & 0 deletions clinicadl/hugging_face/push_cli.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import click

from clinicadl.utils import cli_param


@click.command(name="push", no_args_is_help=True)
@click.argument(
"organization",
type=str,
default=None,
)
@cli_param.argument.input_maps
@click.argument(
"hf_maps_directory",
type=str,
default=None,
)
def cli(
organization,
input_maps_directory,
hf_maps_directory,
):
from .hugging_face import push_to_hf_hub

push_to_hf_hub(
hf_hub_path=organization,
maps_dir=input_maps_directory,
model_name=hf_maps_directory,
)


if __name__ == "__main__":
cli()
9 changes: 5 additions & 4 deletions clinicadl/utils/maps_manager/maps_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,7 +428,7 @@ def predict(

if cluster.master:
self._ensemble_prediction(
data_group, split, selection_metrics, use_labels
data_group, split, selection_metrics, use_labels, skip_leak_check
)

def interpret(
Expand Down Expand Up @@ -1992,6 +1992,7 @@ def _ensemble_prediction(
split,
selection_metrics,
use_labels=True,
skip_leak_check=False,
):
"""Computes the results on the image-level."""

Expand All @@ -2000,14 +2001,14 @@ def _ensemble_prediction(

for selection_metric in selection_metrics:
# Soft voting
if self.num_networks > 1:
if self.num_networks > 1 and not skip_leak_check:
self._ensemble_to_tsv(
split,
selection=selection_metric,
data_group=data_group,
use_labels=use_labels,
)
elif self.mode != "image":
elif self.mode != "image" and not skip_leak_check:
self._mode_to_image_tsv(
split,
selection=selection_metric,
Expand Down Expand Up @@ -2243,7 +2244,7 @@ def _check_data_group(
f"To erase {data_group} please set overwrite to True."
)

if not group_dir.is_dir() and (
elif not group_dir.is_dir() and (
caps_directory is None or df is None
): # Data group does not exist yet / was overwritten + missing data
raise ClinicaDLArgumentError(
Expand Down
Loading

0 comments on commit 972e94d

Please sign in to comment.