diff --git a/hls4ml/model/graph.py b/hls4ml/model/graph.py index 678e6d49af..709c3db3ff 100644 --- a/hls4ml/model/graph.py +++ b/hls4ml/model/graph.py @@ -763,32 +763,24 @@ def predict(self, x): n_inputs = len(self.get_input_variables()) n_outputs = len(self.get_output_variables()) - curr_dir = os.getcwd() - os.chdir(self.config.get_output_dir() + '/firmware') - output = [] if n_samples == 1 and n_inputs == 1: x = [x] - try: - for i in range(n_samples): - predictions = [np.zeros(yj.size(), dtype=ctype) for yj in self.get_output_variables()] - if n_inputs == 1: - inp = [np.asarray(x[i])] - else: - inp = [np.asarray(xj[i]) for xj in x] - argtuple = inp - argtuple += predictions - argtuple = tuple(argtuple) - top_function(*argtuple) - output.append(predictions) - - # Convert to list of numpy arrays (one for each output) - output = [ - np.asarray([output[i_sample][i_output] for i_sample in range(n_samples)]) for i_output in range(n_outputs) - ] - finally: - os.chdir(curr_dir) + for i in range(n_samples): + predictions = [np.zeros(yj.size(), dtype=ctype) for yj in self.get_output_variables()] + if n_inputs == 1: + inp = [np.asarray(x[i])] + else: + inp = [np.asarray(xj[i]) for xj in x] + argtuple = inp + argtuple += predictions + argtuple = tuple(argtuple) + top_function(*argtuple) + output.append(predictions) + + # Convert to list of numpy arrays (one for each output) + output = [np.asarray([output[i_sample][i_output] for i_sample in range(n_samples)]) for i_output in range(n_outputs)] if n_samples == 1 and n_outputs == 1: return output[0][0] diff --git a/hls4ml/templates/catapult/myproject_bridge.cpp b/hls4ml/templates/catapult/myproject_bridge.cpp index f1326a1faf..9937adcf89 100755 --- a/hls4ml/templates/catapult/myproject_bridge.cpp +++ b/hls4ml/templates/catapult/myproject_bridge.cpp @@ -6,7 +6,7 @@ #include #include -static std::string s_weights_dir = "weights"; +// hls-fpga-machine-learning insert weights dir const char *get_weights_dir() { return s_weights_dir.c_str(); } diff --git a/hls4ml/templates/vivado/build_lib.sh b/hls4ml/templates/vivado/build_lib.sh index 8b2daf185f..f5f2431ee4 100755 --- a/hls4ml/templates/vivado/build_lib.sh +++ b/hls4ml/templates/vivado/build_lib.sh @@ -11,7 +11,8 @@ LDFLAGS= INCFLAGS="-Ifirmware/ap_types/" PROJECT=myproject LIB_STAMP=mystamp -WEIGHTS_DIR="\"weights\"" +BASEDIR="$(cd "$(dirname "$0")" && pwd)" +WEIGHTS_DIR="\"${BASEDIR}/firmware/weights\"" ${CC} ${CFLAGS} ${INCFLAGS} -D WEIGHTS_DIR=${WEIGHTS_DIR} -c firmware/${PROJECT}.cpp -o ${PROJECT}.o ${CC} ${CFLAGS} ${INCFLAGS} -D WEIGHTS_DIR=${WEIGHTS_DIR} -c ${PROJECT}_bridge.cpp -o ${PROJECT}_bridge.o diff --git a/hls4ml/writer/catapult_writer.py b/hls4ml/writer/catapult_writer.py index 396ecb968e..7db1063206 100755 --- a/hls4ml/writer/catapult_writer.py +++ b/hls4ml/writer/catapult_writer.py @@ -676,6 +676,9 @@ def write_bridge(self, model): newline = line.replace('MYPROJECT', format(model.config.get_project_name().upper())) elif 'myproject' in line: newline = line.replace('myproject', format(model.config.get_project_name())) + elif '// hls-fpga-machine-learning insert weights dir' in line: + weights_dir = (Path(fout.name).parent / 'firmware/weights').resolve() + newline = f'static std::string s_weights_dir = "{weights_dir}";\n' elif '// hls-fpga-machine-learning insert bram' in line: newline = line for bram in model_brams: