From b6b2a5663920d2c4b38d82127a487b2201e3b793 Mon Sep 17 00:00:00 2001 From: texhnolyze Date: Thu, 31 Oct 2024 14:25:12 +0100 Subject: [PATCH 1/3] refactor(logging): add rich logging handler for nicer outputs and the possibility of better exception handling on exceptions showing code context and locals --- ddlitlab2024/dataset/__init__.py | 10 ++- ddlitlab2024/ml/__init__.py | 9 ++- poetry.lock | 130 +++++++++++++++++++++++-------- pyproject.toml | 1 + 4 files changed, 116 insertions(+), 34 deletions(-) diff --git a/ddlitlab2024/dataset/__init__.py b/ddlitlab2024/dataset/__init__.py index e19abb4..b3cc01f 100644 --- a/ddlitlab2024/dataset/__init__.py +++ b/ddlitlab2024/dataset/__init__.py @@ -1,6 +1,8 @@ import logging import os +from rich.logging import RichHandler + from ddlitlab2024 import LOGGING_PATH, SESSION_ID MODULE_NAME: str = "dataset" @@ -14,9 +16,13 @@ ) # Create additional logging config for the shell with configurable log level -console = logging.StreamHandler() +console = RichHandler( + log_time_format="%H:%M:%S", + show_path=False, + rich_tracebacks=True, + tracebacks_show_locals=True, +) console.setLevel(os.environ.get("LOGLEVEL", "INFO")) -console.setFormatter(logging.Formatter("%(levelname)s: %(message)s")) logger = logging.getLogger(MODULE_NAME) logger.addHandler(console) diff --git a/ddlitlab2024/ml/__init__.py b/ddlitlab2024/ml/__init__.py index 4c4f76d..3a37a2b 100644 --- a/ddlitlab2024/ml/__init__.py +++ b/ddlitlab2024/ml/__init__.py @@ -1,6 +1,8 @@ import logging import os +from rich.logging import RichHandler + from ddlitlab2024 import LOGGING_PATH, SESSION_ID MODULE_NAME: str = "ml" @@ -14,7 +16,12 @@ ) # Create additional logging config for the shell with configurable log level -console = logging.StreamHandler() +console = RichHandler( + log_time_format="%H:%M:%S", + show_path=False, + rich_tracebacks=True, + tracebacks_show_locals=True, +) console.setLevel(os.environ.get("LOGLEVEL", "INFO")) console.setFormatter(logging.Formatter("%(levelname)s: %(message)s")) diff --git a/poetry.lock b/poetry.lock index 580e1e1..a30e391 100644 --- a/poetry.lock +++ b/poetry.lock @@ -493,13 +493,13 @@ test = ["objgraph", "psutil"] [[package]] name = "huggingface-hub" -version = "0.26.1" +version = "0.26.2" description = "Client library to download and publish models, datasets and other repos on the huggingface.co hub" optional = false python-versions = ">=3.8.0" files = [ - {file = "huggingface_hub-0.26.1-py3-none-any.whl", hash = "sha256:5927a8fc64ae68859cd954b7cc29d1c8390a5e15caba6d3d349c973be8fdacf3"}, - {file = "huggingface_hub-0.26.1.tar.gz", hash = "sha256:414c0d9b769eecc86c70f9d939d0f48bb28e8461dd1130021542eff0212db890"}, + {file = "huggingface_hub-0.26.2-py3-none-any.whl", hash = "sha256:98c2a5a8e786c7b2cb6fdeb2740893cba4d53e312572ed3d8afafda65b128c46"}, + {file = "huggingface_hub-0.26.2.tar.gz", hash = "sha256:b100d853465d965733964d123939ba287da60a547087783ddff8a323f340332b"}, ] [package.dependencies] @@ -702,6 +702,30 @@ files = [ {file = "kiwisolver-1.4.7.tar.gz", hash = "sha256:9893ff81bd7107f7b685d3017cc6583daadb4fc26e4a888350df530e41980a60"}, ] +[[package]] +name = "markdown-it-py" +version = "3.0.0" +description = "Python port of markdown-it. Markdown parsing, done right!" +optional = false +python-versions = ">=3.8" +files = [ + {file = "markdown-it-py-3.0.0.tar.gz", hash = "sha256:e3f60a94fa066dc52ec76661e37c851cb232d92f9886b15cb560aaada2df8feb"}, + {file = "markdown_it_py-3.0.0-py3-none-any.whl", hash = "sha256:355216845c60bd96232cd8d8c40e8f9765cc86f46880e43a8fd22dc1a1a8cab1"}, +] + +[package.dependencies] +mdurl = ">=0.1,<1.0" + +[package.extras] +benchmarking = ["psutil", "pytest", "pytest-benchmark"] +code-style = ["pre-commit (>=3.0,<4.0)"] +compare = ["commonmark (>=0.9,<1.0)", "markdown (>=3.4,<4.0)", "mistletoe (>=1.0,<2.0)", "mistune (>=2.0,<3.0)", "panflute (>=2.3,<3.0)"] +linkify = ["linkify-it-py (>=1,<3)"] +plugins = ["mdit-py-plugins"] +profiling = ["gprof2dot"] +rtd = ["jupyter_sphinx", "mdit-py-plugins", "myst-parser", "pyyaml", "sphinx", "sphinx-copybutton", "sphinx-design", "sphinx_book_theme"] +testing = ["coverage", "pytest", "pytest-cov", "pytest-regressions"] + [[package]] name = "markupsafe" version = "3.0.2" @@ -835,6 +859,17 @@ python-dateutil = ">=2.7" [package.extras] dev = ["meson-python (>=0.13.1)", "numpy (>=1.25)", "pybind11 (>=2.6)", "setuptools (>=64)", "setuptools_scm (>=7)"] +[[package]] +name = "mdurl" +version = "0.1.2" +description = "Markdown URL utilities" +optional = false +python-versions = ">=3.7" +files = [ + {file = "mdurl-0.1.2-py3-none-any.whl", hash = "sha256:84008a41e51615a49fc9966191ff91509e3c40b939176e643fd50a5c2196b8f8"}, + {file = "mdurl-0.1.2.tar.gz", hash = "sha256:bb413d29f5eea38f31dd4754dd7377d4465116fb207585f97bf925588687c1ba"}, +] + [[package]] name = "mpmath" version = "1.3.0" @@ -1277,6 +1312,20 @@ tests = ["check-manifest", "coverage", "defusedxml", "markdown2", "olefile", "pa typing = ["typing-extensions"] xmp = ["defusedxml"] +[[package]] +name = "pygments" +version = "2.18.0" +description = "Pygments is a syntax highlighting package written in Python." +optional = false +python-versions = ">=3.8" +files = [ + {file = "pygments-2.18.0-py3-none-any.whl", hash = "sha256:b8e6aca0523f3ab76fee51799c488e38782ac06eafcf95e7ba832985c8e7b13a"}, + {file = "pygments-2.18.0.tar.gz", hash = "sha256:786ff802f32e91311bff3889f6e9a86e81505fe99f2735bb6d60ae0c5004f199"}, +] + +[package.extras] +windows-terminal = ["colorama (>=0.4.6)"] + [[package]] name = "pyparsing" version = "3.2.0" @@ -1502,6 +1551,25 @@ urllib3 = ">=1.21.1,<3" socks = ["PySocks (>=1.5.6,!=1.5.7)"] use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"] +[[package]] +name = "rich" +version = "13.9.3" +description = "Render rich text, tables, progress bars, syntax highlighting, markdown and more to the terminal" +optional = false +python-versions = ">=3.8.0" +files = [ + {file = "rich-13.9.3-py3-none-any.whl", hash = "sha256:9836f5096eb2172c9e77df411c1b009bace4193d6a481d534fea75ebba758283"}, + {file = "rich-13.9.3.tar.gz", hash = "sha256:bc1e01b899537598cf02579d2b9f4a415104d3fc439313a7a2c165d76557a08e"}, +] + +[package.dependencies] +markdown-it-py = ">=2.2.0" +pygments = ">=2.13.0,<3.0.0" +typing-extensions = {version = ">=4.0.0,<5.0", markers = "python_version < \"3.11\""} + +[package.extras] +jupyter = ["ipywidgets (>=7.5.1,<9)"] + [[package]] name = "safetensors" version = "0.4.5" @@ -1636,23 +1704,23 @@ torch = ["safetensors[numpy]", "torch (>=1.10)"] [[package]] name = "setuptools" -version = "75.2.0" +version = "75.3.0" description = "Easily download, build, install, upgrade, and uninstall Python packages" optional = false python-versions = ">=3.8" files = [ - {file = "setuptools-75.2.0-py3-none-any.whl", hash = "sha256:a7fcb66f68b4d9e8e66b42f9876150a3371558f98fa32222ffaa5bced76406f8"}, - {file = "setuptools-75.2.0.tar.gz", hash = "sha256:753bb6ebf1f465a1912e19ed1d41f403a79173a9acf66a42e7e6aec45c3c16ec"}, + {file = "setuptools-75.3.0-py3-none-any.whl", hash = "sha256:f2504966861356aa38616760c0f66568e535562374995367b4e69c7143cf6bcd"}, + {file = "setuptools-75.3.0.tar.gz", hash = "sha256:fba5dd4d766e97be1b1681d98712680ae8f2f26d7881245f2ce9e40714f1a686"}, ] [package.extras] check = ["pytest-checkdocs (>=2.4)", "pytest-ruff (>=0.2.1)", "ruff (>=0.5.2)"] -core = ["importlib-metadata (>=6)", "importlib-resources (>=5.10.2)", "jaraco.collections", "jaraco.functools", "jaraco.text (>=3.7)", "more-itertools", "more-itertools (>=8.8)", "packaging", "packaging (>=24)", "platformdirs (>=2.6.2)", "tomli (>=2.0.1)", "wheel (>=0.43.0)"] +core = ["importlib-metadata (>=6)", "importlib-resources (>=5.10.2)", "jaraco.collections", "jaraco.functools", "jaraco.text (>=3.7)", "more-itertools", "more-itertools (>=8.8)", "packaging", "packaging (>=24)", "platformdirs (>=4.2.2)", "tomli (>=2.0.1)", "wheel (>=0.43.0)"] cover = ["pytest-cov"] doc = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "pygments-github-lexers (==0.0.5)", "pyproject-hooks (!=1.1)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-favicon", "sphinx-inline-tabs", "sphinx-lint", "sphinx-notfound-page (>=1,<2)", "sphinx-reredirects", "sphinxcontrib-towncrier", "towncrier (<24.7)"] enabler = ["pytest-enabler (>=2.2)"] -test = ["build[virtualenv] (>=1.0.3)", "filelock (>=3.4.0)", "ini2toml[lite] (>=0.14)", "jaraco.develop (>=7.21)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "jaraco.test", "packaging (>=23.2)", "pip (>=19.1)", "pyproject-hooks (!=1.1)", "pytest (>=6,!=8.1.*)", "pytest-home (>=0.5)", "pytest-perf", "pytest-subprocess", "pytest-timeout", "pytest-xdist (>=3)", "tomli-w (>=1.0.0)", "virtualenv (>=13.0.0)", "wheel (>=0.44.0)"] -type = ["importlib-metadata (>=7.0.2)", "jaraco.develop (>=7.21)", "mypy (==1.11.*)", "pytest-mypy"] +test = ["build[virtualenv] (>=1.0.3)", "filelock (>=3.4.0)", "ini2toml[lite] (>=0.14)", "jaraco.develop (>=7.21)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "jaraco.test (>=5.5)", "packaging (>=23.2)", "pip (>=19.1)", "pyproject-hooks (!=1.1)", "pytest (>=6,!=8.1.*)", "pytest-home (>=0.5)", "pytest-perf", "pytest-subprocess", "pytest-timeout", "pytest-xdist (>=3)", "tomli-w (>=1.0.0)", "virtualenv (>=13.0.0)", "wheel (>=0.44.0)"] +type = ["importlib-metadata (>=7.0.2)", "jaraco.develop (>=7.21)", "mypy (==1.12.*)", "pytest-mypy"] [[package]] name = "six" @@ -1779,28 +1847,28 @@ dev = ["hypothesis (>=6.70.0)", "pytest (>=7.1.0)"] [[package]] name = "torch" -version = "2.5.0" +version = "2.5.1" description = "Tensors and Dynamic neural networks in Python with strong GPU acceleration" optional = false python-versions = ">=3.8.0" files = [ - {file = "torch-2.5.0-cp310-cp310-manylinux1_x86_64.whl", hash = "sha256:7f179373a047b947dec448243f4e6598a1c960fa3bb978a9a7eecd529fbc363f"}, - {file = "torch-2.5.0-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:15fbc95e38d330e5b0ef1593b7bc0a19f30e5bdad76895a5cffa1a6a044235e9"}, - {file = "torch-2.5.0-cp310-cp310-win_amd64.whl", hash = "sha256:f499212f1cffea5d587e5f06144630ed9aa9c399bba12ec8905798d833bd1404"}, - {file = "torch-2.5.0-cp310-none-macosx_11_0_arm64.whl", hash = "sha256:c54db1fade17287aabbeed685d8e8ab3a56fea9dd8d46e71ced2da367f09a49f"}, - {file = "torch-2.5.0-cp311-cp311-manylinux1_x86_64.whl", hash = "sha256:499a68a756d3b30d10f7e0f6214dc3767b130b797265db3b1c02e9094e2a07be"}, - {file = "torch-2.5.0-cp311-cp311-manylinux2014_aarch64.whl", hash = "sha256:9f3df8138a1126a851440b7d5a4869bfb7c9cc43563d64fd9d96d0465b581024"}, - {file = "torch-2.5.0-cp311-cp311-win_amd64.whl", hash = "sha256:b81da3bdb58c9de29d0e1361e52f12fcf10a89673f17a11a5c6c7da1cb1a8376"}, - {file = "torch-2.5.0-cp311-none-macosx_11_0_arm64.whl", hash = "sha256:ba135923295d564355326dc409b6b7f5bd6edc80f764cdaef1fb0a1b23ff2f9c"}, - {file = "torch-2.5.0-cp312-cp312-manylinux1_x86_64.whl", hash = "sha256:2dd40c885a05ef7fe29356cca81be1435a893096ceb984441d6e2c27aff8c6f4"}, - {file = "torch-2.5.0-cp312-cp312-manylinux2014_aarch64.whl", hash = "sha256:bc52d603d87fe1da24439c0d5fdbbb14e0ae4874451d53f0120ffb1f6c192727"}, - {file = "torch-2.5.0-cp312-cp312-win_amd64.whl", hash = "sha256:ea718746469246cc63b3353afd75698a288344adb55e29b7f814a5d3c0a7c78d"}, - {file = "torch-2.5.0-cp312-none-macosx_11_0_arm64.whl", hash = "sha256:6de1fd253e27e7f01f05cd7c37929ae521ca23ca4620cfc7c485299941679112"}, - {file = "torch-2.5.0-cp313-cp313-manylinux1_x86_64.whl", hash = "sha256:83dcf518685db20912b71fc49cbddcc8849438cdb0e9dcc919b02a849e2cd9e8"}, - {file = "torch-2.5.0-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:65e0a60894435608334d68c8811e55fd8f73e5bf8ee6f9ccedb0064486a7b418"}, - {file = "torch-2.5.0-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:38c21ff1bd39f076d72ab06e3c88c2ea6874f2e6f235c9450816b6c8e7627094"}, - {file = "torch-2.5.0-cp39-cp39-win_amd64.whl", hash = "sha256:ce4baeba9804da5a346e210b3b70826f5811330c343e4fe1582200359ee77fe5"}, - {file = "torch-2.5.0-cp39-none-macosx_11_0_arm64.whl", hash = "sha256:03e53f577a96e4d41aca472da8faa40e55df89d2273664af390ce1f570e885bd"}, + {file = "torch-2.5.1-cp310-cp310-manylinux1_x86_64.whl", hash = "sha256:71328e1bbe39d213b8721678f9dcac30dfc452a46d586f1d514a6aa0a99d4744"}, + {file = "torch-2.5.1-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:34bfa1a852e5714cbfa17f27c49d8ce35e1b7af5608c4bc6e81392c352dbc601"}, + {file = "torch-2.5.1-cp310-cp310-win_amd64.whl", hash = "sha256:32a037bd98a241df6c93e4c789b683335da76a2ac142c0973675b715102dc5fa"}, + {file = "torch-2.5.1-cp310-none-macosx_11_0_arm64.whl", hash = "sha256:23d062bf70776a3d04dbe74db950db2a5245e1ba4f27208a87f0d743b0d06e86"}, + {file = "torch-2.5.1-cp311-cp311-manylinux1_x86_64.whl", hash = "sha256:de5b7d6740c4b636ef4db92be922f0edc425b65ed78c5076c43c42d362a45457"}, + {file = "torch-2.5.1-cp311-cp311-manylinux2014_aarch64.whl", hash = "sha256:340ce0432cad0d37f5a31be666896e16788f1adf8ad7be481196b503dad675b9"}, + {file = "torch-2.5.1-cp311-cp311-win_amd64.whl", hash = "sha256:603c52d2fe06433c18b747d25f5c333f9c1d58615620578c326d66f258686f9a"}, + {file = "torch-2.5.1-cp311-none-macosx_11_0_arm64.whl", hash = "sha256:31f8c39660962f9ae4eeec995e3049b5492eb7360dd4f07377658ef4d728fa4c"}, + {file = "torch-2.5.1-cp312-cp312-manylinux1_x86_64.whl", hash = "sha256:ed231a4b3a5952177fafb661213d690a72caaad97d5824dd4fc17ab9e15cec03"}, + {file = "torch-2.5.1-cp312-cp312-manylinux2014_aarch64.whl", hash = "sha256:3f4b7f10a247e0dcd7ea97dc2d3bfbfc90302ed36d7f3952b0008d0df264e697"}, + {file = "torch-2.5.1-cp312-cp312-win_amd64.whl", hash = "sha256:73e58e78f7d220917c5dbfad1a40e09df9929d3b95d25e57d9f8558f84c9a11c"}, + {file = "torch-2.5.1-cp312-none-macosx_11_0_arm64.whl", hash = "sha256:8c712df61101964eb11910a846514011f0b6f5920c55dbf567bff8a34163d5b1"}, + {file = "torch-2.5.1-cp313-cp313-manylinux1_x86_64.whl", hash = "sha256:9b61edf3b4f6e3b0e0adda8b3960266b9009d02b37555971f4d1c8f7a05afed7"}, + {file = "torch-2.5.1-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:1f3b7fb3cf7ab97fae52161423f81be8c6b8afac8d9760823fd623994581e1a3"}, + {file = "torch-2.5.1-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:7974e3dce28b5a21fb554b73e1bc9072c25dde873fa00d54280861e7a009d7dc"}, + {file = "torch-2.5.1-cp39-cp39-win_amd64.whl", hash = "sha256:46c817d3ea33696ad3b9df5e774dba2257e9a4cd3c4a3afbf92f6bb13ac5ce2d"}, + {file = "torch-2.5.1-cp39-none-macosx_11_0_arm64.whl", hash = "sha256:8046768b7f6d35b85d101b4b38cba8aa2f3cd51952bc4c06a49580f2ce682291"}, ] [package.dependencies] @@ -1842,13 +1910,13 @@ files = [ [[package]] name = "tqdm" -version = "4.66.5" +version = "4.66.6" description = "Fast, Extensible Progress Meter" optional = false python-versions = ">=3.7" files = [ - {file = "tqdm-4.66.5-py3-none-any.whl", hash = "sha256:90279a3770753eafc9194a0364852159802111925aa30eb3f9d85b0e805ac7cd"}, - {file = "tqdm-4.66.5.tar.gz", hash = "sha256:e1020aef2e5096702d8a025ac7d16b1577279c9d63f8375b63083e9a5f0fcbad"}, + {file = "tqdm-4.66.6-py3-none-any.whl", hash = "sha256:223e8b5359c2efc4b30555531f09e9f2f3589bcd7fdd389271191031b49b7a63"}, + {file = "tqdm-4.66.6.tar.gz", hash = "sha256:4bdd694238bef1485ce839d67967ab50af8f9272aab687c0d7702a01da0be090"}, ] [package.dependencies] @@ -1943,4 +2011,4 @@ type = ["pytest-mypy"] [metadata] lock-version = "2.0" python-versions = "^3.10" -content-hash = "75f926f2d1100a0c0470e7182f73ca56029dcaddecfb8dfc1e486bd3bcbe9891" +content-hash = "4c9028f1c631630d99d5039d2c70b11e8b8ecd4089d22916a70f65d16e8c2174" diff --git a/pyproject.toml b/pyproject.toml index 20a2a38..122317f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,6 +25,7 @@ torchinfo = "^1.8.0" numpy = "^2.1.2" tqdm = "^4.66.5" ema-pytorch = "^0.7.3" +rich = "^13.9.3" [tool.ruff] fix = true From 7634badca4fa0d4b93ff2badfbcac7be2e14d638 Mon Sep 17 00:00:00 2001 From: texhnolyze Date: Thu, 31 Oct 2024 14:28:09 +0100 Subject: [PATCH 2/3] refactor(db): split models and db session management --- ddlitlab2024/dataset/db.py | 33 +++++++++++++++++++ ddlitlab2024/dataset/{schema.py => models.py} | 26 ++------------- 2 files changed, 35 insertions(+), 24 deletions(-) create mode 100644 ddlitlab2024/dataset/db.py rename ddlitlab2024/dataset/{schema.py => models.py} (94%) diff --git a/ddlitlab2024/dataset/db.py b/ddlitlab2024/dataset/db.py new file mode 100644 index 0000000..fc14718 --- /dev/null +++ b/ddlitlab2024/dataset/db.py @@ -0,0 +1,33 @@ +from sqlalchemy import Engine, create_engine +from sqlalchemy.orm import Session, sessionmaker + +from ddlitlab2024.dataset import logger +from ddlitlab2024.dataset.models import Base + + +class Database: + def __init__(self, db_path: str): + self.db_path = db_path + self.engine: Engine = self._setup_sqlite() + self.session: Session | None = None + + def _setup_sqlite(self) -> Engine: + return create_engine(f"sqlite:///{self.db_path}") + + def _create_schema(self) -> None: + logger.info("Creating database schema") + Base.metadata.create_all(self.engine) + logger.info("Database schema created") + + def create_session(self, create_schema: bool = True) -> Session: + logger.info("Setting up database session") + if create_schema: + self._create_schema() + return sessionmaker(bind=self.engine)() + + def close_session(self) -> None: + if self.session: + self.session.close() + logger.info("Database session closed") + else: + logger.warning("No database session to close") diff --git a/ddlitlab2024/dataset/schema.py b/ddlitlab2024/dataset/models.py similarity index 94% rename from ddlitlab2024/dataset/schema.py rename to ddlitlab2024/dataset/models.py index 7127d06..c9a2543 100644 --- a/ddlitlab2024/dataset/schema.py +++ b/ddlitlab2024/dataset/models.py @@ -1,14 +1,11 @@ -import argparse from datetime import datetime from enum import Enum from typing import List, Optional -from sqlalchemy import Boolean, CheckConstraint, DateTime, Float, ForeignKey, Integer, String, create_engine -from sqlalchemy.orm import Mapped, declarative_base, mapped_column, relationship, sessionmaker +from sqlalchemy import Boolean, CheckConstraint, DateTime, Float, ForeignKey, Integer, String +from sqlalchemy.orm import Mapped, declarative_base, mapped_column, relationship from sqlalchemy.types import LargeBinary -from ddlitlab2024 import DB_PATH -from ddlitlab2024.dataset import logger Base = declarative_base() @@ -229,22 +226,3 @@ class GameState(Base): recording: Mapped["Recording"] = relationship("Recording", back_populates="game_states") __table_args__ = (CheckConstraint(state.in_(RobotState.values())),) - - -def parse_args(): - parser = argparse.ArgumentParser(description="Create the database schema") - parser.add_argument("--db-path", type=str, default=DB_PATH, help="Path to the database file") - return parser.parse_args() - - -def main(): - logger.info("Creating database schema") - args = parse_args() - engine = create_engine(f"sqlite:///{args.db_path}") - Base.metadata.create_all(engine) - sessionmaker(bind=engine)() - logger.info("Database schema created") - - -if __name__ == "__main__": - main() From ecf5d0a42a7e4b18921f615118e3d90dc673a88d Mon Sep 17 00:00:00 2001 From: texhnolyze Date: Thu, 31 Oct 2024 14:32:26 +0100 Subject: [PATCH 3/3] feat(cli): setup single CLI using subcommands --- .github/workflows/create_db.yml | 2 +- ddlitlab2024/dataset/cli.py | 73 +++++++++++++++++++++++++++++++++ ddlitlab2024/dataset/errors.py | 2 + ddlitlab2024/dataset/main.py | 39 ++++++++++++++++++ ddlitlab2024/dataset/models.py | 1 - poetry.lock | 59 +++++++++++++++++++++++++- pyproject.toml | 3 +- 7 files changed, 175 insertions(+), 4 deletions(-) create mode 100644 ddlitlab2024/dataset/cli.py create mode 100644 ddlitlab2024/dataset/errors.py create mode 100755 ddlitlab2024/dataset/main.py diff --git a/.github/workflows/create_db.yml b/.github/workflows/create_db.yml index 5bba312..2f9defe 100644 --- a/.github/workflows/create_db.yml +++ b/.github/workflows/create_db.yml @@ -30,4 +30,4 @@ jobs: run: poetry install - name: Run create-db - run: poetry run create-db + run: poetry run cli db create_schema diff --git a/ddlitlab2024/dataset/cli.py b/ddlitlab2024/dataset/cli.py new file mode 100644 index 0000000..1ce8a10 --- /dev/null +++ b/ddlitlab2024/dataset/cli.py @@ -0,0 +1,73 @@ +import sys +from enum import Enum +from pathlib import Path + +from tap import Tap + +from ddlitlab2024 import DB_PATH + + +class ImportType(str, Enum): + ROS_BAG = "rosbag" + + +class CLICommand(str, Enum): + DB = "db" + IMPORT = "import" + + +class DBArgs(Tap): + create_schema: bool = False + + def configure(self) -> None: + self.add_argument( + "create_schema", + type=bool, + help="Create the base database schema, if it doesn't exist", + nargs="?", + ) + + +class ImportArgs(Tap): + import_type: ImportType + file: Path + + def configure(self) -> None: + self.add_argument( + "import-type", + type=ImportType, + help="Type of import to perform", + ) + self.add_argument( + "file", + type=Path, + help="File to import", + ) + + +class CLIArgs(Tap): + dry_run: bool = False + db_path: str = DB_PATH # Path to the sqlite database file + version: bool = False # if set print version and exit + + def __init__(self): + super().__init__( + description="ddlitlab dataset CLI", + underscores_to_dashes=True, + ) + + def configure(self) -> None: + self.add_subparsers(dest="command", help="Command to run") + self.add_subparser(CLICommand.DB.value, DBArgs, help="Database management commands") + self.add_subparser(CLICommand.IMPORT.value, ImportArgs, help="Import data into the database") + + def print_help_and_exit(self) -> None: + self.print_help() + sys.exit(0) + + def process_args(self) -> None: + if self.command == CLICommand.DB: + all_args = (self.create_schema,) + + if not any(all_args): + self.print_help_and_exit() diff --git a/ddlitlab2024/dataset/errors.py b/ddlitlab2024/dataset/errors.py new file mode 100644 index 0000000..57ec5aa --- /dev/null +++ b/ddlitlab2024/dataset/errors.py @@ -0,0 +1,2 @@ +class CLIArgumentError(Exception): + """Raised when the configuration of CLI arguments is not valid and execution is impossible""" diff --git a/ddlitlab2024/dataset/main.py b/ddlitlab2024/dataset/main.py new file mode 100755 index 0000000..4ed4792 --- /dev/null +++ b/ddlitlab2024/dataset/main.py @@ -0,0 +1,39 @@ +#!/usr/bin/env python3 + +import os +import sys + +from rich.console import Console + +from ddlitlab2024 import __version__ +from ddlitlab2024.dataset import logger +from ddlitlab2024.dataset.cli import CLIArgs, CLICommand +from ddlitlab2024.dataset.db import Database + +err_console = Console(stderr=True) + + +def main(): + debug_mode = os.getenv("LOGLEVEL") == "DEBUG" + + try: + logger.debug("Parsing CLI args...") + args: CLIArgs = CLIArgs().parse_args() + if args.version: + logger.info(f"running ddlitlab2024 CLI v{__version__}") + sys.exit(0) + + if args.command == CLICommand.DB: + db = Database(args.db_path).create_session(args.create_schema) + logger.info(f"Database session created: {db}") + + logger.info(f"CLI args: {args}") + sys.exit(0) + except Exception as e: + logger.error(e) + err_console.print_exception(show_locals=debug_mode) + sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/ddlitlab2024/dataset/models.py b/ddlitlab2024/dataset/models.py index c9a2543..4ecd0ea 100644 --- a/ddlitlab2024/dataset/models.py +++ b/ddlitlab2024/dataset/models.py @@ -6,7 +6,6 @@ from sqlalchemy.orm import Mapped, declarative_base, mapped_column, relationship from sqlalchemy.types import LargeBinary - Base = declarative_base() diff --git a/poetry.lock b/poetry.lock index a30e391..b1db388 100644 --- a/poetry.lock +++ b/poetry.lock @@ -265,6 +265,17 @@ test = ["GitPython (<3.1.19)", "Jinja2", "compel (==0.1.8)", "datasets", "invisi torch = ["accelerate (>=0.31.0)", "torch (>=1.4,<2.5.0)"] training = ["Jinja2", "accelerate (>=0.31.0)", "datasets", "peft (>=0.6.0)", "protobuf (>=3.20.3,<4)", "tensorboard"] +[[package]] +name = "docstring-parser" +version = "0.16" +description = "Parse Python docstrings in reST, Google and Numpydoc format" +optional = false +python-versions = ">=3.6,<4.0" +files = [ + {file = "docstring_parser-0.16-py3-none-any.whl", hash = "sha256:bf0a1387354d3691d102edef7ec124f219ef639982d096e26e3b60aeffa90637"}, + {file = "docstring_parser-0.16.tar.gz", hash = "sha256:538beabd0af1e2db0146b6bd3caa526c35a34d61af9fd2887f3a8a27a739aa6e"}, +] + [[package]] name = "ema-pytorch" version = "0.7.3" @@ -887,6 +898,17 @@ docs = ["sphinx"] gmpy = ["gmpy2 (>=2.1.0a4)"] tests = ["pytest (>=4.6)"] +[[package]] +name = "mypy-extensions" +version = "1.0.0" +description = "Type system extensions for programs checked with the mypy type checker." +optional = false +python-versions = ">=3.5" +files = [ + {file = "mypy_extensions-1.0.0-py3-none-any.whl", hash = "sha256:4392f6c0eb8a5668a69e23d168ffa70f0be9ccfd32b5cc2d26a34ae5b844552d"}, + {file = "mypy_extensions-1.0.0.tar.gz", hash = "sha256:75dbf8955dc00442a438fc4d0666508a9a97b6bd41aa2f0ffe9d2f2725af0782"}, +] + [[package]] name = "networkx" version = "3.4.2" @@ -1950,6 +1972,26 @@ build = ["cmake (>=3.20)", "lit"] tests = ["autopep8", "flake8", "isort", "llnl-hatchet", "numpy", "pytest", "scipy (>=1.7.1)"] tutorials = ["matplotlib", "pandas", "tabulate"] +[[package]] +name = "typed-argument-parser" +version = "1.10.1" +description = "Typed Argument Parser" +optional = false +python-versions = ">=3.8" +files = [ + {file = "typed_argument_parser-1.10.1-py3-none-any.whl", hash = "sha256:152cad218c84f23d9d4cbbfcb153a1a53fa0b330f0f7a4b85190dd4cb86f0c0a"}, + {file = "typed_argument_parser-1.10.1.tar.gz", hash = "sha256:117c2ef9ad6536d36f84f65eb7be035244c841fea34bd5abf2301a01c8a1ec89"}, +] + +[package.dependencies] +docstring-parser = ">=0.15" +packaging = "*" +typing-inspect = ">=0.7.1" + +[package.extras] +dev = ["pydantic (>=2.5.0)", "typed-argument-parser[dev-no-pydantic]"] +dev-no-pydantic = ["flake8", "pytest", "pytest-cov"] + [[package]] name = "typing-extensions" version = "4.12.2" @@ -1961,6 +2003,21 @@ files = [ {file = "typing_extensions-4.12.2.tar.gz", hash = "sha256:1a7ead55c7e559dd4dee8856e3a88b41225abfe1ce8df57b7c13915fe121ffb8"}, ] +[[package]] +name = "typing-inspect" +version = "0.9.0" +description = "Runtime inspection utilities for typing module." +optional = false +python-versions = "*" +files = [ + {file = "typing_inspect-0.9.0-py3-none-any.whl", hash = "sha256:9ee6fc59062311ef8547596ab6b955e1b8aa46242d854bfc78f4f6b0eff35f9f"}, + {file = "typing_inspect-0.9.0.tar.gz", hash = "sha256:b23fc42ff6f6ef6954e4852c1fb512cdd18dbea03134f91f856a95ccc9461f78"}, +] + +[package.dependencies] +mypy-extensions = ">=0.3.0" +typing-extensions = ">=3.7.4" + [[package]] name = "tzdata" version = "2024.2" @@ -2011,4 +2068,4 @@ type = ["pytest-mypy"] [metadata] lock-version = "2.0" python-versions = "^3.10" -content-hash = "4c9028f1c631630d99d5039d2c70b11e8b8ecd4089d22916a70f65d16e8c2174" +content-hash = "28f9d0261e823e5246f49803631271ef5cdff952e14b8b996a23584c8bbb537d" diff --git a/pyproject.toml b/pyproject.toml index 122317f..2af0e4f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,6 +26,7 @@ numpy = "^2.1.2" tqdm = "^4.66.5" ema-pytorch = "^0.7.3" rich = "^13.9.3" +typed-argument-parser = "^1.10.1" [tool.ruff] fix = true @@ -43,4 +44,4 @@ line-length = 120 select = ["F", "E", "B", "W", "I", "N", "UP"] [tool.poetry.scripts] -create-db = "ddlitlab2024.dataset.schema:main" +cli = "ddlitlab2024.dataset.main:main"