Skip to content

Commit

Permalink
Disable data caching temporarily (#592)
Browse files Browse the repository at this point in the history
* disable caching

* disable caching

* fix lint
  • Loading branch information
jyu00 authored Oct 26, 2022
1 parent 6d87664 commit 6cd2420
Show file tree
Hide file tree
Showing 7 changed files with 81 additions and 60 deletions.
65 changes: 35 additions & 30 deletions qiskit_ibm_runtime/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

from __future__ import annotations
import copy
import json
from typing import Iterable, Optional, Dict, Sequence, Any, Union
import logging
from dataclasses import asdict
Expand All @@ -27,11 +26,8 @@
from qiskit.primitives import BaseEstimator, EstimatorResult

# TODO import _circuit_key from terra once 0.23 is released
from .qiskit.primitives.utils import _circuit_key
from .qiskit_runtime_service import QiskitRuntimeService
from .utils.estimator_result_decoder import EstimatorResultDecoder
from .utils.json import RuntimeEncoder
from .utils.utils import _hash
from .runtime_job import RuntimeJob
from .utils.deprecation import (
deprecate_arguments,
Expand Down Expand Up @@ -211,16 +207,16 @@ def __init__(
backend = session or backend
self._session = get_default_session(service, backend)

self._first_run = True
self._circuits_map = {}
if self.circuits:
for circuit in self.circuits:
circuit_id = _hash(
json.dumps(_circuit_key(circuit), cls=RuntimeEncoder)
)
if circuit_id not in self._session._circuits_map:
self._circuits_map[circuit_id] = circuit
self._session._circuits_map[circuit_id] = circuit
# self._first_run = True
# self._circuits_map = {}
# if self.circuits:
# for circuit in self.circuits:
# circuit_id = _hash(
# json.dumps(_circuit_key(circuit), cls=RuntimeEncoder)
# )
# if circuit_id not in self._session._circuits_map:
# self._circuits_map[circuit_id] = circuit
# self._session._circuits_map[circuit_id] = circuit

def run( # pylint: disable=arguments-differ
self,
Expand Down Expand Up @@ -279,25 +275,34 @@ def _run( # pylint: disable=arguments-differ
Returns:
Submitted job
"""
circuits_map = {}
circuit_ids = []
for circuit in circuits:
circuit_id = _hash(json.dumps(_circuit_key(circuit), cls=RuntimeEncoder))
circuit_ids.append(circuit_id)
if circuit_id in self._session._circuits_map:
continue
self._session._circuits_map[circuit_id] = circuit
circuits_map[circuit_id] = circuit

if self._first_run:
self._first_run = False
circuits_map.update(self._circuits_map)

# TODO: Re-enable data caching when ntc 1748 is fixed
# circuits_map = {}
# circuit_ids = []
# for circuit in circuits:
# circuit_id = _hash(json.dumps(_circuit_key(circuit), cls=RuntimeEncoder))
# circuit_ids.append(circuit_id)
# if circuit_id in self._session._circuits_map:
# continue
# self._session._circuits_map[circuit_id] = circuit
# circuits_map[circuit_id] = circuit

# if self._first_run:
# self._first_run = False
# circuits_map.update(self._circuits_map)

# inputs = {
# "circuits": circuits_map,
# "circuit_ids": circuit_ids,
# "observables": observables,
# "observable_indices": list(range(len(observables))),
# "parameter_values": parameter_values,
# }
inputs = {
"circuits": circuits_map,
"circuit_ids": circuit_ids,
"circuits": circuits,
"circuit_indices": list(range(len(circuits))),
"observables": observables,
"observable_indices": list(range(len(observables))),
"parameters": [circ.parameters for circ in circuits],
"parameter_values": parameter_values,
}

Expand Down
63 changes: 33 additions & 30 deletions qiskit_ibm_runtime/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from __future__ import annotations
from typing import Dict, Iterable, Optional, Sequence, Any, Union
import copy
import json
import logging
from dataclasses import asdict

Expand All @@ -24,12 +23,9 @@
from qiskit.primitives import BaseSampler, SamplerResult

# TODO import _circuit_key from terra once 0.23 released
from .qiskit.primitives.utils import _circuit_key
from .qiskit_runtime_service import QiskitRuntimeService
from .options import Options
from .utils.sampler_result_decoder import SamplerResultDecoder
from .utils.json import RuntimeEncoder
from .utils.utils import _hash
from .runtime_job import RuntimeJob
from .ibm_backend import IBMBackend
from .session import get_default_session
Expand Down Expand Up @@ -181,16 +177,16 @@ def __init__(
backend = session or backend
self._session = get_default_session(service, backend)

self._first_run = True
self._circuits_map = {}
if self.circuits:
for circuit in self.circuits:
circuit_id = _hash(
json.dumps(_circuit_key(circuit), cls=RuntimeEncoder)
)
if circuit_id not in self._session._circuits_map:
self._circuits_map[circuit_id] = circuit
self._session._circuits_map[circuit_id] = circuit
# self._first_run = True
# self._circuits_map = {}
# if self.circuits:
# for circuit in self.circuits:
# circuit_id = _hash(
# json.dumps(_circuit_key(circuit), cls=RuntimeEncoder)
# )
# if circuit_id not in self._session._circuits_map:
# self._circuits_map[circuit_id] = circuit
# self._session._circuits_map[circuit_id] = circuit

def run( # pylint: disable=arguments-differ
self,
Expand Down Expand Up @@ -238,23 +234,30 @@ def _run( # pylint: disable=arguments-differ
Returns:
Submitted job.
"""
circuits_map = {}
circuit_ids = []
for circuit in circuits:
circuit_id = _hash(json.dumps(_circuit_key(circuit), cls=RuntimeEncoder))
circuit_ids.append(circuit_id)
if circuit_id in self._session._circuits_map:
continue
self._session._circuits_map[circuit_id] = circuit
circuits_map[circuit_id] = circuit

if self._first_run:
self._first_run = False
circuits_map.update(self._circuits_map)

# TODO: Re-enable data caching when ntc 1748 is fixed
# circuits_map = {}
# circuit_ids = []
# for circuit in circuits:
# circuit_id = _hash(json.dumps(_circuit_key(circuit), cls=RuntimeEncoder))
# circuit_ids.append(circuit_id)
# if circuit_id in self._session._circuits_map:
# continue
# self._session._circuits_map[circuit_id] = circuit
# circuits_map[circuit_id] = circuit

# if self._first_run:
# self._first_run = False
# circuits_map.update(self._circuits_map)

# inputs = {
# "circuits": circuits_map,
# "circuit_ids": circuit_ids,
# "parameter_values": parameter_values,
# }
inputs = {
"circuits": circuits_map,
"circuit_ids": circuit_ids,
"circuits": circuits,
"parameters": [circ.parameters for circ in circuits],
"circuit_indices": list(range(len(circuits))),
"parameter_values": parameter_values,
}
combined = Options._merge_options(self._options, kwargs.get("_user_kwargs", {}))
Expand Down
4 changes: 4 additions & 0 deletions test/integration/test_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@

"""Integration tests for Estimator primitive."""

import unittest

import numpy as np

from qiskit.circuit import QuantumCircuit, Parameter
Expand Down Expand Up @@ -110,6 +112,7 @@ def test_estimator_session(self, service):
self.assertEqual(len(result5.values), len(circuits5))
self.assertEqual(len(result5.metadata), len(circuits5))

@unittest.skip("Skip until data caching is reenabled.")
@run_integration_test
def test_estimator_session_circuit_caching(self, service):
"""Verify if estimator primitive circuit caching works"""
Expand Down Expand Up @@ -162,6 +165,7 @@ def test_estimator_session_circuit_caching(self, service):
self.assertNotEqual(result.values[1], -1)
self.assertNotEqual(result.values[1], 1)

@unittest.skip("Skip until data caching is reenabled.")
@run_integration_test
def test_estimator_circuit_caching_with_transpilation_options(self, service):
"""Verify if circuit caching works in estimator primitive
Expand Down
3 changes: 3 additions & 0 deletions test/integration/test_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

"""Integration tests for Sampler primitive."""

import unittest
from math import sqrt

from qiskit.circuit import QuantumCircuit, Gate
Expand Down Expand Up @@ -87,6 +88,7 @@ def test_sampler_non_parameterized_circuits(self, service):
self.assertAlmostEqual(result3.quasi_dists[i][3], 0.5, delta=0.1)
self.assertAlmostEqual(result3.quasi_dists[i][0], 0.5, delta=0.1)

@unittest.skip("Skip until data caching is reenabled.")
@run_integration_test
def test_sampler_non_parameterized_circuit_caching(self, service):
"""Verify if circuit caching works in sampler primitive
Expand Down Expand Up @@ -134,6 +136,7 @@ def test_sampler_non_parameterized_circuit_caching(self, service):
self.assertEqual(result.quasi_dists[0][3], 1)
self.assertEqual(result.quasi_dists[1][31], 1)

@unittest.skip("Skip until data caching is reenabled.")
@run_integration_test
def test_sampler_non_parameterized_circuit_caching_with_transpilation_options(
self, service
Expand Down
2 changes: 2 additions & 0 deletions test/unit/test_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

"""Tests for estimator class."""

import unittest
import json
from unittest.mock import patch

Expand All @@ -30,6 +31,7 @@
class TestEstimator(IBMTestCase):
"""Class for testing the Estimator class."""

@unittest.skip("Skip until data caching is reenabled.")
def test_estimator_circuit_caching(self):
"""Test circuit caching in Estimator class"""
psi1 = RealAmplitudes(num_qubits=2, reps=2)
Expand Down
2 changes: 2 additions & 0 deletions test/unit/test_ibm_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import warnings
from dataclasses import asdict
from typing import Dict
import unittest

from qiskit.circuit import QuantumCircuit
from qiskit.circuit.library import RealAmplitudes
Expand Down Expand Up @@ -431,6 +432,7 @@ def test_run_same_session(self):
inst.run(self.qx, observables=self.obs)
self.assertEqual(session.run.call_count, num_runs)

@unittest.skip("Skip until data caching is reenabled.")
def test_primitives_circuit_caching(self):
"""Test circuit caching in Estimator and Sampler classes"""
psi1 = RealAmplitudes(num_qubits=2, reps=2)
Expand Down
2 changes: 2 additions & 0 deletions test/unit/test_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import json
from unittest.mock import patch
import unittest

from qiskit.circuit.library import RealAmplitudes

Expand All @@ -29,6 +30,7 @@
class TestSampler(IBMTestCase):
"""Class for testing the Sampler class."""

@unittest.skip("Skip until data caching is reenabled.")
def test_sampler_circuit_caching(self):
"""Test circuit caching in Sampler class"""

Expand Down

0 comments on commit 6cd2420

Please sign in to comment.