Skip to content

Commit

Permalink
linting
Browse files Browse the repository at this point in the history
  • Loading branch information
phstratmann committed Feb 19, 2024
1 parent 3dde94d commit 11380aa
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 19 deletions.
28 changes: 14 additions & 14 deletions src/lava/lib/optimization/solvers/qubo/solution_readout/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from scipy.sparse import csr_matrix
from abc import ABC, abstractmethod


@implements(SolutionReceiver, protocol=AsyncProtocol)
@requires(CPU)
class SolutionReceiverAbstractPyModel(PyAsyncProcessModel, ABC):
Expand Down Expand Up @@ -60,10 +61,10 @@ def _decompress_state(compressed_states,

variables_32bit = compressed_states[:variables_32bit_num].astype(
np.int32)

variables_1bit = (compressed_states[variables_32bit_num:, None] & (
1 << np.arange(0, num_message_bits))) != 0

# reshape into a 1D array
variables_1bit.reshape(-1)
# If n_vars is not a multiple of num_message_bits, then last entries
Expand Down Expand Up @@ -126,14 +127,11 @@ def run_async(self):
print(f"{self.variables_1bit=}")
print("==============================================================")

# End execution
#self._req_pause = True

@staticmethod
def _check_if_input(results_buffer) -> bool:
"""For QUBO, we know that the readout starts as soon as the 2nd output
(best_timestep) is > 0."""

return results_buffer[1] > 0

@staticmethod
Expand Down Expand Up @@ -244,7 +242,7 @@ def __init__(self, proc):
self.synapses_variables_1bit_3_in.s_in)
self.synapses_variables_1bit_3_in.a_out.connect(
self.spike_integrators.a_in)

# Connect the 32bit InPorts, one by one
for ii in range(variables_32bit_num):
# Create the synapses for InPort ii as self.
Expand All @@ -259,7 +257,6 @@ def __init__(self, proc):
getattr(proc.in_ports,
f"variables_32bit_{ii}_in").connect(synapses_in.s_in)
synapses_in.a_out.connect(self.spike_integrators.a_in)


# Define and connect the SolutionReceiver
self.solution_receiver = SolutionReceiver(
Expand Down Expand Up @@ -295,15 +292,18 @@ def _get_input_weights(variables_1bit_num,
weights = np.zeros((num_spike_int, variables_1bit_num), dtype=np.uint8)

# The first SpikeIntegrators receive 32bit variables
for spike_integrator_id in range(variables_32bit_num, num_spike_int - 1):
variable_start = num_1bit_vars_per_int * (spike_integrator_id - variables_32bit_num) + \
weight_exp
weights[spike_integrator_id, variable_start:variable_start + 8] = \
np.power(2, np.arange(8))
for spike_integrator_id in range(variables_32bit_num,
num_spike_int - 1):
variable_start = num_1bit_vars_per_int * (
spike_integrator_id - variables_32bit_num) + weight_exp
weights[spike_integrator_id,
variable_start:variable_start + 8] = np.power(2,
np.arange(8))
# The last spike integrator might be connected by less than
# num_1bit_vars_per_int neurons
# This happens when mod(num_variables, num_1bit_vars_per_int) != 0
variable_start = num_1bit_vars_per_int * (num_spike_int - variables_32bit_num - 1) + weight_exp
variable_start = num_1bit_vars_per_int * (
num_spike_int - variables_32bit_num - 1) + weight_exp
weights[-1, variable_start:] = np.power(2, np.arange(weights.shape[1]
- variable_start))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -130,13 +130,15 @@ def _validate_input(self,

if isinstance(variables_32bit_init, int) and variables_32bit_num == 1:
return
elif isinstance(variables_32bit_init, list) and len(variables_32bit_init) == variables_32bit_num:
elif (isinstance(variables_32bit_init, list) and
len(variables_32bit_init) == variables_32bit_num):
return
elif isinstance(variables_32bit_init, np.ndarray) and variables_32bit_init.shape[0] == variables_32bit_num:
elif (isinstance(variables_32bit_init, np.ndarray)
and variables_32bit_init.shape[0] == variables_32bit_num):
return
else:
raise ValueError(f"The variables_32bit_num must match the number "
f"of {variables_32bit_init=} provided.")
f"of {variables_32bit_init=} provided.")


class SolutionReceiver(AbstractProcess):
Expand Down Expand Up @@ -201,7 +203,7 @@ def __init__(
log_config: LogConfig, optional
Configuration options for logging.z
"""

super().__init__(
shape=shape,
name=name,
Expand All @@ -216,6 +218,6 @@ def __init__(
init=variables_1bit_init)
self.variables_32bit = Var(shape=(variables_32bit_num,),
init=variables_32bit_init)

# Define InPorts
self.results_in = InPort(shape=(num_spike_integrators,))

0 comments on commit 11380aa

Please sign in to comment.