Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dataset CLI for db/import tasks #26

Merged
merged 3 commits into from
Oct 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/create_db.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,4 +30,4 @@ jobs:
run: poetry install

- name: Run create-db
run: poetry run create-db
run: poetry run cli db create_schema
10 changes: 8 additions & 2 deletions ddlitlab2024/dataset/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import logging
import os

from rich.logging import RichHandler

from ddlitlab2024 import LOGGING_PATH, SESSION_ID

MODULE_NAME: str = "dataset"
Expand All @@ -14,9 +16,13 @@
)

# Create additional logging config for the shell with configurable log level
console = logging.StreamHandler()
console = RichHandler(
log_time_format="%H:%M:%S",
show_path=False,
rich_tracebacks=True,
tracebacks_show_locals=True,
)
console.setLevel(os.environ.get("LOGLEVEL", "INFO"))
console.setFormatter(logging.Formatter("%(levelname)s: %(message)s"))

logger = logging.getLogger(MODULE_NAME)
logger.addHandler(console)
73 changes: 73 additions & 0 deletions ddlitlab2024/dataset/cli.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
import sys
from enum import Enum
from pathlib import Path

from tap import Tap

from ddlitlab2024 import DB_PATH


class ImportType(str, Enum):
ROS_BAG = "rosbag"


class CLICommand(str, Enum):
DB = "db"
IMPORT = "import"


class DBArgs(Tap):
create_schema: bool = False

def configure(self) -> None:
self.add_argument(
"create_schema",
type=bool,
help="Create the base database schema, if it doesn't exist",
nargs="?",
)


class ImportArgs(Tap):
import_type: ImportType
file: Path

def configure(self) -> None:
self.add_argument(
"import-type",
type=ImportType,
help="Type of import to perform",
)
self.add_argument(
"file",
type=Path,
help="File to import",
)


class CLIArgs(Tap):
dry_run: bool = False
db_path: str = DB_PATH # Path to the sqlite database file
version: bool = False # if set print version and exit

def __init__(self):
super().__init__(
description="ddlitlab dataset CLI",
underscores_to_dashes=True,
)

def configure(self) -> None:
self.add_subparsers(dest="command", help="Command to run")
self.add_subparser(CLICommand.DB.value, DBArgs, help="Database management commands")
self.add_subparser(CLICommand.IMPORT.value, ImportArgs, help="Import data into the database")

def print_help_and_exit(self) -> None:
self.print_help()
sys.exit(0)

def process_args(self) -> None:
if self.command == CLICommand.DB:
all_args = (self.create_schema,)

if not any(all_args):
self.print_help_and_exit()
33 changes: 33 additions & 0 deletions ddlitlab2024/dataset/db.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from sqlalchemy import Engine, create_engine
from sqlalchemy.orm import Session, sessionmaker

from ddlitlab2024.dataset import logger
from ddlitlab2024.dataset.models import Base


class Database:
def __init__(self, db_path: str):
self.db_path = db_path
self.engine: Engine = self._setup_sqlite()
self.session: Session | None = None

def _setup_sqlite(self) -> Engine:
return create_engine(f"sqlite:///{self.db_path}")

def _create_schema(self) -> None:
logger.info("Creating database schema")
Base.metadata.create_all(self.engine)
logger.info("Database schema created")

def create_session(self, create_schema: bool = True) -> Session:
logger.info("Setting up database session")
if create_schema:
self._create_schema()
return sessionmaker(bind=self.engine)()

def close_session(self) -> None:
if self.session:
self.session.close()
logger.info("Database session closed")
else:
logger.warning("No database session to close")
2 changes: 2 additions & 0 deletions ddlitlab2024/dataset/errors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
class CLIArgumentError(Exception):
"""Raised when the configuration of CLI arguments is not valid and execution is impossible"""
39 changes: 39 additions & 0 deletions ddlitlab2024/dataset/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
#!/usr/bin/env python3

import os
import sys

from rich.console import Console

from ddlitlab2024 import __version__
from ddlitlab2024.dataset import logger
from ddlitlab2024.dataset.cli import CLIArgs, CLICommand
from ddlitlab2024.dataset.db import Database

err_console = Console(stderr=True)


def main():
debug_mode = os.getenv("LOGLEVEL") == "DEBUG"

try:
logger.debug("Parsing CLI args...")
args: CLIArgs = CLIArgs().parse_args()
if args.version:
logger.info(f"running ddlitlab2024 CLI v{__version__}")
sys.exit(0)

if args.command == CLICommand.DB:
db = Database(args.db_path).create_session(args.create_schema)
logger.info(f"Database session created: {db}")

logger.info(f"CLI args: {args}")
sys.exit(0)
except Exception as e:
logger.error(e)
err_console.print_exception(show_locals=debug_mode)
sys.exit(1)


if __name__ == "__main__":
main()
27 changes: 2 additions & 25 deletions ddlitlab2024/dataset/schema.py → ddlitlab2024/dataset/models.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,11 @@
import argparse
from datetime import datetime
from enum import Enum
from typing import List, Optional

from sqlalchemy import Boolean, CheckConstraint, DateTime, Float, ForeignKey, Integer, String, create_engine
from sqlalchemy.orm import Mapped, declarative_base, mapped_column, relationship, sessionmaker
from sqlalchemy import Boolean, CheckConstraint, DateTime, Float, ForeignKey, Integer, String
from sqlalchemy.orm import Mapped, declarative_base, mapped_column, relationship
from sqlalchemy.types import LargeBinary

from ddlitlab2024 import DB_PATH
from ddlitlab2024.dataset import logger

Base = declarative_base()


Expand Down Expand Up @@ -229,22 +225,3 @@ class GameState(Base):
recording: Mapped["Recording"] = relationship("Recording", back_populates="game_states")

__table_args__ = (CheckConstraint(state.in_(RobotState.values())),)


def parse_args():
parser = argparse.ArgumentParser(description="Create the database schema")
parser.add_argument("--db-path", type=str, default=DB_PATH, help="Path to the database file")
return parser.parse_args()


def main():
logger.info("Creating database schema")
args = parse_args()
engine = create_engine(f"sqlite:///{args.db_path}")
Base.metadata.create_all(engine)
sessionmaker(bind=engine)()
logger.info("Database schema created")


if __name__ == "__main__":
main()
9 changes: 8 additions & 1 deletion ddlitlab2024/ml/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import logging
import os

from rich.logging import RichHandler

from ddlitlab2024 import LOGGING_PATH, SESSION_ID

MODULE_NAME: str = "ml"
Expand All @@ -14,7 +16,12 @@
)

# Create additional logging config for the shell with configurable log level
console = logging.StreamHandler()
console = RichHandler(
log_time_format="%H:%M:%S",
show_path=False,
rich_tracebacks=True,
tracebacks_show_locals=True,
)
console.setLevel(os.environ.get("LOGLEVEL", "INFO"))
console.setFormatter(logging.Formatter("%(levelname)s: %(message)s"))

Expand Down
Loading
Loading