Skip to content

Commit

Permalink
Add hint on import failure
Browse files Browse the repository at this point in the history
  • Loading branch information
calad0i committed Dec 16, 2024
1 parent 014c1db commit d3c8881
Show file tree
Hide file tree
Showing 8 changed files with 34 additions and 39 deletions.
47 changes: 15 additions & 32 deletions hls4ml/converters/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import importlib
import os
import warnings

import yaml

Expand All @@ -10,33 +9,19 @@
from hls4ml.converters.keras_to_hls import get_supported_keras_layers # noqa: F401
from hls4ml.converters.keras_to_hls import parse_keras_model # noqa: F401
from hls4ml.converters.keras_to_hls import keras_to_hls, register_keras_layer_handler
from hls4ml.converters.onnx_to_hls import get_supported_onnx_layers # noqa: F401
from hls4ml.converters.onnx_to_hls import parse_onnx_model # noqa: F401
from hls4ml.converters.onnx_to_hls import onnx_to_hls, register_onnx_layer_handler
from hls4ml.converters.pytorch_to_hls import ( # noqa: F401
get_supported_pytorch_layers,
pytorch_to_hls,
register_pytorch_layer_handler,
)
from hls4ml.model import ModelGraph
from hls4ml.utils.config import create_config
from hls4ml.utils.dependency import requires
from hls4ml.utils.symbolic_utils import LUTFunction

# ----------Make converters available if the libraries can be imported----------#
try:
from hls4ml.converters.pytorch_to_hls import ( # noqa: F401
get_supported_pytorch_layers,
pytorch_to_hls,
register_pytorch_layer_handler,
)

__pytorch_enabled__ = True
except ImportError:
warnings.warn("WARNING: Pytorch converter is not enabled!", stacklevel=1)
__pytorch_enabled__ = False

try:
from hls4ml.converters.onnx_to_hls import get_supported_onnx_layers # noqa: F401
from hls4ml.converters.onnx_to_hls import onnx_to_hls, register_onnx_layer_handler

__onnx_enabled__ = True
except ImportError:
warnings.warn("WARNING: ONNX converter is not enabled!", stacklevel=1)
__onnx_enabled__ = False

# ----------Layer handling register----------#
model_types = ['keras', 'pytorch', 'onnx']

Expand All @@ -51,7 +36,7 @@
# and has 'handles' attribute
# and is defined in this module (i.e., not imported)
if callable(func) and hasattr(func, 'handles') and func.__module__ == lib.__name__:
for layer in func.handles:
for layer in func.handles: # type: ignore
if model_type == 'keras':
register_keras_layer_handler(layer, func)
elif model_type == 'pytorch':
Expand Down Expand Up @@ -124,15 +109,9 @@ def convert_from_config(config):

model = None
if 'OnnxModel' in yamlConfig:
if __onnx_enabled__:
model = onnx_to_hls(yamlConfig)
else:
raise Exception("ONNX not found. Please install ONNX.")
model = onnx_to_hls(yamlConfig)
elif 'PytorchModel' in yamlConfig:
if __pytorch_enabled__:
model = pytorch_to_hls(yamlConfig)
else:
raise Exception("PyTorch not found. Please install PyTorch.")
model = pytorch_to_hls(yamlConfig)
else:
model = keras_to_hls(yamlConfig)

Expand Down Expand Up @@ -174,6 +153,7 @@ def _check_model_config(model_config):
return model_config


@requires('_keras')
def convert_from_keras_model(
model,
output_dir='my-hls-test',
Expand Down Expand Up @@ -237,6 +217,7 @@ def convert_from_keras_model(
return keras_to_hls(config)


@requires('_torch')
def convert_from_pytorch_model(
model,
output_dir='my-hls-test',
Expand Down Expand Up @@ -308,6 +289,7 @@ def convert_from_pytorch_model(
return pytorch_to_hls(config)


@requires('onnx')
def convert_from_onnx_model(
model,
output_dir='my-hls-test',
Expand Down Expand Up @@ -371,6 +353,7 @@ def convert_from_onnx_model(
return onnx_to_hls(config)


@requires('sr')
def convert_from_symbolic_expression(
expr,
n_symbols=None,
Expand Down
4 changes: 4 additions & 0 deletions hls4ml/converters/onnx_to_hls.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from hls4ml.model import ModelGraph
from hls4ml.utils.dependency import requires


# ----------------------Helpers---------------------
Expand All @@ -17,6 +18,7 @@ def replace_char_inconsitency(name):
return name.replace('.', '_')


@requires('onnx')
def get_onnx_attribute(operation, name, default=None):
from onnx import helper

Expand Down Expand Up @@ -73,6 +75,7 @@ def get_input_shape(graph, node):
return rv


@requires('onnx')
def get_constant_value(graph, constant_name):
tensor = next((x for x in graph.initializer if x.name == constant_name), None)
from onnx import numpy_helper
Expand Down Expand Up @@ -258,6 +261,7 @@ def parse_onnx_model(onnx_model):
return layer_list, input_layers, output_layers


@requires('onnx')
def onnx_to_hls(config):
"""Convert onnx model to hls model from configuration.
Expand Down
4 changes: 4 additions & 0 deletions hls4ml/converters/pytorch_to_hls.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from hls4ml.model import ModelGraph
from hls4ml.utils.dependency import requires


class PyTorchModelReader:
Expand All @@ -22,6 +23,7 @@ def get_weights_data(self, layer_name, var_name):
return data


@requires('_torch')
class PyTorchFileReader(PyTorchModelReader): # Inherit get_weights_data method
def __init__(self, config):
import torch
Expand Down Expand Up @@ -103,6 +105,7 @@ def decorator(function):
# ----------------------------------------------------------------


@requires('_torch')
def parse_pytorch_model(config, verbose=True):
"""Convert PyTorch model to hls4ml ModelGraph.
Expand Down Expand Up @@ -368,6 +371,7 @@ def parse_pytorch_model(config, verbose=True):
return layer_list, input_layers


@requires('_torch')
def pytorch_to_hls(config):
layer_list, input_layers = parse_pytorch_model(config)
print('Creating HLS model')
Expand Down
7 changes: 0 additions & 7 deletions hls4ml/model/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1 @@
from hls4ml.model.graph import HLSConfig, ModelGraph # noqa: F401

try:
from hls4ml.model import profiling # noqa: F401

__profiling_enabled__ = True
except ImportError:
__profiling_enabled__ = False
4 changes: 4 additions & 0 deletions hls4ml/model/quantizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
SaturationMode,
XnorPrecisionType,
)
from hls4ml.utils.dependency import requires


class Quantizer:
Expand Down Expand Up @@ -84,6 +85,7 @@ class QKerasQuantizer(Quantizer):
config (dict): Config of the QKeras quantizer to wrap.
"""

@requires('qkeras')
def __init__(self, config):
from qkeras.quantizers import get_quantizer

Expand Down Expand Up @@ -131,6 +133,7 @@ class QKerasBinaryQuantizer(Quantizer):
config (dict): Config of the QKeras quantizer to wrap.
"""

@requires('qkeras')
def __init__(self, config, xnor=False):
from qkeras.quantizers import get_quantizer

Expand All @@ -155,6 +158,7 @@ class QKerasPO2Quantizer(Quantizer):
config (dict): Config of the QKeras quantizer to wrap.
"""

@requires('qkeras')
def __init__(self, config):
from qkeras.quantizers import get_quantizer

Expand Down
4 changes: 4 additions & 0 deletions hls4ml/report/quartus_report.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
import webbrowser
from ast import literal_eval

from hls4ml.utils.dependency import requires


def parse_quartus_report(hls_dir, write_to_file=True):
'''
Expand Down Expand Up @@ -39,6 +41,7 @@ def parse_quartus_report(hls_dir, write_to_file=True):
return results


@requires('quantus-report')
def read_quartus_report(hls_dir, open_browser=False):
'''
Parse and print the Quartus report to print the report. Optionally open a browser.
Expand Down Expand Up @@ -89,6 +92,7 @@ def _find_project_dir(hls_dir):
return top_func_name + '-fpga.prj'


@requires('quantus-report')
def read_js_object(js_script):
'''
Reads the JavaScript file and return a dictionary of variables definded in the script.
Expand Down
2 changes: 2 additions & 0 deletions hls4ml/utils/config.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import json

import hls4ml
from hls4ml.utils.dependency import requires


def create_config(output_dir='my-hls-test', project_name='myproject', backend='Vivado', version='1.0.0', **kwargs):
Expand Down Expand Up @@ -44,6 +45,7 @@ def create_config(output_dir='my-hls-test', project_name='myproject', backend='V
return config


@requires('qkeras')
def _get_precision_from_quantizer(quantizer):
if isinstance(quantizer, str):
import qkeras
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ optional-dependencies.doc = [
"sphinx-rtd-theme",
]
optional-dependencies.HGQ = [ "hgq~=0.2.0" ]
optional-dependencies.onnx = [ "onnx>=1.4" ]
optional-dependencies.optimization = [
"keras-tuner==1.1.3",
"ortools==9.4.1874",
Expand Down

0 comments on commit d3c8881

Please sign in to comment.