Skip to content

Commit

Permalink
Check for None (huggingface#2452)
Browse files Browse the repository at this point in the history
  • Loading branch information
muellerzr authored Feb 15, 2024
1 parent 79016eb commit 97d2168
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions src/accelerate/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,10 @@ def build_pipeline(model, split_points, args, kwargs, num_chunks):
annotate_split_points(model, {split_point: PipeSplitWrapper.SplitPoint.BEGINNING for split_point in split_points})
found_batch_size = find_pippy_batch_size(args, kwargs)
if found_batch_size != num_chunks:
args = pad_input_tensors(args, found_batch_size, num_chunks)
kwargs = pad_input_tensors(kwargs, found_batch_size, num_chunks)
if args is not None:
args = pad_input_tensors(args, found_batch_size, num_chunks)
if kwargs is not None:
kwargs = pad_input_tensors(kwargs, found_batch_size, num_chunks)
pipe = Pipe.from_tracing(model, num_chunks=num_chunks, example_args=args, example_kwargs=kwargs)
stage = PipelineStage(pipe, state.local_process_index, device=state.device)

Expand Down

0 comments on commit 97d2168

Please sign in to comment.