diff --git a/src/braket/pennylane_plugin/braket_device.py b/src/braket/pennylane_plugin/braket_device.py index 84311379..a7fc1bc4 100644 --- a/src/braket/pennylane_plugin/braket_device.py +++ b/src/braket/pennylane_plugin/braket_device.py @@ -101,6 +101,9 @@ class BraketQubitDevice(QubitDevice): execution. verbatim (bool): Whether to run tasks in verbatim mode. Note that verbatim mode only supports the native gate set of the device. Default False. + parametrize_differentiable (bool): Whether to bind differentiable parameters (parameters + marked with ``required_grad=True``) on the Braket device rather than in PennyLane. + Default: True. `**run_kwargs`: Variable length keyword arguments for ``braket.devices.Device.run()``. """ name = "Braket PennyLane plugin" @@ -116,6 +119,7 @@ def __init__( shots: Union[int, None], noise_model: Optional[NoiseModel] = None, verbatim: bool = False, + parametrize_differentiable: bool = True, **run_kwargs, ): if DeviceActionType.OPENQASM not in device.properties.action: @@ -133,6 +137,7 @@ def __init__( self._circuit = None self._task = None self._noise_model = noise_model + self._parametrize_differentiable = parametrize_differentiable self._run_kwargs = run_kwargs self._supported_ops = supported_operations(self._device, verbatim=verbatim) self._check_supported_result_types() @@ -361,7 +366,11 @@ def shadow_expval(self, obs, circuit): def execute(self, circuit: QuantumTape, compute_gradient=False, **run_kwargs) -> np.ndarray: self.check_validity(circuit.operations, circuit.observables) - trainable = BraketQubitDevice._get_trainable_parameters(circuit) if compute_gradient else {} + trainable = ( + BraketQubitDevice._get_trainable_parameters(circuit) + if compute_gradient or self._parametrize_differentiable + else {} + ) self._circuit = self._pl_to_braket_circuit( circuit, compute_gradient=compute_gradient, @@ -590,9 +599,22 @@ def batch_execute(self, circuits, **run_kwargs): for circuit in circuits: self.check_validity(circuit.operations, circuit.observables) - braket_circuits = [ - self._pl_to_braket_circuit(circuit, **run_kwargs) for circuit in circuits - ] + all_trainable = [] + braket_circuits = [] + for circuit in circuits: + trainable = ( + BraketQubitDevice._get_trainable_parameters(circuit) + if self._parametrize_differentiable + else {} + ) + all_trainable.append(trainable) + braket_circuits.append( + self._pl_to_braket_circuit( + circuit, + trainable_indices=frozenset(trainable.keys()), + **run_kwargs, + ) + ) batch_shots = 0 if self.analytic else self.shots @@ -604,6 +626,9 @@ def batch_execute(self, circuits, **run_kwargs): max_connections=self._max_connections, poll_timeout_seconds=self._poll_timeout_seconds, poll_interval_seconds=self._poll_interval_seconds, + inputs=[{f"p_{k}": v for k, v in trainable.items()} for trainable in all_trainable] + if self._parametrize_differentiable + else [], **self._run_kwargs, ) # Call results() to retrieve the Braket results in parallel. diff --git a/test/integ_tests/test_tracking.py b/test/integ_tests/test_tracking.py index c5504f9a..62e77e43 100644 --- a/test/integ_tests/test_tracking.py +++ b/test/integ_tests/test_tracking.py @@ -65,7 +65,7 @@ def circuit(x): assert len(tracker.history["braket_task_id"]) == 3 - if type(dev) == BraketAwsQubitDevice: + if isinstance(dev, BraketAwsQubitDevice): durations = tracker.history["braket_simulator_ms"] billed_durations = tracker.history["braket_simulator_billed_ms"] assert len(durations) == 3 diff --git a/test/unit_tests/test_braket_device.py b/test/unit_tests/test_braket_device.py index 53423b69..61bfb419 100644 --- a/test/unit_tests/test_braket_device.py +++ b/test/unit_tests/test_braket_device.py @@ -350,9 +350,9 @@ def test_execute(mock_run): @patch.object(AwsDevice, "run") -def test_execute_legacy(mock_run): +def test_execute_parametrize_differentiable(mock_run): mock_run.return_value = TASK - dev = _aws_device(wires=4, foo="bar") + dev = _aws_device(wires=4, parametrize_differentiable=True, foo="bar") with QuantumTape() as circuit: qml.Hadamard(wires=0) @@ -364,10 +364,6 @@ def test_execute_legacy(mock_run): qml.var(qml.PauliY(2)) qml.sample(qml.PauliZ(3)) - # If the tape is constructed with a QNode, only the parameters marked requires_grad=True - # will appear - circuit._trainable_params = [0] - results = dev._execute_legacy(circuit) assert np.allclose( @@ -392,7 +388,9 @@ def test_execute_legacy(mock_run): Circuit() .h(0) .unitary([0], 1 / np.sqrt(2) * np.array([[1, 1], [1, -1]])) - .rx(0, 0.432) + # When using QuantumTape directly (as opposed to a QNode), + # all parameters are automatically considered differentiable + .rx(0, FreeParameter("p_1")) .cnot(0, 1) .i(2) .i(3) @@ -408,7 +406,7 @@ def test_execute_legacy(mock_run): poll_timeout_seconds=AwsQuantumTask.DEFAULT_RESULTS_POLL_TIMEOUT, poll_interval_seconds=AwsQuantumTask.DEFAULT_RESULTS_POLL_INTERVAL, foo="bar", - inputs={}, + inputs={"p_1": 0.432}, ) @@ -458,17 +456,17 @@ def test_execute_legacy(mock_run): ) CIRCUIT_4.trainable_params = [] -PARAMS_5 = np.array([0.432, 0.543], requires_grad=True) +PARAM_5 = np.tensor(0.543, requires_grad=True) CIRCUIT_5 = QuantumScript( ops=[ qml.Hadamard(wires=0), qml.CNOT(wires=[0, 1]), - qml.RX(PARAMS_5[0], wires=0), - qml.RY(PARAMS_5[1], wires=0), + qml.RX(0.432, wires=0), + qml.RY(PARAM_5, wires=0), ], measurements=[qml.var(qml.PauliX(0) @ qml.PauliY(1))], ) -CIRCUIT_5.trainable_params = [0, 1] +CIRCUIT_5.trainable_params = [1] PARAM_6 = np.tensor(0.432, requires_grad=True) CIRCUIT_6 = QuantumScript( @@ -1067,6 +1065,7 @@ def test_batch_execute_parallel(mock_run_batch): max_connections=AwsQuantumTaskBatch.MAX_CONNECTIONS_DEFAULT, poll_timeout_seconds=AwsQuantumTask.DEFAULT_RESULTS_POLL_TIMEOUT, poll_interval_seconds=AwsQuantumTask.DEFAULT_RESULTS_POLL_INTERVAL, + inputs=[], foo="bar", ) @@ -1159,6 +1158,61 @@ def test_batch_execute_partial_fail_parallel_tracker(mock_run_batch): callback.assert_called_with(latest=latest, history=history, totals=totals) +@patch.object(AwsDevice, "run_batch") +def test_batch_execute_parametrize_differentiable(mock_run_batch): + """Test batch_execute(parallel=True) correctly calls batch execution methods in Braket SDK""" + mock_run_batch.return_value = TASK_BATCH + dev = _aws_device(wires=4, foo="bar", parametrize_differentiable=True, parallel=True) + + with QuantumTape() as circuit1: + qml.Hadamard(wires=0) + qml.QubitUnitary(1 / np.sqrt(2) * np.tensor([[1, 1], [1, -1]], requires_grad=True), wires=0) + qml.RX(0.432, wires=0) + qml.CNOT(wires=[0, 1]) + qml.expval(qml.PauliX(1)) + + with QuantumTape() as circuit2: + qml.Hadamard(wires=0) + qml.RX(0.123, wires=0) + qml.CNOT(wires=[0, 1]) + qml.sample(qml.PauliZ(3)) + + expected_1 = ( + Circuit() + .h(0) + .unitary([0], 1 / np.sqrt(2) * np.array([[1, 1], [1, -1]])) + .rx(0, FreeParameter("p_1")) + .cnot(0, 1) + .i(2) + .i(3) + .expectation(observable=Observable.X(), target=1) + ) + + expected_2 = ( + Circuit() + .h(0) + .rx(0, FreeParameter("p_0")) + .cnot(0, 1) + .i(2) + .i(3) + .sample(observable=Observable.Z(), target=3) + ) + + circuits = [circuit1, circuit2] + dev.batch_execute(circuits) + mock_run_batch.assert_called_with( + [expected_1, expected_2], + s3_destination_folder=("foo", "bar"), + shots=SHOTS, + max_parallel=None, + max_connections=AwsQuantumTaskBatch.MAX_CONNECTIONS_DEFAULT, + poll_timeout_seconds=AwsQuantumTask.DEFAULT_RESULTS_POLL_TIMEOUT, + poll_interval_seconds=AwsQuantumTask.DEFAULT_RESULTS_POLL_INTERVAL, + inputs=[{"p_1": 0.432}, {"p_0": 0.123}], + foo="bar", + ) + + @pytest.mark.parametrize("old_return_type", [True, False]) @patch.object(AwsDevice, "run") def test_execute_all_samples(mock_run, old_return_type): @@ -1803,6 +1857,7 @@ def _aws_device( device_arn="baz", action_properties=ACTION_PROPERTIES, native_gate_set=None, + parametrize_differentiable=False, **kwargs, ): properties_mock.action = {DeviceActionType.OPENQASM: action_properties} @@ -1821,6 +1876,7 @@ def _aws_device( device_arn=device_arn, aws_session=aws_session_mock, shots=shots, + parametrize_differentiable=parametrize_differentiable, **kwargs, ) # needed by the BraketAwsQubitDevice.capabilities function @@ -2032,6 +2088,7 @@ def test_batch_execute_with_noise_model( poll_interval_seconds=AwsQuantumTask.DEFAULT_RESULTS_POLL_INTERVAL, max_connections=100, max_parallel=None, + inputs=[], )