From 3b7ec6017b8fd6cb3db45a9c31732fc82c9ced4d Mon Sep 17 00:00:00 2001 From: Robert Craigie Date: Thu, 30 Nov 2023 01:13:37 +0000 Subject: [PATCH] fix(enums): ensure consistent enum format in 3.11+ (#846) ## Change Summary closes #845 ## Checklist - [ ] Unit tests for the changes exist - [ ] Tests pass without significant drop in coverage - [ ] Documentation reflects changes where applicable - [ ] Test snapshots have been [updated](https://prisma-client-py.readthedocs.io/en/latest/contributing/contributing/#snapshot-tests) if applicable ## Agreement By submitting this pull request, I confirm that you can use, modify, copy and redistribute this contribution, under the terms of your choice. --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- databases/tests/test_enum.py | 9 +++++++++ requirements/base.txt | 1 + src/prisma/_compat.py | 11 +++++++++++ src/prisma/generator/templates/enums.py.jinja | 4 ++-- .../test_exhaustive/test_async[enums.py].raw | 4 ++-- .../test_exhaustive/test_sync[enums.py].raw | 4 ++-- 6 files changed, 27 insertions(+), 6 deletions(-) diff --git a/databases/tests/test_enum.py b/databases/tests/test_enum.py index 21cd39c41..99c1f2fe6 100644 --- a/databases/tests/test_enum.py +++ b/databases/tests/test_enum.py @@ -16,6 +16,15 @@ async def test_enum_create(client: Prisma) -> None: record = await client.types.create({'enum': Role.ADMIN}) assert record.enum == Role.ADMIN + # ensure consistent format + assert str(record.enum) == 'ADMIN' + assert f'{record.enum}' == 'ADMIN' + assert '%s' % record.enum == 'ADMIN' + + assert str(Role.ADMIN) == 'ADMIN' + assert f'{Role.ADMIN}' == 'ADMIN' + assert '%s' % Role.ADMIN == 'ADMIN' + # TODO: all other actions diff --git a/requirements/base.txt b/requirements/base.txt index 5aac23545..daec40fad 100644 --- a/requirements/base.txt +++ b/requirements/base.txt @@ -7,3 +7,4 @@ typing-extensions>=4.0.1 tomlkit nodeenv cached-property; python_version < '3.8' +StrEnum; python_version < '3.11' diff --git a/src/prisma/_compat.py b/src/prisma/_compat.py index 4ed5bb464..9bd1a8475 100644 --- a/src/prisma/_compat.py +++ b/src/prisma/_compat.py @@ -337,6 +337,17 @@ def Field(*, env: str | None = None, **extra: Any) -> Any: nodejs = None +if TYPE_CHECKING: + from enum import Enum + + StrEnum = Enum +else: + try: + from enum import StrEnum + except ImportError: + from strenum import StrEnum + + def removeprefix(string: str, prefix: str) -> str: if string.startswith(prefix): return string[len(prefix) :] diff --git a/src/prisma/generator/templates/enums.py.jinja b/src/prisma/generator/templates/enums.py.jinja index 610e2b730..39427366d 100644 --- a/src/prisma/generator/templates/enums.py.jinja +++ b/src/prisma/generator/templates/enums.py.jinja @@ -1,10 +1,10 @@ {% include '_header.py.jinja' %} # -- template enums.py.jinja -- -from enum import Enum +from ._compat import StrEnum {% for enum in dmmf.datamodel.enums %} -class {{ enum.name }}(str, Enum): +class {{ enum.name }}(StrEnum): {% for value in enum.values %} {{ value.name }} = '{{ value.name }}' {% endfor %} diff --git a/tests/test_generation/exhaustive/__snapshots__/test_exhaustive/test_async[enums.py].raw b/tests/test_generation/exhaustive/__snapshots__/test_exhaustive/test_async[enums.py].raw index 8d558a77a..d80a320d3 100644 --- a/tests/test_generation/exhaustive/__snapshots__/test_exhaustive/test_async[enums.py].raw +++ b/tests/test_generation/exhaustive/__snapshots__/test_exhaustive/test_async[enums.py].raw @@ -39,10 +39,10 @@ from typing_extensions import TypedDict, Literal LiteralString = str # -- template enums.py.jinja -- -from enum import Enum +from ._compat import StrEnum -class ABeautifulEnum(str, Enum): +class ABeautifulEnum(StrEnum): A = 'A' B = 'B' C = 'C' diff --git a/tests/test_generation/exhaustive/__snapshots__/test_exhaustive/test_sync[enums.py].raw b/tests/test_generation/exhaustive/__snapshots__/test_exhaustive/test_sync[enums.py].raw index 8d558a77a..d80a320d3 100644 --- a/tests/test_generation/exhaustive/__snapshots__/test_exhaustive/test_sync[enums.py].raw +++ b/tests/test_generation/exhaustive/__snapshots__/test_exhaustive/test_sync[enums.py].raw @@ -39,10 +39,10 @@ from typing_extensions import TypedDict, Literal LiteralString = str # -- template enums.py.jinja -- -from enum import Enum +from ._compat import StrEnum -class ABeautifulEnum(str, Enum): +class ABeautifulEnum(StrEnum): A = 'A' B = 'B' C = 'C'