Skip to content

Commit

Permalink
lintin
Browse files Browse the repository at this point in the history
  • Loading branch information
phstratmann committed Jan 22, 2024
1 parent 40617b8 commit a34329f
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 84 deletions.
38 changes: 20 additions & 18 deletions src/lava/lib/optimization/solvers/qubo/cost_integrator/process.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,21 +52,20 @@ class CostIntegrator(AbstractProcess):
"""

def __init__(
self,
*,
shape: ty.Tuple[int, ...] = (1,),
target_cost: int = -2**31 + 1,
timeout: int = 2**24 - 1,
name: ty.Optional[str] = None,
log_config: ty.Optional[LogConfig] = None,
self,
*,
shape: ty.Tuple[int, ...] = (1,),
target_cost: int = -2 ** 31 + 1,
timeout: int = 2 ** 24 - 1,
name: ty.Optional[str] = None,
log_config: ty.Optional[LogConfig] = None,
) -> None:

self._input_validation(target_cost=target_cost,
timeout=timeout)

super().__init__(shape=shape,
target_cost = target_cost,
timeout = timeout,
target_cost=target_cost,
timeout=timeout,
name=name,
log_config=log_config)
self.cost_in = InPort(shape=shape)
Expand Down Expand Up @@ -95,11 +94,14 @@ def __init__(

@staticmethod
def _input_validation(target_cost, timeout) -> None:

assert (target_cost is not None and timeout is not None), \
f"Both the target_cost and the timeout must be defined"
assert 0 > target_cost >= -2**31 + 1, \
f"The target cost must in the range [-2**32 + 1, 0), " \
f"but is {target_cost}."
assert 0 < timeout <= 2**24 - 1, f"The timeout must be in the range (" \
f"0, 2**24 - 1], but is {timeout}."
if (target_cost is None and timeout is None):
raise ValueError(
f"Both the target_cost and the timeout must be defined")
if target_cost > 0 or target_cost < - 2 ** 31 + 1:
raise ValueError(
f"The target cost must in the range [-2**32 + 1, 0], "
f"but is {target_cost}.")
if timeout <= 0 or timeout > 2 ** 24 - 1:
raise ValueError(
f"The timeout must be in the range (0, 2**24 - 1], but is "
f"{timeout}.")
75 changes: 28 additions & 47 deletions src/lava/lib/optimization/solvers/qubo/solution_readout/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,7 @@
from lava.magma.core.resources import CPU
from lava.magma.core.sync.protocols.async_protocol import AsyncProtocol

from lava.lib.optimization.solvers.qubo.solution_readout.process import \
(
from lava.lib.optimization.solvers.qubo.solution_readout.process import (
SolutionReceiver, SpikeIntegrator
)
from lava.magma.core.sync.protocols.loihi_protocol import LoihiProtocol
Expand Down Expand Up @@ -62,7 +61,7 @@ def run_async(self):
_, _, states = self._decompress_state(
compressed_states=results_buffer,
num_message_bits=num_message_bits,
num_vars=num_vars) #[:self.best_state.shape[0]]
num_vars=num_vars)
self.best_state = states
self._req_pause = True

Expand All @@ -76,8 +75,7 @@ def _decompress_state(compressed_states, num_message_bits, num_vars):
cost = int(compressed_states[0])
timestep = int(compressed_states[1])
states = (compressed_states[2:, None] & (
1 << np.arange(0, num_message_bits))) != 0
#1 << np.arange(num_message_bits - 1, -1, -1))) != 0
1 << np.arange(0, num_message_bits))) != 0
# reshape into a 1D array
states.reshape(-1)
# If n_vars is not a multiple of num_message_bits, then last entries
Expand All @@ -86,22 +84,6 @@ def _decompress_state(compressed_states, num_message_bits, num_vars):
return cost, timestep, states




"""
def test_code():
# Assuming you have a 32-bit integer numpy array
original_array = np.array([4294967295, 2147483647, 0, 8983218],
dtype=np.uint32)
# Use bitwise AND operation to convert each integer to a boolean array
boolean_array = (original_array[:, None] & (1 << np.arange(31, -1, -1))) != 0
# Display the result
print(boolean_array)
"""

@implements(proc=SolutionReadoutEthernet, protocol=LoihiProtocol)
@requires(CPU)
class SolutionReadoutEthernetModel(AbstractSubProcessModel):
Expand All @@ -128,7 +110,6 @@ def __init__(self, proc):
)
self.synapses_state_in_0 = Sparse(
weights=weights_state_in_0,
#sign_mode=SignMode.EXCITATORY,
num_weight_bits=8,
num_message_bits=num_message_bits,
weight_exp=0,
Expand All @@ -146,15 +127,14 @@ def __init__(self, proc):
)
self.synapses_state_in_1 = Sparse(
weights=weights_state_in_1,
#sign_mode=SignMode.EXCITATORY,
num_weight_bits=8,
num_message_bits=num_message_bits,
weight_exp=8,
)

proc.in_ports.states_in.connect(self.synapses_state_in_1.s_in)
self.synapses_state_in_1.a_out.connect(self.spike_integrators.a_in)

if num_bin_variables > 16:
weights_state_in_2 = self._get_input_weights(
num_vars=num_bin_variables,
Expand All @@ -164,7 +144,6 @@ def __init__(self, proc):
)
self.synapses_state_in_2 = Sparse(
weights=weights_state_in_2,
#sign_mode=SignMode.EXCITATORY,
num_weight_bits=8,
num_message_bits=num_message_bits,
weight_exp=16,
Expand All @@ -182,15 +161,14 @@ def __init__(self, proc):
)
self.synapses_state_in_3 = Sparse(
weights=weights_state_in_3,
#sign_mode=SignMode.EXCITATORY,
num_weight_bits=8,
num_message_bits=num_message_bits,
weight_exp=24,
)
proc.in_ports.states_in.connect(self.synapses_state_in_3.s_in)
self.synapses_state_in_3.a_out.connect(self.spike_integrators.a_in)

# Connect the CostIntegrator
# Connect the CostIntegrator
weights_cost_in = self._get_cost_in_weights(
num_spike_int=num_spike_integrators,
)
Expand All @@ -205,26 +183,24 @@ def __init__(self, proc):
)
self.synapses_timestep_in = Sparse(
weights=weights_timestep_in,
#sign_mode=SignMode.EXCITATORY,
num_weight_bits=8,
num_message_bits=32,
)

proc.in_ports.cost_in.connect(self.synapses_cost_in.s_in)
self.synapses_cost_in.a_out.connect(self.spike_integrators.a_in)
proc.in_ports.timestep_in.connect(self.synapses_timestep_in.s_in)
self.synapses_timestep_in.a_out.connect(self.spike_integrators.a_in)

# Define and connect the SolutionReceiver

self.solution_receiver = SolutionReceiver(
shape=(1,),
num_variables = num_bin_variables,
num_spike_integrators = num_spike_integrators,
num_message_bits = num_message_bits,
best_cost_init = proc.best_cost.get(),
best_state_init = proc.best_state.get(),
best_timestep_init = proc.best_timestep.get()
num_variables=num_bin_variables,
num_spike_integrators=num_spike_integrators,
num_message_bits=num_message_bits,
best_cost_init=proc.best_cost.get(),
best_state_init=proc.best_state.get(),
best_timestep_init=proc.best_timestep.get()
)

self.spike_integrators.s_out.connect(
Expand All @@ -236,22 +212,26 @@ def __init__(self, proc):
proc.vars.best_cost.alias(self.solution_receiver.best_cost)

@staticmethod
def _get_input_weights(num_vars, num_spike_int, num_vars_per_int, weight_exp):
def _get_input_weights(num_vars,
num_spike_int,
num_vars_per_int,
weight_exp) -> csr_matrix:
"""To be verified. Deprecated due to efficiency"""

weights = np.zeros((num_spike_int, num_vars), dtype=np.uint8)

# The first two SpikeIntegrators receive best_cost and best_timestep
for spike_integrator in range(2, num_spike_int - 1):
variable_start = num_vars_per_int * (spike_integrator - 2) + weight_exp
weights[spike_integrator, variable_start:variable_start +
8] = np.power(2,
np.arange(8))
variable_start = num_vars_per_int * (spike_integrator - 2) + \
weight_exp
weights[spike_integrator, variable_start:variable_start + 8] = \
np.power(2, np.arange(8))
# The last spike integrator might be connected by less than
# num_vars_per_int neurons
# This happens when mod(num_variables, num_vars_per_int) != 0
variable_start = num_vars_per_int * (num_spike_int - 3) + weight_exp
weights[-1, variable_start:] = np.power(2, np.arange(weights.shape[1]-variable_start))
weights[-1, variable_start:] = np.power(2, np.arange(weights.shape[1]
- variable_start))

return csr_matrix(weights)

Expand All @@ -261,10 +241,11 @@ def _get_state_in_weights_index(num_vars, num_spike_int, num_vars_per_int):
weights = np.zeros((num_spike_int, num_vars), dtype=np.int8)

# Compute the indices for setting the values to 1
indices = np.arange(0, num_vars_per_int * (num_spike_int - 1), num_vars_per_int)
indices = np.arange(0, num_vars_per_int * (num_spike_int - 1),
num_vars_per_int)

# Set the values to 1 using array indexing
weights[:num_spike_int-1, indices:indices + num_vars_per_int] = 1
weights[:num_spike_int - 1, indices:indices + num_vars_per_int] = 1

# Set the values for the last spike integrator
weights[-1, num_vars_per_int * (num_spike_int - 1):num_vars] = 1
Expand All @@ -274,11 +255,11 @@ def _get_state_in_weights_index(num_vars, num_spike_int, num_vars_per_int):
@staticmethod
def _get_cost_in_weights(num_spike_int: int) -> csr_matrix:
weights = np.zeros((num_spike_int, 1), dtype=int)
weights[0,0] = 1
weights[0, 0] = 1
return csr_matrix(weights)

@staticmethod
def _get_timestep_in_weights(num_spike_int: int) -> csr_matrix:
weights = np.zeros((num_spike_int, 1), dtype=int)
weights[1,0] = 1
weights[1, 0] = 1
return csr_matrix(weights)
39 changes: 20 additions & 19 deletions src/lava/lib/optimization/solvers/qubo/solution_readout/process.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from lava.magma.core.process.ports.connection_config import ConnectionConfig


class SpikeIntegrator(AbstractProcess):
"""GradedVec
Graded spike vector layer. Accumulates and forwards 32bit spikes.
Expand All @@ -24,7 +25,6 @@ class SpikeIntegrator(AbstractProcess):
def __init__(
self,
shape: ty.Tuple[int, ...]) -> None:

