Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add bin_spike_ms to Spiketrain data #27

Merged
merged 4 commits into from
Oct 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion synapse/server/nodes/spike_detect.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,5 +91,5 @@ async def run(self):
spike_counts.append(spike_count)

await self.emit_data(
SpiketrainData(t0=data.t0, spike_counts=spike_counts)
SpiketrainData(t0=data.t0, bin_size_ms=self.bin_size_ms, spike_counts=spike_counts)
)
5 changes: 4 additions & 1 deletion synapse/simulator/__init__.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
from synapse.api.node_pb2 import NodeType
from synapse.server.entrypoint import ENTRY_DEFAULTS, main as server
from synapse.server.nodes.spectral_filter import SpectralFilter
from synapse.server.nodes.spike_detect import SpikeDetect
from synapse.server.nodes.stream_in import StreamIn
from synapse.server.nodes.stream_out import StreamOut
from synapse.simulator.nodes.electrical_broadband import ElectricalBroadband
from synapse.simulator.nodes.optical_stimulation import OpticalStimulation


SIMULATOR_NODE_OBJECT_MAP = {
NodeType.kStreamIn: StreamIn,
NodeType.kStreamOut: StreamOut,
NodeType.kSpectralFilter: SpectralFilter,
NodeType.kSpikeDetect: SpikeDetect,
NodeType.kElectricalBroadband: ElectricalBroadband,
NodeType.kOpticalStimulation: OpticalStimulation
}
Expand Down
6 changes: 4 additions & 2 deletions synapse/tests/test_ndtp.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,11 +287,13 @@ def test_ndtp_payload_broadband_large():
def test_ndtp_payload_spiketrain():
samples = [0, 1, 2, 3, 2]

payload = NDTPPayloadSpiketrain(samples)
payload = NDTPPayloadSpiketrain(10, samples)
packed = payload.pack()
unpacked = NDTPPayloadSpiketrain.unpack(packed)

assert unpacked == payload
assert unpacked.bin_size_ms == 10
assert list(unpacked.spike_counts) == samples


def test_ndtp_header():
Expand Down Expand Up @@ -379,7 +381,7 @@ def test_ndtp_message_broadband_large():

def test_ndtp_message_spiketrain():
header = NDTPHeader(DataType.kSpiketrain, timestamp=1234567890, seq_number=42)
payload = NDTPPayloadSpiketrain(spike_counts=[1, 2, 3, 2, 1])
payload = NDTPPayloadSpiketrain(bin_size_ms=10, spike_counts=[1, 2, 3, 2, 1])
message = NDTPMessage(header, payload)

packed = message.pack()
Expand Down
2 changes: 2 additions & 0 deletions synapse/tests/test_stream_out.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,13 +140,15 @@ def test_packing_spiketrain_data():

sdata = SpiketrainData(
t0=1234567890,
bin_size_ms=10,
spike_counts=[0, 1, 2, 3, 2, 1, 0],
)

packed = node._pack(sdata)[0]
unpacked = NDTPMessage.unpack(packed)

assert unpacked.header.timestamp == sdata.t0
assert unpacked.payload.bin_size_ms == sdata.bin_size_ms
assert len(unpacked.payload.spike_counts) == len(sdata.spike_counts)

