Skip to content

Commit

Permalink
alternative CLI-based model download
Browse files Browse the repository at this point in the history
  • Loading branch information
ejolly committed Mar 29, 2024
1 parent 497e76d commit 7bcad4a
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 17 deletions.
14 changes: 14 additions & 0 deletions feat/pretrained.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ def get_pretrained_models(
)
)

print("Downloading required models...")
# Face model
if face_model is None:
raise ValueError(
Expand Down Expand Up @@ -257,3 +258,16 @@ def fetch_model(model_type, model_name):
model_type = PRETRAINED_MODELS[model_type]
matches = list(filter(lambda e: model_name in e.keys(), model_type))[0]
return list(matches.values())[0]


def download_default_models():
"""Used by CLI to download all default models at once."""
face, landmark, au, emotion, facepose, identity = get_pretrained_models(
face_model="retinaface",
landmark_model="mobilefacenet",
au_model="xgb",
emotion_model="resmasknet",
facepose_model="img2pose",
identity_model="facenet",
verbose=False,
)
21 changes: 4 additions & 17 deletions setup.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from setuptools import setup, find_packages
from feat.pretrained import get_pretrained_models

with open("requirements.txt") as f:
requirements = f.read().splitlines()
Expand All @@ -10,18 +9,6 @@

extra_setuptools_args = dict(tests_require=["pytest"])


def download_default_models():
face, landmark, au, emotion, facepose, identity = get_pretrained_models(
face_model="retinaface",
landmark_model="mobilefacenet",
au_model="xgb",
emotion_model="resmasknet",
facepose_model="img2pose",
identity_model="facenet",
)


setup(
name="py-feat",
version=version["__version__"],
Expand All @@ -44,10 +31,10 @@ def download_default_models():
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
],
extras_require={
"default_models": [
download_default_models,
],
entry_points={
"console_scripts": [
"feat_get_models=feat.pretrained:download_default_models",
]
},
test_suite="feat/tests",
**extra_setuptools_args
Expand Down

0 comments on commit 7bcad4a

Please sign in to comment.