diff --git a/pydantic_extra_types/timezone_name.py b/pydantic_extra_types/timezone_name.py new file mode 100644 index 00000000..93b2213d --- /dev/null +++ b/pydantic_extra_types/timezone_name.py @@ -0,0 +1,189 @@ +"""Time zone name validation and serialization module.""" + +from __future__ import annotations + +import importlib +import sys +import warnings +from typing import Any, Callable, List, Set, Type, cast + +from pydantic import GetCoreSchemaHandler, GetJsonSchemaHandler +from pydantic_core import PydanticCustomError, core_schema + + +def _is_available(name: str) -> bool: + """Check if a module is available for import.""" + try: + importlib.import_module(name) + return True + except ModuleNotFoundError: # pragma: no cover + return False + + +def _tz_provider_from_zone_info() -> Set[str]: # pragma: no cover + """Get timezones from the zoneinfo module.""" + from zoneinfo import available_timezones + + return set(available_timezones()) + + +def _tz_provider_from_pytz() -> Set[str]: # pragma: no cover + """Get timezones from the pytz module.""" + from pytz import all_timezones + + return set(all_timezones) + + +def _warn_about_pytz_usage() -> None: + """Warn about using pytz with Python 3.9 or later.""" + warnings.warn( # pragma: no cover + 'Projects using Python 3.9 or later should be using the support now included as part of the standard library. ' + 'Please consider switching to the standard library (zoneinfo) module.' + ) + + +def get_timezones() -> Set[str]: + """Determine the timezone provider and return available timezones.""" + if _is_available('zoneinfo') and _is_available('tzdata'): # pragma: no cover + return _tz_provider_from_zone_info() + elif _is_available('pytz'): # pragma: no cover + if sys.version_info[:2] > (3, 8): + _warn_about_pytz_usage() + return _tz_provider_from_pytz() + else: # pragma: no cover + if sys.version_info[:2] == (3, 8): + raise ImportError('No pytz module found. Please install it with "pip install pytz"') + raise ImportError('No timezone provider found. Please install tzdata with "pip install tzdata"') + + +class TimeZoneNameSettings(type): + def __new__(cls, name: str, bases: tuple[type, ...], dct: dict[str, Any], **kwargs: Any) -> Type[TimeZoneName]: + dct['strict'] = kwargs.pop('strict', True) + return cast(Type[TimeZoneName], super().__new__(cls, name, bases, dct)) + + def __init__(cls, name: str, bases: tuple[type, ...], dct: dict[str, Any], **kwargs: Any) -> None: + super().__init__(name, bases, dct) + cls.strict = kwargs.get('strict', True) + + +def timezone_name_settings(**kwargs: Any) -> Callable[[Type[TimeZoneName]], Type[TimeZoneName]]: + def wrapper(cls: Type[TimeZoneName]) -> Type[TimeZoneName]: + cls.strict = kwargs.get('strict', True) + return cls + + return wrapper + + +@timezone_name_settings(strict=True) +class TimeZoneName(str): + """ + TimeZoneName is a custom string subclass for validating and serializing timezone names. + + The TimeZoneName class uses the IANA Time Zone Database for validation. + It supports both strict and non-strict modes for timezone name validation. + + + ## Examples: + + Some examples of using the TimeZoneName class: + + ### Normal usage: + + ```python + from pydantic_extra_types.timezone_name import TimeZoneName + from pydantic import BaseModel + class Location(BaseModel): + city: str + timezone: TimeZoneName + + loc = Location(city="New York", timezone="America/New_York") + print(loc.timezone) + + >> America/New_York + + ``` + + ### Non-strict mode: + + ```python + + from pydantic_extra_types.timezone_name import TimeZoneName, timezone_name_settings + + @timezone_name_settings(strict=False) + class TZNonStrict(TimeZoneName): + pass + + tz = TZNonStrict("america/new_york") + + print(tz) + + >> america/new_york + + ``` + """ + + __slots__: List[str] = [] + allowed_values: Set[str] = set(get_timezones()) + allowed_values_list: List[str] = sorted(allowed_values) + allowed_values_upper_to_correct: dict[str, str] = {val.upper(): val for val in allowed_values} + strict: bool + + @classmethod + def _validate(cls, __input_value: str, _: core_schema.ValidationInfo) -> TimeZoneName: + """ + Validate a time zone name from the provided str value. + + Args: + __input_value: The str value to be validated. + _: The Pydantic ValidationInfo. + + Returns: + The validated time zone name. + + Raises: + PydanticCustomError: If the timezone name is not valid. + """ + if __input_value not in cls.allowed_values: # be fast for the most common case + if not cls.strict: + upper_value = __input_value.strip().upper() + if upper_value in cls.allowed_values_upper_to_correct: + return cls(cls.allowed_values_upper_to_correct[upper_value]) + raise PydanticCustomError('TimeZoneName', 'Invalid timezone name.') + return cls(__input_value) + + @classmethod + def __get_pydantic_core_schema__( + cls, _: Type[Any], __: GetCoreSchemaHandler + ) -> core_schema.AfterValidatorFunctionSchema: + """ + Return a Pydantic CoreSchema with the timezone name validation. + + Args: + _: The source type. + __: The handler to get the CoreSchema. + + Returns: + A Pydantic CoreSchema with the timezone name validation. + """ + return core_schema.with_info_after_validator_function( + cls._validate, + core_schema.str_schema(min_length=1), + ) + + @classmethod + def __get_pydantic_json_schema__( + cls, schema: core_schema.CoreSchema, handler: GetJsonSchemaHandler + ) -> dict[str, Any]: + """ + Return a Pydantic JSON Schema with the timezone name validation. + + Args: + schema: The Pydantic CoreSchema. + handler: The handler to get the JSON Schema. + + Returns: + A Pydantic JSON Schema with the timezone name validation. + """ + json_schema = handler(schema) + json_schema.update({'enum': cls.allowed_values_list}) + return json_schema diff --git a/pyproject.toml b/pyproject.toml index eadf5103..824db8cd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -50,7 +50,9 @@ all = [ 'semver>=3.0.2', 'python-ulid>=1,<2; python_version<"3.9"', 'python-ulid>=1,<3; python_version>="3.9"', - 'pendulum>=3.0.0,<4.0.0' + 'pendulum>=3.0.0,<4.0.0', + 'pytz>=2024.1', + 'tzdata>=2024.1', ] phonenumbers = ['phonenumbers>=8,<9'] pycountry = ['pycountry>=23'] diff --git a/requirements/linting.in b/requirements/linting.in index 06a5fced..fa0f927c 100644 --- a/requirements/linting.in +++ b/requirements/linting.in @@ -2,3 +2,4 @@ pre-commit mypy annotated-types ruff +types-pytz diff --git a/requirements/linting.txt b/requirements/linting.txt index 629c1505..a117fc10 100644 --- a/requirements/linting.txt +++ b/requirements/linting.txt @@ -1,5 +1,5 @@ # -# This file is autogenerated by pip-compile with Python 3.11 +# This file is autogenerated by pip-compile with Python 3.12 # by the following command: # # pip-compile --no-emit-index-url --output-file=requirements/linting.txt requirements/linting.in @@ -28,6 +28,8 @@ pyyaml==6.0.1 # via pre-commit ruff==0.5.0 # via -r requirements/linting.in +types-pytz==2024.1.0.20240417 + # via -r requirements/linting.in typing-extensions==4.10.0 # via mypy virtualenv==20.25.1 diff --git a/tests/test_json_schema.py b/tests/test_json_schema.py index 303098c1..c39c88f0 100644 --- a/tests/test_json_schema.py +++ b/tests/test_json_schema.py @@ -19,6 +19,7 @@ from pydantic_extra_types.pendulum_dt import DateTime from pydantic_extra_types.script_code import ISO_15924 from pydantic_extra_types.semantic_version import SemanticVersion +from pydantic_extra_types.timezone_name import TimeZoneName from pydantic_extra_types.ulid import ULID languages = [lang.alpha_3 for lang in pycountry.languages] @@ -36,6 +37,8 @@ scripts = [script.alpha_4 for script in pycountry.scripts] +timezone_names = TimeZoneName.allowed_values_list + everyday_currencies.sort() @@ -335,6 +338,22 @@ 'type': 'object', }, ), + ( + TimeZoneName, + { + 'properties': { + 'x': { + 'title': 'X', + 'type': 'string', + 'enum': timezone_names, + 'minLength': 1, + } + }, + 'required': ['x'], + 'title': 'Model', + 'type': 'object', + }, + ), ], ) def test_json_schema(cls, expected): diff --git a/tests/test_timezone_names.py b/tests/test_timezone_names.py new file mode 100644 index 00000000..d980092e --- /dev/null +++ b/tests/test_timezone_names.py @@ -0,0 +1,209 @@ +import re + +import pytest +import pytz +from pydantic import BaseModel, ValidationError +from pydantic_core import PydanticCustomError + +from pydantic_extra_types.timezone_name import TimeZoneName, TimeZoneNameSettings, timezone_name_settings + +has_zone_info = True +try: + from zoneinfo import available_timezones +except ImportError: + has_zone_info = False + +pytz_zones_bad = [(zone.lower(), zone) for zone in pytz.all_timezones] +pytz_zones_bad.extend([(f' {zone}', zone) for zone in pytz.all_timezones_set]) + + +class TZNameCheck(BaseModel): + timezone_name: TimeZoneName + + +@timezone_name_settings(strict=False) +class TZNonStrict(TimeZoneName): + pass + + +class NonStrictTzName(BaseModel): + timezone_name: TZNonStrict + + +@pytest.mark.parametrize('zone', pytz.all_timezones) +def test_all_timezones_non_strict_pytz(zone): + assert TZNameCheck(timezone_name=zone).timezone_name == zone + assert NonStrictTzName(timezone_name=zone).timezone_name == zone + + +@pytest.mark.parametrize('zone', pytz_zones_bad) +def test_all_timezones_pytz_lower(zone): + assert NonStrictTzName(timezone_name=zone[0]).timezone_name == zone[1] + + +def test_fail_non_existing_timezone(): + with pytest.raises( + ValidationError, + match=re.escape( + '1 validation error for TZNameCheck\n' + 'timezone_name\n ' + 'Invalid timezone name. ' + "[type=TimeZoneName, input_value='mars', input_type=str]" + ), + ): + TZNameCheck(timezone_name='mars') + + with pytest.raises( + ValidationError, + match=re.escape( + '1 validation error for NonStrictTzName\n' + 'timezone_name\n ' + 'Invalid timezone name. ' + "[type=TimeZoneName, input_value='mars', input_type=str]" + ), + ): + NonStrictTzName(timezone_name='mars') + + +if has_zone_info: + zones = list(available_timezones()) + zones.sort() + zones_bad = [(zone.lower(), zone) for zone in zones] + + @pytest.mark.parametrize('zone', zones) + def test_all_timezones_zone_info(zone): + assert TZNameCheck(timezone_name=zone).timezone_name == zone + assert NonStrictTzName(timezone_name=zone).timezone_name == zone + + @pytest.mark.parametrize('zone', zones_bad) + def test_all_timezones_zone_info_NonStrict(zone): + assert NonStrictTzName(timezone_name=zone[0]).timezone_name == zone[1] + + +def test_timezone_name_settings_metaclass(): + class TestStrictTZ(TimeZoneName, strict=True, metaclass=TimeZoneNameSettings): + pass + + class TestNonStrictTZ(TimeZoneName, strict=False, metaclass=TimeZoneNameSettings): + pass + + assert TestStrictTZ.strict is True + assert TestNonStrictTZ.strict is False + + # Test default value + class TestDefaultStrictTZ(TimeZoneName, metaclass=TimeZoneNameSettings): + pass + + assert TestDefaultStrictTZ.strict is True + + +def test_timezone_name_validation(): + valid_tz = 'America/New_York' + invalid_tz = 'Invalid/Timezone' + + assert TimeZoneName._validate(valid_tz, None) == valid_tz + + with pytest.raises(PydanticCustomError): + TimeZoneName._validate(invalid_tz, None) + + assert TZNonStrict._validate(valid_tz.lower(), None) == valid_tz + assert TZNonStrict._validate(f' {valid_tz} ', None) == valid_tz + + with pytest.raises(PydanticCustomError): + TZNonStrict._validate(invalid_tz, None) + + +def test_timezone_name_pydantic_core_schema(): + schema = TimeZoneName.__get_pydantic_core_schema__(TimeZoneName, None) + assert isinstance(schema, dict) + assert schema['type'] == 'function-after' + assert 'function' in schema + assert 'schema' in schema + assert schema['schema']['type'] == 'str' + assert schema['schema']['min_length'] == 1 + + +def test_timezone_name_pydantic_json_schema(): + core_schema = TimeZoneName.__get_pydantic_core_schema__(TimeZoneName, None) + + class MockJsonSchemaHandler: + def __call__(self, schema): + return {'type': 'string'} + + handler = MockJsonSchemaHandler() + json_schema = TimeZoneName.__get_pydantic_json_schema__(core_schema, handler) + assert 'enum' in json_schema + assert isinstance(json_schema['enum'], list) + assert len(json_schema['enum']) > 0 + + +def test_timezone_name_repr(): + tz = TimeZoneName('America/New_York') + assert repr(tz) == "'America/New_York'" + assert str(tz) == 'America/New_York' + + +def test_timezone_name_allowed_values(): + assert isinstance(TimeZoneName.allowed_values, set) + assert len(TimeZoneName.allowed_values) > 0 + assert all(isinstance(tz, str) for tz in TimeZoneName.allowed_values) + + assert isinstance(TimeZoneName.allowed_values_list, list) + assert len(TimeZoneName.allowed_values_list) > 0 + assert all(isinstance(tz, str) for tz in TimeZoneName.allowed_values_list) + + assert isinstance(TimeZoneName.allowed_values_upper_to_correct, dict) + assert len(TimeZoneName.allowed_values_upper_to_correct) > 0 + assert all( + isinstance(k, str) and isinstance(v, str) for k, v in TimeZoneName.allowed_values_upper_to_correct.items() + ) + + +def test_timezone_name_inheritance(): + class CustomTZ(TimeZoneName, metaclass=TimeZoneNameSettings): + pass + + assert issubclass(CustomTZ, TimeZoneName) + assert issubclass(CustomTZ, str) + assert isinstance(CustomTZ('America/New_York'), (CustomTZ, TimeZoneName, str)) + + +def test_timezone_name_string_operations(): + tz = TimeZoneName('America/New_York') + assert tz.upper() == 'AMERICA/NEW_YORK' + assert tz.lower() == 'america/new_york' + assert tz.strip() == 'America/New_York' + assert f'{tz} Time' == 'America/New_York Time' + assert tz.startswith('America') + assert tz.endswith('York') + + +def test_timezone_name_comparison(): + tz1 = TimeZoneName('America/New_York') + tz2 = TimeZoneName('Europe/London') + tz3 = TimeZoneName('America/New_York') + + assert tz1 == tz3 + assert tz1 != tz2 + assert tz1 < tz2 # Alphabetical comparison + assert tz2 > tz1 + assert tz1 <= tz3 + assert tz1 >= tz3 + + +def test_timezone_name_hash(): + tz1 = TimeZoneName('America/New_York') + tz2 = TimeZoneName('America/New_York') + tz3 = TimeZoneName('Europe/London') + + assert hash(tz1) == hash(tz2) + assert hash(tz1) != hash(tz3) + + tz_set = {tz1, tz2, tz3} + assert len(tz_set) == 2 + + +def test_timezone_name_slots(): + tz = TimeZoneName('America/New_York') + with pytest.raises(AttributeError): + tz.new_attribute = 'test'