Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

KTO training produces NaN rewards #1447

Closed
claralp opened this issue Mar 19, 2024 · 10 comments
Closed

KTO training produces NaN rewards #1447

claralp opened this issue Mar 19, 2024 · 10 comments

Comments

@claralp
Copy link
Contributor

claralp commented Mar 19, 2024

Within the training with KTO Trainer I occasionally experience nan values as rewards.
I am running the training as a job on Ms Azure with one GPU (NVIDIA A100 80GB PCIe).
Ultimately these issues cause my Azure job to crash and retry...

The log output I get from the KTOTrainer:

{'loss': 0.0202, 'grad_norm': 0.016206106171011925, 'learning_rate': 8.595282268751656e-06, 'rewards/chosen': 11.158143043518066, 'rewards/rejected': -29.0671443939209, 'rewards/margins': 40.22528839111328, 'kl': 0.0, 'logps/chosen': -15.192975044250488, 'logps/rejected': -180.1438446044922, 'epoch': 0.43}
{'loss': 0.0155, 'grad_norm': 4.091757774353027, 'learning_rate': 8.568778160614896e-06, 'rewards/chosen': 10.752923965454102, 'rewards/rejected': -26.606868743896484, 'rewards/margins': 37.35979461669922, 'kl': 0.0, 'logps/chosen': -13.974691390991211, 'logps/rejected': -156.9815673828125, 'epoch': 0.44}
{'loss': 0.0124, 'grad_norm': 0.06709074974060059, 'learning_rate': 8.542274052478135e-06, 'rewards/chosen': 10.838713645935059, 'rewards/rejected': -29.24416732788086, 'rewards/margins': 40.08287811279297, 'kl': 0.0, 'logps/chosen': -10.99155044555664, 'logps/rejected': -165.8121795654297, 'epoch': 0.44}
{'loss': 0.0113, 'grad_norm': 14.28693675994873, 'learning_rate': 8.515769944341374e-06, 'rewards/chosen': 11.07004451751709, 'rewards/rejected': -30.99440574645996, 'rewards/margins': 42.064453125, 'kl': 0.0, 'logps/chosen': -13.967004776000977, 'logps/rejected': -176.50094604492188, 'epoch': 0.45}
{'loss': 0.0193, 'grad_norm': 3.899095296859741, 'learning_rate': 8.489265836204611e-06, 'rewards/chosen': 10.825413703918457, 'rewards/rejected': -34.434303283691406, 'rewards/margins': 45.25971984863281, 'kl': 0.0, 'logps/chosen': -12.9598388671875, 'logps/rejected': -186.38381958007812, 'epoch': 0.46}
{'loss': 0.0109, 'grad_norm': 0.009407841600477695, 'learning_rate': 8.46276172806785e-06, 'rewards/chosen': nan, 'rewards/rejected': -33.95360565185547, 'rewards/margins': nan, 'kl': 0.0, 'logps/chosen': nan, 'logps/rejected': -176.4713897705078, 'epoch': 0.47}
{'loss': 0.0324, 'grad_norm': 17.832523345947266, 'learning_rate': 8.43625761993109e-06, 'rewards/chosen': 10.286358833312988, 'rewards/rejected': -33.60068893432617, 'rewards/margins': 43.887046813964844, 'kl': 0.0, 'logps/chosen': -20.224634170532227, 'logps/rejected': -184.18112182617188, 'epoch': 0.48}
{'loss': 0.0029, 'grad_norm': 0.03802444413304329, 'learning_rate': 8.409753511794329e-06, 'rewards/chosen': 10.086004257202148, 'rewards/rejected': nan, 'rewards/margins': nan, 'kl': 0.0, 'logps/chosen': -12.816671371459961, 'logps/rejected': nan, 'epoch': 0.48}
{'loss': 0.012, 'grad_norm': 2.815098524093628, 'learning_rate': 8.383249403657568e-06, 'rewards/chosen': 10.549690246582031, 'rewards/rejected': -31.304590225219727, 'rewards/margins': 41.85428237915039, 'kl': 0.0, 'logps/chosen': -13.178544998168945, 'logps/rejected': -169.447509765625, 'epoch': 0.49}
{'loss': 0.0074, 'grad_norm': 0.001768477726727724, 'learning_rate': 8.356745295520805e-06, 'rewards/chosen': 11.22235107421875, 'rewards/rejected': -33.09156799316406, 'rewards/margins': 44.31391906738281, 'kl': 0.0, 'logps/chosen': -13.94648265838623, 'logps/rejected': -178.08566284179688, 'epoch': 0.5}
{'loss': 0.0055, 'grad_norm': 8.117822647094727, 'learning_rate': 8.330241187384045e-06, 'rewards/chosen': 11.166982650756836, 'rewards/rejected': nan, 'rewards/margins': nan, 'kl': 0.0, 'logps/chosen': -14.707374572753906, 'logps/rejected': nan, 'epoch': 0.51}
{'loss': 0.0206, 'grad_norm': 1.6973105669021606, 'learning_rate': 8.303737079247284e-06, 'rewards/chosen': 10.326757431030273, 'rewards/rejected': -33.753868103027344, 'rewards/margins': 44.08062744140625, 'kl': 0.0, 'logps/chosen': -19.15297508239746, 'logps/rejected': -181.1234130859375, 'epoch': 0.52}
{'loss': 0.0136, 'grad_norm': 9.740607261657715, 'learning_rate': 8.277232971110523e-06, 'rewards/chosen': 10.298160552978516, 'rewards/rejected': nan, 'rewards/margins': nan, 'kl': 0.0, 'logps/chosen': -15.718103408813477, 'logps/rejected': nan, 'epoch': 0.52}

