Skip to content

Commit

Permalink
[Bug] Fix GPSR differentiation mode (#616)
Browse files Browse the repository at this point in the history
  • Loading branch information
vytautas-a authored Nov 19, 2024
1 parent 4655065 commit 66219fb
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 7 deletions.
3 changes: 1 addition & 2 deletions qadence/analog/parse_analog.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,8 @@ def _build_rot_ham_evo(
if block.add_pattern and h_addr is not None:
h_block += h_addr
duration = block.parameters.duration
h_norm = block.parameters.h_norm
h_block += h_drive
return HamEvo(h_block / h_norm, duration * h_norm / 1000)
return HamEvo(h_block, duration / 1000)


def _analog_to_hevo(
Expand Down
10 changes: 8 additions & 2 deletions qadence/backends/gpsr.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,11 @@ def general_psr(spectrum: Tensor, n_eqs: int | None = None, shift_prefac: float
sorted_unique_spectral_gaps = torch.tensor(list(sorted_unique_spectral_gaps)[:n_eqs])

if n_eqs == 1:
return single_gap_psr
return partial(
single_gap_psr,
spectral_gap=sorted_unique_spectral_gaps,
shift=shift_prefac * torch.tensor([PI / 2], dtype=torch.get_default_dtype()),
)
else:
return partial(
multi_gap_psr,
Expand Down Expand Up @@ -110,7 +114,9 @@ def multi_gap_psr(
batch_size = max(t.size(0) for t in param_dict.values())

# get shift values
shifts = shift_prefac * torch.linspace(PI / 2 - PI / 5, PI / 2 + PI / 5, n_eqs)
shifts = shift_prefac * torch.linspace(
PI / 2 - PI / 4, PI / 2 + PI / 5, n_eqs
) # breaking the symmetry of sampling range around PI/2
device = torch.device("cpu")
try:
device = [v.device for v in param_dict.values()][0]
Expand Down
2 changes: 0 additions & 2 deletions qadence/execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,8 +248,6 @@ def _(
configuration: Union[BackendConfiguration, dict, None] = None,
) -> Tensor:
observable = observable if isinstance(observable, list) else [observable]
if backend == BackendName.PYQTORCH:
diff_mode = DiffMode.AD
bknd = backend_factory(backend, diff_mode=diff_mode, configuration=configuration)
conv = bknd.convert(circuit, observable)

Expand Down
2 changes: 1 addition & 1 deletion tests/backends/pulser_basic/test_differentiation.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def test_pulser_gpsr(block_id: int) -> None:
pulser_backend = PulserBackend() # type: ignore[arg-type]
conv = pulser_backend.convert(circ, obs)
pulser_circ, pulser_obs, embedding_fn, params = conv
diff_backend = DifferentiableBackend(pulser_backend, diff_mode=DiffMode.GPSR, shift_prefac=0.2)
diff_backend = DifferentiableBackend(pulser_backend, diff_mode=DiffMode.GPSR, shift_prefac=1.0)
expval_pulser = diff_backend.expectation(pulser_circ, pulser_obs, embedding_fn(params, values))
dexpval_x_pulser = torch.autograd.grad(
expval_pulser, values["x"], torch.ones_like(expval_pulser), create_graph=True
Expand Down

0 comments on commit 66219fb

Please sign in to comment.