Skip to content

Commit

Permalink
CI / CLI: Properly raise error when CLI tests failed (#1446)
Browse files Browse the repository at this point in the history
* properly raise error

* another fix

* Update tests.yml

* Update tests-main.yml
  • Loading branch information
younesbelkada authored Mar 19, 2024
1 parent f976c6d commit eb2d5b2
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 13 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/tests-main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ jobs:
pip install -U git+https://github.com/huggingface/peft.git
pip install -U git+https://github.com/huggingface/transformers.git
# cpu version of pytorch
pip install -e ".[test, diffusers]"
pip install ".[test, diffusers]"
- name: Test with pytest
run: |
make test
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ jobs:
run: |
python -m pip install --upgrade pip
# cpu version of pytorch
pip install -e ".[test, peft, diffusers]"
pip install ".[test, peft, diffusers]"
- name: Test with pytest
run: |
make test
Expand Down
26 changes: 16 additions & 10 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,23 @@

@unittest.skipIf(sys.platform.startswith("win"), "Skipping on Windows")
def test_sft_cli():
subprocess.run(
"trl sft --max_steps 1 --output_dir tmp-sft --model_name_or_path HuggingFaceM4/tiny-random-LlamaForCausalLM --dataset_name imdb --learning_rate 1e-4 --lr_scheduler_type cosine",
shell=True,
check=True,
)
try:
subprocess.run(
"trl sft --max_steps 1 --output_dir tmp-sft --model_name_or_path HuggingFaceM4/tiny-random-LlamaForCausalLM --dataset_name imdb --learning_rate 1e-4 --lr_scheduler_type cosine",
shell=True,
check=True,
)
except BaseException as exc:
raise AssertionError("An error occured while running the CLI, please double check") from exc


@unittest.skipIf(sys.platform.startswith("win"), "Skipping on Windows")
def test_dpo_cli():
subprocess.run(
"trl dpo --max_steps 1 --output_dir tmp-sft --model_name_or_path HuggingFaceM4/tiny-random-LlamaForCausalLM --dataset_name trl-internal-testing/Anthropic-hh-rlhf-processed --learning_rate 1e-4 --lr_scheduler_type cosine",
shell=True,
check=True,
)
try:
subprocess.run(
"trl dpo --max_steps 1 --output_dir tmp-sft --model_name_or_path HuggingFaceM4/tiny-random-LlamaForCausalLM --dataset_name trl-internal-testing/Anthropic-hh-rlhf-processed --learning_rate 1e-4 --lr_scheduler_type cosine",
shell=True,
check=True,
)
except BaseException as exc:
raise AssertionError("An error occured while running the CLI, please double check") from exc
3 changes: 2 additions & 1 deletion trl/commands/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,9 @@ def main():
cwd=os.getcwd(),
env=os.environ.copy(),
)
except (CalledProcessError, ChildProcessError):
except (CalledProcessError, ChildProcessError) as exc:
console.log(f"TRL - {command_name.upper()} failed on ! See the logs above for further details.")
raise ValueError("TRL CLI failed! Check the traceback above..") from exc


if __name__ == "__main__":
Expand Down

0 comments on commit eb2d5b2

Please sign in to comment.