my pip freeze:

accelerate==0.28.0
aiohttp==3.9.3
aiosignal==1.3.1
async-timeout==4.0.3
attrs==23.2.0
bitsandbytes==0.43.0
certifi==2024.2.2
charset-normalizer==3.3.2
datasets==2.18.0
dill==0.3.8
docstring_parser==0.16
filelock==3.13.1
frozenlist==1.4.1
fsspec==2024.2.0
huggingface-hub==0.21.4
idna==3.6
Jinja2==3.1.3
markdown-it-py==3.0.0
MarkupSafe==2.1.5
mdurl==0.1.2
mpmath==1.3.0
multidict==6.0.5
multiprocess==0.70.16
networkx==3.2.1
numpy==1.26.4
nvidia-cublas-cu12==12.1.3.1
nvidia-cuda-cupti-cu12==12.1.105
nvidia-cuda-nvrtc-cu12==12.1.105
nvidia-cuda-runtime-cu12==12.1.105
nvidia-cudnn-cu12==8.9.2.26
nvidia-cufft-cu12==11.0.2.54
nvidia-curand-cu12==10.3.2.106
nvidia-cusolver-cu12==11.4.5.107
nvidia-cusparse-cu12==12.1.0.106
nvidia-nccl-cu12==2.19.3
nvidia-nvjitlink-cu12==12.4.99
nvidia-nvtx-cu12==12.1.105
packaging==24.0
pandas==2.2.1
peft==0.9.0
protobuf==5.26.0
psutil==5.9.8
pyarrow==15.0.1
pyarrow-hotfix==0.6
Pygments==2.17.2
python-dateutil==2.9.0.post0
pytz==2024.1
PyYAML==6.0.1
regex==2023.12.25
requests==2.31.0
rich==13.7.1
safetensors==0.4.2
sentencepiece==0.2.0
shtab==1.7.1
six==1.16.0
sympy==1.12
tokenizers==0.15.2
torch==2.2.1
tqdm==4.66.2
transformers==4.38.2
triton==2.2.0
trl @ git+https://github.com/huggingface/trl@a2aa0f0b09671eaf81a945eb5e4913165fee92fa
typing_extensions==4.10.0
tyro==0.7.3
tzdata==2024.1
urllib3==2.2.1
xxhash==3.4.1
yarl==1.9.4

