Subzero: replace Win32 fibers with Marl for couroutines
* This change was authored by bclayton@, with some modifications.
* Replaces Win32 fiber implementation with marl tasks, making coroutines
work on all marl-supported platforms.
Bug: b/145754674
Change-Id: Ic3de82afc69549e1d56688c6faf8077a6f446ee0
Reviewed-on: https://swiftshader-review.googlesource.com/c/SwiftShader/+/41788
Kokoro-Presubmit: kokoro <noreply+kokoro@google.com>
Reviewed-by: Ben Clayton <bclayton@google.com>
Reviewed-by: Alexis Hétu <sugoi@google.com>
Tested-by: Antonio Maiorano <amaiorano@google.com>
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 37789cc..0589d0d 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -1549,6 +1549,25 @@
target_link_libraries(llvm dl)
endif()
+
+###########################################################
+# marl
+###########################################################
+if(BUILD_MARL)
+ set(MARL_THIRD_PARTY_DIR ${THIRD_PARTY_DIR})
+ add_subdirectory(third_party/marl)
+endif()
+
+
+###########################################################
+# cppdap
+###########################################################
+if(SWIFTSHADER_BUILD_CPPDAP)
+ set(CPPDAP_THIRD_PARTY_DIR ${THIRD_PARTY_DIR})
+ add_subdirectory(${CPPDAP_DIR})
+endif()
+
+
###########################################################
# Subzero
###########################################################
@@ -1696,6 +1715,8 @@
)
target_link_libraries(ReactorSubzero SubzeroDependencies)
+ target_link_libraries(ReactorSubzero marl)
+
if(WIN32)
target_compile_definitions(ReactorSubzero PRIVATE SUBZERO_USE_MICROSOFT_ABI)
endif()
@@ -2157,16 +2178,6 @@
)
endif(SWIFTSHADER_BUILD_GLES_CM)
-if(BUILD_MARL)
- set(MARL_THIRD_PARTY_DIR ${THIRD_PARTY_DIR})
- add_subdirectory(third_party/marl)
-endif()
-
-if(SWIFTSHADER_BUILD_CPPDAP)
- set(CPPDAP_THIRD_PARTY_DIR ${THIRD_PARTY_DIR})
- add_subdirectory(${CPPDAP_DIR})
-endif()
-
if(SWIFTSHADER_BUILD_VULKAN)
add_library(vk_swiftshader SHARED ${VULKAN_LIST})
@@ -2322,9 +2333,9 @@
)
if(NOT WIN32 AND ${REACTOR_BACKEND} STREQUAL "Subzero")
- target_link_libraries(ReactorUnitTests ${Reactor} pthread dl)
+ target_link_libraries(ReactorUnitTests ${Reactor} marl pthread dl)
else()
- target_link_libraries(ReactorUnitTests ${Reactor})
+ target_link_libraries(ReactorUnitTests ${Reactor} marl)
endif()
set(GLES_UNITTESTS_LIST
@@ -2390,7 +2401,7 @@
)
add_executable(ReactorBenchmarks ${REACTOR_BENCHMARK_LIST})
- target_link_libraries(ReactorBenchmarks benchmark::benchmark ${Reactor})
+ target_link_libraries(ReactorBenchmarks benchmark::benchmark marl ${Reactor})
set_target_properties(ReactorBenchmarks PROPERTIES
COMPILE_OPTIONS "${SWIFTSHADER_COMPILE_OPTIONS};${WARNINGS_AS_ERRORS}"
diff --git a/src/Reactor/ReactorUnitTests.cpp b/src/Reactor/ReactorUnitTests.cpp
index bcfd13d..7e38d93 100644
--- a/src/Reactor/ReactorUnitTests.cpp
+++ b/src/Reactor/ReactorUnitTests.cpp
@@ -18,6 +18,9 @@
#include "gtest/gtest.h"
+#include "marl/defer.h"
+#include "marl/scheduler.h"
+
#include <array>
#include <cmath>
#include <thread>
@@ -155,6 +158,38 @@
return result;
}
+static const std::vector<int> fibonacci = {
+ 0,
+ 1,
+ 1,
+ 2,
+ 3,
+ 5,
+ 8,
+ 13,
+ 21,
+ 34,
+ 55,
+ 89,
+ 144,
+ 233,
+ 377,
+ 610,
+ 987,
+ 1597,
+ 2584,
+ 4181,
+ 6765,
+ 10946,
+ 17711,
+ 28657,
+ 46368,
+ 75025,
+ 121393,
+ 196418,
+ 317811,
+};
+
TEST(ReactorUnitTests, PrintPrimitiveTypes)
{
#if defined(ENABLE_RR_PRINT) && !defined(ENABLE_RR_EMIT_PRINT_LOCATION)
@@ -2070,6 +2105,30 @@
}
}
+TEST(ReactorUnitTests, Fibonacci)
+{
+ FunctionT<int(int)> function;
+ {
+ Int n = function.Arg<0>();
+ Int current = 0;
+ Int next = 1;
+ For(Int i = 0, i < n, i++)
+ {
+ auto tmp = current + next;
+ current = next;
+ next = tmp;
+ }
+ Return(current);
+ }
+
+ auto routine = function("one");
+
+ for(size_t i = 0; i < fibonacci.size(); i++)
+ {
+ EXPECT_EQ(routine(i), fibonacci[i]);
+ }
+}
+
TEST(ReactorUnitTests, Coroutines_Fibonacci)
{
if(!rr::Caps.CoroutinesSupported)
@@ -2078,6 +2137,11 @@
return;
}
+ marl::Scheduler scheduler;
+ scheduler.setWorkerThreadCount(8);
+ scheduler.bind();
+ defer(scheduler.unbind());
+
Coroutine<int()> function;
{
Yield(Int(0));
@@ -2095,45 +2159,11 @@
auto coroutine = function();
- int32_t expected[] = {
- 0,
- 1,
- 1,
- 2,
- 3,
- 5,
- 8,
- 13,
- 21,
- 34,
- 55,
- 89,
- 144,
- 233,
- 377,
- 610,
- 987,
- 1597,
- 2584,
- 4181,
- 6765,
- 10946,
- 17711,
- 28657,
- 46368,
- 75025,
- 121393,
- 196418,
- 317811,
- };
-
- auto count = sizeof(expected) / sizeof(expected[0]);
-
- for(size_t i = 0; i < count; i++)
+ for(size_t i = 0; i < fibonacci.size(); i++)
{
int out = 0;
EXPECT_EQ(coroutine->await(out), true);
- EXPECT_EQ(out, expected[i]);
+ EXPECT_EQ(out, fibonacci[i]);
}
}
@@ -2145,6 +2175,11 @@
return;
}
+ marl::Scheduler scheduler;
+ scheduler.setWorkerThreadCount(8);
+ scheduler.bind();
+ defer(scheduler.unbind());
+
Coroutine<uint8_t(uint8_t * data, int count)> function;
{
Pointer<Byte> data = function.Arg<0>();
@@ -2186,6 +2221,11 @@
return;
}
+ marl::Scheduler scheduler;
+ scheduler.setWorkerThreadCount(8);
+ scheduler.bind();
+ defer(scheduler.unbind());
+
Coroutine<int()> function;
{
Int4 a{ 1, 2, 3, 4 };
@@ -2220,6 +2260,11 @@
return;
}
+ marl::Scheduler scheduler;
+ scheduler.setWorkerThreadCount(8);
+ scheduler.bind();
+ defer(scheduler.unbind());
+
for(int i = 0; i < 2; ++i)
{
Coroutine<int()> function;
@@ -2244,6 +2289,11 @@
return;
}
+ marl::Scheduler scheduler;
+ scheduler.setWorkerThreadCount(8);
+ //scheduler.bind();
+ //defer(scheduler.unbind());
+
Coroutine<int()> function;
{
Yield(Int(0));
@@ -2262,53 +2312,22 @@
// Must call on same thread that creates the coroutine
function.finalize();
- constexpr int32_t expected[] = {
- 0,
- 1,
- 1,
- 2,
- 3,
- 5,
- 8,
- 13,
- 21,
- 34,
- 55,
- 89,
- 144,
- 233,
- 377,
- 610,
- 987,
- 1597,
- 2584,
- 4181,
- 6765,
- 10946,
- 17711,
- 28657,
- 46368,
- 75025,
- 121393,
- 196418,
- 317811,
- };
-
- constexpr auto count = sizeof(expected) / sizeof(expected[0]);
-
std::vector<std::thread> threads;
const size_t numThreads = 100;
for(size_t t = 0; t < numThreads; ++t)
{
threads.emplace_back([&] {
+ scheduler.bind();
+ defer(scheduler.unbind());
+
auto coroutine = function();
- for(size_t i = 0; i < count; i++)
+ for(size_t i = 0; i < fibonacci.size(); i++)
{
int out = 0;
EXPECT_EQ(coroutine->await(out), true);
- EXPECT_EQ(out, expected[i]);
+ EXPECT_EQ(out, fibonacci[i]);
}
});
}
diff --git a/src/Reactor/SubzeroReactor.cpp b/src/Reactor/SubzeroReactor.cpp
index 6602ddd..df12cb1 100644
--- a/src/Reactor/SubzeroReactor.cpp
+++ b/src/Reactor/SubzeroReactor.cpp
@@ -33,6 +33,8 @@
#include "llvm/Support/FileSystem.h"
#include "llvm/Support/raw_os_ostream.h"
+#include "marl/event.h"
+
#if __has_feature(memory_sanitizer)
# include <sanitizer/msan_interface.h>
#endif
@@ -380,11 +382,7 @@
}
const Capabilities Caps = {
-#if defined(_WIN32)
true, // CoroutinesSupported
-#else
- false, // CoroutinesSupported
-#endif
};
enum EmulatedType
@@ -4464,19 +4462,15 @@
namespace {
namespace coro {
-using FiberHandle = void *;
-
// Instance data per generated coroutine
// This is the "handle" type used for Coroutine functions
// Lifetime: from yield to when CoroutineEntryDestroy generated function is called.
struct CoroutineData
{
- FiberHandle mainFiber{};
- FiberHandle routineFiber{};
- bool convertedFiber = false;
-
- // Variables used by coroutines
- bool done = false;
+ marl::Event suspended; // the coroutine is suspended on a yield()
+ marl::Event resumed; // the caller is suspended on an await()
+ marl::Event done{ marl::Event::Mode::Manual }; // the coroutine should stop at the next yield()
+ marl::Event terminated{ marl::Event::Mode::Manual }; // the coroutine has finished.
void *promisePtr = nullptr;
};
@@ -4490,108 +4484,36 @@
delete coroData;
}
-void convertThreadToMainFiber(Nucleus::CoroutineHandle handle)
+// suspend() pauses execution of the coroutine, and resumes execution from the
+// caller's call to await().
+// Returns true if await() is called again, or false if coroutine_destroy()
+// is called.
+bool suspend(Nucleus::CoroutineHandle handle)
{
-#if defined(_WIN32)
- auto *coroData = reinterpret_cast<CoroutineData *>(handle);
-
- coroData->mainFiber = ::ConvertThreadToFiber(nullptr);
-
- if(coroData->mainFiber)
- {
- coroData->convertedFiber = true;
- }
- else
- {
- // We're probably already on a fiber, so just grab it and remember that we didn't
- // convert it, so not to convert back to thread.
- coroData->mainFiber = GetCurrentFiber();
- coroData->convertedFiber = false;
- }
- ASSERT(coroData->mainFiber);
-#else
- UNIMPLEMENTED_NO_BUG("convertThreadToMainFiber not implemented for current platform");
-#endif
+ auto *data = reinterpret_cast<CoroutineData *>(handle);
+ data->suspended.signal();
+ data->resumed.wait();
+ return !data->done.test();
}
-void convertMainFiberToThread(Nucleus::CoroutineHandle handle)
+// resume() is called by await(), blocking until the coroutine calls yield()
+// or the coroutine terminates.
+void resume(Nucleus::CoroutineHandle handle)
{
-#if defined(_WIN32)
- auto *coroData = reinterpret_cast<CoroutineData *>(handle);
-
- ASSERT(coroData->mainFiber);
-
- if(coroData->convertedFiber)
- {
- ::ConvertFiberToThread();
- coroData->mainFiber = nullptr;
- }
-#else
- UNIMPLEMENTED_NO_BUG("convertMainFiberToThread not implemented for current platform");
-#endif
-}
-using FiberFunc = std::function<void()>;
-
-void createRoutineFiber(Nucleus::CoroutineHandle handle, FiberFunc *fiberFunc)
-{
-#if defined(_WIN32)
- struct Invoker
- {
- FiberFunc func;
-
- static VOID __stdcall fiberEntry(LPVOID lpParameter)
- {
- auto *func = reinterpret_cast<FiberFunc *>(lpParameter);
- (*func)();
- }
- };
-
- auto *coroData = reinterpret_cast<CoroutineData *>(handle);
-
- constexpr SIZE_T StackSize = 2 * 1024 * 1024;
- coroData->routineFiber = ::CreateFiber(StackSize, &Invoker::fiberEntry, fiberFunc);
- ASSERT(coroData->routineFiber);
-#else
- UNIMPLEMENTED_NO_BUG("createRoutineFiber not implemented for current platform");
-#endif
+ auto *data = reinterpret_cast<CoroutineData *>(handle);
+ data->resumed.signal();
+ data->suspended.wait();
}
-void deleteRoutineFiber(Nucleus::CoroutineHandle handle)
+// stop() is called by coroutine_destroy(), signalling that it's done, then blocks
+// until the coroutine ends, and deletes the coroutine data.
+void stop(Nucleus::CoroutineHandle handle)
{
-#if defined(_WIN32)
auto *coroData = reinterpret_cast<CoroutineData *>(handle);
- ASSERT(coroData->routineFiber);
- ::DeleteFiber(coroData->routineFiber);
- coroData->routineFiber = nullptr;
-#else
- UNIMPLEMENTED_NO_BUG("deleteRoutineFiber not implemented for current platform");
-#endif
-}
-
-void switchToMainFiber(Nucleus::CoroutineHandle handle)
-{
-#if defined(_WIN32)
- auto *coroData = reinterpret_cast<CoroutineData *>(handle);
-
- // Win32
- ASSERT(coroData->mainFiber);
- ::SwitchToFiber(coroData->mainFiber);
-#else
- UNIMPLEMENTED_NO_BUG("switchToMainFiber not implemented for current platform");
-#endif
-}
-
-void switchToRoutineFiber(Nucleus::CoroutineHandle handle)
-{
-#if defined(_WIN32)
- auto *coroData = reinterpret_cast<CoroutineData *>(handle);
-
- // Win32
- ASSERT(coroData->routineFiber);
- ::SwitchToFiber(coroData->routineFiber);
-#else
- UNIMPLEMENTED_NO_BUG("switchToRoutineFiber not implemented for current platform");
-#endif
+ coroData->done.signal(); // signal that the coroutine should stop at next (or current) yield.
+ coroData->resumed.signal(); // wake the coroutine if blocked on a yield.
+ coroData->terminated.wait(); // wait for the coroutine to return.
+ coro::destroyCoroutineData(coroData); // free the coroutine data.
}
namespace detail {
@@ -4612,17 +4534,10 @@
return handle;
}
-void setDone(Nucleus::CoroutineHandle handle)
-{
- auto *coroData = reinterpret_cast<CoroutineData *>(handle);
- ASSERT(!coroData->done); // Should be called once
- coroData->done = true;
-}
-
bool isDone(Nucleus::CoroutineHandle handle)
{
auto *coroData = reinterpret_cast<CoroutineData *>(handle);
- return coroData->done;
+ return coroData->done.test();
}
void setPromisePtr(Nucleus::CoroutineHandle handle, void *promisePtr)
@@ -4697,30 +4612,27 @@
// ... <REACTOR CODE> ...
//
// promise = val;
- // coro::switchToMainFiber(handle);
+ // if (!coro::suspend(handle)) {
+ // return false; // coroutine has been stopped by the caller.
+ // }
//
// ... <REACTOR CODE> ...
+ // promise = val;
Nucleus::createStore(val, V(this->promise), ::coroYieldType);
- sz::Call(::function, ::basicBlock, coro::switchToMainFiber, this->handle);
- }
- // Adds instructions at the end of the current main coroutine function to end the coroutine.
- void generateCoroutineEnd()
- {
+ // if (!coro::suspend(handle)) {
+ auto result = sz::Call(::function, ::basicBlock, coro::suspend, this->handle);
+ auto doneBlock = Nucleus::createBasicBlock();
+ auto resumeBlock = Nucleus::createBasicBlock();
+ Nucleus::createCondBr(V(result), resumeBlock, doneBlock);
+
+ // return false; // coroutine has been stopped by the caller.
+ ::basicBlock = doneBlock;
+ Nucleus::createRetVoid(); // coroutine return value is ignored.
+
// ... <REACTOR CODE> ...
- //
- // coro::setDone(handle);
- // coro::switchToMainFiber();
- // // Unreachable
- // }
- //
-
- sz::Call(::function, ::basicBlock, coro::setDone, this->handle);
-
- // A Win32 Fiber function must not end, otherwise it tears down the thread it's running on.
- // So we add code to switch back to the main thread.
- sz::Call(::function, ::basicBlock, coro::switchToMainFiber, this->handle);
+ ::basicBlock = resumeBlock;
}
using FunctionUniquePtr = std::unique_ptr<Ice::Cfg>;
@@ -4739,7 +4651,7 @@
// {
// YieldType* promise = coro::getPromisePtr(handle);
// *out = *promise;
- // coro::switchToRoutineFiber(handle);
+ // coro::resume(handle);
// return true;
// }
// }
@@ -4776,8 +4688,8 @@
auto store = Ice::InstStore::create(awaitFunc, promiseVal, outPtr);
resumeBlock->appendInst(store);
- // coro::switchToRoutineFiber(handle);
- sz::Call(awaitFunc, resumeBlock, coro::switchToRoutineFiber, handle);
+ // coro::resume(handle);
+ sz::Call(awaitFunc, resumeBlock, coro::resume, handle);
// return true;
Ice::InstRet *ret = Ice::InstRet::create(awaitFunc, ::context->getConstantInt32(1));
@@ -4806,9 +4718,7 @@
{
// void coroutine_destroy(Nucleus::CoroutineHandle handle)
// {
- // coro::convertMainFiberToThread(coroData);
- // coro::deleteRoutineFiber(handle);
- // coro::destroyCoroutineData(handle);
+ // coro::stop(handle); // signal and wait for coroutine to stop, and delete coroutine data
// return;
// }
@@ -4822,14 +4732,8 @@
auto *bb = destroyFunc->getEntryNode();
- // coro::convertMainFiberToThread(coroData);
- sz::Call(destroyFunc, bb, coro::convertMainFiberToThread, handle);
-
- // coro::deleteRoutineFiber(handle);
- sz::Call(destroyFunc, bb, coro::deleteRoutineFiber, handle);
-
- // coro::destroyCoroutineData(handle);
- sz::Call(destroyFunc, bb, coro::destroyCoroutineData, handle);
+ // coro::stop(handle); // signal and wait for coroutine to stop, and delete coroutine data
+ sz::Call(destroyFunc, bb, coro::stop, handle);
// return;
Ice::InstRet *ret = Ice::InstRet::create(destroyFunc);
@@ -4848,26 +4752,19 @@
// This doubles up as our coroutine handle
auto coroData = coro::createCoroutineData();
- // Convert current thread to a fiber so we can create new fibers and switch to them
- coro::convertThreadToMainFiber(coroData);
-
- coro::FiberFunc fiberFunc = [&]() {
+ marl::schedule([=] {
// Store handle in TLS so that the coroutine can grab it right away, before
// any fiber switch occurs.
coro::setHandleParam(coroData);
- // Invoke the begin function in the context of the routine fiber
beginFunc();
- // Either it yielded, or finished. In either case, we switch back to the main fiber.
- // We don't ever return from this function, or the current thread will be destroyed.
- coro::switchToMainFiber(coroData);
- };
+ coroData->done.signal(); // coroutine is done.
+ coroData->suspended.signal(); // resume any blocking await() call.
+ coroData->terminated.signal(); // signal that the coroutine data is ready for freeing.
+ });
- coro::createRoutineFiber(coroData, &fiberFunc);
-
- // Fiber will now start running, executing the saved beginFunc
- coro::switchToRoutineFiber(coroData);
+ coroData->suspended.wait(); // block until the first yield or coroutine end
return coroData;
}
@@ -4914,7 +4811,6 @@
// Finish generating coroutine functions
{
Ice::CfgLocalAllocatorScope scopedAlloc{ ::function };
- ::coroGen->generateCoroutineEnd();
createRetVoidIfNoRet();
}