Skip to content

Commit

Permalink
Merge pull request #223 from InnopolisUni/pose_estimation
Browse files Browse the repository at this point in the history
pose estimation dataset and model
  • Loading branch information
InnopolisUni authored Jun 11, 2024
2 parents 9b0d6ee + 23266b3 commit 8fc36fa
Show file tree
Hide file tree
Showing 5 changed files with 88 additions and 26 deletions.
31 changes: 31 additions & 0 deletions config/datasets/detection/coco_pose_recognition.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
task:
- image-detection

name: pose recognition dataset
description: "COCO Pose dataset for human pose recognitoon. See more: https://cocodataset.org/#keypoints-2017"
markup_info: There is a total of 2346 unique images, format compatible with ultralytics
date_time: 22.05.2024

_target_: innofw.core.integrations.ultralytics.datamodule.UltralyticsDataModuleAdapter

train:
source: https://api.blackhole.ai.innopolis.university/public-datasets/coco_pose_recognition/train.zip
target: ./data/coco_pose/train/

test:
source: https://api.blackhole.ai.innopolis.university/public-datasets/coco_pose_recognition/test.zip
target: ./data/coco_pose/test/

infer:
source: https://api.blackhole.ai.innopolis.university/public-datasets/coco_pose_recognition/test.zip
target: ./data/coco_pose/test/

num_workers: 8

val_size: 0.2
channels_num: 3
image_size: 512
num_classes: 1
is_keypoint: True
names:
- person
21 changes: 21 additions & 0 deletions config/experiments/detection/KG_220524_coco.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# @package _global_
defaults:
- override /datasets: detection/coco_pose_recognition.yaml
- override /models: detection/yolo8_pose.yaml

# devices: 1
epochs: 3
accelerator: gpu
batch_size: 32
project: coco_pose_recognition
random_seed: 0
task: image-detection
weights_freq: 1


wandb:
enable: True
project: coco_pose_recognition
entity: "k-galliamov"
group: none
job_type: training
4 changes: 4 additions & 0 deletions config/models/detection/yolo8_pose.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
_target_: ultralytics.YOLO
description: yolov8 pose by ultralytics
model: yolov8n-pose.pt
name: yolov8
56 changes: 31 additions & 25 deletions innofw/core/integrations/ultralytics/datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,23 +39,23 @@ class UltralyticsDataModuleAdapter(BaseDataModule):
framework = [Frameworks.ultralytics]

def __init__(
self,
train: Optional[str],
# val: Optional[str], # todo: add this
test: Optional[str],
infer: Optional[str],
num_workers: int,
image_size: int,
num_classes: int,
names: List[str],
batch_size: int = 4,
val_size: float = 0.2,
augmentations=None,
stage=False,
channels_num: int = 3,
random_state: int = 42,
*args,
**kwargs,
self,
train: Optional[str],
# val: Optional[str], # todo: add this
test: Optional[str],
infer: Optional[str],
num_workers: int,
image_size: int,
num_classes: int,
names: List[str],
batch_size: int = 4,
val_size: float = 0.2,
augmentations=None,
stage=False,
channels_num: int = 3,
random_state: int = 42,
*args,
**kwargs,
):
"""
Arguments:
Expand All @@ -73,7 +73,8 @@ def __init__(
if self.train:
self.train_source = Path(self.train)
# # In this datamodule, the train source should be the folder train itself not the folder "train/images"
if str(self.train_source).endswith("images") or str(self.train_source).endswith("labels"):
if str(self.train_source).endswith("images") or str(self.train_source).endswith(
"labels"):
self.train_source = Path(str(self.train_source)[:-7])
if self.test:
self.test_source = Path(self.test)
Expand All @@ -86,7 +87,8 @@ def __init__(
if not (type(self.infer) == str and self.infer.startswith("rts"))
else self.infer
)
if str(self.infer_source).endswith("images") or str(self.infer_source).endswith("labels"):
if str(self.infer_source).endswith("images") or str(self.infer_source).endswith(
"labels"):
self.infer_source = Path(str(self.infer_source)[:-7])

self.batch_size = batch_size
Expand All @@ -97,6 +99,7 @@ def __init__(
self.names = names
self.random_state = random_state
self.augmentations = augmentations
self.is_keypoint = kwargs.get("is_keypoint", False)

def setup_train_test_val(self, **kwargs):
"""
Expand Down Expand Up @@ -146,7 +149,7 @@ def setup_train_test_val(self, **kwargs):
img_files = list(train_img_path.iterdir())
label_files = list(train_lbl_path.iterdir())
assert (
len(label_files) == len(img_files) != 0
len(label_files) == len(img_files) != 0
), "number of images and labels should be the same"

# sort the files so that the images and labels are in the same order
Expand All @@ -168,7 +171,7 @@ def setup_train_test_val(self, **kwargs):

# Creating the images directory
for files, folder_name in zip(
[train_img_files, val_img_files], ["train", "val"]
[train_img_files, val_img_files], ["train", "val"]
):
# create a folder
new_path = new_img_path / folder_name
Expand All @@ -180,7 +183,7 @@ def setup_train_test_val(self, **kwargs):

# Creating the labels directory
for files, folder_name in zip(
[train_label_files, val_label_files], ["train", "val"]
[train_label_files, val_label_files], ["train", "val"]
):
# create a folder
new_path = new_lbl_path / folder_name
Expand All @@ -201,14 +204,17 @@ def setup_train_test_val(self, **kwargs):
file.write(f"val: {self.val_dataset}\n")
file.write(f"test: {self.test_dataset}\n")

if self.is_keypoint:
file.write(
"kpt_shape: [17, 3]\nflip_idx: [0, 2, 1, 4, 3, 6, 5, 8, 7, 10, 9, 12, 11, 14, 13, 16, 15]\n")
file.write(f"nc: {self.num_classes}\n")
file.write(f"names: {self.names}\n")

def setup_infer(self):
if (
type(self.infer_source) == str
and self.infer_source.startswith("rts")
or Path(self.infer_source).is_file()
type(self.infer_source) == str
and self.infer_source.startswith("rts")
or Path(self.infer_source).is_file()
):
return
# root_dir
Expand Down
2 changes: 1 addition & 1 deletion pckg_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def install_and_import(package, version="", params="", link="", packageimportnam
[sys.executable, "-m", "pip", *installation_cmd_list]
)
finally:
if packageimportname is None:
if packageimportname is None or packageimportname == "":
globals()[package] = importlib.import_module(package)
else:
globals()[packageimportname] = importlib.import_module(packageimportname)
Expand Down

0 comments on commit 8fc36fa

Please sign in to comment.