the training script I use:

from dataclasses import dataclass, field
from typing import Optional

from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser, BitsAndBytesConfig

from trl import KTOConfig, KTOTrainer, ModelConfig
from peft import LoraConfig, TaskType, get_peft_model, prepare_model_for_kbit_training
import torch


# Define and parse arguments.
@dataclass
class ScriptArguments:
    """
    The arguments for the KTO training script.
    """

    dataset_path: Optional[str] = field(default=None, metadata={"help": "the online dataset to use, should include keys: [prompt, completion, label] OR [messages, completion, label]"})
    data_files: Optional[str] = field(default=None, metadata={"help": "the file(s) including data to use, this looks for 'data/{data_files}_train/test.jsonl.gz'. Datasets should include keys: [prompt, completion, label] OR [messages, completion, label]"})
    file_type: Optional[str] = field(default=None, metadata={"help": "the file type to open, e.g. 'json', 'csv'"})
    max_tokens: Optional[str] = field(default=4096, metadata={"help": "the maximum number of tokens returned by the data collator"})
    # debugging
    sanity_check: Optional[bool] = field(default=False, metadata={"help": "only train on 1000 samples"})
    


if __name__ == "__main__":
    parser = HfArgumentParser((ScriptArguments, KTOConfig, ModelConfig))
    script_args, kto_args, model_args = parser.parse_args_into_dataclasses()
    print(f"train with {script_args}, \n{model_args}")

    # Peft & Quantisation
    quantization_config = BitsAndBytesConfig(load_in_8bit=model_args.load_in_8bit)
    peft_config = LoraConfig(task_type=TaskType.CAUSAL_LM, r=model_args.lora_r, lora_alpha=model_args.lora_alpha, lora_dropout=model_args.lora_dropout)

    # Load the trainable model
    model = AutoModelForCausalLM.from_pretrained(model_args.model_name_or_path,
                                                 quantization_config = quantization_config,
                                                 torch_dtype = getattr(torch, model_args.torch_dtype) if model_args.torch_dtype != None else None,
                                                 device_map = "auto")

    model = prepare_model_for_kbit_training(model)
    model = get_peft_model(model, peft_config = peft_config)
    model.print_trainable_parameters()

    # Reference Model
    model_ref = AutoModelForCausalLM.from_pretrained(model_args.model_name_or_path,
                                                     quantization_config = quantization_config,
                                                     torch_dtype = getattr(torch, model_args.torch_dtype) if model_args.torch_dtype != None else None,
                                                     device_map = "auto")

    model_ref = prepare_model_for_kbit_training(model_ref)
    model_ref = get_peft_model(model_ref, peft_config=peft_config)

    # Load Tokenizer
    tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path)
    tokenizer.truncation_side = "left"
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    # Load the desired dataset
    if script_args.dataset_path != None:
        dataset = load_dataset(script_args.dataset_path)
    elif script_args.data_files != None and script_args.file_type != None:
        dataset = load_dataset(script_args.file_type, data_files={"train": f"./data/{script_args.data_files}_train.jsonl.gz", "test": f"./data/{script_args.data_files}_test.jsonl.gz"})    
    else:
        print("either dataset_path or data_files & file_type have to be defined")
        exit(1)

    if script_args.sanity_check == True:
        dataset["train"] = dataset["train"].select(range(1000))
    
    # Create Split if not existing already
    if "test" not in dataset:
        dataset = dataset["train"].train_test_split(train_size=0.9)
    
    # apply chat template if not preformatted
    if "prompt" not in dataset["train"].features:
        dataset = dataset.map(lambda x: {"prompt": tokenizer.apply_chat_template(x["messages"], tokenize=False, add_generation_prompt=False)})

    # Set max. lengths for DefaultDataCollator

    max_prompt_len, max_compl_len, max_len = 0, 0, 0
    tokenizer.model_max_length = script_args.max_tokens
    tokenizer.max_model_input_sizes = script_args.max_tokens

    for sample in dataset["train"]:

        compl_len = len(tokenizer(sample["completion"], truncation=True)["input_ids"])
        total_len = len(tokenizer(sample["prompt"] + sample["completion"], truncation=True)["input_ids"])
        prompt_len = total_len - compl_len
    
        max_prompt_len = max(max_prompt_len, prompt_len)
        max_compl_len = max(max_compl_len, compl_len)
        max_len = max(max_len, total_len)

    kto_args.max_prompt_length = max_prompt_len
    kto_args.max_completion_length = max_compl_len
    kto_args.max_length = max_len

    print(dataset)
    print(f"max_prompt_length={kto_args.max_prompt_length}, max_completion_length={kto_args.max_completion_length}, max_len={kto_args.max_length}")

    # set desired/undesired weights

    desired_weight = len(dataset['train']) / (2 * len(dataset["train"].filter(lambda d: d["label"] == True)))
    undesired_weight = len(dataset['train']) / (2 * len(dataset["train"].filter(lambda d: d["label"] == False)))

    kto_args.desirable_weight = desired_weight
    kto_args.undesirable_weight = undesired_weight

    # initialize the KTO trainer
    kto_trainer = KTOTrainer(
        model,
        model_ref,
        args=kto_args,
        train_dataset=dataset["train"],
        eval_dataset=dataset["test"],
        tokenizer=tokenizer
    )

    # train
    kto_trainer.train()

the call arguments

python train_kto.py \
    --model_name_or_path DiscoResearch/DiscoLM_German_7b_v1 \
    --data_files wp_rag_kto_20k \
    --file_type json \
    --per_device_train_batch_size 8 \
    --per_device_eval_batch_size 8 \
    --num_train_epochs 3 \
    --learning_rate 1e-5 \
    --gradient_accumulation_steps 2 \
    --logging_steps 10 \
    --evaluation_strategy epoch \
    --output_dir kto_finetuned \
    --optim adamw_bnb_8bit \
    --warmup_steps 10 \
    --logging_first_step \
    --use_peft \
    --lora_r 8 \
    --lora_alpha 16 \
    --report_to none \
    --disable_tqdm False \
    --beta 0.5 \
    --torch_dtype bfloat16 \
    --bf16 \
    --load_in_8bit

Maybe @lewtun can help

@lewtun
Copy link
Member

lewtun commented Mar 19, 2024

cc also @kashif

@kashif
Copy link
Collaborator

kashif commented Mar 19, 2024

@claralp depending on the batch-size it could be some of the metrics are nan, this should not effect the training etc. and special attention has been paid to make sure the loss etc. is robust to these nans when doing back-prop.

@kashif
Copy link
Collaborator

kashif commented Mar 19, 2024

@claralp i do not think nans in a dict should cause this to crash... do you have some crash back-traces?

@claralp
Copy link
Contributor Author

claralp commented Mar 19, 2024

@kashif there are no errors or warnings in the stdout/stderr, it just stops at some point after the nan rewards appear, so I cannot provide a stack trace here.
However, the Azure execution wrapper log shows a blocking process:

2024-03-19T03:33:30.165457Z ERROR Execution::wait_for_completion{parent_span=Span { name: "Execution::spawn", level: Level(Info), target: "execution_wrapper::execution", id: Id(6755674318962691), module_path: "execution_wrapper::execution", line: 163, file: "executor/execution-wrapper/src/execution/mod.rs" } process_manager=Mutex { data: ProcessManager { dangling_processes: [], user_process_groups: [34] } }}: execution_wrapper::execution::process_manager: Failed blocking user process detected, process name: echo, process pid: 34, code: None success_return_code=Zero { additional_codes: [] } code=None
2024-03-19T03:33:31.167084Z ERROR Execution::wait_for_completion{parent_span=Span { name: "Execution::spawn", level: Level(Info), target: "execution_wrapper::execution", id: Id(6755674318962691), module_path: "execution_wrapper::execution", line: 163, file: "executor/execution-wrapper/src/execution/mod.rs" } process_manager=Mutex { data: ProcessManager { dangling_processes: [], user_process_groups: [34] } }}: execution_wrapper::execution: Execution process terminated by a signal, which may be due to failure in other user processes on the same node or node ran out of memory. local_rank=0 name=echo

lifecycler log shows only a Preemption signal:

2024-03-19T03:33:29.494161Z  WARN run_lifecycler:run_service_and_step_through_lifecycle:step_through_lifecycle: lifecycler::lifecycle: Received abort message, exiting lifecycle abort_message=AbortMessage { error: Some(Error { code: "ReceivedPreemptionSignal", message: "{\"Compliant\":\"Job was terminated due to: Runtime received a preemption signal.\"}", target: "", node_info: None, category: UserError, error_details: [], inner_error: None }), broadcast_abort: true, request_timeout: 25 }

@PhilipMay
Copy link
Contributor

PhilipMay commented Mar 19, 2024

