LLVM API Documentation

UnifyFunctionExitNodes.cpp

Go to the documentation of this file.
00001 //===- UnifyFunctionExitNodes.cpp - Make all functions have a single exit -===//
00002 //
00003 //                     The LLVM Compiler Infrastructure
00004 //
00005 // This file was developed by the LLVM research group and is distributed under
00006 // the University of Illinois Open Source License. See LICENSE.TXT for details.
00007 //
00008 //===----------------------------------------------------------------------===//
00009 //
00010 // This pass is used to ensure that functions have at most one return
00011 // instruction in them.  Additionally, it keeps track of which node is the new
00012 // exit node of the CFG.  If there are no exit nodes in the CFG, the getExitNode
00013 // method will return a null pointer.
00014 //
00015 //===----------------------------------------------------------------------===//
00016 
00017 #include "llvm/Transforms/Utils/UnifyFunctionExitNodes.h"
00018 #include "llvm/Transforms/Scalar.h"
00019 #include "llvm/BasicBlock.h"
00020 #include "llvm/Function.h"
00021 #include "llvm/Instructions.h"
00022 #include "llvm/Type.h"
00023 using namespace llvm;
00024 
00025 static RegisterOpt<UnifyFunctionExitNodes>
00026 X("mergereturn", "Unify function exit nodes");
00027 
00028 Pass *llvm::createUnifyFunctionExitNodesPass() {
00029   return new UnifyFunctionExitNodes();
00030 }
00031 
00032 void UnifyFunctionExitNodes::getAnalysisUsage(AnalysisUsage &AU) const{
00033   // We preserve the non-critical-edgeness property
00034   AU.addPreservedID(BreakCriticalEdgesID);
00035 }
00036 
00037 // UnifyAllExitNodes - Unify all exit nodes of the CFG by creating a new
00038 // BasicBlock, and converting all returns to unconditional branches to this
00039 // new basic block.  The singular exit node is returned.
00040 //
00041 // If there are no return stmts in the Function, a null pointer is returned.
00042 //
00043 bool UnifyFunctionExitNodes::runOnFunction(Function &F) {
00044   // Loop over all of the blocks in a function, tracking all of the blocks that
00045   // return.
00046   //
00047   std::vector<BasicBlock*> ReturningBlocks;
00048   std::vector<BasicBlock*> UnwindingBlocks;
00049   std::vector<BasicBlock*> UnreachableBlocks;
00050   for(Function::iterator I = F.begin(), E = F.end(); I != E; ++I)
00051     if (isa<ReturnInst>(I->getTerminator()))
00052       ReturningBlocks.push_back(I);
00053     else if (isa<UnwindInst>(I->getTerminator()))
00054       UnwindingBlocks.push_back(I);
00055     else if (isa<UnreachableInst>(I->getTerminator()))
00056       UnreachableBlocks.push_back(I);
00057 
00058   // Handle unwinding blocks first.
00059   if (UnwindingBlocks.empty()) {
00060     UnwindBlock = 0;
00061   } else if (UnwindingBlocks.size() == 1) {
00062     UnwindBlock = UnwindingBlocks.front();
00063   } else {
00064     UnwindBlock = new BasicBlock("UnifiedUnwindBlock", &F);
00065     new UnwindInst(UnwindBlock);
00066 
00067     for (std::vector<BasicBlock*>::iterator I = UnwindingBlocks.begin(),
00068            E = UnwindingBlocks.end(); I != E; ++I) {
00069       BasicBlock *BB = *I;
00070       BB->getInstList().pop_back();  // Remove the unwind insn
00071       new BranchInst(UnwindBlock, BB);
00072     }
00073   }
00074 
00075   // Then unreachable blocks.
00076   if (UnreachableBlocks.empty()) {
00077     UnreachableBlock = 0;
00078   } else if (UnreachableBlocks.size() == 1) {
00079     UnreachableBlock = UnreachableBlocks.front();
00080   } else {
00081     UnreachableBlock = new BasicBlock("UnifiedUnreachableBlock", &F);
00082     new UnreachableInst(UnreachableBlock);
00083 
00084     for (std::vector<BasicBlock*>::iterator I = UnreachableBlocks.begin(),
00085            E = UnreachableBlocks.end(); I != E; ++I) {
00086       BasicBlock *BB = *I;
00087       BB->getInstList().pop_back();  // Remove the unreachable inst.
00088       new BranchInst(UnreachableBlock, BB);
00089     }
00090   }
00091 
00092   // Now handle return blocks.
00093   if (ReturningBlocks.empty()) {
00094     ReturnBlock = 0;
00095     return false;                          // No blocks return
00096   } else if (ReturningBlocks.size() == 1) {
00097     ReturnBlock = ReturningBlocks.front(); // Already has a single return block
00098     return false;
00099   }
00100 
00101   // Otherwise, we need to insert a new basic block into the function, add a PHI
00102   // node (if the function returns a value), and convert all of the return
00103   // instructions into unconditional branches.
00104   //
00105   BasicBlock *NewRetBlock = new BasicBlock("UnifiedReturnBlock", &F);
00106 
00107   PHINode *PN = 0;
00108   if (F.getReturnType() != Type::VoidTy) {
00109     // If the function doesn't return void... add a PHI node to the block...
00110     PN = new PHINode(F.getReturnType(), "UnifiedRetVal");
00111     NewRetBlock->getInstList().push_back(PN);
00112   }
00113   new ReturnInst(PN, NewRetBlock);
00114 
00115   // Loop over all of the blocks, replacing the return instruction with an
00116   // unconditional branch.
00117   //
00118   for (std::vector<BasicBlock*>::iterator I = ReturningBlocks.begin(),
00119          E = ReturningBlocks.end(); I != E; ++I) {
00120     BasicBlock *BB = *I;
00121 
00122     // Add an incoming element to the PHI node for every return instruction that
00123     // is merging into this new block...
00124     if (PN) PN->addIncoming(BB->getTerminator()->getOperand(0), BB);
00125 
00126     BB->getInstList().pop_back();  // Remove the return insn
00127     new BranchInst(NewRetBlock, BB);
00128   }
00129   ReturnBlock = NewRetBlock;
00130   return true;
00131 }