From f0924048308cacc4fa8e7e114361fe80aa7f2614 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Tue, 3 Oct 2023 09:08:41 +0000 Subject: [PATCH] Handling bloom prefix. --- server/text_generation_server/models/bloom.py | 2 +- .../text_generation_server/utils/weights.py | 19 ++++++++++++++----- 2 files changed, 15 insertions(+), 6 deletions(-) diff --git a/server/text_generation_server/models/bloom.py b/server/text_generation_server/models/bloom.py index 0151b017025..8e8daad3547 100644 --- a/server/text_generation_server/models/bloom.py +++ b/server/text_generation_server/models/bloom.py @@ -74,7 +74,7 @@ def __init__( torch.distributed.barrier(group=self.process_group) filenames = weight_files(model_id, revision=revision, extension=".safetensors") weights = Weights( - filenames, device=device, dtype=dtype, process_group=self.process_group + filenames, device=device, dtype=dtype, process_group=self.process_group, prefix="transformer", ) if config.quantize == "gptq": weights._set_gptq_params(model_id) diff --git a/server/text_generation_server/utils/weights.py b/server/text_generation_server/utils/weights.py index 8a19fd9f722..4bae8cc07bf 100644 --- a/server/text_generation_server/utils/weights.py +++ b/server/text_generation_server/utils/weights.py @@ -16,6 +16,7 @@ def __init__( dtype, process_group, aliases: Optional[Dict[str, List[str]]] = None, + prefix: Optional[str] = None ): routing = {} for filename in filenames: @@ -33,6 +34,7 @@ def __init__( self.device = device self.dtype = dtype self.process_group = process_group + self.prefix = prefix self._handles = {} def _get_handle(self, filename): @@ -43,15 +45,22 @@ def _get_handle(self, filename): return self._handles[filename] def get_filename(self, tensor_name: str) -> (str, str): - filename = self.routing.get(tensor_name, None) - if filename is None: - aliases = self.aliases.get(tensor_name, []) + + names = [tensor_name] + if self.prefix is not None: + prefixed = f"{self.prefix}.{tensor_name}" + names.append(prefixed) + for name in names: + filename = self.routing.get(name, None) + if filename is not None: + return str(filename), name + + aliases = self.aliases.get(name, []) for alias in aliases: filename = self.routing.get(alias, None) if filename is not None: return str(filename), alias - raise RuntimeError(f"weight {tensor_name} does not exist") - return str(filename), tensor_name + raise RuntimeError(f"weight {tensor_name} does not exist") def _get_slice(self, tensor_name: str): filename, tensor_name = self.get_filename(tensor_name)