From d414e902114afe6742ddd998cd6e06aa3ff5bdc4 Mon Sep 17 00:00:00 2001 From: Abi Hafshin Alfarouq Date: Tue, 5 Mar 2024 18:18:07 +0700 Subject: [PATCH] vsmigx: allow fp16 input & output (#86) --- vsmigx/vs_migraphx.cpp | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/vsmigx/vs_migraphx.cpp b/vsmigx/vs_migraphx.cpp index ebfe98b..dcd17da 100644 --- a/vsmigx/vs_migraphx.cpp +++ b/vsmigx/vs_migraphx.cpp @@ -63,6 +63,7 @@ static void setDimensions( std::unique_ptr & vi, const std::array & input_shape, const std::array & output_shape, + int bitsPerSample, VSCore * core, const VSAPI * vsapi ) noexcept { @@ -71,9 +72,9 @@ static void setDimensions( vi->width *= output_shape[3] / input_shape[3]; if (output_shape[1] == 1) { - vi->format = vsapi->registerFormat(cmGray, stFloat, 32, 0, 0, core); + vi->format = vsapi->registerFormat(cmGray, stFloat, bitsPerSample, 0, 0, core); } else if (output_shape[1] == 3) { - vi->format = vsapi->registerFormat(cmRGB, stFloat, 32, 0, 0, core); + vi->format = vsapi->registerFormat(cmRGB, stFloat, bitsPerSample, 0, 0, core); } } @@ -723,8 +724,8 @@ static void VS_CC vsMIGXCreate( } migraphx_shape_datatype_t type; checkError(migraphx_shape_type(&type, input_shape)); - if (type != migraphx_shape_float_type) { - return set_error("input type must be float"); + if (type != migraphx_shape_float_type && type != migraphx_shape_half_type) { + return set_error("input type must be float or half"); } const size_t * lengths; size_t ndim; @@ -769,6 +770,7 @@ static void VS_CC vsMIGXCreate( size_t output_size; const_migraphx_shape_t output_shape; + int bitsPerSample; { migraphx_shapes_t output_shapes; checkError(migraphx_program_get_output_shapes(&output_shapes, d->program)); @@ -786,9 +788,10 @@ static void VS_CC vsMIGXCreate( } migraphx_shape_datatype_t type; checkError(migraphx_shape_type(&type, output_shape)); - if (type != migraphx_shape_float_type) { - return set_error("output type must be float"); + if (type != migraphx_shape_float_type && type != migraphx_shape_half_type) { + return set_error("output type must be float or half"); } + bitsPerSample = type == migraphx_shape_float_type ? 32 : 16; const size_t * lengths; size_t ndim; checkError(migraphx_shape_lengths(&lengths, &ndim, output_shape)); @@ -838,7 +841,7 @@ static void VS_CC vsMIGXCreate( return set_error("\"num_streams\" must be 1 for now"); } - setDimensions(d->out_vi, d->src_tile_shape, d->dst_tile_shape, core, vsapi); + setDimensions(d->out_vi, d->src_tile_shape, d->dst_tile_shape, bitsPerSample, core, vsapi); // per-stream context d->instances.reserve(num_streams);