Emulate gather/scatter for MSan builds

MemorySanitizer doesn't support instrumenting masked_gather and
masked_scatter LLVM intrinsics. Its visitIntrinsicInst() method ends up
calling handleUnknownIntrinsic(), which silently doesn't handle it and
subsequently visitInstruction() checks all operands for poisoned bits.
In the case of a scatter, a 0 bit in the mask means the corresponding
element doesn't get written, so it doesn't matter if it's uninitialized
data. The current implementation leads to false positives.

Work around it by emulating gather and scatter as element-wise loads and
stores. This can be correctly instrumented by MemorySanitizer.

Note this change has no effect currently since we don't support MSan
instrumentation for Reactor yet. We just unpoison all stores. Previously
we did that in element-wise manner after the intrinsic executes. Now
it's done as part of the element stores.

Bug: b/155148722
Change-Id: I9058cd926667fb6df5d9626bc87fb2d0a596771b
Reviewed-on: https://swiftshader-review.googlesource.com/c/SwiftShader/+/49809
Tested-by: Nicolas Capens <nicolascapens@google.com>
Reviewed-by: Antonio Maiorano <amaiorano@google.com>
diff --git a/src/Reactor/LLVMReactor.cpp b/src/Reactor/LLVMReactor.cpp
index 4ab185b..634c8a9 100644
--- a/src/Reactor/LLVMReactor.cpp
+++ b/src/Reactor/LLVMReactor.cpp
@@ -355,84 +355,6 @@
 	return jit->builder->CreateTrunc(mulh, ty);
 }
 
