blob: c2cd8fa0324ee87d61524ede169101de37237524 [file] [log] [blame] [edit]
//===- SwitchLoweringUtils.cpp - Switch Lowering --------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file contains switch inst lowering optimizations and utilities for
// codegen, so that it can be used for both SelectionDAG and GlobalISel.
//
//===----------------------------------------------------------------------===//
#include "llvm/CodeGen/MachineJumpTableInfo.h"
#include "llvm/CodeGen/SwitchLoweringUtils.h"
using namespace llvm;
using namespace SwitchCG;
uint64_t SwitchCG::getJumpTableRange(const CaseClusterVector &Clusters,
unsigned First, unsigned Last) {
assert(Last >= First);
const APInt &LowCase = Clusters[First].Low->getValue();
const APInt &HighCase = Clusters[Last].High->getValue();
assert(LowCase.getBitWidth() == HighCase.getBitWidth());
// FIXME: A range of consecutive cases has 100% density, but only requires one
// comparison to lower. We should discriminate against such consecutive ranges
// in jump tables.
return (HighCase - LowCase).getLimitedValue((UINT64_MAX - 1) / 100) + 1;
}
uint64_t
SwitchCG::getJumpTableNumCases(const SmallVectorImpl<unsigned> &TotalCases,
unsigned First, unsigned Last) {
assert(Last >= First);
assert(TotalCases[Last] >= TotalCases[First]);
uint64_t NumCases =
TotalCases[Last] - (First == 0 ? 0 : TotalCases[First - 1]);
return NumCases;
}
void SwitchCG::SwitchLowering::findJumpTables(CaseClusterVector &Clusters,
const SwitchInst *SI,
MachineBasicBlock *DefaultMBB,
ProfileSummaryInfo *PSI,
BlockFrequencyInfo *BFI) {
#ifndef NDEBUG
// Clusters must be non-empty, sorted, and only contain Range clusters.
assert(!Clusters.empty());
for (CaseCluster &C : Clusters)
assert(C.Kind == CC_Range);
for (unsigned i = 1, e = Clusters.size(); i < e; ++i)
assert(Clusters[i - 1].High->getValue().slt(Clusters[i].Low->getValue()));
#endif
assert(TLI && "TLI not set!");
if (!TLI->areJTsAllowed(SI->getParent()->getParent()))
return;
const unsigned MinJumpTableEntries = TLI->getMinimumJumpTableEntries();
const unsigned SmallNumberOfEntries = MinJumpTableEntries / 2;
// Bail if not enough cases.
const int64_t N = Clusters.size();
if (N < 2 || N < MinJumpTableEntries)
return;
// Accumulated number of cases in each cluster and those prior to it.
SmallVector<unsigned, 8> TotalCases(N);
for (unsigned i = 0; i < N; ++i) {
const APInt &Hi = Clusters[i].High->getValue();
const APInt &Lo = Clusters[i].Low->getValue();
TotalCases[i] = (Hi - Lo).getLimitedValue() + 1;
if (i != 0)
TotalCases[i] += TotalCases[i - 1];
}
uint64_t Range = getJumpTableRange(Clusters,0, N - 1);
uint64_t NumCases = getJumpTableNumCases(TotalCases, 0, N - 1);
assert(NumCases < UINT64_MAX / 100);
assert(Range >= NumCases);
// Cheap case: the whole range may be suitable for jump table.
if (TLI->isSuitableForJumpTable(SI, NumCases, Range, PSI, BFI)) {
CaseCluster JTCluster;
if (buildJumpTable(Clusters, 0, N - 1, SI, DefaultMBB, JTCluster)) {
Clusters[0] = JTCluster;
Clusters.resize(1);
return;
}
}
// The algorithm below is not suitable for -O0.
if (TM->getOptLevel() == CodeGenOpt::None)
return;
// Split Clusters into minimum number of dense partitions. The algorithm uses
// the same idea as Kannan & Proebsting "Correction to 'Producing Good Code
// for the Case Statement'" (1994), but builds the MinPartitions array in
// reverse order to make it easier to reconstruct the partitions in ascending
// order. In the choice between two optimal partitionings, it picks the one
// which yields more jump tables.
// MinPartitions[i] is the minimum nbr of partitions of Clusters[i..N-1].
SmallVector<unsigned, 8> MinPartitions(N);
// LastElement[i] is the last element of the partition starting at i.
SmallVector<unsigned, 8> LastElement(N);
// PartitionsScore[i] is used to break ties when choosing between two
// partitionings resulting in the same number of partitions.
SmallVector<unsigned, 8> PartitionsScore(N);
// For PartitionsScore, a small number of comparisons is considered as good as
// a jump table and a single comparison is considered better than a jump
// table.
enum PartitionScores : unsigned {
NoTable = 0,
Table = 1,
FewCases = 1,
SingleCase = 2
};
// Base case: There is only one way to partition Clusters[N-1].
MinPartitions[N - 1] = 1;
LastElement[N - 1] = N - 1;
PartitionsScore[N - 1] = PartitionScores::SingleCase;
// Note: loop indexes are signed to avoid underflow.
for (int64_t i = N - 2; i >= 0; i--) {
// Find optimal partitioning of Clusters[i..N-1].
// Baseline: Put Clusters[i] into a partition on its own.
MinPartitions[i] = MinPartitions[i + 1] + 1;
LastElement[i] = i;
PartitionsScore[i] = PartitionsScore[i + 1] + PartitionScores::SingleCase;
// Search for a solution that results in fewer partitions.
for (int64_t j = N - 1; j > i; j--) {
// Try building a partition from Clusters[i..j].
Range = getJumpTableRange(Clusters, i, j);
NumCases = getJumpTableNumCases(TotalCases, i, j);
assert(NumCases < UINT64_MAX / 100);
assert(Range >= NumCases);
if (TLI->isSuitableForJumpTable(SI, NumCases, Range, PSI, BFI)) {
unsigned NumPartitions = 1 + (j == N - 1 ? 0 : MinPartitions[j + 1]);
unsigned Score = j == N - 1 ? 0 : PartitionsScore[j + 1];
int64_t NumEntries = j - i + 1;
if (NumEntries == 1)
Score += PartitionScores::SingleCase;
else if (NumEntries <= SmallNumberOfEntries)
Score += PartitionScores::FewCases;
else if (NumEntries >= MinJumpTableEntries)
Score += PartitionScores::Table;
// If this leads to fewer partitions, or to the same number of
// partitions with better score, it is a better partitioning.
if (NumPartitions < MinPartitions[i] ||
(NumPartitions == MinPartitions[i] && Score > PartitionsScore[i])) {
MinPartitions[i] = NumPartitions;
LastElement[i] = j;
PartitionsScore[i] = Score;
}
}
}
}
// Iterate over the partitions, replacing some with jump tables in-place.
unsigned DstIndex = 0;
for (unsigned First = 0, Last; First < N; First = Last + 1) {
Last = LastElement[First];
assert(Last >= First);
assert(DstIndex <= First);
unsigned NumClusters = Last - First + 1;
CaseCluster JTCluster;
if (NumClusters >= MinJumpTableEntries &&
buildJumpTable(Clusters, First, Last, SI, DefaultMBB, JTCluster)) {
Clusters[DstIndex++] = JTCluster;
} else {
for (unsigned I = First; I <= Last; ++I)
std::memmove(&Clusters[DstIndex++], &Clusters[I], sizeof(Clusters[I]));
}
}
Clusters.resize(DstIndex);
}
bool SwitchCG::SwitchLowering::buildJumpTable(const CaseClusterVector &Clusters,
unsigned First, unsigned Last,
const SwitchInst *SI,
MachineBasicBlock *DefaultMBB,
CaseCluster &JTCluster) {
assert(First <= Last);
auto Prob = BranchProbability::getZero();
unsigned NumCmps = 0;
std::vector<MachineBasicBlock*> Table;
DenseMap<MachineBasicBlock*, BranchProbability> JTProbs;
// Initialize probabilities in JTProbs.
for (unsigned I = First; I <= Last; ++I)
JTProbs[Clusters[I].MBB] = BranchProbability::getZero();
for (unsigned I = First; I <= Last; ++I) {
assert(Clusters[I].Kind == CC_Range);
Prob += Clusters[I].Prob;
const APInt &Low = Clusters[I].Low->getValue();
const APInt &High = Clusters[I].High->getValue();
NumCmps += (Low == High) ? 1 : 2;
if (I != First) {
// Fill the gap between this and the previous cluster.
const APInt &PreviousHigh = Clusters[I - 1].High->getValue();
assert(PreviousHigh.slt(Low));
uint64_t Gap = (Low - PreviousHigh).getLimitedValue() - 1;
for (uint64_t J = 0; J < Gap; J++)
Table.push_back(DefaultMBB);
}
uint64_t ClusterSize = (High - Low).getLimitedValue() + 1;
for (uint64_t J = 0; J < ClusterSize; ++J)
Table.push_back(Clusters[I].MBB);
JTProbs[Clusters[I].MBB] += Clusters[I].Prob;
}
unsigned NumDests = JTProbs.size();
if (TLI->isSuitableForBitTests(NumDests, NumCmps,
Clusters[First].Low->getValue(),
Clusters[Last].High->getValue(), *DL)) {
// Clusters[First..Last] should be lowered as bit tests instead.
return false;
}
// Create the MBB that will load from and jump through the table.
// Note: We create it here, but it's not inserted into the function yet.
MachineFunction *CurMF = FuncInfo.MF;
MachineBasicBlock *JumpTableMBB =
CurMF->CreateMachineBasicBlock(SI->getParent());
// Add successors. Note: use table order for determinism.
SmallPtrSet<MachineBasicBlock *, 8> Done;
for (MachineBasicBlock *Succ : Table) {
if (Done.count(Succ))
continue;
addSuccessorWithProb(JumpTableMBB, Succ, JTProbs[Succ]);
Done.insert(Succ);
}
JumpTableMBB->normalizeSuccProbs();
unsigned JTI = CurMF->getOrCreateJumpTableInfo(TLI->getJumpTableEncoding())
->createJumpTableIndex(Table);
// Set up the jump table info.
JumpTable JT(-1U, JTI, JumpTableMBB, nullptr);
JumpTableHeader JTH(Clusters[First].Low->getValue(),
Clusters[Last].High->getValue(), SI->getCondition(),
nullptr, false);
JTCases.emplace_back(std::move(JTH), std::move(JT));
JTCluster = CaseCluster::jumpTable(Clusters[First].Low, Clusters[Last].High,
JTCases.size() - 1, Prob);
return true;
}
void SwitchCG::SwitchLowering::findBitTestClusters(CaseClusterVector &Clusters,
const SwitchInst *SI) {
// Partition Clusters into as few subsets as possible, where each subset has a
// range that fits in a machine word and has <= 3 unique destinations.
#ifndef NDEBUG
// Clusters must be sorted and contain Range or JumpTable clusters.
assert(!Clusters.empty());
assert(Clusters[0].Kind == CC_Range || Clusters[0].Kind == CC_JumpTable);
for (const CaseCluster &C : Clusters)
assert(C.Kind == CC_Range || C.Kind == CC_JumpTable);
for (unsigned i = 1; i < Clusters.size(); ++i)
assert(Clusters[i-1].High->getValue().slt(Clusters[i].Low->getValue()));
#endif
// The algorithm below is not suitable for -O0.
if (TM->getOptLevel() == CodeGenOpt::None)
return;
// If target does not have legal shift left, do not emit bit tests at all.
EVT PTy = TLI->getPointerTy(*DL);
if (!TLI->isOperationLegal(ISD::SHL, PTy))
return;
int BitWidth = PTy.getSizeInBits();
const int64_t N = Clusters.size();
// MinPartitions[i] is the minimum nbr of partitions of Clusters[i..N-1].
SmallVector<unsigned, 8> MinPartitions(N);
// LastElement[i] is the last element of the partition starting at i.
SmallVector<unsigned, 8> LastElement(N);
// FIXME: This might not be the best algorithm for finding bit test clusters.
// Base case: There is only one way to partition Clusters[N-1].
MinPartitions[N - 1] = 1;
LastElement[N - 1] = N - 1;
// Note: loop indexes are signed to avoid underflow.
for (int64_t i = N - 2; i >= 0; --i) {
// Find optimal partitioning of Clusters[i..N-1].
// Baseline: Put Clusters[i] into a partition on its own.
MinPartitions[i] = MinPartitions[i + 1] + 1;
LastElement[i] = i;
// Search for a solution that results in fewer partitions.
// Note: the search is limited by BitWidth, reducing time complexity.
for (int64_t j = std::min(N - 1, i + BitWidth - 1); j > i; --j) {
// Try building a partition from Clusters[i..j].
// Check the range.
if (!TLI->rangeFitsInWord(Clusters[i].Low->getValue(),
Clusters[j].High->getValue(), *DL))
continue;
// Check nbr of destinations and cluster types.
// FIXME: This works, but doesn't seem very efficient.
bool RangesOnly = true;
BitVector Dests(FuncInfo.MF->getNumBlockIDs());
for (int64_t k = i; k <= j; k++) {
if (Clusters[k].Kind != CC_Range) {
RangesOnly = false;
break;
}
Dests.set(Clusters[k].MBB->getNumber());
}
if (!RangesOnly || Dests.count() > 3)
break;
// Check if it's a better partition.
unsigned NumPartitions = 1 + (j == N - 1 ? 0 : MinPartitions[j + 1]);
if (NumPartitions < MinPartitions[i]) {
// Found a better partition.
MinPartitions[i] = NumPartitions;
LastElement[i] = j;
}
}
}
// Iterate over the partitions, replacing with bit-test clusters in-place.
unsigned DstIndex = 0;
for (unsigned First = 0, Last; First < N; First = Last + 1) {
Last = LastElement[First];
assert(First <= Last);
assert(DstIndex <= First);
CaseCluster BitTestCluster;
if (buildBitTests(Clusters, First, Last, SI, BitTestCluster)) {
Clusters[DstIndex++] = BitTestCluster;
} else {
size_t NumClusters = Last - First + 1;
std::memmove(&Clusters[DstIndex], &Clusters[First],
sizeof(Clusters[0]) * NumClusters);
DstIndex += NumClusters;
}
}
Clusters.resize(DstIndex);
}
bool SwitchCG::SwitchLowering::buildBitTests(CaseClusterVector &Clusters,
unsigned First, unsigned Last,
const SwitchInst *SI,
CaseCluster &BTCluster) {
assert(First <= Last);
if (First == Last)
return false;
BitVector Dests(FuncInfo.MF->getNumBlockIDs());
unsigned NumCmps = 0;
for (int64_t I = First; I <= Last; ++I) {
assert(Clusters[I].Kind == CC_Range);
Dests.set(Clusters[I].MBB->getNumber());
NumCmps += (Clusters[I].Low == Clusters[I].High) ? 1 : 2;
}
unsigned NumDests = Dests.count();
APInt Low = Clusters[First].Low->getValue();
APInt High = Clusters[Last].High->getValue();
assert(Low.slt(High));
if (!TLI->isSuitableForBitTests(NumDests, NumCmps, Low, High, *DL))
return false;
APInt LowBound;
APInt CmpRange;
const int BitWidth = TLI->getPointerTy(*DL).getSizeInBits();
assert(TLI->rangeFitsInWord(Low, High, *DL) &&
"Case range must fit in bit mask!");
// Check if the clusters cover a contiguous range such that no value in the
// range will jump to the default statement.
bool ContiguousRange = true;
for (int64_t I = First + 1; I <= Last; ++I) {
if (Clusters[I].Low->getValue() != Clusters[I - 1].High->getValue() + 1) {
ContiguousRange = false;
break;
}
}
if (Low.isStrictlyPositive() && High.slt(BitWidth)) {
// Optimize the case where all the case values fit in a word without having
// to subtract minValue. In this case, we can optimize away the subtraction.
LowBound = APInt::getNullValue(Low.getBitWidth());
CmpRange = High;
ContiguousRange = false;
} else {
LowBound = Low;
CmpRange = High - Low;
}
CaseBitsVector CBV;
auto TotalProb = BranchProbability::getZero();
for (unsigned i = First; i <= Last; ++i) {
// Find the CaseBits for this destination.
unsigned j;
for (j = 0; j < CBV.size(); ++j)
if (CBV[j].BB == Clusters[i].MBB)
break;
if (j == CBV.size())
CBV.push_back(
CaseBits(0, Clusters[i].MBB, 0, BranchProbability::getZero()));
CaseBits *CB = &CBV[j];
// Update Mask, Bits and ExtraProb.
uint64_t Lo = (Clusters[i].Low->getValue() - LowBound).getZExtValue();
uint64_t Hi = (Clusters[i].High->getValue() - LowBound).getZExtValue();
assert(Hi >= Lo && Hi < 64 && "Invalid bit case!");
CB->Mask |= (-1ULL >> (63 - (Hi - Lo))) << Lo;
CB->Bits += Hi - Lo + 1;
CB->ExtraProb += Clusters[i].Prob;
TotalProb += Clusters[i].Prob;
}
BitTestInfo BTI;
llvm::sort(CBV, [](const CaseBits &a, const CaseBits &b) {
// Sort by probability first, number of bits second, bit mask third.
if (a.ExtraProb != b.ExtraProb)
return a.ExtraProb > b.ExtraProb;
if (a.Bits != b.Bits)
return a.Bits > b.Bits;
return a.Mask < b.Mask;
});
for (auto &CB : CBV) {
MachineBasicBlock *BitTestBB =
FuncInfo.MF->CreateMachineBasicBlock(SI->getParent());
BTI.push_back(BitTestCase(CB.Mask, BitTestBB, CB.BB, CB.ExtraProb));
}
BitTestCases.emplace_back(std::move(LowBound), std::move(CmpRange),
SI->getCondition(), -1U, MVT::Other, false,
ContiguousRange, nullptr, nullptr, std::move(BTI),
TotalProb);
BTCluster = CaseCluster::bitTests(Clusters[First].Low, Clusters[Last].High,
BitTestCases.size() - 1, TotalProb);
return true;
}
void SwitchCG::sortAndRangeify(CaseClusterVector &Clusters) {
#ifndef NDEBUG
for (const CaseCluster &CC : Clusters)
assert(CC.Low == CC.High && "Input clusters must be single-case");
#endif
llvm::sort(Clusters, [](const CaseCluster &a, const CaseCluster &b) {
return a.Low->getValue().slt(b.Low->getValue());
});
// Merge adjacent clusters with the same destination.
const unsigned N = Clusters.size();
unsigned DstIndex = 0;
for (unsigned SrcIndex = 0; SrcIndex < N; ++SrcIndex) {
CaseCluster &CC = Clusters[SrcIndex];
const ConstantInt *CaseVal = CC.Low;
MachineBasicBlock *Succ = CC.MBB;
if (DstIndex != 0 && Clusters[DstIndex - 1].MBB == Succ &&
(CaseVal->getValue() - Clusters[DstIndex - 1].High->getValue()) == 1) {
// If this case has the same successor and is a neighbour, merge it into
// the previous cluster.
Clusters[DstIndex - 1].High = CaseVal;
Clusters[DstIndex - 1].Prob += CC.Prob;
} else {
std::memmove(&Clusters[DstIndex++], &Clusters[SrcIndex],
sizeof(Clusters[SrcIndex]));
}
}
Clusters.resize(DstIndex);
}