Skip to content

Commit

Permalink
done till quantum-plugins
Browse files Browse the repository at this point in the history
  • Loading branch information
kessler-frost committed Sep 27, 2023
1 parent e542e98 commit 0996a0f
Show file tree
Hide file tree
Showing 20 changed files with 2,385 additions and 1 deletion.
110 changes: 110 additions & 0 deletions covalent/_results_manager/qresult.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
# Copyright 2021 Agnostiq Inc.
#
# This file is part of Covalent.
#
# Licensed under the Apache License 2.0 (the "License"). A copy of the
# License may be obtained with this software package or at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Use of this file is prohibited except in compliance with the License.
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any

import pennylane as qml
from pennylane.tape import QuantumTape

from .._shared_files.qresult_utils import re_execute
from ..quantum.qclient.core import middleware


class QNodeFutureResult:
"""
A class that stores the `batch_id` of a batch of circuits submitted to the
middleware. The `result` method can then be called to retrieve the results.
Attributes:
device: The Pennylane device used by the original QNode.
interface: The interface of the original QNode.
diff_method: The differentiation method of the original QNode.
qfunc_output: The return value (measurement definition) of the original QNode.
"""

def __init__(
self,
batch_id: str,
interface: str,
original_qnode: qml.QNode,
original_tape: QuantumTape,
):
"""
Initialize a `QNodeFutureResult` instance.
Args:
batch_id: A UUID that identifies a batch of circuits submitted to
the middleware.
"""
self.batch_id = batch_id
self.interface = interface # NOT the original QNode's interface

# Required for batch_transforms and correct output typing.
self.device = original_qnode.device
self.qnode = original_qnode
self.tape = original_tape

self.args = None
self.kwargs = None
self._result = None

def __call__(self, *args, **kwargs):
"""
Store the arguments and keyword arguments of the original QNode call.
"""
self.args = args
self.kwargs = kwargs
return self

def result(self) -> Any:
"""
Retrieve the results for the given `batch_id` from middleware. This method
is blocking until the results are available.
Returns:
The results of the circuit execution.
"""

if self._result is None:
# Get raw results from the middleware.
results = middleware.get_results(self.batch_id)

# Required correct gradient post-processing in some cases.
if self.interface == "autograd":
self._result = results
res = results[0]

if self.interface != "numpy":
interface = self.interface # re-execute with any non-numpy interface
res = results[0] # re-execute with this result

elif self.qnode.interface is None:
interface = None
res = results[0]

elif self.qnode.interface == "auto":
interface = "auto"
res = results

else:
# Skip re-execution.
self._result = results
return results

args, kwargs = self.args, self.kwargs
self._result = re_execute(res, self.qnode, self.tape)(interface, *args, **kwargs)

return self._result
31 changes: 31 additions & 0 deletions covalent/_results_manager/result.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,11 @@
from .._shared_files.config import get_config
from .._shared_files.context_managers import active_lattice_manager
from .._shared_files.defaults import postprocess_prefix, prefix_separator, sublattice_prefix
from .._shared_files.qelectron_utils import QE_DB_DIRNAME
from .._shared_files.util_classes import RESULT_STATUS, Status
from .._workflow.lattice import Lattice
from .._workflow.transport import TransportableObject
from ..quantum.qserver import database as qe_db

