-
Notifications
You must be signed in to change notification settings - Fork 58
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
3 changed files
with
47 additions
and
9 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11,6 +11,7 @@ hjson | |
jsonref | ||
jsonschema | ||
mako | ||
mkdocs-material | ||
progressbar2 | ||
tabulate | ||
yamllint | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,6 +3,8 @@ | |
# SPDX-License-Identifier: Apache-2.0 | ||
# | ||
# Author: Luca Colagrande <[email protected]> | ||
"""Convenience functions for data generation scripts.""" | ||
|
||
|
||
import struct | ||
from datetime import datetime | ||
|
@@ -11,42 +13,65 @@ | |
|
||
|
||
def emit_license(): | ||
"""Emit license header. | ||
Returns: | ||
A header string. | ||
""" | ||
|
||
s = (f"// Copyright {datetime.now().year} ETH Zurich and University of Bologna.\n" | ||
f"// Licensed under the Apache License, Version 2.0, see LICENSE for details.\n" | ||
f"// SPDX-License-Identifier: Apache-2.0\n") | ||
return s | ||
|
||
|
||
# Enum value can be a string or an integer, this function uniformizes the result to integers only | ||
def integer_precision_t(prec): | ||
def _integer_precision_t(prec): | ||
if isinstance(prec, str): | ||
return {'FP64': 8, 'FP32': 4, 'FP16': 2, 'FP8': 1}[prec] | ||
else: | ||
return prec | ||
|
||
|
||
def torch_type_from_precision_t(prec): | ||
"""Convert `precision_t` type to PyTorch type. | ||
Args: | ||
prec: A value of type `precision_t`. Accepts both enum strings | ||
(e.g. "FP64") and integer enumeration values (e.g. 8). | ||
""" | ||
ctype_to_torch_type_map = { | ||
8: torch.float64, | ||
4: torch.float32, | ||
2: torch.float16, | ||
1: None | ||
} | ||
return ctype_to_torch_type_map[integer_precision_t(prec)] | ||
return ctype_to_torch_type_map[_integer_precision_t(prec)] | ||
|
||
|
||
# Returns the C type representing a floating-point value of the specified precision | ||
def ctype_from_precision_t(prec): | ||
"""Convert `precision_t` type to a C type string. | ||
Args: | ||
prec: A value of type `precision_t`. Accepts both enum strings | ||
(e.g. "FP64") and integer enumeration values (e.g. 8). | ||
""" | ||
precision_t_to_ctype_map = { | ||
8: 'double', | ||
4: 'float', | ||
2: '__fp16', | ||
1: '__fp8' | ||
} | ||
return precision_t_to_ctype_map[integer_precision_t(prec)] | ||
return precision_t_to_ctype_map[_integer_precision_t(prec)] | ||
|
||
|
||
def flatten(array): | ||
"""Flatten various array types with a homogeneous API. | ||
Args: | ||
array: Can be a Numpy array, a PyTorch tensor or a nested list. | ||
""" | ||
if isinstance(array, np.ndarray): | ||
return array.flatten() | ||
elif isinstance(array, torch.Tensor): | ||
|
@@ -55,7 +80,7 @@ def flatten(array): | |
return np.array(array).flatten() | ||
|
||
|
||
def variable_attributes(alignment=None, section=None): | ||
def _variable_attributes(alignment=None, section=None): | ||
attributes = '' | ||
if alignment: | ||
attributes = f'__attribute__ ((aligned ({alignment})))' | ||
|
@@ -64,16 +89,16 @@ def variable_attributes(alignment=None, section=None): | |
return attributes | ||
|
||
|
||
def alias_dtype(dtype): | ||
def _alias_dtype(dtype): | ||
if dtype == '__fp8': | ||
return 'char' | ||
else: | ||
return dtype | ||
|
||
|
||
def format_array_declaration(dtype, uid, shape, alignment=None, section=None): | ||
attributes = variable_attributes(alignment, section) | ||
s = f'{alias_dtype(dtype)} {uid}' | ||
attributes = _variable_attributes(alignment, section) | ||
s = f'{_alias_dtype(dtype)} {uid}' | ||
for dim in shape: | ||
s += f'[{dim}]' | ||
if attributes: | ||
|
@@ -143,6 +168,18 @@ def format_ifdef_wrapper(macro, body): | |
|
||
|
||
def from_buffer(byte_array, ctype='uint32_t'): | ||
"""Get structured data from raw bytes. | ||
If `ctype` is a C type string, it returns a homogeneous list of the | ||
specified type from the raw data. | ||
Alternatively, a dictionary can be passed to `ctype` to extract a | ||
struct from the raw data. In this case, it returns a dictionary with | ||
the same keys as in `ctype`. The values in the `ctype` dictionary | ||
should be format strings compatible with Python's `struct` library. | ||
The order of the keys in the `ctype` dictionary should reflect the | ||
order in which the variables appear in the raw data. | ||
""" | ||
# Types which have a direct correspondence in Numpy | ||
NP_DTYPE_FROM_CTYPE = { | ||
'uint32_t': np.uint32, | ||
|