Skip to content

Commit

Permalink
visualization bug + better error message for emulator code generation. (
Browse files Browse the repository at this point in the history
#996)

* visualization bug + better error message for emulator code generation.

* adding back test for location ir.
  • Loading branch information
weinbe58 authored Nov 12, 2024
1 parent fa41f6d commit c325957
Show file tree
Hide file tree
Showing 5 changed files with 26 additions and 8 deletions.
13 changes: 12 additions & 1 deletion src/bloqade/compiler/codegen/python/emulator_ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,12 @@ def visit_field_RunTimeVector(
self, node: field.RunTimeVector
) -> Dict[int, Decimal]:
value = self.assignments[node.name]
for new_index, original_index in enumerate(self.original_index):
if original_index >= len(value):
raise ValueError(
f"Index {original_index} is out of bounds for the runtime vector {node.name}"
)

return {
new_index: Decimal(str(value[original_index]))
for new_index, original_index in enumerate(self.original_index)
Expand All @@ -347,6 +353,12 @@ def visit_field_RunTimeVector(
def visit_field_AssignedRunTimeVector(
self, node: field.AssignedRunTimeVector
) -> Dict[int, Decimal]:
for new_index, original_index in enumerate(self.original_index):
if original_index >= len(node.value):
raise ValueError(
f"Index {original_index} is out of bounds for the mask vector."
)

return {
new_index: Decimal(str(node.value[original_index]))
for new_index, original_index in enumerate(self.original_index)
Expand All @@ -357,7 +369,6 @@ def visit_field_ScaledLocations(
self, node: field.ScaledLocations
) -> Dict[int, Decimal]:
target_atoms = {}

for location in node.value.keys():
if location.value >= self.n_sites or location.value < 0:
raise ValueError(
Expand Down
7 changes: 4 additions & 3 deletions src/bloqade/ir/analog_circuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,9 @@ def figure(self, **assignments):
# analysis the SpatialModulation information
spmod_extracted_data: Dict[str, Tuple[List[int], List[float]]] = {}

def process_names(x):
return int(x.split("[")[-1].split("]")[0])

for tab in fig_seq.tabs:
pulse_name = tab.title
field_plots = tab.child.children
Expand All @@ -101,9 +104,7 @@ def figure(self, **assignments):
for ch in channels:
ch_data = Spmod_raw_data[Spmod_raw_data.d0 == ch]

sites = list(
map(lambda x: int(x.split("[")[-1].split("]")[0]), ch_data.d1)
)
sites = list(map(process_names, ch_data.d1))
values = list(ch_data.px.astype(float))

key = f"{pulse_name}.{field_name}.{ch}"
Expand Down
8 changes: 7 additions & 1 deletion src/bloqade/ir/control/field.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,13 @@ def figure(self, **assginment):
return get_ir_figure(self, **assginment)

def _get_data(self, **assignment):
return [self.name], ["vec"]
locs = []
values = []
for i, v in enumerate(self.value):
locs.append(f"{self.name or 'value'}[{i}]")
values.append(str(v))

return locs, values

def show(self, **assignment):
display_ir(self, **assignment)
Expand Down
4 changes: 2 additions & 2 deletions tests/test_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ def test_metadata_filter_scalar():

assert filtered_batch.tasks.keys() == {0, 1, 4}

with pytest.raises(ValueError):
with pytest.raises(Exception):
filtered_batch = batch.filter_metadata(d=[1, 2, 16, 1j])


Expand All @@ -198,7 +198,7 @@ def test_metadata_filter_vector():

filters = dict(d=[1, 8], m=[[0, 1], [1, 0], (0, 0)])

with pytest.raises(ValueError):
with pytest.raises(Exception):
filtered_batch_all = batch.filter_metadata(**filters)


Expand Down
2 changes: 1 addition & 1 deletion tests/test_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def test_assigned_runtime_vec():
)
assert x.print_node() == "AssignedRunTimeVector: sss"
assert x.children() == cast([Decimal("1.0"), Decimal("2.0")])
assert x._get_data() == (["sss"], ["vec"])
assert x._get_data() == (["sss[0]", "sss[1]"], ["1.0", "2.0"])

mystdout = StringIO()
p = PP(mystdout)
Expand Down

0 comments on commit c325957

Please sign in to comment.