diff --git a/pyproject.toml b/pyproject.toml index 59d0cd330..5abcfc481 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -124,3 +124,68 @@ [build-system] build-backend="poetry.core.masonry.api" requires=["poetry-core"] + +[tool.pyright] + # All rules apart from base are shown explicitly below + deprecateTypingAliases=true + disableBytesTypePromotions=true + exclude = [ + "*/**/*.py", + "!/transformer_lens/hook_points.py" + ] + reportAssertAlwaysTrue=true + reportConstantRedefinition=true + reportDeprecated=true + reportDuplicateImport=true + reportFunctionMemberAccess=true + reportGeneralTypeIssues=true + reportIncompatibleMethodOverride=true + reportIncompatibleVariableOverride=true + reportIncompleteStub=true + reportInconsistentConstructor=true + reportInvalidStringEscapeSequence=true + reportInvalidStubStatement=true + reportInvalidTypeVarUse=true + reportMatchNotExhaustive=true + reportMissingParameterType=true + reportMissingTypeArgument=false + reportMissingTypeStubs=false + reportOptionalCall=true + reportOptionalContextManager=true + reportOptionalIterable=true + reportOptionalMemberAccess=true + reportOptionalOperand=true + reportOptionalSubscript=true + reportOverlappingOverload=true + reportPrivateImportUsage=true + reportPrivateUsage=true + reportSelfClsParameterName=true + reportTypeCommentUsage=true + reportTypedDictNotRequiredAccess=true + reportUnboundVariable=true + reportUnknownArgumentType=false + reportUnknownLambdaType=true + reportUnknownMemberType=false + reportUnknownParameterType=false + reportUnknownVariableType=false + reportUnnecessaryCast=true + reportUnnecessaryComparison=true + reportUnnecessaryContains=true + reportUnnecessaryIsInstance=true + reportUnsupportedDunderAll=true + reportUntypedBaseClass=true + reportUntypedClassDecorator=true + reportUntypedFunctionDecorator=true + reportUntypedNamedTuple=true + reportUnusedClass=true + reportUnusedCoroutine=true + reportUnusedExpression=true + reportUnusedFunction=true + reportUnusedImport=true + reportUnusedVariable=true + reportWildcardImportFromLibrary=true + strictDictionaryInference=true + strictListInference=true + strictParameterNoneValue=true + strictSetInference=true + diff --git a/transformer_lens/hook_points.py b/transformer_lens/hook_points.py index f7afca989..72d95f52e 100644 --- a/transformer_lens/hook_points.py +++ b/transformer_lens/hook_points.py @@ -7,8 +7,21 @@ from contextlib import contextmanager from dataclasses import dataclass from functools import partial -from typing import Callable, Dict, List, Literal, Optional, Sequence, Tuple, Union, cast - +from typing import ( + Any, + Callable, + Dict, + List, + Literal, + Optional, + Protocol, + Sequence, + Tuple, + Union, + runtime_checkable, +) + +import torch import torch.nn as nn import torch.utils.hooks as hooks @@ -33,6 +46,20 @@ class LensHandle: NamesFilter = Optional[Union[Callable[[str], bool], Sequence[str]]] +@runtime_checkable +class _HookFunctionProtocol(Protocol): + """Protocol for hook functions.""" + + def __call__(self, tensor: torch.Tensor, *, hook: "HookPoint") -> Union[Any, None]: + ... + + +HookFunction = _HookFunctionProtocol # Callable[..., _HookFunctionProtocol] + +DeviceType = Optional[torch.device] +_grad_t = Union[Tuple[torch.Tensor, ...], torch.Tensor] + + class HookPoint(nn.Module): """ A helper class to access intermediate activations in a PyTorch model (inspired by Garcon). @@ -49,14 +76,14 @@ def __init__(self): # A variable giving the hook's name (from the perspective of the root # module) - this is set by the root module at setup. - self.name = None + self.name: Union[str, None] = None - def add_perma_hook(self, hook, dir="fwd") -> None: + def add_perma_hook(self, hook: HookFunction, dir: Literal["fwd", "bwd"] = "fwd") -> None: self.add_hook(hook, dir=dir, is_permanent=True) def add_hook( self, - hook: Callable, + hook: HookFunction, dir: Literal["fwd", "bwd"] = "fwd", is_permanent: bool = False, level: Optional[int] = None, @@ -69,7 +96,11 @@ def add_hook( If prepend is True, add this hook before all other hooks """ - def full_hook(module, module_input, module_output): + def full_hook( + module: torch.nn.Module, + module_input: Any, + module_output: Any, + ): if ( dir == "bwd" ): # For a backwards hook, module_output is a tuple of (grad,) - I don't know why. @@ -85,7 +116,7 @@ def full_hook(module, module_input, module_output): _internal_hooks = self._forward_hooks visible_hooks = self.fwd_hooks elif dir == "bwd": - pt_handle = self.register_full_backward_hook(full_hook) + pt_handle = self.register_backward_hook(full_hook) _internal_hooks = self._backward_hooks visible_hooks = self.bwd_hooks else: @@ -101,7 +132,12 @@ def full_hook(module, module_input, module_output): else: visible_hooks.append(handle) - def remove_hooks(self, dir="fwd", including_permanent=False, level=None) -> None: + def remove_hooks( + self, + dir: Literal["fwd", "bwd", "both"] = "fwd", + including_permanent: bool = False, + level: Optional[int] = None, + ) -> None: def _remove_hooks(handles: List[LensHandle]) -> List[LensHandle]: output_handles = [] for handle in handles: @@ -124,14 +160,15 @@ def clear_context(self): del self.ctx self.ctx = {} - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: return x def layer(self): # Returns the layer index if the name has the form 'blocks.{layer}.{...}' # Helper function that's mainly useful on HookedTransformer # If it doesn't have this form, raises an error - - assert self.name is not None # keep mypy happy + if self.name is None: + raise ValueError("Name cannot be None") split_name = self.name.split(".") return int(split_name[1]) @@ -157,7 +194,11 @@ class HookedRootModule(nn.Module): loss.backward() (and so need to disable the reset_hooks_end flag on run_with_hooks) """ - def __init__(self, *args): + name: Optional[str] + mod_dict: Dict[str, nn.Module] + hook_dict: Dict[str, HookPoint] + + def __init__(self, *args: Any): super().__init__() self.is_caching = False self.context_level = 0 @@ -172,19 +213,25 @@ def setup(self): "HookPoint". """ self.mod_dict = {} - self.hook_dict: Dict[str, HookPoint] = {} + self.hook_dict = {} for name, module in self.named_modules(): if name == "": continue module.name = name self.mod_dict[name] = module - if "HookPoint" in str(type(module)): + # TODO: is the bottom line the same as "if "HookPoint" in str(type(module)):" + if isinstance(module, HookPoint): self.hook_dict[name] = module def hook_points(self): return self.hook_dict.values() - def remove_all_hook_fns(self, direction="both", including_permanent=False, level=None): + def remove_all_hook_fns( + self, + direction: Literal["fwd", "bwd", "both"] = "both", + including_permanent: bool = False, + level: Union[int, None] = None, + ): for hp in self.hook_points(): hp.remove_hooks(direction, including_permanent=including_permanent, level=level) @@ -194,10 +241,10 @@ def clear_contexts(self): def reset_hooks( self, - clear_contexts=True, - direction="both", - including_permanent=False, - level=None, + clear_contexts: bool = True, + direction: Literal["fwd", "bwd", "both"] = "both", + including_permanent: bool = False, + level: Union[int, None] = None, ): if clear_contexts: self.clear_contexts() @@ -206,15 +253,16 @@ def reset_hooks( def check_and_add_hook( self, - hook_point, - hook_point_name, - hook, - dir="fwd", - is_permanent=False, - level=None, - prepend=False, + hook_point: HookPoint, + hook_point_name: str, + hook: HookFunction, + dir: Literal["fwd", "bwd"] = "fwd", + is_permanent: bool = False, + level: Union[int, None] = None, + prepend: bool = False, ) -> None: """Runs checks on the hook, and then adds it to the hook point""" + self.check_hooks_to_add( hook_point, hook_point_name, @@ -227,22 +275,32 @@ def check_and_add_hook( def check_hooks_to_add( self, - hook_point, - hook_point_name, - hook, - dir="fwd", - is_permanent=False, - prepend=False, + hook_point: HookPoint, + hook_point_name: str, + hook: HookFunction, + dir: Literal["fwd", "bwd"] = "fwd", + is_permanent: bool = False, + prepend: bool = False, ) -> None: """Override this function to add checks on which hooks should be added""" pass def add_hook( - self, name, hook, dir="fwd", is_permanent=False, level=None, prepend=False + self, + name: Union[str, Callable[[str], bool]], + hook: HookFunction, + dir: Literal["fwd", "bwd"] = "fwd", + is_permanent: bool = False, + level: Union[int, None] = None, + prepend: bool = False, ) -> None: - if type(name) == str: + if isinstance(name, str): + hook_point = self.mod_dict[name] + assert isinstance( + hook_point, HookPoint + ) # TODO does adding assert meaningfully slow down performance? I've added them for type checking purposes. self.check_and_add_hook( - self.mod_dict[name], + hook_point, name, hook, dir=dir, @@ -264,7 +322,12 @@ def add_hook( prepend=prepend, ) - def add_perma_hook(self, name, hook, dir="fwd") -> None: + def add_perma_hook( + self, + name: Union[str, Callable[[str], bool]], + hook: HookFunction, + dir: Literal["fwd", "bwd"] = "fwd", + ) -> None: self.add_hook(name, hook, dir=dir, is_permanent=True) @contextmanager @@ -308,7 +371,7 @@ def hooks( self.mod_dict[name].add_hook(hook, dir="bwd", level=self.context_level) else: # Otherwise, name is a Boolean function on names - for hook_name, hp in self.hook_dict: # type: ignore + for hook_name, hp in self.hook_dict.items(): if name(hook_name): hp.add_hook(hook, dir="bwd", level=self.context_level) yield self @@ -321,12 +384,12 @@ def hooks( def run_with_hooks( self, - *model_args, + *model_args: Any, # TODO: unsure about whether or not this Any typing is correct or not; may need to be replaced with something more specific? fwd_hooks: List[Tuple[Union[str, Callable], Callable]] = [], bwd_hooks: List[Tuple[Union[str, Callable], Callable]] = [], - reset_hooks_end=True, - clear_contexts=False, - **model_kwargs, + reset_hooks_end: bool = True, + clear_contexts: bool = False, + **model_kwargs: Any, ): """ Runs the model with specified forward and backward hooks. @@ -361,7 +424,7 @@ def add_caching_hooks( self, names_filter: NamesFilter = None, incl_bwd: bool = False, - device=None, + device: DeviceType = None, # TODO: unsure about whether or not this device typing is correct or not? remove_batch_dim: bool = False, cache: Optional[dict] = None, ) -> dict: @@ -382,19 +445,19 @@ def add_caching_hooks( if names_filter is None: names_filter = lambda name: True - elif type(names_filter) == str: + elif isinstance(names_filter, str): filter_str = names_filter names_filter = lambda name: name == filter_str - elif type(names_filter) == list: + elif isinstance(names_filter, list): filter_list = names_filter names_filter = lambda name: name in filter_list - # mypy can't seem to infer this - names_filter = cast(Callable[[str], bool], names_filter) + assert callable(names_filter), "names_filter must be a callable" self.is_caching = True - def save_hook(tensor, hook, is_backward): + def save_hook(tensor: torch.Tensor, hook: HookPoint, is_backward: bool): + assert hook.name is not None hook_name = hook.name if is_backward: hook_name += "_grad" @@ -412,15 +475,15 @@ def save_hook(tensor, hook, is_backward): def run_with_cache( self, - *model_args, + *model_args: Any, names_filter: NamesFilter = None, - device=None, - remove_batch_dim=False, - incl_bwd=False, - reset_hooks_end=True, - clear_contexts=False, - pos_slice=None, - **model_kwargs, + device: DeviceType = None, + remove_batch_dim: bool = False, + incl_bwd: bool = False, + reset_hooks_end: bool = True, + clear_contexts: bool = False, + pos_slice: Optional[Union[Slice, SliceInput]] = None, + **model_kwargs: Any, ): """ Runs the model and returns the model output and a Cache object. @@ -477,7 +540,7 @@ def get_caching_hooks( self, names_filter: NamesFilter = None, incl_bwd: bool = False, - device=None, + device: DeviceType = None, remove_batch_dim: bool = False, cache: Optional[dict] = None, pos_slice: Union[Slice, SliceInput] = None, @@ -509,12 +572,19 @@ def get_caching_hooks( elif isinstance(names_filter, list): filter_list = names_filter names_filter = lambda name: name in filter_list + elif callable(names_filter): + names_filter = names_filter + else: + raise ValueError("names_filter must be a string, list of strings, or function") + assert callable(names_filter) # Callable[[str], bool] + self.is_caching = True - # mypy can't seem to infer this - names_filter = cast(Callable[[str], bool], names_filter) + def save_hook(tensor: torch.Tensor, hook: HookPoint, is_backward: bool = False): + # for attention heads the pos dimension is the third from last + if hook.name is None: + raise RuntimeError("Hook should have been provided a name") - def save_hook(tensor, hook, is_backward=False): hook_name = hook.name if is_backward: hook_name += "_grad" @@ -522,7 +592,6 @@ def save_hook(tensor, hook, is_backward=False): if remove_batch_dim: resid_stream = resid_stream[0] - # for attention heads the pos dimension is the third from last if ( hook.name.endswith("hook_q") or hook.name.endswith("hook_k") @@ -544,7 +613,7 @@ def save_hook(tensor, hook, is_backward=False): fwd_hooks = [] bwd_hooks = [] - for name, hp in self.hook_dict.items(): + for name, _ in self.hook_dict.items(): if names_filter(name): fwd_hooks.append((name, partial(save_hook, is_backward=False))) if incl_bwd: @@ -552,7 +621,13 @@ def save_hook(tensor, hook, is_backward=False): return cache, fwd_hooks, bwd_hooks - def cache_all(self, cache, incl_bwd=False, device=None, remove_batch_dim=False): + def cache_all( + self, + cache: Optional[dict], + incl_bwd: bool = False, + device: DeviceType = None, + remove_batch_dim: bool = False, + ): logging.warning( "cache_all is deprecated and will eventually be removed, use add_caching_hooks or run_with_cache" ) @@ -566,11 +641,11 @@ def cache_all(self, cache, incl_bwd=False, device=None, remove_batch_dim=False): def cache_some( self, - cache, + cache: Optional[dict], names: Callable[[str], bool], - incl_bwd=False, - device=None, - remove_batch_dim=False, + incl_bwd: bool = False, + device: DeviceType = None, + remove_batch_dim: bool = False, ): """Cache a list of hook provided by names, Boolean function on names""" logging.warning(