diff --git a/pyqtorch/composite/compose.py b/pyqtorch/composite/compose.py index 925ac1d6..a1515a27 100644 --- a/pyqtorch/composite/compose.py +++ b/pyqtorch/composite/compose.py @@ -9,11 +9,12 @@ from torch import Tensor, einsum, rand from torch.nn import Module, ModuleList, ParameterDict -from pyqtorch.apply import apply_operator +from pyqtorch.apply import apply_operator, apply_operator_dm from pyqtorch.embed import ConcretizedCallable, Embedding from pyqtorch.matrices import add_batch_dim from pyqtorch.primitives import CNOT, RX, RY, Parametric, Primitive from pyqtorch.utils import ( + DensityMatrix, Operator, State, ) @@ -233,6 +234,8 @@ def __init__( f"Require all operations to act on a single qubit. Got: {operations}." ) + self._contains_noise = sum([op.noise is not None for op in self.operations]) + def forward( self, state: Tensor, @@ -253,6 +256,18 @@ def forward( ) ), ) + + if self._contains_noise: + # noisy cannot use merged in tensors, fall back to super forward + return super().forward(state, values, embedding) + + if isinstance(state, DensityMatrix): + return apply_operator_dm( + state, + add_batch_dim(self.tensor(values, embedding), batch_size), + self.qubits, + ) + return apply_operator( state, add_batch_dim(self.tensor(values, embedding), batch_size), diff --git a/tests/test_circuit.py b/tests/test_circuit.py index 26095677..7ec033f0 100644 --- a/tests/test_circuit.py +++ b/tests/test_circuit.py @@ -7,7 +7,10 @@ import pyqtorch as pyq from pyqtorch import run, sample +from pyqtorch.noise import DigitalNoiseProtocol, DigitalNoiseType from pyqtorch.utils import ( + DensityMatrix, + density_mat, product_state, ) @@ -55,6 +58,35 @@ def test_merge() -> None: values = {f"theta_{i}": torch.rand(1) for i in range(3)} assert torch.allclose(circ(state, values), mergecirc(state, values)) + # test with density matrices + state = density_mat(state) + circ_out = circ(state, values) + mergecirc_out = mergecirc(state, values) + assert isinstance(circ_out, DensityMatrix) + assert isinstance(mergecirc_out, DensityMatrix) + assert torch.allclose(circ_out, mergecirc_out) + + +def test_merge_noisy_op() -> None: + ops = [ + pyq.RX(0, "theta_0"), + pyq.RY( + 0, + "theta_1", + noise=DigitalNoiseProtocol(DigitalNoiseType.DEPOLARIZING, 0.1, 0), + ), + pyq.RX(0, "theta_2"), + ] + circ = pyq.QuantumCircuit(2, ops) + mergecirc = pyq.Merge(ops) + state = pyq.random_state(2) + values = {f"theta_{i}": torch.rand(1) for i in range(3)} + circ_out = circ(state, values) + mergecirc_out = mergecirc(state, values) + assert isinstance(circ_out, DensityMatrix) + assert isinstance(mergecirc_out, DensityMatrix) + assert torch.allclose(circ_out, mergecirc_out) + @pytest.mark.xfail( reason="Can only merge single qubit gates acting on the same qubit support."