From 6a411fa025bd2f33a2d944a4997a04bf88eec2c0 Mon Sep 17 00:00:00 2001 From: Aleksandr Movchan Date: Fri, 13 Dec 2024 09:55:40 +0000 Subject: [PATCH] Refactor JSON type handling by moving to aana.storage.types and removing custom_types module --- aana/alembic/versions/5ad873484aa3_init.py | 2 +- aana/storage/custom_types.py | 40 --------------------- aana/storage/models/task.py | 2 +- aana/storage/models/transcript.py | 2 +- aana/storage/op.py | 2 +- aana/storage/types.py | 41 +++++++++++++++++++++- 6 files changed, 44 insertions(+), 45 deletions(-) delete mode 100644 aana/storage/custom_types.py diff --git a/aana/alembic/versions/5ad873484aa3_init.py b/aana/alembic/versions/5ad873484aa3_init.py index c62abda4..89b77d13 100644 --- a/aana/alembic/versions/5ad873484aa3_init.py +++ b/aana/alembic/versions/5ad873484aa3_init.py @@ -9,7 +9,7 @@ from alembic import op from sqlalchemy.schema import CreateSequence, Sequence -from aana.storage.custom_types import JSON +from aana.storage.types import JSON # revision identifiers, used by Alembic. revision: str = "5ad873484aa3" diff --git a/aana/storage/custom_types.py b/aana/storage/custom_types.py deleted file mode 100644 index 70015813..00000000 --- a/aana/storage/custom_types.py +++ /dev/null @@ -1,40 +0,0 @@ -import orjson -from snowflake.sqlalchemy.custom_types import VARIANT as SnowflakeVariantType -from sqlalchemy import func -from sqlalchemy.types import JSON as SqlAlchemyJSON -from sqlalchemy.types import TypeDecorator - - -class VARIANT(SnowflakeVariantType): - """Extends VARIANT type for better SqlAlchemy support.""" - - def bind_expression(self, bindvalue): - """Wraps value with PARSE_JSON for Snowflake.""" - return func.PARSE_JSON(bindvalue) - - def result_processor(self, dialect, coltype): - """Convert JSON string to Python dictionary when retrieving.""" - - def process(value): - if value is None: - return None - try: - return orjson.loads(value) - except (ValueError, TypeError): - return value # Return raw value if not valid JSON - - return process - - -class JSON(TypeDecorator): - """Custom JSON type that supports Snowflake-specific and standard dialects.""" - - impl = SqlAlchemyJSON # Default to standard SQLAlchemy JSON - # impl = VARIANT # Default to Snowflake VARIANT - - def load_dialect_impl(self, dialect): - """Load dialect-specific implementation.""" - if dialect.name == "snowflake": - return VARIANT() - else: - return SqlAlchemyJSON() diff --git a/aana/storage/models/task.py b/aana/storage/models/task.py index 1ad6db12..eda69b51 100644 --- a/aana/storage/models/task.py +++ b/aana/storage/models/task.py @@ -7,8 +7,8 @@ ) from sqlalchemy.orm import Mapped, mapped_column -from aana.storage.custom_types import JSON from aana.storage.models.base import BaseEntity, TimeStampEntity, timestamp +from aana.storage.types import JSON class Status(str, Enum): diff --git a/aana/storage/models/transcript.py b/aana/storage/models/transcript.py index 33d55695..a8610158 100644 --- a/aana/storage/models/transcript.py +++ b/aana/storage/models/transcript.py @@ -5,8 +5,8 @@ from sqlalchemy import CheckConstraint, Sequence from sqlalchemy.orm import Mapped, mapped_column -from aana.storage.custom_types import JSON from aana.storage.models.base import BaseEntity, TimeStampEntity +from aana.storage.types import JSON if TYPE_CHECKING: from aana.core.models.asr import ( diff --git a/aana/storage/op.py b/aana/storage/op.py index b7d4f227..c134b945 100644 --- a/aana/storage/op.py +++ b/aana/storage/op.py @@ -9,7 +9,7 @@ from sqlalchemy import create_engine, event from aana.exceptions.runtime import EmptyMigrationsException -from aana.storage.custom_types import JSON +from aana.storage.types import JSON from aana.utils.core import get_module_dir from aana.utils.json import jsonify diff --git a/aana/storage/types.py b/aana/storage/types.py index 38e29954..9ced8da6 100644 --- a/aana/storage/types.py +++ b/aana/storage/types.py @@ -1,5 +1,44 @@ from typing import TypeAlias -from sqlalchemy import String +import orjson +from snowflake.sqlalchemy.custom_types import VARIANT as SnowflakeVariantType +from sqlalchemy import String, func +from sqlalchemy.types import JSON as SqlAlchemyJSON +from sqlalchemy.types import TypeDecorator MediaIdSqlType: TypeAlias = String(36) + + +class VARIANT(SnowflakeVariantType): + """Extends VARIANT type for better SqlAlchemy support.""" + + def bind_expression(self, bindvalue): + """Wraps value with PARSE_JSON for Snowflake.""" + return func.PARSE_JSON(bindvalue) + + def result_processor(self, dialect, coltype): + """Convert JSON string to Python dictionary when retrieving.""" + + def process(value): + if value is None: + return None + try: + return orjson.loads(value) + except (ValueError, TypeError): + return value # Return raw value if not valid JSON + + return process + + +class JSON(TypeDecorator): + """Custom JSON type that supports Snowflake-specific and standard dialects.""" + + impl = SqlAlchemyJSON # Default to standard SQLAlchemy JSON + # impl = VARIANT # Default to Snowflake VARIANT + + def load_dialect_impl(self, dialect): + """Load dialect-specific implementation.""" + if dialect.name == "snowflake": + return VARIANT() + else: + return SqlAlchemyJSON()