Skip to content

Commit

Permalink
Simplify bit packing
Browse files Browse the repository at this point in the history
  • Loading branch information
lukamac committed Dec 3, 2024
1 parent 234971f commit 91a727b
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 13 deletions.
13 changes: 6 additions & 7 deletions test/Ne16MemoryLayout.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,15 +63,14 @@ def weightEncode(
# (cout, cinMajor, Bits, flattened spatial, cinMinor)
weight = weight.transpose(0, 1, 4, 3, 2)

# Prepare for packing
# (cout, cinMajor, Bits, flattened spatial, cinMinorBytes, 8)
cinMinorBytes = int(np.ceil(cinMinor / 8))
weight = np.stack(np.split(weight, cinMinorBytes, axis=-1), axis=-2)

# Pack
# (cout, cinMajor, Bits, flattened spatial, cinMinorBytes)
# Pack bits
# (-1, 8)
weight = weight.reshape(-1, 8)
# (-1, 1)
weight = np.packbits(weight, axis=-1, bitorder="little")

# Flatten the weights
# (-1, )
return weight.flatten()

@staticmethod
Expand Down
11 changes: 5 additions & 6 deletions test/NeurekaMemoryLayout.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,15 +105,14 @@ def weightEncode(
cout * cinMajor, NeurekaMemoryLayout._WEIGHT_BANDWIDTH
) # cout*cinMajor, 256b

# Prepare for packing
# (-1, Weight Bandwidth Bytes, 8)
weightBandwidthBytes = int(np.ceil(NeurekaMemoryLayout._WEIGHT_BANDWIDTH / 8))
weight = np.stack(np.split(weight, weightBandwidthBytes, axis=-1), axis=-2)

# Pack bits
# (-1, Weight Bandwidth Bytes)
# (-1, 8)
weight = weight.reshape(-1, 8)
# (-1, 1)
weight = np.packbits(weight, axis=-1, bitorder="little")

# Flatten the weights
# (-1, )
return weight.flatten()

@staticmethod
Expand Down

0 comments on commit 91a727b

Please sign in to comment.