-
Notifications
You must be signed in to change notification settings - Fork 56
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
07pepa
committed
Jul 1, 2024
1 parent
6133cbe
commit 636f705
Showing
5 changed files
with
227 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -28,7 +28,7 @@ lint: | |
|
||
.PHONY: mypy | ||
mypy: | ||
mypy pydantic_extra_types | ||
@mypy pydantic_extra_types | ||
|
||
.PHONY: test | ||
test: | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,125 @@ | ||
"""Time zone name validation and serialization module.""" | ||
|
||
from __future__ import annotations | ||
|
||
import importlib | ||
import sys | ||
import warnings | ||
from typing import Any, List | ||
|
||
from pydantic import GetCoreSchemaHandler, GetJsonSchemaHandler | ||
from pydantic_core import PydanticCustomError, core_schema | ||
|
||
|
||
def _is_available(name: str) -> bool: | ||
try: | ||
importlib.import_module(name=name) | ||
return True | ||
except ModuleNotFoundError: # pragma: no cover | ||
return False | ||
|
||
|
||
if _is_available('zoneinfo') and _is_available('tzdata'): # pragma: no cover | ||
from zoneinfo import available_timezones | ||
|
||
def _tz_provider() -> set[str]: | ||
return set(available_timezones()) | ||
|
||
elif _is_available('pytz'): # pragma: no cover | ||
if sys.version_info[:2] > (3, 8): | ||
warnings.warn( | ||
'Projects using Python 3.9 or later' | ||
' should be using the support now included as part of the standard library zone-info. ' | ||
'Please consider switching to the standard library module.' | ||
) | ||
from pytz import all_timezones | ||
|
||
def _tz_provider() -> set[str]: | ||
return set(all_timezones) | ||
else: # pragma: no cover | ||
if sys.version_info[:2] == (3, 8): | ||
raise ImportError('No pytz module not found. Please install it with "pip install pytz') | ||
raise ImportError('No timezone provider found. Please install tzdata' 'Please install it with "pip install tzdata"') | ||
|
||
|
||
class TimeZoneNameSettings(type): | ||
def __new__(cls, name, bases, dct, **kwargs): # type: ignore[no-untyped-def] | ||
dct['strict'] = kwargs.pop('strict', True) | ||
return super().__new__(cls, name, bases, dct) | ||
|
||
def __init__(cls, name, bases, dct, **kwargs): # type: ignore[no-untyped-def] | ||
super().__init__(name, bases, dct) | ||
cls.strict = kwargs.get('strict', True) | ||
|
||
|
||
class TimeZoneName(str, metaclass=TimeZoneNameSettings): # type: ignore[misc] | ||
"""If the mode is not strict matching, it is case-insensitive with whitespace stripped. | ||
Value is then coerced to the correct case.""" | ||
|
||
__slots__: List[str] = [] | ||
allowed_values = set(_tz_provider()) | ||
allowed_values_list = list(allowed_values) | ||
allowed_values_list.sort() | ||
allowed_values_upper_to_correct = {val.upper(): val for val in allowed_values} | ||
|
||
@classmethod | ||
def _validate(cls, __input_value: str, _: core_schema.ValidationInfo) -> TimeZoneName: | ||
""" | ||
Validate a time zone name from the provided str value. | ||
Args: | ||
__input_value: The str value to be validated. | ||
_: The Pydantic ValidationInfo. | ||
Returns: | ||
The validated time zone name. | ||
Raises: | ||
PydanticCustomError: If the timezone name is not valid. | ||
""" | ||
if __input_value not in cls.allowed_values: # be fast for the most common case | ||
if not cls.strict: | ||
upper_value = __input_value.strip().upper() | ||
if upper_value in cls.allowed_values_upper_to_correct: | ||
return cls(cls.allowed_values_upper_to_correct[upper_value]) | ||
raise PydanticCustomError('TimeZoneName', 'Invalid timezone name.') | ||
return cls(__input_value) | ||
|
||
@classmethod | ||
def __get_pydantic_core_schema__( | ||
cls, _: type[Any], __: GetCoreSchemaHandler | ||
) -> core_schema.AfterValidatorFunctionSchema: | ||
""" | ||
Return a Pydantic CoreSchema with the ISO 639-3 language code validation. | ||
Args: | ||
_: The source type. | ||
__: The handler to get the CoreSchema. | ||
Returns: | ||
A Pydantic CoreSchema with the ISO 639-3 language code validation. | ||
""" | ||
return core_schema.with_info_after_validator_function( | ||
cls._validate, | ||
core_schema.str_schema(min_length=1), | ||
) | ||
|
||
@classmethod | ||
def __get_pydantic_json_schema__( | ||
cls, schema: core_schema.CoreSchema, handler: GetJsonSchemaHandler | ||
) -> dict[str, Any]: | ||
""" | ||
Return a Pydantic JSON Schema with the ISO 639-3 language code validation. | ||
Args: | ||
schema: The Pydantic CoreSchema. | ||
handler: The handler to get the JSON Schema. | ||
Returns: | ||
A Pydantic JSON Schema with the ISO 639-3 language code validation. | ||
""" | ||
json_schema = handler(schema) | ||
json_schema.update({'enum': cls.allowed_values_list}) | ||
return json_schema |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
import re | ||
|
||
import pytest | ||
import pytz | ||
from pydantic import BaseModel, ValidationError | ||
|
||
from pydantic_extra_types.timezone_name import TimeZoneName | ||
|
||
has_zone_info = True | ||
try: | ||
from zoneinfo import available_timezones | ||
except ImportError: | ||
has_zone_info = False | ||
|
||
pytz_zones_bad = [(zone.lower(), zone) for zone in pytz.all_timezones] | ||
pytz_zones_bad.extend([(f' {zone}', zone) for zone in pytz.all_timezones_set]) | ||
|
||
|
||
class TZNameCheck(BaseModel): | ||
timezone_name: TimeZoneName | ||
|
||
|
||
class TZNonStrict(TimeZoneName, strict=False): | ||
pass | ||
|
||
|
||
class NonStrictTzName(BaseModel): | ||
timezone_name: TZNonStrict | ||
|
||
|
||
@pytest.mark.parametrize('zone', pytz.all_timezones) | ||
def test_all_timezones_non_strict_pytz(zone): | ||
assert TZNameCheck(timezone_name=zone).timezone_name == zone | ||
assert NonStrictTzName(timezone_name=zone).timezone_name == zone | ||
|
||
|
||
@pytest.mark.parametrize('zone', pytz_zones_bad) | ||
def test_all_timezones_pytz_lower(zone): | ||
assert NonStrictTzName(timezone_name=zone[0]).timezone_name == zone[1] | ||
|
||
|
||
def test_fail_non_existing_timezone(): | ||
with pytest.raises( | ||
ValidationError, | ||
match=re.escape( | ||
'1 validation error for TZNameCheck\n' | ||
'timezone_name\n ' | ||
'Invalid timezone name. ' | ||
"[type=TimeZoneName, input_value='mars', input_type=str]" | ||
), | ||
): | ||
TZNameCheck(timezone_name='mars') | ||
|
||
with pytest.raises( | ||
ValidationError, | ||
match=re.escape( | ||
'1 validation error for NonStrictTzName\n' | ||
'timezone_name\n ' | ||
'Invalid timezone name. ' | ||
"[type=TimeZoneName, input_value='mars', input_type=str]" | ||
), | ||
): | ||
NonStrictTzName(timezone_name='mars') | ||
|
||
|
||
if has_zone_info: | ||
zones = list(available_timezones()) | ||
zones.sort() | ||
zones_bad = [(zone.lower(), zone) for zone in zones] | ||
|
||
@pytest.mark.parametrize('zone', zones) | ||
def test_all_timezones_zone_info(zone): | ||
assert TZNameCheck(timezone_name=zone).timezone_name == zone | ||
assert NonStrictTzName(timezone_name=zone).timezone_name == zone | ||
|
||
@pytest.mark.parametrize('zone', zones_bad) | ||
def test_all_timezones_zone_info_NonStrict(zone): | ||
assert NonStrictTzName(timezone_name=zone[0]).timezone_name == zone[1] |