Skip to content

Commit

Permalink
Move scheduler / runtime testspy definitions to separate files
Browse files Browse the repository at this point in the history
  • Loading branch information
fknorr committed Nov 26, 2024
1 parent 80a3f80 commit f385675
Show file tree
Hide file tree
Showing 16 changed files with 485 additions and 379 deletions.
9 changes: 6 additions & 3 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,10 @@ set(CELERITY_FEATURE_LOCAL_ACCESSOR ON)
set(CELERITY_FEATURE_UNNAMED_KERNELS ON)

# Add header files to library so they show up in IDEs
file(GLOB_RECURSE ALL_HEADERS "${CMAKE_CURRENT_SOURCE_DIR}/include/*.h")
file(GLOB_RECURSE ALL_INCLUDES
"${CMAKE_CURRENT_SOURCE_DIR}/include/*.h"
"${CMAKE_CURRENT_SOURCE_DIR}/src/*.h"
"${CMAKE_CURRENT_SOURCE_DIR}/src/*.inl")

if(CMAKE_GENERATOR STREQUAL "Ninja")
# Force colored warnings in Ninja's output, if the compiler has -fdiagnostics-color support.
Expand Down Expand Up @@ -304,14 +307,14 @@ elseif(CELERITY_SYCL_IMPL STREQUAL "SimSYCL")
endif()

configure_file(include/version.h.in include/version.h @ONLY)
list(APPEND ALL_HEADERS "${CMAKE_CURRENT_BINARY_DIR}/include/version.h")
list(APPEND ALL_INCLUDES "${CMAKE_CURRENT_BINARY_DIR}/include/version.h")
list(APPEND PUBLIC_HEADERS "${CMAKE_CURRENT_BINARY_DIR}/include/version.h")

add_library(
celerity_runtime
STATIC
${SOURCES}
${ALL_HEADERS}
${ALL_INCLUDES}
)

