Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pickling error when using cuda in collate_fn #20469

Open
Richienb opened this issue Dec 5, 2024 · 2 comments
Open

Pickling error when using cuda in collate_fn #20469

Richienb opened this issue Dec 5, 2024 · 2 comments
Labels
3rd party Related to a 3rd-party ver: 2.4.x waiting on author Waiting on user action, correction, or update

Comments

@Richienb
Copy link

Richienb commented Dec 5, 2024

Bug description

  • When I use cuda within the collate_fn parameter of the dataloader to pre-process generated data in bulk, and num_workers > 0,
  • I am required to use the spawn_ddp strategy in the trainer
  • Then, I get this error:
Traceback (most recent call last):
  File "/home/myuser/myproject/scripts/../train.py", line 1139, in <module>
    trainer.fit(training, data)
  File "/home/myuser/anaconda3/envs/myproject/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py", line 538, in fit
    call._call_and_handle_interrupt(
  File "/home/myuser/anaconda3/envs/myproject/lib/python3.10/site-packages/lightning/pytorch/trainer/call.py", line 46, in _call_and_handle_interrupt
    return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
  File "/home/myuser/anaconda3/envs/myproject/lib/python3.10/site-packages/lightning/pytorch/strategies/launchers/multiprocessing.py", line 136, in launch
    process_context = mp.start_processes(
  File "/home/myuser/anaconda3/envs/myproject/lib/python3.10/site-packages/torch/multiprocessing/spawn.py", line 268, in start_processes
    idx, process, tf_name = start_process(i)
  File "/home/myuser/anaconda3/envs/myproject/lib/python3.10/site-packages/torch/multiprocessing/spawn.py", line 263, in start_process
    process.start()
  File "/home/myuser/anaconda3/envs/myproject/lib/python3.10/multiprocessing/process.py", line 121, in start
    self._popen = self._Popen(self)
  File "/home/myuser/anaconda3/envs/myproject/lib/python3.10/multiprocessing/context.py", line 288, in _Popen
    return Popen(process_obj)
  File "/home/myuser/anaconda3/envs/myproject/lib/python3.10/multiprocessing/popen_spawn_posix.py", line 32, in __init__
    super().__init__(process_obj)
  File "/home/myuser/anaconda3/envs/myproject/lib/python3.10/multiprocessing/popen_fork.py", line 19, in __init__
    self._launch(process_obj)
  File "/home/myuser/anaconda3/envs/myproject/lib/python3.10/multiprocessing/popen_spawn_posix.py", line 47, in _launch
    reduction.dump(process_obj, fp)
  File "/home/myuser/anaconda3/envs/myproject/lib/python3.10/multiprocessing/reduction.py", line 60, in dump
    ForkingPickler(file, protocol).dump(obj)
AttributeError: Can't pickle local object 'TorchGraph.create_forward_hook.<locals>.after_forward_hook'
  • Removing the line wandb_logger.watch(training) fixes the problem

What version are you seeing the problem on?

v2.4

How to reproduce the bug

No response

Error messages and logs

# Error messages and logs here please

Environment

Current environment
#- PyTorch Lightning Version (e.g., 2.4.0): 2.4.0
#- PyTorch Version (e.g., 2.4): 3.10.15
#- Python version (e.g., 3.12): 3.10
#- OS (e.g., Linux): Linux
#- CUDA/cuDNN version: 12.4/11.5
#- GPU models and configuration: 1xRTX 3090
#- How you installed Lightning(`conda`, `pip`, source): `conda`

More info

No response

@Richienb Richienb added bug Something isn't working needs triage Waiting to be triaged by maintainers labels Dec 5, 2024
@lantiga
Copy link
Collaborator

lantiga commented Dec 5, 2024

Using cuda in workers indeed means you need to spawn - once cuda is initialized in a process, that process can't be forked. In turns spawn requires data and callables to be pickle-able and the TorchGraph object from WandB isn't.

One question: where are you calling watch in your code? One thing you could try is move the watch to after the training process has been spawned, e.g. in setup call

    self.logger.watch("model", "log", 10, False)

I haven't tried this myself, if this doesn't work it would be great if you could create a minimal repro to speed things up.

@lantiga lantiga added 3rd party Related to a 3rd-party waiting on author Waiting on user action, correction, or update and removed bug Something isn't working needs triage Waiting to be triaged by maintainers labels Dec 5, 2024
@Richienb
Copy link
Author

Richienb commented Jan 2, 2025

If I remember correctly, doing this does work, but then I need to unwatch it right after, and rewatch it before the next training run and so on.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
3rd party Related to a 3rd-party ver: 2.4.x waiting on author Waiting on user action, correction, or update
Projects
None yet
Development

No branches or pull requests

2 participants