Skip to content

Commit

Permalink
refine vllm inference and keep the API same as none-vllm
Browse files Browse the repository at this point in the history
  • Loading branch information
SeaOfOcean committed Aug 30, 2024
1 parent ba42dcd commit 97aad53
Show file tree
Hide file tree
Showing 12 changed files with 81 additions and 202 deletions.
7 changes: 7 additions & 0 deletions chatlearn/data/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,3 +43,10 @@ def get(self, key):
logger.warning("%s is not found in storage", key)
return None
return future.get(ref)

def delete(self, key):
if isinstance(key, str):
key = [key]
for k in key:
# TODO: do we need to release the remote obj?
self._storage.pop(k)
66 changes: 54 additions & 12 deletions chatlearn/models/vllm_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@

from chatlearn.utils.vllm_utils import initialize_vllm, Megatron2LlamaSyncMap, Megatron2QWenSyncMap

from chatlearn.utils import utils
from chatlearn.utils.vllm_utils import get_model, print_rank_0
from .torch_module import TorchModule
try:
Expand Down Expand Up @@ -315,7 +314,7 @@ def reinit_cache_engine(self):
elif CURRENT_VLLM_VERSION == VLLMVersion.v_0_5_1.value:
self.worker.initialize_cache(self.cache_config.num_gpu_blocks, self.cache_config.num_cpu_blocks)

def free_cache_engine(self):
def empty_cache(self):
if CURRENT_VLLM_VERSION == VLLMVersion.v_0_3_0.value:
self.worker.gpu_cache = None # pylint: disable=access-member-before-definition
self.worker.cache_engine.cpu_cache = None
Expand All @@ -338,7 +337,7 @@ def clear_cache(self):
gc.collect()
self.timers("gc").stop()

self.empty_cache()
super().empty_cache()

def profile_cache_blocks(self):
"""Profiles the memory usage and initializes the KV cache."""
Expand Down Expand Up @@ -396,6 +395,11 @@ def convert_v1_inputs(self, prompts, prompt_token_ids):

return inputs

def _add_request(self, data, is_eval=False): # pylint: disable=arguments-differ
prompt_key = self.model_args.get("vllm_prompt_key", "prompt")
input_ids_key = self.model_args.get("vllm_input_ids_key", "input_ids")
return self._add_request_internal(data[prompt_key], data[input_ids_key], is_eval=is_eval)

def _add_request_internal(self, prompt_list, prompt_token_id_list, is_eval=False):
if self._need_to_reset_scheduler:
self._reset_scheduler()
Expand Down Expand Up @@ -583,6 +587,53 @@ def num_layers(self):
"""
return self.model_config.hf_config.num_hidden_layers

def broadcast_var_object_dict(self, obj_dict, src_rank):
if torch.distributed.get_rank() == src_rank:
dict_as_list = list(obj_dict.items())
list_length = len(dict_as_list)
length_tensor = torch.tensor(list_length, device='cuda')
torch.distributed.broadcast(length_tensor, src_rank)
torch.distributed.broadcast_object_list(dict_as_list, src=src_rank)
return obj_dict
else:
length_tensor = torch.tensor(0, device='cuda')
torch.distributed.broadcast(length_tensor, src_rank)
list_length = length_tensor.item()
dict_as_list = [None] * list_length
torch.distributed.broadcast_object_list(dict_as_list, src=src_rank)
return dict(dict_as_list)

def generate_vllm(self, query, is_eval):
num_gpu_blocks, num_cpu_blocks = self.profile_cache_blocks()
num_blocks = torch.tensor([num_gpu_blocks, num_cpu_blocks], device='cuda')
torch.distributed.all_reduce(num_blocks, op=torch.distributed.ReduceOp.MIN)
min_gpu_blocks = num_blocks[0].item()
min_cpu_blocks = num_blocks[1].item()
self.set_cache_config(min_gpu_blocks, min_cpu_blocks)
if self.is_last_rank():
self.build_scheduler()
self.reinit_cache_engine()
# add requests of current episode to vllm scheduler
if self.is_last_rank():
self._add_request(query, is_eval=is_eval)
step_outputs = True
while step_outputs:
schedule_query = None
if self.is_last_rank():
schedule_query = self.schedule()
schedule_query = self.broadcast_var_object_dict(schedule_query, torch.distributed.get_world_size()-1)
output = self.execute_step(schedule_query)
if self.is_last_rank():
step_outputs = bool(output)
signal_tensor = torch.tensor(step_outputs, device='cuda')
torch.distributed.broadcast(signal_tensor, torch.distributed.get_world_size()-1)
else:
signal_tensor = torch.tensor(True, device='cuda')
torch.distributed.broadcast(signal_tensor, torch.distributed.get_world_size()-1)
step_outputs = signal_tensor.item()
if self.is_last_rank():
return self.outputs

def schedule(self):
if self.start_time is None:
self.start_time = time.monotonic()
Expand Down Expand Up @@ -664,15 +715,6 @@ def execute_step(self, data):

return output

def decode(self):
if not self.timers("decode").started_:
self.timers("decode").start()
self.outputs = sorted(self.outputs, key=lambda x: int(x.request_id))
rets = self.decode_internal(self.outputs)
rets = utils.to_device('cpu', rets)
self.timers("decode").stop()
return rets

def offload_weights(self):
"""
offload weights
Expand Down
44 changes: 1 addition & 43 deletions chatlearn/runtime/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,11 @@
import math
from itertools import cycle

