diff --git a/.release-please-manifest.json b/.release-please-manifest.json index fd0ccba..000572e 100644 --- a/.release-please-manifest.json +++ b/.release-please-manifest.json @@ -1,3 +1,3 @@ { - ".": "0.1.0-alpha.12" + ".": "0.1.0-alpha.13" } \ No newline at end of file diff --git a/CHANGELOG.md b/CHANGELOG.md index 5979d8a..fca4be8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,26 @@ # Changelog +## 0.1.0-alpha.13 (2024-11-28) + +Full Changelog: [v0.1.0-alpha.12...v0.1.0-alpha.13](https://github.com/clear-street/studio-sdk-python/compare/v0.1.0-alpha.12...v0.1.0-alpha.13) + +### Bug Fixes + +* **client:** compat with new httpx 0.28.0 release ([#66](https://github.com/clear-street/studio-sdk-python/issues/66)) ([0586398](https://github.com/clear-street/studio-sdk-python/commit/0586398d10ecc772be8d75283162bb22fea4a073)) + + +### Chores + +* **internal:** codegen related update ([#64](https://github.com/clear-street/studio-sdk-python/issues/64)) ([02d56d8](https://github.com/clear-street/studio-sdk-python/commit/02d56d8e51138b0233039763ef0a48980dbb17c3)) +* **internal:** codegen related update ([#65](https://github.com/clear-street/studio-sdk-python/issues/65)) ([f23970d](https://github.com/clear-street/studio-sdk-python/commit/f23970d3be640a2f75909c5709c8048617e874e9)) +* **internal:** fix compat model_dump method when warnings are passed ([#63](https://github.com/clear-street/studio-sdk-python/issues/63)) ([e5874f8](https://github.com/clear-street/studio-sdk-python/commit/e5874f89c9b06e01288f03c35421c751dee5ec0f)) +* **internal:** version bump ([#56](https://github.com/clear-street/studio-sdk-python/issues/56)) ([87913d2](https://github.com/clear-street/studio-sdk-python/commit/87913d2c81c926ab7e9e051b0a26c9e91cf32932)) +* rebuild project due to codegen change ([#58](https://github.com/clear-street/studio-sdk-python/issues/58)) ([0dcf41f](https://github.com/clear-street/studio-sdk-python/commit/0dcf41f9015287afd18a8718c83403573a685201)) +* rebuild project due to codegen change ([#59](https://github.com/clear-street/studio-sdk-python/issues/59)) ([569c4cd](https://github.com/clear-street/studio-sdk-python/commit/569c4cd8025b04e6dad8fbfacdca7bb71bbaf4d3)) +* rebuild project due to codegen change ([#60](https://github.com/clear-street/studio-sdk-python/issues/60)) ([904d865](https://github.com/clear-street/studio-sdk-python/commit/904d8650ff6aa09e031901901d91f546469a8bc6)) +* rebuild project due to codegen change ([#61](https://github.com/clear-street/studio-sdk-python/issues/61)) ([e4fa147](https://github.com/clear-street/studio-sdk-python/commit/e4fa14758305145cdb11ae802d486e54cfa574c5)) +* rebuild project due to codegen change ([#62](https://github.com/clear-street/studio-sdk-python/issues/62)) ([de76347](https://github.com/clear-street/studio-sdk-python/commit/de76347c8e121c6c6df789287663e2c43661332e)) + ## 0.1.0-alpha.12 (2024-10-25) Full Changelog: [v0.1.0-alpha.11...v0.1.0-alpha.12](https://github.com/clear-street/studio-sdk-python/compare/v0.1.0-alpha.11...v0.1.0-alpha.12) diff --git a/README.md b/README.md index dc103a2..a86ec70 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,7 @@ [![PyPI version](https://img.shields.io/pypi/v/clear-street-studio-sdk.svg)](https://pypi.org/project/clear-street-studio-sdk/) -The Studio SDK Python library provides convenient access to the Studio SDK REST API from any Python 3.7+ +The Studio SDK Python library provides convenient access to the Studio SDK REST API from any Python 3.8+ application. The library includes type definitions for all request params and response fields, and offers both synchronous and asynchronous clients powered by [httpx](https://github.com/encode/httpx). @@ -28,8 +28,9 @@ import os from studio_sdk import StudioSDK client = StudioSDK( - # This is the default and can be omitted - bearer_token=os.environ.get("STUDIO_SDK_BEARER_TOKEN"), + bearer_token=os.environ.get( + "STUDIO_SDK_BEARER_TOKEN" + ), # This is the default and can be omitted # defaults to "production". environment="sandbox", ) @@ -55,8 +56,9 @@ import asyncio from studio_sdk import AsyncStudioSDK client = AsyncStudioSDK( - # This is the default and can be omitted - bearer_token=os.environ.get("STUDIO_SDK_BEARER_TOKEN"), + bearer_token=os.environ.get( + "STUDIO_SDK_BEARER_TOKEN" + ), # This is the default and can be omitted # defaults to "production". environment="sandbox", ) @@ -184,12 +186,14 @@ Note that requests that time out are [retried twice by default](#retries). We use the standard library [`logging`](https://docs.python.org/3/library/logging.html) module. -You can enable logging by setting the environment variable `STUDIO_SDK_LOG` to `debug`. +You can enable logging by setting the environment variable `STUDIO_SDK_LOG` to `info`. ```shell -$ export STUDIO_SDK_LOG=debug +$ export STUDIO_SDK_LOG=info ``` +Or to `debug` for more verbose logging. + ### How to tell whether `None` means `null` or missing In an API response, a field may be explicitly `null`, or missing entirely; in either case, its value is `None` in this library. You can differentiate the two cases with `.model_fields_set`: @@ -332,7 +336,7 @@ print(studio_sdk.__version__) ## Requirements -Python 3.7 or higher. +Python 3.8 or higher. ## Contributing diff --git a/mypy.ini b/mypy.ini index fd76e99..6b4ad91 100644 --- a/mypy.ini +++ b/mypy.ini @@ -5,7 +5,10 @@ show_error_codes = True # Exclude _files.py because mypy isn't smart enough to apply # the correct type narrowing and as this is an internal module # it's fine to just use Pyright. -exclude = ^(src/studio_sdk/_files\.py|_dev/.*\.py)$ +# +# We also exclude our `tests` as mypy doesn't always infer +# types correctly and Pyright will still catch any type errors. +exclude = ^(src/studio_sdk/_files\.py|_dev/.*\.py|tests/.*)$ strict_equality = True implicit_reexport = True diff --git a/pyproject.toml b/pyproject.toml index 4c31df4..de9a165 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "clear-street-studio-sdk" -version = "0.1.0-alpha.12" +version = "0.1.0-alpha.13" description = "The official Python library for the studio-sdk API" dynamic = ["readme"] license = "Apache-2.0" @@ -14,13 +14,11 @@ dependencies = [ "anyio>=3.5.0, <5", "distro>=1.7.0, <2", "sniffio", - "cached-property; python_version < '3.8'", ] -requires-python = ">= 3.7" +requires-python = ">= 3.8" classifiers = [ "Typing :: Typed", "Intended Audience :: Developers", - "Programming Language :: Python :: 3.7", "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", @@ -56,6 +54,7 @@ dev-dependencies = [ "dirty-equals>=0.6.0", "importlib-metadata>=6.7.0", "rich>=13.7.1", + "nest_asyncio==1.6.0" ] [tool.rye.scripts] @@ -139,7 +138,7 @@ filterwarnings = [ # there are a couple of flags that are still disabled by # default in strict mode as they are experimental and niche. typeCheckingMode = "strict" -pythonVersion = "3.7" +pythonVersion = "3.8" exclude = [ "_dev", diff --git a/requirements-dev.lock b/requirements-dev.lock index 52085ff..b5bc035 100644 --- a/requirements-dev.lock +++ b/requirements-dev.lock @@ -16,8 +16,6 @@ anyio==4.4.0 # via httpx argcomplete==3.1.2 # via nox -attrs==23.1.0 - # via pytest certifi==2023.7.22 # via httpcore # via httpx @@ -28,8 +26,9 @@ distlib==0.3.7 # via virtualenv distro==1.8.0 # via clear-street-studio-sdk -exceptiongroup==1.1.3 +exceptiongroup==1.2.2 # via anyio + # via pytest filelock==3.12.4 # via virtualenv h11==0.14.0 @@ -49,9 +48,10 @@ markdown-it-py==3.0.0 # via rich mdurl==0.1.2 # via markdown-it-py -mypy==1.11.2 +mypy==1.13.0 mypy-extensions==1.0.0 # via mypy +nest-asyncio==1.6.0 nodeenv==1.8.0 # via pyright nox==2023.4.22 @@ -60,20 +60,18 @@ packaging==23.2 # via pytest platformdirs==3.11.0 # via virtualenv -pluggy==1.3.0 - # via pytest -py==1.11.0 +pluggy==1.5.0 # via pytest -pydantic==2.7.1 +pydantic==2.9.2 # via clear-street-studio-sdk -pydantic-core==2.18.2 +pydantic-core==2.23.4 # via pydantic pygments==2.18.0 # via rich pyright==1.1.380 -pytest==7.1.1 +pytest==8.3.3 # via pytest-asyncio -pytest-asyncio==0.21.1 +pytest-asyncio==0.24.0 python-dateutil==2.8.2 # via time-machine pytz==2023.3.post1 @@ -90,10 +88,10 @@ sniffio==1.3.0 # via clear-street-studio-sdk # via httpx time-machine==2.9.0 -tomli==2.0.1 +tomli==2.0.2 # via mypy # via pytest -typing-extensions==4.8.0 +typing-extensions==4.12.2 # via anyio # via clear-street-studio-sdk # via mypy diff --git a/requirements.lock b/requirements.lock index b4b76f4..844da47 100644 --- a/requirements.lock +++ b/requirements.lock @@ -19,7 +19,7 @@ certifi==2023.7.22 # via httpx distro==1.8.0 # via clear-street-studio-sdk -exceptiongroup==1.1.3 +exceptiongroup==1.2.2 # via anyio h11==0.14.0 # via httpcore @@ -30,15 +30,15 @@ httpx==0.25.2 idna==3.4 # via anyio # via httpx -pydantic==2.7.1 +pydantic==2.9.2 # via clear-street-studio-sdk -pydantic-core==2.18.2 +pydantic-core==2.23.4 # via pydantic sniffio==1.3.0 # via anyio # via clear-street-studio-sdk # via httpx -typing-extensions==4.8.0 +typing-extensions==4.12.2 # via anyio # via clear-street-studio-sdk # via pydantic diff --git a/src/studio_sdk/_base_client.py b/src/studio_sdk/_base_client.py index b7a98a0..2f02c32 100644 --- a/src/studio_sdk/_base_client.py +++ b/src/studio_sdk/_base_client.py @@ -792,6 +792,7 @@ def __init__( custom_query: Mapping[str, object] | None = None, _strict_response_validation: bool, ) -> None: + kwargs: dict[str, Any] = {} if limits is not None: warnings.warn( "The `connection_pool_limits` argument is deprecated. The `http_client` argument should be passed instead", @@ -804,6 +805,7 @@ def __init__( limits = DEFAULT_CONNECTION_LIMITS if transport is not None: + kwargs["transport"] = transport warnings.warn( "The `transport` argument is deprecated. The `http_client` argument should be passed instead", category=DeprecationWarning, @@ -813,6 +815,7 @@ def __init__( raise ValueError("The `http_client` argument is mutually exclusive with `transport`") if proxies is not None: + kwargs["proxies"] = proxies warnings.warn( "The `proxies` argument is deprecated. The `http_client` argument should be passed instead", category=DeprecationWarning, @@ -856,10 +859,9 @@ def __init__( base_url=base_url, # cast to a valid type because mypy doesn't understand our type narrowing timeout=cast(Timeout, timeout), - proxies=proxies, - transport=transport, limits=limits, follow_redirects=True, + **kwargs, # type: ignore ) def is_closed(self) -> bool: @@ -1358,6 +1360,7 @@ def __init__( custom_headers: Mapping[str, str] | None = None, custom_query: Mapping[str, object] | None = None, ) -> None: + kwargs: dict[str, Any] = {} if limits is not None: warnings.warn( "The `connection_pool_limits` argument is deprecated. The `http_client` argument should be passed instead", @@ -1370,6 +1373,7 @@ def __init__( limits = DEFAULT_CONNECTION_LIMITS if transport is not None: + kwargs["transport"] = transport warnings.warn( "The `transport` argument is deprecated. The `http_client` argument should be passed instead", category=DeprecationWarning, @@ -1379,6 +1383,7 @@ def __init__( raise ValueError("The `http_client` argument is mutually exclusive with `transport`") if proxies is not None: + kwargs["proxies"] = proxies warnings.warn( "The `proxies` argument is deprecated. The `http_client` argument should be passed instead", category=DeprecationWarning, @@ -1422,10 +1427,9 @@ def __init__( base_url=base_url, # cast to a valid type because mypy doesn't understand our type narrowing timeout=cast(Timeout, timeout), - proxies=proxies, - transport=transport, limits=limits, follow_redirects=True, + **kwargs, # type: ignore ) def is_closed(self) -> bool: diff --git a/src/studio_sdk/_compat.py b/src/studio_sdk/_compat.py index 162a6fb..92d9ee6 100644 --- a/src/studio_sdk/_compat.py +++ b/src/studio_sdk/_compat.py @@ -2,7 +2,7 @@ from typing import TYPE_CHECKING, Any, Union, Generic, TypeVar, Callable, cast, overload from datetime import date, datetime -from typing_extensions import Self +from typing_extensions import Self, Literal import pydantic from pydantic.fields import FieldInfo @@ -133,17 +133,20 @@ def model_json(model: pydantic.BaseModel, *, indent: int | None = None) -> str: def model_dump( model: pydantic.BaseModel, *, - exclude: IncEx = None, + exclude: IncEx | None = None, exclude_unset: bool = False, exclude_defaults: bool = False, warnings: bool = True, + mode: Literal["json", "python"] = "python", ) -> dict[str, Any]: - if PYDANTIC_V2: + if PYDANTIC_V2 or hasattr(model, "model_dump"): return model.model_dump( + mode=mode, exclude=exclude, exclude_unset=exclude_unset, exclude_defaults=exclude_defaults, - warnings=warnings, + # warnings are not supported in Pydantic v1 + warnings=warnings if PYDANTIC_V2 else True, ) return cast( "dict[str, Any]", @@ -211,9 +214,6 @@ def __set_name__(self, owner: type[Any], name: str) -> None: ... # __set__ is not defined at runtime, but @cached_property is designed to be settable def __set__(self, instance: object, value: _T) -> None: ... else: - try: - from functools import cached_property as cached_property - except ImportError: - from cached_property import cached_property as cached_property + from functools import cached_property as cached_property typed_cached_property = cached_property diff --git a/src/studio_sdk/_models.py b/src/studio_sdk/_models.py index d386eaa..6cb469e 100644 --- a/src/studio_sdk/_models.py +++ b/src/studio_sdk/_models.py @@ -37,6 +37,7 @@ PropertyInfo, is_list, is_given, + json_safe, lru_cache, is_mapping, parse_date, @@ -176,7 +177,7 @@ def __str__(self) -> str: # Based on https://github.com/samuelcolvin/pydantic/issues/1168#issuecomment-817742836. @classmethod @override - def construct( + def construct( # pyright: ignore[reportIncompatibleMethodOverride] cls: Type[ModelT], _fields_set: set[str] | None = None, **values: object, @@ -248,8 +249,8 @@ def model_dump( self, *, mode: Literal["json", "python"] | str = "python", - include: IncEx = None, - exclude: IncEx = None, + include: IncEx | None = None, + exclude: IncEx | None = None, by_alias: bool = False, exclude_unset: bool = False, exclude_defaults: bool = False, @@ -279,8 +280,8 @@ def model_dump( Returns: A dictionary representation of the model. """ - if mode != "python": - raise ValueError("mode is only supported in Pydantic v2") + if mode not in {"json", "python"}: + raise ValueError("mode must be either 'json' or 'python'") if round_trip != False: raise ValueError("round_trip is only supported in Pydantic v2") if warnings != True: @@ -289,7 +290,7 @@ def model_dump( raise ValueError("context is only supported in Pydantic v2") if serialize_as_any != False: raise ValueError("serialize_as_any is only supported in Pydantic v2") - return super().dict( # pyright: ignore[reportDeprecated] + dumped = super().dict( # pyright: ignore[reportDeprecated] include=include, exclude=exclude, by_alias=by_alias, @@ -298,13 +299,15 @@ def model_dump( exclude_none=exclude_none, ) + return cast(dict[str, Any], json_safe(dumped)) if mode == "json" else dumped + @override def model_dump_json( self, *, indent: int | None = None, - include: IncEx = None, - exclude: IncEx = None, + include: IncEx | None = None, + exclude: IncEx | None = None, by_alias: bool = False, exclude_unset: bool = False, exclude_defaults: bool = False, diff --git a/src/studio_sdk/_types.py b/src/studio_sdk/_types.py index ad0f6dd..695ce5a 100644 --- a/src/studio_sdk/_types.py +++ b/src/studio_sdk/_types.py @@ -16,7 +16,7 @@ Optional, Sequence, ) -from typing_extensions import Literal, Protocol, TypeAlias, TypedDict, override, runtime_checkable +from typing_extensions import Set, Literal, Protocol, TypeAlias, TypedDict, override, runtime_checkable import httpx import pydantic @@ -193,7 +193,9 @@ def get(self, __key: str) -> str | None: ... # Note: copied from Pydantic # https://github.com/pydantic/pydantic/blob/32ea570bf96e84234d2992e1ddf40ab8a565925a/pydantic/main.py#L49 -IncEx: TypeAlias = "set[int] | set[str] | dict[int, Any] | dict[str, Any] | None" +IncEx: TypeAlias = Union[ + Set[int], Set[str], Mapping[int, Union["IncEx", Literal[True]]], Mapping[str, Union["IncEx", Literal[True]]] +] PostParser = Callable[[Any], Any] diff --git a/src/studio_sdk/_utils/__init__.py b/src/studio_sdk/_utils/__init__.py index 3efe66c..a7cff3c 100644 --- a/src/studio_sdk/_utils/__init__.py +++ b/src/studio_sdk/_utils/__init__.py @@ -6,6 +6,7 @@ is_list as is_list, is_given as is_given, is_tuple as is_tuple, + json_safe as json_safe, lru_cache as lru_cache, is_mapping as is_mapping, is_tuple_t as is_tuple_t, diff --git a/src/studio_sdk/_utils/_sync.py b/src/studio_sdk/_utils/_sync.py index d0d8103..8b3aaf2 100644 --- a/src/studio_sdk/_utils/_sync.py +++ b/src/studio_sdk/_utils/_sync.py @@ -1,56 +1,62 @@ from __future__ import annotations +import sys +import asyncio import functools -from typing import TypeVar, Callable, Awaitable +import contextvars +from typing import Any, TypeVar, Callable, Awaitable from typing_extensions import ParamSpec -import anyio -import anyio.to_thread - -from ._reflection import function_has_argument - T_Retval = TypeVar("T_Retval") T_ParamSpec = ParamSpec("T_ParamSpec") -# copied from `asyncer`, https://github.com/tiangolo/asyncer -def asyncify( - function: Callable[T_ParamSpec, T_Retval], - *, - cancellable: bool = False, - limiter: anyio.CapacityLimiter | None = None, -) -> Callable[T_ParamSpec, Awaitable[T_Retval]]: +if sys.version_info >= (3, 9): + to_thread = asyncio.to_thread +else: + # backport of https://docs.python.org/3/library/asyncio-task.html#asyncio.to_thread + # for Python 3.8 support + async def to_thread( + func: Callable[T_ParamSpec, T_Retval], /, *args: T_ParamSpec.args, **kwargs: T_ParamSpec.kwargs + ) -> Any: + """Asynchronously run function *func* in a separate thread. + + Any *args and **kwargs supplied for this function are directly passed + to *func*. Also, the current :class:`contextvars.Context` is propagated, + allowing context variables from the main thread to be accessed in the + separate thread. + + Returns a coroutine that can be awaited to get the eventual result of *func*. + """ + loop = asyncio.events.get_running_loop() + ctx = contextvars.copy_context() + func_call = functools.partial(ctx.run, func, *args, **kwargs) + return await loop.run_in_executor(None, func_call) + + +# inspired by `asyncer`, https://github.com/tiangolo/asyncer +def asyncify(function: Callable[T_ParamSpec, T_Retval]) -> Callable[T_ParamSpec, Awaitable[T_Retval]]: """ Take a blocking function and create an async one that receives the same - positional and keyword arguments, and that when called, calls the original function - in a worker thread using `anyio.to_thread.run_sync()`. Internally, - `asyncer.asyncify()` uses the same `anyio.to_thread.run_sync()`, but it supports - keyword arguments additional to positional arguments and it adds better support for - autocompletion and inline errors for the arguments of the function called and the - return value. - - If the `cancellable` option is enabled and the task waiting for its completion is - cancelled, the thread will still run its course but its return value (or any raised - exception) will be ignored. + positional and keyword arguments. For python version 3.9 and above, it uses + asyncio.to_thread to run the function in a separate thread. For python version + 3.8, it uses locally defined copy of the asyncio.to_thread function which was + introduced in python 3.9. - Use it like this: + Usage: - ```Python - def do_work(arg1, arg2, kwarg1="", kwarg2="") -> str: - # Do work - return "Some result" + ```python + def blocking_func(arg1, arg2, kwarg1=None): + # blocking code + return result - result = await to_thread.asyncify(do_work)("spam", "ham", kwarg1="a", kwarg2="b") - print(result) + result = asyncify(blocking_function)(arg1, arg2, kwarg1=value1) ``` ## Arguments `function`: a blocking regular callable (e.g. a function) - `cancellable`: `True` to allow cancellation of the operation - `limiter`: capacity limiter to use to limit the total amount of threads running - (if omitted, the default limiter is used) ## Return @@ -60,22 +66,6 @@ def do_work(arg1, arg2, kwarg1="", kwarg2="") -> str: """ async def wrapper(*args: T_ParamSpec.args, **kwargs: T_ParamSpec.kwargs) -> T_Retval: - partial_f = functools.partial(function, *args, **kwargs) - - # In `v4.1.0` anyio added the `abandon_on_cancel` argument and deprecated the old - # `cancellable` argument, so we need to use the new `abandon_on_cancel` to avoid - # surfacing deprecation warnings. - if function_has_argument(anyio.to_thread.run_sync, "abandon_on_cancel"): - return await anyio.to_thread.run_sync( - partial_f, - abandon_on_cancel=cancellable, - limiter=limiter, - ) - - return await anyio.to_thread.run_sync( - partial_f, - cancellable=cancellable, - limiter=limiter, - ) + return await to_thread(function, *args, **kwargs) return wrapper diff --git a/src/studio_sdk/_utils/_transform.py b/src/studio_sdk/_utils/_transform.py index 47e262a..a6b62ca 100644 --- a/src/studio_sdk/_utils/_transform.py +++ b/src/studio_sdk/_utils/_transform.py @@ -173,6 +173,11 @@ def _transform_recursive( # Iterable[T] or (is_iterable_type(stripped_type) and is_iterable(data) and not isinstance(data, str)) ): + # dicts are technically iterable, but it is an iterable on the keys of the dict and is not usually + # intended as an iterable, so we don't transform it. + if isinstance(data, dict): + return cast(object, data) + inner_type = extract_type_arg(stripped_type, 0) return [_transform_recursive(d, annotation=annotation, inner_type=inner_type) for d in data] @@ -186,7 +191,7 @@ def _transform_recursive( return data if isinstance(data, pydantic.BaseModel): - return model_dump(data, exclude_unset=True) + return model_dump(data, exclude_unset=True, mode="json") annotated_type = _get_annotated_type(annotation) if annotated_type is None: @@ -311,6 +316,11 @@ async def _async_transform_recursive( # Iterable[T] or (is_iterable_type(stripped_type) and is_iterable(data) and not isinstance(data, str)) ): + # dicts are technically iterable, but it is an iterable on the keys of the dict and is not usually + # intended as an iterable, so we don't transform it. + if isinstance(data, dict): + return cast(object, data) + inner_type = extract_type_arg(stripped_type, 0) return [await _async_transform_recursive(d, annotation=annotation, inner_type=inner_type) for d in data] @@ -324,7 +334,7 @@ async def _async_transform_recursive( return data if isinstance(data, pydantic.BaseModel): - return model_dump(data, exclude_unset=True) + return model_dump(data, exclude_unset=True, mode="json") annotated_type = _get_annotated_type(annotation) if annotated_type is None: diff --git a/src/studio_sdk/_utils/_utils.py b/src/studio_sdk/_utils/_utils.py index 0bba17c..e5811bb 100644 --- a/src/studio_sdk/_utils/_utils.py +++ b/src/studio_sdk/_utils/_utils.py @@ -16,6 +16,7 @@ overload, ) from pathlib import Path +from datetime import date, datetime from typing_extensions import TypeGuard import sniffio @@ -395,3 +396,19 @@ def lru_cache(*, maxsize: int | None = 128) -> Callable[[CallableT], CallableT]: maxsize=maxsize, ) return cast(Any, wrapper) # type: ignore[no-any-return] + + +def json_safe(data: object) -> object: + """Translates a mapping / sequence recursively in the same fashion + as `pydantic` v2's `model_dump(mode="json")`. + """ + if is_mapping(data): + return {json_safe(key): json_safe(value) for key, value in data.items()} + + if is_iterable(data) and not isinstance(data, (str, bytes, bytearray)): + return [json_safe(item) for item in data] + + if isinstance(data, (datetime, date)): + return data.isoformat() + + return data diff --git a/src/studio_sdk/_version.py b/src/studio_sdk/_version.py index f2c5199..9a72890 100644 --- a/src/studio_sdk/_version.py +++ b/src/studio_sdk/_version.py @@ -1,4 +1,4 @@ # File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. __title__ = "studio_sdk" -__version__ = "0.1.0-alpha.12" # x-release-please-version +__version__ = "0.1.0-alpha.13" # x-release-please-version diff --git a/tests/api_resources/entities/test_regt_margin_simulations.py b/tests/api_resources/entities/test_regt_margin_simulations.py index 618f474..4015566 100644 --- a/tests/api_resources/entities/test_regt_margin_simulations.py +++ b/tests/api_resources/entities/test_regt_margin_simulations.py @@ -39,17 +39,7 @@ def test_method_create_with_all_params(self, client: StudioSDK) -> None: "price": "x", "symbol": "AAPL", "symbol_format": "cms", - }, - { - "price": "x", - "symbol": "AAPL", - "symbol_format": "cms", - }, - { - "price": "x", - "symbol": "AAPL", - "symbol_format": "cms", - }, + } ], trades=[ { @@ -58,21 +48,7 @@ def test_method_create_with_all_params(self, client: StudioSDK) -> None: "side": "buy", "symbol": "AAPL", "symbol_format": "cms", - }, - { - "price": "x", - "quantity": "x", - "side": "buy", - "symbol": "AAPL", - "symbol_format": "cms", - }, - { - "price": "x", - "quantity": "x", - "side": "buy", - "symbol": "AAPL", - "symbol_format": "cms", - }, + } ], ) assert_matches_type(RegtMarginSimulationCreateResponse, regt_margin_simulation, path=["response"]) @@ -182,17 +158,7 @@ async def test_method_create_with_all_params(self, async_client: AsyncStudioSDK) "price": "x", "symbol": "AAPL", "symbol_format": "cms", - }, - { - "price": "x", - "symbol": "AAPL", - "symbol_format": "cms", - }, - { - "price": "x", - "symbol": "AAPL", - "symbol_format": "cms", - }, + } ], trades=[ { @@ -201,21 +167,7 @@ async def test_method_create_with_all_params(self, async_client: AsyncStudioSDK) "side": "buy", "symbol": "AAPL", "symbol_format": "cms", - }, - { - "price": "x", - "quantity": "x", - "side": "buy", - "symbol": "AAPL", - "symbol_format": "cms", - }, - { - "price": "x", - "quantity": "x", - "side": "buy", - "symbol": "AAPL", - "symbol_format": "cms", - }, + } ], ) assert_matches_type(RegtMarginSimulationCreateResponse, regt_margin_simulation, path=["response"]) diff --git a/tests/conftest.py b/tests/conftest.py index 918845a..93f3a4f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,11 +1,11 @@ from __future__ import annotations import os -import asyncio import logging from typing import TYPE_CHECKING, Iterator, AsyncIterator import pytest +from pytest_asyncio import is_async_test from studio_sdk import StudioSDK, AsyncStudioSDK @@ -17,11 +17,13 @@ logging.getLogger("studio_sdk").setLevel(logging.DEBUG) -@pytest.fixture(scope="session") -def event_loop() -> Iterator[asyncio.AbstractEventLoop]: - loop = asyncio.new_event_loop() - yield loop - loop.close() +# automatically add `pytest.mark.asyncio()` to all of our async tests +# so we don't have to add that boilerplate everywhere +def pytest_collection_modifyitems(items: list[pytest.Function]) -> None: + pytest_asyncio_tests = (item for item in items if is_async_test(item)) + session_scope_marker = pytest.mark.asyncio(loop_scope="session") + for async_test in pytest_asyncio_tests: + async_test.add_marker(session_scope_marker, append=False) base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010") diff --git a/tests/test_client.py b/tests/test_client.py index ac074b5..cafac71 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -4,11 +4,14 @@ import gc import os +import sys import json import asyncio import inspect +import subprocess import tracemalloc from typing import Any, Union, cast +from textwrap import dedent from unittest import mock from typing_extensions import Literal @@ -731,7 +734,7 @@ class Model(BaseModel): [3, "", 0.5], [2, "", 0.5 * 2.0], [1, "", 0.5 * 4.0], - [-1100, "", 7.8], # test large number potentially overflowing + [-1100, "", 8], # test large number potentially overflowing ], ) @mock.patch("time.time", mock.MagicMock(return_value=1696004797)) @@ -1524,7 +1527,7 @@ class Model(BaseModel): [3, "", 0.5], [2, "", 0.5 * 2.0], [1, "", 0.5 * 4.0], - [-1100, "", 7.8], # test large number potentially overflowing + [-1100, "", 8], # test large number potentially overflowing ], ) @mock.patch("time.time", mock.MagicMock(return_value=1696004797)) @@ -1644,3 +1647,38 @@ def retry_handler(_request: httpx.Request) -> httpx.Response: ) assert response.http_request.headers.get("x-stainless-retry-count") == "42" + + def test_get_platform(self) -> None: + # A previous implementation of asyncify could leave threads unterminated when + # used with nest_asyncio. + # + # Since nest_asyncio.apply() is global and cannot be un-applied, this + # test is run in a separate process to avoid affecting other tests. + test_code = dedent(""" + import asyncio + import nest_asyncio + import threading + + from studio_sdk._utils import asyncify + from studio_sdk._base_client import get_platform + + async def test_main() -> None: + result = await asyncify(get_platform)() + print(result) + for thread in threading.enumerate(): + print(thread.name) + + nest_asyncio.apply() + asyncio.run(test_main()) + """) + with subprocess.Popen( + [sys.executable, "-c", test_code], + text=True, + ) as process: + try: + process.wait(2) + if process.returncode: + raise AssertionError("calling get_platform using asyncify resulted in a non-zero exit code") + except subprocess.TimeoutExpired as e: + process.kill() + raise AssertionError("calling get_platform using asyncify resulted in a hung process") from e diff --git a/tests/test_models.py b/tests/test_models.py index 29cda90..dd2d76a 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -520,19 +520,15 @@ class Model(BaseModel): assert m3.to_dict(exclude_none=True) == {} assert m3.to_dict(exclude_defaults=True) == {} - if PYDANTIC_V2: + class Model2(BaseModel): + created_at: datetime - class Model2(BaseModel): - created_at: datetime - - time_str = "2024-03-21T11:39:01.275859" - m4 = Model2.construct(created_at=time_str) - assert m4.to_dict(mode="python") == {"created_at": datetime.fromisoformat(time_str)} - assert m4.to_dict(mode="json") == {"created_at": time_str} - else: - with pytest.raises(ValueError, match="mode is only supported in Pydantic v2"): - m.to_dict(mode="json") + time_str = "2024-03-21T11:39:01.275859" + m4 = Model2.construct(created_at=time_str) + assert m4.to_dict(mode="python") == {"created_at": datetime.fromisoformat(time_str)} + assert m4.to_dict(mode="json") == {"created_at": time_str} + if not PYDANTIC_V2: with pytest.raises(ValueError, match="warnings is only supported in Pydantic v2"): m.to_dict(warnings=False) @@ -558,9 +554,6 @@ class Model(BaseModel): assert m3.model_dump(exclude_none=True) == {} if not PYDANTIC_V2: - with pytest.raises(ValueError, match="mode is only supported in Pydantic v2"): - m.model_dump(mode="json") - with pytest.raises(ValueError, match="round_trip is only supported in Pydantic v2"): m.model_dump(round_trip=True) @@ -568,6 +561,14 @@ class Model(BaseModel): m.model_dump(warnings=False) +def test_compat_method_no_error_for_warnings() -> None: + class Model(BaseModel): + foo: Optional[str] + + m = Model(foo="hello") + assert isinstance(model_dump(m, warnings=False), dict) + + def test_to_json() -> None: class Model(BaseModel): foo: Optional[str] = Field(alias="FOO", default=None) diff --git a/tests/test_transform.py b/tests/test_transform.py index b73010a..d66a46a 100644 --- a/tests/test_transform.py +++ b/tests/test_transform.py @@ -177,17 +177,32 @@ class DateDict(TypedDict, total=False): foo: Annotated[date, PropertyInfo(format="iso8601")] +class DatetimeModel(BaseModel): + foo: datetime + + +class DateModel(BaseModel): + foo: Optional[date] + + @parametrize @pytest.mark.asyncio async def test_iso8601_format(use_async: bool) -> None: dt = datetime.fromisoformat("2023-02-23T14:16:36.337692+00:00") + tz = "Z" if PYDANTIC_V2 else "+00:00" assert await transform({"foo": dt}, DatetimeDict, use_async) == {"foo": "2023-02-23T14:16:36.337692+00:00"} # type: ignore[comparison-overlap] + assert await transform(DatetimeModel(foo=dt), Any, use_async) == {"foo": "2023-02-23T14:16:36.337692" + tz} # type: ignore[comparison-overlap] dt = dt.replace(tzinfo=None) assert await transform({"foo": dt}, DatetimeDict, use_async) == {"foo": "2023-02-23T14:16:36.337692"} # type: ignore[comparison-overlap] + assert await transform(DatetimeModel(foo=dt), Any, use_async) == {"foo": "2023-02-23T14:16:36.337692"} # type: ignore[comparison-overlap] assert await transform({"foo": None}, DateDict, use_async) == {"foo": None} # type: ignore[comparison-overlap] + assert await transform(DateModel(foo=None), Any, use_async) == {"foo": None} # type: ignore assert await transform({"foo": date.fromisoformat("2023-02-23")}, DateDict, use_async) == {"foo": "2023-02-23"} # type: ignore[comparison-overlap] + assert await transform(DateModel(foo=date.fromisoformat("2023-02-23")), DateDict, use_async) == { + "foo": "2023-02-23" + } # type: ignore[comparison-overlap] @parametrize