Subzero. ARM32. Show FP lowering some love.

After some time of being neglected, this CL improves FP lowering for
ARM32.

1) It emits vpush {list}, and vpop {list} when possible.

2) It stops saving alised Vfp registers multiple times (yes, sz used to
save both D and S registers even when they aliased.)

3) Introduces Vmla (fp multiply and accumulate) and Vmls (multiply and
subtract.)

(1 + 2) minimally (but positively) affected SPEC.

(3) caused a 2% geomean improvement.

BUG= https://code.google.com/p/nativeclient/issues/detail?id=4076
R=stichnot@chromium.org

Review URL: https://codereview.chromium.org/1481133002 .
diff --git a/src/IceTargetLoweringARM32.cpp b/src/IceTargetLoweringARM32.cpp
index 74e4a48..2930349 100644
--- a/src/IceTargetLoweringARM32.cpp
+++ b/src/IceTargetLoweringARM32.cpp
@@ -876,6 +876,54 @@
   return true;
 }
 
+// The calling convention helper class (TargetARM32::CallingConv) expects the
+// following registers to be declared in a certain order, so we have these
+// sanity checks to ensure nothing breaks unknowingly.
+// TODO(jpp): modify the CallingConv class so it does not rely on any register
+// declaration order.
+#define SANITY_CHECK_QS(_0, _1)                                                \
+  static_assert((RegARM32::Reg_##_1 + 1) == RegARM32::Reg_##_0,                \
+                "ARM32 " #_0 " and " #_1 " registers are declared "            \
+                "incorrectly.")
+SANITY_CHECK_QS(q0, q1);
+SANITY_CHECK_QS(q1, q2);
+SANITY_CHECK_QS(q2, q3);
+SANITY_CHECK_QS(q3, q4);
+#undef SANITY_CHECK_QS
+#define SANITY_CHECK_DS(_0, _1)                                                \
+  static_assert((RegARM32::Reg_##_1 + 1) == RegARM32::Reg_##_0,                \
+                "ARM32 " #_0 " and " #_1 " registers are declared "            \
+                "incorrectly.")
+SANITY_CHECK_DS(d0, d1);
+SANITY_CHECK_DS(d1, d2);
+SANITY_CHECK_DS(d2, d3);
+SANITY_CHECK_DS(d3, d4);
+SANITY_CHECK_DS(d4, d5);
+SANITY_CHECK_DS(d5, d6);
+SANITY_CHECK_DS(d6, d7);
+SANITY_CHECK_DS(d7, d8);
+#undef SANITY_CHECK_DS
+#define SANITY_CHECK_SS(_0, _1)                                                \
+  static_assert((RegARM32::Reg_##_0 + 1) == RegARM32::Reg_##_1,                \
+                "ARM32 " #_0 " and " #_1 " registers are declared "            \
+                "incorrectly.")
+SANITY_CHECK_SS(s0, s1);
+SANITY_CHECK_SS(s1, s2);
+SANITY_CHECK_SS(s2, s3);
+SANITY_CHECK_SS(s3, s4);
+SANITY_CHECK_SS(s4, s5);
+SANITY_CHECK_SS(s5, s6);
+SANITY_CHECK_SS(s6, s7);
+SANITY_CHECK_SS(s7, s8);
+SANITY_CHECK_SS(s8, s9);
+SANITY_CHECK_SS(s9, s10);
+SANITY_CHECK_SS(s10, s11);
+SANITY_CHECK_SS(s11, s12);
+SANITY_CHECK_SS(s12, s13);
+SANITY_CHECK_SS(s13, s14);
+SANITY_CHECK_SS(s14, s15);
+#undef SANITY_CHECK_SS
+
 bool TargetARM32::CallingConv::FPInReg(Type Ty, int32_t *Reg) {
   if (!VFPRegsFree.any()) {
     return false;
@@ -885,9 +933,6 @@
     // Q registers are declared in reverse order, so RegARM32::Reg_q0 >
     // RegARM32::Reg_q1. Therefore, we need to subtract QRegStart from Reg_q0.
     // Same thing goes for D registers.
-    static_assert(RegARM32::Reg_q0 > RegARM32::Reg_q1,
-                  "ARM32 Q registers are possibly declared incorrectly.");
-
     int32_t QRegStart = (VFPRegsFree & ValidV128Regs).find_first();
     if (QRegStart >= 0) {
       VFPRegsFree.reset(QRegStart, QRegStart + 4);
@@ -895,9 +940,6 @@
       return true;
     }
   } else if (Ty == IceType_f64) {
-    static_assert(RegARM32::Reg_d0 > RegARM32::Reg_d1,
-                  "ARM32 D registers are possibly declared incorrectly.");
-
     int32_t DRegStart = (VFPRegsFree & ValidF64Regs).find_first();
     if (DRegStart >= 0) {
       VFPRegsFree.reset(DRegStart, DRegStart + 2);
@@ -905,9 +947,6 @@
       return true;
     }
   } else {
-    static_assert(RegARM32::Reg_s0 < RegARM32::Reg_s1,
-                  "ARM32 S registers are possibly declared incorrectly.");
-
     assert(Ty == IceType_f32);
     int32_t SReg = VFPRegsFree.find_first();
     assert(SReg >= 0);
@@ -1096,44 +1135,78 @@
 
   // Add push instructions for preserved registers. On ARM, "push" can push a
   // whole list of GPRs via a bitmask (0-15). Unlike x86, ARM also has
-  // callee-saved float/vector registers. The "vpush" instruction can handle a
-  // whole list of float/vector registers, but it only handles contiguous
-  // sequences of registers by specifying the start and the length.
-  VarList GPRsToPreserve;
-  GPRsToPreserve.reserve(CalleeSaves.size());
-  uint32_t NumCallee = 0;
-  size_t PreservedRegsSizeBytes = 0;
+  // callee-saved float/vector registers.
+  //
+  // The "vpush" instruction can handle a whole list of float/vector registers,
+  // but it only handles contiguous sequences of registers by specifying the
+  // start and the length.
+  PreservedGPRs.reserve(CalleeSaves.size());
+  PreservedSRegs.reserve(CalleeSaves.size());
+
   // Consider FP and LR as callee-save / used as needed.
   if (UsesFramePointer) {
+    if (RegsUsed[RegARM32::Reg_fp]) {
+      llvm::report_fatal_error("Frame pointer has been used.");
+    }
     CalleeSaves[RegARM32::Reg_fp] = true;
-    assert(RegsUsed[RegARM32::Reg_fp] == false);
     RegsUsed[RegARM32::Reg_fp] = true;
   }
   if (!MaybeLeafFunc) {
     CalleeSaves[RegARM32::Reg_lr] = true;
     RegsUsed[RegARM32::Reg_lr] = true;
   }
+
+  // Make two passes over the used registers. The first pass records all the
+  // used registers -- and their aliases. Then, we figure out which GPRs and
+  // VFP S registers should be saved. We don't bother saving D/Q registers
+  // because their uses are recorded as S regs uses.
+  llvm::SmallBitVector ToPreserve(RegARM32::Reg_NUM);
   for (SizeT i = 0; i < CalleeSaves.size(); ++i) {
-    if (RegARM32::isI64RegisterPair(i)) {
-      // We don't save register pairs explicitly. Instead, we rely on the code
-      // fake-defing/fake-using each register in the pair.
+    if (NeedSandboxing && i == RegARM32::Reg_r9) {
+      // r9 is never updated in sandboxed code.
       continue;
     }
     if (CalleeSaves[i] && RegsUsed[i]) {
-      if (NeedSandboxing && i == RegARM32::Reg_r9) {
-        // r9 is never updated in sandboxed code.
+      ToPreserve |= RegisterAliases[i];
+    }
+  }
+
+  uint32_t NumCallee = 0;
+  size_t PreservedRegsSizeBytes = 0;
+
+  // RegClasses is a tuple of
+  //
+  // <First Register in Class, Last Register in Class, Vector of Save Registers>
+  //
+  // We use this tuple to figure out which register we should push/pop during
+  // prolog/epilog.
+  using RegClassType = std::tuple<uint32_t, uint32_t, VarList *>;
+  const RegClassType RegClasses[] = {
+      RegClassType(RegARM32::Reg_GPR_First, RegARM32::Reg_GPR_Last,
+                   &PreservedGPRs),
+      RegClassType(RegARM32::Reg_SREG_First, RegARM32::Reg_SREG_Last,
+                   &PreservedSRegs)};
+  for (const auto &RegClass : RegClasses) {
+    const uint32_t FirstRegInClass = std::get<0>(RegClass);
+    const uint32_t LastRegInClass = std::get<1>(RegClass);
+    VarList *const PreservedRegsInClass = std::get<2>(RegClass);
+    for (uint32_t Reg = FirstRegInClass; Reg <= LastRegInClass; ++Reg) {
+      if (!ToPreserve[Reg]) {
         continue;
       }
       ++NumCallee;
-      Variable *PhysicalRegister = getPhysicalRegister(i);
+      Variable *PhysicalRegister = getPhysicalRegister(Reg);
       PreservedRegsSizeBytes +=
           typeWidthInBytesOnStack(PhysicalRegister->getType());
-      GPRsToPreserve.push_back(getPhysicalRegister(i));
+      PreservedRegsInClass->push_back(PhysicalRegister);
     }
   }
+
   Ctx->statsUpdateRegistersSaved(NumCallee);
-  if (!GPRsToPreserve.empty())
-    _push(GPRsToPreserve);
+  if (!PreservedSRegs.empty())
+    _push(PreservedSRegs);
+  if (!PreservedGPRs.empty())
+    _push(PreservedGPRs);
 
   // Generate "mov FP, SP" if needed.
   if (UsesFramePointer) {
@@ -1160,13 +1233,13 @@
       GlobalsSize + LocalsSlotsPaddingBytes;
 
   // Adds the out args space to the stack, and align SP if necessary.
-  if (NeedsStackAlignment) {
+  if (!NeedsStackAlignment) {
+    SpillAreaSizeBytes += MaxOutArgsSizeBytes;
+  } else {
     uint32_t StackOffset = PreservedRegsSizeBytes;
     uint32_t StackSize = applyStackAlignment(StackOffset + SpillAreaSizeBytes);
     StackSize = applyStackAlignment(StackSize + MaxOutArgsSizeBytes);
     SpillAreaSizeBytes = StackSize - StackOffset;
-  } else {
-    SpillAreaSizeBytes += MaxOutArgsSizeBytes;
   }
 
   // Combine fixed alloca with SpillAreaSize.
@@ -1285,43 +1358,21 @@
     }
   }
 
-  // Add pop instructions for preserved registers.
-  llvm::SmallBitVector CalleeSaves =
-      getRegisterSet(RegSet_CalleeSave, RegSet_None);
-  VarList GPRsToRestore;
-  GPRsToRestore.reserve(CalleeSaves.size());
-  // Consider FP and LR as callee-save / used as needed.
-  if (UsesFramePointer) {
-    CalleeSaves[RegARM32::Reg_fp] = true;
-  }
-  if (!MaybeLeafFunc) {
-    CalleeSaves[RegARM32::Reg_lr] = true;
-  }
-  // Pop registers in ascending order just like push (instead of in reverse
-  // order).
-  for (SizeT i = 0; i < CalleeSaves.size(); ++i) {
-    if (RegARM32::isI64RegisterPair(i)) {
-      continue;
-    }
-
-    if (CalleeSaves[i] && RegsUsed[i]) {
-      if (NeedSandboxing && i == RegARM32::Reg_r9) {
-        continue;
-      }
-      GPRsToRestore.push_back(getPhysicalRegister(i));
-    }
-  }
-  if (!GPRsToRestore.empty())
-    _pop(GPRsToRestore);
+  if (!PreservedGPRs.empty())
+    _pop(PreservedGPRs);
+  if (!PreservedSRegs.empty())
+    _pop(PreservedSRegs);
 
   if (!Ctx->getFlags().getUseSandboxing())
     return;
 
   // Change the original ret instruction into a sandboxed return sequence.
+  //
   // bundle_lock
   // bic lr, #0xc000000f
   // bx lr
   // bundle_unlock
+  //
   // This isn't just aligning to the getBundleAlignLog2Bytes(). It needs to
   // restrict to the lower 1GB as well.
   Variable *LR = getPhysicalRegister(RegARM32::Reg_lr);
@@ -2641,8 +2692,8 @@
 } // end of namespace StrengthReduction
 } // end of anonymous namespace
 
-void TargetARM32::lowerArithmetic(const InstArithmetic *Inst) {
-  Variable *Dest = Inst->getDest();
+void TargetARM32::lowerArithmetic(const InstArithmetic *Instr) {
+  Variable *Dest = Instr->getDest();
 
   if (Dest->isRematerializable()) {
     Context.insert(InstFakeDef::create(Func, Dest));
@@ -2651,14 +2702,14 @@
 
   Type DestTy = Dest->getType();
   if (DestTy == IceType_i1) {
-    lowerInt1Arithmetic(Inst);
+    lowerInt1Arithmetic(Instr);
     return;
   }
 
-  Operand *Src0 = legalizeUndef(Inst->getSrc(0));
-  Operand *Src1 = legalizeUndef(Inst->getSrc(1));
+  Operand *Src0 = legalizeUndef(Instr->getSrc(0));
+  Operand *Src1 = legalizeUndef(Instr->getSrc(1));
   if (DestTy == IceType_i64) {
-    lowerInt64Arithmetic(Inst->getOp(), Inst->getDest(), Src0, Src1);
+    lowerInt64Arithmetic(Instr->getOp(), Instr->getDest(), Src0, Src1);
     return;
   }
 
@@ -2679,7 +2730,7 @@
   // difficult to determine (constant may be moved to a register).
   // * Handle floating point arithmetic separately: they require Src1 to be
   // legalized to a register.
-  switch (Inst->getOp()) {
+  switch (Instr->getOp()) {
   default:
     break;
   case InstArithmetic::Udiv: {
@@ -2718,6 +2769,14 @@
   }
   case InstArithmetic::Fadd: {
     Variable *Src0R = legalizeToReg(Src0);
+    if (const Inst *Src1Producer = Computations.getProducerOf(Src1)) {
+      Variable *Src1R = legalizeToReg(Src1Producer->getSrc(0));
+      Variable *Src2R = legalizeToReg(Src1Producer->getSrc(1));
+      _vmla(Src0R, Src1R, Src2R);
+      _mov(Dest, Src0R);
+      return;
+    }
+
     Variable *Src1R = legalizeToReg(Src1);
     _vadd(T, Src0R, Src1R);
     _mov(Dest, T);
@@ -2725,6 +2784,13 @@
   }
   case InstArithmetic::Fsub: {
     Variable *Src0R = legalizeToReg(Src0);
+    if (const Inst *Src1Producer = Computations.getProducerOf(Src1)) {
+      Variable *Src1R = legalizeToReg(Src1Producer->getSrc(0));
+      Variable *Src2R = legalizeToReg(Src1Producer->getSrc(1));
+      _vmls(Src0R, Src1R, Src2R);
+      _mov(Dest, Src0R);
+      return;
+    }
     Variable *Src1R = legalizeToReg(Src1);
     _vsub(T, Src0R, Src1R);
     _mov(Dest, T);
@@ -2748,11 +2814,20 @@
 
   // Handle everything else here.
   Int32Operands Srcs(Src0, Src1);
-  switch (Inst->getOp()) {
+  switch (Instr->getOp()) {
   case InstArithmetic::_num:
     llvm::report_fatal_error("Unknown arithmetic operator");
     return;
   case InstArithmetic::Add: {
+    if (const Inst *Src1Producer = Computations.getProducerOf(Src1)) {
+      Variable *Src0R = legalizeToReg(Src0);
+      Variable *Src1R = legalizeToReg(Src1Producer->getSrc(0));
+      Variable *Src2R = legalizeToReg(Src1Producer->getSrc(1));
+      _mla(T, Src1R, Src2R, Src0R);
+      _mov(Dest, T);
+      return;
+    }
+
     if (Srcs.hasConstOperand()) {
       if (!Srcs.immediateIsFlexEncodable() &&
           Srcs.negatedImmediateIsFlexEncodable()) {
@@ -2805,6 +2880,15 @@
     return;
   }
   case InstArithmetic::Sub: {
+    if (const Inst *Src1Producer = Computations.getProducerOf(Src1)) {
+      Variable *Src0R = legalizeToReg(Src0);
+      Variable *Src1R = legalizeToReg(Src1Producer->getSrc(0));
+      Variable *Src2R = legalizeToReg(Src1Producer->getSrc(1));
+      _mls(T, Src1R, Src2R, Src0R);
+      _mov(Dest, T);
+      return;
+    }
+
     if (Srcs.hasConstOperand()) {
       if (Srcs.immediateIsFlexEncodable()) {
         Variable *Src0R = Srcs.src0R(this);
@@ -3013,7 +3097,7 @@
   InstARM32Label *NewShortCircuitLabel = nullptr;
   Operand *_1 = legalize(Ctx->getConstantInt1(1), Legal_Reg | Legal_Flex);
 
-  const Inst *Producer = BoolComputations.getProducerOf(Boolean);
+  const Inst *Producer = Computations.getProducerOf(Boolean);
 
   if (Producer == nullptr) {
     // No producer, no problem: just do emit code to perform (Boolean & 1) and
@@ -3234,7 +3318,7 @@
     case IceType_void:
       break;
     case IceType_i1:
-      assert(BoolComputations.getProducerOf(Dest) == nullptr);
+      assert(Computations.getProducerOf(Dest) == nullptr);
     // Fall-through intended.
     case IceType_i8:
     case IceType_i16:
@@ -5309,6 +5393,7 @@
   return Reg;
 }
 
+// TODO(jpp): remove unneeded else clauses in legalize.
 Operand *TargetARM32::legalize(Operand *From, LegalMask Allowed,
                                int32_t RegNum) {
   Type Ty = From->getType();
@@ -5412,24 +5497,27 @@
     }
     // There should be no constants of vector type (other than undef).
     assert(!isVectorType(Ty));
-    bool CanBeFlex = Allowed & Legal_Flex;
     if (auto *C32 = llvm::dyn_cast<ConstantInteger32>(From)) {
       uint32_t RotateAmt;
       uint32_t Immed_8;
       uint32_t Value = static_cast<uint32_t>(C32->getValue());
-      // Check if the immediate will fit in a Flexible second operand, if a
-      // Flexible second operand is allowed. We need to know the exact value,
-      // so that rules out relocatable constants. Also try the inverse and use
-      // MVN if possible.
-      if (CanBeFlex &&
-          OperandARM32FlexImm::canHoldImm(Value, &RotateAmt, &Immed_8)) {
-        return OperandARM32FlexImm::create(Func, Ty, Immed_8, RotateAmt);
-      } else if (CanBeFlex && OperandARM32FlexImm::canHoldImm(
-                                  ~Value, &RotateAmt, &Immed_8)) {
-        auto InvertedFlex =
+      if (OperandARM32FlexImm::canHoldImm(Value, &RotateAmt, &Immed_8)) {
+        // The immediate can be encoded as a Flex immediate. We may return the
+        // Flex operand if the caller has Allow'ed it.
+        auto *OpF = OperandARM32FlexImm::create(Func, Ty, Immed_8, RotateAmt);
+        const bool CanBeFlex = Allowed & Legal_Flex;
+        if (CanBeFlex)
+          return OpF;
+        return copyToReg(OpF, RegNum);
+      } else if (OperandARM32FlexImm::canHoldImm(~Value, &RotateAmt,
+                                                 &Immed_8)) {
+        // Even though the immediate can't be encoded as a Flex operand, its
+        // inverted bit pattern can, thus we use ARM's mvn to load the 32-bit
+        // constant with a single instruction.
+        auto *InvOpF =
             OperandARM32FlexImm::create(Func, Ty, Immed_8, RotateAmt);
         Variable *Reg = makeReg(Ty, RegNum);
-        _mvn(Reg, InvertedFlex);
+        _mvn(Reg, InvOpF);
         return Reg;
       } else {
         // Do a movw/movt to a register.
@@ -5486,8 +5574,6 @@
         return From;
       }
 
-      // TODO(jpp): We don't need to rematerialize Var if legalize() was invoked
-      // for a Variable in a Mem operand.
       Variable *T = makeReg(Var->getType(), RegNum);
       _mov(T, Var);
       return T;
@@ -5688,7 +5774,7 @@
   // FlagsWereSet is used to determine wether Boolean was folded or not. If not,
   // add an explicit _tst instruction below.
   bool FlagsWereSet = false;
-  if (const Inst *Producer = BoolComputations.getProducerOf(Boolean)) {
+  if (const Inst *Producer = Computations.getProducerOf(Boolean)) {
     switch (Producer->getKind()) {
     default:
       llvm::report_fatal_error("Unexpected producer.");
@@ -5772,7 +5858,7 @@
   Operand *_1 = legalize(Ctx->getConstantInt1(1), Legal_Reg | Legal_Flex);
 
   SafeBoolChain Safe = SBC_Yes;
-  if (const Inst *Producer = BoolComputations.getProducerOf(Boolean)) {
+  if (const Inst *Producer = Computations.getProducerOf(Boolean)) {
     switch (Producer->getKind()) {
     default:
       llvm::report_fatal_error("Unexpected producer.");
@@ -5884,9 +5970,75 @@
   }
 }
 } // end of namespace BoolFolding
+
+namespace FpFolding {
+bool shouldTrackProducer(const Inst &Instr) {
+  switch (Instr.getKind()) {
+  default:
+    return false;
+  case Inst::Arithmetic: {
+    switch (llvm::cast<InstArithmetic>(&Instr)->getOp()) {
+    default:
+      return false;
+    case InstArithmetic::Fmul:
+      return true;
+    }
+  }
+  }
+}
+
+bool isValidConsumer(const Inst &Instr) {
+  switch (Instr.getKind()) {
+  default:
+    return false;
+  case Inst::Arithmetic: {
+    switch (llvm::cast<InstArithmetic>(&Instr)->getOp()) {
+    default:
+      return false;
+    case InstArithmetic::Fadd:
+    case InstArithmetic::Fsub:
+      return true;
+    }
+  }
+  }
+}
+} // end of namespace FpFolding
+
+namespace IntFolding {
+bool shouldTrackProducer(const Inst &Instr) {
+  switch (Instr.getKind()) {
+  default:
+    return false;
+  case Inst::Arithmetic: {
+    switch (llvm::cast<InstArithmetic>(&Instr)->getOp()) {
+    default:
+      return false;
+    case InstArithmetic::Mul:
+      return true;
+    }
+  }
+  }
+}
+
+bool isValidConsumer(const Inst &Instr) {
+  switch (Instr.getKind()) {
+  default:
+    return false;
+  case Inst::Arithmetic: {
+    switch (llvm::cast<InstArithmetic>(&Instr)->getOp()) {
+    default:
+      return false;
+    case InstArithmetic::Add:
+    case InstArithmetic::Sub:
+      return true;
+    }
+  }
+  }
+}
+} // end of namespace FpFolding
 } // end of anonymous namespace
 
-void TargetARM32::BoolComputationTracker::recordProducers(CfgNode *Node) {
+void TargetARM32::ComputationTracker::recordProducers(CfgNode *Node) {
   for (Inst &Instr : Node->getInsts()) {
     // Check whether Instr is a valid producer.
     Variable *Dest = Instr.getDest();
@@ -5894,7 +6046,22 @@
         && Dest            // only instructions with an actual dest var; and
         && Dest->getType() == IceType_i1 // only bool-type dest vars; and
         && BoolFolding::shouldTrackProducer(Instr)) { // white-listed instr.
-      KnownComputations.emplace(Dest->getIndex(), BoolComputationEntry(&Instr));
+      KnownComputations.emplace(Dest->getIndex(),
+                                ComputationEntry(&Instr, IceType_i1));
+    }
+    if (!Instr.isDeleted() // only consider non-deleted instructions; and
+        && Dest            // only instructions with an actual dest var; and
+        && isScalarFloatingType(Dest->getType()) // fp-type only dest vars; and
+        && FpFolding::shouldTrackProducer(Instr)) { // white-listed instr.
+      KnownComputations.emplace(Dest->getIndex(),
+                                ComputationEntry(&Instr, Dest->getType()));
+    }
+    if (!Instr.isDeleted() // only consider non-deleted instructions; and
+        && Dest            // only instructions with an actual dest var; and
+        && Dest->getType() == IceType_i32            // i32 only dest vars; and
+        && IntFolding::shouldTrackProducer(Instr)) { // white-listed instr.
+      KnownComputations.emplace(Dest->getIndex(),
+                                ComputationEntry(&Instr, IceType_i32));
     }
     // Check each src variable against the map.
     FOREACH_VAR_IN_INST(Var, Instr) {
@@ -5905,9 +6072,29 @@
       }
 
       ++ComputationIter->second.NumUses;
-      if (!BoolFolding::isValidConsumer(Instr)) {
+      switch (ComputationIter->second.ComputationType) {
+      default:
         KnownComputations.erase(VarNum);
         continue;
+      case IceType_i1:
+        if (!BoolFolding::isValidConsumer(Instr)) {
+          KnownComputations.erase(VarNum);
+          continue;
+        }
+        break;
+      case IceType_i32:
+        if (IndexOfVarInInst(Var) != 1 || !IntFolding::isValidConsumer(Instr)) {
+          KnownComputations.erase(VarNum);
+          continue;
+        }
+        break;
+      case IceType_f32:
+      case IceType_f64:
+        if (IndexOfVarInInst(Var) != 1 || !FpFolding::isValidConsumer(Instr)) {
+          KnownComputations.erase(VarNum);
+          continue;
+        }
+        break;
       }
 
       if (Instr.isLastUse(Var)) {