Squashed 'third_party/marl/' changes from ca8408f68..16e1dc37c
16e1dc37c Add missing copyright header to marl-config.cmake.in
09b759550 CMake: link with pthread in a CMake way (#108)
47afda398 Kokoro+ubuntu: enable x86 and clang testing
f421c8b4f Kokoro+ubuntu: Plubming for x86 & clang
69797fcf0 Fixes for x86.
3d6365b82 Scheduler: add wait() overloads that do not take a lock
db1e8c767 Scheduler: block until all threads are unbound.
git-subtree-dir: third_party/marl
git-subtree-split: 16e1dc37c5e12c35b93272f0353d0e6e9f200a26
diff --git a/src/scheduler.cpp b/src/scheduler.cpp
index 1f10f3a..c97d9ec 100644
--- a/src/scheduler.cpp
+++ b/src/scheduler.cpp
@@ -99,29 +99,31 @@
MARL_ASSERT(bound == nullptr, "Scheduler already bound");
bound = this;
{
- std::unique_lock<std::mutex> lock(singleThreadedWorkerMutex);
+ std::unique_lock<std::mutex> lock(singleThreadedWorkers.mutex);
auto worker =
allocator->make_unique<Worker>(this, Worker::Mode::SingleThreaded, -1);
worker->start();
auto tid = std::this_thread::get_id();
- singleThreadedWorkers.emplace(tid, std::move(worker));
+ singleThreadedWorkers.byTid.emplace(tid, std::move(worker));
}
}
void Scheduler::unbind() {
MARL_ASSERT(bound != nullptr, "No scheduler bound");
- Allocator::unique_ptr<Worker> worker;
- {
- std::unique_lock<std::mutex> lock(bound->singleThreadedWorkerMutex);
- auto tid = std::this_thread::get_id();
- auto it = bound->singleThreadedWorkers.find(tid);
- MARL_ASSERT(it != bound->singleThreadedWorkers.end(),
- "singleThreadedWorker not found");
- worker = std::move(it->second);
- bound->singleThreadedWorkers.erase(tid);
- }
- worker->flush();
+ auto worker = Scheduler::Worker::getCurrent();
worker->stop();
+ {
+ std::unique_lock<std::mutex> lock(bound->singleThreadedWorkers.mutex);
+ auto tid = std::this_thread::get_id();
+ auto it = bound->singleThreadedWorkers.byTid.find(tid);
+ MARL_ASSERT(it != bound->singleThreadedWorkers.byTid.end(),
+ "singleThreadedWorker not found");
+ MARL_ASSERT(it->second.get() == worker, "worker is not bound?");
+ bound->singleThreadedWorkers.byTid.erase(it);
+ if (bound->singleThreadedWorkers.byTid.size() == 0) {
+ bound->singleThreadedWorkers.unbind.notify_one();
+ }
+ }
bound = nullptr;
}
@@ -133,14 +135,12 @@
}
Scheduler::~Scheduler() {
-#if MARL_DEBUG_ENABLED
{
- std::unique_lock<std::mutex> lock(singleThreadedWorkerMutex);
- MARL_ASSERT(singleThreadedWorkers.size() == 0,
- "Scheduler still bound on %d threads",
- int(singleThreadedWorkers.size()));
+ // Wait until all the single threaded workers have been unbound.
+ std::unique_lock<std::mutex> lock(singleThreadedWorkers.mutex);
+ singleThreadedWorkers.unbind.wait(
+ lock, [this] { return singleThreadedWorkers.byTid.size() == 0; });
}
-#endif // MARL_DEBUG_ENABLED
// Release all worker threads.
// This will wait for all in-flight tasks to complete before returning.
@@ -211,12 +211,9 @@
}
}
} else {
- auto tid = std::this_thread::get_id();
- std::unique_lock<std::mutex> lock(singleThreadedWorkerMutex);
- auto it = singleThreadedWorkers.find(tid);
- MARL_ASSERT(it != singleThreadedWorkers.end(),
- "singleThreadedWorker not found");
- it->second->enqueue(std::move(task));
+ auto worker = Worker::getCurrent();
+ MARL_ASSERT(worker, "singleThreadedWorker not found");
+ worker->enqueue(std::move(task));
}
}
@@ -398,20 +395,32 @@
void Scheduler::Worker::stop() {
switch (mode) {
- case Mode::MultiThreaded:
+ case Mode::MultiThreaded: {
enqueue(Task([this] { shutdown = true; }, Task::Flags::SameThread));
thread.join();
break;
-
- case Mode::SingleThreaded:
+ }
+ case Mode::SingleThreaded: {
+ std::unique_lock<std::mutex> lock(work.mutex);
+ shutdown = true;
+ runUntilShutdown();
Worker::current = nullptr;
break;
-
+ }
default:
MARL_ASSERT(false, "Unknown mode: %d", int(mode));
}
}
+bool Scheduler::Worker::wait(const TimePoint* timeout) {
+ DBG_LOG("%d: WAIT(%d)", (int)id, (int)currentFiber->id);
+ {
+ std::unique_lock<std::mutex> lock(work.mutex);
+ suspend(timeout);
+ }
+ return timeout == nullptr || std::chrono::system_clock::now() < *timeout;
+}
+
_Requires_lock_held_(waitLock)
bool Scheduler::Worker::wait(Fiber::Lock& waitLock,
const TimePoint* timeout,
@@ -463,6 +472,8 @@
// First wait until there's something else this worker can do.
waitForWork();
+ work.numBlockedFibers++;
+
if (work.fibers.size() > 0) {
// There's another fiber that has become unblocked, resume that.
work.num--;
@@ -480,6 +491,8 @@
switchToFiber(createWorkerFiber());
}
+ work.numBlockedFibers--;
+
setFiberState(currentFiber, Fiber::State::Running);
}
@@ -552,39 +565,24 @@
}
_Requires_lock_held_(work.mutex)
-void Scheduler::Worker::flush() {
- MARL_ASSERT(mode == Mode::SingleThreaded,
- "flush() can only be used on a single-threaded worker");
- std::unique_lock<std::mutex> lock(work.mutex);
- runUntilIdle();
+void Scheduler::Worker::run() {
+ if (mode == Mode::MultiThreaded) {
+ MARL_NAME_THREAD("Thread<%.2d> Fiber<%.2d>", int(id), Fiber::current()->id);
+ // This is the entry point for a multi-threaded worker.
+ // Start with a regular condition-variable wait for work. This avoids
+ // starting the thread with a spinForWork().
+ work.wait([this] { return work.num > 0 || work.waiting || shutdown; });
+ }
+ ASSERT_FIBER_STATE(currentFiber, Fiber::State::Running);
+ runUntilShutdown();
+ switchToFiber(mainFiber.get());
}
_Requires_lock_held_(work.mutex)
-void Scheduler::Worker::run() {
- switch (mode) {
- case Mode::MultiThreaded: {
- MARL_NAME_THREAD("Thread<%.2d> Fiber<%.2d>", int(id),
- Fiber::current()->id);
- work.wait([this] { return work.num > 0 || work.waiting || shutdown; });
- while (!shutdown || work.num > 0 || numBlockedFibers() > 0U) {
- waitForWork();
- runUntilIdle();
- }
- Worker::current = nullptr;
- switchToFiber(mainFiber.get());
- break;
- }
- case Mode::SingleThreaded: {
- ASSERT_FIBER_STATE(currentFiber, Fiber::State::Running);
- while (!shutdown) {
- runUntilIdle();
- idleFibers.emplace(currentFiber);
- switchToFiber(mainFiber.get());
- }
- break;
- }
- default:
- MARL_ASSERT(false, "Unknown mode: %d", int(mode));
+void Scheduler::Worker::runUntilShutdown() {
+ while (!shutdown || work.num > 0 || work.numBlockedFibers > 0U) {
+ waitForWork();
+ runUntilIdle();
}
}
@@ -604,7 +602,7 @@
}
work.wait([this] {
- return work.num > 0 || (shutdown && numBlockedFibers() == 0U);
+ return work.num > 0 || (shutdown && work.numBlockedFibers == 0U);
});
if (work.waiting) {
enqueueFiberTimeouts();