-
Notifications
You must be signed in to change notification settings - Fork 97
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
e542e98
commit 0996a0f
Showing
20 changed files
with
2,385 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,110 @@ | ||
# Copyright 2021 Agnostiq Inc. | ||
# | ||
# This file is part of Covalent. | ||
# | ||
# Licensed under the Apache License 2.0 (the "License"). A copy of the | ||
# License may be obtained with this software package or at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Use of this file is prohibited except in compliance with the License. | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from typing import Any | ||
|
||
import pennylane as qml | ||
from pennylane.tape import QuantumTape | ||
|
||
from .._shared_files.qresult_utils import re_execute | ||
from ..quantum.qclient.core import middleware | ||
|
||
|
||
class QNodeFutureResult: | ||
""" | ||
A class that stores the `batch_id` of a batch of circuits submitted to the | ||
middleware. The `result` method can then be called to retrieve the results. | ||
Attributes: | ||
device: The Pennylane device used by the original QNode. | ||
interface: The interface of the original QNode. | ||
diff_method: The differentiation method of the original QNode. | ||
qfunc_output: The return value (measurement definition) of the original QNode. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
batch_id: str, | ||
interface: str, | ||
original_qnode: qml.QNode, | ||
original_tape: QuantumTape, | ||
): | ||
""" | ||
Initialize a `QNodeFutureResult` instance. | ||
Args: | ||
batch_id: A UUID that identifies a batch of circuits submitted to | ||
the middleware. | ||
""" | ||
self.batch_id = batch_id | ||
self.interface = interface # NOT the original QNode's interface | ||
|
||
# Required for batch_transforms and correct output typing. | ||
self.device = original_qnode.device | ||
self.qnode = original_qnode | ||
self.tape = original_tape | ||
|
||
self.args = None | ||
self.kwargs = None | ||
self._result = None | ||
|
||
def __call__(self, *args, **kwargs): | ||
""" | ||
Store the arguments and keyword arguments of the original QNode call. | ||
""" | ||
self.args = args | ||
self.kwargs = kwargs | ||
return self | ||
|
||
def result(self) -> Any: | ||
""" | ||
Retrieve the results for the given `batch_id` from middleware. This method | ||
is blocking until the results are available. | ||
Returns: | ||
The results of the circuit execution. | ||
""" | ||
|
||
if self._result is None: | ||
# Get raw results from the middleware. | ||
results = middleware.get_results(self.batch_id) | ||
|
||
# Required correct gradient post-processing in some cases. | ||
if self.interface == "autograd": | ||
self._result = results | ||
res = results[0] | ||
|
||
if self.interface != "numpy": | ||
interface = self.interface # re-execute with any non-numpy interface | ||
res = results[0] # re-execute with this result | ||
|
||
elif self.qnode.interface is None: | ||
interface = None | ||
res = results[0] | ||
|
||
elif self.qnode.interface == "auto": | ||
interface = "auto" | ||
res = results | ||
|
||
else: | ||
# Skip re-execution. | ||
self._result = results | ||
return results | ||
|
||
args, kwargs = self.args, self.kwargs | ||
self._result = re_execute(res, self.qnode, self.tape)(interface, *args, **kwargs) | ||
|
||
return self._result |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
# Copyright 2021 Agnostiq Inc. | ||
# | ||
# This file is part of Covalent. | ||
# | ||
# Licensed under the Apache License 2.0 (the "License"). A copy of the | ||
# License may be obtained with this software package or at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Use of this file is prohibited except in compliance with the License. | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
|
||
"""Hopefully temporary custom tools to handle pickling and/or un-pickling.""" | ||
|
||
from contextlib import contextmanager | ||
from typing import Any, Callable, Tuple | ||
|
||
from pennylane.ops.qubit.observables import Projector | ||
|
||
_PENNYLANE_METHOD_OVERRIDES = ( | ||
# class, method_name, method_func | ||
(Projector, "__reduce__", lambda self: (Projector, (self.data[0], self.wires))), | ||
) | ||
|
||
|
||
def _qml_mods_pickle(func: Callable) -> Callable: | ||
""" | ||
A decorator that applies overrides to select PennyLane objects, making them | ||
pickleable and/or un-pickleable in the local scope. | ||
""" | ||
|
||
def _wrapper(*args, **kwargs): | ||
with _method_overrides(_PENNYLANE_METHOD_OVERRIDES): | ||
return func(*args, **kwargs) | ||
|
||
return _wrapper | ||
|
||
|
||
@contextmanager | ||
def _method_overrides(overrides: Tuple[Any, str, Callable]) -> None: | ||
""" | ||
Creates a context where all `overrides` are applied on entry and un-applied on exit. | ||
""" | ||
|
||
unapply_overrides = None | ||
try: | ||
unapply_overrides = _apply_method_overrides(overrides) | ||
yield | ||
finally: | ||
unapply_overrides() | ||
|
||
|
||
def _apply_method_overrides(overrides: Tuple[Any, str, Callable]) -> Callable: | ||
""" | ||
This function is called by the `_method_overrides()` context manager. | ||
It applies the overrides in `_METHOD_OVERRIDES` to the corresponding objects | ||
and returns a function that can later restore those objects. | ||
""" | ||
|
||
restoration_list = [] | ||
for cls, method_name, func in overrides: | ||
# Attribute will be deleted later if `attr` is a length-1 tuple. | ||
attr = (method_name,) | ||
if hasattr(cls, method_name): | ||
# Attribute will be restored later to the corresponding method. | ||
attr += (getattr(cls, method_name),) | ||
|
||
# Store attribute information. | ||
restoration_list.append(attr) | ||
|
||
# Use `func` to create or replace the method by name. | ||
setattr(cls, method_name, func) | ||
|
||
def _unapply_overrides(): | ||
for attr in restoration_list: | ||
# Here `attr` is `(method_name,)` or `(method_name, original_func)`. | ||
if len(attr) == 1: | ||
# Delete attribute that did not exist before. | ||
delattr(cls, attr[0]) | ||
else: | ||
# Restore original attribute. | ||
setattr(cls, attr[0], attr[1]) | ||
|
||
return _unapply_overrides |
Oops, something went wrong.