blob: ca21cdd6019a397b1d293f196c2349d802d917c6 [file] [log] [blame]
// Copyright 2019 The Marl Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "osfiber.h" // Must come first. See osfiber_ucontext.h.
#include "marl/scheduler.h"
#include "marl/debug.h"
#include "marl/defer.h"
#include "marl/thread.h"
#include "marl/trace.h"
#if defined(_WIN32)
#include <intrin.h> // __nop()
#endif
// Enable to trace scheduler events.
#define ENABLE_TRACE_EVENTS 0
#if ENABLE_TRACE_EVENTS
#define TRACE(...) MARL_SCOPED_EVENT(__VA_ARGS__)
#else
#define TRACE(...)
#endif
namespace {
template <typename T>
inline T take(std::queue<T>& queue) {
auto out = std::move(queue.front());
queue.pop();
return out;
}
inline void nop() {
#if defined(_WIN32)
__nop();
#else
__asm__ __volatile__("nop");
#endif
}
} // anonymous namespace
namespace marl {
////////////////////////////////////////////////////////////////////////////////
// Scheduler
////////////////////////////////////////////////////////////////////////////////
thread_local Scheduler* Scheduler::bound = nullptr;
Scheduler* Scheduler::get() {
return bound;
}
void Scheduler::bind() {
MARL_ASSERT(bound == nullptr, "Scheduler already bound");
bound = this;
{
std::unique_lock<std::mutex> lock(singleThreadedWorkerMutex);
auto worker = std::unique_ptr<Worker>(
new Worker(this, Worker::Mode::SingleThreaded, 0));
worker->start();
auto tid = std::this_thread::get_id();
singleThreadedWorkers.emplace(tid, std::move(worker));
}
}
void Scheduler::unbind() {
MARL_ASSERT(bound != nullptr, "No scheduler bound");
std::unique_ptr<Worker> worker;
{
std::unique_lock<std::mutex> lock(bound->singleThreadedWorkerMutex);
auto tid = std::this_thread::get_id();
auto it = bound->singleThreadedWorkers.find(tid);
MARL_ASSERT(it != bound->singleThreadedWorkers.end(),
"singleThreadedWorker not found");
worker = std::move(it->second);
bound->singleThreadedWorkers.erase(tid);
}
worker->flush();
worker->stop();
bound = nullptr;
}
Scheduler::Scheduler() {
for (size_t i = 0; i < spinningWorkers.size(); i++) {
spinningWorkers[i] = -1;
}
}
Scheduler::~Scheduler() {
{
std::unique_lock<std::mutex> lock(singleThreadedWorkerMutex);
MARL_ASSERT(singleThreadedWorkers.size() == 0,
"Scheduler still bound on %d threads",
int(singleThreadedWorkers.size()));
}
setWorkerThreadCount(0);
}
void Scheduler::setThreadInitializer(const std::function<void()>& func) {
std::unique_lock<std::mutex> lock(threadInitFuncMutex);
threadInitFunc = func;
}
const std::function<void()>& Scheduler::getThreadInitializer() {
std::unique_lock<std::mutex> lock(threadInitFuncMutex);
return threadInitFunc;
}
void Scheduler::setWorkerThreadCount(int newCount) {
MARL_ASSERT(newCount >= 0, "count must be positive");
auto oldCount = numWorkerThreads;
for (int idx = oldCount - 1; idx >= newCount; idx--) {
workerThreads[idx]->stop();
}
for (int idx = oldCount - 1; idx >= newCount; idx--) {
delete workerThreads[idx];
}
for (int idx = oldCount; idx < newCount; idx++) {
workerThreads[idx] = new Worker(this, Worker::Mode::MultiThreaded, idx);
}
numWorkerThreads = newCount;
for (int idx = oldCount; idx < newCount; idx++) {
workerThreads[idx]->start();
}
}
int Scheduler::getWorkerThreadCount() {
return numWorkerThreads;
}
void Scheduler::enqueue(Task&& task) {
if (numWorkerThreads > 0) {
while (true) {
// Prioritize workers that have recently started spinning.
auto i = --nextSpinningWorkerIdx % spinningWorkers.size();
auto idx = spinningWorkers[i].exchange(-1);
if (idx < 0) {
// If a spinning worker couldn't be found, round-robin the
// workers.
idx = nextEnqueueIndex++ % numWorkerThreads;
}
auto worker = workerThreads[idx];
if (worker->tryLock()) {
worker->enqueueAndUnlock(std::move(task));
return;
}
}
} else {
auto tid = std::this_thread::get_id();
std::unique_lock<std::mutex> lock(singleThreadedWorkerMutex);
auto it = singleThreadedWorkers.find(tid);
MARL_ASSERT(it != singleThreadedWorkers.end(),
"singleThreadedWorker not found");
it->second->enqueue(std::move(task));
}
}
bool Scheduler::stealWork(Worker* thief, uint64_t from, Task& out) {
if (numWorkerThreads > 0) {
auto thread = workerThreads[from % numWorkerThreads];
if (thread != thief) {
if (thread->dequeue(out)) {
return true;
}
}
}
return false;
}
void Scheduler::onBeginSpinning(int workerId) {
auto idx = nextSpinningWorkerIdx++ % spinningWorkers.size();
spinningWorkers[idx] = workerId;
}
////////////////////////////////////////////////////////////////////////////////
// Fiber
////////////////////////////////////////////////////////////////////////////////
Scheduler::Fiber::Fiber(OSFiber* impl, uint32_t id)
: id(id), impl(impl), worker(Scheduler::Worker::getCurrent()) {
MARL_ASSERT(worker != nullptr, "No Scheduler::Worker bound");
}
Scheduler::Fiber::~Fiber() {
delete impl;
}
Scheduler::Fiber* Scheduler::Fiber::current() {
auto worker = Scheduler::Worker::getCurrent();
return worker != nullptr ? worker->getCurrentFiber() : nullptr;
}
void Scheduler::Fiber::schedule() {
worker->enqueue(this);
}
void Scheduler::Fiber::yield() {
MARL_SCOPED_EVENT("YIELD");
worker->yield(this);
}
void Scheduler::Fiber::switchTo(Fiber* to) {
if (to != this) {
impl->switchTo(to->impl);
}
}
Scheduler::Fiber* Scheduler::Fiber::create(uint32_t id,
size_t stackSize,
const std::function<void()>& func) {
return new Fiber(OSFiber::createFiber(stackSize, func), id);
}
Scheduler::Fiber* Scheduler::Fiber::createFromCurrentThread(uint32_t id) {
return new Fiber(OSFiber::createFiberFromCurrentThread(), id);
}
////////////////////////////////////////////////////////////////////////////////
// Scheduler::Worker
////////////////////////////////////////////////////////////////////////////////
thread_local Scheduler::Worker* Scheduler::Worker::current = nullptr;
Scheduler::Worker::Worker(Scheduler* scheduler, Mode mode, uint32_t id)
: id(id), mode(mode), scheduler(scheduler) {}
void Scheduler::Worker::start() {
switch (mode) {
case Mode::MultiThreaded:
thread = std::thread([=] {
Thread::setName("Thread<%.2d>", int(id));
if (auto const& initFunc = scheduler->getThreadInitializer()) {
initFunc();
}
Scheduler::bound = scheduler;
Worker::current = this;
mainFiber.reset(Fiber::createFromCurrentThread(0));
currentFiber = mainFiber.get();
run();
mainFiber.reset();
Worker::current = nullptr;
});
break;
case Mode::SingleThreaded:
Worker::current = this;
mainFiber.reset(Fiber::createFromCurrentThread(0));
currentFiber = mainFiber.get();
break;
default:
MARL_ASSERT(false, "Unknown mode: %d", int(mode));
}
}
void Scheduler::Worker::stop() {
switch (mode) {
case Mode::MultiThreaded:
shutdown = true;
enqueue([] {}); // Ensure the worker is woken up to notice the shutdown.
thread.join();
break;
case Mode::SingleThreaded:
Worker::current = nullptr;
break;
default:
MARL_ASSERT(false, "Unknown mode: %d", int(mode));
}
}
void Scheduler::Worker::yield(Fiber* from) {
MARL_ASSERT(currentFiber == from,
"Attempting to call yield from a non-current fiber");
// Current fiber is yielding as it is blocked.
// First wait until there's something else this worker can do.
std::unique_lock<std::mutex> lock(work.mutex);
waitForWork(lock);
if (work.fibers.size() > 0) {
// There's another fiber that has become unblocked, resume that.
work.num--;
auto to = take(work.fibers);
lock.unlock();
switchToFiber(to);
} else if (idleFibers.size() > 0) {
// There's an old fiber we can reuse, resume that.
auto to = take(idleFibers);
lock.unlock();
switchToFiber(to);
} else {
// Tasks to process and no existing fibers to resume. Spawn a new fiber.
lock.unlock();
switchToFiber(createWorkerFiber());
}
}
bool Scheduler::Worker::tryLock() {
return work.mutex.try_lock();
}
void Scheduler::Worker::enqueue(Fiber* fiber) {
std::unique_lock<std::mutex> lock(work.mutex);
auto wasIdle = work.num == 0;
work.fibers.push(std::move(fiber));
work.num++;
lock.unlock();
if (wasIdle) {
work.added.notify_one();
}
}
void Scheduler::Worker::enqueue(Task&& task) {
work.mutex.lock();
enqueueAndUnlock(std::move(task));
}
void Scheduler::Worker::enqueueAndUnlock(Task&& task) {
auto wasIdle = work.num == 0;
work.tasks.push(std::move(task));
work.num++;
work.mutex.unlock();
if (wasIdle) {
work.added.notify_one();
}
}
bool Scheduler::Worker::dequeue(Task& out) {
if (work.num.load() == 0) {
return false;
}
if (!work.mutex.try_lock()) {
return false;
}
defer(work.mutex.unlock());
if (work.tasks.size() == 0) {
return false;
}
work.num--;
out = take(work.tasks);
return true;
}
void Scheduler::Worker::flush() {
MARL_ASSERT(mode == Mode::SingleThreaded,
"flush() can only be used on a single-threaded worker");
std::unique_lock<std::mutex> lock(work.mutex);
runUntilIdle(lock);
}
void Scheduler::Worker::run() {
switch (mode) {
case Mode::MultiThreaded: {
MARL_NAME_THREAD("Thread<%.2d> Fiber<%.2d>", int(id),
Fiber::current()->id);
{
std::unique_lock<std::mutex> lock(work.mutex);
work.added.wait(lock, [this] { return work.num > 0 || shutdown; });
while (!shutdown) {
waitForWork(lock);
runUntilIdle(lock);
}
Worker::current = nullptr;
}
switchToFiber(mainFiber.get());
break;
}
case Mode::SingleThreaded:
while (!shutdown) {
flush();
idleFibers.emplace(currentFiber);
switchToFiber(mainFiber.get());
}
break;
default:
MARL_ASSERT(false, "Unknown mode: %d", int(mode));
}
}
_Requires_lock_held_(lock) void Scheduler::Worker::waitForWork(
std::unique_lock<std::mutex>& lock) {
MARL_ASSERT(work.num == work.fibers.size() + work.tasks.size(),
"work.num out of sync");
if (work.num == 0) {
scheduler->onBeginSpinning(id);
lock.unlock();
spinForWork();
lock.lock();
}
work.added.wait(lock, [this] { return work.num > 0 || shutdown; });
}
void Scheduler::Worker::spinForWork() {
TRACE("SPIN");
Task stolen;
constexpr auto duration = std::chrono::milliseconds(1);
auto start = std::chrono::high_resolution_clock::now();
while (std::chrono::high_resolution_clock::now() - start < duration) {
for (int i = 0; i < 256; i++) // Empirically picked magic number!
{
// clang-format off
nop(); nop(); nop(); nop(); nop(); nop(); nop(); nop();
nop(); nop(); nop(); nop(); nop(); nop(); nop(); nop();
nop(); nop(); nop(); nop(); nop(); nop(); nop(); nop();
nop(); nop(); nop(); nop(); nop(); nop(); nop(); nop();
// clang-format on
if (work.num > 0) {
return;
}
}
if (scheduler->stealWork(this, rng(), stolen)) {
std::unique_lock<std::mutex> lock(work.mutex);
work.tasks.emplace(std::move(stolen));
work.num++;
return;
}
std::this_thread::yield();
}
}
_Requires_lock_held_(lock) void Scheduler::Worker::runUntilIdle(
std::unique_lock<std::mutex>& lock) {
MARL_ASSERT(work.num == work.fibers.size() + work.tasks.size(),
"work.num out of sync");
while (work.fibers.size() > 0 || work.tasks.size() > 0) {
// Note: we cannot take and store on the stack more than a single fiber
// or task at a time, as the Fiber may yield and these items may get
// held on suspended fiber stack.
while (work.fibers.size() > 0) {
work.num--;
auto fiber = take(work.fibers);
lock.unlock();
idleFibers.push(currentFiber);
switchToFiber(fiber);
lock.lock();
}
if (work.tasks.size() > 0) {
work.num--;
auto task = take(work.tasks);
lock.unlock();
// Run the task.
task();
// std::function<> can carry arguments with complex destructors.
// Ensure these are destructed outside of the lock.
task = Task();
lock.lock();
}
}
}
Scheduler::Fiber* Scheduler::Worker::createWorkerFiber() {
auto id = static_cast<uint32_t>(workerFibers.size() + 1);
auto fiber = Fiber::create(id, FiberStackSize, [&] { run(); });
workerFibers.push_back(std::unique_ptr<Fiber>(fiber));
return fiber;
}
void Scheduler::Worker::switchToFiber(Fiber* to) {
auto from = currentFiber;
currentFiber = to;
from->switchTo(to);
}
} // namespace marl