Skip to content

Commit

Permalink
Merge branch 'main' into feature/torch_dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
Flova committed Nov 7, 2024
2 parents a5dd2f1 + 8b7e3d2 commit e346f14
Show file tree
Hide file tree
Showing 14 changed files with 998 additions and 72 deletions.
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
name: Test Database Creation
name: Create DB with dummy data

on:
push:
Expand Down Expand Up @@ -29,5 +29,8 @@ jobs:
- name: Install dependencies
run: poetry install

- name: Run create-db
run: poetry run cli db create_schema
- name: Create DB schema
run: poetry run cli db create-schema

- name: Populate DB with dummy data
run: poetry run cli db dummy-data -n 2
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ logs/
*.csv
*.sqlite
*.sqlite3
*.mcap
*/metadata.yaml

# Created by .ignore support plugin (hsz.mobi)
### JetBrains template
Expand Down
5 changes: 5 additions & 0 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
{
"cSpell.words": [
"bigendian",
"ddlitlab",
"mcap",
"nanosec",
"nullable",
"rclpy",
"rosbag",
"sessionmaker",
"sqlalchemy",
"sqlite"
Expand Down
3 changes: 2 additions & 1 deletion ddlitlab2024/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import importlib.metadata
import os
import sys
from pathlib import Path
from uuid import UUID, uuid4

_project_name: str = "ddlitlab2024"
Expand Down Expand Up @@ -39,4 +40,4 @@

SESSION_ID: UUID = uuid4()

DB_PATH: str = os.path.join(os.path.dirname(__file__), "dataset", "db.sqlite3")
DB_PATH: Path = Path.joinpath(Path(__file__).parent, "dataset", "db.sqlite3")
72 changes: 22 additions & 50 deletions ddlitlab2024/dataset/cli.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
import sys
import argparse
from enum import Enum
from pathlib import Path

from tap import Tap

from ddlitlab2024 import DB_PATH


Expand All @@ -16,58 +14,32 @@ class CLICommand(str, Enum):
IMPORT = "import"


class DBArgs(Tap):
create_schema: bool = False

def configure(self) -> None:
self.add_argument(
"create_schema",
type=bool,
help="Create the base database schema, if it doesn't exist",
nargs="?",
)
class CLIArgs:
def __init__(self):
self.parser = argparse.ArgumentParser(description="ddlitlab dataset CLI")

self.parser.add_argument("--dry-run", action="store_true", help="Dry run")
self.parser.add_argument("--db-path", type=Path, default=DB_PATH, help="Path to the sqlite database file")
self.parser.add_argument("--version", action="store_true", help="Print version and exit")

class ImportArgs(Tap):
import_type: ImportType
file: Path
subparsers = self.parser.add_subparsers(dest="command", help="Command to run")
# import_parser = subparsers.add_parser(CLICommand.IMPORT.value, help="Import data into the database")

def configure(self) -> None:
self.add_argument(
"import-type",
type=ImportType,
help="Type of import to perform",
)
self.add_argument(
"file",
type=Path,
help="File to import",
)
db_parser = subparsers.add_parser(CLICommand.DB.value, help="Database management commands")
db_subcommand_parser = db_parser.add_subparsers(dest="db_command", help="Database command")

db_subcommand_parser.add_parser("create-schema", help="Create the base database schema, if it doesn't exist.")

class CLIArgs(Tap):
dry_run: bool = False
db_path: str = DB_PATH # Path to the sqlite database file
version: bool = False # if set print version and exit

def __init__(self):
super().__init__(
description="ddlitlab dataset CLI",
underscores_to_dashes=True,
dummy_data_subparser = db_subcommand_parser.add_parser("dummy-data", help="Insert dummy data into the database")
dummy_data_subparser.add_argument(
"-n", "--num_recordings", type=int, default=10, help="Number of recordings to insert"
)

def configure(self) -> None:
self.add_subparsers(dest="command", help="Command to run")
self.add_subparser(CLICommand.DB.value, DBArgs, help="Database management commands")
self.add_subparser(CLICommand.IMPORT.value, ImportArgs, help="Import data into the database")

def print_help_and_exit(self) -> None:
self.print_help()
sys.exit(0)

def process_args(self) -> None:
if self.command == CLICommand.DB:
all_args = (self.create_schema,)
recording2mcap_subparser = db_subcommand_parser.add_parser(
"recording2mcap", help="Convert a recording to an mcap file"
)
recording2mcap_subparser.add_argument("recording", type=str, help="Recording to convert")
recording2mcap_subparser.add_argument("output_dir", type=Path, help="Output directory to write to")

if not any(all_args):
self.print_help_and_exit()
def parse_args(self) -> argparse.Namespace:
return self.parser.parse_args()
4 changes: 3 additions & 1 deletion ddlitlab2024/dataset/db.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from pathlib import Path

from sqlalchemy import Engine, create_engine
from sqlalchemy.orm import Session, sessionmaker

Expand All @@ -6,7 +8,7 @@


class Database:
def __init__(self, db_path: str):
def __init__(self, db_path: Path):
self.db_path = db_path
self.engine: Engine = self._setup_sqlite()
self.session: Session | None = None
Expand Down
161 changes: 161 additions & 0 deletions ddlitlab2024/dataset/dummy_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
import datetime
import math
import random

import numpy as np
from sqlalchemy.orm import Session

from ddlitlab2024.dataset import logger
from ddlitlab2024.dataset.models import (
GameState,
Image,
JointCommand,
JointState,
Recording,
RobotState,
Rotation,
TeamColor,
)


def insert_recordings(db: Session, n) -> list[int]:
logger.debug("Inserting recordings...")
for i in range(n):
db.add(
Recording(
allow_public=True,
original_file=f"dummy_original_file{i}",
team_name=f"dummy_team_name{i}",
team_color=random.choice(list(TeamColor)),
robot_type=f"dummy_robot_type{i}",
start_time=datetime.datetime.now(),
location=f"dummy_location{i}",
simulated=True,
img_width_scaling=1.0,
img_height_scaling=1.0,
),
)
db.flush() # Ensure the recording is written to the database and the ID is generated
recording = db.query(Recording).order_by(Recording._id.desc()).limit(n).all()
if recording is None:
raise ValueError("Failed to insert recordings")
return [r._id for r in reversed(recording)]


def insert_images(db: Session, recording_ids: list[int], n: int) -> None:
for recording_id in recording_ids:
# Get width and height from the recording
recording = db.query(Recording).get(recording_id)
if recording is None:
raise ValueError(f"Recording '{recording_id}' not found")
for i in range(n):
db.add(
Image(
stamp=float(i),
recording_id=recording_id,
data=np.random.randint(
0, 255, (recording.img_height, recording.img_width, 3), dtype=np.uint8
).tobytes(),
)
)


def insert_rotations(db: Session, recording_ids: list[int], n: int) -> None:
for recording_id in recording_ids:
for i in range(n):
db.add(
Rotation(
stamp=float(i),
recording_id=recording_id,
x=random.random(),
y=random.random(),
z=random.random(),
w=random.random(),
),
)


def insert_joint_states(db: Session, recording_ids: list[int], n: int) -> None:
for recording_id in recording_ids:
for i in range(n):
db.add(
JointState(
stamp=float(i),
recording_id=recording_id,
r_shoulder_pitch=random.random() * 2 * math.pi,
l_shoulder_pitch=random.random() * 2 * math.pi,
r_shoulder_roll=random.random() * 2 * math.pi,
l_shoulder_roll=random.random() * 2 * math.pi,
r_elbow=random.random() * 2 * math.pi,
l_elbow=random.random() * 2 * math.pi,
r_hip_yaw=random.random() * 2 * math.pi,
l_hip_yaw=random.random() * 2 * math.pi,
r_hip_roll=random.random() * 2 * math.pi,
l_hip_roll=random.random() * 2 * math.pi,
r_hip_pitch=random.random() * 2 * math.pi,
l_hip_pitch=random.random() * 2 * math.pi,
r_knee=random.random() * 2 * math.pi,
l_knee=random.random() * 2 * math.pi,
r_ankle_pitch=random.random() * 2 * math.pi,
l_ankle_pitch=random.random() * 2 * math.pi,
r_ankle_roll=random.random() * 2 * math.pi,
l_ankle_roll=random.random() * 2 * math.pi,
head_pan=random.random() * 2 * math.pi,
head_tilt=random.random() * 2 * math.pi,
),
)


def insert_joint_commands(db: Session, recording_ids: list[int], n: int) -> None:
for recording_id in recording_ids:
for i in range(n):
db.add(
JointCommand(
stamp=float(i),
recording_id=recording_id,
r_shoulder_pitch=random.random() * 2 * math.pi,
l_shoulder_pitch=random.random() * 2 * math.pi,
r_shoulder_roll=random.random() * 2 * math.pi,
l_shoulder_roll=random.random() * 2 * math.pi,
r_elbow=random.random() * 2 * math.pi,
l_elbow=random.random() * 2 * math.pi,
r_hip_yaw=random.random() * 2 * math.pi,
l_hip_yaw=random.random() * 2 * math.pi,
r_hip_roll=random.random() * 2 * math.pi,
l_hip_roll=random.random() * 2 * math.pi,
r_hip_pitch=random.random() * 2 * math.pi,
l_hip_pitch=random.random() * 2 * math.pi,
r_knee=random.random() * 2 * math.pi,
l_knee=random.random() * 2 * math.pi,
r_ankle_pitch=random.random() * 2 * math.pi,
l_ankle_pitch=random.random() * 2 * math.pi,
r_ankle_roll=random.random() * 2 * math.pi,
l_ankle_roll=random.random() * 2 * math.pi,
head_pan=random.random() * 2 * math.pi,
head_tilt=random.random() * 2 * math.pi,
),
)


def insert_game_states(db: Session, recording_ids: list[int], n: int) -> None:
for recording_id in recording_ids:
for i in range(n):
db.add(
GameState(
stamp=float(i),
recording_id=recording_id,
state=random.choice(list(RobotState)),
),
)


def insert_dummy_data(db: Session, n: int = 10) -> None:
logger.info("Inserting dummy data...")
recording_ids: list[int] = insert_recordings(db, n)
insert_images(db, recording_ids, n)
insert_rotations(db, recording_ids, n)
insert_joint_states(db, recording_ids, n)
insert_joint_commands(db, recording_ids, n)
insert_game_states(db, recording_ids, n)
db.commit()
logger.info(f"Dummy data inserted. Recording IDs: {recording_ids}")
Loading

0 comments on commit e346f14

Please sign in to comment.