blob: 4ec7181ad85950bc7d78bb38da4c2cfbcdfd15f2 [file] [log] [blame]
//===- TLSVariableHoist.cpp -------- Remove Redundant TLS Loads ---------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This pass identifies/eliminate Redundant TLS Loads if related option is set.
// The example: Please refer to the comment at the head of TLSVariableHoist.h.
//
//===----------------------------------------------------------------------===//
#include "llvm/ADT/SmallVector.h"
#include "llvm/IR/BasicBlock.h"
#include "llvm/IR/Dominators.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/InstrTypes.h"
#include "llvm/IR/Instruction.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/IntrinsicInst.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/Value.h"
#include "llvm/InitializePasses.h"
#include "llvm/Pass.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/raw_ostream.h"
#include "llvm/Transforms/Scalar.h"
#include "llvm/Transforms/Scalar/TLSVariableHoist.h"
#include <algorithm>
#include <cassert>
#include <cstdint>
#include <iterator>
#include <tuple>
#include <utility>
using namespace llvm;
using namespace tlshoist;
#define DEBUG_TYPE "tlshoist"
static cl::opt<bool> TLSLoadHoist(
"tls-load-hoist", cl::init(false), cl::Hidden,
cl::desc("hoist the TLS loads in PIC model to eliminate redundant "
"TLS address calculation."));
namespace {
/// The TLS Variable hoist pass.
class TLSVariableHoistLegacyPass : public FunctionPass {
public:
static char ID; // Pass identification, replacement for typeid
TLSVariableHoistLegacyPass() : FunctionPass(ID) {
initializeTLSVariableHoistLegacyPassPass(*PassRegistry::getPassRegistry());
}
bool runOnFunction(Function &Fn) override;
StringRef getPassName() const override { return "TLS Variable Hoist"; }
void getAnalysisUsage(AnalysisUsage &AU) const override {
AU.setPreservesCFG();
AU.addRequired<DominatorTreeWrapperPass>();
AU.addRequired<LoopInfoWrapperPass>();
}
private:
TLSVariableHoistPass Impl;
};
} // end anonymous namespace
char TLSVariableHoistLegacyPass::ID = 0;
INITIALIZE_PASS_BEGIN(TLSVariableHoistLegacyPass, "tlshoist",
"TLS Variable Hoist", false, false)
INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass)
INITIALIZE_PASS_END(TLSVariableHoistLegacyPass, "tlshoist",
"TLS Variable Hoist", false, false)
FunctionPass *llvm::createTLSVariableHoistPass() {
return new TLSVariableHoistLegacyPass();
}
/// Perform the TLS Variable Hoist optimization for the given function.
bool TLSVariableHoistLegacyPass::runOnFunction(Function &Fn) {
if (skipFunction(Fn))
return false;
LLVM_DEBUG(dbgs() << "********** Begin TLS Variable Hoist **********\n");
LLVM_DEBUG(dbgs() << "********** Function: " << Fn.getName() << '\n');
bool MadeChange =
Impl.runImpl(Fn, getAnalysis<DominatorTreeWrapperPass>().getDomTree(),
getAnalysis<LoopInfoWrapperPass>().getLoopInfo());
if (MadeChange) {
LLVM_DEBUG(dbgs() << "********** Function after TLS Variable Hoist: "
<< Fn.getName() << '\n');
LLVM_DEBUG(dbgs() << Fn);
}
LLVM_DEBUG(dbgs() << "********** End TLS Variable Hoist **********\n");
return MadeChange;
}
void TLSVariableHoistPass::collectTLSCandidate(Instruction *Inst) {
// Skip all cast instructions. They are visited indirectly later on.
if (Inst->isCast())
return;
// Scan all operands.
for (unsigned Idx = 0, E = Inst->getNumOperands(); Idx != E; ++Idx) {
auto *GV = dyn_cast<GlobalVariable>(Inst->getOperand(Idx));
if (!GV || !GV->isThreadLocal())
continue;
// Add Candidate to TLSCandMap (GV --> Candidate).
TLSCandMap[GV].addUser(Inst, Idx);
}
}
void TLSVariableHoistPass::collectTLSCandidates(Function &Fn) {
// First, quickly check if there is TLS Variable.
Module *M = Fn.getParent();
bool HasTLS = llvm::any_of(
M->globals(), [](GlobalVariable &GV) { return GV.isThreadLocal(); });
// If non, directly return.
if (!HasTLS)
return;
TLSCandMap.clear();
// Then, collect TLS Variable info.
for (BasicBlock &BB : Fn) {
// Ignore unreachable basic blocks.
if (!DT->isReachableFromEntry(&BB))
continue;
for (Instruction &Inst : BB)
collectTLSCandidate(&Inst);
}
}
static bool oneUseOutsideLoop(tlshoist::TLSCandidate &Cand, LoopInfo *LI) {
if (Cand.Users.size() != 1)
return false;
BasicBlock *BB = Cand.Users[0].Inst->getParent();
if (LI->getLoopFor(BB))
return false;
return true;
}
Instruction *TLSVariableHoistPass::getNearestLoopDomInst(BasicBlock *BB,
Loop *L) {
assert(L && "Unexcepted Loop status!");
// Get the outermost loop.
while (Loop *Parent = L->getParentLoop())
L = Parent;
BasicBlock *PreHeader = L->getLoopPreheader();
// There is unique predecessor outside the loop.
if (PreHeader)
return PreHeader->getTerminator();
BasicBlock *Header = L->getHeader();
BasicBlock *Dom = Header;
for (BasicBlock *PredBB : predecessors(Header))
Dom = DT->findNearestCommonDominator(Dom, PredBB);
assert(Dom && "Not find dominator BB!");
Instruction *Term = Dom->getTerminator();
return Term;
}
Instruction *TLSVariableHoistPass::getDomInst(Instruction *I1,
Instruction *I2) {
if (!I1)
return I2;
return DT->findNearestCommonDominator(I1, I2);
}
BasicBlock::iterator TLSVariableHoistPass::findInsertPos(Function &Fn,
GlobalVariable *GV,
BasicBlock *&PosBB) {
tlshoist::TLSCandidate &Cand = TLSCandMap[GV];
// We should hoist the TLS use out of loop, so choose its nearest instruction
// which dominate the loop and the outside loops (if exist).
Instruction *LastPos = nullptr;
for (auto &User : Cand.Users) {
BasicBlock *BB = User.Inst->getParent();
Instruction *Pos = User.Inst;
if (Loop *L = LI->getLoopFor(BB)) {
Pos = getNearestLoopDomInst(BB, L);
assert(Pos && "Not find insert position out of loop!");
}
Pos = getDomInst(LastPos, Pos);
LastPos = Pos;
}
assert(LastPos && "Unexpected insert position!");
BasicBlock *Parent = LastPos->getParent();
PosBB = Parent;
return LastPos->getIterator();
}
// Generate a bitcast (no type change) to replace the uses of TLS Candidate.
Instruction *TLSVariableHoistPass::genBitCastInst(Function &Fn,
GlobalVariable *GV) {
BasicBlock *PosBB = &Fn.getEntryBlock();
BasicBlock::iterator Iter = findInsertPos(Fn, GV, PosBB);
Type *Ty = GV->getType();
auto *CastInst = new BitCastInst(GV, Ty, "tls_bitcast");
CastInst->insertInto(PosBB, Iter);
return CastInst;
}
bool TLSVariableHoistPass::tryReplaceTLSCandidate(Function &Fn,
GlobalVariable *GV) {
tlshoist::TLSCandidate &Cand = TLSCandMap[GV];
// If only used 1 time and not in loops, we no need to replace it.
if (oneUseOutsideLoop(Cand, LI))
return false;
// Generate a bitcast (no type change)
auto *CastInst = genBitCastInst(Fn, GV);
// to replace the uses of TLS Candidate
for (auto &User : Cand.Users)
User.Inst->setOperand(User.OpndIdx, CastInst);
return true;
}
bool TLSVariableHoistPass::tryReplaceTLSCandidates(Function &Fn) {
if (TLSCandMap.empty())
return false;
bool Replaced = false;
for (auto &GV2Cand : TLSCandMap) {
GlobalVariable *GV = GV2Cand.first;
Replaced |= tryReplaceTLSCandidate(Fn, GV);
}
return Replaced;
}
/// Optimize expensive TLS variables in the given function.
bool TLSVariableHoistPass::runImpl(Function &Fn, DominatorTree &DT,
LoopInfo &LI) {
if (Fn.hasOptNone())
return false;
if (!TLSLoadHoist && !Fn.getAttributes().hasFnAttr("tls-load-hoist"))
return false;
this->LI = &LI;
this->DT = &DT;
assert(this->LI && this->DT && "Unexcepted requirement!");
// Collect all TLS variable candidates.
collectTLSCandidates(Fn);
bool MadeChange = tryReplaceTLSCandidates(Fn);
return MadeChange;
}
PreservedAnalyses TLSVariableHoistPass::run(Function &F,
FunctionAnalysisManager &AM) {
auto &LI = AM.getResult<LoopAnalysis>(F);
auto &DT = AM.getResult<DominatorTreeAnalysis>(F);
if (!runImpl(F, DT, LI))
return PreservedAnalyses::all();
PreservedAnalyses PA;
PA.preserveSet<CFGAnalyses>();
return PA;
}