set_property(TARGET celerity_runtime PROPERTY CXX_STANDARD "${CELERITY_CXX_STANDARD}")
Expand Down
100 changes: 29 additions & 71 deletions include/runtime.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ namespace detail {
*/
static void init(int* argc, char** argv[], const devices_or_selector& user_devices_or_selector = auto_select_devices{});

static bool has_instance() { return s_instance != nullptr; }
static bool has_instance() { return s_instance.m_impl != nullptr; }

static void shutdown();

Expand All @@ -34,100 +34,58 @@ namespace detail {
runtime(runtime&&) = delete;
runtime& operator=(const runtime&) = delete;
runtime& operator=(runtime&&) = delete;
~runtime() = default;

virtual ~runtime() = default;
task_id submit(raw_command_group&& cg);

virtual task_id submit(raw_command_group&& cg) = 0;
task_id fence(buffer_access access, std::unique_ptr<task_promise> fence_promise);

virtual task_id fence(buffer_access access, std::unique_ptr<task_promise> fence_promise) = 0;
task_id fence(host_object_effect effect, std::unique_ptr<task_promise> fence_promise);

virtual task_id fence(host_object_effect effect, std::unique_ptr<task_promise> fence_promise) = 0;
task_id sync(detail::epoch_action action);

virtual task_id sync(detail::epoch_action action) = 0;
void create_queue();

virtual void create_queue() = 0;
void destroy_queue();

virtual void destroy_queue() = 0;
allocation_id create_user_allocation(void* ptr);

virtual allocation_id create_user_allocation(void* ptr) = 0;
buffer_id create_buffer(const range<3>& range, size_t elem_size, size_t elem_align, allocation_id user_aid);

virtual buffer_id create_buffer(const range<3>& range, size_t elem_size, size_t elem_align, allocation_id user_aid) = 0;
void set_buffer_debug_name(buffer_id bid, const std::string& debug_name);

virtual void set_buffer_debug_name(buffer_id bid, const std::string& debug_name) = 0;
void destroy_buffer(buffer_id bid);

virtual void destroy_buffer(buffer_id bid) = 0;
host_object_id create_host_object(std::unique_ptr<host_object_instance> instance /* optional */);

virtual host_object_id create_host_object(std::unique_ptr<host_object_instance> instance /* optional */) = 0;
void destroy_host_object(host_object_id hoid);

virtual void destroy_host_object(host_object_id hoid) = 0;
reduction_id create_reduction(std::unique_ptr<reducer> reducer);

virtual reduction_id create_reduction(std::unique_ptr<reducer> reducer) = 0;
bool is_dry_run() const;

virtual bool is_dry_run() const = 0;
void set_scheduler_lookahead(experimental::lookahead lookahead);

virtual void set_scheduler_lookahead(experimental::lookahead lookahead) = 0;
void flush_scheduler();

virtual void flush_scheduler() = 0;
private:
class impl;

protected:
inline static bool s_mpi_initialized = false;
inline static bool s_mpi_finalized = false;
static bool s_mpi_initialized;
static bool s_mpi_finalized;

static bool s_test_mode;
static bool s_test_active;
static bool s_test_runtime_was_instantiated;

static void mpi_initialize_once(int* argc, char*** argv);
static void mpi_finalize_once();

static std::unique_ptr<runtime> s_instance;

runtime() = default;
static runtime s_instance;

// ------------------------------------------ TESTING UTILS ------------------------------------------
// We have to jump through some hoops to be able to re-initialize the runtime for unit testing.
// MPI does not like being initialized more than once per process, so we have to skip that part for
// re-initialization.
// ---------------------------------------------------------------------------------------------------
std::unique_ptr<impl> m_impl;

public:
// Switches to test mode, where MPI will be initialized through test_case_enter() instead of runtime::runtime(). Called on Catch2 startup.
static void test_mode_enter() {
assert(!s_mpi_initialized);
s_test_mode = true;
}

// Finalizes MPI if it was ever initialized in test mode. Called on Catch2 shutdown.
static void test_mode_exit() {
assert(s_test_mode && !s_test_active && !s_mpi_finalized);
if(s_mpi_initialized) mpi_finalize_once();
}

// Initializes MPI for tests, if it was not initialized before
static void test_require_mpi() {
assert(s_test_mode && !s_test_active);
if(!s_mpi_initialized) mpi_initialize_once(nullptr, nullptr);
}

// Allows the runtime to be transitively instantiated in tests. Called from runtime_fixture.
static void test_case_enter() {
assert(s_test_mode && !s_test_active && s_mpi_initialized && s_instance == nullptr);
s_test_active = true;
s_test_runtime_was_instantiated = false;
}

static bool test_runtime_was_instantiated() {
assert(s_test_mode);
return s_test_runtime_was_instantiated;
}

// Deletes the runtime instance, which happens only in tests. Called from runtime_fixture.
static void test_case_exit() {
assert(s_test_mode && s_test_active);
s_instance.reset(); // for when the test case explicitly initialized the runtime but did not successfully construct a queue / buffer / ...
s_test_active = false;
}

protected:
inline static bool s_test_mode = false;
inline static bool s_test_active = false;
inline static bool s_test_runtime_was_instantiated = false;
runtime() = default;
};

/// Returns the combined command graph of all nodes on node 0, an empty string on other nodes
Expand Down
128 changes: 0 additions & 128 deletions include/runtime_impl.h

This file was deleted.

18 changes: 1 addition & 17 deletions include/scheduler.h
Original file line number Diff line number Diff line change
@@ -1,27 +1,17 @@
#pragma once

#include "command_graph.h"
#include "command_graph_generator.h"
#include "instruction_graph_generator.h"
#include "ranges.h"
#include "types.h"

#include <cstddef>
#include <functional>
#include <memory>
#include <string>


namespace celerity::detail::scheduler_detail {

/// executed inside scheduler thread, making it safe to access scheduler members
struct test_state {
const command_graph* cdag = nullptr;
const instruction_graph* idag = nullptr;
experimental::lookahead lookahead = experimental::lookahead::automatic;
};
using test_inspector = std::function<void(const test_state&)>;

struct scheduler_impl;

} // namespace celerity::detail::scheduler_detail
Expand Down Expand Up @@ -77,15 +67,9 @@ class scheduler {
void flush_commands();

private:
struct test_threadless_tag {};
scheduler() = default; // used by scheduler_testspy

std::unique_ptr<scheduler_detail::scheduler_impl> m_impl;

// used in scheduler_testspy
scheduler(test_threadless_tag, size_t num_nodes, node_id local_node_id, const system_info& system_info, scheduler::delegate* delegate,
command_recorder* crec, instruction_recorder* irec, const policy_set& policy = {});
void test_scheduling_loop();
void test_inspect(scheduler_detail::test_inspector inspector);
};

} // namespace celerity::detail
Loading

0 comments on commit f385675

Please sign in to comment.