Skip to content

Commit

Permalink
tidy up
Browse files Browse the repository at this point in the history
  • Loading branch information
ElliotStein committed Oct 31, 2024
1 parent 19c9da1 commit 8bd864a
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 22 deletions.
43 changes: 24 additions & 19 deletions mergekit/architecture.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,26 +202,13 @@ def _template_substitution(
return TemplateWithArithmetic(template).substitute(substitutions)


class AutomaticArchitectureInfo(ArchitectureInfo, BaseModel):
arch_name: str = Field(default="")
parameter_names: List[str] = Field(default_factory=list)
layered_parameter_names: Dict[str, List[str]] = Field(default_factory=dict)

def __init__(self, arch_name: str, parameter_names: List[str]):
super().__init__()

self.arch_name = arch_name
self.parameter_names = parameter_names
self.layered_parameter_names = self._hierarchy(self.parameter_names)
# We could further inspect layered_parameter_names to split out pre and post weights

def _hierarchy(self, names):
# Initialize a dictionary to hold the hierarchical structure
hierarchy = defaultdict(list)
def _hierarchy(names, layer_prefix=r"\.\d+\.") -> Dict[str, List[str]]:
hierarchy = defaultdict(list)

# Regular expression to match layers (denoted by .{integer}.)
layer_pattern = re.compile(r"\.\d+\.")
# Regular expression to match layers (denoted by .{integer}. by default)
layer_pattern = re.compile(layer_prefix)

if names:
for name in names:
# Find the layer part of the string (e.g., 'model.layers.0.')
match = layer_pattern.search(name)
Expand All @@ -235,15 +222,32 @@ def _hierarchy(self, names):
else:
hierarchy[name].append("")

return hierarchy
return hierarchy


class AutomaticArchitectureInfo(ArchitectureInfo, BaseModel):
arch_name: str = Field(default="")
parameter_names: List[str] = Field(default_factory=list)
layered_parameter_names: Dict[str, List[str]] = Field(default_factory=dict)

def __init__(self, arch_name: str, parameter_names: List[str]):
super().__init__()

self.arch_name = arch_name
self.parameter_names = parameter_names
self.layered_parameter_names = _hierarchy(self.parameter_names)

def name(self) -> str:
return self.arch_name

def pre_weights(self, config: PretrainedConfig) -> List[WeightInfo]:
# AutomaticArchitectureInfo places all parameters into layer_weights, rather than pre/post weights
# Since many models do not have a clear distinction between pre/post weights
return []

def post_weights(self, config: PretrainedConfig) -> List[WeightInfo]:
# AutomaticArchitectureInfo places all parameters into layer_weights, rather than pre/post weights
# Since many models do not have a clear distinction between pre/post weights
return []

def layer_weights(
Expand All @@ -259,6 +263,7 @@ def sliceable(self) -> bool:
return True

def num_layers(self, config: PretrainedConfig) -> int:
# Note lack of pre/post weights distinction means 'model.layer.i' may not correspond to the ith layer
return len(self.layered_parameter_names)


Expand Down
8 changes: 5 additions & 3 deletions mergekit/merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def run_merge(
if not merge_config.models and not merge_config.slices:
raise RuntimeError("No output requested")

arch_info = load_model_architecture(merge_config, options)
arch_info = _load_arch_info(merge_config, options)

# initialize loader cache and set options
loader_cache = LoaderCache()
Expand Down Expand Up @@ -275,13 +275,15 @@ def _update_config_vocab(
)


def load_model_architecture(merge_config, options):
def _load_arch_info(merge_config, options):
model_arch_info = [
get_architecture_info(m.config(trust_remote_code=options.trust_remote_code))
for m in merge_config.referenced_models()
]

# Check if any of the models failed to load architecture info
if any(a is False for a in model_arch_info):
# Attempt to load the architecture automatically if it's not specified
# Attempt to load the architecture automatically
model_arch_info = [
AutomaticArchitectureInfo(
arch_name=source_model.model.path,
Expand Down

0 comments on commit 8bd864a

Please sign in to comment.