Skip to content

Commit

Permalink
ENH: add prefix support
Browse files Browse the repository at this point in the history
  • Loading branch information
younik committed Nov 23, 2024
1 parent 1c01673 commit 4cfbf39
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 19 deletions.
5 changes: 3 additions & 2 deletions minari/storage/hosting.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import warnings
from collections import defaultdict
from concurrent.futures import ThreadPoolExecutor
from typing import Dict, List
from typing import Dict, List, Optional

from minari.dataset.minari_dataset import gen_dataset_id, parse_dataset_id
from minari.dataset.minari_storage import MinariStorage
Expand Down Expand Up @@ -180,6 +180,7 @@ def download_dataset(dataset_id: str, force_download: bool = False):
def list_remote_datasets(
latest_version: bool = False,
compatible_minari_version: bool = False,
prefix: Optional[str] = None,
) -> Dict[str, Dict[str, str]]:
"""Get the names and metadata of all the Minari datasets in the remote Farama server.
Expand All @@ -200,7 +201,7 @@ def download_metadata(dataset_id):
if supported_dataset or not compatible_minari_version:
return metadata

dataset_ids = cloud_storage.list_datasets()
dataset_ids = cloud_storage.list_datasets(prefix=prefix)
with ThreadPoolExecutor(max_workers=10) as executor:
remote_metadatas = executor.map(download_metadata, dataset_ids)

Expand Down
16 changes: 10 additions & 6 deletions minari/storage/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import pathlib
import shutil
import warnings
from typing import Dict, Iterable, Tuple, Union
from typing import Dict, Iterable, Optional, Tuple, Union

from minari.dataset.minari_dataset import (
MinariDataset,
Expand All @@ -19,9 +19,9 @@
__version__ = importlib.metadata.version("minari")


def list_non_hidden_dirs(path: str) -> Iterable[str]:
def list_non_hidden_dirs(path: pathlib.Path) -> Iterable[str]:
"""List all non-hidden subdirectories."""
for d in os.scandir(path):
for d in path.iterdir():
if d.is_dir() and (not d.name.startswith(".")):
yield d.name

Expand Down Expand Up @@ -60,6 +60,7 @@ def load_dataset(dataset_id: str, download: bool = False):
def list_local_datasets(
latest_version: bool = False,
compatible_minari_version: bool = False,
prefix: Optional[str] = None,
) -> Dict[str, Dict[str, Union[str, int, bool]]]:
"""Get the ids and metadata of all the Minari datasets in the local database.
Expand All @@ -75,8 +76,11 @@ def list_local_datasets(
datasets_path = get_dataset_path()
dataset_ids = []

def recurse_directories(base_path, namespace):
parent_dir = os.path.join(base_path, namespace)
def recurse_directories(base_path: pathlib.Path, namespace):
parent_dir = base_path.joinpath(namespace)
if not parent_dir.exists():
return

for dir_name in list_non_hidden_dirs(parent_dir):
dir_path = os.path.join(parent_dir, dir_name)
namespaced_dir_name = os.path.join(namespace, dir_name)
Expand All @@ -86,7 +90,7 @@ def recurse_directories(base_path, namespace):
else:
recurse_directories(base_path, namespaced_dir_name)

recurse_directories(datasets_path, "")
recurse_directories(datasets_path, prefix or "")

dataset_ids = sorted(dataset_ids, key=dataset_id_sort_key)

Expand Down
32 changes: 21 additions & 11 deletions minari/storage/remotes/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@ def __init__(self, name: str, token: Optional[str] = None) -> None:
self._api = HfApi(token=token)

def _decompose_path(self, path: str) -> Tuple[str, str]:
root, *rem = path.split("/")
return root, "/".join(rem)
root, *rem = path.split('/')
return root, '/'.join(rem)

def upload_dataset(self, dataset_id: str) -> None:
path = get_dataset_path(dataset_id)
Expand Down Expand Up @@ -105,31 +105,41 @@ def upload_namespace(self, namespace: str) -> None:
)

def list_datasets(self, prefix: Optional[str] = None) -> Iterable[str]:
if prefix is not None: # TODO: support prefix
raise NotImplementedError("prefix is not supported yet")

for hf_dataset in self._api.list_datasets(author=self.name):
if prefix is not None:
group_name, _ = self._decompose_path(prefix)
else:
prefix = ''
group_name = None

hf_datasets = self._api.list_datasets(
author=self.name,
dataset_name=group_name
)
for group_info in hf_datasets:
try:
repo_metadata = self._api.hf_hub_download(
repo_id=hf_dataset.id,
repo_id=group_info.id,
filename=_NAMESPACE_METADATA_FILENAME,
repo_type="dataset",
)
except EntryNotFoundError:
try:
self._api.hf_hub_download(
repo_id=hf_dataset.id,
repo_id=group_info.id,
filename=f"data/{METADATA_FILE_NAME}",
repo_type="dataset",
)
yield hf_dataset.id
if group_info.id.startswith(prefix):
yield group_info.id
except Exception:
warnings.warn(f"Skipping {hf_dataset.id} as it is malformed.")
warnings.warn(f"Skipping {group_info.id} as it is malformed.")
else:
with open(repo_metadata) as f:
namespace_metadata = json.load(f)

yield from namespace_metadata.get("datasets", [])
group_datasets = namespace_metadata.get("datasets", [])
group_datasets = filter(lambda x: x.startswith(prefix), group_datasets)
yield from group_datasets

def download_dataset(self, dataset_id: Any, path: Path) -> None:
repo_id, path_in_repo = self._decompose_path(dataset_id)
Expand Down

0 comments on commit 4cfbf39

Please sign in to comment.