if TYPE_CHECKING:
from .._shared_files.util_classes import Status
Expand Down Expand Up @@ -259,6 +261,7 @@ def get_node_result(self, node_id: int) -> dict:
"end_time": self.lattice.transport_graph.get_node_value(node_id, "end_time"),
"status": self._get_node_status(node_id),
"output": self._get_node_output(node_id),
"qelectron": self._get_node_qelectron_data(node_id),
"error": self.lattice.transport_graph.get_node_value(node_id, "error"),
"sublattice_result": self.lattice.transport_graph.get_node_value(
node_id, "sublattice_result"
Expand Down Expand Up @@ -368,6 +371,27 @@ def _get_node_output(self, node_id: int) -> Any:
"""
return self._lattice.transport_graph.get_node_value(node_id, "output")

def _get_node_qelectron_data(self, node_id: int) -> dict:
"""
Return all QElectron data associated with a node.
Args:
node_id: The node id.
Returns:
The QElectron data of said node. Will return None if no data exists.
"""
try:
# Checks existence of QElectron data.
self._lattice.transport_graph.get_node_value(node_id, "qelectron_data_exists")
except KeyError:
return None

results_dir = get_config("dispatcher")["results_dir"]
db_dir = os.path.join(results_dir, self.dispatch_id, QE_DB_DIRNAME)

return qe_db.Database(db_dir).get_db(dispatch_id=self.dispatch_id, node_id=node_id)

def _get_node_error(self, node_id: int) -> Union[None, str]:
"""
Return the error of a node.
Expand Down Expand Up @@ -403,6 +427,7 @@ def _update_node(
sublattice_result: "Result" = None,
stdout: str = None,
stderr: str = None,
qelectron_data_exists: bool = False,
) -> None:
"""
Update the node result in the transport graph.
Expand All @@ -419,6 +444,7 @@ def _update_node(
sublattice_result: The result of the sublattice if any.
stdout: The stdout of the node execution.
stderr: The stderr of the node execution.
qelectron_data_exists: Flag indicating presence of Qelectron(s) inside the task
Returns:
None
Expand Down Expand Up @@ -460,6 +486,11 @@ def _update_node(
if stderr is not None:
self.lattice.transport_graph.set_node_value(node_id, "stderr", stderr)

if qelectron_data_exists:
self.lattice.transport_graph.set_node_value(
node_id, "qelectron_data_exists", qelectron_data_exists
)

app_log.debug("Inside update node - SUCCESS")

def _convert_to_electron_result(self) -> Any:
Expand Down
5 changes: 5 additions & 0 deletions covalent/_shared_files/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,11 @@ def get_default_dispatcher_config():
(os.environ.get("XDG_DATA_HOME") or (os.environ["HOME"] + "/.local/share"))
+ "/covalent/dispatcher_db.sqlite"
),
"qelectron_db_path": os.environ.get("COVALENT_DATABASE")
or (
(os.environ.get("XDG_DATA_HOME") or (os.environ["HOME"] + "/.local/share"))
+ "/covalent/qelectron_db"
),
}


Expand Down
90 changes: 90 additions & 0 deletions covalent/_shared_files/pickling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
# Copyright 2021 Agnostiq Inc.
#
# This file is part of Covalent.
#
# Licensed under the Apache License 2.0 (the "License"). A copy of the
# License may be obtained with this software package or at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Use of this file is prohibited except in compliance with the License.
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


"""Hopefully temporary custom tools to handle pickling and/or un-pickling."""

from contextlib import contextmanager
from typing import Any, Callable, Tuple

from pennylane.ops.qubit.observables import Projector

_PENNYLANE_METHOD_OVERRIDES = (
# class, method_name, method_func
(Projector, "__reduce__", lambda self: (Projector, (self.data[0], self.wires))),
)


def _qml_mods_pickle(func: Callable) -> Callable:
"""
A decorator that applies overrides to select PennyLane objects, making them
pickleable and/or un-pickleable in the local scope.
"""

def _wrapper(*args, **kwargs):
with _method_overrides(_PENNYLANE_METHOD_OVERRIDES):
return func(*args, **kwargs)

return _wrapper


@contextmanager
def _method_overrides(overrides: Tuple[Any, str, Callable]) -> None:
"""
Creates a context where all `overrides` are applied on entry and un-applied on exit.
"""

unapply_overrides = None
try:
unapply_overrides = _apply_method_overrides(overrides)
yield
finally:
unapply_overrides()


def _apply_method_overrides(overrides: Tuple[Any, str, Callable]) -> Callable:
"""
This function is called by the `_method_overrides()` context manager.
It applies the overrides in `_METHOD_OVERRIDES` to the corresponding objects
and returns a function that can later restore those objects.
"""

restoration_list = []
for cls, method_name, func in overrides:
# Attribute will be deleted later if `attr` is a length-1 tuple.
attr = (method_name,)
if hasattr(cls, method_name):
# Attribute will be restored later to the corresponding method.
attr += (getattr(cls, method_name),)

# Store attribute information.
restoration_list.append(attr)

# Use `func` to create or replace the method by name.
setattr(cls, method_name, func)

def _unapply_overrides():
for attr in restoration_list:
# Here `attr` is `(method_name,)` or `(method_name, original_func)`.
if len(attr) == 1:
# Delete attribute that did not exist before.
delattr(cls, attr[0])
else:
# Restore original attribute.
setattr(cls, attr[0], attr[1])

return _unapply_overrides
Loading

0 comments on commit 0996a0f

Please sign in to comment.