Skip to content

Commit

Permalink
more changes
Browse files Browse the repository at this point in the history
  • Loading branch information
amitkparekh committed Dec 4, 2023
1 parent 067c422 commit e20855a
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 60 deletions.
12 changes: 8 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,19 +54,23 @@ python src/emma_perception/commands/run_server.py

### Extracting features

For training things, we need to extract the features for each image.

#### For the pretrained datasets
Here's the command you can use to extract features from images. Obviously, you can change the paths to the folder of images, and the output dir, and whatever else you want.

```bash
python src/emma_perception/commands/extract_visual_features.py --images_dir <path_to_images> --output_dir <path to output dir>
```

<details>
<summary>`argparse` arguments for the command</summary>

#### For the Alexa Arena

</details>

```bash
```
#### Extracting features for the Alexa Arena

If you want to use the fine-tuned model to extract features with the model we trained on the Alexa Arena, just add `--is_arena` onto the above command.

### Developer tooling

Expand Down
2 changes: 1 addition & 1 deletion src/emma_perception/commands/download_checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def download_arena_checkpoint(


def download_vinvl_checkpoint(
*, hf_repo_id: str = HF_REPO_ID, file_name: str = ARENA_CHECKPOINT_NAME
*, hf_repo_id: str = HF_REPO_ID, file_name: str = VINVL_CHECKPOINT_NAME
) -> Path:
"""Download the pre-trained VinVL checkpoint."""
file_path = download_file(repo_id=hf_repo_id, repo_type="model", filename=file_name)
Expand Down
85 changes: 30 additions & 55 deletions src/emma_perception/commands/extract_visual_features.py
Original file line number Diff line number Diff line change
@@ -1,59 +1,45 @@
import argparse
from typing import Union
from ast import arg

Check failure on line 2 in src/emma_perception/commands/extract_visual_features.py

View workflow job for this annotation

GitHub Actions / flake8

[flake8] src/emma_perception/commands/extract_visual_features.py#L2

F401 'ast.arg' imported but unused
Raw output
./src/emma_perception/commands/extract_visual_features.py:2:1: F401 'ast.arg' imported but unused

from maskrcnn_benchmark.config import cfg
from pytorch_lightning import Trainer
from scene_graph_benchmark.config import sg_cfg

from emma_perception.callbacks.callbacks import VisualExtractionCacheCallback
from emma_perception.datamodules.visual_extraction_dataset import (
ImageDataset,
PredictDataModule,
VideoFrameDataset,
from emma_perception.commands.download_checkpoints import (
download_arena_checkpoint,
download_vinvl_checkpoint,
)
from emma_perception.datamodules.visual_extraction_dataset import ImageDataset, PredictDataModule
from emma_perception.models.vinvl_extractor import VinVLExtractor, VinVLTransform


def parse_args() -> argparse.Namespace:
"""Defines arguments."""
parser = argparse.ArgumentParser(prog="PROG")

parser = argparse.ArgumentParser()
parser = Trainer.add_argparse_args(parser) # type: ignore[assignment]
parser.add_argument("-i", "--input_path", required=True, help="Path to input dataset")
parser.add_argument("-b", "--batch_size", type=int, default=2)
parser.add_argument("-w", "--num_workers", type=int, default=0)
parser.add_argument("-cs", "--cache_suffix", default=".pt", help="Extension of cached files")
parser.add_argument("--config_file", metavar="FILE", help="path to VinVL config file")
parser.add_argument("--return_predictions", action="store_true")

parser.add_argument(
"--downsample",
type=int,
default=0,
help="Downsampling factor for videos. If 0 then no downsampling is performed",
)

parser.add_argument(
"-c", "--cache", default="storage/data/cache", help="Path to store visual features"
"-i",
"--images_dir",
required=True,
help="Path to a folder of images to extract features from",
)

parser.add_argument(
"-d", "--dataset", required=True, choices=["images", "frames"], help="Dataset type"
"--is_arena",
action="store_true",
help="If we are extracting features from the Arena images, use the Arena checkpoint",
)
parser.add_argument("-b", "--batch_size", type=int, default=2)
parser.add_argument("-w", "--num_workers", type=int, default=0)
parser.add_argument(
"-a",
"--ann_csv",
help="Path to annotation csv file. Used for video datasets to select only the frames that have annotations",
"-c", "--output_dir", default="storage/data/cache", help="Path to store visual features"
)

parser.add_argument(
"-at",
"--ann_type",
choices=["epic_kitchens"],
default="epic_kitchens",
help="Annotation parser for video datasets",
"--num_gpus",
type=int,
default=None,
help="Number of GPUs to use for visual feature extraction",
)

parser.add_argument(
"opts",
default=None,
Expand All @@ -75,34 +61,23 @@ def main() -> None:
cfg.merge_from_list(args.opts)
cfg.freeze()

extractor = VinVLExtractor(cfg=cfg)
transform = VinVLTransform(cfg=cfg)

dataset: Union[ImageDataset, VideoFrameDataset]
if args.dataset == "images":
dataset = ImageDataset(input_path=args.input_path, preprocess_transform=transform)
elif args.dataset == "frames":
dataset = VideoFrameDataset(
input_path=args.input_path,
ann_csv=args.ann_csv,
ann_type=args.ann_type,
preprocess_transform=transform,
downsample=args.downsample,
)
if args.is_arena:
cfg.MODEL.WEIGHT = download_arena_checkpoint().as_posix()
else:
raise OSError(f"Unsupported dataset type {args.dataset}")
cfg.MODEL.WEIGHT = download_vinvl_checkpoint().as_posix()

dataset = ImageDataset(
input_path=args.images_dir, preprocess_transform=VinVLTransform(cfg=cfg)
)
dm = PredictDataModule(
dataset=dataset, batch_size=args.batch_size, num_workers=args.num_workers
)
extractor = VinVLExtractor(cfg=cfg)
trainer = Trainer(
gpus=args.gpus,
callbacks=[
VisualExtractionCacheCallback(cache_dir=args.cache, cache_suffix=args.cache_suffix)
],
profiler="advanced",
gpus=args.num_gpus,
callbacks=[VisualExtractionCacheCallback(cache_dir=args.output_dir, cache_suffix=".pt")],
)
trainer.predict(extractor, dm, return_predictions=args.return_predictions)
trainer.predict(extractor, dm)


if __name__ == "__main__":
Expand Down
2 changes: 2 additions & 0 deletions src/emma_perception/constants/vinvl_x152c4.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ MODEL:
SCORE_THRESH: 0.2 # 0.0001
DETECTIONS_PER_IMG: 36 # 600
MIN_DETECTIONS_PER_IMG: 10
NMS_FILTER: 1
ROI_BOX_HEAD:
NUM_CLASSES: 1595
ROI_ATTRIBUTE_HEAD:
Expand Down Expand Up @@ -52,6 +53,7 @@ TEST:
TSV_SAVE_SUBSET: ["rect", "class", "conf", "feature"]
OUTPUT_FEATURE: True
GATHER_ON_CPU: True
IGNORE_BOX_REGRESSION: False
OUTPUT_DIR: "./output/X152C5_test"
DATA_DIR: "./datasets"
DISTRIBUTED_BACKEND: "gloo"

0 comments on commit e20855a

Please sign in to comment.