Skip to content

Commit

Permalink
chore(internal): support rendering enum types from the DMMF (#921)
Browse files Browse the repository at this point in the history
Will be useful for #878
  • Loading branch information
RobertCraigie authored Feb 24, 2024
1 parent 1fa9331 commit 64212ac
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 6 deletions.
20 changes: 16 additions & 4 deletions src/prisma/generator/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,6 +353,7 @@ def to_params(self) -> Dict[str, Any]:
"""Get the parameters that should be sent to Jinja templates"""
params = vars(self)
params['type_schema'] = Schema.from_data(self)
params['client_types'] = ClientTypes.from_data(self)

# add utility functions
for func in [
Expand Down Expand Up @@ -628,11 +629,22 @@ def engine_type_validator(cls, value: EngineType) -> EngineType:
assert_never(value)


class DMMFEnumType(BaseModel):
name: str
values: List[object]


class DMMFEnumTypes(BaseModel):
prisma: List[DMMFEnumType]


class PrismaSchema(BaseModel):
enum_types: DMMFEnumTypes = FieldInfo(alias='enumTypes')


class DMMF(BaseModel):
datamodel: 'Datamodel'

# TODO
prisma_schema: Any = FieldInfo(alias='schema')
prisma_schema: PrismaSchema = FieldInfo(alias='schema')


class Datamodel(BaseModel):
Expand Down Expand Up @@ -1182,4 +1194,4 @@ class DefaultData(GenericData[_EmptyModel]):
TemplateError,
PartialTypeGeneratorError,
)
from .schema import Schema
from .schema import Schema, ClientTypes
34 changes: 32 additions & 2 deletions src/prisma/generator/schema.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from enum import Enum
from typing import Any, Dict, List, Type, Tuple, Union
from typing import Any, Dict, List, Type, Tuple, Union, Optional
from typing_extensions import ClassVar

from pydantic import BaseModel

from .models import Model as ModelInfo, AnyData, PrimaryKey
from .utils import to_constant_case
from .models import Model as ModelInfo, AnyData, PrimaryKey, DMMFEnumType
from .._compat import (
PYDANTIC_V2,
ConfigDict,
Expand All @@ -18,6 +19,7 @@ class Kind(str, Enum):
alias = 'alias'
union = 'union'
typeddict = 'typeddict'
enum = 'enum'


class PrismaType(BaseModel):
Expand Down Expand Up @@ -45,6 +47,11 @@ class PrismaUnion(PrismaType):
subtypes: List[PrismaType]


class PrismaEnum(PrismaType):
kind: Kind = Kind.enum
members: List[Tuple[str, str]]


class PrismaAlias(PrismaType):
kind: Kind = Kind.alias
to: str
Expand Down Expand Up @@ -143,6 +150,29 @@ def order_by(self) -> PrismaType:
return PrismaType.from_subtypes(subtypes, name=f'{model}OrderByInput')


class ClientTypes(BaseModel):
transaction_isolation_level: Optional[PrismaEnum]

@classmethod
def from_data(cls, data: AnyData) -> 'ClientTypes':
enum_types = data.dmmf.prisma_schema.enum_types.prisma

return cls(
transaction_isolation_level=construct_enum_type(enum_types, name='TransactionIsolationLevel'),
)


def construct_enum_type(dmmf_enum_types: List[DMMFEnumType], *, name: str) -> Optional[PrismaEnum]:
enum_type = next((t for t in dmmf_enum_types if t.name == name), None)
if not enum_type:
return None

return PrismaEnum(
name=name,
members=[(to_constant_case(str(value)), str(value)) for value in enum_type.values],
)


model_rebuild(Schema)
model_rebuild(PrismaType)
model_rebuild(PrismaDict)
Expand Down
5 changes: 5 additions & 0 deletions src/prisma/generator/templates/types.py.jinja
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,11 @@ from .utils import _NoneType
},
total={{ type.total }}
)
{% elif type.kind == 'enum' %}
class {{ type.name }}(StrEnum):
{% for name, value in type.members %}
{{ name }} = "{{ value }}"
{% endfor %}
{% else %}
{{ raise_err('Unhandled type kind: %s' % type.kind) }}
{% endif %}
Expand Down

0 comments on commit 64212ac

Please sign in to comment.