Skip to content

Commit

Permalink
add timezone name validation
Browse files Browse the repository at this point in the history
  • Loading branch information
07pepa committed Jun 30, 2024
1 parent 6133cbe commit 6ce366a
Show file tree
Hide file tree
Showing 4 changed files with 224 additions and 1 deletion.
123 changes: 123 additions & 0 deletions pydantic_extra_types/timezone_name.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
"""Time zone name validation and serialization module."""

from __future__ import annotations

import importlib
import sys
import warnings
from typing import Any, List

from pydantic import GetCoreSchemaHandler, GetJsonSchemaHandler
from pydantic_core import PydanticCustomError, core_schema


def _is_available(name: str) -> bool:
spec = importlib.util.find_spec(name)
return spec is not None


if _is_available('zoneinfo') and _is_available('tzdata'):
from zoneinfo import available_timezones

def _tz_provider() -> set[str]:
return set(available_timezones())

elif _is_available('pytz'):
if sys.version_info[:2] > (3, 8):
warnings.warn(
'Projects using Python 3.9 or later'
' should be using the support now included as part of the standard library zone-info. '
'Please consider switching to the standard library module.'
)
from pytz import all_timezones

def _tz_provider() -> set[str]:
return set(all_timezones)
else:
if sys.version_info[:2] == (3, 8):
raise ImportError('No pytz module not found. Please install it with "pip install pytz')
raise ImportError('No timezone provider found. Please install tzdata' 'Please install it with "pip install tzdata"')


class TimeZoneNameSettings(type):
def __new__(cls, name, bases, dct, **kwargs): # type: ignore[no-untyped-def]
dct['strict'] = kwargs.pop('strict', True)
return super().__new__(cls, name, bases, dct)

def __init__(cls, name, bases, dct, **kwargs): # type: ignore[no-untyped-def]
super().__init__(name, bases, dct)
cls.strict = kwargs.get('strict', True)


class TimeZoneName(str):
"""If the mode is not strict matching, it is case-insensitive with whitespace stripped.
Value is then coerced to the correct case."""

__metaclass__ = TimeZoneNameSettings
__slots__: List[str] = []
allowed_values = set(_tz_provider())
allowed_values_list = list(allowed_values)
allowed_values_list.sort()
allowed_values_upper_to_correct = {val.upper(): val for val in allowed_values}

@classmethod
def _validate(cls, __input_value: str, _: core_schema.ValidationInfo) -> TimeZoneName:
"""
Validate a time zone name from the provided str value.
Args:
__input_value: The str value to be validated.
_: The Pydantic ValidationInfo.
Returns:
The validated time zone name.
Raises:
PydanticCustomError: If the timezone name is not valid.
"""
if __input_value not in cls.allowed_values: # be fast for the most common case
if not cls.strict: # type: ignore[attr-defined]
upper_value = __input_value.strip().upper()
if upper_value in cls.allowed_values_upper_to_correct:
return cls(cls.allowed_values_upper_to_correct[upper_value])
raise PydanticCustomError('TimeZoneName', 'Invalid timezone name.')
return cls(__input_value)

@classmethod
def __get_pydantic_core_schema__(
cls, _: type[Any], __: GetCoreSchemaHandler
) -> core_schema.AfterValidatorFunctionSchema:
"""
Return a Pydantic CoreSchema with the ISO 639-3 language code validation.
Args:
_: The source type.
__: The handler to get the CoreSchema.
Returns:
A Pydantic CoreSchema with the ISO 639-3 language code validation.
"""
return core_schema.with_info_after_validator_function(
cls._validate,
core_schema.str_schema(min_length=1),
)

