| //===-- AArch64CondBrTuning.cpp --- Conditional branch tuning for AArch64 -===// |
| // |
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| // |
| //===----------------------------------------------------------------------===// |
| /// \file |
| /// This file contains a pass that transforms CBZ/CBNZ/TBZ/TBNZ instructions |
| /// into a conditional branch (B.cond), when the NZCV flags can be set for |
| /// "free". This is preferred on targets that have more flexibility when |
| /// scheduling B.cond instructions as compared to CBZ/CBNZ/TBZ/TBNZ (assuming |
| /// all other variables are equal). This can also reduce register pressure. |
| /// |
| /// A few examples: |
| /// |
| /// 1) add w8, w0, w1 -> cmn w0, w1 ; CMN is an alias of ADDS. |
| /// cbz w8, .LBB_2 -> b.eq .LBB0_2 |
| /// |
| /// 2) add w8, w0, w1 -> adds w8, w0, w1 ; w8 has multiple uses. |
| /// cbz w8, .LBB1_2 -> b.eq .LBB1_2 |
| /// |
| /// 3) sub w8, w0, w1 -> subs w8, w0, w1 ; w8 has multiple uses. |
| /// tbz w8, #31, .LBB6_2 -> b.pl .LBB6_2 |
| /// |
| //===----------------------------------------------------------------------===// |
| |
| #include "AArch64.h" |
| #include "AArch64Subtarget.h" |
| #include "llvm/CodeGen/MachineFunction.h" |
| #include "llvm/CodeGen/MachineFunctionPass.h" |
| #include "llvm/CodeGen/MachineInstrBuilder.h" |
| #include "llvm/CodeGen/MachineRegisterInfo.h" |
| #include "llvm/CodeGen/Passes.h" |
| #include "llvm/CodeGen/TargetInstrInfo.h" |
| #include "llvm/CodeGen/TargetRegisterInfo.h" |
| #include "llvm/CodeGen/TargetSubtargetInfo.h" |
| #include "llvm/Support/Debug.h" |
| #include "llvm/Support/raw_ostream.h" |
| |
| using namespace llvm; |
| |
| #define DEBUG_TYPE "aarch64-cond-br-tuning" |
| #define AARCH64_CONDBR_TUNING_NAME "AArch64 Conditional Branch Tuning" |
| |
| namespace { |
| class AArch64CondBrTuning : public MachineFunctionPass { |
| const AArch64InstrInfo *TII; |
| const TargetRegisterInfo *TRI; |
| |
| MachineRegisterInfo *MRI; |
| |
| public: |
| static char ID; |
| AArch64CondBrTuning() : MachineFunctionPass(ID) { |
| initializeAArch64CondBrTuningPass(*PassRegistry::getPassRegistry()); |
| } |
| void getAnalysisUsage(AnalysisUsage &AU) const override; |
| bool runOnMachineFunction(MachineFunction &MF) override; |
| StringRef getPassName() const override { return AARCH64_CONDBR_TUNING_NAME; } |
| |
| private: |
| MachineInstr *getOperandDef(const MachineOperand &MO); |
| MachineInstr *convertToFlagSetting(MachineInstr &MI, bool IsFlagSetting); |
| MachineInstr *convertToCondBr(MachineInstr &MI); |
| bool tryToTuneBranch(MachineInstr &MI, MachineInstr &DefMI); |
| }; |
| } // end anonymous namespace |
| |
| char AArch64CondBrTuning::ID = 0; |
| |
| INITIALIZE_PASS(AArch64CondBrTuning, "aarch64-cond-br-tuning", |
| AARCH64_CONDBR_TUNING_NAME, false, false) |
| |
| void AArch64CondBrTuning::getAnalysisUsage(AnalysisUsage &AU) const { |
| AU.setPreservesCFG(); |
| MachineFunctionPass::getAnalysisUsage(AU); |
| } |
| |
| MachineInstr *AArch64CondBrTuning::getOperandDef(const MachineOperand &MO) { |
| if (!Register::isVirtualRegister(MO.getReg())) |
| return nullptr; |
| return MRI->getUniqueVRegDef(MO.getReg()); |
| } |
| |
| MachineInstr *AArch64CondBrTuning::convertToFlagSetting(MachineInstr &MI, |
| bool IsFlagSetting) { |
| // If this is already the flag setting version of the instruction (e.g., SUBS) |
| // just make sure the implicit-def of NZCV isn't marked dead. |
| if (IsFlagSetting) { |
| for (unsigned I = MI.getNumExplicitOperands(), E = MI.getNumOperands(); |
| I != E; ++I) { |
| MachineOperand &MO = MI.getOperand(I); |
| if (MO.isReg() && MO.isDead() && MO.getReg() == AArch64::NZCV) |
| MO.setIsDead(false); |
| } |
| return &MI; |
| } |
| bool Is64Bit; |
| unsigned NewOpc = TII->convertToFlagSettingOpc(MI.getOpcode(), Is64Bit); |
| Register NewDestReg = MI.getOperand(0).getReg(); |
| if (MRI->hasOneNonDBGUse(MI.getOperand(0).getReg())) |
| NewDestReg = Is64Bit ? AArch64::XZR : AArch64::WZR; |
| |
| MachineInstrBuilder MIB = BuildMI(*MI.getParent(), MI, MI.getDebugLoc(), |
| TII->get(NewOpc), NewDestReg); |
| for (unsigned I = 1, E = MI.getNumOperands(); I != E; ++I) |
| MIB.add(MI.getOperand(I)); |
| |
| return MIB; |
| } |
| |
| MachineInstr *AArch64CondBrTuning::convertToCondBr(MachineInstr &MI) { |
| AArch64CC::CondCode CC; |
| MachineBasicBlock *TargetMBB = TII->getBranchDestBlock(MI); |
| switch (MI.getOpcode()) { |
| default: |
| llvm_unreachable("Unexpected opcode!"); |
| |
| case AArch64::CBZW: |
| case AArch64::CBZX: |
| CC = AArch64CC::EQ; |
| break; |
| case AArch64::CBNZW: |
| case AArch64::CBNZX: |
| CC = AArch64CC::NE; |
| break; |
| case AArch64::TBZW: |
| case AArch64::TBZX: |
| CC = AArch64CC::PL; |
| break; |
| case AArch64::TBNZW: |
| case AArch64::TBNZX: |
| CC = AArch64CC::MI; |
| break; |
| } |
| return BuildMI(*MI.getParent(), MI, MI.getDebugLoc(), TII->get(AArch64::Bcc)) |
| .addImm(CC) |
| .addMBB(TargetMBB); |
| } |
| |
| bool AArch64CondBrTuning::tryToTuneBranch(MachineInstr &MI, |
| MachineInstr &DefMI) { |
| // We don't want NZCV bits live across blocks. |
| if (MI.getParent() != DefMI.getParent()) |
| return false; |
| |
| bool IsFlagSetting = true; |
| unsigned MIOpc = MI.getOpcode(); |
| MachineInstr *NewCmp = nullptr, *NewBr = nullptr; |
| switch (DefMI.getOpcode()) { |
| default: |
| return false; |
| case AArch64::ADDWri: |
| case AArch64::ADDWrr: |
| case AArch64::ADDWrs: |
| case AArch64::ADDWrx: |
| case AArch64::ANDWri: |
| case AArch64::ANDWrr: |
| case AArch64::ANDWrs: |
| case AArch64::BICWrr: |
| case AArch64::BICWrs: |
| case AArch64::SUBWri: |
| case AArch64::SUBWrr: |
| case AArch64::SUBWrs: |
| case AArch64::SUBWrx: |
| IsFlagSetting = false; |
| LLVM_FALLTHROUGH; |
| case AArch64::ADDSWri: |
| case AArch64::ADDSWrr: |
| case AArch64::ADDSWrs: |
| case AArch64::ADDSWrx: |
| case AArch64::ANDSWri: |
| case AArch64::ANDSWrr: |
| case AArch64::ANDSWrs: |
| case AArch64::BICSWrr: |
| case AArch64::BICSWrs: |
| case AArch64::SUBSWri: |
| case AArch64::SUBSWrr: |
| case AArch64::SUBSWrs: |
| case AArch64::SUBSWrx: |
| switch (MIOpc) { |
| default: |
| llvm_unreachable("Unexpected opcode!"); |
| |
| case AArch64::CBZW: |
| case AArch64::CBNZW: |
| case AArch64::TBZW: |
| case AArch64::TBNZW: |
| // Check to see if the TBZ/TBNZ is checking the sign bit. |
| if ((MIOpc == AArch64::TBZW || MIOpc == AArch64::TBNZW) && |
| MI.getOperand(1).getImm() != 31) |
| return false; |
| |
| // There must not be any instruction between DefMI and MI that clobbers or |
| // reads NZCV. |
| MachineBasicBlock::iterator I(DefMI), E(MI); |
| for (I = std::next(I); I != E; ++I) { |
| if (I->modifiesRegister(AArch64::NZCV, TRI) || |
| I->readsRegister(AArch64::NZCV, TRI)) |
| return false; |
| } |
| LLVM_DEBUG(dbgs() << " Replacing instructions:\n "); |
| LLVM_DEBUG(DefMI.print(dbgs())); |
| LLVM_DEBUG(dbgs() << " "); |
| LLVM_DEBUG(MI.print(dbgs())); |
| |
| NewCmp = convertToFlagSetting(DefMI, IsFlagSetting); |
| NewBr = convertToCondBr(MI); |
| break; |
| } |
| break; |
| |
| case AArch64::ADDXri: |
| case AArch64::ADDXrr: |
| case AArch64::ADDXrs: |
| case AArch64::ADDXrx: |
| case AArch64::ANDXri: |
| case AArch64::ANDXrr: |
| case AArch64::ANDXrs: |
| case AArch64::BICXrr: |
| case AArch64::BICXrs: |
| case AArch64::SUBXri: |
| case AArch64::SUBXrr: |
| case AArch64::SUBXrs: |
| case AArch64::SUBXrx: |
| IsFlagSetting = false; |
| LLVM_FALLTHROUGH; |
| case AArch64::ADDSXri: |
| case AArch64::ADDSXrr: |
| case AArch64::ADDSXrs: |
| case AArch64::ADDSXrx: |
| case AArch64::ANDSXri: |
| case AArch64::ANDSXrr: |
| case AArch64::ANDSXrs: |
| case AArch64::BICSXrr: |
| case AArch64::BICSXrs: |
| case AArch64::SUBSXri: |
| case AArch64::SUBSXrr: |
| case AArch64::SUBSXrs: |
| case AArch64::SUBSXrx: |
| switch (MIOpc) { |
| default: |
| llvm_unreachable("Unexpected opcode!"); |
| |
| case AArch64::CBZX: |
| case AArch64::CBNZX: |
| case AArch64::TBZX: |
| case AArch64::TBNZX: { |
| // Check to see if the TBZ/TBNZ is checking the sign bit. |
| if ((MIOpc == AArch64::TBZX || MIOpc == AArch64::TBNZX) && |
| MI.getOperand(1).getImm() != 63) |
| return false; |
| // There must not be any instruction between DefMI and MI that clobbers or |
| // reads NZCV. |
| MachineBasicBlock::iterator I(DefMI), E(MI); |
| for (I = std::next(I); I != E; ++I) { |
| if (I->modifiesRegister(AArch64::NZCV, TRI) || |
| I->readsRegister(AArch64::NZCV, TRI)) |
| return false; |
| } |
| LLVM_DEBUG(dbgs() << " Replacing instructions:\n "); |
| LLVM_DEBUG(DefMI.print(dbgs())); |
| LLVM_DEBUG(dbgs() << " "); |
| LLVM_DEBUG(MI.print(dbgs())); |
| |
| NewCmp = convertToFlagSetting(DefMI, IsFlagSetting); |
| NewBr = convertToCondBr(MI); |
| break; |
| } |
| } |
| break; |
| } |
| (void)NewCmp; (void)NewBr; |
| assert(NewCmp && NewBr && "Expected new instructions."); |
| |
| LLVM_DEBUG(dbgs() << " with instruction:\n "); |
| LLVM_DEBUG(NewCmp->print(dbgs())); |
| LLVM_DEBUG(dbgs() << " "); |
| LLVM_DEBUG(NewBr->print(dbgs())); |
| |
| // If this was a flag setting version of the instruction, we use the original |
| // instruction by just clearing the dead marked on the implicit-def of NCZV. |
| // Therefore, we should not erase this instruction. |
| if (!IsFlagSetting) |
| DefMI.eraseFromParent(); |
| MI.eraseFromParent(); |
| return true; |
| } |
| |
| bool AArch64CondBrTuning::runOnMachineFunction(MachineFunction &MF) { |
| if (skipFunction(MF.getFunction())) |
| return false; |
| |
| LLVM_DEBUG( |
| dbgs() << "********** AArch64 Conditional Branch Tuning **********\n" |
| << "********** Function: " << MF.getName() << '\n'); |
| |
| TII = static_cast<const AArch64InstrInfo *>(MF.getSubtarget().getInstrInfo()); |
| TRI = MF.getSubtarget().getRegisterInfo(); |
| MRI = &MF.getRegInfo(); |
| |
| bool Changed = false; |
| for (MachineBasicBlock &MBB : MF) { |
| bool LocalChange = false; |
| for (MachineBasicBlock::iterator I = MBB.getFirstTerminator(), |
| E = MBB.end(); |
| I != E; ++I) { |
| MachineInstr &MI = *I; |
| switch (MI.getOpcode()) { |
| default: |
| break; |
| case AArch64::CBZW: |
| case AArch64::CBZX: |
| case AArch64::CBNZW: |
| case AArch64::CBNZX: |
| case AArch64::TBZW: |
| case AArch64::TBZX: |
| case AArch64::TBNZW: |
| case AArch64::TBNZX: |
| MachineInstr *DefMI = getOperandDef(MI.getOperand(0)); |
| LocalChange = (DefMI && tryToTuneBranch(MI, *DefMI)); |
| break; |
| } |
| // If the optimization was successful, we can't optimize any other |
| // branches because doing so would clobber the NZCV flags. |
| if (LocalChange) { |
| Changed = true; |
| break; |
| } |
| } |
| } |
| return Changed; |
| } |
| |
| FunctionPass *llvm::createAArch64CondBrTuning() { |
| return new AArch64CondBrTuning(); |
| } |