assert list(unpacked.payload.spike_counts) == list(sdata.spike_counts)
57 changes: 36 additions & 21 deletions synapse/utils/ndtp.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -76,12 +76,12 @@ def to_bytes(
elif byteorder == 'big':
byteorder_is_little = False
else:
raise ValueError(f"Invalid byteorder: {byteorder}")
raise ValueError("Invalid byteorder: " + byteorder)

for py_value in values:
value = py_value
if not (min_value <= value <= max_value):
raise ValueError(f"Value {value} cannot be represented in {bit_width} bits")
raise ValueError("Value " + str(value) + " cannot be represented in " + str(bit_width) + " bits")

# Handle negative values for signed integers
if is_signed and value < 0:
Expand Down Expand Up @@ -144,14 +144,14 @@ def to_ints(
if isinstance(data, (bytes, bytearray)):
data_view = data
else:
raise TypeError(f"Unsupported data type: {type(data)}")
raise TypeError("Unsupported data type: " + str(type(data)))

cdef Py_ssize_t data_len = len(data_view)

if count > 0 and data_len < (bit_width * count + 7) // 8:
raise ValueError(
f"insufficient data for {count} x {bit_width} bit values "
f"(expected {(bit_width * count + 7) // 8} bytes, given {data_len} bytes)"
"insufficient data for " + str(count) + " x " + str(bit_width) + " bit values " +
"(expected " + str((bit_width * count + 7) // 8) + " bytes, given " + str(data_len) + " bytes)"
)

cdef int current_value = 0
Expand All @@ -163,7 +163,7 @@ def to_ints(
cdef int value_index = 0
cdef int max_values = count if count > 0 else (data_len * 8) // bit_width
if max_values == 0:
raise ValueError(f"max_values must be > 0 (got {len(data)} data, {count} count, bit width {bit_width})")
raise ValueError("max_values must be > 0 (got " + str(len(data)) + " data, " + str(count) + " count, bit width " + str(bit_width) + ")")
cdef int[::1] values_array = cython.view.array(shape=(max_values,), itemsize=cython.sizeof(cython.int), format="i")
cdef int sign_bit = 1 << (bit_width - 1)
cdef uint8_t byte
Expand Down Expand Up @@ -218,7 +218,7 @@ def to_ints(
return [values_array[i] for i in range(value_index)], end_bit, data

else:
raise ValueError(f"Invalid byteorder: {byteorder}")
raise ValueError("Invalid byteorder: " + byteorder)

if bits_in_current_value > 0:
if bits_in_current_value == bit_width:
Expand All @@ -230,7 +230,7 @@ def to_ints(
value_index += 1
elif count == 0:
raise ValueError(
f"{bits_in_current_value} bits left over, not enough to form a complete value of bit width {bit_width}"
str(bits_in_current_value) + " bits left over, not enough to form a complete value of bit width " + str(bit_width)
)

if count > 0:
Expand Down Expand Up @@ -325,7 +325,7 @@ cdef class NDTPPayloadBroadband:
cdef int len_data = len(data)
if len_data < payload_h_size:
raise ValueError(
f"Invalid broadband data size {len_data}: expected at least {payload_h_size} bytes"
"Invalid broadband data size " + str(len_data) + ": expected at least " + str(payload_h_size) + " bytes"
)

cdef int bit_width = data[0] >> 1
Expand Down Expand Up @@ -370,10 +370,12 @@ cdef class NDTPPayloadBroadband:


cdef class NDTPPayloadSpiketrain:
cdef public int bin_size_ms
cdef public int[::1] spike_counts # Memoryview of integers

def __init__(self, spike_counts):
def __init__(self, bin_size_ms, spike_counts):
cdef int size, i
self.bin_size_ms = bin_size_ms
self.spike_counts = None

if isinstance(spike_counts, list):
Expand Down Expand Up @@ -403,6 +405,9 @@ cdef class NDTPPayloadSpiketrain:
# Pack the number of spikes (4 bytes)
payload += struct.pack(">I", spike_counts_len)

# Pack the bin_size (1 byte)
payload += struct.pack(">B", self.bin_size_ms)

# Pack clamped spike counts
spike_counts_bytes, _ = to_bytes(
clamped_counts, NDTPPayloadSpiketrain_BIT_WIDTH, is_signed=False
Expand All @@ -415,26 +420,36 @@ cdef class NDTPPayloadSpiketrain:
if isinstance(data, bytes):
data = bytearray(data)

cdef str msg;
cdef int len_data = len(data)
if len_data < 4:
raise ValueError(
f"Invalid spiketrain data size {len_data}: expected at least 4 bytes"
)
if len_data < 5:
msg = "Invalid spiketrain data size "
msg += str(len_data)
msg += " bytes: expected at least 5 bytes"
raise ValueError(msg)

cdef int num_spikes = struct.unpack(">I", data[:4])[0]
cdef bytearray payload = data[4:]
cdef int bin_size_ms = struct.unpack(">B", data[4:5])[0]
cdef bytearray payload = data[5:]
cdef int bits_needed = num_spikes * NDTPPayloadSpiketrain_BIT_WIDTH
cdef int bytes_needed = (bits_needed + 7) // 8

if len(payload) < bytes_needed:
raise ValueError("Insufficient data for spike_counts")
msg = "Insufficient data for spiketrain data (expected "
msg += str(bytes_needed)
msg += "bytes for "
msg += str(num_spikes)
msg += " spikes, got "
msg += str(len(payload))
msg += ")"
raise ValueError(msg)

# Unpack spike_counts
spike_counts, _, _ = to_ints(
payload[:bytes_needed], NDTPPayloadSpiketrain_BIT_WIDTH, num_spikes, is_signed=False
)

return NDTPPayloadSpiketrain(spike_counts)
return NDTPPayloadSpiketrain(bin_size_ms, spike_counts)

def __eq__(self, other):
if not isinstance(other, NDTPPayloadSpiketrain):
Expand Down Expand Up @@ -483,13 +498,13 @@ cdef class NDTPHeader:
cdef int expected_size = NDTPHeader.STRUCT.size
if len(data) < expected_size:
raise ValueError(
f"Invalid header size {len(data)}: expected {expected_size}"
"Invalid header size " + str(len(data)) + ": expected " + str(expected_size)
)

version, data_type, timestamp, seq_number = NDTPHeader.STRUCT.unpack(bytes(data[:expected_size]))
if version != NDTP_VERSION:
raise ValueError(
f"Incompatible version {version}: expected {hex(NDTP_VERSION)}, got {hex(version)}"
"Incompatible version " + str(version) + ": expected " + hex(NDTP_VERSION) + ", got " + hex(version)
)

return NDTPHeader(data_type, timestamp, seq_number)
Expand Down Expand Up @@ -566,10 +581,10 @@ cdef class NDTPMessage:
elif pdtype == DataType.kSpiketrain:
payload = NDTPPayloadSpiketrain.unpack(pbytes)
else:
raise ValueError(f"unknown data type {pdtype}")
raise ValueError("unknown data type " + str(pdtype))

if not NDTPMessage.crc16_verify(data[:-2], crc16_value):
raise ValueError(f"CRC16 verification failed (expected {crc16_value})")
raise ValueError("CRC16 verification failed (expected " + str(crc16_value) + ")")

msg = NDTPMessage(header, payload)
msg._crc16 = crc16_value
Expand Down
13 changes: 9 additions & 4 deletions synapse/utils/ndtp_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,11 +100,12 @@ def to_list(self):


class SpiketrainData:
__slots__ = ["data_type", "t0", "spike_counts"]
__slots__ = ["data_type", "t0", "bin_size_ms", "spike_counts"]

def __init__(self, t0, spike_counts):
def __init__(self, t0, bin_size_ms, spike_counts):
self.data_type = DataType.kSpiketrain
self.t0 = t0
self.bin_size_ms = bin_size_ms
self.spike_counts = spike_counts

def pack(self, seq_number: int):
Expand All @@ -114,7 +115,10 @@ def pack(self, seq_number: int):
timestamp=self.t0,
seq_number=seq_number,
),
payload=NDTPPayloadSpiketrain(spike_counts=self.spike_counts),
payload=NDTPPayloadSpiketrain(
bin_size_ms=self.bin_size_ms,
spike_counts=self.spike_counts
),
)

return [message.pack()]
Expand All @@ -123,6 +127,7 @@ def pack(self, seq_number: int):
def from_ndtp_message(msg: NDTPMessage):
return SpiketrainData(
t0=msg.header.timestamp,
bin_size_ms=msg.payload.bin_size_ms,
spike_counts=msg.payload.spike_counts,
)

Expand All @@ -132,7 +137,7 @@ def unpack(data):
return SpiketrainData.from_ndtp_message(u)

def to_list(self):
return [self.t0, list(self.spike_counts)]
return [self.t0, self.bin_size_ms, list(self.spike_counts)]


SynapseData = Union[SpiketrainData, ElectricalBroadbandData]
Loading