LLVM API Documentation

LowerSwitch.cpp

Go to the documentation of this file.
00001 //===- LowerSwitch.cpp - Eliminate Switch instructions --------------------===//
00002 //
00003 //                     The LLVM Compiler Infrastructure
00004 //
00005 // This file was developed by the LLVM research group and is distributed under
00006 // the University of Illinois Open Source License. See LICENSE.TXT for details.
00007 //
00008 //===----------------------------------------------------------------------===//
00009 //
00010 // The LowerSwitch transformation rewrites switch statements with a sequence of
00011 // branches, which allows targets to get away with not implementing the switch
00012 // statement until it is convenient.
00013 //
00014 //===----------------------------------------------------------------------===//
00015 
00016 #include "llvm/Transforms/Scalar.h"
00017 #include "llvm/Transforms/Utils/UnifyFunctionExitNodes.h"
00018 #include "llvm/Constants.h"
00019 #include "llvm/Function.h"
00020 #include "llvm/Instructions.h"
00021 #include "llvm/Pass.h"
00022 #include "llvm/Support/Debug.h"
00023 #include "llvm/Support/Visibility.h"
00024 #include "llvm/ADT/Statistic.h"
00025 #include <algorithm>
00026 #include <iostream>
00027 using namespace llvm;
00028 
00029 namespace {
00030   Statistic<> NumLowered("lowerswitch", "Number of SwitchInst's replaced");
00031 
00032   /// LowerSwitch Pass - Replace all SwitchInst instructions with chained branch
00033   /// instructions.  Note that this cannot be a BasicBlock pass because it
00034   /// modifies the CFG!
00035   class VISIBILITY_HIDDEN LowerSwitch : public FunctionPass {
00036   public:
00037     virtual bool runOnFunction(Function &F);
00038     
00039     virtual void getAnalysisUsage(AnalysisUsage &AU) const {
00040       // This is a cluster of orthogonal Transforms 
00041       AU.addPreserved<UnifyFunctionExitNodes>();
00042       AU.addPreservedID(PromoteMemoryToRegisterID);
00043       AU.addPreservedID(LowerSelectID);
00044       AU.addPreservedID(LowerInvokePassID);
00045       AU.addPreservedID(LowerAllocationsID);
00046     }
00047         
00048     typedef std::pair<Constant*, BasicBlock*> Case;
00049     typedef std::vector<Case>::iterator       CaseItr;
00050   private:
00051     void processSwitchInst(SwitchInst *SI);
00052 
00053     BasicBlock* switchConvert(CaseItr Begin, CaseItr End, Value* Val,
00054                               BasicBlock* OrigBlock, BasicBlock* Default);
00055     BasicBlock* newLeafBlock(Case& Leaf, Value* Val,
00056                              BasicBlock* OrigBlock, BasicBlock* Default);
00057   };
00058 
00059   /// The comparison function for sorting the switch case values in the vector.
00060   struct CaseCmp {
00061     bool operator () (const LowerSwitch::Case& C1,
00062                       const LowerSwitch::Case& C2) {
00063       if (const ConstantUInt* U1 = dyn_cast<const ConstantUInt>(C1.first))
00064         return U1->getValue() < cast<const ConstantUInt>(C2.first)->getValue();
00065 
00066       const ConstantSInt* S1 = dyn_cast<const ConstantSInt>(C1.first);
00067       return S1->getValue() < cast<const ConstantSInt>(C2.first)->getValue();
00068     }
00069   };
00070 
00071   RegisterOpt<LowerSwitch>
00072   X("lowerswitch", "Lower SwitchInst's to branches");
00073 }
00074 
00075 // Publically exposed interface to pass...
00076 const PassInfo *llvm::LowerSwitchID = X.getPassInfo();
00077 // createLowerSwitchPass - Interface to this file...
00078 FunctionPass *llvm::createLowerSwitchPass() {
00079   return new LowerSwitch();
00080 }
00081 
00082 bool LowerSwitch::runOnFunction(Function &F) {
00083   bool Changed = false;
00084 
00085   for (Function::iterator I = F.begin(), E = F.end(); I != E; ) {
00086     BasicBlock *Cur = I++; // Advance over block so we don't traverse new blocks
00087 
00088     if (SwitchInst *SI = dyn_cast<SwitchInst>(Cur->getTerminator())) {
00089       Changed = true;
00090       processSwitchInst(SI);
00091     }
00092   }
00093 
00094   return Changed;
00095 }
00096 
00097 // operator<< - Used for debugging purposes.
00098 //
00099 std::ostream& operator<<(std::ostream &O,
00100                          const std::vector<LowerSwitch::Case> &C) {
00101   O << "[";
00102 
00103   for (std::vector<LowerSwitch::Case>::const_iterator B = C.begin(),
00104          E = C.end(); B != E; ) {
00105     O << *B->first;
00106     if (++B != E) O << ", ";
00107   }
00108 
00109   return O << "]";
00110 }
00111 
00112 // switchConvert - Convert the switch statement into a binary lookup of
00113 // the case values. The function recursively builds this tree.
00114 //
00115 BasicBlock* LowerSwitch::switchConvert(CaseItr Begin, CaseItr End,
00116                                        Value* Val, BasicBlock* OrigBlock,
00117                                        BasicBlock* Default)
00118 {
00119   unsigned Size = End - Begin;
00120 
00121   if (Size == 1)
00122     return newLeafBlock(*Begin, Val, OrigBlock, Default);
00123 
00124   unsigned Mid = Size / 2;
00125   std::vector<Case> LHS(Begin, Begin + Mid);
00126   DEBUG(std::cerr << "LHS: " << LHS << "\n");
00127   std::vector<Case> RHS(Begin + Mid, End);
00128   DEBUG(std::cerr << "RHS: " << RHS << "\n");
00129 
00130   Case& Pivot = *(Begin + Mid);
00131   DEBUG(std::cerr << "Pivot ==> "
00132                   << (int64_t)cast<ConstantInt>(Pivot.first)->getRawValue()
00133                   << "\n");
00134 
00135   BasicBlock* LBranch = switchConvert(LHS.begin(), LHS.end(), Val,
00136                                       OrigBlock, Default);
00137   BasicBlock* RBranch = switchConvert(RHS.begin(), RHS.end(), Val,
00138                                       OrigBlock, Default);
00139 
00140   // Create a new node that checks if the value is < pivot. Go to the
00141   // left branch if it is and right branch if not.
00142   Function* F = OrigBlock->getParent();
00143   BasicBlock* NewNode = new BasicBlock("NodeBlock");
00144   F->getBasicBlockList().insert(OrigBlock->getNext(), NewNode);
00145 
00146   SetCondInst* Comp = new SetCondInst(Instruction::SetLT, Val, Pivot.first,
00147                                       "Pivot");
00148   NewNode->getInstList().push_back(Comp);
00149   new BranchInst(LBranch, RBranch, Comp, NewNode);
00150   return NewNode;
00151 }
00152 
00153 // newLeafBlock - Create a new leaf block for the binary lookup tree. It
00154 // checks if the switch's value == the case's value. If not, then it
00155 // jumps to the default branch. At this point in the tree, the value
00156 // can't be another valid case value, so the jump to the "default" branch
00157 // is warranted.
00158 //
00159 BasicBlock* LowerSwitch::newLeafBlock(Case& Leaf, Value* Val,
00160                                       BasicBlock* OrigBlock,
00161                                       BasicBlock* Default)
00162 {
00163   Function* F = OrigBlock->getParent();
00164   BasicBlock* NewLeaf = new BasicBlock("LeafBlock");
00165   F->getBasicBlockList().insert(OrigBlock->getNext(), NewLeaf);
00166 
00167   // Make the seteq instruction...
00168   SetCondInst* Comp = new SetCondInst(Instruction::SetEQ, Val,
00169                                       Leaf.first, "SwitchLeaf");
00170   NewLeaf->getInstList().push_back(Comp);
00171 
00172   // Make the conditional branch...
00173   BasicBlock* Succ = Leaf.second;
00174   new BranchInst(Succ, Default, Comp, NewLeaf);
00175 
00176   // If there were any PHI nodes in this successor, rewrite one entry
00177   // from OrigBlock to come from NewLeaf.
00178   for (BasicBlock::iterator I = Succ->begin(); isa<PHINode>(I); ++I) {
00179     PHINode* PN = cast<PHINode>(I);
00180     int BlockIdx = PN->getBasicBlockIndex(OrigBlock);
00181     assert(BlockIdx != -1 && "Switch didn't go to this successor??");
00182     PN->setIncomingBlock((unsigned)BlockIdx, NewLeaf);
00183   }
00184 
00185   return NewLeaf;
00186 }
00187 
00188 // processSwitchInst - Replace the specified switch instruction with a sequence
00189 // of chained if-then insts in a balanced binary search.
00190 //
00191 void LowerSwitch::processSwitchInst(SwitchInst *SI) {
00192   BasicBlock *CurBlock = SI->getParent();
00193   BasicBlock *OrigBlock = CurBlock;
00194   Function *F = CurBlock->getParent();
00195   Value *Val = SI->getOperand(0);  // The value we are switching on...
00196   BasicBlock* Default = SI->getDefaultDest();
00197 
00198   // If there is only the default destination, don't bother with the code below.
00199   if (SI->getNumOperands() == 2) {
00200     new BranchInst(SI->getDefaultDest(), CurBlock);
00201     CurBlock->getInstList().erase(SI);
00202     return;
00203   }
00204 
00205   // Create a new, empty default block so that the new hierarchy of
00206   // if-then statements go to this and the PHI nodes are happy.
00207   BasicBlock* NewDefault = new BasicBlock("NewDefault");
00208   F->getBasicBlockList().insert(Default, NewDefault);
00209 
00210   new BranchInst(Default, NewDefault);
00211 
00212   // If there is an entry in any PHI nodes for the default edge, make sure
00213   // to update them as well.
00214   for (BasicBlock::iterator I = Default->begin(); isa<PHINode>(I); ++I) {
00215     PHINode *PN = cast<PHINode>(I);
00216     int BlockIdx = PN->getBasicBlockIndex(OrigBlock);
00217     assert(BlockIdx != -1 && "Switch didn't go to this successor??");
00218     PN->setIncomingBlock((unsigned)BlockIdx, NewDefault);
00219   }
00220 
00221   std::vector<Case> Cases;
00222 
00223   // Expand comparisons for all of the non-default cases...
00224   for (unsigned i = 1; i < SI->getNumSuccessors(); ++i)
00225     Cases.push_back(Case(SI->getSuccessorValue(i), SI->getSuccessor(i)));
00226 
00227   std::sort(Cases.begin(), Cases.end(), CaseCmp());
00228   DEBUG(std::cerr << "Cases: " << Cases << "\n");
00229   BasicBlock* SwitchBlock = switchConvert(Cases.begin(), Cases.end(), Val,
00230                                           OrigBlock, NewDefault);
00231 
00232   // Branch to our shiny new if-then stuff...
00233   new BranchInst(SwitchBlock, OrigBlock);
00234 
00235   // We are now done with the switch instruction, delete it.
00236   CurBlock->getInstList().erase(SI);
00237 }