From 5b5a0d1058bd6780582abf3780c195c7136703fc Mon Sep 17 00:00:00 2001 From: Emma Zhou Date: Tue, 24 Sep 2024 18:28:23 -0700 Subject: [PATCH] =?UTF-8?q?fix=20tests=3F=20=F0=9F=98=AC?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- synapse/server/nodes/spectral_filter.py | 1 + .../simulator/nodes/electrical_broadband.py | 1 + synapse/tests/test_ndtp.py | 123 +++++++++--------- synapse/tests/test_stream_out.py | 50 +++---- synapse/utils/ndtp.py | 28 ++-- synapse/utils/types.py | 3 + 6 files changed, 115 insertions(+), 91 deletions(-) diff --git a/synapse/server/nodes/spectral_filter.py b/synapse/server/nodes/spectral_filter.py index d9e9c33..5dfe246 100644 --- a/synapse/server/nodes/spectral_filter.py +++ b/synapse/server/nodes/spectral_filter.py @@ -90,6 +90,7 @@ def run(self): filtered_data = ElectricalBroadbandData( bit_width=data.bit_width, + signed=data.signed, sample_rate=data.sample_rate, t0=data.t0, channels=self.apply_filter(data.channels), diff --git a/synapse/simulator/nodes/electrical_broadband.py b/synapse/simulator/nodes/electrical_broadband.py index 3e17465..9f9742d 100644 --- a/synapse/simulator/nodes/electrical_broadband.py +++ b/synapse/simulator/nodes/electrical_broadband.py @@ -47,6 +47,7 @@ def run(self): data = ElectricalBroadbandData( bit_width=bit_width, + signed=True, sample_rate=sample_rate, t0=t0, channels=[ diff --git a/synapse/tests/test_ndtp.py b/synapse/tests/test_ndtp.py index 56ff950..29f5be2 100644 --- a/synapse/tests/test_ndtp.py +++ b/synapse/tests/test_ndtp.py @@ -1,5 +1,4 @@ import pytest -import math import struct from synapse.api.datatype_pb2 import DataType from synapse.utils.ndtp import ( @@ -14,40 +13,45 @@ def test_to_bytes(): - assert to_bytes([1, 2, 3, 0], bit_width=2) == (bytearray(b'\x6C'), 0) + assert to_bytes([1, 2, 3, 0], bit_width=2) == (bytearray(b"\x6C"), 0) - assert to_bytes([1, 2, 3, 2, 1], bit_width=2) == (bytearray(b'\x6E\x40'), 2) + assert to_bytes([1, 2, 3, 2, 1], bit_width=2) == (bytearray(b"\x6E\x40"), 2) - assert to_bytes([7, 5, 3, 1], bit_width=12) == (bytearray(b'\x00\x70\x05\x00\x30\x01'), 0) + assert to_bytes([7, 5, 3, 1], bit_width=12) == ( + bytearray(b"\x00\x70\x05\x00\x30\x01"), + 0, + ) + + assert to_bytes([-7, -5, -3, -1], bit_width=12, signed=True) == ( + bytearray(b"\xFF\x9F\xFB\xFF\xDF\xFF"), + 0, + ) - assert to_bytes([-7, -5, -3, -1], bit_width=12, signed=True) == (bytearray(b'\xFF\x9F\xFB\xFF\xDF\xFF'), 0) - assert to_bytes( - [7, 5, 3], - bit_width=12, - existing=bytearray(b'\x01\x00'), - writing_bit_offset=4 - ) == (bytearray(b'\x01\x00\x07\x00\x50\x03'), 0) - + [7, 5, 3], bit_width=12, existing=bytearray(b"\x01\x00"), writing_bit_offset=4 + ) == (bytearray(b"\x01\x00\x07\x00\x50\x03"), 0) + assert to_bytes( [-7, -5, -3], bit_width=12, - existing=bytearray(b'\x01\x00'), + existing=bytearray(b"\x01\x00"), writing_bit_offset=4, - signed=True - ) == (bytearray(b'\x01\x0F\xF9\xFF\xBF\xFD'), 0) + signed=True, + ) == (bytearray(b"\x01\x0F\xF9\xFF\xBF\xFD"), 0) - assert to_bytes([7, 5, 3], bit_width=12) == (bytearray(b'\x00p\x05\x000'), 4) + assert to_bytes([7, 5, 3], bit_width=12) == (bytearray(b"\x00p\x05\x000"), 4) - assert to_bytes([1, 2, 3, 4], bit_width=8) == (bytearray(b'\x01\x02\x03\x04'), 0) + assert to_bytes([1, 2, 3, 4], bit_width=8) == (bytearray(b"\x01\x02\x03\x04"), 0) res, offset = to_bytes([7, 5, 3], bit_width=12) - assert res == bytearray(b'\x00p\x05\x000') + assert res == bytearray(b"\x00p\x05\x000") assert len(res) == 5 assert offset == 4 - res, offset = to_bytes([3, 5, 7], bit_width=12, existing=res, writing_bit_offset=offset) - assert res == bytearray(b'\x00\x70\x05\x00\x30\x03\x00\x50\x07') + res, offset = to_bytes( + [3, 5, 7], bit_width=12, existing=res, writing_bit_offset=offset + ) + assert res == bytearray(b"\x00\x70\x05\x00\x30\x03\x00\x50\x07") assert len(res) == 9 assert offset == 0 @@ -59,32 +63,33 @@ def test_to_bytes(): with pytest.raises(ValueError): to_bytes([1, 2, 3, 0], 0) + def test_to_ints(): - res, offset, _ = to_ints(b'\x6C', 2) + res, offset, _ = to_ints(b"\x6C", 2) assert res == [1, 2, 3, 0] assert offset == 8 - - res, offset, _ = to_ints(b'\x6C', 2, 3) + + res, offset, _ = to_ints(b"\x6C", 2, 3) assert res == [1, 2, 3] assert offset == 6 - - res, offset, _ = to_ints(b'\x00\x70\x05\x00\x30\x01', 12) + + res, offset, _ = to_ints(b"\x00\x70\x05\x00\x30\x01", 12) assert res == [7, 5, 3, 1] assert offset == 48 - res, offset, _ = to_ints(b'\x6C', 2, 3, 2) + res, offset, _ = to_ints(b"\x6C", 2, 3, 2) assert res == [2, 3, 0] assert offset == 6 + 2 - res, offset, _ = to_ints(b'\x00\x07\x00\x50\x03', 12, 3, 4) + res, offset, _ = to_ints(b"\x00\x07\x00\x50\x03", 12, 3, 4) assert res == [7, 5, 3] assert offset == 36 + 4 - res, offset, _ = to_ints(b'\xFF\xF9\xFF\xBF\xFD', 12, 3, 4, signed=True) + res, offset, _ = to_ints(b"\xFF\xF9\xFF\xBF\xFD", 12, 3, 4, signed=True) assert res == [-7, -5, -3] assert offset == 36 + 4 - arry = bytearray(b'\x6E\x40') + arry = bytearray(b"\x6E\x40") res, offset, arry = to_ints(arry, 2, 1) assert res == [1] assert offset == 2 @@ -101,23 +106,23 @@ def test_to_ints(): assert res == [2] assert offset == 8 - # Invalid bit width with pytest.raises(ValueError): - to_ints(b'\x01', 0) - + to_ints(b"\x01", 0) + # Incomplete value with pytest.raises(ValueError): - to_ints(b'\x01', 3) - + to_ints(b"\x01", 3) + # Insufficient data with pytest.raises(ValueError): - to_ints(b'\x01\x02', 3) + to_ints(b"\x01\x02", 3) + def test_ndtp_payload_broadband(): bit_width = 12 - sample_rate = 3, - signed = False, + sample_rate = 3 + signed = False channels = [ NDTPPayloadBroadband.ChannelData( channel_id=0, @@ -130,7 +135,7 @@ def test_ndtp_payload_broadband(): NDTPPayloadBroadband.ChannelData( channel_id=2, channel_data=[3000, 2000, 1000], - ) + ), ] payload = NDTPPayloadBroadband(signed, bit_width, sample_rate, channels) @@ -138,6 +143,7 @@ def test_ndtp_payload_broadband(): u = NDTPPayloadBroadband.unpack(p) assert u.bit_width == bit_width + assert u.signed == signed assert len(u.channels) == 3 assert u.channels[0].channel_id == 0 @@ -149,11 +155,11 @@ def test_ndtp_payload_broadband(): assert u.channels[2].channel_id == 2 assert u.channels[2].channel_data == [3000, 2000, 1000] - assert p[0] == bit_width - + assert p[0] >> 1 == bit_width + assert (p[1] << 16) | (p[2] << 8) | p[3] == 3 - p = p[4:] - + p = p[6:] + unpacked, offset, p = to_ints(p, bit_width=24, count=1) assert unpacked[0] == 0 assert offset == 24 @@ -189,7 +195,7 @@ def test_ndtp_payload_broadband(): unpacked, offset, p = to_ints(p, bit_width=bit_width, count=3, start_bit=offset) assert unpacked == [3000, 2000, 1000] assert offset == 36 - + def test_ndtp_payload_spiketrain(): samples = [0, 1, 2, 3, 2] @@ -206,6 +212,7 @@ def test_ndtp_payload_spiketrain(): with pytest.raises(ValueError): payload.pack() + def test_ndtp_header(): header = NDTPHeader(DataType.kBroadband, 1234567890, 42) packed = header.pack() @@ -214,16 +221,17 @@ def test_ndtp_header(): # Invalid version with pytest.raises(ValueError): - NDTPHeader.unpack(b'\x00' + packed[1:]) + NDTPHeader.unpack(b"\x00" + packed[1:]) # Data too smol with pytest.raises(ValueError): NDTPHeader.unpack( - struct.pack(" str: Can append to an existing byte array, and will correctly handle the case where the end of the existing array is not byte aligned (and may contain a partial byte at the end). """ + + def to_bytes( values: List[int], bit_width: int, @@ -63,7 +65,9 @@ def to_bytes( min_value = -(1 << (bit_width - 1)) max_value = (1 << (bit_width - 1)) - 1 if value < min_value or value > max_value: - raise ValueError(f"signed value {value} doesn't fit in {bit_width} bits") + raise ValueError( + f"signed value {value} doesn't fit in {bit_width} bits" + ) # Convert to two's complement representation if value < 0: value = (1 << bit_width) + value @@ -72,7 +76,9 @@ def to_bytes( raise ValueError("unsigned packing specified, but value is negative") if value >= (1 << bit_width): - raise ValueError(f"unsigned value {value} doesn't fit in {bit_width} bits") + raise ValueError( + f"unsigned value {value} doesn't fit in {bit_width} bits" + ) remaining_bits = bit_width while remaining_bits > 0: @@ -111,9 +117,13 @@ def to_bytes( import math + def to_ints( - data: bytes, bit_width: int, count: int = 0, start_bit: int = 0, - signed: bool = False + data: bytes, + bit_width: int, + count: int = 0, + start_bit: int = 0, + signed: bool = False, ) -> Tuple[List[int], int, bytes]: if bit_width <= 0: raise ValueError("bit width must be > 0") @@ -193,7 +203,9 @@ def pack(self): # first bit of the payload is the signed bool # remaining 7 bits are the bit width - payload += struct.pack("> 1 signed = (struct.unpack(" bytes: ), payload=NDTPPayloadBroadband( bit_width=self.bit_width, + signed=self.signed, sample_rate=self.sample_rate, channels=[ NDTPPayloadBroadband.ChannelData( @@ -56,6 +58,7 @@ def from_ndtp_message(msg: NDTPMessage) -> "ElectricalBroadbandData": return ElectricalBroadbandData( t0=msg.header.timestamp, bit_width=msg.payload.bit_width, + signed=msg.payload.signed, sample_rate=msg.payload.sample_rate, channels=[ ElectricalBroadbandData.ChannelData(