From cc3698cc6d92e2fda18f6be951cabea5dc3fb6bf Mon Sep 17 00:00:00 2001 From: Abe Coull <85974725+math411@users.noreply.github.com> Date: Fri, 18 Oct 2024 10:48:49 -0700 Subject: [PATCH] fix: correct typing for task results methods (#1039) --- .coveragerc | 5 ++++ src/braket/aws/aws_quantum_task_batch.py | 29 ++++++++++++++++++++---- 2 files changed, 30 insertions(+), 4 deletions(-) diff --git a/.coveragerc b/.coveragerc index fe9662c0a..22cb4caeb 100644 --- a/.coveragerc +++ b/.coveragerc @@ -32,6 +32,11 @@ exclude_lines = # Avoid situation where system version causes coverage issues if sys.version_info.minor == 9: + # Avoid type checking import conditionals + if TYPE_CHECKING: + + + [html] directory = build/coverage diff --git a/src/braket/aws/aws_quantum_task_batch.py b/src/braket/aws/aws_quantum_task_batch.py index 300963a6f..30c3a8658 100644 --- a/src/braket/aws/aws_quantum_task_batch.py +++ b/src/braket/aws/aws_quantum_task_batch.py @@ -16,7 +16,7 @@ import time from concurrent.futures.thread import ThreadPoolExecutor from itertools import repeat -from typing import Any, Union +from typing import TYPE_CHECKING, Any, Union from braket.ahs.analog_hamiltonian_simulation import AnalogHamiltonianSimulation from braket.annealing import Problem @@ -30,6 +30,14 @@ from braket.registers.qubit_set import QubitSet from braket.tasks.quantum_task_batch import QuantumTaskBatch +if TYPE_CHECKING: + from braket.tasks.analog_hamiltonian_simulation_quantum_task_result import ( + AnalogHamiltonianSimulationQuantumTaskResult, + ) + from braket.tasks.annealing_quantum_task_result import AnnealingQuantumTaskResult + from braket.tasks.gate_model_quantum_task_result import GateModelQuantumTaskResult + from braket.tasks.photonic_model_quantum_task_result import PhotonicModelQuantumTaskResult + class AwsQuantumTaskBatch(QuantumTaskBatch): """Executes a batch of quantum tasks in parallel. @@ -331,7 +339,12 @@ def results( fail_unsuccessful: bool = False, max_retries: int = MAX_RETRIES, use_cached_value: bool = True, - ) -> list[AwsQuantumTask]: + ) -> list[ + GateModelQuantumTaskResult + | AnnealingQuantumTaskResult + | PhotonicModelQuantumTaskResult + | AnalogHamiltonianSimulationQuantumTaskResult + ]: """Retrieves the result of every quantum task in the batch. Polling for results happens in parallel; this method returns when all quantum tasks @@ -348,7 +361,8 @@ def results( even when results have already been cached. Default: `True`. Returns: - list[AwsQuantumTask]: The results of all of the quantum tasks in the batch. + list[GateModelQuantumTaskResult | AnnealingQuantumTaskResult | PhotonicModelQuantumTaskResult | AnalogHamiltonianSimulationQuantumTaskResult]: The # noqa: E501 + results of all of the quantum tasks in the batch. `FAILED`, `CANCELLED`, or timed out quantum tasks will have a result of None """ if not self._results or not use_cached_value: @@ -369,7 +383,14 @@ def results( return self._results @staticmethod - def _retrieve_results(tasks: list[AwsQuantumTask], max_workers: int) -> list[AwsQuantumTask]: + def _retrieve_results( + tasks: list[AwsQuantumTask], max_workers: int + ) -> list[ + GateModelQuantumTaskResult + | AnnealingQuantumTaskResult + | PhotonicModelQuantumTaskResult + | AnalogHamiltonianSimulationQuantumTaskResult + ]: with ThreadPoolExecutor(max_workers=max_workers) as executor: result_futures = [executor.submit(task.result) for task in tasks] return [future.result() for future in result_futures]