Insert __msan_unposion for masked/scattered stores
Fixes more MSan LLVM JIT false positives.
Change-Id: I0579e2cc71b089424fe168cc60fed63c4b431f91
Bug: b/140204746
Reviewed-on: https://swiftshader-review.googlesource.com/c/SwiftShader/+/44708
Presubmit-Ready: James Price <jrprice@google.com>
Reviewed-by: Nicolas Capens <nicolascapens@google.com>
Kokoro-Result: kokoro <noreply+kokoro@google.com>
Tested-by: James Price <jrprice@google.com>
diff --git a/src/Reactor/LLVMReactor.cpp b/src/Reactor/LLVMReactor.cpp
index 0d0e55b..4d232bf 100644
--- a/src/Reactor/LLVMReactor.cpp
+++ b/src/Reactor/LLVMReactor.cpp
@@ -459,10 +459,39 @@
auto i8Base = jit->builder->CreatePointerCast(base, i8PtrTy);
auto i8Ptrs = jit->builder->CreateGEP(i8Base, offsets);
auto elPtrs = jit->builder->CreatePointerCast(i8Ptrs, elPtrVecTy);
- auto i8Mask = jit->builder->CreateIntCast(mask, ::llvm::VectorType::get(i1Ty, numEls), false); // vec<int, int, ...> -> vec<bool, bool, ...>
+ auto i1Mask = jit->builder->CreateIntCast(mask, ::llvm::VectorType::get(i1Ty, numEls), false); // vec<int, int, ...> -> vec<bool, bool, ...>
auto align = ::llvm::ConstantInt::get(i32Ty, alignment);
auto func = ::llvm::Intrinsic::getDeclaration(jit->module.get(), llvm::Intrinsic::masked_scatter, { elVecTy, elPtrVecTy });
- jit->builder->CreateCall(func, { val, elPtrs, align, i8Mask });
+ jit->builder->CreateCall(func, { val, elPtrs, align, i1Mask });
+
+#if __has_feature(memory_sanitizer)
+ // Mark memory writes as initialized by calling __msan_unpoison
+ {
+ // void __msan_unpoison(const volatile void *a, size_t size)
+ auto voidTy = ::llvm::Type::getVoidTy(jit->context);
+ auto voidPtrTy = voidTy->getPointerTo();
+ auto sizetTy = ::llvm::IntegerType::get(jit->context, sizeof(size_t) * 8);
+ auto funcTy = ::llvm::FunctionType::get(voidTy, { voidPtrTy, sizetTy }, false);
+ auto func = jit->module->getOrInsertFunction("__msan_unpoison", funcTy);
+ auto size = jit->module->getDataLayout().getTypeStoreSize(elTy);
+ for(unsigned i = 0; i < numEls; i++)
+ {
+ // Check mask for this element
+ auto idx = ::llvm::ConstantInt::get(i32Ty, i);
+ auto thenBlock = ::llvm::BasicBlock::Create(jit->context, "", jit->function);
+ auto mergeBlock = ::llvm::BasicBlock::Create(jit->context, "", jit->function);
+ jit->builder->CreateCondBr(jit->builder->CreateExtractElement(i1Mask, idx), thenBlock, mergeBlock);
+ jit->builder->SetInsertPoint(thenBlock);
+
+ // Insert __msan_unpoison call in conditional block
+ auto elPtr = jit->builder->CreateExtractElement(elPtrs, idx);
+ jit->builder->CreateCall(func, { jit->builder->CreatePointerCast(elPtr, voidPtrTy),
+ ::llvm::ConstantInt::get(sizetTy, size) });
+ jit->builder->CreateBr(mergeBlock);
+ jit->builder->SetInsertPoint(mergeBlock);
+ }
+ }
+#endif
}
} // namespace
@@ -1165,10 +1194,39 @@
auto i32Ty = ::llvm::Type::getInt32Ty(jit->context);
auto elVecTy = V(val)->getType();
auto elVecPtrTy = elVecTy->getPointerTo();
- auto i8Mask = jit->builder->CreateIntCast(V(mask), ::llvm::VectorType::get(i1Ty, numEls), false); // vec<int, int, ...> -> vec<bool, bool, ...>
+ auto i1Mask = jit->builder->CreateIntCast(V(mask), ::llvm::VectorType::get(i1Ty, numEls), false); // vec<int, int, ...> -> vec<bool, bool, ...>
auto align = ::llvm::ConstantInt::get(i32Ty, alignment);
auto func = ::llvm::Intrinsic::getDeclaration(jit->module.get(), llvm::Intrinsic::masked_store, { elVecTy, elVecPtrTy });
- jit->builder->CreateCall(func, { V(val), V(ptr), align, i8Mask });
+ jit->builder->CreateCall(func, { V(val), V(ptr), align, i1Mask });
+
+#if __has_feature(memory_sanitizer)
+ // Mark memory writes as initialized by calling __msan_unpoison
+ {
+ // void __msan_unpoison(const volatile void *a, size_t size)
+ auto voidTy = ::llvm::Type::getVoidTy(jit->context);
+ auto voidPtrTy = voidTy->getPointerTo();
+ auto sizetTy = ::llvm::IntegerType::get(jit->context, sizeof(size_t) * 8);
+ auto funcTy = ::llvm::FunctionType::get(voidTy, { voidPtrTy, sizetTy }, false);
+ auto func = jit->module->getOrInsertFunction("__msan_unpoison", funcTy);
+ auto size = jit->module->getDataLayout().getTypeStoreSize(llvm::cast<llvm::VectorType>(elVecTy)->getElementType());
+ for(unsigned i = 0; i < numEls; i++)
+ {
+ // Check mask for this element
+ auto idx = ::llvm::ConstantInt::get(i32Ty, i);
+ auto thenBlock = ::llvm::BasicBlock::Create(jit->context, "", jit->function);
+ auto mergeBlock = ::llvm::BasicBlock::Create(jit->context, "", jit->function);
+ jit->builder->CreateCondBr(jit->builder->CreateExtractElement(i1Mask, idx), thenBlock, mergeBlock);
+ jit->builder->SetInsertPoint(thenBlock);
+
+ // Insert __msan_unpoison call in conditional block
+ auto elPtr = jit->builder->CreateGEP(V(ptr), idx);
+ jit->builder->CreateCall(func, { jit->builder->CreatePointerCast(elPtr, voidPtrTy),
+ ::llvm::ConstantInt::get(sizetTy, size) });
+ jit->builder->CreateBr(mergeBlock);
+ jit->builder->SetInsertPoint(mergeBlock);
+ }
+ }
+#endif
}
RValue<Float4> Gather(RValue<Pointer<Float>> base, RValue<Int4> offsets, RValue<Int4> mask, unsigned int alignment, bool zeroMaskedLanes /* = false */)