I think this is could be the "normal" low-prioity Azure preemption? :-(

@claralp
Copy link
Contributor Author

claralp commented Mar 20, 2024

Important note here: The crash only appears after the training shows nan values. Otherwise it doesn't.
I even saw cases where all results converge to nan values

{'loss': 0.0, 'grad_norm': 281.6248474121094, 'learning_rate': 9.856115107913668e-07, 'rewards/chosen': nan, 'rewards/rejected': nan, 'rewards/margins': nan, 'kl': 0.17875319719314575, 'logps/chosen': nan, 'logps/rejected': nan, 'epoch': 0.11}
{'loss': 0.0, 'grad_norm': 192.08326721191406, 'learning_rate': 9.848121502797762e-07, 'rewards/chosen': nan, 'rewards/rejected': nan, 'rewards/margins': nan, 'kl': 1.0570355653762817, 'logps/chosen': nan, 'logps/rejected': nan, 'epoch': 0.11}
{'loss': 0.0, 'grad_norm': 33.55568313598633, 'learning_rate': 9.840127897681853e-07, 'rewards/chosen': nan, 'rewards/rejected': nan, 'rewards/margins': nan, 'kl': 2.1016669273376465, 'logps/chosen': nan, 'logps/rejected': nan, 'epoch': 0.12}
{'loss': 0.0, 'grad_norm': 44.5154914855957, 'learning_rate': 9.832134292565947e-07, 'rewards/chosen': nan, 'rewards/rejected': nan, 'rewards/margins': nan, 'kl': 2.197722911834717, 'logps/chosen': nan, 'logps/rejected': nan, 'epoch': 0.12}
{'loss': 0.0, 'grad_norm': 10.592936515808105, 'learning_rate': 9.82414068745004e-07, 'rewards/chosen': nan, 'rewards/rejected': nan, 'rewards/margins': nan, 'kl': 1.0713751316070557, 'logps/chosen': nan, 'logps/rejected': nan, 'epoch': 0.13}
{'loss': 0.0, 'grad_norm': 61.1552734375, 'learning_rate': 9.81614708233413e-07, 'rewards/chosen': nan, 'rewards/rejected': nan, 'rewards/margins': nan, 'kl': 0.3863883912563324, 'logps/chosen': nan, 'logps/rejected': nan, 'epoch': 0.13}

Could there be anything wrong with the hyperparameter choice, @kashif ?

@kashif
Copy link
Collaborator

kashif commented Mar 20, 2024

@claralp so the main hyperparam that could affect this is the batch size as it needs a good mix of good and bad examples, as well as for the KL estimates... your learning rate is tiny so that should be good... what is your batch size when you get all nans?

also does this happen if you try locally outside of the azure

@claralp
Copy link
Contributor Author

claralp commented Mar 20, 2024

The output below is from a test with very unbalanced data, namely 2k desired completions and 10k undesired ones.
I know that a ratio between 4:3 and 1:1 is required for proper training.
This is just an experiment to see if missing pos/neg samples in a batch might be the reason behind nan values as rewards.
But here I get nan losses even without nan rewards...

{'loss': 1.0431, 'grad_norm': 42.099464416503906, 'learning_rate': 1.0000000000000002e-06, 'rewards/chosen': 0.0, 'rewards/rejected': 0.0, 'rewards/margins': 0.0, 'kl': 0.0, 'logps/chosen': -37.16696548461914, 'logps/rejected': -87.62107849121094, 'epoch': 0.0}
{'loss': nan, 'grad_norm': 41.9438362121582, 'learning_rate': 2.0000000000000003e-06, 'rewards/chosen': 0.0, 'rewards/rejected': nan, 'rewards/margins': nan, 'kl': 0.0, 'logps/chosen': -32.92508316040039, 'logps/rejected': nan, 'epoch': 0.0}
{'loss': nan, 'grad_norm': 29.28327178955078, 'learning_rate': 3e-06, 'rewards/chosen': nan, 'rewards/rejected': nan, 'rewards/margins': nan, 'kl': 0.15479230880737305, 'logps/chosen': nan, 'logps/rejected': nan, 'epoch': 0.0}
{'loss': nan, 'grad_norm': 32.70748519897461, 'learning_rate': 4.000000000000001e-06, 'rewards/chosen': 0.06518054008483887, 'rewards/rejected': nan, 'rewards/margins': nan, 'kl': 0.43951892852783203, 'logps/chosen': -31.101844787597656, 'logps/rejected': nan, 'epoch': 0.0}
{'loss': nan, 'grad_norm': 44.989227294921875, 'learning_rate': 5e-06, 'rewards/chosen': 0.3087962865829468, 'rewards/rejected': 0.23543643951416016, 'rewards/margins': 0.07335984706878662, 'kl': 1.230994462966919, 'logps/chosen': -32.83413314819336, 'logps/rejected': -74.81724548339844, 'epoch': 0.0}
{'loss': nan, 'grad_norm': 55.32667541503906, 'learning_rate': 6e-06, 'rewards/chosen': 0.3336696922779083, 'rewards/rejected': nan, 'rewards/margins': nan, 'kl': 0.3016533851623535, 'logps/chosen': -38.598453521728516, 'logps/rejected': nan, 'epoch': 0.0}
{'loss': nan, 'grad_norm': 32.44403839111328, 'learning_rate': 7e-06, 'rewards/chosen': 0.8524215817451477, 'rewards/rejected': 0.5893988609313965, 'rewards/margins': 0.2630227208137512, 'kl': 0.7648882865905762, 'logps/chosen': -35.86614227294922, 'logps/rejected': -93.13447570800781, 'epoch': 0.01}
{'loss': nan, 'grad_norm': 26.85154914855957, 'learning_rate': 8.000000000000001e-06, 'rewards/chosen': 0.8056153059005737, 'rewards/rejected': 0.40718716382980347, 'rewards/margins': 0.39842814207077026, 'kl': 1.3891675472259521, 'logps/chosen': -34.07681655883789, 'logps/rejected': -113.53411102294922, 'epoch': 0.01}
{'loss': nan, 'grad_norm': 25.181703567504883, 'learning_rate': 9e-06, 'rewards/chosen': nan, 'rewards/rejected': 0.9289813041687012, 'rewards/margins': nan, 'kl': 1.279036521911621, 'logps/chosen': nan, 'logps/rejected': -132.0060272216797, 'epoch': 0.01}
{'loss': nan, 'grad_norm': 36.62141799926758, 'learning_rate': 1e-05, 'rewards/chosen': 1.4094278812408447, 'rewards/rejected': 0.8396401405334473, 'rewards/margins': 0.5697878003120422, 'kl': 2.0255985260009766, 'logps/chosen': -30.87615394592285, 'logps/rejected': -102.92286682128906, 'epoch': 0.01}
{'loss': nan, 'grad_norm': 43.035221099853516, 'learning_rate': 9.997300215982722e-06, 'rewards/chosen': 1.5928469896316528, 'rewards/rejected': 1.5922844409942627, 'rewards/margins': 0.0005625784397125244, 'kl': 2.884922981262207, 'logps/chosen': -39.46299362182617, 'logps/rejected': -121.78970336914062, 'epoch': 0.01}
{'loss': nan, 'grad_norm': 33.07608413696289, 'learning_rate': 9.994600431965443e-06, 'rewards/chosen': nan, 'rewards/rejected': nan, 'rewards/margins': nan, 'kl': 3.1301448345184326, 'logps/chosen': nan, 'logps/rejected': nan, 'epoch': 0.01}
{'loss': nan, 'grad_norm': 43.48128128051758, 'learning_rate': 9.991900647948165e-06, 'rewards/chosen': 2.113973617553711, 'rewards/rejected': nan, 'rewards/margins': nan, 'kl': 3.475428819656372, 'logps/chosen': -26.679065704345703, 'logps/rejected': nan, 'epoch': 0.01}
{'loss': nan, 'grad_norm': 31.501819610595703, 'learning_rate': 9.989200863930886e-06, 'rewards/chosen': 2.6266024112701416, 'rewards/rejected': 2.2295963764190674, 'rewards/margins': 0.3970060348510742, 'kl': 4.643209934234619, 'logps/chosen': -42.25154495239258, 'logps/rejected': -95.91471862792969, 'epoch': 0.01}
{'loss': nan, 'grad_norm': 34.09553527832031, 'learning_rate': 9.986501079913607e-06, 'rewards/chosen': 2.7660703659057617, 'rewards/rejected': 2.6509010791778564, 'rewards/margins': 0.11516910791397095, 'kl': 4.8384199142456055, 'logps/chosen': -49.93422317504883, 'logps/rejected': -73.00190734863281, 'epoch': 0.01}
{'loss': nan, 'grad_norm': 30.591957092285156, 'learning_rate': 9.983801295896329e-06, 'rewards/chosen': 3.131122350692749, 'rewards/rejected': 2.9620559215545654, 'rewards/margins': 0.1690664291381836, 'kl': 4.498130798339844, 'logps/chosen': -29.836196899414062, 'logps/rejected': -105.75230407714844, 'epoch': 0.01}
{'loss': nan, 'grad_norm': 13.737163543701172, 'learning_rate': 9.98110151187905e-06, 'rewards/chosen': nan, 'rewards/rejected': 3.1204824447631836, 'rewards/margins': nan, 'kl': 6.049262523651123, 'logps/chosen': nan, 'logps/rejected': -96.40724182128906, 'epoch': 0.01}
{'loss': nan, 'grad_norm': 30.375396728515625, 'learning_rate': 9.978401727861771e-06, 'rewards/chosen': nan, 'rewards/rejected': 3.636046886444092, 'rewards/margins': nan, 'kl': 6.3599958419799805, 'logps/chosen': nan, 'logps/rejected': -97.00442504882812, 'epoch': 0.01}
{'loss': nan, 'grad_norm': 27.26076889038086, 'learning_rate': 9.975701943844493e-06, 'rewards/chosen': 4.384129524230957, 'rewards/rejected': 3.9822707176208496, 'rewards/margins': 0.40185898542404175, 'kl': 8.23063850402832, 'logps/chosen': -24.248661041259766, 'logps/rejected': -105.89572143554688, 'epoch': 0.02}
{'loss': nan, 'grad_norm': 18.513507843017578, 'learning_rate': 9.973002159827214e-06, 'rewards/chosen': 4.265963077545166, 'rewards/rejected': 3.8863425254821777, 'rewards/margins': 0.3796207308769226, 'kl': 6.635190010070801, 'logps/chosen': -24.802963256835938, 'logps/rejected': -68.99553680419922, 'epoch': 0.02}
{'loss': nan, 'grad_norm': 25.997692108154297, 'learning_rate': 9.970302375809935e-06, 'rewards/chosen': 5.037494659423828, 'rewards/rejected': 4.227317810058594, 'rewards/margins': 0.8101770877838135, 'kl': 8.07493782043457, 'logps/chosen': -24.345657348632812, 'logps/rejected': -74.88150024414062, 'epoch': 0.02}
{'loss': nan, 'grad_norm': 26.245861053466797, 'learning_rate': 9.967602591792658e-06, 'rewards/chosen': 4.526309490203857, 'rewards/rejected': 4.603299140930176, 'rewards/margins': -0.07698965072631836, 'kl': 8.698637008666992, 'logps/chosen': -22.94290542602539, 'logps/rejected': -99.22356414794922, 'epoch': 0.02}
{'loss': nan, 'grad_norm': 22.14063835144043, 'learning_rate': 9.964902807775378e-06, 'rewards/chosen': 5.355809211730957, 'rewards/rejected': 4.891297340393066, 'rewards/margins': 0.464511513710022, 'kl': 8.954204559326172, 'logps/chosen': -23.850910186767578, 'logps/rejected': -87.7445068359375, 'epoch': 0.02}
{'loss': nan, 'grad_norm': 25.642059326171875, 'learning_rate': 9.962203023758101e-06, 'rewards/chosen': 5.606294631958008, 'rewards/rejected': 6.807004928588867, 'rewards/margins': -1.2007099390029907, 'kl': 9.733396530151367, 'logps/chosen': -24.039264678955078, 'logps/rejected': -119.2092514038086, 'epoch': 0.02}
{'loss': nan, 'grad_norm': 10.412492752075195, 'learning_rate': 9.959503239740822e-06, 'rewards/chosen': 5.953470230102539, 'rewards/rejected': 5.025949954986572, 'rewards/margins': 0.9275206327438354, 'kl': 10.74533462524414, 'logps/chosen': -16.727996826171875, 'logps/rejected': -80.9796142578125, 'epoch': 0.02}
{'loss': nan, 'grad_norm': 17.695709228515625, 'learning_rate': 9.956803455723542e-06, 'rewards/chosen': nan, 'rewards/rejected': 6.109594345092773, 'rewards/margins': nan, 'kl': 11.900070190429688, 'logps/chosen': nan, 'logps/rejected': -121.30842590332031, 'epoch': 0.02}
{'loss': nan, 'grad_norm': 35.035892486572266, 'learning_rate': 9.954103671706265e-06, 'rewards/chosen': 6.687896251678467, 'rewards/rejected': nan, 'rewards/margins': nan, 'kl': 12.4317626953125, 'logps/chosen': -16.511333465576172, 'logps/rejected': nan, 'epoch': 0.02}

@claralp
Copy link
Contributor Author

claralp commented Mar 20, 2024

kashif commented 1 hour ago
@claralp so the main hyperparam that could affect this is the batch size as it needs a good mix of good and bad examples, as well as for the KL estimates... your learning rate is tiny so that should be good... what is your batch size when you get all nans?

batch size is 8 and gradient accumulation steps is 2 as in the config above

also does this happen if you try locally outside of the azure

currently checking this

@claralp
Copy link
Contributor Author

claralp commented Apr 11, 2024

closed with #1499 and #1514

@claralp claralp closed this as completed Apr 11, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants