Skip to content

Commit

Permalink
[python] add rolling batch as auto for neuron smart default (#2606)
Browse files Browse the repository at this point in the history
  • Loading branch information
sindhuvahinis authored Nov 27, 2024
1 parent a75865c commit 5a6562a
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
from typing import List, Dict, Any
import subprocess as sp

from djl_python.properties_manager.properties import RollingBatchEnum

logger = logging.getLogger(__name__)

BILLION = 1_000_000_000.0
Expand Down Expand Up @@ -69,6 +71,10 @@ def apply_smart_defaults(self,
:param is_partition: Indicates whether we are saving pre-sharded checkpoints or not.
We set some smart defaults for it.
"""

if "rolling_batch" not in properties:
properties["rolling_batch"] = RollingBatchEnum.auto.value

if "n_positions" not in properties:
if self.get_model_parameters(
model_config) <= 0 or self.available_cores == 0:
Expand Down
8 changes: 7 additions & 1 deletion tests/integration/llm/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1821,8 +1821,14 @@ def test_transformers_neuronx_handler(model, model_spec):
if "worker" in spec:
check_worker_number(spec["worker"])
for batch_size in spec["batch_size"]:
inputs = batch_generation(batch_size)
if batch_size == 1:
# for rolling batch, inputs should be a str not list.
# i.e, client side batching is not enabled when rolling batch is enabled.
# if batch_size is just 1, then we assume it is for rolling batch here.
inputs = inputs[0]
for seq_length in spec["seq_length"]:
req = {"inputs": batch_generation(batch_size)}
req = {"inputs": inputs}
params = {"max_length": seq_length}
if "use_sample" in spec:
params["use_sample"] = True
Expand Down

0 comments on commit 5a6562a

Please sign in to comment.