Vulkan: Fix use-after-destruction of VkFence event

Remove sw::TaskEvents and sw::WaitGroup, replace this with sw::CountedEvent.

See b/173784261 for details.

Fixes: b/173784261
Change-Id: I21fb69c810558a1929bba5cc46f106d9d4e51c4b
Reviewed-on: https://swiftshader-review.googlesource.com/c/SwiftShader/+/50628
Kokoro-Result: kokoro <noreply+kokoro@google.com>
Reviewed-by: Chris Forbes <chrisforbes@google.com>
Reviewed-by: Nicolas Capens <nicolascapens@google.com>
Tested-by: Ben Clayton <bclayton@google.com>
Presubmit-Ready: Ben Clayton <bclayton@google.com>
diff --git a/src/Device/Renderer.cpp b/src/Device/Renderer.cpp
index 262446a..44a6fe8 100644
--- a/src/Device/Renderer.cpp
+++ b/src/Device/Renderer.cpp
@@ -181,7 +181,7 @@
 }
 
 void Renderer::draw(const sw::Context *context, VkIndexType indexType, unsigned int count, int baseVertex,
-                    TaskEvents *events, int instanceID, int viewID, void *indexBuffer, const VkExtent3D &framebufferExtent,
+                    CountedEvent *events, int instanceID, int viewID, void *indexBuffer, const VkExtent3D &framebufferExtent,
                     PushConstantStorage const &pushConstants, bool update)
 {
 	if(count == 0) { return; }
@@ -421,7 +421,7 @@
 
 	if(events)
 	{
-		events->start();
+		events->add();
 	}
 }
 
