Skip to content

Commit

Permalink
Update tests
Browse files Browse the repository at this point in the history
  • Loading branch information
amyeroberts committed Jan 10, 2024
1 parent f2b2e45 commit 0779f42
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 19 deletions.
30 changes: 17 additions & 13 deletions src/transformers/models/oneformer/image_processing_oneformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,21 @@ def get_oneformer_resize_output_image_size(
return output_size


def prepare_metadata(repo_id, class_info_file):
def prepare_metadata(class_info):
metadata = {}
class_names = []
thing_ids = []
for key, info in class_info.items():
metadata[key] = info["name"]
class_names.append(info["name"])
if info["isthing"]:
thing_ids.append(int(key))
metadata["thing_ids"] = thing_ids
metadata["class_names"] = class_names
return metadata


def load_metadata(repo_id, class_info_file):
fname = os.path.join("" if repo_id is None else repo_id, class_info_file)

if not os.path.exists(fname) or not os.path.isfile(fname):
Expand All @@ -346,17 +360,7 @@ def prepare_metadata(repo_id, class_info_file):
with open(fname, "r") as f:
class_info = json.load(f)

metadata = {}
class_names = []
thing_ids = []
for key, info in class_info.items():
metadata[key] = info["name"]
class_names.append(info["name"])
if info["isthing"]:
thing_ids.append(int(key))
metadata["thing_ids"] = thing_ids
metadata["class_names"] = class_names
return metadata
return class_info


class OneFormerImageProcessor(BaseImageProcessor):
Expand Down Expand Up @@ -458,7 +462,7 @@ def __init__(
self.do_reduce_labels = do_reduce_labels
self.class_info_file = class_info_file
self.repo_path = repo_path
self.metadata = prepare_metadata(repo_path, class_info_file)
self.metadata = prepare_metadata(load_metadata(repo_path, class_info_file))
self.num_text = num_text

def resize(
Expand Down
10 changes: 4 additions & 6 deletions tests/models/oneformer/test_image_processing_oneformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,6 @@ def __init__(
self.image_mean = image_mean
self.image_std = image_std
self.class_info_file = class_info_file
self.metadata = prepare_metadata(repo_path, class_info_file)
self.num_text = num_text
self.repo_path = repo_path

Expand All @@ -95,7 +94,6 @@ def prepare_image_processor_dict(self):
"do_reduce_labels": self.do_reduce_labels,
"ignore_index": self.ignore_index,
"class_info_file": self.class_info_file,
"metadata": self.metadata,
"num_text": self.num_text,
}

Expand Down Expand Up @@ -320,21 +318,21 @@ def test_post_process_panoptic_segmentation(self):

def test_can_load_with_local_metadata(self):
# Create a temporary json file
metadata = {
class_info = {
"0": {"isthing": 0, "name": "foo"},
"1": {"isthing": 0, "name": "bar"},
"2": {"isthing": 1, "name": "baz"},
}
metadata = prepare_metadata(class_info)

with tempfile.TemporaryDirectory() as tmpdirname:
metadata_path = os.path.join(tmpdirname, "metadata.json")
with open(metadata_path, "w") as f:
json.dump(metadata, f)
json.dump(class_info, f)

config_dict = self.image_processor_dict()
config_dict = self.image_processor_dict
config_dict["class_info_file"] = metadata_path
config_dict["repo_path"] = tmpdirname

image_processor = self.image_processing_class(**config_dict)

self.assertEqual(image_processor.metadata, metadata)

0 comments on commit 0779f42

Please sign in to comment.