from ray.util.queue import Queue

from chatlearn.data.ranking import batch_generation_ranking
from chatlearn.utils import future
from chatlearn.utils.logger import logger
from .executor import Executor
from .utils import vllm_post_process_generate_step_one_model
from .utils import encode_data, reinit_cache_engine, prepare_vllm
from .utils import execute_in_parallel, decode_data
from .utils import encode_data

# pylint: disable=not-callable
class Environment(Executor):
Expand Down Expand Up @@ -115,30 +111,6 @@ def num_iteration(self):
else:
return self.batch_per_episode

def execute_vllm(self, model_replica, query, out_queues, mb, is_eval, func_name):
self.execute_onload(model_replica)

# profile cache blocks
prepare_vllm(model_replica)

# reinit cache engine
reinit_cache_engine(model_replica)
# add requests of current episode to vllm scheduler
ret = model_replica.tailer._add_request.remote(query, is_eval=is_eval)
future.get(ret)
step_outputs = True
data_queue_internal = Queue()
while step_outputs:
query = model_replica.tailer.schedule.remote()
data_queue_internal.put(encode_data(mb, query))
output = self.generate_step_one_model_internal(self.first_model, data_queue_internal, mb, \
model_replica, func_name, False, is_eval=is_eval)
data = output[-1][0]
step_outputs = future.get(data)
vllm_post_process_generate_step_one_model(model_replica, out_queues, mb)
self.execute_offload(model_replica)


def execute(self, is_eval):
data_queues, out_queue = self.setup_queues()
data_producer_iter = cycle(iter(self.models[0].replicas))
Expand All @@ -149,20 +121,6 @@ def execute(self, is_eval):
encoded_data = encode_data(mb, query)
for data_queue in data_queues:
data_queue.put(encoded_data)

if self.first_model.use_vllm_backend:
data_queue = self.first_node.get_input_queues()
self.timers(f"{self.first_model.name}").start()
args_list = []
for model_replica in self.first_model.replicas:
if data_queue.qsize() == 0:
break
data = data_queue.get()
mb, query = decode_data(data)
func_name = self.first_node.func_name
args_list.append((model_replica, query, self.first_node.out_queues, mb, is_eval, func_name))
execute_in_parallel(self.execute_vllm, args_list)
self.timers(f"{self.first_model.name}").stop()
self.compute_loop(out_queue, self.num_iteration)
return out_queue

Expand Down
6 changes: 2 additions & 4 deletions chatlearn/runtime/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from chatlearn.utils import future
from chatlearn.utils.global_vars import get_args
from chatlearn.utils.logger import logger
from .utils import encode_data, decode_data
from .utils import encode_data
from .utils import FlowParser


Expand Down Expand Up @@ -285,10 +285,8 @@ def compute_loop_one_model(self, model_node, num_batch, is_eval):
return results

def compute_loop(self, out_queue, num_batch):
for i, model_group in enumerate(self.model_flow.flow_topology):
for model_group in self.model_flow.flow_topology:
for model_node in model_group:
if model_node.model.use_vllm_backend and i == 0:
continue
self.compute_loop_one_model(model_node, num_batch, self.is_eval)

data = [None] * len(self.model_flow.return_model_nodes)
Expand Down
63 changes: 0 additions & 63 deletions chatlearn/runtime/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
import concurrent.futures
import textwrap
import inspect
from chatlearn.utils import future


def encode_data(mb, data):
Expand All @@ -31,68 +30,6 @@ def decode_data(data):
return mb, data


def build_scheduler(model_replica):
# build for only last rank of each replica.
future.get(model_replica.tailer.build_scheduler.remote())


def free_cache_engine(model_replica):
rets = []
for actor in model_replica.all_actors:
rets.append(actor.free_cache_engine.remote())
rets = future.get(rets)


def prepare_vllm(model_replica):
"""Profiling cache blocks and build scheduler."""
profile_cache_blocks(model_replica)
# setup vllm scheduler
build_scheduler(model_replica)


def profile_cache_blocks(model_replica):
rets = []
for actor in model_replica.all_actors:
rets.append(actor.profile_cache_blocks.remote())
rets = future.get(rets)

num_gpu_blocks = min(a[0] for a in rets)
num_cpu_blocks = min(a[1] for a in rets)

rets = []
for actor in model_replica.all_actors:
rets.append(actor.set_cache_config.remote(num_gpu_blocks, num_cpu_blocks))
rets = future.get(rets)


def reinit_cache_engine(model_replica):
rets = []
for actor in model_replica.all_actors:
rets.append(actor.reinit_cache_engine.remote())
rets = future.get(rets)


