Skip to content

Commit

Permalink
feat(ingest/sqlite): Support sqlite < 3.24.0 (datahub-project#12137)
Browse files Browse the repository at this point in the history
  • Loading branch information
asikowitz authored Dec 16, 2024
1 parent d5e379a commit 6b8d21a
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 7 deletions.
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import collections
import gzip
import logging
import os
import pathlib
import pickle
import shutil
Expand Down Expand Up @@ -33,6 +34,14 @@

logger: logging.Logger = logging.getLogger(__name__)

OVERRIDE_SQLITE_VERSION_REQUIREMENT_STR = (
os.environ.get("OVERRIDE_SQLITE_VERSION_REQ") or ""
)
OVERRIDE_SQLITE_VERSION_REQUIREMENT = (
OVERRIDE_SQLITE_VERSION_REQUIREMENT_STR
and OVERRIDE_SQLITE_VERSION_REQUIREMENT_STR.lower() != "false"
)

_DEFAULT_FILE_NAME = "sqlite.db"
_DEFAULT_TABLE_NAME = "data"

Expand Down Expand Up @@ -212,6 +221,7 @@ class FileBackedDict(MutableMapping[str, _VT], Closeable, Generic[_VT]):
_active_object_cache: OrderedDict[str, Tuple[_VT, bool]] = field(
init=False, repr=False
)
_use_sqlite_on_conflict: bool = field(repr=False, default=True)

def __post_init__(self) -> None:
assert (
Expand All @@ -232,7 +242,10 @@ def __post_init__(self) -> None:
# We use the ON CONFLICT clause to implement UPSERTs with sqlite.
# This was added in 3.24.0 from 2018-06-04.
# See https://www.sqlite.org/lang_conflict.html
raise RuntimeError("SQLite version 3.24.0 or later is required")
if OVERRIDE_SQLITE_VERSION_REQUIREMENT:
self.use_sqlite_on_conflict = False
else:
raise RuntimeError("SQLite version 3.24.0 or later is required")

# We keep a small cache in memory to avoid having to serialize/deserialize
# data from the database too often. We use an OrderedDict to build
Expand Down Expand Up @@ -295,7 +308,7 @@ def _prune_cache(self, num_items_to_prune: int) -> None:
values.append(column_serializer(value))
items_to_write.append(tuple(values))

if items_to_write:
if items_to_write and self._use_sqlite_on_conflict:
# Tricky: By using a INSERT INTO ... ON CONFLICT (key) structure, we can
# ensure that the rowid remains the same if a value is updated but is
# autoincremented when rows are inserted.
Expand All @@ -312,6 +325,26 @@ def _prune_cache(self, num_items_to_prune: int) -> None:
""",
items_to_write,
)
else:
for item in items_to_write:
try:
self._conn.execute(
f"""INSERT INTO {self.tablename} (
key,
value
{''.join(f', {column_name}' for column_name in self.extra_columns.keys())}
)
VALUES ({', '.join(['?'] *(2 + len(self.extra_columns)))})""",
item,
)
except sqlite3.IntegrityError:
self._conn.execute(
f"""UPDATE {self.tablename} SET
value = ?
{''.join(f', {column_name} = ?' for column_name in self.extra_columns.keys())}
WHERE key = ?""",
(*item[1:], item[0]),
)

def flush(self) -> None:
self._prune_cache(len(self._active_object_cache))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,13 @@
)


def test_file_dict() -> None:
@pytest.mark.parametrize("use_sqlite_on_conflict", [True, False])
def test_file_dict(use_sqlite_on_conflict: bool) -> None:
cache = FileBackedDict[int](
tablename="cache",
cache_max_size=10,
cache_eviction_batch_size=10,
_use_sqlite_on_conflict=use_sqlite_on_conflict,
)

for i in range(100):
Expand Down Expand Up @@ -92,7 +94,8 @@ def test_file_dict() -> None:
cache["a"] = 1


def test_custom_serde() -> None:
@pytest.mark.parametrize("use_sqlite_on_conflict", [True, False])
def test_custom_serde(use_sqlite_on_conflict: bool) -> None:
@dataclass(frozen=True)
class Label:
a: str
Expand Down Expand Up @@ -139,6 +142,7 @@ def deserialize(s: str) -> Main:
deserializer=deserialize,
# Disable the in-memory cache to force all reads/writes to the DB.
cache_max_size=0,
_use_sqlite_on_conflict=use_sqlite_on_conflict,
)
first = Main(3, {Label("one", 1): 0.1, Label("two", 2): 0.2})
second = Main(-100, {Label("z", 26): 0.26})
Expand Down Expand Up @@ -186,7 +190,8 @@ def test_file_dict_stores_counter() -> None:
assert in_memory_counters[i].most_common(2) == cache[str(i)].most_common(2)


def test_file_dict_ordering() -> None:
@pytest.mark.parametrize("use_sqlite_on_conflict", [True, False])
def test_file_dict_ordering(use_sqlite_on_conflict: bool) -> None:
"""
We require that FileBackedDict maintains insertion order, similar to Python's
built-in dict. This test makes one of each and validates that they behave the same.
Expand All @@ -196,6 +201,7 @@ def test_file_dict_ordering() -> None:
serializer=str,
deserializer=int,
cache_max_size=1,
_use_sqlite_on_conflict=use_sqlite_on_conflict,
)
data = {}

Expand Down Expand Up @@ -229,12 +235,14 @@ class Pair:


@pytest.mark.parametrize("cache_max_size", [0, 1, 10])
def test_custom_column(cache_max_size: int) -> None:
@pytest.mark.parametrize("use_sqlite_on_conflict", [True, False])
def test_custom_column(cache_max_size: int, use_sqlite_on_conflict: bool) -> None:
cache = FileBackedDict[Pair](
extra_columns={
"x": lambda m: m.x,
},
cache_max_size=cache_max_size,
_use_sqlite_on_conflict=use_sqlite_on_conflict,
)

cache["first"] = Pair(3, "a")
Expand Down Expand Up @@ -275,14 +283,16 @@ def test_custom_column(cache_max_size: int) -> None:
]


def test_shared_connection() -> None:
@pytest.mark.parametrize("use_sqlite_on_conflict", [True, False])
def test_shared_connection(use_sqlite_on_conflict: bool) -> None:
with ConnectionWrapper() as connection:
cache1 = FileBackedDict[int](
shared_connection=connection,
tablename="cache1",
extra_columns={
"v": lambda v: v,
},
_use_sqlite_on_conflict=use_sqlite_on_conflict,
)
cache2 = FileBackedDict[Pair](
shared_connection=connection,
Expand All @@ -291,6 +301,7 @@ def test_shared_connection() -> None:
"x": lambda m: m.x,
"y": lambda m: m.y,
},
_use_sqlite_on_conflict=use_sqlite_on_conflict,
)

cache1["a"] = 3
Expand Down

0 comments on commit 6b8d21a

Please sign in to comment.