Skip to content

Commit

Permalink
feat: Reduce migrate_db calls
Browse files Browse the repository at this point in the history
This update significantly enhances test execution efficiency by
invoking the `migrate_db` function just once for all tests.
This optimization is accomplished by maintaining database changes
in memory, rather than persisting them with `db.commit()`.
Consequently, it is essential to expire all objects in scenarios
where `db.commit()` would have been typically used. This approach
introduces minimal overhead, given the limited number and frequency
of object interactions during testing.
Additionally, some test cases
have been modified to accommodate this change.
They now rely on dynamic ids returned from the backend instead
of static ids. With the database no longer resetting after each test,
ids now increment continuously, necessitating this adjustment.
  • Loading branch information
dominik003 committed Jan 2, 2024
1 parent 1bc5a44 commit 30afd57
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 21 deletions.
43 changes: 25 additions & 18 deletions backend/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,7 @@
import capellacollab.projects.models as projects_models
import capellacollab.projects.users.crud as projects_users_crud
import capellacollab.projects.users.models as projects_users_models
import capellacollab.users.crud as users_crud
import capellacollab.users.models as users_models
from capellacollab.__main__ import app
from capellacollab.core import database
from capellacollab.core.authentication.jwt_bearer import JWTBearer
from capellacollab.core.database import migration
from capellacollab.users import crud as users_crud
Expand All @@ -33,38 +30,48 @@


@pytest.fixture(name="postgresql", scope="session")
def fixture_postgresql() -> engine.Engine:
def fixture_postgreql_engine() -> engine.Engine:
with postgres.PostgresContainer(image="postgres:14.1") as _postgres:
database_url = _postgres.get_connection_url()

_engine = sqlalchemy.create_engine(database_url.replace("***", "test"))
with pytest.MonkeyPatch.context() as monkeypatch:
_engine = sqlalchemy.create_engine(
database_url.replace("***", "test")
)

yield _engine
session_local = orm.sessionmaker(
autocommit=False, autoflush=False, bind=_engine
)

monkeypatch.setattr(database, "engine", _engine)
monkeypatch.setattr(database, "SessionLocal", session_local)

migration.migrate_db(
_engine, str(_engine.url).replace("***", "test")
)

yield _engine


@pytest.fixture(name="db")
def fixture_db(
postgresql: engine.Engine, monkeypatch: pytest.MonkeyPatch
) -> orm.Session:
session_local = orm.sessionmaker(
with orm.sessionmaker(
autocommit=False, autoflush=False, bind=postgresql
)

monkeypatch.setattr(database, "engine", postgresql)
monkeypatch.setattr(database, "SessionLocal", session_local)

delete_all_tables_if_existent(postgresql)
migration.migrate_db(
postgresql, str(postgresql.url).replace("***", "test")
)

with session_local() as session:
)() as session:

def mock_get_db() -> orm.Session:
return session

app.dependency_overrides[database.get_db] = mock_get_db

def commit(*args, **kwargs):
session.flush()
session.expire_all()

monkeypatch.setattr(session, "commit", commit)

yield session


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,6 @@ def test_get_all_pipelines_of_capellamodel(

assert response.status_code == 200
assert len(response.json()) == 1
assert response.json()[0]["id"] == 1


@pytest.mark.usefixtures(
Expand Down
8 changes: 6 additions & 2 deletions backend/tests/users/test_tokens.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,9 @@ def test_create_and_delete_token(
response = client.post("/api/v1/users/current/tokens", json=POST_TOKEN)
assert response.status_code == 200

response = client.delete("/api/v1/users/current/tokens/1")
token_id = response.json()["id"]

response = client.delete(f"/api/v1/users/current/tokens/{token_id}")
assert response.status_code == 204


Expand All @@ -91,7 +93,9 @@ def test_token_lifecycle(
response_string = response.content.decode("utf-8")
assert len(json.loads(response_string)) == 1

response = client.delete("/api/v1/users/current/tokens/1")
token_id = response.json()[0]["id"]

response = client.delete(f"/api/v1/users/current/tokens/{token_id}")
assert response.status_code == 204

response = client.get("/api/v1/users/current/tokens")
Expand Down

0 comments on commit 30afd57

Please sign in to comment.