Subzero: implement coroutines for Win32
Coroutines are emulated by using fibers.
Bug: b/145754674
Change-Id: I3f4bf29d26a75a2386ed812dd821d8a7a8276305
Reviewed-on: https://swiftshader-review.googlesource.com/c/SwiftShader/+/40548
Tested-by: Antonio Maiorano <amaiorano@google.com>
Reviewed-by: Nicolas Capens <nicolascapens@google.com>
diff --git a/src/Reactor/SubzeroReactor.cpp b/src/Reactor/SubzeroReactor.cpp
index 5d057b0..15b0eec 100644
--- a/src/Reactor/SubzeroReactor.cpp
+++ b/src/Reactor/SubzeroReactor.cpp
@@ -51,8 +51,85 @@
// Subzero utility functions
// These functions only accept and return Subzero (Ice) types, and do not access any globals.
+namespace {
namespace sz {
-static Ice::Constant *getConstantPointer(Ice::GlobalContext *context, void const *ptr)
+void replaceEntryNode(Ice::Cfg *function, Ice::CfgNode *newEntryNode)
+{
+ ASSERT_MSG(function->getEntryNode() != nullptr, "Function should have an entry node");
+
+ if(function->getEntryNode() == newEntryNode)
+ {
+ return;
+ }
+
+ // Make this the new entry node
+ function->setEntryNode(newEntryNode);
+
+ // Reorder nodes so that new entry block comes first. This is required
+ // by Cfg::renumberInstructions, which expects the first node in the list
+ // to be the entry node.
+ {
+ auto nodes = function->getNodes();
+
+ // TODO(amaiorano): Fast path if newEntryNode is last? Can avoid linear search.
+
+ auto iter = std::find(nodes.begin(), nodes.end(), newEntryNode);
+ ASSERT_MSG(iter != nodes.end(), "New node should be in the function's node list");
+
+ nodes.erase(iter);
+ nodes.insert(nodes.begin(), newEntryNode);
+
+ // swapNodes replaces its nodes with the input one, and renumbers them,
+ // so our new entry node will be 0, and the previous will be 1.
+ function->swapNodes(nodes);
+ }
+}
+
+Ice::Cfg *createFunction(Ice::GlobalContext *context, Ice::Type returnType, const std::vector<Ice::Type> ¶mTypes)
+{
+ uint32_t sequenceNumber = 0;
+ auto function = Ice::Cfg::create(context, sequenceNumber).release();
+
+ Ice::CfgLocalAllocatorScope allocScope{ function };
+
+ for(auto type : paramTypes)
+ {
+ Ice::Variable *arg = function->makeVariable(type);
+ function->addArg(arg);
+ }
+
+ Ice::CfgNode *node = function->makeNode();
+ function->setEntryNode(node);
+
+ return function;
+}
+
+Ice::Type getPointerType(Ice::Type elementType)
+{
+ if(sizeof(void *) == 8)
+ {
+ return Ice::IceType_i64;
+ }
+ else
+ {
+ return Ice::IceType_i32;
+ }
+}
+
+Ice::Variable *allocateStackVariable(Ice::Cfg *function, Ice::Type type, int arraySize = 0)
+{
+ int typeSize = Ice::typeWidthInBytes(type);
+ int totalSize = typeSize * (arraySize ? arraySize : 1);
+
+ auto bytes = Ice::ConstantInteger32::create(function->getContext(), Ice::IceType_i32, totalSize);
+ auto address = function->makeVariable(getPointerType(type));
+ auto alloca = Ice::InstAlloca::create(function, address, bytes, typeSize);
+ function->getEntryNode()->getInsts().push_front(alloca);
+
+ return address;
+}
+
+Ice::Constant *getConstantPointer(Ice::GlobalContext *context, void const *ptr)
{
if(sizeof(void *) == 8)
{
@@ -64,8 +141,38 @@
}
}
+// Wrapper for calls on C functions with Ice types
+template<typename Return, typename... CArgs, typename... RArgs>
+Ice::Variable *Call(Ice::Cfg *function, Ice::CfgNode *basicBlock, Return(fptr)(CArgs...), RArgs &&... args)
+{
+ Ice::Type retTy = T(rr::CToReactorT<Return>::getType());
+
+ // Subzero doesn't support boolean return values. Replace with an i32.
+ if(retTy == Ice::IceType_i1)
+ {
+ retTy = Ice::IceType_i32;
+ }
+
+ Ice::Variable *ret = nullptr;
+ if(retTy != Ice::IceType_void)
+ {
+ ret = function->makeVariable(retTy);
+ }
+
+ std::initializer_list<Ice::Variable *> iceArgs = { std::forward<RArgs>(args)... };
+
+ auto call = Ice::InstCall::create(function, iceArgs.size(), ret, getConstantPointer(function->getContext(), reinterpret_cast<void const *>(fptr)), false);
+ for(auto arg : iceArgs)
+ {
+ call->addArg(arg);
+ }
+
+ basicBlock->appendInst(call);
+ return ret;
+}
+
// Returns a non-const variable copy of const v
-static Ice::Variable *createUnconstCast(Ice::Cfg *function, Ice::CfgNode *basicBlock, Ice::Constant *v)
+Ice::Variable *createUnconstCast(Ice::Cfg *function, Ice::CfgNode *basicBlock, Ice::Constant *v)
{
Ice::Variable *result = function->makeVariable(v->getType());
Ice::InstCast *cast = Ice::InstCast::create(function, Ice::InstCast::Bitcast, result, v);
@@ -73,7 +180,7 @@
return result;
}
-static Ice::Variable *createLoad(Ice::Cfg *function, Ice::CfgNode *basicBlock, Ice::Operand *ptr, Ice::Type type, unsigned int align)
+Ice::Variable *createLoad(Ice::Cfg *function, Ice::CfgNode *basicBlock, Ice::Operand *ptr, Ice::Type type, unsigned int align)
{
// TODO(b/148272103): InstLoad assumes that a constant ptr is an offset, rather than an
// absolute address. We circumvent this by casting to a non-const variable, and loading
@@ -91,9 +198,12 @@
}
} // namespace sz
+} // namespace
+
namespace rr {
class ELFMemoryStreamer;
-}
+class CoroutineGenerator;
+} // namespace rr
namespace {
@@ -119,6 +229,10 @@
Ice::ELFFileStreamer *elfFile = nullptr;
Ice::Fdstream *out = nullptr;
+// Coroutine globals
+rr::Type *coroYieldType = nullptr;
+std::shared_ptr<rr::CoroutineGenerator> coroGen;
+
} // Anonymous namespace
namespace {
@@ -232,7 +346,11 @@
}
const Capabilities Caps = {
+#if defined(_WIN32)
+ true, // CoroutinesSupported
+#else
false, // CoroutinesSupported
+#endif
};
enum EmulatedType
@@ -274,11 +392,27 @@
return reinterpret_cast<Type *>(t);
}
+std::vector<Ice::Type> T(const std::vector<Type *> &types)
+{
+ std::vector<Ice::Type> result;
+ result.reserve(types.size());
+ for(auto &t : types)
+ {
+ result.push_back(T(t));
+ }
+ return result;
+}
+
Value *V(Ice::Operand *v)
{
return reinterpret_cast<Value *>(v);
}
+Ice::Operand *V(Value *v)
+{
+ return reinterpret_cast<Ice::Variable *>(v);
+}
+
BasicBlock *B(Ice::CfgNode *b)
{
return reinterpret_cast<BasicBlock *>(b);
@@ -303,6 +437,14 @@
return Ice::typeWidthInBytes(T(type));
}
+static void createRetVoidIfNoRet()
+{
+ if(::basicBlock->getInsts().empty() || ::basicBlock->getInsts().back().getKind() != Ice::Inst::Ret)
+ {
+ Nucleus::createRetVoid();
+ }
+}
+
using ElfHeader = std::conditional<sizeof(void *) == 8, Elf64_Ehdr, Elf32_Ehdr>::type;
using SectionHeader = std::conditional<sizeof(void *) == 8, Elf64_Shdr, Elf32_Shdr>::type;
@@ -462,7 +604,7 @@
return symbolValue;
}
-void *loadImage(uint8_t *const elfImage, size_t &codeSize)
+void *loadImage(uint8_t *const elfImage, size_t &codeSize, const char *functionName = nullptr)
{
ElfHeader *elfHeader = (ElfHeader *)elfImage;
@@ -496,6 +638,15 @@
{
if(sectionHeader[i].sh_flags & SHF_EXECINSTR)
{
+ auto getCurrSectionName = [&]() {
+ auto sectionNameOffset = sectionHeader[elfHeader->e_shstrndx].sh_offset + sectionHeader[i].sh_name;
+ return reinterpret_cast<const char *>(elfImage + sectionNameOffset);
+ };
+ if(functionName && strstr(getCurrSectionName(), functionName) == nullptr)
+ {
+ continue;
+ }
+
entry = elfImage + sectionHeader[i].sh_offset;
codeSize = sectionHeader[i].sh_size;
}
@@ -593,22 +744,27 @@
void seek(uint64_t Off) override { position = Off; }
- const void *finalizeEntryBegin()
+ const void *getEntryByName(const char *name)
{
- position = std::numeric_limits<std::size_t>::max(); // Can't stream more data after this
-
size_t codeSize = 0;
- const void *entry = loadImage(&buffer[0], codeSize);
+ const void *entry = loadImage(&buffer[0], codeSize, name);
- protectMemoryPages(&buffer[0], buffer.size(), PERMISSION_READ | PERMISSION_EXECUTE);
#if defined(_WIN32)
FlushInstructionCache(GetCurrentProcess(), NULL, 0);
#else
__builtin___clear_cache((char *)entry, (char *)entry + codeSize);
#endif
+
return entry;
}
+ void finalize()
+ {
+ position = std::numeric_limits<std::size_t>::max(); // Can't stream more data after this
+
+ protectMemoryPages(&buffer[0], buffer.size(), PERMISSION_READ | PERMISSION_EXECUTE);
+ }
+
void setEntry(int index, const void *func)
{
ASSERT(func);
@@ -664,6 +820,9 @@
Flags.setVerbose(subzeroDumpEnabled ? Ice::IceV_Most : Ice::IceV_None);
Flags.setDisableHybridAssembly(true);
+ // Emit functions into separate sections in the ELF so we can find them by name
+ Flags.setFunctionSections(true);
+
static llvm::raw_os_ostream cout(std::cout);
static llvm::raw_os_ostream cerr(std::cerr);
@@ -691,13 +850,24 @@
Nucleus::~Nucleus()
{
delete ::routine;
+ ::routine = nullptr;
delete ::allocator;
+ ::allocator = nullptr;
+
delete ::function;
+ ::function = nullptr;
+
delete ::context;
+ ::context = nullptr;
delete ::elfFile;
+ ::elfFile = nullptr;
+
delete ::out;
+ ::out = nullptr;
+
+ ::basicBlock = nullptr;
::codegenMutex.unlock();
}
@@ -721,56 +891,89 @@
return ::defaultConfig();
}
-std::shared_ptr<Routine> Nucleus::acquireRoutine(const char *name, const Config::Edit &cfgEdit /* = Config::Edit::None */)
+// This function lowers and produces executable binary code in memory for the input functions,
+// and returns a Routine with the entry points to these functions.
+template<size_t Count>
+static std::shared_ptr<Routine> acquireRoutine(Ice::Cfg *const (&functions)[Count], const char *const (&names)[Count], const Config::Edit &cfgEdit)
{
+ // This logic is modeled after the IceCompiler, as well as GlobalContext::translateFunctions
+ // and GlobalContext::emitItems.
+
if(subzeroDumpEnabled)
{
// Output dump strings immediately, rather than once buffer is full. Useful for debugging.
- context->getStrDump().SetUnbuffered();
- }
-
- if(basicBlock->getInsts().empty() || basicBlock->getInsts().back().getKind() != Ice::Inst::Ret)
- {
- createRetVoid();
- }
-
- ::function->setFunctionName(Ice::GlobalString::createWithString(::context, name));
-
- rr::optimize(::function);
-
- ::function->computeInOutEdges();
- ASSERT(!::function->hasError());
-
- ::function->translate();
- ASSERT(!::function->hasError());
-
- auto globals = ::function->getGlobalInits();
-
- if(globals && !globals->empty())
- {
- ::context->getGlobals()->merge(globals.get());
+ ::context->getStrDump().SetUnbuffered();
}
::context->emitFileHeader();
- if(subzeroEmitTextAsm)
+ // Translate
+
+ for(size_t i = 0; i < Count; ++i)
{
- ::function->emit();
+ Ice::Cfg *currFunc = functions[i];
+
+ // Install function allocator in TLS for Cfg-specific container allocators
+ Ice::CfgLocalAllocatorScope allocScope(currFunc);
+
+ currFunc->setFunctionName(Ice::GlobalString::createWithString(::context, names[i]));
+
+ rr::optimize(currFunc);
+
+ currFunc->computeInOutEdges();
+ ASSERT(!currFunc->hasError());
+
+ currFunc->translate();
+ ASSERT(!currFunc->hasError());
+
+ currFunc->getAssembler<>()->setInternal(currFunc->getInternal());
+
+ if(subzeroEmitTextAsm)
+ {
+ currFunc->emit();
+ }
+
+ currFunc->emitIAS();
}
- ::function->emitIAS();
- auto assembler = ::function->releaseAssembler();
+ // Emit items
+
+ ::context->lowerGlobals("");
+
auto objectWriter = ::context->getObjectWriter();
- assembler->alignFunction();
- objectWriter->writeFunctionCode(::function->getFunctionName(), false, assembler.get());
+
+ for(size_t i = 0; i < Count; ++i)
+ {
+ Ice::Cfg *currFunc = functions[i];
+
+ // Accumulate globals from functions to emit into the "last" section at the end
+ auto globals = currFunc->getGlobalInits();
+ if(globals && !globals->empty())
+ {
+ ::context->getGlobals()->merge(globals.get());
+ }
+
+ auto assembler = currFunc->releaseAssembler();
+ assembler->alignFunction();
+ objectWriter->writeFunctionCode(currFunc->getFunctionName(), currFunc->getInternal(), assembler.get());
+ }
+
::context->lowerGlobals("last");
::context->lowerConstants();
::context->lowerJumpTables();
+
objectWriter->setUndefinedSyms(::context->getConstantExternSyms());
+ ::context->emitTargetRODataSections();
objectWriter->writeNonUserSections();
- const void *entryBegin = ::routine->finalizeEntryBegin();
- ::routine->setEntry(Nucleus::CoroutineEntryBegin, entryBegin);
+ // Done compiling functions, get entry pointers to each of them
+ for(size_t i = 0; i < Count; ++i)
+ {
+ const void *entry = ::routine->getEntryByName(names[i]);
+ ::routine->setEntry(i, entry);
+ }
+
+ ::routine->finalize();
Routine *handoffRoutine = ::routine;
::routine = nullptr;
@@ -778,6 +981,12 @@
return std::shared_ptr<Routine>(handoffRoutine);
}
+std::shared_ptr<Routine> Nucleus::acquireRoutine(const char *name, const Config::Edit &cfgEdit /* = Config::Edit::None */)
+{
+ createRetVoidIfNoRet();
+ return rr::acquireRoutine({ ::function }, { name }, cfgEdit);
+}
+
Value *Nucleus::allocateStackVariable(Type *t, int arraySize)
{
Ice::Type type = T(t);
@@ -811,21 +1020,21 @@
::basicBlock = basicBlock;
}
-void Nucleus::createFunction(Type *ReturnType, std::vector<Type *> &Params)
+void Nucleus::createFunction(Type *returnType, const std::vector<Type *> ¶mTypes)
{
- uint32_t sequenceNumber = 0;
- ::function = Ice::Cfg::create(::context, sequenceNumber).release();
+ ASSERT(::function == nullptr);
+ ASSERT(::allocator == nullptr);
+ ASSERT(::basicBlock == nullptr);
+
+ ::function = sz::createFunction(::context, T(returnType), T(paramTypes));
+
+ // NOTE: The scoped allocator sets the TLS allocator to the one in the function. This global one
+ // becomes invalid if another one is created; for example, when creating await and destroy functions
+ // for coroutines, in which case, we must make sure to create a new scoped allocator for ::function again.
+ // TODO: Get rid of this as a global, and create scoped allocs in every Nucleus function instead.
::allocator = new Ice::CfgLocalAllocatorScope(::function);
- for(Type *type : Params)
- {
- Ice::Variable *arg = ::function->makeVariable(T(type));
- ::function->addArg(arg);
- }
-
- Ice::CfgNode *node = ::function->makeNode();
- ::function->setEntryNode(node);
- ::basicBlock = node;
+ ::basicBlock = ::function->getEntryNode();
}
Value *Nucleus::getArgument(unsigned int index)
@@ -1152,7 +1361,7 @@
{
ASSERT(value->getType() == T(type));
- auto store = Ice::InstStore::create(::function, value, ptr, align);
+ auto store = Ice::InstStore::create(::function, V(value), V(ptr), align);
::basicBlock->appendInst(store);
}
@@ -1556,14 +1765,7 @@
Type *Nucleus::getPointerType(Type *ElementType)
{
- if(sizeof(void *) == 8)
- {
- return T(Ice::IceType_i64);
- }
- else
- {
- return T(Ice::IceType_i32);
- }
+ return T(sz::getPointerType(T(ElementType)));
}
Value *Nucleus::createNullValue(Type *Ty)
@@ -2899,11 +3101,11 @@
Value *e;
int swizzle[16] = { 0, 16, 1, 17, 2, 18, 3, 19, 4, 20, 5, 21, 6, 22, 7, 23 };
Value *b = Nucleus::createBitCast(a, Byte16::getType());
- Value *c = Nucleus::createShuffleVector(b, V(Nucleus::createNullValue(Byte16::getType())), swizzle);
+ Value *c = Nucleus::createShuffleVector(b, Nucleus::createNullValue(Byte16::getType()), swizzle);
int swizzle2[8] = { 0, 8, 1, 9, 2, 10, 3, 11 };
Value *d = Nucleus::createBitCast(c, Short8::getType());
- e = Nucleus::createShuffleVector(d, V(Nucleus::createNullValue(Short8::getType())), swizzle2);
+ e = Nucleus::createShuffleVector(d, Nucleus::createNullValue(Short8::getType()), swizzle2);
Value *f = Nucleus::createBitCast(e, Int4::getType());
storeValue(f);
@@ -3879,34 +4081,507 @@
void EmitDebugVariable(Value *value) {}
void FlushDebug() {}
-void Nucleus::createCoroutine(Type *YieldType, std::vector<Type *> &Params)
+namespace {
+namespace coro {
+
+using FiberHandle = void *;
+
+// Instance data per generated coroutine
+// This is the "handle" type used for Coroutine functions
+// Lifetime: from yield to when CoroutineEntryDestroy generated function is called.
+struct CoroutineData
{
- // Subzero currently only supports coroutines as functions (i.e. that do not yield)
- createFunction(YieldType, Params);
+ FiberHandle mainFiber{};
+ FiberHandle routineFiber{};
+ bool convertedFiber = false;
+
+ // Variables used by coroutines
+ bool done = false;
+ void *promisePtr = nullptr;
+};
+
+CoroutineData *createCoroutineData()
+{
+ return new CoroutineData{};
+}
+
+void destroyCoroutineData(CoroutineData *coroData)
+{
+ delete coroData;
+}
+
+void convertThreadToMainFiber(Nucleus::CoroutineHandle handle)
+{
+#if defined(_WIN32)
+ auto *coroData = reinterpret_cast<CoroutineData *>(handle);
+
+ coroData->mainFiber = ::ConvertThreadToFiber(nullptr);
+
+ if(coroData->mainFiber)
+ {
+ coroData->convertedFiber = true;
+ }
+ else
+ {
+ // We're probably already on a fiber, so just grab it and remember that we didn't
+ // convert it, so not to convert back to thread.
+ coroData->mainFiber = GetCurrentFiber();
+ coroData->convertedFiber = false;
+ }
+ ASSERT(coroData->mainFiber);
+#else
+ UNIMPLEMENTED("convertThreadToMainFiber not implemented for current platform");
+#endif
+}
+
+void convertMainFiberToThread(Nucleus::CoroutineHandle handle)
+{
+#if defined(_WIN32)
+ auto *coroData = reinterpret_cast<CoroutineData *>(handle);
+
+ ASSERT(coroData->mainFiber);
+
+ if(coroData->convertedFiber)
+ {
+ ::ConvertFiberToThread();
+ coroData->mainFiber = nullptr;
+ }
+#else
+ UNIMPLEMENTED("convertMainFiberToThread not implemented for current platform");
+#endif
+}
+using FiberFunc = std::function<void()>;
+
+void createRoutineFiber(Nucleus::CoroutineHandle handle, FiberFunc *fiberFunc)
+{
+#if defined(_WIN32)
+ struct Invoker
+ {
+ FiberFunc func;
+
+ static VOID __stdcall fiberEntry(LPVOID lpParameter)
+ {
+ auto *func = reinterpret_cast<FiberFunc *>(lpParameter);
+ (*func)();
+ }
+ };
+
+ auto *coroData = reinterpret_cast<CoroutineData *>(handle);
+
+ constexpr SIZE_T StackSize = 2 * 1024 * 1024;
+ coroData->routineFiber = ::CreateFiber(StackSize, &Invoker::fiberEntry, fiberFunc);
+ ASSERT(coroData->routineFiber);
+#else
+ UNIMPLEMENTED("createRoutineFiber not implemented for current platform");
+#endif
+}
+
+void deleteRoutineFiber(Nucleus::CoroutineHandle handle)
+{
+#if defined(_WIN32)
+ auto *coroData = reinterpret_cast<CoroutineData *>(handle);
+ ASSERT(coroData->routineFiber);
+ ::DeleteFiber(coroData->routineFiber);
+ coroData->routineFiber = nullptr;
+#else
+ UNIMPLEMENTED("deleteRoutineFiber not implemented for current platform");
+#endif
+}
+
+void switchToMainFiber(Nucleus::CoroutineHandle handle)
+{
+#if defined(_WIN32)
+ auto *coroData = reinterpret_cast<CoroutineData *>(handle);
+
+ // Win32
+ ASSERT(coroData->mainFiber);
+ ::SwitchToFiber(coroData->mainFiber);
+#else
+ UNIMPLEMENTED("switchToMainFiber not implemented for current platform");
+#endif
+}
+
+void switchToRoutineFiber(Nucleus::CoroutineHandle handle)
+{
+#if defined(_WIN32)
+ auto *coroData = reinterpret_cast<CoroutineData *>(handle);
+
+ // Win32
+ ASSERT(coroData->routineFiber);
+ ::SwitchToFiber(coroData->routineFiber);
+#else
+ UNIMPLEMENTED("switchToRoutineFiber not implemented for current platform");
+#endif
+}
+
+namespace detail {
+thread_local rr::Nucleus::CoroutineHandle coroHandle{};
+} // namespace detail
+
+void setHandleParam(Nucleus::CoroutineHandle handle)
+{
+ ASSERT(!detail::coroHandle);
+ detail::coroHandle = handle;
+}
+
+Nucleus::CoroutineHandle getHandleParam()
+{
+ ASSERT(detail::coroHandle);
+ auto handle = detail::coroHandle;
+ detail::coroHandle = {};
+ return handle;
+}
+
+void setDone(Nucleus::CoroutineHandle handle)
+{
+ auto *coroData = reinterpret_cast<CoroutineData *>(handle);
+ ASSERT(!coroData->done); // Should be called once
+ coroData->done = true;
+}
+
+bool isDone(Nucleus::CoroutineHandle handle)
+{
+ auto *coroData = reinterpret_cast<CoroutineData *>(handle);
+ return coroData->done;
+}
+
+void setPromisePtr(Nucleus::CoroutineHandle handle, void *promisePtr)
+{
+ auto *coroData = reinterpret_cast<CoroutineData *>(handle);
+ coroData->promisePtr = promisePtr;
+}
+
+void *getPromisePtr(Nucleus::CoroutineHandle handle)
+{
+ auto *coroData = reinterpret_cast<CoroutineData *>(handle);
+ return coroData->promisePtr;
+}
+
+} // namespace coro
+} // namespace
+
+// Used to generate coroutines.
+// Lifetime: from yield to acquireCoroutine
+class CoroutineGenerator
+{
+public:
+ CoroutineGenerator()
+ {
+ }
+
+ // Inserts instructions at the top of the current function to make it a coroutine.
+ void generateCoroutineBegin()
+ {
+ // Begin building the main coroutine_begin() function.
+ // We insert these instructions at the top of the entry node,
+ // before existing reactor-generated instructions.
+
+ // CoroutineHandle coroutine_begin(<Arguments>)
+ // {
+ // this->handle = coro::getHandleParam();
+ //
+ // YieldType promise;
+ // coro::setPromisePtr(handle, &promise); // For await
+ //
+ // ... <REACTOR CODE> ...
+ //
+
+ // Save original entry block and current block, and create a new entry block and make it current.
+ // This new block will be used to inject code above the begin routine's existing code. We make
+ // this block branch to the original entry block as the last instruction.
+ auto origEntryBB = ::function->getEntryNode();
+ auto origCurrBB = ::basicBlock;
+ auto newBB = ::function->makeNode();
+ sz::replaceEntryNode(::function, newBB);
+ ::basicBlock = newBB;
+
+ // this->handle = coro::getHandleParam();
+ this->handle = sz::Call(::function, ::basicBlock, coro::getHandleParam);
+
+ // YieldType promise;
+ // coro::setPromisePtr(handle, &promise); // For await
+ this->promise = sz::allocateStackVariable(::function, T(::coroYieldType));
+ sz::Call(::function, ::basicBlock, coro::setPromisePtr, this->handle, this->promise);
+
+ // Branch to original entry block
+ auto br = Ice::InstBr::create(::function, origEntryBB);
+ ::basicBlock->appendInst(br);
+
+ // Restore current block for future instructions
+ ::basicBlock = origCurrBB;
+ }
+
+ // Adds instructions for Yield() calls at the current location of the main coroutine function.
+ void generateYield(Value *val)
+ {
+ // ... <REACTOR CODE> ...
+ //
+ // promise = val;
+ // coro::switchToMainFiber(handle);
+ //
+ // ... <REACTOR CODE> ...
+
+ Nucleus::createStore(val, V(this->promise), ::coroYieldType);
+ sz::Call(::function, ::basicBlock, coro::switchToMainFiber, this->handle);
+ }
+
+ // Adds instructions at the end of the current main coroutine function to end the coroutine.
+ void generateCoroutineEnd()
+ {
+ // ... <REACTOR CODE> ...
+ //
+ // coro::setDone(handle);
+ // coro::switchToMainFiber();
+ // // Unreachable
+ // }
+ //
+
+ sz::Call(::function, ::basicBlock, coro::setDone, this->handle);
+
+ // A Win32 Fiber function must not end, otherwise it tears down the thread it's running on.
+ // So we add code to switch back to the main thread.
+ sz::Call(::function, ::basicBlock, coro::switchToMainFiber, this->handle);
+ }
+
+ using FunctionUniquePtr = std::unique_ptr<Ice::Cfg>;
+
+ // Generates the await function for the current coroutine.
+ // Cannot use Nucleus functions that modify ::function and ::basicBlock.
+ static FunctionUniquePtr generateAwaitFunction()
+ {
+ // bool coroutine_await(CoroutineHandle handle, YieldType* out)
+ // {
+ // if (coro::isDone())
+ // {
+ // return false;
+ // }
+ // else // resume
+ // {
+ // YieldType* promise = coro::getPromisePtr(handle);
+ // *out = *promise;
+ // coro::switchToRoutineFiber(handle);
+ // return true;
+ // }
+ // }
+
+ // Subzero doesn't support bool types (IceType_i1) as return type
+ const Ice::Type ReturnType = Ice::IceType_i32;
+ const Ice::Type YieldPtrType = sz::getPointerType(T(::coroYieldType));
+ const Ice::Type HandleType = sz::getPointerType(Ice::IceType_void);
+
+ Ice::Cfg *awaitFunc = sz::createFunction(::context, ReturnType, std::vector<Ice::Type>{ HandleType, YieldPtrType });
+ Ice::CfgLocalAllocatorScope scopedAlloc{ awaitFunc };
+
+ Ice::Variable *handle = awaitFunc->getArgs()[0];
+ Ice::Variable *outPtr = awaitFunc->getArgs()[1];
+
+ auto doneBlock = awaitFunc->makeNode();
+ {
+ // return false;
+ Ice::InstRet *ret = Ice::InstRet::create(awaitFunc, ::context->getConstantInt32(0));
+ doneBlock->appendInst(ret);
+ }
+
+ auto resumeBlock = awaitFunc->makeNode();
+ {
+ // YieldType* promise = coro::getPromisePtr(handle);
+ Ice::Variable *promise = sz::Call(awaitFunc, resumeBlock, coro::getPromisePtr, handle);
+
+ // *out = *promise;
+ // Load promise value
+ Ice::Variable *promiseVal = awaitFunc->makeVariable(T(::coroYieldType));
+ auto load = Ice::InstLoad::create(awaitFunc, promiseVal, promise);
+ resumeBlock->appendInst(load);
+ // Then store it in output param
+ auto store = Ice::InstStore::create(awaitFunc, promiseVal, outPtr);
+ resumeBlock->appendInst(store);
+
+ // coro::switchToRoutineFiber(handle);
+ sz::Call(awaitFunc, resumeBlock, coro::switchToRoutineFiber, handle);
+
+ // return true;
+ Ice::InstRet *ret = Ice::InstRet::create(awaitFunc, ::context->getConstantInt32(1));
+ resumeBlock->appendInst(ret);
+ }
+
+ // if (coro::isDone())
+ // {
+ // <doneBlock>
+ // }
+ // else // resume
+ // {
+ // <resumeBlock>
+ // }
+ Ice::CfgNode *bb = awaitFunc->getEntryNode();
+ Ice::Variable *done = sz::Call(awaitFunc, bb, coro::isDone);
+ auto br = Ice::InstBr::create(awaitFunc, done, doneBlock, resumeBlock);
+ bb->appendInst(br);
+
+ return FunctionUniquePtr{ awaitFunc };
+ }
+
+ // Generates the destroy function for the current coroutine.
+ // Cannot use Nucleus functions that modify ::function and ::basicBlock.
+ static FunctionUniquePtr generateDestroyFunction()
+ {
+ // void coroutine_destroy(Nucleus::CoroutineHandle handle)
+ // {
+ // coro::convertMainFiberToThread(coroData);
+ // coro::deleteRoutineFiber(handle);
+ // coro::destroyCoroutineData(handle);
+ // return;
+ // }
+
+ const Ice::Type ReturnType = Ice::IceType_void;
+ const Ice::Type HandleType = sz::getPointerType(Ice::IceType_void);
+
+ Ice::Cfg *destroyFunc = sz::createFunction(::context, ReturnType, std::vector<Ice::Type>{ HandleType });
+ Ice::CfgLocalAllocatorScope scopedAlloc{ destroyFunc };
+
+ Ice::Variable *handle = destroyFunc->getArgs()[0];
+
+ auto *bb = destroyFunc->getEntryNode();
+
+ // coro::convertMainFiberToThread(coroData);
+ sz::Call(destroyFunc, bb, coro::convertMainFiberToThread, handle);
+
+ // coro::deleteRoutineFiber(handle);
+ sz::Call(destroyFunc, bb, coro::deleteRoutineFiber, handle);
+
+ // coro::destroyCoroutineData(handle);
+ sz::Call(destroyFunc, bb, coro::destroyCoroutineData, handle);
+
+ // return;
+ Ice::InstRet *ret = Ice::InstRet::create(destroyFunc);
+ bb->appendInst(ret);
+
+ return FunctionUniquePtr{ destroyFunc };
+ }
+
+private:
+ Ice::Variable *handle{};
+ Ice::Variable *promise{};
+};
+
+static Nucleus::CoroutineHandle invokeCoroutineBegin(std::function<Nucleus::CoroutineHandle()> beginFunc)
+{
+ // This doubles up as our coroutine handle
+ auto coroData = coro::createCoroutineData();
+
+ // Convert current thread to a fiber so we can create new fibers and switch to them
+ coro::convertThreadToMainFiber(coroData);
+
+ coro::FiberFunc fiberFunc = [&]() {
+ // Store handle in TLS so that the coroutine can grab it right away, before
+ // any fiber switch occurs.
+ coro::setHandleParam(coroData);
+
+ // Invoke the begin function in the context of the routine fiber
+ beginFunc();
+
+ // Either it yielded, or finished. In either case, we switch back to the main fiber.
+ // We don't ever return from this function, or the current thread will be destroyed.
+ coro::switchToMainFiber(coroData);
+ };
+
+ coro::createRoutineFiber(coroData, &fiberFunc);
+
+ // Fiber will now start running, executing the saved beginFunc
+ coro::switchToRoutineFiber(coroData);
+
+ return coroData;
+}
+
+void Nucleus::createCoroutine(Type *yieldType, const std::vector<Type *> ¶ms)
+{
+ // Start by creating a regular function
+ createFunction(yieldType, params);
+
+ // Save in case yield() is called
+ ASSERT(::coroYieldType == nullptr); // Only one coroutine can be generated at once
+ ::coroYieldType = yieldType;
+}
+
+void Nucleus::yield(Value *val)
+{
+ Variable::materializeAll();
+
+ // On first yield, we start generating coroutine functions
+ if(!::coroGen)
+ {
+ ::coroGen = std::make_shared<CoroutineGenerator>();
+ ::coroGen->generateCoroutineBegin();
+ }
+
+ ASSERT(::coroGen);
+ ::coroGen->generateYield(val);
}
static bool coroutineEntryAwaitStub(Nucleus::CoroutineHandle, void *yieldValue)
{
return false;
}
-static void coroutineEntryDestroyStub(Nucleus::CoroutineHandle) {}
+
+static void coroutineEntryDestroyStub(Nucleus::CoroutineHandle handle)
+{
+}
std::shared_ptr<Routine> Nucleus::acquireCoroutine(const char *name, const Config::Edit &cfgEdit /* = Config::Edit::None */)
{
- // acquireRoutine sets the CoroutineEntryBegin entry
- auto coroutineEntry = acquireRoutine(name, cfgEdit);
+ if(::coroGen)
+ {
+ // Finish generating coroutine functions
+ {
+ Ice::CfgLocalAllocatorScope scopedAlloc{ ::function };
+ ::coroGen->generateCoroutineEnd();
+ createRetVoidIfNoRet();
+ }
- // For now, set the await and destroy entries to stubs, until we add proper coroutine support to the Subzero backend
- auto routine = std::static_pointer_cast<ELFMemoryStreamer>(coroutineEntry);
- routine->setEntry(Nucleus::CoroutineEntryAwait, reinterpret_cast<const void *>(&coroutineEntryAwaitStub));
- routine->setEntry(Nucleus::CoroutineEntryDestroy, reinterpret_cast<const void *>(&coroutineEntryDestroyStub));
+ auto awaitFunc = ::coroGen->generateAwaitFunction();
+ auto destroyFunc = ::coroGen->generateDestroyFunction();
- return coroutineEntry;
+ // At this point, we no longer need the CoroutineGenerator.
+ ::coroGen.reset();
+ ::coroYieldType = nullptr;
+
+ auto routine = rr::acquireRoutine({ ::function, awaitFunc.get(), destroyFunc.get() },
+ { name, "await", "destroy" },
+ cfgEdit);
+
+ return routine;
+ }
+ else
+ {
+ {
+ Ice::CfgLocalAllocatorScope scopedAlloc{ ::function };
+ createRetVoidIfNoRet();
+ }
+
+ ::coroYieldType = nullptr;
+
+ // Not an actual coroutine (no yields), so return stubs for await and destroy
+ auto routine = rr::acquireRoutine({ ::function }, { name }, cfgEdit);
+
+ auto routineImpl = std::static_pointer_cast<ELFMemoryStreamer>(routine);
+ routineImpl->setEntry(Nucleus::CoroutineEntryAwait, reinterpret_cast<const void *>(&coroutineEntryAwaitStub));
+ routineImpl->setEntry(Nucleus::CoroutineEntryDestroy, reinterpret_cast<const void *>(&coroutineEntryDestroyStub));
+ return routine;
+ }
}
-void Nucleus::yield(Value *val)
+Nucleus::CoroutineHandle Nucleus::invokeCoroutineBegin(Routine &routine, std::function<Nucleus::CoroutineHandle()> func)
{
- UNIMPLEMENTED("Yield");
+ const bool isCoroutine = routine.getEntry(Nucleus::CoroutineEntryAwait) != reinterpret_cast<const void *>(&coroutineEntryAwaitStub);
+
+ if(isCoroutine)
+ {
+ return rr::invokeCoroutineBegin(func);
+ }
+ else
+ {
+ // For regular routines, just invoke the begin func directly
+ return func();
+ }
}
} // namespace rr