Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

opening enable_cuda_graph only support two different size image when reference #141

Open
zhangp365 opened this issue Apr 10, 2024 · 2 comments

Comments

@zhangp365
Copy link

Thank you for your excellent work on this project, @chengzeyi.

We have integrated the project into our cloud service, which has resulted in significant speed improvements. However, we recently encountered a specific issue. When enabling enable_cuda_graph, the project functions well with images of the first and second size. However, upon changing the image size to the third, it consistently throws an exception.

Our environment setup includes:

torch: 2.2.0
cuda: 11.8
stable_fast: 1.0.4 wheel
xformers: 0.0.24
python: 3.10
Below is the exception log along with additional logging information:

2024-04-10 02:21:19,103 - graphs.py:38 - INFO - Dynamically graphing RecursiveScriptModule
2024-04-10 02:21:27,405 - graphs.py:130 - ERROR - CUDA error: operation failed due to a previous error during capture
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
Traceback (most recent call last):
  File "/opt/conda/lib/python3.10/site-packages/sfast/cuda/graphs.py", line 127, in make_graphed_callable
    static_outputs = func(*static_inputs,
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
RuntimeError: The following operation failed in the TorchScript interpreter.
Traceback of TorchScript (most recent call last):
RuntimeError: CUDA driver error: operation not permitted when stream is capturing

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/opt/conda/lib/python3.10/site-packages/sfast/cuda/graphs.py", line 124, in make_graphed_callable
    with torch.cuda.graph(fwd_graph,
  File "/opt/conda/lib/python3.10/site-packages/torch/cuda/graphs.py", line 183, in __exit__
    self.cuda_graph.capture_end()
  File "/opt/conda/lib/python3.10/site-packages/torch/cuda/graphs.py", line 81, in capture_end
    super().capture_end()
RuntimeError: CUDA error: operation failed due to a previous error during capture
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

To address this issue, we've added additional logging to the code, as shown below:

in \python3.10\Lib\site-packages\sfast\cuda\graphs.py 32 line add the log
    @functools.wraps(wrapped)
    def dynamic_graphed_callable(*args, **kwargs):
        if isinstance(func, torch.nn.Module):
            training = getattr(func, 'training', False)
        elif hasattr(func, '__self__') and isinstance(func.__self__,
                                                      torch.nn.Module):
            training = getattr(func.__self__, 'training', False)
        else:
            training = False
        key = (training, hash_arg(args), hash_arg(kwargs))
        logger.info(f"dynamic_graphed_callable key:{key}")
        cached_callable = cached_callables.get(key)
        if cached_callable is None:
            with lock:
                cached_callable = cached_callables.get(key)
                if cached_callable is None:
                    logger.info(
                        f'Dynamically graphing {getattr(func, "__name__", func.__class__.__name__)}'
                    )
                    cached_callable = simple_make_graphed_callable(
                        func, args, kwargs)
                    cached_callables[key] = cached_callable
        return cached_callable(*args, **kwargs)
line 130  added the log
    try:
        with execution_env.lock:
            with torch.cuda.device(execution_env.device), torch.cuda.stream(
                    execution_env.stream):


                with torch.cuda.graph(fwd_graph,
                                      pool=execution_env.mempool,
                                      stream=execution_env.stream):
                    static_outputs = func(*static_inputs,
                                          **static_kwarg_inputs)
    except Exception as e:
        logger.exception(e)
        logger.error('Failed to capture CUDA Graph, please try without it')
        raise

The issue seems to occur upon the third logging of "Dynamically graphing RecursiveScriptModule," leading to a service disruption. We're actively finding a solution and appreciate any assistance.
Thank you very much.

@chengzeyi
Copy link
Owner

@zhangp365 CUDA Graph capture has one implicit law: During capturing, only one thread can utilize the GPU and execute computation one the GPU device. This is designed by NVIDIA. So you need to modify your program to adopt a single thread pattern or use create a critical section to respect their law.

@zhangp365
Copy link
Author

Thank you very much for your reply.

I have searched for a solution. Actually, I found that CUDA supports updating the image size of the CUDA graph, but there is no interface function to update the parameters in PyTorch. Are there any other ways we can update the size before inference?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants