Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Error differentiating ResNet from torchvision #24

Open
lorenzoh opened this issue Apr 27, 2022 · 21 comments
Open

Error differentiating ResNet from torchvision #24

lorenzoh opened this issue Apr 27, 2022 · 21 comments

Comments

@lorenzoh
Copy link

In trying to get an image classification example working for FastAI.jl, I tried training a pretrained ResNet model from torchvision. The forward pass works fine, but when differentiating, I get an error.

I think this is actually a limitation of functorch, but figured I'd report here.

Minimum working example (last line fails on cpu and gpu):

using Cuda, PyCallChainRules

torchvision = pyimport("torchvision")

model = TorchModuleWrapper(torchvision.models.resnet18(pretrained=true).to("cuda:0"))
xs = randn(Float32, 128, 128, 3, 1) |> cu
ys = model(xs)
Zygote.gradient(() -> Flux.mse(model(xs), ys))
Stacktrace
julia> Zygote.gradient(() -> Flux.mae(model(xs), ys))
ERROR: PyError ($(Expr(:escape, :(ccall(#= /home/lorenz/.julia/packages/PyCall/7a7w0/src/pyfncall.jl:43 =# @pysym(:PyObject_Call), PyPtr, (PyPtr, PyPtr, PyPtr), o, pyargsptr, kw))))) 
RuntimeError('During a grad (vjp, jvp, grad, etc) transform, the function provided attempted to call in-place operation (aten::add_.Tensor) that would mutate a captured Tensor. This is not supported; please rewrite the function being transformed to explicitly accept the mutated Tensor(s) as inputs.')
  File "/home/lorenz/anaconda3/envs/pycall/lib/python3.8/site-packages/functorch/_src/eager_transforms.py", line 243, in vjp
    try:
  File "/home/lorenz/.julia/packages/PyCall/7a7w0/src/pyeval.jl", line 3, in newfn
    const Py_eval_input = 258
  File "/home/lorenz/anaconda3/envs/pycall/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/lorenz/anaconda3/envs/pycall/lib/python3.8/site-packages/functorch/_src/make_functional.py", line 259, in forward
    @staticmethod
  File "/home/lorenz/anaconda3/envs/pycall/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/lorenz/anaconda3/envs/pycall/lib/python3.8/site-packages/torchvision/models/resnet.py", line 283, in forward
    return self._forward_impl(x)
  File "/home/lorenz/anaconda3/envs/pycall/lib/python3.8/site-packages/torchvision/models/resnet.py", line 267, in _forward_impl
    x = self.bn1(x)
  File "/home/lorenz/anaconda3/envs/pycall/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/lorenz/anaconda3/envs/pycall/lib/python3.8/site-packages/torch/nn/modules/batchnorm.py", line 148, in forward
    self.num_batches_tracked.add_(1)  # type: ignore[has-type]

Stacktrace:
[1] pyerr_check
@ ~/.julia/packages/PyCall/7a7w0/src/exception.jl:62 [inlined]
[2] pyerr_check
@ ~/.julia/packages/PyCall/7a7w0/src/exception.jl:66 [inlined]
[3] _handle_error(msg::String)
@ PyCall ~/.julia/packages/PyCall/7a7w0/src/exception.jl:83
[4] macro expansion
@ ~/.julia/packages/PyCall/7a7w0/src/exception.jl:97 [inlined]
[5] #107
@ ~/.julia/packages/PyCall/7a7w0/src/pyfncall.jl:43 [inlined]
[6] disable_sigint
@ ./c.jl:458 [inlined]
[7] __pycall!
@ ~/.julia/packages/PyCall/7a7w0/src/pyfncall.jl:42 [inlined]
[8] _pycall!(ret::PyObject, o::PyObject, args::Tuple{PyObject, NTuple{62, PyObject}, PyObject}, nargs::Int64, kw::Ptr{Nothing})
@ PyCall ~/.julia/packages/PyCall/7a7w0/src/pyfncall.jl:29
[9] _pycall!(ret::PyObject, o::PyObject, args::Tuple{PyObject, NTuple{62, PyObject}, PyObject}, kwargs::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
@ PyCall ~/.julia/packages/PyCall/7a7w0/src/pyfncall.jl:11
[10] (::PyObject)(::PyObject, ::Vararg{Any}; kwargs::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
@ PyCall ~/.julia/packages/PyCall/7a7w0/src/pyfncall.jl:86
[11] (::PyObject)(::PyObject, ::Vararg{Any})
@ PyCall ~/.julia/packages/PyCall/7a7w0/src/pyfncall.jl:86
[12] rrule(wrap::TorchModuleWrapper, args::Array{Float32, 4}; kwargs::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
@ PyCallChainRules.Torch ~/.julia/packages/PyCallChainRules/ebIKG/src/pytorch.jl:62
[13] rrule
@ ~/.julia/packages/PyCallChainRules/ebIKG/src/pytorch.jl:57 [inlined]
[14] rrule
@ ~/.julia/packages/ChainRulesCore/RbX5a/src/rules.jl:134 [inlined]
[15] chain_rrule
@ ~/.julia/packages/Zygote/H6vD3/src/compiler/chainrules.jl:216 [inlined]
[16] macro expansion
@ ~/.julia/packages/Zygote/H6vD3/src/compiler/interface2.jl:0 [inlined]
[17] _pullback(ctx::Zygote.Context, f::TorchModuleWrapper, args::Array{Float32, 4})
@ Zygote ~/.julia/packages/Zygote/H6vD3/src/compiler/interface2.jl:9
[18] _pullback
@ ~/.julia/dev/_InteractiveSessions/22_03/03_25_pychain_fastai.jl:86 [inlined]
[19] _pullback(::Zygote.Context, ::var"#27#28")
@ Zygote ~/.julia/packages/Zygote/H6vD3/src/compiler/interface2.jl:0
[20] _pullback(::Function)
@ Zygote ~/.julia/packages/Zygote/H6vD3/src/compiler/interface.jl:34
[21] pullback(::Function)
@ Zygote ~/.julia/packages/Zygote/H6vD3/src/compiler/interface.jl:40
[22] gradient(::Function)
@ Zygote ~/.julia/packages/Zygote/H6vD3/src/compiler/interface.jl:75
[23] top-level scope
@ ~/.julia/dev/_InteractiveSessions/22_03/03_25_pychain_fastai.jl:86

@rejuvyesh
Copy link
Owner

https://github.com/rejuvyesh/PyCallChainRules.jl/blob/main/test/test_pytorch_hub.jl might be of interest. functorch recommends replacing in place batchnorm with other things like groupnorm which works equally well.

@lorenzoh
Copy link
Author

I see! Thanks for sharing that, I'll try it out and get back here once I have a working FastAI.jl example. Feel free to close the issue, though

@lorenzoh
Copy link
Author

lorenzoh commented Apr 27, 2022

I used the linked code to load a pretrained ResNet and the forward and backward passes work:

image

I then started training it using the standard FastAI.jl image classification which also works, however, after 50 or so steps, I get a CUDA out-of-memory error thrown by PyTorch. Since the training ran fine for 50 batches and reducing the batch size didn't help, I am assuming there is a GPU memory leak somewhere.

image

Have you run into this and have any advice on pinpointing or alleviating the problem? Thanks for your help!

@rejuvyesh
Copy link
Owner

Would it be possible to share the script you are running? It's definitely possible that DLPack's memorypool is not freeing the tensors appropriately.

@lorenzoh
Copy link
Author

lorenzoh commented Apr 27, 2022

Sure, here it is (adapted lines from test_pytorch_hub.jl included for completeness):

using PyCall, PyCallChainRules, Zygote, Flux
using PyCallChainRules.Torch: TorchModuleWrapper, torch
using FastAI
using CUDA

py"""
import torch
def bn2group(module):
    num_groups = 16 # hyper_parameter of GroupNorm
    module_output = module
    if isinstance(module, torch.nn.modules.batchnorm._BatchNorm):
        module_output = torch.nn.GroupNorm(num_groups,
                                           module.num_features,
                                           module.eps,
                                           module.affine,
                                          )
        if module.affine:
            with torch.no_grad():
                module_output.weight = module.weight
                module_output.bias = module.bias
        module_output.running_mean = module.running_mean
        module_output.running_var = module.running_var
        module_output.num_batches_tracked = module.num_batches_tracked
        if hasattr(module, "qconfig"):
            module_output.qconfig = module.qconfig
    for name, child in module.named_children():
        module_output.add_module(name, bn2group(child))
    del module
    return module_output
"""


function loadresnet(c::Int)
    model = torch.hub.load("pytorch/vision", "resnet18")
    model.fc = torch.nn.Linear(model.fc.in_features, c)  # change number of output classes
    model_gn = py"bn2group"(model)
    return TorchModuleWrapper(model_gn)
end

Flux.gpu(m::TorchModuleWrapper) = fmap(CUDA.cu, m)


# FastAI.jl part
data, blocks = loaddataset("imagenette2-320")
task = ImageClassificationSingle(blocks)
learner = tasklearner(
    task, data;
    callbacks=[ToGPU()],
    batchsize=4,
    model=gpu(loadresnet(length(blocks[2].classes))))  # model being loaded here

# Training
fitonecycle!(learner, 1)

@rejuvyesh
Copy link
Owner

rejuvyesh commented Apr 27, 2022

Epoch 1 TrainingPhase(): 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| Time: 0:03:50
┌───────────────┬───────┬─────────┐
│         Phase │ Epoch │    Loss │
├───────────────┼───────┼─────────┤
│ TrainingPhase │   1.0 │ 2.45761 │
└───────────────┴───────┴─────────┘
Epoch 1 ValidationPhase(): 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| Time: 0:00:19
┌─────────────────┬───────┬───────┐
│           Phase │ Epoch │  Loss │
├─────────────────┼───────┼───────┤
│ ValidationPhase │   1.0 │ 2.464 │
└─────────────────┴───────┴───────┘

Again, seems to run correctly for me. I kept nvtop open on the side as well and I never saw memory usage go higher than 70% for me and was quite constant.
I think with this and #18 the common factor seems to be older CUDA (and possibly older NVIDIA drivers) on your end?

Edit:

Also:

julia> torch.__version__
"1.11.0"
julia> functorch.__version__
"0.1.0"

in case that matters.

@rejuvyesh
Copy link
Owner

I also let this run for a few more epochs. While I do see a slight uptick in memory usage with multiple epochs, it didn't become drastic enough to kill training.

@lorenzoh
Copy link
Author

By the way, I found this page in the functorch docs which gives some options for dealing with batch norm layers that may be a bit more convenient, e.g.

from functorch.experimental import replace_all_batch_norm_modules_
replace_all_batch_norm_modules_(net)

@rejuvyesh
Copy link
Owner

Nice, I'll move to using this function then!

@lorenzoh
Copy link
Author

lorenzoh commented Apr 30, 2022

I updated my CUDA drivers to 11.6, but am still experiencing memory leaks as described above 😢

I also tried using a vanilla training loop to take FluxTraining.jl out of the equation, so the training loop is not an issue.

My updated CUDA version info:

CUDA.versioninfo()
CUDA toolkit 11.6, artifact installation
NVIDIA driver 510.47.3, for CUDA 11.6
CUDA driver 11.6

Libraries: 
- CUBLAS: 11.8.1
- CURAND: 10.2.9
- CUFFT: 10.7.0
- CUSOLVER: 11.3.2
- CUSPARSE: 11.7.1
- CUPTI: 16.0.0
- NVML: 11.0.0+510.47.3
- CUDNN: 8.30.2 (for CUDA 11.5.0)
- CUTENSOR: 1.4.0 (for CUDA 11.5.0)

Toolchain:
- Julia: 1.8.0-beta3
- LLVM: 13.0.1
- PTX ISA support: 3.2, 4.0, 4.1, 4.2, 4.3, 5.0, 6.0, 6.1, 6.3, 6.4, 6.5, 7.0, 7.1, 7.2
- Device capability support: sm_35, sm_37, sm_50, sm_52, sm_53, sm_60, sm_61, sm_62, sm_70, sm_72, sm_75, sm_80, sm_86

1 device:
  0: NVIDIA GeForce GTX 1080 Ti (sm_61, 23.625 MiB / 11.000 GiB available)

@lorenzoh
Copy link
Author

lorenzoh commented May 4, 2022

I put together a smaller MWE that reproduces the GPU OOM error:

using CUDA, PyCall, PyCallChainRules
using PyCallChainRules.Torch: TorchModuleWrapper, torch
fexp = pyimport("functorch.experimental")

model_py = torch.hub.load("pytorch/vision", "resnet18")
model_pygn = fexp.replace_all_batch_norm_modules_(model_py).to(device="cuda")
model = TorchModuleWrapper(model_pygn)

function memoryused()
    info = CUDA.MemoryInfo()
    return 1 - (info.free_bytes / info.total_bytes)
end

function oom()
    xs = cu(randn(Float32, 224, 224, 3, 16))
    usage = [memoryused()]
    try
        for _ in 1:1000
            model(xs)
            push!(usage, memoryused())
        end
    catch
    finally
        return usage
    end
end

oom()

Which produces linearly growing utilization values before the error:

image

@rejuvyesh any idea where this leak may be coming from or how to get started debugging this?

@rejuvyesh
Copy link
Owner

rejuvyesh commented May 12, 2022

Just wanted to comment that I can reproduce this, but haven't been able to get the time to figure out the reason. Likely need to create a reproducer with just DLPack.jl because this is just forward pass, with no gradients.
DLPack.jl keeps a memorypool of shared tensors at:
https://github.com/pabloferz/DLPack.jl/blob/2e491ac7e839a7428d817b652c4d525faa52ceac/src/DLPack.jl#L173 and we will need to track the state of this variable to figure out what's happening.

@rejuvyesh
Copy link
Owner

rejuvyesh commented May 16, 2022

This is definitely a bug in the dlpack interaction:

using CUDA, PyCall, DLPack, Functors

dlpack = pyimport("torch.utils.dlpack")
torch = pyimport("torch")
fexp = pyimport("functorch.experimental")

pyto_dlpack(x) = @pycall dlpack.to_dlpack(x)::PyObject
pyfrom_dlpack(x) = @pycall dlpack.from_dlpack(x)::PyObject


struct TorchModel
    fn::PyObject
end

function (wrap::TorchModel)(args...; kwargs...)
    return wrap.fn(fmap(x -> DLPack.share(x, PyObject, pyfrom_dlpack), args)...; kwargs...)
end


model_py = torch.hub.load("pytorch/vision", "resnet18")
model_pygn = fexp.replace_all_batch_norm_modules_(model_py).to(device="cuda")
model = TorchModel(model_pygn)

function memoryused()
    info = CUDA.MemoryInfo()
    return 1 - (info.free_bytes / info.total_bytes)
end

function oom()
    batchsize = 128
    usage = [memoryused()]
    try
        for _ in 1:1000
            xs = cu(randn(Float32, 224, 224, 3, batchsize))
            model(xs)
            push!(usage, memoryused())
        end
    catch
    finally
        return usage
    end
end

oom()

also fails? Might need to change the batchsize for your GPU.

@terasakisatoshi
Copy link
Contributor

terasakisatoshi commented May 18, 2022

Hi @rejuvyesh

I could run your code on my machine. it seems the value of memoryused() increases gradually.

julia> oom()
 0.1589230896872148
 0.3943831411266905
 0.3996930224840288
 0.3996930224840288
 0.4050029038413673
 0.4050029038413673
 0.4050029038413673
 0.4050029038413673
 0.4050029038413673
 0.4050029038413673
 0.4050029038413673
 0.4050029038413673
 0.4050029038413673
 0.4050029038413673
 0.4050029038413673
 0.41296772587737496
 0.41296772587737496
 0.41296772587737496
 
 0.8988218700738405
 0.8988218700738405
 0.8988218700738405
 0.8988218700738405
 0.9067866921098482
 0.9067866921098482
 0.9067866921098482
 0.9067866921098482
 0.9067866921098482
 0.9067866921098482
 0.9067866921098482
 0.9067866921098482
 0.9067866921098482
 0.9067866921098482
 0.9067866921098482
 0.9067866921098482
 0.9120965734671866
 0.9120965734671866

(EDIT)

Here is my hardware information.

julia> CUDA.versioninfo()
CUDA toolkit 11.7, artifact installation
NVIDIA driver 510.60.2, for CUDA 11.6
CUDA driver 11.6

Libraries: 
- CUBLAS: 11.10.1
- CURAND: 10.2.10
- CUFFT: 10.7.2
- CUSOLVER: 11.3.5
- CUSPARSE: 11.7.3
- CUPTI: 17.0.0
- NVML: 11.0.0+510.60.2
- CUDNN: 8.30.2 (for CUDA 11.5.0)
- CUTENSOR: 1.4.0 (for CUDA 11.5.0)

Toolchain:
- Julia: 1.7.2
- LLVM: 12.0.1
- PTX ISA support: 3.2, 4.0, 4.1, 4.2, 4.3, 5.0, 6.0, 6.1, 6.3, 6.4, 6.5, 7.0
- Device capability support: sm_35, sm_37, sm_50, sm_52, sm_53, sm_60, sm_61, sm_62, sm_70, sm_72, sm_75, sm_80

2 devices:
  0: NVIDIA GeForce RTX 3060 (sm_86, 11.762 GiB / 12.000 GiB available)
  1: NVIDIA GeForce RTX 3060 (sm_86, 11.752 GiB / 12.000 GiB available)

(tmp) pkg> st
      Status `~/tmp/Project.toml`
  [fbb218c0] BSON v0.3.5
  [052768ef] CUDA v3.10.0
  [53c2dc0f] DLPack v0.1.1
  [2e981812] DataLoaders v0.1.3
  [587475ba] Flux v0.13.1
  [d9f16b24] Functors v0.2.8
  [dbeba491] Metalhead v0.7.1
  [3bd65402] Optimisers v0.2.4
  [92933f4c] ProgressMeter v1.7.2
  [438e738f] PyCall v1.93.1
  [b12ccfe2] PyCallChainRules v0.3.2
  [e88e6eb3] Zygote v0.6.40
Python 3.8.5 (default, Sep  4 2020, 07:30:14) 
[GCC 7.3.0] :: Anaconda, Inc. on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import torch
>>> import functorch
>>> torch.__version__
'1.11.0+cu113'
>>> functorch.__version__
'0.1.1'

@lorenzoh
Copy link
Author

Thanks for looking into this. I am currently away from my GPU workstation, but can hopefully investigate myself next week.

This is definitely a bug in the dlpack interaction

Can you tell if this is an issue in DLPack.jl or just the way it's used? If the former, I guess we should open an issue there?

@rejuvyesh
Copy link
Owner

The issue is definitely just with dlpack.jl and pytorch interaction and maybe the issue should go there. I need to check whether this happens with Jax as well.

@rejuvyesh
Copy link
Owner

@terasakisatoshi can you increase the batchsize and see if it actually OOMs (earlier)?

@terasakisatoshi
Copy link
Contributor

@rejuvyesh

Setting batchsize=384 returns oom384() function before reaching length(usage) == 1000.

julia> function oom384()
           batchsize = 384
           usage = [memoryused()]
           try
               for _ in 1:1000
                   xs = cu(randn(Float32, 224, 224, 3, batchsize))
                   model(xs)
                   push!(usage, memoryused())
               end
           catch
           finally
               return usage
           end
       end
oom384 (generic function with 1 method)

julia> oom384()
101-element Vector{Float64}:
 0.1589230896872148
 0.8455571227080395
 0.864141707458724
 0.864141707458724
 0.8827262922094085
 0.8827262922094085
 0.8827262922094085
 0.8827262922094085
 0.8827262922094085
 0.8827262922094085
 0.8827262922094085
 0.8827262922094085
 0.8827262922094085
 0.8827262922094085
 0.8827262922094085
 0.9013108769600929
 0.9013108769600929
 0.9013108769600929
 0.9013108769600929
 0.9013108769600929
 0.9013108769600929
 0.9013108769600929
 0.9013108769600929
 0.9013108769600929
 0.9013108769600929
 0.9013108769600929
 
 0.9942338007135153
 0.9942338007135153
 0.9942338007135153
 0.9942338007135153
 0.9942338007135153
 0.9942338007135153
 0.9942338007135153
 0.9942338007135153
 0.9942338007135153
 0.9942338007135153
 0.9942338007135153
 0.9995436820708538
 0.9995436820708538
 0.9995436820708538
 0.9995436820708538
 0.9995436820708538
 0.9995436820708538
 0.9995436820708538
 0.9995436820708538
 0.9995436820708538
 0.9995436820708538
 0.9995436820708538
 0.9995436820708538
 0.9995436820708538
 0.9995436820708538

julia>

@terasakisatoshi
Copy link
Contributor

terasakisatoshi commented May 22, 2022

I don't think this is the best solution, but the following function notoom, just added GC.gc(false) for each loop, does not occur OOM.

using ProgressMeter
function notoom()
    batchsize = 384
    usage = [memoryused()]
    @showprogress for _ in 1:1000
        xs = cu(randn(Float32, 224, 224, 3, batchsize))
        model(xs)
        push!(usage, memoryused())
        GC.gc(false) # <---
    end
    @assert length(usage) == 1+1000
end

julia> notoom()
Progress: 100%|█████████████████████████████████████████| Time: 0:11:57

@lorenzoh
Copy link
Author

Interesting 🤔. I wonder if it would be enough to do that every few steps only and what the memory usage curve would look like

@pabloferz
Copy link

pabloferz commented Jul 16, 2022

This is not really an issue with DLPack.jl nor with PyCallChainRules.jl or PyCall.jl. As mentioned in pabloferz/DLPack.jl#26 It's just that julia has no way of knowing how often it should garbage collect PyObjects in general.

@terasakisatoshi example above is the correct way of handling this.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants