Skip to content

Commit

Permalink
fix tests? 😬
Browse files Browse the repository at this point in the history
  • Loading branch information
emmazhou committed Sep 25, 2024
1 parent 5148712 commit 5b5a0d1
Show file tree
Hide file tree
Showing 6 changed files with 115 additions and 91 deletions.
1 change: 1 addition & 0 deletions synapse/server/nodes/spectral_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ def run(self):

filtered_data = ElectricalBroadbandData(
bit_width=data.bit_width,
signed=data.signed,
sample_rate=data.sample_rate,
t0=data.t0,
channels=self.apply_filter(data.channels),
Expand Down
1 change: 1 addition & 0 deletions synapse/simulator/nodes/electrical_broadband.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def run(self):

data = ElectricalBroadbandData(
bit_width=bit_width,
signed=True,
sample_rate=sample_rate,
t0=t0,
channels=[
Expand Down
123 changes: 65 additions & 58 deletions synapse/tests/test_ndtp.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import pytest
import math
import struct
from synapse.api.datatype_pb2 import DataType
from synapse.utils.ndtp import (
Expand All @@ -14,40 +13,45 @@


def test_to_bytes():
assert to_bytes([1, 2, 3, 0], bit_width=2) == (bytearray(b'\x6C'), 0)
assert to_bytes([1, 2, 3, 0], bit_width=2) == (bytearray(b"\x6C"), 0)

assert to_bytes([1, 2, 3, 2, 1], bit_width=2) == (bytearray(b'\x6E\x40'), 2)
assert to_bytes([1, 2, 3, 2, 1], bit_width=2) == (bytearray(b"\x6E\x40"), 2)

assert to_bytes([7, 5, 3, 1], bit_width=12) == (bytearray(b'\x00\x70\x05\x00\x30\x01'), 0)
assert to_bytes([7, 5, 3, 1], bit_width=12) == (
bytearray(b"\x00\x70\x05\x00\x30\x01"),
0,
)

assert to_bytes([-7, -5, -3, -1], bit_width=12, signed=True) == (
bytearray(b"\xFF\x9F\xFB\xFF\xDF\xFF"),
0,
)

assert to_bytes([-7, -5, -3, -1], bit_width=12, signed=True) == (bytearray(b'\xFF\x9F\xFB\xFF\xDF\xFF'), 0)

assert to_bytes(
[7, 5, 3],
bit_width=12,
existing=bytearray(b'\x01\x00'),
writing_bit_offset=4
) == (bytearray(b'\x01\x00\x07\x00\x50\x03'), 0)

[7, 5, 3], bit_width=12, existing=bytearray(b"\x01\x00"), writing_bit_offset=4
) == (bytearray(b"\x01\x00\x07\x00\x50\x03"), 0)

assert to_bytes(
[-7, -5, -3],
bit_width=12,
existing=bytearray(b'\x01\x00'),
existing=bytearray(b"\x01\x00"),
writing_bit_offset=4,
signed=True
) == (bytearray(b'\x01\x0F\xF9\xFF\xBF\xFD'), 0)
signed=True,
) == (bytearray(b"\x01\x0F\xF9\xFF\xBF\xFD"), 0)

assert to_bytes([7, 5, 3], bit_width=12) == (bytearray(b'\x00p\x05\x000'), 4)
assert to_bytes([7, 5, 3], bit_width=12) == (bytearray(b"\x00p\x05\x000"), 4)

assert to_bytes([1, 2, 3, 4], bit_width=8) == (bytearray(b'\x01\x02\x03\x04'), 0)
assert to_bytes([1, 2, 3, 4], bit_width=8) == (bytearray(b"\x01\x02\x03\x04"), 0)

res, offset = to_bytes([7, 5, 3], bit_width=12)
assert res == bytearray(b'\x00p\x05\x000')
assert res == bytearray(b"\x00p\x05\x000")
assert len(res) == 5
assert offset == 4

res, offset = to_bytes([3, 5, 7], bit_width=12, existing=res, writing_bit_offset=offset)
assert res == bytearray(b'\x00\x70\x05\x00\x30\x03\x00\x50\x07')
res, offset = to_bytes(
[3, 5, 7], bit_width=12, existing=res, writing_bit_offset=offset
)
assert res == bytearray(b"\x00\x70\x05\x00\x30\x03\x00\x50\x07")
assert len(res) == 9
assert offset == 0

Expand All @@ -59,32 +63,33 @@ def test_to_bytes():
with pytest.raises(ValueError):
to_bytes([1, 2, 3, 0], 0)


def test_to_ints():
res, offset, _ = to_ints(b'\x6C', 2)
res, offset, _ = to_ints(b"\x6C", 2)
assert res == [1, 2, 3, 0]
assert offset == 8
res, offset, _ = to_ints(b'\x6C', 2, 3)

res, offset, _ = to_ints(b"\x6C", 2, 3)
assert res == [1, 2, 3]
assert offset == 6
res, offset, _ = to_ints(b'\x00\x70\x05\x00\x30\x01', 12)

res, offset, _ = to_ints(b"\x00\x70\x05\x00\x30\x01", 12)
assert res == [7, 5, 3, 1]
assert offset == 48

res, offset, _ = to_ints(b'\x6C', 2, 3, 2)
res, offset, _ = to_ints(b"\x6C", 2, 3, 2)
assert res == [2, 3, 0]
assert offset == 6 + 2

res, offset, _ = to_ints(b'\x00\x07\x00\x50\x03', 12, 3, 4)
res, offset, _ = to_ints(b"\x00\x07\x00\x50\x03", 12, 3, 4)
assert res == [7, 5, 3]
assert offset == 36 + 4

res, offset, _ = to_ints(b'\xFF\xF9\xFF\xBF\xFD', 12, 3, 4, signed=True)
res, offset, _ = to_ints(b"\xFF\xF9\xFF\xBF\xFD", 12, 3, 4, signed=True)
assert res == [-7, -5, -3]
assert offset == 36 + 4

arry = bytearray(b'\x6E\x40')
arry = bytearray(b"\x6E\x40")
res, offset, arry = to_ints(arry, 2, 1)
assert res == [1]
assert offset == 2
Expand All @@ -101,23 +106,23 @@ def test_to_ints():
assert res == [2]
assert offset == 8


# Invalid bit width
with pytest.raises(ValueError):
to_ints(b'\x01', 0)
to_ints(b"\x01", 0)

# Incomplete value
with pytest.raises(ValueError):
to_ints(b'\x01', 3)
to_ints(b"\x01", 3)

# Insufficient data
with pytest.raises(ValueError):
to_ints(b'\x01\x02', 3)
to_ints(b"\x01\x02", 3)


def test_ndtp_payload_broadband():
bit_width = 12
sample_rate = 3,
signed = False,
sample_rate = 3
signed = False
channels = [
NDTPPayloadBroadband.ChannelData(
channel_id=0,
Expand All @@ -130,14 +135,15 @@ def test_ndtp_payload_broadband():
NDTPPayloadBroadband.ChannelData(
channel_id=2,
channel_data=[3000, 2000, 1000],
)
),
]

payload = NDTPPayloadBroadband(signed, bit_width, sample_rate, channels)
p = payload.pack()

u = NDTPPayloadBroadband.unpack(p)
assert u.bit_width == bit_width
assert u.signed == signed
assert len(u.channels) == 3

assert u.channels[0].channel_id == 0
Expand All @@ -149,11 +155,11 @@ def test_ndtp_payload_broadband():
assert u.channels[2].channel_id == 2
assert u.channels[2].channel_data == [3000, 2000, 1000]

assert p[0] == bit_width
assert p[0] >> 1 == bit_width

assert (p[1] << 16) | (p[2] << 8) | p[3] == 3
p = p[4:]
p = p[6:]

unpacked, offset, p = to_ints(p, bit_width=24, count=1)
assert unpacked[0] == 0
assert offset == 24
Expand Down Expand Up @@ -189,7 +195,7 @@ def test_ndtp_payload_broadband():
unpacked, offset, p = to_ints(p, bit_width=bit_width, count=3, start_bit=offset)
assert unpacked == [3000, 2000, 1000]
assert offset == 36


def test_ndtp_payload_spiketrain():
samples = [0, 1, 2, 3, 2]
Expand All @@ -206,6 +212,7 @@ def test_ndtp_payload_spiketrain():
with pytest.raises(ValueError):
payload.pack()


def test_ndtp_header():
header = NDTPHeader(DataType.kBroadband, 1234567890, 42)
packed = header.pack()
Expand All @@ -214,16 +221,17 @@ def test_ndtp_header():

# Invalid version
with pytest.raises(ValueError):
NDTPHeader.unpack(b'\x00' + packed[1:])
NDTPHeader.unpack(b"\x00" + packed[1:])

# Data too smol
with pytest.raises(ValueError):
NDTPHeader.unpack(
struct.pack("<B", NDTP_VERSION) +
struct.pack("<I", DataType.kBroadband) +
struct.pack("<Q", 123)
struct.pack("<B", NDTP_VERSION)
+ struct.pack("<I", DataType.kBroadband)
+ struct.pack("<Q", 123)
)


def test_ndtp_message():
header = NDTPHeader(DataType.kBroadband, timestamp=1234567890, seq_number=42)
payload = NDTPPayloadBroadband(
Expand All @@ -233,24 +241,22 @@ def test_ndtp_message():
channels=[
NDTPPayloadBroadband.ChannelData(
channel_id=c,
channel_data=[c*100 for _ in range(c+1)],
) for c in range(3)
]
channel_data=[c * 100 for _ in range(c + 1)],
)
for c in range(3)
],
)
message = NDTPMessage(header, payload)

packed = message.pack()
unpacked = NDTPMessage.unpack(packed)

assert unpacked.header == message.header
assert isinstance(unpacked.payload, NDTPPayloadBroadband)
assert unpacked.payload == message.payload


header = NDTPHeader(DataType.kSpiketrain, timestamp=1234567890, seq_number=42)
payload = NDTPPayloadSpiketrain(
spike_counts=[1, 2, 3, 2, 1]
)
payload = NDTPPayloadSpiketrain(spike_counts=[1, 2, 3, 2, 1])
message = NDTPMessage(header, payload)

packed = message.pack()
Expand All @@ -261,7 +267,8 @@ def test_ndtp_message():
assert unpacked.payload == message.payload

with pytest.raises(ValueError):
NDTPMessage.unpack(b'\x00' * (NDTPHeader.STRUCT.size + 8)) # Invalid data type
NDTPMessage.unpack(b"\x00" * (NDTPHeader.STRUCT.size + 8)) # Invalid data type


if __name__ == "__main__":
pytest.main()
pytest.main()
50 changes: 26 additions & 24 deletions synapse/tests/test_stream_out.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,26 +9,27 @@ def test_packing_broadband_data():
node = StreamOut(id=1)

bdata = ElectricalBroadbandData(
bit_width = 12,
t0 = 1234567890,
sample_rate=3,
channels = [
ElectricalBroadbandData.ChannelData(
channel_id=0,
channel_data=[1000, 2000, 3000],
),
ElectricalBroadbandData.ChannelData(
channel_id=1,
channel_data=[1234, 1234, 1234, 1234],
),
ElectricalBroadbandData.ChannelData(
channel_id=2,
channel_data=[3000, 2000, 1000, 2000, 3000],
)
]
bit_width=12,
signed=True,
t0=1234567890,
sample_rate=3,
channels=[
ElectricalBroadbandData.ChannelData(
channel_id=0,
channel_data=[1000, 2000, 3000],
),
ElectricalBroadbandData.ChannelData(
channel_id=1,
channel_data=[1234, 1234, 1234, 1234],
),
ElectricalBroadbandData.ChannelData(
channel_id=2,
channel_data=[3000, 2000, 1000, 2000, 3000],
),
],
)

packed = node._pack(DataType.kBroadband, bdata)[0]
packed = node._pack(bdata)[0]
unpacked = NDTPMessage.unpack(packed)

assert unpacked.header.timestamp == bdata.t0
Expand All @@ -37,22 +38,23 @@ def test_packing_broadband_data():

for i in range(len(bdata.channels)):
assert unpacked.payload.channels[i].channel_id == bdata.channels[i].channel_id
assert unpacked.payload.channels[i].channel_data == bdata.channels[i].channel_data

assert (
unpacked.payload.channels[i].channel_data == bdata.channels[i].channel_data
)


def test_packing_spiketrain_data():
node = StreamOut(id=1)

sdata = SpiketrainData(
t0 = 1234567890,
spike_counts = [0, 1, 2, 3, 2, 1, 0],
t0=1234567890,
spike_counts=[0, 1, 2, 3, 2, 1, 0],
)

packed = node._pack(DataType.kSpiketrain, sdata)[0]
packed = node._pack(sdata)[0]
unpacked = NDTPMessage.unpack(packed)

assert unpacked.header.timestamp == sdata.t0
assert len(unpacked.payload.spike_counts) == len(sdata.spike_counts)

assert unpacked.payload.spike_counts == sdata.spike_counts

Loading

0 comments on commit 5b5a0d1

Please sign in to comment.