LLVM API Documentation

Main Page | Namespace List | Class Hierarchy | Alphabetical List | Class List | Directories | File List | Namespace Members | Class Members | File Members | Related Pages

Parallelize.cpp

Go to the documentation of this file.
00001 //===- Parallelize.cpp - Auto parallelization using DS Graphs -------------===//
00002 // 
00003 //                     The LLVM Compiler Infrastructure
00004 //
00005 // This file was developed by the LLVM research group and is distributed under
00006 // the University of Illinois Open Source License. See LICENSE.TXT for details.
00007 // 
00008 //===----------------------------------------------------------------------===//
00009 //
00010 // This file implements a pass that automatically parallelizes a program,
00011 // using the Cilk multi-threaded runtime system to execute parallel code.
00012 // 
00013 // The pass uses the Program Dependence Graph (class PDGIterator) to
00014 // identify parallelizable function calls, i.e., calls whose instances
00015 // can be executed in parallel with instances of other function calls.
00016 // (In the future, this should also execute different instances of the same
00017 // function call in parallel, but that requires parallelizing across
00018 // loop iterations.)
00019 //
00020 // The output of the pass is LLVM code with:
00021 // (1) all parallelizable functions renamed to flag them as parallelizable;
00022 // (2) calls to a sync() function introduced at synchronization points.
00023 // The CWriter recognizes these functions and inserts the appropriate Cilk
00024 // keywords when writing out C code.  This C code must be compiled with cilk2c.
00025 // 
00026 // Current algorithmic limitations:
00027 // -- no array dependence analysis
00028 // -- no parallelization for function calls in different loop iterations
00029 //    (except in unlikely trivial cases)
00030 //
00031 // Limitations of using Cilk:
00032 // -- No parallelism within a function body, e.g., in a loop;
00033 // -- Simplistic synchronization model requiring all parallel threads 
00034 //    created within a function to block at a sync().
00035 // -- Excessive overhead at "spawned" function calls, which has no benefit
00036 //    once all threads are busy (especially common when the degree of
00037 //    parallelism is low).
00038 //
00039 //===----------------------------------------------------------------------===//
00040 
00041 #include "llvm/DerivedTypes.h"
00042 #include "llvm/Instructions.h"
00043 #include "llvm/Module.h"
00044 #include "PgmDependenceGraph.h"
00045 #include "llvm/Analysis/DataStructure/DataStructure.h"
00046 #include "llvm/Analysis/DataStructure/DSGraph.h"
00047 #include "llvm/Support/InstVisitor.h"
00048 #include "llvm/Transforms/Utils/Local.h"
00049 #include "llvm/ADT/Statistic.h"
00050 #include "llvm/ADT/STLExtras.h"
00051 #include "llvm/ADT/hash_set"
00052 #include "llvm/ADT/hash_map"
00053 #include <functional>
00054 #include <algorithm>
00055 using namespace llvm;
00056 
00057 //---------------------------------------------------------------------------- 
00058 // Global constants used in marking Cilk functions and function calls.
00059 //---------------------------------------------------------------------------- 
00060 
00061 static const char * const CilkSuffix = ".llvm2cilk";
00062 static const char * const DummySyncFuncName = "__sync.llvm2cilk";
00063 
00064 //---------------------------------------------------------------------------- 
00065 // Routines to identify Cilk functions, calls to Cilk functions, and syncs.
00066 //---------------------------------------------------------------------------- 
00067 
00068 static bool isCilk(const Function& F) {
00069   return (F.getName().rfind(CilkSuffix) ==
00070           F.getName().size() - std::strlen(CilkSuffix));
00071 }
00072 
00073 static bool isCilkMain(const Function& F) {
00074   return F.getName() == "main" + std::string(CilkSuffix);
00075 }
00076 
00077 
00078 static bool isCilk(const CallInst& CI) {
00079   return CI.getCalledFunction() && isCilk(*CI.getCalledFunction());
00080 }
00081 
00082 static bool isSync(const CallInst& CI) { 
00083   return CI.getCalledFunction() &&
00084          CI.getCalledFunction()->getName() == DummySyncFuncName;
00085 }
00086 
00087 
00088 //---------------------------------------------------------------------------- 
00089 // class Cilkifier
00090 //
00091 // Code generation pass that transforms code to identify where Cilk keywords
00092 // should be inserted.  This relies on `llvm-dis -c' to print out the keywords.
00093 //---------------------------------------------------------------------------- 
00094 class Cilkifier: public InstVisitor<Cilkifier> {
00095   Function* DummySyncFunc;
00096 
00097   // Data used when transforming each function.
00098   hash_set<const Instruction*>  stmtsVisited;    // Flags for recursive DFS
00099   hash_map<const CallInst*, hash_set<CallInst*> > spawnToSyncsMap;
00100 
00101   // Input data for the transformation.
00102   const hash_set<Function*>*    cilkFunctions;   // Set of parallel functions
00103   PgmDependenceGraph*           depGraph;
00104 
00105   void          DFSVisitInstr   (Instruction* I,
00106                                  Instruction* root,
00107                                  hash_set<const Instruction*>& depsOfRoot);
00108 
00109 public:
00110   /*ctor*/      Cilkifier       (Module& M);
00111 
00112   // Transform a single function including its name, its call sites, and syncs
00113   // 
00114   void          TransformFunc   (Function* F,
00115                                  const hash_set<Function*>& cilkFunctions,
00116                                  PgmDependenceGraph&  _depGraph);
00117 
00118   // The visitor function that does most of the hard work, via DFSVisitInstr
00119   // 
00120   void visitCallInst(CallInst& CI);
00121 };
00122 
00123 
00124 Cilkifier::Cilkifier(Module& M) {
00125   // create the dummy Sync function and add it to the Module
00126   DummySyncFunc = M.getOrInsertFunction(DummySyncFuncName, Type::VoidTy, 0);
00127 }
00128 
00129 void Cilkifier::TransformFunc(Function* F,
00130                               const hash_set<Function*>& _cilkFunctions,
00131                               PgmDependenceGraph& _depGraph) {
00132   // Memoize the information for this function
00133   cilkFunctions = &_cilkFunctions;
00134   depGraph = &_depGraph;
00135 
00136   // Add the marker suffix to the Function name
00137   // This should automatically mark all calls to the function also!
00138   F->setName(F->getName() + CilkSuffix);
00139 
00140   // Insert sync operations for each separate spawn
00141   visit(*F);
00142 
00143   // Now traverse the CFG in rPostorder and eliminate redundant syncs, i.e.,
00144   // two consecutive sync's on a straight-line path with no intervening spawn.
00145   
00146 }
00147 
00148 
00149 void Cilkifier::DFSVisitInstr(Instruction* I,
00150                               Instruction* root,
00151                               hash_set<const Instruction*>& depsOfRoot)
00152 {
00153   assert(stmtsVisited.find(I) == stmtsVisited.end());
00154   stmtsVisited.insert(I);
00155 
00156   // If there is a dependence from root to I, insert Sync and return
00157   if (depsOfRoot.find(I) != depsOfRoot.end()) {
00158     // Insert a sync before I and stop searching along this path.
00159     // If I is a Phi instruction, the dependence can only be an SSA dep.
00160     // and we need to insert the sync in the predecessor on the appropriate
00161     // incoming edge!
00162     CallInst* syncI = 0;
00163     if (PHINode* phiI = dyn_cast<PHINode>(I)) {
00164       // check all operands of the Phi and insert before each one
00165       for (unsigned i = 0, N = phiI->getNumIncomingValues(); i < N; ++i)
00166         if (phiI->getIncomingValue(i) == root)
00167           syncI = new CallInst(DummySyncFunc, std::vector<Value*>(), "",
00168                                phiI->getIncomingBlock(i)->getTerminator());
00169     } else
00170       syncI = new CallInst(DummySyncFunc, std::vector<Value*>(), "", I);
00171 
00172     // Remember the sync for each spawn to eliminate redundant ones later
00173     spawnToSyncsMap[cast<CallInst>(root)].insert(syncI);
00174 
00175     return;
00176   }
00177 
00178   // else visit unvisited successors
00179   if (BranchInst* brI = dyn_cast<BranchInst>(I)) {
00180     // visit first instruction in each successor BB
00181     for (unsigned i = 0, N = brI->getNumSuccessors(); i < N; ++i)
00182       if (stmtsVisited.find(&brI->getSuccessor(i)->front())
00183           == stmtsVisited.end())
00184         DFSVisitInstr(&brI->getSuccessor(i)->front(), root, depsOfRoot);
00185   } else
00186     if (Instruction* nextI = I->getNext())
00187       if (stmtsVisited.find(nextI) == stmtsVisited.end())
00188         DFSVisitInstr(nextI, root, depsOfRoot);
00189 }
00190 
00191 
00192 void Cilkifier::visitCallInst(CallInst& CI)
00193 {
00194   assert(CI.getCalledFunction() != 0 && "Only direct calls can be spawned.");
00195   if (cilkFunctions->find(CI.getCalledFunction()) == cilkFunctions->end())
00196     return;                             // not a spawn
00197 
00198   // Find all the outgoing memory dependences.
00199   hash_set<const Instruction*> depsOfRoot;
00200   for (PgmDependenceGraph::iterator DI =
00201          depGraph->outDepBegin(CI, MemoryDeps); ! DI.fini(); ++DI)
00202     depsOfRoot.insert(&DI->getSink()->getInstr());
00203 
00204   // Now find all outgoing SSA dependences to the eventual non-Phi users of
00205   // the call value (i.e., direct users that are not phis, and for any
00206   // user that is a Phi, direct non-Phi users of that Phi, and recursively).
00207   std::vector<const PHINode*> phiUsers;
00208   hash_set<const PHINode*> phisSeen;    // ensures we don't visit a phi twice
00209   for (Value::use_iterator UI=CI.use_begin(), UE=CI.use_end(); UI != UE; ++UI)
00210     if (const PHINode* phiUser = dyn_cast<PHINode>(*UI)) {
00211       if (phisSeen.find(phiUser) == phisSeen.end()) {
00212         phiUsers.push_back(phiUser);
00213         phisSeen.insert(phiUser);
00214       }
00215     }
00216     else
00217       depsOfRoot.insert(cast<Instruction>(*UI));
00218 
00219   // Now we've found the non-Phi users and immediate phi users.
00220   // Recursively walk the phi users and add their non-phi users.
00221   for (const PHINode* phiUser; !phiUsers.empty(); phiUsers.pop_back()) {
00222     phiUser = phiUsers.back();
00223     for (Value::use_const_iterator UI=phiUser->use_begin(),
00224            UE=phiUser->use_end(); UI != UE; ++UI)
00225       if (const PHINode* pn = dyn_cast<PHINode>(*UI)) {
00226         if (phisSeen.find(pn) == phisSeen.end()) {
00227           phiUsers.push_back(pn);
00228           phisSeen.insert(pn);
00229         }
00230       } else
00231         depsOfRoot.insert(cast<Instruction>(*UI));
00232   }
00233 
00234   // Walk paths of the CFG starting at the call instruction and insert
00235   // one sync before the first dependence on each path, if any.
00236   if (! depsOfRoot.empty()) {
00237     stmtsVisited.clear();             // start a new DFS for this CallInst
00238     assert(CI.getNext() && "Call instruction cannot be a terminator!");
00239     DFSVisitInstr(CI.getNext(), &CI, depsOfRoot);
00240   }
00241 
00242   // Now, eliminate all users of the SSA value of the CallInst, i.e., 
00243   // if the call instruction returns a value, delete the return value
00244   // register and replace it by a stack slot.
00245   if (CI.getType() != Type::VoidTy)
00246     DemoteRegToStack(CI);
00247 }
00248 
00249 
00250 //---------------------------------------------------------------------------- 
00251 // class FindParallelCalls
00252 //
00253 // Find all CallInst instructions that have at least one other CallInst
00254 // that is independent.  These are the instructions that can produce
00255 // useful parallelism.
00256 //---------------------------------------------------------------------------- 
00257 
00258 class FindParallelCalls : public InstVisitor<FindParallelCalls> {
00259   typedef hash_set<CallInst*>           DependentsSet;
00260   typedef DependentsSet::iterator       Dependents_iterator;
00261   typedef DependentsSet::const_iterator Dependents_const_iterator;
00262 
00263   PgmDependenceGraph& depGraph;         // dependence graph for the function
00264   hash_set<Instruction*> stmtsVisited;  // flags for DFS walk of depGraph
00265   hash_map<CallInst*, bool > completed; // flags marking if a CI is done
00266   hash_map<CallInst*, DependentsSet> dependents; // dependent CIs for each CI
00267 
00268   void VisitOutEdges(Instruction*   I,
00269                      CallInst*      root,
00270                      DependentsSet& depsOfRoot);
00271 
00272   FindParallelCalls(const FindParallelCalls &); // DO NOT IMPLEMENT
00273   void operator=(const FindParallelCalls&);     // DO NOT IMPLEMENT
00274 public:
00275   std::vector<CallInst*> parallelCalls;
00276 
00277 public:
00278   /*ctor*/      FindParallelCalls       (Function& F, PgmDependenceGraph& DG);
00279   void          visitCallInst           (CallInst& CI);
00280 };
00281 
00282 
00283 FindParallelCalls::FindParallelCalls(Function& F,
00284                                      PgmDependenceGraph& DG)
00285   : depGraph(DG)
00286 {
00287   // Find all CallInsts reachable from each CallInst using a recursive DFS
00288   visit(F);
00289 
00290   // Now we've found all CallInsts reachable from each CallInst.
00291   // Find those CallInsts that are parallel with at least one other CallInst
00292   // by counting total inEdges and outEdges.
00293   unsigned long totalNumCalls = completed.size();
00294 
00295   if (totalNumCalls == 1) {
00296     // Check first for the special case of a single call instruction not
00297     // in any loop.  It is not parallel, even if it has no dependences
00298     // (this is why it is a special case).
00299     //
00300     // FIXME:
00301     // THIS CASE IS NOT HANDLED RIGHT NOW, I.E., THERE IS NO
00302     // PARALLELISM FOR CALLS IN DIFFERENT ITERATIONS OF A LOOP.
00303     return;
00304   }
00305 
00306   hash_map<CallInst*, unsigned long> numDeps;
00307   for (hash_map<CallInst*, DependentsSet>::iterator II = dependents.begin(),
00308          IE = dependents.end(); II != IE; ++II) {
00309     CallInst* fromCI = II->first;
00310     numDeps[fromCI] += II->second.size();
00311     for (Dependents_iterator DI = II->second.begin(), DE = II->second.end();
00312          DI != DE; ++DI)
00313       numDeps[*DI]++;                 // *DI can be reached from II->first
00314   }
00315 
00316   for (hash_map<CallInst*, DependentsSet>::iterator
00317          II = dependents.begin(), IE = dependents.end(); II != IE; ++II)
00318 
00319     // FIXME: Remove "- 1" when considering parallelism in loops
00320     if (numDeps[II->first] < totalNumCalls - 1)
00321       parallelCalls.push_back(II->first);
00322 }
00323 
00324 
00325 void FindParallelCalls::VisitOutEdges(Instruction* I,
00326                                       CallInst* root,
00327                                       DependentsSet& depsOfRoot)
00328 {
00329   assert(stmtsVisited.find(I) == stmtsVisited.end() && "Stmt visited twice?");
00330   stmtsVisited.insert(I);
00331 
00332   if (CallInst* CI = dyn_cast<CallInst>(I))
00333     // FIXME: Ignoring parallelism in a loop.  Here we're actually *ignoring*
00334     // a self-dependence in order to get the count comparison right above.
00335     // When we include loop parallelism, self-dependences should be included.
00336     if (CI != root) {
00337       // CallInst root has a path to CallInst I and any calls reachable from I
00338       depsOfRoot.insert(CI);
00339       if (completed[CI]) {
00340         // We have already visited I so we know all nodes it can reach!
00341         DependentsSet& depsOfI = dependents[CI];
00342         depsOfRoot.insert(depsOfI.begin(), depsOfI.end());
00343         return;
00344       }
00345     }
00346 
00347   // If we reach here, we need to visit all children of I
00348   for (PgmDependenceGraph::iterator DI = depGraph.outDepBegin(*I);
00349        ! DI.fini(); ++DI) {
00350     Instruction* sink = &DI->getSink()->getInstr();
00351     if (stmtsVisited.find(sink) == stmtsVisited.end())
00352       VisitOutEdges(sink, root, depsOfRoot);
00353   }
00354 }
00355 
00356 
00357 void FindParallelCalls::visitCallInst(CallInst& CI) {
00358   if (completed[&CI])
00359     return;
00360   stmtsVisited.clear();                      // clear flags to do a fresh DFS
00361 
00362   // Visit all children of CI using a recursive walk through dep graph
00363   DependentsSet& depsOfRoot = dependents[&CI];
00364   for (PgmDependenceGraph::iterator DI = depGraph.outDepBegin(CI);
00365        ! DI.fini(); ++DI) {
00366     Instruction* sink = &DI->getSink()->getInstr();
00367     if (stmtsVisited.find(sink) == stmtsVisited.end())
00368       VisitOutEdges(sink, &CI, depsOfRoot);
00369   }
00370 
00371   completed[&CI] = true;
00372 }
00373 
00374 
00375 //---------------------------------------------------------------------------- 
00376 // class Parallelize
00377 //
00378 // (1) Find candidate parallel functions: any function F s.t.
00379 //       there is a call C1 to the function F that is followed or preceded
00380 //       by at least one other call C2 that is independent of this one
00381 //       (i.e., there is no dependence path from C1 to C2 or C2 to C1)
00382 // (2) Label such a function F as a cilk function.
00383 // (3) Convert every call to F to a spawn
00384 // (4) For every function X, insert sync statements so that
00385 //        every spawn is postdominated by a sync before any statements
00386 //        with a data dependence to/from the call site for the spawn
00387 // 
00388 //---------------------------------------------------------------------------- 
00389 
00390 namespace {
00391   class Parallelize : public ModulePass {
00392   public:
00393     /// Driver functions to transform a program
00394     ///
00395     bool runOnModule(Module& M);
00396 
00397     /// getAnalysisUsage - Modifies extensively so preserve nothing.
00398     /// Uses the DependenceGraph and the Top-down DS Graph (only to find
00399     /// all functions called via an indirect call).
00400     ///
00401     void getAnalysisUsage(AnalysisUsage &AU) const {
00402       AU.addRequired<TDDataStructures>();
00403       AU.addRequired<MemoryDepAnalysis>();  // force this not to be released
00404       AU.addRequired<PgmDependenceGraph>(); // because it is needed by this
00405     }
00406   };
00407 
00408   RegisterOpt<Parallelize> X("parallel", "Parallelize program using Cilk");
00409 }
00410 
00411 
00412 bool Parallelize::runOnModule(Module& M) {
00413   hash_set<Function*> parallelFunctions;
00414   hash_set<Function*> safeParallelFunctions;
00415   hash_set<const GlobalValue*> indirectlyCalled;
00416 
00417   // If there is no main (i.e., for an incomplete program), we can do nothing.
00418   // If there is a main, mark main as a parallel function.
00419   Function* mainFunc = M.getMainFunction();
00420   if (!mainFunc)
00421     return false;
00422 
00423   // (1) Find candidate parallel functions and mark them as Cilk functions
00424   for (Module::iterator FI = M.begin(), FE = M.end(); FI != FE; ++FI)
00425     if (! FI->isExternal()) {
00426       Function* F = FI;
00427       DSGraph& tdg = getAnalysis<TDDataStructures>().getDSGraph(*F);
00428 
00429       // All the hard analysis work gets done here!
00430       FindParallelCalls finder(*F,
00431                               getAnalysis<PgmDependenceGraph>().getGraph(*F));
00432                       /* getAnalysis<MemoryDepAnalysis>().getGraph(*F)); */
00433 
00434       // Now we know which call instructions are useful to parallelize.
00435       // Remember those callee functions.
00436       for (std::vector<CallInst*>::iterator
00437              CII = finder.parallelCalls.begin(),
00438              CIE = finder.parallelCalls.end(); CII != CIE; ++CII) {
00439         // Check if this is a direct call...
00440         if ((*CII)->getCalledFunction() != NULL) {
00441           // direct call: if this is to a non-external function,
00442           // mark it as a parallelizable function
00443           if (! (*CII)->getCalledFunction()->isExternal())
00444             parallelFunctions.insert((*CII)->getCalledFunction());
00445         } else {
00446           // Indirect call: mark all potential callees as bad
00447           std::vector<GlobalValue*> callees =
00448             tdg.getNodeForValue((*CII)->getCalledValue())
00449             .getNode()->getGlobals();
00450           indirectlyCalled.insert(callees.begin(), callees.end());
00451         }
00452       }
00453     }
00454 
00455   // Remove all indirectly called functions from the list of Cilk functions.
00456   for (hash_set<Function*>::iterator PFI = parallelFunctions.begin(),
00457          PFE = parallelFunctions.end(); PFI != PFE; ++PFI)
00458     if (indirectlyCalled.count(*PFI) == 0)
00459       safeParallelFunctions.insert(*PFI);
00460 
00461 #undef CAN_USE_BIND1ST_ON_REFERENCE_TYPE_ARGS
00462 #ifdef CAN_USE_BIND1ST_ON_REFERENCE_TYPE_ARGS
00463   // Use this indecipherable STLese because erase invalidates iterators.
00464   // Otherwise we have to copy sets as above.
00465   hash_set<Function*>::iterator extrasBegin = 
00466     std::remove_if(parallelFunctions.begin(), parallelFunctions.end(),
00467                    compose1(std::bind2nd(std::greater<int>(), 0),
00468                             bind_obj(&indirectlyCalled,
00469                                      &hash_set<const GlobalValue*>::count)));
00470   parallelFunctions.erase(extrasBegin, parallelFunctions.end());
00471 #endif
00472 
00473   // If there are no parallel functions, we can just give up.
00474   if (safeParallelFunctions.empty())
00475     return false;
00476 
00477   // Add main as a parallel function since Cilk requires this.
00478   safeParallelFunctions.insert(mainFunc);
00479 
00480   // (2,3) Transform each Cilk function and all its calls simply by
00481   //     adding a unique suffix to the function name.
00482   //     This should identify both functions and calls to such functions
00483   //     to the code generator.
00484   // (4) Also, insert calls to sync at appropriate points.
00485   Cilkifier cilkifier(M);
00486   for (hash_set<Function*>::iterator CFI = safeParallelFunctions.begin(),
00487          CFE = safeParallelFunctions.end(); CFI != CFE; ++CFI) {
00488     cilkifier.TransformFunc(*CFI, safeParallelFunctions,
00489                            getAnalysis<PgmDependenceGraph>().getGraph(**CFI));
00490     /* getAnalysis<MemoryDepAnalysis>().getGraph(**CFI)); */
00491   }
00492 
00493   return true;
00494 }
00495