Skip to content

Commit

Permalink
Update batching for YOLOX
Browse files Browse the repository at this point in the history
  • Loading branch information
milank94 committed Sep 12, 2024
1 parent 55e2f79 commit 6309fd1
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 5 deletions.
14 changes: 12 additions & 2 deletions model_demos/cv_demos/yolo_x/pytorch_yolox.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
# SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC
# SPDX-License-Identifier: Apache-2.0

# yolox demo script

import subprocess
Expand All @@ -22,7 +25,7 @@
from yolox.utils import demo_postprocess, multiclass_nms


def run_yolox_pytorch(variant):
def run_yolox_pytorch(variant, batch_size=1):

# Set PyBuda configuration parameters
compiler_cfg = pybuda.config._get_global_compiler_config()
Expand Down Expand Up @@ -234,11 +237,18 @@ def run_yolox_pytorch(variant):
img, ratio = preprocess(img, input_shape)
img_tensor = torch.from_numpy(img)
img_tensor = img_tensor.unsqueeze(0)
batch_input = torch.cat([img_tensor] * batch_size, dim=0)

# Run inference on Tenstorrent device
output_q = pybuda.run_inference(tt_model, inputs=[(img_tensor)])
output_q = pybuda.run_inference(tt_model, inputs=[(batch_input)])
output = output_q.get()

# Combine outputs for data parallel runs
if os.environ.get("PYBUDA_N300_DATA_PARALLEL", "0") == "1":
concat_tensor = torch.cat((output[0].to_pytorch(), output[1].to_pytorch()), dim=0)
buda_tensor = pybuda.Tensor.create_from_torch(concat_tensor)
output = [buda_tensor]

# Post-processing
for i in range(len(output)):
output[i] = output[i].value().detach().float().numpy()
Expand Down
9 changes: 6 additions & 3 deletions model_demos/tests/test_pytorch_yolox.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
# SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC
# SPDX-License-Identifier: Apache-2.0

import pytest

from cv_demos.yolo_x.pytorch_yolox import run_yolox_pytorch

variants = ["yolox_nano", "yolox_tiny", "yolox_s", "yolox_m", "yolox_l", "yolox_darknet", "yolox_x"]


@pytest.mark.parametrize("variant", variants)
@pytest.mark.parametrize("variant", variants, ids=variants)
@pytest.mark.yolox
def test_yolox_pytorch(variant, clear_pybuda, test_device):
run_yolox_pytorch(variant)
def test_yolox_pytorch(clear_pybuda, test_device, variant, batch_size):
run_yolox_pytorch(variant, batch_size=batch_size)

0 comments on commit 6309fd1

Please sign in to comment.