diff --git a/src/nnfusion/common/type/element_type.cpp b/src/nnfusion/common/type/element_type.cpp index 590fdd2b6..156ec4e8a 100644 --- a/src/nnfusion/common/type/element_type.cpp +++ b/src/nnfusion/common/type/element_type.cpp @@ -59,7 +59,7 @@ bool element::Type::nnfusion_element_type_to_dtype_string(const element::Type& n std::string& dtype) { if (ng_et == element::boolean) - dtype = "char"; + dtype = "int"; else if (ng_et == element::character) dtype = "char"; else if (ng_et == element::f16) diff --git a/src/nnfusion/core/graph/gnode.cpp b/src/nnfusion/core/graph/gnode.cpp index ca6078711..f7b224041 100644 --- a/src/nnfusion/core/graph/gnode.cpp +++ b/src/nnfusion/core/graph/gnode.cpp @@ -416,6 +416,7 @@ void FusedGNode::set_inputs_and_outputs(std::shared_ptr graph) m_op_ctxs.push_back(ctx); } + std::unordered_map, std::unordered_map> input_id_map; // Register input tensors for (const auto& m_node : m_order_nodes) { @@ -430,6 +431,7 @@ void FusedGNode::set_inputs_and_outputs(std::shared_ptr graph) set_input(input_id, m_node->get_inputs().at(in_edge->get_dst_input())); graph->add_edge( in_edge->get_src(), in_edge->get_src_output(), shared_from_this(), input_id); + input_id_map[m_node][in_edge->get_dst_input()] = input_id; } } // Add control-edges as inputs of fused node @@ -461,6 +463,29 @@ void FusedGNode::set_inputs_and_outputs(std::shared_ptr graph) has_output = true; set_output(get_output_size(), m_node->get_outputs().at(out_edge->get_src_output())); + + // get inplace annotation + auto op = std::dynamic_pointer_cast(m_node->get_op_ptr()); + auto op_annotations = op->get_op_annotations(); + if (op_annotations) + { + auto oi_pairs = op_annotations->get_in_place_oi_pairs(); + for (auto oi_pair : oi_pairs) + { + auto iter = input_id_map.find(m_node); + if (iter != input_id_map.end() && iter->second.count(oi_pair.input) > 0) + { + auto fused_op = + std::dynamic_pointer_cast(shared_from_this()->get_op_ptr()); + AddInplace(fused_op, + get_output_size() - 1, + iter->second[oi_pair.input], + oi_pair.destructive, + oi_pair.force_inplace); + //NNFUSION_LOG(INFO) << "========================: node=" << m_node->get_op_type() << ", oi: <" << oi_pair.output << ", " << oi_pair.input << ">"; + } + } + } } graph->add_edge(shared_from_this(), get_output_size() - 1, diff --git a/src/nnfusion/core/kernels/common_langunit.cpp b/src/nnfusion/core/kernels/common_langunit.cpp index 6e897fa17..f258369a2 100644 --- a/src/nnfusion/core/kernels/common_langunit.cpp +++ b/src/nnfusion/core/kernels/common_langunit.cpp @@ -23,7 +23,7 @@ LU_DEFINE(header::chrono, "#include \n"); LU_DEFINE(header::ctime, "#include \n"); LU_DEFINE(header::limits, "#include \n"); LU_DEFINE(header::iostream, "#include \n"); -LU_DEFINE(header::windows, "#include \n"); +LU_DEFINE(header::windows, "#define NOMINMAX\n#include \n"); LU_DEFINE(header::unordered_map, "#include \n"); LU_DEFINE(header::torch_extension, "#include \n"); diff --git a/src/nnfusion/core/kernels/cuda_gpu/cuda_cudnn.cpp b/src/nnfusion/core/kernels/cuda_gpu/cuda_cudnn.cpp index 040349588..5fca1dd00 100644 --- a/src/nnfusion/core/kernels/cuda_gpu/cuda_cudnn.cpp +++ b/src/nnfusion/core/kernels/cuda_gpu/cuda_cudnn.cpp @@ -201,11 +201,12 @@ LanguageUnit_p cuda::get_cudnn_convolution_descriptor(const Shape& padding, << "window_dilation_strides_int, CUDNN_CROSS_CORRELATION, " << data_type << "));\n"; } - if(type == nnfusion::element::f16){ + if (type == nnfusion::element::f16) + { // half precision, use tensor core lu << "CUDNN_SAFE_CALL(cudnnSetConvolutionMathType(" << desc << ", " - << "CUDNN_TENSOR_OP_MATH" - << "));\n"; + << "CUDNN_TENSOR_OP_MATH" + << "));\n"; } return _lu; diff --git a/src/nnfusion/core/kernels/cuda_gpu/kernels/batch_matmul.cpp b/src/nnfusion/core/kernels/cuda_gpu/kernels/batch_matmul.cpp index dff602dc3..1741dbf40 100644 --- a/src/nnfusion/core/kernels/cuda_gpu/kernels/batch_matmul.cpp +++ b/src/nnfusion/core/kernels/cuda_gpu/kernels/batch_matmul.cpp @@ -103,8 +103,9 @@ namespace nnfusion @hCublas@, @transA@, @transB@, @m@, @n@, @k@, &alpha, input1, @lda@, @stride_a@, input0, @ldb@, @stride_b@, &beta, output0, @ldc@, @stride_c@, @batch@)); - )" : - R"( + )" + : + R"( static const float alpha = @alpha@F, beta = @beta@F; // if (!@hCublas@) // CUBLAS_SAFE_CALL(@api_create@(&@hCublas@)); @@ -116,7 +117,9 @@ namespace nnfusion { {"hCublas", "cublas_handle"}, {"api_create", "cublasCreate"}, - {"api_exec", dtype == nnfusion::element::f16 ? "cublasHgemmStridedBatched" : "cublasSgemmStridedBatched"}, + {"api_exec", + dtype == nnfusion::element::f16 ? "cublasHgemmStridedBatched" + : "cublasSgemmStridedBatched"}, {"transA", transB ? "CUBLAS_OP_T" : "CUBLAS_OP_N"}, {"transB", transA ? "CUBLAS_OP_T" : "CUBLAS_OP_N"}, {"alpha", alpha}, diff --git a/src/nnfusion/core/kernels/cuda_gpu/kernels/batch_norm.cpp b/src/nnfusion/core/kernels/cuda_gpu/kernels/batch_norm.cpp index 42404ac7b..a6c66dfdc 100644 --- a/src/nnfusion/core/kernels/cuda_gpu/kernels/batch_norm.cpp +++ b/src/nnfusion/core/kernels/cuda_gpu/kernels/batch_norm.cpp @@ -171,10 +171,14 @@ LanguageUnit_p cuda::BatchNormNCHW::emit_function_body() /* * todo: may have better solution, details in https://github.com/microsoft/nnfusion/issues/434 * */ - if(dtype == nnfusion::element::f16){ - lu << "output0[st + i] = __hadd(input1[c_id] , __hdiv(__hmul(input0[c_id], __hsub(input2[st + i], input3[c_id])), sqrtf(__hadd(__float2half(" + if (dtype == nnfusion::element::f16) + { + lu << "output0[st + i] = __hadd(input1[c_id] , __hdiv(__hmul(input0[c_id], " + "__hsub(input2[st + i], input3[c_id])), sqrtf(__hadd(__float2half(" << epsilon << "), input4[c_id]))));\n"; - }else{ + } + else + { lu << "(input1[c_id] + (input0[c_id] * " "(input2[st + i] - input3[c_id]) / sqrtf(" << epsilon << " + input4[c_id])));\n"; diff --git a/src/nnfusion/core/kernels/cuda_gpu/kernels/dot.cpp b/src/nnfusion/core/kernels/cuda_gpu/kernels/dot.cpp index 8e7b7a735..a78071f18 100644 --- a/src/nnfusion/core/kernels/cuda_gpu/kernels/dot.cpp +++ b/src/nnfusion/core/kernels/cuda_gpu/kernels/dot.cpp @@ -207,161 +207,165 @@ LanguageUnit_p cuda::Dot::emit_function_body() else if (dtype == element::f16) { // case 1: Scalar * Tensor - if (arg0_shape.empty() || arg1_shape.empty()) - { - auto& second = (arg0_shape.empty() ? arg1_shape : arg0_shape); - size_t count = nnfusion::shape_size(second); + if (arg0_shape.empty() || arg1_shape.empty()) + { + auto& second = (arg0_shape.empty() ? arg1_shape : arg0_shape); + size_t count = nnfusion::shape_size(second); - string firstarg = (arg0_shape.empty() ? "input1" : "input0"); - string secondarg = (arg0_shape.empty() ? "input0" : "input1"); + string firstarg = (arg0_shape.empty() ? "input1" : "input0"); + string secondarg = (arg0_shape.empty() ? "input0" : "input1"); - lu << "cublasSetPointerMode(cublas_handle, CUBLAS_POINTER_MODE_DEVICE);\n"; + lu << "cublasSetPointerMode(cublas_handle, CUBLAS_POINTER_MODE_DEVICE);\n"; - lu << "CUDA_SAFE_CALL(cudaMemcpy(outupt0, " << firstarg << ", " << count << ", cudaMemcpyDeviceToDevice));\n"; // copy `firstarg` to `output0` - lu << "CUBLAS_SAFE_CALL(nnfusionHalfScale(" << secondarg << ", output0, " << count << "));\n"; - } + lu << "CUDA_SAFE_CALL(cudaMemcpy(outupt0, " << firstarg << ", " << count + << ", cudaMemcpyDeviceToDevice));\n"; // copy `firstarg` to `output0` + lu << "CUBLAS_SAFE_CALL(nnfusionHalfScale(" << secondarg << ", output0, " << count + << "));\n"; + } // // case 2: 1d Dot - else if ((arg0_shape.size() == arg1_shape.size()) && (arg0_shape.size() == reduction_axes)) - { - for (int i = 0; i < arg0_shape.size(); i++) - { - if (arg0_shape[i] != arg1_shape[i]) - { - std::vector arg_vec{"arg0", "arg1"}; - std::vector shape_vec{arg0_shape, arg1_shape}; - - NNFUSION_CHECK_FAIL() << nnfusion::join(arg_vec) << " with " - << nnfusion::join(shape_vec) << " respectively, at Node " - << m_context->gnode->get_name() - << ", do not match for dot op"; - } - } - - size_t count = nnfusion::shape_size(arg0_shape); - lu << "cublasSetPointerMode(cublas_handle, CUBLAS_POINTER_MODE_DEVICE);\n"; - - lu << "CUBLAS_SAFE_CALL(cublasSdot(cublas_handle, " << count - << ", static_cast(input0), 1, static_cast(input1), 1, " - "static_cast(output0)));\n"; - } - // // matrix * vector - else if ((arg0_shape.size() == 2) && (arg1_shape.size() == 1) && (reduction_axes == 1)) - { - lu << "const float alpha = 1.0;\n const float beta = 0;\n"; - lu << "CUBLAS_SAFE_CALL(cublasSgemv(cublas_handle, "; - if (trans_A) - lu << "CUBLAS_OP_N, " << arg0_shape[0] << ", " << arg0_shape[1] << ", "; - else - lu << "CUBLAS_OP_T, " << arg0_shape[1] << ", " << arg0_shape[0] << ", "; - lu << " &alpha," - << " static_cast(input0)," << arg0_shape[1] << ", " - << " static_cast(input1)," - << " 1," - << " &beta," - << " static_cast(output0)," - << " 1));\n"; - } - else if ((arg0_shape.size() == 2) && (arg1_shape.size() == 2) && (reduction_axes == 1) && - (trans_A || trans_B)) - { - int m = trans_B ? arg1_shape[0] : arg1_shape[1]; - int n = trans_A ? arg0_shape[1] : arg0_shape[0]; - int k = trans_A ? arg0_shape[0] : arg0_shape[1]; - - lu << "const half alpha = 1.0;\nconst half beta = 0;\n"; - - lu << "CUBLAS_SAFE_CALL(cublasHgemm(cublas_handle," - << (trans_B ? " CUBLAS_OP_T," : " CUBLAS_OP_N,") - << (trans_A ? " CUBLAS_OP_T," : " CUBLAS_OP_N,") << " " << m << "," - << " " << n << "," - << " " << k << "," - << " &alpha," - << " static_cast(input1)," - << " " << arg1_shape[1] << "," - << " static_cast(input0)," - << " " << arg0_shape[1] << "," - << " &beta," - << " static_cast(output0)," - << " " << m << "));\n"; - } else { - size_t axes_for_m_count = arg0_shape.size() - reduction_axes; - size_t axes_for_n_count = arg1_shape.size() - reduction_axes; - size_t axes_for_k_count = reduction_axes; - size_t m = 1; - size_t n = 1; - size_t k = 1; - - // check if input and output size correct - // check and calculate k for arg0 and arg1 - size_t arg0_k_idx = axes_for_m_count; // first axe in arg0 for k - size_t arg1_k_idx = 0; // first axe in arg1 for k - - for (size_t i = 0; i < axes_for_k_count; i++) + else if ((arg0_shape.size() == arg1_shape.size()) && (arg0_shape.size() == reduction_axes)) { - k *= arg0_shape[arg0_k_idx]; - if (arg0_shape[arg0_k_idx++] != arg1_shape[arg1_k_idx++]) + for (int i = 0; i < arg0_shape.size(); i++) { - std::vector arg_vec{"arg0", "arg1"}; - std::vector shape_vec{arg0_shape, arg1_shape}; + if (arg0_shape[i] != arg1_shape[i]) + { + std::vector arg_vec{"arg0", "arg1"}; + std::vector shape_vec{arg0_shape, arg1_shape}; - NNFUSION_CHECK_FAIL() << nnfusion::join(arg_vec) << " with " - << nnfusion::join(shape_vec) << " respectively, at Node " - << m_context->gnode->get_name() - << ", do not match for dot op"; + NNFUSION_CHECK_FAIL() << nnfusion::join(arg_vec) << " with " + << nnfusion::join(shape_vec) << " respectively, at Node " + << m_context->gnode->get_name() + << ", do not match for dot op"; + } } + + size_t count = nnfusion::shape_size(arg0_shape); + lu << "cublasSetPointerMode(cublas_handle, CUBLAS_POINTER_MODE_DEVICE);\n"; + + lu << "CUBLAS_SAFE_CALL(cublasSdot(cublas_handle, " << count + << ", static_cast(input0), 1, static_cast(input1), 1, " + "static_cast(output0)));\n"; } - // check and calculate m for arg0 and out - size_t arg0_m_idx = 0; // first axe in arg0 for m - size_t out_m_idx = 0; // first axe in out for m - for (size_t i = 0; i < axes_for_m_count; i++) + // // matrix * vector + else if ((arg0_shape.size() == 2) && (arg1_shape.size() == 1) && (reduction_axes == 1)) { - m *= arg0_shape[arg0_m_idx]; - if (arg0_shape[arg0_m_idx++] != out_shape[out_m_idx++]) - { - std::vector arg_vec{"arg0", "output"}; - std::vector shape_vec{arg0_shape, out_shape}; + lu << "const float alpha = 1.0;\n const float beta = 0;\n"; + lu << "CUBLAS_SAFE_CALL(cublasSgemv(cublas_handle, "; + if (trans_A) + lu << "CUBLAS_OP_N, " << arg0_shape[0] << ", " << arg0_shape[1] << ", "; + else + lu << "CUBLAS_OP_T, " << arg0_shape[1] << ", " << arg0_shape[0] << ", "; + lu << " &alpha," + << " static_cast(input0)," << arg0_shape[1] << ", " + << " static_cast(input1)," + << " 1," + << " &beta," + << " static_cast(output0)," + << " 1));\n"; + } + else if ((arg0_shape.size() == 2) && (arg1_shape.size() == 2) && (reduction_axes == 1) && + (trans_A || trans_B)) + { + int m = trans_B ? arg1_shape[0] : arg1_shape[1]; + int n = trans_A ? arg0_shape[1] : arg0_shape[0]; + int k = trans_A ? arg0_shape[0] : arg0_shape[1]; - NNFUSION_CHECK_FAIL() << nnfusion::join(arg_vec) << " with " - << nnfusion::join(shape_vec) << " respectively, at Node " - << m_context->gnode->get_name() - << ", do not match for dot op"; - } + lu << "const half alpha = 1.0;\nconst half beta = 0;\n"; + + lu << "CUBLAS_SAFE_CALL(cublasHgemm(cublas_handle," + << (trans_B ? " CUBLAS_OP_T," : " CUBLAS_OP_N,") + << (trans_A ? " CUBLAS_OP_T," : " CUBLAS_OP_N,") << " " << m << "," + << " " << n << "," + << " " << k << "," + << " &alpha," + << " static_cast(input1)," + << " " << arg1_shape[1] << "," + << " static_cast(input0)," + << " " << arg0_shape[1] << "," + << " &beta," + << " static_cast(output0)," + << " " << m << "));\n"; } - // check and calculate n for arg1 and out - size_t arg1_n_idx = axes_for_k_count; // first axe in arg1 for n - size_t out_n_idx = axes_for_m_count; // first axe in arg1 for n - for (size_t i = 0; i < axes_for_n_count; i++) + else { - n *= arg1_shape[arg1_n_idx]; - if (arg1_shape[arg1_n_idx++] != out_shape[out_n_idx++]) + size_t axes_for_m_count = arg0_shape.size() - reduction_axes; + size_t axes_for_n_count = arg1_shape.size() - reduction_axes; + size_t axes_for_k_count = reduction_axes; + size_t m = 1; + size_t n = 1; + size_t k = 1; + + // check if input and output size correct + // check and calculate k for arg0 and arg1 + size_t arg0_k_idx = axes_for_m_count; // first axe in arg0 for k + size_t arg1_k_idx = 0; // first axe in arg1 for k + + for (size_t i = 0; i < axes_for_k_count; i++) { - std::vector arg_vec{"arg1", "output"}; - std::vector shape_vec{arg1_shape, out_shape}; + k *= arg0_shape[arg0_k_idx]; + if (arg0_shape[arg0_k_idx++] != arg1_shape[arg1_k_idx++]) + { + std::vector arg_vec{"arg0", "arg1"}; + std::vector shape_vec{arg0_shape, arg1_shape}; - NNFUSION_CHECK_FAIL() << nnfusion::join(arg_vec) << " with " - << nnfusion::join(shape_vec) << " respectively, at Node " - << m_context->gnode->get_name() - << ", do not match for dot op"; + NNFUSION_CHECK_FAIL() << nnfusion::join(arg_vec) << " with " + << nnfusion::join(shape_vec) << " respectively, at Node " + << m_context->gnode->get_name() + << ", do not match for dot op"; + } } - } + // check and calculate m for arg0 and out + size_t arg0_m_idx = 0; // first axe in arg0 for m + size_t out_m_idx = 0; // first axe in out for m + for (size_t i = 0; i < axes_for_m_count; i++) + { + m *= arg0_shape[arg0_m_idx]; + if (arg0_shape[arg0_m_idx++] != out_shape[out_m_idx++]) + { + std::vector arg_vec{"arg0", "output"}; + std::vector shape_vec{arg0_shape, out_shape}; + + NNFUSION_CHECK_FAIL() << nnfusion::join(arg_vec) << " with " + << nnfusion::join(shape_vec) << " respectively, at Node " + << m_context->gnode->get_name() + << ", do not match for dot op"; + } + } + // check and calculate n for arg1 and out + size_t arg1_n_idx = axes_for_k_count; // first axe in arg1 for n + size_t out_n_idx = axes_for_m_count; // first axe in arg1 for n + for (size_t i = 0; i < axes_for_n_count; i++) + { + n *= arg1_shape[arg1_n_idx]; + if (arg1_shape[arg1_n_idx++] != out_shape[out_n_idx++]) + { + std::vector arg_vec{"arg1", "output"}; + std::vector shape_vec{arg1_shape, out_shape}; - lu << "const half alpha = 1.0f;\nconst half beta = 0.f;\n"; - - lu << "CUBLAS_SAFE_CALL(cublasHgemm(cublas_handle," - << " CUBLAS_OP_N," - << " CUBLAS_OP_N," - << " " << n << "," - << " " << m << "," - << " " << k << "," - << " &alpha," - << " static_cast(input1)," - << " " << n << "," - << " static_cast(input0)," - << " " << k << "," - << " &beta," - << " static_cast(output0)," - << " " << n << "));\n"; - } + NNFUSION_CHECK_FAIL() << nnfusion::join(arg_vec) << " with " + << nnfusion::join(shape_vec) << " respectively, at Node " + << m_context->gnode->get_name() + << ", do not match for dot op"; + } + } + + lu << "const half alpha = 1.0f;\nconst half beta = 0.f;\n"; + + lu << "CUBLAS_SAFE_CALL(cublasHgemm(cublas_handle," + << " CUBLAS_OP_N," + << " CUBLAS_OP_N," + << " " << n << "," + << " " << m << "," + << " " << k << "," + << " &alpha," + << " static_cast(input1)," + << " " << n << "," + << " static_cast(input0)," + << " " << k << "," + << " &beta," + << " static_cast(output0)," + << " " << n << "));\n"; + } } else { diff --git a/src/nnfusion/core/kernels/cuda_gpu/kernels/reduce.hpp b/src/nnfusion/core/kernels/cuda_gpu/kernels/reduce.hpp index 5c9146afb..4c47ba346 100644 --- a/src/nnfusion/core/kernels/cuda_gpu/kernels/reduce.hpp +++ b/src/nnfusion/core/kernels/cuda_gpu/kernels/reduce.hpp @@ -230,7 +230,10 @@ for (int tidx = thread_idx; tidx < width; tidx += block_size) { val = reduceSum(val, thread_idx, block_size, shm); if (thread_idx == 0) output0[block_idx] = val; )", - {{"width", width}, {"block_size", expected_block_size}, {"warp_size", 32},{"dataType", dtype==nnfusion::element::f16? "half" : "float"}}); + {{"width", width}, + {"block_size", expected_block_size}, + {"warp_size", 32}, + {"dataType", dtype == nnfusion::element::f16 ? "half" : "float"}}); lu << code << "\n"; return _lu; @@ -582,7 +585,6 @@ if (thread_idx == 0) output0[block_idx] = val; m_gridDim = dim3(1, 1, 1); m_blockDim = dim3(block_size_x, 1, 1); } - } } diff --git a/src/nnfusion/core/kernels/hlsl/hlsl_kernel_emitter.hpp b/src/nnfusion/core/kernels/hlsl/hlsl_kernel_emitter.hpp index acd79a5cb..772684bea 100644 --- a/src/nnfusion/core/kernels/hlsl/hlsl_kernel_emitter.hpp +++ b/src/nnfusion/core/kernels/hlsl/hlsl_kernel_emitter.hpp @@ -69,7 +69,6 @@ namespace nnfusion kernel_info = nnfusion::kernels::AntaresKEImp::get_kernel_info(antares_code); - NNFUSION_CHECK(!kernel_info.empty()); process_antares_kernel_info(); } diff --git a/src/nnfusion/core/operators/generic_op/generic_op.hpp b/src/nnfusion/core/operators/generic_op/generic_op.hpp index 30ba4eeb2..9e3d616a8 100644 --- a/src/nnfusion/core/operators/generic_op/generic_op.hpp +++ b/src/nnfusion/core/operators/generic_op/generic_op.hpp @@ -264,7 +264,9 @@ namespace nnfusion std::vector shape_def; for (int d = 0; d < shape.size(); d++) { - shape_def.push_back(shape[d] == 0 ? "1" : ("N" + to_string(d))); + // Tensor with shape [0] is treated as scalar value and convert its shape to [1] + shape_def.push_back((shape.size() == 1 && shape[d] == 0) ? "1" + : ("N" + to_string(d))); } return shape_def; } diff --git a/src/nnfusion/core/operators/generic_op/generic_op_define/Concat.cpp b/src/nnfusion/core/operators/generic_op/generic_op_define/Concat.cpp index ad5336cb4..de87f56f0 100644 --- a/src/nnfusion/core/operators/generic_op/generic_op_define/Concat.cpp +++ b/src/nnfusion/core/operators/generic_op/generic_op_define/Concat.cpp @@ -46,11 +46,20 @@ REGISTER_OP(Concat) R"( @input@@input_layout@.when(@dim@ < @offset@, @recursive@) )"; auto final_input_template = R"(@input@@input_layout@)"; std::string inputs_body = R"(@recursive@)"; + + size_t num_valid_inputs = 0; + for (int in_id = 0; in_id < curr->get_input_size(); ++in_id) + if (curr->get_input_shape(in_id)[axis] > 0) + num_valid_inputs++; + + size_t processed_inputs = 0; for (int in_id = 0; in_id < curr->get_input_size(); ++in_id) { std::vector in_data_layout(data_layout); in_data_layout[axis] = in_data_layout[axis] + " - " + to_string(offset); auto dim_size = curr->get_input_shape(in_id)[axis]; + if (dim_size == 0) + continue; offset += dim_size; op::OpConfig::any in_config; @@ -58,9 +67,10 @@ REGISTER_OP(Concat) in_config["input_layout"] = vector_to_string>(in_data_layout); in_config["dim"] = data_layout[axis]; in_config["offset"] = offset; + processed_inputs++; std::string cur_body; - if (in_id != curr->get_input_size() - 1) + if (processed_inputs < num_valid_inputs) cur_body = op::create_code_from_template(recursive_input_template, in_config); else cur_body = op::create_code_from_template(final_input_template, in_config); diff --git a/src/nnfusion/core/operators/generic_op/generic_op_define/Convolution.cpp b/src/nnfusion/core/operators/generic_op/generic_op_define/Convolution.cpp index 5c7e0e907..6d2b07dbd 100644 --- a/src/nnfusion/core/operators/generic_op/generic_op_define/Convolution.cpp +++ b/src/nnfusion/core/operators/generic_op/generic_op_define/Convolution.cpp @@ -30,55 +30,139 @@ REGISTER_OP(Convolution) }) */ .translate_v2([](std::shared_ptr curr) -> std::string { - auto ir_template = - R"( @output0@@output0_layout@ +=! @input0@@input0_layout@@pad_cond@ * @input1@@input1_layout@ where HO in @height@, WO in @width@; )"; - auto _op = static_pointer_cast(curr->get_op_ptr()); NNFUSION_CHECK_NOT_NULLPTR(_op) << "Node type is not " << curr->get_op_ptr()->get_op_type(); - const auto& dilation_h = _op->get_window_dilation_strides()[0]; - const auto& dilation_w = _op->get_window_dilation_strides()[1]; - const auto& stride_h = _op->get_window_movement_strides()[0]; - const auto& stride_w = _op->get_window_movement_strides()[1]; - const auto& is_nchw = _op->get_data_format() == "NCHW"; - const auto& padding_below = _op->get_padding_below(); - const auto& padding_above = _op->get_padding_above(); - const auto& padding_h = _op->get_padding_below()[0]; - const auto& padding_w = _op->get_padding_below()[1]; - const auto& kernel_size_h = - is_nchw ? curr->get_input_shape(1)[2] : curr->get_input_shape(1)[0]; - const auto& kernel_size_w = - is_nchw ? curr->get_input_shape(1)[3] : curr->get_input_shape(1)[1]; - const auto& in_shape = curr->get_input_shape(0); - const auto& out_shape = curr->get_output_shape(0); - const std::string data_format = is_nchw ? "nchw" : "nhwc"; - NNFUSION_CHECK(dilation_h == 1) << "Not support other dilation yet."; - NNFUSION_CHECK(dilation_w == 1) << "Not support other dilation yet."; - NNFUSION_CHECK(padding_below == padding_above) - << "Asymetric padding is not supported by now."; - nnfusion::op::OpConfig::any config; - std::string HO = "-@pad_0@ + KH + HO * " + to_string(stride_h); - std::string WO = "-@pad_1@ + KW + WO * " + to_string(stride_w); - std::string shape_template = - is_nchw ? "[N, C, " + HO + ", " + WO + "]" : "[N, " + HO + ", " + WO + ", C]"; - config["input1_layout"] = is_nchw ? "[F, C, KH, KW]" : "[KH, KW, C, F]"; - config["output0_layout"] = is_nchw ? "[N, F, HO, WO]" : "[N, HO, WO, F]"; - config["height"] = is_nchw ? out_shape[2] : out_shape[1]; - config["width"] = is_nchw ? out_shape[3] : out_shape[2]; - config["pad_0"] = to_string(padding_h); - config["pad_1"] = to_string(padding_w); - config["input0_layout"] = op::create_code_from_template(shape_template, config); - std::string pad_cond; - if (padding_h || padding_w) + if (_op->get_data_format() == "NCHW" || _op->get_data_format() == "NHWC") // Conv2D { - config["in_height"] = is_nchw ? in_shape[2] : in_shape[1]; - config["in_width"] = is_nchw ? in_shape[3] : in_shape[2]; - auto pad_template = ".when([" + HO + " >= 0, " + HO + " < @in_height@, " + WO + - " >= 0, " + WO + - " < @in_width@], const(0.0).cast(@input0@@input0_layout@.dtype()))"; - pad_cond = op::create_code_from_template(pad_template, config); + auto ir_template = + R"( @output0@@output0_layout@ +=! @input0@@input0_layout@@pad_cond@ * @input1@@input1_layout@ where HO in @height@, WO in @width@; )"; + + const auto& dilation_h = _op->get_window_dilation_strides()[0]; + const auto& dilation_w = _op->get_window_dilation_strides()[1]; + const auto& stride_h = _op->get_window_movement_strides()[0]; + const auto& stride_w = _op->get_window_movement_strides()[1]; + const auto& is_nchw = _op->get_data_format() == "NCHW"; + const auto& padding_below = _op->get_padding_below(); + const auto& padding_above = _op->get_padding_above(); + const auto& padding_h = _op->get_padding_below()[0]; + const auto& padding_w = _op->get_padding_below()[1]; + const auto& kernel_size_h = + is_nchw ? curr->get_input_shape(1)[2] : curr->get_input_shape(1)[0]; + const auto& kernel_size_w = + is_nchw ? curr->get_input_shape(1)[3] : curr->get_input_shape(1)[1]; + const auto& in_shape = curr->get_input_shape(0); + const auto& out_shape = curr->get_output_shape(0); + const std::string data_format = is_nchw ? "nchw" : "nhwc"; + if (dilation_h != 1 || dilation_w != 1) + { + NNFUSION_LOG(NNFUSION_WARNING) << "Not support other dilation yet."; + return ""; + } + if (padding_below != padding_above) + { + NNFUSION_LOG(NNFUSION_WARNING) << "Asymetric padding is not supported by now."; + return ""; + } + // NNFUSION_CHECK(dilation_h == 1) << "Not support other dilation yet."; + // NNFUSION_CHECK(dilation_w == 1) << "Not support other dilation yet."; + // NNFUSION_CHECK(padding_below == padding_above) + // << "Asymetric padding is not supported by now."; + nnfusion::op::OpConfig::any config; + std::string HO = "-@pad_0@ + KH + HO * " + to_string(stride_h); + std::string WO = "-@pad_1@ + KW + WO * " + to_string(stride_w); + std::string shape_template = + is_nchw ? "[N, C, " + HO + ", " + WO + "]" : "[N, " + HO + ", " + WO + ", C]"; + config["input1_layout"] = is_nchw ? "[F, C, KH, KW]" : "[KH, KW, C, F]"; + config["output0_layout"] = is_nchw ? "[N, F, HO, WO]" : "[N, HO, WO, F]"; + config["height"] = is_nchw ? out_shape[2] : out_shape[1]; + config["width"] = is_nchw ? out_shape[3] : out_shape[2]; + config["pad_0"] = to_string(padding_h); + config["pad_1"] = to_string(padding_w); + config["input0_layout"] = op::create_code_from_template(shape_template, config); + + std::string pad_cond; + if (padding_h || padding_w) + { + config["in_height"] = is_nchw ? in_shape[2] : in_shape[1]; + config["in_width"] = is_nchw ? in_shape[3] : in_shape[2]; + auto pad_template = + ".when([" + HO + " >= 0, " + HO + " < @in_height@, " + WO + " >= 0, " + WO + + " < @in_width@], const(0.0).cast(@input0@@input0_layout@.dtype()))"; + pad_cond = op::create_code_from_template(pad_template, config); + } + config["pad_cond"] = pad_cond; + + return op::create_code_from_template(ir_template, config); + } + else if (_op->get_data_format() == "NCDHW") // Conv3D + { + auto ir_template = + R"( @output0@@output0_layout@ +=! @input0@@input0_layout@@pad_cond@ * @input1@@input1_layout@ where DO in @depth@, HO in @height@, WO in @width@; )"; + + const auto& dilation_d = _op->get_window_dilation_strides()[0]; + const auto& dilation_h = _op->get_window_dilation_strides()[1]; + const auto& dilation_w = _op->get_window_dilation_strides()[2]; + const auto& stride_d = _op->get_window_movement_strides()[0]; + const auto& stride_h = _op->get_window_movement_strides()[1]; + const auto& stride_w = _op->get_window_movement_strides()[2]; + const auto& padding_below = _op->get_padding_below(); + const auto& padding_above = _op->get_padding_above(); + const auto& padding_d = _op->get_padding_below()[0]; + const auto& padding_h = _op->get_padding_below()[1]; + const auto& padding_w = _op->get_padding_below()[2]; + const auto& kernel_size_d = curr->get_input_shape(1)[2]; + const auto& kernel_size_h = curr->get_input_shape(1)[3]; + const auto& kernel_size_w = curr->get_input_shape(1)[4]; + const auto& in_shape = curr->get_input_shape(0); + const auto& out_shape = curr->get_output_shape(0); + const std::string data_format = "NCDHW"; + if (dilation_d != 1 || dilation_h != 1 || dilation_w != 1) + { + NNFUSION_LOG(NNFUSION_WARNING) << "Not support other dilation yet."; + return ""; + } + if (padding_below != padding_above) + { + NNFUSION_LOG(NNFUSION_WARNING) << "Asymetric padding is not supported by now."; + return ""; + } + // NNFUSION_CHECK(dilation_d == 1) << "Not support other dilation yet."; + // NNFUSION_CHECK(dilation_h == 1) << "Not support other dilation yet."; + // NNFUSION_CHECK(dilation_w == 1) << "Not support other dilation yet."; + // NNFUSION_CHECK(padding_below == padding_above) + // << "Asymetric padding is not supported by now."; + nnfusion::op::OpConfig::any config; + std::string DO = "-@pad_0@ + KD + DO * " + to_string(stride_d); + std::string HO = "-@pad_1@ + KH + HO * " + to_string(stride_h); + std::string WO = "-@pad_2@ + KW + WO * " + to_string(stride_w); + std::string shape_template = "[N, C, " + DO + ", " + HO + ", " + WO + "]"; + config["input1_layout"] = "[F, C, KD, KH, KW]"; + config["output0_layout"] = "[N, F, DO, HO, WO]"; + config["depth"] = out_shape[2]; + config["height"] = out_shape[3]; + config["width"] = out_shape[4]; + config["pad_0"] = to_string(padding_d); + config["pad_1"] = to_string(padding_h); + config["pad_2"] = to_string(padding_w); + config["input0_layout"] = op::create_code_from_template(shape_template, config); + + std::string pad_cond; + if (padding_d || padding_h || padding_w) + { + config["in_depth"] = in_shape[2]; + config["in_height"] = in_shape[3]; + config["in_width"] = in_shape[4]; + auto pad_template = + ".when([" + DO + " >= 0, " + DO + " < @in_depth@, " + HO + " >= 0, " + HO + + " < @in_height@, " + WO + " >= 0, " + WO + + " < @in_width@], const(0.0).cast(@input0@@input0_layout@.dtype()))"; + pad_cond = op::create_code_from_template(pad_template, config); + } + config["pad_cond"] = pad_cond; + + return op::create_code_from_template(ir_template, config); } - config["pad_cond"] = pad_cond; - return op::create_code_from_template(ir_template, config); + return ""; }); diff --git a/src/nnfusion/core/operators/generic_op/generic_op_define/ScatterND.cpp b/src/nnfusion/core/operators/generic_op/generic_op_define/ScatterND.cpp index 6174bc824..8a2fcc44e 100644 --- a/src/nnfusion/core/operators/generic_op/generic_op_define/ScatterND.cpp +++ b/src/nnfusion/core/operators/generic_op/generic_op_define/ScatterND.cpp @@ -34,7 +34,7 @@ REGISTER_OP(ScatterND) { auto temp = batch_dims; temp.push_back(to_string(i)); - output_layout.push_back("input1" + + output_layout.push_back("@input1@" + vector_to_string>(temp)); } else diff --git a/src/nnfusion/core/operators/generic_op/generic_op_define/Slice.cpp b/src/nnfusion/core/operators/generic_op/generic_op_define/Slice.cpp index 75f02a837..4bfc4fa03 100644 --- a/src/nnfusion/core/operators/generic_op/generic_op_define/Slice.cpp +++ b/src/nnfusion/core/operators/generic_op/generic_op_define/Slice.cpp @@ -36,10 +36,12 @@ REGISTER_OP(Slice) auto step = steps[d]; auto start = starts[d]; auto end = ends[d]; - auto range = (u_int64_t)ceil((double)(end-start)/(double)step); - input_layout.push_back((step == 1? output_layout[d] : output_layout[d] + " * " + to_string(step)) + " + " + to_string(start)); - slice_dims += (slice_dims.empty() ? "" : " , ") + output_layout[d] + - " in " + to_string(range); + auto range = (u_int64_t)ceil((double)(end - start) / (double)step); + input_layout.push_back( + (step == 1 ? output_layout[d] : output_layout[d] + " * " + to_string(step)) + + " + " + to_string(start)); + slice_dims += + (slice_dims.empty() ? "" : " , ") + output_layout[d] + " in " + to_string(range); } auto expression_code = op::create_code_from_template( diff --git a/src/nnfusion/core/operators/util/validation_util.cpp b/src/nnfusion/core/operators/util/validation_util.cpp index 1cbaa7456..758810682 100644 --- a/src/nnfusion/core/operators/util/validation_util.cpp +++ b/src/nnfusion/core/operators/util/validation_util.cpp @@ -176,26 +176,33 @@ std::tuple << "), padding above (" << data_padding_above << "), filter strides (" << filter_strides << "), and filter dilation (" << filter_dilation << ") do not match."; - OP_VALIDATION(op, data_format == "NCW" || data_format == "NCHW" || data_format == "NHWC") - << "data format must be Conv1D: NCW, Conv2D: NCHW or NHWC."; + OP_VALIDATION(op, + data_format == "NCW" || data_format == "NCHW" || data_format == "NHWC" || + data_format == "NCDHW") + << "data format must be Conv1D: NCW, Conv2D: NCHW or NHWC, Conv3D: NCDHW."; nnfusion::Dimension batch_size = (data_batch_shape.rank().is_static() ? data_batch_shape[0] : nnfusion::Dimension::dynamic()); nnfusion::Dimension data_channel_count = (data_batch_shape.rank().is_static() - ? (data_format == "NCW" || data_format == "NCHW") ? data_batch_shape[1] - : data_batch_shape[3] + ? (data_format == "NCW" || data_format == "NCHW" || data_format == "NCDHW") + ? data_batch_shape[1] + : data_batch_shape[3] : nnfusion::Dimension::dynamic()); nnfusion::PartialShape data_spatial_shape(nnfusion::PartialShape::dynamic(spatial_rank)); nnfusion::Dimension filter_output_channel_count = (filters_shape.rank().is_static() - ? (data_format == "NCW" || data_format == "NCHW") ? filters_shape[0] : filters_shape[3] + ? (data_format == "NCW" || data_format == "NCHW" || data_format == "NCDHW") + ? filters_shape[0] + : filters_shape[3] : nnfusion::Dimension::dynamic()); nnfusion::Dimension filter_input_channel_count = (filters_shape.rank().is_static() - ? (data_format == "NCW" || data_format == "NCHW") ? filters_shape[1] : filters_shape[2] + ? (data_format == "NCW" || data_format == "NCHW" || data_format == "NCDHW") + ? filters_shape[1] + : filters_shape[2] : nnfusion::Dimension::dynamic()); nnfusion::PartialShape filter_spatial_shape(nnfusion::PartialShape::dynamic(spatial_rank)); @@ -207,16 +214,18 @@ std::tuple { if (data_batch_shape.rank().is_static()) { - data_spatial_shape[i] = (data_format == "NCW" || data_format == "NCHW") - ? data_batch_shape[i + 2] - : data_batch_shape[i + 1]; + data_spatial_shape[i] = + (data_format == "NCW" || data_format == "NCHW" || data_format == "NCDHW") + ? data_batch_shape[i + 2] + : data_batch_shape[i + 1]; } if (filters_shape.rank().is_static()) { - filter_spatial_shape[i] = (data_format == "NCW" || data_format == "NCHW") - ? filters_shape[i + 2] - : filters_shape[i]; + filter_spatial_shape[i] = + (data_format == "NCW" || data_format == "NCHW" || data_format == "NCDHW") + ? filters_shape[i + 2] + : filters_shape[i]; } } @@ -253,7 +262,7 @@ std::tuple nnfusion::PartialShape batch_output_shape(nnfusion::PartialShape::dynamic(spatial_rank + 2)); - if (data_format == "NCW" || data_format == "NCHW") + if (data_format == "NCW" || data_format == "NCHW" || data_format == "NCDHW") { batch_output_shape[0] = batch_size; batch_output_shape[1] = filter_output_channel_count; diff --git a/src/nnfusion/engine/pass/codegen/base_codegen_pass.cpp b/src/nnfusion/engine/pass/codegen/base_codegen_pass.cpp index 216c8efe9..cc2d6509e 100644 --- a/src/nnfusion/engine/pass/codegen/base_codegen_pass.cpp +++ b/src/nnfusion/engine/pass/codegen/base_codegen_pass.cpp @@ -297,22 +297,33 @@ std::pair return std::make_pair(lup_alloc, lup_free); } -nnfusion::LanguageUnit_p BaseCodegenPass::codegen_mem_ref(KernelEmitter::Pointer kernel) +nnfusion::LanguageUnit_p BaseCodegenPass::codegen_mem_ref(nnfusion::ir::Instruction::Pointer ins) { - if (!kernel || FLAGS_fcustomized_mem_imp) + auto kernel = ins->getKernel(); + if (!kernel || FLAGS_fcustomized_mem_imp || ins->getGNode()->get_op_type() == "Result") return nullptr; LanguageUnit_p _lu(new LanguageUnit(kernel->get_function_name() + "_mem_ref")); auto& lu = *_lu; bool empty = true; - if (auto annotations = kernel->m_context->annotations) + if ((*ins)["InplaceTensorMapping"].is_valid()) { - for (auto oi_pair : annotations->get_in_place_oi_pairs()) + auto in_place_outputs = + (*ins)["InplaceTensorMapping"] + .as, + std::pair, size_t>>>(); + for (auto output : kernel->m_context->outputs) { - if (oi_pair.force_inplace == true) + if (is_ref_tensor(ins, output)) { - auto input = kernel->m_context->inputs[oi_pair.input]; - auto output = kernel->m_context->outputs[oi_pair.output]; - lu << output->get_name() << " = " << input->get_name() << ";\n"; + auto parent_tensor = in_place_outputs.at(output).first; + size_t tensor_offset = in_place_outputs.at(output).second; + + auto root_tensor = parent_tensor->get_root_tensor() + ? parent_tensor->get_root_tensor() + : parent_tensor; + lu << output->get_name() << " = " << root_tensor->get_name() + << ((tensor_offset > 0) ? (" + " + std::to_string(tensor_offset)) : ("")) + << ";\n"; empty = false; } } @@ -323,6 +334,23 @@ nnfusion::LanguageUnit_p BaseCodegenPass::codegen_mem_ref(KernelEmitter::Pointer return _lu; } +bool BaseCodegenPass::is_ref_tensor(nnfusion::ir::Instruction::Pointer ins, + shared_ptr output) +{ + if ((*ins)["InplaceTensorMapping"].is_valid()) + { + auto in_place_outputs = + (*ins)["InplaceTensorMapping"] + .as, + std::pair, size_t>>>(); + // input tensor is unallocated (e.g., Parameter), need to assign address at runtime + if (in_place_outputs.count(output) > 0 && + (in_place_outputs.at(output).first)->get_pool_offset() == SIZE_MAX) + return true; + } + return false; +} + LanguageUnit_p BaseCodegenPass::codegen_device_type() { auto lu_devtype = make_shared("device_type"); @@ -346,7 +374,7 @@ LanguageUnit_p BaseCodegenPass::codegen_workspace_size(std::shared_ptrmax_allocated(); } - *lu_workspace << "int64_t get_workspace_size()\n{\n"; + *lu_workspace << "size_t get_workspace_size()\n{\n"; *lu_workspace << " return " << total_alloc << ";\n"; *lu_workspace << "}\n"; return lu_workspace; diff --git a/src/nnfusion/engine/pass/codegen/base_codegen_pass.hpp b/src/nnfusion/engine/pass/codegen/base_codegen_pass.hpp index 40cab2b48..515e3c20e 100644 --- a/src/nnfusion/engine/pass/codegen/base_codegen_pass.hpp +++ b/src/nnfusion/engine/pass/codegen/base_codegen_pass.hpp @@ -92,7 +92,10 @@ namespace nnfusion virtual NNFusion_DeviceType device_type() { return NNFusion_DeviceType::UNKNOWN; } virtual std::pair get_customized_mem_imp(nnfusion::ir::Instruction::Pointer ins); - LanguageUnit_p codegen_mem_ref(KernelEmitter::Pointer kernel); + LanguageUnit_p codegen_mem_ref(nnfusion::ir::Instruction::Pointer ins); + // check if an output tensor of ins is ref_tensor, that needs to assign address at runtime + bool is_ref_tensor(nnfusion::ir::Instruction::Pointer ins, + shared_ptr out); LanguageUnit_p codegen_device_type(); LanguageUnit_p codegen_workspace_size(std::shared_ptr tu); CodeGenerator::Pointer projgen; diff --git a/src/nnfusion/engine/pass/codegen/cpu_codegen_pass.cpp b/src/nnfusion/engine/pass/codegen/cpu_codegen_pass.cpp index e8e2005eb..aaa33b60b 100644 --- a/src/nnfusion/engine/pass/codegen/cpu_codegen_pass.cpp +++ b/src/nnfusion/engine/pass/codegen/cpu_codegen_pass.cpp @@ -373,7 +373,7 @@ void CpuCodegenPass::create_header_file(std::shared_ptr ctx, // if (device_type() == CUDA_GPU || device_type() == ROCM_GPU) // lu_header << header::cuda->get_code(); lu_header << "extern \"C\" int get_device_type();\n"; - lu_header << "extern \"C\" int64_t get_workspace_size();\n"; + lu_header << "extern \"C\" size_t get_workspace_size();\n"; lu_header << "extern \"C\" int kernel_entry("; std::string params = get_kernel_entry_paras(tu); lu_header << params; diff --git a/src/nnfusion/engine/pass/codegen/cuda_codegen_pass.cpp b/src/nnfusion/engine/pass/codegen/cuda_codegen_pass.cpp index 2537142bf..b16aa950f 100644 --- a/src/nnfusion/engine/pass/codegen/cuda_codegen_pass.cpp +++ b/src/nnfusion/engine/pass/codegen/cuda_codegen_pass.cpp @@ -325,15 +325,32 @@ bool CudaCodegenPass::collect_funcs(std::shared_ptr ctx, // todo: this hack is to eliminate d2d copy caused by extern result memory if (FLAGS_fextern_result_memory && gnode) { + size_t non_control_edge = 0; + std::shared_ptr out_edge; for (size_t i = 0; i < gnode->get_out_edges().size(); i++) { - if (gnode->get_out_edges()[i]->get_dst()->get_op_ptr()->is_output()) + if (!gnode->get_out_edges()[i]->is_control_edge()) { - std::shared_ptr output = gnode->get_out_edges()[i]->get_dst(); + non_control_edge++; + out_edge = gnode->get_out_edges()[i]; + if (non_control_edge > 1) + break; + } + } + + // inplace the result tensor into kernel only if there is one out edge + if (non_control_edge == 1) + { + auto out_tensor = kernel->m_context->outputs[out_edge->get_src_output()]; + if (out_edge->get_dst()->get_op_ptr()->is_output() && + !is_ref_tensor(ins, out_tensor)) + { + std::shared_ptr output = out_edge->get_dst(); std::string in_name = output->get_input_tensor(0).get_name(); std::string out_name = output->get_output_tensor(0).get_name(); int pos = call_str.find(", " + in_name); call_str.replace(pos, in_name.size() + 2, ", " + out_name); + (*output)["is_eliminative"] = true; } } } @@ -716,9 +733,9 @@ nnfusion::LanguageUnit_p CudaCodegenPass::func_call_codegen(nnfusion::ir::Instru } } - auto mem_ref = codegen_mem_ref(kernel); + auto mem_ref = codegen_mem_ref(ins); if (mem_ref != nullptr) - lu << codegen_mem_ref(kernel)->get_code(); + lu << codegen_mem_ref(ins)->get_code(); if (ins->name() == "Memcpy") { @@ -757,15 +774,16 @@ nnfusion::LanguageUnit_p CudaCodegenPass::func_call_codegen(nnfusion::ir::Instru } else { - if (ins->getKernel()->is_eliminative()) - { - lu << "// eliminated: " << func_call; - } - // todo: this hack is to eliminate d2d copy caused by extern result memory - else if (FLAGS_fextern_result_memory && gnode && gnode->get_op_ptr()->is_output()) + if (ins->getKernel()->is_eliminative() || + (*(ins->getGNode()))["is_eliminative"].is_valid_as()) { lu << "// eliminated: " << func_call; } + // // todo: this hack is to eliminate d2d copy caused by extern result memory + // else if (FLAGS_fextern_result_memory && gnode && gnode->get_op_ptr()->is_output()) + // { + // lu << "// eliminated: " << func_call; + // } else { @@ -1116,7 +1134,7 @@ void CudaCodegenPass::create_header_file(std::shared_ptr ctx lu_header << header::cuda_fp16->get_code(); lu_header << "extern \"C\" int get_device_type();\n"; - lu_header << "extern \"C\" int64_t get_workspace_size();\n"; + lu_header << "extern \"C\" size_t get_workspace_size();\n"; lu_header << "extern \"C\" int kernel_entry"; if (FLAGS_fhost_entry) lu_header << "_host"; @@ -1371,7 +1389,7 @@ cmake_minimum_required(VERSION 3.5) SET(SRC "nnfusion_rt.cu" CACHE STRING "codegen source file") SET(TARGET_NAME "nnfusion_naive_rt" CACHE STRING "codegen target name") -SET(CUDA_ARCH "-gencode arch=compute_60,code=sm_60 -gencode arch=compute_61,code=sm_61 -gencode arch=compute_70,code=sm_70 -gencode arch=compute_75,code=sm_75" CACHE STRING "target architecture") +SET(CUDA_ARCH "-gencode arch=compute_60,code=sm_60 -gencode arch=compute_61,code=sm_61 -gencode arch=compute_70,code=sm_70 -gencode arch=compute_75,code=sm_75 -gencode arch=compute_80,code=sm_80 -gencode arch=compute_86,code=sm_86" CACHE STRING "target architecture") if(NOT CMAKE_BUILD_TYPE) set(CMAKE_BUILD_TYPE Release) diff --git a/src/nnfusion/engine/pass/codegen/hlsl_cpp_codegen_pass.cpp b/src/nnfusion/engine/pass/codegen/hlsl_cpp_codegen_pass.cpp index 630099989..26f17d374 100644 --- a/src/nnfusion/engine/pass/codegen/hlsl_cpp_codegen_pass.cpp +++ b/src/nnfusion/engine/pass/codegen/hlsl_cpp_codegen_pass.cpp @@ -25,6 +25,7 @@ DECLARE_bool(fhost_entry); DECLARE_string(fantares_perf_file); DECLARE_bool(ffunction_codegen); DEFINE_bool(fhlsl_descriptor_heap, false, "enable DirectX descriptor heap"); +DEFINE_bool(fhlsl_free_memory, false, "free host memory after coping to device"); void HLSLCPPCodegenPass::initialize(std::shared_ptr ctx, std::shared_ptr tu) @@ -68,6 +69,7 @@ void HLSLCPPCodegenPass::initialize(std::shared_ptr ctx, else lu_init_begin << "\nvoid hlsl_init()\n{\n"; + lu_init_begin << "dxModuleSetCompat(\"cs_6_2\");\n"; if (FLAGS_fhlsl_descriptor_heap) { lu_init_begin << "dxInit(1);\n"; @@ -285,7 +287,7 @@ bool HLSLCPPCodegenPass::collect_funcs(std::shared_ptr ctx, if (FLAGS_fcustomized_mem_imp) lup_func_calls->unit_vec.push_back(get_customized_mem_imp(ins).first); - auto mem_ref = codegen_mem_ref(kernel); + auto mem_ref = codegen_mem_ref(ins); if (mem_ref != nullptr) lup_func_calls->unit_vec.push_back(mem_ref); lup_func_calls->unit_vec.push_back(kernel_func_call); @@ -325,8 +327,13 @@ void HLSLCPPCodegenPass::create_header_file(std::shared_ptr lup_header->require(macro::RUNTIME_API); auto& lu_header = *lup_header; + lu_header << "#include \"half.hpp\";\n"; + lu_header << "using DebugDataType = half_float::half;\n"; + lu_header << "//using DebugDataType = int64_t;\n"; + lu_header << "using namespace half_float;\n"; + lu_header << "extern \"C\" RUNTIME_API int get_device_type();\n"; - lu_header << "extern \"C\" RUNTIME_API int64_t get_workspace_size();\n"; + lu_header << "extern \"C\" RUNTIME_API size_t get_workspace_size();\n"; lu_header << "extern \"C\" RUNTIME_API int kernel_entry"; if (FLAGS_fhost_entry) lu_header << "_host"; @@ -383,6 +390,129 @@ void HLSLCPPCodegenPass::create_main_file(std::shared_ptr ct lu_ << "int main()"; lu_.block_begin(); + if (FLAGS_fhlsl_free_memory) + { + lu_ << "\nhlsl_init();\n\n"; + + for (size_t i = 0; i < tu->arg.size(); i++) + { + auto& tensor = *tu->arg[i]; + //malloc host input arg + lu_ << "//input argument\n"; + lu_ << tensor.get_element_type().c_type_string() << "* " << tensor.get_name() + << "_host = new " << tensor.get_element_type().c_type_string() << "[" + << tensor.get_tensor_layout()->get_size() << "];\n"; + if (!FLAGS_fhost_entry) + { + lu_ << "void* " << tensor.get_name() << " = dxMemAlloc(sizeof(" + << tensor.get_element_type().c_type_string() << ") * " + << tensor.get_tensor_layout()->get_size() << ");\n"; + } + lu_ << "for (int i = 0; i < " << tensor.get_tensor_layout()->get_size() << "; ++i) " + << tensor.get_name() << "_host[i]= 1;\n"; + if (!FLAGS_fhost_entry) + { + lu_ << "dxMemcpyHtoDAsync(" << tensor.get_name() << ", " << tensor.get_name() + << "_host, sizeof(" << tensor.get_element_type().c_type_string() << ") * " + << tensor.get_tensor_layout()->get_size() << ", 0);\n"; + + lu_ << "dxStreamSynchronize(0);\n"; + + lu_ << "delete " << tensor.get_name() << "_host;\n"; + } + } + + for (size_t i = 0; i < tu->out.size(); i++) + { + auto& tensor = *tu->out[i]; + //malloc host output arg + lu_ << "//output argument\n"; + lu_ << tensor.get_element_type().c_type_string() << "* " << tensor.get_name() + << "_host = new " << tensor.get_element_type().c_type_string() << "[" + << tensor.get_tensor_layout()->get_size() << "];\n"; + lu_ << "void* " << tensor.get_name() << ";\n"; + if (FLAGS_fextern_result_memory && !FLAGS_fhost_entry) + { + lu_ << tensor.get_name() << " = dxMemAlloc(sizeof(" + << tensor.get_element_type().c_type_string() << ") * " + << tensor.get_tensor_layout()->get_size() << ");\n"; + } + } + + lu_ << "int steps = 100;\n"; + lu_ << "auto start = std::chrono::high_resolution_clock::now();\n"; + lu_ << "for (int i = 0; i < steps; i++)\n "; + if (FLAGS_fhost_entry) + { + std::string args = get_kernel_entry_args(tu, true); + lu_ << "kernel_entry_host(" << args << ");\n"; + } + else + { + std::string args = get_kernel_entry_args(tu, false); + //lu_ << get_h2dcopy(tu)->get_code(); + lu_ << "kernel_entry(" << args << ");\n"; + } + lu_ << get_sync()->get_code(); + lu_ << "auto end = std::chrono::high_resolution_clock::now();\n"; + lu_ << "auto duration = std::chrono::duration_cast(end - " + "start);\n"; + lu_ << "OutputDebugStringA(\"Time: \%f ms\\n\", duration.count() / 1000.0 / steps);\n"; + + if (!FLAGS_fhost_entry) + { + lu_ << get_d2hcopy(tu)->get_code(); + lu_ << get_sync()->get_code(); + } + lu_ << "std::string result;\n"; + for (size_t i = 0; i < tu->out.size(); i++) + { + auto& tensor = *tu->out[i]; + // lu_ << "std::cout << \"" << tensor.get_name() << "_host = [\" << " << tensor.get_name() + // << "_host[0] << \", \" << " << tensor.get_name() << "_host[1] << \", .., \" << " + // << tensor.get_name() << "_host[" << tensor.get_tensor_layout()->get_size() + // << "-1] << \"]\" << std::endl;"; + size_t num = std::min(size_t(10), tensor.get_tensor_layout()->get_size()); + if (num == 1) + { + lu_ << "result = \"" << tensor.get_name() << "_host = [\" + std::to_string(" + << tensor.get_name() << "_host[0]) + \"]\\n\";\n"; + } + else + { + lu_ << "result = \"" << tensor.get_name() << "_host = ["; + for (size_t j = 0; j < num; j++) + { + lu_ << "\" + std::to_string(" << tensor.get_name() << "_host[" << j + << "]) + \", "; + } + lu_ << ".., \" + std::to_string(" << tensor.get_name() << "_host[" + << tensor.get_tensor_layout()->get_size() << "-1]) + \"]\\n\";\n"; + } + lu_ << "OutputDebugStringA(result.c_str());\n"; + } + + lu_ << "\n//free context\n"; + if (!FLAGS_fhost_entry) + { + for (size_t i = 0; i < tu->arg.size(); i++) + { + auto& tensor = *tu->arg[i]; + lu_ << "dxMemFree(" << tensor.get_name() << ");\n"; + } + } + + if (FLAGS_fextern_result_memory && !FLAGS_fhost_entry) + { + for (size_t i = 0; i < tu->out.size(); i++) + { + auto& tensor = *tu->out[i]; + lu_ << "dxMemFree(" << tensor.get_name() << ");\n"; + } + } + lu_ << "hlsl_free();\n\n"; + } + else { lu_ << "\nhlsl_init();\n\n"; diff --git a/src/nnfusion/engine/pass/codegen/hlsl_cs_codegen_pass.cpp b/src/nnfusion/engine/pass/codegen/hlsl_cs_codegen_pass.cpp index 6ba1914cd..b2e60cdd4 100644 --- a/src/nnfusion/engine/pass/codegen/hlsl_cs_codegen_pass.cpp +++ b/src/nnfusion/engine/pass/codegen/hlsl_cs_codegen_pass.cpp @@ -240,7 +240,7 @@ bool HLSLCSCodegenPass::collect_funcs(std::shared_ptr ctx, std::make_shared(fu->call_unit->get_symbol(), call_str); if (FLAGS_fcustomized_mem_imp) lup_func_calls->unit_vec.push_back(get_customized_mem_imp(ins).first); - auto mem_ref = codegen_mem_ref(kernel); + auto mem_ref = codegen_mem_ref(ins); if (mem_ref != nullptr) lup_func_calls->unit_vec.push_back(mem_ref); lup_func_calls->unit_vec.push_back(kernel_func_call); diff --git a/src/nnfusion/engine/pass/extract_graph_signature.cpp b/src/nnfusion/engine/pass/extract_graph_signature.cpp index ce0e200bc..f537121f8 100644 --- a/src/nnfusion/engine/pass/extract_graph_signature.cpp +++ b/src/nnfusion/engine/pass/extract_graph_signature.cpp @@ -142,7 +142,8 @@ bool ExtractGraphSignature::extract_args(std::shared_ptr ctx const element::Type& et = tv->get_element_type(); string type; - if(!element::Type::nnfusion_element_type_to_dtype_string(tv->get_element_type(), type)){ + if (!element::Type::nnfusion_element_type_to_dtype_string(tv->get_element_type(), type)) + { NNFUSION_LOG(ERROR) << "Get element type failed"; return false; } @@ -188,7 +189,8 @@ bool ExtractGraphSignature::extract_output(std::shared_ptr c tu->out.push_back(tv); string type; - if(!element::Type::nnfusion_element_type_to_dtype_string(tv->get_element_type(), type)){ + if (!element::Type::nnfusion_element_type_to_dtype_string(tv->get_element_type(), type)) + { NNFUSION_LOG(ERROR) << "Get element type failed"; return false; } diff --git a/src/nnfusion/engine/pass/graph/kernel_tuning.cpp b/src/nnfusion/engine/pass/graph/kernel_tuning.cpp index b71167944..e173f85c3 100644 --- a/src/nnfusion/engine/pass/graph/kernel_tuning.cpp +++ b/src/nnfusion/engine/pass/graph/kernel_tuning.cpp @@ -284,7 +284,7 @@ void KernelTuning::submit_tuning_batch_asyc( NNFUSION_CHECK(n_device_type != UNKNOWN); auto ir = nnfusion::op::get_translation(gnode); - // NNFUSION_LOG(INFO) << gnode->get_op_type() << ", ir: " << ir; + NNFUSION_LOG(INFO) << gnode->get_op_type() << " " << gnode->get_name() << ", ir: " << ir; if (!ir.empty()) { auto status = std::make_shared(gnode); diff --git a/src/nnfusion/engine/pass/tensor/tensor_memory_layout.cpp b/src/nnfusion/engine/pass/tensor/tensor_memory_layout.cpp index 6b063e754..fac78f9f3 100644 --- a/src/nnfusion/engine/pass/tensor/tensor_memory_layout.cpp +++ b/src/nnfusion/engine/pass/tensor/tensor_memory_layout.cpp @@ -83,24 +83,24 @@ bool AssignTensorMemoryLayout::run(std::shared_ptr ctx, unordered_set> newlist(alloc_temp); // todo: this hack is to eliminate d2d copy caused by extern result memory bool skip = false; - if (FLAGS_fextern_result_memory && gnode) - { - bool all_users_are_result = true; - for (size_t i = 0; i < gnode->get_out_edges().size(); i++) - { - auto dst = gnode->get_out_edges()[i]->get_dst(); - - if (dst && !dst->get_op_ptr()->is_output()) - { - all_users_are_result = false; - break; - } - } - if (all_users_are_result) - { - skip = true; - } - } + // if (FLAGS_fextern_result_memory && gnode) + // { + // bool all_users_are_result = true; + // for (size_t i = 0; i < gnode->get_out_edges().size(); i++) + // { + // auto dst = gnode->get_out_edges()[i]->get_dst(); + + // if (dst && !dst->get_op_ptr()->is_output()) + // { + // all_users_are_result = false; + // break; + // } + // } + // if (all_users_are_result) + // { + // skip = true; + // } + // } // The output of output nodes refers to the input, so there is NO need // to allocate memory space for output of output nodes. if (!skip && (!gnode || !gnode->get_op_ptr()->is_output() || diff --git a/src/nnfusion/frontend/onnx_import/op/constant.hpp b/src/nnfusion/frontend/onnx_import/op/constant.hpp index f3dd0bfc9..8c4fbdbe3 100644 --- a/src/nnfusion/frontend/onnx_import/op/constant.hpp +++ b/src/nnfusion/frontend/onnx_import/op/constant.hpp @@ -50,6 +50,7 @@ namespace nnfusion const element::Type&, const Tensor&)>> the_map = {{element::f32, __make_constant_op}, {element::f64, __make_constant_op}, + {element::boolean, __make_constant_op}, {element::i32, __make_constant_op}, {element::i64, __make_constant_op}, {element::u32, __make_constant_op}, diff --git a/src/nnfusion/frontend/onnx_import/op/conv.cpp b/src/nnfusion/frontend/onnx_import/op/conv.cpp index cb98dfd1d..a57341ebf 100644 --- a/src/nnfusion/frontend/onnx_import/op/conv.cpp +++ b/src/nnfusion/frontend/onnx_import/op/conv.cpp @@ -78,10 +78,10 @@ namespace nnfusion { conv_data_format = "NCHW"; } - // else if (data_shape.size() == 5) - // { - // conv_data_format = "NCDHW"; - // } + else if (data_shape.size() == 5) + { + conv_data_format = "NCDHW"; + } else { NNFUSION_CHECK_FAIL() << "Convolution with dimensions of " @@ -168,7 +168,7 @@ namespace nnfusion strides, dilations, padding_below, padding_above, conv_data_format); conv_node = m_graph->add_node_and_edge(conv_op, {data, filters}); } - else + else if (conv_data_format == "NCHW") { // split data and filters for group conv std::size_t n_data_channels{data_shape.at(1)}; @@ -264,6 +264,10 @@ namespace nnfusion convolution_nodes); } } + else + { + NNFUSION_CHECK_FAIL() << "Not support this Convolution yet."; + } // add bias if (input_indexes.size() == 3) diff --git a/src/nnfusion/frontend/onnx_import/op/pad.hpp b/src/nnfusion/frontend/onnx_import/op/pad.hpp index a89a09133..7c65adb38 100644 --- a/src/nnfusion/frontend/onnx_import/op/pad.hpp +++ b/src/nnfusion/frontend/onnx_import/op/pad.hpp @@ -22,16 +22,17 @@ namespace nnfusion /* * since opset 11, 'pads' and 'value' have been moved from attributes to inputs * */ - if (node_proto.attribute_size() == 1){ - + if (node_proto.attribute_size() == 1) + { auto input_gnode = GetInputNode(all_ng_nodes, node_proto, 0); auto padding_gnode = GetInputNode(all_ng_nodes, node_proto, 1); std::vector paddings; bool status = GetValueFromNGraphOp(padding_gnode, &paddings); NNFUSION_CHECK(status); - NNFUSION_CHECK(paddings.size() % 2 == 0) - << "Constant node for paddings does not have an even number of elements"; + NNFUSION_CHECK(paddings.size() % 2 == 0) << "Constant node for paddings " + "does not have an even number " + "of elements"; nnfusion::Shape padding_below(paddings.size() / 2); nnfusion::Shape padding_above(paddings.size() / 2); @@ -48,10 +49,11 @@ namespace nnfusion std::make_shared(input_gnode->get_element_type(), nnfusion::Shape{}, std::vector{"0"}); - auto pad_val_gnode = m_graph->add_node_and_edge(pad_val_op, GNodeVector({})); + auto pad_val_gnode = + m_graph->add_node_and_edge(pad_val_op, GNodeVector({})); - auto pad_op = - std::make_shared(padding_below, padding_above, padding_interior); + auto pad_op = std::make_shared( + padding_below, padding_above, padding_interior); pad_op->set_name(node_proto.output(0)); auto pad_gnode = @@ -59,7 +61,9 @@ namespace nnfusion NamedNodeVector ret{{node_proto.output(0), pad_gnode}}; return ret; - }else{ + } + else + { cout << "meet pad op" << endl; /* for pad op, 0: mode, 1: pads, 2: constant * we can use attr.name() to get the name of the attr @@ -70,7 +74,8 @@ namespace nnfusion auto input_gnode = GetInputNode(all_ng_nodes, node_proto, 0); const onnx::AttributeProto& modeAttr = node_proto.attribute(0); cout << modeAttr.name() << endl; - if(modeAttr.s() != "constant") NNFUSION_CHECK_FAIL() << "unsupported padding type: " << modeAttr.s(); + if (modeAttr.s() != "constant") + NNFUSION_CHECK_FAIL() << "unsupported padding type: " << modeAttr.s(); const onnx::AttributeProto& padAttr = node_proto.attribute(1); cout << padAttr.name() << endl; for (int i = 0; i < 8; ++i) @@ -85,7 +90,8 @@ namespace nnfusion std::make_shared(input_gnode->get_element_type(), nnfusion::Shape{}, std::vector{"0"}); - auto pad_val_gnode = m_graph->add_node_and_edge(pad_val_op, GNodeVector({})); + auto pad_val_gnode = + m_graph->add_node_and_edge(pad_val_op, GNodeVector({})); nnfusion::Shape padding_below(4); nnfusion::Shape padding_above(4); nnfusion::Shape padding_interior(4); @@ -95,11 +101,11 @@ namespace nnfusion for (int i = 0; i < 4; ++i) { padding_below[i] = padAttr.ints(i); - padding_above[i] = padAttr.ints(i+4); + padding_above[i] = padAttr.ints(i + 4); } - auto pad_op = - std::make_shared(padding_below, padding_above, padding_interior); + auto pad_op = std::make_shared( + padding_below, padding_above, padding_interior); pad_op->set_name(node_proto.output(0)); auto pad_gnode = diff --git a/src/nnfusion/frontend/onnx_import/op/where.cpp b/src/nnfusion/frontend/onnx_import/op/where.cpp index 2d184858c..fb7b8bb6d 100644 --- a/src/nnfusion/frontend/onnx_import/op/where.cpp +++ b/src/nnfusion/frontend/onnx_import/op/where.cpp @@ -20,6 +20,8 @@ //---------------------------------------------------------------------------------------------- #include "where.hpp" +#include "nnfusion/core/graph/util/autobroadcast.hpp" +#include "nnfusion/core/graph/util/numpy_transpose.hpp" #include "nnfusion/core/operators/generic_op/generic_op.hpp" namespace nnfusion @@ -39,6 +41,12 @@ namespace nnfusion auto x_gnode = input_indices[1]; auto y_gnode = input_indices[2]; + std::tie(x_gnode, y_gnode) = + graph::numpy_broadcast(std::make_pair(x_gnode, y_gnode), m_graph); + + std::tie(x_gnode, cond_gnode) = + graph::numpy_broadcast(std::make_pair(x_gnode, cond_gnode), m_graph); + auto node_name = node_proto.output(0); nnfusion::op::OpConfig::any op_config; diff --git a/src/nnfusion/frontend/onnx_import/util/util.cpp b/src/nnfusion/frontend/onnx_import/util/util.cpp index f07c48c1c..680be072e 100644 --- a/src/nnfusion/frontend/onnx_import/util/util.cpp +++ b/src/nnfusion/frontend/onnx_import/util/util.cpp @@ -91,7 +91,7 @@ namespace nnfusion switch (onnx_et) { case onnx::TensorProto_DataType::TensorProto_DataType_BOOL: - return make_constant_op(element::boolean, shape, tensor); + return make_constant_op(element::i32, shape, tensor); case onnx::TensorProto_DataType::TensorProto_DataType_FLOAT: return make_constant_op(element::f32, shape, tensor); case onnx::TensorProto_DataType::TensorProto_DataType_FLOAT16: diff --git a/src/nnfusion/frontend/onnx_import/util/util.hpp b/src/nnfusion/frontend/onnx_import/util/util.hpp index 23a136f01..a427f8c83 100644 --- a/src/nnfusion/frontend/onnx_import/util/util.hpp +++ b/src/nnfusion/frontend/onnx_import/util/util.hpp @@ -61,36 +61,6 @@ namespace nnfusion return std::vector(); } - template <> - inline std::vector get_data(const onnx::TensorProto& tensor) - { - if (tensor.has_raw_data()) - { - return __get_raw_data(tensor.raw_data()); - } - switch (tensor.data_type()) - { - case onnx::TensorProto_DataType_DOUBLE: - return __get_data(tensor.double_data()); - case onnx::TensorProto_DataType_FLOAT: - case onnx::TensorProto_DataType_FLOAT16: - return __get_data(tensor.float_data()); - case onnx::TensorProto_DataType_INT32: - return __get_data(tensor.int32_data()); - case onnx::TensorProto_DataType_INT64: - return __get_data(tensor.int64_data()); - case onnx::TensorProto_DataType_UINT64: - return __get_data(tensor.uint64_data()); - default: - NNFUSION_CHECK_FAIL() - << "invalid data type: " - << onnx::TensorProto_DataType_Name( - static_cast(tensor.data_type())); - break; - } - return std::vector(); - } - template <> inline std::vector get_data(const onnx::TensorProto& tensor) { @@ -143,24 +113,97 @@ namespace nnfusion } template <> - inline std::vector get_data(const onnx::TensorProto& tensor){ - + inline std::vector get_data(const onnx::TensorProto& tensor) + { if (tensor.has_raw_data()) { - return __get_raw_data(tensor.raw_data()); + return __get_raw_data(tensor.raw_data()); + } + else + { + NNFUSION_LOG(NNFUSION_WARNING) << "Have no raw data" << endl; + } + + if (tensor.data_type() == onnx::TensorProto_DataType_FLOAT16) + { + nnfusion::Shape shape{std::begin(tensor.dims()), std::end(tensor.dims())}; + size_t num_element = shape_size(shape); + std::vector raw_data = __get_data(tensor.int32_data()); + std::vector ret((num_element + 1) / 2); + uint32_t* src_p = (uint32_t*)raw_data.data(); + uint16_t* dst_p = (uint16_t*)ret.data(); + for (size_t i = 0; i < num_element; i++) + { + NNFUSION_CHECK((src_p[i] & 0xFFFF0000) == 0); + dst_p[i] = src_p[i] & 0x0000FFFF; + } + if (num_element % 2 == 1) + { + dst_p[num_element] = 0; + } + + return ret; + } + if (tensor.data_type() == onnx::TensorProto_DataType_FLOAT) + { + return __get_data(tensor.float_data()); + } + if (tensor.data_type() == onnx::TensorProto_DataType_INT32) + { + return __get_data(tensor.int32_data()); + } + if (tensor.data_type() == onnx::TensorProto_DataType_INT64) + { + return __get_data(tensor.int64_data()); + } + if (tensor.data_type() == onnx::TensorProto_DataType_UINT64) + { + return __get_data(tensor.uint64_data()); } + NNFUSION_CHECK_FAIL() + << "invalid data type: " + << onnx::TensorProto_DataType_Name( + static_cast(tensor.data_type())); + return std::vector(); } template <> inline std::vector get_data(const onnx::TensorProto& tensor) { - if (tensor.has_raw_data()) + if (tensor.data_type() == onnx::TensorProto_DataType_BOOL) { - return __get_raw_data(tensor.raw_data()); + // onnx store bool in byte + std::vector res; + if (tensor.has_raw_data()) + { + res = __get_raw_data(tensor.raw_data()); + } + else + { + res = __get_data(tensor.int32_data()); + } + NNFUSION_CHECK(res.size() > 0); + nnfusion::Shape shape{std::begin(tensor.dims()), std::end(tensor.dims())}; + size_t num_element = shape_size(shape); + char* raw_p = reinterpret_cast(res.data()); + std::vector ret; + ret.reserve(num_element); + for (size_t i = 0; i < num_element; i++) + { + ret.push_back(raw_p[i]); + } + return ret; } if (tensor.data_type() == onnx::TensorProto_DataType_INT32) { - return __get_data(tensor.int32_data()); + if (tensor.has_raw_data()) + { + return __get_raw_data(tensor.raw_data()); + } + else + { + return __get_data(tensor.int32_data()); + } } NNFUSION_CHECK_FAIL() << "invalid data type: " << onnx::TensorProto_DataType_Name( diff --git a/src/python/example/main.py b/src/python/example/main.py new file mode 100644 index 000000000..48a36f953 --- /dev/null +++ b/src/python/example/main.py @@ -0,0 +1,84 @@ +#!D:\project\transfer_xbox\python\tools\python.exe +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from logging import raiseExceptions +import time +import argparse +import numpy as np +import torch +# torch.manual_seed(0) +import torch.nn as nn +import torch.nn.functional as F + +from nnfusion.executor import Executor +from nnfusion.session import generate_sample +from nnfusion.data_format import cast_pytorch_tensor, cast_hlsl_tensor, HLSLTensor + + +def inference(nnf_model_path, total_iter): + assert total_iter >= 1 + executor = Executor(nnf_model_path) + input_dict, output_dict = {}, {} + if executor.host_mode: + # host mode leverage pytorch tensor as storage + for input in executor.get_inputs(): + input_dict[input.name] = cast_pytorch_tensor(generate_sample(input)) + for output in executor.get_outputs(): + output_dict[output.name] = cast_pytorch_tensor(generate_sample(output)) + else: + if executor.device_type == 0: + # cuda device + for input in executor.get_inputs(): + input_dict[input.name] = cast_pytorch_tensor(generate_sample(input, "cuda")) + for output in executor.get_outputs(): + output_dict[output.name] = cast_pytorch_tensor(generate_sample(output, "cuda")) + elif executor.device_type == 3: + # hlsl device + for input in executor.get_inputs(): + input_dict[input.name] = cast_hlsl_tensor(HLSLTensor(generate_sample(input))) + for output in executor.get_outputs(): + output_dict[output.name] = cast_hlsl_tensor(HLSLTensor(generate_sample(output))) + else: + raise Exception("only support device kernel_entry on cuda/hlsl backend.") + + + # warm up + for _ in range(5): + executor(input_dict, output_dict) + for k, v in output_dict.items(): + print(f"{k} = {v.reference}") + + # evaluate + print(f"Begin evaluation of {total_iter} iters") + start = time.time() + perf_list = [] + for _ in range(total_iter): + start_i = time.time() + executor(input_dict, output_dict) + end_i = time.time() + #print(end_i - start_i) + perf_list.append(end_i - start_i) + end = time.time() + + latency_ms = np.array(perf_list) * 1000 + batch_size = list(input_dict.values())[0].shape[0] + print(f"average_latency = {np.mean(latency_ms)} ms") + print(f"latency_50 = {np.percentile(latency_ms, 50)} ms") + print(f"latency_75 = {np.percentile(latency_ms, 75)} ms") + print(f"latency_90 = {np.percentile(latency_ms, 90)} ms") + print(f"latency_95 = {np.percentile(latency_ms, 95)} ms") + print(f"latency_99 = {np.percentile(latency_ms, 99)} ms") + print(f"throughput = {batch_size * (1000.0 / np.mean(latency_ms))} sample/s") + print(f"total elaspe {end - start} s") + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('--nnf_model_path', type=str) + parser.add_argument('--total_iter', type=int, default=1) + args = parser.parse_args() + inference(args.nnf_model_path, args.total_iter) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/src/python/nnfusion/data_format.py b/src/python/nnfusion/data_format.py index 93ce52904..9450ad02e 100644 --- a/src/python/nnfusion/data_format.py +++ b/src/python/nnfusion/data_format.py @@ -2,6 +2,8 @@ # Licensed under the MIT License. import ctypes +from numpy import dtype +import torch from . import dtypes @@ -46,13 +48,74 @@ def dtype(self): def reference(self): return self._reference +class HLSLTensor(object): + antares_lib = None + + @classmethod + def init_antares_lib(cls, antares_dll_path): + if cls.antares_lib is None: + # cls.antares_lib = ctypes.cdll.LoadLibrary(r"D:\project\nnfusion_rt_pow\nnfusion_rt\dxcompute_codegen\Direct3DWinNN_seperate_dll\x64\Release\antares.dll") + cls.antares_lib = ctypes.cdll.LoadLibrary(antares_dll_path) + # alloc + cls.antares_lib.dxMemAlloc.argtypes = [ctypes.c_uint64] + cls.antares_lib.dxMemAlloc.restype = ctypes.c_void_p + # free + cls.antares_lib.dxMemFree.argtypes = [ctypes.c_void_p] + cls.antares_lib.dxMemFree.restype = ctypes.c_int32 + # H2D + cls.antares_lib.dxMemcpyHtoDAsync.argtypes = [ctypes.c_void_p, ctypes.c_void_p, ctypes.c_uint64, ctypes.c_void_p] + cls.antares_lib.dxMemcpyHtoDAsync.restype = ctypes.c_int32 + # D2H + cls.antares_lib.dxMemcpyDtoHAsync.argtypes = [ctypes.c_void_p, ctypes.c_void_p, ctypes.c_uint64, ctypes.c_void_p] + cls.antares_lib.dxMemcpyDtoHAsync.restype = ctypes.c_int32 + # Sync + cls.antares_lib.dxStreamSynchronize.argtypes = [ctypes.c_void_p] + cls.antares_lib.dxStreamSynchronize.restype = ctypes.c_int32 + return + + def __init__(self, pytorch_tensor) -> None: + if self.antares_lib is None: + raise Exception("Please init antares lib firstly(e.g. creating a executor instance antomatically init antares lib") + pytorch_tensor = pytorch_tensor.contiguous() + self.shape = pytorch_tensor.shape + self.pt_type = str(pytorch_tensor.dtype).split(".")[-1] + self.dtype = dtypes.str2type[self.pt_type].type_str + num_element = pytorch_tensor.numel() + element_size = pytorch_tensor.element_size() + self.size = num_element * element_size + self.pointer = self.antares_lib.dxMemAlloc(self.size) + self.antares_lib.dxMemcpyHtoDAsync(self.pointer, ctypes.cast(pytorch_tensor.data_ptr(), ctypes.c_void_p), self.size, None) + self.antares_lib.dxStreamSynchronize(None) + + + def __del__(self): + if hasattr(self, "pointer") and self.pointer: + self.antares_lib.dxMemFree(self.pointer) + self.pointer == ctypes.c_void_p(None) + + def __str__(self): + return self.to_pytorch_tensor().__str__() + + def to_pytorch_tensor(self): + res = torch.empty(self.shape, dtype=dtypes.str2type[self.dtype].torch_type) + self.antares_lib.dxMemcpyDtoHAsync(ctypes.cast(res.data_ptr(), ctypes.c_void_p), self.pointer, self.size, None) + self.antares_lib.dxStreamSynchronize(None) + return res + +def cast_hlsl_tensor(hlsl_tensor): + pointer_type = ctypes.POINTER(dtypes.str2type[hlsl_tensor.dtype].c_type) + pointer = ctypes.cast(hlsl_tensor.pointer, pointer_type) + shape = hlsl_tensor.shape + dtype = hlsl_tensor.dtype + reference = hlsl_tensor + return DataFormat(pointer, pointer_type, shape, dtype, reference) def cast_pytorch_tensor(pytorch_tensor): if not pytorch_tensor.is_contiguous(): raise Exception( "Cannot cast incontiguous tensor, please use tensor.detach().clone().contiguous() before casting." ) - tensor_addr = pytorch_tensor.storage().data_ptr() + tensor_addr = pytorch_tensor.data_ptr() shape = pytorch_tensor.shape dtype = str(pytorch_tensor.dtype).split(".")[-1] pointer_type = ctypes.POINTER(dtypes.str2type[dtype].c_type) diff --git a/src/python/nnfusion/dtypes.py b/src/python/nnfusion/dtypes.py index 1632b60b5..911c1b192 100644 --- a/src/python/nnfusion/dtypes.py +++ b/src/python/nnfusion/dtypes.py @@ -39,9 +39,9 @@ "uint8": TypeObject._make(["uint8", ctypes.c_uint8, torch.uint8, numpy.uint8]), "uint16": - TypeObject._make(["uint8", ctypes.c_uint16, None, numpy.uint16]), + TypeObject._make(["uint16", ctypes.c_uint16, None, numpy.uint16]), "uint32": - TypeObject._make(["uint8", ctypes.c_uint32, None, numpy.uint32]), + TypeObject._make(["uint32", ctypes.c_uint32, None, numpy.uint32]), "uint64": - TypeObject._make(["uint8", ctypes.c_uint64, None, numpy.uint64]), + TypeObject._make(["uint64", ctypes.c_uint64, None, numpy.uint64]), } diff --git a/src/python/nnfusion/executor.py b/src/python/nnfusion/executor.py index 6653a985d..79fb374fd 100644 --- a/src/python/nnfusion/executor.py +++ b/src/python/nnfusion/executor.py @@ -4,10 +4,9 @@ import json import os import platform - import torch -from .data_format import cast_pytorch_tensor +from .data_format import HLSLTensor, cast_pytorch_tensor from .description import IODescription from .utils import cd @@ -98,6 +97,8 @@ def __init__(self, nnf_rt_dir, device=None): # prepare init/free/kernel_entry self.init_flag = False + if os.path.exists(os.path.join(nnf_rt_dir, "antares.dll")): + HLSLTensor.init_antares_lib(os.path.join(nnf_rt_dir, "antares.dll")) # dxil.dll and dxcompiler.dll must be manually imported if os.path.exists(os.path.join(nnf_rt_dir, "dxil.dll")): ctypes.cdll.LoadLibrary(os.path.join(nnf_rt_dir, "dxil.dll")) @@ -106,8 +107,10 @@ def __init__(self, nnf_rt_dir, device=None): self.libnnf = ctypes.cdll.LoadLibrary(self.libnnf_path) if hasattr(self.libnnf, "kernel_entry_host"): self.kernel_entry = self.libnnf.kernel_entry_host + self.host_mode = True elif hasattr(self.libnnf, "kernel_entry"): self.kernel_entry = self.libnnf.kernel_entry + self.host_mode = False else: raise Exception("No kernel_entry found in nnfusion_rt") device_type = self.get_device_type() @@ -180,18 +183,7 @@ def __call__(self, *args, **kwargs): # self.feed_tensors(*args, **kwargs) self.feed_data(*args, **kwargs) - def feed_data(self, inputs, outputs, strict=True): - """ - Execute the kernel_entry in nnf runtime - - Parameters: - inputs: a dict from name to nnf DataFormat - outputs: a dict from name to nnf DataFormat - strict: False if allow unused inputs/outputs - - Returns: - None - """ + def _dict_to_pointer_list(self, inputs, outputs, strict=True): signature = [None] * (len(self.input_descs) + len(self.output_descs)) params = [None] * (len(self.input_descs) + len(self.output_descs)) for name, data_format in inputs.items(): @@ -223,6 +215,21 @@ def feed_data(self, inputs, outputs, strict=True): else: if strict: raise Exception(f"Unused output {name}") + return signature, params + + def feed_data(self, inputs, outputs, strict=True): + """ + Execute the kernel_entry in nnf runtime + + Parameters: + inputs: a dict from name to nnf DataFormat + outputs: a dict from name to nnf DataFormat + strict: False if allow unused inputs/outputs + + Returns: + None + """ + signature, params = self._dict_to_pointer_list(inputs, outputs, strict=strict) self.feed_pointers(signature, params) def feed_pointers(self, signature, params): @@ -233,7 +240,7 @@ def _maybe_reserve_mem(self, device): get_workspace_size = getattr(self.libnnf, 'get_workspace_size', None) if get_workspace_size is None: return None - get_workspace_size.restype = ctypes.c_int64 + get_workspace_size.restype = ctypes.c_size_t n_byte = get_workspace_size() if not n_byte: return None diff --git a/src/tools/nnfusion/templates/dxcompute/Direct3DXBoxNN/runtime/D3D12APIWrapper.cpp b/src/tools/nnfusion/templates/dxcompute/Direct3DWinNN/antares/D3D12APIWrapper.cpp similarity index 84% rename from src/tools/nnfusion/templates/dxcompute/Direct3DXBoxNN/runtime/D3D12APIWrapper.cpp rename to src/tools/nnfusion/templates/dxcompute/Direct3DWinNN/antares/D3D12APIWrapper.cpp index 87e13b14a..9f84f904a 100644 --- a/src/tools/nnfusion/templates/dxcompute/Direct3DXBoxNN/runtime/D3D12APIWrapper.cpp +++ b/src/tools/nnfusion/templates/dxcompute/Direct3DWinNN/antares/D3D12APIWrapper.cpp @@ -8,13 +8,23 @@ #include #include #include +#include +#include #define _USE_GPU_TIMER_ #define _USE_DXC_ +#define ANTARES_EXPORTS + #include "D3D12Util.h" #include "D3D12APIWrapper.h" +#if _DEBUG +#define DEBUG_PRINT(msg) (fprintf(stderr, "[DEBUG] %s\n", msg), fflush(stderr)) +#else +#define DEBUG_PRINT(msg) +#endif + namespace { static bool _USE_DESCRIPTOR_HEAP_ = false; @@ -75,16 +85,34 @@ namespace { } }; + struct VectorHasher { + int operator()(const std::vector& V) const { + int hash = V.size(); + for (auto& i : V) { + hash ^= (i ^ (i >> 32)) + 0x9e3779b9L; + } + return hash; + } + }; + + std::string ReplaceAll(std::string str, const std::string& from, const std::string& to) { + size_t start_pos = 0; + while ((start_pos = str.find(from, start_pos)) != std::string::npos) { + str.replace(start_pos, from.length(), to); + start_pos += to.length(); // Handles case where 'to' is a substring of 'from' + } + return str; + } + struct dx_shader_t { int block[3], thread[3]; std::vector inputs, outputs; std::string source; - CD3DX12_SHADER_BYTECODE bytecode; + std::unordered_map, ComPtr, VectorHasher> pPSO_ht; // bytecode_ht; // Added D3D12 resource ptr. ComPtr pRootSignature; - ComPtr pPSO; }; // Stream is wrapper of resources for record and execute commands. @@ -126,11 +154,8 @@ namespace { std::vector queryHeapsNeedToResolve; }; -#ifdef _DEBUG - static std::shared_ptr device = std::make_shared(true, true); -#else - static std::shared_ptr device = std::make_shared(false, false); -#endif + + static std::shared_ptr device; static void* defaultStream = nullptr; @@ -175,6 +200,16 @@ namespace { int dxInit(int flags) { + DEBUG_PRINT(__func__); + + if (device == nullptr) { +#ifdef _DEBUG + device = std::make_shared(true, true); +#else + device = std::make_shared(false, false); +#endif + } + if (!defaultStream) { // flags = 1: enable descriptor heap, no logging @@ -193,16 +228,55 @@ int dxInit(int flags) } int dxFinalize() { + DEBUG_PRINT(__func__); + device = nullptr; defaultStream = nullptr; return 0; } +static std::unordered_map> unused_buffers; +static std::unordered_map buffer_slots; + +inline size_t compute_slotsize(size_t &value) { + static const int tab64[64] = { + 63, 0, 58, 1, 59, 47, 53, 2, + 60, 39, 48, 27, 54, 33, 42, 3, + 61, 51, 37, 40, 49, 18, 28, 20, + 55, 30, 34, 11, 43, 14, 22, 4, + 62, 57, 46, 52, 38, 26, 32, 41, + 50, 36, 17, 19, 29, 10, 13, 21, + 56, 45, 25, 31, 35, 16, 9, 12, + 44, 24, 15, 8, 23, 7, 6, 5 }; + + value -= 1; + value |= value >> 1; + value |= value >> 2; + value |= value >> 4; + value |= value >> 8; + value |= value >> 16; + value |= value >> 32; + + size_t slot_id = tab64[((uint64_t)((value - (value >> 1)) * 0x07EDD5E59A4E28C2LLU)) >> 58]; + value += 1; + return slot_id; +} + void* dxMemAlloc(size_t bytes) { + DEBUG_PRINT(__func__); + if (dxInit(0) != 0) return nullptr; + auto slot_id = compute_slotsize(bytes); + auto& slot = unused_buffers[slot_id]; + if (slot.size()) { + void* buff = slot.back(); + slot.pop_back(); + return buff; + } + auto buff = new dx_buffer_t(); buff->size = bytes; device->CreateGPUOnlyResource(bytes, &buff->handle); @@ -211,20 +285,30 @@ void* dxMemAlloc(size_t bytes) void* virtualPtr = VirtualAlloc(nullptr, bytes, MEM_RESERVE, PAGE_NOACCESS); assert(virtualPtr != nullptr); + buffer_slots[virtualPtr] = slot_id; memBlocks[virtualPtr] = buff; return virtualPtr; } -int dxMemFree(void* vPtr) +int dxMemFree(void* virtualPtr) { - VirtualFree(vPtr, 0, MEM_RELEASE); - memBlocks.erase(vPtr); + DEBUG_PRINT(__func__); + + auto it = buffer_slots.find(virtualPtr); + assert(it != buffer_slots.end()); + unused_buffers[it->second].push_back(virtualPtr); + return 0; + + VirtualFree(virtualPtr, 0, MEM_RELEASE); + memBlocks.erase(virtualPtr); return 0; } void* dxShaderLoad_v2(const char* shader_src) { + DEBUG_PRINT(__func__); + if (dxInit(0) != 0) return nullptr; @@ -241,27 +325,6 @@ void* dxShaderLoad_v2(const char* shader_src) dx_shader_t* handle = new dx_shader_t; handle->source = source; -#ifdef _USE_DXC_ - // Use cs_6_0 since dxc only supports cs_6_0 or higher shader models. - auto computeShader = antares::DXCompiler::Get()->Compile(source.data(), (uint32_t)source.size(), L"CSMain", L"cs_6_0"); - if (computeShader != nullptr) - handle->bytecode = CD3DX12_SHADER_BYTECODE(computeShader->GetBufferPointer(), computeShader->GetBufferSize()); - else - abort(); -#else - ComPtr computeShader = nullptr, errMsg = nullptr; - if (D3DCompile(source.data(), source.size(), NULL, NULL, NULL, "CSMain", "cs_5_1", 0, 0, &computeShader, &errMsg) >= 0 && computeShader != nullptr) - handle->bytecode = CD3DX12_SHADER_BYTECODE(computeShader.Get()); - else { - auto error_message = (char*)errMsg->GetBufferPointer(); - fprintf(stderr, "[ERROR] D3D12: Shader Compile Failed: %s\n", error_message); - } -#endif - if (computeShader == nullptr) { - //delete handle; - return nullptr; - } - std::string str_params; std::vector arr_params, in_params, out_params; bool legacy_format = (source.size() >= 3 && source.substr(0, 3) == "///"); @@ -321,8 +384,6 @@ void* dxShaderLoad_v2(const char* shader_src) auto& hd = handle; ComPtr& m_computeRootSignature = hd->pRootSignature; - ComPtr& m_computeState = hd->pPSO; - D3D12_COMPUTE_PIPELINE_STATE_DESC computePsoDesc{}; CD3DX12_VERSIONED_ROOT_SIGNATURE_DESC computeRootSignatureDesc; std::vector computeRootParameters; @@ -356,21 +417,20 @@ void* dxShaderLoad_v2(const char* shader_src) IFE(D3DX12SerializeVersionedRootSignature(&computeRootSignatureDesc, D3D_ROOT_SIGNATURE_VERSION_1_1, &signature, &error)); IFE(device->pDevice->CreateRootSignature(0, signature->GetBufferPointer(), signature->GetBufferSize(), IID_GRAPHICS_PPV_ARGS(m_computeRootSignature.ReleaseAndGetAddressOf()))); - - computePsoDesc.CS = hd->bytecode; - computePsoDesc.pRootSignature = m_computeRootSignature.Get(); - IFE(device->pDevice->CreateComputePipelineState(&computePsoDesc, IID_GRAPHICS_PPV_ARGS(m_computeState.ReleaseAndGetAddressOf()))); - return handle; } void dxShaderUnload(void* hShader) { + DEBUG_PRINT(__func__); + free(hShader); } void* dxModuleLoad(const char* module_src) { + DEBUG_PRINT(__func__); + std::string source; const char proto[] = "file://"; if (strncmp(module_src, proto, sizeof(proto) - 1) == 0) { @@ -401,6 +461,8 @@ void* dxModuleLoad(const char* module_src) void dxModuleUnload(void* hModule) { + DEBUG_PRINT(__func__); + auto& hShaderDict = *(std::unordered_map*)hModule; for (auto& it : hShaderDict) dxShaderUnload(it.second); @@ -409,6 +471,8 @@ void dxModuleUnload(void* hModule) void* dxModuleGetShader(void* hModule, const char* fname) { + DEBUG_PRINT(__func__); + auto& dict = *(std::unordered_map*)hModule; auto it = dict.find(fname); return it != dict.end() ? it->second : nullptr; @@ -416,6 +480,8 @@ void* dxModuleGetShader(void* hModule, const char* fname) void* dxStreamCreate() { + DEBUG_PRINT(__func__); + if (dxInit(0) != 0) return nullptr; @@ -446,6 +512,8 @@ void* dxStreamCreate() int dxStreamDestroy(void* hStream) { + DEBUG_PRINT(__func__); + if (hStream != nullptr) delete (dx_stream_t*)hStream; return 0; @@ -453,6 +521,8 @@ int dxStreamDestroy(void* hStream) int dxStreamSubmit(void* hStream) { + DEBUG_PRINT(__func__); + if (!hStream) hStream = defaultStream; @@ -480,6 +550,8 @@ int dxStreamSubmit(void* hStream) int dxStreamSynchronize(void* hStream) { + DEBUG_PRINT(__func__); + if (!hStream) hStream = defaultStream; @@ -499,6 +571,8 @@ int dxStreamSynchronize(void* hStream) int dxMemcpyDtoDAsync(void* dst, void* src, size_t bytes, void* hStream) { + DEBUG_PRINT(__func__); + if (!hStream) hStream = defaultStream; @@ -522,6 +596,8 @@ int dxMemcpyDtoDAsync(void* dst, void* src, size_t bytes, void* hStream) int dxMemcpyHtoDAsync(void* dst, void* src, size_t bytes, void *hStream) { + DEBUG_PRINT(__func__); + if (!hStream) hStream = defaultStream; @@ -563,6 +639,8 @@ int dxMemcpyHtoDAsync(void* dst, void* src, size_t bytes, void *hStream) int dxMemcpyDtoHAsync(void* dst, void* src, size_t bytes, void* hStream) { + DEBUG_PRINT(__func__); + if (!hStream) hStream = defaultStream; @@ -599,15 +677,58 @@ int dxMemcpyDtoHAsync(void* dst, void* src, size_t bytes, void* hStream) return dxStreamSynchronize(hStream); } -int dxShaderLaunchAsync(void* hShader, void** buffers, void* hStream) +static std::wstring default_compat = L"cs_6_0"; + +int dxModuleSetCompat(const char* compat_name) { + std::wstring_convert> converter; + ::default_compat = converter.from_bytes(compat_name); + return 0; +} + +int dxShaderLaunchAsyncExt(void* hShader, void** buffers, int n, int blocks, void* hStream) { + DEBUG_PRINT(__func__); + if (!hStream) hStream = defaultStream; - auto hd = (dx_shader_t*)hShader; auto pStream = (dx_stream_t*)hStream; assert(pStream->state == dx_stream_t::State::INRECORD); + n -= hd->inputs.size() + hd->outputs.size(); + n = max(0, n); + std::vector pargs(n); + for (int i = 0, j = hd->inputs.size() + hd->outputs.size(); i < n; ++i, ++j) + pargs[i] = (size_t)buffers[j]; + auto pso_iter = hd->pPSO_ht.find(pargs); + if (pso_iter == hd->pPSO_ht.end()) { + std::string src = hd->source; + for (int i = 0; i < n; ++i) + src = ReplaceAll(src, "@" + std::to_string(i) + "@", std::to_string(pargs[i])); + CD3DX12_SHADER_BYTECODE bytecode; +#ifdef _USE_DXC_ + // Use cs_6_0 since dxc only supports cs_6_0 or higher shader models. + auto computeShader = antares::DXCompiler::Get()->Compile(src.data(), (uint32_t)src.size(), L"CSMain", default_compat.c_str()); + if (computeShader != nullptr) + bytecode = CD3DX12_SHADER_BYTECODE(computeShader->GetBufferPointer(), computeShader->GetBufferSize()); +#else + ComPtr computeShader = nullptr, errMsg = nullptr; + if (D3DCompile(source.data(), source.size(), NULL, NULL, NULL, "CSMain", "cs_5_1", 0, 0, &computeShader, &errMsg) >= 0 && computeShader != nullptr) + bytecode = CD3DX12_SHADER_BYTECODE(computeShader.Get()); +#endif + if (computeShader == nullptr) { + //delete handle; + IFE(-1); + } + + ComPtr& m_computeState = hd->pPSO_ht[pargs]; + D3D12_COMPUTE_PIPELINE_STATE_DESC computePsoDesc{}; + computePsoDesc.CS = bytecode; + computePsoDesc.pRootSignature = hd->pRootSignature.Get(); + IFE(device->pDevice->CreateComputePipelineState(&computePsoDesc, IID_GRAPHICS_PPV_ARGS(m_computeState.ReleaseAndGetAddressOf()))); + pso_iter = hd->pPSO_ht.find(pargs); + } + std::vector devicePtrs; std::vector offsets; devicePtrs.reserve(hd->inputs.size() + hd->outputs.size()); @@ -636,7 +757,7 @@ int dxShaderLaunchAsync(void* hShader, void** buffers, void* hStream) } pStream->pCmdList->SetComputeRootSignature(hd->pRootSignature.Get()); - pStream->pCmdList->SetPipelineState(hd->pPSO.Get()); + pStream->pCmdList->SetPipelineState(pso_iter->second.Get()); if (_USE_DESCRIPTOR_HEAP_) @@ -696,15 +817,22 @@ int dxShaderLaunchAsync(void* hShader, void** buffers, void* hStream) // Set StartTimer here to only consider kernel execution time. device->StartTimer(pStream->pCmdList.Get(), m_nTimerIndex); #endif - pStream->pCmdList->Dispatch(hd->block[0], hd->block[1], hd->block[2]); + pStream->pCmdList->Dispatch(blocks >= 0 ? blocks : hd->block[0], hd->block[1], hd->block[2]); #ifdef _USE_GPU_TIMER_ device->StopTimer(pStream->pCmdList.Get(), m_nTimerIndex); #endif return 0; } +int dxShaderLaunchAsync(void* hShader, void** buffers, void* hStream) +{ + return dxShaderLaunchAsyncExt(hShader, buffers, 0, -1, hStream); +} + void* dxEventCreate() { + DEBUG_PRINT(__func__); + if (dxInit(0) != 0) return nullptr; @@ -767,6 +895,8 @@ void* dxEventCreate() int dxEventDestroy(void* hEvent) { + DEBUG_PRINT(__func__); + if (hEvent == nullptr) return -1; @@ -779,6 +909,8 @@ int dxEventDestroy(void* hEvent) int dxEventRecord(void* hEvent, void* hStream) { + DEBUG_PRINT(__func__); + if (!hStream) hStream = defaultStream; @@ -802,6 +934,8 @@ int dxEventRecord(void* hEvent, void* hStream) float dxEventElapsedSecond(void* hStart, void* hStop) { + DEBUG_PRINT(__func__); + auto pQueryStart = (antares::dx_query_t*)hStart; auto pQueryEnd = (antares::dx_query_t*)hStop; @@ -809,7 +943,11 @@ float dxEventElapsedSecond(void* hStart, void* hStop) uint64_t* pData; uint64_t timeStampStart = 0; uint64_t timeStampEnd = 0; - IFE(device->globalQueryHeaps[pQueryStart->heapIdx].pReadbackBuffer->Map(0, nullptr, reinterpret_cast(&pData))); + + HRESULT res = device->globalQueryHeaps[pQueryStart->heapIdx].pReadbackBuffer->Map(0, nullptr, reinterpret_cast(&pData)); + if (res < 0) + return -1.0f; + timeStampStart = pData[pQueryStart->queryIdxInHeap]; if (pQueryEnd->heapIdx == pQueryStart->heapIdx) diff --git a/src/tools/nnfusion/templates/dxcompute/Direct3DWinNN/antares/D3D12APIWrapper.h b/src/tools/nnfusion/templates/dxcompute/Direct3DWinNN/antares/D3D12APIWrapper.h new file mode 100644 index 000000000..41a3b5b46 --- /dev/null +++ b/src/tools/nnfusion/templates/dxcompute/Direct3DWinNN/antares/D3D12APIWrapper.h @@ -0,0 +1,44 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +#ifndef __ANTARES_D3D12_WRAPPER__ +#define __ANTARES_D3D12_WRAPPER__ + +#ifdef ANTARES_EXPORTS +#define ANTARES_API __declspec(dllexport) +#else +#define ANTARES_API __declspec(dllimport) +#endif + +#define __EXPORT__ extern "C" + +__EXPORT__ ANTARES_API int dxInit(int flags); +__EXPORT__ ANTARES_API int dxFinalize(); + +__EXPORT__ ANTARES_API void* dxStreamCreate(); +__EXPORT__ ANTARES_API int dxStreamDestroy(void* hStream); +__EXPORT__ ANTARES_API int dxStreamSubmit(void* hStream); +__EXPORT__ ANTARES_API int dxStreamSynchronize(void* hStream); + +__EXPORT__ ANTARES_API void* dxMemAlloc(size_t bytes); +__EXPORT__ ANTARES_API int dxMemFree(void* dptr); +__EXPORT__ ANTARES_API int dxMemcpyHtoDAsync(void* dst, void* src, size_t bytes, void* hStream); +__EXPORT__ ANTARES_API int dxMemcpyDtoHAsync(void* dst, void* src, size_t bytes, void* hStream); +__EXPORT__ ANTARES_API int dxMemcpyDtoDAsync(void* dst, void* src, size_t bytes, void* hStream); + +__EXPORT__ ANTARES_API int dxModuleSetCompat(const char* compat_name); +__EXPORT__ ANTARES_API void* dxModuleLoad(const char* module_src); +__EXPORT__ ANTARES_API void* dxModuleGetShader(void *hModule, const char* fname); +__EXPORT__ ANTARES_API void dxModuleUnload(void* hModule); + +__EXPORT__ ANTARES_API void* dxShaderLoad_v2(const char* shader_src); +__EXPORT__ ANTARES_API int dxShaderLaunchAsyncExt(void* hShader, void** buffers, int n, int blocks, void* hStream); +__EXPORT__ ANTARES_API int dxShaderLaunchAsync(void* hShader, void** buffers, void* hStream); +__EXPORT__ ANTARES_API void dxShaderUnload(void* hShader); + +__EXPORT__ ANTARES_API void* dxEventCreate(); +__EXPORT__ ANTARES_API int dxEventRecord(void* hEvent, void* hStream); +__EXPORT__ ANTARES_API float dxEventElapsedSecond(void* hStart, void* hStop); +__EXPORT__ ANTARES_API int dxEventDestroy(void* hEvent); + +#endif \ No newline at end of file diff --git a/src/tools/nnfusion/templates/dxcompute/Direct3DXBoxNN/runtime/D3D12Util.h b/src/tools/nnfusion/templates/dxcompute/Direct3DWinNN/antares/D3D12Util.h similarity index 99% rename from src/tools/nnfusion/templates/dxcompute/Direct3DXBoxNN/runtime/D3D12Util.h rename to src/tools/nnfusion/templates/dxcompute/Direct3DWinNN/antares/D3D12Util.h index 0cd0ad39c..0d7f59416 100644 --- a/src/tools/nnfusion/templates/dxcompute/Direct3DXBoxNN/runtime/D3D12Util.h +++ b/src/tools/nnfusion/templates/dxcompute/Direct3DWinNN/antares/D3D12Util.h @@ -1121,6 +1121,8 @@ namespace antares { args_i.push_back(args[i].c_str()); } #endif + if (std::wstring(profile) != std::wstring(L"cs_6_0")) + args_i.push_back(L"-enable-16bit-types"); args_i.push_back(NULL); // Just set a random name "ShaderFile" // const WCHAR* args[] = { L"-enable-templates", L"-enable-16bit-types", NULL }; // TODO: will be supported in HLSL 2021 & cs_6_2 diff --git a/src/tools/nnfusion/templates/dxcompute/Direct3DWinNN/antares/antares.vcxproj b/src/tools/nnfusion/templates/dxcompute/Direct3DWinNN/antares/antares.vcxproj new file mode 100644 index 000000000..7fd237563 --- /dev/null +++ b/src/tools/nnfusion/templates/dxcompute/Direct3DWinNN/antares/antares.vcxproj @@ -0,0 +1,170 @@ + + + + + Debug + Win32 + + + Release + Win32 + + + Debug + x64 + + + Release + x64 + + + + 16.0 + Win32Proj + {6c8fb0d7-c5a4-4f39-a9b0-b8f4854e8225} + antares + 10.0 + + + + DynamicLibrary + true + v142 + Unicode + + + DynamicLibrary + false + v142 + true + Unicode + + + DynamicLibrary + true + v142 + Unicode + + + DynamicLibrary + false + v142 + true + Unicode + + + + + + + + + + + + + + + + + + + + + true + antares + + + false + antares + + + true + antares + + + false + antares + + + + Level3 + true + WIN32;_DEBUG;ANTARES_EXPORTS;_WINDOWS;_USRDLL;%(PreprocessorDefinitions) + true + Use + pch.h + + + Windows + true + false + + + + + Level3 + true + true + true + WIN32;NDEBUG;ANTARES_EXPORTS;_WINDOWS;_USRDLL;%(PreprocessorDefinitions) + true + Use + pch.h + + + Windows + true + true + true + false + + + + + Level3 + true + _DEBUG;ANTARES_EXPORTS;_WINDOWS;_USRDLL;%(PreprocessorDefinitions) + true + NotUsing + pch.h + + + Windows + true + false + + + + + Level3 + true + true + true + NDEBUG;ANTARES_EXPORTS;_WINDOWS;_USRDLL;%(PreprocessorDefinitions) + true + NotUsing + pch.h + + + Windows + true + true + true + false + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/src/tools/nnfusion/templates/dxcompute/Direct3DWinNN/antares/antares.vcxproj.filters b/src/tools/nnfusion/templates/dxcompute/Direct3DWinNN/antares/antares.vcxproj.filters new file mode 100644 index 000000000..bbf3e14ef --- /dev/null +++ b/src/tools/nnfusion/templates/dxcompute/Direct3DWinNN/antares/antares.vcxproj.filters @@ -0,0 +1,33 @@ + + + + + {4FC737F1-C7A5-4376-A066-2A32D752A2FF} + cpp;c;cc;cxx;c++;cppm;ixx;def;odl;idl;hpj;bat;asm;asmx + + + {93995380-89BD-4b04-88EB-625FBE52EBFB} + h;hh;hpp;hxx;h++;hm;inl;inc;ipp;xsd + + + {67DA6AB6-F800-4c08-8B7A-83BB121AAD01} + rc;ico;cur;bmp;dlg;rc2;rct;bin;rgs;gif;jpg;jpeg;jpe;resx;tiff;tif;png;wav;mfcribbon-ms + + + + + + + + Header Files + + + Header Files + + + + + Source Files + + + \ No newline at end of file diff --git a/src/tools/nnfusion/templates/dxcompute/Direct3DWinNN/antares/antares.vcxproj.user b/src/tools/nnfusion/templates/dxcompute/Direct3DWinNN/antares/antares.vcxproj.user new file mode 100644 index 000000000..88a550947 --- /dev/null +++ b/src/tools/nnfusion/templates/dxcompute/Direct3DWinNN/antares/antares.vcxproj.user @@ -0,0 +1,4 @@ + + + + \ No newline at end of file diff --git a/src/tools/nnfusion/templates/dxcompute/Direct3DWinNN/antares/cpp.hint b/src/tools/nnfusion/templates/dxcompute/Direct3DWinNN/antares/cpp.hint new file mode 100644 index 000000000..4a4e1c4a8 --- /dev/null +++ b/src/tools/nnfusion/templates/dxcompute/Direct3DWinNN/antares/cpp.hint @@ -0,0 +1,2 @@ +#define ANTARES_API __declspec(dllexport) +#define ANTARES_API __declspec(dllimport) diff --git a/src/tools/nnfusion/templates/dxcompute/Direct3DWinNN/antares/update.bat b/src/tools/nnfusion/templates/dxcompute/Direct3DWinNN/antares/update.bat new file mode 100644 index 000000000..056ae1a02 --- /dev/null +++ b/src/tools/nnfusion/templates/dxcompute/Direct3DWinNN/antares/update.bat @@ -0,0 +1,10 @@ +@echo off + +curl -LOs https://raw.githubusercontent.com/microsoft/antares/v0.3.x/backends/c-hlsl_xbox/evaluator/AntaresHlslLib/D3D12APIWrapper.h && echo updated D3D12APIWrapper.h + +curl -LOs https://raw.githubusercontent.com/microsoft/antares/v0.3.x/backends/c-hlsl_xbox/evaluator/AntaresHlslLib/D3D12APIWrapper.cpp && echo updated D3D12APIWrapper.cpp + +curl -LOs https://raw.githubusercontent.com/microsoft/antares/v0.3.x/backends/c-hlsl_xbox/evaluator/AntaresHlslLib/D3D12Util.h && echo updated D3D12Util.h + +echo finished! +pause diff --git a/src/tools/nnfusion/templates/dxcompute/Direct3DWinNN/build.py b/src/tools/nnfusion/templates/dxcompute/Direct3DWinNN/build.py deleted file mode 100644 index 240df325e..000000000 --- a/src/tools/nnfusion/templates/dxcompute/Direct3DWinNN/build.py +++ /dev/null @@ -1,106 +0,0 @@ -import os -import sys -import shutil -import winreg -import argparse -import logging -import subprocess - -logging.basicConfig(level="INFO") -logger = logging.getLogger(__name__) - - -def find_vs_path(): - # something like r"C:\Program Files (x86)\Microsoft Visual Studio\2019\Enterprise\MSBuild\Microsoft\VC\v160" - version = ["2019", "2017", "2015"] - license = ["Enterprise", "Professional", "Community"] - default_path = r"C:\Program Files (x86)\Microsoft Visual Studio" - for v in version: - v_path = os.path.join(default_path, v) - if not os.path.isdir(v_path): - continue - for l in license: - l_path = os.path.join(v_path, l) - if not os.path.isdir(l_path): - continue - logger.info(f"Find Visual Studio in {l_path}") - return l_path - return "" - - -def copy_to(src, dst): - assert os.path.exists(src), f"File not found: {src}" - if os.path.isfile(dst): - os.remove(dst) - if os.path.isdir(dst): - shutil.rmtree(dst) - if os.path.isfile(src): - shutil.copyfile(src, dst) - if os.path.isdir(src): - shutil.copytree(src, dst) - - -def copy_to_output(output_dir, build_type, platform): - os.makedirs(output_dir, exist_ok=True) - nnf_desktop_dir = r".\nnf_desktop_example" - hlsl_path = os.path.join(nnf_desktop_dir, "HLSL") - const_path = os.path.join(nnf_desktop_dir, "Constant") - para_info = os.path.join(nnf_desktop_dir, "para_info.json") - nnf_exe = os.path.join( - nnf_desktop_dir, platform if "x64" in platform else "", build_type, "nnf_desktop_example.exe") - dxcompiler_lib = os.path.join( - nnf_desktop_dir, platform if "x64" in platform else "", build_type, "dxcompiler.dll") - dxil_lib = os.path.join( - nnf_desktop_dir, platform if "x64" in platform else "", build_type, "dxil.dll") - - runtime_dir = r".\runtime" - nnf_lib = os.path.join( - runtime_dir, platform if "x64" in platform else "", build_type, "nnfusion_rt.dll") - - if os.path.exists(hlsl_path): - copy_to(hlsl_path, os.path.join(output_dir, "HLSL")) - if os.path.exists(const_path): - copy_to(const_path, os.path.join(output_dir, "Constant")) - copy_to(para_info, os.path.join(output_dir, "para_info.json")) - copy_to(nnf_exe, os.path.join(output_dir, "nnf_desktop_example.exe")) - copy_to(nnf_lib, os.path.join(output_dir, "nnfusion_rt.dll")) - copy_to(dxcompiler_lib, os.path.join(output_dir, "dxcompiler.dll")) - copy_to(dxil_lib, os.path.join(output_dir, "dxil.dll")) - - -def setup_parser(): - parser = argparse.ArgumentParser() - parser.add_argument("-v", "--vs_path", default="", - help="visual studio install path") - parser.add_argument("-t", "--build_type", default="Release") - parser.add_argument("-p", "--platform", default="x64") - parser.add_argument("-o", "--output", default="./build") - return parser - - -def main(): - parser = setup_parser() - args = parser.parse_args() - vs_path = args.vs_path if args.vs_path != "" else find_vs_path() - build_type = args.build_type - platform = args.platform - output = args.output - - assert vs_path != "", "please specify vs install path by -v" - msbuild_exe = os.path.join(vs_path, r"MSBuild\Current\Bin\MSBuild.exe") - assert os.path.isfile( - msbuild_exe), f"MSBuild.exe not found in {msbuild_exe}" - - try: - subprocess.check_output([msbuild_exe, r".\nnf_desktop_example\nnf_desktop_example.vcxproj", - f"/property:Configuration={build_type}", f"/property:Platform={platform}"], stderr=subprocess.STDOUT, encoding="utf8") - except subprocess.CalledProcessError as e: - logger.error(e.output) - sys.exit(1) - - copy_to_output(output, build_type, platform) - logger.info(f"Build successfully, output dir: {output}") - - -if __name__ == '__main__': - main() diff --git a/src/tools/nnfusion/templates/dxcompute/Direct3DWinNN/nnf_desktop_example.sln b/src/tools/nnfusion/templates/dxcompute/Direct3DWinNN/nnf_desktop_example.sln index 6e23e220f..942c0d8b9 100644 --- a/src/tools/nnfusion/templates/dxcompute/Direct3DWinNN/nnf_desktop_example.sln +++ b/src/tools/nnfusion/templates/dxcompute/Direct3DWinNN/nnf_desktop_example.sln @@ -4,8 +4,16 @@ Microsoft Visual Studio Solution File, Format Version 12.00 VisualStudioVersion = 16.0.31025.109 MinimumVisualStudioVersion = 10.0.40219.1 Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "nnf_desktop_example", "nnf_desktop_example\nnf_desktop_example.vcxproj", "{9E8C851A-9066-42A5-B160-F279A099A6C7}" + ProjectSection(ProjectDependencies) = postProject + {6C8FB0D7-C5A4-4F39-A9B0-B8F4854E8225} = {6C8FB0D7-C5A4-4F39-A9B0-B8F4854E8225} + EndProjectSection EndProject Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "nnfusion_rt", "runtime\runtime.vcxproj", "{6FCA8D82-F478-46DA-B473-E3A8DEC9B312}" + ProjectSection(ProjectDependencies) = postProject + {6C8FB0D7-C5A4-4F39-A9B0-B8F4854E8225} = {6C8FB0D7-C5A4-4F39-A9B0-B8F4854E8225} + EndProjectSection +EndProject +Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "antares", "antares\antares.vcxproj", "{6C8FB0D7-C5A4-4F39-A9B0-B8F4854E8225}" EndProject Global GlobalSection(SolutionConfigurationPlatforms) = preSolution @@ -31,6 +39,14 @@ Global {6FCA8D82-F478-46DA-B473-E3A8DEC9B312}.Release|x64.Build.0 = Release|x64 {6FCA8D82-F478-46DA-B473-E3A8DEC9B312}.Release|x86.ActiveCfg = Release|Win32 {6FCA8D82-F478-46DA-B473-E3A8DEC9B312}.Release|x86.Build.0 = Release|Win32 + {6C8FB0D7-C5A4-4F39-A9B0-B8F4854E8225}.Debug|x64.ActiveCfg = Debug|x64 + {6C8FB0D7-C5A4-4F39-A9B0-B8F4854E8225}.Debug|x64.Build.0 = Debug|x64 + {6C8FB0D7-C5A4-4F39-A9B0-B8F4854E8225}.Debug|x86.ActiveCfg = Debug|Win32 + {6C8FB0D7-C5A4-4F39-A9B0-B8F4854E8225}.Debug|x86.Build.0 = Debug|Win32 + {6C8FB0D7-C5A4-4F39-A9B0-B8F4854E8225}.Release|x64.ActiveCfg = Release|x64 + {6C8FB0D7-C5A4-4F39-A9B0-B8F4854E8225}.Release|x64.Build.0 = Release|x64 + {6C8FB0D7-C5A4-4F39-A9B0-B8F4854E8225}.Release|x86.ActiveCfg = Release|Win32 + {6C8FB0D7-C5A4-4F39-A9B0-B8F4854E8225}.Release|x86.Build.0 = Release|Win32 EndGlobalSection GlobalSection(SolutionProperties) = preSolution HideSolutionNode = FALSE diff --git a/src/tools/nnfusion/templates/dxcompute/Direct3DWinNN/nnf_desktop_example/nnf_desktop_example.vcxproj b/src/tools/nnfusion/templates/dxcompute/Direct3DWinNN/nnf_desktop_example/nnf_desktop_example.vcxproj index 0ee4a0bc8..aa4fa67d5 100644 --- a/src/tools/nnfusion/templates/dxcompute/Direct3DWinNN/nnf_desktop_example/nnf_desktop_example.vcxproj +++ b/src/tools/nnfusion/templates/dxcompute/Direct3DWinNN/nnf_desktop_example/nnf_desktop_example.vcxproj @@ -126,15 +126,23 @@ copy "$(WindowsSdkDir)Redist\D3D\$(PlatformTarget)\dxil.dll" "$(TargetDir)" /Ytrue _DEBUG;_CONSOLE;%(PreprocessorDefinitions) true - ../runtime;%(AdditionalIncludeDirectories) + ../runtime;../antares;%(AdditionalIncludeDirectories) Console true + antares.lib;%(AdditionalDependencies) + ..\$(IntDir);%(AdditionalLibraryDirectories) copy "$(WindowsSdkDir)Redist\D3D\$(PlatformTarget)\dxcompiler.dll" "$(TargetDir)" /Y -copy "$(WindowsSdkDir)Redist\D3D\$(PlatformTarget)\dxil.dll" "$(TargetDir)" /Y + +copy "$(WindowsSdkDir)Redist\D3D\$(PlatformTarget)\dxil.dll" "$(TargetDir)" /Y +copy .\para_info.json "$(TargetDir)" +IF EXIST .\antares_perf.csv copy .\antares_perf.csv "$(TargetDir)" +IF EXIST .\HLSL xcopy /i /y /e .\HLSL "$(TargetDir)\HLSL" +IF EXIST .\Constant xcopy /i /y /e .\Constant "$(TargetDir)\Constant" + @@ -145,17 +153,25 @@ copy "$(WindowsSdkDir)Redist\D3D\$(PlatformTarget)\dxil.dll" "$(TargetDir)" /Ytrue NDEBUG;_CONSOLE;%(PreprocessorDefinitions) true - ../runtime;%(AdditionalIncludeDirectories) + ../runtime;../antares;%(AdditionalIncludeDirectories) Console true true true + antares.lib;%(AdditionalDependencies) + ..\$(IntDir);%(AdditionalLibraryDirectories) copy "$(WindowsSdkDir)Redist\D3D\$(PlatformTarget)\dxcompiler.dll" "$(TargetDir)" /Y -copy "$(WindowsSdkDir)Redist\D3D\$(PlatformTarget)\dxil.dll" "$(TargetDir)" /Y + +copy "$(WindowsSdkDir)Redist\D3D\$(PlatformTarget)\dxil.dll" "$(TargetDir)" /Y +copy .\para_info.json "$(TargetDir)" +IF EXIST .\antares_perf.csv copy .\antares_perf.csv "$(TargetDir)" +IF EXIST .\HLSL xcopy /i /y /e .\HLSL "$(TargetDir)\HLSL" +IF EXIST .\Constant xcopy /i /y /e .\Constant "$(TargetDir)\Constant" + diff --git a/src/tools/nnfusion/templates/dxcompute/Direct3DWinNN/runtime/D3D12APIWrapper.h b/src/tools/nnfusion/templates/dxcompute/Direct3DWinNN/runtime/D3D12APIWrapper.h deleted file mode 100644 index 0cc522f74..000000000 --- a/src/tools/nnfusion/templates/dxcompute/Direct3DWinNN/runtime/D3D12APIWrapper.h +++ /dev/null @@ -1,36 +0,0 @@ -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT license. - -#ifndef __ANTARES_D3D12_WRAPPER__ -#define __ANTARES_D3D12_WRAPPER__ - -#define __EXPORT__ extern "C" __declspec(dllexport) - -__EXPORT__ int dxInit(int flags); -__EXPORT__ int dxFinalize(); - -__EXPORT__ void* dxStreamCreate(); -__EXPORT__ int dxStreamDestroy(void* hStream); -__EXPORT__ int dxStreamSubmit(void* hStream); -__EXPORT__ int dxStreamSynchronize(void* hStream); - -__EXPORT__ void* dxMemAlloc(size_t bytes); -__EXPORT__ int dxMemFree(void* dptr); -__EXPORT__ int dxMemcpyHtoDAsync(void* dst, void* src, size_t bytes, void* hStream); -__EXPORT__ int dxMemcpyDtoHAsync(void* dst, void* src, size_t bytes, void* hStream); -__EXPORT__ int dxMemcpyDtoDAsync(void* dst, void* src, size_t bytes, void* hStream); - -__EXPORT__ void* dxModuleLoad(const char* module_src); -__EXPORT__ void* dxModuleGetShader(void *hModule, const char* fname); -__EXPORT__ void dxModuleUnload(void* hModule); - -__EXPORT__ void* dxShaderLoad_v2(const char* shader_src); -__EXPORT__ int dxShaderLaunchAsync(void* hShader, void** buffers, void* hStream); -__EXPORT__ void dxShaderUnload(void* hShader); - -__EXPORT__ void* dxEventCreate(); -__EXPORT__ int dxEventRecord(void* hEvent, void* hStream); -__EXPORT__ float dxEventElapsedSecond(void* hStart, void* hStop); -__EXPORT__ int dxEventDestroy(void* hEvent); - -#endif \ No newline at end of file diff --git a/src/tools/nnfusion/templates/dxcompute/Direct3DWinNN/runtime/half.hpp b/src/tools/nnfusion/templates/dxcompute/Direct3DWinNN/runtime/half.hpp new file mode 100644 index 000000000..d0a882dd6 --- /dev/null +++ b/src/tools/nnfusion/templates/dxcompute/Direct3DWinNN/runtime/half.hpp @@ -0,0 +1,4601 @@ +// half - IEEE 754-based half-precision floating-point library. +// +// Copyright (c) 2012-2021 Christian Rau +// +// Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation +// files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, +// modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the +// Software is furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE +// WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR +// COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, +// ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +// Version 2.2.0 + +/// \file +/// Main header file for half-precision functionality. + +#ifndef HALF_HALF_HPP +#define HALF_HALF_HPP + +#define HALF_GCC_VERSION (__GNUC__*100+__GNUC_MINOR__) + +#if defined(__INTEL_COMPILER) + #define HALF_ICC_VERSION __INTEL_COMPILER +#elif defined(__ICC) + #define HALF_ICC_VERSION __ICC +#elif defined(__ICL) + #define HALF_ICC_VERSION __ICL +#else + #define HALF_ICC_VERSION 0 +#endif + +// check C++11 language features +#if defined(__clang__) // clang + #if __has_feature(cxx_static_assert) && !defined(HALF_ENABLE_CPP11_STATIC_ASSERT) + #define HALF_ENABLE_CPP11_STATIC_ASSERT 1 + #endif + #if __has_feature(cxx_constexpr) && !defined(HALF_ENABLE_CPP11_CONSTEXPR) + #define HALF_ENABLE_CPP11_CONSTEXPR 1 + #endif + #if __has_feature(cxx_noexcept) && !defined(HALF_ENABLE_CPP11_NOEXCEPT) + #define HALF_ENABLE_CPP11_NOEXCEPT 1 + #endif + #if __has_feature(cxx_user_literals) && !defined(HALF_ENABLE_CPP11_USER_LITERALS) + #define HALF_ENABLE_CPP11_USER_LITERALS 1 + #endif + #if __has_feature(cxx_thread_local) && !defined(HALF_ENABLE_CPP11_THREAD_LOCAL) + #define HALF_ENABLE_CPP11_THREAD_LOCAL 1 + #endif + #if (defined(__GXX_EXPERIMENTAL_CXX0X__) || __cplusplus >= 201103L) && !defined(HALF_ENABLE_CPP11_LONG_LONG) + #define HALF_ENABLE_CPP11_LONG_LONG 1 + #endif +#elif HALF_ICC_VERSION && defined(__INTEL_CXX11_MODE__) // Intel C++ + #if HALF_ICC_VERSION >= 1500 && !defined(HALF_ENABLE_CPP11_THREAD_LOCAL) + #define HALF_ENABLE_CPP11_THREAD_LOCAL 1 + #endif + #if HALF_ICC_VERSION >= 1500 && !defined(HALF_ENABLE_CPP11_USER_LITERALS) + #define HALF_ENABLE_CPP11_USER_LITERALS 1 + #endif + #if HALF_ICC_VERSION >= 1400 && !defined(HALF_ENABLE_CPP11_CONSTEXPR) + #define HALF_ENABLE_CPP11_CONSTEXPR 1 + #endif + #if HALF_ICC_VERSION >= 1400 && !defined(HALF_ENABLE_CPP11_NOEXCEPT) + #define HALF_ENABLE_CPP11_NOEXCEPT 1 + #endif + #if HALF_ICC_VERSION >= 1110 && !defined(HALF_ENABLE_CPP11_STATIC_ASSERT) + #define HALF_ENABLE_CPP11_STATIC_ASSERT 1 + #endif + #if HALF_ICC_VERSION >= 1110 && !defined(HALF_ENABLE_CPP11_LONG_LONG) + #define HALF_ENABLE_CPP11_LONG_LONG 1 + #endif +#elif defined(__GNUC__) // gcc + #if defined(__GXX_EXPERIMENTAL_CXX0X__) || __cplusplus >= 201103L + #if HALF_GCC_VERSION >= 408 && !defined(HALF_ENABLE_CPP11_THREAD_LOCAL) + #define HALF_ENABLE_CPP11_THREAD_LOCAL 1 + #endif + #if HALF_GCC_VERSION >= 407 && !defined(HALF_ENABLE_CPP11_USER_LITERALS) + #define HALF_ENABLE_CPP11_USER_LITERALS 1 + #endif + #if HALF_GCC_VERSION >= 406 && !defined(HALF_ENABLE_CPP11_CONSTEXPR) + #define HALF_ENABLE_CPP11_CONSTEXPR 1 + #endif + #if HALF_GCC_VERSION >= 406 && !defined(HALF_ENABLE_CPP11_NOEXCEPT) + #define HALF_ENABLE_CPP11_NOEXCEPT 1 + #endif + #if HALF_GCC_VERSION >= 403 && !defined(HALF_ENABLE_CPP11_STATIC_ASSERT) + #define HALF_ENABLE_CPP11_STATIC_ASSERT 1 + #endif + #if !defined(HALF_ENABLE_CPP11_LONG_LONG) + #define HALF_ENABLE_CPP11_LONG_LONG 1 + #endif + #endif + #define HALF_TWOS_COMPLEMENT_INT 1 +#elif defined(_MSC_VER) // Visual C++ + #if _MSC_VER >= 1900 && !defined(HALF_ENABLE_CPP11_THREAD_LOCAL) + #define HALF_ENABLE_CPP11_THREAD_LOCAL 1 + #endif + #if _MSC_VER >= 1900 && !defined(HALF_ENABLE_CPP11_USER_LITERALS) + #define HALF_ENABLE_CPP11_USER_LITERALS 1 + #endif + #if _MSC_VER >= 1900 && !defined(HALF_ENABLE_CPP11_CONSTEXPR) + #define HALF_ENABLE_CPP11_CONSTEXPR 1 + #endif + #if _MSC_VER >= 1900 && !defined(HALF_ENABLE_CPP11_NOEXCEPT) + #define HALF_ENABLE_CPP11_NOEXCEPT 1 + #endif + #if _MSC_VER >= 1600 && !defined(HALF_ENABLE_CPP11_STATIC_ASSERT) + #define HALF_ENABLE_CPP11_STATIC_ASSERT 1 + #endif + #if _MSC_VER >= 1310 && !defined(HALF_ENABLE_CPP11_LONG_LONG) + #define HALF_ENABLE_CPP11_LONG_LONG 1 + #endif + #define HALF_TWOS_COMPLEMENT_INT 1 + #define HALF_POP_WARNINGS 1 + #pragma warning(push) + #pragma warning(disable : 4099 4127 4146) //struct vs class, constant in if, negative unsigned +#endif + +// check C++11 library features +#include +#if defined(_LIBCPP_VERSION) // libc++ + #if defined(__GXX_EXPERIMENTAL_CXX0X__) || __cplusplus >= 201103 + #ifndef HALF_ENABLE_CPP11_TYPE_TRAITS + #define HALF_ENABLE_CPP11_TYPE_TRAITS 1 + #endif + #ifndef HALF_ENABLE_CPP11_CSTDINT + #define HALF_ENABLE_CPP11_CSTDINT 1 + #endif + #ifndef HALF_ENABLE_CPP11_CMATH + #define HALF_ENABLE_CPP11_CMATH 1 + #endif + #ifndef HALF_ENABLE_CPP11_HASH + #define HALF_ENABLE_CPP11_HASH 1 + #endif + #ifndef HALF_ENABLE_CPP11_CFENV + #define HALF_ENABLE_CPP11_CFENV 1 + #endif + #endif +#elif defined(__GLIBCXX__) // libstdc++ + #if defined(__GXX_EXPERIMENTAL_CXX0X__) || __cplusplus >= 201103 + #ifdef __clang__ + #if __GLIBCXX__ >= 20080606 && !defined(HALF_ENABLE_CPP11_TYPE_TRAITS) + #define HALF_ENABLE_CPP11_TYPE_TRAITS 1 + #endif + #if __GLIBCXX__ >= 20080606 && !defined(HALF_ENABLE_CPP11_CSTDINT) + #define HALF_ENABLE_CPP11_CSTDINT 1 + #endif + #if __GLIBCXX__ >= 20080606 && !defined(HALF_ENABLE_CPP11_CMATH) + #define HALF_ENABLE_CPP11_CMATH 1 + #endif + #if __GLIBCXX__ >= 20080606 && !defined(HALF_ENABLE_CPP11_HASH) + #define HALF_ENABLE_CPP11_HASH 1 + #endif + #if __GLIBCXX__ >= 20080606 && !defined(HALF_ENABLE_CPP11_CFENV) + #define HALF_ENABLE_CPP11_CFENV 1 + #endif + #else + #if HALF_GCC_VERSION >= 403 && !defined(HALF_ENABLE_CPP11_TYPE_TRAITS) + #define HALF_ENABLE_CPP11_TYPE_TRAITS 1 + #endif + #if HALF_GCC_VERSION >= 403 && !defined(HALF_ENABLE_CPP11_CSTDINT) + #define HALF_ENABLE_CPP11_CSTDINT 1 + #endif + #if HALF_GCC_VERSION >= 403 && !defined(HALF_ENABLE_CPP11_CMATH) + #define HALF_ENABLE_CPP11_CMATH 1 + #endif + #if HALF_GCC_VERSION >= 403 && !defined(HALF_ENABLE_CPP11_HASH) + #define HALF_ENABLE_CPP11_HASH 1 + #endif + #if HALF_GCC_VERSION >= 403 && !defined(HALF_ENABLE_CPP11_CFENV) + #define HALF_ENABLE_CPP11_CFENV 1 + #endif + #endif + #endif +#elif defined(_CPPLIB_VER) // Dinkumware/Visual C++ + #if _CPPLIB_VER >= 520 && !defined(HALF_ENABLE_CPP11_TYPE_TRAITS) + #define HALF_ENABLE_CPP11_TYPE_TRAITS 1 + #endif + #if _CPPLIB_VER >= 520 && !defined(HALF_ENABLE_CPP11_CSTDINT) + #define HALF_ENABLE_CPP11_CSTDINT 1 + #endif + #if _CPPLIB_VER >= 520 && !defined(HALF_ENABLE_CPP11_HASH) + #define HALF_ENABLE_CPP11_HASH 1 + #endif + #if _CPPLIB_VER >= 610 && !defined(HALF_ENABLE_CPP11_CMATH) + #define HALF_ENABLE_CPP11_CMATH 1 + #endif + #if _CPPLIB_VER >= 610 && !defined(HALF_ENABLE_CPP11_CFENV) + #define HALF_ENABLE_CPP11_CFENV 1 + #endif +#endif +#undef HALF_GCC_VERSION +#undef HALF_ICC_VERSION + +// any error throwing C++ exceptions? +#if defined(HALF_ERRHANDLING_THROW_INVALID) || defined(HALF_ERRHANDLING_THROW_DIVBYZERO) || defined(HALF_ERRHANDLING_THROW_OVERFLOW) || defined(HALF_ERRHANDLING_THROW_UNDERFLOW) || defined(HALF_ERRHANDLING_THROW_INEXACT) +#define HALF_ERRHANDLING_THROWS 1 +#endif + +// any error handling enabled? +#define HALF_ERRHANDLING (HALF_ERRHANDLING_FLAGS||HALF_ERRHANDLING_ERRNO||HALF_ERRHANDLING_FENV||HALF_ERRHANDLING_THROWS) + +#if HALF_ERRHANDLING + #define HALF_UNUSED_NOERR(name) name +#else + #define HALF_UNUSED_NOERR(name) +#endif + +// support constexpr +#if HALF_ENABLE_CPP11_CONSTEXPR + #define HALF_CONSTEXPR constexpr + #define HALF_CONSTEXPR_CONST constexpr + #if HALF_ERRHANDLING + #define HALF_CONSTEXPR_NOERR + #else + #define HALF_CONSTEXPR_NOERR constexpr + #endif +#else + #define HALF_CONSTEXPR + #define HALF_CONSTEXPR_CONST const + #define HALF_CONSTEXPR_NOERR +#endif + +// support noexcept +#if HALF_ENABLE_CPP11_NOEXCEPT + #define HALF_NOEXCEPT noexcept + #define HALF_NOTHROW noexcept +#else + #define HALF_NOEXCEPT + #define HALF_NOTHROW throw() +#endif + +// support thread storage +#if HALF_ENABLE_CPP11_THREAD_LOCAL + #define HALF_THREAD_LOCAL thread_local +#else + #define HALF_THREAD_LOCAL static +#endif + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#if HALF_ENABLE_CPP11_TYPE_TRAITS + #include +#endif +#if HALF_ENABLE_CPP11_CSTDINT + #include +#endif +#if HALF_ERRHANDLING_ERRNO + #include +#endif +#if HALF_ENABLE_CPP11_CFENV + #include +#endif +#if HALF_ENABLE_CPP11_HASH + #include +#endif + + +#ifndef HALF_ENABLE_F16C_INTRINSICS + /// Enable F16C intruction set intrinsics. + /// Defining this to 1 enables the use of [F16C compiler intrinsics](https://en.wikipedia.org/wiki/F16C) for converting between + /// half-precision and single-precision values which may result in improved performance. This will not perform additional checks + /// for support of the F16C instruction set, so an appropriate target platform is required when enabling this feature. + /// + /// Unless predefined it will be enabled automatically when the `__F16C__` symbol is defined, which some compilers do on supporting platforms. + #define HALF_ENABLE_F16C_INTRINSICS __F16C__ +#endif +#if HALF_ENABLE_F16C_INTRINSICS + #include +#endif + +#ifdef HALF_DOXYGEN_ONLY +/// Type for internal floating-point computations. +/// This can be predefined to a built-in floating-point type (`float`, `double` or `long double`) to override the internal +/// half-precision implementation to use this type for computing arithmetic operations and mathematical function (if available). +/// This can result in improved performance for arithmetic operators and mathematical functions but might cause results to +/// deviate from the specified half-precision rounding mode and inhibits proper detection of half-precision exceptions. +#define HALF_ARITHMETIC_TYPE (undefined) + +/// Enable internal exception flags. +/// Defining this to 1 causes operations on half-precision values to raise internal floating-point exception flags according to +/// the IEEE 754 standard. These can then be cleared and checked with clearexcept(), testexcept(). +#define HALF_ERRHANDLING_FLAGS 0 + +/// Enable exception propagation to `errno`. +/// Defining this to 1 causes operations on half-precision values to propagate floating-point exceptions to +/// [errno](https://en.cppreference.com/w/cpp/error/errno) from ``. Specifically this will propagate domain errors as +/// [EDOM](https://en.cppreference.com/w/cpp/error/errno_macros) and pole, overflow and underflow errors as +/// [ERANGE](https://en.cppreference.com/w/cpp/error/errno_macros). Inexact errors won't be propagated. +#define HALF_ERRHANDLING_ERRNO 0 + +/// Enable exception propagation to built-in floating-point platform. +/// Defining this to 1 causes operations on half-precision values to propagate floating-point exceptions to the built-in +/// single- and double-precision implementation's exception flags using the +/// [C++11 floating-point environment control](https://en.cppreference.com/w/cpp/numeric/fenv) from ``. However, this +/// does not work in reverse and single- or double-precision exceptions will not raise the corresponding half-precision +/// exception flags, nor will explicitly clearing flags clear the corresponding built-in flags. +#define HALF_ERRHANDLING_FENV 0 + +/// Throw C++ exception on domain errors. +/// Defining this to a string literal causes operations on half-precision values to throw a +/// [std::domain_error](https://en.cppreference.com/w/cpp/error/domain_error) with the specified message on domain errors. +#define HALF_ERRHANDLING_THROW_INVALID (undefined) + +/// Throw C++ exception on pole errors. +/// Defining this to a string literal causes operations on half-precision values to throw a +/// [std::domain_error](https://en.cppreference.com/w/cpp/error/domain_error) with the specified message on pole errors. +#define HALF_ERRHANDLING_THROW_DIVBYZERO (undefined) + +/// Throw C++ exception on overflow errors. +/// Defining this to a string literal causes operations on half-precision values to throw a +/// [std::overflow_error](https://en.cppreference.com/w/cpp/error/overflow_error) with the specified message on overflows. +#define HALF_ERRHANDLING_THROW_OVERFLOW (undefined) + +/// Throw C++ exception on underflow errors. +/// Defining this to a string literal causes operations on half-precision values to throw a +/// [std::underflow_error](https://en.cppreference.com/w/cpp/error/underflow_error) with the specified message on underflows. +#define HALF_ERRHANDLING_THROW_UNDERFLOW (undefined) + +/// Throw C++ exception on rounding errors. +/// Defining this to 1 causes operations on half-precision values to throw a +/// [std::range_error](https://en.cppreference.com/w/cpp/error/range_error) with the specified message on general rounding errors. +#define HALF_ERRHANDLING_THROW_INEXACT (undefined) +#endif + +#ifndef HALF_ERRHANDLING_OVERFLOW_TO_INEXACT +/// Raise INEXACT exception on overflow. +/// Defining this to 1 (default) causes overflow errors to automatically raise inexact exceptions in addition. +/// These will be raised after any possible handling of the underflow exception. +#define HALF_ERRHANDLING_OVERFLOW_TO_INEXACT 1 +#endif + +#ifndef HALF_ERRHANDLING_UNDERFLOW_TO_INEXACT +/// Raise INEXACT exception on underflow. +/// Defining this to 1 (default) causes underflow errors to automatically raise inexact exceptions in addition. +/// These will be raised after any possible handling of the underflow exception. +/// +/// **Note:** This will actually cause underflow (and the accompanying inexact) exceptions to be raised *only* when the result +/// is inexact, while if disabled bare underflow errors will be raised for *any* (possibly exact) subnormal result. +#define HALF_ERRHANDLING_UNDERFLOW_TO_INEXACT 1 +#endif + +/// Default rounding mode. +/// This specifies the rounding mode used for all conversions between [half](\ref half_float::half)s and more precise types +/// (unless using half_cast() and specifying the rounding mode directly) as well as in arithmetic operations and mathematical +/// functions. It can be redefined (before including half.hpp) to one of the standard rounding modes using their respective +/// constants or the equivalent values of +/// [std::float_round_style](https://en.cppreference.com/w/cpp/types/numeric_limits/float_round_style): +/// +/// `std::float_round_style` | value | rounding +/// ---------------------------------|-------|------------------------- +/// `std::round_indeterminate` | -1 | fastest +/// `std::round_toward_zero` | 0 | toward zero +/// `std::round_to_nearest` | 1 | to nearest (default) +/// `std::round_toward_infinity` | 2 | toward positive infinity +/// `std::round_toward_neg_infinity` | 3 | toward negative infinity +/// +/// By default this is set to `1` (`std::round_to_nearest`), which rounds results to the nearest representable value. It can even +/// be set to [std::numeric_limits::round_style](https://en.cppreference.com/w/cpp/types/numeric_limits/round_style) to synchronize +/// the rounding mode with that of the built-in single-precision implementation (which is likely `std::round_to_nearest`, though). +#ifndef HALF_ROUND_STYLE + #define HALF_ROUND_STYLE 1 // = std::round_to_nearest +#endif + +/// Value signaling overflow. +/// In correspondence with `HUGE_VAL[F|L]` from `` this symbol expands to a positive value signaling the overflow of an +/// operation, in particular it just evaluates to positive infinity. +/// +/// **See also:** Documentation for [HUGE_VAL](https://en.cppreference.com/w/cpp/numeric/math/HUGE_VAL) +#define HUGE_VALH std::numeric_limits::infinity() + +/// Fast half-precision fma function. +/// This symbol is defined if the fma() function generally executes as fast as, or faster than, a separate +/// half-precision multiplication followed by an addition, which is always the case. +/// +/// **See also:** Documentation for [FP_FAST_FMA](https://en.cppreference.com/w/cpp/numeric/math/fma) +#define FP_FAST_FMAH 1 + +/// Half rounding mode. +/// In correspondence with `FLT_ROUNDS` from `` this symbol expands to the rounding mode used for +/// half-precision operations. It is an alias for [HALF_ROUND_STYLE](\ref HALF_ROUND_STYLE). +/// +/// **See also:** Documentation for [FLT_ROUNDS](https://en.cppreference.com/w/cpp/types/climits/FLT_ROUNDS) +#define HLF_ROUNDS HALF_ROUND_STYLE + +#ifndef FP_ILOGB0 + #define FP_ILOGB0 INT_MIN +#endif +#ifndef FP_ILOGBNAN + #define FP_ILOGBNAN INT_MAX +#endif +#ifndef FP_SUBNORMAL + #define FP_SUBNORMAL 0 +#endif +#ifndef FP_ZERO + #define FP_ZERO 1 +#endif +#ifndef FP_NAN + #define FP_NAN 2 +#endif +#ifndef FP_INFINITE + #define FP_INFINITE 3 +#endif +#ifndef FP_NORMAL + #define FP_NORMAL 4 +#endif + +#if !HALF_ENABLE_CPP11_CFENV && !defined(FE_ALL_EXCEPT) + #define FE_INVALID 0x10 + #define FE_DIVBYZERO 0x08 + #define FE_OVERFLOW 0x04 + #define FE_UNDERFLOW 0x02 + #define FE_INEXACT 0x01 + #define FE_ALL_EXCEPT (FE_INVALID|FE_DIVBYZERO|FE_OVERFLOW|FE_UNDERFLOW|FE_INEXACT) +#endif + + +/// Main namespace for half-precision functionality. +/// This namespace contains all the functionality provided by the library. +namespace half_float +{ + class half; + +#if HALF_ENABLE_CPP11_USER_LITERALS + /// Library-defined half-precision literals. + /// Import this namespace to enable half-precision floating-point literals: + /// ~~~~{.cpp} + /// using namespace half_float::literal; + /// half_float::half = 4.2_h; + /// ~~~~ + namespace literal + { + half operator "" _h(long double); + } +#endif + + /// \internal + /// \brief Implementation details. + namespace detail + { + #if HALF_ENABLE_CPP11_TYPE_TRAITS + /// Conditional type. + template struct conditional : std::conditional {}; + + /// Helper for tag dispatching. + template struct bool_type : std::integral_constant {}; + using std::true_type; + using std::false_type; + + /// Type traits for floating-point types. + template struct is_float : std::is_floating_point {}; + #else + /// Conditional type. + template struct conditional { typedef T type; }; + template struct conditional { typedef F type; }; + + /// Helper for tag dispatching. + template struct bool_type {}; + typedef bool_type true_type; + typedef bool_type false_type; + + /// Type traits for floating-point types. + template struct is_float : false_type {}; + template struct is_float : is_float {}; + template struct is_float : is_float {}; + template struct is_float : is_float {}; + template<> struct is_float : true_type {}; + template<> struct is_float : true_type {}; + template<> struct is_float : true_type {}; + #endif + + /// Type traits for floating-point bits. + template struct bits { typedef unsigned char type; }; + template struct bits : bits {}; + template struct bits : bits {}; + template struct bits : bits {}; + + #if HALF_ENABLE_CPP11_CSTDINT + /// Unsigned integer of (at least) 16 bits width. + typedef std::uint_least16_t uint16; + + /// Fastest unsigned integer of (at least) 32 bits width. + typedef std::uint_fast32_t uint32; + + /// Fastest signed integer of (at least) 32 bits width. + typedef std::int_fast32_t int32; + + /// Unsigned integer of (at least) 32 bits width. + template<> struct bits { typedef std::uint_least32_t type; }; + + /// Unsigned integer of (at least) 64 bits width. + template<> struct bits { typedef std::uint_least64_t type; }; + #else + /// Unsigned integer of (at least) 16 bits width. + typedef unsigned short uint16; + + /// Fastest unsigned integer of (at least) 32 bits width. + typedef unsigned long uint32; + + /// Fastest unsigned integer of (at least) 32 bits width. + typedef long int32; + + /// Unsigned integer of (at least) 32 bits width. + template<> struct bits : conditional::digits>=32,unsigned int,unsigned long> {}; + + #if HALF_ENABLE_CPP11_LONG_LONG + /// Unsigned integer of (at least) 64 bits width. + template<> struct bits : conditional::digits>=64,unsigned long,unsigned long long> {}; + #else + /// Unsigned integer of (at least) 64 bits width. + template<> struct bits { typedef unsigned long type; }; + #endif + #endif + + #ifdef HALF_ARITHMETIC_TYPE + /// Type to use for arithmetic computations and mathematic functions internally. + typedef HALF_ARITHMETIC_TYPE internal_t; + #endif + + /// Tag type for binary construction. + struct binary_t {}; + + /// Tag for binary construction. + HALF_CONSTEXPR_CONST binary_t binary = binary_t(); + + /// \name Implementation defined classification and arithmetic + /// \{ + + /// Check for infinity. + /// \tparam T argument type (builtin floating-point type) + /// \param arg value to query + /// \retval true if infinity + /// \retval false else + template bool builtin_isinf(T arg) + { + #if HALF_ENABLE_CPP11_CMATH + return std::isinf(arg); + #elif defined(_MSC_VER) + return !::_finite(static_cast(arg)) && !::_isnan(static_cast(arg)); + #else + return arg == std::numeric_limits::infinity() || arg == -std::numeric_limits::infinity(); + #endif + } + + /// Check for NaN. + /// \tparam T argument type (builtin floating-point type) + /// \param arg value to query + /// \retval true if not a number + /// \retval false else + template bool builtin_isnan(T arg) + { + #if HALF_ENABLE_CPP11_CMATH + return std::isnan(arg); + #elif defined(_MSC_VER) + return ::_isnan(static_cast(arg)) != 0; + #else + return arg != arg; + #endif + } + + /// Check sign. + /// \tparam T argument type (builtin floating-point type) + /// \param arg value to query + /// \retval true if signbit set + /// \retval false else + template bool builtin_signbit(T arg) + { + #if HALF_ENABLE_CPP11_CMATH + return std::signbit(arg); + #else + return arg < T() || (arg == T() && T(1)/arg < T()); + #endif + } + + /// Platform-independent sign mask. + /// \param arg integer value in two's complement + /// \retval -1 if \a arg negative + /// \retval 0 if \a arg positive + inline uint32 sign_mask(uint32 arg) + { + static const int N = std::numeric_limits::digits - 1; + #if HALF_TWOS_COMPLEMENT_INT + return static_cast(arg) >> N; + #else + return -((arg>>N)&1); + #endif + } + + /// Platform-independent arithmetic right shift. + /// \param arg integer value in two's complement + /// \param i shift amount (at most 31) + /// \return \a arg right shifted for \a i bits with possible sign extension + inline uint32 arithmetic_shift(uint32 arg, int i) + { + #if HALF_TWOS_COMPLEMENT_INT + return static_cast(arg) >> i; + #else + return static_cast(arg)/(static_cast(1)<>(std::numeric_limits::digits-1))&1); + #endif + } + + /// \} + /// \name Error handling + /// \{ + + /// Internal exception flags. + /// \return reference to global exception flags + inline int& errflags() { HALF_THREAD_LOCAL int flags = 0; return flags; } + + /// Raise floating-point exception. + /// \param flags exceptions to raise + /// \param cond condition to raise exceptions for + inline void raise(int HALF_UNUSED_NOERR(flags), bool HALF_UNUSED_NOERR(cond) = true) + { + #if HALF_ERRHANDLING + if(!cond) + return; + #if HALF_ERRHANDLING_FLAGS + errflags() |= flags; + #endif + #if HALF_ERRHANDLING_ERRNO + if(flags & FE_INVALID) + errno = EDOM; + else if(flags & (FE_DIVBYZERO|FE_OVERFLOW|FE_UNDERFLOW)) + errno = ERANGE; + #endif + #if HALF_ERRHANDLING_FENV && HALF_ENABLE_CPP11_CFENV + std::feraiseexcept(flags); + #endif + #ifdef HALF_ERRHANDLING_THROW_INVALID + if(flags & FE_INVALID) + throw std::domain_error(HALF_ERRHANDLING_THROW_INVALID); + #endif + #ifdef HALF_ERRHANDLING_THROW_DIVBYZERO + if(flags & FE_DIVBYZERO) + throw std::domain_error(HALF_ERRHANDLING_THROW_DIVBYZERO); + #endif + #ifdef HALF_ERRHANDLING_THROW_OVERFLOW + if(flags & FE_OVERFLOW) + throw std::overflow_error(HALF_ERRHANDLING_THROW_OVERFLOW); + #endif + #ifdef HALF_ERRHANDLING_THROW_UNDERFLOW + if(flags & FE_UNDERFLOW) + throw std::underflow_error(HALF_ERRHANDLING_THROW_UNDERFLOW); + #endif + #ifdef HALF_ERRHANDLING_THROW_INEXACT + if(flags & FE_INEXACT) + throw std::range_error(HALF_ERRHANDLING_THROW_INEXACT); + #endif + #if HALF_ERRHANDLING_UNDERFLOW_TO_INEXACT + if((flags & FE_UNDERFLOW) && !(flags & FE_INEXACT)) + raise(FE_INEXACT); + #endif + #if HALF_ERRHANDLING_OVERFLOW_TO_INEXACT + if((flags & FE_OVERFLOW) && !(flags & FE_INEXACT)) + raise(FE_INEXACT); + #endif + #endif + } + + /// Check and signal for any NaN. + /// \param x first half-precision value to check + /// \param y second half-precision value to check + /// \retval true if either \a x or \a y is NaN + /// \retval false else + /// \exception FE_INVALID if \a x or \a y is NaN + inline HALF_CONSTEXPR_NOERR bool compsignal(unsigned int x, unsigned int y) + { + #if HALF_ERRHANDLING + raise(FE_INVALID, (x&0x7FFF)>0x7C00 || (y&0x7FFF)>0x7C00); + #endif + return (x&0x7FFF) > 0x7C00 || (y&0x7FFF) > 0x7C00; + } + + /// Signal and silence signaling NaN. + /// \param nan half-precision NaN value + /// \return quiet NaN + /// \exception FE_INVALID if \a nan is signaling NaN + inline HALF_CONSTEXPR_NOERR unsigned int signal(unsigned int nan) + { + #if HALF_ERRHANDLING + raise(FE_INVALID, !(nan&0x200)); + #endif + return nan | 0x200; + } + + /// Signal and silence signaling NaNs. + /// \param x first half-precision value to check + /// \param y second half-precision value to check + /// \return quiet NaN + /// \exception FE_INVALID if \a x or \a y is signaling NaN + inline HALF_CONSTEXPR_NOERR unsigned int signal(unsigned int x, unsigned int y) + { + #if HALF_ERRHANDLING + raise(FE_INVALID, ((x&0x7FFF)>0x7C00 && !(x&0x200)) || ((y&0x7FFF)>0x7C00 && !(y&0x200))); + #endif + return ((x&0x7FFF)>0x7C00) ? (x|0x200) : (y|0x200); + } + + /// Signal and silence signaling NaNs. + /// \param x first half-precision value to check + /// \param y second half-precision value to check + /// \param z third half-precision value to check + /// \return quiet NaN + /// \exception FE_INVALID if \a x, \a y or \a z is signaling NaN + inline HALF_CONSTEXPR_NOERR unsigned int signal(unsigned int x, unsigned int y, unsigned int z) + { + #if HALF_ERRHANDLING + raise(FE_INVALID, ((x&0x7FFF)>0x7C00 && !(x&0x200)) || ((y&0x7FFF)>0x7C00 && !(y&0x200)) || ((z&0x7FFF)>0x7C00 && !(z&0x200))); + #endif + return ((x&0x7FFF)>0x7C00) ? (x|0x200) : ((y&0x7FFF)>0x7C00) ? (y|0x200) : (z|0x200); + } + + /// Select value or signaling NaN. + /// \param x preferred half-precision value + /// \param y ignored half-precision value except for signaling NaN + /// \return \a y if signaling NaN, \a x otherwise + /// \exception FE_INVALID if \a y is signaling NaN + inline HALF_CONSTEXPR_NOERR unsigned int select(unsigned int x, unsigned int HALF_UNUSED_NOERR(y)) + { + #if HALF_ERRHANDLING + return (((y&0x7FFF)>0x7C00) && !(y&0x200)) ? signal(y) : x; + #else + return x; + #endif + } + + /// Raise domain error and return NaN. + /// return quiet NaN + /// \exception FE_INVALID + inline HALF_CONSTEXPR_NOERR unsigned int invalid() + { + #if HALF_ERRHANDLING + raise(FE_INVALID); + #endif + return 0x7FFF; + } + + /// Raise pole error and return infinity. + /// \param sign half-precision value with sign bit only + /// \return half-precision infinity with sign of \a sign + /// \exception FE_DIVBYZERO + inline HALF_CONSTEXPR_NOERR unsigned int pole(unsigned int sign = 0) + { + #if HALF_ERRHANDLING + raise(FE_DIVBYZERO); + #endif + return sign | 0x7C00; + } + + /// Check value for underflow. + /// \param arg non-zero half-precision value to check + /// \return \a arg + /// \exception FE_UNDERFLOW if arg is subnormal + inline HALF_CONSTEXPR_NOERR unsigned int check_underflow(unsigned int arg) + { + #if HALF_ERRHANDLING && !HALF_ERRHANDLING_UNDERFLOW_TO_INEXACT + raise(FE_UNDERFLOW, !(arg&0x7C00)); + #endif + return arg; + } + + /// \} + /// \name Conversion and rounding + /// \{ + + /// Half-precision overflow. + /// \tparam R rounding mode to use + /// \param sign half-precision value with sign bit only + /// \return rounded overflowing half-precision value + /// \exception FE_OVERFLOW + template HALF_CONSTEXPR_NOERR unsigned int overflow(unsigned int sign = 0) + { + #if HALF_ERRHANDLING + raise(FE_OVERFLOW); + #endif + return (R==std::round_toward_infinity) ? (sign+0x7C00-(sign>>15)) : + (R==std::round_toward_neg_infinity) ? (sign+0x7BFF+(sign>>15)) : + (R==std::round_toward_zero) ? (sign|0x7BFF) : + (sign|0x7C00); + } + + /// Half-precision underflow. + /// \tparam R rounding mode to use + /// \param sign half-precision value with sign bit only + /// \return rounded underflowing half-precision value + /// \exception FE_UNDERFLOW + template HALF_CONSTEXPR_NOERR unsigned int underflow(unsigned int sign = 0) + { + #if HALF_ERRHANDLING + raise(FE_UNDERFLOW); + #endif + return (R==std::round_toward_infinity) ? (sign+1-(sign>>15)) : + (R==std::round_toward_neg_infinity) ? (sign+(sign>>15)) : + sign; + } + + /// Round half-precision number. + /// \tparam R rounding mode to use + /// \tparam I `true` to always raise INEXACT exception, `false` to raise only for rounded results + /// \param value finite half-precision number to round + /// \param g guard bit (most significant discarded bit) + /// \param s sticky bit (or of all but the most significant discarded bits) + /// \return rounded half-precision value + /// \exception FE_OVERFLOW on overflows + /// \exception FE_UNDERFLOW on underflows + /// \exception FE_INEXACT if value had to be rounded or \a I is `true` + template HALF_CONSTEXPR_NOERR unsigned int rounded(unsigned int value, int g, int s) + { + #if HALF_ERRHANDLING + value += (R==std::round_to_nearest) ? (g&(s|value)) : + (R==std::round_toward_infinity) ? (~(value>>15)&(g|s)) : + (R==std::round_toward_neg_infinity) ? ((value>>15)&(g|s)) : 0; + if((value&0x7C00) == 0x7C00) + raise(FE_OVERFLOW); + else if(value & 0x7C00) + raise(FE_INEXACT, I || (g|s)!=0); + else + raise(FE_UNDERFLOW, !(HALF_ERRHANDLING_UNDERFLOW_TO_INEXACT) || I || (g|s)!=0); + return value; + #else + return (R==std::round_to_nearest) ? (value+(g&(s|value))) : + (R==std::round_toward_infinity) ? (value+(~(value>>15)&(g|s))) : + (R==std::round_toward_neg_infinity) ? (value+((value>>15)&(g|s))) : + value; + #endif + } + + /// Round half-precision number to nearest integer value. + /// \tparam R rounding mode to use + /// \tparam E `true` for round to even, `false` for round away from zero + /// \tparam I `true` to raise INEXACT exception (if inexact), `false` to never raise it + /// \param value half-precision value to round + /// \return half-precision bits for nearest integral value + /// \exception FE_INVALID for signaling NaN + /// \exception FE_INEXACT if value had to be rounded and \a I is `true` + template unsigned int integral(unsigned int value) + { + unsigned int abs = value & 0x7FFF; + if(abs < 0x3C00) + { + raise(FE_INEXACT, I); + return ((R==std::round_to_nearest) ? (0x3C00&-static_cast(abs>=(0x3800+E))) : + (R==std::round_toward_infinity) ? (0x3C00&-(~(value>>15)&(abs!=0))) : + (R==std::round_toward_neg_infinity) ? (0x3C00&-static_cast(value>0x8000)) : + 0) | (value&0x8000); + } + if(abs >= 0x6400) + return (abs>0x7C00) ? signal(value) : value; + unsigned int exp = 25 - (abs>>10), mask = (1<>exp)&E)) : + (R==std::round_toward_infinity) ? (mask&((value>>15)-1)) : + (R==std::round_toward_neg_infinity) ? (mask&-(value>>15)) : + 0) + value) & ~mask; + } + + /// Convert fixed point to half-precision floating-point. + /// \tparam R rounding mode to use + /// \tparam F number of fractional bits in [11,31] + /// \tparam S `true` for signed, `false` for unsigned + /// \tparam N `true` for additional normalization step, `false` if already normalized to 1.F + /// \tparam I `true` to always raise INEXACT exception, `false` to raise only for rounded results + /// \param m mantissa in Q1.F fixed point format + /// \param exp biased exponent - 1 + /// \param sign half-precision value with sign bit only + /// \param s sticky bit (or of all but the most significant already discarded bits) + /// \return value converted to half-precision + /// \exception FE_OVERFLOW on overflows + /// \exception FE_UNDERFLOW on underflows + /// \exception FE_INEXACT if value had to be rounded or \a I is `true` + template unsigned int fixed2half(uint32 m, int exp = 14, unsigned int sign = 0, int s = 0) + { + if(S) + { + uint32 msign = sign_mask(m); + m = (m^msign) - msign; + sign = msign & 0x8000; + } + if(N) + for(; m<(static_cast(1)<(sign+(m>>(F-10-exp)), (m>>(F-11-exp))&1, s|((m&((static_cast(1)<<(F-11-exp))-1))!=0)); + return rounded(sign+(exp<<10)+(m>>(F-10)), (m>>(F-11))&1, s|((m&((static_cast(1)<<(F-11))-1))!=0)); + } + + /// Convert IEEE single-precision to half-precision. + /// Credit for this goes to [Jeroen van der Zijp](ftp://ftp.fox-toolkit.org/pub/fasthalffloatconversion.pdf). + /// \tparam R rounding mode to use + /// \param value single-precision value to convert + /// \return rounded half-precision value + /// \exception FE_OVERFLOW on overflows + /// \exception FE_UNDERFLOW on underflows + /// \exception FE_INEXACT if value had to be rounded + template unsigned int float2half_impl(float value, true_type) + { + #if HALF_ENABLE_F16C_INTRINSICS + return _mm_cvtsi128_si32(_mm_cvtps_ph(_mm_set_ss(value), + (R==std::round_to_nearest) ? _MM_FROUND_TO_NEAREST_INT : + (R==std::round_toward_zero) ? _MM_FROUND_TO_ZERO : + (R==std::round_toward_infinity) ? _MM_FROUND_TO_POS_INF : + (R==std::round_toward_neg_infinity) ? _MM_FROUND_TO_NEG_INF : + _MM_FROUND_CUR_DIRECTION)); + #else + bits::type fbits; + std::memcpy(&fbits, &value, sizeof(float)); + #if 1 + unsigned int sign = (fbits>>16) & 0x8000; + fbits &= 0x7FFFFFFF; + if(fbits >= 0x7F800000) + return sign | 0x7C00 | ((fbits>0x7F800000) ? (0x200|((fbits>>13)&0x3FF)) : 0); + if(fbits >= 0x47800000) + return overflow(sign); + if(fbits >= 0x38800000) + return rounded(sign|(((fbits>>23)-112)<<10)|((fbits>>13)&0x3FF), (fbits>>12)&1, (fbits&0xFFF)!=0); + if(fbits >= 0x33000000) + { + int i = 125 - (fbits>>23); + fbits = (fbits&0x7FFFFF) | 0x800000; + return rounded(sign|(fbits>>(i+1)), (fbits>>i)&1, (fbits&((static_cast(1)<(sign); + return sign; + #else + static const uint16 base_table[512] = { + 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, + 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, + 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, + 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, + 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, + 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, + 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0001, 0x0002, 0x0004, 0x0008, 0x0010, 0x0020, 0x0040, 0x0080, 0x0100, + 0x0200, 0x0400, 0x0800, 0x0C00, 0x1000, 0x1400, 0x1800, 0x1C00, 0x2000, 0x2400, 0x2800, 0x2C00, 0x3000, 0x3400, 0x3800, 0x3C00, + 0x4000, 0x4400, 0x4800, 0x4C00, 0x5000, 0x5400, 0x5800, 0x5C00, 0x6000, 0x6400, 0x6800, 0x6C00, 0x7000, 0x7400, 0x7800, 0x7BFF, + 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, + 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, + 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, + 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, + 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, + 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, + 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7C00, + 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, + 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, + 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, + 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, + 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, + 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, + 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8001, 0x8002, 0x8004, 0x8008, 0x8010, 0x8020, 0x8040, 0x8080, 0x8100, + 0x8200, 0x8400, 0x8800, 0x8C00, 0x9000, 0x9400, 0x9800, 0x9C00, 0xA000, 0xA400, 0xA800, 0xAC00, 0xB000, 0xB400, 0xB800, 0xBC00, + 0xC000, 0xC400, 0xC800, 0xCC00, 0xD000, 0xD400, 0xD800, 0xDC00, 0xE000, 0xE400, 0xE800, 0xEC00, 0xF000, 0xF400, 0xF800, 0xFBFF, + 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, + 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, + 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, + 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, + 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, + 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, + 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFC00 }; + static const unsigned char shift_table[256] = { + 24, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, + 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, + 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, + 25, 25, 25, 25, 25, 25, 24, 23, 22, 21, 20, 19, 18, 17, 16, 15, 14, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, + 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, + 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, + 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, + 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 13 }; + int sexp = fbits >> 23, exp = sexp & 0xFF, i = shift_table[exp]; + fbits &= 0x7FFFFF; + uint32 m = (fbits|((exp!=0)<<23)) & -static_cast(exp!=0xFF); + return rounded(base_table[sexp]+(fbits>>i), (m>>(i-1))&1, (((static_cast(1)<<(i-1))-1)&m)!=0); + #endif + #endif + } + + /// Convert IEEE double-precision to half-precision. + /// \tparam R rounding mode to use + /// \param value double-precision value to convert + /// \return rounded half-precision value + /// \exception FE_OVERFLOW on overflows + /// \exception FE_UNDERFLOW on underflows + /// \exception FE_INEXACT if value had to be rounded + template unsigned int float2half_impl(double value, true_type) + { + #if HALF_ENABLE_F16C_INTRINSICS + if(R == std::round_indeterminate) + return _mm_cvtsi128_si32(_mm_cvtps_ph(_mm_cvtpd_ps(_mm_set_sd(value)), _MM_FROUND_CUR_DIRECTION)); + #endif + bits::type dbits; + std::memcpy(&dbits, &value, sizeof(double)); + uint32 hi = dbits >> 32, lo = dbits & 0xFFFFFFFF; + unsigned int sign = (hi>>16) & 0x8000; + hi &= 0x7FFFFFFF; + if(hi >= 0x7FF00000) + return sign | 0x7C00 | ((dbits&0xFFFFFFFFFFFFF) ? (0x200|((hi>>10)&0x3FF)) : 0); + if(hi >= 0x40F00000) + return overflow(sign); + if(hi >= 0x3F100000) + return rounded(sign|(((hi>>20)-1008)<<10)|((hi>>10)&0x3FF), (hi>>9)&1, ((hi&0x1FF)|lo)!=0); + if(hi >= 0x3E600000) + { + int i = 1018 - (hi>>20); + hi = (hi&0xFFFFF) | 0x100000; + return rounded(sign|(hi>>(i+1)), (hi>>i)&1, ((hi&((static_cast(1)<(sign); + return sign; + } + + /// Convert non-IEEE floating-point to half-precision. + /// \tparam R rounding mode to use + /// \tparam T source type (builtin floating-point type) + /// \param value floating-point value to convert + /// \return rounded half-precision value + /// \exception FE_OVERFLOW on overflows + /// \exception FE_UNDERFLOW on underflows + /// \exception FE_INEXACT if value had to be rounded + template unsigned int float2half_impl(T value, ...) + { + unsigned int hbits = static_cast(builtin_signbit(value)) << 15; + if(value == T()) + return hbits; + if(builtin_isnan(value)) + return hbits | 0x7FFF; + if(builtin_isinf(value)) + return hbits | 0x7C00; + int exp; + std::frexp(value, &exp); + if(exp > 16) + return overflow(hbits); + if(exp < -13) + value = std::ldexp(value, 25); + else + { + value = std::ldexp(value, 12-exp); + hbits |= ((exp+13)<<10); + } + T ival, frac = std::modf(value, &ival); + int m = std::abs(static_cast(ival)); + return rounded(hbits+(m>>1), m&1, frac!=T()); + } + + /// Convert floating-point to half-precision. + /// \tparam R rounding mode to use + /// \tparam T source type (builtin floating-point type) + /// \param value floating-point value to convert + /// \return rounded half-precision value + /// \exception FE_OVERFLOW on overflows + /// \exception FE_UNDERFLOW on underflows + /// \exception FE_INEXACT if value had to be rounded + template unsigned int float2half(T value) + { + return float2half_impl(value, bool_type::is_iec559&&sizeof(typename bits::type)==sizeof(T)>()); + } + + /// Convert integer to half-precision floating-point. + /// \tparam R rounding mode to use + /// \tparam T type to convert (builtin integer type) + /// \param value integral value to convert + /// \return rounded half-precision value + /// \exception FE_OVERFLOW on overflows + /// \exception FE_INEXACT if value had to be rounded + template unsigned int int2half(T value) + { + unsigned int bits = static_cast(value<0) << 15; + if(!value) + return bits; + if(bits) + value = -value; + if(value > 0xFFFF) + return overflow(bits); + unsigned int m = static_cast(value), exp = 24; + for(; m<0x400; m<<=1,--exp) ; + for(; m>0x7FF; m>>=1,++exp) ; + bits |= (exp<<10) + m; + return (exp>24) ? rounded(bits, (value>>(exp-25))&1, (((1<<(exp-25))-1)&value)!=0) : bits; + } + + /// Convert half-precision to IEEE single-precision. + /// Credit for this goes to [Jeroen van der Zijp](ftp://ftp.fox-toolkit.org/pub/fasthalffloatconversion.pdf). + /// \param value half-precision value to convert + /// \return single-precision value + inline float half2float_impl(unsigned int value, float, true_type) + { + #if HALF_ENABLE_F16C_INTRINSICS + return _mm_cvtss_f32(_mm_cvtph_ps(_mm_cvtsi32_si128(value))); + #else + #if 0 + bits::type fbits = static_cast::type>(value&0x8000) << 16; + int abs = value & 0x7FFF; + if(abs) + { + fbits |= 0x38000000 << static_cast(abs>=0x7C00); + for(; abs<0x400; abs<<=1,fbits-=0x800000) ; + fbits += static_cast::type>(abs) << 13; + } + #else + static const bits::type mantissa_table[2048] = { + 0x00000000, 0x33800000, 0x34000000, 0x34400000, 0x34800000, 0x34A00000, 0x34C00000, 0x34E00000, 0x35000000, 0x35100000, 0x35200000, 0x35300000, 0x35400000, 0x35500000, 0x35600000, 0x35700000, + 0x35800000, 0x35880000, 0x35900000, 0x35980000, 0x35A00000, 0x35A80000, 0x35B00000, 0x35B80000, 0x35C00000, 0x35C80000, 0x35D00000, 0x35D80000, 0x35E00000, 0x35E80000, 0x35F00000, 0x35F80000, + 0x36000000, 0x36040000, 0x36080000, 0x360C0000, 0x36100000, 0x36140000, 0x36180000, 0x361C0000, 0x36200000, 0x36240000, 0x36280000, 0x362C0000, 0x36300000, 0x36340000, 0x36380000, 0x363C0000, + 0x36400000, 0x36440000, 0x36480000, 0x364C0000, 0x36500000, 0x36540000, 0x36580000, 0x365C0000, 0x36600000, 0x36640000, 0x36680000, 0x366C0000, 0x36700000, 0x36740000, 0x36780000, 0x367C0000, + 0x36800000, 0x36820000, 0x36840000, 0x36860000, 0x36880000, 0x368A0000, 0x368C0000, 0x368E0000, 0x36900000, 0x36920000, 0x36940000, 0x36960000, 0x36980000, 0x369A0000, 0x369C0000, 0x369E0000, + 0x36A00000, 0x36A20000, 0x36A40000, 0x36A60000, 0x36A80000, 0x36AA0000, 0x36AC0000, 0x36AE0000, 0x36B00000, 0x36B20000, 0x36B40000, 0x36B60000, 0x36B80000, 0x36BA0000, 0x36BC0000, 0x36BE0000, + 0x36C00000, 0x36C20000, 0x36C40000, 0x36C60000, 0x36C80000, 0x36CA0000, 0x36CC0000, 0x36CE0000, 0x36D00000, 0x36D20000, 0x36D40000, 0x36D60000, 0x36D80000, 0x36DA0000, 0x36DC0000, 0x36DE0000, + 0x36E00000, 0x36E20000, 0x36E40000, 0x36E60000, 0x36E80000, 0x36EA0000, 0x36EC0000, 0x36EE0000, 0x36F00000, 0x36F20000, 0x36F40000, 0x36F60000, 0x36F80000, 0x36FA0000, 0x36FC0000, 0x36FE0000, + 0x37000000, 0x37010000, 0x37020000, 0x37030000, 0x37040000, 0x37050000, 0x37060000, 0x37070000, 0x37080000, 0x37090000, 0x370A0000, 0x370B0000, 0x370C0000, 0x370D0000, 0x370E0000, 0x370F0000, + 0x37100000, 0x37110000, 0x37120000, 0x37130000, 0x37140000, 0x37150000, 0x37160000, 0x37170000, 0x37180000, 0x37190000, 0x371A0000, 0x371B0000, 0x371C0000, 0x371D0000, 0x371E0000, 0x371F0000, + 0x37200000, 0x37210000, 0x37220000, 0x37230000, 0x37240000, 0x37250000, 0x37260000, 0x37270000, 0x37280000, 0x37290000, 0x372A0000, 0x372B0000, 0x372C0000, 0x372D0000, 0x372E0000, 0x372F0000, + 0x37300000, 0x37310000, 0x37320000, 0x37330000, 0x37340000, 0x37350000, 0x37360000, 0x37370000, 0x37380000, 0x37390000, 0x373A0000, 0x373B0000, 0x373C0000, 0x373D0000, 0x373E0000, 0x373F0000, + 0x37400000, 0x37410000, 0x37420000, 0x37430000, 0x37440000, 0x37450000, 0x37460000, 0x37470000, 0x37480000, 0x37490000, 0x374A0000, 0x374B0000, 0x374C0000, 0x374D0000, 0x374E0000, 0x374F0000, + 0x37500000, 0x37510000, 0x37520000, 0x37530000, 0x37540000, 0x37550000, 0x37560000, 0x37570000, 0x37580000, 0x37590000, 0x375A0000, 0x375B0000, 0x375C0000, 0x375D0000, 0x375E0000, 0x375F0000, + 0x37600000, 0x37610000, 0x37620000, 0x37630000, 0x37640000, 0x37650000, 0x37660000, 0x37670000, 0x37680000, 0x37690000, 0x376A0000, 0x376B0000, 0x376C0000, 0x376D0000, 0x376E0000, 0x376F0000, + 0x37700000, 0x37710000, 0x37720000, 0x37730000, 0x37740000, 0x37750000, 0x37760000, 0x37770000, 0x37780000, 0x37790000, 0x377A0000, 0x377B0000, 0x377C0000, 0x377D0000, 0x377E0000, 0x377F0000, + 0x37800000, 0x37808000, 0x37810000, 0x37818000, 0x37820000, 0x37828000, 0x37830000, 0x37838000, 0x37840000, 0x37848000, 0x37850000, 0x37858000, 0x37860000, 0x37868000, 0x37870000, 0x37878000, + 0x37880000, 0x37888000, 0x37890000, 0x37898000, 0x378A0000, 0x378A8000, 0x378B0000, 0x378B8000, 0x378C0000, 0x378C8000, 0x378D0000, 0x378D8000, 0x378E0000, 0x378E8000, 0x378F0000, 0x378F8000, + 0x37900000, 0x37908000, 0x37910000, 0x37918000, 0x37920000, 0x37928000, 0x37930000, 0x37938000, 0x37940000, 0x37948000, 0x37950000, 0x37958000, 0x37960000, 0x37968000, 0x37970000, 0x37978000, + 0x37980000, 0x37988000, 0x37990000, 0x37998000, 0x379A0000, 0x379A8000, 0x379B0000, 0x379B8000, 0x379C0000, 0x379C8000, 0x379D0000, 0x379D8000, 0x379E0000, 0x379E8000, 0x379F0000, 0x379F8000, + 0x37A00000, 0x37A08000, 0x37A10000, 0x37A18000, 0x37A20000, 0x37A28000, 0x37A30000, 0x37A38000, 0x37A40000, 0x37A48000, 0x37A50000, 0x37A58000, 0x37A60000, 0x37A68000, 0x37A70000, 0x37A78000, + 0x37A80000, 0x37A88000, 0x37A90000, 0x37A98000, 0x37AA0000, 0x37AA8000, 0x37AB0000, 0x37AB8000, 0x37AC0000, 0x37AC8000, 0x37AD0000, 0x37AD8000, 0x37AE0000, 0x37AE8000, 0x37AF0000, 0x37AF8000, + 0x37B00000, 0x37B08000, 0x37B10000, 0x37B18000, 0x37B20000, 0x37B28000, 0x37B30000, 0x37B38000, 0x37B40000, 0x37B48000, 0x37B50000, 0x37B58000, 0x37B60000, 0x37B68000, 0x37B70000, 0x37B78000, + 0x37B80000, 0x37B88000, 0x37B90000, 0x37B98000, 0x37BA0000, 0x37BA8000, 0x37BB0000, 0x37BB8000, 0x37BC0000, 0x37BC8000, 0x37BD0000, 0x37BD8000, 0x37BE0000, 0x37BE8000, 0x37BF0000, 0x37BF8000, + 0x37C00000, 0x37C08000, 0x37C10000, 0x37C18000, 0x37C20000, 0x37C28000, 0x37C30000, 0x37C38000, 0x37C40000, 0x37C48000, 0x37C50000, 0x37C58000, 0x37C60000, 0x37C68000, 0x37C70000, 0x37C78000, + 0x37C80000, 0x37C88000, 0x37C90000, 0x37C98000, 0x37CA0000, 0x37CA8000, 0x37CB0000, 0x37CB8000, 0x37CC0000, 0x37CC8000, 0x37CD0000, 0x37CD8000, 0x37CE0000, 0x37CE8000, 0x37CF0000, 0x37CF8000, + 0x37D00000, 0x37D08000, 0x37D10000, 0x37D18000, 0x37D20000, 0x37D28000, 0x37D30000, 0x37D38000, 0x37D40000, 0x37D48000, 0x37D50000, 0x37D58000, 0x37D60000, 0x37D68000, 0x37D70000, 0x37D78000, + 0x37D80000, 0x37D88000, 0x37D90000, 0x37D98000, 0x37DA0000, 0x37DA8000, 0x37DB0000, 0x37DB8000, 0x37DC0000, 0x37DC8000, 0x37DD0000, 0x37DD8000, 0x37DE0000, 0x37DE8000, 0x37DF0000, 0x37DF8000, + 0x37E00000, 0x37E08000, 0x37E10000, 0x37E18000, 0x37E20000, 0x37E28000, 0x37E30000, 0x37E38000, 0x37E40000, 0x37E48000, 0x37E50000, 0x37E58000, 0x37E60000, 0x37E68000, 0x37E70000, 0x37E78000, + 0x37E80000, 0x37E88000, 0x37E90000, 0x37E98000, 0x37EA0000, 0x37EA8000, 0x37EB0000, 0x37EB8000, 0x37EC0000, 0x37EC8000, 0x37ED0000, 0x37ED8000, 0x37EE0000, 0x37EE8000, 0x37EF0000, 0x37EF8000, + 0x37F00000, 0x37F08000, 0x37F10000, 0x37F18000, 0x37F20000, 0x37F28000, 0x37F30000, 0x37F38000, 0x37F40000, 0x37F48000, 0x37F50000, 0x37F58000, 0x37F60000, 0x37F68000, 0x37F70000, 0x37F78000, + 0x37F80000, 0x37F88000, 0x37F90000, 0x37F98000, 0x37FA0000, 0x37FA8000, 0x37FB0000, 0x37FB8000, 0x37FC0000, 0x37FC8000, 0x37FD0000, 0x37FD8000, 0x37FE0000, 0x37FE8000, 0x37FF0000, 0x37FF8000, + 0x38000000, 0x38004000, 0x38008000, 0x3800C000, 0x38010000, 0x38014000, 0x38018000, 0x3801C000, 0x38020000, 0x38024000, 0x38028000, 0x3802C000, 0x38030000, 0x38034000, 0x38038000, 0x3803C000, + 0x38040000, 0x38044000, 0x38048000, 0x3804C000, 0x38050000, 0x38054000, 0x38058000, 0x3805C000, 0x38060000, 0x38064000, 0x38068000, 0x3806C000, 0x38070000, 0x38074000, 0x38078000, 0x3807C000, + 0x38080000, 0x38084000, 0x38088000, 0x3808C000, 0x38090000, 0x38094000, 0x38098000, 0x3809C000, 0x380A0000, 0x380A4000, 0x380A8000, 0x380AC000, 0x380B0000, 0x380B4000, 0x380B8000, 0x380BC000, + 0x380C0000, 0x380C4000, 0x380C8000, 0x380CC000, 0x380D0000, 0x380D4000, 0x380D8000, 0x380DC000, 0x380E0000, 0x380E4000, 0x380E8000, 0x380EC000, 0x380F0000, 0x380F4000, 0x380F8000, 0x380FC000, + 0x38100000, 0x38104000, 0x38108000, 0x3810C000, 0x38110000, 0x38114000, 0x38118000, 0x3811C000, 0x38120000, 0x38124000, 0x38128000, 0x3812C000, 0x38130000, 0x38134000, 0x38138000, 0x3813C000, + 0x38140000, 0x38144000, 0x38148000, 0x3814C000, 0x38150000, 0x38154000, 0x38158000, 0x3815C000, 0x38160000, 0x38164000, 0x38168000, 0x3816C000, 0x38170000, 0x38174000, 0x38178000, 0x3817C000, + 0x38180000, 0x38184000, 0x38188000, 0x3818C000, 0x38190000, 0x38194000, 0x38198000, 0x3819C000, 0x381A0000, 0x381A4000, 0x381A8000, 0x381AC000, 0x381B0000, 0x381B4000, 0x381B8000, 0x381BC000, + 0x381C0000, 0x381C4000, 0x381C8000, 0x381CC000, 0x381D0000, 0x381D4000, 0x381D8000, 0x381DC000, 0x381E0000, 0x381E4000, 0x381E8000, 0x381EC000, 0x381F0000, 0x381F4000, 0x381F8000, 0x381FC000, + 0x38200000, 0x38204000, 0x38208000, 0x3820C000, 0x38210000, 0x38214000, 0x38218000, 0x3821C000, 0x38220000, 0x38224000, 0x38228000, 0x3822C000, 0x38230000, 0x38234000, 0x38238000, 0x3823C000, + 0x38240000, 0x38244000, 0x38248000, 0x3824C000, 0x38250000, 0x38254000, 0x38258000, 0x3825C000, 0x38260000, 0x38264000, 0x38268000, 0x3826C000, 0x38270000, 0x38274000, 0x38278000, 0x3827C000, + 0x38280000, 0x38284000, 0x38288000, 0x3828C000, 0x38290000, 0x38294000, 0x38298000, 0x3829C000, 0x382A0000, 0x382A4000, 0x382A8000, 0x382AC000, 0x382B0000, 0x382B4000, 0x382B8000, 0x382BC000, + 0x382C0000, 0x382C4000, 0x382C8000, 0x382CC000, 0x382D0000, 0x382D4000, 0x382D8000, 0x382DC000, 0x382E0000, 0x382E4000, 0x382E8000, 0x382EC000, 0x382F0000, 0x382F4000, 0x382F8000, 0x382FC000, + 0x38300000, 0x38304000, 0x38308000, 0x3830C000, 0x38310000, 0x38314000, 0x38318000, 0x3831C000, 0x38320000, 0x38324000, 0x38328000, 0x3832C000, 0x38330000, 0x38334000, 0x38338000, 0x3833C000, + 0x38340000, 0x38344000, 0x38348000, 0x3834C000, 0x38350000, 0x38354000, 0x38358000, 0x3835C000, 0x38360000, 0x38364000, 0x38368000, 0x3836C000, 0x38370000, 0x38374000, 0x38378000, 0x3837C000, + 0x38380000, 0x38384000, 0x38388000, 0x3838C000, 0x38390000, 0x38394000, 0x38398000, 0x3839C000, 0x383A0000, 0x383A4000, 0x383A8000, 0x383AC000, 0x383B0000, 0x383B4000, 0x383B8000, 0x383BC000, + 0x383C0000, 0x383C4000, 0x383C8000, 0x383CC000, 0x383D0000, 0x383D4000, 0x383D8000, 0x383DC000, 0x383E0000, 0x383E4000, 0x383E8000, 0x383EC000, 0x383F0000, 0x383F4000, 0x383F8000, 0x383FC000, + 0x38400000, 0x38404000, 0x38408000, 0x3840C000, 0x38410000, 0x38414000, 0x38418000, 0x3841C000, 0x38420000, 0x38424000, 0x38428000, 0x3842C000, 0x38430000, 0x38434000, 0x38438000, 0x3843C000, + 0x38440000, 0x38444000, 0x38448000, 0x3844C000, 0x38450000, 0x38454000, 0x38458000, 0x3845C000, 0x38460000, 0x38464000, 0x38468000, 0x3846C000, 0x38470000, 0x38474000, 0x38478000, 0x3847C000, + 0x38480000, 0x38484000, 0x38488000, 0x3848C000, 0x38490000, 0x38494000, 0x38498000, 0x3849C000, 0x384A0000, 0x384A4000, 0x384A8000, 0x384AC000, 0x384B0000, 0x384B4000, 0x384B8000, 0x384BC000, + 0x384C0000, 0x384C4000, 0x384C8000, 0x384CC000, 0x384D0000, 0x384D4000, 0x384D8000, 0x384DC000, 0x384E0000, 0x384E4000, 0x384E8000, 0x384EC000, 0x384F0000, 0x384F4000, 0x384F8000, 0x384FC000, + 0x38500000, 0x38504000, 0x38508000, 0x3850C000, 0x38510000, 0x38514000, 0x38518000, 0x3851C000, 0x38520000, 0x38524000, 0x38528000, 0x3852C000, 0x38530000, 0x38534000, 0x38538000, 0x3853C000, + 0x38540000, 0x38544000, 0x38548000, 0x3854C000, 0x38550000, 0x38554000, 0x38558000, 0x3855C000, 0x38560000, 0x38564000, 0x38568000, 0x3856C000, 0x38570000, 0x38574000, 0x38578000, 0x3857C000, + 0x38580000, 0x38584000, 0x38588000, 0x3858C000, 0x38590000, 0x38594000, 0x38598000, 0x3859C000, 0x385A0000, 0x385A4000, 0x385A8000, 0x385AC000, 0x385B0000, 0x385B4000, 0x385B8000, 0x385BC000, + 0x385C0000, 0x385C4000, 0x385C8000, 0x385CC000, 0x385D0000, 0x385D4000, 0x385D8000, 0x385DC000, 0x385E0000, 0x385E4000, 0x385E8000, 0x385EC000, 0x385F0000, 0x385F4000, 0x385F8000, 0x385FC000, + 0x38600000, 0x38604000, 0x38608000, 0x3860C000, 0x38610000, 0x38614000, 0x38618000, 0x3861C000, 0x38620000, 0x38624000, 0x38628000, 0x3862C000, 0x38630000, 0x38634000, 0x38638000, 0x3863C000, + 0x38640000, 0x38644000, 0x38648000, 0x3864C000, 0x38650000, 0x38654000, 0x38658000, 0x3865C000, 0x38660000, 0x38664000, 0x38668000, 0x3866C000, 0x38670000, 0x38674000, 0x38678000, 0x3867C000, + 0x38680000, 0x38684000, 0x38688000, 0x3868C000, 0x38690000, 0x38694000, 0x38698000, 0x3869C000, 0x386A0000, 0x386A4000, 0x386A8000, 0x386AC000, 0x386B0000, 0x386B4000, 0x386B8000, 0x386BC000, + 0x386C0000, 0x386C4000, 0x386C8000, 0x386CC000, 0x386D0000, 0x386D4000, 0x386D8000, 0x386DC000, 0x386E0000, 0x386E4000, 0x386E8000, 0x386EC000, 0x386F0000, 0x386F4000, 0x386F8000, 0x386FC000, + 0x38700000, 0x38704000, 0x38708000, 0x3870C000, 0x38710000, 0x38714000, 0x38718000, 0x3871C000, 0x38720000, 0x38724000, 0x38728000, 0x3872C000, 0x38730000, 0x38734000, 0x38738000, 0x3873C000, + 0x38740000, 0x38744000, 0x38748000, 0x3874C000, 0x38750000, 0x38754000, 0x38758000, 0x3875C000, 0x38760000, 0x38764000, 0x38768000, 0x3876C000, 0x38770000, 0x38774000, 0x38778000, 0x3877C000, + 0x38780000, 0x38784000, 0x38788000, 0x3878C000, 0x38790000, 0x38794000, 0x38798000, 0x3879C000, 0x387A0000, 0x387A4000, 0x387A8000, 0x387AC000, 0x387B0000, 0x387B4000, 0x387B8000, 0x387BC000, + 0x387C0000, 0x387C4000, 0x387C8000, 0x387CC000, 0x387D0000, 0x387D4000, 0x387D8000, 0x387DC000, 0x387E0000, 0x387E4000, 0x387E8000, 0x387EC000, 0x387F0000, 0x387F4000, 0x387F8000, 0x387FC000, + 0x38000000, 0x38002000, 0x38004000, 0x38006000, 0x38008000, 0x3800A000, 0x3800C000, 0x3800E000, 0x38010000, 0x38012000, 0x38014000, 0x38016000, 0x38018000, 0x3801A000, 0x3801C000, 0x3801E000, + 0x38020000, 0x38022000, 0x38024000, 0x38026000, 0x38028000, 0x3802A000, 0x3802C000, 0x3802E000, 0x38030000, 0x38032000, 0x38034000, 0x38036000, 0x38038000, 0x3803A000, 0x3803C000, 0x3803E000, + 0x38040000, 0x38042000, 0x38044000, 0x38046000, 0x38048000, 0x3804A000, 0x3804C000, 0x3804E000, 0x38050000, 0x38052000, 0x38054000, 0x38056000, 0x38058000, 0x3805A000, 0x3805C000, 0x3805E000, + 0x38060000, 0x38062000, 0x38064000, 0x38066000, 0x38068000, 0x3806A000, 0x3806C000, 0x3806E000, 0x38070000, 0x38072000, 0x38074000, 0x38076000, 0x38078000, 0x3807A000, 0x3807C000, 0x3807E000, + 0x38080000, 0x38082000, 0x38084000, 0x38086000, 0x38088000, 0x3808A000, 0x3808C000, 0x3808E000, 0x38090000, 0x38092000, 0x38094000, 0x38096000, 0x38098000, 0x3809A000, 0x3809C000, 0x3809E000, + 0x380A0000, 0x380A2000, 0x380A4000, 0x380A6000, 0x380A8000, 0x380AA000, 0x380AC000, 0x380AE000, 0x380B0000, 0x380B2000, 0x380B4000, 0x380B6000, 0x380B8000, 0x380BA000, 0x380BC000, 0x380BE000, + 0x380C0000, 0x380C2000, 0x380C4000, 0x380C6000, 0x380C8000, 0x380CA000, 0x380CC000, 0x380CE000, 0x380D0000, 0x380D2000, 0x380D4000, 0x380D6000, 0x380D8000, 0x380DA000, 0x380DC000, 0x380DE000, + 0x380E0000, 0x380E2000, 0x380E4000, 0x380E6000, 0x380E8000, 0x380EA000, 0x380EC000, 0x380EE000, 0x380F0000, 0x380F2000, 0x380F4000, 0x380F6000, 0x380F8000, 0x380FA000, 0x380FC000, 0x380FE000, + 0x38100000, 0x38102000, 0x38104000, 0x38106000, 0x38108000, 0x3810A000, 0x3810C000, 0x3810E000, 0x38110000, 0x38112000, 0x38114000, 0x38116000, 0x38118000, 0x3811A000, 0x3811C000, 0x3811E000, + 0x38120000, 0x38122000, 0x38124000, 0x38126000, 0x38128000, 0x3812A000, 0x3812C000, 0x3812E000, 0x38130000, 0x38132000, 0x38134000, 0x38136000, 0x38138000, 0x3813A000, 0x3813C000, 0x3813E000, + 0x38140000, 0x38142000, 0x38144000, 0x38146000, 0x38148000, 0x3814A000, 0x3814C000, 0x3814E000, 0x38150000, 0x38152000, 0x38154000, 0x38156000, 0x38158000, 0x3815A000, 0x3815C000, 0x3815E000, + 0x38160000, 0x38162000, 0x38164000, 0x38166000, 0x38168000, 0x3816A000, 0x3816C000, 0x3816E000, 0x38170000, 0x38172000, 0x38174000, 0x38176000, 0x38178000, 0x3817A000, 0x3817C000, 0x3817E000, + 0x38180000, 0x38182000, 0x38184000, 0x38186000, 0x38188000, 0x3818A000, 0x3818C000, 0x3818E000, 0x38190000, 0x38192000, 0x38194000, 0x38196000, 0x38198000, 0x3819A000, 0x3819C000, 0x3819E000, + 0x381A0000, 0x381A2000, 0x381A4000, 0x381A6000, 0x381A8000, 0x381AA000, 0x381AC000, 0x381AE000, 0x381B0000, 0x381B2000, 0x381B4000, 0x381B6000, 0x381B8000, 0x381BA000, 0x381BC000, 0x381BE000, + 0x381C0000, 0x381C2000, 0x381C4000, 0x381C6000, 0x381C8000, 0x381CA000, 0x381CC000, 0x381CE000, 0x381D0000, 0x381D2000, 0x381D4000, 0x381D6000, 0x381D8000, 0x381DA000, 0x381DC000, 0x381DE000, + 0x381E0000, 0x381E2000, 0x381E4000, 0x381E6000, 0x381E8000, 0x381EA000, 0x381EC000, 0x381EE000, 0x381F0000, 0x381F2000, 0x381F4000, 0x381F6000, 0x381F8000, 0x381FA000, 0x381FC000, 0x381FE000, + 0x38200000, 0x38202000, 0x38204000, 0x38206000, 0x38208000, 0x3820A000, 0x3820C000, 0x3820E000, 0x38210000, 0x38212000, 0x38214000, 0x38216000, 0x38218000, 0x3821A000, 0x3821C000, 0x3821E000, + 0x38220000, 0x38222000, 0x38224000, 0x38226000, 0x38228000, 0x3822A000, 0x3822C000, 0x3822E000, 0x38230000, 0x38232000, 0x38234000, 0x38236000, 0x38238000, 0x3823A000, 0x3823C000, 0x3823E000, + 0x38240000, 0x38242000, 0x38244000, 0x38246000, 0x38248000, 0x3824A000, 0x3824C000, 0x3824E000, 0x38250000, 0x38252000, 0x38254000, 0x38256000, 0x38258000, 0x3825A000, 0x3825C000, 0x3825E000, + 0x38260000, 0x38262000, 0x38264000, 0x38266000, 0x38268000, 0x3826A000, 0x3826C000, 0x3826E000, 0x38270000, 0x38272000, 0x38274000, 0x38276000, 0x38278000, 0x3827A000, 0x3827C000, 0x3827E000, + 0x38280000, 0x38282000, 0x38284000, 0x38286000, 0x38288000, 0x3828A000, 0x3828C000, 0x3828E000, 0x38290000, 0x38292000, 0x38294000, 0x38296000, 0x38298000, 0x3829A000, 0x3829C000, 0x3829E000, + 0x382A0000, 0x382A2000, 0x382A4000, 0x382A6000, 0x382A8000, 0x382AA000, 0x382AC000, 0x382AE000, 0x382B0000, 0x382B2000, 0x382B4000, 0x382B6000, 0x382B8000, 0x382BA000, 0x382BC000, 0x382BE000, + 0x382C0000, 0x382C2000, 0x382C4000, 0x382C6000, 0x382C8000, 0x382CA000, 0x382CC000, 0x382CE000, 0x382D0000, 0x382D2000, 0x382D4000, 0x382D6000, 0x382D8000, 0x382DA000, 0x382DC000, 0x382DE000, + 0x382E0000, 0x382E2000, 0x382E4000, 0x382E6000, 0x382E8000, 0x382EA000, 0x382EC000, 0x382EE000, 0x382F0000, 0x382F2000, 0x382F4000, 0x382F6000, 0x382F8000, 0x382FA000, 0x382FC000, 0x382FE000, + 0x38300000, 0x38302000, 0x38304000, 0x38306000, 0x38308000, 0x3830A000, 0x3830C000, 0x3830E000, 0x38310000, 0x38312000, 0x38314000, 0x38316000, 0x38318000, 0x3831A000, 0x3831C000, 0x3831E000, + 0x38320000, 0x38322000, 0x38324000, 0x38326000, 0x38328000, 0x3832A000, 0x3832C000, 0x3832E000, 0x38330000, 0x38332000, 0x38334000, 0x38336000, 0x38338000, 0x3833A000, 0x3833C000, 0x3833E000, + 0x38340000, 0x38342000, 0x38344000, 0x38346000, 0x38348000, 0x3834A000, 0x3834C000, 0x3834E000, 0x38350000, 0x38352000, 0x38354000, 0x38356000, 0x38358000, 0x3835A000, 0x3835C000, 0x3835E000, + 0x38360000, 0x38362000, 0x38364000, 0x38366000, 0x38368000, 0x3836A000, 0x3836C000, 0x3836E000, 0x38370000, 0x38372000, 0x38374000, 0x38376000, 0x38378000, 0x3837A000, 0x3837C000, 0x3837E000, + 0x38380000, 0x38382000, 0x38384000, 0x38386000, 0x38388000, 0x3838A000, 0x3838C000, 0x3838E000, 0x38390000, 0x38392000, 0x38394000, 0x38396000, 0x38398000, 0x3839A000, 0x3839C000, 0x3839E000, + 0x383A0000, 0x383A2000, 0x383A4000, 0x383A6000, 0x383A8000, 0x383AA000, 0x383AC000, 0x383AE000, 0x383B0000, 0x383B2000, 0x383B4000, 0x383B6000, 0x383B8000, 0x383BA000, 0x383BC000, 0x383BE000, + 0x383C0000, 0x383C2000, 0x383C4000, 0x383C6000, 0x383C8000, 0x383CA000, 0x383CC000, 0x383CE000, 0x383D0000, 0x383D2000, 0x383D4000, 0x383D6000, 0x383D8000, 0x383DA000, 0x383DC000, 0x383DE000, + 0x383E0000, 0x383E2000, 0x383E4000, 0x383E6000, 0x383E8000, 0x383EA000, 0x383EC000, 0x383EE000, 0x383F0000, 0x383F2000, 0x383F4000, 0x383F6000, 0x383F8000, 0x383FA000, 0x383FC000, 0x383FE000, + 0x38400000, 0x38402000, 0x38404000, 0x38406000, 0x38408000, 0x3840A000, 0x3840C000, 0x3840E000, 0x38410000, 0x38412000, 0x38414000, 0x38416000, 0x38418000, 0x3841A000, 0x3841C000, 0x3841E000, + 0x38420000, 0x38422000, 0x38424000, 0x38426000, 0x38428000, 0x3842A000, 0x3842C000, 0x3842E000, 0x38430000, 0x38432000, 0x38434000, 0x38436000, 0x38438000, 0x3843A000, 0x3843C000, 0x3843E000, + 0x38440000, 0x38442000, 0x38444000, 0x38446000, 0x38448000, 0x3844A000, 0x3844C000, 0x3844E000, 0x38450000, 0x38452000, 0x38454000, 0x38456000, 0x38458000, 0x3845A000, 0x3845C000, 0x3845E000, + 0x38460000, 0x38462000, 0x38464000, 0x38466000, 0x38468000, 0x3846A000, 0x3846C000, 0x3846E000, 0x38470000, 0x38472000, 0x38474000, 0x38476000, 0x38478000, 0x3847A000, 0x3847C000, 0x3847E000, + 0x38480000, 0x38482000, 0x38484000, 0x38486000, 0x38488000, 0x3848A000, 0x3848C000, 0x3848E000, 0x38490000, 0x38492000, 0x38494000, 0x38496000, 0x38498000, 0x3849A000, 0x3849C000, 0x3849E000, + 0x384A0000, 0x384A2000, 0x384A4000, 0x384A6000, 0x384A8000, 0x384AA000, 0x384AC000, 0x384AE000, 0x384B0000, 0x384B2000, 0x384B4000, 0x384B6000, 0x384B8000, 0x384BA000, 0x384BC000, 0x384BE000, + 0x384C0000, 0x384C2000, 0x384C4000, 0x384C6000, 0x384C8000, 0x384CA000, 0x384CC000, 0x384CE000, 0x384D0000, 0x384D2000, 0x384D4000, 0x384D6000, 0x384D8000, 0x384DA000, 0x384DC000, 0x384DE000, + 0x384E0000, 0x384E2000, 0x384E4000, 0x384E6000, 0x384E8000, 0x384EA000, 0x384EC000, 0x384EE000, 0x384F0000, 0x384F2000, 0x384F4000, 0x384F6000, 0x384F8000, 0x384FA000, 0x384FC000, 0x384FE000, + 0x38500000, 0x38502000, 0x38504000, 0x38506000, 0x38508000, 0x3850A000, 0x3850C000, 0x3850E000, 0x38510000, 0x38512000, 0x38514000, 0x38516000, 0x38518000, 0x3851A000, 0x3851C000, 0x3851E000, + 0x38520000, 0x38522000, 0x38524000, 0x38526000, 0x38528000, 0x3852A000, 0x3852C000, 0x3852E000, 0x38530000, 0x38532000, 0x38534000, 0x38536000, 0x38538000, 0x3853A000, 0x3853C000, 0x3853E000, + 0x38540000, 0x38542000, 0x38544000, 0x38546000, 0x38548000, 0x3854A000, 0x3854C000, 0x3854E000, 0x38550000, 0x38552000, 0x38554000, 0x38556000, 0x38558000, 0x3855A000, 0x3855C000, 0x3855E000, + 0x38560000, 0x38562000, 0x38564000, 0x38566000, 0x38568000, 0x3856A000, 0x3856C000, 0x3856E000, 0x38570000, 0x38572000, 0x38574000, 0x38576000, 0x38578000, 0x3857A000, 0x3857C000, 0x3857E000, + 0x38580000, 0x38582000, 0x38584000, 0x38586000, 0x38588000, 0x3858A000, 0x3858C000, 0x3858E000, 0x38590000, 0x38592000, 0x38594000, 0x38596000, 0x38598000, 0x3859A000, 0x3859C000, 0x3859E000, + 0x385A0000, 0x385A2000, 0x385A4000, 0x385A6000, 0x385A8000, 0x385AA000, 0x385AC000, 0x385AE000, 0x385B0000, 0x385B2000, 0x385B4000, 0x385B6000, 0x385B8000, 0x385BA000, 0x385BC000, 0x385BE000, + 0x385C0000, 0x385C2000, 0x385C4000, 0x385C6000, 0x385C8000, 0x385CA000, 0x385CC000, 0x385CE000, 0x385D0000, 0x385D2000, 0x385D4000, 0x385D6000, 0x385D8000, 0x385DA000, 0x385DC000, 0x385DE000, + 0x385E0000, 0x385E2000, 0x385E4000, 0x385E6000, 0x385E8000, 0x385EA000, 0x385EC000, 0x385EE000, 0x385F0000, 0x385F2000, 0x385F4000, 0x385F6000, 0x385F8000, 0x385FA000, 0x385FC000, 0x385FE000, + 0x38600000, 0x38602000, 0x38604000, 0x38606000, 0x38608000, 0x3860A000, 0x3860C000, 0x3860E000, 0x38610000, 0x38612000, 0x38614000, 0x38616000, 0x38618000, 0x3861A000, 0x3861C000, 0x3861E000, + 0x38620000, 0x38622000, 0x38624000, 0x38626000, 0x38628000, 0x3862A000, 0x3862C000, 0x3862E000, 0x38630000, 0x38632000, 0x38634000, 0x38636000, 0x38638000, 0x3863A000, 0x3863C000, 0x3863E000, + 0x38640000, 0x38642000, 0x38644000, 0x38646000, 0x38648000, 0x3864A000, 0x3864C000, 0x3864E000, 0x38650000, 0x38652000, 0x38654000, 0x38656000, 0x38658000, 0x3865A000, 0x3865C000, 0x3865E000, + 0x38660000, 0x38662000, 0x38664000, 0x38666000, 0x38668000, 0x3866A000, 0x3866C000, 0x3866E000, 0x38670000, 0x38672000, 0x38674000, 0x38676000, 0x38678000, 0x3867A000, 0x3867C000, 0x3867E000, + 0x38680000, 0x38682000, 0x38684000, 0x38686000, 0x38688000, 0x3868A000, 0x3868C000, 0x3868E000, 0x38690000, 0x38692000, 0x38694000, 0x38696000, 0x38698000, 0x3869A000, 0x3869C000, 0x3869E000, + 0x386A0000, 0x386A2000, 0x386A4000, 0x386A6000, 0x386A8000, 0x386AA000, 0x386AC000, 0x386AE000, 0x386B0000, 0x386B2000, 0x386B4000, 0x386B6000, 0x386B8000, 0x386BA000, 0x386BC000, 0x386BE000, + 0x386C0000, 0x386C2000, 0x386C4000, 0x386C6000, 0x386C8000, 0x386CA000, 0x386CC000, 0x386CE000, 0x386D0000, 0x386D2000, 0x386D4000, 0x386D6000, 0x386D8000, 0x386DA000, 0x386DC000, 0x386DE000, + 0x386E0000, 0x386E2000, 0x386E4000, 0x386E6000, 0x386E8000, 0x386EA000, 0x386EC000, 0x386EE000, 0x386F0000, 0x386F2000, 0x386F4000, 0x386F6000, 0x386F8000, 0x386FA000, 0x386FC000, 0x386FE000, + 0x38700000, 0x38702000, 0x38704000, 0x38706000, 0x38708000, 0x3870A000, 0x3870C000, 0x3870E000, 0x38710000, 0x38712000, 0x38714000, 0x38716000, 0x38718000, 0x3871A000, 0x3871C000, 0x3871E000, + 0x38720000, 0x38722000, 0x38724000, 0x38726000, 0x38728000, 0x3872A000, 0x3872C000, 0x3872E000, 0x38730000, 0x38732000, 0x38734000, 0x38736000, 0x38738000, 0x3873A000, 0x3873C000, 0x3873E000, + 0x38740000, 0x38742000, 0x38744000, 0x38746000, 0x38748000, 0x3874A000, 0x3874C000, 0x3874E000, 0x38750000, 0x38752000, 0x38754000, 0x38756000, 0x38758000, 0x3875A000, 0x3875C000, 0x3875E000, + 0x38760000, 0x38762000, 0x38764000, 0x38766000, 0x38768000, 0x3876A000, 0x3876C000, 0x3876E000, 0x38770000, 0x38772000, 0x38774000, 0x38776000, 0x38778000, 0x3877A000, 0x3877C000, 0x3877E000, + 0x38780000, 0x38782000, 0x38784000, 0x38786000, 0x38788000, 0x3878A000, 0x3878C000, 0x3878E000, 0x38790000, 0x38792000, 0x38794000, 0x38796000, 0x38798000, 0x3879A000, 0x3879C000, 0x3879E000, + 0x387A0000, 0x387A2000, 0x387A4000, 0x387A6000, 0x387A8000, 0x387AA000, 0x387AC000, 0x387AE000, 0x387B0000, 0x387B2000, 0x387B4000, 0x387B6000, 0x387B8000, 0x387BA000, 0x387BC000, 0x387BE000, + 0x387C0000, 0x387C2000, 0x387C4000, 0x387C6000, 0x387C8000, 0x387CA000, 0x387CC000, 0x387CE000, 0x387D0000, 0x387D2000, 0x387D4000, 0x387D6000, 0x387D8000, 0x387DA000, 0x387DC000, 0x387DE000, + 0x387E0000, 0x387E2000, 0x387E4000, 0x387E6000, 0x387E8000, 0x387EA000, 0x387EC000, 0x387EE000, 0x387F0000, 0x387F2000, 0x387F4000, 0x387F6000, 0x387F8000, 0x387FA000, 0x387FC000, 0x387FE000 }; + static const bits::type exponent_table[64] = { + 0x00000000, 0x00800000, 0x01000000, 0x01800000, 0x02000000, 0x02800000, 0x03000000, 0x03800000, 0x04000000, 0x04800000, 0x05000000, 0x05800000, 0x06000000, 0x06800000, 0x07000000, 0x07800000, + 0x08000000, 0x08800000, 0x09000000, 0x09800000, 0x0A000000, 0x0A800000, 0x0B000000, 0x0B800000, 0x0C000000, 0x0C800000, 0x0D000000, 0x0D800000, 0x0E000000, 0x0E800000, 0x0F000000, 0x47800000, + 0x80000000, 0x80800000, 0x81000000, 0x81800000, 0x82000000, 0x82800000, 0x83000000, 0x83800000, 0x84000000, 0x84800000, 0x85000000, 0x85800000, 0x86000000, 0x86800000, 0x87000000, 0x87800000, + 0x88000000, 0x88800000, 0x89000000, 0x89800000, 0x8A000000, 0x8A800000, 0x8B000000, 0x8B800000, 0x8C000000, 0x8C800000, 0x8D000000, 0x8D800000, 0x8E000000, 0x8E800000, 0x8F000000, 0xC7800000 }; + static const unsigned short offset_table[64] = { + 0, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, + 0, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024 }; + bits::type fbits = mantissa_table[offset_table[value>>10]+(value&0x3FF)] + exponent_table[value>>10]; + #endif + float out; + std::memcpy(&out, &fbits, sizeof(float)); + return out; + #endif + } + + /// Convert half-precision to IEEE double-precision. + /// \param value half-precision value to convert + /// \return double-precision value + inline double half2float_impl(unsigned int value, double, true_type) + { + #if HALF_ENABLE_F16C_INTRINSICS + return _mm_cvtsd_f64(_mm_cvtps_pd(_mm_cvtph_ps(_mm_cvtsi32_si128(value)))); + #else + uint32 hi = static_cast(value&0x8000) << 16; + unsigned int abs = value & 0x7FFF; + if(abs) + { + hi |= 0x3F000000 << static_cast(abs>=0x7C00); + for(; abs<0x400; abs<<=1,hi-=0x100000) ; + hi += static_cast(abs) << 10; + } + bits::type dbits = static_cast::type>(hi) << 32; + double out; + std::memcpy(&out, &dbits, sizeof(double)); + return out; + #endif + } + + /// Convert half-precision to non-IEEE floating-point. + /// \tparam T type to convert to (builtin integer type) + /// \param value half-precision value to convert + /// \return floating-point value + template T half2float_impl(unsigned int value, T, ...) + { + T out; + unsigned int abs = value & 0x7FFF; + if(abs > 0x7C00) + out = (std::numeric_limits::has_signaling_NaN && !(abs&0x200)) ? std::numeric_limits::signaling_NaN() : + std::numeric_limits::has_quiet_NaN ? std::numeric_limits::quiet_NaN() : T(); + else if(abs == 0x7C00) + out = std::numeric_limits::has_infinity ? std::numeric_limits::infinity() : std::numeric_limits::max(); + else if(abs > 0x3FF) + out = std::ldexp(static_cast((abs&0x3FF)|0x400), (abs>>10)-25); + else + out = std::ldexp(static_cast(abs), -24); + return (value&0x8000) ? -out : out; + } + + /// Convert half-precision to floating-point. + /// \tparam T type to convert to (builtin integer type) + /// \param value half-precision value to convert + /// \return floating-point value + template T half2float(unsigned int value) + { + return half2float_impl(value, T(), bool_type::is_iec559&&sizeof(typename bits::type)==sizeof(T)>()); + } + + /// Convert half-precision floating-point to integer. + /// \tparam R rounding mode to use + /// \tparam E `true` for round to even, `false` for round away from zero + /// \tparam I `true` to raise INEXACT exception (if inexact), `false` to never raise it + /// \tparam T type to convert to (buitlin integer type with at least 16 bits precision, excluding any implicit sign bits) + /// \param value half-precision value to convert + /// \return rounded integer value + /// \exception FE_INVALID if value is not representable in type \a T + /// \exception FE_INEXACT if value had to be rounded and \a I is `true` + template T half2int(unsigned int value) + { + unsigned int abs = value & 0x7FFF; + if(abs >= 0x7C00) + { + raise(FE_INVALID); + return (value&0x8000) ? std::numeric_limits::min() : std::numeric_limits::max(); + } + if(abs < 0x3800) + { + raise(FE_INEXACT, I); + return (R==std::round_toward_infinity) ? T(~(value>>15)&(abs!=0)) : + (R==std::round_toward_neg_infinity) ? -T(value>0x8000) : + T(); + } + int exp = 25 - (abs>>10); + unsigned int m = (value&0x3FF) | 0x400; + int32 i = static_cast((exp<=0) ? (m<<-exp) : ((m+( + (R==std::round_to_nearest) ? ((1<<(exp-1))-(~(m>>exp)&E)) : + (R==std::round_toward_infinity) ? (((1<>15)-1)) : + (R==std::round_toward_neg_infinity) ? (((1<>15)) : 0))>>exp)); + if((!std::numeric_limits::is_signed && (value&0x8000)) || (std::numeric_limits::digits<16 && + ((value&0x8000) ? (-i::min()) : (i>std::numeric_limits::max())))) + raise(FE_INVALID); + else if(I && exp > 0 && (m&((1<((value&0x8000) ? -i : i); + } + + /// \} + /// \name Mathematics + /// \{ + + /// upper part of 64-bit multiplication. + /// \tparam R rounding mode to use + /// \param x first factor + /// \param y second factor + /// \return upper 32 bit of \a x * \a y + template uint32 mulhi(uint32 x, uint32 y) + { + uint32 xy = (x>>16) * (y&0xFFFF), yx = (x&0xFFFF) * (y>>16), c = (xy&0xFFFF) + (yx&0xFFFF) + (((x&0xFFFF)*(y&0xFFFF))>>16); + return (x>>16)*(y>>16) + (xy>>16) + (yx>>16) + (c>>16) + + ((R==std::round_to_nearest) ? ((c>>15)&1) : (R==std::round_toward_infinity) ? ((c&0xFFFF)!=0) : 0); + } + + /// 64-bit multiplication. + /// \param x first factor + /// \param y second factor + /// \return upper 32 bit of \a x * \a y rounded to nearest + inline uint32 multiply64(uint32 x, uint32 y) + { + #if HALF_ENABLE_CPP11_LONG_LONG + return static_cast((static_cast(x)*static_cast(y)+0x80000000)>>32); + #else + return mulhi(x, y); + #endif + } + + /// 64-bit division. + /// \param x upper 32 bit of dividend + /// \param y divisor + /// \param s variable to store sticky bit for rounding + /// \return (\a x << 32) / \a y + inline uint32 divide64(uint32 x, uint32 y, int &s) + { + #if HALF_ENABLE_CPP11_LONG_LONG + unsigned long long xx = static_cast(x) << 32; + return s = (xx%y!=0), static_cast(xx/y); + #else + y >>= 1; + uint32 rem = x, div = 0; + for(unsigned int i=0; i<32; ++i) + { + div <<= 1; + if(rem >= y) + { + rem -= y; + div |= 1; + } + rem <<= 1; + } + return s = rem > 1, div; + #endif + } + + /// Half precision positive modulus. + /// \tparam Q `true` to compute full quotient, `false` else + /// \tparam R `true` to compute signed remainder, `false` for positive remainder + /// \param x first operand as positive finite half-precision value + /// \param y second operand as positive finite half-precision value + /// \param quo adress to store quotient at, `nullptr` if \a Q `false` + /// \return modulus of \a x / \a y + template unsigned int mod(unsigned int x, unsigned int y, int *quo = NULL) + { + unsigned int q = 0; + if(x > y) + { + int absx = x, absy = y, expx = 0, expy = 0; + for(; absx<0x400; absx<<=1,--expx) ; + for(; absy<0x400; absy<<=1,--expy) ; + expx += absx >> 10; + expy += absy >> 10; + int mx = (absx&0x3FF) | 0x400, my = (absy&0x3FF) | 0x400; + for(int d=expx-expy; d; --d) + { + if(!Q && mx == my) + return 0; + if(mx >= my) + { + mx -= my; + q += Q; + } + mx <<= 1; + q <<= static_cast(Q); + } + if(!Q && mx == my) + return 0; + if(mx >= my) + { + mx -= my; + ++q; + } + if(Q) + { + q &= (1<<(std::numeric_limits::digits-1)) - 1; + if(!mx) + return *quo = q, 0; + } + for(; mx<0x400; mx<<=1,--expy) ; + x = (expy>0) ? ((expy<<10)|(mx&0x3FF)) : (mx>>(1-expy)); + } + if(R) + { + unsigned int a, b; + if(y < 0x800) + { + a = (x<0x400) ? (x<<1) : (x+0x400); + b = y; + } + else + { + a = x; + b = y - 0x400; + } + if(a > b || (a == b && (q&1))) + { + int exp = (y>>10) + (y<=0x3FF), d = exp - (x>>10) - (x<=0x3FF); + int m = (((y&0x3FF)|((y>0x3FF)<<10))<<1) - (((x&0x3FF)|((x>0x3FF)<<10))<<(1-d)); + for(; m<0x800 && exp>1; m<<=1,--exp) ; + x = 0x8000 + ((exp-1)<<10) + (m>>1); + q += Q; + } + } + if(Q) + *quo = q; + return x; + } + + /// Fixed point square root. + /// \tparam F number of fractional bits + /// \param r radicand in Q1.F fixed point format + /// \param exp exponent + /// \return square root as Q1.F/2 + template uint32 sqrt(uint32 &r, int &exp) + { + int i = exp & 1; + r <<= i; + exp = (exp-i) / 2; + uint32 m = 0; + for(uint32 bit=static_cast(1)<>=2) + { + if(r < m+bit) + m >>= 1; + else + { + r -= m + bit; + m = (m>>1) + bit; + } + } + return m; + } + + /// Fixed point binary exponential. + /// This uses the BKM algorithm in E-mode. + /// \param m exponent in [0,1) as Q0.31 + /// \param n number of iterations (at most 32) + /// \return 2 ^ \a m as Q1.31 + inline uint32 exp2(uint32 m, unsigned int n = 32) + { + static const uint32 logs[] = { + 0x80000000, 0x4AE00D1D, 0x2934F098, 0x15C01A3A, 0x0B31FB7D, 0x05AEB4DD, 0x02DCF2D1, 0x016FE50B, + 0x00B84E23, 0x005C3E10, 0x002E24CA, 0x001713D6, 0x000B8A47, 0x0005C53B, 0x0002E2A3, 0x00017153, + 0x0000B8AA, 0x00005C55, 0x00002E2B, 0x00001715, 0x00000B8B, 0x000005C5, 0x000002E3, 0x00000171, + 0x000000B9, 0x0000005C, 0x0000002E, 0x00000017, 0x0000000C, 0x00000006, 0x00000003, 0x00000001 }; + if(!m) + return 0x80000000; + uint32 mx = 0x80000000, my = 0; + for(unsigned int i=1; i> i; + } + } + return mx; + } + + /// Fixed point binary logarithm. + /// This uses the BKM algorithm in L-mode. + /// \param m mantissa in [1,2) as Q1.30 + /// \param n number of iterations (at most 32) + /// \return log2(\a m) as Q0.31 + inline uint32 log2(uint32 m, unsigned int n = 32) + { + static const uint32 logs[] = { + 0x80000000, 0x4AE00D1D, 0x2934F098, 0x15C01A3A, 0x0B31FB7D, 0x05AEB4DD, 0x02DCF2D1, 0x016FE50B, + 0x00B84E23, 0x005C3E10, 0x002E24CA, 0x001713D6, 0x000B8A47, 0x0005C53B, 0x0002E2A3, 0x00017153, + 0x0000B8AA, 0x00005C55, 0x00002E2B, 0x00001715, 0x00000B8B, 0x000005C5, 0x000002E3, 0x00000171, + 0x000000B9, 0x0000005C, 0x0000002E, 0x00000017, 0x0000000C, 0x00000006, 0x00000003, 0x00000001 }; + if(m == 0x40000000) + return 0; + uint32 mx = 0x40000000, my = 0; + for(unsigned int i=1; i>i); + if(mz <= m) + { + mx = mz; + my += logs[i]; + } + } + return my; + } + + /// Fixed point sine and cosine. + /// This uses the CORDIC algorithm in rotation mode. + /// \param mz angle in [-pi/2,pi/2] as Q1.30 + /// \param n number of iterations (at most 31) + /// \return sine and cosine of \a mz as Q1.30 + inline std::pair sincos(uint32 mz, unsigned int n = 31) + { + static const uint32 angles[] = { + 0x3243F6A9, 0x1DAC6705, 0x0FADBAFD, 0x07F56EA7, 0x03FEAB77, 0x01FFD55C, 0x00FFFAAB, 0x007FFF55, + 0x003FFFEB, 0x001FFFFD, 0x00100000, 0x00080000, 0x00040000, 0x00020000, 0x00010000, 0x00008000, + 0x00004000, 0x00002000, 0x00001000, 0x00000800, 0x00000400, 0x00000200, 0x00000100, 0x00000080, + 0x00000040, 0x00000020, 0x00000010, 0x00000008, 0x00000004, 0x00000002, 0x00000001 }; + uint32 mx = 0x26DD3B6A, my = 0; + for(unsigned int i=0; i0x3FF)<<10); + int exp = (abs>>10) + (abs<=0x3FF) - 15; + if(abs < 0x3A48) + return k = 0, m << (exp+20); + #if HALF_ENABLE_CPP11_LONG_LONG + unsigned long long y = m * 0xA2F9836E4E442, mask = (1ULL<<(62-exp)) - 1, yi = (y+(mask>>1)) & ~mask, f = y - yi; + uint32 sign = -static_cast(f>>63); + k = static_cast(yi>>(62-exp)); + return (multiply64(static_cast((sign ? -f : f)>>(31-exp)), 0xC90FDAA2)^sign) - sign; + #else + uint32 yh = m*0xA2F98 + mulhi(m, 0x36E4E442), yl = (m*0x36E4E442) & 0xFFFFFFFF; + uint32 mask = (static_cast(1)<<(30-exp)) - 1, yi = (yh+(mask>>1)) & ~mask, sign = -static_cast(yi>yh); + k = static_cast(yi>>(30-exp)); + uint32 fh = (yh^sign) + (yi^~sign) - ~sign, fl = (yl^sign) - sign; + return (multiply64((exp>-1) ? (((fh<<(1+exp))&0xFFFFFFFF)|((fl&0xFFFFFFFF)>>(31-exp))) : fh, 0xC90FDAA2)^sign) - sign; + #endif + } + + /// Get arguments for atan2 function. + /// \param abs half-precision floating-point value + /// \return \a abs and sqrt(1 - \a abs^2) as Q0.30 + inline std::pair atan2_args(unsigned int abs) + { + int exp = -15; + for(; abs<0x400; abs<<=1,--exp) ; + exp += abs >> 10; + uint32 my = ((abs&0x3FF)|0x400) << 5, r = my * my; + int rexp = 2 * exp; + r = 0x40000000 - ((rexp>-31) ? ((r>>-rexp)|((r&((static_cast(1)<<-rexp)-1))!=0)) : 1); + for(rexp=0; r<0x40000000; r<<=1,--rexp) ; + uint32 mx = sqrt<30>(r, rexp); + int d = exp - rexp; + if(d < 0) + return std::make_pair((d<-14) ? ((my>>(-d-14))+((my>>(-d-15))&1)) : (my<<(14+d)), (mx<<14)+(r<<13)/mx); + if(d > 0) + return std::make_pair(my<<14, (d>14) ? ((mx>>(d-14))+((mx>>(d-15))&1)) : ((d==14) ? mx : ((mx<<(14-d))+(r<<(13-d))/mx))); + return std::make_pair(my<<13, (mx<<13)+(r<<12)/mx); + } + + /// Get exponentials for hyperbolic computation + /// \param abs half-precision floating-point value + /// \param exp variable to take unbiased exponent of larger result + /// \param n number of BKM iterations (at most 32) + /// \return exp(abs) and exp(-\a abs) as Q1.31 with same exponent + inline std::pair hyperbolic_args(unsigned int abs, int &exp, unsigned int n = 32) + { + uint32 mx = detail::multiply64(static_cast((abs&0x3FF)+((abs>0x3FF)<<10))<<21, 0xB8AA3B29), my; + int e = (abs>>10) + (abs<=0x3FF); + if(e < 14) + { + exp = 0; + mx >>= 14 - e; + } + else + { + exp = mx >> (45-e); + mx = (mx<<(e-14)) & 0x7FFFFFFF; + } + mx = exp2(mx, n); + int d = exp << 1, s; + if(mx > 0x80000000) + { + my = divide64(0x80000000, mx, s); + my |= s; + ++d; + } + else + my = mx; + return std::make_pair(mx, (d<31) ? ((my>>d)|((my&((static_cast(1)< unsigned int exp2_post(uint32 m, int exp, bool esign, unsigned int sign = 0, unsigned int n = 32) + { + if(esign) + { + exp = -exp - (m!=0); + if(exp < -25) + return underflow(sign); + else if(exp == -25) + return rounded(sign, 1, m!=0); + } + else if(exp > 15) + return overflow(sign); + if(!m) + return sign | (((exp+=15)>0) ? (exp<<10) : check_underflow(0x200>>-exp)); + m = exp2(m, n); + int s = 0; + if(esign) + m = divide64(0x80000000, m, s); + return fixed2half(m, exp+14, sign, s); + } + + /// Postprocessing for binary logarithm. + /// \tparam R rounding mode to use + /// \tparam L logarithm for base transformation as Q1.31 + /// \param m fractional part of logarithm as Q0.31 + /// \param ilog signed integer part of logarithm + /// \param exp biased exponent of result + /// \param sign sign bit of result + /// \return value base-transformed and converted to half-precision + /// \exception FE_OVERFLOW on overflows + /// \exception FE_UNDERFLOW on underflows + /// \exception FE_INEXACT if no other exception occurred + template unsigned int log2_post(uint32 m, int ilog, int exp, unsigned int sign = 0) + { + uint32 msign = sign_mask(ilog); + m = (((static_cast(ilog)<<27)+(m>>4))^msign) - msign; + if(!m) + return 0; + for(; m<0x80000000; m<<=1,--exp) ; + int i = m >= L, s; + exp += i; + m >>= 1 + i; + sign ^= msign & 0x8000; + if(exp < -11) + return underflow(sign); + m = divide64(m, L, s); + return fixed2half(m, exp, sign, 1); + } + + /// Hypotenuse square root and postprocessing. + /// \tparam R rounding mode to use + /// \param r mantissa as Q2.30 + /// \param exp biased exponent + /// \return square root converted to half-precision + /// \exception FE_OVERFLOW on overflows + /// \exception FE_UNDERFLOW on underflows + /// \exception FE_INEXACT if value had to be rounded + template unsigned int hypot_post(uint32 r, int exp) + { + int i = r >> 31; + if((exp+=i) > 46) + return overflow(); + if(exp < -34) + return underflow(); + r = (r>>i) | (r&i); + uint32 m = sqrt<30>(r, exp+=15); + return fixed2half(m, exp-1, 0, r!=0); + } + + /// Division and postprocessing for tangents. + /// \tparam R rounding mode to use + /// \param my dividend as Q1.31 + /// \param mx divisor as Q1.31 + /// \param exp biased exponent of result + /// \param sign sign bit of result + /// \return quotient converted to half-precision + /// \exception FE_OVERFLOW on overflows + /// \exception FE_UNDERFLOW on underflows + /// \exception FE_INEXACT if no other exception occurred + template unsigned int tangent_post(uint32 my, uint32 mx, int exp, unsigned int sign = 0) + { + int i = my >= mx, s; + exp += i; + if(exp > 29) + return overflow(sign); + if(exp < -11) + return underflow(sign); + uint32 m = divide64(my>>(i+1), mx, s); + return fixed2half(m, exp, sign, s); + } + + /// Area function and postprocessing. + /// This computes the value directly in Q2.30 using the representation `asinh|acosh(x) = log(x+sqrt(x^2+|-1))`. + /// \tparam R rounding mode to use + /// \tparam S `true` for asinh, `false` for acosh + /// \param arg half-precision argument + /// \return asinh|acosh(\a arg) converted to half-precision + /// \exception FE_OVERFLOW on overflows + /// \exception FE_UNDERFLOW on underflows + /// \exception FE_INEXACT if no other exception occurred + template unsigned int area(unsigned int arg) + { + int abs = arg & 0x7FFF, expx = (abs>>10) + (abs<=0x3FF) - 15, expy = -15, ilog, i; + uint32 mx = static_cast((abs&0x3FF)|((abs>0x3FF)<<10)) << 20, my, r; + for(; abs<0x400; abs<<=1,--expy) ; + expy += abs >> 10; + r = ((abs&0x3FF)|0x400) << 5; + r *= r; + i = r >> 31; + expy = 2*expy + i; + r >>= i; + if(S) + { + if(expy < 0) + { + r = 0x40000000 + ((expy>-30) ? ((r>>-expy)|((r&((static_cast(1)<<-expy)-1))!=0)) : 1); + expy = 0; + } + else + { + r += 0x40000000 >> expy; + i = r >> 31; + r = (r>>i) | (r&i); + expy += i; + } + } + else + { + r -= 0x40000000 >> expy; + for(; r<0x40000000; r<<=1,--expy) ; + } + my = sqrt<30>(r, expy); + my = (my<<15) + (r<<14)/my; + if(S) + { + mx >>= expy - expx; + ilog = expy; + } + else + { + my >>= expx - expy; + ilog = expx; + } + my += mx; + i = my >> 31; + static const int G = S && (R==std::round_to_nearest); + return log2_post(log2(my>>i, 26+S+G)+(G<<3), ilog+i, 17, arg&(static_cast(S)<<15)); + } + + /// Class for 1.31 unsigned floating-point computation + struct f31 + { + /// Constructor. + /// \param mant mantissa as 1.31 + /// \param e exponent + HALF_CONSTEXPR f31(uint32 mant, int e) : m(mant), exp(e) {} + + /// Constructor. + /// \param abs unsigned half-precision value + f31(unsigned int abs) : exp(-15) + { + for(; abs<0x400; abs<<=1,--exp) ; + m = static_cast((abs&0x3FF)|0x400) << 21; + exp += (abs>>10); + } + + /// Addition operator. + /// \param a first operand + /// \param b second operand + /// \return \a a + \a b + friend f31 operator+(f31 a, f31 b) + { + if(b.exp > a.exp) + std::swap(a, b); + int d = a.exp - b.exp; + uint32 m = a.m + ((d<32) ? (b.m>>d) : 0); + int i = (m&0xFFFFFFFF) < a.m; + return f31(((m+i)>>i)|0x80000000, a.exp+i); + } + + /// Subtraction operator. + /// \param a first operand + /// \param b second operand + /// \return \a a - \a b + friend f31 operator-(f31 a, f31 b) + { + int d = a.exp - b.exp, exp = a.exp; + uint32 m = a.m - ((d<32) ? (b.m>>d) : 0); + if(!m) + return f31(0, -32); + for(; m<0x80000000; m<<=1,--exp) ; + return f31(m, exp); + } + + /// Multiplication operator. + /// \param a first operand + /// \param b second operand + /// \return \a a * \a b + friend f31 operator*(f31 a, f31 b) + { + uint32 m = multiply64(a.m, b.m); + int i = m >> 31; + return f31(m<<(1-i), a.exp + b.exp + i); + } + + /// Division operator. + /// \param a first operand + /// \param b second operand + /// \return \a a / \a b + friend f31 operator/(f31 a, f31 b) + { + int i = a.m >= b.m, s; + uint32 m = divide64((a.m+i)>>i, b.m, s); + return f31(m, a.exp - b.exp + i - 1); + } + + uint32 m; ///< mantissa as 1.31. + int exp; ///< exponent. + }; + + /// Error function and postprocessing. + /// This computes the value directly in Q1.31 using the approximations given + /// [here](https://en.wikipedia.org/wiki/Error_function#Approximation_with_elementary_functions). + /// \tparam R rounding mode to use + /// \tparam C `true` for comlementary error function, `false` else + /// \param arg half-precision function argument + /// \return approximated value of error function in half-precision + /// \exception FE_OVERFLOW on overflows + /// \exception FE_UNDERFLOW on underflows + /// \exception FE_INEXACT if no other exception occurred + template unsigned int erf(unsigned int arg) + { + unsigned int abs = arg & 0x7FFF, sign = arg & 0x8000; + f31 x(abs), x2 = x * x * f31(0xB8AA3B29, 0), t = f31(0x80000000, 0) / (f31(0x80000000, 0)+f31(0xA7BA054A, -2)*x), t2 = t * t; + f31 e = ((f31(0x87DC2213, 0)*t2+f31(0xB5F0E2AE, 0))*t2+f31(0x82790637, -2)-(f31(0xBA00E2B8, 0)*t2+f31(0x91A98E62, -2))*t) * t / + ((x2.exp<0) ? f31(exp2((x2.exp>-32) ? (x2.m>>-x2.exp) : 0, 30), 0) : f31(exp2((x2.m<>(31-x2.exp))); + return (!C || sign) ? fixed2half(0x80000000-(e.m>>(C-e.exp)), 14+C, sign&(C-1U)) : + (e.exp<-25) ? underflow() : fixed2half(e.m>>1, e.exp+14, 0, e.m&1); + } + + /// Gamma function and postprocessing. + /// This approximates the value of either the gamma function or its logarithm directly in Q1.31. + /// \tparam R rounding mode to use + /// \tparam L `true` for lograithm of gamma function, `false` for gamma function + /// \param arg half-precision floating-point value + /// \return lgamma/tgamma(\a arg) in half-precision + /// \exception FE_OVERFLOW on overflows + /// \exception FE_UNDERFLOW on underflows + /// \exception FE_INEXACT if \a arg is not a positive integer + template unsigned int gamma(unsigned int arg) + { +/* static const double p[] ={ 2.50662827563479526904, 225.525584619175212544, -268.295973841304927459, 80.9030806934622512966, -5.00757863970517583837, 0.0114684895434781459556 }; + double t = arg + 4.65, s = p[0]; + for(unsigned int i=0; i<5; ++i) + s += p[i+1] / (arg+i); + return std::log(s) + (arg-0.5)*std::log(t) - t; +*/ static const f31 pi(0xC90FDAA2, 1), lbe(0xB8AA3B29, 0); + unsigned int abs = arg & 0x7FFF, sign = arg & 0x8000; + bool bsign = sign != 0; + f31 z(abs), x = sign ? (z+f31(0x80000000, 0)) : z, t = x + f31(0x94CCCCCD, 2), s = + f31(0xA06C9901, 1) + f31(0xBBE654E2, -7)/(x+f31(0x80000000, 2)) + f31(0xA1CE6098, 6)/(x+f31(0x80000000, 1)) + + f31(0xE1868CB7, 7)/x - f31(0x8625E279, 8)/(x+f31(0x80000000, 0)) - f31(0xA03E158F, 2)/(x+f31(0xC0000000, 1)); + int i = (s.exp>=2) + (s.exp>=4) + (s.exp>=8) + (s.exp>=16); + s = f31((static_cast(s.exp)<<(31-i))+(log2(s.m>>1, 28)>>i), i) / lbe; + if(x.exp != -1 || x.m != 0x80000000) + { + i = (t.exp>=2) + (t.exp>=4) + (t.exp>=8); + f31 l = f31((static_cast(t.exp)<<(31-i))+(log2(t.m>>1, 30)>>i), i) / lbe; + s = (x.exp<-1) ? (s-(f31(0x80000000, -1)-x)*l) : (s+(x-f31(0x80000000, -1))*l); + } + s = x.exp ? (s-t) : (t-s); + if(bsign) + { + if(z.exp >= 0) + { + sign &= (L|((z.m>>(31-z.exp))&1)) - 1; + for(z=f31((z.m<<(1+z.exp))&0xFFFFFFFF, -1); z.m<0x80000000; z.m<<=1,--z.exp) ; + } + if(z.exp == -1) + z = f31(0x80000000, 0) - z; + if(z.exp < -1) + { + z = z * pi; + z.m = sincos(z.m>>(1-z.exp), 30).first; + for(z.exp=1; z.m<0x80000000; z.m<<=1,--z.exp) ; + } + else + z = f31(0x80000000, 0); + } + if(L) + { + if(bsign) + { + f31 l(0x92868247, 0); + if(z.exp < 0) + { + uint32 m = log2((z.m+1)>>1, 27); + z = f31(-((static_cast(z.exp)<<26)+(m>>5)), 5); + for(; z.m<0x80000000; z.m<<=1,--z.exp) ; + l = l + z / lbe; + } + sign = static_cast(x.exp&&(l.exp(x.exp==0) << 15; + if(s.exp < -24) + return underflow(sign); + if(s.exp > 15) + return overflow(sign); + } + } + else + { + s = s * lbe; + uint32 m; + if(s.exp < 0) + { + m = s.m >> -s.exp; + s.exp = 0; + } + else + { + m = (s.m<>(31-s.exp)); + } + s.m = exp2(m, 27); + if(!x.exp) + s = f31(0x80000000, 0) / s; + if(bsign) + { + if(z.exp < 0) + s = s * z; + s = pi / s; + if(s.exp < -24) + return underflow(sign); + } + else if(z.exp > 0 && !(z.m&((1<<(31-z.exp))-1))) + return ((s.exp+14)<<10) + (s.m>>21); + if(s.exp > 15) + return overflow(sign); + } + return fixed2half(s.m, s.exp+14, sign); + } + /// \} + + template struct half_caster; + } + + /// Half-precision floating-point type. + /// This class implements an IEEE-conformant half-precision floating-point type with the usual arithmetic + /// operators and conversions. It is implicitly convertible to single-precision floating-point, which makes artihmetic + /// expressions and functions with mixed-type operands to be of the most precise operand type. + /// + /// According to the C++98/03 definition, the half type is not a POD type. But according to C++11's less strict and + /// extended definitions it is both a standard layout type and a trivially copyable type (even if not a POD type), which + /// means it can be standard-conformantly copied using raw binary copies. But in this context some more words about the + /// actual size of the type. Although the half is representing an IEEE 16-bit type, it does not neccessarily have to be of + /// exactly 16-bits size. But on any reasonable implementation the actual binary representation of this type will most + /// probably not ivolve any additional "magic" or padding beyond the simple binary representation of the underlying 16-bit + /// IEEE number, even if not strictly guaranteed by the standard. But even then it only has an actual size of 16 bits if + /// your C++ implementation supports an unsigned integer type of exactly 16 bits width. But this should be the case on + /// nearly any reasonable platform. + /// + /// So if your C++ implementation is not totally exotic or imposes special alignment requirements, it is a reasonable + /// assumption that the data of a half is just comprised of the 2 bytes of the underlying IEEE representation. + class half + { + public: + /// \name Construction and assignment + /// \{ + + /// Default constructor. + /// This initializes the half to 0. Although this does not match the builtin types' default-initialization semantics + /// and may be less efficient than no initialization, it is needed to provide proper value-initialization semantics. + HALF_CONSTEXPR half() HALF_NOEXCEPT : data_() {} + + /// Conversion constructor. + /// \param rhs float to convert + /// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding + explicit half(float rhs) : data_(static_cast(detail::float2half(rhs))) {} + + /// Conversion to single-precision. + /// \return single precision value representing expression value + operator float() const { return detail::half2float(data_); } + + /// Assignment operator. + /// \param rhs single-precision value to copy from + /// \return reference to this half + /// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding + half& operator=(float rhs) { data_ = static_cast(detail::float2half(rhs)); return *this; } + + /// \} + /// \name Arithmetic updates + /// \{ + + /// Arithmetic assignment. + /// \tparam T type of concrete half expression + /// \param rhs half expression to add + /// \return reference to this half + /// \exception FE_... according to operator+(half,half) + half& operator+=(half rhs) { return *this = *this + rhs; } + + /// Arithmetic assignment. + /// \tparam T type of concrete half expression + /// \param rhs half expression to subtract + /// \return reference to this half + /// \exception FE_... according to operator-(half,half) + half& operator-=(half rhs) { return *this = *this - rhs; } + + /// Arithmetic assignment. + /// \tparam T type of concrete half expression + /// \param rhs half expression to multiply with + /// \return reference to this half + /// \exception FE_... according to operator*(half,half) + half& operator*=(half rhs) { return *this = *this * rhs; } + + /// Arithmetic assignment. + /// \tparam T type of concrete half expression + /// \param rhs half expression to divide by + /// \return reference to this half + /// \exception FE_... according to operator/(half,half) + half& operator/=(half rhs) { return *this = *this / rhs; } + + /// Arithmetic assignment. + /// \param rhs single-precision value to add + /// \return reference to this half + /// \exception FE_... according to operator=() + half& operator+=(float rhs) { return *this = *this + rhs; } + + /// Arithmetic assignment. + /// \param rhs single-precision value to subtract + /// \return reference to this half + /// \exception FE_... according to operator=() + half& operator-=(float rhs) { return *this = *this - rhs; } + + /// Arithmetic assignment. + /// \param rhs single-precision value to multiply with + /// \return reference to this half + /// \exception FE_... according to operator=() + half& operator*=(float rhs) { return *this = *this * rhs; } + + /// Arithmetic assignment. + /// \param rhs single-precision value to divide by + /// \return reference to this half + /// \exception FE_... according to operator=() + half& operator/=(float rhs) { return *this = *this / rhs; } + + /// \} + /// \name Increment and decrement + /// \{ + + /// Prefix increment. + /// \return incremented half value + /// \exception FE_... according to operator+(half,half) + half& operator++() { return *this = *this + half(detail::binary, 0x3C00); } + + /// Prefix decrement. + /// \return decremented half value + /// \exception FE_... according to operator-(half,half) + half& operator--() { return *this = *this + half(detail::binary, 0xBC00); } + + /// Postfix increment. + /// \return non-incremented half value + /// \exception FE_... according to operator+(half,half) + half operator++(int) { half out(*this); ++*this; return out; } + + /// Postfix decrement. + /// \return non-decremented half value + /// \exception FE_... according to operator-(half,half) + half operator--(int) { half out(*this); --*this; return out; } + /// \} + + private: + /// Rounding mode to use + static const std::float_round_style round_style = (std::float_round_style)(HALF_ROUND_STYLE); + + /// Constructor. + /// \param bits binary representation to set half to + HALF_CONSTEXPR half(detail::binary_t, unsigned int bits) HALF_NOEXCEPT : data_(static_cast(bits)) {} + + /// Internal binary representation + detail::uint16 data_; + + #ifndef HALF_DOXYGEN_ONLY + friend HALF_CONSTEXPR_NOERR bool operator==(half, half); + friend HALF_CONSTEXPR_NOERR bool operator!=(half, half); + friend HALF_CONSTEXPR_NOERR bool operator<(half, half); + friend HALF_CONSTEXPR_NOERR bool operator>(half, half); + friend HALF_CONSTEXPR_NOERR bool operator<=(half, half); + friend HALF_CONSTEXPR_NOERR bool operator>=(half, half); + friend HALF_CONSTEXPR half operator-(half); + friend half operator+(half, half); + friend half operator-(half, half); + friend half operator*(half, half); + friend half operator/(half, half); + template friend std::basic_ostream& operator<<(std::basic_ostream&, half); + template friend std::basic_istream& operator>>(std::basic_istream&, half&); + friend HALF_CONSTEXPR half fabs(half); + friend half fmod(half, half); + friend half remainder(half, half); + friend half remquo(half, half, int*); + friend half fma(half, half, half); + friend HALF_CONSTEXPR_NOERR half fmax(half, half); + friend HALF_CONSTEXPR_NOERR half fmin(half, half); + friend half fdim(half, half); + friend half nanh(const char*); + friend half exp(half); + friend half exp2(half); + friend half expm1(half); + friend half log(half); + friend half log10(half); + friend half log2(half); + friend half log1p(half); + friend half sqrt(half); + friend half rsqrt(half); + friend half cbrt(half); + friend half hypot(half, half); + friend half hypot(half, half, half); + friend half pow(half, half); + friend void sincos(half, half*, half*); + friend half sin(half); + friend half cos(half); + friend half tan(half); + friend half asin(half); + friend half acos(half); + friend half atan(half); + friend half atan2(half, half); + friend half sinh(half); + friend half cosh(half); + friend half tanh(half); + friend half asinh(half); + friend half acosh(half); + friend half atanh(half); + friend half erf(half); + friend half erfc(half); + friend half lgamma(half); + friend half tgamma(half); + friend half ceil(half); + friend half floor(half); + friend half trunc(half); + friend half round(half); + friend long lround(half); + friend half rint(half); + friend long lrint(half); + friend half nearbyint(half); + #ifdef HALF_ENABLE_CPP11_LONG_LONG + friend long long llround(half); + friend long long llrint(half); + #endif + friend half frexp(half, int*); + friend half scalbln(half, long); + friend half modf(half, half*); + friend int ilogb(half); + friend half logb(half); + friend half nextafter(half, half); + friend half nexttoward(half, long double); + friend HALF_CONSTEXPR half copysign(half, half); + friend HALF_CONSTEXPR int fpclassify(half); + friend HALF_CONSTEXPR bool isfinite(half); + friend HALF_CONSTEXPR bool isinf(half); + friend HALF_CONSTEXPR bool isnan(half); + friend HALF_CONSTEXPR bool isnormal(half); + friend HALF_CONSTEXPR bool signbit(half); + friend HALF_CONSTEXPR bool isgreater(half, half); + friend HALF_CONSTEXPR bool isgreaterequal(half, half); + friend HALF_CONSTEXPR bool isless(half, half); + friend HALF_CONSTEXPR bool islessequal(half, half); + friend HALF_CONSTEXPR bool islessgreater(half, half); + template friend struct detail::half_caster; + friend class std::numeric_limits; + #if HALF_ENABLE_CPP11_HASH + friend struct std::hash; + #endif + #if HALF_ENABLE_CPP11_USER_LITERALS + friend half literal::operator "" _h(long double); + #endif + #endif + }; + +#if HALF_ENABLE_CPP11_USER_LITERALS + namespace literal + { + /// Half literal. + /// While this returns a properly rounded half-precision value, half literals can unfortunately not be constant + /// expressions due to rather involved conversions. So don't expect this to be a literal literal without involving + /// conversion operations at runtime. It is a convenience feature, not a performance optimization. + /// \param value literal value + /// \return half with of given value (possibly rounded) + /// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding + inline half operator "" _h(long double value) { return half(detail::binary, detail::float2half(value)); } + } +#endif + + namespace detail + { + /// Helper class for half casts. + /// This class template has to be specialized for all valid cast arguments to define an appropriate static + /// `cast` member function and a corresponding `type` member denoting its return type. + /// \tparam T destination type + /// \tparam U source type + /// \tparam R rounding mode to use + template struct half_caster {}; + template struct half_caster + { + #if HALF_ENABLE_CPP11_STATIC_ASSERT && HALF_ENABLE_CPP11_TYPE_TRAITS + static_assert(std::is_arithmetic::value, "half_cast from non-arithmetic type unsupported"); + #endif + + static half cast(U arg) { return cast_impl(arg, is_float()); }; + + private: + static half cast_impl(U arg, true_type) { return half(binary, float2half(arg)); } + static half cast_impl(U arg, false_type) { return half(binary, int2half(arg)); } + }; + template struct half_caster + { + #if HALF_ENABLE_CPP11_STATIC_ASSERT && HALF_ENABLE_CPP11_TYPE_TRAITS + static_assert(std::is_arithmetic::value, "half_cast to non-arithmetic type unsupported"); + #endif + + static T cast(half arg) { return cast_impl(arg, is_float()); } + + private: + static T cast_impl(half arg, true_type) { return half2float(arg.data_); } + static T cast_impl(half arg, false_type) { return half2int(arg.data_); } + }; + template struct half_caster + { + static half cast(half arg) { return arg; } + }; + } +} + +/// Extensions to the C++ standard library. +namespace std +{ + /// Numeric limits for half-precision floats. + /// **See also:** Documentation for [std::numeric_limits](https://en.cppreference.com/w/cpp/types/numeric_limits) + template<> class numeric_limits + { + public: + /// Is template specialization. + static HALF_CONSTEXPR_CONST bool is_specialized = true; + + /// Supports signed values. + static HALF_CONSTEXPR_CONST bool is_signed = true; + + /// Is not an integer type. + static HALF_CONSTEXPR_CONST bool is_integer = false; + + /// Is not exact. + static HALF_CONSTEXPR_CONST bool is_exact = false; + + /// Doesn't provide modulo arithmetic. + static HALF_CONSTEXPR_CONST bool is_modulo = false; + + /// Has a finite set of values. + static HALF_CONSTEXPR_CONST bool is_bounded = true; + + /// IEEE conformant. + static HALF_CONSTEXPR_CONST bool is_iec559 = true; + + /// Supports infinity. + static HALF_CONSTEXPR_CONST bool has_infinity = true; + + /// Supports quiet NaNs. + static HALF_CONSTEXPR_CONST bool has_quiet_NaN = true; + + /// Supports signaling NaNs. + static HALF_CONSTEXPR_CONST bool has_signaling_NaN = true; + + /// Supports subnormal values. + static HALF_CONSTEXPR_CONST float_denorm_style has_denorm = denorm_present; + + /// Supports no denormalization detection. + static HALF_CONSTEXPR_CONST bool has_denorm_loss = false; + + #if HALF_ERRHANDLING_THROWS + static HALF_CONSTEXPR_CONST bool traps = true; + #else + /// Traps only if [HALF_ERRHANDLING_THROW_...](\ref HALF_ERRHANDLING_THROW_INVALID) is acitvated. + static HALF_CONSTEXPR_CONST bool traps = false; + #endif + + /// Does not support no pre-rounding underflow detection. + static HALF_CONSTEXPR_CONST bool tinyness_before = false; + + /// Rounding mode. + static HALF_CONSTEXPR_CONST float_round_style round_style = half_float::half::round_style; + + /// Significant digits. + static HALF_CONSTEXPR_CONST int digits = 11; + + /// Significant decimal digits. + static HALF_CONSTEXPR_CONST int digits10 = 3; + + /// Required decimal digits to represent all possible values. + static HALF_CONSTEXPR_CONST int max_digits10 = 5; + + /// Number base. + static HALF_CONSTEXPR_CONST int radix = 2; + + /// One more than smallest exponent. + static HALF_CONSTEXPR_CONST int min_exponent = -13; + + /// Smallest normalized representable power of 10. + static HALF_CONSTEXPR_CONST int min_exponent10 = -4; + + /// One more than largest exponent + static HALF_CONSTEXPR_CONST int max_exponent = 16; + + /// Largest finitely representable power of 10. + static HALF_CONSTEXPR_CONST int max_exponent10 = 4; + + /// Smallest positive normal value. + static HALF_CONSTEXPR half_float::half min() HALF_NOTHROW { return half_float::half(half_float::detail::binary, 0x0400); } + + /// Smallest finite value. + static HALF_CONSTEXPR half_float::half lowest() HALF_NOTHROW { return half_float::half(half_float::detail::binary, 0xFBFF); } + + /// Largest finite value. + static HALF_CONSTEXPR half_float::half max() HALF_NOTHROW { return half_float::half(half_float::detail::binary, 0x7BFF); } + + /// Difference between 1 and next representable value. + static HALF_CONSTEXPR half_float::half epsilon() HALF_NOTHROW { return half_float::half(half_float::detail::binary, 0x1400); } + + /// Maximum rounding error in ULP (units in the last place). + static HALF_CONSTEXPR half_float::half round_error() HALF_NOTHROW + { return half_float::half(half_float::detail::binary, (round_style==std::round_to_nearest) ? 0x3800 : 0x3C00); } + + /// Positive infinity. + static HALF_CONSTEXPR half_float::half infinity() HALF_NOTHROW { return half_float::half(half_float::detail::binary, 0x7C00); } + + /// Quiet NaN. + static HALF_CONSTEXPR half_float::half quiet_NaN() HALF_NOTHROW { return half_float::half(half_float::detail::binary, 0x7FFF); } + + /// Signaling NaN. + static HALF_CONSTEXPR half_float::half signaling_NaN() HALF_NOTHROW { return half_float::half(half_float::detail::binary, 0x7DFF); } + + /// Smallest positive subnormal value. + static HALF_CONSTEXPR half_float::half denorm_min() HALF_NOTHROW { return half_float::half(half_float::detail::binary, 0x0001); } + }; + +#if HALF_ENABLE_CPP11_HASH + /// Hash function for half-precision floats. + /// This is only defined if C++11 `std::hash` is supported and enabled. + /// + /// **See also:** Documentation for [std::hash](https://en.cppreference.com/w/cpp/utility/hash) + template<> struct hash + { + /// Type of function argument. + typedef half_float::half argument_type; + + /// Function return type. + typedef size_t result_type; + + /// Compute hash function. + /// \param arg half to hash + /// \return hash value + result_type operator()(argument_type arg) const { return hash()(arg.data_&-static_cast(arg.data_!=0x8000)); } + }; +#endif +} + +namespace half_float +{ + /// \anchor compop + /// \name Comparison operators + /// \{ + + /// Comparison for equality. + /// \param x first operand + /// \param y second operand + /// \retval true if operands equal + /// \retval false else + /// \exception FE_INVALID if \a x or \a y is NaN + inline HALF_CONSTEXPR_NOERR bool operator==(half x, half y) + { + return !detail::compsignal(x.data_, y.data_) && (x.data_==y.data_ || !((x.data_|y.data_)&0x7FFF)); + } + + /// Comparison for inequality. + /// \param x first operand + /// \param y second operand + /// \retval true if operands not equal + /// \retval false else + /// \exception FE_INVALID if \a x or \a y is NaN + inline HALF_CONSTEXPR_NOERR bool operator!=(half x, half y) + { + return detail::compsignal(x.data_, y.data_) || (x.data_!=y.data_ && ((x.data_|y.data_)&0x7FFF)); + } + + /// Comparison for less than. + /// \param x first operand + /// \param y second operand + /// \retval true if \a x less than \a y + /// \retval false else + /// \exception FE_INVALID if \a x or \a y is NaN + inline HALF_CONSTEXPR_NOERR bool operator<(half x, half y) + { + return !detail::compsignal(x.data_, y.data_) && + ((x.data_^(0x8000|(0x8000-(x.data_>>15))))+(x.data_>>15)) < ((y.data_^(0x8000|(0x8000-(y.data_>>15))))+(y.data_>>15)); + } + + /// Comparison for greater than. + /// \param x first operand + /// \param y second operand + /// \retval true if \a x greater than \a y + /// \retval false else + /// \exception FE_INVALID if \a x or \a y is NaN + inline HALF_CONSTEXPR_NOERR bool operator>(half x, half y) + { + return !detail::compsignal(x.data_, y.data_) && + ((x.data_^(0x8000|(0x8000-(x.data_>>15))))+(x.data_>>15)) > ((y.data_^(0x8000|(0x8000-(y.data_>>15))))+(y.data_>>15)); + } + + /// Comparison for less equal. + /// \param x first operand + /// \param y second operand + /// \retval true if \a x less equal \a y + /// \retval false else + /// \exception FE_INVALID if \a x or \a y is NaN + inline HALF_CONSTEXPR_NOERR bool operator<=(half x, half y) + { + return !detail::compsignal(x.data_, y.data_) && + ((x.data_^(0x8000|(0x8000-(x.data_>>15))))+(x.data_>>15)) <= ((y.data_^(0x8000|(0x8000-(y.data_>>15))))+(y.data_>>15)); + } + + /// Comparison for greater equal. + /// \param x first operand + /// \param y second operand + /// \retval true if \a x greater equal \a y + /// \retval false else + /// \exception FE_INVALID if \a x or \a y is NaN + inline HALF_CONSTEXPR_NOERR bool operator>=(half x, half y) + { + return !detail::compsignal(x.data_, y.data_) && + ((x.data_^(0x8000|(0x8000-(x.data_>>15))))+(x.data_>>15)) >= ((y.data_^(0x8000|(0x8000-(y.data_>>15))))+(y.data_>>15)); + } + + /// \} + /// \anchor arithmetics + /// \name Arithmetic operators + /// \{ + + /// Identity. + /// \param arg operand + /// \return unchanged operand + inline HALF_CONSTEXPR half operator+(half arg) { return arg; } + + /// Negation. + /// \param arg operand + /// \return negated operand + inline HALF_CONSTEXPR half operator-(half arg) { return half(detail::binary, arg.data_^0x8000); } + + /// Addition. + /// This operation is exact to rounding for all rounding modes. + /// \param x left operand + /// \param y right operand + /// \return sum of half expressions + /// \exception FE_INVALID if \a x and \a y are infinities with different signs or signaling NaNs + /// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding + inline half operator+(half x, half y) + { + #ifdef HALF_ARITHMETIC_TYPE + return half(detail::binary, detail::float2half(detail::half2float(x.data_)+detail::half2float(y.data_))); + #else + int absx = x.data_ & 0x7FFF, absy = y.data_ & 0x7FFF; + bool sub = ((x.data_^y.data_)&0x8000) != 0; + if(absx >= 0x7C00 || absy >= 0x7C00) + return half(detail::binary, (absx>0x7C00 || absy>0x7C00) ? detail::signal(x.data_, y.data_) : (absy!=0x7C00) ? x.data_ : + (sub && absx==0x7C00) ? detail::invalid() : y.data_); + if(!absx) + return absy ? y : half(detail::binary, (half::round_style==std::round_toward_neg_infinity) ? (x.data_|y.data_) : (x.data_&y.data_)); + if(!absy) + return x; + unsigned int sign = ((sub && absy>absx) ? y.data_ : x.data_) & 0x8000; + if(absy > absx) + std::swap(absx, absy); + int exp = (absx>>10) + (absx<=0x3FF), d = exp - (absy>>10) - (absy<=0x3FF), mx = ((absx&0x3FF)|((absx>0x3FF)<<10)) << 3, my; + if(d < 13) + { + my = ((absy&0x3FF)|((absy>0x3FF)<<10)) << 3; + my = (my>>d) | ((my&((1<(half::round_style==std::round_toward_neg_infinity)<<15); + for(; mx<0x2000 && exp>1; mx<<=1,--exp) ; + } + else + { + mx += my; + int i = mx >> 14; + if((exp+=i) > 30) + return half(detail::binary, detail::overflow(sign)); + mx = (mx>>i) | (mx&i); + } + return half(detail::binary, detail::rounded(sign+((exp-1)<<10)+(mx>>3), (mx>>2)&1, (mx&0x3)!=0)); + #endif + } + + /// Subtraction. + /// This operation is exact to rounding for all rounding modes. + /// \param x left operand + /// \param y right operand + /// \return difference of half expressions + /// \exception FE_INVALID if \a x and \a y are infinities with equal signs or signaling NaNs + /// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding + inline half operator-(half x, half y) + { + #ifdef HALF_ARITHMETIC_TYPE + return half(detail::binary, detail::float2half(detail::half2float(x.data_)-detail::half2float(y.data_))); + #else + return x + -y; + #endif + } + + /// Multiplication. + /// This operation is exact to rounding for all rounding modes. + /// \param x left operand + /// \param y right operand + /// \return product of half expressions + /// \exception FE_INVALID if multiplying 0 with infinity or if \a x or \a y is signaling NaN + /// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding + inline half operator*(half x, half y) + { + #ifdef HALF_ARITHMETIC_TYPE + return half(detail::binary, detail::float2half(detail::half2float(x.data_)*detail::half2float(y.data_))); + #else + int absx = x.data_ & 0x7FFF, absy = y.data_ & 0x7FFF, exp = -16; + unsigned int sign = (x.data_^y.data_) & 0x8000; + if(absx >= 0x7C00 || absy >= 0x7C00) + return half(detail::binary, (absx>0x7C00 || absy>0x7C00) ? detail::signal(x.data_, y.data_) : + ((absx==0x7C00 && !absy)||(absy==0x7C00 && !absx)) ? detail::invalid() : (sign|0x7C00)); + if(!absx || !absy) + return half(detail::binary, sign); + for(; absx<0x400; absx<<=1,--exp) ; + for(; absy<0x400; absy<<=1,--exp) ; + detail::uint32 m = static_cast((absx&0x3FF)|0x400) * static_cast((absy&0x3FF)|0x400); + int i = m >> 21, s = m & i; + exp += (absx>>10) + (absy>>10) + i; + if(exp > 29) + return half(detail::binary, detail::overflow(sign)); + else if(exp < -11) + return half(detail::binary, detail::underflow(sign)); + return half(detail::binary, detail::fixed2half(m>>i, exp, sign, s)); + #endif + } + + /// Division. + /// This operation is exact to rounding for all rounding modes. + /// \param x left operand + /// \param y right operand + /// \return quotient of half expressions + /// \exception FE_INVALID if dividing 0s or infinities with each other or if \a x or \a y is signaling NaN + /// \exception FE_DIVBYZERO if dividing finite value by 0 + /// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding + inline half operator/(half x, half y) + { + #ifdef HALF_ARITHMETIC_TYPE + return half(detail::binary, detail::float2half(detail::half2float(x.data_)/detail::half2float(y.data_))); + #else + int absx = x.data_ & 0x7FFF, absy = y.data_ & 0x7FFF, exp = 14; + unsigned int sign = (x.data_^y.data_) & 0x8000; + if(absx >= 0x7C00 || absy >= 0x7C00) + return half(detail::binary, (absx>0x7C00 || absy>0x7C00) ? detail::signal(x.data_, y.data_) : + (absx==absy) ? detail::invalid() : (sign|((absx==0x7C00) ? 0x7C00 : 0))); + if(!absx) + return half(detail::binary, absy ? sign : detail::invalid()); + if(!absy) + return half(detail::binary, detail::pole(sign)); + for(; absx<0x400; absx<<=1,--exp) ; + for(; absy<0x400; absy<<=1,++exp) ; + detail::uint32 mx = (absx&0x3FF) | 0x400, my = (absy&0x3FF) | 0x400; + int i = mx < my; + exp += (absx>>10) - (absy>>10) - i; + if(exp > 29) + return half(detail::binary, detail::overflow(sign)); + else if(exp < -11) + return half(detail::binary, detail::underflow(sign)); + mx <<= 12 + i; + my <<= 1; + return half(detail::binary, detail::fixed2half(mx/my, exp, sign, mx%my!=0)); + #endif + } + + /// \} + /// \anchor streaming + /// \name Input and output + /// \{ + + /// Output operator. + /// This uses the built-in functionality for streaming out floating-point numbers. + /// \param out output stream to write into + /// \param arg half expression to write + /// \return reference to output stream + template std::basic_ostream& operator<<(std::basic_ostream &out, half arg) + { + #ifdef HALF_ARITHMETIC_TYPE + return out << detail::half2float(arg.data_); + #else + return out << detail::half2float(arg.data_); + #endif + } + + /// Input operator. + /// This uses the built-in functionality for streaming in floating-point numbers, specifically double precision floating + /// point numbers (unless overridden with [HALF_ARITHMETIC_TYPE](\ref HALF_ARITHMETIC_TYPE)). So the input string is first + /// rounded to double precision using the underlying platform's current floating-point rounding mode before being rounded + /// to half-precision using the library's half-precision rounding mode. + /// \param in input stream to read from + /// \param arg half to read into + /// \return reference to input stream + /// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding + template std::basic_istream& operator>>(std::basic_istream &in, half &arg) + { + #ifdef HALF_ARITHMETIC_TYPE + detail::internal_t f; + #else + double f; + #endif + if(in >> f) + arg.data_ = detail::float2half(f); + return in; + } + + /// \} + /// \anchor basic + /// \name Basic mathematical operations + /// \{ + + /// Absolute value. + /// **See also:** Documentation for [std::fabs](https://en.cppreference.com/w/cpp/numeric/math/fabs). + /// \param arg operand + /// \return absolute value of \a arg + inline HALF_CONSTEXPR half fabs(half arg) { return half(detail::binary, arg.data_&0x7FFF); } + + /// Absolute value. + /// **See also:** Documentation for [std::abs](https://en.cppreference.com/w/cpp/numeric/math/fabs). + /// \param arg operand + /// \return absolute value of \a arg + inline HALF_CONSTEXPR half abs(half arg) { return fabs(arg); } + + /// Remainder of division. + /// **See also:** Documentation for [std::fmod](https://en.cppreference.com/w/cpp/numeric/math/fmod). + /// \param x first operand + /// \param y second operand + /// \return remainder of floating-point division. + /// \exception FE_INVALID if \a x is infinite or \a y is 0 or if \a x or \a y is signaling NaN + inline half fmod(half x, half y) + { + unsigned int absx = x.data_ & 0x7FFF, absy = y.data_ & 0x7FFF, sign = x.data_ & 0x8000; + if(absx >= 0x7C00 || absy >= 0x7C00) + return half(detail::binary, (absx>0x7C00 || absy>0x7C00) ? detail::signal(x.data_, y.data_) : + (absx==0x7C00) ? detail::invalid() : x.data_); + if(!absy) + return half(detail::binary, detail::invalid()); + if(!absx) + return x; + if(absx == absy) + return half(detail::binary, sign); + return half(detail::binary, sign|detail::mod(absx, absy)); + } + + /// Remainder of division. + /// **See also:** Documentation for [std::remainder](https://en.cppreference.com/w/cpp/numeric/math/remainder). + /// \param x first operand + /// \param y second operand + /// \return remainder of floating-point division. + /// \exception FE_INVALID if \a x is infinite or \a y is 0 or if \a x or \a y is signaling NaN + inline half remainder(half x, half y) + { + unsigned int absx = x.data_ & 0x7FFF, absy = y.data_ & 0x7FFF, sign = x.data_ & 0x8000; + if(absx >= 0x7C00 || absy >= 0x7C00) + return half(detail::binary, (absx>0x7C00 || absy>0x7C00) ? detail::signal(x.data_, y.data_) : + (absx==0x7C00) ? detail::invalid() : x.data_); + if(!absy) + return half(detail::binary, detail::invalid()); + if(absx == absy) + return half(detail::binary, sign); + return half(detail::binary, sign^detail::mod(absx, absy)); + } + + /// Remainder of division. + /// **See also:** Documentation for [std::remquo](https://en.cppreference.com/w/cpp/numeric/math/remquo). + /// \param x first operand + /// \param y second operand + /// \param quo address to store some bits of quotient at + /// \return remainder of floating-point division. + /// \exception FE_INVALID if \a x is infinite or \a y is 0 or if \a x or \a y is signaling NaN + inline half remquo(half x, half y, int *quo) + { + unsigned int absx = x.data_ & 0x7FFF, absy = y.data_ & 0x7FFF, value = x.data_ & 0x8000; + if(absx >= 0x7C00 || absy >= 0x7C00) + return half(detail::binary, (absx>0x7C00 || absy>0x7C00) ? detail::signal(x.data_, y.data_) : + (absx==0x7C00) ? detail::invalid() : (*quo = 0, x.data_)); + if(!absy) + return half(detail::binary, detail::invalid()); + bool qsign = ((value^y.data_)&0x8000) != 0; + int q = 1; + if(absx != absy) + value ^= detail::mod(absx, absy, &q); + return *quo = qsign ? -q : q, half(detail::binary, value); + } + + /// Fused multiply add. + /// This function is exact to rounding for all rounding modes. + /// + /// **See also:** Documentation for [std::fma](https://en.cppreference.com/w/cpp/numeric/math/fma). + /// \param x first operand + /// \param y second operand + /// \param z third operand + /// \return ( \a x * \a y ) + \a z rounded as one operation. + /// \exception FE_INVALID according to operator*() and operator+() unless any argument is a quiet NaN and no argument is a signaling NaN + /// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding the final addition + inline half fma(half x, half y, half z) + { + #ifdef HALF_ARITHMETIC_TYPE + detail::internal_t fx = detail::half2float(x.data_), fy = detail::half2float(y.data_), fz = detail::half2float(z.data_); + #if HALF_ENABLE_CPP11_CMATH && FP_FAST_FMA + return half(detail::binary, detail::float2half(std::fma(fx, fy, fz))); + #else + return half(detail::binary, detail::float2half(fx*fy+fz)); + #endif + #else + int absx = x.data_ & 0x7FFF, absy = y.data_ & 0x7FFF, absz = z.data_ & 0x7FFF, exp = -15; + unsigned int sign = (x.data_^y.data_) & 0x8000; + bool sub = ((sign^z.data_)&0x8000) != 0; + if(absx >= 0x7C00 || absy >= 0x7C00 || absz >= 0x7C00) + return (absx>0x7C00 || absy>0x7C00 || absz>0x7C00) ? half(detail::binary, detail::signal(x.data_, y.data_, z.data_)) : + (absx==0x7C00) ? half(detail::binary, (!absy || (sub && absz==0x7C00)) ? detail::invalid() : (sign|0x7C00)) : + (absy==0x7C00) ? half(detail::binary, (!absx || (sub && absz==0x7C00)) ? detail::invalid() : (sign|0x7C00)) : z; + if(!absx || !absy) + return absz ? z : half(detail::binary, (half::round_style==std::round_toward_neg_infinity) ? (z.data_|sign) : (z.data_&sign)); + for(; absx<0x400; absx<<=1,--exp) ; + for(; absy<0x400; absy<<=1,--exp) ; + detail::uint32 m = static_cast((absx&0x3FF)|0x400) * static_cast((absy&0x3FF)|0x400); + int i = m >> 21; + exp += (absx>>10) + (absy>>10) + i; + m <<= 3 - i; + if(absz) + { + int expz = 0; + for(; absz<0x400; absz<<=1,--expz) ; + expz += absz >> 10; + detail::uint32 mz = static_cast((absz&0x3FF)|0x400) << 13; + if(expz > exp || (expz == exp && mz > m)) + { + std::swap(m, mz); + std::swap(exp, expz); + if(sub) + sign = z.data_ & 0x8000; + } + int d = exp - expz; + mz = (d<23) ? ((mz>>d)|((mz&((static_cast(1)<(half::round_style==std::round_toward_neg_infinity)<<15); + for(; m<0x800000; m<<=1,--exp) ; + } + else + { + m += mz; + i = m >> 24; + m = (m>>i) | (m&i); + exp += i; + } + } + if(exp > 30) + return half(detail::binary, detail::overflow(sign)); + else if(exp < -10) + return half(detail::binary, detail::underflow(sign)); + return half(detail::binary, detail::fixed2half(m, exp-1, sign)); + #endif + } + + /// Maximum of half expressions. + /// **See also:** Documentation for [std::fmax](https://en.cppreference.com/w/cpp/numeric/math/fmax). + /// \param x first operand + /// \param y second operand + /// \return maximum of operands, ignoring quiet NaNs + /// \exception FE_INVALID if \a x or \a y is signaling NaN + inline HALF_CONSTEXPR_NOERR half fmax(half x, half y) + { + return half(detail::binary, (!isnan(y) && (isnan(x) || (x.data_^(0x8000|(0x8000-(x.data_>>15)))) < + (y.data_^(0x8000|(0x8000-(y.data_>>15)))))) ? detail::select(y.data_, x.data_) : detail::select(x.data_, y.data_)); + } + + /// Minimum of half expressions. + /// **See also:** Documentation for [std::fmin](https://en.cppreference.com/w/cpp/numeric/math/fmin). + /// \param x first operand + /// \param y second operand + /// \return minimum of operands, ignoring quiet NaNs + /// \exception FE_INVALID if \a x or \a y is signaling NaN + inline HALF_CONSTEXPR_NOERR half fmin(half x, half y) + { + return half(detail::binary, (!isnan(y) && (isnan(x) || (x.data_^(0x8000|(0x8000-(x.data_>>15)))) > + (y.data_^(0x8000|(0x8000-(y.data_>>15)))))) ? detail::select(y.data_, x.data_) : detail::select(x.data_, y.data_)); + } + + /// Positive difference. + /// This function is exact to rounding for all rounding modes. + /// + /// **See also:** Documentation for [std::fdim](https://en.cppreference.com/w/cpp/numeric/math/fdim). + /// \param x first operand + /// \param y second operand + /// \return \a x - \a y or 0 if difference negative + /// \exception FE_... according to operator-(half,half) + inline half fdim(half x, half y) + { + if(isnan(x) || isnan(y)) + return half(detail::binary, detail::signal(x.data_, y.data_)); + return (x.data_^(0x8000|(0x8000-(x.data_>>15)))) <= (y.data_^(0x8000|(0x8000-(y.data_>>15)))) ? half(detail::binary, 0) : (x-y); + } + + /// Get NaN value. + /// **See also:** Documentation for [std::nan](https://en.cppreference.com/w/cpp/numeric/math/nan). + /// \param arg string code + /// \return quiet NaN + inline half nanh(const char *arg) + { + unsigned int value = 0x7FFF; + while(*arg) + value ^= static_cast(*arg++) & 0xFF; + return half(detail::binary, value); + } + + /// \} + /// \anchor exponential + /// \name Exponential functions + /// \{ + + /// Exponential function. + /// This function is exact to rounding for all rounding modes. + /// + /// **See also:** Documentation for [std::exp](https://en.cppreference.com/w/cpp/numeric/math/exp). + /// \param arg function argument + /// \return e raised to \a arg + /// \exception FE_INVALID for signaling NaN + /// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding + inline half exp(half arg) + { + #ifdef HALF_ARITHMETIC_TYPE + return half(detail::binary, detail::float2half(std::exp(detail::half2float(arg.data_)))); + #else + int abs = arg.data_ & 0x7FFF, e = (abs>>10) + (abs<=0x3FF), exp; + if(!abs) + return half(detail::binary, 0x3C00); + if(abs >= 0x7C00) + return half(detail::binary, (abs==0x7C00) ? (0x7C00&((arg.data_>>15)-1U)) : detail::signal(arg.data_)); + if(abs >= 0x4C80) + return half(detail::binary, (arg.data_&0x8000) ? detail::underflow() : detail::overflow()); + detail::uint32 m = detail::multiply64(static_cast((abs&0x3FF)+((abs>0x3FF)<<10))<<21, 0xB8AA3B29); + if(e < 14) + { + exp = 0; + m >>= 14 - e; + } + else + { + exp = m >> (45-e); + m = (m<<(e-14)) & 0x7FFFFFFF; + } + return half(detail::binary, detail::exp2_post(m, exp, (arg.data_&0x8000)!=0, 0, 26)); + #endif + } + + /// Binary exponential. + /// This function is exact to rounding for all rounding modes. + /// + /// **See also:** Documentation for [std::exp2](https://en.cppreference.com/w/cpp/numeric/math/exp2). + /// \param arg function argument + /// \return 2 raised to \a arg + /// \exception FE_INVALID for signaling NaN + /// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding + inline half exp2(half arg) + { + #if defined(HALF_ARITHMETIC_TYPE) && HALF_ENABLE_CPP11_CMATH + return half(detail::binary, detail::float2half(std::exp2(detail::half2float(arg.data_)))); + #else + int abs = arg.data_ & 0x7FFF, e = (abs>>10) + (abs<=0x3FF), exp = (abs&0x3FF) + ((abs>0x3FF)<<10); + if(!abs) + return half(detail::binary, 0x3C00); + if(abs >= 0x7C00) + return half(detail::binary, (abs==0x7C00) ? (0x7C00&((arg.data_>>15)-1U)) : detail::signal(arg.data_)); + if(abs >= 0x4E40) + return half(detail::binary, (arg.data_&0x8000) ? detail::underflow() : detail::overflow()); + return half(detail::binary, detail::exp2_post( + (static_cast(exp)<<(6+e))&0x7FFFFFFF, exp>>(25-e), (arg.data_&0x8000)!=0, 0, 28)); + #endif + } + + /// Exponential minus one. + /// This function may be 1 ULP off the correctly rounded exact result in <0.05% of inputs for `std::round_to_nearest` + /// and in <1% of inputs for any other rounding mode. + /// + /// **See also:** Documentation for [std::expm1](https://en.cppreference.com/w/cpp/numeric/math/expm1). + /// \param arg function argument + /// \return e raised to \a arg and subtracted by 1 + /// \exception FE_INVALID for signaling NaN + /// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding + inline half expm1(half arg) + { + #if defined(HALF_ARITHMETIC_TYPE) && HALF_ENABLE_CPP11_CMATH + return half(detail::binary, detail::float2half(std::expm1(detail::half2float(arg.data_)))); + #else + unsigned int abs = arg.data_ & 0x7FFF, sign = arg.data_ & 0x8000, e = (abs>>10) + (abs<=0x3FF), exp; + if(!abs) + return arg; + if(abs >= 0x7C00) + return half(detail::binary, (abs==0x7C00) ? (0x7C00+(sign>>1)) : detail::signal(arg.data_)); + if(abs >= 0x4A00) + return half(detail::binary, (arg.data_&0x8000) ? detail::rounded(0xBBFF, 1, 1) : detail::overflow()); + detail::uint32 m = detail::multiply64(static_cast((abs&0x3FF)+((abs>0x3FF)<<10))<<21, 0xB8AA3B29); + if(e < 14) + { + exp = 0; + m >>= 14 - e; + } + else + { + exp = m >> (45-e); + m = (m<<(e-14)) & 0x7FFFFFFF; + } + m = detail::exp2(m); + if(sign) + { + int s = 0; + if(m > 0x80000000) + { + ++exp; + m = detail::divide64(0x80000000, m, s); + } + m = 0x80000000 - ((m>>exp)|((m&((static_cast(1)<>exp) : 1; + for(exp+=14; m<0x80000000 && exp; m<<=1,--exp) ; + if(exp > 29) + return half(detail::binary, detail::overflow()); + return half(detail::binary, detail::rounded(sign+(exp<<10)+(m>>21), (m>>20)&1, (m&0xFFFFF)!=0)); + #endif + } + + /// Natural logarithm. + /// This function is exact to rounding for all rounding modes. + /// + /// **See also:** Documentation for [std::log](https://en.cppreference.com/w/cpp/numeric/math/log). + /// \param arg function argument + /// \return logarithm of \a arg to base e + /// \exception FE_INVALID for signaling NaN or negative argument + /// \exception FE_DIVBYZERO for 0 + /// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding + inline half log(half arg) + { + #ifdef HALF_ARITHMETIC_TYPE + return half(detail::binary, detail::float2half(std::log(detail::half2float(arg.data_)))); + #else + int abs = arg.data_ & 0x7FFF, exp = -15; + if(!abs) + return half(detail::binary, detail::pole(0x8000)); + if(arg.data_ & 0x8000) + return half(detail::binary, (arg.data_<=0xFC00) ? detail::invalid() : detail::signal(arg.data_)); + if(abs >= 0x7C00) + return (abs==0x7C00) ? arg : half(detail::binary, detail::signal(arg.data_)); + for(; abs<0x400; abs<<=1,--exp) ; + exp += abs >> 10; + return half(detail::binary, detail::log2_post( + detail::log2(static_cast((abs&0x3FF)|0x400)<<20, 27)+8, exp, 17)); + #endif + } + + /// Common logarithm. + /// This function is exact to rounding for all rounding modes. + /// + /// **See also:** Documentation for [std::log10](https://en.cppreference.com/w/cpp/numeric/math/log10). + /// \param arg function argument + /// \return logarithm of \a arg to base 10 + /// \exception FE_INVALID for signaling NaN or negative argument + /// \exception FE_DIVBYZERO for 0 + /// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding + inline half log10(half arg) + { + #ifdef HALF_ARITHMETIC_TYPE + return half(detail::binary, detail::float2half(std::log10(detail::half2float(arg.data_)))); + #else + int abs = arg.data_ & 0x7FFF, exp = -15; + if(!abs) + return half(detail::binary, detail::pole(0x8000)); + if(arg.data_ & 0x8000) + return half(detail::binary, (arg.data_<=0xFC00) ? detail::invalid() : detail::signal(arg.data_)); + if(abs >= 0x7C00) + return (abs==0x7C00) ? arg : half(detail::binary, detail::signal(arg.data_)); + switch(abs) + { + case 0x4900: return half(detail::binary, 0x3C00); + case 0x5640: return half(detail::binary, 0x4000); + case 0x63D0: return half(detail::binary, 0x4200); + case 0x70E2: return half(detail::binary, 0x4400); + } + for(; abs<0x400; abs<<=1,--exp) ; + exp += abs >> 10; + return half(detail::binary, detail::log2_post( + detail::log2(static_cast((abs&0x3FF)|0x400)<<20, 27)+8, exp, 16)); + #endif + } + + /// Binary logarithm. + /// This function is exact to rounding for all rounding modes. + /// + /// **See also:** Documentation for [std::log2](https://en.cppreference.com/w/cpp/numeric/math/log2). + /// \param arg function argument + /// \return logarithm of \a arg to base 2 + /// \exception FE_INVALID for signaling NaN or negative argument + /// \exception FE_DIVBYZERO for 0 + /// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding + inline half log2(half arg) + { + #if defined(HALF_ARITHMETIC_TYPE) && HALF_ENABLE_CPP11_CMATH + return half(detail::binary, detail::float2half(std::log2(detail::half2float(arg.data_)))); + #else + int abs = arg.data_ & 0x7FFF, exp = -15, s = 0; + if(!abs) + return half(detail::binary, detail::pole(0x8000)); + if(arg.data_ & 0x8000) + return half(detail::binary, (arg.data_<=0xFC00) ? detail::invalid() : detail::signal(arg.data_)); + if(abs >= 0x7C00) + return (abs==0x7C00) ? arg : half(detail::binary, detail::signal(arg.data_)); + if(abs == 0x3C00) + return half(detail::binary, 0); + for(; abs<0x400; abs<<=1,--exp) ; + exp += (abs>>10); + if(!(abs&0x3FF)) + { + unsigned int value = static_cast(exp<0) << 15, m = std::abs(exp) << 6; + for(exp=18; m<0x400; m<<=1,--exp) ; + return half(detail::binary, value+(exp<<10)+m); + } + detail::uint32 ilog = exp, sign = detail::sign_mask(ilog), m = + (((ilog<<27)+(detail::log2(static_cast((abs&0x3FF)|0x400)<<20, 28)>>4))^sign) - sign; + if(!m) + return half(detail::binary, 0); + for(exp=14; m<0x8000000 && exp; m<<=1,--exp) ; + for(; m>0xFFFFFFF; m>>=1,++exp) + s |= m & 1; + return half(detail::binary, detail::fixed2half(m, exp, sign&0x8000, s)); + #endif + } + + /// Natural logarithm plus one. + /// This function may be 1 ULP off the correctly rounded exact result in <0.05% of inputs for `std::round_to_nearest` + /// and in ~1% of inputs for any other rounding mode. + /// + /// **See also:** Documentation for [std::log1p](https://en.cppreference.com/w/cpp/numeric/math/log1p). + /// \param arg function argument + /// \return logarithm of \a arg plus 1 to base e + /// \exception FE_INVALID for signaling NaN or argument <-1 + /// \exception FE_DIVBYZERO for -1 + /// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding + inline half log1p(half arg) + { + #if defined(HALF_ARITHMETIC_TYPE) && HALF_ENABLE_CPP11_CMATH + return half(detail::binary, detail::float2half(std::log1p(detail::half2float(arg.data_)))); + #else + if(arg.data_ >= 0xBC00) + return half(detail::binary, (arg.data_==0xBC00) ? detail::pole(0x8000) : (arg.data_<=0xFC00) ? detail::invalid() : detail::signal(arg.data_)); + int abs = arg.data_ & 0x7FFF, exp = -15; + if(!abs || abs >= 0x7C00) + return (abs>0x7C00) ? half(detail::binary, detail::signal(arg.data_)) : arg; + for(; abs<0x400; abs<<=1,--exp) ; + exp += abs >> 10; + detail::uint32 m = static_cast((abs&0x3FF)|0x400) << 20; + if(arg.data_ & 0x8000) + { + m = 0x40000000 - (m>>-exp); + for(exp=0; m<0x40000000; m<<=1,--exp) ; + } + else + { + if(exp < 0) + { + m = 0x40000000 + (m>>-exp); + exp = 0; + } + else + { + m += 0x40000000 >> exp; + int i = m >> 31; + m >>= i; + exp += i; + } + } + return half(detail::binary, detail::log2_post(detail::log2(m), exp, 17)); + #endif + } + + /// \} + /// \anchor power + /// \name Power functions + /// \{ + + /// Square root. + /// This function is exact to rounding for all rounding modes. + /// + /// **See also:** Documentation for [std::sqrt](https://en.cppreference.com/w/cpp/numeric/math/sqrt). + /// \param arg function argument + /// \return square root of \a arg + /// \exception FE_INVALID for signaling NaN and negative arguments + /// \exception FE_INEXACT according to rounding + inline half sqrt(half arg) + { + #ifdef HALF_ARITHMETIC_TYPE + return half(detail::binary, detail::float2half(std::sqrt(detail::half2float(arg.data_)))); + #else + int abs = arg.data_ & 0x7FFF, exp = 15; + if(!abs || arg.data_ >= 0x7C00) + return half(detail::binary, (abs>0x7C00) ? detail::signal(arg.data_) : (arg.data_>0x8000) ? detail::invalid() : arg.data_); + for(; abs<0x400; abs<<=1,--exp) ; + detail::uint32 r = static_cast((abs&0x3FF)|0x400) << 10, m = detail::sqrt<20>(r, exp+=abs>>10); + return half(detail::binary, detail::rounded((exp<<10)+(m&0x3FF), r>m, r!=0)); + #endif + } + + /// Inverse square root. + /// This function is exact to rounding for all rounding modes and thus generally more accurate than directly computing + /// 1 / sqrt(\a arg) in half-precision, in addition to also being faster. + /// \param arg function argument + /// \return reciprocal of square root of \a arg + /// \exception FE_INVALID for signaling NaN and negative arguments + /// \exception FE_INEXACT according to rounding + inline half rsqrt(half arg) + { + #ifdef HALF_ARITHMETIC_TYPE + return half(detail::binary, detail::float2half(detail::internal_t(1)/std::sqrt(detail::half2float(arg.data_)))); + #else + unsigned int abs = arg.data_ & 0x7FFF, bias = 0x4000; + if(!abs || arg.data_ >= 0x7C00) + return half(detail::binary, (abs>0x7C00) ? detail::signal(arg.data_) : (arg.data_>0x8000) ? + detail::invalid() : !abs ? detail::pole(arg.data_&0x8000) : 0); + for(; abs<0x400; abs<<=1,bias-=0x400) ; + unsigned int frac = (abs+=bias) & 0x7FF; + if(frac == 0x400) + return half(detail::binary, 0x7A00-(abs>>1)); + if((half::round_style == std::round_to_nearest && (frac == 0x3FE || frac == 0x76C)) || + (half::round_style != std::round_to_nearest && (frac == 0x15A || frac == 0x3FC || frac == 0x401 || frac == 0x402 || frac == 0x67B))) + return pow(arg, half(detail::binary, 0xB800)); + detail::uint32 f = 0x17376 - abs, mx = (abs&0x3FF) | 0x400, my = ((f>>1)&0x3FF) | 0x400, mz = my * my; + int expy = (f>>11) - 31, expx = 32 - (abs>>10), i = mz >> 21; + for(mz=0x60000000-(((mz>>i)*mx)>>(expx-2*expy-i)); mz<0x40000000; mz<<=1,--expy) ; + i = (my*=mz>>10) >> 31; + expy += i; + my = (my>>(20+i)) + 1; + i = (mz=my*my) >> 21; + for(mz=0x60000000-(((mz>>i)*mx)>>(expx-2*expy-i)); mz<0x40000000; mz<<=1,--expy) ; + i = (my*=(mz>>10)+1) >> 31; + return half(detail::binary, detail::fixed2half(my>>i, expy+i+14)); + #endif + } + + /// Cubic root. + /// This function is exact to rounding for all rounding modes. + /// + /// **See also:** Documentation for [std::cbrt](https://en.cppreference.com/w/cpp/numeric/math/cbrt). + /// \param arg function argument + /// \return cubic root of \a arg + /// \exception FE_INVALID for signaling NaN + /// \exception FE_INEXACT according to rounding + inline half cbrt(half arg) + { + #if defined(HALF_ARITHMETIC_TYPE) && HALF_ENABLE_CPP11_CMATH + return half(detail::binary, detail::float2half(std::cbrt(detail::half2float(arg.data_)))); + #else + int abs = arg.data_ & 0x7FFF, exp = -15; + if(!abs || abs == 0x3C00 || abs >= 0x7C00) + return (abs>0x7C00) ? half(detail::binary, detail::signal(arg.data_)) : arg; + for(; abs<0x400; abs<<=1, --exp); + detail::uint32 ilog = exp + (abs>>10), sign = detail::sign_mask(ilog), f, m = + (((ilog<<27)+(detail::log2(static_cast((abs&0x3FF)|0x400)<<20, 24)>>4))^sign) - sign; + for(exp=2; m<0x80000000; m<<=1,--exp) ; + m = detail::multiply64(m, 0xAAAAAAAB); + int i = m >> 31, s; + exp += i; + m <<= 1 - i; + if(exp < 0) + { + f = m >> -exp; + exp = 0; + } + else + { + f = (m<> (31-exp); + } + m = detail::exp2(f, (half::round_style==std::round_to_nearest) ? 29 : 26); + if(sign) + { + if(m > 0x80000000) + { + m = detail::divide64(0x80000000, m, s); + ++exp; + } + exp = -exp; + } + return half(detail::binary, (half::round_style==std::round_to_nearest) ? + detail::fixed2half(m, exp+14, arg.data_&0x8000) : + detail::fixed2half((m+0x80)>>8, exp+14, arg.data_&0x8000)); + #endif + } + + /// Hypotenuse function. + /// This function is exact to rounding for all rounding modes. + /// + /// **See also:** Documentation for [std::hypot](https://en.cppreference.com/w/cpp/numeric/math/hypot). + /// \param x first argument + /// \param y second argument + /// \return square root of sum of squares without internal over- or underflows + /// \exception FE_INVALID if \a x or \a y is signaling NaN + /// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding of the final square root + inline half hypot(half x, half y) + { + #ifdef HALF_ARITHMETIC_TYPE + detail::internal_t fx = detail::half2float(x.data_), fy = detail::half2float(y.data_); + #if HALF_ENABLE_CPP11_CMATH + return half(detail::binary, detail::float2half(std::hypot(fx, fy))); + #else + return half(detail::binary, detail::float2half(std::sqrt(fx*fx+fy*fy))); + #endif + #else + int absx = x.data_ & 0x7FFF, absy = y.data_ & 0x7FFF, expx = 0, expy = 0; + if(absx >= 0x7C00 || absy >= 0x7C00) + return half(detail::binary, (absx==0x7C00) ? detail::select(0x7C00, y.data_) : + (absy==0x7C00) ? detail::select(0x7C00, x.data_) : detail::signal(x.data_, y.data_)); + if(!absx) + return half(detail::binary, absy ? detail::check_underflow(absy) : 0); + if(!absy) + return half(detail::binary, detail::check_underflow(absx)); + if(absy > absx) + std::swap(absx, absy); + for(; absx<0x400; absx<<=1,--expx) ; + for(; absy<0x400; absy<<=1,--expy) ; + detail::uint32 mx = (absx&0x3FF) | 0x400, my = (absy&0x3FF) | 0x400; + mx *= mx; + my *= my; + int ix = mx >> 21, iy = my >> 21; + expx = 2*(expx+(absx>>10)) - 15 + ix; + expy = 2*(expy+(absy>>10)) - 15 + iy; + mx <<= 10 - ix; + my <<= 10 - iy; + int d = expx - expy; + my = (d<30) ? ((my>>d)|((my&((static_cast(1)<(mx+my, expx)); + #endif + } + + /// Hypotenuse function. + /// This function is exact to rounding for all rounding modes. + /// + /// **See also:** Documentation for [std::hypot](https://en.cppreference.com/w/cpp/numeric/math/hypot). + /// \param x first argument + /// \param y second argument + /// \param z third argument + /// \return square root of sum of squares without internal over- or underflows + /// \exception FE_INVALID if \a x, \a y or \a z is signaling NaN + /// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding of the final square root + inline half hypot(half x, half y, half z) + { + #ifdef HALF_ARITHMETIC_TYPE + detail::internal_t fx = detail::half2float(x.data_), fy = detail::half2float(y.data_), fz = detail::half2float(z.data_); + return half(detail::binary, detail::float2half(std::sqrt(fx*fx+fy*fy+fz*fz))); + #else + int absx = x.data_ & 0x7FFF, absy = y.data_ & 0x7FFF, absz = z.data_ & 0x7FFF, expx = 0, expy = 0, expz = 0; + if(!absx) + return hypot(y, z); + if(!absy) + return hypot(x, z); + if(!absz) + return hypot(x, y); + if(absx >= 0x7C00 || absy >= 0x7C00 || absz >= 0x7C00) + return half(detail::binary, (absx==0x7C00) ? detail::select(0x7C00, detail::select(y.data_, z.data_)) : + (absy==0x7C00) ? detail::select(0x7C00, detail::select(x.data_, z.data_)) : + (absz==0x7C00) ? detail::select(0x7C00, detail::select(x.data_, y.data_)) : + detail::signal(x.data_, y.data_, z.data_)); + if(absz > absy) + std::swap(absy, absz); + if(absy > absx) + std::swap(absx, absy); + if(absz > absy) + std::swap(absy, absz); + for(; absx<0x400; absx<<=1,--expx) ; + for(; absy<0x400; absy<<=1,--expy) ; + for(; absz<0x400; absz<<=1,--expz) ; + detail::uint32 mx = (absx&0x3FF) | 0x400, my = (absy&0x3FF) | 0x400, mz = (absz&0x3FF) | 0x400; + mx *= mx; + my *= my; + mz *= mz; + int ix = mx >> 21, iy = my >> 21, iz = mz >> 21; + expx = 2*(expx+(absx>>10)) - 15 + ix; + expy = 2*(expy+(absy>>10)) - 15 + iy; + expz = 2*(expz+(absz>>10)) - 15 + iz; + mx <<= 10 - ix; + my <<= 10 - iy; + mz <<= 10 - iz; + int d = expy - expz; + mz = (d<30) ? ((mz>>d)|((mz&((static_cast(1)<>1) | (my&1); + if(++expy > expx) + { + std::swap(mx, my); + std::swap(expx, expy); + } + } + d = expx - expy; + my = (d<30) ? ((my>>d)|((my&((static_cast(1)<(mx+my, expx)); + #endif + } + + /// Power function. + /// This function may be 1 ULP off the correctly rounded exact result for any rounding mode in ~0.00025% of inputs. + /// + /// **See also:** Documentation for [std::pow](https://en.cppreference.com/w/cpp/numeric/math/pow). + /// \param x base + /// \param y exponent + /// \return \a x raised to \a y + /// \exception FE_INVALID if \a x or \a y is signaling NaN or if \a x is finite an negative and \a y is finite and not integral + /// \exception FE_DIVBYZERO if \a x is 0 and \a y is negative + /// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding + inline half pow(half x, half y) + { + #ifdef HALF_ARITHMETIC_TYPE + return half(detail::binary, detail::float2half(std::pow(detail::half2float(x.data_), detail::half2float(y.data_)))); + #else + int absx = x.data_ & 0x7FFF, absy = y.data_ & 0x7FFF, exp = -15; + if(!absy || x.data_ == 0x3C00) + return half(detail::binary, detail::select(0x3C00, (x.data_==0x3C00) ? y.data_ : x.data_)); + bool is_int = absy >= 0x6400 || (absy>=0x3C00 && !(absy&((1<<(25-(absy>>10)))-1))); + unsigned int sign = x.data_ & (static_cast((absy<0x6800)&&is_int&&((absy>>(25-(absy>>10)))&1))<<15); + if(absx >= 0x7C00 || absy >= 0x7C00) + return half(detail::binary, (absx>0x7C00 || absy>0x7C00) ? detail::signal(x.data_, y.data_) : + (absy==0x7C00) ? ((absx==0x3C00) ? 0x3C00 : (!absx && y.data_==0xFC00) ? detail::pole() : + (0x7C00&-((y.data_>>15)^(absx>0x3C00)))) : (sign|(0x7C00&((y.data_>>15)-1U)))); + if(!absx) + return half(detail::binary, (y.data_&0x8000) ? detail::pole(sign) : sign); + if((x.data_&0x8000) && !is_int) + return half(detail::binary, detail::invalid()); + if(x.data_ == 0xBC00) + return half(detail::binary, sign|0x3C00); + switch(y.data_) + { + case 0x3800: return sqrt(x); + case 0x3C00: return half(detail::binary, detail::check_underflow(x.data_)); + case 0x4000: return x * x; + case 0xBC00: return half(detail::binary, 0x3C00) / x; + } + for(; absx<0x400; absx<<=1,--exp) ; + detail::uint32 ilog = exp + (absx>>10), msign = detail::sign_mask(ilog), f, m = + (((ilog<<27)+((detail::log2(static_cast((absx&0x3FF)|0x400)<<20)+8)>>4))^msign) - msign; + for(exp=-11; m<0x80000000; m<<=1,--exp) ; + for(; absy<0x400; absy<<=1,--exp) ; + m = detail::multiply64(m, static_cast((absy&0x3FF)|0x400)<<21); + int i = m >> 31; + exp += (absy>>10) + i; + m <<= 1 - i; + if(exp < 0) + { + f = m >> -exp; + exp = 0; + } + else + { + f = (m<> (31-exp); + } + return half(detail::binary, detail::exp2_post(f, exp, ((msign&1)^(y.data_>>15))!=0, sign)); + #endif + } + + /// \} + /// \anchor trigonometric + /// \name Trigonometric functions + /// \{ + + /// Compute sine and cosine simultaneously. + /// This returns the same results as sin() and cos() but is faster than calling each function individually. + /// + /// This function is exact to rounding for all rounding modes. + /// \param arg function argument + /// \param sin variable to take sine of \a arg + /// \param cos variable to take cosine of \a arg + /// \exception FE_INVALID for signaling NaN or infinity + /// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding + inline void sincos(half arg, half *sin, half *cos) + { + #ifdef HALF_ARITHMETIC_TYPE + detail::internal_t f = detail::half2float(arg.data_); + *sin = half(detail::binary, detail::float2half(std::sin(f))); + *cos = half(detail::binary, detail::float2half(std::cos(f))); + #else + int abs = arg.data_ & 0x7FFF, sign = arg.data_ >> 15, k; + if(abs >= 0x7C00) + *sin = *cos = half(detail::binary, (abs==0x7C00) ? detail::invalid() : detail::signal(arg.data_)); + else if(!abs) + { + *sin = arg; + *cos = half(detail::binary, 0x3C00); + } + else if(abs < 0x2500) + { + *sin = half(detail::binary, detail::rounded(arg.data_-1, 1, 1)); + *cos = half(detail::binary, detail::rounded(0x3BFF, 1, 1)); + } + else + { + if(half::round_style != std::round_to_nearest) + { + switch(abs) + { + case 0x48B7: + *sin = half(detail::binary, detail::rounded((~arg.data_&0x8000)|0x1D07, 1, 1)); + *cos = half(detail::binary, detail::rounded(0xBBFF, 1, 1)); + return; + case 0x598C: + *sin = half(detail::binary, detail::rounded((arg.data_&0x8000)|0x3BFF, 1, 1)); + *cos = half(detail::binary, detail::rounded(0x80FC, 1, 1)); + return; + case 0x6A64: + *sin = half(detail::binary, detail::rounded((~arg.data_&0x8000)|0x3BFE, 1, 1)); + *cos = half(detail::binary, detail::rounded(0x27FF, 1, 1)); + return; + case 0x6D8C: + *sin = half(detail::binary, detail::rounded((arg.data_&0x8000)|0x0FE6, 1, 1)); + *cos = half(detail::binary, detail::rounded(0x3BFF, 1, 1)); + return; + } + } + std::pair sc = detail::sincos(detail::angle_arg(abs, k), 28); + switch(k & 3) + { + case 1: sc = std::make_pair(sc.second, -sc.first); break; + case 2: sc = std::make_pair(-sc.first, -sc.second); break; + case 3: sc = std::make_pair(-sc.second, sc.first); break; + } + *sin = half(detail::binary, detail::fixed2half((sc.first^-static_cast(sign))+sign)); + *cos = half(detail::binary, detail::fixed2half(sc.second)); + } + #endif + } + + /// Sine function. + /// This function is exact to rounding for all rounding modes. + /// + /// **See also:** Documentation for [std::sin](https://en.cppreference.com/w/cpp/numeric/math/sin). + /// \param arg function argument + /// \return sine value of \a arg + /// \exception FE_INVALID for signaling NaN or infinity + /// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding + inline half sin(half arg) + { + #ifdef HALF_ARITHMETIC_TYPE + return half(detail::binary, detail::float2half(std::sin(detail::half2float(arg.data_)))); + #else + int abs = arg.data_ & 0x7FFF, k; + if(!abs) + return arg; + if(abs >= 0x7C00) + return half(detail::binary, (abs==0x7C00) ? detail::invalid() : detail::signal(arg.data_)); + if(abs < 0x2900) + return half(detail::binary, detail::rounded(arg.data_-1, 1, 1)); + if(half::round_style != std::round_to_nearest) + switch(abs) + { + case 0x48B7: return half(detail::binary, detail::rounded((~arg.data_&0x8000)|0x1D07, 1, 1)); + case 0x6A64: return half(detail::binary, detail::rounded((~arg.data_&0x8000)|0x3BFE, 1, 1)); + case 0x6D8C: return half(detail::binary, detail::rounded((arg.data_&0x8000)|0x0FE6, 1, 1)); + } + std::pair sc = detail::sincos(detail::angle_arg(abs, k), 28); + detail::uint32 sign = -static_cast(((k>>1)&1)^(arg.data_>>15)); + return half(detail::binary, detail::fixed2half((((k&1) ? sc.second : sc.first)^sign) - sign)); + #endif + } + + /// Cosine function. + /// This function is exact to rounding for all rounding modes. + /// + /// **See also:** Documentation for [std::cos](https://en.cppreference.com/w/cpp/numeric/math/cos). + /// \param arg function argument + /// \return cosine value of \a arg + /// \exception FE_INVALID for signaling NaN or infinity + /// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding + inline half cos(half arg) + { + #ifdef HALF_ARITHMETIC_TYPE + return half(detail::binary, detail::float2half(std::cos(detail::half2float(arg.data_)))); + #else + int abs = arg.data_ & 0x7FFF, k; + if(!abs) + return half(detail::binary, 0x3C00); + if(abs >= 0x7C00) + return half(detail::binary, (abs==0x7C00) ? detail::invalid() : detail::signal(arg.data_)); + if(abs < 0x2500) + return half(detail::binary, detail::rounded(0x3BFF, 1, 1)); + if(half::round_style != std::round_to_nearest && abs == 0x598C) + return half(detail::binary, detail::rounded(0x80FC, 1, 1)); + std::pair sc = detail::sincos(detail::angle_arg(abs, k), 28); + detail::uint32 sign = -static_cast(((k>>1)^k)&1); + return half(detail::binary, detail::fixed2half((((k&1) ? sc.first : sc.second)^sign) - sign)); + #endif + } + + /// Tangent function. + /// This function is exact to rounding for all rounding modes. + /// + /// **See also:** Documentation for [std::tan](https://en.cppreference.com/w/cpp/numeric/math/tan). + /// \param arg function argument + /// \return tangent value of \a arg + /// \exception FE_INVALID for signaling NaN or infinity + /// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding + inline half tan(half arg) + { + #ifdef HALF_ARITHMETIC_TYPE + return half(detail::binary, detail::float2half(std::tan(detail::half2float(arg.data_)))); + #else + int abs = arg.data_ & 0x7FFF, exp = 13, k; + if(!abs) + return arg; + if(abs >= 0x7C00) + return half(detail::binary, (abs==0x7C00) ? detail::invalid() : detail::signal(arg.data_)); + if(abs < 0x2700) + return half(detail::binary, detail::rounded(arg.data_, 0, 1)); + if(half::round_style != std::round_to_nearest) + switch(abs) + { + case 0x658C: return half(detail::binary, detail::rounded((arg.data_&0x8000)|0x07E6, 1, 1)); + case 0x7330: return half(detail::binary, detail::rounded((~arg.data_&0x8000)|0x4B62, 1, 1)); + } + std::pair sc = detail::sincos(detail::angle_arg(abs, k), 30); + if(k & 1) + sc = std::make_pair(-sc.second, sc.first); + detail::uint32 signy = detail::sign_mask(sc.first), signx = detail::sign_mask(sc.second); + detail::uint32 my = (sc.first^signy) - signy, mx = (sc.second^signx) - signx; + for(; my<0x80000000; my<<=1,--exp) ; + for(; mx<0x80000000; mx<<=1,++exp) ; + return half(detail::binary, detail::tangent_post(my, mx, exp, (signy^signx^arg.data_)&0x8000)); + #endif + } + + /// Arc sine. + /// This function is exact to rounding for all rounding modes. + /// + /// **See also:** Documentation for [std::asin](https://en.cppreference.com/w/cpp/numeric/math/asin). + /// \param arg function argument + /// \return arc sine value of \a arg + /// \exception FE_INVALID for signaling NaN or if abs(\a arg) > 1 + /// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding + inline half asin(half arg) + { + #ifdef HALF_ARITHMETIC_TYPE + return half(detail::binary, detail::float2half(std::asin(detail::half2float(arg.data_)))); + #else + unsigned int abs = arg.data_ & 0x7FFF, sign = arg.data_ & 0x8000; + if(!abs) + return arg; + if(abs >= 0x3C00) + return half(detail::binary, (abs>0x7C00) ? detail::signal(arg.data_) : (abs>0x3C00) ? detail::invalid() : + detail::rounded(sign|0x3E48, 0, 1)); + if(abs < 0x2900) + return half(detail::binary, detail::rounded(arg.data_, 0, 1)); + if(half::round_style != std::round_to_nearest && (abs == 0x2B44 || abs == 0x2DC3)) + return half(detail::binary, detail::rounded(arg.data_+1, 1, 1)); + std::pair sc = detail::atan2_args(abs); + detail::uint32 m = detail::atan2(sc.first, sc.second, (half::round_style==std::round_to_nearest) ? 27 : 26); + return half(detail::binary, detail::fixed2half(m, 14, sign)); + #endif + } + + /// Arc cosine function. + /// This function is exact to rounding for all rounding modes. + /// + /// **See also:** Documentation for [std::acos](https://en.cppreference.com/w/cpp/numeric/math/acos). + /// \param arg function argument + /// \return arc cosine value of \a arg + /// \exception FE_INVALID for signaling NaN or if abs(\a arg) > 1 + /// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding + inline half acos(half arg) + { + #ifdef HALF_ARITHMETIC_TYPE + return half(detail::binary, detail::float2half(std::acos(detail::half2float(arg.data_)))); + #else + unsigned int abs = arg.data_ & 0x7FFF, sign = arg.data_ >> 15; + if(!abs) + return half(detail::binary, detail::rounded(0x3E48, 0, 1)); + if(abs >= 0x3C00) + return half(detail::binary, (abs>0x7C00) ? detail::signal(arg.data_) : (abs>0x3C00) ? detail::invalid() : + sign ? detail::rounded(0x4248, 0, 1) : 0); + std::pair cs = detail::atan2_args(abs); + detail::uint32 m = detail::atan2(cs.second, cs.first, 28); + return half(detail::binary, detail::fixed2half(sign ? (0xC90FDAA2-m) : m, 15, 0, sign)); + #endif + } + + /// Arc tangent function. + /// This function is exact to rounding for all rounding modes. + /// + /// **See also:** Documentation for [std::atan](https://en.cppreference.com/w/cpp/numeric/math/atan). + /// \param arg function argument + /// \return arc tangent value of \a arg + /// \exception FE_INVALID for signaling NaN + /// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding + inline half atan(half arg) + { + #ifdef HALF_ARITHMETIC_TYPE + return half(detail::binary, detail::float2half(std::atan(detail::half2float(arg.data_)))); + #else + unsigned int abs = arg.data_ & 0x7FFF, sign = arg.data_ & 0x8000; + if(!abs) + return arg; + if(abs >= 0x7C00) + return half(detail::binary, (abs==0x7C00) ? detail::rounded(sign|0x3E48, 0, 1) : detail::signal(arg.data_)); + if(abs <= 0x2700) + return half(detail::binary, detail::rounded(arg.data_-1, 1, 1)); + int exp = (abs>>10) + (abs<=0x3FF); + detail::uint32 my = (abs&0x3FF) | ((abs>0x3FF)<<10); + detail::uint32 m = (exp>15) ? detail::atan2(my<<19, 0x20000000>>(exp-15), (half::round_style==std::round_to_nearest) ? 26 : 24) : + detail::atan2(my<<(exp+4), 0x20000000, (half::round_style==std::round_to_nearest) ? 30 : 28); + return half(detail::binary, detail::fixed2half(m, 14, sign)); + #endif + } + + /// Arc tangent function. + /// This function may be 1 ULP off the correctly rounded exact result in ~0.005% of inputs for `std::round_to_nearest`, + /// in ~0.1% of inputs for `std::round_toward_zero` and in ~0.02% of inputs for any other rounding mode. + /// + /// **See also:** Documentation for [std::atan2](https://en.cppreference.com/w/cpp/numeric/math/atan2). + /// \param y numerator + /// \param x denominator + /// \return arc tangent value + /// \exception FE_INVALID if \a x or \a y is signaling NaN + /// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding + inline half atan2(half y, half x) + { + #ifdef HALF_ARITHMETIC_TYPE + return half(detail::binary, detail::float2half(std::atan2(detail::half2float(y.data_), detail::half2float(x.data_)))); + #else + unsigned int absx = x.data_ & 0x7FFF, absy = y.data_ & 0x7FFF, signx = x.data_ >> 15, signy = y.data_ & 0x8000; + if(absx >= 0x7C00 || absy >= 0x7C00) + { + if(absx > 0x7C00 || absy > 0x7C00) + return half(detail::binary, detail::signal(x.data_, y.data_)); + if(absy == 0x7C00) + return half(detail::binary, (absx<0x7C00) ? detail::rounded(signy|0x3E48, 0, 1) : + signx ? detail::rounded(signy|0x40B6, 0, 1) : + detail::rounded(signy|0x3A48, 0, 1)); + return (x.data_==0x7C00) ? half(detail::binary, signy) : half(detail::binary, detail::rounded(signy|0x4248, 0, 1)); + } + if(!absy) + return signx ? half(detail::binary, detail::rounded(signy|0x4248, 0, 1)) : y; + if(!absx) + return half(detail::binary, detail::rounded(signy|0x3E48, 0, 1)); + int d = (absy>>10) + (absy<=0x3FF) - (absx>>10) - (absx<=0x3FF); + if(d > (signx ? 18 : 12)) + return half(detail::binary, detail::rounded(signy|0x3E48, 0, 1)); + if(signx && d < -11) + return half(detail::binary, detail::rounded(signy|0x4248, 0, 1)); + if(!signx && d < ((half::round_style==std::round_toward_zero) ? -15 : -9)) + { + for(; absy<0x400; absy<<=1,--d) ; + detail::uint32 mx = ((absx<<1)&0x7FF) | 0x800, my = ((absy<<1)&0x7FF) | 0x800; + int i = my < mx; + d -= i; + if(d < -25) + return half(detail::binary, detail::underflow(signy)); + my <<= 11 + i; + return half(detail::binary, detail::fixed2half(my/mx, d+14, signy, my%mx!=0)); + } + detail::uint32 m = detail::atan2( ((absy&0x3FF)|((absy>0x3FF)<<10))<<(19+((d<0) ? d : (d>0) ? 0 : -1)), + ((absx&0x3FF)|((absx>0x3FF)<<10))<<(19-((d>0) ? d : (d<0) ? 0 : 1))); + return half(detail::binary, detail::fixed2half(signx ? (0xC90FDAA2-m) : m, 15, signy, signx)); + #endif + } + + /// \} + /// \anchor hyperbolic + /// \name Hyperbolic functions + /// \{ + + /// Hyperbolic sine. + /// This function is exact to rounding for all rounding modes. + /// + /// **See also:** Documentation for [std::sinh](https://en.cppreference.com/w/cpp/numeric/math/sinh). + /// \param arg function argument + /// \return hyperbolic sine value of \a arg + /// \exception FE_INVALID for signaling NaN + /// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding + inline half sinh(half arg) + { + #ifdef HALF_ARITHMETIC_TYPE + return half(detail::binary, detail::float2half(std::sinh(detail::half2float(arg.data_)))); + #else + int abs = arg.data_ & 0x7FFF, exp; + if(!abs || abs >= 0x7C00) + return (abs>0x7C00) ? half(detail::binary, detail::signal(arg.data_)) : arg; + if(abs <= 0x2900) + return half(detail::binary, detail::rounded(arg.data_, 0, 1)); + std::pair mm = detail::hyperbolic_args(abs, exp, (half::round_style==std::round_to_nearest) ? 29 : 27); + detail::uint32 m = mm.first - mm.second; + for(exp+=13; m<0x80000000 && exp; m<<=1,--exp) ; + unsigned int sign = arg.data_ & 0x8000; + if(exp > 29) + return half(detail::binary, detail::overflow(sign)); + return half(detail::binary, detail::fixed2half(m, exp, sign)); + #endif + } + + /// Hyperbolic cosine. + /// This function is exact to rounding for all rounding modes. + /// + /// **See also:** Documentation for [std::cosh](https://en.cppreference.com/w/cpp/numeric/math/cosh). + /// \param arg function argument + /// \return hyperbolic cosine value of \a arg + /// \exception FE_INVALID for signaling NaN + /// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding + inline half cosh(half arg) + { + #ifdef HALF_ARITHMETIC_TYPE + return half(detail::binary, detail::float2half(std::cosh(detail::half2float(arg.data_)))); + #else + int abs = arg.data_ & 0x7FFF, exp; + if(!abs) + return half(detail::binary, 0x3C00); + if(abs >= 0x7C00) + return half(detail::binary, (abs>0x7C00) ? detail::signal(arg.data_) : 0x7C00); + std::pair mm = detail::hyperbolic_args(abs, exp, (half::round_style==std::round_to_nearest) ? 23 : 26); + detail::uint32 m = mm.first + mm.second, i = (~m&0xFFFFFFFF) >> 31; + m = (m>>i) | (m&i) | 0x80000000; + if((exp+=13+i) > 29) + return half(detail::binary, detail::overflow()); + return half(detail::binary, detail::fixed2half(m, exp)); + #endif + } + + /// Hyperbolic tangent. + /// This function is exact to rounding for all rounding modes. + /// + /// **See also:** Documentation for [std::tanh](https://en.cppreference.com/w/cpp/numeric/math/tanh). + /// \param arg function argument + /// \return hyperbolic tangent value of \a arg + /// \exception FE_INVALID for signaling NaN + /// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding + inline half tanh(half arg) + { + #ifdef HALF_ARITHMETIC_TYPE + return half(detail::binary, detail::float2half(std::tanh(detail::half2float(arg.data_)))); + #else + int abs = arg.data_ & 0x7FFF, exp; + if(!abs) + return arg; + if(abs >= 0x7C00) + return half(detail::binary, (abs>0x7C00) ? detail::signal(arg.data_) : (arg.data_-0x4000)); + if(abs >= 0x4500) + return half(detail::binary, detail::rounded((arg.data_&0x8000)|0x3BFF, 1, 1)); + if(abs < 0x2700) + return half(detail::binary, detail::rounded(arg.data_-1, 1, 1)); + if(half::round_style != std::round_to_nearest && abs == 0x2D3F) + return half(detail::binary, detail::rounded(arg.data_-3, 0, 1)); + std::pair mm = detail::hyperbolic_args(abs, exp, 27); + detail::uint32 my = mm.first - mm.second - (half::round_style!=std::round_to_nearest), mx = mm.first + mm.second, i = (~mx&0xFFFFFFFF) >> 31; + for(exp=13; my<0x80000000; my<<=1,--exp) ; + mx = (mx>>i) | 0x80000000; + return half(detail::binary, detail::tangent_post(my, mx, exp-i, arg.data_&0x8000)); + #endif + } + + /// Hyperbolic area sine. + /// This function is exact to rounding for all rounding modes. + /// + /// **See also:** Documentation for [std::asinh](https://en.cppreference.com/w/cpp/numeric/math/asinh). + /// \param arg function argument + /// \return area sine value of \a arg + /// \exception FE_INVALID for signaling NaN + /// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding + inline half asinh(half arg) + { + #if defined(HALF_ARITHMETIC_TYPE) && HALF_ENABLE_CPP11_CMATH + return half(detail::binary, detail::float2half(std::asinh(detail::half2float(arg.data_)))); + #else + int abs = arg.data_ & 0x7FFF; + if(!abs || abs >= 0x7C00) + return (abs>0x7C00) ? half(detail::binary, detail::signal(arg.data_)) : arg; + if(abs <= 0x2900) + return half(detail::binary, detail::rounded(arg.data_-1, 1, 1)); + if(half::round_style != std::round_to_nearest) + switch(abs) + { + case 0x32D4: return half(detail::binary, detail::rounded(arg.data_-13, 1, 1)); + case 0x3B5B: return half(detail::binary, detail::rounded(arg.data_-197, 1, 1)); + } + return half(detail::binary, detail::area(arg.data_)); + #endif + } + + /// Hyperbolic area cosine. + /// This function is exact to rounding for all rounding modes. + /// + /// **See also:** Documentation for [std::acosh](https://en.cppreference.com/w/cpp/numeric/math/acosh). + /// \param arg function argument + /// \return area cosine value of \a arg + /// \exception FE_INVALID for signaling NaN or arguments <1 + /// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding + inline half acosh(half arg) + { + #if defined(HALF_ARITHMETIC_TYPE) && HALF_ENABLE_CPP11_CMATH + return half(detail::binary, detail::float2half(std::acosh(detail::half2float(arg.data_)))); + #else + int abs = arg.data_ & 0x7FFF; + if((arg.data_&0x8000) || abs < 0x3C00) + return half(detail::binary, (abs<=0x7C00) ? detail::invalid() : detail::signal(arg.data_)); + if(abs == 0x3C00) + return half(detail::binary, 0); + if(arg.data_ >= 0x7C00) + return (abs>0x7C00) ? half(detail::binary, detail::signal(arg.data_)) : arg; + return half(detail::binary, detail::area(arg.data_)); + #endif + } + + /// Hyperbolic area tangent. + /// This function is exact to rounding for all rounding modes. + /// + /// **See also:** Documentation for [std::atanh](https://en.cppreference.com/w/cpp/numeric/math/atanh). + /// \param arg function argument + /// \return area tangent value of \a arg + /// \exception FE_INVALID for signaling NaN or if abs(\a arg) > 1 + /// \exception FE_DIVBYZERO for +/-1 + /// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding + inline half atanh(half arg) + { + #if defined(HALF_ARITHMETIC_TYPE) && HALF_ENABLE_CPP11_CMATH + return half(detail::binary, detail::float2half(std::atanh(detail::half2float(arg.data_)))); + #else + int abs = arg.data_ & 0x7FFF, exp = 0; + if(!abs) + return arg; + if(abs >= 0x3C00) + return half(detail::binary, (abs==0x3C00) ? detail::pole(arg.data_&0x8000) : (abs<=0x7C00) ? detail::invalid() : detail::signal(arg.data_)); + if(abs < 0x2700) + return half(detail::binary, detail::rounded(arg.data_, 0, 1)); + detail::uint32 m = static_cast((abs&0x3FF)|((abs>0x3FF)<<10)) << ((abs>>10)+(abs<=0x3FF)+6), my = 0x80000000 + m, mx = 0x80000000 - m; + for(; mx<0x80000000; mx<<=1,++exp) ; + int i = my >= mx, s; + return half(detail::binary, detail::log2_post(detail::log2( + (detail::divide64(my>>i, mx, s)+1)>>1, 27)+0x10, exp+i-1, 16, arg.data_&0x8000)); + #endif + } + + /// \} + /// \anchor special + /// \name Error and gamma functions + /// \{ + + /// Error function. + /// This function may be 1 ULP off the correctly rounded exact result for any rounding mode in <0.5% of inputs. + /// + /// **See also:** Documentation for [std::erf](https://en.cppreference.com/w/cpp/numeric/math/erf). + /// \param arg function argument + /// \return error function value of \a arg + /// \exception FE_INVALID for signaling NaN + /// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding + inline half erf(half arg) + { + #if defined(HALF_ARITHMETIC_TYPE) && HALF_ENABLE_CPP11_CMATH + return half(detail::binary, detail::float2half(std::erf(detail::half2float(arg.data_)))); + #else + unsigned int abs = arg.data_ & 0x7FFF; + if(!abs || abs >= 0x7C00) + return (abs>=0x7C00) ? half(detail::binary, (abs==0x7C00) ? (arg.data_-0x4000) : detail::signal(arg.data_)) : arg; + if(abs >= 0x4200) + return half(detail::binary, detail::rounded((arg.data_&0x8000)|0x3BFF, 1, 1)); + return half(detail::binary, detail::erf(arg.data_)); + #endif + } + + /// Complementary error function. + /// This function may be 1 ULP off the correctly rounded exact result for any rounding mode in <0.5% of inputs. + /// + /// **See also:** Documentation for [std::erfc](https://en.cppreference.com/w/cpp/numeric/math/erfc). + /// \param arg function argument + /// \return 1 minus error function value of \a arg + /// \exception FE_INVALID for signaling NaN + /// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding + inline half erfc(half arg) + { + #if defined(HALF_ARITHMETIC_TYPE) && HALF_ENABLE_CPP11_CMATH + return half(detail::binary, detail::float2half(std::erfc(detail::half2float(arg.data_)))); + #else + unsigned int abs = arg.data_ & 0x7FFF, sign = arg.data_ & 0x8000; + if(abs >= 0x7C00) + return (abs>=0x7C00) ? half(detail::binary, (abs==0x7C00) ? (sign>>1) : detail::signal(arg.data_)) : arg; + if(!abs) + return half(detail::binary, 0x3C00); + if(abs >= 0x4400) + return half(detail::binary, detail::rounded((sign>>1)-(sign>>15), sign>>15, 1)); + return half(detail::binary, detail::erf(arg.data_)); + #endif + } + + /// Natural logarithm of gamma function. + /// This function may be 1 ULP off the correctly rounded exact result for any rounding mode in ~0.025% of inputs. + /// + /// **See also:** Documentation for [std::lgamma](https://en.cppreference.com/w/cpp/numeric/math/lgamma). + /// \param arg function argument + /// \return natural logarith of gamma function for \a arg + /// \exception FE_INVALID for signaling NaN + /// \exception FE_DIVBYZERO for 0 or negative integer arguments + /// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding + inline half lgamma(half arg) + { + #if defined(HALF_ARITHMETIC_TYPE) && HALF_ENABLE_CPP11_CMATH + return half(detail::binary, detail::float2half(std::lgamma(detail::half2float(arg.data_)))); + #else + int abs = arg.data_ & 0x7FFF; + if(abs >= 0x7C00) + return half(detail::binary, (abs==0x7C00) ? 0x7C00 : detail::signal(arg.data_)); + if(!abs || arg.data_ >= 0xE400 || (arg.data_ >= 0xBC00 && !(abs&((1<<(25-(abs>>10)))-1)))) + return half(detail::binary, detail::pole()); + if(arg.data_ == 0x3C00 || arg.data_ == 0x4000) + return half(detail::binary, 0); + return half(detail::binary, detail::gamma(arg.data_)); + #endif + } + + /// Gamma function. + /// This function may be 1 ULP off the correctly rounded exact result for any rounding mode in <0.25% of inputs. + /// + /// **See also:** Documentation for [std::tgamma](https://en.cppreference.com/w/cpp/numeric/math/tgamma). + /// \param arg function argument + /// \return gamma function value of \a arg + /// \exception FE_INVALID for signaling NaN, negative infinity or negative integer arguments + /// \exception FE_DIVBYZERO for 0 + /// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding + inline half tgamma(half arg) + { + #if defined(HALF_ARITHMETIC_TYPE) && HALF_ENABLE_CPP11_CMATH + return half(detail::binary, detail::float2half(std::tgamma(detail::half2float(arg.data_)))); + #else + unsigned int abs = arg.data_ & 0x7FFF; + if(!abs) + return half(detail::binary, detail::pole(arg.data_)); + if(abs >= 0x7C00) + return (arg.data_==0x7C00) ? arg : half(detail::binary, detail::signal(arg.data_)); + if(arg.data_ >= 0xE400 || (arg.data_ >= 0xBC00 && !(abs&((1<<(25-(abs>>10)))-1)))) + return half(detail::binary, detail::invalid()); + if(arg.data_ >= 0xCA80) + return half(detail::binary, detail::underflow((1-((abs>>(25-(abs>>10)))&1))<<15)); + if(arg.data_ <= 0x100 || (arg.data_ >= 0x4900 && arg.data_ < 0x8000)) + return half(detail::binary, detail::overflow()); + if(arg.data_ == 0x3C00) + return arg; + return half(detail::binary, detail::gamma(arg.data_)); + #endif + } + + /// \} + /// \anchor rounding + /// \name Rounding + /// \{ + + /// Nearest integer not less than half value. + /// **See also:** Documentation for [std::ceil](https://en.cppreference.com/w/cpp/numeric/math/ceil). + /// \param arg half to round + /// \return nearest integer not less than \a arg + /// \exception FE_INVALID for signaling NaN + /// \exception FE_INEXACT if value had to be rounded + inline half ceil(half arg) { return half(detail::binary, detail::integral(arg.data_)); } + + /// Nearest integer not greater than half value. + /// **See also:** Documentation for [std::floor](https://en.cppreference.com/w/cpp/numeric/math/floor). + /// \param arg half to round + /// \return nearest integer not greater than \a arg + /// \exception FE_INVALID for signaling NaN + /// \exception FE_INEXACT if value had to be rounded + inline half floor(half arg) { return half(detail::binary, detail::integral(arg.data_)); } + + /// Nearest integer not greater in magnitude than half value. + /// **See also:** Documentation for [std::trunc](https://en.cppreference.com/w/cpp/numeric/math/trunc). + /// \param arg half to round + /// \return nearest integer not greater in magnitude than \a arg + /// \exception FE_INVALID for signaling NaN + /// \exception FE_INEXACT if value had to be rounded + inline half trunc(half arg) { return half(detail::binary, detail::integral(arg.data_)); } + + /// Nearest integer. + /// **See also:** Documentation for [std::round](https://en.cppreference.com/w/cpp/numeric/math/round). + /// \param arg half to round + /// \return nearest integer, rounded away from zero in half-way cases + /// \exception FE_INVALID for signaling NaN + /// \exception FE_INEXACT if value had to be rounded + inline half round(half arg) { return half(detail::binary, detail::integral(arg.data_)); } + + /// Nearest integer. + /// **See also:** Documentation for [std::lround](https://en.cppreference.com/w/cpp/numeric/math/round). + /// \param arg half to round + /// \return nearest integer, rounded away from zero in half-way cases + /// \exception FE_INVALID if value is not representable as `long` + inline long lround(half arg) { return detail::half2int(arg.data_); } + + /// Nearest integer using half's internal rounding mode. + /// **See also:** Documentation for [std::rint](https://en.cppreference.com/w/cpp/numeric/math/rint). + /// \param arg half expression to round + /// \return nearest integer using default rounding mode + /// \exception FE_INVALID for signaling NaN + /// \exception FE_INEXACT if value had to be rounded + inline half rint(half arg) { return half(detail::binary, detail::integral(arg.data_)); } + + /// Nearest integer using half's internal rounding mode. + /// **See also:** Documentation for [std::lrint](https://en.cppreference.com/w/cpp/numeric/math/rint). + /// \param arg half expression to round + /// \return nearest integer using default rounding mode + /// \exception FE_INVALID if value is not representable as `long` + /// \exception FE_INEXACT if value had to be rounded + inline long lrint(half arg) { return detail::half2int(arg.data_); } + + /// Nearest integer using half's internal rounding mode. + /// **See also:** Documentation for [std::nearbyint](https://en.cppreference.com/w/cpp/numeric/math/nearbyint). + /// \param arg half expression to round + /// \return nearest integer using default rounding mode + /// \exception FE_INVALID for signaling NaN + inline half nearbyint(half arg) { return half(detail::binary, detail::integral(arg.data_)); } +#if HALF_ENABLE_CPP11_LONG_LONG + /// Nearest integer. + /// **See also:** Documentation for [std::llround](https://en.cppreference.com/w/cpp/numeric/math/round). + /// \param arg half to round + /// \return nearest integer, rounded away from zero in half-way cases + /// \exception FE_INVALID if value is not representable as `long long` + inline long long llround(half arg) { return detail::half2int(arg.data_); } + + /// Nearest integer using half's internal rounding mode. + /// **See also:** Documentation for [std::llrint](https://en.cppreference.com/w/cpp/numeric/math/rint). + /// \param arg half expression to round + /// \return nearest integer using default rounding mode + /// \exception FE_INVALID if value is not representable as `long long` + /// \exception FE_INEXACT if value had to be rounded + inline long long llrint(half arg) { return detail::half2int(arg.data_); } +#endif + + /// \} + /// \anchor float + /// \name Floating point manipulation + /// \{ + + /// Decompress floating-point number. + /// **See also:** Documentation for [std::frexp](https://en.cppreference.com/w/cpp/numeric/math/frexp). + /// \param arg number to decompress + /// \param exp address to store exponent at + /// \return significant in range [0.5, 1) + /// \exception FE_INVALID for signaling NaN + inline half frexp(half arg, int *exp) + { + *exp = 0; + unsigned int abs = arg.data_ & 0x7FFF; + if(abs >= 0x7C00 || !abs) + return (abs>0x7C00) ? half(detail::binary, detail::signal(arg.data_)) : arg; + for(; abs<0x400; abs<<=1,--*exp) ; + *exp += (abs>>10) - 14; + return half(detail::binary, (arg.data_&0x8000)|0x3800|(abs&0x3FF)); + } + + /// Multiply by power of two. + /// This function is exact to rounding for all rounding modes. + /// + /// **See also:** Documentation for [std::scalbln](https://en.cppreference.com/w/cpp/numeric/math/scalbn). + /// \param arg number to modify + /// \param exp power of two to multiply with + /// \return \a arg multplied by 2 raised to \a exp + /// \exception FE_INVALID for signaling NaN + /// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding + inline half scalbln(half arg, long exp) + { + unsigned int abs = arg.data_ & 0x7FFF, sign = arg.data_ & 0x8000; + if(abs >= 0x7C00 || !abs) + return (abs>0x7C00) ? half(detail::binary, detail::signal(arg.data_)) : arg; + for(; abs<0x400; abs<<=1,--exp) ; + exp += abs >> 10; + if(exp > 30) + return half(detail::binary, detail::overflow(sign)); + else if(exp < -10) + return half(detail::binary, detail::underflow(sign)); + else if(exp > 0) + return half(detail::binary, sign|(exp<<10)|(abs&0x3FF)); + unsigned int m = (abs&0x3FF) | 0x400; + return half(detail::binary, detail::rounded(sign|(m>>(1-exp)), (m>>-exp)&1, (m&((1<<-exp)-1))!=0)); + } + + /// Multiply by power of two. + /// This function is exact to rounding for all rounding modes. + /// + /// **See also:** Documentation for [std::scalbn](https://en.cppreference.com/w/cpp/numeric/math/scalbn). + /// \param arg number to modify + /// \param exp power of two to multiply with + /// \return \a arg multplied by 2 raised to \a exp + /// \exception FE_INVALID for signaling NaN + /// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding + inline half scalbn(half arg, int exp) { return scalbln(arg, exp); } + + /// Multiply by power of two. + /// This function is exact to rounding for all rounding modes. + /// + /// **See also:** Documentation for [std::ldexp](https://en.cppreference.com/w/cpp/numeric/math/ldexp). + /// \param arg number to modify + /// \param exp power of two to multiply with + /// \return \a arg multplied by 2 raised to \a exp + /// \exception FE_INVALID for signaling NaN + /// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding + inline half ldexp(half arg, int exp) { return scalbln(arg, exp); } + + /// Extract integer and fractional parts. + /// **See also:** Documentation for [std::modf](https://en.cppreference.com/w/cpp/numeric/math/modf). + /// \param arg number to decompress + /// \param iptr address to store integer part at + /// \return fractional part + /// \exception FE_INVALID for signaling NaN + inline half modf(half arg, half *iptr) + { + unsigned int abs = arg.data_ & 0x7FFF; + if(abs > 0x7C00) + { + arg = half(detail::binary, detail::signal(arg.data_)); + return *iptr = arg, arg; + } + if(abs >= 0x6400) + return *iptr = arg, half(detail::binary, arg.data_&0x8000); + if(abs < 0x3C00) + return iptr->data_ = arg.data_ & 0x8000, arg; + unsigned int exp = abs >> 10, mask = (1<<(25-exp)) - 1, m = arg.data_ & mask; + iptr->data_ = arg.data_ & ~mask; + if(!m) + return half(detail::binary, arg.data_&0x8000); + for(; m<0x400; m<<=1,--exp) ; + return half(detail::binary, (arg.data_&0x8000)|(exp<<10)|(m&0x3FF)); + } + + /// Extract exponent. + /// **See also:** Documentation for [std::ilogb](https://en.cppreference.com/w/cpp/numeric/math/ilogb). + /// \param arg number to query + /// \return floating-point exponent + /// \retval FP_ILOGB0 for zero + /// \retval FP_ILOGBNAN for NaN + /// \retval INT_MAX for infinity + /// \exception FE_INVALID for 0 or infinite values + inline int ilogb(half arg) + { + int abs = arg.data_ & 0x7FFF, exp; + if(!abs || abs >= 0x7C00) + { + detail::raise(FE_INVALID); + return !abs ? FP_ILOGB0 : (abs==0x7C00) ? INT_MAX : FP_ILOGBNAN; + } + for(exp=(abs>>10)-15; abs<0x200; abs<<=1,--exp) ; + return exp; + } + + /// Extract exponent. + /// **See also:** Documentation for [std::logb](https://en.cppreference.com/w/cpp/numeric/math/logb). + /// \param arg number to query + /// \return floating-point exponent + /// \exception FE_INVALID for signaling NaN + /// \exception FE_DIVBYZERO for 0 + inline half logb(half arg) + { + int abs = arg.data_ & 0x7FFF, exp; + if(!abs) + return half(detail::binary, detail::pole(0x8000)); + if(abs >= 0x7C00) + return half(detail::binary, (abs==0x7C00) ? 0x7C00 : detail::signal(arg.data_)); + for(exp=(abs>>10)-15; abs<0x200; abs<<=1,--exp) ; + unsigned int value = static_cast(exp<0) << 15; + if(exp) + { + unsigned int m = std::abs(exp) << 6; + for(exp=18; m<0x400; m<<=1,--exp) ; + value |= (exp<<10) + m; + } + return half(detail::binary, value); + } + + /// Next representable value. + /// **See also:** Documentation for [std::nextafter](https://en.cppreference.com/w/cpp/numeric/math/nextafter). + /// \param from value to compute next representable value for + /// \param to direction towards which to compute next value + /// \return next representable value after \a from in direction towards \a to + /// \exception FE_INVALID for signaling NaN + /// \exception FE_OVERFLOW for infinite result from finite argument + /// \exception FE_UNDERFLOW for subnormal result + inline half nextafter(half from, half to) + { + int fabs = from.data_ & 0x7FFF, tabs = to.data_ & 0x7FFF; + if(fabs > 0x7C00 || tabs > 0x7C00) + return half(detail::binary, detail::signal(from.data_, to.data_)); + if(from.data_ == to.data_ || !(fabs|tabs)) + return to; + if(!fabs) + { + detail::raise(FE_UNDERFLOW, !HALF_ERRHANDLING_UNDERFLOW_TO_INEXACT); + return half(detail::binary, (to.data_&0x8000)+1); + } + unsigned int out = from.data_ + (((from.data_>>15)^static_cast( + (from.data_^(0x8000|(0x8000-(from.data_>>15))))<(to.data_^(0x8000|(0x8000-(to.data_>>15))))))<<1) - 1; + detail::raise(FE_OVERFLOW, fabs<0x7C00 && (out&0x7C00)==0x7C00); + detail::raise(FE_UNDERFLOW, !HALF_ERRHANDLING_UNDERFLOW_TO_INEXACT && (out&0x7C00)<0x400); + return half(detail::binary, out); + } + + /// Next representable value. + /// **See also:** Documentation for [std::nexttoward](https://en.cppreference.com/w/cpp/numeric/math/nexttoward). + /// \param from value to compute next representable value for + /// \param to direction towards which to compute next value + /// \return next representable value after \a from in direction towards \a to + /// \exception FE_INVALID for signaling NaN + /// \exception FE_OVERFLOW for infinite result from finite argument + /// \exception FE_UNDERFLOW for subnormal result + inline half nexttoward(half from, long double to) + { + int fabs = from.data_ & 0x7FFF; + if(fabs > 0x7C00) + return half(detail::binary, detail::signal(from.data_)); + long double lfrom = static_cast(from); + if(detail::builtin_isnan(to) || lfrom == to) + return half(static_cast(to)); + if(!fabs) + { + detail::raise(FE_UNDERFLOW, !HALF_ERRHANDLING_UNDERFLOW_TO_INEXACT); + return half(detail::binary, (static_cast(detail::builtin_signbit(to))<<15)+1); + } + unsigned int out = from.data_ + (((from.data_>>15)^static_cast(lfrom 0x7C00; } + + /// Check if normal number. + /// **See also:** Documentation for [std::isnormal](https://en.cppreference.com/w/cpp/numeric/math/isnormal). + /// \param arg number to check + /// \retval true if normal number + /// \retval false if either subnormal, zero, infinity or NaN + inline HALF_CONSTEXPR bool isnormal(half arg) { return ((arg.data_&0x7C00)!=0) & ((arg.data_&0x7C00)!=0x7C00); } + + /// Check sign. + /// **See also:** Documentation for [std::signbit](https://en.cppreference.com/w/cpp/numeric/math/signbit). + /// \param arg number to check + /// \retval true for negative number + /// \retval false for positive number + inline HALF_CONSTEXPR bool signbit(half arg) { return (arg.data_&0x8000) != 0; } + + /// \} + /// \anchor compfunc + /// \name Comparison + /// \{ + + /// Quiet comparison for greater than. + /// **See also:** Documentation for [std::isgreater](https://en.cppreference.com/w/cpp/numeric/math/isgreater). + /// \param x first operand + /// \param y second operand + /// \retval true if \a x greater than \a y + /// \retval false else + inline HALF_CONSTEXPR bool isgreater(half x, half y) + { + return ((x.data_^(0x8000|(0x8000-(x.data_>>15))))+(x.data_>>15)) > ((y.data_^(0x8000|(0x8000-(y.data_>>15))))+(y.data_>>15)) && !isnan(x) && !isnan(y); + } + + /// Quiet comparison for greater equal. + /// **See also:** Documentation for [std::isgreaterequal](https://en.cppreference.com/w/cpp/numeric/math/isgreaterequal). + /// \param x first operand + /// \param y second operand + /// \retval true if \a x greater equal \a y + /// \retval false else + inline HALF_CONSTEXPR bool isgreaterequal(half x, half y) + { + return ((x.data_^(0x8000|(0x8000-(x.data_>>15))))+(x.data_>>15)) >= ((y.data_^(0x8000|(0x8000-(y.data_>>15))))+(y.data_>>15)) && !isnan(x) && !isnan(y); + } + + /// Quiet comparison for less than. + /// **See also:** Documentation for [std::isless](https://en.cppreference.com/w/cpp/numeric/math/isless). + /// \param x first operand + /// \param y second operand + /// \retval true if \a x less than \a y + /// \retval false else + inline HALF_CONSTEXPR bool isless(half x, half y) + { + return ((x.data_^(0x8000|(0x8000-(x.data_>>15))))+(x.data_>>15)) < ((y.data_^(0x8000|(0x8000-(y.data_>>15))))+(y.data_>>15)) && !isnan(x) && !isnan(y); + } + + /// Quiet comparison for less equal. + /// **See also:** Documentation for [std::islessequal](https://en.cppreference.com/w/cpp/numeric/math/islessequal). + /// \param x first operand + /// \param y second operand + /// \retval true if \a x less equal \a y + /// \retval false else + inline HALF_CONSTEXPR bool islessequal(half x, half y) + { + return ((x.data_^(0x8000|(0x8000-(x.data_>>15))))+(x.data_>>15)) <= ((y.data_^(0x8000|(0x8000-(y.data_>>15))))+(y.data_>>15)) && !isnan(x) && !isnan(y); + } + + /// Quiet comarison for less or greater. + /// **See also:** Documentation for [std::islessgreater](https://en.cppreference.com/w/cpp/numeric/math/islessgreater). + /// \param x first operand + /// \param y second operand + /// \retval true if either less or greater + /// \retval false else + inline HALF_CONSTEXPR bool islessgreater(half x, half y) + { + return x.data_!=y.data_ && ((x.data_|y.data_)&0x7FFF) && !isnan(x) && !isnan(y); + } + + /// Quiet check if unordered. + /// **See also:** Documentation for [std::isunordered](https://en.cppreference.com/w/cpp/numeric/math/isunordered). + /// \param x first operand + /// \param y second operand + /// \retval true if unordered (one or two NaN operands) + /// \retval false else + inline HALF_CONSTEXPR bool isunordered(half x, half y) { return isnan(x) || isnan(y); } + + /// \} + /// \anchor casting + /// \name Casting + /// \{ + + /// Cast to or from half-precision floating-point number. + /// This casts between [half](\ref half_float::half) and any built-in arithmetic type. The values are converted + /// directly using the default rounding mode, without any roundtrip over `float` that a `static_cast` would otherwise do. + /// + /// Using this cast with neither of the two types being a [half](\ref half_float::half) or with any of the two types + /// not being a built-in arithmetic type (apart from [half](\ref half_float::half), of course) results in a compiler + /// error and casting between [half](\ref half_float::half)s returns the argument unmodified. + /// \tparam T destination type (half or built-in arithmetic type) + /// \tparam U source type (half or built-in arithmetic type) + /// \param arg value to cast + /// \return \a arg converted to destination type + /// \exception FE_INVALID if \a T is integer type and result is not representable as \a T + /// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding + template T half_cast(U arg) { return detail::half_caster::cast(arg); } + + /// Cast to or from half-precision floating-point number. + /// This casts between [half](\ref half_float::half) and any built-in arithmetic type. The values are converted + /// directly using the specified rounding mode, without any roundtrip over `float` that a `static_cast` would otherwise do. + /// + /// Using this cast with neither of the two types being a [half](\ref half_float::half) or with any of the two types + /// not being a built-in arithmetic type (apart from [half](\ref half_float::half), of course) results in a compiler + /// error and casting between [half](\ref half_float::half)s returns the argument unmodified. + /// \tparam T destination type (half or built-in arithmetic type) + /// \tparam R rounding mode to use. + /// \tparam U source type (half or built-in arithmetic type) + /// \param arg value to cast + /// \return \a arg converted to destination type + /// \exception FE_INVALID if \a T is integer type and result is not representable as \a T + /// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding + template T half_cast(U arg) { return detail::half_caster::cast(arg); } + /// \} + + /// \} + /// \anchor errors + /// \name Error handling + /// \{ + + /// Clear exception flags. + /// This function works even if [automatic exception flag handling](\ref HALF_ERRHANDLING_FLAGS) is disabled, + /// but in that case manual flag management is the only way to raise flags. + /// + /// **See also:** Documentation for [std::feclearexcept](https://en.cppreference.com/w/cpp/numeric/fenv/feclearexcept). + /// \param excepts OR of exceptions to clear + /// \retval 0 all selected flags cleared successfully + inline int feclearexcept(int excepts) { detail::errflags() &= ~excepts; return 0; } + + /// Test exception flags. + /// This function works even if [automatic exception flag handling](\ref HALF_ERRHANDLING_FLAGS) is disabled, + /// but in that case manual flag management is the only way to raise flags. + /// + /// **See also:** Documentation for [std::fetestexcept](https://en.cppreference.com/w/cpp/numeric/fenv/fetestexcept). + /// \param excepts OR of exceptions to test + /// \return OR of selected exceptions if raised + inline int fetestexcept(int excepts) { return detail::errflags() & excepts; } + + /// Raise exception flags. + /// This raises the specified floating point exceptions and also invokes any additional automatic exception handling as + /// configured with the [HALF_ERRHANDLIG_...](\ref HALF_ERRHANDLING_ERRNO) preprocessor symbols. + /// This function works even if [automatic exception flag handling](\ref HALF_ERRHANDLING_FLAGS) is disabled, + /// but in that case manual flag management is the only way to raise flags. + /// + /// **See also:** Documentation for [std::feraiseexcept](https://en.cppreference.com/w/cpp/numeric/fenv/feraiseexcept). + /// \param excepts OR of exceptions to raise + /// \retval 0 all selected exceptions raised successfully + inline int feraiseexcept(int excepts) { detail::errflags() |= excepts; detail::raise(excepts); return 0; } + + /// Save exception flags. + /// This function works even if [automatic exception flag handling](\ref HALF_ERRHANDLING_FLAGS) is disabled, + /// but in that case manual flag management is the only way to raise flags. + /// + /// **See also:** Documentation for [std::fegetexceptflag](https://en.cppreference.com/w/cpp/numeric/fenv/feexceptflag). + /// \param flagp adress to store flag state at + /// \param excepts OR of flags to save + /// \retval 0 for success + inline int fegetexceptflag(int *flagp, int excepts) { *flagp = detail::errflags() & excepts; return 0; } + + /// Restore exception flags. + /// This only copies the specified exception state (including unset flags) without incurring any additional exception handling. + /// This function works even if [automatic exception flag handling](\ref HALF_ERRHANDLING_FLAGS) is disabled, + /// but in that case manual flag management is the only way to raise flags. + /// + /// **See also:** Documentation for [std::fesetexceptflag](https://en.cppreference.com/w/cpp/numeric/fenv/feexceptflag). + /// \param flagp adress to take flag state from + /// \param excepts OR of flags to restore + /// \retval 0 for success + inline int fesetexceptflag(const int *flagp, int excepts) { detail::errflags() = (detail::errflags()|(*flagp&excepts)) & (*flagp|~excepts); return 0; } + + /// Throw C++ exceptions based on set exception flags. + /// This function manually throws a corresponding C++ exception if one of the specified flags is set, + /// no matter if automatic throwing (via [HALF_ERRHANDLING_THROW_...](\ref HALF_ERRHANDLING_THROW_INVALID)) is enabled or not. + /// This function works even if [automatic exception flag handling](\ref HALF_ERRHANDLING_FLAGS) is disabled, + /// but in that case manual flag management is the only way to raise flags. + /// \param excepts OR of exceptions to test + /// \param msg error message to use for exception description + /// \throw std::domain_error if `FE_INVALID` or `FE_DIVBYZERO` is selected and set + /// \throw std::overflow_error if `FE_OVERFLOW` is selected and set + /// \throw std::underflow_error if `FE_UNDERFLOW` is selected and set + /// \throw std::range_error if `FE_INEXACT` is selected and set + inline void fethrowexcept(int excepts, const char *msg = "") + { + excepts &= detail::errflags(); + if(excepts & (FE_INVALID|FE_DIVBYZERO)) + throw std::domain_error(msg); + if(excepts & FE_OVERFLOW) + throw std::overflow_error(msg); + if(excepts & FE_UNDERFLOW) + throw std::underflow_error(msg); + if(excepts & FE_INEXACT) + throw std::range_error(msg); + } + /// \} +} + + +#undef HALF_UNUSED_NOERR +#undef HALF_CONSTEXPR +#undef HALF_CONSTEXPR_CONST +#undef HALF_CONSTEXPR_NOERR +#undef HALF_NOEXCEPT +#undef HALF_NOTHROW +#undef HALF_THREAD_LOCAL +#undef HALF_TWOS_COMPLEMENT_INT +#ifdef HALF_POP_WARNINGS + #pragma warning(pop) + #undef HALF_POP_WARNINGS +#endif + +#endif diff --git a/src/tools/nnfusion/templates/dxcompute/Direct3DWinNN/runtime/runtime.vcxproj b/src/tools/nnfusion/templates/dxcompute/Direct3DWinNN/runtime/runtime.vcxproj index 35632915f..52c9deede 100644 --- a/src/tools/nnfusion/templates/dxcompute/Direct3DWinNN/runtime/runtime.vcxproj +++ b/src/tools/nnfusion/templates/dxcompute/Direct3DWinNN/runtime/runtime.vcxproj @@ -129,11 +129,14 @@ true NotUsing pch.h + ../antares;%(AdditionalIncludeDirectories) Windows true false + antares.lib;%(AdditionalDependencies) + ..\$(IntDir);%(AdditionalLibraryDirectories) @@ -146,6 +149,7 @@ true NotUsing pch.h + ../antares;%(AdditionalIncludeDirectories) Windows @@ -153,18 +157,17 @@ true true false + antares.lib;%(AdditionalDependencies) + ..\$(IntDir);%(AdditionalLibraryDirectories) - - - diff --git a/src/tools/nnfusion/templates/dxcompute/Direct3DWinNN/runtime/runtime.vcxproj.filters b/src/tools/nnfusion/templates/dxcompute/Direct3DWinNN/runtime/runtime.vcxproj.filters index 82eca40df..87feee260 100644 --- a/src/tools/nnfusion/templates/dxcompute/Direct3DWinNN/runtime/runtime.vcxproj.filters +++ b/src/tools/nnfusion/templates/dxcompute/Direct3DWinNN/runtime/runtime.vcxproj.filters @@ -21,19 +21,10 @@ Header Files - - Header Files - - - Header Files - Source Files - - Source Files - \ No newline at end of file diff --git a/src/tools/nnfusion/templates/dxcompute/Direct3DWinNN/runtime/update.bat b/src/tools/nnfusion/templates/dxcompute/Direct3DWinNN/runtime/update.bat deleted file mode 100644 index aa139794d..000000000 --- a/src/tools/nnfusion/templates/dxcompute/Direct3DWinNN/runtime/update.bat +++ /dev/null @@ -1,13 +0,0 @@ -@echo off - -echo update D3D12APIWrapper.h -curl -LOs https://github.com/microsoft/antares/blob/master/backends/c-hlsl/evaluator/AntaresHlslLib/D3D12APIWrapper.h - -echo update D3D12APIWrapper.cpp -curl -LOs https://github.com/microsoft/antares/blob/master/backends/c-hlsl/evaluator/AntaresHlslLib/D3D12APIWrapper.cpp - -echo update D3D12Antares.h -curl -LOs https://github.com/microsoft/antares/blob/master/backends/c-hlsl/evaluator/AntaresHlslLib/D3D12Antares.h - -echo finished! -pause diff --git a/src/tools/nnfusion/templates/dxcompute/Direct3DWinNN/runtime/D3D12APIWrapper.cpp b/src/tools/nnfusion/templates/dxcompute/Direct3DXBoxNN/antares/D3D12APIWrapper.cpp similarity index 84% rename from src/tools/nnfusion/templates/dxcompute/Direct3DWinNN/runtime/D3D12APIWrapper.cpp rename to src/tools/nnfusion/templates/dxcompute/Direct3DXBoxNN/antares/D3D12APIWrapper.cpp index 87e13b14a..9f84f904a 100644 --- a/src/tools/nnfusion/templates/dxcompute/Direct3DWinNN/runtime/D3D12APIWrapper.cpp +++ b/src/tools/nnfusion/templates/dxcompute/Direct3DXBoxNN/antares/D3D12APIWrapper.cpp @@ -8,13 +8,23 @@ #include #include #include +#include +#include #define _USE_GPU_TIMER_ #define _USE_DXC_ +#define ANTARES_EXPORTS + #include "D3D12Util.h" #include "D3D12APIWrapper.h" +#if _DEBUG +#define DEBUG_PRINT(msg) (fprintf(stderr, "[DEBUG] %s\n", msg), fflush(stderr)) +#else +#define DEBUG_PRINT(msg) +#endif + namespace { static bool _USE_DESCRIPTOR_HEAP_ = false; @@ -75,16 +85,34 @@ namespace { } }; + struct VectorHasher { + int operator()(const std::vector& V) const { + int hash = V.size(); + for (auto& i : V) { + hash ^= (i ^ (i >> 32)) + 0x9e3779b9L; + } + return hash; + } + }; + + std::string ReplaceAll(std::string str, const std::string& from, const std::string& to) { + size_t start_pos = 0; + while ((start_pos = str.find(from, start_pos)) != std::string::npos) { + str.replace(start_pos, from.length(), to); + start_pos += to.length(); // Handles case where 'to' is a substring of 'from' + } + return str; + } + struct dx_shader_t { int block[3], thread[3]; std::vector inputs, outputs; std::string source; - CD3DX12_SHADER_BYTECODE bytecode; + std::unordered_map, ComPtr, VectorHasher> pPSO_ht; // bytecode_ht; // Added D3D12 resource ptr. ComPtr pRootSignature; - ComPtr pPSO; }; // Stream is wrapper of resources for record and execute commands. @@ -126,11 +154,8 @@ namespace { std::vector queryHeapsNeedToResolve; }; -#ifdef _DEBUG - static std::shared_ptr device = std::make_shared(true, true); -#else - static std::shared_ptr device = std::make_shared(false, false); -#endif + + static std::shared_ptr device; static void* defaultStream = nullptr; @@ -175,6 +200,16 @@ namespace { int dxInit(int flags) { + DEBUG_PRINT(__func__); + + if (device == nullptr) { +#ifdef _DEBUG + device = std::make_shared(true, true); +#else + device = std::make_shared(false, false); +#endif + } + if (!defaultStream) { // flags = 1: enable descriptor heap, no logging @@ -193,16 +228,55 @@ int dxInit(int flags) } int dxFinalize() { + DEBUG_PRINT(__func__); + device = nullptr; defaultStream = nullptr; return 0; } +static std::unordered_map> unused_buffers; +static std::unordered_map buffer_slots; + +inline size_t compute_slotsize(size_t &value) { + static const int tab64[64] = { + 63, 0, 58, 1, 59, 47, 53, 2, + 60, 39, 48, 27, 54, 33, 42, 3, + 61, 51, 37, 40, 49, 18, 28, 20, + 55, 30, 34, 11, 43, 14, 22, 4, + 62, 57, 46, 52, 38, 26, 32, 41, + 50, 36, 17, 19, 29, 10, 13, 21, + 56, 45, 25, 31, 35, 16, 9, 12, + 44, 24, 15, 8, 23, 7, 6, 5 }; + + value -= 1; + value |= value >> 1; + value |= value >> 2; + value |= value >> 4; + value |= value >> 8; + value |= value >> 16; + value |= value >> 32; + + size_t slot_id = tab64[((uint64_t)((value - (value >> 1)) * 0x07EDD5E59A4E28C2LLU)) >> 58]; + value += 1; + return slot_id; +} + void* dxMemAlloc(size_t bytes) { + DEBUG_PRINT(__func__); + if (dxInit(0) != 0) return nullptr; + auto slot_id = compute_slotsize(bytes); + auto& slot = unused_buffers[slot_id]; + if (slot.size()) { + void* buff = slot.back(); + slot.pop_back(); + return buff; + } + auto buff = new dx_buffer_t(); buff->size = bytes; device->CreateGPUOnlyResource(bytes, &buff->handle); @@ -211,20 +285,30 @@ void* dxMemAlloc(size_t bytes) void* virtualPtr = VirtualAlloc(nullptr, bytes, MEM_RESERVE, PAGE_NOACCESS); assert(virtualPtr != nullptr); + buffer_slots[virtualPtr] = slot_id; memBlocks[virtualPtr] = buff; return virtualPtr; } -int dxMemFree(void* vPtr) +int dxMemFree(void* virtualPtr) { - VirtualFree(vPtr, 0, MEM_RELEASE); - memBlocks.erase(vPtr); + DEBUG_PRINT(__func__); + + auto it = buffer_slots.find(virtualPtr); + assert(it != buffer_slots.end()); + unused_buffers[it->second].push_back(virtualPtr); + return 0; + + VirtualFree(virtualPtr, 0, MEM_RELEASE); + memBlocks.erase(virtualPtr); return 0; } void* dxShaderLoad_v2(const char* shader_src) { + DEBUG_PRINT(__func__); + if (dxInit(0) != 0) return nullptr; @@ -241,27 +325,6 @@ void* dxShaderLoad_v2(const char* shader_src) dx_shader_t* handle = new dx_shader_t; handle->source = source; -#ifdef _USE_DXC_ - // Use cs_6_0 since dxc only supports cs_6_0 or higher shader models. - auto computeShader = antares::DXCompiler::Get()->Compile(source.data(), (uint32_t)source.size(), L"CSMain", L"cs_6_0"); - if (computeShader != nullptr) - handle->bytecode = CD3DX12_SHADER_BYTECODE(computeShader->GetBufferPointer(), computeShader->GetBufferSize()); - else - abort(); -#else - ComPtr computeShader = nullptr, errMsg = nullptr; - if (D3DCompile(source.data(), source.size(), NULL, NULL, NULL, "CSMain", "cs_5_1", 0, 0, &computeShader, &errMsg) >= 0 && computeShader != nullptr) - handle->bytecode = CD3DX12_SHADER_BYTECODE(computeShader.Get()); - else { - auto error_message = (char*)errMsg->GetBufferPointer(); - fprintf(stderr, "[ERROR] D3D12: Shader Compile Failed: %s\n", error_message); - } -#endif - if (computeShader == nullptr) { - //delete handle; - return nullptr; - } - std::string str_params; std::vector arr_params, in_params, out_params; bool legacy_format = (source.size() >= 3 && source.substr(0, 3) == "///"); @@ -321,8 +384,6 @@ void* dxShaderLoad_v2(const char* shader_src) auto& hd = handle; ComPtr& m_computeRootSignature = hd->pRootSignature; - ComPtr& m_computeState = hd->pPSO; - D3D12_COMPUTE_PIPELINE_STATE_DESC computePsoDesc{}; CD3DX12_VERSIONED_ROOT_SIGNATURE_DESC computeRootSignatureDesc; std::vector computeRootParameters; @@ -356,21 +417,20 @@ void* dxShaderLoad_v2(const char* shader_src) IFE(D3DX12SerializeVersionedRootSignature(&computeRootSignatureDesc, D3D_ROOT_SIGNATURE_VERSION_1_1, &signature, &error)); IFE(device->pDevice->CreateRootSignature(0, signature->GetBufferPointer(), signature->GetBufferSize(), IID_GRAPHICS_PPV_ARGS(m_computeRootSignature.ReleaseAndGetAddressOf()))); - - computePsoDesc.CS = hd->bytecode; - computePsoDesc.pRootSignature = m_computeRootSignature.Get(); - IFE(device->pDevice->CreateComputePipelineState(&computePsoDesc, IID_GRAPHICS_PPV_ARGS(m_computeState.ReleaseAndGetAddressOf()))); - return handle; } void dxShaderUnload(void* hShader) { + DEBUG_PRINT(__func__); + free(hShader); } void* dxModuleLoad(const char* module_src) { + DEBUG_PRINT(__func__); + std::string source; const char proto[] = "file://"; if (strncmp(module_src, proto, sizeof(proto) - 1) == 0) { @@ -401,6 +461,8 @@ void* dxModuleLoad(const char* module_src) void dxModuleUnload(void* hModule) { + DEBUG_PRINT(__func__); + auto& hShaderDict = *(std::unordered_map*)hModule; for (auto& it : hShaderDict) dxShaderUnload(it.second); @@ -409,6 +471,8 @@ void dxModuleUnload(void* hModule) void* dxModuleGetShader(void* hModule, const char* fname) { + DEBUG_PRINT(__func__); + auto& dict = *(std::unordered_map*)hModule; auto it = dict.find(fname); return it != dict.end() ? it->second : nullptr; @@ -416,6 +480,8 @@ void* dxModuleGetShader(void* hModule, const char* fname) void* dxStreamCreate() { + DEBUG_PRINT(__func__); + if (dxInit(0) != 0) return nullptr; @@ -446,6 +512,8 @@ void* dxStreamCreate() int dxStreamDestroy(void* hStream) { + DEBUG_PRINT(__func__); + if (hStream != nullptr) delete (dx_stream_t*)hStream; return 0; @@ -453,6 +521,8 @@ int dxStreamDestroy(void* hStream) int dxStreamSubmit(void* hStream) { + DEBUG_PRINT(__func__); + if (!hStream) hStream = defaultStream; @@ -480,6 +550,8 @@ int dxStreamSubmit(void* hStream) int dxStreamSynchronize(void* hStream) { + DEBUG_PRINT(__func__); + if (!hStream) hStream = defaultStream; @@ -499,6 +571,8 @@ int dxStreamSynchronize(void* hStream) int dxMemcpyDtoDAsync(void* dst, void* src, size_t bytes, void* hStream) { + DEBUG_PRINT(__func__); + if (!hStream) hStream = defaultStream; @@ -522,6 +596,8 @@ int dxMemcpyDtoDAsync(void* dst, void* src, size_t bytes, void* hStream) int dxMemcpyHtoDAsync(void* dst, void* src, size_t bytes, void *hStream) { + DEBUG_PRINT(__func__); + if (!hStream) hStream = defaultStream; @@ -563,6 +639,8 @@ int dxMemcpyHtoDAsync(void* dst, void* src, size_t bytes, void *hStream) int dxMemcpyDtoHAsync(void* dst, void* src, size_t bytes, void* hStream) { + DEBUG_PRINT(__func__); + if (!hStream) hStream = defaultStream; @@ -599,15 +677,58 @@ int dxMemcpyDtoHAsync(void* dst, void* src, size_t bytes, void* hStream) return dxStreamSynchronize(hStream); } -int dxShaderLaunchAsync(void* hShader, void** buffers, void* hStream) +static std::wstring default_compat = L"cs_6_0"; + +int dxModuleSetCompat(const char* compat_name) { + std::wstring_convert> converter; + ::default_compat = converter.from_bytes(compat_name); + return 0; +} + +int dxShaderLaunchAsyncExt(void* hShader, void** buffers, int n, int blocks, void* hStream) { + DEBUG_PRINT(__func__); + if (!hStream) hStream = defaultStream; - auto hd = (dx_shader_t*)hShader; auto pStream = (dx_stream_t*)hStream; assert(pStream->state == dx_stream_t::State::INRECORD); + n -= hd->inputs.size() + hd->outputs.size(); + n = max(0, n); + std::vector pargs(n); + for (int i = 0, j = hd->inputs.size() + hd->outputs.size(); i < n; ++i, ++j) + pargs[i] = (size_t)buffers[j]; + auto pso_iter = hd->pPSO_ht.find(pargs); + if (pso_iter == hd->pPSO_ht.end()) { + std::string src = hd->source; + for (int i = 0; i < n; ++i) + src = ReplaceAll(src, "@" + std::to_string(i) + "@", std::to_string(pargs[i])); + CD3DX12_SHADER_BYTECODE bytecode; +#ifdef _USE_DXC_ + // Use cs_6_0 since dxc only supports cs_6_0 or higher shader models. + auto computeShader = antares::DXCompiler::Get()->Compile(src.data(), (uint32_t)src.size(), L"CSMain", default_compat.c_str()); + if (computeShader != nullptr) + bytecode = CD3DX12_SHADER_BYTECODE(computeShader->GetBufferPointer(), computeShader->GetBufferSize()); +#else + ComPtr computeShader = nullptr, errMsg = nullptr; + if (D3DCompile(source.data(), source.size(), NULL, NULL, NULL, "CSMain", "cs_5_1", 0, 0, &computeShader, &errMsg) >= 0 && computeShader != nullptr) + bytecode = CD3DX12_SHADER_BYTECODE(computeShader.Get()); +#endif + if (computeShader == nullptr) { + //delete handle; + IFE(-1); + } + + ComPtr& m_computeState = hd->pPSO_ht[pargs]; + D3D12_COMPUTE_PIPELINE_STATE_DESC computePsoDesc{}; + computePsoDesc.CS = bytecode; + computePsoDesc.pRootSignature = hd->pRootSignature.Get(); + IFE(device->pDevice->CreateComputePipelineState(&computePsoDesc, IID_GRAPHICS_PPV_ARGS(m_computeState.ReleaseAndGetAddressOf()))); + pso_iter = hd->pPSO_ht.find(pargs); + } + std::vector devicePtrs; std::vector offsets; devicePtrs.reserve(hd->inputs.size() + hd->outputs.size()); @@ -636,7 +757,7 @@ int dxShaderLaunchAsync(void* hShader, void** buffers, void* hStream) } pStream->pCmdList->SetComputeRootSignature(hd->pRootSignature.Get()); - pStream->pCmdList->SetPipelineState(hd->pPSO.Get()); + pStream->pCmdList->SetPipelineState(pso_iter->second.Get()); if (_USE_DESCRIPTOR_HEAP_) @@ -696,15 +817,22 @@ int dxShaderLaunchAsync(void* hShader, void** buffers, void* hStream) // Set StartTimer here to only consider kernel execution time. device->StartTimer(pStream->pCmdList.Get(), m_nTimerIndex); #endif - pStream->pCmdList->Dispatch(hd->block[0], hd->block[1], hd->block[2]); + pStream->pCmdList->Dispatch(blocks >= 0 ? blocks : hd->block[0], hd->block[1], hd->block[2]); #ifdef _USE_GPU_TIMER_ device->StopTimer(pStream->pCmdList.Get(), m_nTimerIndex); #endif return 0; } +int dxShaderLaunchAsync(void* hShader, void** buffers, void* hStream) +{ + return dxShaderLaunchAsyncExt(hShader, buffers, 0, -1, hStream); +} + void* dxEventCreate() { + DEBUG_PRINT(__func__); + if (dxInit(0) != 0) return nullptr; @@ -767,6 +895,8 @@ void* dxEventCreate() int dxEventDestroy(void* hEvent) { + DEBUG_PRINT(__func__); + if (hEvent == nullptr) return -1; @@ -779,6 +909,8 @@ int dxEventDestroy(void* hEvent) int dxEventRecord(void* hEvent, void* hStream) { + DEBUG_PRINT(__func__); + if (!hStream) hStream = defaultStream; @@ -802,6 +934,8 @@ int dxEventRecord(void* hEvent, void* hStream) float dxEventElapsedSecond(void* hStart, void* hStop) { + DEBUG_PRINT(__func__); + auto pQueryStart = (antares::dx_query_t*)hStart; auto pQueryEnd = (antares::dx_query_t*)hStop; @@ -809,7 +943,11 @@ float dxEventElapsedSecond(void* hStart, void* hStop) uint64_t* pData; uint64_t timeStampStart = 0; uint64_t timeStampEnd = 0; - IFE(device->globalQueryHeaps[pQueryStart->heapIdx].pReadbackBuffer->Map(0, nullptr, reinterpret_cast(&pData))); + + HRESULT res = device->globalQueryHeaps[pQueryStart->heapIdx].pReadbackBuffer->Map(0, nullptr, reinterpret_cast(&pData)); + if (res < 0) + return -1.0f; + timeStampStart = pData[pQueryStart->queryIdxInHeap]; if (pQueryEnd->heapIdx == pQueryStart->heapIdx) diff --git a/src/tools/nnfusion/templates/dxcompute/Direct3DXBoxNN/antares/D3D12APIWrapper.h b/src/tools/nnfusion/templates/dxcompute/Direct3DXBoxNN/antares/D3D12APIWrapper.h new file mode 100644 index 000000000..41a3b5b46 --- /dev/null +++ b/src/tools/nnfusion/templates/dxcompute/Direct3DXBoxNN/antares/D3D12APIWrapper.h @@ -0,0 +1,44 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +#ifndef __ANTARES_D3D12_WRAPPER__ +#define __ANTARES_D3D12_WRAPPER__ + +#ifdef ANTARES_EXPORTS +#define ANTARES_API __declspec(dllexport) +#else +#define ANTARES_API __declspec(dllimport) +#endif + +#define __EXPORT__ extern "C" + +__EXPORT__ ANTARES_API int dxInit(int flags); +__EXPORT__ ANTARES_API int dxFinalize(); + +__EXPORT__ ANTARES_API void* dxStreamCreate(); +__EXPORT__ ANTARES_API int dxStreamDestroy(void* hStream); +__EXPORT__ ANTARES_API int dxStreamSubmit(void* hStream); +__EXPORT__ ANTARES_API int dxStreamSynchronize(void* hStream); + +__EXPORT__ ANTARES_API void* dxMemAlloc(size_t bytes); +__EXPORT__ ANTARES_API int dxMemFree(void* dptr); +__EXPORT__ ANTARES_API int dxMemcpyHtoDAsync(void* dst, void* src, size_t bytes, void* hStream); +__EXPORT__ ANTARES_API int dxMemcpyDtoHAsync(void* dst, void* src, size_t bytes, void* hStream); +__EXPORT__ ANTARES_API int dxMemcpyDtoDAsync(void* dst, void* src, size_t bytes, void* hStream); + +__EXPORT__ ANTARES_API int dxModuleSetCompat(const char* compat_name); +__EXPORT__ ANTARES_API void* dxModuleLoad(const char* module_src); +__EXPORT__ ANTARES_API void* dxModuleGetShader(void *hModule, const char* fname); +__EXPORT__ ANTARES_API void dxModuleUnload(void* hModule); + +__EXPORT__ ANTARES_API void* dxShaderLoad_v2(const char* shader_src); +__EXPORT__ ANTARES_API int dxShaderLaunchAsyncExt(void* hShader, void** buffers, int n, int blocks, void* hStream); +__EXPORT__ ANTARES_API int dxShaderLaunchAsync(void* hShader, void** buffers, void* hStream); +__EXPORT__ ANTARES_API void dxShaderUnload(void* hShader); + +__EXPORT__ ANTARES_API void* dxEventCreate(); +__EXPORT__ ANTARES_API int dxEventRecord(void* hEvent, void* hStream); +__EXPORT__ ANTARES_API float dxEventElapsedSecond(void* hStart, void* hStop); +__EXPORT__ ANTARES_API int dxEventDestroy(void* hEvent); + +#endif \ No newline at end of file diff --git a/src/tools/nnfusion/templates/dxcompute/Direct3DWinNN/runtime/D3D12Util.h b/src/tools/nnfusion/templates/dxcompute/Direct3DXBoxNN/antares/D3D12Util.h similarity index 99% rename from src/tools/nnfusion/templates/dxcompute/Direct3DWinNN/runtime/D3D12Util.h rename to src/tools/nnfusion/templates/dxcompute/Direct3DXBoxNN/antares/D3D12Util.h index 0cd0ad39c..0d7f59416 100644 --- a/src/tools/nnfusion/templates/dxcompute/Direct3DWinNN/runtime/D3D12Util.h +++ b/src/tools/nnfusion/templates/dxcompute/Direct3DXBoxNN/antares/D3D12Util.h @@ -1121,6 +1121,8 @@ namespace antares { args_i.push_back(args[i].c_str()); } #endif + if (std::wstring(profile) != std::wstring(L"cs_6_0")) + args_i.push_back(L"-enable-16bit-types"); args_i.push_back(NULL); // Just set a random name "ShaderFile" // const WCHAR* args[] = { L"-enable-templates", L"-enable-16bit-types", NULL }; // TODO: will be supported in HLSL 2021 & cs_6_2 diff --git a/src/tools/nnfusion/templates/dxcompute/Direct3DXBoxNN/antares/antares.vcxproj b/src/tools/nnfusion/templates/dxcompute/Direct3DXBoxNN/antares/antares.vcxproj new file mode 100644 index 000000000..87e579f20 --- /dev/null +++ b/src/tools/nnfusion/templates/dxcompute/Direct3DXBoxNN/antares/antares.vcxproj @@ -0,0 +1,293 @@ + + + + + Debug + Gaming.Xbox.Scarlett.x64 + + + Profile + Gaming.Xbox.Scarlett.x64 + + + Release + Gaming.Xbox.Scarlett.x64 + + + Release + Gaming.Xbox.XboxOne.x64 + + + Profile + Gaming.Xbox.XboxOne.x64 + + + Debug + Gaming.Xbox.XboxOne.x64 + + + + + + + + + + + + antares + {38f33782-704d-424c-8eac-249afa785995} + en-US + Win32Proj + + 15.0 + Native + x64 + + + + DynamicLibrary + v142 + false + true + Unicode + false + false + + + DynamicLibrary + v142 + false + true + Unicode + false + false + + + DynamicLibrary + v142 + false + true + Unicode + false + false + + + DynamicLibrary + v142 + false + true + Unicode + false + false + + + DynamicLibrary + v142 + true + Unicode + false + false + + + DynamicLibrary + v142 + true + Unicode + false + false + + + + + + + + + + + + + + + + + + + + + + + + + $(Console_SdkLibPath);$(Console_SdkWindowsMetadataPath) + $(Console_SdkLibPath) + $(Console_SdkLibPath);$(Console_SdkWindowsMetadataPath) + $(Console_SdkIncludeRoot) + $(Console_SdkRoot)bin;$(Console_SdkToolPath);$(ExecutablePath) + false + + + $(Console_SdkLibPath);$(Console_SdkWindowsMetadataPath) + $(Console_SdkLibPath) + $(Console_SdkLibPath);$(Console_SdkWindowsMetadataPath) + $(Console_SdkIncludeRoot) + $(Console_SdkRoot)bin;$(Console_SdkToolPath);$(ExecutablePath) + false + + + $(Console_SdkLibPath);$(Console_SdkWindowsMetadataPath) + $(Console_SdkLibPath) + $(Console_SdkLibPath);$(Console_SdkWindowsMetadataPath) + $(Console_SdkIncludeRoot) + $(Console_SdkRoot)bin;$(Console_SdkToolPath);$(ExecutablePath) + false + + + $(Console_SdkLibPath);$(Console_SdkWindowsMetadataPath) + $(Console_SdkLibPath) + $(Console_SdkLibPath);$(Console_SdkWindowsMetadataPath) + $(Console_SdkIncludeRoot) + $(Console_SdkRoot)bin;$(Console_SdkToolPath);$(ExecutablePath) + false + + + $(Console_SdkLibPath);$(Console_SdkWindowsMetadataPath) + $(Console_SdkLibPath) + $(Console_SdkLibPath);$(Console_SdkWindowsMetadataPath) + $(Console_SdkIncludeRoot) + $(Console_SdkRoot)bin;$(Console_SdkToolPath);$(ExecutablePath) + true + + + $(Console_SdkLibPath);$(Console_SdkWindowsMetadataPath) + $(Console_SdkLibPath) + $(Console_SdkLibPath);$(Console_SdkWindowsMetadataPath) + $(Console_SdkIncludeRoot) + $(Console_SdkRoot)bin;$(Console_SdkToolPath);$(ExecutablePath) + true + + + + $(Console_Libs);%(XboxExtensionsDependencies);%(AdditionalDependencies) + true + Windows + true + true + + + Use + pch.h + MaxSpeed + NDEBUG;_USRDLL;%(PreprocessorDefinitions) + Level4 + true + true + true + + + + + $(Console_Libs);%(XboxExtensionsDependencies);%(AdditionalDependencies) + true + Windows + true + true + + + NotUsing + pch.h + MaxSpeed + NDEBUG;ANTARES_EXPORTS;_USRDLL;%(PreprocessorDefinitions) + Level4 + true + true + true + + + + + + + + + $(Console_Libs);%(XboxExtensionsDependencies);%(AdditionalDependencies) + true + Windows + true + true + + + Use + pch.h + MaxSpeed + NDEBUG;PROFILE;_USRDLL;%(PreprocessorDefinitions) + Level4 + true + true + true + + + + + $(Console_Libs);%(XboxExtensionsDependencies);%(AdditionalDependencies) + true + Windows + true + true + + + NotUsing + pch.h + MaxSpeed + NDEBUG;PROFILE;_USRDLL;%(PreprocessorDefinitions) + Level4 + true + true + true + + + + + + + + + $(Console_Libs);%(XboxExtensionsDependencies);%(AdditionalDependencies) + Windows + true + + + pch.h + Use + false + Level4 + Disabled + _DEBUG;_USRDLL;%(PreprocessorDefinitions) + true + + + + + $(Console_Libs);%(XboxExtensionsDependencies);%(AdditionalDependencies) + Windows + true + + + pch.h + NotUsing + false + Level4 + Disabled + _DEBUG;ANTARES_EXPORTS;_USRDLL;%(PreprocessorDefinitions) + true + + + + + + + + + + \ No newline at end of file diff --git a/src/tools/nnfusion/templates/dxcompute/Direct3DXBoxNN/antares/antares.vcxproj.filters b/src/tools/nnfusion/templates/dxcompute/Direct3DXBoxNN/antares/antares.vcxproj.filters new file mode 100644 index 000000000..31f43c1fe --- /dev/null +++ b/src/tools/nnfusion/templates/dxcompute/Direct3DXBoxNN/antares/antares.vcxproj.filters @@ -0,0 +1,33 @@ + + + + + {4FC737F1-C7A5-4376-A066-2A32D752A2FF} + cpp;c;cc;cxx;def;odl;idl;hpj;bat;asm;asmx + + + {93995380-89BD-4b04-88EB-625FBE52EBFB} + h;hpp;hxx;hm;inl;inc;xsd + + + {67DA6AB6-F800-4c08-8B7A-83BB121AAD01} + rc;ico;cur;bmp;dlg;rc2;rct;bin;rgs;gif;jpg;jpeg;jpe;resx;tiff;tif;png;wav;mfcribbon-ms + + + + + Header Files + + + Header Files + + + Header Files + + + + + Source Files + + + \ No newline at end of file diff --git a/src/tools/nnfusion/templates/dxcompute/Direct3DXBoxNN/antares/antares.vcxproj.user b/src/tools/nnfusion/templates/dxcompute/Direct3DXBoxNN/antares/antares.vcxproj.user new file mode 100644 index 000000000..88a550947 --- /dev/null +++ b/src/tools/nnfusion/templates/dxcompute/Direct3DXBoxNN/antares/antares.vcxproj.user @@ -0,0 +1,4 @@ + + + + \ No newline at end of file diff --git a/src/tools/nnfusion/templates/dxcompute/Direct3DXBoxNN/runtime/pch.h b/src/tools/nnfusion/templates/dxcompute/Direct3DXBoxNN/antares/pch.h similarity index 100% rename from src/tools/nnfusion/templates/dxcompute/Direct3DXBoxNN/runtime/pch.h rename to src/tools/nnfusion/templates/dxcompute/Direct3DXBoxNN/antares/pch.h diff --git a/src/tools/nnfusion/templates/dxcompute/Direct3DXBoxNN/antares/update.bat b/src/tools/nnfusion/templates/dxcompute/Direct3DXBoxNN/antares/update.bat new file mode 100644 index 000000000..056ae1a02 --- /dev/null +++ b/src/tools/nnfusion/templates/dxcompute/Direct3DXBoxNN/antares/update.bat @@ -0,0 +1,10 @@ +@echo off + +curl -LOs https://raw.githubusercontent.com/microsoft/antares/v0.3.x/backends/c-hlsl_xbox/evaluator/AntaresHlslLib/D3D12APIWrapper.h && echo updated D3D12APIWrapper.h + +curl -LOs https://raw.githubusercontent.com/microsoft/antares/v0.3.x/backends/c-hlsl_xbox/evaluator/AntaresHlslLib/D3D12APIWrapper.cpp && echo updated D3D12APIWrapper.cpp + +curl -LOs https://raw.githubusercontent.com/microsoft/antares/v0.3.x/backends/c-hlsl_xbox/evaluator/AntaresHlslLib/D3D12Util.h && echo updated D3D12Util.h + +echo finished! +pause diff --git a/src/tools/nnfusion/templates/dxcompute/Direct3DXBoxNN/build.py b/src/tools/nnfusion/templates/dxcompute/Direct3DXBoxNN/build.py deleted file mode 100644 index e1f47b687..000000000 --- a/src/tools/nnfusion/templates/dxcompute/Direct3DXBoxNN/build.py +++ /dev/null @@ -1,109 +0,0 @@ -import os -import sys -import shutil -import winreg -import argparse -import logging -import subprocess - -logging.basicConfig(level="INFO") -logger = logging.getLogger(__name__) - - -def find_vs_path(): - # something like r"C:\Program Files (x86)\Microsoft Visual Studio\2019\Enterprise\MSBuild\Microsoft\VC\v160" - version = ["2019", "2017", "2015"] - license = ["Enterprise", "Professional", "Community"] - default_path = r"C:\Program Files (x86)\Microsoft Visual Studio" - for v in version: - v_path = os.path.join(default_path, v) - if not os.path.isdir(v_path): - continue - for l in license: - l_path = os.path.join(v_path, l) - if not os.path.isdir(l_path): - continue - logger.info(f"Find Visual Studio in {l_path}") - return l_path - return "" - - -def copy_to(src, dst): - assert os.path.exists(src), f"File not found: {src}" - if os.path.isfile(dst): - os.remove(dst) - if os.path.isdir(dst): - shutil.rmtree(dst) - if os.path.isfile(src): - shutil.copyfile(src, dst) - if os.path.isdir(src): - shutil.copytree(src, dst) - - -def copy_to_output(output_dir, build_type, platform): - os.makedirs(output_dir, exist_ok=True) - nnf_xbox_dir = r".\nnf_xbox_example" - hlsl_path = os.path.join(nnf_xbox_dir, "HLSL") - const_path = os.path.join(nnf_xbox_dir, "Constant") - para_info = os.path.join(nnf_xbox_dir, "para_info.json") - nnf_exe = os.path.join( - nnf_xbox_dir, platform if "x64" in platform else "", build_type, "nnf_xbox_example.exe") - deps_dir = os.path.join( - nnf_xbox_dir, platform if "x64" in platform else "", r"Layout\Image\Loose") - deps = [] - for file_name in os.listdir(deps_dir): - if file_name.endswith(".dll"): - deps.append(file_name) - - runtime_dir = r".\runtime" - nnf_lib = os.path.join( - runtime_dir, platform if "x64" in platform else "", build_type, "nnfusion_rt.dll") - - if os.path.exists(hlsl_path): - copy_to(hlsl_path, os.path.join(output_dir, "HLSL")) - if os.path.exists(const_path): - copy_to(const_path, os.path.join(output_dir, "Constant")) - copy_to(para_info, os.path.join(output_dir, "para_info.json")) - copy_to(nnf_exe, os.path.join(output_dir, "nnf_xbox_example.exe")) - copy_to(nnf_lib, os.path.join(output_dir, "nnfusion_rt.dll")) - for file_name in deps: - copy_to(os.path.join(deps_dir, file_name), - os.path.join(output_dir, file_name)) - - -def setup_parser(): - parser = argparse.ArgumentParser() - parser.add_argument("-v", "--vs_path", default="", - help="visual studio install path") - parser.add_argument("-t", "--build_type", default="Release") - parser.add_argument("-p", "--platform", default="Gaming.Xbox.Scarlett.x64") - parser.add_argument("-o", "--output", default="./build") - return parser - - -def main(): - parser = setup_parser() - args = parser.parse_args() - vs_path = args.vs_path if args.vs_path != "" else find_vs_path() - build_type = args.build_type - platform = args.platform - output = args.output - - assert vs_path != "", "please specify vs install path by -v" - msbuild_exe = os.path.join(vs_path, r"MSBuild\Current\Bin\MSBuild.exe") - assert os.path.isfile( - msbuild_exe), f"MSBuild.exe not found in {msbuild_exe}" - - try: - subprocess.check_output([msbuild_exe, r".\nnf_xbox_example\nnf_xbox_example.vcxproj", - f"/property:Configuration={build_type}", f"/property:Platform={platform}"], stderr=subprocess.STDOUT, encoding="utf8") - except subprocess.CalledProcessError as e: - logger.error(e.output) - sys.exit(1) - - copy_to_output(output, build_type, platform) - logger.info(f"Build successfully, output dir: {output}") - - -if __name__ == '__main__': - main() diff --git a/src/tools/nnfusion/templates/dxcompute/Direct3DXBoxNN/nnf_xbox_example.sln b/src/tools/nnfusion/templates/dxcompute/Direct3DXBoxNN/nnf_xbox_example.sln index 9da84ab6d..9d52b7544 100644 --- a/src/tools/nnfusion/templates/dxcompute/Direct3DXBoxNN/nnf_xbox_example.sln +++ b/src/tools/nnfusion/templates/dxcompute/Direct3DXBoxNN/nnf_xbox_example.sln @@ -4,8 +4,16 @@ Microsoft Visual Studio Solution File, Format Version 12.00 VisualStudioVersion = 16.0.31025.109 MinimumVisualStudioVersion = 10.0.40219.1 Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "nnf_xbox_example", "nnf_xbox_example\nnf_xbox_example.vcxproj", "{5054F7A7-E4B3-4AEC-AF86-9BEB27FCAAAC}" + ProjectSection(ProjectDependencies) = postProject + {38F33782-704D-424C-8EAC-249AFA785995} = {38F33782-704D-424C-8EAC-249AFA785995} + EndProjectSection EndProject -Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "runtime", "runtime\runtime.vcxproj", "{76E5C805-8662-4A30-AF0A-21518FEEED41}" +Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "nnfusion_rt", "runtime\runtime.vcxproj", "{76E5C805-8662-4A30-AF0A-21518FEEED41}" + ProjectSection(ProjectDependencies) = postProject + {38F33782-704D-424C-8EAC-249AFA785995} = {38F33782-704D-424C-8EAC-249AFA785995} + EndProjectSection +EndProject +Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "antares", "antares\antares.vcxproj", "{38F33782-704D-424C-8EAC-249AFA785995}" EndProject Global GlobalSection(SolutionConfigurationPlatforms) = preSolution @@ -47,6 +55,18 @@ Global {76E5C805-8662-4A30-AF0A-21518FEEED41}.Release|Gaming.Xbox.Scarlett.x64.Build.0 = Release|Gaming.Xbox.Scarlett.x64 {76E5C805-8662-4A30-AF0A-21518FEEED41}.Release|Gaming.Xbox.XboxOne.x64.ActiveCfg = Release|Gaming.Xbox.XboxOne.x64 {76E5C805-8662-4A30-AF0A-21518FEEED41}.Release|Gaming.Xbox.XboxOne.x64.Build.0 = Release|Gaming.Xbox.XboxOne.x64 + {38F33782-704D-424C-8EAC-249AFA785995}.Debug|Gaming.Xbox.Scarlett.x64.ActiveCfg = Debug|Gaming.Xbox.Scarlett.x64 + {38F33782-704D-424C-8EAC-249AFA785995}.Debug|Gaming.Xbox.Scarlett.x64.Build.0 = Debug|Gaming.Xbox.Scarlett.x64 + {38F33782-704D-424C-8EAC-249AFA785995}.Debug|Gaming.Xbox.XboxOne.x64.ActiveCfg = Debug|Gaming.Xbox.XboxOne.x64 + {38F33782-704D-424C-8EAC-249AFA785995}.Debug|Gaming.Xbox.XboxOne.x64.Build.0 = Debug|Gaming.Xbox.XboxOne.x64 + {38F33782-704D-424C-8EAC-249AFA785995}.Profile|Gaming.Xbox.Scarlett.x64.ActiveCfg = Profile|Gaming.Xbox.Scarlett.x64 + {38F33782-704D-424C-8EAC-249AFA785995}.Profile|Gaming.Xbox.Scarlett.x64.Build.0 = Profile|Gaming.Xbox.Scarlett.x64 + {38F33782-704D-424C-8EAC-249AFA785995}.Profile|Gaming.Xbox.XboxOne.x64.ActiveCfg = Profile|Gaming.Xbox.XboxOne.x64 + {38F33782-704D-424C-8EAC-249AFA785995}.Profile|Gaming.Xbox.XboxOne.x64.Build.0 = Profile|Gaming.Xbox.XboxOne.x64 + {38F33782-704D-424C-8EAC-249AFA785995}.Release|Gaming.Xbox.Scarlett.x64.ActiveCfg = Release|Gaming.Xbox.Scarlett.x64 + {38F33782-704D-424C-8EAC-249AFA785995}.Release|Gaming.Xbox.Scarlett.x64.Build.0 = Release|Gaming.Xbox.Scarlett.x64 + {38F33782-704D-424C-8EAC-249AFA785995}.Release|Gaming.Xbox.XboxOne.x64.ActiveCfg = Release|Gaming.Xbox.XboxOne.x64 + {38F33782-704D-424C-8EAC-249AFA785995}.Release|Gaming.Xbox.XboxOne.x64.Build.0 = Release|Gaming.Xbox.XboxOne.x64 EndGlobalSection GlobalSection(SolutionProperties) = preSolution HideSolutionNode = FALSE diff --git a/src/tools/nnfusion/templates/dxcompute/Direct3DXBoxNN/nnf_xbox_example/nnf_xbox_example.vcxproj b/src/tools/nnfusion/templates/dxcompute/Direct3DXBoxNN/nnf_xbox_example/nnf_xbox_example.vcxproj index 39ca97d52..7aefa6260 100644 --- a/src/tools/nnfusion/templates/dxcompute/Direct3DXBoxNN/nnf_xbox_example/nnf_xbox_example.vcxproj +++ b/src/tools/nnfusion/templates/dxcompute/Direct3DXBoxNN/nnf_xbox_example/nnf_xbox_example.vcxproj @@ -126,7 +126,8 @@ $(Console_SdkIncludeRoot) $(Console_SdkRoot)bin;$(Console_SdkToolPath);$(ExecutablePath) false - D:\GameOS.xvd + + $(Console_SdkLibPath);$(Console_SdkWindowsMetadataPath) @@ -143,7 +144,8 @@ $(Console_SdkIncludeRoot) $(Console_SdkRoot)bin;$(Console_SdkToolPath);$(ExecutablePath) false - D:\GameOS.xvd + + $(Console_SdkLibPath);$(Console_SdkWindowsMetadataPath) @@ -160,7 +162,8 @@ $(Console_SdkIncludeRoot) $(Console_SdkRoot)bin;$(Console_SdkToolPath);$(ExecutablePath) true - D:\GameOS.xvd + + @@ -184,11 +187,12 @@ - $(Console_Libs);%(XboxExtensionsDependencies);%(AdditionalDependencies) + $(Console_Libs);antares.lib;%(XboxExtensionsDependencies);%(AdditionalDependencies) true Console true true + ..\$(IntDir);%(AdditionalLibraryDirectories) NotUsing @@ -199,8 +203,15 @@ true true true - ..\runtime;%(AdditionalIncludeDirectories) + ..\runtime;..\antares;%(AdditionalIncludeDirectories) + + copy .\para_info.json "$(TargetDir)" +IF EXIST .\antares_perf.csv copy .\antares_perf.csv "$(TargetDir)" +IF EXIST .\HLSL xcopy /i /y /e .\HLSL "$(TargetDir)\HLSL" +IF EXIST .\Constant xcopy /i /y /e .\Constant "$(TargetDir)\Constant" + + @@ -224,11 +235,12 @@ - $(Console_Libs);%(XboxExtensionsDependencies);%(AdditionalDependencies) + $(Console_Libs);antares.lib;%(XboxExtensionsDependencies);%(AdditionalDependencies) true Console true true + ..\$(IntDir);%(AdditionalLibraryDirectories) NotUsing @@ -239,8 +251,15 @@ true true true - ..\runtime;%(AdditionalIncludeDirectories) + ..\runtime;..\antares;%(AdditionalIncludeDirectories) + + copy .\para_info.json "$(TargetDir)" +IF EXIST .\antares_perf.csv copy .\antares_perf.csv "$(TargetDir)" +IF EXIST .\HLSL xcopy /i /y /e .\HLSL "$(TargetDir)\HLSL" +IF EXIST .\Constant xcopy /i /y /e .\Constant "$(TargetDir)\Constant" + + @@ -261,9 +280,10 @@ - $(Console_Libs);%(XboxExtensionsDependencies);%(AdditionalDependencies) + $(Console_Libs);antares.lib;%(XboxExtensionsDependencies);%(AdditionalDependencies) Console true + ..\$(IntDir);%(AdditionalLibraryDirectories) pch.h @@ -273,8 +293,15 @@ Disabled _DEBUG;%(PreprocessorDefinitions) true - ..\runtime;%(AdditionalIncludeDirectories) + ..\runtime;..\antares;%(AdditionalIncludeDirectories) + + copy .\para_info.json "$(TargetDir)" +IF EXIST .\antares_perf.csv copy .\antares_perf.csv "$(TargetDir)" +IF EXIST .\HLSL xcopy /i /y /e .\HLSL "$(TargetDir)\HLSL" +IF EXIST .\Constant xcopy /i /y /e .\Constant "$(TargetDir)\Constant" + + @@ -286,6 +313,11 @@ PreserveNewest + + + PreserveNewest + + {76e5c805-8662-4a30-af0a-21518feeed41} diff --git a/src/tools/nnfusion/templates/dxcompute/Direct3DXBoxNN/runtime/D3D12APIWrapper.h b/src/tools/nnfusion/templates/dxcompute/Direct3DXBoxNN/runtime/D3D12APIWrapper.h deleted file mode 100644 index 0cc522f74..000000000 --- a/src/tools/nnfusion/templates/dxcompute/Direct3DXBoxNN/runtime/D3D12APIWrapper.h +++ /dev/null @@ -1,36 +0,0 @@ -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT license. - -#ifndef __ANTARES_D3D12_WRAPPER__ -#define __ANTARES_D3D12_WRAPPER__ - -#define __EXPORT__ extern "C" __declspec(dllexport) - -__EXPORT__ int dxInit(int flags); -__EXPORT__ int dxFinalize(); - -__EXPORT__ void* dxStreamCreate(); -__EXPORT__ int dxStreamDestroy(void* hStream); -__EXPORT__ int dxStreamSubmit(void* hStream); -__EXPORT__ int dxStreamSynchronize(void* hStream); - -__EXPORT__ void* dxMemAlloc(size_t bytes); -__EXPORT__ int dxMemFree(void* dptr); -__EXPORT__ int dxMemcpyHtoDAsync(void* dst, void* src, size_t bytes, void* hStream); -__EXPORT__ int dxMemcpyDtoHAsync(void* dst, void* src, size_t bytes, void* hStream); -__EXPORT__ int dxMemcpyDtoDAsync(void* dst, void* src, size_t bytes, void* hStream); - -__EXPORT__ void* dxModuleLoad(const char* module_src); -__EXPORT__ void* dxModuleGetShader(void *hModule, const char* fname); -__EXPORT__ void dxModuleUnload(void* hModule); - -__EXPORT__ void* dxShaderLoad_v2(const char* shader_src); -__EXPORT__ int dxShaderLaunchAsync(void* hShader, void** buffers, void* hStream); -__EXPORT__ void dxShaderUnload(void* hShader); - -__EXPORT__ void* dxEventCreate(); -__EXPORT__ int dxEventRecord(void* hEvent, void* hStream); -__EXPORT__ float dxEventElapsedSecond(void* hStart, void* hStop); -__EXPORT__ int dxEventDestroy(void* hEvent); - -#endif \ No newline at end of file diff --git a/src/tools/nnfusion/templates/dxcompute/Direct3DXBoxNN/runtime/DeviceResources.cpp b/src/tools/nnfusion/templates/dxcompute/Direct3DXBoxNN/runtime/DeviceResources.cpp deleted file mode 100644 index 19d1f3740..000000000 --- a/src/tools/nnfusion/templates/dxcompute/Direct3DXBoxNN/runtime/DeviceResources.cpp +++ /dev/null @@ -1,422 +0,0 @@ -// -// DeviceResources.cpp - A wrapper for the Direct3D 12.X device and swapchain -// - -#include "pch.h" -#include "DeviceResources.h" - -using namespace DirectX; -using namespace DX; - -using Microsoft::WRL::ComPtr; - -// Constructor for DeviceResources. -DeviceResources::DeviceResources( - DXGI_FORMAT backBufferFormat, - DXGI_FORMAT depthBufferFormat, - UINT backBufferCount, - unsigned int flags) noexcept(false) : - m_backBufferIndex(0), - m_fenceValues{}, - m_framePipelineToken{}, - m_rtvDescriptorSize(0), - m_screenViewport{}, - m_scissorRect{}, - m_backBufferFormat(backBufferFormat), - m_depthBufferFormat(depthBufferFormat), - m_backBufferCount(backBufferCount), - m_window(nullptr), - m_d3dFeatureLevel(D3D_FEATURE_LEVEL_12_0), - m_outputSize{0, 0, 1920, 1080}, - m_options(flags) -{ - if (backBufferCount < 2 || backBufferCount > MAX_BACK_BUFFER_COUNT) - { - throw std::out_of_range("invalid backBufferCount"); - } -} - -// Destructor for DeviceResources. -DeviceResources::~DeviceResources() -{ - // Ensure that the GPU is no longer referencing resources that are about to be destroyed. - WaitForGpu(); - - // Ensure we present a blank screen before cleaning up resources. - if (m_commandQueue) - { - (void)m_commandQueue->PresentX(0, nullptr, nullptr); - } -} - -// Configures the Direct3D device, and stores handles to it and the device context. -void DeviceResources::CreateDeviceResources() -{ - // Create the DX12 API device object. - D3D12XBOX_CREATE_DEVICE_PARAMETERS params = {}; - params.Version = D3D12_SDK_VERSION; - -#if defined(_DEBUG) - // Enable the debug layer. - params.ProcessDebugFlags = D3D12_PROCESS_DEBUG_FLAG_DEBUG_LAYER_ENABLED; -#elif defined(PROFILE) - // Enable the instrumented driver. - params.ProcessDebugFlags = D3D12XBOX_PROCESS_DEBUG_FLAG_INSTRUMENTED; -#endif - - params.GraphicsCommandQueueRingSizeBytes = static_cast(D3D12XBOX_DEFAULT_SIZE_BYTES); - params.GraphicsScratchMemorySizeBytes = static_cast(D3D12XBOX_DEFAULT_SIZE_BYTES); - params.ComputeScratchMemorySizeBytes = static_cast(D3D12XBOX_DEFAULT_SIZE_BYTES); - - HRESULT hr = D3D12XboxCreateDevice( - nullptr, - ¶ms, - IID_GRAPHICS_PPV_ARGS(m_d3dDevice.ReleaseAndGetAddressOf())); -#ifdef _DEBUG - if (hr == D3D12_ERROR_DRIVER_VERSION_MISMATCH) - { -#ifdef _GAMING_XBOX_SCARLETT - OutputDebugStringA("ERROR: Running a d3d12_xs.lib (Scarlett) linked binary on an Xbox One is not supported\n"); -#else - OutputDebugStringA("ERROR: Running a d3d12_x.lib (Xbox One) linked binary on a Scarlett device is not supported\n"); -#endif - } -#endif - ThrowIfFailed(hr); - - m_d3dDevice->SetName(L"DeviceResources"); - - // Create the command queue. - D3D12_COMMAND_QUEUE_DESC queueDesc = {}; - queueDesc.Flags = D3D12_COMMAND_QUEUE_FLAG_NONE; - queueDesc.Type = D3D12_COMMAND_LIST_TYPE_DIRECT; - - ThrowIfFailed(m_d3dDevice->CreateCommandQueue(&queueDesc, IID_GRAPHICS_PPV_ARGS(m_commandQueue.ReleaseAndGetAddressOf()))); - - m_commandQueue->SetName(L"DeviceResources"); - - // Create descriptor heaps for render target views and depth stencil views. - D3D12_DESCRIPTOR_HEAP_DESC rtvDescriptorHeapDesc = {}; - rtvDescriptorHeapDesc.NumDescriptors = m_backBufferCount; - rtvDescriptorHeapDesc.Type = D3D12_DESCRIPTOR_HEAP_TYPE_RTV; - - ThrowIfFailed(m_d3dDevice->CreateDescriptorHeap(&rtvDescriptorHeapDesc, IID_GRAPHICS_PPV_ARGS(m_rtvDescriptorHeap.ReleaseAndGetAddressOf()))); - - m_rtvDescriptorHeap->SetName(L"DeviceResources"); - - m_rtvDescriptorSize = m_d3dDevice->GetDescriptorHandleIncrementSize(D3D12_DESCRIPTOR_HEAP_TYPE_RTV); - - if (m_depthBufferFormat != DXGI_FORMAT_UNKNOWN) - { - D3D12_DESCRIPTOR_HEAP_DESC dsvDescriptorHeapDesc = {}; - dsvDescriptorHeapDesc.NumDescriptors = 1; - dsvDescriptorHeapDesc.Type = D3D12_DESCRIPTOR_HEAP_TYPE_DSV; - - ThrowIfFailed(m_d3dDevice->CreateDescriptorHeap(&dsvDescriptorHeapDesc, IID_GRAPHICS_PPV_ARGS(m_dsvDescriptorHeap.ReleaseAndGetAddressOf()))); - - m_dsvDescriptorHeap->SetName(L"DeviceResources"); - } - - // Create a command allocator for each back buffer that will be rendered to. - for (UINT n = 0; n < m_backBufferCount; n++) - { - ThrowIfFailed(m_d3dDevice->CreateCommandAllocator(D3D12_COMMAND_LIST_TYPE_DIRECT, IID_GRAPHICS_PPV_ARGS(m_commandAllocators[n].ReleaseAndGetAddressOf()))); - - wchar_t name[25] = {}; - swprintf_s(name, L"Render target %u", n); - m_commandAllocators[n]->SetName(name); - } - - // Create a command list for recording graphics commands. - ThrowIfFailed(m_d3dDevice->CreateCommandList(0, D3D12_COMMAND_LIST_TYPE_DIRECT, m_commandAllocators[0].Get(), nullptr, IID_GRAPHICS_PPV_ARGS(m_commandList.ReleaseAndGetAddressOf()))); - ThrowIfFailed(m_commandList->Close()); - - m_commandList->SetName(L"DeviceResources"); - - // Create a fence for tracking GPU execution progress. - ThrowIfFailed(m_d3dDevice->CreateFence(m_fenceValues[m_backBufferIndex], D3D12_FENCE_FLAG_NONE, IID_GRAPHICS_PPV_ARGS(m_fence.ReleaseAndGetAddressOf()))); - m_fenceValues[m_backBufferIndex]++; - - m_fence->SetName(L"DeviceResources"); - - m_fenceEvent.Attach(CreateEventEx(nullptr, nullptr, 0, EVENT_MODIFY_STATE | SYNCHRONIZE)); - if (!m_fenceEvent.IsValid()) - { - throw std::exception("CreateEvent"); - } - - if (m_options & c_Enable4K_UHD) - { - switch (XSystemGetDeviceType()) - { - case XSystemDeviceType::XboxOne: - case XSystemDeviceType::XboxOneS: - case XSystemDeviceType::XboxScarlettLockhart /* Xbox Series S */: - m_options &= ~c_Enable4K_UHD; -#ifdef _DEBUG - OutputDebugStringA("INFO: Swapchain using 1080p (1920 x 1080)\n"); -#endif - break; - - default: - m_outputSize = { 0, 0, 3840, 2160 }; -#ifdef _DEBUG - OutputDebugStringA("INFO: Swapchain using 4k (3840 x 2160)\n"); -#endif - break; - } - } - - RegisterFrameEvents(); -} - -// These resources need to be recreated every time the window size is changed. -void DeviceResources::CreateWindowSizeDependentResources() -{ - if (!m_window) - { - throw std::exception("Call SetWindow with a valid window handle"); - } - - // Wait until all previous GPU work is complete. - WaitForGpu(); - - // Ensure we present a blank screen before cleaning up resources. - ThrowIfFailed(m_commandQueue->PresentX(0, nullptr, nullptr)); - - // Release resources that are tied to the swap chain and update fence values. - for (UINT n = 0; n < m_backBufferCount; n++) - { - m_renderTargets[n].Reset(); - m_fenceValues[n] = m_fenceValues[m_backBufferIndex]; - } - - // Determine the render target size in pixels. - const UINT backBufferWidth = std::max(static_cast(m_outputSize.right - m_outputSize.left), 1u); - const UINT backBufferHeight = std::max(static_cast(m_outputSize.bottom - m_outputSize.top), 1u); - - // Obtain the back buffers for this window which will be the final render targets - // and create render target views for each of them. - CD3DX12_HEAP_PROPERTIES swapChainHeapProperties(D3D12_HEAP_TYPE_DEFAULT); - - D3D12_RESOURCE_DESC swapChainBufferDesc = CD3DX12_RESOURCE_DESC::Tex2D( - m_backBufferFormat, - backBufferWidth, - backBufferHeight, - 1, // This resource has only one texture. - 1 // Use a single mipmap level. - ); - swapChainBufferDesc.Flags |= D3D12_RESOURCE_FLAG_ALLOW_RENDER_TARGET; - - D3D12_CLEAR_VALUE swapChainOptimizedClearValue = {}; - swapChainOptimizedClearValue.Format = m_backBufferFormat; - - for (UINT n = 0; n < m_backBufferCount; n++) - { - ThrowIfFailed(m_d3dDevice->CreateCommittedResource( - &swapChainHeapProperties, - D3D12_HEAP_FLAG_ALLOW_DISPLAY, - &swapChainBufferDesc, - D3D12_RESOURCE_STATE_PRESENT, - &swapChainOptimizedClearValue, - IID_GRAPHICS_PPV_ARGS(m_renderTargets[n].GetAddressOf()))); - - wchar_t name[25] = {}; - swprintf_s(name, L"Render target %u", n); - m_renderTargets[n]->SetName(name); - - D3D12_RENDER_TARGET_VIEW_DESC rtvDesc = {}; - rtvDesc.Format = m_backBufferFormat; - rtvDesc.ViewDimension = D3D12_RTV_DIMENSION_TEXTURE2D; - - CD3DX12_CPU_DESCRIPTOR_HANDLE rtvDescriptor( - m_rtvDescriptorHeap->GetCPUDescriptorHandleForHeapStart(), - static_cast(n), m_rtvDescriptorSize); - m_d3dDevice->CreateRenderTargetView(m_renderTargets[n].Get(), &rtvDesc, rtvDescriptor); - } - - // Reset the index to the current back buffer. - m_backBufferIndex = 0; - - if (m_depthBufferFormat != DXGI_FORMAT_UNKNOWN) - { - // Allocate a 2-D surface as the depth/stencil buffer and create a depth/stencil view - // on this surface. - CD3DX12_HEAP_PROPERTIES depthHeapProperties(D3D12_HEAP_TYPE_DEFAULT); - - D3D12_RESOURCE_DESC depthStencilDesc = CD3DX12_RESOURCE_DESC::Tex2D( - m_depthBufferFormat, - backBufferWidth, - backBufferHeight, - 1, // This depth stencil view has only one texture. - 1 // Use a single mipmap level. - ); - depthStencilDesc.Flags |= D3D12_RESOURCE_FLAG_ALLOW_DEPTH_STENCIL; - - D3D12_CLEAR_VALUE depthOptimizedClearValue = {}; - depthOptimizedClearValue.Format = m_depthBufferFormat; - depthOptimizedClearValue.DepthStencil.Depth = 1.0f; - depthOptimizedClearValue.DepthStencil.Stencil = 0; - - ThrowIfFailed(m_d3dDevice->CreateCommittedResource( - &depthHeapProperties, - D3D12_HEAP_FLAG_NONE, - &depthStencilDesc, - D3D12_RESOURCE_STATE_DEPTH_WRITE, - &depthOptimizedClearValue, - IID_GRAPHICS_PPV_ARGS(m_depthStencil.ReleaseAndGetAddressOf()) - )); - - m_depthStencil->SetName(L"Depth stencil"); - - D3D12_DEPTH_STENCIL_VIEW_DESC dsvDesc = {}; - dsvDesc.Format = m_depthBufferFormat; - dsvDesc.ViewDimension = D3D12_DSV_DIMENSION_TEXTURE2D; - - m_d3dDevice->CreateDepthStencilView(m_depthStencil.Get(), &dsvDesc, m_dsvDescriptorHeap->GetCPUDescriptorHandleForHeapStart()); - } - - // Set the 3D rendering viewport and scissor rectangle to target the entire window. - m_screenViewport.TopLeftX = m_screenViewport.TopLeftY = 0.f; - m_screenViewport.Width = static_cast(backBufferWidth); - m_screenViewport.Height = static_cast(backBufferHeight); - m_screenViewport.MinDepth = D3D12_MIN_DEPTH; - m_screenViewport.MaxDepth = D3D12_MAX_DEPTH; - - m_scissorRect.left = m_scissorRect.top = 0; - m_scissorRect.right = static_cast(backBufferWidth); - m_scissorRect.bottom = static_cast(backBufferHeight); -} - -// Prepare the command list and render target for rendering. -void DeviceResources::Prepare(D3D12_RESOURCE_STATES beforeState, D3D12_RESOURCE_STATES afterState) -{ - // Wait until frame start is signaled - m_framePipelineToken = D3D12XBOX_FRAME_PIPELINE_TOKEN_NULL; - ThrowIfFailed(m_d3dDevice->WaitFrameEventX(D3D12XBOX_FRAME_EVENT_ORIGIN, INFINITE, nullptr, D3D12XBOX_WAIT_FRAME_EVENT_FLAG_NONE, &m_framePipelineToken)); - - // Reset command list and allocator. - ThrowIfFailed(m_commandAllocators[m_backBufferIndex]->Reset()); - ThrowIfFailed(m_commandList->Reset(m_commandAllocators[m_backBufferIndex].Get(), nullptr)); - - if (beforeState != afterState) - { - // Transition the render target into the correct state to allow for drawing into it. - D3D12_RESOURCE_BARRIER barrier = CD3DX12_RESOURCE_BARRIER::Transition(m_renderTargets[m_backBufferIndex].Get(), - beforeState, afterState); - m_commandList->ResourceBarrier(1, &barrier); - } -} - -// Present the contents of the swap chain to the screen. -void DeviceResources::Present(D3D12_RESOURCE_STATES beforeState) -{ - if (beforeState != D3D12_RESOURCE_STATE_PRESENT) - { - // Transition the render target to the state that allows it to be presented to the display. - D3D12_RESOURCE_BARRIER barrier = CD3DX12_RESOURCE_BARRIER::Transition(m_renderTargets[m_backBufferIndex].Get(), beforeState, D3D12_RESOURCE_STATE_PRESENT); - m_commandList->ResourceBarrier(1, &barrier); - } - - // Send the command list off to the GPU for processing. - ThrowIfFailed(m_commandList->Close()); - m_commandQueue->ExecuteCommandLists(1, CommandListCast(m_commandList.GetAddressOf())); - - // Present the backbuffer using the PresentX API. - D3D12XBOX_PRESENT_PLANE_PARAMETERS planeParameters = {}; - planeParameters.Token = m_framePipelineToken; - planeParameters.ResourceCount = 1; - planeParameters.ppResources = m_renderTargets[m_backBufferIndex].GetAddressOf(); - - ThrowIfFailed( - m_commandQueue->PresentX(1, &planeParameters, nullptr) - ); - - // Xbox One apps do not need to handle DXGI_ERROR_DEVICE_REMOVED or DXGI_ERROR_DEVICE_RESET. - - MoveToNextFrame(); -} - -// Handle GPU suspend/resume -void DeviceResources::Suspend() -{ - m_commandQueue->SuspendX(0); -} - -void DeviceResources::Resume() -{ - m_commandQueue->ResumeX(); - - RegisterFrameEvents(); -} - -// Wait for pending GPU work to complete. -void DeviceResources::WaitForGpu() noexcept -{ - if (m_commandQueue && m_fence && m_fenceEvent.IsValid()) - { - // Schedule a Signal command in the GPU queue. - UINT64 fenceValue = m_fenceValues[m_backBufferIndex]; - if (SUCCEEDED(m_commandQueue->Signal(m_fence.Get(), fenceValue))) - { - // Wait until the Signal has been processed. - if (SUCCEEDED(m_fence->SetEventOnCompletion(fenceValue, m_fenceEvent.Get()))) - { - WaitForSingleObjectEx(m_fenceEvent.Get(), INFINITE, FALSE); - - // Increment the fence value for the current frame. - m_fenceValues[m_backBufferIndex]++; - } - } - } -} - -// Prepare to render the next frame. -void DeviceResources::MoveToNextFrame() -{ - // Schedule a Signal command in the queue. - const UINT64 currentFenceValue = m_fenceValues[m_backBufferIndex]; - ThrowIfFailed(m_commandQueue->Signal(m_fence.Get(), currentFenceValue)); - - // Update the back buffer index. - m_backBufferIndex = (m_backBufferIndex + 1) % m_backBufferCount; - - // If the next frame is not ready to be rendered yet, wait until it is ready. - if (m_fence->GetCompletedValue() < m_fenceValues[m_backBufferIndex]) - { - ThrowIfFailed(m_fence->SetEventOnCompletion(m_fenceValues[m_backBufferIndex], m_fenceEvent.Get())); - WaitForSingleObjectEx(m_fenceEvent.Get(), INFINITE, FALSE); - } - - // Set the fence value for the next frame. - m_fenceValues[m_backBufferIndex] = currentFenceValue + 1; -} - -// Set frame interval and register for frame events -void DeviceResources::RegisterFrameEvents() -{ - // First, retrieve the underlying DXGI device from the D3D device. - ComPtr dxgiDevice; - ThrowIfFailed(m_d3dDevice.As(&dxgiDevice)); - - // Identify the physical adapter (GPU or card) this device is running on. - ComPtr dxgiAdapter; - ThrowIfFailed(dxgiDevice->GetAdapter(dxgiAdapter.GetAddressOf())); - - // Retrieve the outputs for the adapter. - ComPtr dxgiOutput; - ThrowIfFailed(dxgiAdapter->EnumOutputs(0, dxgiOutput.GetAddressOf())); - - // Set frame interval and register for frame events - ThrowIfFailed(m_d3dDevice->SetFrameIntervalX( - dxgiOutput.Get(), - D3D12XBOX_FRAME_INTERVAL_60_HZ, - m_backBufferCount - 1u /* Allow n-1 frames of latency */, - D3D12XBOX_FRAME_INTERVAL_FLAG_NONE)); - - ThrowIfFailed(m_d3dDevice->ScheduleFrameEventX( - D3D12XBOX_FRAME_EVENT_ORIGIN, - 0U, - nullptr, - D3D12XBOX_SCHEDULE_FRAME_EVENT_FLAG_NONE)); -} diff --git a/src/tools/nnfusion/templates/dxcompute/Direct3DXBoxNN/runtime/DeviceResources.h b/src/tools/nnfusion/templates/dxcompute/Direct3DXBoxNN/runtime/DeviceResources.h deleted file mode 100644 index 47f5f3eec..000000000 --- a/src/tools/nnfusion/templates/dxcompute/Direct3DXBoxNN/runtime/DeviceResources.h +++ /dev/null @@ -1,112 +0,0 @@ -// -// DeviceResources.h - A wrapper for the Direct3D 12.X device and swapchain -// - -#pragma once - -namespace DX -{ - // Controls all the DirectX device resources. - class DeviceResources - { - public: - static const unsigned int c_Enable4K_UHD = 0x1; - - DeviceResources(DXGI_FORMAT backBufferFormat = DXGI_FORMAT_B8G8R8A8_UNORM, - DXGI_FORMAT depthBufferFormat = DXGI_FORMAT_D32_FLOAT, - UINT backBufferCount = 2, - unsigned int flags = 0) noexcept(false); - ~DeviceResources(); - - DeviceResources(DeviceResources&&) = default; - DeviceResources& operator= (DeviceResources&&) = default; - - DeviceResources(DeviceResources const&) = delete; - DeviceResources& operator= (DeviceResources const&) = delete; - - void CreateDeviceResources(); - void CreateWindowSizeDependentResources(); - void SetWindow(HWND window) noexcept { m_window = window; } - void Prepare(D3D12_RESOURCE_STATES beforeState = D3D12_RESOURCE_STATE_PRESENT, - D3D12_RESOURCE_STATES afterState = D3D12_RESOURCE_STATE_RENDER_TARGET); - void Present(D3D12_RESOURCE_STATES beforeState = D3D12_RESOURCE_STATE_RENDER_TARGET); - void Suspend(); - void Resume(); - void WaitForGpu() noexcept; - - // Device Accessors. - RECT GetOutputSize() const noexcept { return m_outputSize; } - - // Direct3D Accessors. - auto GetD3DDevice() const noexcept { return m_d3dDevice.Get(); } - HWND GetWindow() const noexcept { return m_window; } - D3D_FEATURE_LEVEL GetDeviceFeatureLevel() const noexcept { return m_d3dFeatureLevel; } - ID3D12Resource* GetRenderTarget() const noexcept { return m_renderTargets[m_backBufferIndex].Get(); } - ID3D12Resource* GetDepthStencil() const noexcept { return m_depthStencil.Get(); } - ID3D12CommandQueue* GetCommandQueue() const noexcept { return m_commandQueue.Get(); } - ID3D12CommandAllocator* GetCommandAllocator() const noexcept { return m_commandAllocators[m_backBufferIndex].Get(); } - auto GetCommandList() const noexcept { return m_commandList.Get(); } - DXGI_FORMAT GetBackBufferFormat() const noexcept { return m_backBufferFormat; } - DXGI_FORMAT GetDepthBufferFormat() const noexcept { return m_depthBufferFormat; } - D3D12_VIEWPORT GetScreenViewport() const noexcept { return m_screenViewport; } - D3D12_RECT GetScissorRect() const noexcept { return m_scissorRect; } - UINT GetCurrentFrameIndex() const noexcept { return m_backBufferIndex; } - UINT GetBackBufferCount() const noexcept { return m_backBufferCount; } - unsigned int GetDeviceOptions() const noexcept { return m_options; } - - CD3DX12_CPU_DESCRIPTOR_HANDLE GetRenderTargetView() const noexcept - { - return CD3DX12_CPU_DESCRIPTOR_HANDLE( - m_rtvDescriptorHeap->GetCPUDescriptorHandleForHeapStart(), - static_cast(m_backBufferIndex), m_rtvDescriptorSize); - } - CD3DX12_CPU_DESCRIPTOR_HANDLE GetDepthStencilView() const noexcept - { - return CD3DX12_CPU_DESCRIPTOR_HANDLE(m_dsvDescriptorHeap->GetCPUDescriptorHandleForHeapStart()); - } - - private: - void MoveToNextFrame(); - void RegisterFrameEvents(); - - static const size_t MAX_BACK_BUFFER_COUNT = 3; - - UINT m_backBufferIndex; - - // Direct3D objects. - Microsoft::WRL::ComPtr m_d3dDevice; - Microsoft::WRL::ComPtr m_commandList; - Microsoft::WRL::ComPtr m_commandQueue; - Microsoft::WRL::ComPtr m_commandAllocators[MAX_BACK_BUFFER_COUNT]; - - // Swap chain objects. - Microsoft::WRL::ComPtr m_renderTargets[MAX_BACK_BUFFER_COUNT]; - Microsoft::WRL::ComPtr m_depthStencil; - - // Presentation fence objects. - Microsoft::WRL::ComPtr m_fence; - UINT64 m_fenceValues[MAX_BACK_BUFFER_COUNT]; - Microsoft::WRL::Wrappers::Event m_fenceEvent; - D3D12XBOX_FRAME_PIPELINE_TOKEN m_framePipelineToken; - - // Direct3D rendering objects. - Microsoft::WRL::ComPtr m_rtvDescriptorHeap; - Microsoft::WRL::ComPtr m_dsvDescriptorHeap; - UINT m_rtvDescriptorSize; - D3D12_VIEWPORT m_screenViewport; - D3D12_RECT m_scissorRect; - - // Direct3D properties. - DXGI_FORMAT m_backBufferFormat; - DXGI_FORMAT m_depthBufferFormat; - UINT m_backBufferCount; - - // Cached device properties. - HWND m_window; - D3D_FEATURE_LEVEL m_d3dFeatureLevel; - RECT m_outputSize; - - // DeviceResources options (see flags above) - unsigned int m_options; - }; -} diff --git a/src/tools/nnfusion/templates/dxcompute/Direct3DXBoxNN/runtime/StepTimer.h b/src/tools/nnfusion/templates/dxcompute/Direct3DXBoxNN/runtime/StepTimer.h deleted file mode 100644 index c9fa16197..000000000 --- a/src/tools/nnfusion/templates/dxcompute/Direct3DXBoxNN/runtime/StepTimer.h +++ /dev/null @@ -1,189 +0,0 @@ -// -// StepTimer.h - A simple timer that provides elapsed time information -// - -#pragma once - -#include -#include -#include - -namespace DX -{ - // Helper class for animation and simulation timing. - class StepTimer - { - public: - StepTimer() noexcept(false) : - m_elapsedTicks(0), - m_totalTicks(0), - m_leftOverTicks(0), - m_frameCount(0), - m_framesPerSecond(0), - m_framesThisSecond(0), - m_qpcSecondCounter(0), - m_isFixedTimeStep(false), - m_targetElapsedTicks(TicksPerSecond / 60) - { - if (!QueryPerformanceFrequency(&m_qpcFrequency)) - { - throw std::exception( "QueryPerformanceFrequency" ); - } - - if (!QueryPerformanceCounter(&m_qpcLastTime)) - { - throw std::exception( "QueryPerformanceCounter" ); - } - - // Initialize max delta to 1/10 of a second. - m_qpcMaxDelta = static_cast(m_qpcFrequency.QuadPart / 10); - } - - // Get elapsed time since the previous Update call. - uint64_t GetElapsedTicks() const noexcept { return m_elapsedTicks; } - double GetElapsedSeconds() const noexcept { return TicksToSeconds(m_elapsedTicks); } - - // Get total time since the start of the program. - uint64_t GetTotalTicks() const noexcept { return m_totalTicks; } - double GetTotalSeconds() const noexcept { return TicksToSeconds(m_totalTicks); } - - // Get total number of updates since start of the program. - uint32_t GetFrameCount() const noexcept { return m_frameCount; } - - // Get the current framerate. - uint32_t GetFramesPerSecond() const noexcept { return m_framesPerSecond; } - - // Set whether to use fixed or variable timestep mode. - void SetFixedTimeStep(bool isFixedTimestep) noexcept { m_isFixedTimeStep = isFixedTimestep; } - - // Set how often to call Update when in fixed timestep mode. - void SetTargetElapsedTicks(uint64_t targetElapsed) noexcept { m_targetElapsedTicks = targetElapsed; } - void SetTargetElapsedSeconds(double targetElapsed) noexcept { m_targetElapsedTicks = SecondsToTicks(targetElapsed); } - - // Integer format represents time using 10,000,000 ticks per second. - static const uint64_t TicksPerSecond = 10000000; - - static constexpr double TicksToSeconds(uint64_t ticks) noexcept { return static_cast(ticks) / TicksPerSecond; } - static constexpr uint64_t SecondsToTicks(double seconds) noexcept { return static_cast(seconds * TicksPerSecond); } - - // After an intentional timing discontinuity (for instance a blocking IO operation) - // call this to avoid having the fixed timestep logic attempt a set of catch-up - // Update calls. - - void ResetElapsedTime() - { - if (!QueryPerformanceCounter(&m_qpcLastTime)) - { - throw std::exception("QueryPerformanceCounter"); - } - - m_leftOverTicks = 0; - m_framesPerSecond = 0; - m_framesThisSecond = 0; - m_qpcSecondCounter = 0; - } - - // Update timer state, calling the specified Update function the appropriate number of times. - template - void Tick(const TUpdate& update) - { - // Query the current time. - LARGE_INTEGER currentTime; - - if (!QueryPerformanceCounter(¤tTime)) - { - throw std::exception( "QueryPerformanceCounter" ); - } - - uint64_t timeDelta = static_cast(currentTime.QuadPart - m_qpcLastTime.QuadPart); - - m_qpcLastTime = currentTime; - m_qpcSecondCounter += timeDelta; - - // Clamp excessively large time deltas (e.g. after paused in the debugger). - if (timeDelta > m_qpcMaxDelta) - { - timeDelta = m_qpcMaxDelta; - } - - // Convert QPC units into a canonical tick format. This cannot overflow due to the previous clamp. - timeDelta *= TicksPerSecond; - timeDelta /= static_cast(m_qpcFrequency.QuadPart); - - uint32_t lastFrameCount = m_frameCount; - - if (m_isFixedTimeStep) - { - // Fixed timestep update logic - - // If the app is running very close to the target elapsed time (within 1/4 of a millisecond) just clamp - // the clock to exactly match the target value. This prevents tiny and irrelevant errors - // from accumulating over time. Without this clamping, a game that requested a 60 fps - // fixed update, running with vsync enabled on a 59.94 NTSC display, would eventually - // accumulate enough tiny errors that it would drop a frame. It is better to just round - // small deviations down to zero to leave things running smoothly. - - if (static_cast(std::abs(static_cast(timeDelta - m_targetElapsedTicks))) < TicksPerSecond / 4000) - { - timeDelta = m_targetElapsedTicks; - } - - m_leftOverTicks += timeDelta; - - while (m_leftOverTicks >= m_targetElapsedTicks) - { - m_elapsedTicks = m_targetElapsedTicks; - m_totalTicks += m_targetElapsedTicks; - m_leftOverTicks -= m_targetElapsedTicks; - m_frameCount++; - - update(); - } - } - else - { - // Variable timestep update logic. - m_elapsedTicks = timeDelta; - m_totalTicks += timeDelta; - m_leftOverTicks = 0; - m_frameCount++; - - update(); - } - - // Track the current framerate. - if (m_frameCount != lastFrameCount) - { - m_framesThisSecond++; - } - - if (m_qpcSecondCounter >= static_cast(m_qpcFrequency.QuadPart)) - { - m_framesPerSecond = m_framesThisSecond; - m_framesThisSecond = 0; - m_qpcSecondCounter %= static_cast(m_qpcFrequency.QuadPart); - } - } - - private: - // Source timing data uses QPC units. - LARGE_INTEGER m_qpcFrequency; - LARGE_INTEGER m_qpcLastTime; - uint64_t m_qpcMaxDelta; - - // Derived timing data uses a canonical tick format. - uint64_t m_elapsedTicks; - uint64_t m_totalTicks; - uint64_t m_leftOverTicks; - - // Members for tracking the framerate. - uint32_t m_frameCount; - uint32_t m_framesPerSecond; - uint32_t m_framesThisSecond; - uint64_t m_qpcSecondCounter; - - // Members for configuring fixed timestep mode. - bool m_isFixedTimeStep; - uint64_t m_targetElapsedTicks; - }; -} diff --git a/src/tools/nnfusion/templates/dxcompute/Direct3DXBoxNN/runtime/half.hpp b/src/tools/nnfusion/templates/dxcompute/Direct3DXBoxNN/runtime/half.hpp new file mode 100644 index 000000000..d0a882dd6 --- /dev/null +++ b/src/tools/nnfusion/templates/dxcompute/Direct3DXBoxNN/runtime/half.hpp @@ -0,0 +1,4601 @@ +// half - IEEE 754-based half-precision floating-point library. +// +// Copyright (c) 2012-2021 Christian Rau +// +// Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation +// files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, +// modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the +// Software is furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE +// WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR +// COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, +// ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +// Version 2.2.0 + +/// \file +/// Main header file for half-precision functionality. + +#ifndef HALF_HALF_HPP +#define HALF_HALF_HPP + +#define HALF_GCC_VERSION (__GNUC__*100+__GNUC_MINOR__) + +#if defined(__INTEL_COMPILER) + #define HALF_ICC_VERSION __INTEL_COMPILER +#elif defined(__ICC) + #define HALF_ICC_VERSION __ICC +#elif defined(__ICL) + #define HALF_ICC_VERSION __ICL +#else + #define HALF_ICC_VERSION 0 +#endif + +// check C++11 language features +#if defined(__clang__) // clang + #if __has_feature(cxx_static_assert) && !defined(HALF_ENABLE_CPP11_STATIC_ASSERT) + #define HALF_ENABLE_CPP11_STATIC_ASSERT 1 + #endif + #if __has_feature(cxx_constexpr) && !defined(HALF_ENABLE_CPP11_CONSTEXPR) + #define HALF_ENABLE_CPP11_CONSTEXPR 1 + #endif + #if __has_feature(cxx_noexcept) && !defined(HALF_ENABLE_CPP11_NOEXCEPT) + #define HALF_ENABLE_CPP11_NOEXCEPT 1 + #endif + #if __has_feature(cxx_user_literals) && !defined(HALF_ENABLE_CPP11_USER_LITERALS) + #define HALF_ENABLE_CPP11_USER_LITERALS 1 + #endif + #if __has_feature(cxx_thread_local) && !defined(HALF_ENABLE_CPP11_THREAD_LOCAL) + #define HALF_ENABLE_CPP11_THREAD_LOCAL 1 + #endif + #if (defined(__GXX_EXPERIMENTAL_CXX0X__) || __cplusplus >= 201103L) && !defined(HALF_ENABLE_CPP11_LONG_LONG) + #define HALF_ENABLE_CPP11_LONG_LONG 1 + #endif +#elif HALF_ICC_VERSION && defined(__INTEL_CXX11_MODE__) // Intel C++ + #if HALF_ICC_VERSION >= 1500 && !defined(HALF_ENABLE_CPP11_THREAD_LOCAL) + #define HALF_ENABLE_CPP11_THREAD_LOCAL 1 + #endif + #if HALF_ICC_VERSION >= 1500 && !defined(HALF_ENABLE_CPP11_USER_LITERALS) + #define HALF_ENABLE_CPP11_USER_LITERALS 1 + #endif + #if HALF_ICC_VERSION >= 1400 && !defined(HALF_ENABLE_CPP11_CONSTEXPR) + #define HALF_ENABLE_CPP11_CONSTEXPR 1 + #endif + #if HALF_ICC_VERSION >= 1400 && !defined(HALF_ENABLE_CPP11_NOEXCEPT) + #define HALF_ENABLE_CPP11_NOEXCEPT 1 + #endif + #if HALF_ICC_VERSION >= 1110 && !defined(HALF_ENABLE_CPP11_STATIC_ASSERT) + #define HALF_ENABLE_CPP11_STATIC_ASSERT 1 + #endif + #if HALF_ICC_VERSION >= 1110 && !defined(HALF_ENABLE_CPP11_LONG_LONG) + #define HALF_ENABLE_CPP11_LONG_LONG 1 + #endif +#elif defined(__GNUC__) // gcc + #if defined(__GXX_EXPERIMENTAL_CXX0X__) || __cplusplus >= 201103L + #if HALF_GCC_VERSION >= 408 && !defined(HALF_ENABLE_CPP11_THREAD_LOCAL) + #define HALF_ENABLE_CPP11_THREAD_LOCAL 1 + #endif + #if HALF_GCC_VERSION >= 407 && !defined(HALF_ENABLE_CPP11_USER_LITERALS) + #define HALF_ENABLE_CPP11_USER_LITERALS 1 + #endif + #if HALF_GCC_VERSION >= 406 && !defined(HALF_ENABLE_CPP11_CONSTEXPR) + #define HALF_ENABLE_CPP11_CONSTEXPR 1 + #endif + #if HALF_GCC_VERSION >= 406 && !defined(HALF_ENABLE_CPP11_NOEXCEPT) + #define HALF_ENABLE_CPP11_NOEXCEPT 1 + #endif + #if HALF_GCC_VERSION >= 403 && !defined(HALF_ENABLE_CPP11_STATIC_ASSERT) + #define HALF_ENABLE_CPP11_STATIC_ASSERT 1 + #endif + #if !defined(HALF_ENABLE_CPP11_LONG_LONG) + #define HALF_ENABLE_CPP11_LONG_LONG 1 + #endif + #endif + #define HALF_TWOS_COMPLEMENT_INT 1 +#elif defined(_MSC_VER) // Visual C++ + #if _MSC_VER >= 1900 && !defined(HALF_ENABLE_CPP11_THREAD_LOCAL) + #define HALF_ENABLE_CPP11_THREAD_LOCAL 1 + #endif + #if _MSC_VER >= 1900 && !defined(HALF_ENABLE_CPP11_USER_LITERALS) + #define HALF_ENABLE_CPP11_USER_LITERALS 1 + #endif + #if _MSC_VER >= 1900 && !defined(HALF_ENABLE_CPP11_CONSTEXPR) + #define HALF_ENABLE_CPP11_CONSTEXPR 1 + #endif + #if _MSC_VER >= 1900 && !defined(HALF_ENABLE_CPP11_NOEXCEPT) + #define HALF_ENABLE_CPP11_NOEXCEPT 1 + #endif + #if _MSC_VER >= 1600 && !defined(HALF_ENABLE_CPP11_STATIC_ASSERT) + #define HALF_ENABLE_CPP11_STATIC_ASSERT 1 + #endif + #if _MSC_VER >= 1310 && !defined(HALF_ENABLE_CPP11_LONG_LONG) + #define HALF_ENABLE_CPP11_LONG_LONG 1 + #endif + #define HALF_TWOS_COMPLEMENT_INT 1 + #define HALF_POP_WARNINGS 1 + #pragma warning(push) + #pragma warning(disable : 4099 4127 4146) //struct vs class, constant in if, negative unsigned +#endif + +// check C++11 library features +#include +#if defined(_LIBCPP_VERSION) // libc++ + #if defined(__GXX_EXPERIMENTAL_CXX0X__) || __cplusplus >= 201103 + #ifndef HALF_ENABLE_CPP11_TYPE_TRAITS + #define HALF_ENABLE_CPP11_TYPE_TRAITS 1 + #endif + #ifndef HALF_ENABLE_CPP11_CSTDINT + #define HALF_ENABLE_CPP11_CSTDINT 1 + #endif + #ifndef HALF_ENABLE_CPP11_CMATH + #define HALF_ENABLE_CPP11_CMATH 1 + #endif + #ifndef HALF_ENABLE_CPP11_HASH + #define HALF_ENABLE_CPP11_HASH 1 + #endif + #ifndef HALF_ENABLE_CPP11_CFENV + #define HALF_ENABLE_CPP11_CFENV 1 + #endif + #endif +#elif defined(__GLIBCXX__) // libstdc++ + #if defined(__GXX_EXPERIMENTAL_CXX0X__) || __cplusplus >= 201103 + #ifdef __clang__ + #if __GLIBCXX__ >= 20080606 && !defined(HALF_ENABLE_CPP11_TYPE_TRAITS) + #define HALF_ENABLE_CPP11_TYPE_TRAITS 1 + #endif + #if __GLIBCXX__ >= 20080606 && !defined(HALF_ENABLE_CPP11_CSTDINT) + #define HALF_ENABLE_CPP11_CSTDINT 1 + #endif + #if __GLIBCXX__ >= 20080606 && !defined(HALF_ENABLE_CPP11_CMATH) + #define HALF_ENABLE_CPP11_CMATH 1 + #endif + #if __GLIBCXX__ >= 20080606 && !defined(HALF_ENABLE_CPP11_HASH) + #define HALF_ENABLE_CPP11_HASH 1 + #endif + #if __GLIBCXX__ >= 20080606 && !defined(HALF_ENABLE_CPP11_CFENV) + #define HALF_ENABLE_CPP11_CFENV 1 + #endif + #else + #if HALF_GCC_VERSION >= 403 && !defined(HALF_ENABLE_CPP11_TYPE_TRAITS) + #define HALF_ENABLE_CPP11_TYPE_TRAITS 1 + #endif + #if HALF_GCC_VERSION >= 403 && !defined(HALF_ENABLE_CPP11_CSTDINT) + #define HALF_ENABLE_CPP11_CSTDINT 1 + #endif + #if HALF_GCC_VERSION >= 403 && !defined(HALF_ENABLE_CPP11_CMATH) + #define HALF_ENABLE_CPP11_CMATH 1 + #endif + #if HALF_GCC_VERSION >= 403 && !defined(HALF_ENABLE_CPP11_HASH) + #define HALF_ENABLE_CPP11_HASH 1 + #endif + #if HALF_GCC_VERSION >= 403 && !defined(HALF_ENABLE_CPP11_CFENV) + #define HALF_ENABLE_CPP11_CFENV 1 + #endif + #endif + #endif +#elif defined(_CPPLIB_VER) // Dinkumware/Visual C++ + #if _CPPLIB_VER >= 520 && !defined(HALF_ENABLE_CPP11_TYPE_TRAITS) + #define HALF_ENABLE_CPP11_TYPE_TRAITS 1 + #endif + #if _CPPLIB_VER >= 520 && !defined(HALF_ENABLE_CPP11_CSTDINT) + #define HALF_ENABLE_CPP11_CSTDINT 1 + #endif + #if _CPPLIB_VER >= 520 && !defined(HALF_ENABLE_CPP11_HASH) + #define HALF_ENABLE_CPP11_HASH 1 + #endif + #if _CPPLIB_VER >= 610 && !defined(HALF_ENABLE_CPP11_CMATH) + #define HALF_ENABLE_CPP11_CMATH 1 + #endif + #if _CPPLIB_VER >= 610 && !defined(HALF_ENABLE_CPP11_CFENV) + #define HALF_ENABLE_CPP11_CFENV 1 + #endif +#endif +#undef HALF_GCC_VERSION +#undef HALF_ICC_VERSION + +// any error throwing C++ exceptions? +#if defined(HALF_ERRHANDLING_THROW_INVALID) || defined(HALF_ERRHANDLING_THROW_DIVBYZERO) || defined(HALF_ERRHANDLING_THROW_OVERFLOW) || defined(HALF_ERRHANDLING_THROW_UNDERFLOW) || defined(HALF_ERRHANDLING_THROW_INEXACT) +#define HALF_ERRHANDLING_THROWS 1 +#endif + +// any error handling enabled? +#define HALF_ERRHANDLING (HALF_ERRHANDLING_FLAGS||HALF_ERRHANDLING_ERRNO||HALF_ERRHANDLING_FENV||HALF_ERRHANDLING_THROWS) + +#if HALF_ERRHANDLING + #define HALF_UNUSED_NOERR(name) name +#else + #define HALF_UNUSED_NOERR(name) +#endif + +// support constexpr +#if HALF_ENABLE_CPP11_CONSTEXPR + #define HALF_CONSTEXPR constexpr + #define HALF_CONSTEXPR_CONST constexpr + #if HALF_ERRHANDLING + #define HALF_CONSTEXPR_NOERR + #else + #define HALF_CONSTEXPR_NOERR constexpr + #endif +#else + #define HALF_CONSTEXPR + #define HALF_CONSTEXPR_CONST const + #define HALF_CONSTEXPR_NOERR +#endif + +// support noexcept +#if HALF_ENABLE_CPP11_NOEXCEPT + #define HALF_NOEXCEPT noexcept + #define HALF_NOTHROW noexcept +#else + #define HALF_NOEXCEPT + #define HALF_NOTHROW throw() +#endif + +// support thread storage +#if HALF_ENABLE_CPP11_THREAD_LOCAL + #define HALF_THREAD_LOCAL thread_local +#else + #define HALF_THREAD_LOCAL static +#endif + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#if HALF_ENABLE_CPP11_TYPE_TRAITS + #include +#endif +#if HALF_ENABLE_CPP11_CSTDINT + #include +#endif +#if HALF_ERRHANDLING_ERRNO + #include +#endif +#if HALF_ENABLE_CPP11_CFENV + #include +#endif +#if HALF_ENABLE_CPP11_HASH + #include +#endif + + +#ifndef HALF_ENABLE_F16C_INTRINSICS + /// Enable F16C intruction set intrinsics. + /// Defining this to 1 enables the use of [F16C compiler intrinsics](https://en.wikipedia.org/wiki/F16C) for converting between + /// half-precision and single-precision values which may result in improved performance. This will not perform additional checks + /// for support of the F16C instruction set, so an appropriate target platform is required when enabling this feature. + /// + /// Unless predefined it will be enabled automatically when the `__F16C__` symbol is defined, which some compilers do on supporting platforms. + #define HALF_ENABLE_F16C_INTRINSICS __F16C__ +#endif +#if HALF_ENABLE_F16C_INTRINSICS + #include +#endif + +#ifdef HALF_DOXYGEN_ONLY +/// Type for internal floating-point computations. +/// This can be predefined to a built-in floating-point type (`float`, `double` or `long double`) to override the internal +/// half-precision implementation to use this type for computing arithmetic operations and mathematical function (if available). +/// This can result in improved performance for arithmetic operators and mathematical functions but might cause results to +/// deviate from the specified half-precision rounding mode and inhibits proper detection of half-precision exceptions. +#define HALF_ARITHMETIC_TYPE (undefined) + +/// Enable internal exception flags. +/// Defining this to 1 causes operations on half-precision values to raise internal floating-point exception flags according to +/// the IEEE 754 standard. These can then be cleared and checked with clearexcept(), testexcept(). +#define HALF_ERRHANDLING_FLAGS 0 + +/// Enable exception propagation to `errno`. +/// Defining this to 1 causes operations on half-precision values to propagate floating-point exceptions to +/// [errno](https://en.cppreference.com/w/cpp/error/errno) from ``. Specifically this will propagate domain errors as +/// [EDOM](https://en.cppreference.com/w/cpp/error/errno_macros) and pole, overflow and underflow errors as +/// [ERANGE](https://en.cppreference.com/w/cpp/error/errno_macros). Inexact errors won't be propagated. +#define HALF_ERRHANDLING_ERRNO 0 + +/// Enable exception propagation to built-in floating-point platform. +/// Defining this to 1 causes operations on half-precision values to propagate floating-point exceptions to the built-in +/// single- and double-precision implementation's exception flags using the +/// [C++11 floating-point environment control](https://en.cppreference.com/w/cpp/numeric/fenv) from ``. However, this +/// does not work in reverse and single- or double-precision exceptions will not raise the corresponding half-precision +/// exception flags, nor will explicitly clearing flags clear the corresponding built-in flags. +#define HALF_ERRHANDLING_FENV 0 + +/// Throw C++ exception on domain errors. +/// Defining this to a string literal causes operations on half-precision values to throw a +/// [std::domain_error](https://en.cppreference.com/w/cpp/error/domain_error) with the specified message on domain errors. +#define HALF_ERRHANDLING_THROW_INVALID (undefined) + +/// Throw C++ exception on pole errors. +/// Defining this to a string literal causes operations on half-precision values to throw a +/// [std::domain_error](https://en.cppreference.com/w/cpp/error/domain_error) with the specified message on pole errors. +#define HALF_ERRHANDLING_THROW_DIVBYZERO (undefined) + +/// Throw C++ exception on overflow errors. +/// Defining this to a string literal causes operations on half-precision values to throw a +/// [std::overflow_error](https://en.cppreference.com/w/cpp/error/overflow_error) with the specified message on overflows. +#define HALF_ERRHANDLING_THROW_OVERFLOW (undefined) + +/// Throw C++ exception on underflow errors. +/// Defining this to a string literal causes operations on half-precision values to throw a +/// [std::underflow_error](https://en.cppreference.com/w/cpp/error/underflow_error) with the specified message on underflows. +#define HALF_ERRHANDLING_THROW_UNDERFLOW (undefined) + +/// Throw C++ exception on rounding errors. +/// Defining this to 1 causes operations on half-precision values to throw a +/// [std::range_error](https://en.cppreference.com/w/cpp/error/range_error) with the specified message on general rounding errors. +#define HALF_ERRHANDLING_THROW_INEXACT (undefined) +#endif + +#ifndef HALF_ERRHANDLING_OVERFLOW_TO_INEXACT +/// Raise INEXACT exception on overflow. +/// Defining this to 1 (default) causes overflow errors to automatically raise inexact exceptions in addition. +/// These will be raised after any possible handling of the underflow exception. +#define HALF_ERRHANDLING_OVERFLOW_TO_INEXACT 1 +#endif + +#ifndef HALF_ERRHANDLING_UNDERFLOW_TO_INEXACT +/// Raise INEXACT exception on underflow. +/// Defining this to 1 (default) causes underflow errors to automatically raise inexact exceptions in addition. +/// These will be raised after any possible handling of the underflow exception. +/// +/// **Note:** This will actually cause underflow (and the accompanying inexact) exceptions to be raised *only* when the result +/// is inexact, while if disabled bare underflow errors will be raised for *any* (possibly exact) subnormal result. +#define HALF_ERRHANDLING_UNDERFLOW_TO_INEXACT 1 +#endif + +/// Default rounding mode. +/// This specifies the rounding mode used for all conversions between [half](\ref half_float::half)s and more precise types +/// (unless using half_cast() and specifying the rounding mode directly) as well as in arithmetic operations and mathematical +/// functions. It can be redefined (before including half.hpp) to one of the standard rounding modes using their respective +/// constants or the equivalent values of +/// [std::float_round_style](https://en.cppreference.com/w/cpp/types/numeric_limits/float_round_style): +/// +/// `std::float_round_style` | value | rounding +/// ---------------------------------|-------|------------------------- +/// `std::round_indeterminate` | -1 | fastest +/// `std::round_toward_zero` | 0 | toward zero +/// `std::round_to_nearest` | 1 | to nearest (default) +/// `std::round_toward_infinity` | 2 | toward positive infinity +/// `std::round_toward_neg_infinity` | 3 | toward negative infinity +/// +/// By default this is set to `1` (`std::round_to_nearest`), which rounds results to the nearest representable value. It can even +/// be set to [std::numeric_limits::round_style](https://en.cppreference.com/w/cpp/types/numeric_limits/round_style) to synchronize +/// the rounding mode with that of the built-in single-precision implementation (which is likely `std::round_to_nearest`, though). +#ifndef HALF_ROUND_STYLE + #define HALF_ROUND_STYLE 1 // = std::round_to_nearest +#endif + +/// Value signaling overflow. +/// In correspondence with `HUGE_VAL[F|L]` from `` this symbol expands to a positive value signaling the overflow of an +/// operation, in particular it just evaluates to positive infinity. +/// +/// **See also:** Documentation for [HUGE_VAL](https://en.cppreference.com/w/cpp/numeric/math/HUGE_VAL) +#define HUGE_VALH std::numeric_limits::infinity() + +/// Fast half-precision fma function. +/// This symbol is defined if the fma() function generally executes as fast as, or faster than, a separate +/// half-precision multiplication followed by an addition, which is always the case. +/// +/// **See also:** Documentation for [FP_FAST_FMA](https://en.cppreference.com/w/cpp/numeric/math/fma) +#define FP_FAST_FMAH 1 + +/// Half rounding mode. +/// In correspondence with `FLT_ROUNDS` from `` this symbol expands to the rounding mode used for +/// half-precision operations. It is an alias for [HALF_ROUND_STYLE](\ref HALF_ROUND_STYLE). +/// +/// **See also:** Documentation for [FLT_ROUNDS](https://en.cppreference.com/w/cpp/types/climits/FLT_ROUNDS) +#define HLF_ROUNDS HALF_ROUND_STYLE + +#ifndef FP_ILOGB0 + #define FP_ILOGB0 INT_MIN +#endif +#ifndef FP_ILOGBNAN + #define FP_ILOGBNAN INT_MAX +#endif +#ifndef FP_SUBNORMAL + #define FP_SUBNORMAL 0 +#endif +#ifndef FP_ZERO + #define FP_ZERO 1 +#endif +#ifndef FP_NAN + #define FP_NAN 2 +#endif +#ifndef FP_INFINITE + #define FP_INFINITE 3 +#endif +#ifndef FP_NORMAL + #define FP_NORMAL 4 +#endif + +#if !HALF_ENABLE_CPP11_CFENV && !defined(FE_ALL_EXCEPT) + #define FE_INVALID 0x10 + #define FE_DIVBYZERO 0x08 + #define FE_OVERFLOW 0x04 + #define FE_UNDERFLOW 0x02 + #define FE_INEXACT 0x01 + #define FE_ALL_EXCEPT (FE_INVALID|FE_DIVBYZERO|FE_OVERFLOW|FE_UNDERFLOW|FE_INEXACT) +#endif + + +/// Main namespace for half-precision functionality. +/// This namespace contains all the functionality provided by the library. +namespace half_float +{ + class half; + +#if HALF_ENABLE_CPP11_USER_LITERALS + /// Library-defined half-precision literals. + /// Import this namespace to enable half-precision floating-point literals: + /// ~~~~{.cpp} + /// using namespace half_float::literal; + /// half_float::half = 4.2_h; + /// ~~~~ + namespace literal + { + half operator "" _h(long double); + } +#endif + + /// \internal + /// \brief Implementation details. + namespace detail + { + #if HALF_ENABLE_CPP11_TYPE_TRAITS + /// Conditional type. + template struct conditional : std::conditional {}; + + /// Helper for tag dispatching. + template struct bool_type : std::integral_constant {}; + using std::true_type; + using std::false_type; + + /// Type traits for floating-point types. + template struct is_float : std::is_floating_point {}; + #else + /// Conditional type. + template struct conditional { typedef T type; }; + template struct conditional { typedef F type; }; + + /// Helper for tag dispatching. + template struct bool_type {}; + typedef bool_type true_type; + typedef bool_type false_type; + + /// Type traits for floating-point types. + template struct is_float : false_type {}; + template struct is_float : is_float {}; + template struct is_float : is_float {}; + template struct is_float : is_float {}; + template<> struct is_float : true_type {}; + template<> struct is_float : true_type {}; + template<> struct is_float : true_type {}; + #endif + + /// Type traits for floating-point bits. + template struct bits { typedef unsigned char type; }; + template struct bits : bits {}; + template struct bits : bits {}; + template struct bits : bits {}; + + #if HALF_ENABLE_CPP11_CSTDINT + /// Unsigned integer of (at least) 16 bits width. + typedef std::uint_least16_t uint16; + + /// Fastest unsigned integer of (at least) 32 bits width. + typedef std::uint_fast32_t uint32; + + /// Fastest signed integer of (at least) 32 bits width. + typedef std::int_fast32_t int32; + + /// Unsigned integer of (at least) 32 bits width. + template<> struct bits { typedef std::uint_least32_t type; }; + + /// Unsigned integer of (at least) 64 bits width. + template<> struct bits { typedef std::uint_least64_t type; }; + #else + /// Unsigned integer of (at least) 16 bits width. + typedef unsigned short uint16; + + /// Fastest unsigned integer of (at least) 32 bits width. + typedef unsigned long uint32; + + /// Fastest unsigned integer of (at least) 32 bits width. + typedef long int32; + + /// Unsigned integer of (at least) 32 bits width. + template<> struct bits : conditional::digits>=32,unsigned int,unsigned long> {}; + + #if HALF_ENABLE_CPP11_LONG_LONG + /// Unsigned integer of (at least) 64 bits width. + template<> struct bits : conditional::digits>=64,unsigned long,unsigned long long> {}; + #else + /// Unsigned integer of (at least) 64 bits width. + template<> struct bits { typedef unsigned long type; }; + #endif + #endif + + #ifdef HALF_ARITHMETIC_TYPE + /// Type to use for arithmetic computations and mathematic functions internally. + typedef HALF_ARITHMETIC_TYPE internal_t; + #endif + + /// Tag type for binary construction. + struct binary_t {}; + + /// Tag for binary construction. + HALF_CONSTEXPR_CONST binary_t binary = binary_t(); + + /// \name Implementation defined classification and arithmetic + /// \{ + + /// Check for infinity. + /// \tparam T argument type (builtin floating-point type) + /// \param arg value to query + /// \retval true if infinity + /// \retval false else + template bool builtin_isinf(T arg) + { + #if HALF_ENABLE_CPP11_CMATH + return std::isinf(arg); + #elif defined(_MSC_VER) + return !::_finite(static_cast(arg)) && !::_isnan(static_cast(arg)); + #else + return arg == std::numeric_limits::infinity() || arg == -std::numeric_limits::infinity(); + #endif + } + + /// Check for NaN. + /// \tparam T argument type (builtin floating-point type) + /// \param arg value to query + /// \retval true if not a number + /// \retval false else + template bool builtin_isnan(T arg) + { + #if HALF_ENABLE_CPP11_CMATH + return std::isnan(arg); + #elif defined(_MSC_VER) + return ::_isnan(static_cast(arg)) != 0; + #else + return arg != arg; + #endif + } + + /// Check sign. + /// \tparam T argument type (builtin floating-point type) + /// \param arg value to query + /// \retval true if signbit set + /// \retval false else + template bool builtin_signbit(T arg) + { + #if HALF_ENABLE_CPP11_CMATH + return std::signbit(arg); + #else + return arg < T() || (arg == T() && T(1)/arg < T()); + #endif + } + + /// Platform-independent sign mask. + /// \param arg integer value in two's complement + /// \retval -1 if \a arg negative + /// \retval 0 if \a arg positive + inline uint32 sign_mask(uint32 arg) + { + static const int N = std::numeric_limits::digits - 1; + #if HALF_TWOS_COMPLEMENT_INT + return static_cast(arg) >> N; + #else + return -((arg>>N)&1); + #endif + } + + /// Platform-independent arithmetic right shift. + /// \param arg integer value in two's complement + /// \param i shift amount (at most 31) + /// \return \a arg right shifted for \a i bits with possible sign extension + inline uint32 arithmetic_shift(uint32 arg, int i) + { + #if HALF_TWOS_COMPLEMENT_INT + return static_cast(arg) >> i; + #else + return static_cast(arg)/(static_cast(1)<>(std::numeric_limits::digits-1))&1); + #endif + } + + /// \} + /// \name Error handling + /// \{ + + /// Internal exception flags. + /// \return reference to global exception flags + inline int& errflags() { HALF_THREAD_LOCAL int flags = 0; return flags; } + + /// Raise floating-point exception. + /// \param flags exceptions to raise + /// \param cond condition to raise exceptions for + inline void raise(int HALF_UNUSED_NOERR(flags), bool HALF_UNUSED_NOERR(cond) = true) + { + #if HALF_ERRHANDLING + if(!cond) + return; + #if HALF_ERRHANDLING_FLAGS + errflags() |= flags; + #endif + #if HALF_ERRHANDLING_ERRNO + if(flags & FE_INVALID) + errno = EDOM; + else if(flags & (FE_DIVBYZERO|FE_OVERFLOW|FE_UNDERFLOW)) + errno = ERANGE; + #endif + #if HALF_ERRHANDLING_FENV && HALF_ENABLE_CPP11_CFENV + std::feraiseexcept(flags); + #endif + #ifdef HALF_ERRHANDLING_THROW_INVALID + if(flags & FE_INVALID) + throw std::domain_error(HALF_ERRHANDLING_THROW_INVALID); + #endif + #ifdef HALF_ERRHANDLING_THROW_DIVBYZERO + if(flags & FE_DIVBYZERO) + throw std::domain_error(HALF_ERRHANDLING_THROW_DIVBYZERO); + #endif + #ifdef HALF_ERRHANDLING_THROW_OVERFLOW + if(flags & FE_OVERFLOW) + throw std::overflow_error(HALF_ERRHANDLING_THROW_OVERFLOW); + #endif + #ifdef HALF_ERRHANDLING_THROW_UNDERFLOW + if(flags & FE_UNDERFLOW) + throw std::underflow_error(HALF_ERRHANDLING_THROW_UNDERFLOW); + #endif + #ifdef HALF_ERRHANDLING_THROW_INEXACT + if(flags & FE_INEXACT) + throw std::range_error(HALF_ERRHANDLING_THROW_INEXACT); + #endif + #if HALF_ERRHANDLING_UNDERFLOW_TO_INEXACT + if((flags & FE_UNDERFLOW) && !(flags & FE_INEXACT)) + raise(FE_INEXACT); + #endif + #if HALF_ERRHANDLING_OVERFLOW_TO_INEXACT + if((flags & FE_OVERFLOW) && !(flags & FE_INEXACT)) + raise(FE_INEXACT); + #endif + #endif + } + + /// Check and signal for any NaN. + /// \param x first half-precision value to check + /// \param y second half-precision value to check + /// \retval true if either \a x or \a y is NaN + /// \retval false else + /// \exception FE_INVALID if \a x or \a y is NaN + inline HALF_CONSTEXPR_NOERR bool compsignal(unsigned int x, unsigned int y) + { + #if HALF_ERRHANDLING + raise(FE_INVALID, (x&0x7FFF)>0x7C00 || (y&0x7FFF)>0x7C00); + #endif + return (x&0x7FFF) > 0x7C00 || (y&0x7FFF) > 0x7C00; + } + + /// Signal and silence signaling NaN. + /// \param nan half-precision NaN value + /// \return quiet NaN + /// \exception FE_INVALID if \a nan is signaling NaN + inline HALF_CONSTEXPR_NOERR unsigned int signal(unsigned int nan) + { + #if HALF_ERRHANDLING + raise(FE_INVALID, !(nan&0x200)); + #endif + return nan | 0x200; + } + + /// Signal and silence signaling NaNs. + /// \param x first half-precision value to check + /// \param y second half-precision value to check + /// \return quiet NaN + /// \exception FE_INVALID if \a x or \a y is signaling NaN + inline HALF_CONSTEXPR_NOERR unsigned int signal(unsigned int x, unsigned int y) + { + #if HALF_ERRHANDLING + raise(FE_INVALID, ((x&0x7FFF)>0x7C00 && !(x&0x200)) || ((y&0x7FFF)>0x7C00 && !(y&0x200))); + #endif + return ((x&0x7FFF)>0x7C00) ? (x|0x200) : (y|0x200); + } + + /// Signal and silence signaling NaNs. + /// \param x first half-precision value to check + /// \param y second half-precision value to check + /// \param z third half-precision value to check + /// \return quiet NaN + /// \exception FE_INVALID if \a x, \a y or \a z is signaling NaN + inline HALF_CONSTEXPR_NOERR unsigned int signal(unsigned int x, unsigned int y, unsigned int z) + { + #if HALF_ERRHANDLING + raise(FE_INVALID, ((x&0x7FFF)>0x7C00 && !(x&0x200)) || ((y&0x7FFF)>0x7C00 && !(y&0x200)) || ((z&0x7FFF)>0x7C00 && !(z&0x200))); + #endif + return ((x&0x7FFF)>0x7C00) ? (x|0x200) : ((y&0x7FFF)>0x7C00) ? (y|0x200) : (z|0x200); + } + + /// Select value or signaling NaN. + /// \param x preferred half-precision value + /// \param y ignored half-precision value except for signaling NaN + /// \return \a y if signaling NaN, \a x otherwise + /// \exception FE_INVALID if \a y is signaling NaN + inline HALF_CONSTEXPR_NOERR unsigned int select(unsigned int x, unsigned int HALF_UNUSED_NOERR(y)) + { + #if HALF_ERRHANDLING + return (((y&0x7FFF)>0x7C00) && !(y&0x200)) ? signal(y) : x; + #else + return x; + #endif + } + + /// Raise domain error and return NaN. + /// return quiet NaN + /// \exception FE_INVALID + inline HALF_CONSTEXPR_NOERR unsigned int invalid() + { + #if HALF_ERRHANDLING + raise(FE_INVALID); + #endif + return 0x7FFF; + } + + /// Raise pole error and return infinity. + /// \param sign half-precision value with sign bit only + /// \return half-precision infinity with sign of \a sign + /// \exception FE_DIVBYZERO + inline HALF_CONSTEXPR_NOERR unsigned int pole(unsigned int sign = 0) + { + #if HALF_ERRHANDLING + raise(FE_DIVBYZERO); + #endif + return sign | 0x7C00; + } + + /// Check value for underflow. + /// \param arg non-zero half-precision value to check + /// \return \a arg + /// \exception FE_UNDERFLOW if arg is subnormal + inline HALF_CONSTEXPR_NOERR unsigned int check_underflow(unsigned int arg) + { + #if HALF_ERRHANDLING && !HALF_ERRHANDLING_UNDERFLOW_TO_INEXACT + raise(FE_UNDERFLOW, !(arg&0x7C00)); + #endif + return arg; + } + + /// \} + /// \name Conversion and rounding + /// \{ + + /// Half-precision overflow. + /// \tparam R rounding mode to use + /// \param sign half-precision value with sign bit only + /// \return rounded overflowing half-precision value + /// \exception FE_OVERFLOW + template HALF_CONSTEXPR_NOERR unsigned int overflow(unsigned int sign = 0) + { + #if HALF_ERRHANDLING + raise(FE_OVERFLOW); + #endif + return (R==std::round_toward_infinity) ? (sign+0x7C00-(sign>>15)) : + (R==std::round_toward_neg_infinity) ? (sign+0x7BFF+(sign>>15)) : + (R==std::round_toward_zero) ? (sign|0x7BFF) : + (sign|0x7C00); + } + + /// Half-precision underflow. + /// \tparam R rounding mode to use + /// \param sign half-precision value with sign bit only + /// \return rounded underflowing half-precision value + /// \exception FE_UNDERFLOW + template HALF_CONSTEXPR_NOERR unsigned int underflow(unsigned int sign = 0) + { + #if HALF_ERRHANDLING + raise(FE_UNDERFLOW); + #endif + return (R==std::round_toward_infinity) ? (sign+1-(sign>>15)) : + (R==std::round_toward_neg_infinity) ? (sign+(sign>>15)) : + sign; + } + + /// Round half-precision number. + /// \tparam R rounding mode to use + /// \tparam I `true` to always raise INEXACT exception, `false` to raise only for rounded results + /// \param value finite half-precision number to round + /// \param g guard bit (most significant discarded bit) + /// \param s sticky bit (or of all but the most significant discarded bits) + /// \return rounded half-precision value + /// \exception FE_OVERFLOW on overflows + /// \exception FE_UNDERFLOW on underflows + /// \exception FE_INEXACT if value had to be rounded or \a I is `true` + template HALF_CONSTEXPR_NOERR unsigned int rounded(unsigned int value, int g, int s) + { + #if HALF_ERRHANDLING + value += (R==std::round_to_nearest) ? (g&(s|value)) : + (R==std::round_toward_infinity) ? (~(value>>15)&(g|s)) : + (R==std::round_toward_neg_infinity) ? ((value>>15)&(g|s)) : 0; + if((value&0x7C00) == 0x7C00) + raise(FE_OVERFLOW); + else if(value & 0x7C00) + raise(FE_INEXACT, I || (g|s)!=0); + else + raise(FE_UNDERFLOW, !(HALF_ERRHANDLING_UNDERFLOW_TO_INEXACT) || I || (g|s)!=0); + return value; + #else + return (R==std::round_to_nearest) ? (value+(g&(s|value))) : + (R==std::round_toward_infinity) ? (value+(~(value>>15)&(g|s))) : + (R==std::round_toward_neg_infinity) ? (value+((value>>15)&(g|s))) : + value; + #endif + } + + /// Round half-precision number to nearest integer value. + /// \tparam R rounding mode to use + /// \tparam E `true` for round to even, `false` for round away from zero + /// \tparam I `true` to raise INEXACT exception (if inexact), `false` to never raise it + /// \param value half-precision value to round + /// \return half-precision bits for nearest integral value + /// \exception FE_INVALID for signaling NaN + /// \exception FE_INEXACT if value had to be rounded and \a I is `true` + template unsigned int integral(unsigned int value) + { + unsigned int abs = value & 0x7FFF; + if(abs < 0x3C00) + { + raise(FE_INEXACT, I); + return ((R==std::round_to_nearest) ? (0x3C00&-static_cast(abs>=(0x3800+E))) : + (R==std::round_toward_infinity) ? (0x3C00&-(~(value>>15)&(abs!=0))) : + (R==std::round_toward_neg_infinity) ? (0x3C00&-static_cast(value>0x8000)) : + 0) | (value&0x8000); + } + if(abs >= 0x6400) + return (abs>0x7C00) ? signal(value) : value; + unsigned int exp = 25 - (abs>>10), mask = (1<>exp)&E)) : + (R==std::round_toward_infinity) ? (mask&((value>>15)-1)) : + (R==std::round_toward_neg_infinity) ? (mask&-(value>>15)) : + 0) + value) & ~mask; + } + + /// Convert fixed point to half-precision floating-point. + /// \tparam R rounding mode to use + /// \tparam F number of fractional bits in [11,31] + /// \tparam S `true` for signed, `false` for unsigned + /// \tparam N `true` for additional normalization step, `false` if already normalized to 1.F + /// \tparam I `true` to always raise INEXACT exception, `false` to raise only for rounded results + /// \param m mantissa in Q1.F fixed point format + /// \param exp biased exponent - 1 + /// \param sign half-precision value with sign bit only + /// \param s sticky bit (or of all but the most significant already discarded bits) + /// \return value converted to half-precision + /// \exception FE_OVERFLOW on overflows + /// \exception FE_UNDERFLOW on underflows + /// \exception FE_INEXACT if value had to be rounded or \a I is `true` + template unsigned int fixed2half(uint32 m, int exp = 14, unsigned int sign = 0, int s = 0) + { + if(S) + { + uint32 msign = sign_mask(m); + m = (m^msign) - msign; + sign = msign & 0x8000; + } + if(N) + for(; m<(static_cast(1)<(sign+(m>>(F-10-exp)), (m>>(F-11-exp))&1, s|((m&((static_cast(1)<<(F-11-exp))-1))!=0)); + return rounded(sign+(exp<<10)+(m>>(F-10)), (m>>(F-11))&1, s|((m&((static_cast(1)<<(F-11))-1))!=0)); + } + + /// Convert IEEE single-precision to half-precision. + /// Credit for this goes to [Jeroen van der Zijp](ftp://ftp.fox-toolkit.org/pub/fasthalffloatconversion.pdf). + /// \tparam R rounding mode to use + /// \param value single-precision value to convert + /// \return rounded half-precision value + /// \exception FE_OVERFLOW on overflows + /// \exception FE_UNDERFLOW on underflows + /// \exception FE_INEXACT if value had to be rounded + template unsigned int float2half_impl(float value, true_type) + { + #if HALF_ENABLE_F16C_INTRINSICS + return _mm_cvtsi128_si32(_mm_cvtps_ph(_mm_set_ss(value), + (R==std::round_to_nearest) ? _MM_FROUND_TO_NEAREST_INT : + (R==std::round_toward_zero) ? _MM_FROUND_TO_ZERO : + (R==std::round_toward_infinity) ? _MM_FROUND_TO_POS_INF : + (R==std::round_toward_neg_infinity) ? _MM_FROUND_TO_NEG_INF : + _MM_FROUND_CUR_DIRECTION)); + #else + bits::type fbits; + std::memcpy(&fbits, &value, sizeof(float)); + #if 1 + unsigned int sign = (fbits>>16) & 0x8000; + fbits &= 0x7FFFFFFF; + if(fbits >= 0x7F800000) + return sign | 0x7C00 | ((fbits>0x7F800000) ? (0x200|((fbits>>13)&0x3FF)) : 0); + if(fbits >= 0x47800000) + return overflow(sign); + if(fbits >= 0x38800000) + return rounded(sign|(((fbits>>23)-112)<<10)|((fbits>>13)&0x3FF), (fbits>>12)&1, (fbits&0xFFF)!=0); + if(fbits >= 0x33000000) + { + int i = 125 - (fbits>>23); + fbits = (fbits&0x7FFFFF) | 0x800000; + return rounded(sign|(fbits>>(i+1)), (fbits>>i)&1, (fbits&((static_cast(1)<(sign); + return sign; + #else + static const uint16 base_table[512] = { + 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, + 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, + 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, + 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, + 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, + 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, + 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0001, 0x0002, 0x0004, 0x0008, 0x0010, 0x0020, 0x0040, 0x0080, 0x0100, + 0x0200, 0x0400, 0x0800, 0x0C00, 0x1000, 0x1400, 0x1800, 0x1C00, 0x2000, 0x2400, 0x2800, 0x2C00, 0x3000, 0x3400, 0x3800, 0x3C00, + 0x4000, 0x4400, 0x4800, 0x4C00, 0x5000, 0x5400, 0x5800, 0x5C00, 0x6000, 0x6400, 0x6800, 0x6C00, 0x7000, 0x7400, 0x7800, 0x7BFF, + 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, + 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, + 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, + 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, + 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, + 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, + 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7C00, + 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, + 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, + 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, + 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, + 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, + 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, + 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8001, 0x8002, 0x8004, 0x8008, 0x8010, 0x8020, 0x8040, 0x8080, 0x8100, + 0x8200, 0x8400, 0x8800, 0x8C00, 0x9000, 0x9400, 0x9800, 0x9C00, 0xA000, 0xA400, 0xA800, 0xAC00, 0xB000, 0xB400, 0xB800, 0xBC00, + 0xC000, 0xC400, 0xC800, 0xCC00, 0xD000, 0xD400, 0xD800, 0xDC00, 0xE000, 0xE400, 0xE800, 0xEC00, 0xF000, 0xF400, 0xF800, 0xFBFF, + 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, + 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, + 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, + 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, + 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, + 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, + 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFC00 }; + static const unsigned char shift_table[256] = { + 24, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, + 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, + 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, + 25, 25, 25, 25, 25, 25, 24, 23, 22, 21, 20, 19, 18, 17, 16, 15, 14, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, + 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, + 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, + 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, + 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 13 }; + int sexp = fbits >> 23, exp = sexp & 0xFF, i = shift_table[exp]; + fbits &= 0x7FFFFF; + uint32 m = (fbits|((exp!=0)<<23)) & -static_cast(exp!=0xFF); + return rounded(base_table[sexp]+(fbits>>i), (m>>(i-1))&1, (((static_cast(1)<<(i-1))-1)&m)!=0); + #endif + #endif + } + + /// Convert IEEE double-precision to half-precision. + /// \tparam R rounding mode to use + /// \param value double-precision value to convert + /// \return rounded half-precision value + /// \exception FE_OVERFLOW on overflows + /// \exception FE_UNDERFLOW on underflows + /// \exception FE_INEXACT if value had to be rounded + template unsigned int float2half_impl(double value, true_type) + { + #if HALF_ENABLE_F16C_INTRINSICS + if(R == std::round_indeterminate) + return _mm_cvtsi128_si32(_mm_cvtps_ph(_mm_cvtpd_ps(_mm_set_sd(value)), _MM_FROUND_CUR_DIRECTION)); + #endif + bits::type dbits; + std::memcpy(&dbits, &value, sizeof(double)); + uint32 hi = dbits >> 32, lo = dbits & 0xFFFFFFFF; + unsigned int sign = (hi>>16) & 0x8000; + hi &= 0x7FFFFFFF; + if(hi >= 0x7FF00000) + return sign | 0x7C00 | ((dbits&0xFFFFFFFFFFFFF) ? (0x200|((hi>>10)&0x3FF)) : 0); + if(hi >= 0x40F00000) + return overflow(sign); + if(hi >= 0x3F100000) + return rounded(sign|(((hi>>20)-1008)<<10)|((hi>>10)&0x3FF), (hi>>9)&1, ((hi&0x1FF)|lo)!=0); + if(hi >= 0x3E600000) + { + int i = 1018 - (hi>>20); + hi = (hi&0xFFFFF) | 0x100000; + return rounded(sign|(hi>>(i+1)), (hi>>i)&1, ((hi&((static_cast(1)<(sign); + return sign; + } + + /// Convert non-IEEE floating-point to half-precision. + /// \tparam R rounding mode to use + /// \tparam T source type (builtin floating-point type) + /// \param value floating-point value to convert + /// \return rounded half-precision value + /// \exception FE_OVERFLOW on overflows + /// \exception FE_UNDERFLOW on underflows + /// \exception FE_INEXACT if value had to be rounded + template unsigned int float2half_impl(T value, ...) + { + unsigned int hbits = static_cast(builtin_signbit(value)) << 15; + if(value == T()) + return hbits; + if(builtin_isnan(value)) + return hbits | 0x7FFF; + if(builtin_isinf(value)) + return hbits | 0x7C00; + int exp; + std::frexp(value, &exp); + if(exp > 16) + return overflow(hbits); + if(exp < -13) + value = std::ldexp(value, 25); + else + { + value = std::ldexp(value, 12-exp); + hbits |= ((exp+13)<<10); + } + T ival, frac = std::modf(value, &ival); + int m = std::abs(static_cast(ival)); + return rounded(hbits+(m>>1), m&1, frac!=T()); + } + + /// Convert floating-point to half-precision. + /// \tparam R rounding mode to use + /// \tparam T source type (builtin floating-point type) + /// \param value floating-point value to convert + /// \return rounded half-precision value + /// \exception FE_OVERFLOW on overflows + /// \exception FE_UNDERFLOW on underflows + /// \exception FE_INEXACT if value had to be rounded + template unsigned int float2half(T value) + { + return float2half_impl(value, bool_type::is_iec559&&sizeof(typename bits::type)==sizeof(T)>()); + } + + /// Convert integer to half-precision floating-point. + /// \tparam R rounding mode to use + /// \tparam T type to convert (builtin integer type) + /// \param value integral value to convert + /// \return rounded half-precision value + /// \exception FE_OVERFLOW on overflows + /// \exception FE_INEXACT if value had to be rounded + template unsigned int int2half(T value) + { + unsigned int bits = static_cast(value<0) << 15; + if(!value) + return bits; + if(bits) + value = -value; + if(value > 0xFFFF) + return overflow(bits); + unsigned int m = static_cast(value), exp = 24; + for(; m<0x400; m<<=1,--exp) ; + for(; m>0x7FF; m>>=1,++exp) ; + bits |= (exp<<10) + m; + return (exp>24) ? rounded(bits, (value>>(exp-25))&1, (((1<<(exp-25))-1)&value)!=0) : bits; + } + + /// Convert half-precision to IEEE single-precision. + /// Credit for this goes to [Jeroen van der Zijp](ftp://ftp.fox-toolkit.org/pub/fasthalffloatconversion.pdf). + /// \param value half-precision value to convert + /// \return single-precision value + inline float half2float_impl(unsigned int value, float, true_type) + { + #if HALF_ENABLE_F16C_INTRINSICS + return _mm_cvtss_f32(_mm_cvtph_ps(_mm_cvtsi32_si128(value))); + #else + #if 0 + bits::type fbits = static_cast::type>(value&0x8000) << 16; + int abs = value & 0x7FFF; + if(abs) + { + fbits |= 0x38000000 << static_cast(abs>=0x7C00); + for(; abs<0x400; abs<<=1,fbits-=0x800000) ; + fbits += static_cast::type>(abs) << 13; + } + #else + static const bits::type mantissa_table[2048] = { + 0x00000000, 0x33800000, 0x34000000, 0x34400000, 0x34800000, 0x34A00000, 0x34C00000, 0x34E00000, 0x35000000, 0x35100000, 0x35200000, 0x35300000, 0x35400000, 0x35500000, 0x35600000, 0x35700000, + 0x35800000, 0x35880000, 0x35900000, 0x35980000, 0x35A00000, 0x35A80000, 0x35B00000, 0x35B80000, 0x35C00000, 0x35C80000, 0x35D00000, 0x35D80000, 0x35E00000, 0x35E80000, 0x35F00000, 0x35F80000, + 0x36000000, 0x36040000, 0x36080000, 0x360C0000, 0x36100000, 0x36140000, 0x36180000, 0x361C0000, 0x36200000, 0x36240000, 0x36280000, 0x362C0000, 0x36300000, 0x36340000, 0x36380000, 0x363C0000, + 0x36400000, 0x36440000, 0x36480000, 0x364C0000, 0x36500000, 0x36540000, 0x36580000, 0x365C0000, 0x36600000, 0x36640000, 0x36680000, 0x366C0000, 0x36700000, 0x36740000, 0x36780000, 0x367C0000, + 0x36800000, 0x36820000, 0x36840000, 0x36860000, 0x36880000, 0x368A0000, 0x368C0000, 0x368E0000, 0x36900000, 0x36920000, 0x36940000, 0x36960000, 0x36980000, 0x369A0000, 0x369C0000, 0x369E0000, + 0x36A00000, 0x36A20000, 0x36A40000, 0x36A60000, 0x36A80000, 0x36AA0000, 0x36AC0000, 0x36AE0000, 0x36B00000, 0x36B20000, 0x36B40000, 0x36B60000, 0x36B80000, 0x36BA0000, 0x36BC0000, 0x36BE0000, + 0x36C00000, 0x36C20000, 0x36C40000, 0x36C60000, 0x36C80000, 0x36CA0000, 0x36CC0000, 0x36CE0000, 0x36D00000, 0x36D20000, 0x36D40000, 0x36D60000, 0x36D80000, 0x36DA0000, 0x36DC0000, 0x36DE0000, + 0x36E00000, 0x36E20000, 0x36E40000, 0x36E60000, 0x36E80000, 0x36EA0000, 0x36EC0000, 0x36EE0000, 0x36F00000, 0x36F20000, 0x36F40000, 0x36F60000, 0x36F80000, 0x36FA0000, 0x36FC0000, 0x36FE0000, + 0x37000000, 0x37010000, 0x37020000, 0x37030000, 0x37040000, 0x37050000, 0x37060000, 0x37070000, 0x37080000, 0x37090000, 0x370A0000, 0x370B0000, 0x370C0000, 0x370D0000, 0x370E0000, 0x370F0000, + 0x37100000, 0x37110000, 0x37120000, 0x37130000, 0x37140000, 0x37150000, 0x37160000, 0x37170000, 0x37180000, 0x37190000, 0x371A0000, 0x371B0000, 0x371C0000, 0x371D0000, 0x371E0000, 0x371F0000, + 0x37200000, 0x37210000, 0x37220000, 0x37230000, 0x37240000, 0x37250000, 0x37260000, 0x37270000, 0x37280000, 0x37290000, 0x372A0000, 0x372B0000, 0x372C0000, 0x372D0000, 0x372E0000, 0x372F0000, + 0x37300000, 0x37310000, 0x37320000, 0x37330000, 0x37340000, 0x37350000, 0x37360000, 0x37370000, 0x37380000, 0x37390000, 0x373A0000, 0x373B0000, 0x373C0000, 0x373D0000, 0x373E0000, 0x373F0000, + 0x37400000, 0x37410000, 0x37420000, 0x37430000, 0x37440000, 0x37450000, 0x37460000, 0x37470000, 0x37480000, 0x37490000, 0x374A0000, 0x374B0000, 0x374C0000, 0x374D0000, 0x374E0000, 0x374F0000, + 0x37500000, 0x37510000, 0x37520000, 0x37530000, 0x37540000, 0x37550000, 0x37560000, 0x37570000, 0x37580000, 0x37590000, 0x375A0000, 0x375B0000, 0x375C0000, 0x375D0000, 0x375E0000, 0x375F0000, + 0x37600000, 0x37610000, 0x37620000, 0x37630000, 0x37640000, 0x37650000, 0x37660000, 0x37670000, 0x37680000, 0x37690000, 0x376A0000, 0x376B0000, 0x376C0000, 0x376D0000, 0x376E0000, 0x376F0000, + 0x37700000, 0x37710000, 0x37720000, 0x37730000, 0x37740000, 0x37750000, 0x37760000, 0x37770000, 0x37780000, 0x37790000, 0x377A0000, 0x377B0000, 0x377C0000, 0x377D0000, 0x377E0000, 0x377F0000, + 0x37800000, 0x37808000, 0x37810000, 0x37818000, 0x37820000, 0x37828000, 0x37830000, 0x37838000, 0x37840000, 0x37848000, 0x37850000, 0x37858000, 0x37860000, 0x37868000, 0x37870000, 0x37878000, + 0x37880000, 0x37888000, 0x37890000, 0x37898000, 0x378A0000, 0x378A8000, 0x378B0000, 0x378B8000, 0x378C0000, 0x378C8000, 0x378D0000, 0x378D8000, 0x378E0000, 0x378E8000, 0x378F0000, 0x378F8000, + 0x37900000, 0x37908000, 0x37910000, 0x37918000, 0x37920000, 0x37928000, 0x37930000, 0x37938000, 0x37940000, 0x37948000, 0x37950000, 0x37958000, 0x37960000, 0x37968000, 0x37970000, 0x37978000, + 0x37980000, 0x37988000, 0x37990000, 0x37998000, 0x379A0000, 0x379A8000, 0x379B0000, 0x379B8000, 0x379C0000, 0x379C8000, 0x379D0000, 0x379D8000, 0x379E0000, 0x379E8000, 0x379F0000, 0x379F8000, + 0x37A00000, 0x37A08000, 0x37A10000, 0x37A18000, 0x37A20000, 0x37A28000, 0x37A30000, 0x37A38000, 0x37A40000, 0x37A48000, 0x37A50000, 0x37A58000, 0x37A60000, 0x37A68000, 0x37A70000, 0x37A78000, + 0x37A80000, 0x37A88000, 0x37A90000, 0x37A98000, 0x37AA0000, 0x37AA8000, 0x37AB0000, 0x37AB8000, 0x37AC0000, 0x37AC8000, 0x37AD0000, 0x37AD8000, 0x37AE0000, 0x37AE8000, 0x37AF0000, 0x37AF8000, + 0x37B00000, 0x37B08000, 0x37B10000, 0x37B18000, 0x37B20000, 0x37B28000, 0x37B30000, 0x37B38000, 0x37B40000, 0x37B48000, 0x37B50000, 0x37B58000, 0x37B60000, 0x37B68000, 0x37B70000, 0x37B78000, + 0x37B80000, 0x37B88000, 0x37B90000, 0x37B98000, 0x37BA0000, 0x37BA8000, 0x37BB0000, 0x37BB8000, 0x37BC0000, 0x37BC8000, 0x37BD0000, 0x37BD8000, 0x37BE0000, 0x37BE8000, 0x37BF0000, 0x37BF8000, + 0x37C00000, 0x37C08000, 0x37C10000, 0x37C18000, 0x37C20000, 0x37C28000, 0x37C30000, 0x37C38000, 0x37C40000, 0x37C48000, 0x37C50000, 0x37C58000, 0x37C60000, 0x37C68000, 0x37C70000, 0x37C78000, + 0x37C80000, 0x37C88000, 0x37C90000, 0x37C98000, 0x37CA0000, 0x37CA8000, 0x37CB0000, 0x37CB8000, 0x37CC0000, 0x37CC8000, 0x37CD0000, 0x37CD8000, 0x37CE0000, 0x37CE8000, 0x37CF0000, 0x37CF8000, + 0x37D00000, 0x37D08000, 0x37D10000, 0x37D18000, 0x37D20000, 0x37D28000, 0x37D30000, 0x37D38000, 0x37D40000, 0x37D48000, 0x37D50000, 0x37D58000, 0x37D60000, 0x37D68000, 0x37D70000, 0x37D78000, + 0x37D80000, 0x37D88000, 0x37D90000, 0x37D98000, 0x37DA0000, 0x37DA8000, 0x37DB0000, 0x37DB8000, 0x37DC0000, 0x37DC8000, 0x37DD0000, 0x37DD8000, 0x37DE0000, 0x37DE8000, 0x37DF0000, 0x37DF8000, + 0x37E00000, 0x37E08000, 0x37E10000, 0x37E18000, 0x37E20000, 0x37E28000, 0x37E30000, 0x37E38000, 0x37E40000, 0x37E48000, 0x37E50000, 0x37E58000, 0x37E60000, 0x37E68000, 0x37E70000, 0x37E78000, + 0x37E80000, 0x37E88000, 0x37E90000, 0x37E98000, 0x37EA0000, 0x37EA8000, 0x37EB0000, 0x37EB8000, 0x37EC0000, 0x37EC8000, 0x37ED0000, 0x37ED8000, 0x37EE0000, 0x37EE8000, 0x37EF0000, 0x37EF8000, + 0x37F00000, 0x37F08000, 0x37F10000, 0x37F18000, 0x37F20000, 0x37F28000, 0x37F30000, 0x37F38000, 0x37F40000, 0x37F48000, 0x37F50000, 0x37F58000, 0x37F60000, 0x37F68000, 0x37F70000, 0x37F78000, + 0x37F80000, 0x37F88000, 0x37F90000, 0x37F98000, 0x37FA0000, 0x37FA8000, 0x37FB0000, 0x37FB8000, 0x37FC0000, 0x37FC8000, 0x37FD0000, 0x37FD8000, 0x37FE0000, 0x37FE8000, 0x37FF0000, 0x37FF8000, + 0x38000000, 0x38004000, 0x38008000, 0x3800C000, 0x38010000, 0x38014000, 0x38018000, 0x3801C000, 0x38020000, 0x38024000, 0x38028000, 0x3802C000, 0x38030000, 0x38034000, 0x38038000, 0x3803C000, + 0x38040000, 0x38044000, 0x38048000, 0x3804C000, 0x38050000, 0x38054000, 0x38058000, 0x3805C000, 0x38060000, 0x38064000, 0x38068000, 0x3806C000, 0x38070000, 0x38074000, 0x38078000, 0x3807C000, + 0x38080000, 0x38084000, 0x38088000, 0x3808C000, 0x38090000, 0x38094000, 0x38098000, 0x3809C000, 0x380A0000, 0x380A4000, 0x380A8000, 0x380AC000, 0x380B0000, 0x380B4000, 0x380B8000, 0x380BC000, + 0x380C0000, 0x380C4000, 0x380C8000, 0x380CC000, 0x380D0000, 0x380D4000, 0x380D8000, 0x380DC000, 0x380E0000, 0x380E4000, 0x380E8000, 0x380EC000, 0x380F0000, 0x380F4000, 0x380F8000, 0x380FC000, + 0x38100000, 0x38104000, 0x38108000, 0x3810C000, 0x38110000, 0x38114000, 0x38118000, 0x3811C000, 0x38120000, 0x38124000, 0x38128000, 0x3812C000, 0x38130000, 0x38134000, 0x38138000, 0x3813C000, + 0x38140000, 0x38144000, 0x38148000, 0x3814C000, 0x38150000, 0x38154000, 0x38158000, 0x3815C000, 0x38160000, 0x38164000, 0x38168000, 0x3816C000, 0x38170000, 0x38174000, 0x38178000, 0x3817C000, + 0x38180000, 0x38184000, 0x38188000, 0x3818C000, 0x38190000, 0x38194000, 0x38198000, 0x3819C000, 0x381A0000, 0x381A4000, 0x381A8000, 0x381AC000, 0x381B0000, 0x381B4000, 0x381B8000, 0x381BC000, + 0x381C0000, 0x381C4000, 0x381C8000, 0x381CC000, 0x381D0000, 0x381D4000, 0x381D8000, 0x381DC000, 0x381E0000, 0x381E4000, 0x381E8000, 0x381EC000, 0x381F0000, 0x381F4000, 0x381F8000, 0x381FC000, + 0x38200000, 0x38204000, 0x38208000, 0x3820C000, 0x38210000, 0x38214000, 0x38218000, 0x3821C000, 0x38220000, 0x38224000, 0x38228000, 0x3822C000, 0x38230000, 0x38234000, 0x38238000, 0x3823C000, + 0x38240000, 0x38244000, 0x38248000, 0x3824C000, 0x38250000, 0x38254000, 0x38258000, 0x3825C000, 0x38260000, 0x38264000, 0x38268000, 0x3826C000, 0x38270000, 0x38274000, 0x38278000, 0x3827C000, + 0x38280000, 0x38284000, 0x38288000, 0x3828C000, 0x38290000, 0x38294000, 0x38298000, 0x3829C000, 0x382A0000, 0x382A4000, 0x382A8000, 0x382AC000, 0x382B0000, 0x382B4000, 0x382B8000, 0x382BC000, + 0x382C0000, 0x382C4000, 0x382C8000, 0x382CC000, 0x382D0000, 0x382D4000, 0x382D8000, 0x382DC000, 0x382E0000, 0x382E4000, 0x382E8000, 0x382EC000, 0x382F0000, 0x382F4000, 0x382F8000, 0x382FC000, + 0x38300000, 0x38304000, 0x38308000, 0x3830C000, 0x38310000, 0x38314000, 0x38318000, 0x3831C000, 0x38320000, 0x38324000, 0x38328000, 0x3832C000, 0x38330000, 0x38334000, 0x38338000, 0x3833C000, + 0x38340000, 0x38344000, 0x38348000, 0x3834C000, 0x38350000, 0x38354000, 0x38358000, 0x3835C000, 0x38360000, 0x38364000, 0x38368000, 0x3836C000, 0x38370000, 0x38374000, 0x38378000, 0x3837C000, + 0x38380000, 0x38384000, 0x38388000, 0x3838C000, 0x38390000, 0x38394000, 0x38398000, 0x3839C000, 0x383A0000, 0x383A4000, 0x383A8000, 0x383AC000, 0x383B0000, 0x383B4000, 0x383B8000, 0x383BC000, + 0x383C0000, 0x383C4000, 0x383C8000, 0x383CC000, 0x383D0000, 0x383D4000, 0x383D8000, 0x383DC000, 0x383E0000, 0x383E4000, 0x383E8000, 0x383EC000, 0x383F0000, 0x383F4000, 0x383F8000, 0x383FC000, + 0x38400000, 0x38404000, 0x38408000, 0x3840C000, 0x38410000, 0x38414000, 0x38418000, 0x3841C000, 0x38420000, 0x38424000, 0x38428000, 0x3842C000, 0x38430000, 0x38434000, 0x38438000, 0x3843C000, + 0x38440000, 0x38444000, 0x38448000, 0x3844C000, 0x38450000, 0x38454000, 0x38458000, 0x3845C000, 0x38460000, 0x38464000, 0x38468000, 0x3846C000, 0x38470000, 0x38474000, 0x38478000, 0x3847C000, + 0x38480000, 0x38484000, 0x38488000, 0x3848C000, 0x38490000, 0x38494000, 0x38498000, 0x3849C000, 0x384A0000, 0x384A4000, 0x384A8000, 0x384AC000, 0x384B0000, 0x384B4000, 0x384B8000, 0x384BC000, + 0x384C0000, 0x384C4000, 0x384C8000, 0x384CC000, 0x384D0000, 0x384D4000, 0x384D8000, 0x384DC000, 0x384E0000, 0x384E4000, 0x384E8000, 0x384EC000, 0x384F0000, 0x384F4000, 0x384F8000, 0x384FC000, + 0x38500000, 0x38504000, 0x38508000, 0x3850C000, 0x38510000, 0x38514000, 0x38518000, 0x3851C000, 0x38520000, 0x38524000, 0x38528000, 0x3852C000, 0x38530000, 0x38534000, 0x38538000, 0x3853C000, + 0x38540000, 0x38544000, 0x38548000, 0x3854C000, 0x38550000, 0x38554000, 0x38558000, 0x3855C000, 0x38560000, 0x38564000, 0x38568000, 0x3856C000, 0x38570000, 0x38574000, 0x38578000, 0x3857C000, + 0x38580000, 0x38584000, 0x38588000, 0x3858C000, 0x38590000, 0x38594000, 0x38598000, 0x3859C000, 0x385A0000, 0x385A4000, 0x385A8000, 0x385AC000, 0x385B0000, 0x385B4000, 0x385B8000, 0x385BC000, + 0x385C0000, 0x385C4000, 0x385C8000, 0x385CC000, 0x385D0000, 0x385D4000, 0x385D8000, 0x385DC000, 0x385E0000, 0x385E4000, 0x385E8000, 0x385EC000, 0x385F0000, 0x385F4000, 0x385F8000, 0x385FC000, + 0x38600000, 0x38604000, 0x38608000, 0x3860C000, 0x38610000, 0x38614000, 0x38618000, 0x3861C000, 0x38620000, 0x38624000, 0x38628000, 0x3862C000, 0x38630000, 0x38634000, 0x38638000, 0x3863C000, + 0x38640000, 0x38644000, 0x38648000, 0x3864C000, 0x38650000, 0x38654000, 0x38658000, 0x3865C000, 0x38660000, 0x38664000, 0x38668000, 0x3866C000, 0x38670000, 0x38674000, 0x38678000, 0x3867C000, + 0x38680000, 0x38684000, 0x38688000, 0x3868C000, 0x38690000, 0x38694000, 0x38698000, 0x3869C000, 0x386A0000, 0x386A4000, 0x386A8000, 0x386AC000, 0x386B0000, 0x386B4000, 0x386B8000, 0x386BC000, + 0x386C0000, 0x386C4000, 0x386C8000, 0x386CC000, 0x386D0000, 0x386D4000, 0x386D8000, 0x386DC000, 0x386E0000, 0x386E4000, 0x386E8000, 0x386EC000, 0x386F0000, 0x386F4000, 0x386F8000, 0x386FC000, + 0x38700000, 0x38704000, 0x38708000, 0x3870C000, 0x38710000, 0x38714000, 0x38718000, 0x3871C000, 0x38720000, 0x38724000, 0x38728000, 0x3872C000, 0x38730000, 0x38734000, 0x38738000, 0x3873C000, + 0x38740000, 0x38744000, 0x38748000, 0x3874C000, 0x38750000, 0x38754000, 0x38758000, 0x3875C000, 0x38760000, 0x38764000, 0x38768000, 0x3876C000, 0x38770000, 0x38774000, 0x38778000, 0x3877C000, + 0x38780000, 0x38784000, 0x38788000, 0x3878C000, 0x38790000, 0x38794000, 0x38798000, 0x3879C000, 0x387A0000, 0x387A4000, 0x387A8000, 0x387AC000, 0x387B0000, 0x387B4000, 0x387B8000, 0x387BC000, + 0x387C0000, 0x387C4000, 0x387C8000, 0x387CC000, 0x387D0000, 0x387D4000, 0x387D8000, 0x387DC000, 0x387E0000, 0x387E4000, 0x387E8000, 0x387EC000, 0x387F0000, 0x387F4000, 0x387F8000, 0x387FC000, + 0x38000000, 0x38002000, 0x38004000, 0x38006000, 0x38008000, 0x3800A000, 0x3800C000, 0x3800E000, 0x38010000, 0x38012000, 0x38014000, 0x38016000, 0x38018000, 0x3801A000, 0x3801C000, 0x3801E000, + 0x38020000, 0x38022000, 0x38024000, 0x38026000, 0x38028000, 0x3802A000, 0x3802C000, 0x3802E000, 0x38030000, 0x38032000, 0x38034000, 0x38036000, 0x38038000, 0x3803A000, 0x3803C000, 0x3803E000, + 0x38040000, 0x38042000, 0x38044000, 0x38046000, 0x38048000, 0x3804A000, 0x3804C000, 0x3804E000, 0x38050000, 0x38052000, 0x38054000, 0x38056000, 0x38058000, 0x3805A000, 0x3805C000, 0x3805E000, + 0x38060000, 0x38062000, 0x38064000, 0x38066000, 0x38068000, 0x3806A000, 0x3806C000, 0x3806E000, 0x38070000, 0x38072000, 0x38074000, 0x38076000, 0x38078000, 0x3807A000, 0x3807C000, 0x3807E000, + 0x38080000, 0x38082000, 0x38084000, 0x38086000, 0x38088000, 0x3808A000, 0x3808C000, 0x3808E000, 0x38090000, 0x38092000, 0x38094000, 0x38096000, 0x38098000, 0x3809A000, 0x3809C000, 0x3809E000, + 0x380A0000, 0x380A2000, 0x380A4000, 0x380A6000, 0x380A8000, 0x380AA000, 0x380AC000, 0x380AE000, 0x380B0000, 0x380B2000, 0x380B4000, 0x380B6000, 0x380B8000, 0x380BA000, 0x380BC000, 0x380BE000, + 0x380C0000, 0x380C2000, 0x380C4000, 0x380C6000, 0x380C8000, 0x380CA000, 0x380CC000, 0x380CE000, 0x380D0000, 0x380D2000, 0x380D4000, 0x380D6000, 0x380D8000, 0x380DA000, 0x380DC000, 0x380DE000, + 0x380E0000, 0x380E2000, 0x380E4000, 0x380E6000, 0x380E8000, 0x380EA000, 0x380EC000, 0x380EE000, 0x380F0000, 0x380F2000, 0x380F4000, 0x380F6000, 0x380F8000, 0x380FA000, 0x380FC000, 0x380FE000, + 0x38100000, 0x38102000, 0x38104000, 0x38106000, 0x38108000, 0x3810A000, 0x3810C000, 0x3810E000, 0x38110000, 0x38112000, 0x38114000, 0x38116000, 0x38118000, 0x3811A000, 0x3811C000, 0x3811E000, + 0x38120000, 0x38122000, 0x38124000, 0x38126000, 0x38128000, 0x3812A000, 0x3812C000, 0x3812E000, 0x38130000, 0x38132000, 0x38134000, 0x38136000, 0x38138000, 0x3813A000, 0x3813C000, 0x3813E000, + 0x38140000, 0x38142000, 0x38144000, 0x38146000, 0x38148000, 0x3814A000, 0x3814C000, 0x3814E000, 0x38150000, 0x38152000, 0x38154000, 0x38156000, 0x38158000, 0x3815A000, 0x3815C000, 0x3815E000, + 0x38160000, 0x38162000, 0x38164000, 0x38166000, 0x38168000, 0x3816A000, 0x3816C000, 0x3816E000, 0x38170000, 0x38172000, 0x38174000, 0x38176000, 0x38178000, 0x3817A000, 0x3817C000, 0x3817E000, + 0x38180000, 0x38182000, 0x38184000, 0x38186000, 0x38188000, 0x3818A000, 0x3818C000, 0x3818E000, 0x38190000, 0x38192000, 0x38194000, 0x38196000, 0x38198000, 0x3819A000, 0x3819C000, 0x3819E000, + 0x381A0000, 0x381A2000, 0x381A4000, 0x381A6000, 0x381A8000, 0x381AA000, 0x381AC000, 0x381AE000, 0x381B0000, 0x381B2000, 0x381B4000, 0x381B6000, 0x381B8000, 0x381BA000, 0x381BC000, 0x381BE000, + 0x381C0000, 0x381C2000, 0x381C4000, 0x381C6000, 0x381C8000, 0x381CA000, 0x381CC000, 0x381CE000, 0x381D0000, 0x381D2000, 0x381D4000, 0x381D6000, 0x381D8000, 0x381DA000, 0x381DC000, 0x381DE000, + 0x381E0000, 0x381E2000, 0x381E4000, 0x381E6000, 0x381E8000, 0x381EA000, 0x381EC000, 0x381EE000, 0x381F0000, 0x381F2000, 0x381F4000, 0x381F6000, 0x381F8000, 0x381FA000, 0x381FC000, 0x381FE000, + 0x38200000, 0x38202000, 0x38204000, 0x38206000, 0x38208000, 0x3820A000, 0x3820C000, 0x3820E000, 0x38210000, 0x38212000, 0x38214000, 0x38216000, 0x38218000, 0x3821A000, 0x3821C000, 0x3821E000, + 0x38220000, 0x38222000, 0x38224000, 0x38226000, 0x38228000, 0x3822A000, 0x3822C000, 0x3822E000, 0x38230000, 0x38232000, 0x38234000, 0x38236000, 0x38238000, 0x3823A000, 0x3823C000, 0x3823E000, + 0x38240000, 0x38242000, 0x38244000, 0x38246000, 0x38248000, 0x3824A000, 0x3824C000, 0x3824E000, 0x38250000, 0x38252000, 0x38254000, 0x38256000, 0x38258000, 0x3825A000, 0x3825C000, 0x3825E000, + 0x38260000, 0x38262000, 0x38264000, 0x38266000, 0x38268000, 0x3826A000, 0x3826C000, 0x3826E000, 0x38270000, 0x38272000, 0x38274000, 0x38276000, 0x38278000, 0x3827A000, 0x3827C000, 0x3827E000, + 0x38280000, 0x38282000, 0x38284000, 0x38286000, 0x38288000, 0x3828A000, 0x3828C000, 0x3828E000, 0x38290000, 0x38292000, 0x38294000, 0x38296000, 0x38298000, 0x3829A000, 0x3829C000, 0x3829E000, + 0x382A0000, 0x382A2000, 0x382A4000, 0x382A6000, 0x382A8000, 0x382AA000, 0x382AC000, 0x382AE000, 0x382B0000, 0x382B2000, 0x382B4000, 0x382B6000, 0x382B8000, 0x382BA000, 0x382BC000, 0x382BE000, + 0x382C0000, 0x382C2000, 0x382C4000, 0x382C6000, 0x382C8000, 0x382CA000, 0x382CC000, 0x382CE000, 0x382D0000, 0x382D2000, 0x382D4000, 0x382D6000, 0x382D8000, 0x382DA000, 0x382DC000, 0x382DE000, + 0x382E0000, 0x382E2000, 0x382E4000, 0x382E6000, 0x382E8000, 0x382EA000, 0x382EC000, 0x382EE000, 0x382F0000, 0x382F2000, 0x382F4000, 0x382F6000, 0x382F8000, 0x382FA000, 0x382FC000, 0x382FE000, + 0x38300000, 0x38302000, 0x38304000, 0x38306000, 0x38308000, 0x3830A000, 0x3830C000, 0x3830E000, 0x38310000, 0x38312000, 0x38314000, 0x38316000, 0x38318000, 0x3831A000, 0x3831C000, 0x3831E000, + 0x38320000, 0x38322000, 0x38324000, 0x38326000, 0x38328000, 0x3832A000, 0x3832C000, 0x3832E000, 0x38330000, 0x38332000, 0x38334000, 0x38336000, 0x38338000, 0x3833A000, 0x3833C000, 0x3833E000, + 0x38340000, 0x38342000, 0x38344000, 0x38346000, 0x38348000, 0x3834A000, 0x3834C000, 0x3834E000, 0x38350000, 0x38352000, 0x38354000, 0x38356000, 0x38358000, 0x3835A000, 0x3835C000, 0x3835E000, + 0x38360000, 0x38362000, 0x38364000, 0x38366000, 0x38368000, 0x3836A000, 0x3836C000, 0x3836E000, 0x38370000, 0x38372000, 0x38374000, 0x38376000, 0x38378000, 0x3837A000, 0x3837C000, 0x3837E000, + 0x38380000, 0x38382000, 0x38384000, 0x38386000, 0x38388000, 0x3838A000, 0x3838C000, 0x3838E000, 0x38390000, 0x38392000, 0x38394000, 0x38396000, 0x38398000, 0x3839A000, 0x3839C000, 0x3839E000, + 0x383A0000, 0x383A2000, 0x383A4000, 0x383A6000, 0x383A8000, 0x383AA000, 0x383AC000, 0x383AE000, 0x383B0000, 0x383B2000, 0x383B4000, 0x383B6000, 0x383B8000, 0x383BA000, 0x383BC000, 0x383BE000, + 0x383C0000, 0x383C2000, 0x383C4000, 0x383C6000, 0x383C8000, 0x383CA000, 0x383CC000, 0x383CE000, 0x383D0000, 0x383D2000, 0x383D4000, 0x383D6000, 0x383D8000, 0x383DA000, 0x383DC000, 0x383DE000, + 0x383E0000, 0x383E2000, 0x383E4000, 0x383E6000, 0x383E8000, 0x383EA000, 0x383EC000, 0x383EE000, 0x383F0000, 0x383F2000, 0x383F4000, 0x383F6000, 0x383F8000, 0x383FA000, 0x383FC000, 0x383FE000, + 0x38400000, 0x38402000, 0x38404000, 0x38406000, 0x38408000, 0x3840A000, 0x3840C000, 0x3840E000, 0x38410000, 0x38412000, 0x38414000, 0x38416000, 0x38418000, 0x3841A000, 0x3841C000, 0x3841E000, + 0x38420000, 0x38422000, 0x38424000, 0x38426000, 0x38428000, 0x3842A000, 0x3842C000, 0x3842E000, 0x38430000, 0x38432000, 0x38434000, 0x38436000, 0x38438000, 0x3843A000, 0x3843C000, 0x3843E000, + 0x38440000, 0x38442000, 0x38444000, 0x38446000, 0x38448000, 0x3844A000, 0x3844C000, 0x3844E000, 0x38450000, 0x38452000, 0x38454000, 0x38456000, 0x38458000, 0x3845A000, 0x3845C000, 0x3845E000, + 0x38460000, 0x38462000, 0x38464000, 0x38466000, 0x38468000, 0x3846A000, 0x3846C000, 0x3846E000, 0x38470000, 0x38472000, 0x38474000, 0x38476000, 0x38478000, 0x3847A000, 0x3847C000, 0x3847E000, + 0x38480000, 0x38482000, 0x38484000, 0x38486000, 0x38488000, 0x3848A000, 0x3848C000, 0x3848E000, 0x38490000, 0x38492000, 0x38494000, 0x38496000, 0x38498000, 0x3849A000, 0x3849C000, 0x3849E000, + 0x384A0000, 0x384A2000, 0x384A4000, 0x384A6000, 0x384A8000, 0x384AA000, 0x384AC000, 0x384AE000, 0x384B0000, 0x384B2000, 0x384B4000, 0x384B6000, 0x384B8000, 0x384BA000, 0x384BC000, 0x384BE000, + 0x384C0000, 0x384C2000, 0x384C4000, 0x384C6000, 0x384C8000, 0x384CA000, 0x384CC000, 0x384CE000, 0x384D0000, 0x384D2000, 0x384D4000, 0x384D6000, 0x384D8000, 0x384DA000, 0x384DC000, 0x384DE000, + 0x384E0000, 0x384E2000, 0x384E4000, 0x384E6000, 0x384E8000, 0x384EA000, 0x384EC000, 0x384EE000, 0x384F0000, 0x384F2000, 0x384F4000, 0x384F6000, 0x384F8000, 0x384FA000, 0x384FC000, 0x384FE000, + 0x38500000, 0x38502000, 0x38504000, 0x38506000, 0x38508000, 0x3850A000, 0x3850C000, 0x3850E000, 0x38510000, 0x38512000, 0x38514000, 0x38516000, 0x38518000, 0x3851A000, 0x3851C000, 0x3851E000, + 0x38520000, 0x38522000, 0x38524000, 0x38526000, 0x38528000, 0x3852A000, 0x3852C000, 0x3852E000, 0x38530000, 0x38532000, 0x38534000, 0x38536000, 0x38538000, 0x3853A000, 0x3853C000, 0x3853E000, + 0x38540000, 0x38542000, 0x38544000, 0x38546000, 0x38548000, 0x3854A000, 0x3854C000, 0x3854E000, 0x38550000, 0x38552000, 0x38554000, 0x38556000, 0x38558000, 0x3855A000, 0x3855C000, 0x3855E000, + 0x38560000, 0x38562000, 0x38564000, 0x38566000, 0x38568000, 0x3856A000, 0x3856C000, 0x3856E000, 0x38570000, 0x38572000, 0x38574000, 0x38576000, 0x38578000, 0x3857A000, 0x3857C000, 0x3857E000, + 0x38580000, 0x38582000, 0x38584000, 0x38586000, 0x38588000, 0x3858A000, 0x3858C000, 0x3858E000, 0x38590000, 0x38592000, 0x38594000, 0x38596000, 0x38598000, 0x3859A000, 0x3859C000, 0x3859E000, + 0x385A0000, 0x385A2000, 0x385A4000, 0x385A6000, 0x385A8000, 0x385AA000, 0x385AC000, 0x385AE000, 0x385B0000, 0x385B2000, 0x385B4000, 0x385B6000, 0x385B8000, 0x385BA000, 0x385BC000, 0x385BE000, + 0x385C0000, 0x385C2000, 0x385C4000, 0x385C6000, 0x385C8000, 0x385CA000, 0x385CC000, 0x385CE000, 0x385D0000, 0x385D2000, 0x385D4000, 0x385D6000, 0x385D8000, 0x385DA000, 0x385DC000, 0x385DE000, + 0x385E0000, 0x385E2000, 0x385E4000, 0x385E6000, 0x385E8000, 0x385EA000, 0x385EC000, 0x385EE000, 0x385F0000, 0x385F2000, 0x385F4000, 0x385F6000, 0x385F8000, 0x385FA000, 0x385FC000, 0x385FE000, + 0x38600000, 0x38602000, 0x38604000, 0x38606000, 0x38608000, 0x3860A000, 0x3860C000, 0x3860E000, 0x38610000, 0x38612000, 0x38614000, 0x38616000, 0x38618000, 0x3861A000, 0x3861C000, 0x3861E000, + 0x38620000, 0x38622000, 0x38624000, 0x38626000, 0x38628000, 0x3862A000, 0x3862C000, 0x3862E000, 0x38630000, 0x38632000, 0x38634000, 0x38636000, 0x38638000, 0x3863A000, 0x3863C000, 0x3863E000, + 0x38640000, 0x38642000, 0x38644000, 0x38646000, 0x38648000, 0x3864A000, 0x3864C000, 0x3864E000, 0x38650000, 0x38652000, 0x38654000, 0x38656000, 0x38658000, 0x3865A000, 0x3865C000, 0x3865E000, + 0x38660000, 0x38662000, 0x38664000, 0x38666000, 0x38668000, 0x3866A000, 0x3866C000, 0x3866E000, 0x38670000, 0x38672000, 0x38674000, 0x38676000, 0x38678000, 0x3867A000, 0x3867C000, 0x3867E000, + 0x38680000, 0x38682000, 0x38684000, 0x38686000, 0x38688000, 0x3868A000, 0x3868C000, 0x3868E000, 0x38690000, 0x38692000, 0x38694000, 0x38696000, 0x38698000, 0x3869A000, 0x3869C000, 0x3869E000, + 0x386A0000, 0x386A2000, 0x386A4000, 0x386A6000, 0x386A8000, 0x386AA000, 0x386AC000, 0x386AE000, 0x386B0000, 0x386B2000, 0x386B4000, 0x386B6000, 0x386B8000, 0x386BA000, 0x386BC000, 0x386BE000, + 0x386C0000, 0x386C2000, 0x386C4000, 0x386C6000, 0x386C8000, 0x386CA000, 0x386CC000, 0x386CE000, 0x386D0000, 0x386D2000, 0x386D4000, 0x386D6000, 0x386D8000, 0x386DA000, 0x386DC000, 0x386DE000, + 0x386E0000, 0x386E2000, 0x386E4000, 0x386E6000, 0x386E8000, 0x386EA000, 0x386EC000, 0x386EE000, 0x386F0000, 0x386F2000, 0x386F4000, 0x386F6000, 0x386F8000, 0x386FA000, 0x386FC000, 0x386FE000, + 0x38700000, 0x38702000, 0x38704000, 0x38706000, 0x38708000, 0x3870A000, 0x3870C000, 0x3870E000, 0x38710000, 0x38712000, 0x38714000, 0x38716000, 0x38718000, 0x3871A000, 0x3871C000, 0x3871E000, + 0x38720000, 0x38722000, 0x38724000, 0x38726000, 0x38728000, 0x3872A000, 0x3872C000, 0x3872E000, 0x38730000, 0x38732000, 0x38734000, 0x38736000, 0x38738000, 0x3873A000, 0x3873C000, 0x3873E000, + 0x38740000, 0x38742000, 0x38744000, 0x38746000, 0x38748000, 0x3874A000, 0x3874C000, 0x3874E000, 0x38750000, 0x38752000, 0x38754000, 0x38756000, 0x38758000, 0x3875A000, 0x3875C000, 0x3875E000, + 0x38760000, 0x38762000, 0x38764000, 0x38766000, 0x38768000, 0x3876A000, 0x3876C000, 0x3876E000, 0x38770000, 0x38772000, 0x38774000, 0x38776000, 0x38778000, 0x3877A000, 0x3877C000, 0x3877E000, + 0x38780000, 0x38782000, 0x38784000, 0x38786000, 0x38788000, 0x3878A000, 0x3878C000, 0x3878E000, 0x38790000, 0x38792000, 0x38794000, 0x38796000, 0x38798000, 0x3879A000, 0x3879C000, 0x3879E000, + 0x387A0000, 0x387A2000, 0x387A4000, 0x387A6000, 0x387A8000, 0x387AA000, 0x387AC000, 0x387AE000, 0x387B0000, 0x387B2000, 0x387B4000, 0x387B6000, 0x387B8000, 0x387BA000, 0x387BC000, 0x387BE000, + 0x387C0000, 0x387C2000, 0x387C4000, 0x387C6000, 0x387C8000, 0x387CA000, 0x387CC000, 0x387CE000, 0x387D0000, 0x387D2000, 0x387D4000, 0x387D6000, 0x387D8000, 0x387DA000, 0x387DC000, 0x387DE000, + 0x387E0000, 0x387E2000, 0x387E4000, 0x387E6000, 0x387E8000, 0x387EA000, 0x387EC000, 0x387EE000, 0x387F0000, 0x387F2000, 0x387F4000, 0x387F6000, 0x387F8000, 0x387FA000, 0x387FC000, 0x387FE000 }; + static const bits::type exponent_table[64] = { + 0x00000000, 0x00800000, 0x01000000, 0x01800000, 0x02000000, 0x02800000, 0x03000000, 0x03800000, 0x04000000, 0x04800000, 0x05000000, 0x05800000, 0x06000000, 0x06800000, 0x07000000, 0x07800000, + 0x08000000, 0x08800000, 0x09000000, 0x09800000, 0x0A000000, 0x0A800000, 0x0B000000, 0x0B800000, 0x0C000000, 0x0C800000, 0x0D000000, 0x0D800000, 0x0E000000, 0x0E800000, 0x0F000000, 0x47800000, + 0x80000000, 0x80800000, 0x81000000, 0x81800000, 0x82000000, 0x82800000, 0x83000000, 0x83800000, 0x84000000, 0x84800000, 0x85000000, 0x85800000, 0x86000000, 0x86800000, 0x87000000, 0x87800000, + 0x88000000, 0x88800000, 0x89000000, 0x89800000, 0x8A000000, 0x8A800000, 0x8B000000, 0x8B800000, 0x8C000000, 0x8C800000, 0x8D000000, 0x8D800000, 0x8E000000, 0x8E800000, 0x8F000000, 0xC7800000 }; + static const unsigned short offset_table[64] = { + 0, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, + 0, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024 }; + bits::type fbits = mantissa_table[offset_table[value>>10]+(value&0x3FF)] + exponent_table[value>>10]; + #endif + float out; + std::memcpy(&out, &fbits, sizeof(float)); + return out; + #endif + } + + /// Convert half-precision to IEEE double-precision. + /// \param value half-precision value to convert + /// \return double-precision value + inline double half2float_impl(unsigned int value, double, true_type) + { + #if HALF_ENABLE_F16C_INTRINSICS + return _mm_cvtsd_f64(_mm_cvtps_pd(_mm_cvtph_ps(_mm_cvtsi32_si128(value)))); + #else + uint32 hi = static_cast(value&0x8000) << 16; + unsigned int abs = value & 0x7FFF; + if(abs) + { + hi |= 0x3F000000 << static_cast(abs>=0x7C00); + for(; abs<0x400; abs<<=1,hi-=0x100000) ; + hi += static_cast(abs) << 10; + } + bits::type dbits = static_cast::type>(hi) << 32; + double out; + std::memcpy(&out, &dbits, sizeof(double)); + return out; + #endif + } + + /// Convert half-precision to non-IEEE floating-point. + /// \tparam T type to convert to (builtin integer type) + /// \param value half-precision value to convert + /// \return floating-point value + template T half2float_impl(unsigned int value, T, ...) + { + T out; + unsigned int abs = value & 0x7FFF; + if(abs > 0x7C00) + out = (std::numeric_limits::has_signaling_NaN && !(abs&0x200)) ? std::numeric_limits::signaling_NaN() : + std::numeric_limits::has_quiet_NaN ? std::numeric_limits::quiet_NaN() : T(); + else if(abs == 0x7C00) + out = std::numeric_limits::has_infinity ? std::numeric_limits::infinity() : std::numeric_limits::max(); + else if(abs > 0x3FF) + out = std::ldexp(static_cast((abs&0x3FF)|0x400), (abs>>10)-25); + else + out = std::ldexp(static_cast(abs), -24); + return (value&0x8000) ? -out : out; + } + + /// Convert half-precision to floating-point. + /// \tparam T type to convert to (builtin integer type) + /// \param value half-precision value to convert + /// \return floating-point value + template T half2float(unsigned int value) + { + return half2float_impl(value, T(), bool_type::is_iec559&&sizeof(typename bits::type)==sizeof(T)>()); + } + + /// Convert half-precision floating-point to integer. + /// \tparam R rounding mode to use + /// \tparam E `true` for round to even, `false` for round away from zero + /// \tparam I `true` to raise INEXACT exception (if inexact), `false` to never raise it + /// \tparam T type to convert to (buitlin integer type with at least 16 bits precision, excluding any implicit sign bits) + /// \param value half-precision value to convert + /// \return rounded integer value + /// \exception FE_INVALID if value is not representable in type \a T + /// \exception FE_INEXACT if value had to be rounded and \a I is `true` + template T half2int(unsigned int value) + { + unsigned int abs = value & 0x7FFF; + if(abs >= 0x7C00) + { + raise(FE_INVALID); + return (value&0x8000) ? std::numeric_limits::min() : std::numeric_limits::max(); + } + if(abs < 0x3800) + { + raise(FE_INEXACT, I); + return (R==std::round_toward_infinity) ? T(~(value>>15)&(abs!=0)) : + (R==std::round_toward_neg_infinity) ? -T(value>0x8000) : + T(); + } + int exp = 25 - (abs>>10); + unsigned int m = (value&0x3FF) | 0x400; + int32 i = static_cast((exp<=0) ? (m<<-exp) : ((m+( + (R==std::round_to_nearest) ? ((1<<(exp-1))-(~(m>>exp)&E)) : + (R==std::round_toward_infinity) ? (((1<>15)-1)) : + (R==std::round_toward_neg_infinity) ? (((1<>15)) : 0))>>exp)); + if((!std::numeric_limits::is_signed && (value&0x8000)) || (std::numeric_limits::digits<16 && + ((value&0x8000) ? (-i::min()) : (i>std::numeric_limits::max())))) + raise(FE_INVALID); + else if(I && exp > 0 && (m&((1<((value&0x8000) ? -i : i); + } + + /// \} + /// \name Mathematics + /// \{ + + /// upper part of 64-bit multiplication. + /// \tparam R rounding mode to use + /// \param x first factor + /// \param y second factor + /// \return upper 32 bit of \a x * \a y + template uint32 mulhi(uint32 x, uint32 y) + { + uint32 xy = (x>>16) * (y&0xFFFF), yx = (x&0xFFFF) * (y>>16), c = (xy&0xFFFF) + (yx&0xFFFF) + (((x&0xFFFF)*(y&0xFFFF))>>16); + return (x>>16)*(y>>16) + (xy>>16) + (yx>>16) + (c>>16) + + ((R==std::round_to_nearest) ? ((c>>15)&1) : (R==std::round_toward_infinity) ? ((c&0xFFFF)!=0) : 0); + } + + /// 64-bit multiplication. + /// \param x first factor + /// \param y second factor + /// \return upper 32 bit of \a x * \a y rounded to nearest + inline uint32 multiply64(uint32 x, uint32 y) + { + #if HALF_ENABLE_CPP11_LONG_LONG + return static_cast((static_cast(x)*static_cast(y)+0x80000000)>>32); + #else + return mulhi(x, y); + #endif + } + + /// 64-bit division. + /// \param x upper 32 bit of dividend + /// \param y divisor + /// \param s variable to store sticky bit for rounding + /// \return (\a x << 32) / \a y + inline uint32 divide64(uint32 x, uint32 y, int &s) + { + #if HALF_ENABLE_CPP11_LONG_LONG + unsigned long long xx = static_cast(x) << 32; + return s = (xx%y!=0), static_cast(xx/y); + #else + y >>= 1; + uint32 rem = x, div = 0; + for(unsigned int i=0; i<32; ++i) + { + div <<= 1; + if(rem >= y) + { + rem -= y; + div |= 1; + } + rem <<= 1; + } + return s = rem > 1, div; + #endif + } + + /// Half precision positive modulus. + /// \tparam Q `true` to compute full quotient, `false` else + /// \tparam R `true` to compute signed remainder, `false` for positive remainder + /// \param x first operand as positive finite half-precision value + /// \param y second operand as positive finite half-precision value + /// \param quo adress to store quotient at, `nullptr` if \a Q `false` + /// \return modulus of \a x / \a y + template unsigned int mod(unsigned int x, unsigned int y, int *quo = NULL) + { + unsigned int q = 0; + if(x > y) + { + int absx = x, absy = y, expx = 0, expy = 0; + for(; absx<0x400; absx<<=1,--expx) ; + for(; absy<0x400; absy<<=1,--expy) ; + expx += absx >> 10; + expy += absy >> 10; + int mx = (absx&0x3FF) | 0x400, my = (absy&0x3FF) | 0x400; + for(int d=expx-expy; d; --d) + { + if(!Q && mx == my) + return 0; + if(mx >= my) + { + mx -= my; + q += Q; + } + mx <<= 1; + q <<= static_cast(Q); + } + if(!Q && mx == my) + return 0; + if(mx >= my) + { + mx -= my; + ++q; + } + if(Q) + { + q &= (1<<(std::numeric_limits::digits-1)) - 1; + if(!mx) + return *quo = q, 0; + } + for(; mx<0x400; mx<<=1,--expy) ; + x = (expy>0) ? ((expy<<10)|(mx&0x3FF)) : (mx>>(1-expy)); + } + if(R) + { + unsigned int a, b; + if(y < 0x800) + { + a = (x<0x400) ? (x<<1) : (x+0x400); + b = y; + } + else + { + a = x; + b = y - 0x400; + } + if(a > b || (a == b && (q&1))) + { + int exp = (y>>10) + (y<=0x3FF), d = exp - (x>>10) - (x<=0x3FF); + int m = (((y&0x3FF)|((y>0x3FF)<<10))<<1) - (((x&0x3FF)|((x>0x3FF)<<10))<<(1-d)); + for(; m<0x800 && exp>1; m<<=1,--exp) ; + x = 0x8000 + ((exp-1)<<10) + (m>>1); + q += Q; + } + } + if(Q) + *quo = q; + return x; + } + + /// Fixed point square root. + /// \tparam F number of fractional bits + /// \param r radicand in Q1.F fixed point format + /// \param exp exponent + /// \return square root as Q1.F/2 + template uint32 sqrt(uint32 &r, int &exp) + { + int i = exp & 1; + r <<= i; + exp = (exp-i) / 2; + uint32 m = 0; + for(uint32 bit=static_cast(1)<>=2) + { + if(r < m+bit) + m >>= 1; + else + { + r -= m + bit; + m = (m>>1) + bit; + } + } + return m; + } + + /// Fixed point binary exponential. + /// This uses the BKM algorithm in E-mode. + /// \param m exponent in [0,1) as Q0.31 + /// \param n number of iterations (at most 32) + /// \return 2 ^ \a m as Q1.31 + inline uint32 exp2(uint32 m, unsigned int n = 32) + { + static const uint32 logs[] = { + 0x80000000, 0x4AE00D1D, 0x2934F098, 0x15C01A3A, 0x0B31FB7D, 0x05AEB4DD, 0x02DCF2D1, 0x016FE50B, + 0x00B84E23, 0x005C3E10, 0x002E24CA, 0x001713D6, 0x000B8A47, 0x0005C53B, 0x0002E2A3, 0x00017153, + 0x0000B8AA, 0x00005C55, 0x00002E2B, 0x00001715, 0x00000B8B, 0x000005C5, 0x000002E3, 0x00000171, + 0x000000B9, 0x0000005C, 0x0000002E, 0x00000017, 0x0000000C, 0x00000006, 0x00000003, 0x00000001 }; + if(!m) + return 0x80000000; + uint32 mx = 0x80000000, my = 0; + for(unsigned int i=1; i> i; + } + } + return mx; + } + + /// Fixed point binary logarithm. + /// This uses the BKM algorithm in L-mode. + /// \param m mantissa in [1,2) as Q1.30 + /// \param n number of iterations (at most 32) + /// \return log2(\a m) as Q0.31 + inline uint32 log2(uint32 m, unsigned int n = 32) + { + static const uint32 logs[] = { + 0x80000000, 0x4AE00D1D, 0x2934F098, 0x15C01A3A, 0x0B31FB7D, 0x05AEB4DD, 0x02DCF2D1, 0x016FE50B, + 0x00B84E23, 0x005C3E10, 0x002E24CA, 0x001713D6, 0x000B8A47, 0x0005C53B, 0x0002E2A3, 0x00017153, + 0x0000B8AA, 0x00005C55, 0x00002E2B, 0x00001715, 0x00000B8B, 0x000005C5, 0x000002E3, 0x00000171, + 0x000000B9, 0x0000005C, 0x0000002E, 0x00000017, 0x0000000C, 0x00000006, 0x00000003, 0x00000001 }; + if(m == 0x40000000) + return 0; + uint32 mx = 0x40000000, my = 0; + for(unsigned int i=1; i>i); + if(mz <= m) + { + mx = mz; + my += logs[i]; + } + } + return my; + } + + /// Fixed point sine and cosine. + /// This uses the CORDIC algorithm in rotation mode. + /// \param mz angle in [-pi/2,pi/2] as Q1.30 + /// \param n number of iterations (at most 31) + /// \return sine and cosine of \a mz as Q1.30 + inline std::pair sincos(uint32 mz, unsigned int n = 31) + { + static const uint32 angles[] = { + 0x3243F6A9, 0x1DAC6705, 0x0FADBAFD, 0x07F56EA7, 0x03FEAB77, 0x01FFD55C, 0x00FFFAAB, 0x007FFF55, + 0x003FFFEB, 0x001FFFFD, 0x00100000, 0x00080000, 0x00040000, 0x00020000, 0x00010000, 0x00008000, + 0x00004000, 0x00002000, 0x00001000, 0x00000800, 0x00000400, 0x00000200, 0x00000100, 0x00000080, + 0x00000040, 0x00000020, 0x00000010, 0x00000008, 0x00000004, 0x00000002, 0x00000001 }; + uint32 mx = 0x26DD3B6A, my = 0; + for(unsigned int i=0; i0x3FF)<<10); + int exp = (abs>>10) + (abs<=0x3FF) - 15; + if(abs < 0x3A48) + return k = 0, m << (exp+20); + #if HALF_ENABLE_CPP11_LONG_LONG + unsigned long long y = m * 0xA2F9836E4E442, mask = (1ULL<<(62-exp)) - 1, yi = (y+(mask>>1)) & ~mask, f = y - yi; + uint32 sign = -static_cast(f>>63); + k = static_cast(yi>>(62-exp)); + return (multiply64(static_cast((sign ? -f : f)>>(31-exp)), 0xC90FDAA2)^sign) - sign; + #else + uint32 yh = m*0xA2F98 + mulhi(m, 0x36E4E442), yl = (m*0x36E4E442) & 0xFFFFFFFF; + uint32 mask = (static_cast(1)<<(30-exp)) - 1, yi = (yh+(mask>>1)) & ~mask, sign = -static_cast(yi>yh); + k = static_cast(yi>>(30-exp)); + uint32 fh = (yh^sign) + (yi^~sign) - ~sign, fl = (yl^sign) - sign; + return (multiply64((exp>-1) ? (((fh<<(1+exp))&0xFFFFFFFF)|((fl&0xFFFFFFFF)>>(31-exp))) : fh, 0xC90FDAA2)^sign) - sign; + #endif + } + + /// Get arguments for atan2 function. + /// \param abs half-precision floating-point value + /// \return \a abs and sqrt(1 - \a abs^2) as Q0.30 + inline std::pair atan2_args(unsigned int abs) + { + int exp = -15; + for(; abs<0x400; abs<<=1,--exp) ; + exp += abs >> 10; + uint32 my = ((abs&0x3FF)|0x400) << 5, r = my * my; + int rexp = 2 * exp; + r = 0x40000000 - ((rexp>-31) ? ((r>>-rexp)|((r&((static_cast(1)<<-rexp)-1))!=0)) : 1); + for(rexp=0; r<0x40000000; r<<=1,--rexp) ; + uint32 mx = sqrt<30>(r, rexp); + int d = exp - rexp; + if(d < 0) + return std::make_pair((d<-14) ? ((my>>(-d-14))+((my>>(-d-15))&1)) : (my<<(14+d)), (mx<<14)+(r<<13)/mx); + if(d > 0) + return std::make_pair(my<<14, (d>14) ? ((mx>>(d-14))+((mx>>(d-15))&1)) : ((d==14) ? mx : ((mx<<(14-d))+(r<<(13-d))/mx))); + return std::make_pair(my<<13, (mx<<13)+(r<<12)/mx); + } + + /// Get exponentials for hyperbolic computation + /// \param abs half-precision floating-point value + /// \param exp variable to take unbiased exponent of larger result + /// \param n number of BKM iterations (at most 32) + /// \return exp(abs) and exp(-\a abs) as Q1.31 with same exponent + inline std::pair hyperbolic_args(unsigned int abs, int &exp, unsigned int n = 32) + { + uint32 mx = detail::multiply64(static_cast((abs&0x3FF)+((abs>0x3FF)<<10))<<21, 0xB8AA3B29), my; + int e = (abs>>10) + (abs<=0x3FF); + if(e < 14) + { + exp = 0; + mx >>= 14 - e; + } + else + { + exp = mx >> (45-e); + mx = (mx<<(e-14)) & 0x7FFFFFFF; + } + mx = exp2(mx, n); + int d = exp << 1, s; + if(mx > 0x80000000) + { + my = divide64(0x80000000, mx, s); + my |= s; + ++d; + } + else + my = mx; + return std::make_pair(mx, (d<31) ? ((my>>d)|((my&((static_cast(1)< unsigned int exp2_post(uint32 m, int exp, bool esign, unsigned int sign = 0, unsigned int n = 32) + { + if(esign) + { + exp = -exp - (m!=0); + if(exp < -25) + return underflow(sign); + else if(exp == -25) + return rounded(sign, 1, m!=0); + } + else if(exp > 15) + return overflow(sign); + if(!m) + return sign | (((exp+=15)>0) ? (exp<<10) : check_underflow(0x200>>-exp)); + m = exp2(m, n); + int s = 0; + if(esign) + m = divide64(0x80000000, m, s); + return fixed2half(m, exp+14, sign, s); + } + + /// Postprocessing for binary logarithm. + /// \tparam R rounding mode to use + /// \tparam L logarithm for base transformation as Q1.31 + /// \param m fractional part of logarithm as Q0.31 + /// \param ilog signed integer part of logarithm + /// \param exp biased exponent of result + /// \param sign sign bit of result + /// \return value base-transformed and converted to half-precision + /// \exception FE_OVERFLOW on overflows + /// \exception FE_UNDERFLOW on underflows + /// \exception FE_INEXACT if no other exception occurred + template unsigned int log2_post(uint32 m, int ilog, int exp, unsigned int sign = 0) + { + uint32 msign = sign_mask(ilog); + m = (((static_cast(ilog)<<27)+(m>>4))^msign) - msign; + if(!m) + return 0; + for(; m<0x80000000; m<<=1,--exp) ; + int i = m >= L, s; + exp += i; + m >>= 1 + i; + sign ^= msign & 0x8000; + if(exp < -11) + return underflow(sign); + m = divide64(m, L, s); + return fixed2half(m, exp, sign, 1); + } + + /// Hypotenuse square root and postprocessing. + /// \tparam R rounding mode to use + /// \param r mantissa as Q2.30 + /// \param exp biased exponent + /// \return square root converted to half-precision + /// \exception FE_OVERFLOW on overflows + /// \exception FE_UNDERFLOW on underflows + /// \exception FE_INEXACT if value had to be rounded + template unsigned int hypot_post(uint32 r, int exp) + { + int i = r >> 31; + if((exp+=i) > 46) + return overflow(); + if(exp < -34) + return underflow(); + r = (r>>i) | (r&i); + uint32 m = sqrt<30>(r, exp+=15); + return fixed2half(m, exp-1, 0, r!=0); + } + + /// Division and postprocessing for tangents. + /// \tparam R rounding mode to use + /// \param my dividend as Q1.31 + /// \param mx divisor as Q1.31 + /// \param exp biased exponent of result + /// \param sign sign bit of result + /// \return quotient converted to half-precision + /// \exception FE_OVERFLOW on overflows + /// \exception FE_UNDERFLOW on underflows + /// \exception FE_INEXACT if no other exception occurred + template unsigned int tangent_post(uint32 my, uint32 mx, int exp, unsigned int sign = 0) + { + int i = my >= mx, s; + exp += i; + if(exp > 29) + return overflow(sign); + if(exp < -11) + return underflow(sign); + uint32 m = divide64(my>>(i+1), mx, s); + return fixed2half(m, exp, sign, s); + } + + /// Area function and postprocessing. + /// This computes the value directly in Q2.30 using the representation `asinh|acosh(x) = log(x+sqrt(x^2+|-1))`. + /// \tparam R rounding mode to use + /// \tparam S `true` for asinh, `false` for acosh + /// \param arg half-precision argument + /// \return asinh|acosh(\a arg) converted to half-precision + /// \exception FE_OVERFLOW on overflows + /// \exception FE_UNDERFLOW on underflows + /// \exception FE_INEXACT if no other exception occurred + template unsigned int area(unsigned int arg) + { + int abs = arg & 0x7FFF, expx = (abs>>10) + (abs<=0x3FF) - 15, expy = -15, ilog, i; + uint32 mx = static_cast((abs&0x3FF)|((abs>0x3FF)<<10)) << 20, my, r; + for(; abs<0x400; abs<<=1,--expy) ; + expy += abs >> 10; + r = ((abs&0x3FF)|0x400) << 5; + r *= r; + i = r >> 31; + expy = 2*expy + i; + r >>= i; + if(S) + { + if(expy < 0) + { + r = 0x40000000 + ((expy>-30) ? ((r>>-expy)|((r&((static_cast(1)<<-expy)-1))!=0)) : 1); + expy = 0; + } + else + { + r += 0x40000000 >> expy; + i = r >> 31; + r = (r>>i) | (r&i); + expy += i; + } + } + else + { + r -= 0x40000000 >> expy; + for(; r<0x40000000; r<<=1,--expy) ; + } + my = sqrt<30>(r, expy); + my = (my<<15) + (r<<14)/my; + if(S) + { + mx >>= expy - expx; + ilog = expy; + } + else + { + my >>= expx - expy; + ilog = expx; + } + my += mx; + i = my >> 31; + static const int G = S && (R==std::round_to_nearest); + return log2_post(log2(my>>i, 26+S+G)+(G<<3), ilog+i, 17, arg&(static_cast(S)<<15)); + } + + /// Class for 1.31 unsigned floating-point computation + struct f31 + { + /// Constructor. + /// \param mant mantissa as 1.31 + /// \param e exponent + HALF_CONSTEXPR f31(uint32 mant, int e) : m(mant), exp(e) {} + + /// Constructor. + /// \param abs unsigned half-precision value + f31(unsigned int abs) : exp(-15) + { + for(; abs<0x400; abs<<=1,--exp) ; + m = static_cast((abs&0x3FF)|0x400) << 21; + exp += (abs>>10); + } + + /// Addition operator. + /// \param a first operand + /// \param b second operand + /// \return \a a + \a b + friend f31 operator+(f31 a, f31 b) + { + if(b.exp > a.exp) + std::swap(a, b); + int d = a.exp - b.exp; + uint32 m = a.m + ((d<32) ? (b.m>>d) : 0); + int i = (m&0xFFFFFFFF) < a.m; + return f31(((m+i)>>i)|0x80000000, a.exp+i); + } + + /// Subtraction operator. + /// \param a first operand + /// \param b second operand + /// \return \a a - \a b + friend f31 operator-(f31 a, f31 b) + { + int d = a.exp - b.exp, exp = a.exp; + uint32 m = a.m - ((d<32) ? (b.m>>d) : 0); + if(!m) + return f31(0, -32); + for(; m<0x80000000; m<<=1,--exp) ; + return f31(m, exp); + } + + /// Multiplication operator. + /// \param a first operand + /// \param b second operand + /// \return \a a * \a b + friend f31 operator*(f31 a, f31 b) + { + uint32 m = multiply64(a.m, b.m); + int i = m >> 31; + return f31(m<<(1-i), a.exp + b.exp + i); + } + + /// Division operator. + /// \param a first operand + /// \param b second operand + /// \return \a a / \a b + friend f31 operator/(f31 a, f31 b) + { + int i = a.m >= b.m, s; + uint32 m = divide64((a.m+i)>>i, b.m, s); + return f31(m, a.exp - b.exp + i - 1); + } + + uint32 m; ///< mantissa as 1.31. + int exp; ///< exponent. + }; + + /// Error function and postprocessing. + /// This computes the value directly in Q1.31 using the approximations given + /// [here](https://en.wikipedia.org/wiki/Error_function#Approximation_with_elementary_functions). + /// \tparam R rounding mode to use + /// \tparam C `true` for comlementary error function, `false` else + /// \param arg half-precision function argument + /// \return approximated value of error function in half-precision + /// \exception FE_OVERFLOW on overflows + /// \exception FE_UNDERFLOW on underflows + /// \exception FE_INEXACT if no other exception occurred + template unsigned int erf(unsigned int arg) + { + unsigned int abs = arg & 0x7FFF, sign = arg & 0x8000; + f31 x(abs), x2 = x * x * f31(0xB8AA3B29, 0), t = f31(0x80000000, 0) / (f31(0x80000000, 0)+f31(0xA7BA054A, -2)*x), t2 = t * t; + f31 e = ((f31(0x87DC2213, 0)*t2+f31(0xB5F0E2AE, 0))*t2+f31(0x82790637, -2)-(f31(0xBA00E2B8, 0)*t2+f31(0x91A98E62, -2))*t) * t / + ((x2.exp<0) ? f31(exp2((x2.exp>-32) ? (x2.m>>-x2.exp) : 0, 30), 0) : f31(exp2((x2.m<>(31-x2.exp))); + return (!C || sign) ? fixed2half(0x80000000-(e.m>>(C-e.exp)), 14+C, sign&(C-1U)) : + (e.exp<-25) ? underflow() : fixed2half(e.m>>1, e.exp+14, 0, e.m&1); + } + + /// Gamma function and postprocessing. + /// This approximates the value of either the gamma function or its logarithm directly in Q1.31. + /// \tparam R rounding mode to use + /// \tparam L `true` for lograithm of gamma function, `false` for gamma function + /// \param arg half-precision floating-point value + /// \return lgamma/tgamma(\a arg) in half-precision + /// \exception FE_OVERFLOW on overflows + /// \exception FE_UNDERFLOW on underflows + /// \exception FE_INEXACT if \a arg is not a positive integer + template unsigned int gamma(unsigned int arg) + { +/* static const double p[] ={ 2.50662827563479526904, 225.525584619175212544, -268.295973841304927459, 80.9030806934622512966, -5.00757863970517583837, 0.0114684895434781459556 }; + double t = arg + 4.65, s = p[0]; + for(unsigned int i=0; i<5; ++i) + s += p[i+1] / (arg+i); + return std::log(s) + (arg-0.5)*std::log(t) - t; +*/ static const f31 pi(0xC90FDAA2, 1), lbe(0xB8AA3B29, 0); + unsigned int abs = arg & 0x7FFF, sign = arg & 0x8000; + bool bsign = sign != 0; + f31 z(abs), x = sign ? (z+f31(0x80000000, 0)) : z, t = x + f31(0x94CCCCCD, 2), s = + f31(0xA06C9901, 1) + f31(0xBBE654E2, -7)/(x+f31(0x80000000, 2)) + f31(0xA1CE6098, 6)/(x+f31(0x80000000, 1)) + + f31(0xE1868CB7, 7)/x - f31(0x8625E279, 8)/(x+f31(0x80000000, 0)) - f31(0xA03E158F, 2)/(x+f31(0xC0000000, 1)); + int i = (s.exp>=2) + (s.exp>=4) + (s.exp>=8) + (s.exp>=16); + s = f31((static_cast(s.exp)<<(31-i))+(log2(s.m>>1, 28)>>i), i) / lbe; + if(x.exp != -1 || x.m != 0x80000000) + { + i = (t.exp>=2) + (t.exp>=4) + (t.exp>=8); + f31 l = f31((static_cast(t.exp)<<(31-i))+(log2(t.m>>1, 30)>>i), i) / lbe; + s = (x.exp<-1) ? (s-(f31(0x80000000, -1)-x)*l) : (s+(x-f31(0x80000000, -1))*l); + } + s = x.exp ? (s-t) : (t-s); + if(bsign) + { + if(z.exp >= 0) + { + sign &= (L|((z.m>>(31-z.exp))&1)) - 1; + for(z=f31((z.m<<(1+z.exp))&0xFFFFFFFF, -1); z.m<0x80000000; z.m<<=1,--z.exp) ; + } + if(z.exp == -1) + z = f31(0x80000000, 0) - z; + if(z.exp < -1) + { + z = z * pi; + z.m = sincos(z.m>>(1-z.exp), 30).first; + for(z.exp=1; z.m<0x80000000; z.m<<=1,--z.exp) ; + } + else + z = f31(0x80000000, 0); + } + if(L) + { + if(bsign) + { + f31 l(0x92868247, 0); + if(z.exp < 0) + { + uint32 m = log2((z.m+1)>>1, 27); + z = f31(-((static_cast(z.exp)<<26)+(m>>5)), 5); + for(; z.m<0x80000000; z.m<<=1,--z.exp) ; + l = l + z / lbe; + } + sign = static_cast(x.exp&&(l.exp(x.exp==0) << 15; + if(s.exp < -24) + return underflow(sign); + if(s.exp > 15) + return overflow(sign); + } + } + else + { + s = s * lbe; + uint32 m; + if(s.exp < 0) + { + m = s.m >> -s.exp; + s.exp = 0; + } + else + { + m = (s.m<>(31-s.exp)); + } + s.m = exp2(m, 27); + if(!x.exp) + s = f31(0x80000000, 0) / s; + if(bsign) + { + if(z.exp < 0) + s = s * z; + s = pi / s; + if(s.exp < -24) + return underflow(sign); + } + else if(z.exp > 0 && !(z.m&((1<<(31-z.exp))-1))) + return ((s.exp+14)<<10) + (s.m>>21); + if(s.exp > 15) + return overflow(sign); + } + return fixed2half(s.m, s.exp+14, sign); + } + /// \} + + template struct half_caster; + } + + /// Half-precision floating-point type. + /// This class implements an IEEE-conformant half-precision floating-point type with the usual arithmetic + /// operators and conversions. It is implicitly convertible to single-precision floating-point, which makes artihmetic + /// expressions and functions with mixed-type operands to be of the most precise operand type. + /// + /// According to the C++98/03 definition, the half type is not a POD type. But according to C++11's less strict and + /// extended definitions it is both a standard layout type and a trivially copyable type (even if not a POD type), which + /// means it can be standard-conformantly copied using raw binary copies. But in this context some more words about the + /// actual size of the type. Although the half is representing an IEEE 16-bit type, it does not neccessarily have to be of + /// exactly 16-bits size. But on any reasonable implementation the actual binary representation of this type will most + /// probably not ivolve any additional "magic" or padding beyond the simple binary representation of the underlying 16-bit + /// IEEE number, even if not strictly guaranteed by the standard. But even then it only has an actual size of 16 bits if + /// your C++ implementation supports an unsigned integer type of exactly 16 bits width. But this should be the case on + /// nearly any reasonable platform. + /// + /// So if your C++ implementation is not totally exotic or imposes special alignment requirements, it is a reasonable + /// assumption that the data of a half is just comprised of the 2 bytes of the underlying IEEE representation. + class half + { + public: + /// \name Construction and assignment + /// \{ + + /// Default constructor. + /// This initializes the half to 0. Although this does not match the builtin types' default-initialization semantics + /// and may be less efficient than no initialization, it is needed to provide proper value-initialization semantics. + HALF_CONSTEXPR half() HALF_NOEXCEPT : data_() {} + + /// Conversion constructor. + /// \param rhs float to convert + /// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding + explicit half(float rhs) : data_(static_cast(detail::float2half(rhs))) {} + + /// Conversion to single-precision. + /// \return single precision value representing expression value + operator float() const { return detail::half2float(data_); } + + /// Assignment operator. + /// \param rhs single-precision value to copy from + /// \return reference to this half + /// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding + half& operator=(float rhs) { data_ = static_cast(detail::float2half(rhs)); return *this; } + + /// \} + /// \name Arithmetic updates + /// \{ + + /// Arithmetic assignment. + /// \tparam T type of concrete half expression + /// \param rhs half expression to add + /// \return reference to this half + /// \exception FE_... according to operator+(half,half) + half& operator+=(half rhs) { return *this = *this + rhs; } + + /// Arithmetic assignment. + /// \tparam T type of concrete half expression + /// \param rhs half expression to subtract + /// \return reference to this half + /// \exception FE_... according to operator-(half,half) + half& operator-=(half rhs) { return *this = *this - rhs; } + + /// Arithmetic assignment. + /// \tparam T type of concrete half expression + /// \param rhs half expression to multiply with + /// \return reference to this half + /// \exception FE_... according to operator*(half,half) + half& operator*=(half rhs) { return *this = *this * rhs; } + + /// Arithmetic assignment. + /// \tparam T type of concrete half expression + /// \param rhs half expression to divide by + /// \return reference to this half + /// \exception FE_... according to operator/(half,half) + half& operator/=(half rhs) { return *this = *this / rhs; } + + /// Arithmetic assignment. + /// \param rhs single-precision value to add + /// \return reference to this half + /// \exception FE_... according to operator=() + half& operator+=(float rhs) { return *this = *this + rhs; } + + /// Arithmetic assignment. + /// \param rhs single-precision value to subtract + /// \return reference to this half + /// \exception FE_... according to operator=() + half& operator-=(float rhs) { return *this = *this - rhs; } + + /// Arithmetic assignment. + /// \param rhs single-precision value to multiply with + /// \return reference to this half + /// \exception FE_... according to operator=() + half& operator*=(float rhs) { return *this = *this * rhs; } + + /// Arithmetic assignment. + /// \param rhs single-precision value to divide by + /// \return reference to this half + /// \exception FE_... according to operator=() + half& operator/=(float rhs) { return *this = *this / rhs; } + + /// \} + /// \name Increment and decrement + /// \{ + + /// Prefix increment. + /// \return incremented half value + /// \exception FE_... according to operator+(half,half) + half& operator++() { return *this = *this + half(detail::binary, 0x3C00); } + + /// Prefix decrement. + /// \return decremented half value + /// \exception FE_... according to operator-(half,half) + half& operator--() { return *this = *this + half(detail::binary, 0xBC00); } + + /// Postfix increment. + /// \return non-incremented half value + /// \exception FE_... according to operator+(half,half) + half operator++(int) { half out(*this); ++*this; return out; } + + /// Postfix decrement. + /// \return non-decremented half value + /// \exception FE_... according to operator-(half,half) + half operator--(int) { half out(*this); --*this; return out; } + /// \} + + private: + /// Rounding mode to use + static const std::float_round_style round_style = (std::float_round_style)(HALF_ROUND_STYLE); + + /// Constructor. + /// \param bits binary representation to set half to + HALF_CONSTEXPR half(detail::binary_t, unsigned int bits) HALF_NOEXCEPT : data_(static_cast(bits)) {} + + /// Internal binary representation + detail::uint16 data_; + + #ifndef HALF_DOXYGEN_ONLY + friend HALF_CONSTEXPR_NOERR bool operator==(half, half); + friend HALF_CONSTEXPR_NOERR bool operator!=(half, half); + friend HALF_CONSTEXPR_NOERR bool operator<(half, half); + friend HALF_CONSTEXPR_NOERR bool operator>(half, half); + friend HALF_CONSTEXPR_NOERR bool operator<=(half, half); + friend HALF_CONSTEXPR_NOERR bool operator>=(half, half); + friend HALF_CONSTEXPR half operator-(half); + friend half operator+(half, half); + friend half operator-(half, half); + friend half operator*(half, half); + friend half operator/(half, half); + template friend std::basic_ostream& operator<<(std::basic_ostream&, half); + template friend std::basic_istream& operator>>(std::basic_istream&, half&); + friend HALF_CONSTEXPR half fabs(half); + friend half fmod(half, half); + friend half remainder(half, half); + friend half remquo(half, half, int*); + friend half fma(half, half, half); + friend HALF_CONSTEXPR_NOERR half fmax(half, half); + friend HALF_CONSTEXPR_NOERR half fmin(half, half); + friend half fdim(half, half); + friend half nanh(const char*); + friend half exp(half); + friend half exp2(half); + friend half expm1(half); + friend half log(half); + friend half log10(half); + friend half log2(half); + friend half log1p(half); + friend half sqrt(half); + friend half rsqrt(half); + friend half cbrt(half); + friend half hypot(half, half); + friend half hypot(half, half, half); + friend half pow(half, half); + friend void sincos(half, half*, half*); + friend half sin(half); + friend half cos(half); + friend half tan(half); + friend half asin(half); + friend half acos(half); + friend half atan(half); + friend half atan2(half, half); + friend half sinh(half); + friend half cosh(half); + friend half tanh(half); + friend half asinh(half); + friend half acosh(half); + friend half atanh(half); + friend half erf(half); + friend half erfc(half); + friend half lgamma(half); + friend half tgamma(half); + friend half ceil(half); + friend half floor(half); + friend half trunc(half); + friend half round(half); + friend long lround(half); + friend half rint(half); + friend long lrint(half); + friend half nearbyint(half); + #ifdef HALF_ENABLE_CPP11_LONG_LONG + friend long long llround(half); + friend long long llrint(half); + #endif + friend half frexp(half, int*); + friend half scalbln(half, long); + friend half modf(half, half*); + friend int ilogb(half); + friend half logb(half); + friend half nextafter(half, half); + friend half nexttoward(half, long double); + friend HALF_CONSTEXPR half copysign(half, half); + friend HALF_CONSTEXPR int fpclassify(half); + friend HALF_CONSTEXPR bool isfinite(half); + friend HALF_CONSTEXPR bool isinf(half); + friend HALF_CONSTEXPR bool isnan(half); + friend HALF_CONSTEXPR bool isnormal(half); + friend HALF_CONSTEXPR bool signbit(half); + friend HALF_CONSTEXPR bool isgreater(half, half); + friend HALF_CONSTEXPR bool isgreaterequal(half, half); + friend HALF_CONSTEXPR bool isless(half, half); + friend HALF_CONSTEXPR bool islessequal(half, half); + friend HALF_CONSTEXPR bool islessgreater(half, half); + template friend struct detail::half_caster; + friend class std::numeric_limits; + #if HALF_ENABLE_CPP11_HASH + friend struct std::hash; + #endif + #if HALF_ENABLE_CPP11_USER_LITERALS + friend half literal::operator "" _h(long double); + #endif + #endif + }; + +#if HALF_ENABLE_CPP11_USER_LITERALS + namespace literal + { + /// Half literal. + /// While this returns a properly rounded half-precision value, half literals can unfortunately not be constant + /// expressions due to rather involved conversions. So don't expect this to be a literal literal without involving + /// conversion operations at runtime. It is a convenience feature, not a performance optimization. + /// \param value literal value + /// \return half with of given value (possibly rounded) + /// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding + inline half operator "" _h(long double value) { return half(detail::binary, detail::float2half(value)); } + } +#endif + + namespace detail + { + /// Helper class for half casts. + /// This class template has to be specialized for all valid cast arguments to define an appropriate static + /// `cast` member function and a corresponding `type` member denoting its return type. + /// \tparam T destination type + /// \tparam U source type + /// \tparam R rounding mode to use + template struct half_caster {}; + template struct half_caster + { + #if HALF_ENABLE_CPP11_STATIC_ASSERT && HALF_ENABLE_CPP11_TYPE_TRAITS + static_assert(std::is_arithmetic::value, "half_cast from non-arithmetic type unsupported"); + #endif + + static half cast(U arg) { return cast_impl(arg, is_float()); }; + + private: + static half cast_impl(U arg, true_type) { return half(binary, float2half(arg)); } + static half cast_impl(U arg, false_type) { return half(binary, int2half(arg)); } + }; + template struct half_caster + { + #if HALF_ENABLE_CPP11_STATIC_ASSERT && HALF_ENABLE_CPP11_TYPE_TRAITS + static_assert(std::is_arithmetic::value, "half_cast to non-arithmetic type unsupported"); + #endif + + static T cast(half arg) { return cast_impl(arg, is_float()); } + + private: + static T cast_impl(half arg, true_type) { return half2float(arg.data_); } + static T cast_impl(half arg, false_type) { return half2int(arg.data_); } + }; + template struct half_caster + { + static half cast(half arg) { return arg; } + }; + } +} + +/// Extensions to the C++ standard library. +namespace std +{ + /// Numeric limits for half-precision floats. + /// **See also:** Documentation for [std::numeric_limits](https://en.cppreference.com/w/cpp/types/numeric_limits) + template<> class numeric_limits + { + public: + /// Is template specialization. + static HALF_CONSTEXPR_CONST bool is_specialized = true; + + /// Supports signed values. + static HALF_CONSTEXPR_CONST bool is_signed = true; + + /// Is not an integer type. + static HALF_CONSTEXPR_CONST bool is_integer = false; + + /// Is not exact. + static HALF_CONSTEXPR_CONST bool is_exact = false; + + /// Doesn't provide modulo arithmetic. + static HALF_CONSTEXPR_CONST bool is_modulo = false; + + /// Has a finite set of values. + static HALF_CONSTEXPR_CONST bool is_bounded = true; + + /// IEEE conformant. + static HALF_CONSTEXPR_CONST bool is_iec559 = true; + + /// Supports infinity. + static HALF_CONSTEXPR_CONST bool has_infinity = true; + + /// Supports quiet NaNs. + static HALF_CONSTEXPR_CONST bool has_quiet_NaN = true; + + /// Supports signaling NaNs. + static HALF_CONSTEXPR_CONST bool has_signaling_NaN = true; + + /// Supports subnormal values. + static HALF_CONSTEXPR_CONST float_denorm_style has_denorm = denorm_present; + + /// Supports no denormalization detection. + static HALF_CONSTEXPR_CONST bool has_denorm_loss = false; + + #if HALF_ERRHANDLING_THROWS + static HALF_CONSTEXPR_CONST bool traps = true; + #else + /// Traps only if [HALF_ERRHANDLING_THROW_...](\ref HALF_ERRHANDLING_THROW_INVALID) is acitvated. + static HALF_CONSTEXPR_CONST bool traps = false; + #endif + + /// Does not support no pre-rounding underflow detection. + static HALF_CONSTEXPR_CONST bool tinyness_before = false; + + /// Rounding mode. + static HALF_CONSTEXPR_CONST float_round_style round_style = half_float::half::round_style; + + /// Significant digits. + static HALF_CONSTEXPR_CONST int digits = 11; + + /// Significant decimal digits. + static HALF_CONSTEXPR_CONST int digits10 = 3; + + /// Required decimal digits to represent all possible values. + static HALF_CONSTEXPR_CONST int max_digits10 = 5; + + /// Number base. + static HALF_CONSTEXPR_CONST int radix = 2; + + /// One more than smallest exponent. + static HALF_CONSTEXPR_CONST int min_exponent = -13; + + /// Smallest normalized representable power of 10. + static HALF_CONSTEXPR_CONST int min_exponent10 = -4; + + /// One more than largest exponent + static HALF_CONSTEXPR_CONST int max_exponent = 16; + + /// Largest finitely representable power of 10. + static HALF_CONSTEXPR_CONST int max_exponent10 = 4; + + /// Smallest positive normal value. + static HALF_CONSTEXPR half_float::half min() HALF_NOTHROW { return half_float::half(half_float::detail::binary, 0x0400); } + + /// Smallest finite value. + static HALF_CONSTEXPR half_float::half lowest() HALF_NOTHROW { return half_float::half(half_float::detail::binary, 0xFBFF); } + + /// Largest finite value. + static HALF_CONSTEXPR half_float::half max() HALF_NOTHROW { return half_float::half(half_float::detail::binary, 0x7BFF); } + + /// Difference between 1 and next representable value. + static HALF_CONSTEXPR half_float::half epsilon() HALF_NOTHROW { return half_float::half(half_float::detail::binary, 0x1400); } + + /// Maximum rounding error in ULP (units in the last place). + static HALF_CONSTEXPR half_float::half round_error() HALF_NOTHROW + { return half_float::half(half_float::detail::binary, (round_style==std::round_to_nearest) ? 0x3800 : 0x3C00); } + + /// Positive infinity. + static HALF_CONSTEXPR half_float::half infinity() HALF_NOTHROW { return half_float::half(half_float::detail::binary, 0x7C00); } + + /// Quiet NaN. + static HALF_CONSTEXPR half_float::half quiet_NaN() HALF_NOTHROW { return half_float::half(half_float::detail::binary, 0x7FFF); } + + /// Signaling NaN. + static HALF_CONSTEXPR half_float::half signaling_NaN() HALF_NOTHROW { return half_float::half(half_float::detail::binary, 0x7DFF); } + + /// Smallest positive subnormal value. + static HALF_CONSTEXPR half_float::half denorm_min() HALF_NOTHROW { return half_float::half(half_float::detail::binary, 0x0001); } + }; + +#if HALF_ENABLE_CPP11_HASH + /// Hash function for half-precision floats. + /// This is only defined if C++11 `std::hash` is supported and enabled. + /// + /// **See also:** Documentation for [std::hash](https://en.cppreference.com/w/cpp/utility/hash) + template<> struct hash + { + /// Type of function argument. + typedef half_float::half argument_type; + + /// Function return type. + typedef size_t result_type; + + /// Compute hash function. + /// \param arg half to hash + /// \return hash value + result_type operator()(argument_type arg) const { return hash()(arg.data_&-static_cast(arg.data_!=0x8000)); } + }; +#endif +} + +namespace half_float +{ + /// \anchor compop + /// \name Comparison operators + /// \{ + + /// Comparison for equality. + /// \param x first operand + /// \param y second operand + /// \retval true if operands equal + /// \retval false else + /// \exception FE_INVALID if \a x or \a y is NaN + inline HALF_CONSTEXPR_NOERR bool operator==(half x, half y) + { + return !detail::compsignal(x.data_, y.data_) && (x.data_==y.data_ || !((x.data_|y.data_)&0x7FFF)); + } + + /// Comparison for inequality. + /// \param x first operand + /// \param y second operand + /// \retval true if operands not equal + /// \retval false else + /// \exception FE_INVALID if \a x or \a y is NaN + inline HALF_CONSTEXPR_NOERR bool operator!=(half x, half y) + { + return detail::compsignal(x.data_, y.data_) || (x.data_!=y.data_ && ((x.data_|y.data_)&0x7FFF)); + } + + /// Comparison for less than. + /// \param x first operand + /// \param y second operand + /// \retval true if \a x less than \a y + /// \retval false else + /// \exception FE_INVALID if \a x or \a y is NaN + inline HALF_CONSTEXPR_NOERR bool operator<(half x, half y) + { + return !detail::compsignal(x.data_, y.data_) && + ((x.data_^(0x8000|(0x8000-(x.data_>>15))))+(x.data_>>15)) < ((y.data_^(0x8000|(0x8000-(y.data_>>15))))+(y.data_>>15)); + } + + /// Comparison for greater than. + /// \param x first operand + /// \param y second operand + /// \retval true if \a x greater than \a y + /// \retval false else + /// \exception FE_INVALID if \a x or \a y is NaN + inline HALF_CONSTEXPR_NOERR bool operator>(half x, half y) + { + return !detail::compsignal(x.data_, y.data_) && + ((x.data_^(0x8000|(0x8000-(x.data_>>15))))+(x.data_>>15)) > ((y.data_^(0x8000|(0x8000-(y.data_>>15))))+(y.data_>>15)); + } + + /// Comparison for less equal. + /// \param x first operand + /// \param y second operand + /// \retval true if \a x less equal \a y + /// \retval false else + /// \exception FE_INVALID if \a x or \a y is NaN + inline HALF_CONSTEXPR_NOERR bool operator<=(half x, half y) + { + return !detail::compsignal(x.data_, y.data_) && + ((x.data_^(0x8000|(0x8000-(x.data_>>15))))+(x.data_>>15)) <= ((y.data_^(0x8000|(0x8000-(y.data_>>15))))+(y.data_>>15)); + } + + /// Comparison for greater equal. + /// \param x first operand + /// \param y second operand + /// \retval true if \a x greater equal \a y + /// \retval false else + /// \exception FE_INVALID if \a x or \a y is NaN + inline HALF_CONSTEXPR_NOERR bool operator>=(half x, half y) + { + return !detail::compsignal(x.data_, y.data_) && + ((x.data_^(0x8000|(0x8000-(x.data_>>15))))+(x.data_>>15)) >= ((y.data_^(0x8000|(0x8000-(y.data_>>15))))+(y.data_>>15)); + } + + /// \} + /// \anchor arithmetics + /// \name Arithmetic operators + /// \{ + + /// Identity. + /// \param arg operand + /// \return unchanged operand + inline HALF_CONSTEXPR half operator+(half arg) { return arg; } + + /// Negation. + /// \param arg operand + /// \return negated operand + inline HALF_CONSTEXPR half operator-(half arg) { return half(detail::binary, arg.data_^0x8000); } + + /// Addition. + /// This operation is exact to rounding for all rounding modes. + /// \param x left operand + /// \param y right operand + /// \return sum of half expressions + /// \exception FE_INVALID if \a x and \a y are infinities with different signs or signaling NaNs + /// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding + inline half operator+(half x, half y) + { + #ifdef HALF_ARITHMETIC_TYPE + return half(detail::binary, detail::float2half(detail::half2float(x.data_)+detail::half2float(y.data_))); + #else + int absx = x.data_ & 0x7FFF, absy = y.data_ & 0x7FFF; + bool sub = ((x.data_^y.data_)&0x8000) != 0; + if(absx >= 0x7C00 || absy >= 0x7C00) + return half(detail::binary, (absx>0x7C00 || absy>0x7C00) ? detail::signal(x.data_, y.data_) : (absy!=0x7C00) ? x.data_ : + (sub && absx==0x7C00) ? detail::invalid() : y.data_); + if(!absx) + return absy ? y : half(detail::binary, (half::round_style==std::round_toward_neg_infinity) ? (x.data_|y.data_) : (x.data_&y.data_)); + if(!absy) + return x; + unsigned int sign = ((sub && absy>absx) ? y.data_ : x.data_) & 0x8000; + if(absy > absx) + std::swap(absx, absy); + int exp = (absx>>10) + (absx<=0x3FF), d = exp - (absy>>10) - (absy<=0x3FF), mx = ((absx&0x3FF)|((absx>0x3FF)<<10)) << 3, my; + if(d < 13) + { + my = ((absy&0x3FF)|((absy>0x3FF)<<10)) << 3; + my = (my>>d) | ((my&((1<(half::round_style==std::round_toward_neg_infinity)<<15); + for(; mx<0x2000 && exp>1; mx<<=1,--exp) ; + } + else + { + mx += my; + int i = mx >> 14; + if((exp+=i) > 30) + return half(detail::binary, detail::overflow(sign)); + mx = (mx>>i) | (mx&i); + } + return half(detail::binary, detail::rounded(sign+((exp-1)<<10)+(mx>>3), (mx>>2)&1, (mx&0x3)!=0)); + #endif + } + + /// Subtraction. + /// This operation is exact to rounding for all rounding modes. + /// \param x left operand + /// \param y right operand + /// \return difference of half expressions + /// \exception FE_INVALID if \a x and \a y are infinities with equal signs or signaling NaNs + /// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding + inline half operator-(half x, half y) + { + #ifdef HALF_ARITHMETIC_TYPE + return half(detail::binary, detail::float2half(detail::half2float(x.data_)-detail::half2float(y.data_))); + #else + return x + -y; + #endif + } + + /// Multiplication. + /// This operation is exact to rounding for all rounding modes. + /// \param x left operand + /// \param y right operand + /// \return product of half expressions + /// \exception FE_INVALID if multiplying 0 with infinity or if \a x or \a y is signaling NaN + /// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding + inline half operator*(half x, half y) + { + #ifdef HALF_ARITHMETIC_TYPE + return half(detail::binary, detail::float2half(detail::half2float(x.data_)*detail::half2float(y.data_))); + #else + int absx = x.data_ & 0x7FFF, absy = y.data_ & 0x7FFF, exp = -16; + unsigned int sign = (x.data_^y.data_) & 0x8000; + if(absx >= 0x7C00 || absy >= 0x7C00) + return half(detail::binary, (absx>0x7C00 || absy>0x7C00) ? detail::signal(x.data_, y.data_) : + ((absx==0x7C00 && !absy)||(absy==0x7C00 && !absx)) ? detail::invalid() : (sign|0x7C00)); + if(!absx || !absy) + return half(detail::binary, sign); + for(; absx<0x400; absx<<=1,--exp) ; + for(; absy<0x400; absy<<=1,--exp) ; + detail::uint32 m = static_cast((absx&0x3FF)|0x400) * static_cast((absy&0x3FF)|0x400); + int i = m >> 21, s = m & i; + exp += (absx>>10) + (absy>>10) + i; + if(exp > 29) + return half(detail::binary, detail::overflow(sign)); + else if(exp < -11) + return half(detail::binary, detail::underflow(sign)); + return half(detail::binary, detail::fixed2half(m>>i, exp, sign, s)); + #endif + } + + /// Division. + /// This operation is exact to rounding for all rounding modes. + /// \param x left operand + /// \param y right operand + /// \return quotient of half expressions + /// \exception FE_INVALID if dividing 0s or infinities with each other or if \a x or \a y is signaling NaN + /// \exception FE_DIVBYZERO if dividing finite value by 0 + /// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding + inline half operator/(half x, half y) + { + #ifdef HALF_ARITHMETIC_TYPE + return half(detail::binary, detail::float2half(detail::half2float(x.data_)/detail::half2float(y.data_))); + #else + int absx = x.data_ & 0x7FFF, absy = y.data_ & 0x7FFF, exp = 14; + unsigned int sign = (x.data_^y.data_) & 0x8000; + if(absx >= 0x7C00 || absy >= 0x7C00) + return half(detail::binary, (absx>0x7C00 || absy>0x7C00) ? detail::signal(x.data_, y.data_) : + (absx==absy) ? detail::invalid() : (sign|((absx==0x7C00) ? 0x7C00 : 0))); + if(!absx) + return half(detail::binary, absy ? sign : detail::invalid()); + if(!absy) + return half(detail::binary, detail::pole(sign)); + for(; absx<0x400; absx<<=1,--exp) ; + for(; absy<0x400; absy<<=1,++exp) ; + detail::uint32 mx = (absx&0x3FF) | 0x400, my = (absy&0x3FF) | 0x400; + int i = mx < my; + exp += (absx>>10) - (absy>>10) - i; + if(exp > 29) + return half(detail::binary, detail::overflow(sign)); + else if(exp < -11) + return half(detail::binary, detail::underflow(sign)); + mx <<= 12 + i; + my <<= 1; + return half(detail::binary, detail::fixed2half(mx/my, exp, sign, mx%my!=0)); + #endif + } + + /// \} + /// \anchor streaming + /// \name Input and output + /// \{ + + /// Output operator. + /// This uses the built-in functionality for streaming out floating-point numbers. + /// \param out output stream to write into + /// \param arg half expression to write + /// \return reference to output stream + template std::basic_ostream& operator<<(std::basic_ostream &out, half arg) + { + #ifdef HALF_ARITHMETIC_TYPE + return out << detail::half2float(arg.data_); + #else + return out << detail::half2float(arg.data_); + #endif + } + + /// Input operator. + /// This uses the built-in functionality for streaming in floating-point numbers, specifically double precision floating + /// point numbers (unless overridden with [HALF_ARITHMETIC_TYPE](\ref HALF_ARITHMETIC_TYPE)). So the input string is first + /// rounded to double precision using the underlying platform's current floating-point rounding mode before being rounded + /// to half-precision using the library's half-precision rounding mode. + /// \param in input stream to read from + /// \param arg half to read into + /// \return reference to input stream + /// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding + template std::basic_istream& operator>>(std::basic_istream &in, half &arg) + { + #ifdef HALF_ARITHMETIC_TYPE + detail::internal_t f; + #else + double f; + #endif + if(in >> f) + arg.data_ = detail::float2half(f); + return in; + } + + /// \} + /// \anchor basic + /// \name Basic mathematical operations + /// \{ + + /// Absolute value. + /// **See also:** Documentation for [std::fabs](https://en.cppreference.com/w/cpp/numeric/math/fabs). + /// \param arg operand + /// \return absolute value of \a arg + inline HALF_CONSTEXPR half fabs(half arg) { return half(detail::binary, arg.data_&0x7FFF); } + + /// Absolute value. + /// **See also:** Documentation for [std::abs](https://en.cppreference.com/w/cpp/numeric/math/fabs). + /// \param arg operand + /// \return absolute value of \a arg + inline HALF_CONSTEXPR half abs(half arg) { return fabs(arg); } + + /// Remainder of division. + /// **See also:** Documentation for [std::fmod](https://en.cppreference.com/w/cpp/numeric/math/fmod). + /// \param x first operand + /// \param y second operand + /// \return remainder of floating-point division. + /// \exception FE_INVALID if \a x is infinite or \a y is 0 or if \a x or \a y is signaling NaN + inline half fmod(half x, half y) + { + unsigned int absx = x.data_ & 0x7FFF, absy = y.data_ & 0x7FFF, sign = x.data_ & 0x8000; + if(absx >= 0x7C00 || absy >= 0x7C00) + return half(detail::binary, (absx>0x7C00 || absy>0x7C00) ? detail::signal(x.data_, y.data_) : + (absx==0x7C00) ? detail::invalid() : x.data_); + if(!absy) + return half(detail::binary, detail::invalid()); + if(!absx) + return x; + if(absx == absy) + return half(detail::binary, sign); + return half(detail::binary, sign|detail::mod(absx, absy)); + } + + /// Remainder of division. + /// **See also:** Documentation for [std::remainder](https://en.cppreference.com/w/cpp/numeric/math/remainder). + /// \param x first operand + /// \param y second operand + /// \return remainder of floating-point division. + /// \exception FE_INVALID if \a x is infinite or \a y is 0 or if \a x or \a y is signaling NaN + inline half remainder(half x, half y) + { + unsigned int absx = x.data_ & 0x7FFF, absy = y.data_ & 0x7FFF, sign = x.data_ & 0x8000; + if(absx >= 0x7C00 || absy >= 0x7C00) + return half(detail::binary, (absx>0x7C00 || absy>0x7C00) ? detail::signal(x.data_, y.data_) : + (absx==0x7C00) ? detail::invalid() : x.data_); + if(!absy) + return half(detail::binary, detail::invalid()); + if(absx == absy) + return half(detail::binary, sign); + return half(detail::binary, sign^detail::mod(absx, absy)); + } + + /// Remainder of division. + /// **See also:** Documentation for [std::remquo](https://en.cppreference.com/w/cpp/numeric/math/remquo). + /// \param x first operand + /// \param y second operand + /// \param quo address to store some bits of quotient at + /// \return remainder of floating-point division. + /// \exception FE_INVALID if \a x is infinite or \a y is 0 or if \a x or \a y is signaling NaN + inline half remquo(half x, half y, int *quo) + { + unsigned int absx = x.data_ & 0x7FFF, absy = y.data_ & 0x7FFF, value = x.data_ & 0x8000; + if(absx >= 0x7C00 || absy >= 0x7C00) + return half(detail::binary, (absx>0x7C00 || absy>0x7C00) ? detail::signal(x.data_, y.data_) : + (absx==0x7C00) ? detail::invalid() : (*quo = 0, x.data_)); + if(!absy) + return half(detail::binary, detail::invalid()); + bool qsign = ((value^y.data_)&0x8000) != 0; + int q = 1; + if(absx != absy) + value ^= detail::mod(absx, absy, &q); + return *quo = qsign ? -q : q, half(detail::binary, value); + } + + /// Fused multiply add. + /// This function is exact to rounding for all rounding modes. + /// + /// **See also:** Documentation for [std::fma](https://en.cppreference.com/w/cpp/numeric/math/fma). + /// \param x first operand + /// \param y second operand + /// \param z third operand + /// \return ( \a x * \a y ) + \a z rounded as one operation. + /// \exception FE_INVALID according to operator*() and operator+() unless any argument is a quiet NaN and no argument is a signaling NaN + /// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding the final addition + inline half fma(half x, half y, half z) + { + #ifdef HALF_ARITHMETIC_TYPE + detail::internal_t fx = detail::half2float(x.data_), fy = detail::half2float(y.data_), fz = detail::half2float(z.data_); + #if HALF_ENABLE_CPP11_CMATH && FP_FAST_FMA + return half(detail::binary, detail::float2half(std::fma(fx, fy, fz))); + #else + return half(detail::binary, detail::float2half(fx*fy+fz)); + #endif + #else + int absx = x.data_ & 0x7FFF, absy = y.data_ & 0x7FFF, absz = z.data_ & 0x7FFF, exp = -15; + unsigned int sign = (x.data_^y.data_) & 0x8000; + bool sub = ((sign^z.data_)&0x8000) != 0; + if(absx >= 0x7C00 || absy >= 0x7C00 || absz >= 0x7C00) + return (absx>0x7C00 || absy>0x7C00 || absz>0x7C00) ? half(detail::binary, detail::signal(x.data_, y.data_, z.data_)) : + (absx==0x7C00) ? half(detail::binary, (!absy || (sub && absz==0x7C00)) ? detail::invalid() : (sign|0x7C00)) : + (absy==0x7C00) ? half(detail::binary, (!absx || (sub && absz==0x7C00)) ? detail::invalid() : (sign|0x7C00)) : z; + if(!absx || !absy) + return absz ? z : half(detail::binary, (half::round_style==std::round_toward_neg_infinity) ? (z.data_|sign) : (z.data_&sign)); + for(; absx<0x400; absx<<=1,--exp) ; + for(; absy<0x400; absy<<=1,--exp) ; + detail::uint32 m = static_cast((absx&0x3FF)|0x400) * static_cast((absy&0x3FF)|0x400); + int i = m >> 21; + exp += (absx>>10) + (absy>>10) + i; + m <<= 3 - i; + if(absz) + { + int expz = 0; + for(; absz<0x400; absz<<=1,--expz) ; + expz += absz >> 10; + detail::uint32 mz = static_cast((absz&0x3FF)|0x400) << 13; + if(expz > exp || (expz == exp && mz > m)) + { + std::swap(m, mz); + std::swap(exp, expz); + if(sub) + sign = z.data_ & 0x8000; + } + int d = exp - expz; + mz = (d<23) ? ((mz>>d)|((mz&((static_cast(1)<(half::round_style==std::round_toward_neg_infinity)<<15); + for(; m<0x800000; m<<=1,--exp) ; + } + else + { + m += mz; + i = m >> 24; + m = (m>>i) | (m&i); + exp += i; + } + } + if(exp > 30) + return half(detail::binary, detail::overflow(sign)); + else if(exp < -10) + return half(detail::binary, detail::underflow(sign)); + return half(detail::binary, detail::fixed2half(m, exp-1, sign)); + #endif + } + + /// Maximum of half expressions. + /// **See also:** Documentation for [std::fmax](https://en.cppreference.com/w/cpp/numeric/math/fmax). + /// \param x first operand + /// \param y second operand + /// \return maximum of operands, ignoring quiet NaNs + /// \exception FE_INVALID if \a x or \a y is signaling NaN + inline HALF_CONSTEXPR_NOERR half fmax(half x, half y) + { + return half(detail::binary, (!isnan(y) && (isnan(x) || (x.data_^(0x8000|(0x8000-(x.data_>>15)))) < + (y.data_^(0x8000|(0x8000-(y.data_>>15)))))) ? detail::select(y.data_, x.data_) : detail::select(x.data_, y.data_)); + } + + /// Minimum of half expressions. + /// **See also:** Documentation for [std::fmin](https://en.cppreference.com/w/cpp/numeric/math/fmin). + /// \param x first operand + /// \param y second operand + /// \return minimum of operands, ignoring quiet NaNs + /// \exception FE_INVALID if \a x or \a y is signaling NaN + inline HALF_CONSTEXPR_NOERR half fmin(half x, half y) + { + return half(detail::binary, (!isnan(y) && (isnan(x) || (x.data_^(0x8000|(0x8000-(x.data_>>15)))) > + (y.data_^(0x8000|(0x8000-(y.data_>>15)))))) ? detail::select(y.data_, x.data_) : detail::select(x.data_, y.data_)); + } + + /// Positive difference. + /// This function is exact to rounding for all rounding modes. + /// + /// **See also:** Documentation for [std::fdim](https://en.cppreference.com/w/cpp/numeric/math/fdim). + /// \param x first operand + /// \param y second operand + /// \return \a x - \a y or 0 if difference negative + /// \exception FE_... according to operator-(half,half) + inline half fdim(half x, half y) + { + if(isnan(x) || isnan(y)) + return half(detail::binary, detail::signal(x.data_, y.data_)); + return (x.data_^(0x8000|(0x8000-(x.data_>>15)))) <= (y.data_^(0x8000|(0x8000-(y.data_>>15)))) ? half(detail::binary, 0) : (x-y); + } + + /// Get NaN value. + /// **See also:** Documentation for [std::nan](https://en.cppreference.com/w/cpp/numeric/math/nan). + /// \param arg string code + /// \return quiet NaN + inline half nanh(const char *arg) + { + unsigned int value = 0x7FFF; + while(*arg) + value ^= static_cast(*arg++) & 0xFF; + return half(detail::binary, value); + } + + /// \} + /// \anchor exponential + /// \name Exponential functions + /// \{ + + /// Exponential function. + /// This function is exact to rounding for all rounding modes. + /// + /// **See also:** Documentation for [std::exp](https://en.cppreference.com/w/cpp/numeric/math/exp). + /// \param arg function argument + /// \return e raised to \a arg + /// \exception FE_INVALID for signaling NaN + /// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding + inline half exp(half arg) + { + #ifdef HALF_ARITHMETIC_TYPE + return half(detail::binary, detail::float2half(std::exp(detail::half2float(arg.data_)))); + #else + int abs = arg.data_ & 0x7FFF, e = (abs>>10) + (abs<=0x3FF), exp; + if(!abs) + return half(detail::binary, 0x3C00); + if(abs >= 0x7C00) + return half(detail::binary, (abs==0x7C00) ? (0x7C00&((arg.data_>>15)-1U)) : detail::signal(arg.data_)); + if(abs >= 0x4C80) + return half(detail::binary, (arg.data_&0x8000) ? detail::underflow() : detail::overflow()); + detail::uint32 m = detail::multiply64(static_cast((abs&0x3FF)+((abs>0x3FF)<<10))<<21, 0xB8AA3B29); + if(e < 14) + { + exp = 0; + m >>= 14 - e; + } + else + { + exp = m >> (45-e); + m = (m<<(e-14)) & 0x7FFFFFFF; + } + return half(detail::binary, detail::exp2_post(m, exp, (arg.data_&0x8000)!=0, 0, 26)); + #endif + } + + /// Binary exponential. + /// This function is exact to rounding for all rounding modes. + /// + /// **See also:** Documentation for [std::exp2](https://en.cppreference.com/w/cpp/numeric/math/exp2). + /// \param arg function argument + /// \return 2 raised to \a arg + /// \exception FE_INVALID for signaling NaN + /// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding + inline half exp2(half arg) + { + #if defined(HALF_ARITHMETIC_TYPE) && HALF_ENABLE_CPP11_CMATH + return half(detail::binary, detail::float2half(std::exp2(detail::half2float(arg.data_)))); + #else + int abs = arg.data_ & 0x7FFF, e = (abs>>10) + (abs<=0x3FF), exp = (abs&0x3FF) + ((abs>0x3FF)<<10); + if(!abs) + return half(detail::binary, 0x3C00); + if(abs >= 0x7C00) + return half(detail::binary, (abs==0x7C00) ? (0x7C00&((arg.data_>>15)-1U)) : detail::signal(arg.data_)); + if(abs >= 0x4E40) + return half(detail::binary, (arg.data_&0x8000) ? detail::underflow() : detail::overflow()); + return half(detail::binary, detail::exp2_post( + (static_cast(exp)<<(6+e))&0x7FFFFFFF, exp>>(25-e), (arg.data_&0x8000)!=0, 0, 28)); + #endif + } + + /// Exponential minus one. + /// This function may be 1 ULP off the correctly rounded exact result in <0.05% of inputs for `std::round_to_nearest` + /// and in <1% of inputs for any other rounding mode. + /// + /// **See also:** Documentation for [std::expm1](https://en.cppreference.com/w/cpp/numeric/math/expm1). + /// \param arg function argument + /// \return e raised to \a arg and subtracted by 1 + /// \exception FE_INVALID for signaling NaN + /// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding + inline half expm1(half arg) + { + #if defined(HALF_ARITHMETIC_TYPE) && HALF_ENABLE_CPP11_CMATH + return half(detail::binary, detail::float2half(std::expm1(detail::half2float(arg.data_)))); + #else + unsigned int abs = arg.data_ & 0x7FFF, sign = arg.data_ & 0x8000, e = (abs>>10) + (abs<=0x3FF), exp; + if(!abs) + return arg; + if(abs >= 0x7C00) + return half(detail::binary, (abs==0x7C00) ? (0x7C00+(sign>>1)) : detail::signal(arg.data_)); + if(abs >= 0x4A00) + return half(detail::binary, (arg.data_&0x8000) ? detail::rounded(0xBBFF, 1, 1) : detail::overflow()); + detail::uint32 m = detail::multiply64(static_cast((abs&0x3FF)+((abs>0x3FF)<<10))<<21, 0xB8AA3B29); + if(e < 14) + { + exp = 0; + m >>= 14 - e; + } + else + { + exp = m >> (45-e); + m = (m<<(e-14)) & 0x7FFFFFFF; + } + m = detail::exp2(m); + if(sign) + { + int s = 0; + if(m > 0x80000000) + { + ++exp; + m = detail::divide64(0x80000000, m, s); + } + m = 0x80000000 - ((m>>exp)|((m&((static_cast(1)<>exp) : 1; + for(exp+=14; m<0x80000000 && exp; m<<=1,--exp) ; + if(exp > 29) + return half(detail::binary, detail::overflow()); + return half(detail::binary, detail::rounded(sign+(exp<<10)+(m>>21), (m>>20)&1, (m&0xFFFFF)!=0)); + #endif + } + + /// Natural logarithm. + /// This function is exact to rounding for all rounding modes. + /// + /// **See also:** Documentation for [std::log](https://en.cppreference.com/w/cpp/numeric/math/log). + /// \param arg function argument + /// \return logarithm of \a arg to base e + /// \exception FE_INVALID for signaling NaN or negative argument + /// \exception FE_DIVBYZERO for 0 + /// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding + inline half log(half arg) + { + #ifdef HALF_ARITHMETIC_TYPE + return half(detail::binary, detail::float2half(std::log(detail::half2float(arg.data_)))); + #else + int abs = arg.data_ & 0x7FFF, exp = -15; + if(!abs) + return half(detail::binary, detail::pole(0x8000)); + if(arg.data_ & 0x8000) + return half(detail::binary, (arg.data_<=0xFC00) ? detail::invalid() : detail::signal(arg.data_)); + if(abs >= 0x7C00) + return (abs==0x7C00) ? arg : half(detail::binary, detail::signal(arg.data_)); + for(; abs<0x400; abs<<=1,--exp) ; + exp += abs >> 10; + return half(detail::binary, detail::log2_post( + detail::log2(static_cast((abs&0x3FF)|0x400)<<20, 27)+8, exp, 17)); + #endif + } + + /// Common logarithm. + /// This function is exact to rounding for all rounding modes. + /// + /// **See also:** Documentation for [std::log10](https://en.cppreference.com/w/cpp/numeric/math/log10). + /// \param arg function argument + /// \return logarithm of \a arg to base 10 + /// \exception FE_INVALID for signaling NaN or negative argument + /// \exception FE_DIVBYZERO for 0 + /// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding + inline half log10(half arg) + { + #ifdef HALF_ARITHMETIC_TYPE + return half(detail::binary, detail::float2half(std::log10(detail::half2float(arg.data_)))); + #else + int abs = arg.data_ & 0x7FFF, exp = -15; + if(!abs) + return half(detail::binary, detail::pole(0x8000)); + if(arg.data_ & 0x8000) + return half(detail::binary, (arg.data_<=0xFC00) ? detail::invalid() : detail::signal(arg.data_)); + if(abs >= 0x7C00) + return (abs==0x7C00) ? arg : half(detail::binary, detail::signal(arg.data_)); + switch(abs) + { + case 0x4900: return half(detail::binary, 0x3C00); + case 0x5640: return half(detail::binary, 0x4000); + case 0x63D0: return half(detail::binary, 0x4200); + case 0x70E2: return half(detail::binary, 0x4400); + } + for(; abs<0x400; abs<<=1,--exp) ; + exp += abs >> 10; + return half(detail::binary, detail::log2_post( + detail::log2(static_cast((abs&0x3FF)|0x400)<<20, 27)+8, exp, 16)); + #endif + } + + /// Binary logarithm. + /// This function is exact to rounding for all rounding modes. + /// + /// **See also:** Documentation for [std::log2](https://en.cppreference.com/w/cpp/numeric/math/log2). + /// \param arg function argument + /// \return logarithm of \a arg to base 2 + /// \exception FE_INVALID for signaling NaN or negative argument + /// \exception FE_DIVBYZERO for 0 + /// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding + inline half log2(half arg) + { + #if defined(HALF_ARITHMETIC_TYPE) && HALF_ENABLE_CPP11_CMATH + return half(detail::binary, detail::float2half(std::log2(detail::half2float(arg.data_)))); + #else + int abs = arg.data_ & 0x7FFF, exp = -15, s = 0; + if(!abs) + return half(detail::binary, detail::pole(0x8000)); + if(arg.data_ & 0x8000) + return half(detail::binary, (arg.data_<=0xFC00) ? detail::invalid() : detail::signal(arg.data_)); + if(abs >= 0x7C00) + return (abs==0x7C00) ? arg : half(detail::binary, detail::signal(arg.data_)); + if(abs == 0x3C00) + return half(detail::binary, 0); + for(; abs<0x400; abs<<=1,--exp) ; + exp += (abs>>10); + if(!(abs&0x3FF)) + { + unsigned int value = static_cast(exp<0) << 15, m = std::abs(exp) << 6; + for(exp=18; m<0x400; m<<=1,--exp) ; + return half(detail::binary, value+(exp<<10)+m); + } + detail::uint32 ilog = exp, sign = detail::sign_mask(ilog), m = + (((ilog<<27)+(detail::log2(static_cast((abs&0x3FF)|0x400)<<20, 28)>>4))^sign) - sign; + if(!m) + return half(detail::binary, 0); + for(exp=14; m<0x8000000 && exp; m<<=1,--exp) ; + for(; m>0xFFFFFFF; m>>=1,++exp) + s |= m & 1; + return half(detail::binary, detail::fixed2half(m, exp, sign&0x8000, s)); + #endif + } + + /// Natural logarithm plus one. + /// This function may be 1 ULP off the correctly rounded exact result in <0.05% of inputs for `std::round_to_nearest` + /// and in ~1% of inputs for any other rounding mode. + /// + /// **See also:** Documentation for [std::log1p](https://en.cppreference.com/w/cpp/numeric/math/log1p). + /// \param arg function argument + /// \return logarithm of \a arg plus 1 to base e + /// \exception FE_INVALID for signaling NaN or argument <-1 + /// \exception FE_DIVBYZERO for -1 + /// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding + inline half log1p(half arg) + { + #if defined(HALF_ARITHMETIC_TYPE) && HALF_ENABLE_CPP11_CMATH + return half(detail::binary, detail::float2half(std::log1p(detail::half2float(arg.data_)))); + #else + if(arg.data_ >= 0xBC00) + return half(detail::binary, (arg.data_==0xBC00) ? detail::pole(0x8000) : (arg.data_<=0xFC00) ? detail::invalid() : detail::signal(arg.data_)); + int abs = arg.data_ & 0x7FFF, exp = -15; + if(!abs || abs >= 0x7C00) + return (abs>0x7C00) ? half(detail::binary, detail::signal(arg.data_)) : arg; + for(; abs<0x400; abs<<=1,--exp) ; + exp += abs >> 10; + detail::uint32 m = static_cast((abs&0x3FF)|0x400) << 20; + if(arg.data_ & 0x8000) + { + m = 0x40000000 - (m>>-exp); + for(exp=0; m<0x40000000; m<<=1,--exp) ; + } + else + { + if(exp < 0) + { + m = 0x40000000 + (m>>-exp); + exp = 0; + } + else + { + m += 0x40000000 >> exp; + int i = m >> 31; + m >>= i; + exp += i; + } + } + return half(detail::binary, detail::log2_post(detail::log2(m), exp, 17)); + #endif + } + + /// \} + /// \anchor power + /// \name Power functions + /// \{ + + /// Square root. + /// This function is exact to rounding for all rounding modes. + /// + /// **See also:** Documentation for [std::sqrt](https://en.cppreference.com/w/cpp/numeric/math/sqrt). + /// \param arg function argument + /// \return square root of \a arg + /// \exception FE_INVALID for signaling NaN and negative arguments + /// \exception FE_INEXACT according to rounding + inline half sqrt(half arg) + { + #ifdef HALF_ARITHMETIC_TYPE + return half(detail::binary, detail::float2half(std::sqrt(detail::half2float(arg.data_)))); + #else + int abs = arg.data_ & 0x7FFF, exp = 15; + if(!abs || arg.data_ >= 0x7C00) + return half(detail::binary, (abs>0x7C00) ? detail::signal(arg.data_) : (arg.data_>0x8000) ? detail::invalid() : arg.data_); + for(; abs<0x400; abs<<=1,--exp) ; + detail::uint32 r = static_cast((abs&0x3FF)|0x400) << 10, m = detail::sqrt<20>(r, exp+=abs>>10); + return half(detail::binary, detail::rounded((exp<<10)+(m&0x3FF), r>m, r!=0)); + #endif + } + + /// Inverse square root. + /// This function is exact to rounding for all rounding modes and thus generally more accurate than directly computing + /// 1 / sqrt(\a arg) in half-precision, in addition to also being faster. + /// \param arg function argument + /// \return reciprocal of square root of \a arg + /// \exception FE_INVALID for signaling NaN and negative arguments + /// \exception FE_INEXACT according to rounding + inline half rsqrt(half arg) + { + #ifdef HALF_ARITHMETIC_TYPE + return half(detail::binary, detail::float2half(detail::internal_t(1)/std::sqrt(detail::half2float(arg.data_)))); + #else + unsigned int abs = arg.data_ & 0x7FFF, bias = 0x4000; + if(!abs || arg.data_ >= 0x7C00) + return half(detail::binary, (abs>0x7C00) ? detail::signal(arg.data_) : (arg.data_>0x8000) ? + detail::invalid() : !abs ? detail::pole(arg.data_&0x8000) : 0); + for(; abs<0x400; abs<<=1,bias-=0x400) ; + unsigned int frac = (abs+=bias) & 0x7FF; + if(frac == 0x400) + return half(detail::binary, 0x7A00-(abs>>1)); + if((half::round_style == std::round_to_nearest && (frac == 0x3FE || frac == 0x76C)) || + (half::round_style != std::round_to_nearest && (frac == 0x15A || frac == 0x3FC || frac == 0x401 || frac == 0x402 || frac == 0x67B))) + return pow(arg, half(detail::binary, 0xB800)); + detail::uint32 f = 0x17376 - abs, mx = (abs&0x3FF) | 0x400, my = ((f>>1)&0x3FF) | 0x400, mz = my * my; + int expy = (f>>11) - 31, expx = 32 - (abs>>10), i = mz >> 21; + for(mz=0x60000000-(((mz>>i)*mx)>>(expx-2*expy-i)); mz<0x40000000; mz<<=1,--expy) ; + i = (my*=mz>>10) >> 31; + expy += i; + my = (my>>(20+i)) + 1; + i = (mz=my*my) >> 21; + for(mz=0x60000000-(((mz>>i)*mx)>>(expx-2*expy-i)); mz<0x40000000; mz<<=1,--expy) ; + i = (my*=(mz>>10)+1) >> 31; + return half(detail::binary, detail::fixed2half(my>>i, expy+i+14)); + #endif + } + + /// Cubic root. + /// This function is exact to rounding for all rounding modes. + /// + /// **See also:** Documentation for [std::cbrt](https://en.cppreference.com/w/cpp/numeric/math/cbrt). + /// \param arg function argument + /// \return cubic root of \a arg + /// \exception FE_INVALID for signaling NaN + /// \exception FE_INEXACT according to rounding + inline half cbrt(half arg) + { + #if defined(HALF_ARITHMETIC_TYPE) && HALF_ENABLE_CPP11_CMATH + return half(detail::binary, detail::float2half(std::cbrt(detail::half2float(arg.data_)))); + #else + int abs = arg.data_ & 0x7FFF, exp = -15; + if(!abs || abs == 0x3C00 || abs >= 0x7C00) + return (abs>0x7C00) ? half(detail::binary, detail::signal(arg.data_)) : arg; + for(; abs<0x400; abs<<=1, --exp); + detail::uint32 ilog = exp + (abs>>10), sign = detail::sign_mask(ilog), f, m = + (((ilog<<27)+(detail::log2(static_cast((abs&0x3FF)|0x400)<<20, 24)>>4))^sign) - sign; + for(exp=2; m<0x80000000; m<<=1,--exp) ; + m = detail::multiply64(m, 0xAAAAAAAB); + int i = m >> 31, s; + exp += i; + m <<= 1 - i; + if(exp < 0) + { + f = m >> -exp; + exp = 0; + } + else + { + f = (m<> (31-exp); + } + m = detail::exp2(f, (half::round_style==std::round_to_nearest) ? 29 : 26); + if(sign) + { + if(m > 0x80000000) + { + m = detail::divide64(0x80000000, m, s); + ++exp; + } + exp = -exp; + } + return half(detail::binary, (half::round_style==std::round_to_nearest) ? + detail::fixed2half(m, exp+14, arg.data_&0x8000) : + detail::fixed2half((m+0x80)>>8, exp+14, arg.data_&0x8000)); + #endif + } + + /// Hypotenuse function. + /// This function is exact to rounding for all rounding modes. + /// + /// **See also:** Documentation for [std::hypot](https://en.cppreference.com/w/cpp/numeric/math/hypot). + /// \param x first argument + /// \param y second argument + /// \return square root of sum of squares without internal over- or underflows + /// \exception FE_INVALID if \a x or \a y is signaling NaN + /// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding of the final square root + inline half hypot(half x, half y) + { + #ifdef HALF_ARITHMETIC_TYPE + detail::internal_t fx = detail::half2float(x.data_), fy = detail::half2float(y.data_); + #if HALF_ENABLE_CPP11_CMATH + return half(detail::binary, detail::float2half(std::hypot(fx, fy))); + #else + return half(detail::binary, detail::float2half(std::sqrt(fx*fx+fy*fy))); + #endif + #else + int absx = x.data_ & 0x7FFF, absy = y.data_ & 0x7FFF, expx = 0, expy = 0; + if(absx >= 0x7C00 || absy >= 0x7C00) + return half(detail::binary, (absx==0x7C00) ? detail::select(0x7C00, y.data_) : + (absy==0x7C00) ? detail::select(0x7C00, x.data_) : detail::signal(x.data_, y.data_)); + if(!absx) + return half(detail::binary, absy ? detail::check_underflow(absy) : 0); + if(!absy) + return half(detail::binary, detail::check_underflow(absx)); + if(absy > absx) + std::swap(absx, absy); + for(; absx<0x400; absx<<=1,--expx) ; + for(; absy<0x400; absy<<=1,--expy) ; + detail::uint32 mx = (absx&0x3FF) | 0x400, my = (absy&0x3FF) | 0x400; + mx *= mx; + my *= my; + int ix = mx >> 21, iy = my >> 21; + expx = 2*(expx+(absx>>10)) - 15 + ix; + expy = 2*(expy+(absy>>10)) - 15 + iy; + mx <<= 10 - ix; + my <<= 10 - iy; + int d = expx - expy; + my = (d<30) ? ((my>>d)|((my&((static_cast(1)<(mx+my, expx)); + #endif + } + + /// Hypotenuse function. + /// This function is exact to rounding for all rounding modes. + /// + /// **See also:** Documentation for [std::hypot](https://en.cppreference.com/w/cpp/numeric/math/hypot). + /// \param x first argument + /// \param y second argument + /// \param z third argument + /// \return square root of sum of squares without internal over- or underflows + /// \exception FE_INVALID if \a x, \a y or \a z is signaling NaN + /// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding of the final square root + inline half hypot(half x, half y, half z) + { + #ifdef HALF_ARITHMETIC_TYPE + detail::internal_t fx = detail::half2float(x.data_), fy = detail::half2float(y.data_), fz = detail::half2float(z.data_); + return half(detail::binary, detail::float2half(std::sqrt(fx*fx+fy*fy+fz*fz))); + #else + int absx = x.data_ & 0x7FFF, absy = y.data_ & 0x7FFF, absz = z.data_ & 0x7FFF, expx = 0, expy = 0, expz = 0; + if(!absx) + return hypot(y, z); + if(!absy) + return hypot(x, z); + if(!absz) + return hypot(x, y); + if(absx >= 0x7C00 || absy >= 0x7C00 || absz >= 0x7C00) + return half(detail::binary, (absx==0x7C00) ? detail::select(0x7C00, detail::select(y.data_, z.data_)) : + (absy==0x7C00) ? detail::select(0x7C00, detail::select(x.data_, z.data_)) : + (absz==0x7C00) ? detail::select(0x7C00, detail::select(x.data_, y.data_)) : + detail::signal(x.data_, y.data_, z.data_)); + if(absz > absy) + std::swap(absy, absz); + if(absy > absx) + std::swap(absx, absy); + if(absz > absy) + std::swap(absy, absz); + for(; absx<0x400; absx<<=1,--expx) ; + for(; absy<0x400; absy<<=1,--expy) ; + for(; absz<0x400; absz<<=1,--expz) ; + detail::uint32 mx = (absx&0x3FF) | 0x400, my = (absy&0x3FF) | 0x400, mz = (absz&0x3FF) | 0x400; + mx *= mx; + my *= my; + mz *= mz; + int ix = mx >> 21, iy = my >> 21, iz = mz >> 21; + expx = 2*(expx+(absx>>10)) - 15 + ix; + expy = 2*(expy+(absy>>10)) - 15 + iy; + expz = 2*(expz+(absz>>10)) - 15 + iz; + mx <<= 10 - ix; + my <<= 10 - iy; + mz <<= 10 - iz; + int d = expy - expz; + mz = (d<30) ? ((mz>>d)|((mz&((static_cast(1)<>1) | (my&1); + if(++expy > expx) + { + std::swap(mx, my); + std::swap(expx, expy); + } + } + d = expx - expy; + my = (d<30) ? ((my>>d)|((my&((static_cast(1)<(mx+my, expx)); + #endif + } + + /// Power function. + /// This function may be 1 ULP off the correctly rounded exact result for any rounding mode in ~0.00025% of inputs. + /// + /// **See also:** Documentation for [std::pow](https://en.cppreference.com/w/cpp/numeric/math/pow). + /// \param x base + /// \param y exponent + /// \return \a x raised to \a y + /// \exception FE_INVALID if \a x or \a y is signaling NaN or if \a x is finite an negative and \a y is finite and not integral + /// \exception FE_DIVBYZERO if \a x is 0 and \a y is negative + /// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding + inline half pow(half x, half y) + { + #ifdef HALF_ARITHMETIC_TYPE + return half(detail::binary, detail::float2half(std::pow(detail::half2float(x.data_), detail::half2float(y.data_)))); + #else + int absx = x.data_ & 0x7FFF, absy = y.data_ & 0x7FFF, exp = -15; + if(!absy || x.data_ == 0x3C00) + return half(detail::binary, detail::select(0x3C00, (x.data_==0x3C00) ? y.data_ : x.data_)); + bool is_int = absy >= 0x6400 || (absy>=0x3C00 && !(absy&((1<<(25-(absy>>10)))-1))); + unsigned int sign = x.data_ & (static_cast((absy<0x6800)&&is_int&&((absy>>(25-(absy>>10)))&1))<<15); + if(absx >= 0x7C00 || absy >= 0x7C00) + return half(detail::binary, (absx>0x7C00 || absy>0x7C00) ? detail::signal(x.data_, y.data_) : + (absy==0x7C00) ? ((absx==0x3C00) ? 0x3C00 : (!absx && y.data_==0xFC00) ? detail::pole() : + (0x7C00&-((y.data_>>15)^(absx>0x3C00)))) : (sign|(0x7C00&((y.data_>>15)-1U)))); + if(!absx) + return half(detail::binary, (y.data_&0x8000) ? detail::pole(sign) : sign); + if((x.data_&0x8000) && !is_int) + return half(detail::binary, detail::invalid()); + if(x.data_ == 0xBC00) + return half(detail::binary, sign|0x3C00); + switch(y.data_) + { + case 0x3800: return sqrt(x); + case 0x3C00: return half(detail::binary, detail::check_underflow(x.data_)); + case 0x4000: return x * x; + case 0xBC00: return half(detail::binary, 0x3C00) / x; + } + for(; absx<0x400; absx<<=1,--exp) ; + detail::uint32 ilog = exp + (absx>>10), msign = detail::sign_mask(ilog), f, m = + (((ilog<<27)+((detail::log2(static_cast((absx&0x3FF)|0x400)<<20)+8)>>4))^msign) - msign; + for(exp=-11; m<0x80000000; m<<=1,--exp) ; + for(; absy<0x400; absy<<=1,--exp) ; + m = detail::multiply64(m, static_cast((absy&0x3FF)|0x400)<<21); + int i = m >> 31; + exp += (absy>>10) + i; + m <<= 1 - i; + if(exp < 0) + { + f = m >> -exp; + exp = 0; + } + else + { + f = (m<> (31-exp); + } + return half(detail::binary, detail::exp2_post(f, exp, ((msign&1)^(y.data_>>15))!=0, sign)); + #endif + } + + /// \} + /// \anchor trigonometric + /// \name Trigonometric functions + /// \{ + + /// Compute sine and cosine simultaneously. + /// This returns the same results as sin() and cos() but is faster than calling each function individually. + /// + /// This function is exact to rounding for all rounding modes. + /// \param arg function argument + /// \param sin variable to take sine of \a arg + /// \param cos variable to take cosine of \a arg + /// \exception FE_INVALID for signaling NaN or infinity + /// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding + inline void sincos(half arg, half *sin, half *cos) + { + #ifdef HALF_ARITHMETIC_TYPE + detail::internal_t f = detail::half2float(arg.data_); + *sin = half(detail::binary, detail::float2half(std::sin(f))); + *cos = half(detail::binary, detail::float2half(std::cos(f))); + #else + int abs = arg.data_ & 0x7FFF, sign = arg.data_ >> 15, k; + if(abs >= 0x7C00) + *sin = *cos = half(detail::binary, (abs==0x7C00) ? detail::invalid() : detail::signal(arg.data_)); + else if(!abs) + { + *sin = arg; + *cos = half(detail::binary, 0x3C00); + } + else if(abs < 0x2500) + { + *sin = half(detail::binary, detail::rounded(arg.data_-1, 1, 1)); + *cos = half(detail::binary, detail::rounded(0x3BFF, 1, 1)); + } + else + { + if(half::round_style != std::round_to_nearest) + { + switch(abs) + { + case 0x48B7: + *sin = half(detail::binary, detail::rounded((~arg.data_&0x8000)|0x1D07, 1, 1)); + *cos = half(detail::binary, detail::rounded(0xBBFF, 1, 1)); + return; + case 0x598C: + *sin = half(detail::binary, detail::rounded((arg.data_&0x8000)|0x3BFF, 1, 1)); + *cos = half(detail::binary, detail::rounded(0x80FC, 1, 1)); + return; + case 0x6A64: + *sin = half(detail::binary, detail::rounded((~arg.data_&0x8000)|0x3BFE, 1, 1)); + *cos = half(detail::binary, detail::rounded(0x27FF, 1, 1)); + return; + case 0x6D8C: + *sin = half(detail::binary, detail::rounded((arg.data_&0x8000)|0x0FE6, 1, 1)); + *cos = half(detail::binary, detail::rounded(0x3BFF, 1, 1)); + return; + } + } + std::pair sc = detail::sincos(detail::angle_arg(abs, k), 28); + switch(k & 3) + { + case 1: sc = std::make_pair(sc.second, -sc.first); break; + case 2: sc = std::make_pair(-sc.first, -sc.second); break; + case 3: sc = std::make_pair(-sc.second, sc.first); break; + } + *sin = half(detail::binary, detail::fixed2half((sc.first^-static_cast(sign))+sign)); + *cos = half(detail::binary, detail::fixed2half(sc.second)); + } + #endif + } + + /// Sine function. + /// This function is exact to rounding for all rounding modes. + /// + /// **See also:** Documentation for [std::sin](https://en.cppreference.com/w/cpp/numeric/math/sin). + /// \param arg function argument + /// \return sine value of \a arg + /// \exception FE_INVALID for signaling NaN or infinity + /// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding + inline half sin(half arg) + { + #ifdef HALF_ARITHMETIC_TYPE + return half(detail::binary, detail::float2half(std::sin(detail::half2float(arg.data_)))); + #else + int abs = arg.data_ & 0x7FFF, k; + if(!abs) + return arg; + if(abs >= 0x7C00) + return half(detail::binary, (abs==0x7C00) ? detail::invalid() : detail::signal(arg.data_)); + if(abs < 0x2900) + return half(detail::binary, detail::rounded(arg.data_-1, 1, 1)); + if(half::round_style != std::round_to_nearest) + switch(abs) + { + case 0x48B7: return half(detail::binary, detail::rounded((~arg.data_&0x8000)|0x1D07, 1, 1)); + case 0x6A64: return half(detail::binary, detail::rounded((~arg.data_&0x8000)|0x3BFE, 1, 1)); + case 0x6D8C: return half(detail::binary, detail::rounded((arg.data_&0x8000)|0x0FE6, 1, 1)); + } + std::pair sc = detail::sincos(detail::angle_arg(abs, k), 28); + detail::uint32 sign = -static_cast(((k>>1)&1)^(arg.data_>>15)); + return half(detail::binary, detail::fixed2half((((k&1) ? sc.second : sc.first)^sign) - sign)); + #endif + } + + /// Cosine function. + /// This function is exact to rounding for all rounding modes. + /// + /// **See also:** Documentation for [std::cos](https://en.cppreference.com/w/cpp/numeric/math/cos). + /// \param arg function argument + /// \return cosine value of \a arg + /// \exception FE_INVALID for signaling NaN or infinity + /// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding + inline half cos(half arg) + { + #ifdef HALF_ARITHMETIC_TYPE + return half(detail::binary, detail::float2half(std::cos(detail::half2float(arg.data_)))); + #else + int abs = arg.data_ & 0x7FFF, k; + if(!abs) + return half(detail::binary, 0x3C00); + if(abs >= 0x7C00) + return half(detail::binary, (abs==0x7C00) ? detail::invalid() : detail::signal(arg.data_)); + if(abs < 0x2500) + return half(detail::binary, detail::rounded(0x3BFF, 1, 1)); + if(half::round_style != std::round_to_nearest && abs == 0x598C) + return half(detail::binary, detail::rounded(0x80FC, 1, 1)); + std::pair sc = detail::sincos(detail::angle_arg(abs, k), 28); + detail::uint32 sign = -static_cast(((k>>1)^k)&1); + return half(detail::binary, detail::fixed2half((((k&1) ? sc.first : sc.second)^sign) - sign)); + #endif + } + + /// Tangent function. + /// This function is exact to rounding for all rounding modes. + /// + /// **See also:** Documentation for [std::tan](https://en.cppreference.com/w/cpp/numeric/math/tan). + /// \param arg function argument + /// \return tangent value of \a arg + /// \exception FE_INVALID for signaling NaN or infinity + /// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding + inline half tan(half arg) + { + #ifdef HALF_ARITHMETIC_TYPE + return half(detail::binary, detail::float2half(std::tan(detail::half2float(arg.data_)))); + #else + int abs = arg.data_ & 0x7FFF, exp = 13, k; + if(!abs) + return arg; + if(abs >= 0x7C00) + return half(detail::binary, (abs==0x7C00) ? detail::invalid() : detail::signal(arg.data_)); + if(abs < 0x2700) + return half(detail::binary, detail::rounded(arg.data_, 0, 1)); + if(half::round_style != std::round_to_nearest) + switch(abs) + { + case 0x658C: return half(detail::binary, detail::rounded((arg.data_&0x8000)|0x07E6, 1, 1)); + case 0x7330: return half(detail::binary, detail::rounded((~arg.data_&0x8000)|0x4B62, 1, 1)); + } + std::pair sc = detail::sincos(detail::angle_arg(abs, k), 30); + if(k & 1) + sc = std::make_pair(-sc.second, sc.first); + detail::uint32 signy = detail::sign_mask(sc.first), signx = detail::sign_mask(sc.second); + detail::uint32 my = (sc.first^signy) - signy, mx = (sc.second^signx) - signx; + for(; my<0x80000000; my<<=1,--exp) ; + for(; mx<0x80000000; mx<<=1,++exp) ; + return half(detail::binary, detail::tangent_post(my, mx, exp, (signy^signx^arg.data_)&0x8000)); + #endif + } + + /// Arc sine. + /// This function is exact to rounding for all rounding modes. + /// + /// **See also:** Documentation for [std::asin](https://en.cppreference.com/w/cpp/numeric/math/asin). + /// \param arg function argument + /// \return arc sine value of \a arg + /// \exception FE_INVALID for signaling NaN or if abs(\a arg) > 1 + /// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding + inline half asin(half arg) + { + #ifdef HALF_ARITHMETIC_TYPE + return half(detail::binary, detail::float2half(std::asin(detail::half2float(arg.data_)))); + #else + unsigned int abs = arg.data_ & 0x7FFF, sign = arg.data_ & 0x8000; + if(!abs) + return arg; + if(abs >= 0x3C00) + return half(detail::binary, (abs>0x7C00) ? detail::signal(arg.data_) : (abs>0x3C00) ? detail::invalid() : + detail::rounded(sign|0x3E48, 0, 1)); + if(abs < 0x2900) + return half(detail::binary, detail::rounded(arg.data_, 0, 1)); + if(half::round_style != std::round_to_nearest && (abs == 0x2B44 || abs == 0x2DC3)) + return half(detail::binary, detail::rounded(arg.data_+1, 1, 1)); + std::pair sc = detail::atan2_args(abs); + detail::uint32 m = detail::atan2(sc.first, sc.second, (half::round_style==std::round_to_nearest) ? 27 : 26); + return half(detail::binary, detail::fixed2half(m, 14, sign)); + #endif + } + + /// Arc cosine function. + /// This function is exact to rounding for all rounding modes. + /// + /// **See also:** Documentation for [std::acos](https://en.cppreference.com/w/cpp/numeric/math/acos). + /// \param arg function argument + /// \return arc cosine value of \a arg + /// \exception FE_INVALID for signaling NaN or if abs(\a arg) > 1 + /// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding + inline half acos(half arg) + { + #ifdef HALF_ARITHMETIC_TYPE + return half(detail::binary, detail::float2half(std::acos(detail::half2float(arg.data_)))); + #else + unsigned int abs = arg.data_ & 0x7FFF, sign = arg.data_ >> 15; + if(!abs) + return half(detail::binary, detail::rounded(0x3E48, 0, 1)); + if(abs >= 0x3C00) + return half(detail::binary, (abs>0x7C00) ? detail::signal(arg.data_) : (abs>0x3C00) ? detail::invalid() : + sign ? detail::rounded(0x4248, 0, 1) : 0); + std::pair cs = detail::atan2_args(abs); + detail::uint32 m = detail::atan2(cs.second, cs.first, 28); + return half(detail::binary, detail::fixed2half(sign ? (0xC90FDAA2-m) : m, 15, 0, sign)); + #endif + } + + /// Arc tangent function. + /// This function is exact to rounding for all rounding modes. + /// + /// **See also:** Documentation for [std::atan](https://en.cppreference.com/w/cpp/numeric/math/atan). + /// \param arg function argument + /// \return arc tangent value of \a arg + /// \exception FE_INVALID for signaling NaN + /// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding + inline half atan(half arg) + { + #ifdef HALF_ARITHMETIC_TYPE + return half(detail::binary, detail::float2half(std::atan(detail::half2float(arg.data_)))); + #else + unsigned int abs = arg.data_ & 0x7FFF, sign = arg.data_ & 0x8000; + if(!abs) + return arg; + if(abs >= 0x7C00) + return half(detail::binary, (abs==0x7C00) ? detail::rounded(sign|0x3E48, 0, 1) : detail::signal(arg.data_)); + if(abs <= 0x2700) + return half(detail::binary, detail::rounded(arg.data_-1, 1, 1)); + int exp = (abs>>10) + (abs<=0x3FF); + detail::uint32 my = (abs&0x3FF) | ((abs>0x3FF)<<10); + detail::uint32 m = (exp>15) ? detail::atan2(my<<19, 0x20000000>>(exp-15), (half::round_style==std::round_to_nearest) ? 26 : 24) : + detail::atan2(my<<(exp+4), 0x20000000, (half::round_style==std::round_to_nearest) ? 30 : 28); + return half(detail::binary, detail::fixed2half(m, 14, sign)); + #endif + } + + /// Arc tangent function. + /// This function may be 1 ULP off the correctly rounded exact result in ~0.005% of inputs for `std::round_to_nearest`, + /// in ~0.1% of inputs for `std::round_toward_zero` and in ~0.02% of inputs for any other rounding mode. + /// + /// **See also:** Documentation for [std::atan2](https://en.cppreference.com/w/cpp/numeric/math/atan2). + /// \param y numerator + /// \param x denominator + /// \return arc tangent value + /// \exception FE_INVALID if \a x or \a y is signaling NaN + /// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding + inline half atan2(half y, half x) + { + #ifdef HALF_ARITHMETIC_TYPE + return half(detail::binary, detail::float2half(std::atan2(detail::half2float(y.data_), detail::half2float(x.data_)))); + #else + unsigned int absx = x.data_ & 0x7FFF, absy = y.data_ & 0x7FFF, signx = x.data_ >> 15, signy = y.data_ & 0x8000; + if(absx >= 0x7C00 || absy >= 0x7C00) + { + if(absx > 0x7C00 || absy > 0x7C00) + return half(detail::binary, detail::signal(x.data_, y.data_)); + if(absy == 0x7C00) + return half(detail::binary, (absx<0x7C00) ? detail::rounded(signy|0x3E48, 0, 1) : + signx ? detail::rounded(signy|0x40B6, 0, 1) : + detail::rounded(signy|0x3A48, 0, 1)); + return (x.data_==0x7C00) ? half(detail::binary, signy) : half(detail::binary, detail::rounded(signy|0x4248, 0, 1)); + } + if(!absy) + return signx ? half(detail::binary, detail::rounded(signy|0x4248, 0, 1)) : y; + if(!absx) + return half(detail::binary, detail::rounded(signy|0x3E48, 0, 1)); + int d = (absy>>10) + (absy<=0x3FF) - (absx>>10) - (absx<=0x3FF); + if(d > (signx ? 18 : 12)) + return half(detail::binary, detail::rounded(signy|0x3E48, 0, 1)); + if(signx && d < -11) + return half(detail::binary, detail::rounded(signy|0x4248, 0, 1)); + if(!signx && d < ((half::round_style==std::round_toward_zero) ? -15 : -9)) + { + for(; absy<0x400; absy<<=1,--d) ; + detail::uint32 mx = ((absx<<1)&0x7FF) | 0x800, my = ((absy<<1)&0x7FF) | 0x800; + int i = my < mx; + d -= i; + if(d < -25) + return half(detail::binary, detail::underflow(signy)); + my <<= 11 + i; + return half(detail::binary, detail::fixed2half(my/mx, d+14, signy, my%mx!=0)); + } + detail::uint32 m = detail::atan2( ((absy&0x3FF)|((absy>0x3FF)<<10))<<(19+((d<0) ? d : (d>0) ? 0 : -1)), + ((absx&0x3FF)|((absx>0x3FF)<<10))<<(19-((d>0) ? d : (d<0) ? 0 : 1))); + return half(detail::binary, detail::fixed2half(signx ? (0xC90FDAA2-m) : m, 15, signy, signx)); + #endif + } + + /// \} + /// \anchor hyperbolic + /// \name Hyperbolic functions + /// \{ + + /// Hyperbolic sine. + /// This function is exact to rounding for all rounding modes. + /// + /// **See also:** Documentation for [std::sinh](https://en.cppreference.com/w/cpp/numeric/math/sinh). + /// \param arg function argument + /// \return hyperbolic sine value of \a arg + /// \exception FE_INVALID for signaling NaN + /// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding + inline half sinh(half arg) + { + #ifdef HALF_ARITHMETIC_TYPE + return half(detail::binary, detail::float2half(std::sinh(detail::half2float(arg.data_)))); + #else + int abs = arg.data_ & 0x7FFF, exp; + if(!abs || abs >= 0x7C00) + return (abs>0x7C00) ? half(detail::binary, detail::signal(arg.data_)) : arg; + if(abs <= 0x2900) + return half(detail::binary, detail::rounded(arg.data_, 0, 1)); + std::pair mm = detail::hyperbolic_args(abs, exp, (half::round_style==std::round_to_nearest) ? 29 : 27); + detail::uint32 m = mm.first - mm.second; + for(exp+=13; m<0x80000000 && exp; m<<=1,--exp) ; + unsigned int sign = arg.data_ & 0x8000; + if(exp > 29) + return half(detail::binary, detail::overflow(sign)); + return half(detail::binary, detail::fixed2half(m, exp, sign)); + #endif + } + + /// Hyperbolic cosine. + /// This function is exact to rounding for all rounding modes. + /// + /// **See also:** Documentation for [std::cosh](https://en.cppreference.com/w/cpp/numeric/math/cosh). + /// \param arg function argument + /// \return hyperbolic cosine value of \a arg + /// \exception FE_INVALID for signaling NaN + /// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding + inline half cosh(half arg) + { + #ifdef HALF_ARITHMETIC_TYPE + return half(detail::binary, detail::float2half(std::cosh(detail::half2float(arg.data_)))); + #else + int abs = arg.data_ & 0x7FFF, exp; + if(!abs) + return half(detail::binary, 0x3C00); + if(abs >= 0x7C00) + return half(detail::binary, (abs>0x7C00) ? detail::signal(arg.data_) : 0x7C00); + std::pair mm = detail::hyperbolic_args(abs, exp, (half::round_style==std::round_to_nearest) ? 23 : 26); + detail::uint32 m = mm.first + mm.second, i = (~m&0xFFFFFFFF) >> 31; + m = (m>>i) | (m&i) | 0x80000000; + if((exp+=13+i) > 29) + return half(detail::binary, detail::overflow()); + return half(detail::binary, detail::fixed2half(m, exp)); + #endif + } + + /// Hyperbolic tangent. + /// This function is exact to rounding for all rounding modes. + /// + /// **See also:** Documentation for [std::tanh](https://en.cppreference.com/w/cpp/numeric/math/tanh). + /// \param arg function argument + /// \return hyperbolic tangent value of \a arg + /// \exception FE_INVALID for signaling NaN + /// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding + inline half tanh(half arg) + { + #ifdef HALF_ARITHMETIC_TYPE + return half(detail::binary, detail::float2half(std::tanh(detail::half2float(arg.data_)))); + #else + int abs = arg.data_ & 0x7FFF, exp; + if(!abs) + return arg; + if(abs >= 0x7C00) + return half(detail::binary, (abs>0x7C00) ? detail::signal(arg.data_) : (arg.data_-0x4000)); + if(abs >= 0x4500) + return half(detail::binary, detail::rounded((arg.data_&0x8000)|0x3BFF, 1, 1)); + if(abs < 0x2700) + return half(detail::binary, detail::rounded(arg.data_-1, 1, 1)); + if(half::round_style != std::round_to_nearest && abs == 0x2D3F) + return half(detail::binary, detail::rounded(arg.data_-3, 0, 1)); + std::pair mm = detail::hyperbolic_args(abs, exp, 27); + detail::uint32 my = mm.first - mm.second - (half::round_style!=std::round_to_nearest), mx = mm.first + mm.second, i = (~mx&0xFFFFFFFF) >> 31; + for(exp=13; my<0x80000000; my<<=1,--exp) ; + mx = (mx>>i) | 0x80000000; + return half(detail::binary, detail::tangent_post(my, mx, exp-i, arg.data_&0x8000)); + #endif + } + + /// Hyperbolic area sine. + /// This function is exact to rounding for all rounding modes. + /// + /// **See also:** Documentation for [std::asinh](https://en.cppreference.com/w/cpp/numeric/math/asinh). + /// \param arg function argument + /// \return area sine value of \a arg + /// \exception FE_INVALID for signaling NaN + /// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding + inline half asinh(half arg) + { + #if defined(HALF_ARITHMETIC_TYPE) && HALF_ENABLE_CPP11_CMATH + return half(detail::binary, detail::float2half(std::asinh(detail::half2float(arg.data_)))); + #else + int abs = arg.data_ & 0x7FFF; + if(!abs || abs >= 0x7C00) + return (abs>0x7C00) ? half(detail::binary, detail::signal(arg.data_)) : arg; + if(abs <= 0x2900) + return half(detail::binary, detail::rounded(arg.data_-1, 1, 1)); + if(half::round_style != std::round_to_nearest) + switch(abs) + { + case 0x32D4: return half(detail::binary, detail::rounded(arg.data_-13, 1, 1)); + case 0x3B5B: return half(detail::binary, detail::rounded(arg.data_-197, 1, 1)); + } + return half(detail::binary, detail::area(arg.data_)); + #endif + } + + /// Hyperbolic area cosine. + /// This function is exact to rounding for all rounding modes. + /// + /// **See also:** Documentation for [std::acosh](https://en.cppreference.com/w/cpp/numeric/math/acosh). + /// \param arg function argument + /// \return area cosine value of \a arg + /// \exception FE_INVALID for signaling NaN or arguments <1 + /// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding + inline half acosh(half arg) + { + #if defined(HALF_ARITHMETIC_TYPE) && HALF_ENABLE_CPP11_CMATH + return half(detail::binary, detail::float2half(std::acosh(detail::half2float(arg.data_)))); + #else + int abs = arg.data_ & 0x7FFF; + if((arg.data_&0x8000) || abs < 0x3C00) + return half(detail::binary, (abs<=0x7C00) ? detail::invalid() : detail::signal(arg.data_)); + if(abs == 0x3C00) + return half(detail::binary, 0); + if(arg.data_ >= 0x7C00) + return (abs>0x7C00) ? half(detail::binary, detail::signal(arg.data_)) : arg; + return half(detail::binary, detail::area(arg.data_)); + #endif + } + + /// Hyperbolic area tangent. + /// This function is exact to rounding for all rounding modes. + /// + /// **See also:** Documentation for [std::atanh](https://en.cppreference.com/w/cpp/numeric/math/atanh). + /// \param arg function argument + /// \return area tangent value of \a arg + /// \exception FE_INVALID for signaling NaN or if abs(\a arg) > 1 + /// \exception FE_DIVBYZERO for +/-1 + /// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding + inline half atanh(half arg) + { + #if defined(HALF_ARITHMETIC_TYPE) && HALF_ENABLE_CPP11_CMATH + return half(detail::binary, detail::float2half(std::atanh(detail::half2float(arg.data_)))); + #else + int abs = arg.data_ & 0x7FFF, exp = 0; + if(!abs) + return arg; + if(abs >= 0x3C00) + return half(detail::binary, (abs==0x3C00) ? detail::pole(arg.data_&0x8000) : (abs<=0x7C00) ? detail::invalid() : detail::signal(arg.data_)); + if(abs < 0x2700) + return half(detail::binary, detail::rounded(arg.data_, 0, 1)); + detail::uint32 m = static_cast((abs&0x3FF)|((abs>0x3FF)<<10)) << ((abs>>10)+(abs<=0x3FF)+6), my = 0x80000000 + m, mx = 0x80000000 - m; + for(; mx<0x80000000; mx<<=1,++exp) ; + int i = my >= mx, s; + return half(detail::binary, detail::log2_post(detail::log2( + (detail::divide64(my>>i, mx, s)+1)>>1, 27)+0x10, exp+i-1, 16, arg.data_&0x8000)); + #endif + } + + /// \} + /// \anchor special + /// \name Error and gamma functions + /// \{ + + /// Error function. + /// This function may be 1 ULP off the correctly rounded exact result for any rounding mode in <0.5% of inputs. + /// + /// **See also:** Documentation for [std::erf](https://en.cppreference.com/w/cpp/numeric/math/erf). + /// \param arg function argument + /// \return error function value of \a arg + /// \exception FE_INVALID for signaling NaN + /// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding + inline half erf(half arg) + { + #if defined(HALF_ARITHMETIC_TYPE) && HALF_ENABLE_CPP11_CMATH + return half(detail::binary, detail::float2half(std::erf(detail::half2float(arg.data_)))); + #else + unsigned int abs = arg.data_ & 0x7FFF; + if(!abs || abs >= 0x7C00) + return (abs>=0x7C00) ? half(detail::binary, (abs==0x7C00) ? (arg.data_-0x4000) : detail::signal(arg.data_)) : arg; + if(abs >= 0x4200) + return half(detail::binary, detail::rounded((arg.data_&0x8000)|0x3BFF, 1, 1)); + return half(detail::binary, detail::erf(arg.data_)); + #endif + } + + /// Complementary error function. + /// This function may be 1 ULP off the correctly rounded exact result for any rounding mode in <0.5% of inputs. + /// + /// **See also:** Documentation for [std::erfc](https://en.cppreference.com/w/cpp/numeric/math/erfc). + /// \param arg function argument + /// \return 1 minus error function value of \a arg + /// \exception FE_INVALID for signaling NaN + /// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding + inline half erfc(half arg) + { + #if defined(HALF_ARITHMETIC_TYPE) && HALF_ENABLE_CPP11_CMATH + return half(detail::binary, detail::float2half(std::erfc(detail::half2float(arg.data_)))); + #else + unsigned int abs = arg.data_ & 0x7FFF, sign = arg.data_ & 0x8000; + if(abs >= 0x7C00) + return (abs>=0x7C00) ? half(detail::binary, (abs==0x7C00) ? (sign>>1) : detail::signal(arg.data_)) : arg; + if(!abs) + return half(detail::binary, 0x3C00); + if(abs >= 0x4400) + return half(detail::binary, detail::rounded((sign>>1)-(sign>>15), sign>>15, 1)); + return half(detail::binary, detail::erf(arg.data_)); + #endif + } + + /// Natural logarithm of gamma function. + /// This function may be 1 ULP off the correctly rounded exact result for any rounding mode in ~0.025% of inputs. + /// + /// **See also:** Documentation for [std::lgamma](https://en.cppreference.com/w/cpp/numeric/math/lgamma). + /// \param arg function argument + /// \return natural logarith of gamma function for \a arg + /// \exception FE_INVALID for signaling NaN + /// \exception FE_DIVBYZERO for 0 or negative integer arguments + /// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding + inline half lgamma(half arg) + { + #if defined(HALF_ARITHMETIC_TYPE) && HALF_ENABLE_CPP11_CMATH + return half(detail::binary, detail::float2half(std::lgamma(detail::half2float(arg.data_)))); + #else + int abs = arg.data_ & 0x7FFF; + if(abs >= 0x7C00) + return half(detail::binary, (abs==0x7C00) ? 0x7C00 : detail::signal(arg.data_)); + if(!abs || arg.data_ >= 0xE400 || (arg.data_ >= 0xBC00 && !(abs&((1<<(25-(abs>>10)))-1)))) + return half(detail::binary, detail::pole()); + if(arg.data_ == 0x3C00 || arg.data_ == 0x4000) + return half(detail::binary, 0); + return half(detail::binary, detail::gamma(arg.data_)); + #endif + } + + /// Gamma function. + /// This function may be 1 ULP off the correctly rounded exact result for any rounding mode in <0.25% of inputs. + /// + /// **See also:** Documentation for [std::tgamma](https://en.cppreference.com/w/cpp/numeric/math/tgamma). + /// \param arg function argument + /// \return gamma function value of \a arg + /// \exception FE_INVALID for signaling NaN, negative infinity or negative integer arguments + /// \exception FE_DIVBYZERO for 0 + /// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding + inline half tgamma(half arg) + { + #if defined(HALF_ARITHMETIC_TYPE) && HALF_ENABLE_CPP11_CMATH + return half(detail::binary, detail::float2half(std::tgamma(detail::half2float(arg.data_)))); + #else + unsigned int abs = arg.data_ & 0x7FFF; + if(!abs) + return half(detail::binary, detail::pole(arg.data_)); + if(abs >= 0x7C00) + return (arg.data_==0x7C00) ? arg : half(detail::binary, detail::signal(arg.data_)); + if(arg.data_ >= 0xE400 || (arg.data_ >= 0xBC00 && !(abs&((1<<(25-(abs>>10)))-1)))) + return half(detail::binary, detail::invalid()); + if(arg.data_ >= 0xCA80) + return half(detail::binary, detail::underflow((1-((abs>>(25-(abs>>10)))&1))<<15)); + if(arg.data_ <= 0x100 || (arg.data_ >= 0x4900 && arg.data_ < 0x8000)) + return half(detail::binary, detail::overflow()); + if(arg.data_ == 0x3C00) + return arg; + return half(detail::binary, detail::gamma(arg.data_)); + #endif + } + + /// \} + /// \anchor rounding + /// \name Rounding + /// \{ + + /// Nearest integer not less than half value. + /// **See also:** Documentation for [std::ceil](https://en.cppreference.com/w/cpp/numeric/math/ceil). + /// \param arg half to round + /// \return nearest integer not less than \a arg + /// \exception FE_INVALID for signaling NaN + /// \exception FE_INEXACT if value had to be rounded + inline half ceil(half arg) { return half(detail::binary, detail::integral(arg.data_)); } + + /// Nearest integer not greater than half value. + /// **See also:** Documentation for [std::floor](https://en.cppreference.com/w/cpp/numeric/math/floor). + /// \param arg half to round + /// \return nearest integer not greater than \a arg + /// \exception FE_INVALID for signaling NaN + /// \exception FE_INEXACT if value had to be rounded + inline half floor(half arg) { return half(detail::binary, detail::integral(arg.data_)); } + + /// Nearest integer not greater in magnitude than half value. + /// **See also:** Documentation for [std::trunc](https://en.cppreference.com/w/cpp/numeric/math/trunc). + /// \param arg half to round + /// \return nearest integer not greater in magnitude than \a arg + /// \exception FE_INVALID for signaling NaN + /// \exception FE_INEXACT if value had to be rounded + inline half trunc(half arg) { return half(detail::binary, detail::integral(arg.data_)); } + + /// Nearest integer. + /// **See also:** Documentation for [std::round](https://en.cppreference.com/w/cpp/numeric/math/round). + /// \param arg half to round + /// \return nearest integer, rounded away from zero in half-way cases + /// \exception FE_INVALID for signaling NaN + /// \exception FE_INEXACT if value had to be rounded + inline half round(half arg) { return half(detail::binary, detail::integral(arg.data_)); } + + /// Nearest integer. + /// **See also:** Documentation for [std::lround](https://en.cppreference.com/w/cpp/numeric/math/round). + /// \param arg half to round + /// \return nearest integer, rounded away from zero in half-way cases + /// \exception FE_INVALID if value is not representable as `long` + inline long lround(half arg) { return detail::half2int(arg.data_); } + + /// Nearest integer using half's internal rounding mode. + /// **See also:** Documentation for [std::rint](https://en.cppreference.com/w/cpp/numeric/math/rint). + /// \param arg half expression to round + /// \return nearest integer using default rounding mode + /// \exception FE_INVALID for signaling NaN + /// \exception FE_INEXACT if value had to be rounded + inline half rint(half arg) { return half(detail::binary, detail::integral(arg.data_)); } + + /// Nearest integer using half's internal rounding mode. + /// **See also:** Documentation for [std::lrint](https://en.cppreference.com/w/cpp/numeric/math/rint). + /// \param arg half expression to round + /// \return nearest integer using default rounding mode + /// \exception FE_INVALID if value is not representable as `long` + /// \exception FE_INEXACT if value had to be rounded + inline long lrint(half arg) { return detail::half2int(arg.data_); } + + /// Nearest integer using half's internal rounding mode. + /// **See also:** Documentation for [std::nearbyint](https://en.cppreference.com/w/cpp/numeric/math/nearbyint). + /// \param arg half expression to round + /// \return nearest integer using default rounding mode + /// \exception FE_INVALID for signaling NaN + inline half nearbyint(half arg) { return half(detail::binary, detail::integral(arg.data_)); } +#if HALF_ENABLE_CPP11_LONG_LONG + /// Nearest integer. + /// **See also:** Documentation for [std::llround](https://en.cppreference.com/w/cpp/numeric/math/round). + /// \param arg half to round + /// \return nearest integer, rounded away from zero in half-way cases + /// \exception FE_INVALID if value is not representable as `long long` + inline long long llround(half arg) { return detail::half2int(arg.data_); } + + /// Nearest integer using half's internal rounding mode. + /// **See also:** Documentation for [std::llrint](https://en.cppreference.com/w/cpp/numeric/math/rint). + /// \param arg half expression to round + /// \return nearest integer using default rounding mode + /// \exception FE_INVALID if value is not representable as `long long` + /// \exception FE_INEXACT if value had to be rounded + inline long long llrint(half arg) { return detail::half2int(arg.data_); } +#endif + + /// \} + /// \anchor float + /// \name Floating point manipulation + /// \{ + + /// Decompress floating-point number. + /// **See also:** Documentation for [std::frexp](https://en.cppreference.com/w/cpp/numeric/math/frexp). + /// \param arg number to decompress + /// \param exp address to store exponent at + /// \return significant in range [0.5, 1) + /// \exception FE_INVALID for signaling NaN + inline half frexp(half arg, int *exp) + { + *exp = 0; + unsigned int abs = arg.data_ & 0x7FFF; + if(abs >= 0x7C00 || !abs) + return (abs>0x7C00) ? half(detail::binary, detail::signal(arg.data_)) : arg; + for(; abs<0x400; abs<<=1,--*exp) ; + *exp += (abs>>10) - 14; + return half(detail::binary, (arg.data_&0x8000)|0x3800|(abs&0x3FF)); + } + + /// Multiply by power of two. + /// This function is exact to rounding for all rounding modes. + /// + /// **See also:** Documentation for [std::scalbln](https://en.cppreference.com/w/cpp/numeric/math/scalbn). + /// \param arg number to modify + /// \param exp power of two to multiply with + /// \return \a arg multplied by 2 raised to \a exp + /// \exception FE_INVALID for signaling NaN + /// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding + inline half scalbln(half arg, long exp) + { + unsigned int abs = arg.data_ & 0x7FFF, sign = arg.data_ & 0x8000; + if(abs >= 0x7C00 || !abs) + return (abs>0x7C00) ? half(detail::binary, detail::signal(arg.data_)) : arg; + for(; abs<0x400; abs<<=1,--exp) ; + exp += abs >> 10; + if(exp > 30) + return half(detail::binary, detail::overflow(sign)); + else if(exp < -10) + return half(detail::binary, detail::underflow(sign)); + else if(exp > 0) + return half(detail::binary, sign|(exp<<10)|(abs&0x3FF)); + unsigned int m = (abs&0x3FF) | 0x400; + return half(detail::binary, detail::rounded(sign|(m>>(1-exp)), (m>>-exp)&1, (m&((1<<-exp)-1))!=0)); + } + + /// Multiply by power of two. + /// This function is exact to rounding for all rounding modes. + /// + /// **See also:** Documentation for [std::scalbn](https://en.cppreference.com/w/cpp/numeric/math/scalbn). + /// \param arg number to modify + /// \param exp power of two to multiply with + /// \return \a arg multplied by 2 raised to \a exp + /// \exception FE_INVALID for signaling NaN + /// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding + inline half scalbn(half arg, int exp) { return scalbln(arg, exp); } + + /// Multiply by power of two. + /// This function is exact to rounding for all rounding modes. + /// + /// **See also:** Documentation for [std::ldexp](https://en.cppreference.com/w/cpp/numeric/math/ldexp). + /// \param arg number to modify + /// \param exp power of two to multiply with + /// \return \a arg multplied by 2 raised to \a exp + /// \exception FE_INVALID for signaling NaN + /// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding + inline half ldexp(half arg, int exp) { return scalbln(arg, exp); } + + /// Extract integer and fractional parts. + /// **See also:** Documentation for [std::modf](https://en.cppreference.com/w/cpp/numeric/math/modf). + /// \param arg number to decompress + /// \param iptr address to store integer part at + /// \return fractional part + /// \exception FE_INVALID for signaling NaN + inline half modf(half arg, half *iptr) + { + unsigned int abs = arg.data_ & 0x7FFF; + if(abs > 0x7C00) + { + arg = half(detail::binary, detail::signal(arg.data_)); + return *iptr = arg, arg; + } + if(abs >= 0x6400) + return *iptr = arg, half(detail::binary, arg.data_&0x8000); + if(abs < 0x3C00) + return iptr->data_ = arg.data_ & 0x8000, arg; + unsigned int exp = abs >> 10, mask = (1<<(25-exp)) - 1, m = arg.data_ & mask; + iptr->data_ = arg.data_ & ~mask; + if(!m) + return half(detail::binary, arg.data_&0x8000); + for(; m<0x400; m<<=1,--exp) ; + return half(detail::binary, (arg.data_&0x8000)|(exp<<10)|(m&0x3FF)); + } + + /// Extract exponent. + /// **See also:** Documentation for [std::ilogb](https://en.cppreference.com/w/cpp/numeric/math/ilogb). + /// \param arg number to query + /// \return floating-point exponent + /// \retval FP_ILOGB0 for zero + /// \retval FP_ILOGBNAN for NaN + /// \retval INT_MAX for infinity + /// \exception FE_INVALID for 0 or infinite values + inline int ilogb(half arg) + { + int abs = arg.data_ & 0x7FFF, exp; + if(!abs || abs >= 0x7C00) + { + detail::raise(FE_INVALID); + return !abs ? FP_ILOGB0 : (abs==0x7C00) ? INT_MAX : FP_ILOGBNAN; + } + for(exp=(abs>>10)-15; abs<0x200; abs<<=1,--exp) ; + return exp; + } + + /// Extract exponent. + /// **See also:** Documentation for [std::logb](https://en.cppreference.com/w/cpp/numeric/math/logb). + /// \param arg number to query + /// \return floating-point exponent + /// \exception FE_INVALID for signaling NaN + /// \exception FE_DIVBYZERO for 0 + inline half logb(half arg) + { + int abs = arg.data_ & 0x7FFF, exp; + if(!abs) + return half(detail::binary, detail::pole(0x8000)); + if(abs >= 0x7C00) + return half(detail::binary, (abs==0x7C00) ? 0x7C00 : detail::signal(arg.data_)); + for(exp=(abs>>10)-15; abs<0x200; abs<<=1,--exp) ; + unsigned int value = static_cast(exp<0) << 15; + if(exp) + { + unsigned int m = std::abs(exp) << 6; + for(exp=18; m<0x400; m<<=1,--exp) ; + value |= (exp<<10) + m; + } + return half(detail::binary, value); + } + + /// Next representable value. + /// **See also:** Documentation for [std::nextafter](https://en.cppreference.com/w/cpp/numeric/math/nextafter). + /// \param from value to compute next representable value for + /// \param to direction towards which to compute next value + /// \return next representable value after \a from in direction towards \a to + /// \exception FE_INVALID for signaling NaN + /// \exception FE_OVERFLOW for infinite result from finite argument + /// \exception FE_UNDERFLOW for subnormal result + inline half nextafter(half from, half to) + { + int fabs = from.data_ & 0x7FFF, tabs = to.data_ & 0x7FFF; + if(fabs > 0x7C00 || tabs > 0x7C00) + return half(detail::binary, detail::signal(from.data_, to.data_)); + if(from.data_ == to.data_ || !(fabs|tabs)) + return to; + if(!fabs) + { + detail::raise(FE_UNDERFLOW, !HALF_ERRHANDLING_UNDERFLOW_TO_INEXACT); + return half(detail::binary, (to.data_&0x8000)+1); + } + unsigned int out = from.data_ + (((from.data_>>15)^static_cast( + (from.data_^(0x8000|(0x8000-(from.data_>>15))))<(to.data_^(0x8000|(0x8000-(to.data_>>15))))))<<1) - 1; + detail::raise(FE_OVERFLOW, fabs<0x7C00 && (out&0x7C00)==0x7C00); + detail::raise(FE_UNDERFLOW, !HALF_ERRHANDLING_UNDERFLOW_TO_INEXACT && (out&0x7C00)<0x400); + return half(detail::binary, out); + } + + /// Next representable value. + /// **See also:** Documentation for [std::nexttoward](https://en.cppreference.com/w/cpp/numeric/math/nexttoward). + /// \param from value to compute next representable value for + /// \param to direction towards which to compute next value + /// \return next representable value after \a from in direction towards \a to + /// \exception FE_INVALID for signaling NaN + /// \exception FE_OVERFLOW for infinite result from finite argument + /// \exception FE_UNDERFLOW for subnormal result + inline half nexttoward(half from, long double to) + { + int fabs = from.data_ & 0x7FFF; + if(fabs > 0x7C00) + return half(detail::binary, detail::signal(from.data_)); + long double lfrom = static_cast(from); + if(detail::builtin_isnan(to) || lfrom == to) + return half(static_cast(to)); + if(!fabs) + { + detail::raise(FE_UNDERFLOW, !HALF_ERRHANDLING_UNDERFLOW_TO_INEXACT); + return half(detail::binary, (static_cast(detail::builtin_signbit(to))<<15)+1); + } + unsigned int out = from.data_ + (((from.data_>>15)^static_cast(lfrom 0x7C00; } + + /// Check if normal number. + /// **See also:** Documentation for [std::isnormal](https://en.cppreference.com/w/cpp/numeric/math/isnormal). + /// \param arg number to check + /// \retval true if normal number + /// \retval false if either subnormal, zero, infinity or NaN + inline HALF_CONSTEXPR bool isnormal(half arg) { return ((arg.data_&0x7C00)!=0) & ((arg.data_&0x7C00)!=0x7C00); } + + /// Check sign. + /// **See also:** Documentation for [std::signbit](https://en.cppreference.com/w/cpp/numeric/math/signbit). + /// \param arg number to check + /// \retval true for negative number + /// \retval false for positive number + inline HALF_CONSTEXPR bool signbit(half arg) { return (arg.data_&0x8000) != 0; } + + /// \} + /// \anchor compfunc + /// \name Comparison + /// \{ + + /// Quiet comparison for greater than. + /// **See also:** Documentation for [std::isgreater](https://en.cppreference.com/w/cpp/numeric/math/isgreater). + /// \param x first operand + /// \param y second operand + /// \retval true if \a x greater than \a y + /// \retval false else + inline HALF_CONSTEXPR bool isgreater(half x, half y) + { + return ((x.data_^(0x8000|(0x8000-(x.data_>>15))))+(x.data_>>15)) > ((y.data_^(0x8000|(0x8000-(y.data_>>15))))+(y.data_>>15)) && !isnan(x) && !isnan(y); + } + + /// Quiet comparison for greater equal. + /// **See also:** Documentation for [std::isgreaterequal](https://en.cppreference.com/w/cpp/numeric/math/isgreaterequal). + /// \param x first operand + /// \param y second operand + /// \retval true if \a x greater equal \a y + /// \retval false else + inline HALF_CONSTEXPR bool isgreaterequal(half x, half y) + { + return ((x.data_^(0x8000|(0x8000-(x.data_>>15))))+(x.data_>>15)) >= ((y.data_^(0x8000|(0x8000-(y.data_>>15))))+(y.data_>>15)) && !isnan(x) && !isnan(y); + } + + /// Quiet comparison for less than. + /// **See also:** Documentation for [std::isless](https://en.cppreference.com/w/cpp/numeric/math/isless). + /// \param x first operand + /// \param y second operand + /// \retval true if \a x less than \a y + /// \retval false else + inline HALF_CONSTEXPR bool isless(half x, half y) + { + return ((x.data_^(0x8000|(0x8000-(x.data_>>15))))+(x.data_>>15)) < ((y.data_^(0x8000|(0x8000-(y.data_>>15))))+(y.data_>>15)) && !isnan(x) && !isnan(y); + } + + /// Quiet comparison for less equal. + /// **See also:** Documentation for [std::islessequal](https://en.cppreference.com/w/cpp/numeric/math/islessequal). + /// \param x first operand + /// \param y second operand + /// \retval true if \a x less equal \a y + /// \retval false else + inline HALF_CONSTEXPR bool islessequal(half x, half y) + { + return ((x.data_^(0x8000|(0x8000-(x.data_>>15))))+(x.data_>>15)) <= ((y.data_^(0x8000|(0x8000-(y.data_>>15))))+(y.data_>>15)) && !isnan(x) && !isnan(y); + } + + /// Quiet comarison for less or greater. + /// **See also:** Documentation for [std::islessgreater](https://en.cppreference.com/w/cpp/numeric/math/islessgreater). + /// \param x first operand + /// \param y second operand + /// \retval true if either less or greater + /// \retval false else + inline HALF_CONSTEXPR bool islessgreater(half x, half y) + { + return x.data_!=y.data_ && ((x.data_|y.data_)&0x7FFF) && !isnan(x) && !isnan(y); + } + + /// Quiet check if unordered. + /// **See also:** Documentation for [std::isunordered](https://en.cppreference.com/w/cpp/numeric/math/isunordered). + /// \param x first operand + /// \param y second operand + /// \retval true if unordered (one or two NaN operands) + /// \retval false else + inline HALF_CONSTEXPR bool isunordered(half x, half y) { return isnan(x) || isnan(y); } + + /// \} + /// \anchor casting + /// \name Casting + /// \{ + + /// Cast to or from half-precision floating-point number. + /// This casts between [half](\ref half_float::half) and any built-in arithmetic type. The values are converted + /// directly using the default rounding mode, without any roundtrip over `float` that a `static_cast` would otherwise do. + /// + /// Using this cast with neither of the two types being a [half](\ref half_float::half) or with any of the two types + /// not being a built-in arithmetic type (apart from [half](\ref half_float::half), of course) results in a compiler + /// error and casting between [half](\ref half_float::half)s returns the argument unmodified. + /// \tparam T destination type (half or built-in arithmetic type) + /// \tparam U source type (half or built-in arithmetic type) + /// \param arg value to cast + /// \return \a arg converted to destination type + /// \exception FE_INVALID if \a T is integer type and result is not representable as \a T + /// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding + template T half_cast(U arg) { return detail::half_caster::cast(arg); } + + /// Cast to or from half-precision floating-point number. + /// This casts between [half](\ref half_float::half) and any built-in arithmetic type. The values are converted + /// directly using the specified rounding mode, without any roundtrip over `float` that a `static_cast` would otherwise do. + /// + /// Using this cast with neither of the two types being a [half](\ref half_float::half) or with any of the two types + /// not being a built-in arithmetic type (apart from [half](\ref half_float::half), of course) results in a compiler + /// error and casting between [half](\ref half_float::half)s returns the argument unmodified. + /// \tparam T destination type (half or built-in arithmetic type) + /// \tparam R rounding mode to use. + /// \tparam U source type (half or built-in arithmetic type) + /// \param arg value to cast + /// \return \a arg converted to destination type + /// \exception FE_INVALID if \a T is integer type and result is not representable as \a T + /// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding + template T half_cast(U arg) { return detail::half_caster::cast(arg); } + /// \} + + /// \} + /// \anchor errors + /// \name Error handling + /// \{ + + /// Clear exception flags. + /// This function works even if [automatic exception flag handling](\ref HALF_ERRHANDLING_FLAGS) is disabled, + /// but in that case manual flag management is the only way to raise flags. + /// + /// **See also:** Documentation for [std::feclearexcept](https://en.cppreference.com/w/cpp/numeric/fenv/feclearexcept). + /// \param excepts OR of exceptions to clear + /// \retval 0 all selected flags cleared successfully + inline int feclearexcept(int excepts) { detail::errflags() &= ~excepts; return 0; } + + /// Test exception flags. + /// This function works even if [automatic exception flag handling](\ref HALF_ERRHANDLING_FLAGS) is disabled, + /// but in that case manual flag management is the only way to raise flags. + /// + /// **See also:** Documentation for [std::fetestexcept](https://en.cppreference.com/w/cpp/numeric/fenv/fetestexcept). + /// \param excepts OR of exceptions to test + /// \return OR of selected exceptions if raised + inline int fetestexcept(int excepts) { return detail::errflags() & excepts; } + + /// Raise exception flags. + /// This raises the specified floating point exceptions and also invokes any additional automatic exception handling as + /// configured with the [HALF_ERRHANDLIG_...](\ref HALF_ERRHANDLING_ERRNO) preprocessor symbols. + /// This function works even if [automatic exception flag handling](\ref HALF_ERRHANDLING_FLAGS) is disabled, + /// but in that case manual flag management is the only way to raise flags. + /// + /// **See also:** Documentation for [std::feraiseexcept](https://en.cppreference.com/w/cpp/numeric/fenv/feraiseexcept). + /// \param excepts OR of exceptions to raise + /// \retval 0 all selected exceptions raised successfully + inline int feraiseexcept(int excepts) { detail::errflags() |= excepts; detail::raise(excepts); return 0; } + + /// Save exception flags. + /// This function works even if [automatic exception flag handling](\ref HALF_ERRHANDLING_FLAGS) is disabled, + /// but in that case manual flag management is the only way to raise flags. + /// + /// **See also:** Documentation for [std::fegetexceptflag](https://en.cppreference.com/w/cpp/numeric/fenv/feexceptflag). + /// \param flagp adress to store flag state at + /// \param excepts OR of flags to save + /// \retval 0 for success + inline int fegetexceptflag(int *flagp, int excepts) { *flagp = detail::errflags() & excepts; return 0; } + + /// Restore exception flags. + /// This only copies the specified exception state (including unset flags) without incurring any additional exception handling. + /// This function works even if [automatic exception flag handling](\ref HALF_ERRHANDLING_FLAGS) is disabled, + /// but in that case manual flag management is the only way to raise flags. + /// + /// **See also:** Documentation for [std::fesetexceptflag](https://en.cppreference.com/w/cpp/numeric/fenv/feexceptflag). + /// \param flagp adress to take flag state from + /// \param excepts OR of flags to restore + /// \retval 0 for success + inline int fesetexceptflag(const int *flagp, int excepts) { detail::errflags() = (detail::errflags()|(*flagp&excepts)) & (*flagp|~excepts); return 0; } + + /// Throw C++ exceptions based on set exception flags. + /// This function manually throws a corresponding C++ exception if one of the specified flags is set, + /// no matter if automatic throwing (via [HALF_ERRHANDLING_THROW_...](\ref HALF_ERRHANDLING_THROW_INVALID)) is enabled or not. + /// This function works even if [automatic exception flag handling](\ref HALF_ERRHANDLING_FLAGS) is disabled, + /// but in that case manual flag management is the only way to raise flags. + /// \param excepts OR of exceptions to test + /// \param msg error message to use for exception description + /// \throw std::domain_error if `FE_INVALID` or `FE_DIVBYZERO` is selected and set + /// \throw std::overflow_error if `FE_OVERFLOW` is selected and set + /// \throw std::underflow_error if `FE_UNDERFLOW` is selected and set + /// \throw std::range_error if `FE_INEXACT` is selected and set + inline void fethrowexcept(int excepts, const char *msg = "") + { + excepts &= detail::errflags(); + if(excepts & (FE_INVALID|FE_DIVBYZERO)) + throw std::domain_error(msg); + if(excepts & FE_OVERFLOW) + throw std::overflow_error(msg); + if(excepts & FE_UNDERFLOW) + throw std::underflow_error(msg); + if(excepts & FE_INEXACT) + throw std::range_error(msg); + } + /// \} +} + + +#undef HALF_UNUSED_NOERR +#undef HALF_CONSTEXPR +#undef HALF_CONSTEXPR_CONST +#undef HALF_CONSTEXPR_NOERR +#undef HALF_NOEXCEPT +#undef HALF_NOTHROW +#undef HALF_THREAD_LOCAL +#undef HALF_TWOS_COMPLEMENT_INT +#ifdef HALF_POP_WARNINGS + #pragma warning(pop) + #undef HALF_POP_WARNINGS +#endif + +#endif diff --git a/src/tools/nnfusion/templates/dxcompute/Direct3DXBoxNN/runtime/pch.cpp b/src/tools/nnfusion/templates/dxcompute/Direct3DXBoxNN/runtime/pch.cpp deleted file mode 100644 index 97b544ec1..000000000 --- a/src/tools/nnfusion/templates/dxcompute/Direct3DXBoxNN/runtime/pch.cpp +++ /dev/null @@ -1,6 +0,0 @@ -// -// pch.cpp -// Include the standard header and generate the precompiled header. -// - -#include "pch.h" diff --git a/src/tools/nnfusion/templates/dxcompute/Direct3DXBoxNN/runtime/resource.h b/src/tools/nnfusion/templates/dxcompute/Direct3DXBoxNN/runtime/resource.h deleted file mode 100644 index 0bf240968..000000000 --- a/src/tools/nnfusion/templates/dxcompute/Direct3DXBoxNN/runtime/resource.h +++ /dev/null @@ -1,14 +0,0 @@ -//{{NO_DEPENDENCIES}} -// Microsoft Visual C++ generated include file. -// Used by nnf_xbox_example.rc - -// Next default values for new objects -// -#ifdef APSTUDIO_INVOKED -#ifndef APSTUDIO_READONLY_SYMBOLS -#define _APS_NEXT_RESOURCE_VALUE 101 -#define _APS_NEXT_COMMAND_VALUE 40001 -#define _APS_NEXT_CONTROL_VALUE 1001 -#define _APS_NEXT_SYMED_VALUE 101 -#endif -#endif diff --git a/src/tools/nnfusion/templates/dxcompute/Direct3DXBoxNN/runtime/runtime.vcxproj b/src/tools/nnfusion/templates/dxcompute/Direct3DXBoxNN/runtime/runtime.vcxproj index 0d759c310..cf5f02404 100644 --- a/src/tools/nnfusion/templates/dxcompute/Direct3DXBoxNN/runtime/runtime.vcxproj +++ b/src/tools/nnfusion/templates/dxcompute/Direct3DXBoxNN/runtime/runtime.vcxproj @@ -27,18 +27,9 @@ - - - - - - - - - @@ -196,11 +187,12 @@ - $(Console_Libs);%(XboxExtensionsDependencies);%(AdditionalDependencies) + $(Console_Libs);antares.lib;%(XboxExtensionsDependencies);%(AdditionalDependencies) true Windows true true + ..\$(IntDir);%(AdditionalLibraryDirectories) NotUsing @@ -211,6 +203,7 @@ true true true + ..\antares @@ -234,11 +227,12 @@ - $(Console_Libs);%(XboxExtensionsDependencies);%(AdditionalDependencies) + $(Console_Libs);antares.lib;%(XboxExtensionsDependencies);%(AdditionalDependencies) true Windows true true + ..\$(IntDir);%(AdditionalLibraryDirectories) NotUsing @@ -249,6 +243,7 @@ true true true + ..\antares @@ -269,9 +264,10 @@ - $(Console_Libs);%(XboxExtensionsDependencies);%(AdditionalDependencies) + $(Console_Libs);antares.lib;%(XboxExtensionsDependencies);%(AdditionalDependencies) Windows true + ..\$(IntDir);%(AdditionalLibraryDirectories) pch.h @@ -281,6 +277,7 @@ Disabled _DEBUG;RUNTIME_EXPORTS;_USRDLL;%(PreprocessorDefinitions) true + ..\antares diff --git a/src/tools/nnfusion/templates/dxcompute/Direct3DXBoxNN/runtime/runtime.vcxproj.filters b/src/tools/nnfusion/templates/dxcompute/Direct3DXBoxNN/runtime/runtime.vcxproj.filters index e48812cf7..dd018aaf9 100644 --- a/src/tools/nnfusion/templates/dxcompute/Direct3DXBoxNN/runtime/runtime.vcxproj.filters +++ b/src/tools/nnfusion/templates/dxcompute/Direct3DXBoxNN/runtime/runtime.vcxproj.filters @@ -15,38 +15,11 @@ - - Header Files - - - Header Files - - - Header Files - - - Header Files - - - Header Files - Header Files - - Header Files - - - Source Files - - - Source Files - - - Source Files - Source Files diff --git a/src/tools/nnfusion/templates/dxcompute/Direct3DXBoxNN/runtime/update.bat b/src/tools/nnfusion/templates/dxcompute/Direct3DXBoxNN/runtime/update.bat deleted file mode 100644 index aa139794d..000000000 --- a/src/tools/nnfusion/templates/dxcompute/Direct3DXBoxNN/runtime/update.bat +++ /dev/null @@ -1,13 +0,0 @@ -@echo off - -echo update D3D12APIWrapper.h -curl -LOs https://github.com/microsoft/antares/blob/master/backends/c-hlsl/evaluator/AntaresHlslLib/D3D12APIWrapper.h - -echo update D3D12APIWrapper.cpp -curl -LOs https://github.com/microsoft/antares/blob/master/backends/c-hlsl/evaluator/AntaresHlslLib/D3D12APIWrapper.cpp - -echo update D3D12Antares.h -curl -LOs https://github.com/microsoft/antares/blob/master/backends/c-hlsl/evaluator/AntaresHlslLib/D3D12Antares.h - -echo finished! -pause diff --git a/src/tools/nnfusion/templates/dxcompute/DxCompute.vcxproj b/src/tools/nnfusion/templates/dxcompute/DxCompute.vcxproj deleted file mode 100644 index 23b8799e2..000000000 --- a/src/tools/nnfusion/templates/dxcompute/DxCompute.vcxproj +++ /dev/null @@ -1,65 +0,0 @@ - - - - - Release - x64 - - - - 16.0 - {38E1DCE3-5652-4456-BDD2-DE0CD06805AC} - Win32Proj - DxCompute - 10.0 - - - - Application - false - v142 - true - Unicode - - - - - - - - - - - - false - - - - - - Level3 - true - true - true - NDEBUG;_CONSOLE;%(PreprocessorDefinitions) - true - - - Console - true - true - true - - - - - - - - - - - - - - diff --git a/src/tools/nnfusion/templates/dxcompute/d3dx12_helper.h b/src/tools/nnfusion/templates/dxcompute/d3dx12_helper.h deleted file mode 100644 index 38413d4e5..000000000 --- a/src/tools/nnfusion/templates/dxcompute/d3dx12_helper.h +++ /dev/null @@ -1,3637 +0,0 @@ -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT License. - -#define _CRT_RAND_S -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#pragma once - -#pragma comment(lib, "d3d12.lib") -#pragma comment(lib, "dxgi.lib") -#pragma comment(lib, "d3dcompiler.lib") - -using namespace std; -using namespace Microsoft::WRL; - - -#define IFE(x) ((FAILED(x)) ? (printf("Error-line: (%s) %d\n", __FILE__, __LINE__), _exit(1), 0): 1) - - -inline void IFB(HRESULT hr) -{ - if (FAILED(hr)) - { - DebugBreak(); - } -} - -inline const D3D12_COMMAND_QUEUE_DESC D3D12CommandQueueDesc(D3D12_COMMAND_LIST_TYPE type, D3D12_COMMAND_QUEUE_FLAGS flags = D3D12_COMMAND_QUEUE_FLAG_NONE, UINT nodeMask = 0, INT priority = 0) -{ - D3D12_COMMAND_QUEUE_DESC desc = { - type, - priority, - flags, - nodeMask - }; - return desc; -} - -inline const D3D12_HEAP_PROPERTIES D3D12HeapProperties( - D3D12_HEAP_TYPE heapType, - D3D12_CPU_PAGE_PROPERTY pageProperty = D3D12_CPU_PAGE_PROPERTY_UNKNOWN, - D3D12_MEMORY_POOL memoryPoolType = D3D12_MEMORY_POOL_UNKNOWN, - UINT creationNodeMask = 0, - UINT visibleNodeMask = 0 -) -{ - D3D12_HEAP_PROPERTIES heapProperties = { - heapType, - pageProperty, - memoryPoolType, - creationNodeMask, - visibleNodeMask - }; - return heapProperties; -} - -inline const D3D12_RESOURCE_DESC D3D12BufferResourceDesc( - UINT64 width, - UINT height = 1, - D3D12_RESOURCE_FLAGS flags = D3D12_RESOURCE_FLAG_NONE, - UINT64 alignment = 0 -) -{ - D3D12_RESOURCE_DESC resourceDesc = { - D3D12_RESOURCE_DIMENSION_BUFFER, - alignment, - width, - height, - 1, - 1, - DXGI_FORMAT_UNKNOWN, - {1, 0}, - D3D12_TEXTURE_LAYOUT_ROW_MAJOR, - flags - }; - - return resourceDesc; -} - -struct D3DDevice -{ - ComPtr pDxgiFactory; - ComPtr pDevice; - ComPtr pCommandQueue; - ComPtr pCommandAllocator; - ComPtr pFence; - HANDLE event; - uint64_t fenceValue = 0; - bool bEnableDebugLayer = false; - bool bEnableGPUValidation = false; - D3DDevice(bool EnableDebugLayer = false, bool EnableGPUValidation = false) - { - bEnableDebugLayer = EnableDebugLayer; - bEnableGPUValidation = EnableGPUValidation; - } - void InitD3DDevice() - { - IFE(CreateDXGIFactory1(IID_PPV_ARGS(&pDxgiFactory))); - - if (D3D12CreateDevice(nullptr, D3D_FEATURE_LEVEL_11_0, IID_PPV_ARGS(&pDevice))) - { - ComPtr pAdapter; - IFE(pDxgiFactory->EnumWarpAdapter(IID_PPV_ARGS(&pAdapter))); - IFE(D3D12CreateDevice(pAdapter.Get(), D3D_FEATURE_LEVEL_11_0, IID_PPV_ARGS(&pDevice))); - } - } - - void Init() - { - // Enable debug layer - ComPtr pDebug; - if (bEnableDebugLayer && SUCCEEDED(D3D12GetDebugInterface(IID_PPV_ARGS(&pDebug)))) - { - pDebug->EnableDebugLayer(); - - ComPtr pDebug1; - if (bEnableGPUValidation && SUCCEEDED((pDebug->QueryInterface(IID_PPV_ARGS(&pDebug1))))) - { - pDebug1->SetEnableGPUBasedValidation(true); - } - } - - InitD3DDevice(); - - // Create a command queue - D3D12_COMMAND_QUEUE_DESC commandQueueDesc = D3D12CommandQueueDesc(D3D12_COMMAND_LIST_TYPE_COMPUTE); - IFE(pDevice->CreateCommandQueue(&commandQueueDesc, IID_PPV_ARGS(&pCommandQueue))); - - // Create a command allocator - IFE(pDevice->CreateCommandAllocator(D3D12_COMMAND_LIST_TYPE_COMPUTE, IID_PPV_ARGS(&pCommandAllocator))); - - // Create a CPU-GPU synchronization event - event = CreateEvent(nullptr, FALSE, FALSE, nullptr); - - // Create a fence to allow GPU to signal upon completion of execution - IFE(pDevice->CreateFence(fenceValue, D3D12_FENCE_FLAG_SHARED, IID_PPV_ARGS(&pFence))); - } - - void AwaitExecution() - { - ++fenceValue; - IFE(pCommandQueue->Signal(pFence.Get(), fenceValue)); - - IFE(pFence->SetEventOnCompletion(fenceValue, event)); - - DWORD retVal = WaitForSingleObject(event, INFINITE); - if (retVal != WAIT_OBJECT_0) - { - DebugBreak(); - } - } - - inline void CreateCommittedResource( - const D3D12_HEAP_PROPERTIES& heapProperties, - const D3D12_RESOURCE_DESC& resourceDesc, - D3D12_RESOURCE_STATES initialState, - ID3D12Resource** ppResource - ) - { - IFE(pDevice->CreateCommittedResource( - &heapProperties, - D3D12_HEAP_FLAG_NONE, - &resourceDesc, - initialState, - nullptr, - IID_PPV_ARGS(ppResource) - )); - } - inline void CreateGPUOnlyResource(UINT64 size, ID3D12Resource** ppResource) - { - CreateCommittedResource( - D3D12HeapProperties(D3D12_HEAP_TYPE_DEFAULT), - D3D12BufferResourceDesc(size, 1, D3D12_RESOURCE_FLAG_ALLOW_UNORDERED_ACCESS), - D3D12_RESOURCE_STATE_COMMON, - ppResource - ); - } - inline void CreateUploadBuffer(UINT64 size, ID3D12Resource** ppResource) - { - CreateCommittedResource( - D3D12HeapProperties(D3D12_HEAP_TYPE_UPLOAD), - D3D12BufferResourceDesc(size), - D3D12_RESOURCE_STATE_GENERIC_READ, - ppResource - ); - } - - inline void CreateReadbackBuffer(UINT64 size, ID3D12Resource** ppResource) - { - CreateCommittedResource( - D3D12HeapProperties(D3D12_HEAP_TYPE_READBACK), - D3D12BufferResourceDesc(size), - D3D12_RESOURCE_STATE_COPY_DEST, - ppResource - ); - } - - inline void CreateDefaultBuffer(UINT64 size, ID3D12Resource** ppResource) - { - CreateCommittedResource( - D3D12HeapProperties(D3D12_HEAP_TYPE_DEFAULT), - D3D12BufferResourceDesc(size, 1, D3D12_RESOURCE_FLAG_ALLOW_UNORDERED_ACCESS), - D3D12_RESOURCE_STATE_COMMON, - ppResource - ); - } - - void MapAndCopyToResource(ID3D12Resource* pResource, const void* pSrc, UINT64 numBytes) - { - D3D12_RANGE range = { 0, static_cast(numBytes) }; - void* pData; - IFE(pResource->Map(0, &range, reinterpret_cast(&pData))); - memcpy(pData, pSrc, static_cast(numBytes)); - pResource->Unmap(0, &range); - } - - void MapCopyFromResource(ID3D12Resource* pResource, void* pDest, UINT64 numBytes) - { - D3D12_RANGE range = { 0, static_cast(numBytes) }; - void* pData; - IFE(pResource->Map(0, &range, reinterpret_cast(&pData))); - memcpy(pDest, pData, static_cast(numBytes)); - range.End = 0; - pResource->Unmap(0, &range); - } -}; - - -//********************************************************* -// -// Copyright (c) Microsoft. All rights reserved. -// This code is licensed under the MIT License (MIT). -// THIS CODE IS PROVIDED *AS IS* WITHOUT WARRANTY OF -// ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING ANY -// IMPLIED WARRANTIES OF FITNESS FOR A PARTICULAR -// PURPOSE, MERCHANTABILITY, OR NON-INFRINGEMENT. -// -//********************************************************* - -#ifndef __D3DX12_H__ -#define __D3DX12_H__ - -#include "d3d12.h" - -#if defined( __cplusplus ) - -struct CD3DX12_DEFAULT {}; -extern const DECLSPEC_SELECTANY CD3DX12_DEFAULT D3D12_DEFAULT; - -inline bool operator==(const D3D12_VIEWPORT& l, const D3D12_VIEWPORT& r) -{ - return l.TopLeftX == r.TopLeftX && l.TopLeftY == r.TopLeftY && l.Width == r.Width && - l.Height == r.Height && l.MinDepth == r.MinDepth && l.MaxDepth == r.MaxDepth; -} - -inline bool operator!=(const D3D12_VIEWPORT& l, const D3D12_VIEWPORT& r) -{ - return !(l == r); -} - -struct CD3DX12_RECT : public D3D12_RECT -{ - CD3DX12_RECT() = default; - explicit CD3DX12_RECT(const D3D12_RECT& o) : - D3D12_RECT(o) - {} - explicit CD3DX12_RECT( - LONG Left, - LONG Top, - LONG Right, - LONG Bottom) - { - left = Left; - top = Top; - right = Right; - bottom = Bottom; - } -}; - -struct CD3DX12_VIEWPORT : public D3D12_VIEWPORT -{ - CD3DX12_VIEWPORT() = default; - explicit CD3DX12_VIEWPORT(const D3D12_VIEWPORT& o) : - D3D12_VIEWPORT(o) - {} - explicit CD3DX12_VIEWPORT( - FLOAT topLeftX, - FLOAT topLeftY, - FLOAT width, - FLOAT height, - FLOAT minDepth = D3D12_MIN_DEPTH, - FLOAT maxDepth = D3D12_MAX_DEPTH) - { - TopLeftX = topLeftX; - TopLeftY = topLeftY; - Width = width; - Height = height; - MinDepth = minDepth; - MaxDepth = maxDepth; - } - explicit CD3DX12_VIEWPORT( - _In_ ID3D12Resource* pResource, - UINT mipSlice = 0, - FLOAT topLeftX = 0.0f, - FLOAT topLeftY = 0.0f, - FLOAT minDepth = D3D12_MIN_DEPTH, - FLOAT maxDepth = D3D12_MAX_DEPTH) - { - auto Desc = pResource->GetDesc(); - const UINT64 SubresourceWidth = Desc.Width >> mipSlice; - const UINT64 SubresourceHeight = Desc.Height >> mipSlice; - switch (Desc.Dimension) - { - case D3D12_RESOURCE_DIMENSION_BUFFER: - TopLeftX = topLeftX; - TopLeftY = 0.0f; - Width = Desc.Width - topLeftX; - Height = 1.0f; - break; - case D3D12_RESOURCE_DIMENSION_TEXTURE1D: - TopLeftX = topLeftX; - TopLeftY = 0.0f; - Width = (SubresourceWidth ? SubresourceWidth : 1.0f) - topLeftX; - Height = 1.0f; - break; - case D3D12_RESOURCE_DIMENSION_TEXTURE2D: - case D3D12_RESOURCE_DIMENSION_TEXTURE3D: - TopLeftX = topLeftX; - TopLeftY = topLeftY; - Width = (SubresourceWidth ? SubresourceWidth : 1.0f) - topLeftX; - Height = (SubresourceHeight ? SubresourceHeight : 1.0f) - topLeftY; - break; - default: break; - } - - MinDepth = minDepth; - MaxDepth = maxDepth; - } -}; - -struct CD3DX12_BOX : public D3D12_BOX -{ - CD3DX12_BOX() = default; - explicit CD3DX12_BOX(const D3D12_BOX& o) : - D3D12_BOX(o) - {} - explicit CD3DX12_BOX( - LONG Left, - LONG Right) - { - left = static_cast(Left); - top = 0; - front = 0; - right = static_cast(Right); - bottom = 1; - back = 1; - } - explicit CD3DX12_BOX( - LONG Left, - LONG Top, - LONG Right, - LONG Bottom) - { - left = static_cast(Left); - top = static_cast(Top); - front = 0; - right = static_cast(Right); - bottom = static_cast(Bottom); - back = 1; - } - explicit CD3DX12_BOX( - LONG Left, - LONG Top, - LONG Front, - LONG Right, - LONG Bottom, - LONG Back) - { - left = static_cast(Left); - top = static_cast(Top); - front = static_cast(Front); - right = static_cast(Right); - bottom = static_cast(Bottom); - back = static_cast(Back); - } -}; -inline bool operator==(const D3D12_BOX& l, const D3D12_BOX& r) -{ - return l.left == r.left && l.top == r.top && l.front == r.front && - l.right == r.right && l.bottom == r.bottom && l.back == r.back; -} -inline bool operator!=(const D3D12_BOX& l, const D3D12_BOX& r) -{ - return !(l == r); -} - -struct CD3DX12_DEPTH_STENCIL_DESC : public D3D12_DEPTH_STENCIL_DESC -{ - CD3DX12_DEPTH_STENCIL_DESC() = default; - explicit CD3DX12_DEPTH_STENCIL_DESC(const D3D12_DEPTH_STENCIL_DESC& o) : - D3D12_DEPTH_STENCIL_DESC(o) - {} - explicit CD3DX12_DEPTH_STENCIL_DESC(CD3DX12_DEFAULT) - { - DepthEnable = TRUE; - DepthWriteMask = D3D12_DEPTH_WRITE_MASK_ALL; - DepthFunc = D3D12_COMPARISON_FUNC_LESS; - StencilEnable = FALSE; - StencilReadMask = D3D12_DEFAULT_STENCIL_READ_MASK; - StencilWriteMask = D3D12_DEFAULT_STENCIL_WRITE_MASK; - const D3D12_DEPTH_STENCILOP_DESC defaultStencilOp = - { D3D12_STENCIL_OP_KEEP, D3D12_STENCIL_OP_KEEP, D3D12_STENCIL_OP_KEEP, D3D12_COMPARISON_FUNC_ALWAYS }; - FrontFace = defaultStencilOp; - BackFace = defaultStencilOp; - } - explicit CD3DX12_DEPTH_STENCIL_DESC( - BOOL depthEnable, - D3D12_DEPTH_WRITE_MASK depthWriteMask, - D3D12_COMPARISON_FUNC depthFunc, - BOOL stencilEnable, - UINT8 stencilReadMask, - UINT8 stencilWriteMask, - D3D12_STENCIL_OP frontStencilFailOp, - D3D12_STENCIL_OP frontStencilDepthFailOp, - D3D12_STENCIL_OP frontStencilPassOp, - D3D12_COMPARISON_FUNC frontStencilFunc, - D3D12_STENCIL_OP backStencilFailOp, - D3D12_STENCIL_OP backStencilDepthFailOp, - D3D12_STENCIL_OP backStencilPassOp, - D3D12_COMPARISON_FUNC backStencilFunc) - { - DepthEnable = depthEnable; - DepthWriteMask = depthWriteMask; - DepthFunc = depthFunc; - StencilEnable = stencilEnable; - StencilReadMask = stencilReadMask; - StencilWriteMask = stencilWriteMask; - FrontFace.StencilFailOp = frontStencilFailOp; - FrontFace.StencilDepthFailOp = frontStencilDepthFailOp; - FrontFace.StencilPassOp = frontStencilPassOp; - FrontFace.StencilFunc = frontStencilFunc; - BackFace.StencilFailOp = backStencilFailOp; - BackFace.StencilDepthFailOp = backStencilDepthFailOp; - BackFace.StencilPassOp = backStencilPassOp; - BackFace.StencilFunc = backStencilFunc; - } -}; - -struct CD3DX12_DEPTH_STENCIL_DESC1 : public D3D12_DEPTH_STENCIL_DESC1 -{ - CD3DX12_DEPTH_STENCIL_DESC1() = default; - explicit CD3DX12_DEPTH_STENCIL_DESC1(const D3D12_DEPTH_STENCIL_DESC1& o) : - D3D12_DEPTH_STENCIL_DESC1(o) - {} - explicit CD3DX12_DEPTH_STENCIL_DESC1(const D3D12_DEPTH_STENCIL_DESC& o) - { - DepthEnable = o.DepthEnable; - DepthWriteMask = o.DepthWriteMask; - DepthFunc = o.DepthFunc; - StencilEnable = o.StencilEnable; - StencilReadMask = o.StencilReadMask; - StencilWriteMask = o.StencilWriteMask; - FrontFace.StencilFailOp = o.FrontFace.StencilFailOp; - FrontFace.StencilDepthFailOp = o.FrontFace.StencilDepthFailOp; - FrontFace.StencilPassOp = o.FrontFace.StencilPassOp; - FrontFace.StencilFunc = o.FrontFace.StencilFunc; - BackFace.StencilFailOp = o.BackFace.StencilFailOp; - BackFace.StencilDepthFailOp = o.BackFace.StencilDepthFailOp; - BackFace.StencilPassOp = o.BackFace.StencilPassOp; - BackFace.StencilFunc = o.BackFace.StencilFunc; - DepthBoundsTestEnable = FALSE; - } - explicit CD3DX12_DEPTH_STENCIL_DESC1(CD3DX12_DEFAULT) - { - DepthEnable = TRUE; - DepthWriteMask = D3D12_DEPTH_WRITE_MASK_ALL; - DepthFunc = D3D12_COMPARISON_FUNC_LESS; - StencilEnable = FALSE; - StencilReadMask = D3D12_DEFAULT_STENCIL_READ_MASK; - StencilWriteMask = D3D12_DEFAULT_STENCIL_WRITE_MASK; - const D3D12_DEPTH_STENCILOP_DESC defaultStencilOp = - { D3D12_STENCIL_OP_KEEP, D3D12_STENCIL_OP_KEEP, D3D12_STENCIL_OP_KEEP, D3D12_COMPARISON_FUNC_ALWAYS }; - FrontFace = defaultStencilOp; - BackFace = defaultStencilOp; - DepthBoundsTestEnable = FALSE; - } - explicit CD3DX12_DEPTH_STENCIL_DESC1( - BOOL depthEnable, - D3D12_DEPTH_WRITE_MASK depthWriteMask, - D3D12_COMPARISON_FUNC depthFunc, - BOOL stencilEnable, - UINT8 stencilReadMask, - UINT8 stencilWriteMask, - D3D12_STENCIL_OP frontStencilFailOp, - D3D12_STENCIL_OP frontStencilDepthFailOp, - D3D12_STENCIL_OP frontStencilPassOp, - D3D12_COMPARISON_FUNC frontStencilFunc, - D3D12_STENCIL_OP backStencilFailOp, - D3D12_STENCIL_OP backStencilDepthFailOp, - D3D12_STENCIL_OP backStencilPassOp, - D3D12_COMPARISON_FUNC backStencilFunc, - BOOL depthBoundsTestEnable) - { - DepthEnable = depthEnable; - DepthWriteMask = depthWriteMask; - DepthFunc = depthFunc; - StencilEnable = stencilEnable; - StencilReadMask = stencilReadMask; - StencilWriteMask = stencilWriteMask; - FrontFace.StencilFailOp = frontStencilFailOp; - FrontFace.StencilDepthFailOp = frontStencilDepthFailOp; - FrontFace.StencilPassOp = frontStencilPassOp; - FrontFace.StencilFunc = frontStencilFunc; - BackFace.StencilFailOp = backStencilFailOp; - BackFace.StencilDepthFailOp = backStencilDepthFailOp; - BackFace.StencilPassOp = backStencilPassOp; - BackFace.StencilFunc = backStencilFunc; - DepthBoundsTestEnable = depthBoundsTestEnable; - } - operator D3D12_DEPTH_STENCIL_DESC() const - { - D3D12_DEPTH_STENCIL_DESC D; - D.DepthEnable = DepthEnable; - D.DepthWriteMask = DepthWriteMask; - D.DepthFunc = DepthFunc; - D.StencilEnable = StencilEnable; - D.StencilReadMask = StencilReadMask; - D.StencilWriteMask = StencilWriteMask; - D.FrontFace.StencilFailOp = FrontFace.StencilFailOp; - D.FrontFace.StencilDepthFailOp = FrontFace.StencilDepthFailOp; - D.FrontFace.StencilPassOp = FrontFace.StencilPassOp; - D.FrontFace.StencilFunc = FrontFace.StencilFunc; - D.BackFace.StencilFailOp = BackFace.StencilFailOp; - D.BackFace.StencilDepthFailOp = BackFace.StencilDepthFailOp; - D.BackFace.StencilPassOp = BackFace.StencilPassOp; - D.BackFace.StencilFunc = BackFace.StencilFunc; - return D; - } -}; - -struct CD3DX12_BLEND_DESC : public D3D12_BLEND_DESC -{ - CD3DX12_BLEND_DESC() = default; - explicit CD3DX12_BLEND_DESC(const D3D12_BLEND_DESC& o) : - D3D12_BLEND_DESC(o) - {} - explicit CD3DX12_BLEND_DESC(CD3DX12_DEFAULT) - { - AlphaToCoverageEnable = FALSE; - IndependentBlendEnable = FALSE; - const D3D12_RENDER_TARGET_BLEND_DESC defaultRenderTargetBlendDesc = - { - FALSE,FALSE, - D3D12_BLEND_ONE, D3D12_BLEND_ZERO, D3D12_BLEND_OP_ADD, - D3D12_BLEND_ONE, D3D12_BLEND_ZERO, D3D12_BLEND_OP_ADD, - D3D12_LOGIC_OP_NOOP, - D3D12_COLOR_WRITE_ENABLE_ALL, - }; - for (UINT i = 0; i < D3D12_SIMULTANEOUS_RENDER_TARGET_COUNT; ++i) - RenderTarget[i] = defaultRenderTargetBlendDesc; - } -}; - -struct CD3DX12_RASTERIZER_DESC : public D3D12_RASTERIZER_DESC -{ - CD3DX12_RASTERIZER_DESC() = default; - explicit CD3DX12_RASTERIZER_DESC(const D3D12_RASTERIZER_DESC& o) : - D3D12_RASTERIZER_DESC(o) - {} - explicit CD3DX12_RASTERIZER_DESC(CD3DX12_DEFAULT) - { - FillMode = D3D12_FILL_MODE_SOLID; - CullMode = D3D12_CULL_MODE_BACK; - FrontCounterClockwise = FALSE; - DepthBias = D3D12_DEFAULT_DEPTH_BIAS; - DepthBiasClamp = D3D12_DEFAULT_DEPTH_BIAS_CLAMP; - SlopeScaledDepthBias = D3D12_DEFAULT_SLOPE_SCALED_DEPTH_BIAS; - DepthClipEnable = TRUE; - MultisampleEnable = FALSE; - AntialiasedLineEnable = FALSE; - ForcedSampleCount = 0; - ConservativeRaster = D3D12_CONSERVATIVE_RASTERIZATION_MODE_OFF; - } - explicit CD3DX12_RASTERIZER_DESC( - D3D12_FILL_MODE fillMode, - D3D12_CULL_MODE cullMode, - BOOL frontCounterClockwise, - INT depthBias, - FLOAT depthBiasClamp, - FLOAT slopeScaledDepthBias, - BOOL depthClipEnable, - BOOL multisampleEnable, - BOOL antialiasedLineEnable, - UINT forcedSampleCount, - D3D12_CONSERVATIVE_RASTERIZATION_MODE conservativeRaster) - { - FillMode = fillMode; - CullMode = cullMode; - FrontCounterClockwise = frontCounterClockwise; - DepthBias = depthBias; - DepthBiasClamp = depthBiasClamp; - SlopeScaledDepthBias = slopeScaledDepthBias; - DepthClipEnable = depthClipEnable; - MultisampleEnable = multisampleEnable; - AntialiasedLineEnable = antialiasedLineEnable; - ForcedSampleCount = forcedSampleCount; - ConservativeRaster = conservativeRaster; - } -}; - -struct CD3DX12_RESOURCE_ALLOCATION_INFO : public D3D12_RESOURCE_ALLOCATION_INFO -{ - CD3DX12_RESOURCE_ALLOCATION_INFO() = default; - explicit CD3DX12_RESOURCE_ALLOCATION_INFO(const D3D12_RESOURCE_ALLOCATION_INFO& o) : - D3D12_RESOURCE_ALLOCATION_INFO(o) - {} - CD3DX12_RESOURCE_ALLOCATION_INFO( - UINT64 size, - UINT64 alignment) - { - SizeInBytes = size; - Alignment = alignment; - } -}; - -struct CD3DX12_HEAP_PROPERTIES : public D3D12_HEAP_PROPERTIES -{ - CD3DX12_HEAP_PROPERTIES() = default; - explicit CD3DX12_HEAP_PROPERTIES(const D3D12_HEAP_PROPERTIES& o) : - D3D12_HEAP_PROPERTIES(o) - {} - CD3DX12_HEAP_PROPERTIES( - D3D12_CPU_PAGE_PROPERTY cpuPageProperty, - D3D12_MEMORY_POOL memoryPoolPreference, - UINT creationNodeMask = 1, - UINT nodeMask = 1) - { - Type = D3D12_HEAP_TYPE_CUSTOM; - CPUPageProperty = cpuPageProperty; - MemoryPoolPreference = memoryPoolPreference; - CreationNodeMask = creationNodeMask; - VisibleNodeMask = nodeMask; - } - explicit CD3DX12_HEAP_PROPERTIES( - D3D12_HEAP_TYPE type, - UINT creationNodeMask = 1, - UINT nodeMask = 1) - { - Type = type; - CPUPageProperty = D3D12_CPU_PAGE_PROPERTY_UNKNOWN; - MemoryPoolPreference = D3D12_MEMORY_POOL_UNKNOWN; - CreationNodeMask = creationNodeMask; - VisibleNodeMask = nodeMask; - } - bool IsCPUAccessible() const - { - return Type == D3D12_HEAP_TYPE_UPLOAD || Type == D3D12_HEAP_TYPE_READBACK || (Type == D3D12_HEAP_TYPE_CUSTOM && - (CPUPageProperty == D3D12_CPU_PAGE_PROPERTY_WRITE_COMBINE || CPUPageProperty == D3D12_CPU_PAGE_PROPERTY_WRITE_BACK)); - } -}; -inline bool operator==(const D3D12_HEAP_PROPERTIES& l, const D3D12_HEAP_PROPERTIES& r) -{ - return l.Type == r.Type && l.CPUPageProperty == r.CPUPageProperty && - l.MemoryPoolPreference == r.MemoryPoolPreference && - l.CreationNodeMask == r.CreationNodeMask && - l.VisibleNodeMask == r.VisibleNodeMask; -} -inline bool operator!=(const D3D12_HEAP_PROPERTIES& l, const D3D12_HEAP_PROPERTIES& r) -{ - return !(l == r); -} - -struct CD3DX12_HEAP_DESC : public D3D12_HEAP_DESC -{ - CD3DX12_HEAP_DESC() = default; - explicit CD3DX12_HEAP_DESC(const D3D12_HEAP_DESC& o) : - D3D12_HEAP_DESC(o) - {} - CD3DX12_HEAP_DESC( - UINT64 size, - D3D12_HEAP_PROPERTIES properties, - UINT64 alignment = 0, - D3D12_HEAP_FLAGS flags = D3D12_HEAP_FLAG_NONE) - { - SizeInBytes = size; - Properties = properties; - Alignment = alignment; - Flags = flags; - } - CD3DX12_HEAP_DESC( - UINT64 size, - D3D12_HEAP_TYPE type, - UINT64 alignment = 0, - D3D12_HEAP_FLAGS flags = D3D12_HEAP_FLAG_NONE) - { - SizeInBytes = size; - Properties = CD3DX12_HEAP_PROPERTIES(type); - Alignment = alignment; - Flags = flags; - } - CD3DX12_HEAP_DESC( - UINT64 size, - D3D12_CPU_PAGE_PROPERTY cpuPageProperty, - D3D12_MEMORY_POOL memoryPoolPreference, - UINT64 alignment = 0, - D3D12_HEAP_FLAGS flags = D3D12_HEAP_FLAG_NONE) - { - SizeInBytes = size; - Properties = CD3DX12_HEAP_PROPERTIES(cpuPageProperty, memoryPoolPreference); - Alignment = alignment; - Flags = flags; - } - CD3DX12_HEAP_DESC( - const D3D12_RESOURCE_ALLOCATION_INFO& resAllocInfo, - D3D12_HEAP_PROPERTIES properties, - D3D12_HEAP_FLAGS flags = D3D12_HEAP_FLAG_NONE) - { - SizeInBytes = resAllocInfo.SizeInBytes; - Properties = properties; - Alignment = resAllocInfo.Alignment; - Flags = flags; - } - CD3DX12_HEAP_DESC( - const D3D12_RESOURCE_ALLOCATION_INFO& resAllocInfo, - D3D12_HEAP_TYPE type, - D3D12_HEAP_FLAGS flags = D3D12_HEAP_FLAG_NONE) - { - SizeInBytes = resAllocInfo.SizeInBytes; - Properties = CD3DX12_HEAP_PROPERTIES(type); - Alignment = resAllocInfo.Alignment; - Flags = flags; - } - CD3DX12_HEAP_DESC( - const D3D12_RESOURCE_ALLOCATION_INFO& resAllocInfo, - D3D12_CPU_PAGE_PROPERTY cpuPageProperty, - D3D12_MEMORY_POOL memoryPoolPreference, - D3D12_HEAP_FLAGS flags = D3D12_HEAP_FLAG_NONE) - { - SizeInBytes = resAllocInfo.SizeInBytes; - Properties = CD3DX12_HEAP_PROPERTIES(cpuPageProperty, memoryPoolPreference); - Alignment = resAllocInfo.Alignment; - Flags = flags; - } - bool IsCPUAccessible() const - { - return static_cast(&Properties)->IsCPUAccessible(); - } -}; -inline bool operator==(const D3D12_HEAP_DESC& l, const D3D12_HEAP_DESC& r) -{ - return l.SizeInBytes == r.SizeInBytes && - l.Properties == r.Properties && - l.Alignment == r.Alignment && - l.Flags == r.Flags; -} -inline bool operator!=(const D3D12_HEAP_DESC& l, const D3D12_HEAP_DESC& r) -{ - return !(l == r); -} - -struct CD3DX12_CLEAR_VALUE : public D3D12_CLEAR_VALUE -{ - CD3DX12_CLEAR_VALUE() = default; - explicit CD3DX12_CLEAR_VALUE(const D3D12_CLEAR_VALUE& o) : - D3D12_CLEAR_VALUE(o) - {} - CD3DX12_CLEAR_VALUE( - DXGI_FORMAT format, - const FLOAT color[4]) - { - Format = format; - memcpy(Color, color, sizeof(Color)); - } - CD3DX12_CLEAR_VALUE( - DXGI_FORMAT format, - FLOAT depth, - UINT8 stencil) - { - Format = format; - memset(&Color, 0, sizeof(Color)); - /* Use memcpy to preserve NAN values */ - memcpy(&DepthStencil.Depth, &depth, sizeof(depth)); - DepthStencil.Stencil = stencil; - } -}; - -struct CD3DX12_RANGE : public D3D12_RANGE -{ - CD3DX12_RANGE() = default; - explicit CD3DX12_RANGE(const D3D12_RANGE& o) : - D3D12_RANGE(o) - {} - CD3DX12_RANGE( - SIZE_T begin, - SIZE_T end) - { - Begin = begin; - End = end; - } -}; - -struct CD3DX12_RANGE_UINT64 : public D3D12_RANGE_UINT64 -{ - CD3DX12_RANGE_UINT64() = default; - explicit CD3DX12_RANGE_UINT64(const D3D12_RANGE_UINT64& o) : - D3D12_RANGE_UINT64(o) - {} - CD3DX12_RANGE_UINT64( - UINT64 begin, - UINT64 end) - { - Begin = begin; - End = end; - } -}; - -struct CD3DX12_SUBRESOURCE_RANGE_UINT64 : public D3D12_SUBRESOURCE_RANGE_UINT64 -{ - CD3DX12_SUBRESOURCE_RANGE_UINT64() = default; - explicit CD3DX12_SUBRESOURCE_RANGE_UINT64(const D3D12_SUBRESOURCE_RANGE_UINT64& o) : - D3D12_SUBRESOURCE_RANGE_UINT64(o) - {} - CD3DX12_SUBRESOURCE_RANGE_UINT64( - UINT subresource, - const D3D12_RANGE_UINT64& range) - { - Subresource = subresource; - Range = range; - } - CD3DX12_SUBRESOURCE_RANGE_UINT64( - UINT subresource, - UINT64 begin, - UINT64 end) - { - Subresource = subresource; - Range.Begin = begin; - Range.End = end; - } -}; - -struct CD3DX12_SHADER_BYTECODE : public D3D12_SHADER_BYTECODE -{ - CD3DX12_SHADER_BYTECODE() = default; - explicit CD3DX12_SHADER_BYTECODE(const D3D12_SHADER_BYTECODE& o) : - D3D12_SHADER_BYTECODE(o) - {} - CD3DX12_SHADER_BYTECODE( - _In_ ID3DBlob* pShaderBlob) - { - pShaderBytecode = pShaderBlob->GetBufferPointer(); - BytecodeLength = pShaderBlob->GetBufferSize(); - } - CD3DX12_SHADER_BYTECODE( - const void* _pShaderBytecode, - SIZE_T bytecodeLength) - { - pShaderBytecode = _pShaderBytecode; - BytecodeLength = bytecodeLength; - } -}; - -struct CD3DX12_TILED_RESOURCE_COORDINATE : public D3D12_TILED_RESOURCE_COORDINATE -{ - CD3DX12_TILED_RESOURCE_COORDINATE() = default; - explicit CD3DX12_TILED_RESOURCE_COORDINATE(const D3D12_TILED_RESOURCE_COORDINATE& o) : - D3D12_TILED_RESOURCE_COORDINATE(o) - {} - CD3DX12_TILED_RESOURCE_COORDINATE( - UINT x, - UINT y, - UINT z, - UINT subresource) - { - X = x; - Y = y; - Z = z; - Subresource = subresource; - } -}; - -struct CD3DX12_TILE_REGION_SIZE : public D3D12_TILE_REGION_SIZE -{ - CD3DX12_TILE_REGION_SIZE() = default; - explicit CD3DX12_TILE_REGION_SIZE(const D3D12_TILE_REGION_SIZE& o) : - D3D12_TILE_REGION_SIZE(o) - {} - CD3DX12_TILE_REGION_SIZE( - UINT numTiles, - BOOL useBox, - UINT width, - UINT16 height, - UINT16 depth) - { - NumTiles = numTiles; - UseBox = useBox; - Width = width; - Height = height; - Depth = depth; - } -}; - -struct CD3DX12_SUBRESOURCE_TILING : public D3D12_SUBRESOURCE_TILING -{ - CD3DX12_SUBRESOURCE_TILING() = default; - explicit CD3DX12_SUBRESOURCE_TILING(const D3D12_SUBRESOURCE_TILING& o) : - D3D12_SUBRESOURCE_TILING(o) - {} - CD3DX12_SUBRESOURCE_TILING( - UINT widthInTiles, - UINT16 heightInTiles, - UINT16 depthInTiles, - UINT startTileIndexInOverallResource) - { - WidthInTiles = widthInTiles; - HeightInTiles = heightInTiles; - DepthInTiles = depthInTiles; - StartTileIndexInOverallResource = startTileIndexInOverallResource; - } -}; - -struct CD3DX12_TILE_SHAPE : public D3D12_TILE_SHAPE -{ - CD3DX12_TILE_SHAPE() = default; - explicit CD3DX12_TILE_SHAPE(const D3D12_TILE_SHAPE& o) : - D3D12_TILE_SHAPE(o) - {} - CD3DX12_TILE_SHAPE( - UINT widthInTexels, - UINT heightInTexels, - UINT depthInTexels) - { - WidthInTexels = widthInTexels; - HeightInTexels = heightInTexels; - DepthInTexels = depthInTexels; - } -}; - -struct CD3DX12_RESOURCE_BARRIER : public D3D12_RESOURCE_BARRIER -{ - CD3DX12_RESOURCE_BARRIER() = default; - explicit CD3DX12_RESOURCE_BARRIER(const D3D12_RESOURCE_BARRIER& o) : - D3D12_RESOURCE_BARRIER(o) - {} - static inline CD3DX12_RESOURCE_BARRIER Transition( - _In_ ID3D12Resource* pResource, - D3D12_RESOURCE_STATES stateBefore, - D3D12_RESOURCE_STATES stateAfter, - UINT subresource = D3D12_RESOURCE_BARRIER_ALL_SUBRESOURCES, - D3D12_RESOURCE_BARRIER_FLAGS flags = D3D12_RESOURCE_BARRIER_FLAG_NONE) - { - CD3DX12_RESOURCE_BARRIER result = {}; - D3D12_RESOURCE_BARRIER& barrier = result; - result.Type = D3D12_RESOURCE_BARRIER_TYPE_TRANSITION; - result.Flags = flags; - barrier.Transition.pResource = pResource; - barrier.Transition.StateBefore = stateBefore; - barrier.Transition.StateAfter = stateAfter; - barrier.Transition.Subresource = subresource; - return result; - } - static inline CD3DX12_RESOURCE_BARRIER Aliasing( - _In_ ID3D12Resource* pResourceBefore, - _In_ ID3D12Resource* pResourceAfter) - { - CD3DX12_RESOURCE_BARRIER result = {}; - D3D12_RESOURCE_BARRIER& barrier = result; - result.Type = D3D12_RESOURCE_BARRIER_TYPE_ALIASING; - barrier.Aliasing.pResourceBefore = pResourceBefore; - barrier.Aliasing.pResourceAfter = pResourceAfter; - return result; - } - static inline CD3DX12_RESOURCE_BARRIER UAV( - _In_ ID3D12Resource* pResource) - { - CD3DX12_RESOURCE_BARRIER result = {}; - D3D12_RESOURCE_BARRIER& barrier = result; - result.Type = D3D12_RESOURCE_BARRIER_TYPE_UAV; - barrier.UAV.pResource = pResource; - return result; - } -}; - -struct CD3DX12_PACKED_MIP_INFO : public D3D12_PACKED_MIP_INFO -{ - CD3DX12_PACKED_MIP_INFO() = default; - explicit CD3DX12_PACKED_MIP_INFO(const D3D12_PACKED_MIP_INFO& o) : - D3D12_PACKED_MIP_INFO(o) - {} - CD3DX12_PACKED_MIP_INFO( - UINT8 numStandardMips, - UINT8 numPackedMips, - UINT numTilesForPackedMips, - UINT startTileIndexInOverallResource) - { - NumStandardMips = numStandardMips; - NumPackedMips = numPackedMips; - NumTilesForPackedMips = numTilesForPackedMips; - StartTileIndexInOverallResource = startTileIndexInOverallResource; - } -}; - -struct CD3DX12_SUBRESOURCE_FOOTPRINT : public D3D12_SUBRESOURCE_FOOTPRINT -{ - CD3DX12_SUBRESOURCE_FOOTPRINT() = default; - explicit CD3DX12_SUBRESOURCE_FOOTPRINT(const D3D12_SUBRESOURCE_FOOTPRINT& o) : - D3D12_SUBRESOURCE_FOOTPRINT(o) - {} - CD3DX12_SUBRESOURCE_FOOTPRINT( - DXGI_FORMAT format, - UINT width, - UINT height, - UINT depth, - UINT rowPitch) - { - Format = format; - Width = width; - Height = height; - Depth = depth; - RowPitch = rowPitch; - } - explicit CD3DX12_SUBRESOURCE_FOOTPRINT( - const D3D12_RESOURCE_DESC& resDesc, - UINT rowPitch) - { - Format = resDesc.Format; - Width = UINT(resDesc.Width); - Height = resDesc.Height; - Depth = (resDesc.Dimension == D3D12_RESOURCE_DIMENSION_TEXTURE3D ? resDesc.DepthOrArraySize : 1); - RowPitch = rowPitch; - } -}; - -struct CD3DX12_TEXTURE_COPY_LOCATION : public D3D12_TEXTURE_COPY_LOCATION -{ - CD3DX12_TEXTURE_COPY_LOCATION() = default; - explicit CD3DX12_TEXTURE_COPY_LOCATION(const D3D12_TEXTURE_COPY_LOCATION& o) : - D3D12_TEXTURE_COPY_LOCATION(o) - {} - CD3DX12_TEXTURE_COPY_LOCATION(_In_ ID3D12Resource* pRes) - { - pResource = pRes; - Type = D3D12_TEXTURE_COPY_TYPE_SUBRESOURCE_INDEX; - PlacedFootprint = {}; - } - CD3DX12_TEXTURE_COPY_LOCATION(_In_ ID3D12Resource* pRes, D3D12_PLACED_SUBRESOURCE_FOOTPRINT const& Footprint) - { - pResource = pRes; - Type = D3D12_TEXTURE_COPY_TYPE_PLACED_FOOTPRINT; - PlacedFootprint = Footprint; - } - CD3DX12_TEXTURE_COPY_LOCATION(_In_ ID3D12Resource* pRes, UINT Sub) - { - pResource = pRes; - Type = D3D12_TEXTURE_COPY_TYPE_SUBRESOURCE_INDEX; - PlacedFootprint = {}; - SubresourceIndex = Sub; - } -}; - -struct CD3DX12_DESCRIPTOR_RANGE : public D3D12_DESCRIPTOR_RANGE -{ - CD3DX12_DESCRIPTOR_RANGE() = default; - explicit CD3DX12_DESCRIPTOR_RANGE(const D3D12_DESCRIPTOR_RANGE& o) : - D3D12_DESCRIPTOR_RANGE(o) - {} - CD3DX12_DESCRIPTOR_RANGE( - D3D12_DESCRIPTOR_RANGE_TYPE rangeType, - UINT numDescriptors, - UINT baseShaderRegister, - UINT registerSpace = 0, - UINT offsetInDescriptorsFromTableStart = - D3D12_DESCRIPTOR_RANGE_OFFSET_APPEND) - { - Init(rangeType, numDescriptors, baseShaderRegister, registerSpace, offsetInDescriptorsFromTableStart); - } - - inline void Init( - D3D12_DESCRIPTOR_RANGE_TYPE rangeType, - UINT numDescriptors, - UINT baseShaderRegister, - UINT registerSpace = 0, - UINT offsetInDescriptorsFromTableStart = - D3D12_DESCRIPTOR_RANGE_OFFSET_APPEND) - { - Init(*this, rangeType, numDescriptors, baseShaderRegister, registerSpace, offsetInDescriptorsFromTableStart); - } - - static inline void Init( - _Out_ D3D12_DESCRIPTOR_RANGE& range, - D3D12_DESCRIPTOR_RANGE_TYPE rangeType, - UINT numDescriptors, - UINT baseShaderRegister, - UINT registerSpace = 0, - UINT offsetInDescriptorsFromTableStart = - D3D12_DESCRIPTOR_RANGE_OFFSET_APPEND) - { - range.RangeType = rangeType; - range.NumDescriptors = numDescriptors; - range.BaseShaderRegister = baseShaderRegister; - range.RegisterSpace = registerSpace; - range.OffsetInDescriptorsFromTableStart = offsetInDescriptorsFromTableStart; - } -}; - -struct CD3DX12_ROOT_DESCRIPTOR_TABLE : public D3D12_ROOT_DESCRIPTOR_TABLE -{ - CD3DX12_ROOT_DESCRIPTOR_TABLE() = default; - explicit CD3DX12_ROOT_DESCRIPTOR_TABLE(const D3D12_ROOT_DESCRIPTOR_TABLE& o) : - D3D12_ROOT_DESCRIPTOR_TABLE(o) - {} - CD3DX12_ROOT_DESCRIPTOR_TABLE( - UINT numDescriptorRanges, - _In_reads_opt_(numDescriptorRanges) const D3D12_DESCRIPTOR_RANGE* _pDescriptorRanges) - { - Init(numDescriptorRanges, _pDescriptorRanges); - } - - inline void Init( - UINT numDescriptorRanges, - _In_reads_opt_(numDescriptorRanges) const D3D12_DESCRIPTOR_RANGE* _pDescriptorRanges) - { - Init(*this, numDescriptorRanges, _pDescriptorRanges); - } - - static inline void Init( - _Out_ D3D12_ROOT_DESCRIPTOR_TABLE& rootDescriptorTable, - UINT numDescriptorRanges, - _In_reads_opt_(numDescriptorRanges) const D3D12_DESCRIPTOR_RANGE* _pDescriptorRanges) - { - rootDescriptorTable.NumDescriptorRanges = numDescriptorRanges; - rootDescriptorTable.pDescriptorRanges = _pDescriptorRanges; - } -}; - -struct CD3DX12_ROOT_CONSTANTS : public D3D12_ROOT_CONSTANTS -{ - CD3DX12_ROOT_CONSTANTS() = default; - explicit CD3DX12_ROOT_CONSTANTS(const D3D12_ROOT_CONSTANTS& o) : - D3D12_ROOT_CONSTANTS(o) - {} - CD3DX12_ROOT_CONSTANTS( - UINT num32BitValues, - UINT shaderRegister, - UINT registerSpace = 0) - { - Init(num32BitValues, shaderRegister, registerSpace); - } - - inline void Init( - UINT num32BitValues, - UINT shaderRegister, - UINT registerSpace = 0) - { - Init(*this, num32BitValues, shaderRegister, registerSpace); - } - - static inline void Init( - _Out_ D3D12_ROOT_CONSTANTS& rootConstants, - UINT num32BitValues, - UINT shaderRegister, - UINT registerSpace = 0) - { - rootConstants.Num32BitValues = num32BitValues; - rootConstants.ShaderRegister = shaderRegister; - rootConstants.RegisterSpace = registerSpace; - } -}; - -struct CD3DX12_ROOT_DESCRIPTOR : public D3D12_ROOT_DESCRIPTOR -{ - CD3DX12_ROOT_DESCRIPTOR() = default; - explicit CD3DX12_ROOT_DESCRIPTOR(const D3D12_ROOT_DESCRIPTOR& o) : - D3D12_ROOT_DESCRIPTOR(o) - {} - CD3DX12_ROOT_DESCRIPTOR( - UINT shaderRegister, - UINT registerSpace = 0) - { - Init(shaderRegister, registerSpace); - } - - inline void Init( - UINT shaderRegister, - UINT registerSpace = 0) - { - Init(*this, shaderRegister, registerSpace); - } - - static inline void Init(_Out_ D3D12_ROOT_DESCRIPTOR& table, UINT shaderRegister, UINT registerSpace = 0) - { - table.ShaderRegister = shaderRegister; - table.RegisterSpace = registerSpace; - } -}; - -struct CD3DX12_ROOT_PARAMETER : public D3D12_ROOT_PARAMETER -{ - CD3DX12_ROOT_PARAMETER() = default; - explicit CD3DX12_ROOT_PARAMETER(const D3D12_ROOT_PARAMETER& o) : - D3D12_ROOT_PARAMETER(o) - {} - - static inline void InitAsDescriptorTable( - _Out_ D3D12_ROOT_PARAMETER& rootParam, - UINT numDescriptorRanges, - _In_reads_(numDescriptorRanges) const D3D12_DESCRIPTOR_RANGE* pDescriptorRanges, - D3D12_SHADER_VISIBILITY visibility = D3D12_SHADER_VISIBILITY_ALL) - { - rootParam.ParameterType = D3D12_ROOT_PARAMETER_TYPE_DESCRIPTOR_TABLE; - rootParam.ShaderVisibility = visibility; - CD3DX12_ROOT_DESCRIPTOR_TABLE::Init(rootParam.DescriptorTable, numDescriptorRanges, pDescriptorRanges); - } - - static inline void InitAsConstants( - _Out_ D3D12_ROOT_PARAMETER& rootParam, - UINT num32BitValues, - UINT shaderRegister, - UINT registerSpace = 0, - D3D12_SHADER_VISIBILITY visibility = D3D12_SHADER_VISIBILITY_ALL) - { - rootParam.ParameterType = D3D12_ROOT_PARAMETER_TYPE_32BIT_CONSTANTS; - rootParam.ShaderVisibility = visibility; - CD3DX12_ROOT_CONSTANTS::Init(rootParam.Constants, num32BitValues, shaderRegister, registerSpace); - } - - static inline void InitAsConstantBufferView( - _Out_ D3D12_ROOT_PARAMETER& rootParam, - UINT shaderRegister, - UINT registerSpace = 0, - D3D12_SHADER_VISIBILITY visibility = D3D12_SHADER_VISIBILITY_ALL) - { - rootParam.ParameterType = D3D12_ROOT_PARAMETER_TYPE_CBV; - rootParam.ShaderVisibility = visibility; - CD3DX12_ROOT_DESCRIPTOR::Init(rootParam.Descriptor, shaderRegister, registerSpace); - } - - static inline void InitAsShaderResourceView( - _Out_ D3D12_ROOT_PARAMETER& rootParam, - UINT shaderRegister, - UINT registerSpace = 0, - D3D12_SHADER_VISIBILITY visibility = D3D12_SHADER_VISIBILITY_ALL) - { - rootParam.ParameterType = D3D12_ROOT_PARAMETER_TYPE_SRV; - rootParam.ShaderVisibility = visibility; - CD3DX12_ROOT_DESCRIPTOR::Init(rootParam.Descriptor, shaderRegister, registerSpace); - } - - static inline void InitAsUnorderedAccessView( - _Out_ D3D12_ROOT_PARAMETER& rootParam, - UINT shaderRegister, - UINT registerSpace = 0, - D3D12_SHADER_VISIBILITY visibility = D3D12_SHADER_VISIBILITY_ALL) - { - rootParam.ParameterType = D3D12_ROOT_PARAMETER_TYPE_UAV; - rootParam.ShaderVisibility = visibility; - CD3DX12_ROOT_DESCRIPTOR::Init(rootParam.Descriptor, shaderRegister, registerSpace); - } - - inline void InitAsDescriptorTable( - UINT numDescriptorRanges, - _In_reads_(numDescriptorRanges) const D3D12_DESCRIPTOR_RANGE* pDescriptorRanges, - D3D12_SHADER_VISIBILITY visibility = D3D12_SHADER_VISIBILITY_ALL) - { - InitAsDescriptorTable(*this, numDescriptorRanges, pDescriptorRanges, visibility); - } - - inline void InitAsConstants( - UINT num32BitValues, - UINT shaderRegister, - UINT registerSpace = 0, - D3D12_SHADER_VISIBILITY visibility = D3D12_SHADER_VISIBILITY_ALL) - { - InitAsConstants(*this, num32BitValues, shaderRegister, registerSpace, visibility); - } - - inline void InitAsConstantBufferView( - UINT shaderRegister, - UINT registerSpace = 0, - D3D12_SHADER_VISIBILITY visibility = D3D12_SHADER_VISIBILITY_ALL) - { - InitAsConstantBufferView(*this, shaderRegister, registerSpace, visibility); - } - - inline void InitAsShaderResourceView( - UINT shaderRegister, - UINT registerSpace = 0, - D3D12_SHADER_VISIBILITY visibility = D3D12_SHADER_VISIBILITY_ALL) - { - InitAsShaderResourceView(*this, shaderRegister, registerSpace, visibility); - } - - inline void InitAsUnorderedAccessView( - UINT shaderRegister, - UINT registerSpace = 0, - D3D12_SHADER_VISIBILITY visibility = D3D12_SHADER_VISIBILITY_ALL) - { - InitAsUnorderedAccessView(*this, shaderRegister, registerSpace, visibility); - } -}; - -struct CD3DX12_STATIC_SAMPLER_DESC : public D3D12_STATIC_SAMPLER_DESC -{ - CD3DX12_STATIC_SAMPLER_DESC() = default; - explicit CD3DX12_STATIC_SAMPLER_DESC(const D3D12_STATIC_SAMPLER_DESC& o) : - D3D12_STATIC_SAMPLER_DESC(o) - {} - CD3DX12_STATIC_SAMPLER_DESC( - UINT shaderRegister, - D3D12_FILTER filter = D3D12_FILTER_ANISOTROPIC, - D3D12_TEXTURE_ADDRESS_MODE addressU = D3D12_TEXTURE_ADDRESS_MODE_WRAP, - D3D12_TEXTURE_ADDRESS_MODE addressV = D3D12_TEXTURE_ADDRESS_MODE_WRAP, - D3D12_TEXTURE_ADDRESS_MODE addressW = D3D12_TEXTURE_ADDRESS_MODE_WRAP, - FLOAT mipLODBias = 0, - UINT maxAnisotropy = 16, - D3D12_COMPARISON_FUNC comparisonFunc = D3D12_COMPARISON_FUNC_LESS_EQUAL, - D3D12_STATIC_BORDER_COLOR borderColor = D3D12_STATIC_BORDER_COLOR_OPAQUE_WHITE, - FLOAT minLOD = 0.f, - FLOAT maxLOD = D3D12_FLOAT32_MAX, - D3D12_SHADER_VISIBILITY shaderVisibility = D3D12_SHADER_VISIBILITY_ALL, - UINT registerSpace = 0) - { - Init( - shaderRegister, - filter, - addressU, - addressV, - addressW, - mipLODBias, - maxAnisotropy, - comparisonFunc, - borderColor, - minLOD, - maxLOD, - shaderVisibility, - registerSpace); - } - - static inline void Init( - _Out_ D3D12_STATIC_SAMPLER_DESC& samplerDesc, - UINT shaderRegister, - D3D12_FILTER filter = D3D12_FILTER_ANISOTROPIC, - D3D12_TEXTURE_ADDRESS_MODE addressU = D3D12_TEXTURE_ADDRESS_MODE_WRAP, - D3D12_TEXTURE_ADDRESS_MODE addressV = D3D12_TEXTURE_ADDRESS_MODE_WRAP, - D3D12_TEXTURE_ADDRESS_MODE addressW = D3D12_TEXTURE_ADDRESS_MODE_WRAP, - FLOAT mipLODBias = 0, - UINT maxAnisotropy = 16, - D3D12_COMPARISON_FUNC comparisonFunc = D3D12_COMPARISON_FUNC_LESS_EQUAL, - D3D12_STATIC_BORDER_COLOR borderColor = D3D12_STATIC_BORDER_COLOR_OPAQUE_WHITE, - FLOAT minLOD = 0.f, - FLOAT maxLOD = D3D12_FLOAT32_MAX, - D3D12_SHADER_VISIBILITY shaderVisibility = D3D12_SHADER_VISIBILITY_ALL, - UINT registerSpace = 0) - { - samplerDesc.ShaderRegister = shaderRegister; - samplerDesc.Filter = filter; - samplerDesc.AddressU = addressU; - samplerDesc.AddressV = addressV; - samplerDesc.AddressW = addressW; - samplerDesc.MipLODBias = mipLODBias; - samplerDesc.MaxAnisotropy = maxAnisotropy; - samplerDesc.ComparisonFunc = comparisonFunc; - samplerDesc.BorderColor = borderColor; - samplerDesc.MinLOD = minLOD; - samplerDesc.MaxLOD = maxLOD; - samplerDesc.ShaderVisibility = shaderVisibility; - samplerDesc.RegisterSpace = registerSpace; - } - inline void Init( - UINT shaderRegister, - D3D12_FILTER filter = D3D12_FILTER_ANISOTROPIC, - D3D12_TEXTURE_ADDRESS_MODE addressU = D3D12_TEXTURE_ADDRESS_MODE_WRAP, - D3D12_TEXTURE_ADDRESS_MODE addressV = D3D12_TEXTURE_ADDRESS_MODE_WRAP, - D3D12_TEXTURE_ADDRESS_MODE addressW = D3D12_TEXTURE_ADDRESS_MODE_WRAP, - FLOAT mipLODBias = 0, - UINT maxAnisotropy = 16, - D3D12_COMPARISON_FUNC comparisonFunc = D3D12_COMPARISON_FUNC_LESS_EQUAL, - D3D12_STATIC_BORDER_COLOR borderColor = D3D12_STATIC_BORDER_COLOR_OPAQUE_WHITE, - FLOAT minLOD = 0.f, - FLOAT maxLOD = D3D12_FLOAT32_MAX, - D3D12_SHADER_VISIBILITY shaderVisibility = D3D12_SHADER_VISIBILITY_ALL, - UINT registerSpace = 0) - { - Init( - *this, - shaderRegister, - filter, - addressU, - addressV, - addressW, - mipLODBias, - maxAnisotropy, - comparisonFunc, - borderColor, - minLOD, - maxLOD, - shaderVisibility, - registerSpace); - } - -}; - -struct CD3DX12_ROOT_SIGNATURE_DESC : public D3D12_ROOT_SIGNATURE_DESC -{ - CD3DX12_ROOT_SIGNATURE_DESC() = default; - explicit CD3DX12_ROOT_SIGNATURE_DESC(const D3D12_ROOT_SIGNATURE_DESC& o) : - D3D12_ROOT_SIGNATURE_DESC(o) - {} - CD3DX12_ROOT_SIGNATURE_DESC( - UINT numParameters, - _In_reads_opt_(numParameters) const D3D12_ROOT_PARAMETER* _pParameters, - UINT numStaticSamplers = 0, - _In_reads_opt_(numStaticSamplers) const D3D12_STATIC_SAMPLER_DESC* _pStaticSamplers = nullptr, - D3D12_ROOT_SIGNATURE_FLAGS flags = D3D12_ROOT_SIGNATURE_FLAG_NONE) - { - Init(numParameters, _pParameters, numStaticSamplers, _pStaticSamplers, flags); - } - CD3DX12_ROOT_SIGNATURE_DESC(CD3DX12_DEFAULT) - { - Init(0, nullptr, 0, nullptr, D3D12_ROOT_SIGNATURE_FLAG_NONE); - } - - inline void Init( - UINT numParameters, - _In_reads_opt_(numParameters) const D3D12_ROOT_PARAMETER* _pParameters, - UINT numStaticSamplers = 0, - _In_reads_opt_(numStaticSamplers) const D3D12_STATIC_SAMPLER_DESC* _pStaticSamplers = nullptr, - D3D12_ROOT_SIGNATURE_FLAGS flags = D3D12_ROOT_SIGNATURE_FLAG_NONE) - { - Init(*this, numParameters, _pParameters, numStaticSamplers, _pStaticSamplers, flags); - } - - static inline void Init( - _Out_ D3D12_ROOT_SIGNATURE_DESC& desc, - UINT numParameters, - _In_reads_opt_(numParameters) const D3D12_ROOT_PARAMETER* _pParameters, - UINT numStaticSamplers = 0, - _In_reads_opt_(numStaticSamplers) const D3D12_STATIC_SAMPLER_DESC* _pStaticSamplers = nullptr, - D3D12_ROOT_SIGNATURE_FLAGS flags = D3D12_ROOT_SIGNATURE_FLAG_NONE) - { - desc.NumParameters = numParameters; - desc.pParameters = _pParameters; - desc.NumStaticSamplers = numStaticSamplers; - desc.pStaticSamplers = _pStaticSamplers; - desc.Flags = flags; - } -}; - -struct CD3DX12_DESCRIPTOR_RANGE1 : public D3D12_DESCRIPTOR_RANGE1 -{ - CD3DX12_DESCRIPTOR_RANGE1() = default; - explicit CD3DX12_DESCRIPTOR_RANGE1(const D3D12_DESCRIPTOR_RANGE1& o) : - D3D12_DESCRIPTOR_RANGE1(o) - {} - CD3DX12_DESCRIPTOR_RANGE1( - D3D12_DESCRIPTOR_RANGE_TYPE rangeType, - UINT numDescriptors, - UINT baseShaderRegister, - UINT registerSpace = 0, - D3D12_DESCRIPTOR_RANGE_FLAGS flags = D3D12_DESCRIPTOR_RANGE_FLAG_NONE, - UINT offsetInDescriptorsFromTableStart = - D3D12_DESCRIPTOR_RANGE_OFFSET_APPEND) - { - Init(rangeType, numDescriptors, baseShaderRegister, registerSpace, flags, offsetInDescriptorsFromTableStart); - } - - inline void Init( - D3D12_DESCRIPTOR_RANGE_TYPE rangeType, - UINT numDescriptors, - UINT baseShaderRegister, - UINT registerSpace = 0, - D3D12_DESCRIPTOR_RANGE_FLAGS flags = D3D12_DESCRIPTOR_RANGE_FLAG_NONE, - UINT offsetInDescriptorsFromTableStart = - D3D12_DESCRIPTOR_RANGE_OFFSET_APPEND) - { - Init(*this, rangeType, numDescriptors, baseShaderRegister, registerSpace, flags, offsetInDescriptorsFromTableStart); - } - - static inline void Init( - _Out_ D3D12_DESCRIPTOR_RANGE1& range, - D3D12_DESCRIPTOR_RANGE_TYPE rangeType, - UINT numDescriptors, - UINT baseShaderRegister, - UINT registerSpace = 0, - D3D12_DESCRIPTOR_RANGE_FLAGS flags = D3D12_DESCRIPTOR_RANGE_FLAG_NONE, - UINT offsetInDescriptorsFromTableStart = - D3D12_DESCRIPTOR_RANGE_OFFSET_APPEND) - { - range.RangeType = rangeType; - range.NumDescriptors = numDescriptors; - range.BaseShaderRegister = baseShaderRegister; - range.RegisterSpace = registerSpace; - range.Flags = flags; - range.OffsetInDescriptorsFromTableStart = offsetInDescriptorsFromTableStart; - } -}; - -struct CD3DX12_ROOT_DESCRIPTOR_TABLE1 : public D3D12_ROOT_DESCRIPTOR_TABLE1 -{ - CD3DX12_ROOT_DESCRIPTOR_TABLE1() = default; - explicit CD3DX12_ROOT_DESCRIPTOR_TABLE1(const D3D12_ROOT_DESCRIPTOR_TABLE1& o) : - D3D12_ROOT_DESCRIPTOR_TABLE1(o) - {} - CD3DX12_ROOT_DESCRIPTOR_TABLE1( - UINT numDescriptorRanges, - _In_reads_opt_(numDescriptorRanges) const D3D12_DESCRIPTOR_RANGE1* _pDescriptorRanges) - { - Init(numDescriptorRanges, _pDescriptorRanges); - } - - inline void Init( - UINT numDescriptorRanges, - _In_reads_opt_(numDescriptorRanges) const D3D12_DESCRIPTOR_RANGE1* _pDescriptorRanges) - { - Init(*this, numDescriptorRanges, _pDescriptorRanges); - } - - static inline void Init( - _Out_ D3D12_ROOT_DESCRIPTOR_TABLE1& rootDescriptorTable, - UINT numDescriptorRanges, - _In_reads_opt_(numDescriptorRanges) const D3D12_DESCRIPTOR_RANGE1* _pDescriptorRanges) - { - rootDescriptorTable.NumDescriptorRanges = numDescriptorRanges; - rootDescriptorTable.pDescriptorRanges = _pDescriptorRanges; - } -}; - -struct CD3DX12_ROOT_DESCRIPTOR1 : public D3D12_ROOT_DESCRIPTOR1 -{ - CD3DX12_ROOT_DESCRIPTOR1() = default; - explicit CD3DX12_ROOT_DESCRIPTOR1(const D3D12_ROOT_DESCRIPTOR1& o) : - D3D12_ROOT_DESCRIPTOR1(o) - {} - CD3DX12_ROOT_DESCRIPTOR1( - UINT shaderRegister, - UINT registerSpace = 0, - D3D12_ROOT_DESCRIPTOR_FLAGS flags = D3D12_ROOT_DESCRIPTOR_FLAG_NONE) - { - Init(shaderRegister, registerSpace, flags); - } - - inline void Init( - UINT shaderRegister, - UINT registerSpace = 0, - D3D12_ROOT_DESCRIPTOR_FLAGS flags = D3D12_ROOT_DESCRIPTOR_FLAG_NONE) - { - Init(*this, shaderRegister, registerSpace, flags); - } - - static inline void Init( - _Out_ D3D12_ROOT_DESCRIPTOR1& table, - UINT shaderRegister, - UINT registerSpace = 0, - D3D12_ROOT_DESCRIPTOR_FLAGS flags = D3D12_ROOT_DESCRIPTOR_FLAG_NONE) - { - table.ShaderRegister = shaderRegister; - table.RegisterSpace = registerSpace; - table.Flags = flags; - } -}; - -struct CD3DX12_ROOT_PARAMETER1 : public D3D12_ROOT_PARAMETER1 -{ - CD3DX12_ROOT_PARAMETER1() = default; - explicit CD3DX12_ROOT_PARAMETER1(const D3D12_ROOT_PARAMETER1& o) : - D3D12_ROOT_PARAMETER1(o) - {} - - static inline void InitAsDescriptorTable( - _Out_ D3D12_ROOT_PARAMETER1& rootParam, - UINT numDescriptorRanges, - _In_reads_(numDescriptorRanges) const D3D12_DESCRIPTOR_RANGE1* pDescriptorRanges, - D3D12_SHADER_VISIBILITY visibility = D3D12_SHADER_VISIBILITY_ALL) - { - rootParam.ParameterType = D3D12_ROOT_PARAMETER_TYPE_DESCRIPTOR_TABLE; - rootParam.ShaderVisibility = visibility; - CD3DX12_ROOT_DESCRIPTOR_TABLE1::Init(rootParam.DescriptorTable, numDescriptorRanges, pDescriptorRanges); - } - - static inline void InitAsConstants( - _Out_ D3D12_ROOT_PARAMETER1& rootParam, - UINT num32BitValues, - UINT shaderRegister, - UINT registerSpace = 0, - D3D12_SHADER_VISIBILITY visibility = D3D12_SHADER_VISIBILITY_ALL) - { - rootParam.ParameterType = D3D12_ROOT_PARAMETER_TYPE_32BIT_CONSTANTS; - rootParam.ShaderVisibility = visibility; - CD3DX12_ROOT_CONSTANTS::Init(rootParam.Constants, num32BitValues, shaderRegister, registerSpace); - } - - static inline void InitAsConstantBufferView( - _Out_ D3D12_ROOT_PARAMETER1& rootParam, - UINT shaderRegister, - UINT registerSpace = 0, - D3D12_ROOT_DESCRIPTOR_FLAGS flags = D3D12_ROOT_DESCRIPTOR_FLAG_NONE, - D3D12_SHADER_VISIBILITY visibility = D3D12_SHADER_VISIBILITY_ALL) - { - rootParam.ParameterType = D3D12_ROOT_PARAMETER_TYPE_CBV; - rootParam.ShaderVisibility = visibility; - CD3DX12_ROOT_DESCRIPTOR1::Init(rootParam.Descriptor, shaderRegister, registerSpace, flags); - } - - static inline void InitAsShaderResourceView( - _Out_ D3D12_ROOT_PARAMETER1& rootParam, - UINT shaderRegister, - UINT registerSpace = 0, - D3D12_ROOT_DESCRIPTOR_FLAGS flags = D3D12_ROOT_DESCRIPTOR_FLAG_NONE, - D3D12_SHADER_VISIBILITY visibility = D3D12_SHADER_VISIBILITY_ALL) - { - rootParam.ParameterType = D3D12_ROOT_PARAMETER_TYPE_SRV; - rootParam.ShaderVisibility = visibility; - CD3DX12_ROOT_DESCRIPTOR1::Init(rootParam.Descriptor, shaderRegister, registerSpace, flags); - } - - static inline void InitAsUnorderedAccessView( - _Out_ D3D12_ROOT_PARAMETER1& rootParam, - UINT shaderRegister, - UINT registerSpace = 0, - D3D12_ROOT_DESCRIPTOR_FLAGS flags = D3D12_ROOT_DESCRIPTOR_FLAG_NONE, - D3D12_SHADER_VISIBILITY visibility = D3D12_SHADER_VISIBILITY_ALL) - { - rootParam.ParameterType = D3D12_ROOT_PARAMETER_TYPE_UAV; - rootParam.ShaderVisibility = visibility; - CD3DX12_ROOT_DESCRIPTOR1::Init(rootParam.Descriptor, shaderRegister, registerSpace, flags); - } - - inline void InitAsDescriptorTable( - UINT numDescriptorRanges, - _In_reads_(numDescriptorRanges) const D3D12_DESCRIPTOR_RANGE1* pDescriptorRanges, - D3D12_SHADER_VISIBILITY visibility = D3D12_SHADER_VISIBILITY_ALL) - { - InitAsDescriptorTable(*this, numDescriptorRanges, pDescriptorRanges, visibility); - } - - inline void InitAsConstants( - UINT num32BitValues, - UINT shaderRegister, - UINT registerSpace = 0, - D3D12_SHADER_VISIBILITY visibility = D3D12_SHADER_VISIBILITY_ALL) - { - InitAsConstants(*this, num32BitValues, shaderRegister, registerSpace, visibility); - } - - inline void InitAsConstantBufferView( - UINT shaderRegister, - UINT registerSpace = 0, - D3D12_ROOT_DESCRIPTOR_FLAGS flags = D3D12_ROOT_DESCRIPTOR_FLAG_NONE, - D3D12_SHADER_VISIBILITY visibility = D3D12_SHADER_VISIBILITY_ALL) - { - InitAsConstantBufferView(*this, shaderRegister, registerSpace, flags, visibility); - } - - inline void InitAsShaderResourceView( - UINT shaderRegister, - UINT registerSpace = 0, - D3D12_ROOT_DESCRIPTOR_FLAGS flags = D3D12_ROOT_DESCRIPTOR_FLAG_NONE, - D3D12_SHADER_VISIBILITY visibility = D3D12_SHADER_VISIBILITY_ALL) - { - InitAsShaderResourceView(*this, shaderRegister, registerSpace, flags, visibility); - } - - inline void InitAsUnorderedAccessView( - UINT shaderRegister, - UINT registerSpace = 0, - D3D12_ROOT_DESCRIPTOR_FLAGS flags = D3D12_ROOT_DESCRIPTOR_FLAG_NONE, - D3D12_SHADER_VISIBILITY visibility = D3D12_SHADER_VISIBILITY_ALL) - { - InitAsUnorderedAccessView(*this, shaderRegister, registerSpace, flags, visibility); - } -}; - -struct CD3DX12_VERSIONED_ROOT_SIGNATURE_DESC : public D3D12_VERSIONED_ROOT_SIGNATURE_DESC -{ - CD3DX12_VERSIONED_ROOT_SIGNATURE_DESC() = default; - explicit CD3DX12_VERSIONED_ROOT_SIGNATURE_DESC(const D3D12_VERSIONED_ROOT_SIGNATURE_DESC& o) : - D3D12_VERSIONED_ROOT_SIGNATURE_DESC(o) - {} - explicit CD3DX12_VERSIONED_ROOT_SIGNATURE_DESC(const D3D12_ROOT_SIGNATURE_DESC& o) - { - Version = D3D_ROOT_SIGNATURE_VERSION_1_0; - Desc_1_0 = o; - } - explicit CD3DX12_VERSIONED_ROOT_SIGNATURE_DESC(const D3D12_ROOT_SIGNATURE_DESC1& o) - { - Version = D3D_ROOT_SIGNATURE_VERSION_1_1; - Desc_1_1 = o; - } - CD3DX12_VERSIONED_ROOT_SIGNATURE_DESC( - UINT numParameters, - _In_reads_opt_(numParameters) const D3D12_ROOT_PARAMETER* _pParameters, - UINT numStaticSamplers = 0, - _In_reads_opt_(numStaticSamplers) const D3D12_STATIC_SAMPLER_DESC* _pStaticSamplers = nullptr, - D3D12_ROOT_SIGNATURE_FLAGS flags = D3D12_ROOT_SIGNATURE_FLAG_NONE) - { - Init_1_0(numParameters, _pParameters, numStaticSamplers, _pStaticSamplers, flags); - } - CD3DX12_VERSIONED_ROOT_SIGNATURE_DESC( - UINT numParameters, - _In_reads_opt_(numParameters) const D3D12_ROOT_PARAMETER1* _pParameters, - UINT numStaticSamplers = 0, - _In_reads_opt_(numStaticSamplers) const D3D12_STATIC_SAMPLER_DESC* _pStaticSamplers = nullptr, - D3D12_ROOT_SIGNATURE_FLAGS flags = D3D12_ROOT_SIGNATURE_FLAG_NONE) - { - Init_1_1(numParameters, _pParameters, numStaticSamplers, _pStaticSamplers, flags); - } - CD3DX12_VERSIONED_ROOT_SIGNATURE_DESC(CD3DX12_DEFAULT) - { - Init_1_1(0, nullptr, 0, nullptr, D3D12_ROOT_SIGNATURE_FLAG_NONE); - } - - inline void Init_1_0( - UINT numParameters, - _In_reads_opt_(numParameters) const D3D12_ROOT_PARAMETER* _pParameters, - UINT numStaticSamplers = 0, - _In_reads_opt_(numStaticSamplers) const D3D12_STATIC_SAMPLER_DESC* _pStaticSamplers = nullptr, - D3D12_ROOT_SIGNATURE_FLAGS flags = D3D12_ROOT_SIGNATURE_FLAG_NONE) - { - Init_1_0(*this, numParameters, _pParameters, numStaticSamplers, _pStaticSamplers, flags); - } - - static inline void Init_1_0( - _Out_ D3D12_VERSIONED_ROOT_SIGNATURE_DESC& desc, - UINT numParameters, - _In_reads_opt_(numParameters) const D3D12_ROOT_PARAMETER* _pParameters, - UINT numStaticSamplers = 0, - _In_reads_opt_(numStaticSamplers) const D3D12_STATIC_SAMPLER_DESC* _pStaticSamplers = nullptr, - D3D12_ROOT_SIGNATURE_FLAGS flags = D3D12_ROOT_SIGNATURE_FLAG_NONE) - { - desc.Version = D3D_ROOT_SIGNATURE_VERSION_1_0; - desc.Desc_1_0.NumParameters = numParameters; - desc.Desc_1_0.pParameters = _pParameters; - desc.Desc_1_0.NumStaticSamplers = numStaticSamplers; - desc.Desc_1_0.pStaticSamplers = _pStaticSamplers; - desc.Desc_1_0.Flags = flags; - } - - inline void Init_1_1( - UINT numParameters, - _In_reads_opt_(numParameters) const D3D12_ROOT_PARAMETER1* _pParameters, - UINT numStaticSamplers = 0, - _In_reads_opt_(numStaticSamplers) const D3D12_STATIC_SAMPLER_DESC* _pStaticSamplers = nullptr, - D3D12_ROOT_SIGNATURE_FLAGS flags = D3D12_ROOT_SIGNATURE_FLAG_NONE) - { - Init_1_1(*this, numParameters, _pParameters, numStaticSamplers, _pStaticSamplers, flags); - } - - static inline void Init_1_1( - _Out_ D3D12_VERSIONED_ROOT_SIGNATURE_DESC& desc, - UINT numParameters, - _In_reads_opt_(numParameters) const D3D12_ROOT_PARAMETER1* _pParameters, - UINT numStaticSamplers = 0, - _In_reads_opt_(numStaticSamplers) const D3D12_STATIC_SAMPLER_DESC* _pStaticSamplers = nullptr, - D3D12_ROOT_SIGNATURE_FLAGS flags = D3D12_ROOT_SIGNATURE_FLAG_NONE) - { - desc.Version = D3D_ROOT_SIGNATURE_VERSION_1_1; - desc.Desc_1_1.NumParameters = numParameters; - desc.Desc_1_1.pParameters = _pParameters; - desc.Desc_1_1.NumStaticSamplers = numStaticSamplers; - desc.Desc_1_1.pStaticSamplers = _pStaticSamplers; - desc.Desc_1_1.Flags = flags; - } -}; - -struct CD3DX12_CPU_DESCRIPTOR_HANDLE : public D3D12_CPU_DESCRIPTOR_HANDLE -{ - CD3DX12_CPU_DESCRIPTOR_HANDLE() = default; - explicit CD3DX12_CPU_DESCRIPTOR_HANDLE(const D3D12_CPU_DESCRIPTOR_HANDLE& o) : - D3D12_CPU_DESCRIPTOR_HANDLE(o) - {} - CD3DX12_CPU_DESCRIPTOR_HANDLE(CD3DX12_DEFAULT) { ptr = 0; } - CD3DX12_CPU_DESCRIPTOR_HANDLE(_In_ const D3D12_CPU_DESCRIPTOR_HANDLE& other, INT offsetScaledByIncrementSize) - { - InitOffsetted(other, offsetScaledByIncrementSize); - } - CD3DX12_CPU_DESCRIPTOR_HANDLE(_In_ const D3D12_CPU_DESCRIPTOR_HANDLE& other, INT offsetInDescriptors, UINT descriptorIncrementSize) - { - InitOffsetted(other, offsetInDescriptors, descriptorIncrementSize); - } - CD3DX12_CPU_DESCRIPTOR_HANDLE& Offset(INT offsetInDescriptors, UINT descriptorIncrementSize) - { - ptr = SIZE_T(INT64(ptr) + INT64(offsetInDescriptors) * INT64(descriptorIncrementSize)); - return *this; - } - CD3DX12_CPU_DESCRIPTOR_HANDLE& Offset(INT offsetScaledByIncrementSize) - { - ptr = SIZE_T(INT64(ptr) + INT64(offsetScaledByIncrementSize)); - return *this; - } - bool operator==(_In_ const D3D12_CPU_DESCRIPTOR_HANDLE& other) const - { - return (ptr == other.ptr); - } - bool operator!=(_In_ const D3D12_CPU_DESCRIPTOR_HANDLE& other) const - { - return (ptr != other.ptr); - } - CD3DX12_CPU_DESCRIPTOR_HANDLE& operator=(const D3D12_CPU_DESCRIPTOR_HANDLE& other) - { - ptr = other.ptr; - return *this; - } - - inline void InitOffsetted(_In_ const D3D12_CPU_DESCRIPTOR_HANDLE& base, INT offsetScaledByIncrementSize) - { - InitOffsetted(*this, base, offsetScaledByIncrementSize); - } - - inline void InitOffsetted(_In_ const D3D12_CPU_DESCRIPTOR_HANDLE& base, INT offsetInDescriptors, UINT descriptorIncrementSize) - { - InitOffsetted(*this, base, offsetInDescriptors, descriptorIncrementSize); - } - - static inline void InitOffsetted(_Out_ D3D12_CPU_DESCRIPTOR_HANDLE& handle, _In_ const D3D12_CPU_DESCRIPTOR_HANDLE& base, INT offsetScaledByIncrementSize) - { - handle.ptr = SIZE_T(INT64(base.ptr) + INT64(offsetScaledByIncrementSize)); - } - - static inline void InitOffsetted(_Out_ D3D12_CPU_DESCRIPTOR_HANDLE& handle, _In_ const D3D12_CPU_DESCRIPTOR_HANDLE& base, INT offsetInDescriptors, UINT descriptorIncrementSize) - { - handle.ptr = SIZE_T(INT64(base.ptr) + INT64(offsetInDescriptors) * INT64(descriptorIncrementSize)); - } -}; - -struct CD3DX12_GPU_DESCRIPTOR_HANDLE : public D3D12_GPU_DESCRIPTOR_HANDLE -{ - CD3DX12_GPU_DESCRIPTOR_HANDLE() = default; - explicit CD3DX12_GPU_DESCRIPTOR_HANDLE(const D3D12_GPU_DESCRIPTOR_HANDLE& o) : - D3D12_GPU_DESCRIPTOR_HANDLE(o) - {} - CD3DX12_GPU_DESCRIPTOR_HANDLE(CD3DX12_DEFAULT) { ptr = 0; } - CD3DX12_GPU_DESCRIPTOR_HANDLE(_In_ const D3D12_GPU_DESCRIPTOR_HANDLE& other, INT offsetScaledByIncrementSize) - { - InitOffsetted(other, offsetScaledByIncrementSize); - } - CD3DX12_GPU_DESCRIPTOR_HANDLE(_In_ const D3D12_GPU_DESCRIPTOR_HANDLE& other, INT offsetInDescriptors, UINT descriptorIncrementSize) - { - InitOffsetted(other, offsetInDescriptors, descriptorIncrementSize); - } - CD3DX12_GPU_DESCRIPTOR_HANDLE& Offset(INT offsetInDescriptors, UINT descriptorIncrementSize) - { - ptr = UINT64(INT64(ptr) + INT64(offsetInDescriptors) * INT64(descriptorIncrementSize)); - return *this; - } - CD3DX12_GPU_DESCRIPTOR_HANDLE& Offset(INT offsetScaledByIncrementSize) - { - ptr = UINT64(INT64(ptr) + INT64(offsetScaledByIncrementSize)); - return *this; - } - inline bool operator==(_In_ const D3D12_GPU_DESCRIPTOR_HANDLE& other) const - { - return (ptr == other.ptr); - } - inline bool operator!=(_In_ const D3D12_GPU_DESCRIPTOR_HANDLE& other) const - { - return (ptr != other.ptr); - } - CD3DX12_GPU_DESCRIPTOR_HANDLE& operator=(const D3D12_GPU_DESCRIPTOR_HANDLE& other) - { - ptr = other.ptr; - return *this; - } - - inline void InitOffsetted(_In_ const D3D12_GPU_DESCRIPTOR_HANDLE& base, INT offsetScaledByIncrementSize) - { - InitOffsetted(*this, base, offsetScaledByIncrementSize); - } - - inline void InitOffsetted(_In_ const D3D12_GPU_DESCRIPTOR_HANDLE& base, INT offsetInDescriptors, UINT descriptorIncrementSize) - { - InitOffsetted(*this, base, offsetInDescriptors, descriptorIncrementSize); - } - - static inline void InitOffsetted(_Out_ D3D12_GPU_DESCRIPTOR_HANDLE& handle, _In_ const D3D12_GPU_DESCRIPTOR_HANDLE& base, INT offsetScaledByIncrementSize) - { - handle.ptr = UINT64(INT64(base.ptr) + INT64(offsetScaledByIncrementSize)); - } - - static inline void InitOffsetted(_Out_ D3D12_GPU_DESCRIPTOR_HANDLE& handle, _In_ const D3D12_GPU_DESCRIPTOR_HANDLE& base, INT offsetInDescriptors, UINT descriptorIncrementSize) - { - handle.ptr = UINT64(INT64(base.ptr) + INT64(offsetInDescriptors) * INT64(descriptorIncrementSize)); - } -}; - -inline UINT D3D12CalcSubresource(UINT MipSlice, UINT ArraySlice, UINT PlaneSlice, UINT MipLevels, UINT ArraySize) -{ - return MipSlice + ArraySlice * MipLevels + PlaneSlice * MipLevels * ArraySize; -} - -template -inline void D3D12DecomposeSubresource(UINT Subresource, UINT MipLevels, UINT ArraySize, _Out_ T& MipSlice, _Out_ U& ArraySlice, _Out_ V& PlaneSlice) -{ - MipSlice = static_cast(Subresource % MipLevels); - ArraySlice = static_cast((Subresource / MipLevels) % ArraySize); - PlaneSlice = static_cast(Subresource / (MipLevels * ArraySize)); -} - -inline UINT8 D3D12GetFormatPlaneCount( - _In_ ID3D12Device* pDevice, - DXGI_FORMAT Format -) -{ - D3D12_FEATURE_DATA_FORMAT_INFO formatInfo = { Format, 0 }; - if (FAILED(pDevice->CheckFeatureSupport(D3D12_FEATURE_FORMAT_INFO, &formatInfo, sizeof(formatInfo)))) - { - return 0; - } - return formatInfo.PlaneCount; -} - -struct CD3DX12_RESOURCE_DESC : public D3D12_RESOURCE_DESC -{ - CD3DX12_RESOURCE_DESC() = default; - explicit CD3DX12_RESOURCE_DESC(const D3D12_RESOURCE_DESC& o) : - D3D12_RESOURCE_DESC(o) - {} - CD3DX12_RESOURCE_DESC( - D3D12_RESOURCE_DIMENSION dimension, - UINT64 alignment, - UINT64 width, - UINT height, - UINT16 depthOrArraySize, - UINT16 mipLevels, - DXGI_FORMAT format, - UINT sampleCount, - UINT sampleQuality, - D3D12_TEXTURE_LAYOUT layout, - D3D12_RESOURCE_FLAGS flags) - { - Dimension = dimension; - Alignment = alignment; - Width = width; - Height = height; - DepthOrArraySize = depthOrArraySize; - MipLevels = mipLevels; - Format = format; - SampleDesc.Count = sampleCount; - SampleDesc.Quality = sampleQuality; - Layout = layout; - Flags = flags; - } - static inline CD3DX12_RESOURCE_DESC Buffer( - const D3D12_RESOURCE_ALLOCATION_INFO& resAllocInfo, - D3D12_RESOURCE_FLAGS flags = D3D12_RESOURCE_FLAG_NONE) - { - return CD3DX12_RESOURCE_DESC(D3D12_RESOURCE_DIMENSION_BUFFER, resAllocInfo.Alignment, resAllocInfo.SizeInBytes, - 1, 1, 1, DXGI_FORMAT_UNKNOWN, 1, 0, D3D12_TEXTURE_LAYOUT_ROW_MAJOR, flags); - } - static inline CD3DX12_RESOURCE_DESC Buffer( - UINT64 width, - D3D12_RESOURCE_FLAGS flags = D3D12_RESOURCE_FLAG_NONE, - UINT64 alignment = 0) - { - return CD3DX12_RESOURCE_DESC(D3D12_RESOURCE_DIMENSION_BUFFER, alignment, width, 1, 1, 1, - DXGI_FORMAT_UNKNOWN, 1, 0, D3D12_TEXTURE_LAYOUT_ROW_MAJOR, flags); - } - static inline CD3DX12_RESOURCE_DESC Tex1D( - DXGI_FORMAT format, - UINT64 width, - UINT16 arraySize = 1, - UINT16 mipLevels = 0, - D3D12_RESOURCE_FLAGS flags = D3D12_RESOURCE_FLAG_NONE, - D3D12_TEXTURE_LAYOUT layout = D3D12_TEXTURE_LAYOUT_UNKNOWN, - UINT64 alignment = 0) - { - return CD3DX12_RESOURCE_DESC(D3D12_RESOURCE_DIMENSION_TEXTURE1D, alignment, width, 1, arraySize, - mipLevels, format, 1, 0, layout, flags); - } - static inline CD3DX12_RESOURCE_DESC Tex2D( - DXGI_FORMAT format, - UINT64 width, - UINT height, - UINT16 arraySize = 1, - UINT16 mipLevels = 0, - UINT sampleCount = 1, - UINT sampleQuality = 0, - D3D12_RESOURCE_FLAGS flags = D3D12_RESOURCE_FLAG_NONE, - D3D12_TEXTURE_LAYOUT layout = D3D12_TEXTURE_LAYOUT_UNKNOWN, - UINT64 alignment = 0) - { - return CD3DX12_RESOURCE_DESC(D3D12_RESOURCE_DIMENSION_TEXTURE2D, alignment, width, height, arraySize, - mipLevels, format, sampleCount, sampleQuality, layout, flags); - } - static inline CD3DX12_RESOURCE_DESC Tex3D( - DXGI_FORMAT format, - UINT64 width, - UINT height, - UINT16 depth, - UINT16 mipLevels = 0, - D3D12_RESOURCE_FLAGS flags = D3D12_RESOURCE_FLAG_NONE, - D3D12_TEXTURE_LAYOUT layout = D3D12_TEXTURE_LAYOUT_UNKNOWN, - UINT64 alignment = 0) - { - return CD3DX12_RESOURCE_DESC(D3D12_RESOURCE_DIMENSION_TEXTURE3D, alignment, width, height, depth, - mipLevels, format, 1, 0, layout, flags); - } - inline UINT16 Depth() const - { - return (Dimension == D3D12_RESOURCE_DIMENSION_TEXTURE3D ? DepthOrArraySize : 1); - } - inline UINT16 ArraySize() const - { - return (Dimension != D3D12_RESOURCE_DIMENSION_TEXTURE3D ? DepthOrArraySize : 1); - } - inline UINT8 PlaneCount(_In_ ID3D12Device* pDevice) const - { - return D3D12GetFormatPlaneCount(pDevice, Format); - } - inline UINT Subresources(_In_ ID3D12Device* pDevice) const - { - return MipLevels * ArraySize() * PlaneCount(pDevice); - } - inline UINT CalcSubresource(UINT MipSlice, UINT ArraySlice, UINT PlaneSlice) - { - return D3D12CalcSubresource(MipSlice, ArraySlice, PlaneSlice, MipLevels, ArraySize()); - } -}; -inline bool operator==(const D3D12_RESOURCE_DESC& l, const D3D12_RESOURCE_DESC& r) -{ - return l.Dimension == r.Dimension && - l.Alignment == r.Alignment && - l.Width == r.Width && - l.Height == r.Height && - l.DepthOrArraySize == r.DepthOrArraySize && - l.MipLevels == r.MipLevels && - l.Format == r.Format && - l.SampleDesc.Count == r.SampleDesc.Count && - l.SampleDesc.Quality == r.SampleDesc.Quality && - l.Layout == r.Layout && - l.Flags == r.Flags; -} -inline bool operator!=(const D3D12_RESOURCE_DESC& l, const D3D12_RESOURCE_DESC& r) -{ - return !(l == r); -} - -struct CD3DX12_VIEW_INSTANCING_DESC : public D3D12_VIEW_INSTANCING_DESC -{ - CD3DX12_VIEW_INSTANCING_DESC() = default; - explicit CD3DX12_VIEW_INSTANCING_DESC(const D3D12_VIEW_INSTANCING_DESC& o) : - D3D12_VIEW_INSTANCING_DESC(o) - {} - explicit CD3DX12_VIEW_INSTANCING_DESC(CD3DX12_DEFAULT) - { - ViewInstanceCount = 0; - pViewInstanceLocations = nullptr; - Flags = D3D12_VIEW_INSTANCING_FLAG_NONE; - } - explicit CD3DX12_VIEW_INSTANCING_DESC( - UINT InViewInstanceCount, - const D3D12_VIEW_INSTANCE_LOCATION* InViewInstanceLocations, - D3D12_VIEW_INSTANCING_FLAGS InFlags) - { - ViewInstanceCount = InViewInstanceCount; - pViewInstanceLocations = InViewInstanceLocations; - Flags = InFlags; - } -}; - -// Row-by-row memcpy -inline void MemcpySubresource( - _In_ const D3D12_MEMCPY_DEST* pDest, - _In_ const D3D12_SUBRESOURCE_DATA* pSrc, - SIZE_T RowSizeInBytes, - UINT NumRows, - UINT NumSlices) -{ - for (UINT z = 0; z < NumSlices; ++z) - { - auto pDestSlice = reinterpret_cast(pDest->pData) + pDest->SlicePitch * z; - auto pSrcSlice = reinterpret_cast(pSrc->pData) + pSrc->SlicePitch * LONG_PTR(z); - for (UINT y = 0; y < NumRows; ++y) - { - memcpy(pDestSlice + pDest->RowPitch * y, - pSrcSlice + pSrc->RowPitch * LONG_PTR(y), - RowSizeInBytes); - } - } -} - -// Returns required size of a buffer to be used for data upload -inline UINT64 GetRequiredIntermediateSize( - _In_ ID3D12Resource* pDestinationResource, - _In_range_(0, D3D12_REQ_SUBRESOURCES) UINT FirstSubresource, - _In_range_(0, D3D12_REQ_SUBRESOURCES - FirstSubresource) UINT NumSubresources) -{ - auto Desc = pDestinationResource->GetDesc(); - UINT64 RequiredSize = 0; - - ID3D12Device* pDevice = nullptr; - pDestinationResource->GetDevice(IID_ID3D12Device, reinterpret_cast(&pDevice)); - pDevice->GetCopyableFootprints(&Desc, FirstSubresource, NumSubresources, 0, nullptr, nullptr, nullptr, &RequiredSize); - pDevice->Release(); - - return RequiredSize; -} - -// All arrays must be populated (e.g. by calling GetCopyableFootprints) -inline UINT64 UpdateSubresources( - _In_ ID3D12GraphicsCommandList* pCmdList, - _In_ ID3D12Resource* pDestinationResource, - _In_ ID3D12Resource* pIntermediate, - _In_range_(0, D3D12_REQ_SUBRESOURCES) UINT FirstSubresource, - _In_range_(0, D3D12_REQ_SUBRESOURCES - FirstSubresource) UINT NumSubresources, - UINT64 RequiredSize, - _In_reads_(NumSubresources) const D3D12_PLACED_SUBRESOURCE_FOOTPRINT* pLayouts, - _In_reads_(NumSubresources) const UINT* pNumRows, - _In_reads_(NumSubresources) const UINT64* pRowSizesInBytes, - _In_reads_(NumSubresources) const D3D12_SUBRESOURCE_DATA* pSrcData) -{ - // Minor validation - auto IntermediateDesc = pIntermediate->GetDesc(); - auto DestinationDesc = pDestinationResource->GetDesc(); - if (IntermediateDesc.Dimension != D3D12_RESOURCE_DIMENSION_BUFFER || - IntermediateDesc.Width < RequiredSize + pLayouts[0].Offset || - RequiredSize > SIZE_T(-1) || - (DestinationDesc.Dimension == D3D12_RESOURCE_DIMENSION_BUFFER && - (FirstSubresource != 0 || NumSubresources != 1))) - { - return 0; - } - - BYTE* pData; - HRESULT hr = pIntermediate->Map(0, nullptr, reinterpret_cast(&pData)); - if (FAILED(hr)) - { - return 0; - } - - for (UINT i = 0; i < NumSubresources; ++i) - { - if (pRowSizesInBytes[i] > SIZE_T(-1)) return 0; - D3D12_MEMCPY_DEST DestData = { pData + pLayouts[i].Offset, pLayouts[i].Footprint.RowPitch, SIZE_T(pLayouts[i].Footprint.RowPitch) * SIZE_T(pNumRows[i]) }; - MemcpySubresource(&DestData, &pSrcData[i], static_cast(pRowSizesInBytes[i]), pNumRows[i], pLayouts[i].Footprint.Depth); - } - pIntermediate->Unmap(0, nullptr); - - if (DestinationDesc.Dimension == D3D12_RESOURCE_DIMENSION_BUFFER) - { - pCmdList->CopyBufferRegion( - pDestinationResource, 0, pIntermediate, pLayouts[0].Offset, pLayouts[0].Footprint.Width); - } - else - { - for (UINT i = 0; i < NumSubresources; ++i) - { - CD3DX12_TEXTURE_COPY_LOCATION Dst(pDestinationResource, i + FirstSubresource); - CD3DX12_TEXTURE_COPY_LOCATION Src(pIntermediate, pLayouts[i]); - pCmdList->CopyTextureRegion(&Dst, 0, 0, 0, &Src, nullptr); - } - } - return RequiredSize; -} - -// Heap-allocating UpdateSubresources implementation -inline UINT64 UpdateSubresources( - _In_ ID3D12GraphicsCommandList* pCmdList, - _In_ ID3D12Resource* pDestinationResource, - _In_ ID3D12Resource* pIntermediate, - UINT64 IntermediateOffset, - _In_range_(0, D3D12_REQ_SUBRESOURCES) UINT FirstSubresource, - _In_range_(0, D3D12_REQ_SUBRESOURCES - FirstSubresource) UINT NumSubresources, - _In_reads_(NumSubresources) D3D12_SUBRESOURCE_DATA* pSrcData) -{ - UINT64 RequiredSize = 0; - UINT64 MemToAlloc = static_cast(sizeof(D3D12_PLACED_SUBRESOURCE_FOOTPRINT) + sizeof(UINT) + sizeof(UINT64)) * NumSubresources; - if (MemToAlloc > SIZE_MAX) - { - return 0; - } - void* pMem = HeapAlloc(GetProcessHeap(), 0, static_cast(MemToAlloc)); - if (pMem == nullptr) - { - return 0; - } - auto pLayouts = reinterpret_cast(pMem); - UINT64* pRowSizesInBytes = reinterpret_cast(pLayouts + NumSubresources); - UINT* pNumRows = reinterpret_cast(pRowSizesInBytes + NumSubresources); - - auto Desc = pDestinationResource->GetDesc(); - ID3D12Device* pDevice = nullptr; - pDestinationResource->GetDevice(IID_ID3D12Device, reinterpret_cast(&pDevice)); - pDevice->GetCopyableFootprints(&Desc, FirstSubresource, NumSubresources, IntermediateOffset, pLayouts, pNumRows, pRowSizesInBytes, &RequiredSize); - pDevice->Release(); - - UINT64 Result = UpdateSubresources(pCmdList, pDestinationResource, pIntermediate, FirstSubresource, NumSubresources, RequiredSize, pLayouts, pNumRows, pRowSizesInBytes, pSrcData); - HeapFree(GetProcessHeap(), 0, pMem); - return Result; -} - -// Stack-allocating UpdateSubresources implementation -template -inline UINT64 UpdateSubresources( - _In_ ID3D12GraphicsCommandList* pCmdList, - _In_ ID3D12Resource* pDestinationResource, - _In_ ID3D12Resource* pIntermediate, - UINT64 IntermediateOffset, - _In_range_(0, MaxSubresources) UINT FirstSubresource, - _In_range_(1, MaxSubresources - FirstSubresource) UINT NumSubresources, - _In_reads_(NumSubresources) D3D12_SUBRESOURCE_DATA* pSrcData) -{ - UINT64 RequiredSize = 0; - D3D12_PLACED_SUBRESOURCE_FOOTPRINT Layouts[MaxSubresources]; - UINT NumRows[MaxSubresources]; - UINT64 RowSizesInBytes[MaxSubresources]; - - auto Desc = pDestinationResource->GetDesc(); - ID3D12Device* pDevice = nullptr; - pDestinationResource->GetDevice(IID_ID3D12Device, reinterpret_cast(&pDevice)); - pDevice->GetCopyableFootprints(&Desc, FirstSubresource, NumSubresources, IntermediateOffset, Layouts, NumRows, RowSizesInBytes, &RequiredSize); - pDevice->Release(); - - return UpdateSubresources(pCmdList, pDestinationResource, pIntermediate, FirstSubresource, NumSubresources, RequiredSize, Layouts, NumRows, RowSizesInBytes, pSrcData); -} - -inline bool D3D12IsLayoutOpaque(D3D12_TEXTURE_LAYOUT Layout) -{ - return Layout == D3D12_TEXTURE_LAYOUT_UNKNOWN || Layout == D3D12_TEXTURE_LAYOUT_64KB_UNDEFINED_SWIZZLE; -} - -template -inline ID3D12CommandList* const* CommandListCast(t_CommandListType* const* pp) -{ - // This cast is useful for passing strongly typed command list pointers into - // ExecuteCommandLists. - // This cast is valid as long as the const-ness is respected. D3D12 APIs do - // respect the const-ness of their arguments. - return reinterpret_cast(pp); -} - -// D3D12 exports a new method for serializing root signatures in the Windows 10 Anniversary Update. -// To help enable root signature 1.1 features when they are available and not require maintaining -// two code paths for building root signatures, this helper method reconstructs a 1.0 signature when -// 1.1 is not supported. -inline HRESULT D3DX12SerializeVersionedRootSignature( - _In_ const D3D12_VERSIONED_ROOT_SIGNATURE_DESC* pRootSignatureDesc, - D3D_ROOT_SIGNATURE_VERSION MaxVersion, - _Outptr_ ID3DBlob** ppBlob, - _Always_(_Outptr_opt_result_maybenull_) ID3DBlob** ppErrorBlob) -{ - if (ppErrorBlob != nullptr) - { - *ppErrorBlob = nullptr; - } - - switch (MaxVersion) - { - case D3D_ROOT_SIGNATURE_VERSION_1_0: - switch (pRootSignatureDesc->Version) - { - case D3D_ROOT_SIGNATURE_VERSION_1_0: - return D3D12SerializeRootSignature(&pRootSignatureDesc->Desc_1_0, D3D_ROOT_SIGNATURE_VERSION_1, ppBlob, ppErrorBlob); - - case D3D_ROOT_SIGNATURE_VERSION_1_1: - { - HRESULT hr = S_OK; - const D3D12_ROOT_SIGNATURE_DESC1& desc_1_1 = pRootSignatureDesc->Desc_1_1; - - const SIZE_T ParametersSize = sizeof(D3D12_ROOT_PARAMETER) * desc_1_1.NumParameters; - void* pParameters = (ParametersSize > 0) ? HeapAlloc(GetProcessHeap(), 0, ParametersSize) : nullptr; - if (ParametersSize > 0 && pParameters == nullptr) - { - hr = E_OUTOFMEMORY; - } - auto pParameters_1_0 = reinterpret_cast(pParameters); - - if (SUCCEEDED(hr)) - { - for (UINT n = 0; n < desc_1_1.NumParameters; n++) - { - __analysis_assume(ParametersSize == sizeof(D3D12_ROOT_PARAMETER) * desc_1_1.NumParameters); - pParameters_1_0[n].ParameterType = desc_1_1.pParameters[n].ParameterType; - pParameters_1_0[n].ShaderVisibility = desc_1_1.pParameters[n].ShaderVisibility; - - switch (desc_1_1.pParameters[n].ParameterType) - { - case D3D12_ROOT_PARAMETER_TYPE_32BIT_CONSTANTS: - pParameters_1_0[n].Constants.Num32BitValues = desc_1_1.pParameters[n].Constants.Num32BitValues; - pParameters_1_0[n].Constants.RegisterSpace = desc_1_1.pParameters[n].Constants.RegisterSpace; - pParameters_1_0[n].Constants.ShaderRegister = desc_1_1.pParameters[n].Constants.ShaderRegister; - break; - - case D3D12_ROOT_PARAMETER_TYPE_CBV: - case D3D12_ROOT_PARAMETER_TYPE_SRV: - case D3D12_ROOT_PARAMETER_TYPE_UAV: - pParameters_1_0[n].Descriptor.RegisterSpace = desc_1_1.pParameters[n].Descriptor.RegisterSpace; - pParameters_1_0[n].Descriptor.ShaderRegister = desc_1_1.pParameters[n].Descriptor.ShaderRegister; - break; - - case D3D12_ROOT_PARAMETER_TYPE_DESCRIPTOR_TABLE: - const D3D12_ROOT_DESCRIPTOR_TABLE1& table_1_1 = desc_1_1.pParameters[n].DescriptorTable; - - const SIZE_T DescriptorRangesSize = sizeof(D3D12_DESCRIPTOR_RANGE) * table_1_1.NumDescriptorRanges; - void* pDescriptorRanges = (DescriptorRangesSize > 0 && SUCCEEDED(hr)) ? HeapAlloc(GetProcessHeap(), 0, DescriptorRangesSize) : nullptr; - if (DescriptorRangesSize > 0 && pDescriptorRanges == nullptr) - { - hr = E_OUTOFMEMORY; - } - auto pDescriptorRanges_1_0 = reinterpret_cast(pDescriptorRanges); - - if (SUCCEEDED(hr)) - { - for (UINT x = 0; x < table_1_1.NumDescriptorRanges; x++) - { - __analysis_assume(DescriptorRangesSize == sizeof(D3D12_DESCRIPTOR_RANGE) * table_1_1.NumDescriptorRanges); - pDescriptorRanges_1_0[x].BaseShaderRegister = table_1_1.pDescriptorRanges[x].BaseShaderRegister; - pDescriptorRanges_1_0[x].NumDescriptors = table_1_1.pDescriptorRanges[x].NumDescriptors; - pDescriptorRanges_1_0[x].OffsetInDescriptorsFromTableStart = table_1_1.pDescriptorRanges[x].OffsetInDescriptorsFromTableStart; - pDescriptorRanges_1_0[x].RangeType = table_1_1.pDescriptorRanges[x].RangeType; - pDescriptorRanges_1_0[x].RegisterSpace = table_1_1.pDescriptorRanges[x].RegisterSpace; - } - } - - D3D12_ROOT_DESCRIPTOR_TABLE& table_1_0 = pParameters_1_0[n].DescriptorTable; - table_1_0.NumDescriptorRanges = table_1_1.NumDescriptorRanges; - table_1_0.pDescriptorRanges = pDescriptorRanges_1_0; - } - } - } - - if (SUCCEEDED(hr)) - { - CD3DX12_ROOT_SIGNATURE_DESC desc_1_0(desc_1_1.NumParameters, pParameters_1_0, desc_1_1.NumStaticSamplers, desc_1_1.pStaticSamplers, desc_1_1.Flags); - hr = D3D12SerializeRootSignature(&desc_1_0, D3D_ROOT_SIGNATURE_VERSION_1, ppBlob, ppErrorBlob); - } - - if (pParameters) - { - for (UINT n = 0; n < desc_1_1.NumParameters; n++) - { - if (desc_1_1.pParameters[n].ParameterType == D3D12_ROOT_PARAMETER_TYPE_DESCRIPTOR_TABLE) - { - HeapFree(GetProcessHeap(), 0, reinterpret_cast(const_cast(pParameters_1_0[n].DescriptorTable.pDescriptorRanges))); - } - } - HeapFree(GetProcessHeap(), 0, pParameters); - } - return hr; - } - } - break; - - case D3D_ROOT_SIGNATURE_VERSION_1_1: - return D3D12SerializeVersionedRootSignature(pRootSignatureDesc, ppBlob, ppErrorBlob); - } - - return E_INVALIDARG; -} - -struct CD3DX12_RT_FORMAT_ARRAY : public D3D12_RT_FORMAT_ARRAY -{ - CD3DX12_RT_FORMAT_ARRAY() = default; - explicit CD3DX12_RT_FORMAT_ARRAY(const D3D12_RT_FORMAT_ARRAY& o) - : D3D12_RT_FORMAT_ARRAY(o) - {} - explicit CD3DX12_RT_FORMAT_ARRAY(_In_reads_(NumFormats) const DXGI_FORMAT* pFormats, UINT NumFormats) - { - NumRenderTargets = NumFormats; - memcpy(RTFormats, pFormats, sizeof(RTFormats)); - // assumes ARRAY_SIZE(pFormats) == ARRAY_SIZE(RTFormats) - } -}; - -// Pipeline State Stream Helpers - -// Stream Subobjects, i.e. elements of a stream - -struct DefaultSampleMask { operator UINT() { return UINT_MAX; } }; -struct DefaultSampleDesc { operator DXGI_SAMPLE_DESC() { return DXGI_SAMPLE_DESC{ 1, 0 }; } }; - -#pragma warning(push) -#pragma warning(disable : 4324) -template -class alignas(void*) CD3DX12_PIPELINE_STATE_STREAM_SUBOBJECT -{ -private: - D3D12_PIPELINE_STATE_SUBOBJECT_TYPE _Type; - InnerStructType _Inner; -public: - CD3DX12_PIPELINE_STATE_STREAM_SUBOBJECT() noexcept : _Type(Type), _Inner(DefaultArg()) {} - CD3DX12_PIPELINE_STATE_STREAM_SUBOBJECT(InnerStructType const& i) : _Type(Type), _Inner(i) {} - CD3DX12_PIPELINE_STATE_STREAM_SUBOBJECT& operator=(InnerStructType const& i) { _Type = Type; _Inner = i; return *this; } - operator InnerStructType const& () const { return _Inner; } - operator InnerStructType& () { return _Inner; } - InnerStructType* operator&() { return &_Inner; } - InnerStructType const* operator&() const { return &_Inner; } -}; -#pragma warning(pop) -typedef CD3DX12_PIPELINE_STATE_STREAM_SUBOBJECT< D3D12_PIPELINE_STATE_FLAGS, D3D12_PIPELINE_STATE_SUBOBJECT_TYPE_FLAGS> CD3DX12_PIPELINE_STATE_STREAM_FLAGS; -typedef CD3DX12_PIPELINE_STATE_STREAM_SUBOBJECT< UINT, D3D12_PIPELINE_STATE_SUBOBJECT_TYPE_NODE_MASK> CD3DX12_PIPELINE_STATE_STREAM_NODE_MASK; -typedef CD3DX12_PIPELINE_STATE_STREAM_SUBOBJECT< ID3D12RootSignature*, D3D12_PIPELINE_STATE_SUBOBJECT_TYPE_ROOT_SIGNATURE> CD3DX12_PIPELINE_STATE_STREAM_ROOT_SIGNATURE; -typedef CD3DX12_PIPELINE_STATE_STREAM_SUBOBJECT< D3D12_INPUT_LAYOUT_DESC, D3D12_PIPELINE_STATE_SUBOBJECT_TYPE_INPUT_LAYOUT> CD3DX12_PIPELINE_STATE_STREAM_INPUT_LAYOUT; -typedef CD3DX12_PIPELINE_STATE_STREAM_SUBOBJECT< D3D12_INDEX_BUFFER_STRIP_CUT_VALUE, D3D12_PIPELINE_STATE_SUBOBJECT_TYPE_IB_STRIP_CUT_VALUE> CD3DX12_PIPELINE_STATE_STREAM_IB_STRIP_CUT_VALUE; -typedef CD3DX12_PIPELINE_STATE_STREAM_SUBOBJECT< D3D12_PRIMITIVE_TOPOLOGY_TYPE, D3D12_PIPELINE_STATE_SUBOBJECT_TYPE_PRIMITIVE_TOPOLOGY> CD3DX12_PIPELINE_STATE_STREAM_PRIMITIVE_TOPOLOGY; -typedef CD3DX12_PIPELINE_STATE_STREAM_SUBOBJECT< D3D12_SHADER_BYTECODE, D3D12_PIPELINE_STATE_SUBOBJECT_TYPE_VS> CD3DX12_PIPELINE_STATE_STREAM_VS; -typedef CD3DX12_PIPELINE_STATE_STREAM_SUBOBJECT< D3D12_SHADER_BYTECODE, D3D12_PIPELINE_STATE_SUBOBJECT_TYPE_GS> CD3DX12_PIPELINE_STATE_STREAM_GS; -typedef CD3DX12_PIPELINE_STATE_STREAM_SUBOBJECT< D3D12_STREAM_OUTPUT_DESC, D3D12_PIPELINE_STATE_SUBOBJECT_TYPE_STREAM_OUTPUT> CD3DX12_PIPELINE_STATE_STREAM_STREAM_OUTPUT; -typedef CD3DX12_PIPELINE_STATE_STREAM_SUBOBJECT< D3D12_SHADER_BYTECODE, D3D12_PIPELINE_STATE_SUBOBJECT_TYPE_HS> CD3DX12_PIPELINE_STATE_STREAM_HS; -typedef CD3DX12_PIPELINE_STATE_STREAM_SUBOBJECT< D3D12_SHADER_BYTECODE, D3D12_PIPELINE_STATE_SUBOBJECT_TYPE_DS> CD3DX12_PIPELINE_STATE_STREAM_DS; -typedef CD3DX12_PIPELINE_STATE_STREAM_SUBOBJECT< D3D12_SHADER_BYTECODE, D3D12_PIPELINE_STATE_SUBOBJECT_TYPE_PS> CD3DX12_PIPELINE_STATE_STREAM_PS; -typedef CD3DX12_PIPELINE_STATE_STREAM_SUBOBJECT< D3D12_SHADER_BYTECODE, D3D12_PIPELINE_STATE_SUBOBJECT_TYPE_CS> CD3DX12_PIPELINE_STATE_STREAM_CS; -typedef CD3DX12_PIPELINE_STATE_STREAM_SUBOBJECT< CD3DX12_BLEND_DESC, D3D12_PIPELINE_STATE_SUBOBJECT_TYPE_BLEND, CD3DX12_DEFAULT> CD3DX12_PIPELINE_STATE_STREAM_BLEND_DESC; -typedef CD3DX12_PIPELINE_STATE_STREAM_SUBOBJECT< CD3DX12_DEPTH_STENCIL_DESC, D3D12_PIPELINE_STATE_SUBOBJECT_TYPE_DEPTH_STENCIL, CD3DX12_DEFAULT> CD3DX12_PIPELINE_STATE_STREAM_DEPTH_STENCIL; -typedef CD3DX12_PIPELINE_STATE_STREAM_SUBOBJECT< CD3DX12_DEPTH_STENCIL_DESC1, D3D12_PIPELINE_STATE_SUBOBJECT_TYPE_DEPTH_STENCIL1, CD3DX12_DEFAULT> CD3DX12_PIPELINE_STATE_STREAM_DEPTH_STENCIL1; -typedef CD3DX12_PIPELINE_STATE_STREAM_SUBOBJECT< DXGI_FORMAT, D3D12_PIPELINE_STATE_SUBOBJECT_TYPE_DEPTH_STENCIL_FORMAT> CD3DX12_PIPELINE_STATE_STREAM_DEPTH_STENCIL_FORMAT; -typedef CD3DX12_PIPELINE_STATE_STREAM_SUBOBJECT< CD3DX12_RASTERIZER_DESC, D3D12_PIPELINE_STATE_SUBOBJECT_TYPE_RASTERIZER, CD3DX12_DEFAULT> CD3DX12_PIPELINE_STATE_STREAM_RASTERIZER; -typedef CD3DX12_PIPELINE_STATE_STREAM_SUBOBJECT< D3D12_RT_FORMAT_ARRAY, D3D12_PIPELINE_STATE_SUBOBJECT_TYPE_RENDER_TARGET_FORMATS> CD3DX12_PIPELINE_STATE_STREAM_RENDER_TARGET_FORMATS; -typedef CD3DX12_PIPELINE_STATE_STREAM_SUBOBJECT< DXGI_SAMPLE_DESC, D3D12_PIPELINE_STATE_SUBOBJECT_TYPE_SAMPLE_DESC, DefaultSampleDesc> CD3DX12_PIPELINE_STATE_STREAM_SAMPLE_DESC; -typedef CD3DX12_PIPELINE_STATE_STREAM_SUBOBJECT< UINT, D3D12_PIPELINE_STATE_SUBOBJECT_TYPE_SAMPLE_MASK, DefaultSampleMask> CD3DX12_PIPELINE_STATE_STREAM_SAMPLE_MASK; -typedef CD3DX12_PIPELINE_STATE_STREAM_SUBOBJECT< D3D12_CACHED_PIPELINE_STATE, D3D12_PIPELINE_STATE_SUBOBJECT_TYPE_CACHED_PSO> CD3DX12_PIPELINE_STATE_STREAM_CACHED_PSO; -typedef CD3DX12_PIPELINE_STATE_STREAM_SUBOBJECT< CD3DX12_VIEW_INSTANCING_DESC, D3D12_PIPELINE_STATE_SUBOBJECT_TYPE_VIEW_INSTANCING, CD3DX12_DEFAULT> CD3DX12_PIPELINE_STATE_STREAM_VIEW_INSTANCING; - -// Stream Parser Helpers - -struct ID3DX12PipelineParserCallbacks -{ - // Subobject Callbacks - virtual void FlagsCb(D3D12_PIPELINE_STATE_FLAGS) {} - virtual void NodeMaskCb(UINT) {} - virtual void RootSignatureCb(ID3D12RootSignature*) {} - virtual void InputLayoutCb(const D3D12_INPUT_LAYOUT_DESC&) {} - virtual void IBStripCutValueCb(D3D12_INDEX_BUFFER_STRIP_CUT_VALUE) {} - virtual void PrimitiveTopologyTypeCb(D3D12_PRIMITIVE_TOPOLOGY_TYPE) {} - virtual void VSCb(const D3D12_SHADER_BYTECODE&) {} - virtual void GSCb(const D3D12_SHADER_BYTECODE&) {} - virtual void StreamOutputCb(const D3D12_STREAM_OUTPUT_DESC&) {} - virtual void HSCb(const D3D12_SHADER_BYTECODE&) {} - virtual void DSCb(const D3D12_SHADER_BYTECODE&) {} - virtual void PSCb(const D3D12_SHADER_BYTECODE&) {} - virtual void CSCb(const D3D12_SHADER_BYTECODE&) {} - virtual void BlendStateCb(const D3D12_BLEND_DESC&) {} - virtual void DepthStencilStateCb(const D3D12_DEPTH_STENCIL_DESC&) {} - virtual void DepthStencilState1Cb(const D3D12_DEPTH_STENCIL_DESC1&) {} - virtual void DSVFormatCb(DXGI_FORMAT) {} - virtual void RasterizerStateCb(const D3D12_RASTERIZER_DESC&) {} - virtual void RTVFormatsCb(const D3D12_RT_FORMAT_ARRAY&) {} - virtual void SampleDescCb(const DXGI_SAMPLE_DESC&) {} - virtual void SampleMaskCb(UINT) {} - virtual void ViewInstancingCb(const D3D12_VIEW_INSTANCING_DESC&) {} - virtual void CachedPSOCb(const D3D12_CACHED_PIPELINE_STATE&) {} - - // Error Callbacks - virtual void ErrorBadInputParameter(UINT /*ParameterIndex*/) {} - virtual void ErrorDuplicateSubobject(D3D12_PIPELINE_STATE_SUBOBJECT_TYPE /*DuplicateType*/) {} - virtual void ErrorUnknownSubobject(UINT /*UnknownTypeValue*/) {} - - virtual ~ID3DX12PipelineParserCallbacks() = default; -}; - -// CD3DX12_PIPELINE_STATE_STREAM1 Works on RS3+ (where there is a new view instancing subobject). -// Use CD3DX12_PIPELINE_STATE_STREAM for RS2+ support. -struct CD3DX12_PIPELINE_STATE_STREAM1 -{ - CD3DX12_PIPELINE_STATE_STREAM1() = default; - CD3DX12_PIPELINE_STATE_STREAM1(const D3D12_GRAPHICS_PIPELINE_STATE_DESC& Desc) - : Flags(Desc.Flags) - , NodeMask(Desc.NodeMask) - , pRootSignature(Desc.pRootSignature) - , InputLayout(Desc.InputLayout) - , IBStripCutValue(Desc.IBStripCutValue) - , PrimitiveTopologyType(Desc.PrimitiveTopologyType) - , VS(Desc.VS) - , GS(Desc.GS) - , StreamOutput(Desc.StreamOutput) - , HS(Desc.HS) - , DS(Desc.DS) - , PS(Desc.PS) - , BlendState(CD3DX12_BLEND_DESC(Desc.BlendState)) - , DepthStencilState(CD3DX12_DEPTH_STENCIL_DESC1(Desc.DepthStencilState)) - , DSVFormat(Desc.DSVFormat) - , RasterizerState(CD3DX12_RASTERIZER_DESC(Desc.RasterizerState)) - , RTVFormats(CD3DX12_RT_FORMAT_ARRAY(Desc.RTVFormats, Desc.NumRenderTargets)) - , SampleDesc(Desc.SampleDesc) - , SampleMask(Desc.SampleMask) - , CachedPSO(Desc.CachedPSO) - , ViewInstancingDesc(CD3DX12_VIEW_INSTANCING_DESC(CD3DX12_DEFAULT())) - {} - CD3DX12_PIPELINE_STATE_STREAM1(const D3D12_COMPUTE_PIPELINE_STATE_DESC& Desc) - : Flags(Desc.Flags) - , NodeMask(Desc.NodeMask) - , pRootSignature(Desc.pRootSignature) - , CS(CD3DX12_SHADER_BYTECODE(Desc.CS)) - , CachedPSO(Desc.CachedPSO) - { - static_cast(DepthStencilState).DepthEnable = false; - } - CD3DX12_PIPELINE_STATE_STREAM_FLAGS Flags; - CD3DX12_PIPELINE_STATE_STREAM_NODE_MASK NodeMask; - CD3DX12_PIPELINE_STATE_STREAM_ROOT_SIGNATURE pRootSignature; - CD3DX12_PIPELINE_STATE_STREAM_INPUT_LAYOUT InputLayout; - CD3DX12_PIPELINE_STATE_STREAM_IB_STRIP_CUT_VALUE IBStripCutValue; - CD3DX12_PIPELINE_STATE_STREAM_PRIMITIVE_TOPOLOGY PrimitiveTopologyType; - CD3DX12_PIPELINE_STATE_STREAM_VS VS; - CD3DX12_PIPELINE_STATE_STREAM_GS GS; - CD3DX12_PIPELINE_STATE_STREAM_STREAM_OUTPUT StreamOutput; - CD3DX12_PIPELINE_STATE_STREAM_HS HS; - CD3DX12_PIPELINE_STATE_STREAM_DS DS; - CD3DX12_PIPELINE_STATE_STREAM_PS PS; - CD3DX12_PIPELINE_STATE_STREAM_CS CS; - CD3DX12_PIPELINE_STATE_STREAM_BLEND_DESC BlendState; - CD3DX12_PIPELINE_STATE_STREAM_DEPTH_STENCIL1 DepthStencilState; - CD3DX12_PIPELINE_STATE_STREAM_DEPTH_STENCIL_FORMAT DSVFormat; - CD3DX12_PIPELINE_STATE_STREAM_RASTERIZER RasterizerState; - CD3DX12_PIPELINE_STATE_STREAM_RENDER_TARGET_FORMATS RTVFormats; - CD3DX12_PIPELINE_STATE_STREAM_SAMPLE_DESC SampleDesc; - CD3DX12_PIPELINE_STATE_STREAM_SAMPLE_MASK SampleMask; - CD3DX12_PIPELINE_STATE_STREAM_CACHED_PSO CachedPSO; - CD3DX12_PIPELINE_STATE_STREAM_VIEW_INSTANCING ViewInstancingDesc; - D3D12_GRAPHICS_PIPELINE_STATE_DESC GraphicsDescV0() const - { - D3D12_GRAPHICS_PIPELINE_STATE_DESC D; - D.Flags = this->Flags; - D.NodeMask = this->NodeMask; - D.pRootSignature = this->pRootSignature; - D.InputLayout = this->InputLayout; - D.IBStripCutValue = this->IBStripCutValue; - D.PrimitiveTopologyType = this->PrimitiveTopologyType; - D.VS = this->VS; - D.GS = this->GS; - D.StreamOutput = this->StreamOutput; - D.HS = this->HS; - D.DS = this->DS; - D.PS = this->PS; - D.BlendState = this->BlendState; - D.DepthStencilState = CD3DX12_DEPTH_STENCIL_DESC1(D3D12_DEPTH_STENCIL_DESC1(this->DepthStencilState)); - D.DSVFormat = this->DSVFormat; - D.RasterizerState = this->RasterizerState; - D.NumRenderTargets = D3D12_RT_FORMAT_ARRAY(this->RTVFormats).NumRenderTargets; - memcpy(D.RTVFormats, D3D12_RT_FORMAT_ARRAY(this->RTVFormats).RTFormats, sizeof(D.RTVFormats)); - D.SampleDesc = this->SampleDesc; - D.SampleMask = this->SampleMask; - D.CachedPSO = this->CachedPSO; - return D; - } - D3D12_COMPUTE_PIPELINE_STATE_DESC ComputeDescV0() const - { - D3D12_COMPUTE_PIPELINE_STATE_DESC D; - D.Flags = this->Flags; - D.NodeMask = this->NodeMask; - D.pRootSignature = this->pRootSignature; - D.CS = this->CS; - D.CachedPSO = this->CachedPSO; - return D; - } -}; - -// CD3DX12_PIPELINE_STATE_STREAM works on RS2+ but does not support new subobject(s) added in RS3+. -// See CD3DX12_PIPELINE_STATE_STREAM1 for instance. -struct CD3DX12_PIPELINE_STATE_STREAM -{ - CD3DX12_PIPELINE_STATE_STREAM() = default; - CD3DX12_PIPELINE_STATE_STREAM(const D3D12_GRAPHICS_PIPELINE_STATE_DESC& Desc) - : Flags(Desc.Flags) - , NodeMask(Desc.NodeMask) - , pRootSignature(Desc.pRootSignature) - , InputLayout(Desc.InputLayout) - , IBStripCutValue(Desc.IBStripCutValue) - , PrimitiveTopologyType(Desc.PrimitiveTopologyType) - , VS(Desc.VS) - , GS(Desc.GS) - , StreamOutput(Desc.StreamOutput) - , HS(Desc.HS) - , DS(Desc.DS) - , PS(Desc.PS) - , BlendState(CD3DX12_BLEND_DESC(Desc.BlendState)) - , DepthStencilState(CD3DX12_DEPTH_STENCIL_DESC1(Desc.DepthStencilState)) - , DSVFormat(Desc.DSVFormat) - , RasterizerState(CD3DX12_RASTERIZER_DESC(Desc.RasterizerState)) - , RTVFormats(CD3DX12_RT_FORMAT_ARRAY(Desc.RTVFormats, Desc.NumRenderTargets)) - , SampleDesc(Desc.SampleDesc) - , SampleMask(Desc.SampleMask) - , CachedPSO(Desc.CachedPSO) - {} - CD3DX12_PIPELINE_STATE_STREAM(const D3D12_COMPUTE_PIPELINE_STATE_DESC& Desc) - : Flags(Desc.Flags) - , NodeMask(Desc.NodeMask) - , pRootSignature(Desc.pRootSignature) - , CS(CD3DX12_SHADER_BYTECODE(Desc.CS)) - , CachedPSO(Desc.CachedPSO) - {} - CD3DX12_PIPELINE_STATE_STREAM_FLAGS Flags; - CD3DX12_PIPELINE_STATE_STREAM_NODE_MASK NodeMask; - CD3DX12_PIPELINE_STATE_STREAM_ROOT_SIGNATURE pRootSignature; - CD3DX12_PIPELINE_STATE_STREAM_INPUT_LAYOUT InputLayout; - CD3DX12_PIPELINE_STATE_STREAM_IB_STRIP_CUT_VALUE IBStripCutValue; - CD3DX12_PIPELINE_STATE_STREAM_PRIMITIVE_TOPOLOGY PrimitiveTopologyType; - CD3DX12_PIPELINE_STATE_STREAM_VS VS; - CD3DX12_PIPELINE_STATE_STREAM_GS GS; - CD3DX12_PIPELINE_STATE_STREAM_STREAM_OUTPUT StreamOutput; - CD3DX12_PIPELINE_STATE_STREAM_HS HS; - CD3DX12_PIPELINE_STATE_STREAM_DS DS; - CD3DX12_PIPELINE_STATE_STREAM_PS PS; - CD3DX12_PIPELINE_STATE_STREAM_CS CS; - CD3DX12_PIPELINE_STATE_STREAM_BLEND_DESC BlendState; - CD3DX12_PIPELINE_STATE_STREAM_DEPTH_STENCIL1 DepthStencilState; - CD3DX12_PIPELINE_STATE_STREAM_DEPTH_STENCIL_FORMAT DSVFormat; - CD3DX12_PIPELINE_STATE_STREAM_RASTERIZER RasterizerState; - CD3DX12_PIPELINE_STATE_STREAM_RENDER_TARGET_FORMATS RTVFormats; - CD3DX12_PIPELINE_STATE_STREAM_SAMPLE_DESC SampleDesc; - CD3DX12_PIPELINE_STATE_STREAM_SAMPLE_MASK SampleMask; - CD3DX12_PIPELINE_STATE_STREAM_CACHED_PSO CachedPSO; - D3D12_GRAPHICS_PIPELINE_STATE_DESC GraphicsDescV0() const - { - D3D12_GRAPHICS_PIPELINE_STATE_DESC D; - D.Flags = this->Flags; - D.NodeMask = this->NodeMask; - D.pRootSignature = this->pRootSignature; - D.InputLayout = this->InputLayout; - D.IBStripCutValue = this->IBStripCutValue; - D.PrimitiveTopologyType = this->PrimitiveTopologyType; - D.VS = this->VS; - D.GS = this->GS; - D.StreamOutput = this->StreamOutput; - D.HS = this->HS; - D.DS = this->DS; - D.PS = this->PS; - D.BlendState = this->BlendState; - D.DepthStencilState = CD3DX12_DEPTH_STENCIL_DESC1(D3D12_DEPTH_STENCIL_DESC1(this->DepthStencilState)); - D.DSVFormat = this->DSVFormat; - D.RasterizerState = this->RasterizerState; - D.NumRenderTargets = D3D12_RT_FORMAT_ARRAY(this->RTVFormats).NumRenderTargets; - memcpy(D.RTVFormats, D3D12_RT_FORMAT_ARRAY(this->RTVFormats).RTFormats, sizeof(D.RTVFormats)); - D.SampleDesc = this->SampleDesc; - D.SampleMask = this->SampleMask; - D.CachedPSO = this->CachedPSO; - return D; - } - D3D12_COMPUTE_PIPELINE_STATE_DESC ComputeDescV0() const - { - D3D12_COMPUTE_PIPELINE_STATE_DESC D; - D.Flags = this->Flags; - D.NodeMask = this->NodeMask; - D.pRootSignature = this->pRootSignature; - D.CS = this->CS; - D.CachedPSO = this->CachedPSO; - return D; - } -}; - -struct CD3DX12_PIPELINE_STATE_STREAM_PARSE_HELPER : public ID3DX12PipelineParserCallbacks -{ - CD3DX12_PIPELINE_STATE_STREAM1 PipelineStream; - CD3DX12_PIPELINE_STATE_STREAM_PARSE_HELPER() noexcept - : SeenDSS(false) - { - // Adjust defaults to account for absent members. - PipelineStream.PrimitiveTopologyType = D3D12_PRIMITIVE_TOPOLOGY_TYPE_TRIANGLE; - - // Depth disabled if no DSV format specified. - static_cast(PipelineStream.DepthStencilState).DepthEnable = false; - } - - // ID3DX12PipelineParserCallbacks - void FlagsCb(D3D12_PIPELINE_STATE_FLAGS Flags) override { PipelineStream.Flags = Flags; } - void NodeMaskCb(UINT NodeMask) override { PipelineStream.NodeMask = NodeMask; } - void RootSignatureCb(ID3D12RootSignature* pRootSignature) override { PipelineStream.pRootSignature = pRootSignature; } - void InputLayoutCb(const D3D12_INPUT_LAYOUT_DESC& InputLayout) override { PipelineStream.InputLayout = InputLayout; } - void IBStripCutValueCb(D3D12_INDEX_BUFFER_STRIP_CUT_VALUE IBStripCutValue) override { PipelineStream.IBStripCutValue = IBStripCutValue; } - void PrimitiveTopologyTypeCb(D3D12_PRIMITIVE_TOPOLOGY_TYPE PrimitiveTopologyType) override { PipelineStream.PrimitiveTopologyType = PrimitiveTopologyType; } - void VSCb(const D3D12_SHADER_BYTECODE& VS) override { PipelineStream.VS = VS; } - void GSCb(const D3D12_SHADER_BYTECODE& GS) override { PipelineStream.GS = GS; } - void StreamOutputCb(const D3D12_STREAM_OUTPUT_DESC& StreamOutput) override { PipelineStream.StreamOutput = StreamOutput; } - void HSCb(const D3D12_SHADER_BYTECODE& HS) override { PipelineStream.HS = HS; } - void DSCb(const D3D12_SHADER_BYTECODE& DS) override { PipelineStream.DS = DS; } - void PSCb(const D3D12_SHADER_BYTECODE& PS) override { PipelineStream.PS = PS; } - void CSCb(const D3D12_SHADER_BYTECODE& CS) override { PipelineStream.CS = CS; } - void BlendStateCb(const D3D12_BLEND_DESC& BlendState) override { PipelineStream.BlendState = CD3DX12_BLEND_DESC(BlendState); } - void DepthStencilStateCb(const D3D12_DEPTH_STENCIL_DESC& DepthStencilState) override - { - PipelineStream.DepthStencilState = CD3DX12_DEPTH_STENCIL_DESC1(DepthStencilState); - SeenDSS = true; - } - void DepthStencilState1Cb(const D3D12_DEPTH_STENCIL_DESC1& DepthStencilState) override - { - PipelineStream.DepthStencilState = CD3DX12_DEPTH_STENCIL_DESC1(DepthStencilState); - SeenDSS = true; - } - void DSVFormatCb(DXGI_FORMAT DSVFormat) override - { - PipelineStream.DSVFormat = DSVFormat; - if (!SeenDSS && DSVFormat != DXGI_FORMAT_UNKNOWN) - { - // Re-enable depth for the default state. - static_cast(PipelineStream.DepthStencilState).DepthEnable = true; - } - } - void RasterizerStateCb(const D3D12_RASTERIZER_DESC& RasterizerState) override { PipelineStream.RasterizerState = CD3DX12_RASTERIZER_DESC(RasterizerState); } - void RTVFormatsCb(const D3D12_RT_FORMAT_ARRAY& RTVFormats) override { PipelineStream.RTVFormats = RTVFormats; } - void SampleDescCb(const DXGI_SAMPLE_DESC& SampleDesc) override { PipelineStream.SampleDesc = SampleDesc; } - void SampleMaskCb(UINT SampleMask) override { PipelineStream.SampleMask = SampleMask; } - void ViewInstancingCb(const D3D12_VIEW_INSTANCING_DESC& ViewInstancingDesc) override { PipelineStream.ViewInstancingDesc = CD3DX12_VIEW_INSTANCING_DESC(ViewInstancingDesc); } - void CachedPSOCb(const D3D12_CACHED_PIPELINE_STATE& CachedPSO) override { PipelineStream.CachedPSO = CachedPSO; } - -private: - bool SeenDSS; -}; - -inline D3D12_PIPELINE_STATE_SUBOBJECT_TYPE D3DX12GetBaseSubobjectType(D3D12_PIPELINE_STATE_SUBOBJECT_TYPE SubobjectType) -{ - switch (SubobjectType) - { - case D3D12_PIPELINE_STATE_SUBOBJECT_TYPE_DEPTH_STENCIL1: - return D3D12_PIPELINE_STATE_SUBOBJECT_TYPE_DEPTH_STENCIL; - default: - return SubobjectType; - } -} - -inline HRESULT D3DX12ParsePipelineStream(const D3D12_PIPELINE_STATE_STREAM_DESC& Desc, ID3DX12PipelineParserCallbacks* pCallbacks) -{ - if (pCallbacks == nullptr) - { - return E_INVALIDARG; - } - - if (Desc.SizeInBytes == 0 || Desc.pPipelineStateSubobjectStream == nullptr) - { - pCallbacks->ErrorBadInputParameter(1); // first parameter issue - return E_INVALIDARG; - } - - bool SubobjectSeen[D3D12_PIPELINE_STATE_SUBOBJECT_TYPE_MAX_VALID] = {}; - for (SIZE_T CurOffset = 0, SizeOfSubobject = 0; CurOffset < Desc.SizeInBytes; CurOffset += SizeOfSubobject) - { - BYTE* pStream = static_cast(Desc.pPipelineStateSubobjectStream) + CurOffset; - auto SubobjectType = *reinterpret_cast(pStream); - if (SubobjectType < 0 || SubobjectType >= D3D12_PIPELINE_STATE_SUBOBJECT_TYPE_MAX_VALID) - { - pCallbacks->ErrorUnknownSubobject(SubobjectType); - return E_INVALIDARG; - } - if (SubobjectSeen[D3DX12GetBaseSubobjectType(SubobjectType)]) - { - pCallbacks->ErrorDuplicateSubobject(SubobjectType); - return E_INVALIDARG; // disallow subobject duplicates in a stream - } - SubobjectSeen[SubobjectType] = true; - switch (SubobjectType) - { - case D3D12_PIPELINE_STATE_SUBOBJECT_TYPE_ROOT_SIGNATURE: - pCallbacks->RootSignatureCb(*reinterpret_cast(pStream)); - SizeOfSubobject = sizeof(CD3DX12_PIPELINE_STATE_STREAM::pRootSignature); - break; - case D3D12_PIPELINE_STATE_SUBOBJECT_TYPE_VS: - pCallbacks->VSCb(*reinterpret_cast(pStream)); - SizeOfSubobject = sizeof(CD3DX12_PIPELINE_STATE_STREAM::VS); - break; - case D3D12_PIPELINE_STATE_SUBOBJECT_TYPE_PS: - pCallbacks->PSCb(*reinterpret_cast(pStream)); - SizeOfSubobject = sizeof(CD3DX12_PIPELINE_STATE_STREAM::PS); - break; - case D3D12_PIPELINE_STATE_SUBOBJECT_TYPE_DS: - pCallbacks->DSCb(*reinterpret_cast(pStream)); - SizeOfSubobject = sizeof(CD3DX12_PIPELINE_STATE_STREAM::DS); - break; - case D3D12_PIPELINE_STATE_SUBOBJECT_TYPE_HS: - pCallbacks->HSCb(*reinterpret_cast(pStream)); - SizeOfSubobject = sizeof(CD3DX12_PIPELINE_STATE_STREAM::HS); - break; - case D3D12_PIPELINE_STATE_SUBOBJECT_TYPE_GS: - pCallbacks->GSCb(*reinterpret_cast(pStream)); - SizeOfSubobject = sizeof(CD3DX12_PIPELINE_STATE_STREAM::GS); - break; - case D3D12_PIPELINE_STATE_SUBOBJECT_TYPE_CS: - pCallbacks->CSCb(*reinterpret_cast(pStream)); - SizeOfSubobject = sizeof(CD3DX12_PIPELINE_STATE_STREAM::CS); - break; - case D3D12_PIPELINE_STATE_SUBOBJECT_TYPE_STREAM_OUTPUT: - pCallbacks->StreamOutputCb(*reinterpret_cast(pStream)); - SizeOfSubobject = sizeof(CD3DX12_PIPELINE_STATE_STREAM::StreamOutput); - break; - case D3D12_PIPELINE_STATE_SUBOBJECT_TYPE_BLEND: - pCallbacks->BlendStateCb(*reinterpret_cast(pStream)); - SizeOfSubobject = sizeof(CD3DX12_PIPELINE_STATE_STREAM::BlendState); - break; - case D3D12_PIPELINE_STATE_SUBOBJECT_TYPE_SAMPLE_MASK: - pCallbacks->SampleMaskCb(*reinterpret_cast(pStream)); - SizeOfSubobject = sizeof(CD3DX12_PIPELINE_STATE_STREAM::SampleMask); - break; - case D3D12_PIPELINE_STATE_SUBOBJECT_TYPE_RASTERIZER: - pCallbacks->RasterizerStateCb(*reinterpret_cast(pStream)); - SizeOfSubobject = sizeof(CD3DX12_PIPELINE_STATE_STREAM::RasterizerState); - break; - case D3D12_PIPELINE_STATE_SUBOBJECT_TYPE_DEPTH_STENCIL: - pCallbacks->DepthStencilStateCb(*reinterpret_cast(pStream)); - SizeOfSubobject = sizeof(CD3DX12_PIPELINE_STATE_STREAM_DEPTH_STENCIL); - break; - case D3D12_PIPELINE_STATE_SUBOBJECT_TYPE_DEPTH_STENCIL1: - pCallbacks->DepthStencilState1Cb(*reinterpret_cast(pStream)); - SizeOfSubobject = sizeof(CD3DX12_PIPELINE_STATE_STREAM::DepthStencilState); - break; - case D3D12_PIPELINE_STATE_SUBOBJECT_TYPE_INPUT_LAYOUT: - pCallbacks->InputLayoutCb(*reinterpret_cast(pStream)); - SizeOfSubobject = sizeof(CD3DX12_PIPELINE_STATE_STREAM::InputLayout); - break; - case D3D12_PIPELINE_STATE_SUBOBJECT_TYPE_IB_STRIP_CUT_VALUE: - pCallbacks->IBStripCutValueCb(*reinterpret_cast(pStream)); - SizeOfSubobject = sizeof(CD3DX12_PIPELINE_STATE_STREAM::IBStripCutValue); - break; - case D3D12_PIPELINE_STATE_SUBOBJECT_TYPE_PRIMITIVE_TOPOLOGY: - pCallbacks->PrimitiveTopologyTypeCb(*reinterpret_cast(pStream)); - SizeOfSubobject = sizeof(CD3DX12_PIPELINE_STATE_STREAM::PrimitiveTopologyType); - break; - case D3D12_PIPELINE_STATE_SUBOBJECT_TYPE_RENDER_TARGET_FORMATS: - pCallbacks->RTVFormatsCb(*reinterpret_cast(pStream)); - SizeOfSubobject = sizeof(CD3DX12_PIPELINE_STATE_STREAM::RTVFormats); - break; - case D3D12_PIPELINE_STATE_SUBOBJECT_TYPE_DEPTH_STENCIL_FORMAT: - pCallbacks->DSVFormatCb(*reinterpret_cast(pStream)); - SizeOfSubobject = sizeof(CD3DX12_PIPELINE_STATE_STREAM::DSVFormat); - break; - case D3D12_PIPELINE_STATE_SUBOBJECT_TYPE_SAMPLE_DESC: - pCallbacks->SampleDescCb(*reinterpret_cast(pStream)); - SizeOfSubobject = sizeof(CD3DX12_PIPELINE_STATE_STREAM::SampleDesc); - break; - case D3D12_PIPELINE_STATE_SUBOBJECT_TYPE_NODE_MASK: - pCallbacks->NodeMaskCb(*reinterpret_cast(pStream)); - SizeOfSubobject = sizeof(CD3DX12_PIPELINE_STATE_STREAM::NodeMask); - break; - case D3D12_PIPELINE_STATE_SUBOBJECT_TYPE_CACHED_PSO: - pCallbacks->CachedPSOCb(*reinterpret_cast(pStream)); - SizeOfSubobject = sizeof(CD3DX12_PIPELINE_STATE_STREAM::CachedPSO); - break; - case D3D12_PIPELINE_STATE_SUBOBJECT_TYPE_FLAGS: - pCallbacks->FlagsCb(*reinterpret_cast(pStream)); - SizeOfSubobject = sizeof(CD3DX12_PIPELINE_STATE_STREAM::Flags); - break; - case D3D12_PIPELINE_STATE_SUBOBJECT_TYPE_VIEW_INSTANCING: - pCallbacks->ViewInstancingCb(*reinterpret_cast(pStream)); - SizeOfSubobject = sizeof(CD3DX12_PIPELINE_STATE_STREAM1::ViewInstancingDesc); - break; - default: - pCallbacks->ErrorUnknownSubobject(SubobjectType); - return E_INVALIDARG; - break; - } - } - - return S_OK; -} - -inline bool operator==(const D3D12_CLEAR_VALUE& a, const D3D12_CLEAR_VALUE& b) -{ - if (a.Format != b.Format) return false; - if (a.Format == DXGI_FORMAT_D24_UNORM_S8_UINT - || a.Format == DXGI_FORMAT_D16_UNORM - || a.Format == DXGI_FORMAT_D32_FLOAT - || a.Format == DXGI_FORMAT_D32_FLOAT_S8X24_UINT) - { - return (a.DepthStencil.Depth == b.DepthStencil.Depth) && - (a.DepthStencil.Stencil == b.DepthStencil.Stencil); - } - else { - return (a.Color[0] == b.Color[0]) && - (a.Color[1] == b.Color[1]) && - (a.Color[2] == b.Color[2]) && - (a.Color[3] == b.Color[3]); - } -} -inline bool operator==(const D3D12_RENDER_PASS_BEGINNING_ACCESS_CLEAR_PARAMETERS& a, const D3D12_RENDER_PASS_BEGINNING_ACCESS_CLEAR_PARAMETERS& b) -{ - return a.ClearValue == b.ClearValue; -} -inline bool operator==(const D3D12_RENDER_PASS_ENDING_ACCESS_RESOLVE_PARAMETERS& a, const D3D12_RENDER_PASS_ENDING_ACCESS_RESOLVE_PARAMETERS& b) -{ - if (a.pSrcResource != b.pSrcResource) return false; - if (a.pDstResource != b.pDstResource) return false; - if (a.SubresourceCount != b.SubresourceCount) return false; - if (a.Format != b.Format) return false; - if (a.ResolveMode != b.ResolveMode) return false; - if (a.PreserveResolveSource != b.PreserveResolveSource) return false; - return true; -} -inline bool operator==(const D3D12_RENDER_PASS_BEGINNING_ACCESS& a, const D3D12_RENDER_PASS_BEGINNING_ACCESS& b) -{ - if (a.Type != b.Type) return false; - if (a.Type == D3D12_RENDER_PASS_BEGINNING_ACCESS_TYPE_CLEAR && !(a.Clear == b.Clear)) return false; - return true; -} -inline bool operator==(const D3D12_RENDER_PASS_ENDING_ACCESS& a, const D3D12_RENDER_PASS_ENDING_ACCESS& b) -{ - if (a.Type != b.Type) return false; - if (a.Type == D3D12_RENDER_PASS_ENDING_ACCESS_TYPE_RESOLVE && !(a.Resolve == b.Resolve)) return false; - return true; -} -inline bool operator==(const D3D12_RENDER_PASS_RENDER_TARGET_DESC& a, const D3D12_RENDER_PASS_RENDER_TARGET_DESC& b) -{ - if (a.cpuDescriptor.ptr != b.cpuDescriptor.ptr) return false; - if (!(a.BeginningAccess == b.BeginningAccess)) return false; - if (!(a.EndingAccess == b.EndingAccess)) return false; - return true; -} -inline bool operator==(const D3D12_RENDER_PASS_DEPTH_STENCIL_DESC& a, const D3D12_RENDER_PASS_DEPTH_STENCIL_DESC& b) -{ - if (a.cpuDescriptor.ptr != b.cpuDescriptor.ptr) return false; - if (!(a.DepthBeginningAccess == b.DepthBeginningAccess)) return false; - if (!(a.StencilBeginningAccess == b.StencilBeginningAccess)) return false; - if (!(a.DepthEndingAccess == b.DepthEndingAccess)) return false; - if (!(a.StencilEndingAccess == b.StencilEndingAccess)) return false; - return true; -} - - -#ifndef D3DX12_NO_STATE_OBJECT_HELPERS - -//================================================================================================ -// D3DX12 State Object Creation Helpers -// -// Helper classes for creating new style state objects out of an arbitrary set of subobjects. -// Uses STL -// -// Start by instantiating CD3DX12_STATE_OBJECT_DESC (see it's public methods). -// One of its methods is CreateSubobject(), which has a comment showing a couple of options for -// defining subobjects using the helper classes for each subobject (CD3DX12_DXIL_LIBRARY_SUBOBJECT -// etc.). The subobject helpers each have methods specific to the subobject for configuring it's -// contents. -// -//================================================================================================ -#include -#include -#include -#include -#include - -class CD3DX12_STATE_OBJECT_DESC -{ -public: - CD3DX12_STATE_OBJECT_DESC() - { - Init(D3D12_STATE_OBJECT_TYPE_COLLECTION); - } - CD3DX12_STATE_OBJECT_DESC(D3D12_STATE_OBJECT_TYPE Type) - { - Init(Type); - } - void SetStateObjectType(D3D12_STATE_OBJECT_TYPE Type) { m_Desc.Type = Type; } - operator const D3D12_STATE_OBJECT_DESC& () - { - // Do final preparation work - m_RepointedAssociations.clear(); - m_SubobjectArray.clear(); - m_SubobjectArray.reserve(m_Desc.NumSubobjects); - // Flatten subobjects into an array (each flattened subobject still has a - // member that's a pointer to it's desc that's not flattened) - for (auto Iter = m_SubobjectList.begin(); - Iter != m_SubobjectList.end(); Iter++) - { - m_SubobjectArray.push_back(*Iter); - // Store new location in array so we can redirect pointers contained in subobjects - Iter->pSubobjectArrayLocation = &m_SubobjectArray.back(); - } - // For subobjects with pointer fields, create a new copy of those subobject definitions - // with fixed pointers - for (UINT i = 0; i < m_Desc.NumSubobjects; i++) - { - if (m_SubobjectArray[i].Type == D3D12_STATE_SUBOBJECT_TYPE_SUBOBJECT_TO_EXPORTS_ASSOCIATION) - { - auto pOriginalSubobjectAssociation = - reinterpret_cast(m_SubobjectArray[i].pDesc); - D3D12_SUBOBJECT_TO_EXPORTS_ASSOCIATION Repointed = *pOriginalSubobjectAssociation; - auto pWrapper = - static_cast(pOriginalSubobjectAssociation->pSubobjectToAssociate); - Repointed.pSubobjectToAssociate = pWrapper->pSubobjectArrayLocation; - m_RepointedAssociations.push_back(Repointed); - m_SubobjectArray[i].pDesc = &m_RepointedAssociations.back(); - } - } - // Below: using ugly way to get pointer in case .data() is not defined - m_Desc.pSubobjects = m_Desc.NumSubobjects ? &m_SubobjectArray[0] : nullptr; - return m_Desc; - } - operator const D3D12_STATE_OBJECT_DESC* () - { - // Cast calls the above final preparation work - return &static_cast(*this); - } - - // CreateSubobject creates a sububject helper (e.g. CD3DX12_HIT_GROUP_SUBOBJECT) - // whose lifetime is owned by this class. - // e.g. - // - // CD3DX12_STATE_OBJECT_DESC Collection1(D3D12_STATE_OBJECT_TYPE_COLLECTION); - // auto Lib0 = Collection1.CreateSubobject(); - // Lib0->SetDXILLibrary(&pMyAppDxilLibs[0]); - // Lib0->DefineExport(L"rayGenShader0"); // in practice these export listings might be - // // data/engine driven - // etc. - // - // Alternatively, users can instantiate sububject helpers explicitly, such as via local - // variables instead, passing the state object desc that should point to it into the helper - // constructor (or call mySubobjectHelper.AddToStateObject(Collection1)). - // In this alternative scenario, the user must keep the subobject alive as long as the state - // object it is associated with is alive, else it's pointer references will be stale. - // e.g. - // - // CD3DX12_STATE_OBJECT_DESC RaytracingState2(D3D12_STATE_OBJECT_TYPE_RAYTRACING_PIPELINE); - // CD3DX12_DXIL_LIBRARY_SUBOBJECT LibA(RaytracingState2); - // LibA.SetDXILLibrary(&pMyAppDxilLibs[4]); // not manually specifying exports - // // - meaning all exports in the libraries - // // are exported - // etc. - - template - T* CreateSubobject() - { - T* pSubobject = new T(*this); - m_OwnedSubobjectHelpers.emplace_back(pSubobject); - return pSubobject; - } - -private: - D3D12_STATE_SUBOBJECT* TrackSubobject(D3D12_STATE_SUBOBJECT_TYPE Type, void* pDesc) - { - SUBOBJECT_WRAPPER Subobject; - Subobject.pSubobjectArrayLocation = nullptr; - Subobject.Type = Type; - Subobject.pDesc = pDesc; - m_SubobjectList.push_back(Subobject); - m_Desc.NumSubobjects++; - return &m_SubobjectList.back(); - } - void Init(D3D12_STATE_OBJECT_TYPE Type) - { - SetStateObjectType(Type); - m_Desc.pSubobjects = nullptr; - m_Desc.NumSubobjects = 0; - m_SubobjectList.clear(); - m_SubobjectArray.clear(); - m_RepointedAssociations.clear(); - } - typedef struct SUBOBJECT_WRAPPER : public D3D12_STATE_SUBOBJECT - { - D3D12_STATE_SUBOBJECT* pSubobjectArrayLocation; // new location when flattened into array - // for repointing pointers in subobjects - } SUBOBJECT_WRAPPER; - D3D12_STATE_OBJECT_DESC m_Desc; - std::list m_SubobjectList; // Pointers to list nodes handed out so - // these can be edited live - std::vector m_SubobjectArray; // Built at the end, copying list contents - - std::list - m_RepointedAssociations; // subobject type that contains pointers to other subobjects, - // repointed to flattened array - - class StringContainer - { - public: - LPCWSTR LocalCopy(LPCWSTR string, bool bSingleString = false) - { - if (string) - { - if (bSingleString) - { - m_Strings.clear(); - m_Strings.push_back(string); - } - else - { - m_Strings.push_back(string); - } - return m_Strings.back().c_str(); - } - else - { - return nullptr; - } - } - void clear() { m_Strings.clear(); } - private: - std::list m_Strings; - }; - - class SUBOBJECT_HELPER_BASE - { - public: - SUBOBJECT_HELPER_BASE() { Init(); } - virtual ~SUBOBJECT_HELPER_BASE() {} - virtual D3D12_STATE_SUBOBJECT_TYPE Type() const = 0; - void AddToStateObject(CD3DX12_STATE_OBJECT_DESC& ContainingStateObject) - { - m_pSubobject = ContainingStateObject.TrackSubobject(Type(), Data()); - } - protected: - virtual void* Data() = 0; - void Init() { m_pSubobject = nullptr; } - D3D12_STATE_SUBOBJECT* m_pSubobject; - }; - -#if(__cplusplus >= 201103L) - std::list> m_OwnedSubobjectHelpers; -#else - class OWNED_HELPER - { - public: - OWNED_HELPER(const SUBOBJECT_HELPER_BASE* pHelper) { m_pHelper = pHelper; } - ~OWNED_HELPER() { delete m_pHelper; } - const SUBOBJECT_HELPER_BASE* m_pHelper; - }; - - std::list m_OwnedSubobjectHelpers; -#endif - - friend class CD3DX12_DXIL_LIBRARY_SUBOBJECT; - friend class CD3DX12_EXISTING_COLLECTION_SUBOBJECT; - friend class CD3DX12_SUBOBJECT_TO_EXPORTS_ASSOCIATION_SUBOBJECT; - friend class CD3DX12_DXIL_SUBOBJECT_TO_EXPORTS_ASSOCIATION; - friend class CD3DX12_HIT_GROUP_SUBOBJECT; - friend class CD3DX12_RAYTRACING_SHADER_CONFIG_SUBOBJECT; - friend class CD3DX12_RAYTRACING_PIPELINE_CONFIG_SUBOBJECT; - friend class CD3DX12_GLOBAL_ROOT_SIGNATURE_SUBOBJECT; - friend class CD3DX12_LOCAL_ROOT_SIGNATURE_SUBOBJECT; - friend class CD3DX12_STATE_OBJECT_CONFIG_SUBOBJECT; - friend class CD3DX12_NODE_MASK_SUBOBJECT; -}; - -class CD3DX12_DXIL_LIBRARY_SUBOBJECT - : public CD3DX12_STATE_OBJECT_DESC::SUBOBJECT_HELPER_BASE -{ -public: - CD3DX12_DXIL_LIBRARY_SUBOBJECT() - { - Init(); - } - CD3DX12_DXIL_LIBRARY_SUBOBJECT(CD3DX12_STATE_OBJECT_DESC& ContainingStateObject) - { - Init(); - AddToStateObject(ContainingStateObject); - } - void SetDXILLibrary(D3D12_SHADER_BYTECODE* pCode) - { - static const D3D12_SHADER_BYTECODE Default = {}; - m_Desc.DXILLibrary = pCode ? *pCode : Default; - } - void DefineExport( - LPCWSTR Name, - LPCWSTR ExportToRename = nullptr, - D3D12_EXPORT_FLAGS Flags = D3D12_EXPORT_FLAG_NONE) - { - D3D12_EXPORT_DESC Export; - Export.Name = m_Strings.LocalCopy(Name); - Export.ExportToRename = m_Strings.LocalCopy(ExportToRename); - Export.Flags = Flags; - m_Exports.push_back(Export); - m_Desc.pExports = &m_Exports[0]; // using ugly way to get pointer in case .data() is not defined - m_Desc.NumExports = static_cast(m_Exports.size()); - } - template - void DefineExports(LPCWSTR(&Exports)[N]) - { - for (UINT i = 0; i < N; i++) - { - DefineExport(Exports[i]); - } - } - void DefineExports(LPCWSTR* Exports, UINT N) - { - for (UINT i = 0; i < N; i++) - { - DefineExport(Exports[i]); - } - } - D3D12_STATE_SUBOBJECT_TYPE Type() const - { - return D3D12_STATE_SUBOBJECT_TYPE_DXIL_LIBRARY; - } - operator const D3D12_STATE_SUBOBJECT& () const { return *m_pSubobject; } - operator const D3D12_DXIL_LIBRARY_DESC& () const { return m_Desc; } -private: - void Init() - { - SUBOBJECT_HELPER_BASE::Init(); - m_Desc = {}; - m_Strings.clear(); - m_Exports.clear(); - } - void* Data() { return &m_Desc; } - D3D12_DXIL_LIBRARY_DESC m_Desc; - CD3DX12_STATE_OBJECT_DESC::StringContainer m_Strings; - std::vector m_Exports; -}; - -class CD3DX12_EXISTING_COLLECTION_SUBOBJECT - : public CD3DX12_STATE_OBJECT_DESC::SUBOBJECT_HELPER_BASE -{ -public: - CD3DX12_EXISTING_COLLECTION_SUBOBJECT() - { - Init(); - } - CD3DX12_EXISTING_COLLECTION_SUBOBJECT(CD3DX12_STATE_OBJECT_DESC& ContainingStateObject) - { - Init(); - AddToStateObject(ContainingStateObject); - } - void SetExistingCollection(ID3D12StateObject* pExistingCollection) - { - m_Desc.pExistingCollection = pExistingCollection; - m_CollectionRef = pExistingCollection; - } - void DefineExport( - LPCWSTR Name, - LPCWSTR ExportToRename = nullptr, - D3D12_EXPORT_FLAGS Flags = D3D12_EXPORT_FLAG_NONE) - { - D3D12_EXPORT_DESC Export; - Export.Name = m_Strings.LocalCopy(Name); - Export.ExportToRename = m_Strings.LocalCopy(ExportToRename); - Export.Flags = Flags; - m_Exports.push_back(Export); - m_Desc.pExports = &m_Exports[0]; // using ugly way to get pointer in case .data() is not defined - m_Desc.NumExports = static_cast(m_Exports.size()); - } - template - void DefineExports(LPCWSTR(&Exports)[N]) - { - for (UINT i = 0; i < N; i++) - { - DefineExport(Exports[i]); - } - } - void DefineExports(LPCWSTR* Exports, UINT N) - { - for (UINT i = 0; i < N; i++) - { - DefineExport(Exports[i]); - } - } - D3D12_STATE_SUBOBJECT_TYPE Type() const - { - return D3D12_STATE_SUBOBJECT_TYPE_EXISTING_COLLECTION; - } - operator const D3D12_STATE_SUBOBJECT& () const { return *m_pSubobject; } - operator const D3D12_EXISTING_COLLECTION_DESC& () const { return m_Desc; } -private: - void Init() - { - SUBOBJECT_HELPER_BASE::Init(); - m_Desc = {}; - m_CollectionRef = nullptr; - m_Strings.clear(); - m_Exports.clear(); - } - void* Data() { return &m_Desc; } - D3D12_EXISTING_COLLECTION_DESC m_Desc; - Microsoft::WRL::ComPtr m_CollectionRef; - CD3DX12_STATE_OBJECT_DESC::StringContainer m_Strings; - std::vector m_Exports; -}; - -class CD3DX12_SUBOBJECT_TO_EXPORTS_ASSOCIATION_SUBOBJECT - : public CD3DX12_STATE_OBJECT_DESC::SUBOBJECT_HELPER_BASE -{ -public: - CD3DX12_SUBOBJECT_TO_EXPORTS_ASSOCIATION_SUBOBJECT() - { - Init(); - } - CD3DX12_SUBOBJECT_TO_EXPORTS_ASSOCIATION_SUBOBJECT(CD3DX12_STATE_OBJECT_DESC& ContainingStateObject) - { - Init(); - AddToStateObject(ContainingStateObject); - } - void SetSubobjectToAssociate(const D3D12_STATE_SUBOBJECT& SubobjectToAssociate) - { - m_Desc.pSubobjectToAssociate = &SubobjectToAssociate; - } - void AddExport(LPCWSTR Export) - { - m_Desc.NumExports++; - m_Exports.push_back(m_Strings.LocalCopy(Export)); - m_Desc.pExports = &m_Exports[0]; // using ugly way to get pointer in case .data() is not defined - } - template - void AddExports(LPCWSTR(&Exports)[N]) - { - for (UINT i = 0; i < N; i++) - { - AddExport(Exports[i]); - } - } - void AddExports(LPCWSTR* Exports, UINT N) - { - for (UINT i = 0; i < N; i++) - { - AddExport(Exports[i]); - } - } - D3D12_STATE_SUBOBJECT_TYPE Type() const - { - return D3D12_STATE_SUBOBJECT_TYPE_SUBOBJECT_TO_EXPORTS_ASSOCIATION; - } - operator const D3D12_STATE_SUBOBJECT& () const { return *m_pSubobject; } - operator const D3D12_SUBOBJECT_TO_EXPORTS_ASSOCIATION& () const { return m_Desc; } -private: - void Init() - { - SUBOBJECT_HELPER_BASE::Init(); - m_Desc = {}; - m_Strings.clear(); - m_Exports.clear(); - } - void* Data() { return &m_Desc; } - D3D12_SUBOBJECT_TO_EXPORTS_ASSOCIATION m_Desc; - CD3DX12_STATE_OBJECT_DESC::StringContainer m_Strings; - std::vector m_Exports; -}; - -class CD3DX12_DXIL_SUBOBJECT_TO_EXPORTS_ASSOCIATION - : public CD3DX12_STATE_OBJECT_DESC::SUBOBJECT_HELPER_BASE -{ -public: - CD3DX12_DXIL_SUBOBJECT_TO_EXPORTS_ASSOCIATION() - { - Init(); - } - CD3DX12_DXIL_SUBOBJECT_TO_EXPORTS_ASSOCIATION(CD3DX12_STATE_OBJECT_DESC& ContainingStateObject) - { - Init(); - AddToStateObject(ContainingStateObject); - } - void SetSubobjectNameToAssociate(LPCWSTR SubobjectToAssociate) - { - m_Desc.SubobjectToAssociate = m_SubobjectName.LocalCopy(SubobjectToAssociate, true); - } - void AddExport(LPCWSTR Export) - { - m_Desc.NumExports++; - m_Exports.push_back(m_Strings.LocalCopy(Export)); - m_Desc.pExports = &m_Exports[0]; // using ugly way to get pointer in case .data() is not defined - } - template - void AddExports(LPCWSTR(&Exports)[N]) - { - for (UINT i = 0; i < N; i++) - { - AddExport(Exports[i]); - } - } - void AddExports(LPCWSTR* Exports, UINT N) - { - for (UINT i = 0; i < N; i++) - { - AddExport(Exports[i]); - } - } - D3D12_STATE_SUBOBJECT_TYPE Type() const - { - return D3D12_STATE_SUBOBJECT_TYPE_DXIL_SUBOBJECT_TO_EXPORTS_ASSOCIATION; - } - operator const D3D12_STATE_SUBOBJECT& () const { return *m_pSubobject; } - operator const D3D12_DXIL_SUBOBJECT_TO_EXPORTS_ASSOCIATION& () const { return m_Desc; } -private: - void Init() - { - SUBOBJECT_HELPER_BASE::Init(); - m_Desc = {}; - m_Strings.clear(); - m_SubobjectName.clear(); - m_Exports.clear(); - } - void* Data() { return &m_Desc; } - D3D12_DXIL_SUBOBJECT_TO_EXPORTS_ASSOCIATION m_Desc; - CD3DX12_STATE_OBJECT_DESC::StringContainer m_Strings; - CD3DX12_STATE_OBJECT_DESC::StringContainer m_SubobjectName; - std::vector m_Exports; -}; - -class CD3DX12_HIT_GROUP_SUBOBJECT - : public CD3DX12_STATE_OBJECT_DESC::SUBOBJECT_HELPER_BASE -{ -public: - CD3DX12_HIT_GROUP_SUBOBJECT() - { - Init(); - } - CD3DX12_HIT_GROUP_SUBOBJECT(CD3DX12_STATE_OBJECT_DESC& ContainingStateObject) - { - Init(); - AddToStateObject(ContainingStateObject); - } - void SetHitGroupExport(LPCWSTR exportName) - { - m_Desc.HitGroupExport = m_Strings[0].LocalCopy(exportName, true); - } - void SetHitGroupType(D3D12_HIT_GROUP_TYPE Type) { m_Desc.Type = Type; } - void SetAnyHitShaderImport(LPCWSTR importName) - { - m_Desc.AnyHitShaderImport = m_Strings[1].LocalCopy(importName, true); - } - void SetClosestHitShaderImport(LPCWSTR importName) - { - m_Desc.ClosestHitShaderImport = m_Strings[2].LocalCopy(importName, true); - } - void SetIntersectionShaderImport(LPCWSTR importName) - { - m_Desc.IntersectionShaderImport = m_Strings[3].LocalCopy(importName, true); - } - D3D12_STATE_SUBOBJECT_TYPE Type() const - { - return D3D12_STATE_SUBOBJECT_TYPE_HIT_GROUP; - } - operator const D3D12_STATE_SUBOBJECT& () const { return *m_pSubobject; } - operator const D3D12_HIT_GROUP_DESC& () const { return m_Desc; } -private: - void Init() - { - SUBOBJECT_HELPER_BASE::Init(); - m_Desc = {}; - for (UINT i = 0; i < m_NumStrings; i++) - { - m_Strings[i].clear(); - } - } - void* Data() { return &m_Desc; } - D3D12_HIT_GROUP_DESC m_Desc; - static const UINT m_NumStrings = 4; - CD3DX12_STATE_OBJECT_DESC::StringContainer - m_Strings[m_NumStrings]; // one string for every entrypoint name -}; - -class CD3DX12_RAYTRACING_SHADER_CONFIG_SUBOBJECT - : public CD3DX12_STATE_OBJECT_DESC::SUBOBJECT_HELPER_BASE -{ -public: - CD3DX12_RAYTRACING_SHADER_CONFIG_SUBOBJECT() - { - Init(); - } - CD3DX12_RAYTRACING_SHADER_CONFIG_SUBOBJECT(CD3DX12_STATE_OBJECT_DESC& ContainingStateObject) - { - Init(); - AddToStateObject(ContainingStateObject); - } - void Config(UINT MaxPayloadSizeInBytes, UINT MaxAttributeSizeInBytes) - { - m_Desc.MaxPayloadSizeInBytes = MaxPayloadSizeInBytes; - m_Desc.MaxAttributeSizeInBytes = MaxAttributeSizeInBytes; - } - D3D12_STATE_SUBOBJECT_TYPE Type() const - { - return D3D12_STATE_SUBOBJECT_TYPE_RAYTRACING_SHADER_CONFIG; - } - operator const D3D12_STATE_SUBOBJECT& () const { return *m_pSubobject; } - operator const D3D12_RAYTRACING_SHADER_CONFIG& () const { return m_Desc; } -private: - void Init() - { - SUBOBJECT_HELPER_BASE::Init(); - m_Desc = {}; - } - void* Data() { return &m_Desc; } - D3D12_RAYTRACING_SHADER_CONFIG m_Desc; -}; - -class CD3DX12_RAYTRACING_PIPELINE_CONFIG_SUBOBJECT - : public CD3DX12_STATE_OBJECT_DESC::SUBOBJECT_HELPER_BASE -{ -public: - CD3DX12_RAYTRACING_PIPELINE_CONFIG_SUBOBJECT() - { - Init(); - } - CD3DX12_RAYTRACING_PIPELINE_CONFIG_SUBOBJECT(CD3DX12_STATE_OBJECT_DESC& ContainingStateObject) - { - Init(); - AddToStateObject(ContainingStateObject); - } - void Config(UINT MaxTraceRecursionDepth) - { - m_Desc.MaxTraceRecursionDepth = MaxTraceRecursionDepth; - } - D3D12_STATE_SUBOBJECT_TYPE Type() const - { - return D3D12_STATE_SUBOBJECT_TYPE_RAYTRACING_PIPELINE_CONFIG; - } - operator const D3D12_STATE_SUBOBJECT& () const { return *m_pSubobject; } - operator const D3D12_RAYTRACING_PIPELINE_CONFIG& () const { return m_Desc; } -private: - void Init() - { - SUBOBJECT_HELPER_BASE::Init(); - m_Desc = {}; - } - void* Data() { return &m_Desc; } - D3D12_RAYTRACING_PIPELINE_CONFIG m_Desc; -}; - -class CD3DX12_GLOBAL_ROOT_SIGNATURE_SUBOBJECT - : public CD3DX12_STATE_OBJECT_DESC::SUBOBJECT_HELPER_BASE -{ -public: - CD3DX12_GLOBAL_ROOT_SIGNATURE_SUBOBJECT() - { - Init(); - } - CD3DX12_GLOBAL_ROOT_SIGNATURE_SUBOBJECT(CD3DX12_STATE_OBJECT_DESC& ContainingStateObject) - { - Init(); - AddToStateObject(ContainingStateObject); - } - void SetRootSignature(ID3D12RootSignature* pRootSig) - { - m_pRootSig = pRootSig; - } - D3D12_STATE_SUBOBJECT_TYPE Type() const - { - return D3D12_STATE_SUBOBJECT_TYPE_GLOBAL_ROOT_SIGNATURE; - } - operator const D3D12_STATE_SUBOBJECT& () const { return *m_pSubobject; } - operator ID3D12RootSignature* () const { return m_pRootSig.Get(); } -private: - void Init() - { - SUBOBJECT_HELPER_BASE::Init(); - m_pRootSig = nullptr; - } - void* Data() { return m_pRootSig.GetAddressOf(); } - Microsoft::WRL::ComPtr m_pRootSig; -}; - -class CD3DX12_LOCAL_ROOT_SIGNATURE_SUBOBJECT - : public CD3DX12_STATE_OBJECT_DESC::SUBOBJECT_HELPER_BASE -{ -public: - CD3DX12_LOCAL_ROOT_SIGNATURE_SUBOBJECT() - { - Init(); - } - CD3DX12_LOCAL_ROOT_SIGNATURE_SUBOBJECT(CD3DX12_STATE_OBJECT_DESC& ContainingStateObject) - { - Init(); - AddToStateObject(ContainingStateObject); - } - void SetRootSignature(ID3D12RootSignature* pRootSig) - { - m_pRootSig = pRootSig; - } - D3D12_STATE_SUBOBJECT_TYPE Type() const - { - return D3D12_STATE_SUBOBJECT_TYPE_LOCAL_ROOT_SIGNATURE; - } - operator const D3D12_STATE_SUBOBJECT& () const { return *m_pSubobject; } - operator ID3D12RootSignature* () const { return m_pRootSig.Get(); } -private: - void Init() - { - SUBOBJECT_HELPER_BASE::Init(); - m_pRootSig = nullptr; - } - void* Data() { return m_pRootSig.GetAddressOf(); } - Microsoft::WRL::ComPtr m_pRootSig; -}; - -class CD3DX12_STATE_OBJECT_CONFIG_SUBOBJECT - : public CD3DX12_STATE_OBJECT_DESC::SUBOBJECT_HELPER_BASE -{ -public: - CD3DX12_STATE_OBJECT_CONFIG_SUBOBJECT() - { - Init(); - } - CD3DX12_STATE_OBJECT_CONFIG_SUBOBJECT(CD3DX12_STATE_OBJECT_DESC& ContainingStateObject) - { - Init(); - AddToStateObject(ContainingStateObject); - } - void SetFlags(D3D12_STATE_OBJECT_FLAGS Flags) - { - m_Desc.Flags = Flags; - } - D3D12_STATE_SUBOBJECT_TYPE Type() const - { - return D3D12_STATE_SUBOBJECT_TYPE_STATE_OBJECT_CONFIG; - } - operator const D3D12_STATE_SUBOBJECT& () const { return *m_pSubobject; } - operator const D3D12_STATE_OBJECT_CONFIG& () const { return m_Desc; } -private: - void Init() - { - SUBOBJECT_HELPER_BASE::Init(); - m_Desc = {}; - } - void* Data() { return &m_Desc; } - D3D12_STATE_OBJECT_CONFIG m_Desc; -}; - -class CD3DX12_NODE_MASK_SUBOBJECT - : public CD3DX12_STATE_OBJECT_DESC::SUBOBJECT_HELPER_BASE -{ -public: - CD3DX12_NODE_MASK_SUBOBJECT() - { - Init(); - } - CD3DX12_NODE_MASK_SUBOBJECT(CD3DX12_STATE_OBJECT_DESC& ContainingStateObject) - { - Init(); - AddToStateObject(ContainingStateObject); - } - void SetNodeMask(UINT NodeMask) - { - m_Desc.NodeMask = NodeMask; - } - D3D12_STATE_SUBOBJECT_TYPE Type() const - { - return D3D12_STATE_SUBOBJECT_TYPE_NODE_MASK; - } - operator const D3D12_STATE_SUBOBJECT& () const { return *m_pSubobject; } - operator const D3D12_NODE_MASK& () const { return m_Desc; } -private: - void Init() - { - SUBOBJECT_HELPER_BASE::Init(); - m_Desc = {}; - } - void* Data() { return &m_Desc; } - D3D12_NODE_MASK m_Desc; -}; - -#endif // #ifndef D3DX12_NO_STATE_OBJECT_HELPERS - -#endif // defined( __cplusplus ) - -#endif //__D3DX12_H__ diff --git a/src/tools/nnfusion/templates/dxcompute/d3dx12_nnfusion.h b/src/tools/nnfusion/templates/dxcompute/d3dx12_nnfusion.h deleted file mode 100644 index b2b977575..000000000 --- a/src/tools/nnfusion/templates/dxcompute/d3dx12_nnfusion.h +++ /dev/null @@ -1,349 +0,0 @@ -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT License. - -#pragma once - -#include "d3dx12_helper.h" -#include - -#define ASSERT(x) ((x) ? (printf("Error-line: (%s) %d\n", __FILE__, __LINE__), _exit(1), 0): 1) - -namespace nnfusion_dml -{ - template - std::string read_file(P& printer, const T& name) - { - std::ifstream t(name, ios_base::binary); - if (t.fail()) - { - printer << "[Error] Cannot find file from: `" << name - << "`, please copy the full codegen folder!" << std::endl; - ASSERT(0); - } - std::string str((std::istreambuf_iterator(t)), std::istreambuf_iterator()); - return std::move(str); - } - - template - std::vector load_data(const std::string& name, size_t num_elements, const T defval = 1) - { - std::vector ret(num_elements * sizeof(T)); - if (name == "") - { - auto hptr = (T*)ret.data(); - std::fill(hptr, hptr + num_elements, defval); - } - else - { - auto str = read_file(std::cout, "Constant\\" + name); - assert(str.size() == num_elements * sizeof(T)); - memcpy(ret.data(), str.data(), str.size()); - } - return std::move(ret); - } - - static std::vector cmdQueue, preloadQueue; - static std::map> computeShaderDict; - static std::map> profCostDict; - static unsigned long totalGPUMemoryAccess = 0; - - class NNfusionTensor - { - ComPtr deviceGPUSrcX; - std::vector shape; - size_t type_size; - - public: - NNfusionTensor(D3DDevice& device, const std::vector& shape, size_t type_size) - : shape(shape) - , type_size(type_size) - { - size_t size = type_size * NumElements(); - size = ((size - 1) | 1023) + 1; - totalGPUMemoryAccess += size; - device.CreateGPUOnlyResource(size, &deviceGPUSrcX); - } - - size_t NumElements() const - { - return std::accumulate(shape.begin(), shape.end(), 1LU, std::multiplies()); - } - - size_t TypeSize() const { return type_size; } - ComPtr Data() const { return deviceGPUSrcX; } - std::vector Shape() const { return shape; } - }; - - - class NNfusionMemcpy - { - ComPtr deviceGPUSrcX; - ComPtr deviceCPUSrcX; - ComPtr m_computeCommandList; - size_t bufferSize, elements; - - public: - NNfusionMemcpy(D3DDevice& device, - NNfusionTensor& dst, - const std::vector &src, bool preload = false) - { - elements = dst.NumElements(); - bufferSize = dst.TypeSize() * dst.NumElements(); - bufferSize = ((bufferSize - 1) | 1023) + 1; - - deviceGPUSrcX = dst.Data(); - device.CreateUploadBuffer(bufferSize, &deviceCPUSrcX); - device.MapAndCopyToResource(deviceCPUSrcX.Get(), src.data(), src.size()); - - IFE(device.pDevice->CreateCommandList(0, - D3D12_COMMAND_LIST_TYPE_COMPUTE, - device.pCommandAllocator.Get(), - nullptr, - IID_PPV_ARGS(&m_computeCommandList))); - m_computeCommandList->CopyResource(deviceGPUSrcX.Get(), deviceCPUSrcX.Get()); - m_computeCommandList->Close(); - - if (preload) - { - preloadQueue.push_back(Launch()); - return; - } - cmdQueue.push_back(Launch()); - } - - NNfusionMemcpy(D3DDevice& device, - void* dst, - NNfusionTensor& src) - { - elements = src.NumElements(); - bufferSize = src.TypeSize() * src.NumElements(); - bufferSize = ((bufferSize - 1) | 1023) + 1; - - deviceGPUSrcX = src.Data(); - device.CreateReadbackBuffer(bufferSize, &deviceCPUSrcX); - - IFE(device.pDevice->CreateCommandList(0, - D3D12_COMMAND_LIST_TYPE_COMPUTE, - device.pCommandAllocator.Get(), - nullptr, - IID_PPV_ARGS(&m_computeCommandList))); - m_computeCommandList->CopyResource(deviceCPUSrcX.Get(), deviceGPUSrcX.Get()); - m_computeCommandList->Close(); - cmdQueue.push_back(Launch()); - } - - ID3D12GraphicsCommandList* Launch() { return m_computeCommandList.Get(); } - template - void PrintStageBuffer(D3DDevice& device, const std::string& name) - { - assert(bufferSize % sizeof(T) == 0); - std::vector dst(bufferSize / sizeof(T)); - device.MapCopyFromResource(deviceCPUSrcX.Get(), dst.data(), bufferSize); - T* buffer = (T*)dst.data(); - std::cout << "Result(" << name << ") = {"; - - constexpr size_t most_display = 6L; - for (int i = 0; i < min(elements, most_display); ++i) - { - if (i) - std::cout << ", "; - std::cout << dst[i]; - } - if (elements > most_display) - { - std::cout << " .., " << dst[elements - 1]; - } - std::cout << "}\n" << std::endl; - } - }; - - class NNfusionOperator - { - ComPtr m_computeCommandList; - - ComPtr m_computeRootSignature; - ComPtr computeShader; - ComPtr m_computeState; - D3D12_COMPUTE_PIPELINE_STATE_DESC computePsoDesc; - - LPCWSTR hlsl_source; - - public: - NNfusionOperator(D3DDevice& device, - const std::vector& inputs, - const std::vector& outputs, - LPCWSTR hlsl_source) - : hlsl_source(hlsl_source) - { - -#define _USE_DECRIPTOR_HEAP_ - -#ifdef _USE_DECRIPTOR_HEAP_ - - struct DescHeap { - ComPtr heap; - D3D12_CPU_DESCRIPTOR_HANDLE cpuHandle; - UINT nDescStep, offsetRecord; - }; - - static DescHeap globalDescHeap; - static bool initHeap = false; - if (!globalDescHeap.nDescStep) { - initHeap = true; - auto InitDescriptorHeap = [](ID3D12Device* pDevice, D3D12_DESCRIPTOR_HEAP_TYPE type, UINT nDescriptors) - { - D3D12_DESCRIPTOR_HEAP_DESC desc; - memset(&desc, 0, sizeof(desc)); - ZeroMemory(&desc, sizeof(desc)); - desc.NumDescriptors = nDescriptors; - desc.Type = type; - desc.Flags = D3D12_DESCRIPTOR_HEAP_FLAG_SHADER_VISIBLE; - ComPtr pDescHeap; - IFE(pDevice->CreateDescriptorHeap(&desc, IID_PPV_ARGS(&pDescHeap))); - - globalDescHeap.nDescStep = pDevice->GetDescriptorHandleIncrementSize(type); - globalDescHeap.heap = pDescHeap; - globalDescHeap.cpuHandle = pDescHeap->GetCPUDescriptorHandleForHeapStart(); - globalDescHeap.offsetRecord = 0; - }; - - const UINT MAX_HEAP_SIZE = (1U << 20); - InitDescriptorHeap(device.pDevice.Get(), D3D12_DESCRIPTOR_HEAP_TYPE_CBV_SRV_UAV, MAX_HEAP_SIZE); - assert(globalDescHeap.nDescStep > 0); - } - - std::vector argsOffset; - // Prepare Heap Argument Offset - for (int i = 0; i < inputs.size(); ++i) { - D3D12_SHADER_RESOURCE_VIEW_DESC srvDesc; - ZeroMemory(&srvDesc, sizeof(srvDesc)); - srvDesc.Format = DXGI_FORMAT_UNKNOWN; - srvDesc.ViewDimension = D3D12_SRV_DIMENSION_BUFFER; - srvDesc.Buffer.FirstElement = 0; - srvDesc.Buffer.NumElements = inputs[i].NumElements(); - srvDesc.Buffer.StructureByteStride = inputs[i].TypeSize(); - srvDesc.Shader4ComponentMapping = D3D12_DEFAULT_SHADER_4_COMPONENT_MAPPING; - - device.pDevice->CreateShaderResourceView(inputs[i].Data().Get(), &srvDesc, globalDescHeap.cpuHandle); - globalDescHeap.cpuHandle.ptr += globalDescHeap.nDescStep; - argsOffset.push_back(globalDescHeap.offsetRecord++); - assert(globalDescHeap.offsetRecord <= MAX_HEAP_SIZE); - } - for (int i = 0; i < outputs.size(); ++i) { - D3D12_UNORDERED_ACCESS_VIEW_DESC uavDesc; - ZeroMemory(&uavDesc, sizeof(uavDesc)); - uavDesc.Format = DXGI_FORMAT_UNKNOWN; - uavDesc.ViewDimension = D3D12_UAV_DIMENSION_BUFFER; - uavDesc.Buffer.FirstElement = 0; - uavDesc.Buffer.NumElements = outputs[i].NumElements(); - uavDesc.Buffer.StructureByteStride = outputs[i].TypeSize(); - device.pDevice->CreateUnorderedAccessView(outputs[i].Data().Get(), nullptr, &uavDesc, globalDescHeap.cpuHandle); - globalDescHeap.cpuHandle.ptr += globalDescHeap.nDescStep; - argsOffset.push_back(globalDescHeap.offsetRecord++); - assert(globalDescHeap.offsetRecord <= MAX_HEAP_SIZE); - } - - // Prepare Root - std::vector computeRootParameters(1); - CD3DX12_DESCRIPTOR_RANGE1 ranges[2]; - // D3D12_DESCRIPTOR_RANGE_FLAG_DESCRIPTORS_VOLATILE is needed to disable unproper driver optimization. - ranges[0].Init(D3D12_DESCRIPTOR_RANGE_TYPE_SRV, inputs.size(), 0, 0, D3D12_DESCRIPTOR_RANGE_FLAG_DESCRIPTORS_VOLATILE, argsOffset[0]); - ranges[1].Init(D3D12_DESCRIPTOR_RANGE_TYPE_UAV, outputs.size(), 0, 0, D3D12_DESCRIPTOR_RANGE_FLAG_DESCRIPTORS_VOLATILE, argsOffset[inputs.size()]); - - computeRootParameters[0].InitAsDescriptorTable(2, ranges); - CD3DX12_VERSIONED_ROOT_SIGNATURE_DESC computeRootSignatureDesc; - computeRootSignatureDesc.Init_1_1((UINT)computeRootParameters.size(), - computeRootParameters.data()); -#else - // Prepare Root - std::vector computeRootParameters(inputs.size() + - outputs.size()); - for (int i = 0; i < inputs.size(); ++i) - computeRootParameters[i].InitAsShaderResourceView(i); - for (int i = 0; i < outputs.size(); ++i) - computeRootParameters[inputs.size() + i].InitAsUnorderedAccessView(i); - - CD3DX12_VERSIONED_ROOT_SIGNATURE_DESC computeRootSignatureDesc; - computeRootSignatureDesc.Init_1_1(computeRootParameters.size(), - computeRootParameters.data()); -#endif - - ComPtr signature; - ComPtr error; - - IFE(D3DX12SerializeVersionedRootSignature( - &computeRootSignatureDesc, D3D_ROOT_SIGNATURE_VERSION_1_1, &signature, &error)); - IFE(device.pDevice->CreateRootSignature(0, - signature->GetBufferPointer(), - signature->GetBufferSize(), - IID_PPV_ARGS(&m_computeRootSignature))); - - auto path = std::wstring(L"HLSL\\") + hlsl_source; - std::wcout << L"[Info] Loading HLSL data from: `" << path << L"` .." << std::endl; - auto str = read_file(std::wcout, path); - int at_bx = str.find("// [thread_extent] blockIdx.x = "), blockX = (at_bx >= 0) ? std::atoi(str.data() + at_bx + sizeof("// [thread_extent] blockIdx.x = ") - 1) : 1; - int at_by = str.find("// [thread_extent] blockIdx.y = "), blockY = (at_by >= 0) ? std::atoi(str.data() + at_by + sizeof("// [thread_extent] blockIdx.y = ") - 1) : 1; - int at_bz = str.find("// [thread_extent] blockIdx.z = "), blockZ = (at_bz >= 0) ? std::atoi(str.data() + at_bz + sizeof("// [thread_extent] blockIdx.z = ") - 1) : 1; - std::vector threads = { (UINT)blockX, (UINT)blockY, (UINT)blockZ }; - - auto it = computeShaderDict.find(hlsl_source); - if (it == computeShaderDict.end()) - { - IFE(D3DCompileFromFile( - path.c_str(), NULL, NULL, "CSMain", "cs_5_0", 0, 0, &computeShader, NULL)); - computeShaderDict[hlsl_source] = computeShader; - } - else - computeShader = it->second; - - computePsoDesc.pRootSignature = m_computeRootSignature.Get(); - computePsoDesc.CS = CD3DX12_SHADER_BYTECODE(computeShader.Get()); - - IFE(device.pDevice->CreateComputePipelineState(&computePsoDesc, - IID_PPV_ARGS(&m_computeState))); - IFE(device.pDevice->CreateCommandList(0, - D3D12_COMMAND_LIST_TYPE_COMPUTE, - device.pCommandAllocator.Get(), - m_computeState.Get(), - IID_PPV_ARGS(&m_computeCommandList))); - - m_computeCommandList->SetComputeRootSignature(m_computeRootSignature.Get()); - -#ifdef _USE_DECRIPTOR_HEAP_ - ID3D12DescriptorHeap* pHeaps[] = { globalDescHeap.heap.Get() }; - m_computeCommandList->SetDescriptorHeaps(1, pHeaps); - m_computeCommandList->SetComputeRootDescriptorTable(0, globalDescHeap.heap->GetGPUDescriptorHandleForHeapStart()); -#else - for (int i = 0; i < inputs.size(); ++i) - m_computeCommandList->SetComputeRootShaderResourceView( - i, inputs[i].Data()->GetGPUVirtualAddress()); - for (int i = 0; i < outputs.size(); ++i) - m_computeCommandList->SetComputeRootUnorderedAccessView( - inputs.size() + i, outputs[i].Data()->GetGPUVirtualAddress()); -#endif - m_computeCommandList->Dispatch(threads[0], threads[1], threads[2]); - IFE(m_computeCommandList->Close()); - - cmdQueue.push_back(Launch()); - - if (!profCostDict.count(hlsl_source)) { - std::chrono::high_resolution_clock::time_point t1 = std::chrono::high_resolution_clock::now(); - constexpr int NUM_STEPS = 10; - for (int i = 0; i < NUM_STEPS; i++) - { - device.pCommandQueue->ExecuteCommandLists(1, cmdQueue.data() + cmdQueue.size() - 1); - device.AwaitExecution(); - } - std::chrono::high_resolution_clock::time_point t2 = std::chrono::high_resolution_clock::now(); - double sec = std::chrono::duration_cast>(t2 - t1).count() / - NUM_STEPS; - profCostDict[hlsl_source] = { sec, 1 }; - } - else - profCostDict[hlsl_source].second++; - } - - ID3D12GraphicsCommandList* Launch() { return m_computeCommandList.Get(); } - }; -} - diff --git a/src/tools/nnfusion/templates/dxcompute/make.bat b/src/tools/nnfusion/templates/dxcompute/make.bat deleted file mode 100644 index dab0fa5f6..000000000 --- a/src/tools/nnfusion/templates/dxcompute/make.bat +++ /dev/null @@ -1,16 +0,0 @@ -@echo off - -echo Downloading dependencies .. -curl -LOs https://github.com/microsoft/antares/raw/library/antares_hlsl_v0.1_x64.dll - -echo Compiling nnfusion_rt .. -C:\Windows\Microsoft.NET\Framework64\v4.0.30319\csc.exe nnfusion_rt.cs - -echo Compiling finished! -pause - -echo Executing program .. -nnfusion_rt - -echo Program finished! -pause diff --git a/src/tools/nnfusion/templates/dxcompute/run_graph.cpp b/src/tools/nnfusion/templates/dxcompute/run_graph.cpp deleted file mode 100644 index b5dc5e35f..000000000 --- a/src/tools/nnfusion/templates/dxcompute/run_graph.cpp +++ /dev/null @@ -1,60 +0,0 @@ -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT License. - -#include "d3dx12_nnfusion.h" - -int main(int argc, char** argv) -{ - D3DDevice device(false, false); - device.Init(); - - using namespace nnfusion_dml; - -#include "nnfusion_rt.h" - - auto evaluateQueue = [&](const std::vector& cmdQueue, const char* qtype) { - std::chrono::high_resolution_clock::time_point t1 = - std::chrono::high_resolution_clock::now(); - constexpr int NUM_STEPS = 10; - for (int i = 0; i < NUM_STEPS; i++) - { - device.pCommandQueue->ExecuteCommandLists(cmdQueue.size(), cmdQueue.data()); - device.AwaitExecution(); - } - std::chrono::high_resolution_clock::time_point t2 = - std::chrono::high_resolution_clock::now(); - printf("DxCompute Time per Run for [%s] = %g sec.\n", - qtype, - std::chrono::duration_cast>(t2 - t1).count() / - NUM_STEPS); - }; - - if (profCostDict.size() > 0) - { - double evaluate_sum = 0.0; - std::multimap orderedProf; - for (auto& it : profCostDict) - { - double timecost = it.second.first * it.second.second; - evaluate_sum += timecost; - orderedProf.insert(std::make_pair(timecost, it.first)); - } - for (auto it = orderedProf.rbegin(); it != orderedProf.rend(); ++it) - { - auto ratio = std::to_wstring(it->first * 1e2 / evaluate_sum); - if (ratio.size() > 6) - ratio = ratio.substr(0, 6); - printf("%8ls%% %6d %4.8lf\t%ls\n", - ratio.c_str(), - profCostDict[it->second].second, - it->first, - it->second.c_str()); - } - printf("DxCompute Time per Run for [Profile Sum] = %g sec.\n", evaluate_sum); - } - - evaluateQueue(cmdQueue, "Standard Queue"); - printf("Total GPU Memory Allocated = %g MB\n", totalGPUMemoryAccess / double(1024 * 1024)); - system("pause"); - return 0; -}