-
Notifications
You must be signed in to change notification settings - Fork 2
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Error differentiating ResNet from torchvision
#24
Comments
https://github.com/rejuvyesh/PyCallChainRules.jl/blob/main/test/test_pytorch_hub.jl might be of interest. |
I see! Thanks for sharing that, I'll try it out and get back here once I have a working FastAI.jl example. Feel free to close the issue, though |
I used the linked code to load a pretrained ResNet and the forward and backward passes work: I then started training it using the standard FastAI.jl image classification which also works, however, after 50 or so steps, I get a CUDA out-of-memory error thrown by PyTorch. Since the training ran fine for 50 batches and reducing the batch size didn't help, I am assuming there is a GPU memory leak somewhere. Have you run into this and have any advice on pinpointing or alleviating the problem? Thanks for your help! |
Would it be possible to share the script you are running? It's definitely possible that DLPack's memorypool is not freeing the tensors appropriately. |
Sure, here it is (adapted lines from using PyCall, PyCallChainRules, Zygote, Flux
using PyCallChainRules.Torch: TorchModuleWrapper, torch
using FastAI
using CUDA
py"""
import torch
def bn2group(module):
num_groups = 16 # hyper_parameter of GroupNorm
module_output = module
if isinstance(module, torch.nn.modules.batchnorm._BatchNorm):
module_output = torch.nn.GroupNorm(num_groups,
module.num_features,
module.eps,
module.affine,
)
if module.affine:
with torch.no_grad():
module_output.weight = module.weight
module_output.bias = module.bias
module_output.running_mean = module.running_mean
module_output.running_var = module.running_var
module_output.num_batches_tracked = module.num_batches_tracked
if hasattr(module, "qconfig"):
module_output.qconfig = module.qconfig
for name, child in module.named_children():
module_output.add_module(name, bn2group(child))
del module
return module_output
"""
function loadresnet(c::Int)
model = torch.hub.load("pytorch/vision", "resnet18")
model.fc = torch.nn.Linear(model.fc.in_features, c) # change number of output classes
model_gn = py"bn2group"(model)
return TorchModuleWrapper(model_gn)
end
Flux.gpu(m::TorchModuleWrapper) = fmap(CUDA.cu, m)
# FastAI.jl part
data, blocks = loaddataset("imagenette2-320")
task = ImageClassificationSingle(blocks)
learner = tasklearner(
task, data;
callbacks=[ToGPU()],
batchsize=4,
model=gpu(loadresnet(length(blocks[2].classes)))) # model being loaded here
# Training
fitonecycle!(learner, 1) |
Again, seems to run correctly for me. I kept Edit: Also:
in case that matters. |
I also let this run for a few more epochs. While I do see a slight uptick in memory usage with multiple epochs, it didn't become drastic enough to kill training. |
By the way, I found this page in the functorch docs which gives some options for dealing with batch norm layers that may be a bit more convenient, e.g. from functorch.experimental import replace_all_batch_norm_modules_
replace_all_batch_norm_modules_(net) |
Nice, I'll move to using this function then! |
I updated my CUDA drivers to 11.6, but am still experiencing memory leaks as described above 😢 I also tried using a vanilla training loop to take FluxTraining.jl out of the equation, so the training loop is not an issue. My updated CUDA version info:
|
I put together a smaller MWE that reproduces the GPU OOM error: using CUDA, PyCall, PyCallChainRules
using PyCallChainRules.Torch: TorchModuleWrapper, torch
fexp = pyimport("functorch.experimental")
model_py = torch.hub.load("pytorch/vision", "resnet18")
model_pygn = fexp.replace_all_batch_norm_modules_(model_py).to(device="cuda")
model = TorchModuleWrapper(model_pygn)
function memoryused()
info = CUDA.MemoryInfo()
return 1 - (info.free_bytes / info.total_bytes)
end
function oom()
xs = cu(randn(Float32, 224, 224, 3, 16))
usage = [memoryused()]
try
for _ in 1:1000
model(xs)
push!(usage, memoryused())
end
catch
finally
return usage
end
end
oom() Which produces linearly growing utilization values before the error: @rejuvyesh any idea where this leak may be coming from or how to get started debugging this? |
Just wanted to comment that I can reproduce this, but haven't been able to get the time to figure out the reason. Likely need to create a reproducer with just DLPack.jl because this is just forward pass, with no gradients. |
This is definitely a bug in the dlpack interaction: using CUDA, PyCall, DLPack, Functors
dlpack = pyimport("torch.utils.dlpack")
torch = pyimport("torch")
fexp = pyimport("functorch.experimental")
pyto_dlpack(x) = @pycall dlpack.to_dlpack(x)::PyObject
pyfrom_dlpack(x) = @pycall dlpack.from_dlpack(x)::PyObject
struct TorchModel
fn::PyObject
end
function (wrap::TorchModel)(args...; kwargs...)
return wrap.fn(fmap(x -> DLPack.share(x, PyObject, pyfrom_dlpack), args)...; kwargs...)
end
model_py = torch.hub.load("pytorch/vision", "resnet18")
model_pygn = fexp.replace_all_batch_norm_modules_(model_py).to(device="cuda")
model = TorchModel(model_pygn)
function memoryused()
info = CUDA.MemoryInfo()
return 1 - (info.free_bytes / info.total_bytes)
end
function oom()
batchsize = 128
usage = [memoryused()]
try
for _ in 1:1000
xs = cu(randn(Float32, 224, 224, 3, batchsize))
model(xs)
push!(usage, memoryused())
end
catch
finally
return usage
end
end
oom() also fails? Might need to change the batchsize for your GPU. |
Hi @rejuvyesh I could run your code on my machine. it seems the value of julia> oom()
0.1589230896872148
0.3943831411266905
0.3996930224840288
0.3996930224840288
0.4050029038413673
0.4050029038413673
0.4050029038413673
0.4050029038413673
0.4050029038413673
0.4050029038413673
0.4050029038413673
0.4050029038413673
0.4050029038413673
0.4050029038413673
0.4050029038413673
0.41296772587737496
0.41296772587737496
0.41296772587737496
⋮
0.8988218700738405
0.8988218700738405
0.8988218700738405
0.8988218700738405
0.9067866921098482
0.9067866921098482
0.9067866921098482
0.9067866921098482
0.9067866921098482
0.9067866921098482
0.9067866921098482
0.9067866921098482
0.9067866921098482
0.9067866921098482
0.9067866921098482
0.9067866921098482
0.9120965734671866
0.9120965734671866 (EDIT) Here is my hardware information. julia> CUDA.versioninfo()
CUDA toolkit 11.7, artifact installation
NVIDIA driver 510.60.2, for CUDA 11.6
CUDA driver 11.6
Libraries:
- CUBLAS: 11.10.1
- CURAND: 10.2.10
- CUFFT: 10.7.2
- CUSOLVER: 11.3.5
- CUSPARSE: 11.7.3
- CUPTI: 17.0.0
- NVML: 11.0.0+510.60.2
- CUDNN: 8.30.2 (for CUDA 11.5.0)
- CUTENSOR: 1.4.0 (for CUDA 11.5.0)
Toolchain:
- Julia: 1.7.2
- LLVM: 12.0.1
- PTX ISA support: 3.2, 4.0, 4.1, 4.2, 4.3, 5.0, 6.0, 6.1, 6.3, 6.4, 6.5, 7.0
- Device capability support: sm_35, sm_37, sm_50, sm_52, sm_53, sm_60, sm_61, sm_62, sm_70, sm_72, sm_75, sm_80
2 devices:
0: NVIDIA GeForce RTX 3060 (sm_86, 11.762 GiB / 12.000 GiB available)
1: NVIDIA GeForce RTX 3060 (sm_86, 11.752 GiB / 12.000 GiB available)
(tmp) pkg> st
Status `~/tmp/Project.toml`
[fbb218c0] BSON v0.3.5
[052768ef] CUDA v3.10.0
[53c2dc0f] DLPack v0.1.1
[2e981812] DataLoaders v0.1.3
[587475ba] Flux v0.13.1
[d9f16b24] Functors v0.2.8
[dbeba491] Metalhead v0.7.1
[3bd65402] Optimisers v0.2.4
[92933f4c] ProgressMeter v1.7.2
[438e738f] PyCall v1.93.1
[b12ccfe2] PyCallChainRules v0.3.2
[e88e6eb3] Zygote v0.6.40 Python 3.8.5 (default, Sep 4 2020, 07:30:14)
[GCC 7.3.0] :: Anaconda, Inc. on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import torch
>>> import functorch
>>> torch.__version__
'1.11.0+cu113'
>>> functorch.__version__
'0.1.1' |
Thanks for looking into this. I am currently away from my GPU workstation, but can hopefully investigate myself next week.
Can you tell if this is an issue in DLPack.jl or just the way it's used? If the former, I guess we should open an issue there? |
The issue is definitely just with dlpack.jl and pytorch interaction and maybe the issue should go there. I need to check whether this happens with Jax as well. |
@terasakisatoshi can you increase the batchsize and see if it actually OOMs (earlier)? |
Setting batchsize=384 returns julia> function oom384()
batchsize = 384
usage = [memoryused()]
try
for _ in 1:1000
xs = cu(randn(Float32, 224, 224, 3, batchsize))
model(xs)
push!(usage, memoryused())
end
catch
finally
return usage
end
end
oom384 (generic function with 1 method)
julia> oom384()
101-element Vector{Float64}:
0.1589230896872148
0.8455571227080395
0.864141707458724
0.864141707458724
0.8827262922094085
0.8827262922094085
0.8827262922094085
0.8827262922094085
0.8827262922094085
0.8827262922094085
0.8827262922094085
0.8827262922094085
0.8827262922094085
0.8827262922094085
0.8827262922094085
0.9013108769600929
0.9013108769600929
0.9013108769600929
0.9013108769600929
0.9013108769600929
0.9013108769600929
0.9013108769600929
0.9013108769600929
0.9013108769600929
0.9013108769600929
0.9013108769600929
⋮
0.9942338007135153
0.9942338007135153
0.9942338007135153
0.9942338007135153
0.9942338007135153
0.9942338007135153
0.9942338007135153
0.9942338007135153
0.9942338007135153
0.9942338007135153
0.9942338007135153
0.9995436820708538
0.9995436820708538
0.9995436820708538
0.9995436820708538
0.9995436820708538
0.9995436820708538
0.9995436820708538
0.9995436820708538
0.9995436820708538
0.9995436820708538
0.9995436820708538
0.9995436820708538
0.9995436820708538
0.9995436820708538
julia> |
I don't think this is the best solution, but the following function using ProgressMeter
function notoom()
batchsize = 384
usage = [memoryused()]
@showprogress for _ in 1:1000
xs = cu(randn(Float32, 224, 224, 3, batchsize))
model(xs)
push!(usage, memoryused())
GC.gc(false) # <---
end
@assert length(usage) == 1+1000
end
julia> notoom()
Progress: 100%|█████████████████████████████████████████| Time: 0:11:57 |
Interesting 🤔. I wonder if it would be enough to do that every few steps only and what the memory usage curve would look like |
This is not really an issue with DLPack.jl nor with PyCallChainRules.jl or PyCall.jl. As mentioned in pabloferz/DLPack.jl#26 It's just that julia has no way of knowing how often it should garbage collect @terasakisatoshi example above is the correct way of handling this. |
In trying to get an image classification example working for FastAI.jl, I tried training a pretrained ResNet model from
torchvision
. The forward pass works fine, but when differentiating, I get an error.I think this is actually a limitation of
functorch
, but figured I'd report here.Minimum working example (last line fails on cpu and gpu):
Stacktrace
The text was updated successfully, but these errors were encountered: