diff --git a/langserve/api_handler.py b/langserve/api_handler.py index 23f29c35..ed8f974d 100644 --- a/langserve/api_handler.py +++ b/langserve/api_handler.py @@ -782,7 +782,7 @@ async def _get_config_and_input( # This takes into account changes in the input type when # using configuration. schema = self._runnable.with_config(config).input_schema - input_ = schema.validate(body.input) + input_ = schema.model_validate(body.input) return config, _unpack_input(input_) except ValidationError as e: raise RequestValidationError(e.errors(), body=body) @@ -892,7 +892,7 @@ async def batch( raise RequestValidationError(errors=["Invalid JSON body"]) with _with_validation_error_translation(): - body = BatchRequestShallowValidator.validate(body) + body = BatchRequestShallowValidator.model_validate(body) config = body.config # First unpack the config @@ -943,7 +943,7 @@ async def batch( inputs = [ _unpack_input( - self._runnable.with_config(config_).input_schema.validate(input_) + self._runnable.with_config(config_).input_schema.model_validate(input_) ) for config_, input_ in zip(configs_, inputs_) ] diff --git a/langserve/server.py b/langserve/server.py index 99a02039..cfe935e4 100644 --- a/langserve/server.py +++ b/langserve/server.py @@ -5,6 +5,7 @@ The main entry point is the `add_routes` function which adds the routes to an existing FastAPI app or APIRouter. """ +import warnings import weakref from typing import ( Any, @@ -201,37 +202,47 @@ def _register_path_for_app( def _setup_global_app_handlers( app: Union[FastAPI, APIRouter], endpoint_configuration: _EndpointConfiguration ) -> None: - @app.on_event("startup") - async def startup_event(): - LANGSERVE = r""" - __ ___ .__ __. _______ _______. _______ .______ ____ ____ _______ -| | / \ | \ | | / _____| / || ____|| _ \ \ \ / / | ____| -| | / ^ \ | \| | | | __ | (----`| |__ | |_) | \ \/ / | |__ -| | / /_\ \ | . ` | | | |_ | \ \ | __| | / \ / | __| -| `----./ _____ \ | |\ | | |__| | .----) | | |____ | |\ \----. \ / | |____ -|_______/__/ \__\ |__| \__| \______| |_______/ |_______|| _| `._____| \__/ |_______| -""" # noqa: E501 - - def green(text: str) -> str: - """Return the given text in green.""" - return "\x1b[1;32;40m" + text + "\x1b[0m" - - def orange(text: str) -> str: - """Return the given text in orange.""" - return "\x1b[1;31;40m" + text + "\x1b[0m" - - paths = _APP_TO_PATHS[app] - print(LANGSERVE) - for path in paths: - if endpoint_configuration.is_playground_enabled: - print( - f'{green("LANGSERVE:")} Playground for chain "{path or ""}/" is ' - f"live at:" - ) - print(f'{green("LANGSERVE:")} │') - print(f'{green("LANGSERVE:")} └──> {path}/playground/') - print(f'{green("LANGSERVE:")}') - print(f'{green("LANGSERVE:")} See all available routes at {app.docs_url}/') + with warnings.catch_warnings(): + # We are using deprecated functionality here. + # This code should be re-written to simply construct a pydantic model + # using inspect.signature and create_model. + warnings.filterwarnings( + "ignore", + "[\\s.]*on_event is deprecated[\\s.]*", + category=DeprecationWarning, + ) + + @app.on_event("startup") + async def startup_event(): + LANGSERVE = r""" + __ ___ .__ __. _______ _______. _______ .______ ____ ____ _______ + | | / \ | \ | | / _____| / || ____|| _ \ \ \ / / | ____| + | | / ^ \ | \| | | | __ | (----`| |__ | |_) | \ \/ / | |__ + | | / /_\ \ | . ` | | | |_ | \ \ | __| | / \ / | __| + | `----./ _____ \ | |\ | | |__| | .----) | | |____ | |\ \----. \ / | |____ + |_______/__/ \__\ |__| \__| \______| |_______/ |_______|| _| `._____| \__/ |_______| + """ # noqa: E501 + + def green(text: str) -> str: + """Return the given text in green.""" + return "\x1b[1;32;40m" + text + "\x1b[0m" + + def orange(text: str) -> str: + """Return the given text in orange.""" + return "\x1b[1;31;40m" + text + "\x1b[0m" + + paths = _APP_TO_PATHS[app] + print(LANGSERVE) + for path in paths: + if endpoint_configuration.is_playground_enabled: + print( + f'{green("LANGSERVE:")} Playground for chain "{path or ""}/" ' + f'is live at:' + ) + print(f'{green("LANGSERVE:")} │') + print(f'{green("LANGSERVE:")} └──> {path}/playground/') + print(f'{green("LANGSERVE:")}') + print(f'{green("LANGSERVE:")} See all available routes at {app.docs_url}/') # PUBLIC API diff --git a/pyproject.toml b/pyproject.toml index 175c7b0c..10b55612 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -94,3 +94,7 @@ addopts = "--strict-markers --strict-config --durations=5 -vv" # take more than 5 seconds timeout = 5 asyncio_mode = "auto" +filterwarnings = [ + "ignore::langchain_core._api.beta_decorator.LangChainBetaWarning", +] + diff --git a/tests/unit_tests/test_server_client.py b/tests/unit_tests/test_server_client.py index d2260cc4..6d78b31f 100644 --- a/tests/unit_tests/test_server_client.py +++ b/tests/unit_tests/test_server_client.py @@ -4,7 +4,6 @@ import json import sys import uuid -from asyncio import AbstractEventLoop from contextlib import asynccontextmanager, contextmanager from dataclasses import dataclass from enum import Enum @@ -123,7 +122,7 @@ def _replace_run_id_in_stream_resp(streamed_resp: str) -> str: return streamed_resp.replace(uuid, "") -@pytest.fixture(scope="session") +@pytest.fixture(scope="module") def event_loop(): """Create an instance of the default event loop for each test case.""" loop = asyncio.get_event_loop() @@ -134,7 +133,7 @@ def event_loop(): @pytest.fixture() -def app(event_loop: AbstractEventLoop) -> FastAPI: +def app() -> FastAPI: """A simple server that wraps a Runnable and exposes it as an API.""" async def add_one_or_passthrough( @@ -158,7 +157,7 @@ async def add_one_or_passthrough( @pytest.fixture() -def app_for_config(event_loop: AbstractEventLoop) -> FastAPI: +def app_for_config() -> FastAPI: """A simple server that wraps a Runnable and exposes it as an API.""" async def return_config( @@ -223,7 +222,7 @@ async def get_async_test_client( app=server, raise_app_exceptions=raise_app_exceptions, ) - async_client = AsyncClient(app=server, base_url=url, transport=transport) + async_client = AsyncClient(base_url=url, transport=transport) try: yield async_client finally: @@ -333,7 +332,7 @@ async def test_server_async(app: FastAPI) -> None: # test bad requests async with get_async_test_client(app, raise_app_exceptions=True) as async_client: # Test invoke - response = await async_client.post("/invoke", data="bad json []") + response = await async_client.post("/invoke", content="bad json []") # Client side error bad json. assert response.status_code == 422 @@ -353,7 +352,7 @@ async def test_server_async(app: FastAPI) -> None: async with get_async_test_client(app, raise_app_exceptions=True) as async_client: # Test invoke # Test bad batch requests - response = await async_client.post("/batch", data="bad json []") + response = await async_client.post("/batch", content="bad json []") # Client side error bad json. assert response.status_code == 422 @@ -378,7 +377,7 @@ async def test_server_async(app: FastAPI) -> None: # test stream bad requests async with get_async_test_client(app, raise_app_exceptions=True) as async_client: # Test bad stream requests - response = await async_client.post("/stream", data="bad json []") + response = await async_client.post("/stream", content="bad json []") assert response.status_code == 422 response = await async_client.post("/stream", json={}) @@ -386,7 +385,7 @@ async def test_server_async(app: FastAPI) -> None: # test stream_log bad requests async with get_async_test_client(app, raise_app_exceptions=True) as async_client: - response = await async_client.post("/stream_log", data="bad json []") + response = await async_client.post("/stream_log", content="bad json []") assert response.status_code == 422 response = await async_client.post("/stream_log", json={}) @@ -448,7 +447,7 @@ async def test_server_astream_events(app: FastAPI) -> None: # test stream_events with bad requests async with get_async_test_client(app, raise_app_exceptions=True) as async_client: - response = await async_client.post("/stream_events", data="bad json []") + response = await async_client.post("/stream_events", content="bad json []") assert response.status_code == 422 response = await async_client.post("/stream_events", json={}) @@ -854,7 +853,7 @@ async def with_errors(inputs: dict) -> AsyncIterator[int]: assert e.value.response.status_code == 500 -async def test_astream_log_allowlist(event_loop: AbstractEventLoop) -> None: +async def test_astream_log_allowlist() -> None: """Test async stream with an allowlist.""" async def add_one(x: int) -> int: @@ -1035,7 +1034,7 @@ async def test_invoke_as_part_of_sequence_async( } -async def test_multiple_runnables(event_loop: AbstractEventLoop) -> None: +async def test_multiple_runnables() -> None: """Test serving multiple runnables.""" async def add_one(x: int) -> int: @@ -1159,7 +1158,7 @@ async def add_one(x: int) -> int: await runnable.abatch(["hello"]) -async def test_input_validation_with_lc_types(event_loop: AbstractEventLoop) -> None: +async def test_input_validation_with_lc_types() -> None: """Test client side and server side exceptions.""" app = FastAPI() @@ -1252,9 +1251,7 @@ async def test_async_client_close() -> None: assert async_client.is_closed is True -async def test_openapi_docs_with_identical_runnables( - event_loop: AbstractEventLoop, mocker: MockerFixture -) -> None: +async def test_openapi_docs_with_identical_runnables(mocker: MockerFixture) -> None: """Test client side and server side exceptions.""" async def add_one(x: int) -> int: @@ -1301,7 +1298,7 @@ async def add_one(x: int) -> int: assert response.status_code == 200 -async def test_configurable_runnables(event_loop: AbstractEventLoop) -> None: +async def test_configurable_runnables() -> None: """Add tests for using langchain's configurable runnables""" template = PromptTemplate.from_template("say {name}").configurable_fields( @@ -1391,7 +1388,7 @@ class Foo(BaseModel): assert Model.__name__ == "BarFoo" -async def test_input_config_output_schemas(event_loop: AbstractEventLoop) -> None: +async def test_input_config_output_schemas() -> None: """Test schemas returned for different configurations.""" # TODO(Fix me): need to fix handling of global state -- we get problems # gives inconsistent results when running multiple tests / results @@ -1753,7 +1750,7 @@ async def test_server_side_error() -> None: # assert e.response.text == "Internal Server Error" -def test_server_side_error_sync(event_loop: AbstractEventLoop) -> None: +def test_server_side_error_sync() -> None: """Test server side error handling.""" app = FastAPI() @@ -1982,7 +1979,7 @@ async def test_enforce_trailing_slash_in_client() -> None: assert r.url == "nosuchurl/" -async def test_per_request_config_modifier(event_loop: AbstractEventLoop) -> None: +async def test_per_request_config_modifier() -> None: """Test updating the config based on the raw request object.""" async def add_one(x: int) -> int: @@ -2025,9 +2022,7 @@ async def header_passthru_modifier( assert response.json()["output"] == 2 -async def test_per_request_config_modifier_endpoints( - event_loop: AbstractEventLoop, -) -> None: +async def test_per_request_config_modifier_endpoints() -> None: """Verify that per request modifier is only applied for the expected endpoints.""" # this test verifies that per request modifier is only @@ -2097,7 +2092,7 @@ async def buggy_modifier( assert response.status_code != 500 -async def test_uuid_serialization(event_loop: AbstractEventLoop) -> None: +async def test_uuid_serialization() -> None: """Test updating the config based on the raw request object.""" import datetime