Skip to content

Commit

Permalink
Fix floating point conversion in Python wrapper (#515)
Browse files Browse the repository at this point in the history
* Generator.py: use LF ending when run from windows

* Convert scalar input to float16 in cython

* Index buffer as uints in cython wrapper

* Update CHANGELOG and bump ver. 1.4.0
  • Loading branch information
vathomass authored Nov 10, 2023
1 parent bcd294a commit 564629c
Show file tree
Hide file tree
Showing 6 changed files with 75 additions and 36 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@ Development version (next version)
- Fix a bug in the pre-processor that would cause issues on Arm GPUs
- Fix DLL install directory in mingw
- Added tuned parameters for various devices (see doc/tuning.md)
- Modifications to the python bindings (pyclblast)
* Convert float scalar values to cl_half for fp16 routines
* Amax/amin, max/min routines accept unsigned integer buffers for index
- Generator script now always use LF endings, independent of the platform

Version 1.6.1
- Fix pointer error in pyclblast on Arm
Expand Down
10 changes: 5 additions & 5 deletions scripts/generator/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
"/src/clblast_cuda.cpp",
"/src/pyclblast/src/pyclblast.pyx"
]
HEADER_LINES = [129, 21, 133, 24, 29, 45, 29, 66, 40, 96, 21, 327]
HEADER_LINES = [129, 21, 133, 24, 29, 45, 29, 66, 40, 96, 21, 341]
FOOTER_LINES = [98, 57, 112, 275, 6, 6, 6, 9, 2, 41, 56, 37]
HEADER_LINES_DOC = 0
FOOTER_LINES_DOC = 232
Expand Down Expand Up @@ -215,7 +215,7 @@ def main(argv):
file_footer = original[-FOOTER_LINES[i]:]

# Re-writes the body of the file
with open(library_root + FILES[i], "w") as f:
with open(library_root + FILES[i], "w", newline="\n") as f:
body = ""
levels = [1, 2, 3] if (i == 4 or i == 5 or i == 6) else [1, 2, 3, 4]
for level in levels:
Expand Down Expand Up @@ -261,14 +261,14 @@ def main(argv):

# Correctness tests
filename = library_root + "/test/correctness/routines/" + routine_suffix
with open(filename, "w") as f:
with open(filename, "w", newline="\n") as f:
f.write(cpp.HEADER + "\n")
f.write(cpp.correctness_test(routine, level_string))
f.write(cpp.FOOTER)

# Performance tests
filename = library_root + "/test/performance/routines/" + routine_suffix
with open(filename, "w") as f:
with open(filename, "w", newline="\n") as f:
f.write(cpp.HEADER + "\n")
f.write(cpp.performance_test(routine, level_string))
f.write(cpp.FOOTER)
Expand All @@ -283,7 +283,7 @@ def main(argv):
file_footer = original[-FOOTER_LINES_DOC:]

# Outputs the API documentation
with open(filename, "w") as f:
with open(filename, "w", newline="\n") as f:

# Outputs the header
f.write("".join(file_header))
Expand Down
16 changes: 13 additions & 3 deletions scripts/generator/generator/pyclblast.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import os


NL = os.linesep
NL = '\n'
SEPARATOR = "####################################################################################################"


Expand Down Expand Up @@ -43,7 +43,7 @@ def scalar_cython_conversion(scalar, flavour):
if scalar_type in ["cl_double2", "double2"]:
return "<cl_double2>cl_double2(x=" + scalar + ".real,y=" + scalar + ".imag)"
if scalar_type in ["cl_half", "half"]:
return "<cl_half>" + scalar
return "<cl_half>val_to_half(" + scalar + ")"
raise RuntimeError("Could not convert flavour '%s:%s'" % (flavour.precision_name, scalar_type))


Expand Down Expand Up @@ -82,8 +82,18 @@ def generate_pyx(routine):
result += NL

# Data types and checks
result += indent + "dtype = check_dtype([" + ", ".join(buffers) + "], "
int_buff = []
other_buff = []
for buf in buffers:
if buf in routine.index_buffers():
int_buff.append(buf)
else:
other_buff.append(buf)
result += indent + "dtype = check_dtype([" + ", ".join(other_buff) + "], "
result += "[" + ", ".join(['"%s"' % d for d in np_dtypes]) + "])" + NL
if int_buff:
result += indent + "check_dtype([" + ", ".join(int_buff) + "], "
result += "[" + ", ".join(['"uint16", "uint32", "uint64"']) + "])" + NL
for buf in buffers:
if buf in routine.buffers_vector():
result += indent + "check_vector("
Expand Down
2 changes: 1 addition & 1 deletion src/pyclblast/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,5 +54,5 @@ How to release a new version on PyPi
Following [the guide](https://packaging.python.org/tutorials/packaging-projects/), in essence doing (after changing the version number in `setup.py`):

python3 setup.py sdist bdist_wheel
python3 -m twine upload --repository pypi dist/pyclblast-1.3.2.tar.gz
python3 -m twine upload --repository pypi dist/pyclblast-1.4.0.tar.gz
# use '__token__' as username and supply the token from your PyPi account
11 changes: 9 additions & 2 deletions src/pyclblast/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,11 @@
from distutils.extension import Extension
from Cython.Distutils import build_ext
import platform
import numpy
import os

np_incdir = numpy.get_include()
np_libdir = os.path.join(np_incdir, '..', 'lib', '')

runtime_library_dirs = list()
if platform.system() == "Linux":
Expand All @@ -23,15 +28,17 @@
Extension(
"pyclblast",
["src/pyclblast.pyx"],
libraries=["clblast"],
libraries=["clblast", "npymath"],
runtime_library_dirs=runtime_library_dirs,
library_dirs=[np_libdir],
include_dirs=[np_incdir],
language="c++"
)
)

setup(
name="pyclblast",
version="1.3.2",
version="1.4.0",
author="Cedric Nugteren",
author_email="[email protected]",
url="https://github.com/CNugteren/CLBlast/blob/master/src/pyclblast",
Expand Down
Loading

0 comments on commit 564629c

Please sign in to comment.