Skip to content

Commit

Permalink
fixes and make it work
Browse files Browse the repository at this point in the history
  • Loading branch information
amitkparekh committed Dec 4, 2023
1 parent 16713a1 commit 8026b4c
Show file tree
Hide file tree
Showing 6 changed files with 41 additions and 206 deletions.
Empty file removed notebooks/.gitkeep
Empty file.
171 changes: 0 additions & 171 deletions notebooks/demo_inference.ipynb

This file was deleted.

35 changes: 33 additions & 2 deletions src/emma_perception/api/datamodels.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,31 @@
import argparse
import json
from dataclasses import dataclass
from typing import Literal, Optional
from typing import Literal, Optional, TypedDict

import torch
from pydantic import BaseSettings

from emma_perception.constants import (
VINVL_ALFRED_CLASS_MAP_PATH,
VINVL_CLASS_MAP_PATH,
VINVL_SIMBOT_CLASS_MAP_PATH,
VINVL_SIMBOT_CONFIG_PATH,
)
from emma_perception.models.simbot_entity_classifier import SimBotMLPEntityClassifier
from emma_perception.models.vinvl_extractor import VinVLExtractor, VinVLTransform


ClassmapType = Literal["alfred", "original", "simbot"]


class AlfredClassMap(TypedDict):
"""Classmap for class to idx."""

label_to_idx: dict[str, int]
idx_to_label: dict[str, str]


class ApiSettings(BaseSettings):
"""Common settings, which can also be got from the environment vars."""

Expand All @@ -26,6 +40,20 @@ class ApiSettings(BaseSettings):
# batch size used to extract visual features
batch_size: int = 2

def object_classmap(self) -> AlfredClassMap:
"""Get the mapping of objects to class indices."""
if self.classmap_type == "alfred":
classmap_file = VINVL_ALFRED_CLASS_MAP_PATH
elif self.classmap_type == "original":
classmap_file = VINVL_CLASS_MAP_PATH
elif self.classmap_type == "simbot":
classmap_file = VINVL_SIMBOT_CLASS_MAP_PATH
else:
raise ValueError(f"Invalid classmap type: {self.classmap_type}")

with open(classmap_file) as in_file:
return json.load(in_file)


@dataclass(init=False)
class ApiStore:
Expand All @@ -41,7 +69,10 @@ def parse_api_args() -> argparse.Namespace:
"""Defines arguments."""
parser = argparse.ArgumentParser()
parser.add_argument(
"--config_file", required=True, metavar="FILE", help="path to VinVL config file"
"--config_file",
metavar="FILE",
help="path to VinVL config file",
default=VINVL_SIMBOT_CONFIG_PATH.as_posix(),
)
parser.add_argument(
"opts",
Expand Down
9 changes: 6 additions & 3 deletions src/emma_perception/api/extract_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,15 @@
from torch.utils.data import DataLoader

from emma_perception.api.api_dataset import ApiDataset
from emma_perception.api.datamodels import ApiStore
from emma_perception.constants import OBJECT_CLASSMAP
from emma_perception.api.datamodels import ApiSettings, ApiStore
from emma_perception.models.simbot_entity_classifier import SimBotMLPEntityClassifier
from emma_perception.models.vinvl_extractor import VinVLExtractor


settings = ApiSettings()
object_classmap = settings.object_classmap()


@torch.inference_mode()
def get_batch_features(
extractor: VinVLExtractor,
Expand Down Expand Up @@ -40,7 +43,7 @@ def get_batch_features(

bbox_probas = predictions.get_field("scores_all")
idx_labels = bbox_probas.argmax(dim=1)
class_labels = [OBJECT_CLASSMAP["idx_to_label"][str(idx.item())] for idx in idx_labels]
class_labels = [object_classmap["idx_to_label"][str(idx.item())] for idx in idx_labels]
entity_labels = None
bbox_features = predictions.get_field("box_features")

Expand Down
2 changes: 1 addition & 1 deletion src/emma_perception/commands/download_checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,6 @@ def download_vinvl_checkpoint(
*, hf_repo_id: str = HF_REPO_ID, file_name: str = CHECKPOINT_NAME
) -> Path:
"""Download the checkpoint from VinVL and put it where we expect it."""
file_path = download_file(hf_repo_id=hf_repo_id, file_name=file_name)
file_path = download_file(repo_id=hf_repo_id, repo_type="model", filename=file_name)
logger.info(f"Downloaded {file_name}")
return file_path
30 changes: 1 addition & 29 deletions src/emma_perception/constants/__init__.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,12 @@
import json
from pathlib import Path
from typing import TypedDict

from emma_perception.api.datamodels import ApiSettings, ClassmapType


CONSTANTS_DIR_PATH = Path(__file__).parent.resolve()


VINVL_CONFIG_PATH = CONSTANTS_DIR_PATH.joinpath("vinvl_x152c4.yaml")
VINVL_ALFRED_CONFIG_PATH = CONSTANTS_DIR_PATH.joinpath("vinvl_x152c4_alfred.yaml")


class AlfredClassMap(TypedDict):
label_to_idx: dict[str, int]
idx_to_label: dict[str, str]
VINVL_SIMBOT_CONFIG_PATH = CONSTANTS_DIR_PATH.joinpath("vinvl_x152c4_simbot_customised.yaml")


VINVL_ALFRED_CLASS_MAP_PATH = CONSTANTS_DIR_PATH.joinpath("vinvl_x152c4_alfred_classmap.json")
Expand All @@ -29,23 +21,3 @@ class AlfredClassMap(TypedDict):
SIMBOT_ENTITY_MLPCLASSIFIER_CLASSMAP_PATH = CONSTANTS_DIR_PATH.joinpath(
"entity_classlabel_map.json"
)


def _classmap(classmap_type: ClassmapType) -> AlfredClassMap:
# Returns the map that will be used by the object detector to determine the object class.
if classmap_type == "alfred":
classmap_file = VINVL_ALFRED_CLASS_MAP_PATH
elif classmap_type == "original":
classmap_file = VINVL_CLASS_MAP_PATH
elif classmap_type == "simbot":
classmap_file = VINVL_SIMBOT_CLASS_MAP_PATH
else:
raise ValueError(f"Invalid classmap type: {classmap_type}")

with open(classmap_file) as in_file:
return json.load(in_file)


settings = ApiSettings()

OBJECT_CLASSMAP = _classmap(settings.classmap_type)

0 comments on commit 8026b4c

Please sign in to comment.