Skip to content

Commit

Permalink
[SYCL][Graph] Make SYCL-Graph functions thread-safe (intel#10778)
Browse files Browse the repository at this point in the history
This PR makes the new APIs defined by
[sycl_ext_oneapi_graph](https://github.com/intel/llvm/blob/sycl/sycl/doc/extensions/proposed/sycl_ext_oneapi_graph.asciidoc)
thread safe.


## Authors

Co-authored-by: Pablo Reble <[email protected]>
Co-authored-by: Julian Miller <[email protected]>
Co-authored-by: Ben Tracy <[email protected]>
Co-authored-by: Ewan Crawford <[email protected]>
Co-authored-by: Maxime France-Pillois
<[email protected]>
  • Loading branch information
EwanC authored Aug 23, 2023
1 parent fdee56c commit c8c64a6
Show file tree
Hide file tree
Showing 12 changed files with 697 additions and 8 deletions.
17 changes: 16 additions & 1 deletion sycl/source/detail/graph_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,8 @@ void exec_graph_impl::createCommandBuffers(sycl::device Device) {
}

exec_graph_impl::~exec_graph_impl() {
WriteLock LockImpl(MGraphImpl->MMutex);

// clear all recording queue if not done before (no call to end_recording)
MGraphImpl->clearQueues();

Expand All @@ -370,6 +372,8 @@ exec_graph_impl::~exec_graph_impl() {
sycl::event
exec_graph_impl::enqueue(const std::shared_ptr<sycl::detail::queue_impl> &Queue,
sycl::detail::CG::StorageInitHelper CGData) {
WriteLock Lock(MMutex);

auto CreateNewEvent([&]() {
auto NewEvent = std::make_shared<sycl::detail::event_impl>(Queue);
NewEvent->setContextImpl(Queue->getContextImplPtr());
Expand Down Expand Up @@ -483,6 +487,7 @@ node modifiable_command_graph::addImpl(const std::vector<node> &Deps) {
DepImpls.push_back(sycl::detail::getSyclObjImpl(D));
}

graph_impl::WriteLock Lock(impl->MMutex);
std::shared_ptr<detail::node_impl> NodeImpl = impl->add(DepImpls);
return sycl::detail::createSyclObjFromImpl<node>(NodeImpl);
}
Expand All @@ -494,6 +499,7 @@ node modifiable_command_graph::addImpl(std::function<void(handler &)> CGF,
DepImpls.push_back(sycl::detail::getSyclObjImpl(D));
}

graph_impl::WriteLock Lock(impl->MMutex);
std::shared_ptr<detail::node_impl> NodeImpl =
impl->add(impl, CGF, {}, DepImpls);
return sycl::detail::createSyclObjFromImpl<node>(NodeImpl);
Expand All @@ -505,13 +511,17 @@ void modifiable_command_graph::make_edge(node &Src, node &Dest) {
std::shared_ptr<detail::node_impl> ReceiverImpl =
sycl::detail::getSyclObjImpl(Dest);

graph_impl::WriteLock Lock(impl->MMutex);
SenderImpl->registerSuccessor(ReceiverImpl,
SenderImpl); // register successor
impl->removeRoot(ReceiverImpl); // remove receiver from root node list
}

command_graph<graph_state::executable>
modifiable_command_graph::finalize(const sycl::property_list &) const {
// Graph is read and written in this scope so we lock
// this graph with full priviledges.
graph_impl::WriteLock Lock(impl->MMutex);
return command_graph<graph_state::executable>{this->impl,
this->impl->getContext()};
}
Expand Down Expand Up @@ -549,6 +559,7 @@ bool modifiable_command_graph::begin_recording(queue &RecordingQueue) {

if (QueueImpl->getCommandGraph() == nullptr) {
QueueImpl->setCommandGraph(impl);
graph_impl::WriteLock Lock(impl->MMutex);
impl->addQueue(QueueImpl);
return true;
}
Expand All @@ -570,12 +581,16 @@ bool modifiable_command_graph::begin_recording(
return QueueStateChanged;
}

bool modifiable_command_graph::end_recording() { return impl->clearQueues(); }
bool modifiable_command_graph::end_recording() {
graph_impl::WriteLock Lock(impl->MMutex);
return impl->clearQueues();
}

bool modifiable_command_graph::end_recording(queue &RecordingQueue) {
auto QueueImpl = sycl::detail::getSyclObjImpl(RecordingQueue);
if (QueueImpl && QueueImpl->getCommandGraph() == impl) {
QueueImpl->setCommandGraph(nullptr);
graph_impl::WriteLock Lock(impl->MMutex);
impl->removeQueue(QueueImpl);
return true;
}
Expand Down
171 changes: 167 additions & 4 deletions sycl/source/detail/graph_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include <functional>
#include <list>
#include <set>
#include <shared_mutex>

namespace sycl {
inline namespace _V1 {
Expand Down Expand Up @@ -167,6 +168,62 @@ class node_impl {
return nullptr;
}

/// Tests is the caller is similar to Node
/// @return True if the two nodes are similar
bool isSimilar(std::shared_ptr<node_impl> Node) {
if (MCGType != Node->MCGType)
return false;

if (MSuccessors.size() != Node->MSuccessors.size())
return false;

if (MPredecessors.size() != Node->MPredecessors.size())
return false;

if ((MCGType == sycl::detail::CG::CGTYPE::Kernel) &&
(Node->MCGType == sycl::detail::CG::CGTYPE::Kernel)) {
sycl::detail::CGExecKernel *ExecKernelA =
static_cast<sycl::detail::CGExecKernel *>(MCommandGroup.get());
sycl::detail::CGExecKernel *ExecKernelB =
static_cast<sycl::detail::CGExecKernel *>(Node->MCommandGroup.get());

if (ExecKernelA->MKernelName.compare(ExecKernelB->MKernelName) != 0)
return false;
}
return true;
}

/// Recursive traversal of successor nodes checking for
/// equivalent node successions in Node
/// @param Node pointer to the starting node for structure comparison
/// @return true is same structure found, false otherwise
bool checkNodeRecursive(std::shared_ptr<node_impl> Node) {
size_t FoundCnt = 0;
for (std::shared_ptr<node_impl> SuccA : MSuccessors) {
for (std::shared_ptr<node_impl> SuccB : Node->MSuccessors) {
if (isSimilar(Node) && SuccA->checkNodeRecursive(SuccB)) {
FoundCnt++;
break;
}
}
}
if (FoundCnt != MSuccessors.size()) {
return false;
}

return true;
}

/// Recusively computes the number of successor nodes
/// @return number of successor nodes
size_t depthSearchCount() const {
size_t NumberOfNodes = 1;
for (const auto &Succ : MSuccessors) {
NumberOfNodes += Succ->depthSearchCount();
}
return NumberOfNodes;
}

private:
/// Creates a copy of the node's CG by casting to it's actual type, then using
/// that to copy construct and create a new unique ptr from that copy.
Expand All @@ -180,17 +237,19 @@ class node_impl {
/// Implementation details of command_graph<modifiable>.
class graph_impl {
public:
using ReadLock = std::shared_lock<std::shared_mutex>;
using WriteLock = std::unique_lock<std::shared_mutex>;

/// Protects all the fields that can be changed by class' methods.
mutable std::shared_mutex MMutex;

/// Constructor.
/// @param SyclContext Context to use for graph.
/// @param SyclDevice Device to create nodes with.
graph_impl(const sycl::context &SyclContext, const sycl::device &SyclDevice)
: MContext(SyclContext), MDevice(SyclDevice), MRecordingQueues(),
MEventsMap(), MInorderQueueMap() {}

/// Insert node into list of root nodes.
/// @param Root Node to add to list of root nodes.
void addRoot(const std::shared_ptr<node_impl> &Root);

/// Remove node from list of root nodes.
/// @param Root Node to remove from list of root nodes.
void removeRoot(const std::shared_ptr<node_impl> &Root);
Expand Down Expand Up @@ -264,6 +323,7 @@ class graph_impl {
/// @return Event associated with node.
std::shared_ptr<sycl::detail::event_impl>
getEventForNode(std::shared_ptr<node_impl> NodeImpl) const {
ReadLock Lock(MMutex);
if (auto EventImpl = std::find_if(
MEventsMap.begin(), MEventsMap.end(),
[NodeImpl](auto &it) { return it.second == NodeImpl; });
Expand Down Expand Up @@ -315,6 +375,95 @@ class graph_impl {
MInorderQueueMap[QueueWeakPtr] = Node;
}

/// Checks if the graph_impl of Graph has a similar structure to
/// the graph_impl of the caller.
/// Graphs are considered similar if they have same numbers of nodes
/// of the same type with similar predecessor and successor nodes (number and
/// type). Two nodes are considered similar if they have the same
/// command-group type. For command-groups of type "kernel", the "signature"
/// of the kernel is also compared (i.e. the name of the command-group).
/// @param Graph if reference to the graph to compare with.
/// @param DebugPrint if set to true throw exception with additional debug
/// information about the spotted graph differences.
/// @return true if the two graphs are similar, false otherwise
bool hasSimilarStructure(std::shared_ptr<detail::graph_impl> Graph,
bool DebugPrint = false) const {
if (this == Graph.get())
return true;

if (MContext != Graph->MContext) {
if (DebugPrint) {
throw sycl::exception(sycl::make_error_code(errc::invalid),
"MContext are not the same.");
}
return false;
}

if (MDevice != Graph->MDevice) {
if (DebugPrint) {
throw sycl::exception(sycl::make_error_code(errc::invalid),
"MDevice are not the same.");
}
return false;
}

if (MEventsMap.size() != Graph->MEventsMap.size()) {
if (DebugPrint) {
throw sycl::exception(sycl::make_error_code(errc::invalid),
"MEventsMap sizes are not the same.");
}
return false;
}

if (MInorderQueueMap.size() != Graph->MInorderQueueMap.size()) {
if (DebugPrint) {
throw sycl::exception(sycl::make_error_code(errc::invalid),
"MInorderQueueMap sizes are not the same.");
}
return false;
}

if (MRoots.size() != Graph->MRoots.size()) {
if (DebugPrint) {
throw sycl::exception(sycl::make_error_code(errc::invalid),
"MRoots sizes are not the same.");
}
return false;
}

size_t RootsFound = 0;
for (std::shared_ptr<node_impl> NodeA : MRoots) {
for (std::shared_ptr<node_impl> NodeB : Graph->MRoots) {
if (NodeA->isSimilar(NodeB)) {
if (NodeA->checkNodeRecursive(NodeB)) {
RootsFound++;
break;
}
}
}
}

if (RootsFound != MRoots.size()) {
if (DebugPrint) {
throw sycl::exception(sycl::make_error_code(errc::invalid),
"Root Nodes do NOT match.");
}
return false;
}

return true;
}

// Returns the number of nodes in the Graph
// @return Number of nodes in the Graph
size_t getNumberOfNodes() const {
size_t NumberOfNodes = 0;
for (const auto &Node : MRoots) {
NumberOfNodes += Node->depthSearchCount();
}
return NumberOfNodes;
}

private:
/// Context associated with this graph.
sycl::context MContext;
Expand All @@ -333,11 +482,21 @@ class graph_impl {
std::map<std::weak_ptr<sycl::detail::queue_impl>, std::shared_ptr<node_impl>,
std::owner_less<std::weak_ptr<sycl::detail::queue_impl>>>
MInorderQueueMap;

/// Insert node into list of root nodes.
/// @param Root Node to add to list of root nodes.
void addRoot(const std::shared_ptr<node_impl> &Root);
};

/// Class representing the implementation of command_graph<executable>.
class exec_graph_impl {
public:
using ReadLock = std::shared_lock<std::shared_mutex>;
using WriteLock = std::unique_lock<std::shared_mutex>;

/// Protects all the fields that can be changed by class' methods.
mutable std::shared_mutex MMutex;

/// Constructor.
/// @param Context Context to create graph with.
/// @param GraphImpl Modifiable graph implementation to create with.
Expand Down Expand Up @@ -413,6 +572,10 @@ class exec_graph_impl {
std::list<std::shared_ptr<node_impl>> MSchedule;
/// Pointer to the modifiable graph impl associated with this executable
/// graph.
/// Thread-safe implementation note: in the current implementation
/// multiple exec_graph_impl can reference the same graph_impl object.
/// This specificity must be taken into account when trying to lock
/// the graph_impl mutex from an exec_graph_impl to avoid deadlock.
std::shared_ptr<graph_impl> MGraphImpl;
/// Map of devices to command buffers.
std::unordered_map<sycl::device, sycl::detail::pi::PiExtCommandBuffer>
Expand Down
1 change: 1 addition & 0 deletions sycl/source/detail/queue_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -684,6 +684,7 @@ class queue_impl {

void setCommandGraph(
std::shared_ptr<ext::oneapi::experimental::detail::graph_impl> Graph) {
std::lock_guard<std::mutex> Lock(MMutex);
MGraph = Graph;
}

Expand Down
18 changes: 18 additions & 0 deletions sycl/source/handler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -456,6 +456,11 @@ event handler::finalize() {
std::shared_ptr<ext::oneapi::experimental::detail::node_impl> NodeImpl =
nullptr;

// GraphImpl is read and written in this scope so we lock this graph
// with full priviledges.
ext::oneapi::experimental::detail::graph_impl::WriteLock Lock(
GraphImpl->MMutex);

// Create a new node in the graph representing this command-group
if (MQueue->isInOrder()) {
// In-order queues create implicit linear dependencies between nodes.
Expand Down Expand Up @@ -1332,15 +1337,28 @@ void handler::ext_oneapi_graph(
Graph) {
MCGType = detail::CG::ExecCommandBuffer;
auto GraphImpl = detail::getSyclObjImpl(Graph);
// GraphImpl is only read in this scope so we lock this graph for read only
ext::oneapi::experimental::detail::graph_impl::ReadLock Lock(
GraphImpl->MMutex);

std::shared_ptr<ext::oneapi::experimental::detail::graph_impl> ParentGraph;
if (MQueue) {
ParentGraph = MQueue->getCommandGraph();
} else {
ParentGraph = MGraph;
}

ext::oneapi::experimental::detail::graph_impl::WriteLock ParentLock;
// If a parent graph is set that means we are adding or recording a subgraph
if (ParentGraph) {
// ParentGraph is read and written in this scope so we lock this graph
// with full priviledges.
// We only lock for Record&Replay API because the graph has already been
// lock if this function was called from the explicit API function add
if (MQueue) {
ParentLock = ext::oneapi::experimental::detail::graph_impl::WriteLock(
ParentGraph->MMutex);
}
// Store the node representing the subgraph in the handler so that we can
// return it to the user later.
MSubgraphNode = ParentGraph->addSubgraphNodes(GraphImpl->getSchedule());
Expand Down
2 changes: 1 addition & 1 deletion sycl/test-e2e/Graph/Explicit/basic_usm.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// REQUIRES: level_zero, gpu
// RUN: %{build} -o %t.out
// RUN: %{build_pthread_inc} -o %t.out
// RUN: %{run} %t.out
// Extra run to check for leaks in Level Zero using ZE_DEBUG
// RUN: %if ext_oneapi_level_zero %{env ZE_DEBUG=4 %{run} %t.out 2>&1 | FileCheck %s %}
Expand Down
9 changes: 9 additions & 0 deletions sycl/test-e2e/Graph/Inputs/basic_usm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,14 @@
// and submission of the graph.

#include "../graph_common.hpp"
#include <thread>

int main() {
queue Queue;

using T = int;

const unsigned NumThreads = std::thread::hardware_concurrency();
std::vector<T> DataA(Size), DataB(Size), DataC(Size);

std::iota(DataA.begin(), DataA.end(), 1);
Expand All @@ -32,8 +34,15 @@ int main() {
// Add commands to graph
add_nodes(Graph, Queue, Size, PtrA, PtrB, PtrC);

Barrier SyncPoint{NumThreads};

auto GraphExec = Graph.finalize();

auto SubmitGraph = [&]() {
SyncPoint.wait();
Queue.submit([&](handler &CGH) { CGH.ext_oneapi_graph(GraphExec); });
};

event Event;
for (unsigned n = 0; n < Iterations; n++) {
Event = Queue.submit([&](handler &CGH) {
Expand Down
Loading

0 comments on commit c8c64a6

Please sign in to comment.