Insert __msan_unposion for masked/scattered stores

Fixes more MSan LLVM JIT false positives.

Change-Id: I0579e2cc71b089424fe168cc60fed63c4b431f91
Bug: b/140204746
Reviewed-on: https://swiftshader-review.googlesource.com/c/SwiftShader/+/44708
Presubmit-Ready: James Price <jrprice@google.com>
Reviewed-by: Nicolas Capens <nicolascapens@google.com>
Kokoro-Result: kokoro <noreply+kokoro@google.com>
Tested-by: James Price <jrprice@google.com>
diff --git a/src/Reactor/LLVMReactor.cpp b/src/Reactor/LLVMReactor.cpp
index 0d0e55b..4d232bf 100644
--- a/src/Reactor/LLVMReactor.cpp
+++ b/src/Reactor/LLVMReactor.cpp
@@ -459,10 +459,39 @@
 	auto i8Base = jit->builder->CreatePointerCast(base, i8PtrTy);
 	auto i8Ptrs = jit->builder->CreateGEP(i8Base, offsets);
 	auto elPtrs = jit->builder->CreatePointerCast(i8Ptrs, elPtrVecTy);
-	auto i8Mask = jit->builder->CreateIntCast(mask, ::llvm::VectorType::get(i1Ty, numEls), false);  // vec<int, int, ...> -> vec<bool, bool, ...>
+	auto i1Mask = jit->builder->CreateIntCast(mask, ::llvm::VectorType::get(i1Ty, numEls), false);  // vec<int, int, ...> -> vec<bool, bool, ...>
 	auto align = ::llvm::ConstantInt::get(i32Ty, alignment);
 	auto func = ::llvm::Intrinsic::getDeclaration(jit->module.get(), llvm::Intrinsic::masked_scatter, { elVecTy, elPtrVecTy });
-	jit->builder->CreateCall(func, { val, elPtrs, align, i8Mask });
+	jit->builder->CreateCall(func, { val, elPtrs, align, i1Mask });
+
+#if __has_feature(memory_sanitizer)
+	// Mark memory writes as initialized by calling __msan_unpoison
+	{
+		// void __msan_unpoison(const volatile void *a, size_t size)
+		auto voidTy = ::llvm::Type::getVoidTy(jit->context);
+		auto voidPtrTy = voidTy->getPointerTo();
+		auto sizetTy = ::llvm::IntegerType::get(jit->context, sizeof(size_t) * 8);
+		auto funcTy = ::llvm::FunctionType::get(voidTy, { voidPtrTy, sizetTy }, false);
+		auto func = jit->module->getOrInsertFunction("__msan_unpoison", funcTy);
+		auto size = jit->module->getDataLayout().getTypeStoreSize(elTy);
+		for(unsigned i = 0; i < numEls; i++)
+		{
+			// Check mask for this element
+			auto idx = ::llvm::ConstantInt::get(i32Ty, i);
+			auto thenBlock = ::llvm::BasicBlock::Create(jit->context, "", jit->function);
+			auto mergeBlock = ::llvm::BasicBlock::Create(jit->context, "", jit->function);
+			jit->builder->CreateCondBr(jit->builder->CreateExtractElement(i1Mask, idx), thenBlock, mergeBlock);
+			jit->builder->SetInsertPoint(thenBlock);
+
+			// Insert __msan_unpoison call in conditional block
+			auto elPtr = jit->builder->CreateExtractElement(elPtrs, idx);
+			jit->builder->CreateCall(func, { jit->builder->CreatePointerCast(elPtr, voidPtrTy),
+			                                 ::llvm::ConstantInt::get(sizetTy, size) });
+			jit->builder->CreateBr(mergeBlock);
+			jit->builder->SetInsertPoint(mergeBlock);
+		}
+	}
+#endif
 }
 }  // namespace
 
@@ -1165,10 +1194,39 @@
 	auto i32Ty = ::llvm::Type::getInt32Ty(jit->context);
 	auto elVecTy = V(val)->getType();
 	auto elVecPtrTy = elVecTy->getPointerTo();
-	auto i8Mask = jit->builder->CreateIntCast(V(mask), ::llvm::VectorType::get(i1Ty, numEls), false);  // vec<int, int, ...> -> vec<bool, bool, ...>
+	auto i1Mask = jit->builder->CreateIntCast(V(mask), ::llvm::VectorType::get(i1Ty, numEls), false);  // vec<int, int, ...> -> vec<bool, bool, ...>
 	auto align = ::llvm::ConstantInt::get(i32Ty, alignment);
 	auto func = ::llvm::Intrinsic::getDeclaration(jit->module.get(), llvm::Intrinsic::masked_store, { elVecTy, elVecPtrTy });
-	jit->builder->CreateCall(func, { V(val), V(ptr), align, i8Mask });
+	jit->builder->CreateCall(func, { V(val), V(ptr), align, i1Mask });
+
+#if __has_feature(memory_sanitizer)
+	// Mark memory writes as initialized by calling __msan_unpoison
+	{
+		// void __msan_unpoison(const volatile void *a, size_t size)
+		auto voidTy = ::llvm::Type::getVoidTy(jit->context);
+		auto voidPtrTy = voidTy->getPointerTo();
+		auto sizetTy = ::llvm::IntegerType::get(jit->context, sizeof(size_t) * 8);
+		auto funcTy = ::llvm::FunctionType::get(voidTy, { voidPtrTy, sizetTy }, false);
+		auto func = jit->module->getOrInsertFunction("__msan_unpoison", funcTy);
+		auto size = jit->module->getDataLayout().getTypeStoreSize(llvm::cast<llvm::VectorType>(elVecTy)->getElementType());
+		for(unsigned i = 0; i < numEls; i++)
+		{
+			// Check mask for this element
+			auto idx = ::llvm::ConstantInt::get(i32Ty, i);
+			auto thenBlock = ::llvm::BasicBlock::Create(jit->context, "", jit->function);
+			auto mergeBlock = ::llvm::BasicBlock::Create(jit->context, "", jit->function);
+			jit->builder->CreateCondBr(jit->builder->CreateExtractElement(i1Mask, idx), thenBlock, mergeBlock);
+			jit->builder->SetInsertPoint(thenBlock);
+
+			// Insert __msan_unpoison call in conditional block
+			auto elPtr = jit->builder->CreateGEP(V(ptr), idx);
+			jit->builder->CreateCall(func, { jit->builder->CreatePointerCast(elPtr, voidPtrTy),
+			                                 ::llvm::ConstantInt::get(sizetTy, size) });
+			jit->builder->CreateBr(mergeBlock);
+			jit->builder->SetInsertPoint(mergeBlock);
+		}
+	}
+#endif
 }
 
 RValue<Float4> Gather(RValue<Pointer<Float>> base, RValue<Int4> offsets, RValue<Int4> mask, unsigned int alignment, bool zeroMaskedLanes /* = false */)