def vllm_post_process_generate_step_one_model(replica, out_queue, mb):
"""
Args:
model: DistModel
out_queue: Queue
"""
output = replica.tailer.decode.remote()

free_cache_engine(replica)

# If tp > 1 or pp > 1 for current model, its `output` will be a list whose
# length is the number of Actors. In this case, all members in the list
# are the same, and we choose output[-1] to put into out_queue.
last_output = output[-1] if isinstance(output, list) else output
if isinstance(out_queue, list):
for oq in out_queue:
oq.put(encode_data(mb, last_output))
else:
out_queue.put(encode_data(mb, last_output))


def parse_assign_target(line):
targets = []
for target in line.targets:
Expand Down
5 changes: 1 addition & 4 deletions chatlearn/schedule/model_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,12 +179,9 @@ def set_func_decorator(self, model):

# public user function
# TODO: use decorator to annotate
# TODO: we may need to merge these vllm func call
vllm_funcs = ['build_scheduler', 'free_cache_engine', 'profile_cache_blocks',
'set_cache_config', 'reinit_cache_engine', 'decode', '_add_request', 'schedule']
for func_name in ["save_checkpoint", "model_setup", "onload_optimizer_states", "offload_optimizer_states",
'offload_weights', 'onload_weights', 'offload_main_weights', 'onload_main_weights',
'free_grad_buffers', 'build_grad_buffers', 'build_dataset', '_build_dataloader'] + model.call_funcs + vllm_funcs:
'free_grad_buffers', 'build_grad_buffers', 'build_dataset', '_build_dataloader', "generate_vllm"] + model.call_funcs:
decorate_class_func(model_cls, func_name, monitor_error, func_name)
set_decorated(model.name)

Expand Down
1 change: 0 additions & 1 deletion docs/en/advanced.rst
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,6 @@ Here is an example demonstrating how to pass the `adaptive_parallel_strategy_on_

.. code-block:: python
# model = get_model(model_provider)
load_checkpoint(
model, None, None,
adaptive_parallel_strategy=self.args.adaptive_parallel_strategy_on_checkpoint
Expand Down
21 changes: 1 addition & 20 deletions docs/en/programming/vllm.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,9 @@ For now, we enable vLLM to accelerate policy generation.
## Model Definition

Similar to inheriting `MegatronModule` for implementing [PolicyInference Model](https://github.com/alibaba/ChatLearn/blob/main/examples/megatron/models/old_policy_inference.py), the vLLM backend can be enabled by inheriting `VLLMModule` class and implementing the following key modules:
- model_provider: model definition function.
- setup: call model_provider to define model. Optionly, call `load_checkpoint` or others.
- build_dataset: Preprocess train/eval dataset with vLLM tokenizer.
- eval_forward: distributed inference tasks in eval mode.
- forward_step: distributed inference tasks in training mode.
- _add_request: prepare inputs for vLLM scheduler.
- decode_internal: decode generation outputs of vLLM as you need.

Code structure shows as following:
Expand All @@ -31,24 +28,14 @@ class VLLMPolicyInference(VLLMModule):
def build_dataset(self, train_prompts, is_eval=False):
pass

def model_provider(self):
"""Build the model."""
pass

def eval_forward(self, data, iteration=0):
pass

def _add_request(self, data):
pass

def forward_step(self, data, iteration=0):
pass

def decode_internal(self, batched_outputs):
pass
```

You can refer to[vllm_policy_inference.py](https://github.com/alibaba/ChatLearn/blob/main/examples/megatron/models/vllm_policy_inference.py), in which build_dataset/_add_request/forward_step/decode_internal clarified as following:
You can refer to[vllm_policy_inference.py](https://github.com/alibaba/ChatLearn/blob/main/examples/megatron/models/vllm_policy_inference.py), in which build_dataset/forward_step/decode_internal clarified as following:

- build_dataset: Use `tokenizer`, you only need to return prompt_ids and prompt string. In `build_dataset`, [VLLMPromptPipeline](https://github.com/alibaba/ChatLearn/blob/main/examples/megatron/data/prompt_dataset.py#141) shows as following:
```python
Expand Down Expand Up @@ -82,12 +69,6 @@ class VLLMPolicyInference(VLLMModule):
return prompts_dataset
```

- _add_request: add preprocessed request pairs (input_ids, prompt) to vLLM scheduler
```python
def _add_request(self, data, is_eval=False):
return self._add_request_internal(data["prompt"], data["input_ids"], is_eval=is_eval)
```

- forward_step: take batch `data` scheduled by vLLM scheduler as input, and call `execute_step` for distributed inference.

```python
Expand Down
1 change: 0 additions & 1 deletion docs/zh/advanced.rst
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,6 @@ YAML 配置

.. code-block:: python
# model = get_model(model_provider)
load_checkpoint(
model, None, None,
adaptive_parallel_strategy=self.args.adaptive_parallel_strategy_on_checkpoint
Expand Down
Loading

0 comments on commit 97aad53

Please sign in to comment.