Skip to content

Commit

Permalink
Prepping for edgefxn
Browse files Browse the repository at this point in the history
  • Loading branch information
olokobayusuf committed Feb 13, 2024
1 parent 38d3c7b commit c666dcc
Show file tree
Hide file tree
Showing 6 changed files with 90 additions and 81 deletions.
3 changes: 3 additions & 0 deletions Changelog.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
## 0.0.27
+ Added support for streaming when making predictions with Function CLI.
+ Added `PredictionResource.type` field for inspecting the type of a prediction resource.
+ Fixed pydantic forward reference errors when constructing `Signature` and `Predictor` instances.
+ Fixed `model_dump` error when making predictions in Google Colab due to outdated `pydantic` dependency.
+ Refactored `fxn.predictions.create` method to accept an `inputs` dictionary instead of relying on keyword arguments.

## 0.0.26
+ Added support for serializing `BytesIO` instances in `fxn.predictions.to_value` method.
Expand Down
12 changes: 9 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,18 @@ Let's run the [`@samplefxn/stable-diffusion`](https://fxn.ai/@samplefxn/stable-d
### In Python
Run the following Python script:
```py
import fxn
from fxn import Function

prediction = fxn.Prediction.create(
# Create the Function client
fxn = Function()
# Create a prediction
prediction = fxn.predictions.create(
tag="@samplefxn/stable-diffusion",
prompt="An astronaut riding a horse on Mars"
inputs={
"prompt": "An astronaut riding a horse on Mars"
}
)
# Show the generated image
image = prediction.results[0]
image.show()
```
Expand Down
34 changes: 15 additions & 19 deletions fxn/cli/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# Copyright © 2024 NatML Inc. All Rights Reserved.
#

from asyncio import run as run_async
from io import BytesIO
from numpy import ndarray
from pathlib import Path, PurePath
Expand All @@ -11,6 +12,7 @@
from rich.progress import Progress, SpinnerColumn, TextColumn
from tempfile import mkstemp
from typer import Argument, Context, Option
from typing import Any, Dict

from ..function import Function
from .auth import get_access_key
Expand All @@ -20,32 +22,26 @@ def predict (
raw_outputs: bool = Option(False, "--raw-outputs", help="Output raw Function values instead of converting into plain Python values."),
context: Context = 0
):
# Predict
inputs = { context.args[i].replace("-", ""): _parse_value(context.args[i+1]) for i in range(0, len(context.args), 2) }
run_async(_predict_async(tag, inputs, raw_outputs=raw_outputs))

async def _predict_async (tag: str, inputs: Dict[str, Any], raw_outputs: bool):
with Progress(
SpinnerColumn(spinner_name="dots"),
TextColumn("[progress.description]{task.description}"),
transient=True
) as progress:
progress.add_task(description="Running Function...", total=None)
fxn = Function(get_access_key())
inputs = { context.args[i].replace("-", ""): _parse_value(context.args[i+1]) for i in range(0, len(context.args), 2) }
prediction = fxn.predictions.create(
tag=tag,
**inputs,
raw_outputs=raw_outputs,
return_binary_path=True,
)
# Parse results
images = []
if hasattr(prediction, "results") and prediction.results is not None:
images = [value for value in prediction.results if isinstance(value, Image.Image)]
results = [_serialize_value(value) for value in prediction.results]
object.__setattr__(prediction, "results", results)
# Print
print_json(data=prediction.model_dump())
# Show images
for image in images:
image.show()
async for prediction in fxn.predictions.stream(tag, inputs=inputs, raw_outputs=raw_outputs, return_binary_path=True):
# Parse results
images = [value for value in prediction.results or [] if isinstance(value, Image.Image)]
prediction.results = [_serialize_value(value) for value in prediction.results] if prediction.results is not None else None
# Print
print_json(data=prediction.model_dump())
# Show images
for image in images:
image.show()

