| //===- ConvergenceUtils.cpp -----------------------------------------------===// |
| // |
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| // |
| //===----------------------------------------------------------------------===// |
| |
| #include "llvm/Analysis/UniformityAnalysis.h" |
| #include "llvm/ADT/GenericUniformityImpl.h" |
| #include "llvm/Analysis/CycleAnalysis.h" |
| #include "llvm/Analysis/TargetTransformInfo.h" |
| #include "llvm/IR/Constants.h" |
| #include "llvm/IR/Dominators.h" |
| #include "llvm/IR/InstIterator.h" |
| #include "llvm/IR/Instructions.h" |
| #include "llvm/InitializePasses.h" |
| |
| using namespace llvm; |
| |
| template <> |
| bool llvm::GenericUniformityAnalysisImpl<SSAContext>::hasDivergentDefs( |
| const Instruction &I) const { |
| return isDivergent((const Value *)&I); |
| } |
| |
| template <> |
| bool llvm::GenericUniformityAnalysisImpl<SSAContext>::markDefsDivergent( |
| const Instruction &Instr, bool AllDefsDivergent) { |
| return markDivergent(&Instr); |
| } |
| |
| template <> void llvm::GenericUniformityAnalysisImpl<SSAContext>::initialize() { |
| for (auto &I : instructions(F)) { |
| if (TTI->isSourceOfDivergence(&I)) { |
| assert(!I.isTerminator()); |
| markDivergent(I); |
| } else if (TTI->isAlwaysUniform(&I)) { |
| addUniformOverride(I); |
| } |
| } |
| for (auto &Arg : F.args()) { |
| if (TTI->isSourceOfDivergence(&Arg)) { |
| markDivergent(&Arg); |
| } |
| } |
| } |
| |
| template <> |
| void llvm::GenericUniformityAnalysisImpl<SSAContext>::pushUsers( |
| const Value *V) { |
| for (const auto *User : V->users()) { |
| const auto *UserInstr = dyn_cast<const Instruction>(User); |
| if (!UserInstr) |
| continue; |
| if (isAlwaysUniform(*UserInstr)) |
| continue; |
| if (markDivergent(*UserInstr)) { |
| Worklist.push_back(UserInstr); |
| } |
| } |
| } |
| |
| template <> |
| void llvm::GenericUniformityAnalysisImpl<SSAContext>::pushUsers( |
| const Instruction &Instr) { |
| assert(!isAlwaysUniform(Instr)); |
| if (Instr.isTerminator()) |
| return; |
| pushUsers(cast<Value>(&Instr)); |
| } |
| |
| template <> |
| bool llvm::GenericUniformityAnalysisImpl<SSAContext>::usesValueFromCycle( |
| const Instruction &I, const Cycle &DefCycle) const { |
| if (isAlwaysUniform(I)) |
| return false; |
| for (const Use &U : I.operands()) { |
| if (auto *I = dyn_cast<Instruction>(&U)) { |
| if (DefCycle.contains(I->getParent())) |
| return true; |
| } |
| } |
| return false; |
| } |
| |
| // This ensures explicit instantiation of |
| // GenericUniformityAnalysisImpl::ImplDeleter::operator() |
| template class llvm::GenericUniformityInfo<SSAContext>; |
| template struct llvm::GenericUniformityAnalysisImplDeleter< |
| llvm::GenericUniformityAnalysisImpl<SSAContext>>; |
| |
| //===----------------------------------------------------------------------===// |
| // UniformityInfoAnalysis and related pass implementations |
| //===----------------------------------------------------------------------===// |
| |
| llvm::UniformityInfo UniformityInfoAnalysis::run(Function &F, |
| FunctionAnalysisManager &FAM) { |
| auto &DT = FAM.getResult<DominatorTreeAnalysis>(F); |
| auto &TTI = FAM.getResult<TargetIRAnalysis>(F); |
| auto &CI = FAM.getResult<CycleAnalysis>(F); |
| return UniformityInfo{F, DT, CI, &TTI}; |
| } |
| |
| AnalysisKey UniformityInfoAnalysis::Key; |
| |
| UniformityInfoPrinterPass::UniformityInfoPrinterPass(raw_ostream &OS) |
| : OS(OS) {} |
| |
| PreservedAnalyses UniformityInfoPrinterPass::run(Function &F, |
| FunctionAnalysisManager &AM) { |
| OS << "UniformityInfo for function '" << F.getName() << "':\n"; |
| AM.getResult<UniformityInfoAnalysis>(F).print(OS); |
| |
| return PreservedAnalyses::all(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // UniformityInfoWrapperPass Implementation |
| //===----------------------------------------------------------------------===// |
| |
| char UniformityInfoWrapperPass::ID = 0; |
| |
| UniformityInfoWrapperPass::UniformityInfoWrapperPass() : FunctionPass(ID) { |
| initializeUniformityInfoWrapperPassPass(*PassRegistry::getPassRegistry()); |
| } |
| |
| INITIALIZE_PASS_BEGIN(UniformityInfoWrapperPass, "uniforminfo", |
| "Uniform Info Analysis", true, true) |
| INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) |
| INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass) |
| INITIALIZE_PASS_END(UniformityInfoWrapperPass, "uniforminfo", |
| "Uniform Info Analysis", true, true) |
| |
| void UniformityInfoWrapperPass::getAnalysisUsage(AnalysisUsage &AU) const { |
| AU.setPreservesAll(); |
| AU.addRequired<DominatorTreeWrapperPass>(); |
| AU.addRequired<CycleInfoWrapperPass>(); |
| AU.addRequired<TargetTransformInfoWrapperPass>(); |
| } |
| |
| bool UniformityInfoWrapperPass::runOnFunction(Function &F) { |
| auto &cycleInfo = getAnalysis<CycleInfoWrapperPass>().getResult(); |
| auto &domTree = getAnalysis<DominatorTreeWrapperPass>().getDomTree(); |
| auto &targetTransformInfo = |
| getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F); |
| |
| m_function = &F; |
| m_uniformityInfo = |
| UniformityInfo{F, domTree, cycleInfo, &targetTransformInfo}; |
| return false; |
| } |
| |
| void UniformityInfoWrapperPass::print(raw_ostream &OS, const Module *) const { |
| OS << "UniformityInfo for function '" << m_function->getName() << "':\n"; |
| } |
| |
| void UniformityInfoWrapperPass::releaseMemory() { |
| m_uniformityInfo = UniformityInfo{}; |
| m_function = nullptr; |
| } |