forked from ultralytics/yolov5
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Detect.py supports running against a Triton container (ultralytics#9228)
* update coco128-seg comments * Enables detect.py to use Triton for inference Triton Inference Server is an open source inference serving software that streamlines AI inferencing. https://github.com/triton-inference-server/server The user can now provide a "--triton-url" argument to detect.py to use a local or remote Triton server for inference. For e.g., http://localhost:8000 will use http over port 8000 and grpc://localhost:8001 will use grpc over port 8001. Note, it is not necessary to specify a weights file to use Triton. A Triton container can be created by first exporting the Yolov5 model to a Triton supported runtime. Onnx, Torchscript, TensorRT are supported by both Triton and the export.py script. The exported model can then be containerized via the OctoML CLI. See https://github.com/octoml/octo-cli#getting-started for a guide. * added triton client to requirements * fixed support for TFSavedModels in Triton * reverted change * Test CoreML update Signed-off-by: Glenn Jocher <[email protected]> * Update ci-testing.yml Signed-off-by: Glenn Jocher <[email protected]> * Use pathlib Signed-off-by: Glenn Jocher <[email protected]> * Refacto DetectMultiBackend to directly accept triton url as --weights http://... Signed-off-by: Glenn Jocher <[email protected]> * Deploy category Signed-off-by: Glenn Jocher <[email protected]> * Update detect.py Signed-off-by: Glenn Jocher <[email protected]> * Update common.py Signed-off-by: Glenn Jocher <[email protected]> * Update common.py Signed-off-by: Glenn Jocher <[email protected]> * Update predict.py Signed-off-by: Glenn Jocher <[email protected]> * Update predict.py Signed-off-by: Glenn Jocher <[email protected]> * Update predict.py Signed-off-by: Glenn Jocher <[email protected]> * Update triton.py Signed-off-by: Glenn Jocher <[email protected]> * Update triton.py Signed-off-by: Glenn Jocher <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Add printout and requirements check * Cleanup Signed-off-by: Glenn Jocher <[email protected]> * triton fixes * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fixed triton model query over grpc * Update check_requirements('tritonclient[all]') * group imports * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix likely remote URL bug * update comment * Update is_url() * Fix 2x download attempt on http://path/to/model.pt Signed-off-by: Glenn Jocher <[email protected]> Co-authored-by: glennjocher <[email protected]> Co-authored-by: Gaz Iqbal <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
- Loading branch information
1 parent
1320ce1
commit d669a74
Showing
7 changed files
with
126 additions
and
22 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
# YOLOv5 🚀 by Ultralytics, GPL-3.0 license | ||
""" Utils to interact with the Triton Inference Server | ||
""" | ||
|
||
import typing | ||
from urllib.parse import urlparse | ||
|
||
import torch | ||
|
||
|
||
class TritonRemoteModel: | ||
""" A wrapper over a model served by the Triton Inference Server. It can | ||
be configured to communicate over GRPC or HTTP. It accepts Torch Tensors | ||
as input and returns them as outputs. | ||
""" | ||
|
||
def __init__(self, url: str): | ||
""" | ||
Keyword arguments: | ||
url: Fully qualified address of the Triton server - for e.g. grpc://localhost:8000 | ||
""" | ||
|
||
parsed_url = urlparse(url) | ||
if parsed_url.scheme == "grpc": | ||
from tritonclient.grpc import InferenceServerClient, InferInput | ||
|
||
self.client = InferenceServerClient(parsed_url.netloc) # Triton GRPC client | ||
model_repository = self.client.get_model_repository_index() | ||
self.model_name = model_repository.models[0].name | ||
self.metadata = self.client.get_model_metadata(self.model_name, as_json=True) | ||
|
||
def create_input_placeholders() -> typing.List[InferInput]: | ||
return [ | ||
InferInput(i['name'], [int(s) for s in i["shape"]], i['datatype']) for i in self.metadata['inputs']] | ||
|
||
else: | ||
from tritonclient.http import InferenceServerClient, InferInput | ||
|
||
self.client = InferenceServerClient(parsed_url.netloc) # Triton HTTP client | ||
model_repository = self.client.get_model_repository_index() | ||
self.model_name = model_repository[0]['name'] | ||
self.metadata = self.client.get_model_metadata(self.model_name) | ||
|
||
def create_input_placeholders() -> typing.List[InferInput]: | ||
return [ | ||
InferInput(i['name'], [int(s) for s in i["shape"]], i['datatype']) for i in self.metadata['inputs']] | ||
|
||
self._create_input_placeholders_fn = create_input_placeholders | ||
|
||
@property | ||
def runtime(self): | ||
"""Returns the model runtime""" | ||
return self.metadata.get("backend", self.metadata.get("platform")) | ||
|
||
def __call__(self, *args, **kwargs) -> typing.Union[torch.Tensor, typing.Tuple[torch.Tensor, ...]]: | ||
""" Invokes the model. Parameters can be provided via args or kwargs. | ||
args, if provided, are assumed to match the order of inputs of the model. | ||
kwargs are matched with the model input names. | ||
""" | ||
inputs = self._create_inputs(*args, **kwargs) | ||
response = self.client.infer(model_name=self.model_name, inputs=inputs) | ||
result = [] | ||
for output in self.metadata['outputs']: | ||
tensor = torch.as_tensor(response.as_numpy(output['name'])) | ||
result.append(tensor) | ||
return result[0] if len(result) == 1 else result | ||
|
||
def _create_inputs(self, *args, **kwargs): | ||
args_len, kwargs_len = len(args), len(kwargs) | ||
if not args_len and not kwargs_len: | ||
raise RuntimeError("No inputs provided.") | ||
if args_len and kwargs_len: | ||
raise RuntimeError("Cannot specify args and kwargs at the same time") | ||
|
||
placeholders = self._create_input_placeholders_fn() | ||
if args_len: | ||
if args_len != len(placeholders): | ||
raise RuntimeError(f"Expected {len(placeholders)} inputs, got {args_len}.") | ||
for input, value in zip(placeholders, args): | ||
input.set_data_from_numpy(value.cpu().numpy()) | ||
else: | ||
for input in placeholders: | ||
value = kwargs[input.name] | ||
input.set_data_from_numpy(value.cpu().numpy()) | ||
return placeholders |