Skip to content

Commit

Permalink
Handling bloom prefix.
Browse files Browse the repository at this point in the history
  • Loading branch information
Narsil committed Oct 3, 2023
1 parent bd998d8 commit f092404
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 6 deletions.
2 changes: 1 addition & 1 deletion server/text_generation_server/models/bloom.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def __init__(
torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(
filenames, device=device, dtype=dtype, process_group=self.process_group
filenames, device=device, dtype=dtype, process_group=self.process_group, prefix="transformer",
)
if config.quantize == "gptq":
weights._set_gptq_params(model_id)
Expand Down
19 changes: 14 additions & 5 deletions server/text_generation_server/utils/weights.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ def __init__(
dtype,
process_group,
aliases: Optional[Dict[str, List[str]]] = None,
prefix: Optional[str] = None
):
routing = {}
for filename in filenames:
Expand All @@ -33,6 +34,7 @@ def __init__(
self.device = device
self.dtype = dtype
self.process_group = process_group
self.prefix = prefix
self._handles = {}

def _get_handle(self, filename):
Expand All @@ -43,15 +45,22 @@ def _get_handle(self, filename):
return self._handles[filename]

def get_filename(self, tensor_name: str) -> (str, str):
filename = self.routing.get(tensor_name, None)
if filename is None:
aliases = self.aliases.get(tensor_name, [])

names = [tensor_name]
if self.prefix is not None:
prefixed = f"{self.prefix}.{tensor_name}"
names.append(prefixed)
for name in names:
filename = self.routing.get(name, None)
if filename is not None:
return str(filename), name

aliases = self.aliases.get(name, [])
for alias in aliases:
filename = self.routing.get(alias, None)
if filename is not None:
return str(filename), alias
raise RuntimeError(f"weight {tensor_name} does not exist")
return str(filename), tensor_name
raise RuntimeError(f"weight {tensor_name} does not exist")

def _get_slice(self, tensor_name: str):
filename, tensor_name = self.get_filename(tensor_name)
Expand Down

0 comments on commit f092404

Please sign in to comment.