// Copyright 2019 The SwiftShader Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//  http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "OSFiber.hpp" // Must come first. See OSFiber_ucontext.hpp.

#include "Scheduler.hpp"

#include "Debug.hpp"
#include "Defer.hpp"
#include "Thread.hpp"
#include "Trace.hpp"

#if defined(_WIN32)
#include <intrin.h> // __nop()
#endif

// Enable to trace scheduler events.
#define ENABLE_TRACE_EVENTS 0

#if ENABLE_TRACE_EVENTS
#define TRACE(...) YARN_SCOPED_EVENT(__VA_ARGS__)
#else
#define TRACE(...)
#endif

namespace
{

template <typename T>
inline T take(std::queue<T>& queue)
{
    auto out = std::move(queue.front());
    queue.pop();
    return out;
}

inline void nop()
{
#if defined(_WIN32)
    __nop();
#else
    __asm__ __volatile__ ("nop");
#endif
}

} // anonymous namespace

namespace yarn {

////////////////////////////////////////////////////////////////////////////////
// Scheduler
////////////////////////////////////////////////////////////////////////////////
thread_local Scheduler* Scheduler::bound = nullptr;

Scheduler* Scheduler::get()
{
    return bound;
}

void Scheduler::bind()
{
    YARN_ASSERT(bound == nullptr, "Scheduler already bound");
    bound = this;
    {
        std::unique_lock<std::mutex> lock(singleThreadedWorkerMutex);
        auto worker = std::unique_ptr<Worker>(new Worker(this, Worker::Mode::SingleThreaded, 0));
        worker->start();
        auto tid = std::this_thread::get_id();
        singleThreadedWorkers.emplace(tid, std::move(worker));
    }
}

void Scheduler::unbind()
{
    YARN_ASSERT(bound != nullptr, "No scheduler bound");
    std::unique_ptr<Worker> worker;
    {
        std::unique_lock<std::mutex> lock(bound->singleThreadedWorkerMutex);
        auto tid = std::this_thread::get_id();
        auto it = bound->singleThreadedWorkers.find(tid);
        YARN_ASSERT(it != bound->singleThreadedWorkers.end(), "singleThreadedWorker not found");
        worker = std::move(it->second);
        bound->singleThreadedWorkers.erase(tid);
    }
    worker->flush();
    worker->stop();
    bound = nullptr;
}

Scheduler::Scheduler()
    : nextEnqueueIndex(0)
{
    for (size_t i = 0; i < spinningWorkers.size(); i++)
    {
        spinningWorkers[i] = -1;
    }
}

Scheduler::~Scheduler()
{
    {
        std::unique_lock<std::mutex> lock(singleThreadedWorkerMutex);
        YARN_ASSERT(singleThreadedWorkers.size() == 0, "Scheduler still bound on %d threads", int(singleThreadedWorkers.size()));
    }
    setWorkerThreadCount(0);
}

void Scheduler::setThreadInitializer(const std::function<void()>& func)
{
    std::unique_lock<std::mutex> lock(threadInitFuncMutex);
    threadInitFunc = func;
}

const std::function<void()>& Scheduler::getThreadInitializer()
{
    std::unique_lock<std::mutex> lock(threadInitFuncMutex);
    return threadInitFunc;
}

void Scheduler::setWorkerThreadCount(int newCount)
{
    YARN_ASSERT(newCount >= 0, "count must be positive");
    auto oldCount = numWorkerThreads;
    for (int idx = oldCount - 1; idx >= newCount; idx--)
    {
        workerThreads[idx]->stop();
    }
    for (int idx = oldCount - 1; idx >= newCount; idx--)
    {
        delete workerThreads[idx];
    }
    for (int idx = oldCount; idx < newCount; idx++)
    {
        workerThreads[idx] = new Worker(this, Worker::Mode::MultiThreaded, idx);
    }
    numWorkerThreads = newCount;
    for (int idx = oldCount; idx < newCount; idx++)
    {
        workerThreads[idx]->start();
    }
}

int Scheduler::getWorkerThreadCount()
{
    return numWorkerThreads;
}

void Scheduler::enqueue(Task&& task)
{
    if (numWorkerThreads > 0)
    {
        while (true)
        {
            // Prioritize workers that have recently started spinning.
            auto i = --nextSpinningWorkerIdx % spinningWorkers.size();
            auto idx = spinningWorkers[i].exchange(-1);
            if (idx < 0)
            {
                // If a spinning worker couldn't be found, round-robin the
                // workers.
                idx = nextEnqueueIndex++ % numWorkerThreads;
            }

            auto worker = workerThreads[idx];
            if (worker->tryLock())
            {
                worker->enqueueAndUnlock(std::move(task));
                return;
            }
        }
    }
    else
    {
        auto tid = std::this_thread::get_id();
        std::unique_lock<std::mutex> lock(singleThreadedWorkerMutex);
        auto it = singleThreadedWorkers.find(tid);
        YARN_ASSERT(it != singleThreadedWorkers.end(), "singleThreadedWorker not found");
        it->second->enqueue(std::move(task));
    }
}

bool Scheduler::stealWork(Worker* thief, uint64_t from, Task& out)
{
    if (numWorkerThreads > 0)
    {
        auto thread = workerThreads[from % numWorkerThreads];
        if (thread != thief)
        {
            if (thread->dequeue(out))
            {
                return true;
            }
        }
    }
    return false;
}

void Scheduler::onBeginSpinning(int workerId)
{
    auto idx = nextSpinningWorkerIdx++ % spinningWorkers.size();
    spinningWorkers[idx] = workerId;
}

////////////////////////////////////////////////////////////////////////////////
// Fiber
////////////////////////////////////////////////////////////////////////////////
Scheduler::Fiber::Fiber(OSFiber* impl, uint32_t id) :
    id(id), impl(impl), worker(Scheduler::Worker::getCurrent())
{
    YARN_ASSERT(worker != nullptr, "No Scheduler::Worker bound");
}

Scheduler::Fiber::~Fiber()
{
    delete impl;
}

Scheduler::Fiber* Scheduler::Fiber::current()
{
    auto worker = Scheduler::Worker::getCurrent();
    return worker != nullptr ? worker->getCurrentFiber() : nullptr;
}

void Scheduler::Fiber::schedule()
{
    worker->enqueue(this);
}

void Scheduler::Fiber::yield()
{
    YARN_SCOPED_EVENT("YIELD");
    worker->yield(this);
}

void Scheduler::Fiber::switchTo(Fiber* to)
{
    if (to != this)
    {
        impl->switchTo(to->impl);
    }
}

Scheduler::Fiber* Scheduler::Fiber::create(uint32_t id, size_t stackSize, const std::function<void()>& func)
{
    return new Fiber(OSFiber::createFiber(stackSize, func), id);
}

Scheduler::Fiber* Scheduler::Fiber::createFromCurrentThread(uint32_t id)
{
    return new Fiber(OSFiber::createFiberFromCurrentThread(), id);
}

////////////////////////////////////////////////////////////////////////////////
// Scheduler::Worker
////////////////////////////////////////////////////////////////////////////////
thread_local Scheduler::Worker* Scheduler::Worker::current = nullptr;

Scheduler::Worker::Worker(Scheduler *scheduler, Mode mode, uint32_t id) : id(id), mode(mode), scheduler(scheduler) {}

void Scheduler::Worker::start()
{
    switch (mode)
    {
    case Mode::MultiThreaded:
        thread = std::thread([=]
        {
            Thread::setName("Thread<%.2d>", int(id));

            if (auto const &initFunc = scheduler->getThreadInitializer())
            {
                initFunc();
            }

            Scheduler::bound = scheduler;
            Worker::current = this;
            mainFiber.reset(Fiber::createFromCurrentThread(0));
            currentFiber = mainFiber.get();
            run();
            mainFiber.reset();
            Worker::current = nullptr;
        });
        break;

    case Mode::SingleThreaded:
        Worker::current = this;
        mainFiber.reset(Fiber::createFromCurrentThread(0));
        currentFiber = mainFiber.get();
        break;

    default:
        YARN_ASSERT(false, "Unknown mode: %d", int(mode));
    }
}

void Scheduler::Worker::stop()
{
    switch (mode)
    {
    case Mode::MultiThreaded:
        shutdown = true;
        enqueue([]{}); // Ensure the worker is woken up to notice the shutdown.
        thread.join();
        break;

    case Mode::SingleThreaded:
        Worker::current = nullptr;
        break;

    default:
        YARN_ASSERT(false, "Unknown mode: %d", int(mode));
    }
}

void Scheduler::Worker::yield(Fiber *from)
{
    YARN_ASSERT(currentFiber == from, "Attempting to call yield from a non-current fiber");

    // Current fiber is yielding as it is blocked.

    // First wait until there's something else this worker can do.
    std::unique_lock<std::mutex> lock(work.mutex);
    waitForWork(lock);


    if (work.fibers.size() > 0)
    {
        // There's another fiber that has become unblocked, resume that.
        work.num--;
        auto to = take(work.fibers);
        lock.unlock();
        switchToFiber(to);
    }
    else if (idleFibers.size() > 0)
    {
        // There's an old fiber we can reuse, resume that.
        auto to = take(idleFibers);
        lock.unlock();
        switchToFiber(to);
    }
    else
    {
        // Tasks to process and no existing fibers to resume. Spawn a new fiber.
        lock.unlock();
        switchToFiber(createWorkerFiber());
    }
}

bool Scheduler::Worker::tryLock()
{
    return work.mutex.try_lock();
}

void Scheduler::Worker::enqueue(Fiber* fiber)
{
    std::unique_lock<std::mutex> lock(work.mutex);
    auto wasIdle = work.num == 0;
    work.fibers.push(std::move(fiber));
    work.num++;
    lock.unlock();
    if (wasIdle) { work.added.notify_one(); }
}

void Scheduler::Worker::enqueue(Task&& task)
{
    work.mutex.lock();
    enqueueAndUnlock(std::move(task));
}

void Scheduler::Worker::enqueueAndUnlock(Task&& task)
{
    auto wasIdle = work.num == 0;
    work.tasks.push(std::move(task));
    work.num++;
    work.mutex.unlock();
    if (wasIdle) { work.added.notify_one(); }
}

bool Scheduler::Worker::dequeue(Task& out)
{
    if (work.num.load() == 0) { return false; }
    if (!work.mutex.try_lock()) { return false; }
    defer(work.mutex.unlock());
    if (work.tasks.size() == 0) { return false; }
    work.num--;
    out = take(work.tasks);
    return true;
}

void Scheduler::Worker::flush()
{
    YARN_ASSERT(mode == Mode::SingleThreaded, "flush() can only be used on a single-threaded worker");
    std::unique_lock<std::mutex> lock(work.mutex);
    runUntilIdle(lock);
}

void Scheduler::Worker::run()
{
    switch (mode)
    {
    case Mode::MultiThreaded:
    {
        YARN_NAME_THREAD("Thread<%.2d> Fiber<%.2d>", int(id), Fiber::current()->id);
        {
            std::unique_lock<std::mutex> lock(work.mutex);
            work.added.wait(lock, [this] { return work.num > 0 || shutdown; });
            while (!shutdown)
            {
                waitForWork(lock);
                runUntilIdle(lock);
            }
            Worker::current = nullptr;
        }
        switchToFiber(mainFiber.get());
        break;
    }
    case Mode::SingleThreaded:
        while (!shutdown)
        {
            flush();
            idleFibers.emplace(currentFiber);
            switchToFiber(mainFiber.get());
        }
        break;

    default:
        YARN_ASSERT(false, "Unknown mode: %d", int(mode));
    }
}

_Requires_lock_held_(lock)
void Scheduler::Worker::waitForWork(std::unique_lock<std::mutex> &lock)
{
    YARN_ASSERT(work.num == work.fibers.size() + work.tasks.size(), "work.num out of sync");
    if (work.num == 0)
    {
        scheduler->onBeginSpinning(id);
        lock.unlock();
        spinForWork();
        lock.lock();
    }
    work.added.wait(lock, [this] { return work.num > 0 || shutdown; });
}

void Scheduler::Worker::spinForWork()
{
    TRACE("SPIN");
    Task stolen;

    constexpr auto duration = std::chrono::milliseconds(1);
    auto start = std::chrono::high_resolution_clock::now();
    while (std::chrono::high_resolution_clock::now() - start < duration)
    {
        for (int i = 0; i < 256; i++) // Empirically picked magic number!
        {
            nop(); nop(); nop(); nop(); nop(); nop(); nop(); nop();
            nop(); nop(); nop(); nop(); nop(); nop(); nop(); nop();
            nop(); nop(); nop(); nop(); nop(); nop(); nop(); nop();
            nop(); nop(); nop(); nop(); nop(); nop(); nop(); nop();
            if (work.num > 0)
            {
                return;
            }
        }

        if (scheduler->stealWork(this, rng(), stolen))
        {
            std::unique_lock<std::mutex> lock(work.mutex);
            work.tasks.emplace(std::move(stolen));
            work.num++;
            return;
        }

        std::this_thread::yield();
    }
}

_Requires_lock_held_(lock)
void Scheduler::Worker::runUntilIdle(std::unique_lock<std::mutex> &lock)
{
    YARN_ASSERT(work.num == work.fibers.size() + work.tasks.size(), "work.num out of sync");
    while (work.fibers.size() > 0 || work.tasks.size() > 0)
    {
        // Note: we cannot take and store on the stack more than a single fiber
        // or task at a time, as the Fiber may yield and these items may get
        // held on suspended fiber stack.

        while (work.fibers.size() > 0)
        {
            work.num--;
            auto fiber = take(work.fibers);
            lock.unlock();
            idleFibers.push(currentFiber);
            switchToFiber(fiber);
            lock.lock();
        }

        if (work.tasks.size() > 0)
        {
            work.num--;
            auto task = take(work.tasks);
            lock.unlock();

            // Run the task.
            task();

            // std::function<> can carry arguments with complex destructors.
            // Ensure these are destructed outside of the lock.
            task = Task();

            lock.lock();
        }
    }
}

Scheduler::Fiber* Scheduler::Worker::createWorkerFiber()
{
    auto id = workerFibers.size() + 1;
    auto fiber = Fiber::create(id, FiberStackSize, [&] { run(); });
    workerFibers.push_back(std::unique_ptr<Fiber>(fiber));
    return fiber;
}

void Scheduler::Worker::switchToFiber(Fiber* to)
{
    auto from = currentFiber;
    currentFiber = to;
    from->switchTo(to);
}

} // namespace yarn
