diff --git a/conftest.py b/conftest.py index 4eee9e1..89639b2 100644 --- a/conftest.py +++ b/conftest.py @@ -1,5 +1,6 @@ import asyncio import os +from typing import Generator import pytest from tortoise import Tortoise, expand_db_url, generate_schema_for_client @@ -12,8 +13,9 @@ from aerich.ddl.sqlite import SqliteDDL from aerich.migrate import Migrate -db_url = os.getenv("TEST_DB", "sqlite://:memory:") -db_url_second = os.getenv("TEST_DB_SECOND", "sqlite://:memory:") +MEMORY_SQLITE = "sqlite://:memory:" +db_url = os.getenv("TEST_DB", MEMORY_SQLITE) +db_url_second = os.getenv("TEST_DB_SECOND", MEMORY_SQLITE) tortoise_orm = { "connections": { "default": expand_db_url(db_url, True), @@ -27,7 +29,7 @@ @pytest.fixture(scope="function", autouse=True) -def reset_migrate(): +def reset_migrate() -> None: Migrate.upgrade_operators = [] Migrate.downgrade_operators = [] Migrate._upgrade_fk_m2m_index_operators = [] @@ -37,20 +39,20 @@ def reset_migrate(): @pytest.fixture(scope="session") -def event_loop(): +def event_loop() -> Generator: policy = asyncio.get_event_loop_policy() res = policy.new_event_loop() asyncio.set_event_loop(res) - res._close = res.close - res.close = lambda: None + res._close = res.close # type:ignore[attr-defined] + res.close = lambda: None # type:ignore[method-assign] yield res - res._close() + res._close() # type:ignore[attr-defined] @pytest.fixture(scope="session", autouse=True) -async def initialize_tests(event_loop, request): +async def initialize_tests(event_loop, request) -> None: await Tortoise.init(config=tortoise_orm, _create_db=True) await generate_schema_for_client(Tortoise.get_connection("default"), safe=True) diff --git a/tests/models.py b/tests/models.py index 5dd7f61..5f8ae87 100644 --- a/tests/models.py +++ b/tests/models.py @@ -37,7 +37,7 @@ class Email(Model): email = fields.CharField(max_length=200, index=True) is_primary = fields.BooleanField(default=False) address = fields.CharField(max_length=200) - users = fields.ManyToManyField("models.User") + users: fields.ManyToManyRelation[User] = fields.ManyToManyField("models.User") def default_name(): @@ -47,13 +47,15 @@ def default_name(): class Category(Model): slug = fields.CharField(max_length=100) name = fields.CharField(max_length=200, null=True, default=default_name) - user = fields.ForeignKeyField("models.User", description="User") + user: fields.ForeignKeyRelation[User] = fields.ForeignKeyField( + "models.User", description="User" + ) title = fields.CharField(max_length=20, unique=False) created_at = fields.DatetimeField(auto_now_add=True) class Product(Model): - categories = fields.ManyToManyField("models.Category") + categories: fields.ManyToManyRelation[Category] = fields.ManyToManyField("models.Category") name = fields.CharField(max_length=50) view_num = fields.IntField(description="View Num", default=0) sort = fields.IntField() @@ -75,7 +77,9 @@ class Config(Model): key = fields.CharField(max_length=20) value = fields.JSONField() status: Status = fields.IntEnumField(Status) - user = fields.ForeignKeyField("models.User", description="User") + user: fields.ForeignKeyRelation[User] = fields.ForeignKeyField( + "models.User", description="User" + ) class NewModel(Model): diff --git a/tests/models_second.py b/tests/models_second.py index 505aa27..71f108f 100644 --- a/tests/models_second.py +++ b/tests/models_second.py @@ -34,18 +34,24 @@ class User(Model): class Email(Model): email = fields.CharField(max_length=200) is_primary = fields.BooleanField(default=False) - user = fields.ForeignKeyField("models_second.User", db_constraint=False) + user: fields.ForeignKeyRelation[User] = fields.ForeignKeyField( + "models_second.User", db_constraint=False + ) class Category(Model): slug = fields.CharField(max_length=200) name = fields.CharField(max_length=200) - user = fields.ForeignKeyField("models_second.User", description="User") + user: fields.ForeignKeyRelation[User] = fields.ForeignKeyField( + "models_second.User", description="User" + ) created_at = fields.DatetimeField(auto_now_add=True) class Product(Model): - categories = fields.ManyToManyField("models_second.Category") + categories: fields.ManyToManyRelation[Category] = fields.ManyToManyField( + "models_second.Category" + ) name = fields.CharField(max_length=50) view_num = fields.IntField(description="View Num") sort = fields.IntField() diff --git a/tests/old_models.py b/tests/old_models.py index b8ffc5d..e882bc8 100644 --- a/tests/old_models.py +++ b/tests/old_models.py @@ -35,18 +35,22 @@ class User(Model): class Email(Model): email = fields.CharField(max_length=200) is_primary = fields.BooleanField(default=False) - user = fields.ForeignKeyField("models.User", db_constraint=False) + user: fields.ForeignKeyRelation[User] = fields.ForeignKeyField( + "models.User", db_constraint=False + ) class Category(Model): slug = fields.CharField(max_length=200) name = fields.CharField(max_length=200) - user = fields.ForeignKeyField("models.User", description="User") + user: fields.ForeignKeyRelation[User] = fields.ForeignKeyField( + "models.User", description="User" + ) created_at = fields.DatetimeField(auto_now_add=True) class Product(Model): - categories = fields.ManyToManyField("models.Category") + categories: fields.ManyToManyRelation[Category] = fields.ManyToManyField("models.Category") name = fields.CharField(max_length=50) view_num = fields.IntField(description="View Num") sort = fields.IntField() diff --git a/tests/test_migrate.py b/tests/test_migrate.py index a99a143..4614606 100644 --- a/tests/test_migrate.py +++ b/tests/test_migrate.py @@ -1,4 +1,3 @@ -import tempfile from pathlib import Path import pytest @@ -1010,14 +1009,13 @@ def test_sort_all_version_files(mocker): ] -async def test_empty_migration(mocker) -> None: +async def test_empty_migration(mocker, tmp_path: Path) -> None: mocker.patch("os.listdir", return_value=[]) Migrate.app = "foo" expected_content = MIGRATE_TEMPLATE.format(upgrade_sql="", downgrade_sql="") - with tempfile.TemporaryDirectory() as temp_dir: - Migrate.migrate_location = temp_dir + Migrate.migrate_location = tmp_path - migration_file = await Migrate.migrate("update", True) + migration_file = await Migrate.migrate("update", True) - with open(Path(temp_dir, migration_file), "r") as f: - assert f.read() == expected_content + f = tmp_path / migration_file + assert f.read_text() == expected_content