diff --git a/tools/tpcds.py b/tools/tpcds.py index 282e2e675b..407a4ed7b9 100644 --- a/tools/tpcds.py +++ b/tools/tpcds.py @@ -25,6 +25,20 @@ WHEEL_NAME = "getdaft-0.3.0.dev0-cp38-abi3-manylinux_2_31_x86_64.whl" RETRY_ATTEMPTS = 5 +auth = Auth.Token(gha_run_cluster_job.get_oauth_token()) +g = Github(auth=auth) +repo = g.get_repo("Eventual-Inc/Daft") + + +def dispatch(workflow: Workflow, branch_name: str, inputs: dict): + print(f"Launching workflow '{workflow.name}' on the branch '{branch_name}' with the inputs '{inputs}'") + created = workflow.create_dispatch( + ref=branch_name, + inputs=inputs, + ) + if not created: + raise RuntimeError("Could not create workflow, suggestion: run again with --verbose") + def sleep_and_then_retry(sleep_amount_sec: int = 3): time.sleep(sleep_amount_sec) @@ -55,10 +69,6 @@ def get_name_and_commit_hash(branch_name: Optional[str]) -> tuple[str, str]: return name, commit_hash -auth = Auth.Token(gha_run_cluster_job.get_oauth_token()) -g = Github(auth=auth) - - def run_build( branch_name: str, commit_hash: str, @@ -70,19 +80,11 @@ def run_build( print("Workflow aborted") exit(1) - repo = g.get_repo("Eventual-Inc/Daft") workflow = repo.get_workflow("build-commit.yaml") pre_creation_latest_run = get_latest_run(workflow) - inputs = {"arch": "x86"} - print(f"Launching new 'build-commit' workflow with the following inputs: {inputs}") - created = workflow.create_dispatch( - ref=branch_name, - inputs=inputs, - ) - if not created: - raise RuntimeError("Could not create workflow, suggestion: run again with --verbose") + dispatch(workflow=workflow, branch_name=branch_name, inputs={"arch": "x86"}) post_creation_latest_run = None for _ in range(RETRY_ATTEMPTS): @@ -99,7 +101,7 @@ def run_build( if not post_creation_latest_run: raise RuntimeError("Unable to locate the new run request for the 'build-commit' workflow") - print(f"Latest 'build-commit' workflow run found with id: {post_creation_latest_run.id}") + print(f"Launched new 'build-commit' workflow with id: {post_creation_latest_run.id}") print(f"View the workflow run at: {post_creation_latest_run.url}") while True: @@ -139,7 +141,6 @@ def find_wheel(commit_hash: str) -> Optional[str]: if "Key" in wheel: wheel_path = Path(wheel["Key"]) wheel_name = wheel_path.name - print(wheel_name) if wheel_name == WHEEL_NAME: wheel_urls.append( f"https://github-actions-artifacts-bucket.s3.us-west-2.amazonaws.com/builds/{commit_hash}/{wheel_name}" @@ -151,26 +152,49 @@ def find_wheel(commit_hash: str) -> Optional[str]: return wheel_urls[0] if wheel_urls else None -def build(branch_name: Optional[str], force: bool): +def build(branch_name: str, commit_hash: str, force: bool) -> str: """Runs a build on the given branch. If the branch has already been built, it will reuse the already built wheel. """ - branch_name, commit_hash = get_name_and_commit_hash(branch_name) - print(f"Checking if a build exists for the branch '{branch_name}' (commit-hash: {commit_hash})") wheel_url = find_wheel(commit_hash) - if wheel_url: - if force: - wheel_url = run_build(branch_name, commit_hash) - else: - print(f"Wheel already found at url {wheel_url}; re-using") - else: + should_build = force or wheel_url is None + if should_build: wheel_url = run_build(branch_name, commit_hash) + else: + # wheel_url must be non-None if this branch is executed + print(f"Wheel already found at url {wheel_url}; re-using") + + return wheel_url + - print(wheel_url) +def run( + wheel_url: str, + branch_name: str, +): + workflow = repo.get_workflow("run-cluster.yaml") + dispatch( + workflow=workflow, + branch_name=branch_name, + inputs={ + "daft_wheel_url": wheel_url, + "working_dir": "benchmarking/tpcds", + "entrypoint_script": "ray_entrypoint.py", + "entrypoint_args": "--question=3 --scale-factor=100", + }, + ) + + +def main( + branch_name: Optional[str], + force: bool, +): + branch_name, commit_hash = get_name_and_commit_hash(branch_name) + wheel_url = build(branch_name=branch_name, commit_hash=commit_hash, force=force) + run(wheel_url=wheel_url, branch_name=branch_name) if __name__ == "__main__": @@ -183,7 +207,7 @@ def build(branch_name: Optional[str], force: bool): if args.verbose: enable_console_debug_logging() - build( + main( branch_name=args.ref, force=args.force, )