@classmethod
def __get_pydantic_json_schema__(
cls, schema: core_schema.CoreSchema, handler: GetJsonSchemaHandler
) -> dict[str, Any]:
"""
Return a Pydantic JSON Schema with the ISO 639-3 language code validation.
Args:
schema: The Pydantic CoreSchema.
handler: The handler to get the JSON Schema.
Returns:
A Pydantic JSON Schema with the ISO 639-3 language code validation.
"""
json_schema = handler(schema)
json_schema.update({'enum': cls.allowed_values_list})
return json_schema
5 changes: 4 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,10 @@ all = [
'pycountry>=23',
'python-ulid>=1,<2; python_version<"3.9"',
'python-ulid>=1,<3; python_version>="3.9"',
'pendulum>=3.0.0,<4.0.0'
'pendulum>=3.0.0,<4.0.0',
'pytz',
'tzdata',
'types-pytz',
]
phonenumbers = ['phonenumbers>=8,<9']
pycountry = ['pycountry>=23']
Expand Down
19 changes: 19 additions & 0 deletions tests/test_json_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from pydantic_extra_types.payment import PaymentCardNumber
from pydantic_extra_types.pendulum_dt import DateTime
from pydantic_extra_types.script_code import ISO_15924
from pydantic_extra_types.timezone_name import TimeZoneName
from pydantic_extra_types.ulid import ULID

languages = [lang.alpha_3 for lang in pycountry.languages]
Expand All @@ -35,6 +36,8 @@

scripts = [script.alpha_4 for script in pycountry.scripts]

timezone_names = TimeZoneName.allowed_values_list

everyday_currencies.sort()


Expand Down Expand Up @@ -325,6 +328,22 @@
'type': 'object',
},
),
(
TimeZoneName,
{
'properties': {
'x': {
'title': 'X',
'type': 'string',
'enum': timezone_names,
'minLength': 1,
}
},
'required': ['x'],
'title': 'Model',
'type': 'object',
},
),
],
)
def test_json_schema(cls, expected):
Expand Down
78 changes: 78 additions & 0 deletions tests/test_timezone_names.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
import re

import pytest
import pytz
from pydantic import BaseModel, ValidationError

from pydantic_extra_types.timezone_name import TimeZoneName

has_zone_info = True
try:
from zoneinfo import available_timezones
except ImportError:
has_zone_info = False

pytz_zones_bad = [(zone.lower(), zone) for zone in pytz.all_timezones]
pytz_zones_bad.extend([(f' {zone}', zone) for zone in pytz.all_timezones_set])


class TZNameCheck(BaseModel):
timezone_name: TimeZoneName


class TZNonStrict(TimeZoneName, strict=False):
pass


class NonStrictTzName(BaseModel):
timezone_name: TZNonStrict


@pytest.mark.parametrize('zone', pytz.all_timezones)
def test_all_timezones_non_strict_pytz(zone):
assert TZNameCheck(timezone_name=zone).timezone_name == zone
assert NonStrictTzName(timezone_name=zone).timezone_name == zone


@pytest.mark.parametrize('zone', pytz_zones_bad)
def test_all_timezones_pytz_lower(zone):
assert NonStrictTzName(timezone_name=zone[0]).timezone_name == zone[1]


def test_fail_non_existing_timezone():
with pytest.raises(
ValidationError,
match=re.escape(
'1 validation error for TZNameCheck\n'
'timezone_name\n '
'Invalid timezone name. '
"[type=TimeZoneName, input_value='mars', input_type=str]"
),
):
TZNameCheck(timezone_name='mars')

with pytest.raises(
ValidationError,
match=re.escape(
'1 validation error for NonStrictTzName\n'
'timezone_name\n '
'Invalid timezone name. '
"[type=TimeZoneName, input_value='mars', input_type=str]"
),
):
NonStrictTzName(timezone_name='mars')


if has_zone_info:
zones = list(available_timezones())
zones.sort()
zones_bad = [(zone.lower(), zone) for zone in zones]

@pytest.mark.parametrize('zone', zones)
def test_all_timezones_zone_info(zone):
assert TZNameCheck(timezone_name=zone).timezone_name == zone
assert NonStrictTzName(timezone_name=zone).timezone_name == zone

@pytest.mark.parametrize('zone', zones_bad)
def test_all_timezones_zone_info_NonStrict(zone):
assert NonStrictTzName(timezone_name=zone[0]).timezone_name == zone[1]

0 comments on commit 6ce366a

Please sign in to comment.