Skip to content

Commit

Permalink
feat: Enable FreeParameter creation for trainable parameters (#132)
Browse files Browse the repository at this point in the history
  • Loading branch information
speller26 authored Aug 18, 2023
1 parent f5ebace commit a6a7145
Show file tree
Hide file tree
Showing 3 changed files with 99 additions and 17 deletions.
33 changes: 29 additions & 4 deletions src/braket/pennylane_plugin/braket_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,9 @@ class BraketQubitDevice(QubitDevice):
execution.
verbatim (bool): Whether to run tasks in verbatim mode. Note that verbatim mode only
supports the native gate set of the device. Default False.
parametrize_differentiable (bool): Whether to bind differentiable parameters (parameters
marked with ``required_grad=True``) on the Braket device rather than in PennyLane.
Default: True.
`**run_kwargs`: Variable length keyword arguments for ``braket.devices.Device.run()``.
"""
name = "Braket PennyLane plugin"
Expand All @@ -116,6 +119,7 @@ def __init__(
shots: Union[int, None],
noise_model: Optional[NoiseModel] = None,
verbatim: bool = False,
parametrize_differentiable: bool = True,
**run_kwargs,
):
if DeviceActionType.OPENQASM not in device.properties.action:
Expand All @@ -133,6 +137,7 @@ def __init__(
self._circuit = None
self._task = None
self._noise_model = noise_model
self._parametrize_differentiable = parametrize_differentiable
self._run_kwargs = run_kwargs
self._supported_ops = supported_operations(self._device, verbatim=verbatim)
self._check_supported_result_types()
Expand Down Expand Up @@ -361,7 +366,11 @@ def shadow_expval(self, obs, circuit):

def execute(self, circuit: QuantumTape, compute_gradient=False, **run_kwargs) -> np.ndarray:
self.check_validity(circuit.operations, circuit.observables)
trainable = BraketQubitDevice._get_trainable_parameters(circuit) if compute_gradient else {}
trainable = (
BraketQubitDevice._get_trainable_parameters(circuit)
if compute_gradient or self._parametrize_differentiable
else {}
)
self._circuit = self._pl_to_braket_circuit(
circuit,
compute_gradient=compute_gradient,
Expand Down Expand Up @@ -590,9 +599,22 @@ def batch_execute(self, circuits, **run_kwargs):

for circuit in circuits:
self.check_validity(circuit.operations, circuit.observables)
braket_circuits = [
self._pl_to_braket_circuit(circuit, **run_kwargs) for circuit in circuits
]
all_trainable = []
braket_circuits = []
for circuit in circuits:
trainable = (
BraketQubitDevice._get_trainable_parameters(circuit)
if self._parametrize_differentiable
else {}
)
all_trainable.append(trainable)
braket_circuits.append(
self._pl_to_braket_circuit(
circuit,
trainable_indices=frozenset(trainable.keys()),
**run_kwargs,
)
)

batch_shots = 0 if self.analytic else self.shots

Expand All @@ -604,6 +626,9 @@ def batch_execute(self, circuits, **run_kwargs):
max_connections=self._max_connections,
poll_timeout_seconds=self._poll_timeout_seconds,
poll_interval_seconds=self._poll_interval_seconds,
inputs=[{f"p_{k}": v for k, v in trainable.items()} for trainable in all_trainable]
if self._parametrize_differentiable
else [],
**self._run_kwargs,
)
# Call results() to retrieve the Braket results in parallel.
Expand Down
2 changes: 1 addition & 1 deletion test/integ_tests/test_tracking.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def circuit(x):

assert len(tracker.history["braket_task_id"]) == 3

if type(dev) == BraketAwsQubitDevice:
if isinstance(dev, BraketAwsQubitDevice):
durations = tracker.history["braket_simulator_ms"]
billed_durations = tracker.history["braket_simulator_billed_ms"]
assert len(durations) == 3
Expand Down
81 changes: 69 additions & 12 deletions test/unit_tests/test_braket_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,9 +350,9 @@ def test_execute(mock_run):


@patch.object(AwsDevice, "run")
def test_execute_legacy(mock_run):
def test_execute_parametrize_differentiable(mock_run):
mock_run.return_value = TASK
dev = _aws_device(wires=4, foo="bar")
dev = _aws_device(wires=4, parametrize_differentiable=True, foo="bar")

with QuantumTape() as circuit:
qml.Hadamard(wires=0)
Expand All @@ -364,10 +364,6 @@ def test_execute_legacy(mock_run):
qml.var(qml.PauliY(2))
qml.sample(qml.PauliZ(3))

# If the tape is constructed with a QNode, only the parameters marked requires_grad=True
# will appear
circuit._trainable_params = [0]

results = dev._execute_legacy(circuit)

assert np.allclose(
Expand All @@ -392,7 +388,9 @@ def test_execute_legacy(mock_run):
Circuit()
.h(0)
.unitary([0], 1 / np.sqrt(2) * np.array([[1, 1], [1, -1]]))
.rx(0, 0.432)
# When using QuantumTape directly (as opposed to a QNode),
# all parameters are automatically considered differentiable
.rx(0, FreeParameter("p_1"))
.cnot(0, 1)
.i(2)
.i(3)
Expand All @@ -408,7 +406,7 @@ def test_execute_legacy(mock_run):
poll_timeout_seconds=AwsQuantumTask.DEFAULT_RESULTS_POLL_TIMEOUT,
poll_interval_seconds=AwsQuantumTask.DEFAULT_RESULTS_POLL_INTERVAL,
foo="bar",
inputs={},
inputs={"p_1": 0.432},
)


Expand Down Expand Up @@ -458,17 +456,17 @@ def test_execute_legacy(mock_run):
)
CIRCUIT_4.trainable_params = []

PARAMS_5 = np.array([0.432, 0.543], requires_grad=True)
PARAM_5 = np.tensor(0.543, requires_grad=True)
CIRCUIT_5 = QuantumScript(
ops=[
qml.Hadamard(wires=0),
qml.CNOT(wires=[0, 1]),
qml.RX(PARAMS_5[0], wires=0),
qml.RY(PARAMS_5[1], wires=0),
qml.RX(0.432, wires=0),
qml.RY(PARAM_5, wires=0),
],
measurements=[qml.var(qml.PauliX(0) @ qml.PauliY(1))],
)
CIRCUIT_5.trainable_params = [0, 1]
CIRCUIT_5.trainable_params = [1]

PARAM_6 = np.tensor(0.432, requires_grad=True)
CIRCUIT_6 = QuantumScript(
Expand Down Expand Up @@ -1067,6 +1065,7 @@ def test_batch_execute_parallel(mock_run_batch):
max_connections=AwsQuantumTaskBatch.MAX_CONNECTIONS_DEFAULT,
poll_timeout_seconds=AwsQuantumTask.DEFAULT_RESULTS_POLL_TIMEOUT,
poll_interval_seconds=AwsQuantumTask.DEFAULT_RESULTS_POLL_INTERVAL,
inputs=[],
foo="bar",
)

Expand Down Expand Up @@ -1159,6 +1158,61 @@ def test_batch_execute_partial_fail_parallel_tracker(mock_run_batch):
callback.assert_called_with(latest=latest, history=history, totals=totals)


@patch.object(AwsDevice, "run_batch")
def test_batch_execute_parametrize_differentiable(mock_run_batch):
"""Test batch_execute(parallel=True) correctly calls batch execution methods in Braket SDK"""
mock_run_batch.return_value = TASK_BATCH
dev = _aws_device(wires=4, foo="bar", parametrize_differentiable=True, parallel=True)

with QuantumTape() as circuit1:
qml.Hadamard(wires=0)
qml.QubitUnitary(1 / np.sqrt(2) * np.tensor([[1, 1], [1, -1]], requires_grad=True), wires=0)
qml.RX(0.432, wires=0)
qml.CNOT(wires=[0, 1])
qml.expval(qml.PauliX(1))

with QuantumTape() as circuit2:
qml.Hadamard(wires=0)
qml.RX(0.123, wires=0)
qml.CNOT(wires=[0, 1])
qml.sample(qml.PauliZ(3))

expected_1 = (
Circuit()
.h(0)
.unitary([0], 1 / np.sqrt(2) * np.array([[1, 1], [1, -1]]))
.rx(0, FreeParameter("p_1"))
.cnot(0, 1)
.i(2)
.i(3)
.expectation(observable=Observable.X(), target=1)
)

expected_2 = (
Circuit()
.h(0)
.rx(0, FreeParameter("p_0"))
.cnot(0, 1)
.i(2)
.i(3)
.sample(observable=Observable.Z(), target=3)
)

circuits = [circuit1, circuit2]
dev.batch_execute(circuits)
mock_run_batch.assert_called_with(
[expected_1, expected_2],
s3_destination_folder=("foo", "bar"),
shots=SHOTS,
max_parallel=None,
max_connections=AwsQuantumTaskBatch.MAX_CONNECTIONS_DEFAULT,
poll_timeout_seconds=AwsQuantumTask.DEFAULT_RESULTS_POLL_TIMEOUT,
poll_interval_seconds=AwsQuantumTask.DEFAULT_RESULTS_POLL_INTERVAL,
inputs=[{"p_1": 0.432}, {"p_0": 0.123}],
foo="bar",
)


@pytest.mark.parametrize("old_return_type", [True, False])
@patch.object(AwsDevice, "run")
def test_execute_all_samples(mock_run, old_return_type):
Expand Down Expand Up @@ -1803,6 +1857,7 @@ def _aws_device(
device_arn="baz",
action_properties=ACTION_PROPERTIES,
native_gate_set=None,
parametrize_differentiable=False,
**kwargs,
):
properties_mock.action = {DeviceActionType.OPENQASM: action_properties}
Expand All @@ -1821,6 +1876,7 @@ def _aws_device(
device_arn=device_arn,
aws_session=aws_session_mock,
shots=shots,
parametrize_differentiable=parametrize_differentiable,
**kwargs,
)
# needed by the BraketAwsQubitDevice.capabilities function
Expand Down Expand Up @@ -2032,6 +2088,7 @@ def test_batch_execute_with_noise_model(
poll_interval_seconds=AwsQuantumTask.DEFAULT_RESULTS_POLL_INTERVAL,
max_connections=100,
max_parallel=None,
inputs=[],
)


Expand Down

0 comments on commit a6a7145

Please sign in to comment.