diff --git a/Changelog.md b/Changelog.md index a088281..d98f053 100644 --- a/Changelog.md +++ b/Changelog.md @@ -1,6 +1,9 @@ ## 0.0.27 ++ Added support for streaming when making predictions with Function CLI. ++ Added `PredictionResource.type` field for inspecting the type of a prediction resource. + Fixed pydantic forward reference errors when constructing `Signature` and `Predictor` instances. + Fixed `model_dump` error when making predictions in Google Colab due to outdated `pydantic` dependency. ++ Refactored `fxn.predictions.create` method to accept an `inputs` dictionary instead of relying on keyword arguments. ## 0.0.26 + Added support for serializing `BytesIO` instances in `fxn.predictions.to_value` method. diff --git a/README.md b/README.md index 0c8126e..ef3a8bb 100644 --- a/README.md +++ b/README.md @@ -20,12 +20,18 @@ Let's run the [`@samplefxn/stable-diffusion`](https://fxn.ai/@samplefxn/stable-d ### In Python Run the following Python script: ```py -import fxn +from fxn import Function -prediction = fxn.Prediction.create( +# Create the Function client +fxn = Function() +# Create a prediction +prediction = fxn.predictions.create( tag="@samplefxn/stable-diffusion", - prompt="An astronaut riding a horse on Mars" + inputs={ + "prompt": "An astronaut riding a horse on Mars" + } ) +# Show the generated image image = prediction.results[0] image.show() ``` diff --git a/fxn/cli/predict.py b/fxn/cli/predict.py index 14c18cb..c0f69cc 100644 --- a/fxn/cli/predict.py +++ b/fxn/cli/predict.py @@ -3,6 +3,7 @@ # Copyright © 2024 NatML Inc. All Rights Reserved. # +from asyncio import run as run_async from io import BytesIO from numpy import ndarray from pathlib import Path, PurePath @@ -11,6 +12,7 @@ from rich.progress import Progress, SpinnerColumn, TextColumn from tempfile import mkstemp from typer import Argument, Context, Option +from typing import Any, Dict from ..function import Function from .auth import get_access_key @@ -20,7 +22,10 @@ def predict ( raw_outputs: bool = Option(False, "--raw-outputs", help="Output raw Function values instead of converting into plain Python values."), context: Context = 0 ): - # Predict + inputs = { context.args[i].replace("-", ""): _parse_value(context.args[i+1]) for i in range(0, len(context.args), 2) } + run_async(_predict_async(tag, inputs, raw_outputs=raw_outputs)) + +async def _predict_async (tag: str, inputs: Dict[str, Any], raw_outputs: bool): with Progress( SpinnerColumn(spinner_name="dots"), TextColumn("[progress.description]{task.description}"), @@ -28,24 +33,15 @@ def predict ( ) as progress: progress.add_task(description="Running Function...", total=None) fxn = Function(get_access_key()) - inputs = { context.args[i].replace("-", ""): _parse_value(context.args[i+1]) for i in range(0, len(context.args), 2) } - prediction = fxn.predictions.create( - tag=tag, - **inputs, - raw_outputs=raw_outputs, - return_binary_path=True, - ) - # Parse results - images = [] - if hasattr(prediction, "results") and prediction.results is not None: - images = [value for value in prediction.results if isinstance(value, Image.Image)] - results = [_serialize_value(value) for value in prediction.results] - object.__setattr__(prediction, "results", results) - # Print - print_json(data=prediction.model_dump()) - # Show images - for image in images: - image.show() + async for prediction in fxn.predictions.stream(tag, inputs=inputs, raw_outputs=raw_outputs, return_binary_path=True): + # Parse results + images = [value for value in prediction.results or [] if isinstance(value, Image.Image)] + prediction.results = [_serialize_value(value) for value in prediction.results] if prediction.results is not None else None + # Print + print_json(data=prediction.model_dump()) + # Show images + for image in images: + image.show() def _parse_value (value: str): """ diff --git a/fxn/services/prediction.py b/fxn/services/prediction.py index 5c42760..42aed60 100644 --- a/fxn/services/prediction.py +++ b/fxn/services/prediction.py @@ -13,7 +13,7 @@ from PIL import Image from platform import system from pydantic import BaseModel -from requests import get +from requests import get, post from tempfile import NamedTemporaryFile from typing import Any, AsyncIterator, Dict, List, Union from uuid import uuid4 @@ -32,49 +32,60 @@ def __init__ (self, client: GraphClient, storage: StorageService) -> None: def create ( self, tag: str, + *, + inputs: Dict[str, Union[ndarray, str, float, int, bool, List, Dict[str, Any], Path, Image.Image, Value]] = {}, raw_outputs: bool=False, return_binary_path: bool=True, data_url_limit: int=None, - **inputs: Dict[str, Union[ndarray, str, float, int, bool, List, Dict[str, Any], Path, Image.Image, Value]], ) -> Prediction: """ Create a prediction. Parameters: tag (str): Predictor tag. - raw_outputs (bool): Skip converting output values into Pythonic types. + inputs (dict): Input values. This only applies to `CLOUD` predictions. + raw_outputs (bool): Skip converting output values into Pythonic types. This only applies to `CLOUD` predictions. return_binary_path (bool): Write binary values to file and return a `Path` instead of returning `BytesIO` instance. - data_url_limit (int): Return a data URL if a given output value is smaller than this size in bytes. Only applies to `CLOUD` predictions. - inputs (dict): Input values. Only applies to `CLOUD` predictions. + data_url_limit (int): Return a data URL if a given output value is smaller than this size in bytes. This only applies to `CLOUD` predictions. Returns: Prediction: Created prediction. """ # Serialize inputs key = uuid4().hex - inputs = [{ "name": name, **self.to_value(value, name, key=key).model_dump() } for name, value in inputs.items()] + inputs = { name: self.to_value(value, name, key=key).model_dump(mode="json") for name, value in inputs.items() } # Query - response = self.client.query(f""" - mutation ($input: CreatePredictionInput!) {{ - createPrediction (input: $input) {{ - {PREDICTION_FIELDS} - }} - }}""", - { "input": { "tag": tag, "client": self.__get_client_id(), "inputs": inputs, "dataUrlLimit": data_url_limit } } + response = post( + f"{self.client.api_url}/predict/{tag}?rawOutputs=true&dataUrlLimit={data_url_limit}", + json=inputs, + headers={ + "Authorization": f"Bearer {self.client.access_key}", + "fxn-client": self.__get_client_id() + } ) - # Parse - prediction = response["createPrediction"] - prediction = self.__parse_cloud_prediction(prediction, raw_outputs=raw_outputs, return_binary_path=return_binary_path) + # Check + prediction = response.json() + try: + response.raise_for_status() + except: + raise RuntimeError(prediction.get("error")) + # Parse prediction + prediction = Prediction(**prediction) + prediction.results = [Value(**value) for value in prediction.results] if prediction.results is not None else None + prediction.results = [self.to_object(value, return_binary_path=return_binary_path) for value in prediction.results] if prediction.results is not None and not raw_outputs else prediction.results + # Create edge outputs + # Return return prediction async def stream ( self, tag: str, + *, + inputs: Dict[str, Union[ndarray, str, float, int, bool, List, Dict[str, Any], Path, Image.Image, Value]] = {}, raw_outputs: bool=False, return_binary_path: bool=True, data_url_limit: int=None, - **inputs: Dict[str, Union[ndarray, str, float, int, bool, List, Dict[str, Any], Path, Image.Image, Value]], ) -> AsyncIterator[Prediction]: """ Create a streaming prediction. @@ -83,17 +94,17 @@ async def stream ( Parameters: tag (str): Predictor tag. - raw_outputs (bool): Skip converting output values into Pythonic types. + inputs (dict): Input values. This only applies to `CLOUD` predictions. + raw_outputs (bool): Skip converting output values into Pythonic types. This only applies to `CLOUD` predictions. return_binary_path (bool): Write binary values to file and return a `Path` instead of returning `BytesIO` instance. - data_url_limit (int): Return a data URL if a given output value is smaller than this size in bytes. Only applies to `CLOUD` predictions. - inputs (dict): Input values. Only applies to `CLOUD` predictions. + data_url_limit (int): Return a data URL if a given output value is smaller than this size in bytes. This only applies to `CLOUD` predictions. Returns: Prediction: Created prediction. """ # Serialize inputs key = uuid4().hex - inputs = { name: self.to_value(value, name, key=key).model_dump() for name, value in inputs.items() } + inputs = { name: self.to_value(value, name, key=key).model_dump(mode="json") for name, value in inputs.items() } # Request url = f"{self.client.api_url}/predict/{tag}?stream=true&rawOutputs=true&dataUrlLimit={data_url_limit}" headers = { @@ -104,12 +115,15 @@ async def stream ( async with ClientSession(headers=headers) as session: async with session.post(url, data=dumps(inputs)) as response: async for chunk in response.content.iter_any(): - payload = loads(chunk) + prediction = loads(chunk) # Check status if response.status >= 400: - raise RuntimeError(payload.get("error")) + raise RuntimeError(prediction.get("error")) + # Parse prediction + prediction = Prediction(**prediction) + prediction.results = [Value(**value) for value in prediction.results] if prediction.results is not None else None + prediction.results = [self.to_object(value, return_binary_path=return_binary_path) for value in prediction.results] if prediction.results is not None and not raw_outputs else prediction.results # Yield - prediction = self.__parse_cloud_prediction(payload, raw_outputs=raw_outputs, return_binary_path=return_binary_path) yield prediction def to_object ( @@ -239,26 +253,6 @@ def to_value ( return Value(data=data, type=dtype) # Unsupported raise RuntimeError(f"Cannot create Function value '{name}' for object {object} of type {type(object)}") - - def __parse_cloud_prediction ( - self, - prediction: Dict[str, Any], - raw_outputs: bool=False, - return_binary_path: bool=True - ) -> Prediction: - # Check null - if not prediction: - return None - # Check type - if prediction["type"] != PredictorType.Cloud: - return prediction - # Gather results - if "results" in prediction and prediction["results"] is not None: - prediction["results"] = [Value(**value) for value in prediction["results"]] - if not raw_outputs: - prediction["results"] = [self.to_object(value, return_binary_path=return_binary_path) for value in prediction["results"]] - # Return - return Prediction(**prediction) def __get_data_dtype (self, data: Union[Path, BytesIO]) -> Dtype: mime = guess_mime(str(data) if isinstance(data, Path) else data) @@ -312,11 +306,10 @@ def __get_client_id (self) -> str: id tag type -created -implementation configuration resources {{ id + type url }} results {{ @@ -327,4 +320,5 @@ def __get_client_id (self) -> str: latency error logs +created """ \ No newline at end of file diff --git a/fxn/types/prediction.py b/fxn/types/prediction.py index 1d500e2..17afc36 100644 --- a/fxn/types/prediction.py +++ b/fxn/types/prediction.py @@ -14,9 +14,11 @@ class PredictionResource (BaseModel): Members: id (str): Resource identifier. + type (str): Resource type. url (str): Resource URL. """ id: str + type: str url: str class Prediction (BaseModel): @@ -27,23 +29,21 @@ class Prediction (BaseModel): id (str): Prediction ID. tag (str): Predictor tag. type (PredictorType): Prediction type. - created (str): Date created. + configuration (str): Prediction configuration token. This is only populated for `EDGE` predictions. + resources (list): Prediction resources. This is only populated for `EDGE` predictions. results (list): Prediction results. latency (float): Prediction latency in milliseconds. error (str): Prediction error. This is `null` if the prediction completed successfully. logs (str): Prediction logs. - implementation (str): Predictor implementation. This is only populated for `EDGE` predictions. - resources (list): Prediction resources. This is only populated for `EDGE` predictions. - configuration (str): Prediction configuration token. This is only populated for `EDGE` predictions. + created (str): Date created. """ id: str tag: str type: PredictorType - created: str + configuration: Optional[str] = None + resources: Optional[List[PredictionResource]] = None results: Optional[List[Any]] = None latency: Optional[float] = None error: Optional[str] = None - logs: Optional[str] = None - implementation: Optional[str] = None - resources: Optional[List[PredictionResource]] = None - configuration: Optional[str] = None \ No newline at end of file + logs: Optional[str] = None + created: str \ No newline at end of file diff --git a/test/prediction_test.py b/test/prediction_test.py index c953580..9e6f88d 100644 --- a/test/prediction_test.py +++ b/test/prediction_test.py @@ -4,19 +4,29 @@ # from fxn import Function +from numpy import allclose, pi import pytest pytest_plugins = ("pytest_asyncio",) -def test_create_prediction (): +def test_create_cloud_prediction (): fxn = Function() prediction = fxn.predictions.create( tag="@yusuf-delete/streaming", - sentence="Hello world" + inputs={ "sentence": "Hello world" } ) assert(prediction.results[0] == "world") -def test_create_prediction_raise_informative_error (): +def test_create_edge_prediction (): + fxn = Function() + radius = 4 + prediction = fxn.predictions.create( + tag="@yusuf-delete/math", + inputs={ "radius": radius } + ) + assert(allclose(prediction.results[0], pi * (radius ** 2))) + +def test_create_invalid_prediction (): fxn = Function() with pytest.raises(RuntimeError): fxn.predictions.create(tag="@yusuf-delete/invalid-predictor") @@ -26,7 +36,7 @@ async def test_stream_prediction (): fxn = Function() stream = fxn.predictions.stream( tag="@yusuf-delete/streaming", - sentence="Hello world" + inputs={ "sentence": "Hello world" } ) async for prediction in stream: print(prediction) \ No newline at end of file