-llvm::Value *createGather(llvm::Value *base, llvm::Type *elTy, llvm::Value *offsets, llvm::Value *mask, unsigned int alignment, bool zeroMaskedLanes)
-{
-	ASSERT(base->getType()->isPointerTy());
-	ASSERT(offsets->getType()->isVectorTy());
-	ASSERT(mask->getType()->isVectorTy());
-
-	auto numEls = llvm::cast<llvm::VectorType>(mask->getType())->getNumElements();
-	auto i1Ty = ::llvm::Type::getInt1Ty(jit->context);
-	auto i32Ty = ::llvm::Type::getInt32Ty(jit->context);
-	auto i8Ty = ::llvm::Type::getInt8Ty(jit->context);
-	auto i8PtrTy = i8Ty->getPointerTo();
-	auto elPtrTy = elTy->getPointerTo();
-	auto elVecTy = ::llvm::VectorType::get(elTy, numEls, false);
-	auto elPtrVecTy = ::llvm::VectorType::get(elPtrTy, numEls, false);
-	auto i8Base = jit->builder->CreatePointerCast(base, i8PtrTy);
-	auto i8Ptrs = jit->builder->CreateGEP(i8Base, offsets);
-	auto elPtrs = jit->builder->CreatePointerCast(i8Ptrs, elPtrVecTy);
-	auto i8Mask = jit->builder->CreateIntCast(mask, ::llvm::VectorType::get(i1Ty, numEls, false), false);  // vec<int, int, ...> -> vec<bool, bool, ...>
-	auto passthrough = zeroMaskedLanes ? ::llvm::Constant::getNullValue(elVecTy) : llvm::UndefValue::get(elVecTy);
-	auto align = ::llvm::ConstantInt::get(i32Ty, alignment);
-	auto func = ::llvm::Intrinsic::getDeclaration(jit->module.get(), llvm::Intrinsic::masked_gather, { elVecTy, elPtrVecTy });
-	return jit->builder->CreateCall(func, { elPtrs, align, i8Mask, passthrough });
-}
-
-void createScatter(llvm::Value *base, llvm::Value *val, llvm::Value *offsets, llvm::Value *mask, unsigned int alignment)
-{
-	ASSERT(base->getType()->isPointerTy());
-	ASSERT(val->getType()->isVectorTy());
-	ASSERT(offsets->getType()->isVectorTy());
-	ASSERT(mask->getType()->isVectorTy());
-
-	auto numEls = llvm::cast<llvm::VectorType>(mask->getType())->getNumElements();
-	auto i1Ty = ::llvm::Type::getInt1Ty(jit->context);
-	auto i32Ty = ::llvm::Type::getInt32Ty(jit->context);
-	auto i8Ty = ::llvm::Type::getInt8Ty(jit->context);
-	auto i8PtrTy = i8Ty->getPointerTo();
-	auto elVecTy = val->getType();
-	auto elTy = llvm::cast<llvm::VectorType>(elVecTy)->getElementType();
-	auto elPtrTy = elTy->getPointerTo();
-	auto elPtrVecTy = ::llvm::VectorType::get(elPtrTy, numEls, false);
-	auto i8Base = jit->builder->CreatePointerCast(base, i8PtrTy);
-	auto i8Ptrs = jit->builder->CreateGEP(i8Base, offsets);
-	auto elPtrs = jit->builder->CreatePointerCast(i8Ptrs, elPtrVecTy);
-	auto i1Mask = jit->builder->CreateIntCast(mask, ::llvm::VectorType::get(i1Ty, numEls, false), false);  // vec<int, int, ...> -> vec<bool, bool, ...>
-	auto align = ::llvm::ConstantInt::get(i32Ty, alignment);
-	auto func = ::llvm::Intrinsic::getDeclaration(jit->module.get(), llvm::Intrinsic::masked_scatter, { elVecTy, elPtrVecTy });
-	jit->builder->CreateCall(func, { val, elPtrs, align, i1Mask });
-
-#if __has_feature(memory_sanitizer)
-	// Mark memory writes as initialized by calling __msan_unpoison
-	{
-		// void __msan_unpoison(const volatile void *a, size_t size)
-		auto voidTy = ::llvm::Type::getVoidTy(jit->context);
-		auto int8Ty = ::llvm::Type::getInt8Ty(jit->context);
-		auto int8PtrTy = int8Ty->getPointerTo();
-		auto sizetTy = ::llvm::IntegerType::get(jit->context, sizeof(size_t) * 8);
-		auto funcTy = ::llvm::FunctionType::get(voidTy, { int8PtrTy, sizetTy }, false);
-		auto func = jit->module->getOrInsertFunction("__msan_unpoison", funcTy);
-		auto size = jit->module->getDataLayout().getTypeStoreSize(elTy);
-		for(unsigned i = 0; i < numEls; i++)
-		{
-			// Check mask for this element
-			auto idx = ::llvm::ConstantInt::get(i32Ty, i);
-			auto thenBlock = ::llvm::BasicBlock::Create(jit->context, "", jit->function);
-			auto mergeBlock = ::llvm::BasicBlock::Create(jit->context, "", jit->function);
-			jit->builder->CreateCondBr(jit->builder->CreateExtractElement(i1Mask, idx), thenBlock, mergeBlock);
-			jit->builder->SetInsertPoint(thenBlock);
-
-			// Insert __msan_unpoison call in conditional block
-			auto elPtr = jit->builder->CreateExtractElement(elPtrs, idx);
-			jit->builder->CreateCall(func, { jit->builder->CreatePointerCast(elPtr, int8PtrTy),
-			                                 ::llvm::ConstantInt::get(sizetTy, size) });
-			jit->builder->CreateBr(mergeBlock);
-			jit->builder->SetInsertPoint(mergeBlock);
-		}
-	}
-#endif
-}
 }  // namespace
 
 namespace rr {
@@ -1044,9 +966,9 @@
 			auto elTy = T(type);
 			ASSERT(V(ptr)->getType()->getContainedType(0) == elTy);
 
-#if __has_feature(memory_sanitizer)
-			// Mark all memory writes as initialized by calling __msan_unpoison
+			if(__has_feature(memory_sanitizer))
 			{
+				// Mark all memory writes as initialized by calling __msan_unpoison
 				// void __msan_unpoison(const volatile void *a, size_t size)
 				auto voidTy = ::llvm::Type::getVoidTy(jit->context);
 				auto i8Ty = ::llvm::Type::getInt8Ty(jit->context);
@@ -1055,10 +977,10 @@
 				auto funcTy = ::llvm::FunctionType::get(voidTy, { voidPtrTy, sizetTy }, false);
 				auto func = jit->module->getOrInsertFunction("__msan_unpoison", funcTy);
 				auto size = jit->module->getDataLayout().getTypeStoreSize(elTy);
+
 				jit->builder->CreateCall(func, { jit->builder->CreatePointerCast(V(ptr), voidPtrTy),
 				                                 ::llvm::ConstantInt::get(sizetTy, size) });
 			}
-#endif
 
 			if(!atomic)
 			{
@@ -1150,9 +1072,9 @@
 	auto func = ::llvm::Intrinsic::getDeclaration(jit->module.get(), llvm::Intrinsic::masked_store, { elVecTy, elVecPtrTy });
 	jit->builder->CreateCall(func, { V(val), V(ptr), align, i1Mask });
 
-#if __has_feature(memory_sanitizer)
-	// Mark memory writes as initialized by calling __msan_unpoison
+	if(__has_feature(memory_sanitizer))
 	{
+		// Mark memory writes as initialized by calling __msan_unpoison
 		// void __msan_unpoison(const volatile void *a, size_t size)
 		auto voidTy = ::llvm::Type::getVoidTy(jit->context);
 		auto voidPtrTy = voidTy->getPointerTo();
@@ -1160,6 +1082,7 @@
 		auto funcTy = ::llvm::FunctionType::get(voidTy, { voidPtrTy, sizetTy }, false);
 		auto func = jit->module->getOrInsertFunction("__msan_unpoison", funcTy);
 		auto size = jit->module->getDataLayout().getTypeStoreSize(llvm::cast<llvm::VectorType>(elVecTy)->getElementType());
+
 		for(unsigned i = 0; i < numEls; i++)
 		{
 			// Check mask for this element
@@ -1173,11 +1096,66 @@
 			auto elPtr = jit->builder->CreateGEP(V(ptr), idx);
 			jit->builder->CreateCall(func, { jit->builder->CreatePointerCast(elPtr, voidPtrTy),
 			                                 ::llvm::ConstantInt::get(sizetTy, size) });
+
 			jit->builder->CreateBr(mergeBlock);
 			jit->builder->SetInsertPoint(mergeBlock);
 		}
 	}
-#endif
+}  // namespace rr
+
+static llvm::Value *createGather(llvm::Value *base, llvm::Type *elTy, llvm::Value *offsets, llvm::Value *mask, unsigned int alignment, bool zeroMaskedLanes)
+{
+	ASSERT(base->getType()->isPointerTy());
+	ASSERT(offsets->getType()->isVectorTy());
+	ASSERT(mask->getType()->isVectorTy());
+
+	auto numEls = llvm::cast<llvm::VectorType>(mask->getType())->getNumElements();
+	auto i1Ty = ::llvm::Type::getInt1Ty(jit->context);
+	auto i32Ty = ::llvm::Type::getInt32Ty(jit->context);
+	auto i8Ty = ::llvm::Type::getInt8Ty(jit->context);
+	auto i8PtrTy = i8Ty->getPointerTo();
+	auto elPtrTy = elTy->getPointerTo();
+	auto elVecTy = ::llvm::VectorType::get(elTy, numEls, false);
+	auto elPtrVecTy = ::llvm::VectorType::get(elPtrTy, numEls, false);
+	auto i8Base = jit->builder->CreatePointerCast(base, i8PtrTy);
+	auto i8Ptrs = jit->builder->CreateGEP(i8Base, offsets);
+	auto elPtrs = jit->builder->CreatePointerCast(i8Ptrs, elPtrVecTy);
+	auto i1Mask = jit->builder->CreateIntCast(mask, ::llvm::VectorType::get(i1Ty, numEls, false), false);  // vec<int, int, ...> -> vec<bool, bool, ...>
+	auto passthrough = zeroMaskedLanes ? ::llvm::Constant::getNullValue(elVecTy) : llvm::UndefValue::get(elVecTy);
+
+	if(!__has_feature(memory_sanitizer))
+	{
+		auto align = ::llvm::ConstantInt::get(i32Ty, alignment);
+		auto func = ::llvm::Intrinsic::getDeclaration(jit->module.get(), llvm::Intrinsic::masked_gather, { elVecTy, elPtrVecTy });
+		return jit->builder->CreateCall(func, { elPtrs, align, i1Mask, passthrough });
+	}
+	else  // __has_feature(memory_sanitizer)
+	{
+		// MemorySanitizer currently does not support instrumenting llvm::Intrinsic::masked_gather
+		// Work around it by emulating gather with element-wise loads.
+		// TODO(b/172238865): Remove when supported by MemorySanitizer.
+
+		Value *result = Nucleus::allocateStackVariable(T(elVecTy));
+		Nucleus::createStore(V(passthrough), result, T(elVecTy));
+
+		for(unsigned i = 0; i < numEls; i++)
+		{
+			// Check mask for this element
+			Value *elementMask = Nucleus::createExtractElement(V(i1Mask), T(i1Ty), i);
+
+			If(RValue<Bool>(elementMask))
+			{
+				Value *elPtr = Nucleus::createExtractElement(V(elPtrs), T(elPtrTy), i);
+				Value *el = Nucleus::createLoad(elPtr, T(elTy), /*isVolatile */ false, alignment, /* atomic */ false, std::memory_order_relaxed);
+
+				Value *v = Nucleus::createLoad(result, T(elVecTy));
+				v = Nucleus::createInsertElement(v, el, i);
+				Nucleus::createStore(v, result, T(elVecTy));
+			}
+		}
+
+		return V(Nucleus::createLoad(result, T(elVecTy)));
+	}
 }
 
 RValue<Float4> Gather(RValue<Pointer<Float>> base, RValue<Int4> offsets, RValue<Int4> mask, unsigned int alignment, bool zeroMaskedLanes /* = false */)
@@ -1187,7 +1165,60 @@
 
 RValue<Int4> Gather(RValue<Pointer<Int>> base, RValue<Int4> offsets, RValue<Int4> mask, unsigned int alignment, bool zeroMaskedLanes /* = false */)
 {
-	return As<Int4>(V(createGather(V(base.value()), T(Float::type()), V(offsets.value()), V(mask.value()), alignment, zeroMaskedLanes)));
+	return As<Int4>(V(createGather(V(base.value()), T(Int::type()), V(offsets.value()), V(mask.value()), alignment, zeroMaskedLanes)));
+}
+
+static void createScatter(llvm::Value *base, llvm::Value *val, llvm::Value *offsets, llvm::Value *mask, unsigned int alignment)
+{
+	ASSERT(base->getType()->isPointerTy());
+	ASSERT(val->getType()->isVectorTy());
+	ASSERT(offsets->getType()->isVectorTy());
+	ASSERT(mask->getType()->isVectorTy());
+
+	auto numEls = llvm::cast<llvm::VectorType>(mask->getType())->getNumElements();
+	auto i1Ty = ::llvm::Type::getInt1Ty(jit->context);
+	auto i32Ty = ::llvm::Type::getInt32Ty(jit->context);
+	auto i8Ty = ::llvm::Type::getInt8Ty(jit->context);
+	auto i8PtrTy = i8Ty->getPointerTo();
+	auto elVecTy = val->getType();
+	auto elTy = llvm::cast<llvm::VectorType>(elVecTy)->getElementType();
+	auto elPtrTy = elTy->getPointerTo();
+	auto elPtrVecTy = ::llvm::VectorType::get(elPtrTy, numEls, false);
+
+	auto i8Base = jit->builder->CreatePointerCast(base, i8PtrTy);
+	auto i8Ptrs = jit->builder->CreateGEP(i8Base, offsets);
+	auto elPtrs = jit->builder->CreatePointerCast(i8Ptrs, elPtrVecTy);
+	auto i1Mask = jit->builder->CreateIntCast(mask, ::llvm::VectorType::get(i1Ty, numEls, false), false);  // vec<int, int, ...> -> vec<bool, bool, ...>
+
+	if(!__has_feature(memory_sanitizer))
+	{
+		auto align = ::llvm::ConstantInt::get(i32Ty, alignment);
+		auto func = ::llvm::Intrinsic::getDeclaration(jit->module.get(), llvm::Intrinsic::masked_scatter, { elVecTy, elPtrVecTy });
+		jit->builder->CreateCall(func, { val, elPtrs, align, i1Mask });
+	}
+	else  // __has_feature(memory_sanitizer)
+	{
+		// MemorySanitizer currently does not support instrumenting llvm::Intrinsic::masked_scatter
+		// Work around it by emulating scatter with element-wise stores.
+		// TODO(b/172238865): Remove when supported by MemorySanitizer.
+
+		for(unsigned i = 0; i < numEls; i++)
+		{
+			// Check mask for this element
+			auto idx = ::llvm::ConstantInt::get(i32Ty, i);
+			auto thenBlock = ::llvm::BasicBlock::Create(jit->context, "", jit->function);
+			auto mergeBlock = ::llvm::BasicBlock::Create(jit->context, "", jit->function);
+			jit->builder->CreateCondBr(jit->builder->CreateExtractElement(i1Mask, idx), thenBlock, mergeBlock);
+			jit->builder->SetInsertPoint(thenBlock);
+
+			auto el = jit->builder->CreateExtractElement(val, idx);
+			auto elPtr = jit->builder->CreateExtractElement(elPtrs, idx);
+			Nucleus::createStore(V(el), V(elPtr), T(elTy), /*isVolatile */ false, alignment, /* atomic */ false, std::memory_order_relaxed);
+
+			jit->builder->CreateBr(mergeBlock);
+			jit->builder->SetInsertPoint(mergeBlock);
+		}
+	}
 }
 
 void Scatter(RValue<Pointer<Float>> base, RValue<Float4> val, RValue<Int4> offsets, RValue<Int4> mask, unsigned int alignment)