@@ -429,7 +429,7 @@
 {
 	if(events)
 	{
-		events->finish();
+		events->done();
 		events = nullptr;
 	}
 
diff --git a/src/Device/Renderer.hpp b/src/Device/Renderer.hpp
index ab19a8b..57eb7ba 100644
--- a/src/Device/Renderer.hpp
+++ b/src/Device/Renderer.hpp
@@ -43,11 +43,11 @@
 
 namespace sw {
 
+class CountedEvent;
 struct DrawCall;
 class PixelShader;
 class VertexShader;
 struct Task;
-class TaskEvents;
 class Resource;
 struct Constants;
 
@@ -172,7 +172,7 @@
 	vk::ImageView *stencilBuffer;
 	vk::DescriptorSet::Array descriptorSetObjects;
 	const vk::PipelineLayout *pipelineLayout;
-	TaskEvents *events;
+	sw::CountedEvent *events;
 
 	vk::Query *occlusionQuery;
 
@@ -210,7 +210,7 @@
 	bool hasOcclusionQuery() const { return occlusionQuery != nullptr; }
 
 	void draw(const sw::Context *context, VkIndexType indexType, unsigned int count, int baseVertex,
-	          TaskEvents *events, int instanceID, int viewID, void *indexBuffer, const VkExtent3D &framebufferExtent,
+	          CountedEvent *events, int instanceID, int viewID, void *indexBuffer, const VkExtent3D &framebufferExtent,
 	          PushConstantStorage const &pushConstants, bool update = true);
 
 	// Viewport & Clipper
diff --git a/src/System/CMakeLists.txt b/src/System/CMakeLists.txt
index 24aae13..d28feca 100644
--- a/src/System/CMakeLists.txt
+++ b/src/System/CMakeLists.txt
@@ -68,6 +68,11 @@
         ${ROOT_PROJECT_COMPILE_OPTIONS}
 )
 
+target_link_libraries(vk_system
+    PUBLIC
+        marl
+)
+
 target_link_options(vk_system
     PUBLIC
         ${SWIFTSHADER_LINK_FLAGS}
diff --git a/src/System/Synchronization.hpp b/src/System/Synchronization.hpp
index 81794ab..8e9d63b 100644
--- a/src/System/Synchronization.hpp
+++ b/src/System/Synchronization.hpp
@@ -22,103 +22,85 @@
 #ifndef sw_Synchronization_hpp
 #define sw_Synchronization_hpp
 
+#include "Debug.hpp"
+
 #include <assert.h>
 #include <chrono>
 #include <condition_variable>
 #include <queue>
 
+#include "marl/event.h"
 #include "marl/mutex.h"
+#include "marl/waitgroup.h"
 
 namespace sw {
 
-// TaskEvents is an interface for notifying when tasks begin and end.
-// Tasks can be nested and/or overlapping.
-// TaskEvents is used for task queue synchronization.
-class TaskEvents
+// CountedEvent is an event that is signalled when the internal counter is
+// decremented and reaches zero.
+// The counter is incremented with calls to add() and decremented with calls to
+// done().
+class CountedEvent
 {
 public:
-	// start() is called before a task begins.
-	virtual void start() = 0;
-	// finish() is called after a task ends. finish() must only be called after
-	// a corresponding call to start().
-	virtual void finish() = 0;
-	// complete() is a helper for calling start() followed by finish().
-	inline void complete()
+	// Constructs the CountedEvent with the initial signalled state set to the
+	// provided value.
+	CountedEvent(bool signalled = false)
+	    : ev(marl::Event::Mode::Manual, signalled)
+	{}
+
+	// add() increments the internal counter.
+	// add() must not be called when the event is already signalled.
+	void add() const
 	{
-		start();
-		finish();
+		ASSERT(!ev.isSignalled());
+		wg.add();
 	}
 
-protected:
-	virtual ~TaskEvents() = default;
-};
-
-// WaitGroup is a synchronization primitive that allows you to wait for
-// collection of asynchronous tasks to finish executing.
-// Call add() before each task begins, and then call done() when after each task
-// is finished.
-// At the same time, wait() can be used to block until all tasks have finished.
-// WaitGroup takes its name after Golang's sync.WaitGroup.
-class WaitGroup : public TaskEvents
-{
-public:
-	// add() begins a new task.
-	void add()
+	// done() decrements the internal counter, signalling the event if the new
+	// counter value is zero.
+	// done() must not be called when the event is already signalled.
+	void done() const
 	{
-		marl::lock lock(mutex);
-		++count_;
-	}
-
-	// done() is called when a task of the WaitGroup has been completed.
-	// Returns true if there are no more tasks currently running in the
-	// WaitGroup.
-	bool done()
-	{
-		marl::lock lock(mutex);
-		assert(count_ > 0);
-		--count_;
-		if(count_ == 0)
+		ASSERT(!ev.isSignalled());
+		if(wg.done())
 		{
-			condition.notify_all();
+			ev.signal();
 		}
-		return count_ == 0;
 	}
 
-	// wait() blocks until all the tasks have been finished.
-	void wait()
+	// reset() clears the signal state.
+	// done() must not be called when the internal counter is non-zero.
+	void reset() const
 	{
-		marl::lock lock(mutex);
-		lock.wait(condition, [this]() REQUIRES(mutex) { return count_ == 0; });
+		ev.clear();
 	}
 
-	// wait() blocks until all the tasks have been finished or the timeout
-	// has been reached, returning true if all tasks have been completed, or
-	// false if the timeout has been reached.
+	// signalled() returns the current signal state.
+	bool signalled() const
+	{
+		return ev.isSignalled();
+	}
+
+	// wait() waits until the event is signalled.
+	void wait() const
+	{
+		ev.wait();
+	}
+
+	// wait() waits until the event is signalled or the timeout is reached.
+	// If the timeout was reached, then wait() return false.
 	template<class CLOCK, class DURATION>
-	bool wait(const std::chrono::time_point<CLOCK, DURATION> &timeout)
+	bool wait(const std::chrono::time_point<CLOCK, DURATION> &timeout) const
 	{
-		marl::lock lock(mutex);
-		return condition.wait_until(lock, timeout, [this]() REQUIRES(mutex) { return count_ == 0; });
+		return ev.wait_until(timeout);
 	}
 
-	// count() returns the number of times add() has been called without a call
-	// to done().
-	// Note: No lock is held after count() returns, so the count may immediately
-	// change after returning.
-	int32_t count()
-	{
-		marl::lock lock(mutex);
-		return count_;
-	}
-
-	// TaskEvents compliance
-	void start() override { add(); }
-	void finish() override { done(); }
+	// event() returns the internal marl event.
+	const marl::Event &event() { return ev; }
 
 private:
-	marl::mutex mutex;
-	int32_t count_ GUARDED_BY(mutex) = 0;
-	std::condition_variable condition;
+	const marl::WaitGroup wg;
+	const marl::Event ev;
 };
 
 // Chan is a thread-safe FIFO queue of type T.
diff --git a/src/Vulkan/VkCommandBuffer.hpp b/src/Vulkan/VkCommandBuffer.hpp
index aeaba58..e08f3ec 100644
--- a/src/Vulkan/VkCommandBuffer.hpp
+++ b/src/Vulkan/VkCommandBuffer.hpp
@@ -19,6 +19,7 @@
 #include "VkDescriptorSet.hpp"
 #include "VkObject.hpp"
 #include "Device/Context.hpp"
+#include "System/Synchronization.hpp"
 
 #include <memory>
 #include <vector>
@@ -27,7 +28,6 @@
 
 class Context;
 class Renderer;
-class TaskEvents;
 
 }  // namespace sw
 
@@ -150,7 +150,7 @@
 		};
 
 		sw::Renderer *renderer = nullptr;
-		sw::TaskEvents *events = nullptr;
+		sw::CountedEvent *events = nullptr;
 		RenderPass *renderPass = nullptr;
 		Framebuffer *renderPassFramebuffer = nullptr;
 		std::array<PipelineState, vk::VK_PIPELINE_BIND_POINT_RANGE_SIZE> pipelineState;
diff --git a/src/Vulkan/VkDevice.cpp b/src/Vulkan/VkDevice.cpp
index e673abe..7b65016 100644
--- a/src/Vulkan/VkDevice.cpp
+++ b/src/Vulkan/VkDevice.cpp
@@ -246,7 +246,7 @@
 		marl::containers::vector<marl::Event, 8> events;
 		for(uint32_t i = 0; i < fenceCount; i++)
 		{
-			events.push_back(Cast(pFences[i])->getEvent());
+			events.push_back(Cast(pFences[i])->getCountedEvent()->event());
 		}
 
 		auto any = marl::Event::any(events.begin(), events.end());
diff --git a/src/Vulkan/VkFence.hpp b/src/Vulkan/VkFence.hpp
index 086eff4..e3bfcff 100644
--- a/src/Vulkan/VkFence.hpp
+++ b/src/Vulkan/VkFence.hpp
@@ -18,17 +18,13 @@
 #include "VkObject.hpp"
 #include "System/Synchronization.hpp"
 
-#include "marl/containers.h"
-#include "marl/event.h"
-#include "marl/waitgroup.h"
-
 namespace vk {
 
-class Fence : public Object<Fence, VkFence>, public sw::TaskEvents
+class Fence : public Object<Fence, VkFence>
 {
 public:
 	Fence(const VkFenceCreateInfo *pCreateInfo, void *mem)
-	    : event(marl::Event::Mode::Manual, (pCreateInfo->flags & VK_FENCE_CREATE_SIGNALED_BIT) != 0)
+	    : counted_event(std::make_shared<sw::CountedEvent>((pCreateInfo->flags & VK_FENCE_CREATE_SIGNALED_BIT) != 0))
 	{}
 
 	static size_t ComputeRequiredAllocationSize(const VkFenceCreateInfo *pCreateInfo)
@@ -38,49 +34,38 @@
 
 	void reset()
 	{
-		event.clear();
+		counted_event->reset();
+	}
+
+	void complete()
+	{
+		counted_event->add();
+		counted_event->done();
 	}
 
 	VkResult getStatus()
 	{
-		return event.isSignalled() ? VK_SUCCESS : VK_NOT_READY;
+		return counted_event->signalled() ? VK_SUCCESS : VK_NOT_READY;
 	}
 
 	VkResult wait()
 	{
-		event.wait();
+		counted_event->wait();
 		return VK_SUCCESS;
 	}
 
 	template<class CLOCK, class DURATION>
 	VkResult wait(const std::chrono::time_point<CLOCK, DURATION> &timeout)
 	{
-		return event.wait_until(timeout) ? VK_SUCCESS : VK_TIMEOUT;
+		return counted_event->wait(timeout) ? VK_SUCCESS : VK_TIMEOUT;
 	}
 
-	const marl::Event &getEvent() const { return event; }
-
-	// TaskEvents compliance
-	void start() override
-	{
-		ASSERT(!event.isSignalled());
-		wg.add();
-	}
-
-	void finish() override
-	{
-		ASSERT(!event.isSignalled());
-		if(wg.done())
-		{
-			event.signal();
-		}
-	}
+	const std::shared_ptr<sw::CountedEvent> &getCountedEvent() const { return counted_event; };
 
 private:
 	Fence(const Fence &) = delete;
 
-	marl::WaitGroup wg;
-	const marl::Event event;
+	const std::shared_ptr<sw::CountedEvent> counted_event;
 };
 
 static inline Fence *Cast(VkFence object)
diff --git a/src/Vulkan/VkQueue.cpp b/src/Vulkan/VkQueue.cpp
index ccb8cfe..fd8bc6d 100644
--- a/src/Vulkan/VkQueue.cpp
+++ b/src/Vulkan/VkQueue.cpp
@@ -102,11 +102,10 @@
 	Task task;
 	task.submitCount = submitCount;
 	task.pSubmits = DeepCopySubmitInfo(submitCount, pSubmits);
-	task.events = fence;
-
-	if(task.events)
+	if(fence)
 	{
-		task.events->start();
+		task.events = fence->getCountedEvent();
+		task.events->add();
 	}
 
 	pending.put(task);
@@ -132,7 +131,7 @@
 		{
 			CommandBuffer::ExecutionState executionState;
 			executionState.renderer = renderer.get();
-			executionState.events = task.events;
+			executionState.events = task.events.get();
 			for(uint32_t j = 0; j < submitInfo.commandBufferCount; j++)
 			{
 				vk::Cast(submitInfo.pCommandBuffers[j])->submit(executionState);
@@ -155,7 +154,7 @@
 		// TODO: fix renderer signaling so that work submitted separately from (but before) a fence
 		// is guaranteed complete by the time the fence signals.
 		renderer->synchronize();
-		task.events->finish();
+		task.events->done();
 	}
 }
 
@@ -187,14 +186,14 @@
 VkResult Queue::waitIdle()
 {
 	// Wait for task queue to flush.
-	sw::WaitGroup wg;
-	wg.add();
+	auto event = std::make_shared<sw::CountedEvent>();
+	event->add();  // done() is called at the end of submitQueue()
 
 	Task task;
-	task.events = &wg;
+	task.events = event;
 	pending.put(task);
 
-	wg.wait();
+	event->wait();
 
 	garbageCollect();
 
diff --git a/src/Vulkan/VkQueue.hpp b/src/Vulkan/VkQueue.hpp
index a436c8a..08ff88c 100644
--- a/src/Vulkan/VkQueue.hpp
+++ b/src/Vulkan/VkQueue.hpp
@@ -65,7 +65,7 @@
 	{
 		uint32_t submitCount = 0;
 		VkSubmitInfo *pSubmits = nullptr;
-		sw::TaskEvents *events = nullptr;
+		std::shared_ptr<sw::CountedEvent> events;
 
 		enum Type
 		{
diff --git a/tests/SystemUnitTests/BUILD.gn b/tests/SystemUnitTests/BUILD.gn
index f33d4d0..b3256e0 100644
--- a/tests/SystemUnitTests/BUILD.gn
+++ b/tests/SystemUnitTests/BUILD.gn
@@ -21,12 +21,14 @@
     "//testing/gmock",
     "//testing/gtest",
     "../../src/System",
+    "../../third_party/marl:Marl",
   ]
 
   sources = [
     "//gpu/swiftshader_tests_main.cc",
     "LRUCacheTests.cpp",
     "unittests.cpp",
+    "SynchronizationTests.cpp",
   ]
 
   include_dirs = [
diff --git a/tests/SystemUnitTests/CMakeLists.txt b/tests/SystemUnitTests/CMakeLists.txt
index f6c1f1f..6871d53 100644
--- a/tests/SystemUnitTests/CMakeLists.txt
+++ b/tests/SystemUnitTests/CMakeLists.txt
@@ -26,6 +26,7 @@
     LRUCacheTests.cpp
     main.cpp
     unittests.cpp
+    SynchronizationTests.cpp
 )
 
 add_executable(system-unittests
diff --git a/tests/SystemUnitTests/SynchronizationTests.cpp b/tests/SystemUnitTests/SynchronizationTests.cpp
new file mode 100644
index 0000000..66c4f99
--- /dev/null
+++ b/tests/SystemUnitTests/SynchronizationTests.cpp
@@ -0,0 +1,93 @@
+// Copyright 2020 The SwiftShader Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//    http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "System/Synchronization.hpp"
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+
+#include <thread>
+
+TEST(EventCounter, ConstructUnsignalled)
+{
+	sw::CountedEvent ev;
+	ASSERT_FALSE(ev.signalled());
+}
+
+TEST(EventCounter, ConstructSignalled)
+{
+	sw::CountedEvent ev(true);
+	ASSERT_TRUE(ev.signalled());
+}
+
+TEST(EventCounter, Reset)
+{
+	sw::CountedEvent ev(true);
+	ev.reset();
+	ASSERT_FALSE(ev.signalled());
+}
+
+TEST(EventCounter, AddUnsignalled)
+{
+	sw::CountedEvent ev;
+	ev.add();
+	ASSERT_FALSE(ev.signalled());
+}
+
+TEST(EventCounter, AddDoneUnsignalled)
+{
+	sw::CountedEvent ev;
+	ev.add();
+	ev.done();
+	ASSERT_TRUE(ev.signalled());
+}
+
+TEST(EventCounter, Wait)
+{
+	sw::CountedEvent ev;
+	bool b = false;
+
+	ev.add();
+	auto t = std::thread([=, &b] {
+		b = true;
+		ev.done();
+	});
+
+	ev.wait();
+	ASSERT_TRUE(b);
+	t.join();
+}
+
+TEST(EventCounter, WaitNoTimeout)
+{
+	sw::CountedEvent ev;
+	bool b = false;
+
+	ev.add();
+	auto t = std::thread([=, &b] {
+		b = true;
+		ev.done();
+	});
+
+	ASSERT_TRUE(ev.wait(std::chrono::system_clock::now() + std::chrono::seconds(10)));
+	ASSERT_TRUE(b);
+	t.join();
+}
+
+TEST(EventCounter, WaitTimeout)
+{
+	sw::CountedEvent ev;
+	ev.add();
+	ASSERT_FALSE(ev.wait(std::chrono::system_clock::now() + std::chrono::milliseconds(1)));
+}