diff --git a/.github/workflows/self-new-model-pr-caller.yml b/.github/workflows/self-new-model-pr-caller.yml index a59b91df4449b9..2614a189485f76 100644 --- a/.github/workflows/self-new-model-pr-caller.yml +++ b/.github/workflows/self-new-model-pr-caller.yml @@ -43,7 +43,6 @@ jobs: # TODO: get this value from the commit message - name: Check if there are specified models run: | - echo "models=['models/bert', 'models/gpt2']" >> $GITHUB_ENV echo "models=$(python utils/check_if_new_model_added.py | tail -n 1)" >> $GITHUB_ENV # TODO: combine the values @@ -64,6 +63,16 @@ jobs: run: | echo "${{ needs.find_models_to_run.outputs.models }}" + dummy2: + runs-on: ubuntu-22.04 + name: Check specified models to test 3333444444 + needs: find_models_to_run + if: contains(fromJson(needs.find_models_to_run.outputs.matrix), 'dummy') != true + steps: + - name: Check if there are specified models + run: | + echo "${{ needs.find_models_to_run.outputs.models }}" + # # run_models_gpu: # name: Run all tests for the new model diff --git a/utils/check_if_new_model_added.py b/utils/check_if_new_model_added.py index f3ae0d585a1517..1a2730d5e18e16 100644 --- a/utils/check_if_new_model_added.py +++ b/utils/check_if_new_model_added.py @@ -82,7 +82,7 @@ def get_new_python_files() -> List[str]: return get_new_python_files_between_commits(repo.head.commit, branching_commits) -if __name__ == "__main__": +def get_new_model(): new_files = get_new_python_files() reg = re.compile(r"src/transformers/(models/.*)/modeling_.*\.py") @@ -93,4 +93,18 @@ def get_new_python_files() -> List[str]: new_model = find_new_model[0] # It's unlikely we have 2 new modeling files in a pull request. break - print(new_model) + return new_model + + +def get_models_from_commit_message(commit_message): + return ["models/bert", "models/gpt2"] + + +if __name__ == "__main__": + + new_model = get_new_model() + specified_models = get_models_from_commit_message("") + models = ([] if new_model == "" else [new_model]) + specified_models + if len(models) == 0: + models = ["dummy"] + print(models)