Skip to content

Commit

Permalink
Move hw params to pytest product
Browse files Browse the repository at this point in the history
  • Loading branch information
Aba committed Nov 19, 2023
1 parent 9349f3a commit 0511986
Showing 1 changed file with 36 additions and 30 deletions.
66 changes: 36 additions & 30 deletions run/param_test.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import numpy as np
import os
import pytest
import itertools
from qkeras import *
from tensorflow.keras.layers import Input
import sys
Expand All @@ -20,41 +22,45 @@
keras.utils.set_random_seed(0)
type_d = { 'np': {8: np.int8, 16: np.int16, 32: np.int32, 64: np.int64} }

'''
0. Specify Hardware
'''
hw = Hardware (
processing_elements = (8,24),
frequency_mhz = 250,
bits_input = 4,
bits_weights = 4,
bits_sum = 24,
bits_bias = 16,
max_batch_size = 64,
max_channels_in = 2048,
max_kernel_size = (13, 13),
max_image_size = (512,512),
ram_weights_depth = 20,
ram_edges_depth = 288,
axi_width = 64,
target_cpu_int_bits = 32,
valid_prob = 0.1,
ready_prob = 0.1,
data_dir = 'vectors',
)
hw.export_json()
hw = Hardware.from_json('hardware.json')
hw.export() # Generates: config_hw.svh, config_hw.tcl
hw.export_vivado_tcl(board='zcu104')


def test_dnn_engine():
def product_dict(**kwargs):
for instance in itertools.product(*(kwargs.values())):
yield dict(zip(kwargs.keys(), instance))

@pytest.mark.parametrize("PARAMS", list(product_dict(
processing_elements = [(8,24) ],
frequency_mhz = [ 250 ],
bits_input = [ 4 ],
bits_weights = [ 4 ],
bits_sum = [ 16 ],
bits_bias = [ 16 ],
max_batch_size = [ 64 ],
max_channels_in = [ 2048 ],
max_kernel_size = [ 13 ],
max_image_size = [ 512 ],
ram_weights_depth = [ 20 ],
ram_edges_depth = [ 288 ],
axi_width = [ 64 ],
target_cpu_int_bits = [ 32 ],
valid_prob = [ 0.1 ],
ready_prob = [ 0.1 ],
data_dir = ['vectors'],
)))
def test_dnn_engine(PARAMS):
'''
0. Specify Hardware
'''
hw = Hardware (**PARAMS)
hw.export_json()
hw = Hardware.from_json('hardware.json')
hw.export() # Generates: config_hw.svh, config_hw.tcl
hw.export_vivado_tcl(board='zcu104')


xq, kq, bq = f'quantized_bits({hw.X_BITS},0,False,True,1)', f'quantized_bits({hw.K_BITS},0,False,True,1)', f'quantized_bits({hw.B_BITS},0,False,True,1)'
inp = {'bits':hw.X_BITS, 'frac':hw.X_BITS-1}

'''
Build Model
1. Build Model
'''
input_shape = (8,18,18,3) # (XN, XH, XW, CI)
x = x_in = Input(input_shape[1:], name='input')
Expand Down

0 comments on commit 0511986

Please sign in to comment.