Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/qasm convert conditional rangepredicate #1645

Merged
merged 8 commits into from
Oct 31, 2024
1 change: 1 addition & 0 deletions pytket/docs/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ Fixes:

* Fix `symbol_substitution` not preserving opgroups.
* Remove hardware inefficient circuit construction in `_tk1_to_rzsx`
* Support converting conditional `RangePredicate`s to QASM.

1.34.0 (October 2024)
---------------------
Expand Down
67 changes: 51 additions & 16 deletions pytket/pytket/qasm/qasm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1525,25 +1525,27 @@ def mark_as_written(self, label: int, written_variable: str) -> None:
else:
self.variable_writes[label] = [written_variable]

def add_range_predicate(self, op: RangePredicateOp, args: List[Bit]) -> None:
comparator, value = _parse_range(op.lower, op.upper, self.maxwidth)
if (not hqs_header(self.header)) and comparator != "==":
def check_range_predicate(self, op: RangePredicateOp, args: List[Bit]) -> None:
if (not hqs_header(self.header)) and op.lower != op.upper:
raise QASMUnsupportedError(
"OpenQASM conditions must be on a register's fixed value."
)
bits = args[:-1]
variable = args[0].reg_name
assert isinstance(variable, str)
if op.n_inputs != self.cregs[variable].size:
raise QASMUnsupportedError(
"RangePredicate conditions must be an entire classical register"
)
if args[:-1] != self.cregs[variable].to_list():
raise QASMUnsupportedError(
"RangePredicate conditions must be a single classical register"
)

def add_range_predicate(self, op: RangePredicateOp, args: List[Bit]) -> None:
self.check_range_predicate(op, args)
comparator, value = _parse_range(op.lower, op.upper, self.maxwidth)
variable = args[0].reg_name
dest_bit = str(args[-1])
if not hqs_header(self.header):
assert isinstance(variable, str)
if op.n_inputs != self.cregs[variable].size:
raise QASMUnsupportedError(
"OpenQASM conditions must be an entire classical register"
)
if bits != self.cregs[variable].to_list():
raise QASMUnsupportedError(
"OpenQASM conditions must be a single classical register"
)
label = self.strings.add_string(
"".join(
[
Expand Down Expand Up @@ -1660,9 +1662,42 @@ def add_conditional(self, op: Conditional, args: Sequence[UnitID]) -> None:
# Conditional phase is ignored.
return
if op.op.type == OpType.RangePredicate:
raise QASMUnsupportedError(
"Conditional RangePredicate is currently unsupported."
# Special handling for nested ifs
# if condition
# if pred dest = 1
# if not pred dest = 0
# can be written as
# if condition s0 = 1
# if pred s1 = 1
# s2 = s0 & s1
# s3 = s0 & ~s1
# if s2 dest = 1
# if s3 dest = 0
# where s0, s1, s2, and s3 are scratch bits
s0 = self.fresh_scratch_bit()
l = self.strings.add_string(f"{s0} = 1;\n")
# we store the condition in self.strings.conditions
# as it can be later replaced by `replace_condition`
# if possible
self.strings.conditions[l] = ConditionString(variable, "==", op.value)
# output the RangePredicate to s1
s1 = self.fresh_scratch_bit()
assert isinstance(op.op, RangePredicateOp)
self.check_range_predicate(op.op, cast(List[Bit], args[op.width :]))
pred_comparator, pred_value = _parse_range(
op.op.lower, op.op.upper, self.maxwidth
)
pred_variable = args[op.width :][0].reg_name
self.strings.add_string(
f"if({pred_variable}{pred_comparator}{pred_value}) {s1} = 1;\n"
)
s2 = self.fresh_scratch_bit()
self.strings.add_string(f"{s2} = {s0} & {s1};\n")
s3 = self.fresh_scratch_bit()
self.strings.add_string(f"{s3} = {s0} & (~ {s1});\n")
self.strings.add_string(f"if({s2}==1) {args[-1]} = 1;\n")
self.strings.add_string(f"if({s3}==1) {args[-1]} = 0;\n")
return
# we assign the condition to a scratch bit, which we will later remove
# if the condition variable is unchanged.
scratch_bit = self.fresh_scratch_bit()
Expand Down
60 changes: 53 additions & 7 deletions pytket/tests/qasm_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1026,16 +1026,62 @@ def test_conditional_multi_line_ops() -> None:


def test_conditional_range_predicate() -> None:
range_predicate = RangePredicateOp(6, 0, 27)
c = Circuit(0, 8)
c.add_gate(range_predicate, [0, 1, 2, 3, 4, 5, 6], condition=Bit(7))
# remove once https://github.com/CQCL/tket/issues/1508
# is resolved
range_predicate = RangePredicateOp(2, 0, 2)
c = Circuit(0, 5)
c.add_gate(range_predicate, [1, 2, 4])
# https://github.com/CQCL/tket/issues/1642
with pytest.raises(Exception) as errorinfo:
circuit_to_qasm_str(c, header="hqslib1")
assert "Conditional RangePredicate is currently unsupported." in str(
qasm = circuit_to_qasm_str(c, header="hqslib1")
assert "RangePredicate conditions must be an entire classical register" in str(
errorinfo.value
)
# https://github.com/CQCL/tket/issues/1508
range_predicate = RangePredicateOp(6, 0, 27)
c = Circuit(0, 6)
c.add_gate(range_predicate, [0, 1, 2, 3, 4, 5, 5], condition=Bit(5))
qasm = circuit_to_qasm_str(c, header="hqslib1")
assert (
qasm
== """OPENQASM 2.0;
include "hqslib1.inc";

creg c[6];
creg tk_SCRATCH_BITREG_0[4];
if(c[5]==1) tk_SCRATCH_BITREG_0[0] = 1;
if(c<=27) tk_SCRATCH_BITREG_0[1] = 1;
tk_SCRATCH_BITREG_0[2] = tk_SCRATCH_BITREG_0[0] & tk_SCRATCH_BITREG_0[1];
tk_SCRATCH_BITREG_0[3] = tk_SCRATCH_BITREG_0[0] & (~ tk_SCRATCH_BITREG_0[1]);
if(tk_SCRATCH_BITREG_0[2]==1) c[5] = 1;
if(tk_SCRATCH_BITREG_0[3]==1) c[5] = 0;
"""
)
# more test
range_predicate = RangePredicateOp(2, 0, 2)
c = Circuit()
reg_a = c.add_c_register("a", 2)
reg_b = c.add_c_register("b", 2)
reg_d = c.add_c_register("d", 1)
c.add_gate(
range_predicate, reg_a.to_list() + reg_d.to_list(), condition=reg_gt(reg_b, 1)
)
qasm = circuit_to_qasm_str(c, header="hqslib1")
assert (
qasm
== """OPENQASM 2.0;
include "hqslib1.inc";

creg a[2];
creg b[2];
creg d[1];
creg tk_SCRATCH_BITREG_0[4];
if(b>=2) tk_SCRATCH_BITREG_0[0] = 1;
if(a<=2) tk_SCRATCH_BITREG_0[1] = 1;
tk_SCRATCH_BITREG_0[2] = tk_SCRATCH_BITREG_0[0] & tk_SCRATCH_BITREG_0[1];
tk_SCRATCH_BITREG_0[3] = tk_SCRATCH_BITREG_0[0] & (~ tk_SCRATCH_BITREG_0[1]);
if(tk_SCRATCH_BITREG_0[2]==1) d[0] = 1;
if(tk_SCRATCH_BITREG_0[3]==1) d[0] = 0;
"""
)


def test_range_with_maxwidth() -> None:
Expand Down
Loading