diff --git a/acacore/database/base.py b/acacore/database/base.py index e58312b..08c8636 100644 --- a/acacore/database/base.py +++ b/acacore/database/base.py @@ -6,6 +6,7 @@ from sqlite3 import Connection from sqlite3 import Cursor as SQLiteCursor from sqlite3 import OperationalError +from types import TracebackType from typing import Any from typing import Generator from typing import Generic @@ -780,6 +781,12 @@ def __init__( def __repr__(self) -> str: return f"{self.__class__.__name__}({self.path})" + def __enter__(self) -> "FileDBBase": + return self + + def __exit__(self, _exc_type: Type[BaseException], _exc_val: BaseException, _exc_tb: TracebackType) -> None: + self.close() + @property def path(self) -> Optional[Path]: for _, name, filename in self.execute("PRAGMA database_list"):