Vulkan: Fix use-after-destruction of VkFence event
Remove sw::TaskEvents and sw::WaitGroup, replace this with sw::CountedEvent.
See b/173784261 for details.
Fixes: b/173784261
Change-Id: I21fb69c810558a1929bba5cc46f106d9d4e51c4b
Reviewed-on: https://swiftshader-review.googlesource.com/c/SwiftShader/+/50628
Kokoro-Result: kokoro <noreply+kokoro@google.com>
Reviewed-by: Chris Forbes <chrisforbes@google.com>
Reviewed-by: Nicolas Capens <nicolascapens@google.com>
Tested-by: Ben Clayton <bclayton@google.com>
Presubmit-Ready: Ben Clayton <bclayton@google.com>
diff --git a/src/Device/Renderer.cpp b/src/Device/Renderer.cpp
index 262446a..44a6fe8 100644
--- a/src/Device/Renderer.cpp
+++ b/src/Device/Renderer.cpp
@@ -181,7 +181,7 @@
}
void Renderer::draw(const sw::Context *context, VkIndexType indexType, unsigned int count, int baseVertex,
- TaskEvents *events, int instanceID, int viewID, void *indexBuffer, const VkExtent3D &framebufferExtent,
+ CountedEvent *events, int instanceID, int viewID, void *indexBuffer, const VkExtent3D &framebufferExtent,
PushConstantStorage const &pushConstants, bool update)
{
if(count == 0) { return; }
@@ -421,7 +421,7 @@
if(events)
{
- events->start();
+ events->add();
}
}
@@ -429,7 +429,7 @@
{
if(events)
{
- events->finish();
+ events->done();
events = nullptr;
}
diff --git a/src/Device/Renderer.hpp b/src/Device/Renderer.hpp
index ab19a8b..57eb7ba 100644
--- a/src/Device/Renderer.hpp
+++ b/src/Device/Renderer.hpp
@@ -43,11 +43,11 @@
namespace sw {
+class CountedEvent;
struct DrawCall;
class PixelShader;
class VertexShader;
struct Task;
-class TaskEvents;
class Resource;
struct Constants;
@@ -172,7 +172,7 @@
vk::ImageView *stencilBuffer;
vk::DescriptorSet::Array descriptorSetObjects;
const vk::PipelineLayout *pipelineLayout;
- TaskEvents *events;
+ sw::CountedEvent *events;
vk::Query *occlusionQuery;
@@ -210,7 +210,7 @@
bool hasOcclusionQuery() const { return occlusionQuery != nullptr; }
void draw(const sw::Context *context, VkIndexType indexType, unsigned int count, int baseVertex,
- TaskEvents *events, int instanceID, int viewID, void *indexBuffer, const VkExtent3D &framebufferExtent,
+ CountedEvent *events, int instanceID, int viewID, void *indexBuffer, const VkExtent3D &framebufferExtent,
PushConstantStorage const &pushConstants, bool update = true);
// Viewport & Clipper
diff --git a/src/System/CMakeLists.txt b/src/System/CMakeLists.txt
index 24aae13..d28feca 100644
--- a/src/System/CMakeLists.txt
+++ b/src/System/CMakeLists.txt
@@ -68,6 +68,11 @@
${ROOT_PROJECT_COMPILE_OPTIONS}
)
+target_link_libraries(vk_system
+ PUBLIC
+ marl
+)
+
target_link_options(vk_system
PUBLIC
${SWIFTSHADER_LINK_FLAGS}
diff --git a/src/System/Synchronization.hpp b/src/System/Synchronization.hpp
index 81794ab..8e9d63b 100644
--- a/src/System/Synchronization.hpp
+++ b/src/System/Synchronization.hpp
@@ -22,103 +22,85 @@
#ifndef sw_Synchronization_hpp
#define sw_Synchronization_hpp
+#include "Debug.hpp"
+
#include <assert.h>
#include <chrono>
#include <condition_variable>
#include <queue>
+#include "marl/event.h"
#include "marl/mutex.h"
+#include "marl/waitgroup.h"
namespace sw {
-// TaskEvents is an interface for notifying when tasks begin and end.
-// Tasks can be nested and/or overlapping.
-// TaskEvents is used for task queue synchronization.
-class TaskEvents
+// CountedEvent is an event that is signalled when the internal counter is
+// decremented and reaches zero.
+// The counter is incremented with calls to add() and decremented with calls to
+// done().
+class CountedEvent
{
public:
- // start() is called before a task begins.
- virtual void start() = 0;
- // finish() is called after a task ends. finish() must only be called after
- // a corresponding call to start().
- virtual void finish() = 0;
- // complete() is a helper for calling start() followed by finish().
- inline void complete()
+ // Constructs the CountedEvent with the initial signalled state set to the
+ // provided value.
+ CountedEvent(bool signalled = false)
+ : ev(marl::Event::Mode::Manual, signalled)
+ {}
+
+ // add() increments the internal counter.
+ // add() must not be called when the event is already signalled.
+ void add() const
{
- start();
- finish();
+ ASSERT(!ev.isSignalled());
+ wg.add();
}
-protected:
- virtual ~TaskEvents() = default;
-};
-
-// WaitGroup is a synchronization primitive that allows you to wait for
-// collection of asynchronous tasks to finish executing.
-// Call add() before each task begins, and then call done() when after each task
-// is finished.
-// At the same time, wait() can be used to block until all tasks have finished.
-// WaitGroup takes its name after Golang's sync.WaitGroup.
-class WaitGroup : public TaskEvents
-{
-public:
- // add() begins a new task.
- void add()
+ // done() decrements the internal counter, signalling the event if the new
+ // counter value is zero.
+ // done() must not be called when the event is already signalled.
+ void done() const
{
- marl::lock lock(mutex);
- ++count_;
- }
-
- // done() is called when a task of the WaitGroup has been completed.
- // Returns true if there are no more tasks currently running in the
- // WaitGroup.
- bool done()
- {
- marl::lock lock(mutex);
- assert(count_ > 0);
- --count_;
- if(count_ == 0)
+ ASSERT(!ev.isSignalled());
+ if(wg.done())
{
- condition.notify_all();
+ ev.signal();
}
- return count_ == 0;
}
- // wait() blocks until all the tasks have been finished.
- void wait()
+ // reset() clears the signal state.
+ // done() must not be called when the internal counter is non-zero.
+ void reset() const
{
- marl::lock lock(mutex);
- lock.wait(condition, [this]() REQUIRES(mutex) { return count_ == 0; });
+ ev.clear();
}
- // wait() blocks until all the tasks have been finished or the timeout
- // has been reached, returning true if all tasks have been completed, or
- // false if the timeout has been reached.
+ // signalled() returns the current signal state.
+ bool signalled() const
+ {
+ return ev.isSignalled();
+ }
+
+ // wait() waits until the event is signalled.
+ void wait() const
+ {
+ ev.wait();
+ }
+
+ // wait() waits until the event is signalled or the timeout is reached.
+ // If the timeout was reached, then wait() return false.
template<class CLOCK, class DURATION>
- bool wait(const std::chrono::time_point<CLOCK, DURATION> &timeout)
+ bool wait(const std::chrono::time_point<CLOCK, DURATION> &timeout) const
{
- marl::lock lock(mutex);
- return condition.wait_until(lock, timeout, [this]() REQUIRES(mutex) { return count_ == 0; });
+ return ev.wait_until(timeout);
}
- // count() returns the number of times add() has been called without a call
- // to done().
- // Note: No lock is held after count() returns, so the count may immediately
- // change after returning.
- int32_t count()
- {
- marl::lock lock(mutex);
- return count_;
- }
-
- // TaskEvents compliance
- void start() override { add(); }
- void finish() override { done(); }
+ // event() returns the internal marl event.
+ const marl::Event &event() { return ev; }
private:
- marl::mutex mutex;
- int32_t count_ GUARDED_BY(mutex) = 0;
- std::condition_variable condition;
+ const marl::WaitGroup wg;
+ const marl::Event ev;
};
// Chan is a thread-safe FIFO queue of type T.
diff --git a/src/Vulkan/VkCommandBuffer.hpp b/src/Vulkan/VkCommandBuffer.hpp
index aeaba58..e08f3ec 100644
--- a/src/Vulkan/VkCommandBuffer.hpp
+++ b/src/Vulkan/VkCommandBuffer.hpp
@@ -19,6 +19,7 @@
#include "VkDescriptorSet.hpp"
#include "VkObject.hpp"
#include "Device/Context.hpp"
+#include "System/Synchronization.hpp"
#include <memory>
#include <vector>
@@ -27,7 +28,6 @@
class Context;
class Renderer;
-class TaskEvents;
} // namespace sw
@@ -150,7 +150,7 @@
};
sw::Renderer *renderer = nullptr;
- sw::TaskEvents *events = nullptr;
+ sw::CountedEvent *events = nullptr;
RenderPass *renderPass = nullptr;
Framebuffer *renderPassFramebuffer = nullptr;
std::array<PipelineState, vk::VK_PIPELINE_BIND_POINT_RANGE_SIZE> pipelineState;
diff --git a/src/Vulkan/VkDevice.cpp b/src/Vulkan/VkDevice.cpp
index e673abe..7b65016 100644
--- a/src/Vulkan/VkDevice.cpp
+++ b/src/Vulkan/VkDevice.cpp
@@ -246,7 +246,7 @@
marl::containers::vector<marl::Event, 8> events;
for(uint32_t i = 0; i < fenceCount; i++)
{
- events.push_back(Cast(pFences[i])->getEvent());
+ events.push_back(Cast(pFences[i])->getCountedEvent()->event());
}
auto any = marl::Event::any(events.begin(), events.end());
diff --git a/src/Vulkan/VkFence.hpp b/src/Vulkan/VkFence.hpp
index 086eff4..e3bfcff 100644
--- a/src/Vulkan/VkFence.hpp
+++ b/src/Vulkan/VkFence.hpp
@@ -18,17 +18,13 @@
#include "VkObject.hpp"
#include "System/Synchronization.hpp"
-#include "marl/containers.h"
-#include "marl/event.h"
-#include "marl/waitgroup.h"
-
namespace vk {
-class Fence : public Object<Fence, VkFence>, public sw::TaskEvents
+class Fence : public Object<Fence, VkFence>
{
public:
Fence(const VkFenceCreateInfo *pCreateInfo, void *mem)
- : event(marl::Event::Mode::Manual, (pCreateInfo->flags & VK_FENCE_CREATE_SIGNALED_BIT) != 0)
+ : counted_event(std::make_shared<sw::CountedEvent>((pCreateInfo->flags & VK_FENCE_CREATE_SIGNALED_BIT) != 0))
{}
static size_t ComputeRequiredAllocationSize(const VkFenceCreateInfo *pCreateInfo)
@@ -38,49 +34,38 @@
void reset()
{
- event.clear();
+ counted_event->reset();
+ }
+
+ void complete()
+ {
+ counted_event->add();
+ counted_event->done();
}
VkResult getStatus()
{
- return event.isSignalled() ? VK_SUCCESS : VK_NOT_READY;
+ return counted_event->signalled() ? VK_SUCCESS : VK_NOT_READY;
}
VkResult wait()
{
- event.wait();
+ counted_event->wait();
return VK_SUCCESS;
}
template<class CLOCK, class DURATION>
VkResult wait(const std::chrono::time_point<CLOCK, DURATION> &timeout)
{
- return event.wait_until(timeout) ? VK_SUCCESS : VK_TIMEOUT;
+ return counted_event->wait(timeout) ? VK_SUCCESS : VK_TIMEOUT;
}
- const marl::Event &getEvent() const { return event; }
-
- // TaskEvents compliance
- void start() override
- {
- ASSERT(!event.isSignalled());
- wg.add();
- }
-
- void finish() override
- {
- ASSERT(!event.isSignalled());
- if(wg.done())
- {
- event.signal();
- }
- }
+ const std::shared_ptr<sw::CountedEvent> &getCountedEvent() const { return counted_event; };
private:
Fence(const Fence &) = delete;
- marl::WaitGroup wg;
- const marl::Event event;
+ const std::shared_ptr<sw::CountedEvent> counted_event;
};
static inline Fence *Cast(VkFence object)
diff --git a/src/Vulkan/VkQueue.cpp b/src/Vulkan/VkQueue.cpp
index ccb8cfe..fd8bc6d 100644
--- a/src/Vulkan/VkQueue.cpp
+++ b/src/Vulkan/VkQueue.cpp
@@ -102,11 +102,10 @@
Task task;
task.submitCount = submitCount;
task.pSubmits = DeepCopySubmitInfo(submitCount, pSubmits);
- task.events = fence;
-
- if(task.events)
+ if(fence)
{
- task.events->start();
+ task.events = fence->getCountedEvent();
+ task.events->add();
}
pending.put(task);
@@ -132,7 +131,7 @@
{
CommandBuffer::ExecutionState executionState;
executionState.renderer = renderer.get();
- executionState.events = task.events;
+ executionState.events = task.events.get();
for(uint32_t j = 0; j < submitInfo.commandBufferCount; j++)
{
vk::Cast(submitInfo.pCommandBuffers[j])->submit(executionState);
@@ -155,7 +154,7 @@
// TODO: fix renderer signaling so that work submitted separately from (but before) a fence
// is guaranteed complete by the time the fence signals.
renderer->synchronize();
- task.events->finish();
+ task.events->done();
}
}
@@ -187,14 +186,14 @@
VkResult Queue::waitIdle()
{
// Wait for task queue to flush.
- sw::WaitGroup wg;
- wg.add();
+ auto event = std::make_shared<sw::CountedEvent>();
+ event->add(); // done() is called at the end of submitQueue()
Task task;
- task.events = &wg;
+ task.events = event;
pending.put(task);
- wg.wait();
+ event->wait();
garbageCollect();
diff --git a/src/Vulkan/VkQueue.hpp b/src/Vulkan/VkQueue.hpp
index a436c8a..08ff88c 100644
--- a/src/Vulkan/VkQueue.hpp
+++ b/src/Vulkan/VkQueue.hpp
@@ -65,7 +65,7 @@
{
uint32_t submitCount = 0;
VkSubmitInfo *pSubmits = nullptr;
- sw::TaskEvents *events = nullptr;
+ std::shared_ptr<sw::CountedEvent> events;
enum Type
{
diff --git a/tests/SystemUnitTests/BUILD.gn b/tests/SystemUnitTests/BUILD.gn
index f33d4d0..b3256e0 100644
--- a/tests/SystemUnitTests/BUILD.gn
+++ b/tests/SystemUnitTests/BUILD.gn
@@ -21,12 +21,14 @@
"//testing/gmock",
"//testing/gtest",
"../../src/System",
+ "../../third_party/marl:Marl",
]
sources = [
"//gpu/swiftshader_tests_main.cc",
"LRUCacheTests.cpp",
"unittests.cpp",
+ "SynchronizationTests.cpp",
]
include_dirs = [
diff --git a/tests/SystemUnitTests/CMakeLists.txt b/tests/SystemUnitTests/CMakeLists.txt
index f6c1f1f..6871d53 100644
--- a/tests/SystemUnitTests/CMakeLists.txt
+++ b/tests/SystemUnitTests/CMakeLists.txt
@@ -26,6 +26,7 @@
LRUCacheTests.cpp
main.cpp
unittests.cpp
+ SynchronizationTests.cpp
)
add_executable(system-unittests
diff --git a/tests/SystemUnitTests/SynchronizationTests.cpp b/tests/SystemUnitTests/SynchronizationTests.cpp
new file mode 100644
index 0000000..66c4f99
--- /dev/null
+++ b/tests/SystemUnitTests/SynchronizationTests.cpp
@@ -0,0 +1,93 @@
+// Copyright 2020 The SwiftShader Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "System/Synchronization.hpp"
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+
+#include <thread>
+
+TEST(EventCounter, ConstructUnsignalled)
+{
+ sw::CountedEvent ev;
+ ASSERT_FALSE(ev.signalled());
+}
+
+TEST(EventCounter, ConstructSignalled)
+{
+ sw::CountedEvent ev(true);
+ ASSERT_TRUE(ev.signalled());
+}
+
+TEST(EventCounter, Reset)
+{
+ sw::CountedEvent ev(true);
+ ev.reset();
+ ASSERT_FALSE(ev.signalled());
+}
+
+TEST(EventCounter, AddUnsignalled)
+{
+ sw::CountedEvent ev;
+ ev.add();
+ ASSERT_FALSE(ev.signalled());
+}
+
+TEST(EventCounter, AddDoneUnsignalled)
+{
+ sw::CountedEvent ev;
+ ev.add();
+ ev.done();
+ ASSERT_TRUE(ev.signalled());
+}
+
+TEST(EventCounter, Wait)
+{
+ sw::CountedEvent ev;
+ bool b = false;
+
+ ev.add();
+ auto t = std::thread([=, &b] {
+ b = true;
+ ev.done();
+ });
+
+ ev.wait();
+ ASSERT_TRUE(b);
+ t.join();
+}
+
+TEST(EventCounter, WaitNoTimeout)
+{
+ sw::CountedEvent ev;
+ bool b = false;
+
+ ev.add();
+ auto t = std::thread([=, &b] {
+ b = true;
+ ev.done();
+ });
+
+ ASSERT_TRUE(ev.wait(std::chrono::system_clock::now() + std::chrono::seconds(10)));
+ ASSERT_TRUE(b);
+ t.join();
+}
+
+TEST(EventCounter, WaitTimeout)
+{
+ sw::CountedEvent ev;
+ ev.add();
+ ASSERT_FALSE(ev.wait(std::chrono::system_clock::now() + std::chrono::milliseconds(1)));
+}