Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Model device is not being set from train_model #123

Open
rusheb opened this issue Mar 24, 2023 · 3 comments
Open

Model device is not being set from train_model #123

rusheb opened this issue Mar 24, 2023 · 3 comments
Labels
bug Something isn't working

Comments

@rusheb
Copy link
Collaborator

rusheb commented Mar 24, 2023

Description

In train_model.py, the device arg passed to train is obtained via the get_device function. However this device is not being passed to the model. Instead, the model is using the default HookedTransformer device, which is "cuda" if available or else "cpu".

Steps to reproduce

Train a model from a M1 mac, e.g.

poetry run python scripts/train_model.py ./data/maze/g4-n10

Check the logs. device will be reported as "mps" while model.device will be set to "cpu".

Mitigation

One possible fix would be to add a device parameter to ConfigHolder.create_model() and set the device on the HookedTransfomrmerConfig.

I attempted this and got the following error:

Traceback (most recent call last):
  File "/Users/rusheb/code/maze-transformer/scripts/train_model.py", line 75, in <module>
    fire.Fire(train_model)
  File "/Users/rusheb/code/maze-transformer/.venv/lib/python3.10/site-packages/fire/core.py", line 141, in Fire
    component_trace = _Fire(component, args, parsed_flag_args, context, name)
  File "/Users/rusheb/code/maze-transformer/.venv/lib/python3.10/site-packages/fire/core.py", line 475, in _Fire
    component, remaining_args = _CallAndUpdateTrace(
  File "/Users/rusheb/code/maze-transformer/.venv/lib/python3.10/site-packages/fire/core.py", line 691, in _CallAndUpdateTrace
    component = fn(*varargs, **kwargs)
  File "/Users/rusheb/code/maze-transformer/scripts/train_model.py", line 69, in train_model
    train(dataloader, cfg, logger, output_path, device)
  File "/Users/rusheb/code/maze-transformer/maze_transformer/training/training.py", line 88, in train
    loss = model(batch_on_device[:, :-1], return_type="loss")
  File "/Users/rusheb/code/maze-transformer/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "/Users/rusheb/code/maze-transformer/.venv/lib/python3.10/site-packages/transformer_lens/HookedTransformer.py", line 302, in forward
    residual = block(
  File "/Users/rusheb/code/maze-transformer/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "/Users/rusheb/code/maze-transformer/.venv/lib/python3.10/site-packages/transformer_lens/components.py", line 693, in forward
    self.attn(
  File "/Users/rusheb/code/maze-transformer/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "/Users/rusheb/code/maze-transformer/.venv/lib/python3.10/site-packages/transformer_lens/components.py", line 408, in forward
    attn_scores = self.apply_causal_mask(
  File "/Users/rusheb/code/maze-transformer/.venv/lib/python3.10/site-packages/transformer_lens/components.py", line 471, in apply_causal_mask
    return torch.where(
RuntimeError: 0'th index 32 of x tensor does not match the other tensors

I'm not sure of the cause of this error. It might be that HookedTransformer does not support mac acceleration.

@rusheb rusheb added the bug Something isn't working label Mar 24, 2023
@luciaquirke
Copy link
Contributor

TransformerLens does not support mac acceleration, possibly due to all the issues: pytorch/pytorch#77764

@rusheb
Copy link
Collaborator Author

rusheb commented Mar 25, 2023

In that case does it make sense to move the mps branch from get_device()?

@luciaquirke
Copy link
Contributor

luciaquirke commented Mar 26, 2023

Yeah maybe. I'm talking to Joseph about adding MPS support to transformerlens e.g.: https://github.com/neelnanda-io/TransformerLens/pull/221/files but need to change pinned pytorch version in both circuitsvis and transformerlens for that. Maybe we should remove the MPS stuff in the meantime

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants