diff --git a/.github/workflows/tests-main.yml b/.github/workflows/tests-main.yml index 6025297883..aa5b88babd 100644 --- a/.github/workflows/tests-main.yml +++ b/.github/workflows/tests-main.yml @@ -32,7 +32,7 @@ jobs: pip install -U git+https://github.com/huggingface/peft.git pip install -U git+https://github.com/huggingface/transformers.git # cpu version of pytorch - pip install -e ".[test, diffusers]" + pip install ".[test, diffusers]" - name: Test with pytest run: | make test diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 4fd69d610e..a468890cbb 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -54,7 +54,7 @@ jobs: run: | python -m pip install --upgrade pip # cpu version of pytorch - pip install -e ".[test, peft, diffusers]" + pip install ".[test, peft, diffusers]" - name: Test with pytest run: | make test diff --git a/tests/test_cli.py b/tests/test_cli.py index d26ed5ffba..861fcaf535 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -18,17 +18,23 @@ @unittest.skipIf(sys.platform.startswith("win"), "Skipping on Windows") def test_sft_cli(): - subprocess.run( - "trl sft --max_steps 1 --output_dir tmp-sft --model_name_or_path HuggingFaceM4/tiny-random-LlamaForCausalLM --dataset_name imdb --learning_rate 1e-4 --lr_scheduler_type cosine", - shell=True, - check=True, - ) + try: + subprocess.run( + "trl sft --max_steps 1 --output_dir tmp-sft --model_name_or_path HuggingFaceM4/tiny-random-LlamaForCausalLM --dataset_name imdb --learning_rate 1e-4 --lr_scheduler_type cosine", + shell=True, + check=True, + ) + except BaseException as exc: + raise AssertionError("An error occured while running the CLI, please double check") from exc @unittest.skipIf(sys.platform.startswith("win"), "Skipping on Windows") def test_dpo_cli(): - subprocess.run( - "trl dpo --max_steps 1 --output_dir tmp-sft --model_name_or_path HuggingFaceM4/tiny-random-LlamaForCausalLM --dataset_name trl-internal-testing/Anthropic-hh-rlhf-processed --learning_rate 1e-4 --lr_scheduler_type cosine", - shell=True, - check=True, - ) + try: + subprocess.run( + "trl dpo --max_steps 1 --output_dir tmp-sft --model_name_or_path HuggingFaceM4/tiny-random-LlamaForCausalLM --dataset_name trl-internal-testing/Anthropic-hh-rlhf-processed --learning_rate 1e-4 --lr_scheduler_type cosine", + shell=True, + check=True, + ) + except BaseException as exc: + raise AssertionError("An error occured while running the CLI, please double check") from exc diff --git a/trl/commands/cli.py b/trl/commands/cli.py index 9f4f8da1f5..239d66b80d 100644 --- a/trl/commands/cli.py +++ b/trl/commands/cli.py @@ -57,8 +57,9 @@ def main(): cwd=os.getcwd(), env=os.environ.copy(), ) - except (CalledProcessError, ChildProcessError): + except (CalledProcessError, ChildProcessError) as exc: console.log(f"TRL - {command_name.upper()} failed on ! See the logs above for further details.") + raise ValueError("TRL CLI failed! Check the traceback above..") from exc if __name__ == "__main__":