def _parse_value (value: str):
"""
Expand Down
86 changes: 40 additions & 46 deletions fxn/services/prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from PIL import Image
from platform import system
from pydantic import BaseModel
from requests import get
from requests import get, post
from tempfile import NamedTemporaryFile
from typing import Any, AsyncIterator, Dict, List, Union
from uuid import uuid4
Expand All @@ -32,49 +32,60 @@ def __init__ (self, client: GraphClient, storage: StorageService) -> None:
def create (
self,
tag: str,
*,
inputs: Dict[str, Union[ndarray, str, float, int, bool, List, Dict[str, Any], Path, Image.Image, Value]] = {},
raw_outputs: bool=False,
return_binary_path: bool=True,
data_url_limit: int=None,
**inputs: Dict[str, Union[ndarray, str, float, int, bool, List, Dict[str, Any], Path, Image.Image, Value]],
) -> Prediction:
"""
Create a prediction.
Parameters:
tag (str): Predictor tag.
raw_outputs (bool): Skip converting output values into Pythonic types.
inputs (dict): Input values. This only applies to `CLOUD` predictions.
raw_outputs (bool): Skip converting output values into Pythonic types. This only applies to `CLOUD` predictions.
return_binary_path (bool): Write binary values to file and return a `Path` instead of returning `BytesIO` instance.
data_url_limit (int): Return a data URL if a given output value is smaller than this size in bytes. Only applies to `CLOUD` predictions.
inputs (dict): Input values. Only applies to `CLOUD` predictions.
data_url_limit (int): Return a data URL if a given output value is smaller than this size in bytes. This only applies to `CLOUD` predictions.
Returns:
Prediction: Created prediction.
"""
# Serialize inputs
key = uuid4().hex
inputs = [{ "name": name, **self.to_value(value, name, key=key).model_dump() } for name, value in inputs.items()]
inputs = { name: self.to_value(value, name, key=key).model_dump(mode="json") for name, value in inputs.items() }
# Query
response = self.client.query(f"""
mutation ($input: CreatePredictionInput!) {{
createPrediction (input: $input) {{
{PREDICTION_FIELDS}
}}
}}""",
{ "input": { "tag": tag, "client": self.__get_client_id(), "inputs": inputs, "dataUrlLimit": data_url_limit } }
response = post(
f"{self.client.api_url}/predict/{tag}?rawOutputs=true&dataUrlLimit={data_url_limit}",
json=inputs,
headers={
"Authorization": f"Bearer {self.client.access_key}",
"fxn-client": self.__get_client_id()
}
)
# Parse
prediction = response["createPrediction"]
prediction = self.__parse_cloud_prediction(prediction, raw_outputs=raw_outputs, return_binary_path=return_binary_path)
# Check
prediction = response.json()
try:
response.raise_for_status()
except:
raise RuntimeError(prediction.get("error"))
# Parse prediction
prediction = Prediction(**prediction)
prediction.results = [Value(**value) for value in prediction.results] if prediction.results is not None else None
prediction.results = [self.to_object(value, return_binary_path=return_binary_path) for value in prediction.results] if prediction.results is not None and not raw_outputs else prediction.results
# Create edge outputs

# Return
return prediction

async def stream (
self,
tag: str,
*,
inputs: Dict[str, Union[ndarray, str, float, int, bool, List, Dict[str, Any], Path, Image.Image, Value]] = {},
raw_outputs: bool=False,
return_binary_path: bool=True,
data_url_limit: int=None,
**inputs: Dict[str, Union[ndarray, str, float, int, bool, List, Dict[str, Any], Path, Image.Image, Value]],
) -> AsyncIterator[Prediction]:
"""
Create a streaming prediction.
Expand All @@ -83,17 +94,17 @@ async def stream (
Parameters:
tag (str): Predictor tag.
raw_outputs (bool): Skip converting output values into Pythonic types.
inputs (dict): Input values. This only applies to `CLOUD` predictions.
raw_outputs (bool): Skip converting output values into Pythonic types. This only applies to `CLOUD` predictions.
return_binary_path (bool): Write binary values to file and return a `Path` instead of returning `BytesIO` instance.
data_url_limit (int): Return a data URL if a given output value is smaller than this size in bytes. Only applies to `CLOUD` predictions.
inputs (dict): Input values. Only applies to `CLOUD` predictions.
data_url_limit (int): Return a data URL if a given output value is smaller than this size in bytes. This only applies to `CLOUD` predictions.
Returns:
Prediction: Created prediction.
"""
# Serialize inputs
key = uuid4().hex
inputs = { name: self.to_value(value, name, key=key).model_dump() for name, value in inputs.items() }
inputs = { name: self.to_value(value, name, key=key).model_dump(mode="json") for name, value in inputs.items() }
# Request
url = f"{self.client.api_url}/predict/{tag}?stream=true&rawOutputs=true&dataUrlLimit={data_url_limit}"
headers = {
Expand All @@ -104,12 +115,15 @@ async def stream (
async with ClientSession(headers=headers) as session:
async with session.post(url, data=dumps(inputs)) as response:
async for chunk in response.content.iter_any():
payload = loads(chunk)
prediction = loads(chunk)
# Check status
if response.status >= 400:
raise RuntimeError(payload.get("error"))
raise RuntimeError(prediction.get("error"))
# Parse prediction
prediction = Prediction(**prediction)
prediction.results = [Value(**value) for value in prediction.results] if prediction.results is not None else None
prediction.results = [self.to_object(value, return_binary_path=return_binary_path) for value in prediction.results] if prediction.results is not None and not raw_outputs else prediction.results
# Yield
prediction = self.__parse_cloud_prediction(payload, raw_outputs=raw_outputs, return_binary_path=return_binary_path)
yield prediction

def to_object (
Expand Down Expand Up @@ -239,26 +253,6 @@ def to_value (
return Value(data=data, type=dtype)
# Unsupported
raise RuntimeError(f"Cannot create Function value '{name}' for object {object} of type {type(object)}")

def __parse_cloud_prediction (
self,
prediction: Dict[str, Any],
raw_outputs: bool=False,
return_binary_path: bool=True
) -> Prediction:
# Check null
if not prediction:
return None
# Check type
if prediction["type"] != PredictorType.Cloud:
return prediction
# Gather results
if "results" in prediction and prediction["results"] is not None:
prediction["results"] = [Value(**value) for value in prediction["results"]]
if not raw_outputs:
prediction["results"] = [self.to_object(value, return_binary_path=return_binary_path) for value in prediction["results"]]
# Return
return Prediction(**prediction)

def __get_data_dtype (self, data: Union[Path, BytesIO]) -> Dtype:
mime = guess_mime(str(data) if isinstance(data, Path) else data)
Expand Down Expand Up @@ -312,11 +306,10 @@ def __get_client_id (self) -> str:
id
tag
type
created
implementation
configuration
resources {{
id
type
url
}}
results {{
Expand All @@ -327,4 +320,5 @@ def __get_client_id (self) -> str:
latency
error
logs
created
"""
18 changes: 9 additions & 9 deletions fxn/types/prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,11 @@ class PredictionResource (BaseModel):
Members:
id (str): Resource identifier.
type (str): Resource type.
url (str): Resource URL.
"""
id: str
type: str
url: str

class Prediction (BaseModel):
Expand All @@ -27,23 +29,21 @@ class Prediction (BaseModel):
id (str): Prediction ID.
tag (str): Predictor tag.
type (PredictorType): Prediction type.
created (str): Date created.
configuration (str): Prediction configuration token. This is only populated for `EDGE` predictions.
resources (list): Prediction resources. This is only populated for `EDGE` predictions.
results (list): Prediction results.
latency (float): Prediction latency in milliseconds.
error (str): Prediction error. This is `null` if the prediction completed successfully.
logs (str): Prediction logs.
implementation (str): Predictor implementation. This is only populated for `EDGE` predictions.
resources (list): Prediction resources. This is only populated for `EDGE` predictions.
configuration (str): Prediction configuration token. This is only populated for `EDGE` predictions.
created (str): Date created.
"""
id: str
tag: str
type: PredictorType
created: str
configuration: Optional[str] = None
resources: Optional[List[PredictionResource]] = None
results: Optional[List[Any]] = None
latency: Optional[float] = None
error: Optional[str] = None
logs: Optional[str] = None
implementation: Optional[str] = None
resources: Optional[List[PredictionResource]] = None
configuration: Optional[str] = None
logs: Optional[str] = None
created: str
18 changes: 14 additions & 4 deletions test/prediction_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,29 @@
#

from fxn import Function
from numpy import allclose, pi
import pytest

pytest_plugins = ("pytest_asyncio",)

def test_create_prediction ():
def test_create_cloud_prediction ():
fxn = Function()
prediction = fxn.predictions.create(
tag="@yusuf-delete/streaming",
sentence="Hello world"
inputs={ "sentence": "Hello world" }
)
assert(prediction.results[0] == "world")

def test_create_prediction_raise_informative_error ():
def test_create_edge_prediction ():
fxn = Function()
radius = 4
prediction = fxn.predictions.create(
tag="@yusuf-delete/math",
inputs={ "radius": radius }
)
assert(allclose(prediction.results[0], pi * (radius ** 2)))

def test_create_invalid_prediction ():
fxn = Function()
with pytest.raises(RuntimeError):
fxn.predictions.create(tag="@yusuf-delete/invalid-predictor")
Expand All @@ -26,7 +36,7 @@ async def test_stream_prediction ():
fxn = Function()
stream = fxn.predictions.stream(
tag="@yusuf-delete/streaming",
sentence="Hello world"
inputs={ "sentence": "Hello world" }
)
async for prediction in stream:
print(prediction)

0 comments on commit c666dcc

Please sign in to comment.