You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
In train_model.py, the device arg passed to train is obtained via the get_device function. However this device is not being passed to the model. Instead, the model is using the default HookedTransformer device, which is "cuda" if available or else "cpu".
Steps to reproduce
Train a model from a M1 mac, e.g.
poetry run python scripts/train_model.py ./data/maze/g4-n10
Check the logs. device will be reported as "mps" while model.device will be set to "cpu".
Mitigation
One possible fix would be to add a device parameter to ConfigHolder.create_model() and set the device on the HookedTransfomrmerConfig.
I attempted this and got the following error:
Traceback (most recent call last):
File "/Users/rusheb/code/maze-transformer/scripts/train_model.py", line 75, in <module>
fire.Fire(train_model)
File "/Users/rusheb/code/maze-transformer/.venv/lib/python3.10/site-packages/fire/core.py", line 141, in Fire
component_trace = _Fire(component, args, parsed_flag_args, context, name)
File "/Users/rusheb/code/maze-transformer/.venv/lib/python3.10/site-packages/fire/core.py", line 475, in _Fire
component, remaining_args = _CallAndUpdateTrace(
File "/Users/rusheb/code/maze-transformer/.venv/lib/python3.10/site-packages/fire/core.py", line 691, in _CallAndUpdateTrace
component = fn(*varargs, **kwargs)
File "/Users/rusheb/code/maze-transformer/scripts/train_model.py", line 69, in train_model
train(dataloader, cfg, logger, output_path, device)
File "/Users/rusheb/code/maze-transformer/maze_transformer/training/training.py", line 88, in train
loss = model(batch_on_device[:, :-1], return_type="loss")
File "/Users/rusheb/code/maze-transformer/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
return forward_call(*input, **kwargs)
File "/Users/rusheb/code/maze-transformer/.venv/lib/python3.10/site-packages/transformer_lens/HookedTransformer.py", line 302, in forward
residual = block(
File "/Users/rusheb/code/maze-transformer/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
return forward_call(*input, **kwargs)
File "/Users/rusheb/code/maze-transformer/.venv/lib/python3.10/site-packages/transformer_lens/components.py", line 693, in forward
self.attn(
File "/Users/rusheb/code/maze-transformer/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
return forward_call(*input, **kwargs)
File "/Users/rusheb/code/maze-transformer/.venv/lib/python3.10/site-packages/transformer_lens/components.py", line 408, in forward
attn_scores = self.apply_causal_mask(
File "/Users/rusheb/code/maze-transformer/.venv/lib/python3.10/site-packages/transformer_lens/components.py", line 471, in apply_causal_mask
return torch.where(
RuntimeError: 0'th index 32 of x tensor does not match the other tensors
I'm not sure of the cause of this error. It might be that HookedTransformer does not support mac acceleration.
The text was updated successfully, but these errors were encountered:
Yeah maybe. I'm talking to Joseph about adding MPS support to transformerlens e.g.: https://github.com/neelnanda-io/TransformerLens/pull/221/files but need to change pinned pytorch version in both circuitsvis and transformerlens for that. Maybe we should remove the MPS stuff in the meantime
Description
In
train_model.py
, thedevice
arg passed totrain
is obtained via the get_device function. However this device is not being passed to the model. Instead, the model is using the default HookedTransformer device, which is "cuda" if available or else "cpu".Steps to reproduce
Train a model from a M1 mac, e.g.
Check the logs.
device
will be reported as "mps" whilemodel.device
will be set to "cpu".Mitigation
One possible fix would be to add a
device
parameter toConfigHolder.create_model()
and set the device on the HookedTransfomrmerConfig.I attempted this and got the following error:
I'm not sure of the cause of this error. It might be that HookedTransformer does not support mac acceleration.
The text was updated successfully, but these errors were encountered: