Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GPU out of memory issues #242

Closed
chschroeder opened this issue Dec 22, 2022 · 11 comments
Closed

GPU out of memory issues #242

chschroeder opened this issue Dec 22, 2022 · 11 comments
Labels
needs verification Perhaps a bug; needs further attention

Comments

@chschroeder
Copy link

Hi,

When repeatedly using SetFit's train() / predict() inside a loop (for active learning) the GPU memory usage steadily grows (despite that all results have been correctly transferred to the CPU). Eventually, OutOfMemoryError: CUDA out of memory. is raised.

This is caused by sentence-transformers and I have reported it here in detail:
UKPLab/sentence-transformers/issues/1793

Just wanted to mention it here as well for the purposes of documentation.

@tomaarsen
Copy link
Member

Hello @chschroeder,

Thank you for the heads up!
You mention that repeated uses of trainer.train() causes endless GPU memory growth. I'm trying to reproduce this in SetFit with a minimal example, but am unable to get these results. For example in the following script:

Example script that can't reproduce endless memory growth
from datasets import load_dataset
from sentence_transformers.losses import CosineSimilarityLoss
import torch

from setfit import SetFitModel, SetFitTrainer, sample_dataset


# Load a dataset from the Hugging Face Hub
dataset = load_dataset("sst2")

# Simulate the few-shot regime by sampling 8 examples per class
train_dataset = sample_dataset(dataset["train"], label_column="label", num_samples=8)
eval_dataset = dataset["validation"]

# Load a SetFit model from Hub
model = SetFitModel.from_pretrained("sentence-transformers/paraphrase-mpnet-base-v2")

# Create trainer
trainer = SetFitTrainer(
    model=model,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    loss_class=CosineSimilarityLoss,
    metric="accuracy",
    batch_size=32,
    num_iterations=5,  # The number of text pairs to generate for contrastive learning
    num_epochs=1,  # The number of epochs to use for contrastive learning
    column_mapping={"sentence": "text", "label": "label"}  # Map dataset columns to text/label expected by trainer
)

for i in range(10):
    trainer.train()
    model.predict(["i loved the spiderman movie!", "pineapple on pizza is the worst 🤮"])
    print(torch.cuda.memory_allocated(0))

This constantly prints 892729856 (i.e. 0.9GB) as the allocated memory.

Similarly, if I run the hyperparameter tuning from the README, except on the sst2 dataset, with a fixed batch size of 8, just 1 epoch, and the sentence-transformers/paraphrase-mpnet-base-v2 model as the body, then I report values of around 0.9GB allocated throughout the entire hyperparameter tuning process of 20 trials.

Would it be possible to provide a script using either SetFit or a SentenceTransformer where the endless growth is reported? That way it can eventually be used to verify that a potential fix works or modified into a test over on the sentence-transformer repository. Feel free to post the response on the aforementioned issue on the sentence-transformers repository, too.

  • Tom Aarsen

@chschroeder
Copy link
Author

chschroeder commented Dec 23, 2022

Would it be possible to provide a script using either SetFit or a SentenceTransformer where the endless growth is reported?

Sure! I am on it. My first attempt where I used a setfit example from scratch did not show the same behaviour, which is why I have to take a step back and take a closer look at my own code. I guess it might be related to the garbage collection because my own setup brings its own abstractions so the references might be cleaned up later. "Normal" transformer models as a cross check work flawlessly using this setup which is why I still think there is some problem elsewhere. I will investigate further and provide a script or notebook.

@chschroeder
Copy link
Author

Update:
I could not completely solve it yet, this is really strange.

I continued investigating my own example (active learning instead of the setfit-only example):

  • When I add code to monitor torch.cuda memory stats at multiple places in my loop the OOM error disappears.
    print('Before query:', torch.cuda.memory_stats()['inactive_split_bytes.all.current'] / 1024**2)
    
  • Once I remove this the OOM error is back.

My assessment so far:

  • It is likely not a bug but we could help to free the GPU memory in a more reliable way. (Although it is difficult to reproduce, I can make the error disappear by explicitly deleting unused tensors using "del" statements in the sentence-transformers code.)
  • These improvments would likely by on the sentence-transformers side, so SetFit might only indirectly be affected.

I might defer further investigation but I will report back in this issue once I find some time to continue investigating here.

@tomaarsen tomaarsen added the needs verification Perhaps a bug; needs further attention label Jan 19, 2023
@anjanvb
Copy link

anjanvb commented Jan 21, 2023

I just ran into this in SageMaker Studio instance (4 vCPU + 16 GiB + 1 GPU). I trained a model before, using default dataset from Datasets. My next attempt is using 8 samples each for 6 classes, with my custom text. I am curious to know if there's a max size limit for text? and could that be an issue? Some of my texts are large (like 1 page of a PDF form). I killed my previous instance and started up a fresh instance but still running into this.

OutOfMemoryError: CUDA out of memory. Tried to allocate 72.00 MiB (GPU 0; 14.76 GiB total capacity; 13.93 GiB already  allocated; 15.75 MiB free; 13.98 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

@chschroeder
Copy link
Author

@anjanvb If I understand you correctly, your error does likely not occur over time but instantly. This is unrelated to the specific issue described above.

(You have, however, guessed correctly a) there is token limit, and b) the GPU memory usage scales with this text length. So your problem is likely caused by the length of your input data. Search for the arguments max_length in SetFit and max_seq_length in sentence transformers.)

Regarding the original issue: Sorry, I did not have time to investigate this further. Depending on your choice you can close the issue, as long as everything is documented here it might still be useful in the future.

@tomaarsen
Copy link
Member

I'll close this by now, as the issue will still be accessible here. Should you find anything new on the topic that is worth our attention, then do not hesitate to reopen this issue or make a new one altogether. I'm personally interested in any memory leak issues that may exist, and I'd like to thank you for the time that you spent to look into this.

  • Tom Aarsen

@chschroeder
Copy link
Author

Thank you, Tom! Unfortunately I could not manage to produce a failing self-contained example (given my limited current time budget at least). I think I know how to fix the problem (and that is by adding explicit delete statements) but such a code change is hard to justify if we reliably verify it. Also, the problem is located in sentence-transformers and not in setfit, but nevertheless, it is good that this is documented here as well.

For future reference, just two days ago I encountered the problem again, it is still there. This notebook failed in 3 out of 5 tries and succeeded otherwise. The notebook was run on a 2080 with 11GB.

@matthieuvion
Copy link

I ran into this OOM error myself, doing HPO with SetFit.
Tried a lot of things etc.
TL;DR :

  • Not SetFit related, rather a mix between GPU vs. model choice / SentenceTransformers / Optuna
  • Useful ressources for me, pretty much aligned with my own tests/debugs attempt : here, Transformers #13019, Transformers #1742
  • Leaving my (partial) fix down below, adapted for Setfit 1.0.1, as I spent a lot of time and it might help other souls ;)

My case:
Poetry env, torch 2.1.1, SetFit 1.0.1, sentence-transformers 2.2.2, Jupyter notebook
(old!) Hardware ; 1080ti (11GB VRAM), 32gb RAM.
Model used : paraphrase-multilingual-mpnet-base-v2.
OOM error after 2 or 3 trials; could extend to 6-7trials by playing with smaller models (mini-LM) and the batch size.

OOM error using :

from optuna import Trial

# Optional, but for test purposes 8 ex. per class
train_dataset = sample_dataset(dataset["train"], label_column="label", num_samples=8, seed=40)

def model_init(params):
    params = params or {}
    max_iter = params.get("max_iter", 100)
    solver = params.get("solver", "liblinear")
    params = {
        "head_params": {
            "max_iter": max_iter,
            "solver": solver,
        }
    }
    return SetFitModel.from_pretrained("sentence-transformers/paraphrase-multilingual-mpnet-base-v2", **params)

def hp_space(trial):
    """ Define hyperparams search space (Optuna) """
    
    return {
        # Embeddings fine-tuning phase params :
        
        "body_learning_rate": trial.suggest_float("body_learning_rate", 1e-6, 1e-3, log=True),
        "num_epochs": trial.suggest_int("num_epochs", 1, 3),
        "batch_size": trial.suggest_categorical("batch_size", [16, 32]),
        "seed": trial.suggest_int("seed", 1, 40),
        
        # LogisticRegression head params :
        
        "max_iter": trial.suggest_int("max_iter", 50, 300),
        "solver": trial.suggest_categorical("solver", ["newton-cg", "liblinear","lbfgs"]),
    }

trainer = Trainer(
    model_init=model_init,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    metric="accuracy",
    column_mapping={"comment": "text", "label": "label"},
)
best_run = trainer.hyperparameter_search(direction="maximize", hp_space=hp_space, n_trials=4)

I Could extent the trials to more (at least 10), doing some memory management.
Before that tested a lot of things : check of my dataset (french) & sentences structure, max_seq, smaller models & batch size etc.
It eventually will still runs OOM, and IMO, a better solution would prob. be tweaking the optimizer/scheduler as said in other threads.

(Partial) fix to extend the time before running OOM:
Overwrite run_hp_search_optuna function, more specifically _objective, adding memory management.
Memory management also added at model_init (helps even more)

import gc
import torch
from optuna import Trial
from setfit import Trainer, SetFitModel, sample_dataset
import time

# Model initialization function
def model_init(params):
    params = params or {}
    max_iter = params.get("max_iter", 100)
    solver = params.get("solver", "liblinear")
    params = {
        "head_params": {
            "max_iter": max_iter,
            "solver": solver,
        }
    }
    # memory management
    gc.collect()
    torch.cuda.empty_cache()

    return SetFitModel.from_pretrained("sentence-transformers/paraphrase-multilingual-mpnet-base-v2", **params)
    
# Hyperparameter space definition
def hp_space(trial):
    """ Define hyperparams search space (Optuna) """
    
    return {
        # Embeddings fine-tuning phase params :
        
        "body_learning_rate": trial.suggest_float("body_learning_rate", 1e-7 , 1e-5, log=True), # 1e-6, 1e-3
        # "num_epochs": trial.suggest_int("num_epochs", 1, 2),
        "max_steps": trial.suggest_int("max_steps", 650, 800), # 200, 900
        "batch_size": trial.suggest_categorical("batch_size", [16]),
        "seed": trial.suggest_int("seed", 1, 40),
        
        # LogisticRegression head params :
        
        "max_iter": trial.suggest_int("max_iter", 120, 126), # 100, 200
        "solver": trial.suggest_categorical("solver", ["liblinear"]), # "newton-cg",'lbfgs'
    }

# Customized run_hp_search_optuna function
def run_hp_search_optuna_modified(trainer, n_trials, direction, **kwargs):
    import optuna

    def _objective(trial):
        trainer.objective = None
        trainer.train(trial=trial)
        
        # memory management
        del trainer.model
        gc.collect()
        torch.cuda.empty_cache()

        # Evaluate if needed
        if getattr(trainer, "objective", None) is None:
            metrics = trainer.evaluate()
            trainer.objective = trainer.compute_objective(metrics)
        
        
        return trainer.objective

    timeout = kwargs.pop("timeout", None)
    n_jobs = kwargs.pop("n_jobs", 1)
    study = optuna.create_study(direction=direction, **kwargs)

    # memory management : overkill, but also adding gc_after_trial=True in study.optimize()
    study.optimize(_objective, n_trials=n_trials, timeout=timeout, n_jobs=n_jobs, gc_after_trial=True)
    best_trial = study.best_trial
    return BestRun(str(best_trial.number), best_trial.value, best_trial.params, study)
    
