Skip to content

Commit

Permalink
Figure handling
Browse files Browse the repository at this point in the history
  • Loading branch information
KilianPoirier committed Jun 19, 2024
1 parent d449fcc commit b42852d
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -155,8 +155,12 @@ def empty(cls, qaoa_descriptor: QAOADescriptor, total_annealing_time: float):
def plot(self, ax=None, **kwargs):
if ax is None:
fig, ax = plt.subplots()
else:
fig = ax.get_figure()

ax.plot(self.schedule, marker="s", **kwargs)
ax.set_xlabel("p", fontsize=14)
ax.set_ylabel("s(t)", fontsize=14)
ax.xaxis.set_major_locator(MaxNLocator(integer=True))

return fig, ax
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,8 @@ def plot(self, ax=None, **kwargs):

if ax is None:
fig, ax = plt.subplots((n + 1) // 2, 2, figsize=(9, 9 if n > 2 else 5))
else:
fig = ax.get_figure()

fig.tight_layout(pad=4.0)

Expand Down Expand Up @@ -342,3 +344,5 @@ def plot(self, ax=None, **kwargs):
ax[1].axis("off")
elif k == 2:
ax[1, 1].axis("off")

return fig, ax
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,8 @@ def plot(self, ax=None, **kwargs):

if ax is None:
fig, ax = plt.subplots(2, figsize=(7, 9))
else:
fig = ax.get_figure()

fig.tight_layout(pad=4.0)

Expand All @@ -213,6 +215,8 @@ def plot(self, ax=None, **kwargs):
ax[1].legend()
ax[1].xaxis.set_major_locator(MaxNLocator(integer=True))

return fig, ax


class QAOAVariationalFourierWithBiasParams(QAOAVariationalBaseParams):
"""
Expand Down Expand Up @@ -425,6 +429,8 @@ def plot(self, ax=None, **kwargs):
# "params.u_singles, params.u_pairs")
if ax is None:
fig, ax = plt.subplots(2, figsize=(7, 9))
else:
fig = ax.get_figure()

fig.tight_layout(pad=4.0)

Expand Down Expand Up @@ -456,6 +462,8 @@ def plot(self, ax=None, **kwargs):
ax[1].legend()
ax[1].xaxis.set_major_locator(MaxNLocator(integer=True))

return fig, ax


class QAOAVariationalFourierExtendedParams(QAOAVariationalBaseParams):
r"""
Expand Down Expand Up @@ -792,3 +800,5 @@ def plot(self, ax=None, **kwargs):

if j == 0:
ax[i, j + 1].axis("off")

return fig, ax
3 changes: 2 additions & 1 deletion src/openqaoa-core/tests/test_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -657,7 +657,8 @@ def test_rqaoa_result_plot_corr_matrix(self):

# test the plot_corr_matrix method
for i in range(results["number_steps"]):
results.plot_corr_matrix(step=i)
fig, _ = results.plot_corr_matrix(step=i)
fig.close()

def test_rqaoa_result_asdict(self):
"""
Expand Down

0 comments on commit b42852d

Please sign in to comment.