Skip to content

Commit

Permalink
[SYCL][Graph] Add exceptions on using spec constants and kernel bundles
Browse files Browse the repository at this point in the history
- Add exceptions when using spec constants and kernel bundles
- Unit tests for both of these.
- Refactored handler throwing code to remove templates.
- Moved and renamed unsupported features enum to graph header.
  • Loading branch information
Bensuo committed Aug 7, 2023
1 parent bc01f0f commit 077c65a
Show file tree
Hide file tree
Showing 4 changed files with 132 additions and 63 deletions.
32 changes: 32 additions & 0 deletions sycl/include/sycl/ext/oneapi/experimental/graph.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,38 @@ namespace oneapi {
namespace experimental {

namespace detail {
// List of sycl features and extensions which are not supported by graphs. Used
// for throwing errors when these features are used with graphs.
enum class UnsupportedGraphFeatures {
sycl_specialization_constants,
sycl_kernel_bundle,
sycl_ext_oneapi_kernel_properties,
sycl_ext_oneapi_enqueue_barrier,
sycl_ext_oneapi_memcpy2d,
sycl_ext_oneapi_device_global
};

constexpr const char *
UnsupportedFeatureToString(UnsupportedGraphFeatures Feature) {
using UGF = UnsupportedGraphFeatures;
switch (Feature) {
case UGF::sycl_specialization_constants:
return "Specialization Constants";
case UGF::sycl_kernel_bundle:
return "Kernel Bundles";
case UGF::sycl_ext_oneapi_kernel_properties:
return "sycl_ext_oneapi_kernel_properties";
case UGF::sycl_ext_oneapi_enqueue_barrier:
return "sycl_ext_oneapi_enqueue_barrier";
case UGF::sycl_ext_oneapi_memcpy2d:
return "sycl_ext_oneapi_memcpy2d";
case UGF::sycl_ext_oneapi_device_global:
return "sycl_ext_oneapi_device_global";
default:
return {};
}
}

class node_impl;
class graph_impl;
class exec_graph_impl;
Expand Down
51 changes: 27 additions & 24 deletions sycl/include/sycl/handler.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -110,17 +110,6 @@ class pipe;
}

namespace ext::oneapi::experimental::detail {
// List of sycl experimental extensions
// This enum is used to define the extension from which a function is called.
// This is used in handler::throwIfGraphAssociated() to specify
// the message of the thrown expection.
enum SyclExtensions {
sycl_ext_oneapi_kernel_properties,
sycl_ext_oneapi_enqueue_barrier,
sycl_ext_oneapi_memcpy2d,
sycl_ext_oneapi_device_global
};

class graph_impl;
} // namespace ext::oneapi::experimental::detail
namespace detail {
Expand Down Expand Up @@ -1571,6 +1560,10 @@ class __SYCL_EXPORT handler {
void set_specialization_constant(
typename std::remove_reference_t<decltype(SpecName)>::value_type Value) {

throwIfGraphAssociated(
ext::oneapi::experimental::detail::UnsupportedGraphFeatures::
sycl_specialization_constants);

setStateSpecConstSet();

std::shared_ptr<detail::kernel_bundle_impl> KernelBundleImplPtr =
Expand All @@ -1585,6 +1578,10 @@ class __SYCL_EXPORT handler {
typename std::remove_reference_t<decltype(SpecName)>::value_type
get_specialization_constant() const {

throwIfGraphAssociated(
ext::oneapi::experimental::detail::UnsupportedGraphFeatures::
sycl_specialization_constants);

if (isStateExplicitKernelBundle())
throw sycl::exception(make_error_code(errc::invalid),
"Specialization constants cannot be read after "
Expand Down Expand Up @@ -2555,8 +2552,9 @@ class __SYCL_EXPORT handler {
/// until all commands previously submitted to this queue have entered the
/// complete state.
void ext_oneapi_barrier() {
throwIfGraphAssociated<ext::oneapi::experimental::detail::SyclExtensions::
sycl_ext_oneapi_enqueue_barrier>();
throwIfGraphAssociated(
ext::oneapi::experimental::detail::UnsupportedGraphFeatures::
sycl_ext_oneapi_enqueue_barrier);
throwIfActionIsCreated();
setType(detail::CG::Barrier);
}
Expand Down Expand Up @@ -2642,8 +2640,9 @@ class __SYCL_EXPORT handler {
typename = std::enable_if_t<std::is_same_v<T, unsigned char>>>
void ext_oneapi_memcpy2d(void *Dest, size_t DestPitch, const void *Src,
size_t SrcPitch, size_t Width, size_t Height) {
throwIfGraphAssociated<ext::oneapi::experimental::detail::SyclExtensions::
sycl_ext_oneapi_memcpy2d>();
throwIfGraphAssociated(
ext::oneapi::experimental::detail::UnsupportedGraphFeatures::
sycl_ext_oneapi_memcpy2d);
throwIfActionIsCreated();
if (Width > DestPitch)
throw sycl::exception(sycl::make_error_code(errc::invalid),
Expand Down Expand Up @@ -2822,8 +2821,9 @@ class __SYCL_EXPORT handler {
void memcpy(ext::oneapi::experimental::device_global<T, PropertyListT> &Dest,
const void *Src, size_t NumBytes = sizeof(T),
size_t DestOffset = 0) {
throwIfGraphAssociated<ext::oneapi::experimental::detail::SyclExtensions::
sycl_ext_oneapi_device_global>();
throwIfGraphAssociated(
ext::oneapi::experimental::detail::UnsupportedGraphFeatures::
sycl_ext_oneapi_device_global);
if (sizeof(T) < DestOffset + NumBytes)
throw sycl::exception(make_error_code(errc::invalid),
"Copy to device_global is out of bounds.");
Expand Down Expand Up @@ -2856,8 +2856,9 @@ class __SYCL_EXPORT handler {
memcpy(void *Dest,
const ext::oneapi::experimental::device_global<T, PropertyListT> &Src,
size_t NumBytes = sizeof(T), size_t SrcOffset = 0) {
throwIfGraphAssociated<ext::oneapi::experimental::detail::SyclExtensions::
sycl_ext_oneapi_device_global>();
throwIfGraphAssociated(
ext::oneapi::experimental::detail::UnsupportedGraphFeatures::
sycl_ext_oneapi_device_global);
if (sizeof(T) < SrcOffset + NumBytes)
throw sycl::exception(make_error_code(errc::invalid),
"Copy from device_global is out of bounds.");
Expand Down Expand Up @@ -3382,18 +3383,20 @@ class __SYCL_EXPORT handler {
template <typename PropertiesT>
std::enable_if_t<
ext::oneapi::experimental::is_property_list<PropertiesT>::value>
throwIfGraphAssociatedAndKernelProperties() {
throwIfGraphAssociatedAndKernelProperties() const {
if (!std::is_same_v<PropertiesT,
ext::oneapi::experimental::detail::empty_properties_t>)
throwIfGraphAssociated<ext::oneapi::experimental::detail::SyclExtensions::
sycl_ext_oneapi_kernel_properties>();
throwIfGraphAssociated(
ext::oneapi::experimental::detail::UnsupportedGraphFeatures::
sycl_ext_oneapi_kernel_properties);
}

// Set value of the gpu cache configuration for the kernel.
void setKernelCacheConfig(sycl::detail::pi::PiKernelCacheConfig);

template <ext::oneapi::experimental::detail::SyclExtensions ExtensionT>
void throwIfGraphAssociated();
void throwIfGraphAssociated(
ext::oneapi::experimental::detail::UnsupportedGraphFeatures Feature)
const;
};
} // namespace _V1
} // namespace sycl
49 changes: 10 additions & 39 deletions sycl/source/handler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -799,8 +799,9 @@ void handler::verifyUsedKernelBundle(const std::string &KernelName) {
}

void handler::ext_oneapi_barrier(const std::vector<event> &WaitList) {
throwIfGraphAssociated<ext::oneapi::experimental::detail::SyclExtensions::
sycl_ext_oneapi_enqueue_barrier>();
throwIfGraphAssociated(
ext::oneapi::experimental::detail::UnsupportedGraphFeatures::
sycl_ext_oneapi_enqueue_barrier);
throwIfActionIsCreated();
MCGType = detail::CG::BarrierWaitlist;
MEventsWaitWithBarrier.resize(WaitList.size());
Expand Down Expand Up @@ -1109,6 +1110,9 @@ void handler::ext_oneapi_signal_external_semaphore(
void handler::use_kernel_bundle(
const kernel_bundle<bundle_state::executable> &ExecBundle) {

throwIfGraphAssociated(ext::oneapi::experimental::detail::
UnsupportedGraphFeatures::sycl_kernel_bundle);

std::shared_ptr<detail::queue_impl> PrimaryQueue =
MImpl->MSubmissionPrimaryQueue;
if (PrimaryQueue->get_context() != ExecBundle.get_context())
Expand Down Expand Up @@ -1358,46 +1362,13 @@ handler::getCommandGraph() const {
return MQueue->getCommandGraph();
}

template void handler::throwIfGraphAssociated<
ext::oneapi::experimental::detail::SyclExtensions::
sycl_ext_oneapi_kernel_properties>();
template void handler::throwIfGraphAssociated<
ext::oneapi::experimental::detail::SyclExtensions::
sycl_ext_oneapi_enqueue_barrier>();
template void
handler::throwIfGraphAssociated<ext::oneapi::experimental::detail::
SyclExtensions::sycl_ext_oneapi_memcpy2d>();
template void handler::throwIfGraphAssociated<
ext::oneapi::experimental::detail::SyclExtensions::
sycl_ext_oneapi_device_global>();

template <ext::oneapi::experimental::detail::SyclExtensions ExtensionT>
void handler::throwIfGraphAssociated() {
std::string ExceptionMsg = "";

if constexpr (ExtensionT ==
ext::oneapi::experimental::detail::SyclExtensions::
sycl_ext_oneapi_kernel_properties) {
ExceptionMsg = "sycl_ext_oneapi_kernel_properties";
}
if constexpr (ExtensionT ==
ext::oneapi::experimental::detail::SyclExtensions::
sycl_ext_oneapi_enqueue_barrier) {
ExceptionMsg = "sycl_ext_oneapi_enqueue_barrier";
}
if constexpr (ExtensionT == ext::oneapi::experimental::detail::
SyclExtensions::sycl_ext_oneapi_memcpy2d) {
ExceptionMsg = "sycl_ext_oneapi_memcpy2d";
}
if constexpr (ExtensionT ==
ext::oneapi::experimental::detail::SyclExtensions::
sycl_ext_oneapi_device_global) {
ExceptionMsg = "sycl_ext_oneapi_device_global";
}
void handler::throwIfGraphAssociated(
ext::oneapi::experimental::detail::UnsupportedGraphFeatures Feature) const {

if (MGraph || MQueue->getCommandGraph()) {
std::string FeatureString = UnsupportedFeatureToString(Feature);
throw sycl::exception(sycl::make_error_code(errc::invalid),
"The feature " + ExceptionMsg +
"The feature " + FeatureString +
" is not yet available "
"along with SYCL Graph extension.");
}
Expand Down
63 changes: 63 additions & 0 deletions sycl/unittests/Extensions/CommandGraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,21 @@
using namespace sycl;
using namespace sycl::ext::oneapi;

// Spec constant for testing.
constexpr specialization_id<int> SpecConst1{7};

namespace sycl {
inline namespace _V1 {
namespace detail {

// Necessary for get_specialization_constant() to work in unit tests.
template <> const char *get_spec_constant_symbolic_ID<SpecConst1>() {
return "SC1";
}
} // namespace detail
} // namespace _V1
} // namespace sycl

// anonymous namespace used to avoid code redundancy by defining functions
// used by multiple times by unitests.
// Defining anonymous namespace prevents from function naming conflits
Expand Down Expand Up @@ -1139,6 +1154,54 @@ TEST_F(CommandGraphTest, Memcpy2DExceptionCheck) {
sycl::free(USMMemDst, Queue);
}

// Tests that using specialization constants in a graph will throw.
TEST_F(CommandGraphTest, SpecializationConstant) {

ASSERT_THROW(
{
try {
Graph.add([&](handler &CGH) {
CGH.set_specialization_constant<SpecConst1>(8);
});
} catch (const sycl::exception &e) {
ASSERT_EQ(e.code(), make_error_code(sycl::errc::invalid));
throw;
}
},
sycl::exception);
ASSERT_THROW(
{
try {
Graph.add([&](handler &CGH) {
int Value = CGH.get_specialization_constant<SpecConst1>();
(void)Value;
});
} catch (const sycl::exception &e) {
ASSERT_EQ(e.code(), make_error_code(sycl::errc::invalid));
throw;
}
},
sycl::exception);
}

// Tests that using kernel bundles in a graph will throw.
TEST_F(CommandGraphTest, KernelBundle) {
sycl::kernel_bundle KernelBundle =
sycl::get_kernel_bundle<sycl::bundle_state::executable>(
Queue.get_context(), {Dev});

ASSERT_THROW(
{
try {
Graph.add([&](handler &CGH) { CGH.use_kernel_bundle(KernelBundle); });
} catch (const sycl::exception &e) {
ASSERT_EQ(e.code(), make_error_code(sycl::errc::invalid));
throw;
}
},
sycl::exception);
}

class MultiThreadGraphTest : public CommandGraphTest {
public:
MultiThreadGraphTest()
Expand Down

0 comments on commit 077c65a

Please sign in to comment.