diff --git a/src/openqaoa-core/openqaoa/qaoa_components/variational_parameters/annealingparams.py b/src/openqaoa-core/openqaoa/qaoa_components/variational_parameters/annealingparams.py index a7dfd6db..5f611d23 100644 --- a/src/openqaoa-core/openqaoa/qaoa_components/variational_parameters/annealingparams.py +++ b/src/openqaoa-core/openqaoa/qaoa_components/variational_parameters/annealingparams.py @@ -155,8 +155,12 @@ def empty(cls, qaoa_descriptor: QAOADescriptor, total_annealing_time: float): def plot(self, ax=None, **kwargs): if ax is None: fig, ax = plt.subplots() + else: + fig = ax.get_figure() ax.plot(self.schedule, marker="s", **kwargs) ax.set_xlabel("p", fontsize=14) ax.set_ylabel("s(t)", fontsize=14) ax.xaxis.set_major_locator(MaxNLocator(integer=True)) + + return fig, ax \ No newline at end of file diff --git a/src/openqaoa-core/openqaoa/qaoa_components/variational_parameters/extendedparams.py b/src/openqaoa-core/openqaoa/qaoa_components/variational_parameters/extendedparams.py index a76aebf2..8db928bb 100644 --- a/src/openqaoa-core/openqaoa/qaoa_components/variational_parameters/extendedparams.py +++ b/src/openqaoa-core/openqaoa/qaoa_components/variational_parameters/extendedparams.py @@ -306,6 +306,8 @@ def plot(self, ax=None, **kwargs): if ax is None: fig, ax = plt.subplots((n + 1) // 2, 2, figsize=(9, 9 if n > 2 else 5)) + else: + fig = ax.get_figure() fig.tight_layout(pad=4.0) @@ -342,3 +344,5 @@ def plot(self, ax=None, **kwargs): ax[1].axis("off") elif k == 2: ax[1, 1].axis("off") + + return fig, ax diff --git a/src/openqaoa-core/openqaoa/qaoa_components/variational_parameters/fourierparams.py b/src/openqaoa-core/openqaoa/qaoa_components/variational_parameters/fourierparams.py index f01abe55..2aa113b9 100644 --- a/src/openqaoa-core/openqaoa/qaoa_components/variational_parameters/fourierparams.py +++ b/src/openqaoa-core/openqaoa/qaoa_components/variational_parameters/fourierparams.py @@ -198,6 +198,8 @@ def plot(self, ax=None, **kwargs): if ax is None: fig, ax = plt.subplots(2, figsize=(7, 9)) + else: + fig = ax.get_figure() fig.tight_layout(pad=4.0) @@ -213,6 +215,8 @@ def plot(self, ax=None, **kwargs): ax[1].legend() ax[1].xaxis.set_major_locator(MaxNLocator(integer=True)) + return fig, ax + class QAOAVariationalFourierWithBiasParams(QAOAVariationalBaseParams): """ @@ -425,6 +429,8 @@ def plot(self, ax=None, **kwargs): # "params.u_singles, params.u_pairs") if ax is None: fig, ax = plt.subplots(2, figsize=(7, 9)) + else: + fig = ax.get_figure() fig.tight_layout(pad=4.0) @@ -456,6 +462,8 @@ def plot(self, ax=None, **kwargs): ax[1].legend() ax[1].xaxis.set_major_locator(MaxNLocator(integer=True)) + return fig, ax + class QAOAVariationalFourierExtendedParams(QAOAVariationalBaseParams): r""" @@ -792,3 +800,5 @@ def plot(self, ax=None, **kwargs): if j == 0: ax[i, j + 1].axis("off") + + return fig, ax diff --git a/src/openqaoa-core/tests/test_results.py b/src/openqaoa-core/tests/test_results.py index bbab7ac8..c0a5993c 100644 --- a/src/openqaoa-core/tests/test_results.py +++ b/src/openqaoa-core/tests/test_results.py @@ -657,7 +657,8 @@ def test_rqaoa_result_plot_corr_matrix(self): # test the plot_corr_matrix method for i in range(results["number_steps"]): - results.plot_corr_matrix(step=i) + fig, _ = results.plot_corr_matrix(step=i) + fig.close() def test_rqaoa_result_asdict(self): """