super().__init__(shape=shape)

self.a_in = InPort(shape=shape)
Expand Down Expand Up @@ -54,13 +54,13 @@ class SolutionReadoutEthernet(AbstractProcess):
"""

def __init__(
self,
shape: ty.Tuple[int, ...],
connection_config: ConnectionConfig,
num_bin_variables: int,
num_message_bits = 32,
name: ty.Optional[str] = None,
log_config: ty.Optional[LogConfig] = None,
self,
shape: ty.Tuple[int, ...],
connection_config: ConnectionConfig,
num_bin_variables: int,
num_message_bits=32,
name: ty.Optional[str] = None,
log_config: ty.Optional[LogConfig] = None,
) -> None:
"""
Parameters
Expand All @@ -80,7 +80,8 @@ def __init__(
log_config: LogConfig, optional
Configuration options for logging.z"""

num_spike_integrators = 2 + np.ceil(num_bin_variables / num_message_bits).astype(int)
num_spike_integrators = 2 + np.ceil(
num_bin_variables / num_message_bits).astype(int)

super().__init__(
shape=shape,
Expand Down Expand Up @@ -130,16 +131,16 @@ class SolutionReceiver(AbstractProcess):
"""

def __init__(
self,
shape: ty.Tuple[int, ...],
num_variables: int,
best_cost_init: int,
best_state_init: ty.Union[npty.ArrayLike, int],
num_spike_integrators: int,
best_timestep_init: int,
num_message_bits: int = 24,
name: ty.Optional[str] = None,
log_config: ty.Optional[LogConfig] = None,
self,
shape: ty.Tuple[int, ...],
num_variables: int,
best_cost_init: int,
best_state_init: ty.Union[npty.ArrayLike, int],
num_spike_integrators: int,
best_timestep_init: int,
num_message_bits: int = 24,
name: ty.Optional[str] = None,
log_config: ty.Optional[LogConfig] = None,
) -> None:
super().__init__(
shape=shape,
Expand Down

0 comments on commit a34329f

Please sign in to comment.