| //===- TLSVariableHoist.cpp -------- Remove Redundant TLS Loads ---------===// |
| // |
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| // |
| //===----------------------------------------------------------------------===// |
| // |
| // This pass identifies/eliminate Redundant TLS Loads if related option is set. |
| // The example: Please refer to the comment at the head of TLSVariableHoist.h. |
| // |
| //===----------------------------------------------------------------------===// |
| |
| #include "llvm/ADT/SmallVector.h" |
| #include "llvm/IR/BasicBlock.h" |
| #include "llvm/IR/Dominators.h" |
| #include "llvm/IR/Function.h" |
| #include "llvm/IR/InstrTypes.h" |
| #include "llvm/IR/Instruction.h" |
| #include "llvm/IR/Instructions.h" |
| #include "llvm/IR/IntrinsicInst.h" |
| #include "llvm/IR/Module.h" |
| #include "llvm/IR/Value.h" |
| #include "llvm/InitializePasses.h" |
| #include "llvm/Pass.h" |
| #include "llvm/Support/Casting.h" |
| #include "llvm/Support/Debug.h" |
| #include "llvm/Support/raw_ostream.h" |
| #include "llvm/Transforms/Scalar.h" |
| #include "llvm/Transforms/Scalar/TLSVariableHoist.h" |
| #include <algorithm> |
| #include <cassert> |
| #include <cstdint> |
| #include <iterator> |
| #include <tuple> |
| #include <utility> |
| |
| using namespace llvm; |
| using namespace tlshoist; |
| |
| #define DEBUG_TYPE "tlshoist" |
| |
| static cl::opt<bool> TLSLoadHoist( |
| "tls-load-hoist", cl::init(false), cl::Hidden, |
| cl::desc("hoist the TLS loads in PIC model to eliminate redundant " |
| "TLS address calculation.")); |
| |
| namespace { |
| |
| /// The TLS Variable hoist pass. |
| class TLSVariableHoistLegacyPass : public FunctionPass { |
| public: |
| static char ID; // Pass identification, replacement for typeid |
| |
| TLSVariableHoistLegacyPass() : FunctionPass(ID) { |
| initializeTLSVariableHoistLegacyPassPass(*PassRegistry::getPassRegistry()); |
| } |
| |
| bool runOnFunction(Function &Fn) override; |
| |
| StringRef getPassName() const override { return "TLS Variable Hoist"; } |
| |
| void getAnalysisUsage(AnalysisUsage &AU) const override { |
| AU.setPreservesCFG(); |
| AU.addRequired<DominatorTreeWrapperPass>(); |
| AU.addRequired<LoopInfoWrapperPass>(); |
| } |
| |
| private: |
| TLSVariableHoistPass Impl; |
| }; |
| |
| } // end anonymous namespace |
| |
| char TLSVariableHoistLegacyPass::ID = 0; |
| |
| INITIALIZE_PASS_BEGIN(TLSVariableHoistLegacyPass, "tlshoist", |
| "TLS Variable Hoist", false, false) |
| INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) |
| INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass) |
| INITIALIZE_PASS_END(TLSVariableHoistLegacyPass, "tlshoist", |
| "TLS Variable Hoist", false, false) |
| |
| FunctionPass *llvm::createTLSVariableHoistPass() { |
| return new TLSVariableHoistLegacyPass(); |
| } |
| |
| /// Perform the TLS Variable Hoist optimization for the given function. |
| bool TLSVariableHoistLegacyPass::runOnFunction(Function &Fn) { |
| if (skipFunction(Fn)) |
| return false; |
| |
| LLVM_DEBUG(dbgs() << "********** Begin TLS Variable Hoist **********\n"); |
| LLVM_DEBUG(dbgs() << "********** Function: " << Fn.getName() << '\n'); |
| |
| bool MadeChange = |
| Impl.runImpl(Fn, getAnalysis<DominatorTreeWrapperPass>().getDomTree(), |
| getAnalysis<LoopInfoWrapperPass>().getLoopInfo()); |
| |
| if (MadeChange) { |
| LLVM_DEBUG(dbgs() << "********** Function after TLS Variable Hoist: " |
| << Fn.getName() << '\n'); |
| LLVM_DEBUG(dbgs() << Fn); |
| } |
| LLVM_DEBUG(dbgs() << "********** End TLS Variable Hoist **********\n"); |
| |
| return MadeChange; |
| } |
| |
| void TLSVariableHoistPass::collectTLSCandidate(Instruction *Inst) { |
| // Skip all cast instructions. They are visited indirectly later on. |
| if (Inst->isCast()) |
| return; |
| |
| // Scan all operands. |
| for (unsigned Idx = 0, E = Inst->getNumOperands(); Idx != E; ++Idx) { |
| auto *GV = dyn_cast<GlobalVariable>(Inst->getOperand(Idx)); |
| if (!GV || !GV->isThreadLocal()) |
| continue; |
| |
| // Add Candidate to TLSCandMap (GV --> Candidate). |
| TLSCandMap[GV].addUser(Inst, Idx); |
| } |
| } |
| |
| void TLSVariableHoistPass::collectTLSCandidates(Function &Fn) { |
| // First, quickly check if there is TLS Variable. |
| Module *M = Fn.getParent(); |
| |
| bool HasTLS = llvm::any_of( |
| M->globals(), [](GlobalVariable &GV) { return GV.isThreadLocal(); }); |
| |
| // If non, directly return. |
| if (!HasTLS) |
| return; |
| |
| TLSCandMap.clear(); |
| |
| // Then, collect TLS Variable info. |
| for (BasicBlock &BB : Fn) { |
| // Ignore unreachable basic blocks. |
| if (!DT->isReachableFromEntry(&BB)) |
| continue; |
| |
| for (Instruction &Inst : BB) |
| collectTLSCandidate(&Inst); |
| } |
| } |
| |
| static bool oneUseOutsideLoop(tlshoist::TLSCandidate &Cand, LoopInfo *LI) { |
| if (Cand.Users.size() != 1) |
| return false; |
| |
| BasicBlock *BB = Cand.Users[0].Inst->getParent(); |
| if (LI->getLoopFor(BB)) |
| return false; |
| |
| return true; |
| } |
| |
| Instruction *TLSVariableHoistPass::getNearestLoopDomInst(BasicBlock *BB, |
| Loop *L) { |
| assert(L && "Unexcepted Loop status!"); |
| |
| // Get the outermost loop. |
| while (Loop *Parent = L->getParentLoop()) |
| L = Parent; |
| |
| BasicBlock *PreHeader = L->getLoopPreheader(); |
| |
| // There is unique predecessor outside the loop. |
| if (PreHeader) |
| return PreHeader->getTerminator(); |
| |
| BasicBlock *Header = L->getHeader(); |
| BasicBlock *Dom = Header; |
| for (BasicBlock *PredBB : predecessors(Header)) |
| Dom = DT->findNearestCommonDominator(Dom, PredBB); |
| |
| assert(Dom && "Not find dominator BB!"); |
| Instruction *Term = Dom->getTerminator(); |
| |
| return Term; |
| } |
| |
| Instruction *TLSVariableHoistPass::getDomInst(Instruction *I1, |
| Instruction *I2) { |
| if (!I1) |
| return I2; |
| return DT->findNearestCommonDominator(I1, I2); |
| } |
| |
| BasicBlock::iterator TLSVariableHoistPass::findInsertPos(Function &Fn, |
| GlobalVariable *GV, |
| BasicBlock *&PosBB) { |
| tlshoist::TLSCandidate &Cand = TLSCandMap[GV]; |
| |
| // We should hoist the TLS use out of loop, so choose its nearest instruction |
| // which dominate the loop and the outside loops (if exist). |
| Instruction *LastPos = nullptr; |
| for (auto &User : Cand.Users) { |
| BasicBlock *BB = User.Inst->getParent(); |
| Instruction *Pos = User.Inst; |
| if (Loop *L = LI->getLoopFor(BB)) { |
| Pos = getNearestLoopDomInst(BB, L); |
| assert(Pos && "Not find insert position out of loop!"); |
| } |
| Pos = getDomInst(LastPos, Pos); |
| LastPos = Pos; |
| } |
| |
| assert(LastPos && "Unexpected insert position!"); |
| BasicBlock *Parent = LastPos->getParent(); |
| PosBB = Parent; |
| return LastPos->getIterator(); |
| } |
| |
| // Generate a bitcast (no type change) to replace the uses of TLS Candidate. |
| Instruction *TLSVariableHoistPass::genBitCastInst(Function &Fn, |
| GlobalVariable *GV) { |
| BasicBlock *PosBB = &Fn.getEntryBlock(); |
| BasicBlock::iterator Iter = findInsertPos(Fn, GV, PosBB); |
| Type *Ty = GV->getType(); |
| auto *CastInst = new BitCastInst(GV, Ty, "tls_bitcast"); |
| CastInst->insertInto(PosBB, Iter); |
| return CastInst; |
| } |
| |
| bool TLSVariableHoistPass::tryReplaceTLSCandidate(Function &Fn, |
| GlobalVariable *GV) { |
| |
| tlshoist::TLSCandidate &Cand = TLSCandMap[GV]; |
| |
| // If only used 1 time and not in loops, we no need to replace it. |
| if (oneUseOutsideLoop(Cand, LI)) |
| return false; |
| |
| // Generate a bitcast (no type change) |
| auto *CastInst = genBitCastInst(Fn, GV); |
| |
| // to replace the uses of TLS Candidate |
| for (auto &User : Cand.Users) |
| User.Inst->setOperand(User.OpndIdx, CastInst); |
| |
| return true; |
| } |
| |
| bool TLSVariableHoistPass::tryReplaceTLSCandidates(Function &Fn) { |
| if (TLSCandMap.empty()) |
| return false; |
| |
| bool Replaced = false; |
| for (auto &GV2Cand : TLSCandMap) { |
| GlobalVariable *GV = GV2Cand.first; |
| Replaced |= tryReplaceTLSCandidate(Fn, GV); |
| } |
| |
| return Replaced; |
| } |
| |
| /// Optimize expensive TLS variables in the given function. |
| bool TLSVariableHoistPass::runImpl(Function &Fn, DominatorTree &DT, |
| LoopInfo &LI) { |
| if (Fn.hasOptNone()) |
| return false; |
| |
| if (!TLSLoadHoist && !Fn.getAttributes().hasFnAttr("tls-load-hoist")) |
| return false; |
| |
| this->LI = &LI; |
| this->DT = &DT; |
| assert(this->LI && this->DT && "Unexcepted requirement!"); |
| |
| // Collect all TLS variable candidates. |
| collectTLSCandidates(Fn); |
| |
| bool MadeChange = tryReplaceTLSCandidates(Fn); |
| |
| return MadeChange; |
| } |
| |
| PreservedAnalyses TLSVariableHoistPass::run(Function &F, |
| FunctionAnalysisManager &AM) { |
| |
| auto &LI = AM.getResult<LoopAnalysis>(F); |
| auto &DT = AM.getResult<DominatorTreeAnalysis>(F); |
| |
| if (!runImpl(F, DT, LI)) |
| return PreservedAnalyses::all(); |
| |
| PreservedAnalyses PA; |
| PA.preserveSet<CFGAnalyses>(); |
| return PA; |
| } |