# Initialize Trainer
trainer = Trainer(
    model_init=model_init,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    metric="accuracy",
    column_mapping={"comment": "text", "label": "label"},
)

# Replace the run_hp_search_optuna method with the modified one
trainer.run_hp_search_optuna = run_hp_search_optuna_modified

# Run hyperparameter search
best_run = trainer.hyperparameter_search(direction="maximize", hp_space=hp_space, n_trials=3)

@Alexander-Mark
Copy link

Thanks @matthieuvion but didn't seem to make much difference for me - every implementation seems to have memory issues.

Running 3 trials of HPO on 64 short paragraphs of text with intfloat/multilingual-e5-small on a 4090:

torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 768.00 MiB. GPU 0 has a total capacity of 23.64 GiB of which 592.62 MiB is free. Including non-PyTorch memory, this process has 21.52 GiB memory in use. Of the allocated memory 20.92 GiB is allocated by PyTorch, and 156.86 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation. See documentation for Memory Management (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

Maybe I'm missing something more fundamental?

@matthieuvion
Copy link

Thanks @matthieuvion but didn't seem to make much difference for me - every implementation seems to have memory issues.

Running 3 trials of HPO on 64 short paragraphs of text with intfloat/multilingual-e5-small on a 4090:

torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 768.00 MiB. GPU 0 has a total capacity of 23.64 GiB of which 592.62 MiB is free. Including non-PyTorch memory, this process has 21.52 GiB memory in use. Of the allocated memory 20.92 GiB is allocated by PyTorch, and 156.86 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation. See documentation for Memory Management (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

Maybe I'm missing something more fundamental?

Had the exact same error. Honestly I spent too much time understanding where it could come from, after a week of trials and your issue could be different. That's why I posted my half backed "solution" here still. It does not prevent saturating the VRAM eventually. Just did a quick search with "e5 OOM" and plenty of possibilities.
Setfit is a wonderful model/library with impressive few shot performance, as soon as I tried to get further (HPO, differentiable head, onnx conversion etc.) some workarounds can't be applied easily (at least by me) as it is a wrapper around transformers and some features are not yet implemented.
In my case I ended up doing HPO with the same model on multiple runs, and then with a smaller model, without OOM errors. I believe you can still retain some hyperparameters, particularly the one related to the logistic head (if that's the one you're using).

@Alexander-Mark
Copy link

There are some other tricks you can try:

  • run watch nvidia-smi to see which other programs are using your GPU, then close unnecessary ones.
  • set CUDA env variables that work better for you before your training code runs:
    • os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128"
    • os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
  • using automatic mixed precision actually caused OOM for me when training the head, ie don't set use_amp=True.
  • batch size is the main param that influences memory, I had to keep it small on both body and head, then gradually increase:
    • batch_size=(4, 2)
  • some transformers are just too much for consumer hardware (24GB VRAM is relatively small for larger models). Check the config.json file for num_hidden_layers, num_attention_heads, max_position_embeddings, intermediate_size, hidden_size, etc. I'm experimenting with multilingual-e5-large-instruct and it is very complex and demanding on memory. I usually start with small -> base -> large.

I don't really have time to dive into setfit source code, so can only say:

  • Main memory usage is storing the gradients and activations, as well as allocating memory for the input data which includes the batch of input features. Deeper and wider models will quickly consume memory, especially if the text sequences are long.
  • Garbage collection won't actually release tensor memory until the tensor is out of scope, ie there are no more references to it. So it depends on how setfit is creating and using tensors. I'm not sure how much flexibility there is in setfit source code to change this. @tomaarsen or other engineers will have to weigh in on this.

But you are right, if you're able to work within these limits, setfit is an awesome package. We use it a lot for training production models.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
needs verification Perhaps a bug; needs further attention
Projects
None yet
Development

No branches or pull requests

5 participants