diff --git a/llmfoundry/utils/builders.py b/llmfoundry/utils/builders.py index 32f7aceea3..7ce061ee3a 100644 --- a/llmfoundry/utils/builders.py +++ b/llmfoundry/utils/builders.py @@ -8,7 +8,7 @@ import torch from composer import algorithms from composer.callbacks import (EarlyStopper, LRMonitor, MemoryMonitor, - OptimizerMonitor, RuntimeEstimator, + OptimizerMonitor, RuntimeEstimator, EvalOutputLogging, SpeedMonitor) from composer.core import Algorithm, Callback, Evaluator from composer.datasets.in_context_learning_evaluation import \ @@ -101,6 +101,8 @@ def build_callback(name: str, kwargs: Dict[str, Any]) -> Callback: return EarlyStopper(**kwargs) elif name == 'hf_checkpointer': return HuggingFaceCheckpointer(**kwargs) + elif name == 'eval_output_logging': + return EvalOutputLogging(**kwargs) else: raise ValueError(f'Not sure how to build callback: {name}') diff --git a/mcli/mcli-hf-eval.yaml b/mcli/mcli-hf-eval.yaml index 5cf4cf98da..3171671567 100644 --- a/mcli/mcli-hf-eval.yaml +++ b/mcli/mcli-hf-eval.yaml @@ -7,6 +7,7 @@ integrations: ssh_clone: false # Should be true if using a private repo command: | + pip install git+https://github.com/bmosaicml/composer.git@error_logging_callback cd llm-foundry/scripts composer eval/eval.py /mnt/config/parameters.yaml diff --git a/scripts/eval/eval.py b/scripts/eval/eval.py index 1ba723a172..a13a087e43 100644 --- a/scripts/eval/eval.py +++ b/scripts/eval/eval.py @@ -7,6 +7,7 @@ import time import warnings from typing import Any, Dict, List, Optional, Union +from composer.core.callback import Callback import pandas as pd import torch @@ -21,7 +22,7 @@ from llmfoundry.models import MPTForCausalLM from llmfoundry.models.model_registry import COMPOSER_MODEL_REGISTRY -from llmfoundry.utils.builders import (build_icl_data_and_gauntlet, +from llmfoundry.utils.builders import (build_icl_data_and_gauntlet, build_callback, build_logger, build_tokenizer) from llmfoundry.utils.config_utils import pop_config, process_init_device @@ -106,6 +107,7 @@ def evaluate_model( precision: str, eval_gauntlet_df: Optional[pd.DataFrame], icl_subset_num_batches: Optional[int], + callback_configs: Optional[Dict] ): print(f'Evaluating model: {model_cfg.model_name}', flush=True) # Build tokenizer and model @@ -120,7 +122,12 @@ def evaluate_model( icl_tasks, eval_gauntlet_config, tokenizer, device_eval_batch_size, max_seq_len, icl_subset_num_batches) - callbacks = [] + # Callbacks + callbacks: List[Callback] = [ + build_callback(str(name), callback_cfg) + for name, callback_cfg in callback_configs.items() + ] if callback_configs else [] + if eval_gauntlet_callback is not None: callbacks.append(eval_gauntlet_callback) @@ -170,6 +177,7 @@ def evaluate_model( dist_timeout=dist_timeout, python_log_level=python_log_level, ) + breakpoint() if torch.cuda.is_available(): torch.cuda.synchronize() @@ -245,7 +253,11 @@ def main(cfg: DictConfig): default_value=None) # Pop out interpolation variables. pop_config(cfg, 'model_name_or_path', must_exist=False, default_value=None) - + callback_configs: Optional[DictConfig] = pop_config(cfg, + 'callbacks', + must_exist=False, + default_value=None) + # Warn for unused parameters for key in cfg: warnings.warn( @@ -283,7 +295,9 @@ def main(cfg: DictConfig): python_log_level=python_log_level, precision=precision, eval_gauntlet_df=eval_gauntlet_df, - icl_subset_num_batches=icl_subset_num_batches) + icl_subset_num_batches=icl_subset_num_batches, + callback_configs=callback_configs + ) if eval_gauntlet_callback is not None: composite_scores = eval_gauntlet_callback.eval_after_all( diff --git a/scripts/eval/yamls/hf_eval.yaml b/scripts/eval/yamls/hf_eval.yaml index 8eecf57c30..dd2794005b 100644 --- a/scripts/eval/yamls/hf_eval.yaml +++ b/scripts/eval/yamls/hf_eval.yaml @@ -45,3 +45,8 @@ device_eval_batch_size: 4 icl_tasks: 'eval/yamls/tasks.yaml' eval_gauntlet: 'eval/yamls/eval_gauntlet.yaml' + +callbacks: + eval_output_logging: + print_only_incorrect: false + subset_sample: 100 \ No newline at end of file diff --git a/scripts/eval/yamls/tasks.yaml b/scripts/eval/yamls/tasks.yaml index 70ef2ca667..0dce273212 100644 --- a/scripts/eval/yamls/tasks.yaml +++ b/scripts/eval/yamls/tasks.yaml @@ -1,175 +1,175 @@ icl_tasks: - label: jeopardy - dataset_uri: eval/local_data/world_knowledge/jeopardy_all.jsonl # ADD YOUR OWN DATASET URI + dataset_uri: eval/local_data/world_knowledge/jeopardy_small.jsonl # ADD YOUR OWN DATASET URI num_fewshot: [10] icl_task_type: language_modeling continuation_delimiter: "\nAnswer: " # this separates questions from answers has_categories: true -- - label: bigbench_qa_wikidata - dataset_uri: eval/local_data/world_knowledge/bigbench_qa_wikidata.jsonl # ADD YOUR OWN DATASET URI - num_fewshot: [10] - icl_task_type: language_modeling -- - label: arc_easy - dataset_uri: eval/local_data/world_knowledge/arc_easy.jsonl # ADD YOUR OWN DATASET URI - num_fewshot: [10] - icl_task_type: multiple_choice - continuation_delimiter: "\nAnswer: " # this separates questions from answers -- - label: arc_challenge - dataset_uri: eval/local_data/world_knowledge/arc_challenge.jsonl # ADD YOUR OWN DATASET URI - num_fewshot: [10] - icl_task_type: multiple_choice - continuation_delimiter: "\nAnswer: " # this separates questions from answers -- - label: mmlu - dataset_uri: eval/local_data/world_knowledge/mmlu.jsonl # ADD YOUR OWN DATASET URI - num_fewshot: [10] - icl_task_type: multiple_choice - continuation_delimiter: "\nAnswer: " # this separates questions from answers - has_categories: true -- - label: bigbench_misconceptions - dataset_uri: eval/local_data/world_knowledge/bigbench_misconceptions.jsonl # ADD YOUR OWN DATASET URI - num_fewshot: [10] - icl_task_type: multiple_choice -- - label: copa - dataset_uri: eval/local_data/commonsense_reasoning/copa.jsonl # ADD YOUR OWN DATASET URI - num_fewshot: [0] - icl_task_type: multiple_choice -- - label: piqa - dataset_uri: eval/local_data/commonsense_reasoning/piqa.jsonl # ADD YOUR OWN DATASET URI - num_fewshot: [10] - icl_task_type: multiple_choice - continuation_delimiter: "\nAnswer: " # this separates questions from answers -- - label: openbook_qa - dataset_uri: eval/local_data/commonsense_reasoning/openbook_qa.jsonl # ADD YOUR OWN DATASET URI - num_fewshot: [0] - icl_task_type: multiple_choice -- - label: bigbench_novel_concepts - dataset_uri: eval/local_data/commonsense_reasoning/bigbench_novel_concepts.jsonl # ADD YOUR OWN DATASET URI - num_fewshot: [10] - icl_task_type: multiple_choice -- - label: bigbench_strange_stories - dataset_uri: eval/local_data/commonsense_reasoning/bigbench_strange_stories.jsonl # ADD YOUR OWN DATASET URI - num_fewshot: [10] - icl_task_type: multiple_choice -- - label: bigbench_strategy_qa - dataset_uri: eval/local_data/commonsense_reasoning/bigbench_strategy_qa.jsonl # ADD YOUR OWN DATASET URI - num_fewshot: [10] - icl_task_type: multiple_choice -- - label: lambada_openai - dataset_uri: eval/local_data/language_understanding/lambada_openai.jsonl - num_fewshot: [0] - icl_task_type: language_modeling -- - label: hellaswag - dataset_uri: eval/local_data/language_understanding/hellaswag.jsonl # ADD YOUR OWN DATASET URI - num_fewshot: [10] - icl_task_type: multiple_choice -- - label: winograd - dataset_uri: eval/local_data/language_understanding/winograd_wsc.jsonl # ADD YOUR OWN DATASET URI - num_fewshot: [0] - icl_task_type: schema -- - label: winogrande - dataset_uri: eval/local_data/language_understanding/winogrande.jsonl # ADD YOUR OWN DATASET URI - num_fewshot: [0] - icl_task_type: schema -- - label: bigbench_conlang_translation - dataset_uri: eval/local_data/language_understanding/bigbench_conlang_translation.jsonl - num_fewshot: [0] - icl_task_type: language_modeling -- - label: bigbench_language_identification - dataset_uri: eval/local_data/language_understanding/bigbench_language_identification.jsonl - num_fewshot: [10] - icl_task_type: multiple_choice -- - label: bigbench_conceptual_combinations - dataset_uri: eval/local_data/language_understanding/bigbench_conceptual_combinations.jsonl - num_fewshot: [10] - icl_task_type: multiple_choice -- - label: bigbench_elementary_math_qa - dataset_uri: eval/local_data/symbolic_problem_solving/bigbench_elementary_math_qa.jsonl - num_fewshot: [10] - icl_task_type: multiple_choice -- - label: bigbench_dyck_languages - dataset_uri: eval/local_data/symbolic_problem_solving/bigbench_dyck_languages.jsonl - num_fewshot: [10] - icl_task_type: language_modeling -- - label: bigbench_cs_algorithms - dataset_uri: eval/local_data/symbolic_problem_solving/bigbench_cs_algorithms.jsonl - num_fewshot: [10] - icl_task_type: language_modeling -- - label: bigbench_logical_deduction - dataset_uri: eval/local_data/symbolic_problem_solving/bigbench_logical_deduction.jsonl - num_fewshot: [10] - icl_task_type: multiple_choice -- - label: bigbench_operators - dataset_uri: eval/local_data/symbolic_problem_solving/bigbench_operators.jsonl - num_fewshot: [10] - icl_task_type: language_modeling -- - label: bigbench_repeat_copy_logic - dataset_uri: eval/local_data/symbolic_problem_solving/bigbench_repeat_copy_logic.jsonl - num_fewshot: [10] - icl_task_type: language_modeling -- - label: simple_arithmetic_nospaces - dataset_uri: eval/local_data/symbolic_problem_solving/simple_arithmetic_nospaces.jsonl - num_fewshot: [10] - icl_task_type: language_modeling -- - label: simple_arithmetic_withspaces - dataset_uri: eval/local_data/symbolic_problem_solving/simple_arithmetic_withspaces.jsonl - num_fewshot: [10] - icl_task_type: language_modeling -- - label: math_qa - dataset_uri: eval/local_data/symbolic_problem_solving/math_qa.jsonl - num_fewshot: [10] - icl_task_type: multiple_choice -- - label: logi_qa - dataset_uri: eval/local_data/symbolic_problem_solving/logi_qa.jsonl - num_fewshot: [10] - icl_task_type: multiple_choice - continuation_delimiter: "\nAnswer: " # this separates questions from answers -- - label: pubmed_qa_labeled - dataset_uri: eval/local_data/reading_comprehension/pubmed_qa_labeled.jsonl # ADD YOUR OWN DATASET URI - num_fewshot: [10] - icl_task_type: language_modeling -- - label: squad - dataset_uri: eval/local_data/reading_comprehension/squad.jsonl # ADD YOUR OWN DATASET URI - num_fewshot: [10] - icl_task_type: language_modeling -- - label: bigbench_understanding_fables - dataset_uri: eval/local_data/reading_comprehension/bigbench_understanding_fables.jsonl # ADD YOUR OWN DATASET URI - num_fewshot: [10] - icl_task_type: multiple_choice -- - label: boolq - dataset_uri: eval/local_data/reading_comprehension/boolq.jsonl # ADD YOUR OWN DATASET URI - num_fewshot: [10] - icl_task_type: multiple_choice - continuation_delimiter: "\nAnswer: " # this separates questions from answers +# - +# label: bigbench_qa_wikidata +# dataset_uri: eval/local_data/world_knowledge/bigbench_qa_wikidata.jsonl # ADD YOUR OWN DATASET URI +# num_fewshot: [10] +# icl_task_type: language_modeling +# - +# label: arc_easy +# dataset_uri: eval/local_data/world_knowledge/arc_easy.jsonl # ADD YOUR OWN DATASET URI +# num_fewshot: [10] +# icl_task_type: multiple_choice +# continuation_delimiter: "\nAnswer: " # this separates questions from answers +# - +# label: arc_challenge +# dataset_uri: eval/local_data/world_knowledge/arc_challenge.jsonl # ADD YOUR OWN DATASET URI +# num_fewshot: [10] +# icl_task_type: multiple_choice +# continuation_delimiter: "\nAnswer: " # this separates questions from answers +# - +# label: mmlu +# dataset_uri: eval/local_data/world_knowledge/mmlu.jsonl # ADD YOUR OWN DATASET URI +# num_fewshot: [10] +# icl_task_type: multiple_choice +# continuation_delimiter: "\nAnswer: " # this separates questions from answers +# has_categories: true +# - +# label: bigbench_misconceptions +# dataset_uri: eval/local_data/world_knowledge/bigbench_misconceptions.jsonl # ADD YOUR OWN DATASET URI +# num_fewshot: [10] +# icl_task_type: multiple_choice +# - +# label: copa +# dataset_uri: eval/local_data/commonsense_reasoning/copa.jsonl # ADD YOUR OWN DATASET URI +# num_fewshot: [0] +# icl_task_type: multiple_choice +# - +# label: piqa +# dataset_uri: eval/local_data/commonsense_reasoning/piqa.jsonl # ADD YOUR OWN DATASET URI +# num_fewshot: [10] +# icl_task_type: multiple_choice +# continuation_delimiter: "\nAnswer: " # this separates questions from answers +# - +# label: openbook_qa +# dataset_uri: eval/local_data/commonsense_reasoning/openbook_qa.jsonl # ADD YOUR OWN DATASET URI +# num_fewshot: [0] +# icl_task_type: multiple_choice +# - +# label: bigbench_novel_concepts +# dataset_uri: eval/local_data/commonsense_reasoning/bigbench_novel_concepts.jsonl # ADD YOUR OWN DATASET URI +# num_fewshot: [10] +# icl_task_type: multiple_choice +# - +# label: bigbench_strange_stories +# dataset_uri: eval/local_data/commonsense_reasoning/bigbench_strange_stories.jsonl # ADD YOUR OWN DATASET URI +# num_fewshot: [10] +# icl_task_type: multiple_choice +# - +# label: bigbench_strategy_qa +# dataset_uri: eval/local_data/commonsense_reasoning/bigbench_strategy_qa.jsonl # ADD YOUR OWN DATASET URI +# num_fewshot: [10] +# icl_task_type: multiple_choice +# - +# label: lambada_openai +# dataset_uri: eval/local_data/language_understanding/lambada_openai.jsonl +# num_fewshot: [0] +# icl_task_type: language_modeling +# - +# label: hellaswag +# dataset_uri: eval/local_data/language_understanding/hellaswag.jsonl # ADD YOUR OWN DATASET URI +# num_fewshot: [10] +# icl_task_type: multiple_choice +# - +# label: winograd +# dataset_uri: eval/local_data/language_understanding/winograd_wsc.jsonl # ADD YOUR OWN DATASET URI +# num_fewshot: [0] +# icl_task_type: schema +# - +# label: winogrande +# dataset_uri: eval/local_data/language_understanding/winogrande.jsonl # ADD YOUR OWN DATASET URI +# num_fewshot: [0] +# icl_task_type: schema +# - +# label: bigbench_conlang_translation +# dataset_uri: eval/local_data/language_understanding/bigbench_conlang_translation.jsonl +# num_fewshot: [0] +# icl_task_type: language_modeling +# - +# label: bigbench_language_identification +# dataset_uri: eval/local_data/language_understanding/bigbench_language_identification.jsonl +# num_fewshot: [10] +# icl_task_type: multiple_choice +# - +# label: bigbench_conceptual_combinations +# dataset_uri: eval/local_data/language_understanding/bigbench_conceptual_combinations.jsonl +# num_fewshot: [10] +# icl_task_type: multiple_choice +# - +# label: bigbench_elementary_math_qa +# dataset_uri: eval/local_data/symbolic_problem_solving/bigbench_elementary_math_qa.jsonl +# num_fewshot: [10] +# icl_task_type: multiple_choice +# - +# label: bigbench_dyck_languages +# dataset_uri: eval/local_data/symbolic_problem_solving/bigbench_dyck_languages.jsonl +# num_fewshot: [10] +# icl_task_type: language_modeling +# - +# label: bigbench_cs_algorithms +# dataset_uri: eval/local_data/symbolic_problem_solving/bigbench_cs_algorithms.jsonl +# num_fewshot: [10] +# icl_task_type: language_modeling +# - +# label: bigbench_logical_deduction +# dataset_uri: eval/local_data/symbolic_problem_solving/bigbench_logical_deduction.jsonl +# num_fewshot: [10] +# icl_task_type: multiple_choice +# - +# label: bigbench_operators +# dataset_uri: eval/local_data/symbolic_problem_solving/bigbench_operators.jsonl +# num_fewshot: [10] +# icl_task_type: language_modeling +# - +# label: bigbench_repeat_copy_logic +# dataset_uri: eval/local_data/symbolic_problem_solving/bigbench_repeat_copy_logic.jsonl +# num_fewshot: [10] +# icl_task_type: language_modeling +# - +# label: simple_arithmetic_nospaces +# dataset_uri: eval/local_data/symbolic_problem_solving/simple_arithmetic_nospaces.jsonl +# num_fewshot: [10] +# icl_task_type: language_modeling +# - +# label: simple_arithmetic_withspaces +# dataset_uri: eval/local_data/symbolic_problem_solving/simple_arithmetic_withspaces.jsonl +# num_fewshot: [10] +# icl_task_type: language_modeling +# - +# label: math_qa +# dataset_uri: eval/local_data/symbolic_problem_solving/math_qa.jsonl +# num_fewshot: [10] +# icl_task_type: multiple_choice +# - +# label: logi_qa +# dataset_uri: eval/local_data/symbolic_problem_solving/logi_qa.jsonl +# num_fewshot: [10] +# icl_task_type: multiple_choice +# continuation_delimiter: "\nAnswer: " # this separates questions from answers +# - +# label: pubmed_qa_labeled +# dataset_uri: eval/local_data/reading_comprehension/pubmed_qa_labeled.jsonl # ADD YOUR OWN DATASET URI +# num_fewshot: [10] +# icl_task_type: language_modeling +# - +# label: squad +# dataset_uri: eval/local_data/reading_comprehension/squad.jsonl # ADD YOUR OWN DATASET URI +# num_fewshot: [10] +# icl_task_type: language_modeling +# - +# label: bigbench_understanding_fables +# dataset_uri: eval/local_data/reading_comprehension/bigbench_understanding_fables.jsonl # ADD YOUR OWN DATASET URI +# num_fewshot: [10] +# icl_task_type: multiple_choice +# - +# label: boolq +# dataset_uri: eval/local_data/reading_comprehension/boolq.jsonl # ADD YOUR OWN DATASET URI +# num_fewshot: [10] +# icl_task_type: multiple_choice +# continuation_delimiter: "\nAnswer: " # this separates questions from answers diff --git a/setup.py b/setup.py index b07b8afe08..238b83abfb 100644 --- a/setup.py +++ b/setup.py @@ -47,7 +47,8 @@ ] install_requires = [ - 'mosaicml[libcloud,wandb,mlflow]>=0.16.1,<0.17', + # 'mosaicml[libcloud,wandb,mlflow]>=0.16.1,<0.17', + 'mosaicml@git+https://github.com/bmosaicml/composer.git@error_logging_callback', 'accelerate>=0.20,<0.21', # for HF inference `device_map` 'transformers>=4.33,<4.34', 'mosaicml-streaming>=0.5.1,<0.6', @@ -80,12 +81,12 @@ ] extra_deps['tensorboard'] = [ - 'mosaicml[tensorboard]>=0.16.1,<0.17', + # 'mosaicml[tensorboard]>=0.16.1,<0.17', ] extra_deps['gpu'] = [ 'flash-attn==v1.0.3.post0', - 'mosaicml-turbo==0.0.3', + # 'mosaicml-turbo==0.0.3', # PyPI does not support direct dependencies, so we remove this line before uploading from PyPI 'xentropy-cuda-lib@git+https://github.com/HazyResearch/flash-attention.git@v1.0.3#subdirectory=csrc/xentropy', ] diff --git a/tests/test_lion8b.py b/tests/test_lion8b.py index 2852d99b8b..dbd6ff6352 100644 --- a/tests/test_lion8b.py +++ b/tests/test_lion8b.py @@ -504,45 +504,59 @@ def _set_state_dict_type(model: nn.Module): def test_fused_as_fast_as_unfused(N: int, D: int, min_elems_traversed: int = 1000000): - W = torch.randn((N, D), device='cuda', requires_grad=True) - W.grad = torch.randn((N, D), device='cuda', requires_grad=False) - - num_iters = int(np.ceil(min_elems_traversed / W.grad.numel())) - num_iters = min(100, num_iters) # don't take all day when overhead-bound - - times = {} - kwargs = {'weight_decay': .01} - combos = [(True, False), (True, True), (False, False), ('NA', False)] - for fused, use_errors in combos: - if fused == 'NA': - opt = Lion8bit([W], quantize=False, - **kwargs) # type:ignore (reportGeneralTypeIssues) - else: - opt = Lion8bit([W], - _fused=fused, - error_correction=use_errors, - **kwargs) # type:ignore (reportGeneralTypeIssues) - for _ in range(3): - opt.step() # warmup iters - torch.cuda.synchronize() - t_start = time.time() - for _ in range(num_iters): - opt.step() - torch.cuda.synchronize() - t_end = time.time() - dur = (t_end - t_start) / num_iters - if use_errors: - times['ecc'] = dur - else: - times[fused] = dur - - atol = 20e-6 # should always be faster, but avoids rare flakiness - assert times[True] < times[False] + atol - assert times[True] < times['NA'] + atol - assert times['ecc'] < times['NA'] + atol - - print('') - print('time fused (ms): ', times[True] * 1e3) - print('time fused+ecc (ms): ', times['ecc'] * 1e3) - print('time unfused (ms): ', times[False] * 1e3) - print('time unquantized (ms): ', times['NA'] * 1e3) + + def _time_kernels(N: int, D: int, min_elems_traversed: int): + W = torch.randn((N, D), device='cuda', requires_grad=True) + W.grad = torch.randn((N, D), device='cuda', requires_grad=False) + + num_iters = int(np.ceil(min_elems_traversed / W.grad.numel())) + num_iters = min(100, + num_iters) # don't take all day when overhead-bound + + times = {} + kwargs = {'weight_decay': .01} + combos = [(True, False), (True, True), (False, False), ('NA', False)] + for fused, use_errors in combos: + if fused == 'NA': + opt = Lion8bit( + [W], quantize=False, + **kwargs) # type:ignore (reportGeneralTypeIssues) + else: + opt = Lion8bit( + [W], _fused=fused, error_correction=use_errors, + **kwargs) # type:ignore (reportGeneralTypeIssues) + for _ in range(3): + opt.step() # warmup iters + torch.cuda.synchronize() + t_start = time.time() + for _ in range(num_iters): + opt.step() + torch.cuda.synchronize() + t_end = time.time() + dur = (t_end - t_start) / num_iters + if use_errors: + times['ecc'] = dur + else: + times[fused] = dur + return times + + times = _time_kernels(N, D, min_elems_traversed) + + atol = 2e-4 # should always be faster, but atol helps avoid flakiness + it = 0 + while True: + try: + assert times[True] < times[False] + atol + assert times[True] < times['NA'] + atol + assert times['ecc'] < times['NA'] + atol + print('') + print('time fused (ms): ', times[True] * 1e3) + print('time fused+ecc (ms): ', times['ecc'] * 1e3) + print('time unfused (ms): ', times[False] * 1e3) + print('time unquantized (ms): ', times['NA'] * 1e3) + break + except AssertionError as e: + if it >= 2: # allow 3 retries to avoid flakiness + raise e + times = _time_kernels(N, D, min_elems_traversed) + it += 1