Skip to content

Commit

Permalink
refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
cli99 committed Nov 8, 2023
1 parent 51beeef commit a22ce6d
Showing 1 changed file with 20 additions and 24 deletions.
44 changes: 20 additions & 24 deletions llmfoundry/models/mpt/modeling_mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -707,34 +707,30 @@ def fsdp_wrap_fn(self, module: nn.Module) -> bool:

# Activation Checkpointing
def activation_checkpointing_fn(self, module: nn.Module) -> bool:
if not hasattr(self.config, 'activation_checkpointing_target'
) or self.config.activation_checkpointing_target is None:
act_ckpt_list = getattr(self.config, 'activation_checkpointing_target',
None) or ['MPTBlock']

if 'MPTBlock' in act_ckpt_list or 'mptblock' in act_ckpt_list:
log.info(
f'activation checkpointing MPTBlock as activation_checkpointing_target is not set in model_config'
'Activation checkpointing MPTBlock only (ignoring other sub-block modules if specified in activation_checkpointing_target).'
)
return isinstance(module, MPTBlock)
else:
act_ckpt_list = self.config.activation_checkpointing_target
if 'MPTBlock' in act_ckpt_list:
act_ckpt_list = ['MPTBlock']
warnings.warn(
f'activation checkpointing MPTBlock, ignoring other sub-block modules if specified'

mod_types = ()
for mod_name in act_ckpt_list:
if mod_name.lower() == 'mptblock':
mod_types += (MPTBlock,)
elif mod_name in ATTN_CLASS_REGISTRY:
mod_types += (ATTN_CLASS_REGISTRY[mod_name],)
elif mod_name in FFN_CLASS_REGISTRY:
mod_types += (FFN_CLASS_REGISTRY[mod_name],)
elif mod_name in NORM_CLASS_REGISTRY:
mod_types += (NORM_CLASS_REGISTRY[mod_name],)
else:
raise ValueError(
f'{mod_name=} (specified in activation_checkpointing_target) is not a recognized option, available options are names in ATTN_CLASS_REGISTRY, FFN_CLASS_REGISTRY, NORM_CLASS_REGISTRY, or MPTBlock.'
)
mod_types = ()
for mod_name in act_ckpt_list:
if mod_name.lower() == 'mptblock':
mod_types += (MPTBlock,)
elif mod_name in ATTN_CLASS_REGISTRY:
mod_types += (ATTN_CLASS_REGISTRY[mod_name],)
elif mod_name in FFN_CLASS_REGISTRY:
mod_types += (FFN_CLASS_REGISTRY[mod_name],)
elif mod_name in NORM_CLASS_REGISTRY:
mod_types += (NORM_CLASS_REGISTRY[mod_name],)
else:
warnings.warn(
f'module name specified in activation_checkpointing_target ({mod_name}) not recognized, available options are names in ATTN_CLASS_REGISTRY, FFN_CLASS_REGISTRY, NORM_CLASS_REGISTRY, or MPTBlock.'
)
return isinstance(module, mod_types)
return isinstance(module, mod_types)

def prepare_inputs_for_generation(
self,
Expand Down

0 comments on commit a22ce6d

Please sign in to comment.