Skip to content

Commit

Permalink
Merge pull request #340 from waketzheng/type-hint-tests
Browse files Browse the repository at this point in the history
Improve type hints for tests/
  • Loading branch information
long2ice authored Jun 6, 2024
2 parents 219633a + 7b73349 commit 13dd44b
Show file tree
Hide file tree
Showing 5 changed files with 39 additions and 25 deletions.
18 changes: 10 additions & 8 deletions conftest.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import asyncio
import os
from typing import Generator

import pytest
from tortoise import Tortoise, expand_db_url, generate_schema_for_client
Expand All @@ -12,8 +13,9 @@
from aerich.ddl.sqlite import SqliteDDL
from aerich.migrate import Migrate

db_url = os.getenv("TEST_DB", "sqlite://:memory:")
db_url_second = os.getenv("TEST_DB_SECOND", "sqlite://:memory:")
MEMORY_SQLITE = "sqlite://:memory:"
db_url = os.getenv("TEST_DB", MEMORY_SQLITE)
db_url_second = os.getenv("TEST_DB_SECOND", MEMORY_SQLITE)
tortoise_orm = {
"connections": {
"default": expand_db_url(db_url, True),
Expand All @@ -27,7 +29,7 @@


@pytest.fixture(scope="function", autouse=True)
def reset_migrate():
def reset_migrate() -> None:
Migrate.upgrade_operators = []
Migrate.downgrade_operators = []
Migrate._upgrade_fk_m2m_index_operators = []
Expand All @@ -37,20 +39,20 @@ def reset_migrate():


@pytest.fixture(scope="session")
def event_loop():
def event_loop() -> Generator:
policy = asyncio.get_event_loop_policy()
res = policy.new_event_loop()
asyncio.set_event_loop(res)
res._close = res.close
res.close = lambda: None
res._close = res.close # type:ignore[attr-defined]
res.close = lambda: None # type:ignore[method-assign]

yield res

res._close()
res._close() # type:ignore[attr-defined]


@pytest.fixture(scope="session", autouse=True)
async def initialize_tests(event_loop, request):
async def initialize_tests(event_loop, request) -> None:
await Tortoise.init(config=tortoise_orm, _create_db=True)
await generate_schema_for_client(Tortoise.get_connection("default"), safe=True)

Expand Down
12 changes: 8 additions & 4 deletions tests/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class Email(Model):
email = fields.CharField(max_length=200, index=True)
is_primary = fields.BooleanField(default=False)
address = fields.CharField(max_length=200)
users = fields.ManyToManyField("models.User")
users: fields.ManyToManyRelation[User] = fields.ManyToManyField("models.User")


def default_name():
Expand All @@ -47,13 +47,15 @@ def default_name():
class Category(Model):
slug = fields.CharField(max_length=100)
name = fields.CharField(max_length=200, null=True, default=default_name)
user = fields.ForeignKeyField("models.User", description="User")
user: fields.ForeignKeyRelation[User] = fields.ForeignKeyField(
"models.User", description="User"
)
title = fields.CharField(max_length=20, unique=False)
created_at = fields.DatetimeField(auto_now_add=True)


class Product(Model):
categories = fields.ManyToManyField("models.Category")
categories: fields.ManyToManyRelation[Category] = fields.ManyToManyField("models.Category")
name = fields.CharField(max_length=50)
view_num = fields.IntField(description="View Num", default=0)
sort = fields.IntField()
Expand All @@ -75,7 +77,9 @@ class Config(Model):
key = fields.CharField(max_length=20)
value = fields.JSONField()
status: Status = fields.IntEnumField(Status)
user = fields.ForeignKeyField("models.User", description="User")
user: fields.ForeignKeyRelation[User] = fields.ForeignKeyField(
"models.User", description="User"
)


class NewModel(Model):
Expand Down
12 changes: 9 additions & 3 deletions tests/models_second.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,18 +34,24 @@ class User(Model):
class Email(Model):
email = fields.CharField(max_length=200)
is_primary = fields.BooleanField(default=False)
user = fields.ForeignKeyField("models_second.User", db_constraint=False)
user: fields.ForeignKeyRelation[User] = fields.ForeignKeyField(
"models_second.User", db_constraint=False
)


class Category(Model):
slug = fields.CharField(max_length=200)
name = fields.CharField(max_length=200)
user = fields.ForeignKeyField("models_second.User", description="User")
user: fields.ForeignKeyRelation[User] = fields.ForeignKeyField(
"models_second.User", description="User"
)
created_at = fields.DatetimeField(auto_now_add=True)


class Product(Model):
categories = fields.ManyToManyField("models_second.Category")
categories: fields.ManyToManyRelation[Category] = fields.ManyToManyField(
"models_second.Category"
)
name = fields.CharField(max_length=50)
view_num = fields.IntField(description="View Num")
sort = fields.IntField()
Expand Down
10 changes: 7 additions & 3 deletions tests/old_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,18 +35,22 @@ class User(Model):
class Email(Model):
email = fields.CharField(max_length=200)
is_primary = fields.BooleanField(default=False)
user = fields.ForeignKeyField("models.User", db_constraint=False)
user: fields.ForeignKeyRelation[User] = fields.ForeignKeyField(
"models.User", db_constraint=False
)


class Category(Model):
slug = fields.CharField(max_length=200)
name = fields.CharField(max_length=200)
user = fields.ForeignKeyField("models.User", description="User")
user: fields.ForeignKeyRelation[User] = fields.ForeignKeyField(
"models.User", description="User"
)
created_at = fields.DatetimeField(auto_now_add=True)


class Product(Model):
categories = fields.ManyToManyField("models.Category")
categories: fields.ManyToManyRelation[Category] = fields.ManyToManyField("models.Category")
name = fields.CharField(max_length=50)
view_num = fields.IntField(description="View Num")
sort = fields.IntField()
Expand Down
12 changes: 5 additions & 7 deletions tests/test_migrate.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import tempfile
from pathlib import Path

import pytest
Expand Down Expand Up @@ -1010,14 +1009,13 @@ def test_sort_all_version_files(mocker):
]


async def test_empty_migration(mocker) -> None:
async def test_empty_migration(mocker, tmp_path: Path) -> None:
mocker.patch("os.listdir", return_value=[])
Migrate.app = "foo"
expected_content = MIGRATE_TEMPLATE.format(upgrade_sql="", downgrade_sql="")
with tempfile.TemporaryDirectory() as temp_dir:
Migrate.migrate_location = temp_dir
Migrate.migrate_location = tmp_path

migration_file = await Migrate.migrate("update", True)
migration_file = await Migrate.migrate("update", True)

with open(Path(temp_dir, migration_file), "r") as f:
assert f.read() == expected_content
f = tmp_path / migration_file
assert f.read_text() == expected_content

0 comments on commit 13dd44b

Please sign in to comment.