Skip to content

Commit

Permalink
Add serialization support for EstimatorPubResult (#2078)
Browse files Browse the repository at this point in the history
* initial commit

* black
  • Loading branch information
joshuasn authored Dec 6, 2024
1 parent 740916b commit e7b6e93
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 0 deletions.
6 changes: 6 additions & 0 deletions qiskit_ibm_runtime/utils/json.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@
ExecutionSpans,
TwirledSliceSpan,
)
from qiskit_ibm_runtime.utils.estimator_pub_result import EstimatorPubResult

from .noise_learner_result import NoiseLearnerResult

Expand Down Expand Up @@ -323,6 +324,9 @@ def default(self, obj: Any) -> Any: # pylint: disable=arguments-differ
obj.parameter_values.as_array(obj.circuit.parameters),
obj.shots,
)
if isinstance(obj, EstimatorPubResult):
out_val = {"data": obj.data, "metadata": obj.metadata}
return {"__type__": "EstimatorPubResult", "__value__": out_val}
if isinstance(obj, SamplerPubResult):
out_val = {"data": obj.data, "metadata": obj.metadata}
return {"__type__": "SamplerPubResult", "__value__": out_val}
Expand Down Expand Up @@ -470,6 +474,8 @@ def object_hook(self, obj: Any) -> Any:
if shape is not None and isinstance(shape, list):
shape = tuple(shape)
return DataBin(shape=shape, **obj_val["fields"])
if obj_type == "EstimatorPubResult":
return EstimatorPubResult(**obj_val)
if obj_type == "SamplerPubResult":
return SamplerPubResult(**obj_val)
if obj_type == "PubResult":
Expand Down
19 changes: 19 additions & 0 deletions test/unit/test_data_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
)
from qiskit_aer.noise import NoiseModel
from qiskit_ibm_runtime.utils import RuntimeEncoder, RuntimeDecoder
from qiskit_ibm_runtime.utils.estimator_pub_result import EstimatorPubResult
from qiskit_ibm_runtime.utils.noise_learner_result import (
PauliLindbladError,
LayerError,
Expand Down Expand Up @@ -422,6 +423,15 @@ def make_test_pub_results(self):
pub_results.append(pub_result)
return pub_results

def make_test_estimator_pub_results(self):
"""Generates test data for EstimatorPubResult test"""
pub_results = []
pub_result = EstimatorPubResult(DataBin(a=1.0, b=2))
pub_results.append(pub_result)
pub_result = EstimatorPubResult(DataBin(a=1.0, b=2), {"x": 1})
pub_results.append(pub_result)
return pub_results

def make_test_sampler_pub_results(self):
"""Generates test data for SamplerPubResult test"""
pub_results = []
Expand Down Expand Up @@ -605,6 +615,15 @@ def test_pub_result(self):
self.assertIsInstance(decoded, PubResult)
self.assert_pub_results_equal(pub_result, decoded)

def test_estimator_pub_result(self):
"""Test encoding and decoding EstimatorPubResult"""
for pub_result in self.make_test_estimator_pub_results():
payload = {"estimator_pub_result": pub_result}
encoded = json.dumps(payload, cls=RuntimeEncoder)
decoded = json.loads(encoded, cls=RuntimeDecoder)["estimator_pub_result"]
self.assertIsInstance(decoded, EstimatorPubResult)
self.assert_pub_results_equal(pub_result, decoded)

def test_sampler_pub_result(self):
"""Test encoding and decoding SamplerPubResult"""
for pub_result in self.make_test_sampler_pub_results():
Expand Down

0 comments on commit e7b6e93

Please sign in to comment.