LLVM API Documentation
00001 //===- Parallelize.cpp - Auto parallelization using DS Graphs -------------===// 00002 // 00003 // The LLVM Compiler Infrastructure 00004 // 00005 // This file was developed by the LLVM research group and is distributed under 00006 // the University of Illinois Open Source License. See LICENSE.TXT for details. 00007 // 00008 //===----------------------------------------------------------------------===// 00009 // 00010 // This file implements a pass that automatically parallelizes a program, 00011 // using the Cilk multi-threaded runtime system to execute parallel code. 00012 // 00013 // The pass uses the Program Dependence Graph (class PDGIterator) to 00014 // identify parallelizable function calls, i.e., calls whose instances 00015 // can be executed in parallel with instances of other function calls. 00016 // (In the future, this should also execute different instances of the same 00017 // function call in parallel, but that requires parallelizing across 00018 // loop iterations.) 00019 // 00020 // The output of the pass is LLVM code with: 00021 // (1) all parallelizable functions renamed to flag them as parallelizable; 00022 // (2) calls to a sync() function introduced at synchronization points. 00023 // The CWriter recognizes these functions and inserts the appropriate Cilk 00024 // keywords when writing out C code. This C code must be compiled with cilk2c. 00025 // 00026 // Current algorithmic limitations: 00027 // -- no array dependence analysis 00028 // -- no parallelization for function calls in different loop iterations 00029 // (except in unlikely trivial cases) 00030 // 00031 // Limitations of using Cilk: 00032 // -- No parallelism within a function body, e.g., in a loop; 00033 // -- Simplistic synchronization model requiring all parallel threads 00034 // created within a function to block at a sync(). 00035 // -- Excessive overhead at "spawned" function calls, which has no benefit 00036 // once all threads are busy (especially common when the degree of 00037 // parallelism is low). 00038 // 00039 //===----------------------------------------------------------------------===// 00040 00041 #include "llvm/DerivedTypes.h" 00042 #include "llvm/Instructions.h" 00043 #include "llvm/Module.h" 00044 #include "PgmDependenceGraph.h" 00045 #include "llvm/Analysis/DataStructure/DataStructure.h" 00046 #include "llvm/Analysis/DataStructure/DSGraph.h" 00047 #include "llvm/Support/InstVisitor.h" 00048 #include "llvm/Transforms/Utils/Local.h" 00049 #include "llvm/ADT/Statistic.h" 00050 #include "llvm/ADT/STLExtras.h" 00051 #include "llvm/ADT/hash_set" 00052 #include "llvm/ADT/hash_map" 00053 #include <functional> 00054 #include <algorithm> 00055 using namespace llvm; 00056 00057 //---------------------------------------------------------------------------- 00058 // Global constants used in marking Cilk functions and function calls. 00059 //---------------------------------------------------------------------------- 00060 00061 static const char * const CilkSuffix = ".llvm2cilk"; 00062 static const char * const DummySyncFuncName = "__sync.llvm2cilk"; 00063 00064 //---------------------------------------------------------------------------- 00065 // Routines to identify Cilk functions, calls to Cilk functions, and syncs. 00066 //---------------------------------------------------------------------------- 00067 00068 static bool isCilk(const Function& F) { 00069 return (F.getName().rfind(CilkSuffix) == 00070 F.getName().size() - std::strlen(CilkSuffix)); 00071 } 00072 00073 static bool isCilkMain(const Function& F) { 00074 return F.getName() == "main" + std::string(CilkSuffix); 00075 } 00076 00077 00078 static bool isCilk(const CallInst& CI) { 00079 return CI.getCalledFunction() && isCilk(*CI.getCalledFunction()); 00080 } 00081 00082 static bool isSync(const CallInst& CI) { 00083 return CI.getCalledFunction() && 00084 CI.getCalledFunction()->getName() == DummySyncFuncName; 00085 } 00086 00087 00088 //---------------------------------------------------------------------------- 00089 // class Cilkifier 00090 // 00091 // Code generation pass that transforms code to identify where Cilk keywords 00092 // should be inserted. This relies on `llvm-dis -c' to print out the keywords. 00093 //---------------------------------------------------------------------------- 00094 class Cilkifier: public InstVisitor<Cilkifier> { 00095 Function* DummySyncFunc; 00096 00097 // Data used when transforming each function. 00098 hash_set<const Instruction*> stmtsVisited; // Flags for recursive DFS 00099 hash_map<const CallInst*, hash_set<CallInst*> > spawnToSyncsMap; 00100 00101 // Input data for the transformation. 00102 const hash_set<Function*>* cilkFunctions; // Set of parallel functions 00103 PgmDependenceGraph* depGraph; 00104 00105 void DFSVisitInstr (Instruction* I, 00106 Instruction* root, 00107 hash_set<const Instruction*>& depsOfRoot); 00108 00109 public: 00110 /*ctor*/ Cilkifier (Module& M); 00111 00112 // Transform a single function including its name, its call sites, and syncs 00113 // 00114 void TransformFunc (Function* F, 00115 const hash_set<Function*>& cilkFunctions, 00116 PgmDependenceGraph& _depGraph); 00117 00118 // The visitor function that does most of the hard work, via DFSVisitInstr 00119 // 00120 void visitCallInst(CallInst& CI); 00121 }; 00122 00123 00124 Cilkifier::Cilkifier(Module& M) { 00125 // create the dummy Sync function and add it to the Module 00126 DummySyncFunc = M.getOrInsertFunction(DummySyncFuncName, Type::VoidTy, 0); 00127 } 00128 00129 void Cilkifier::TransformFunc(Function* F, 00130 const hash_set<Function*>& _cilkFunctions, 00131 PgmDependenceGraph& _depGraph) { 00132 // Memoize the information for this function 00133 cilkFunctions = &_cilkFunctions; 00134 depGraph = &_depGraph; 00135 00136 // Add the marker suffix to the Function name 00137 // This should automatically mark all calls to the function also! 00138 F->setName(F->getName() + CilkSuffix); 00139 00140 // Insert sync operations for each separate spawn 00141 visit(*F); 00142 00143 // Now traverse the CFG in rPostorder and eliminate redundant syncs, i.e., 00144 // two consecutive sync's on a straight-line path with no intervening spawn. 00145 00146 } 00147 00148 00149 void Cilkifier::DFSVisitInstr(Instruction* I, 00150 Instruction* root, 00151 hash_set<const Instruction*>& depsOfRoot) 00152 { 00153 assert(stmtsVisited.find(I) == stmtsVisited.end()); 00154 stmtsVisited.insert(I); 00155 00156 // If there is a dependence from root to I, insert Sync and return 00157 if (depsOfRoot.find(I) != depsOfRoot.end()) { 00158 // Insert a sync before I and stop searching along this path. 00159 // If I is a Phi instruction, the dependence can only be an SSA dep. 00160 // and we need to insert the sync in the predecessor on the appropriate 00161 // incoming edge! 00162 CallInst* syncI = 0; 00163 if (PHINode* phiI = dyn_cast<PHINode>(I)) { 00164 // check all operands of the Phi and insert before each one 00165 for (unsigned i = 0, N = phiI->getNumIncomingValues(); i < N; ++i) 00166 if (phiI->getIncomingValue(i) == root) 00167 syncI = new CallInst(DummySyncFunc, std::vector<Value*>(), "", 00168 phiI->getIncomingBlock(i)->getTerminator()); 00169 } else 00170 syncI = new CallInst(DummySyncFunc, std::vector<Value*>(), "", I); 00171 00172 // Remember the sync for each spawn to eliminate redundant ones later 00173 spawnToSyncsMap[cast<CallInst>(root)].insert(syncI); 00174 00175 return; 00176 } 00177 00178 // else visit unvisited successors 00179 if (BranchInst* brI = dyn_cast<BranchInst>(I)) { 00180 // visit first instruction in each successor BB 00181 for (unsigned i = 0, N = brI->getNumSuccessors(); i < N; ++i) 00182 if (stmtsVisited.find(&brI->getSuccessor(i)->front()) 00183 == stmtsVisited.end()) 00184 DFSVisitInstr(&brI->getSuccessor(i)->front(), root, depsOfRoot); 00185 } else 00186 if (Instruction* nextI = I->getNext()) 00187 if (stmtsVisited.find(nextI) == stmtsVisited.end()) 00188 DFSVisitInstr(nextI, root, depsOfRoot); 00189 } 00190 00191 00192 void Cilkifier::visitCallInst(CallInst& CI) 00193 { 00194 assert(CI.getCalledFunction() != 0 && "Only direct calls can be spawned."); 00195 if (cilkFunctions->find(CI.getCalledFunction()) == cilkFunctions->end()) 00196 return; // not a spawn 00197 00198 // Find all the outgoing memory dependences. 00199 hash_set<const Instruction*> depsOfRoot; 00200 for (PgmDependenceGraph::iterator DI = 00201 depGraph->outDepBegin(CI, MemoryDeps); ! DI.fini(); ++DI) 00202 depsOfRoot.insert(&DI->getSink()->getInstr()); 00203 00204 // Now find all outgoing SSA dependences to the eventual non-Phi users of 00205 // the call value (i.e., direct users that are not phis, and for any 00206 // user that is a Phi, direct non-Phi users of that Phi, and recursively). 00207 std::vector<const PHINode*> phiUsers; 00208 hash_set<const PHINode*> phisSeen; // ensures we don't visit a phi twice 00209 for (Value::use_iterator UI=CI.use_begin(), UE=CI.use_end(); UI != UE; ++UI) 00210 if (const PHINode* phiUser = dyn_cast<PHINode>(*UI)) { 00211 if (phisSeen.find(phiUser) == phisSeen.end()) { 00212 phiUsers.push_back(phiUser); 00213 phisSeen.insert(phiUser); 00214 } 00215 } 00216 else 00217 depsOfRoot.insert(cast<Instruction>(*UI)); 00218 00219 // Now we've found the non-Phi users and immediate phi users. 00220 // Recursively walk the phi users and add their non-phi users. 00221 for (const PHINode* phiUser; !phiUsers.empty(); phiUsers.pop_back()) { 00222 phiUser = phiUsers.back(); 00223 for (Value::use_const_iterator UI=phiUser->use_begin(), 00224 UE=phiUser->use_end(); UI != UE; ++UI) 00225 if (const PHINode* pn = dyn_cast<PHINode>(*UI)) { 00226 if (phisSeen.find(pn) == phisSeen.end()) { 00227 phiUsers.push_back(pn); 00228 phisSeen.insert(pn); 00229 } 00230 } else 00231 depsOfRoot.insert(cast<Instruction>(*UI)); 00232 } 00233 00234 // Walk paths of the CFG starting at the call instruction and insert 00235 // one sync before the first dependence on each path, if any. 00236 if (! depsOfRoot.empty()) { 00237 stmtsVisited.clear(); // start a new DFS for this CallInst 00238 assert(CI.getNext() && "Call instruction cannot be a terminator!"); 00239 DFSVisitInstr(CI.getNext(), &CI, depsOfRoot); 00240 } 00241 00242 // Now, eliminate all users of the SSA value of the CallInst, i.e., 00243 // if the call instruction returns a value, delete the return value 00244 // register and replace it by a stack slot. 00245 if (CI.getType() != Type::VoidTy) 00246 DemoteRegToStack(CI); 00247 } 00248 00249 00250 //---------------------------------------------------------------------------- 00251 // class FindParallelCalls 00252 // 00253 // Find all CallInst instructions that have at least one other CallInst 00254 // that is independent. These are the instructions that can produce 00255 // useful parallelism. 00256 //---------------------------------------------------------------------------- 00257 00258 class FindParallelCalls : public InstVisitor<FindParallelCalls> { 00259 typedef hash_set<CallInst*> DependentsSet; 00260 typedef DependentsSet::iterator Dependents_iterator; 00261 typedef DependentsSet::const_iterator Dependents_const_iterator; 00262 00263 PgmDependenceGraph& depGraph; // dependence graph for the function 00264 hash_set<Instruction*> stmtsVisited; // flags for DFS walk of depGraph 00265 hash_map<CallInst*, bool > completed; // flags marking if a CI is done 00266 hash_map<CallInst*, DependentsSet> dependents; // dependent CIs for each CI 00267 00268 void VisitOutEdges(Instruction* I, 00269 CallInst* root, 00270 DependentsSet& depsOfRoot); 00271 00272 FindParallelCalls(const FindParallelCalls &); // DO NOT IMPLEMENT 00273 void operator=(const FindParallelCalls&); // DO NOT IMPLEMENT 00274 public: 00275 std::vector<CallInst*> parallelCalls; 00276 00277 public: 00278 /*ctor*/ FindParallelCalls (Function& F, PgmDependenceGraph& DG); 00279 void visitCallInst (CallInst& CI); 00280 }; 00281 00282 00283 FindParallelCalls::FindParallelCalls(Function& F, 00284 PgmDependenceGraph& DG) 00285 : depGraph(DG) 00286 { 00287 // Find all CallInsts reachable from each CallInst using a recursive DFS 00288 visit(F); 00289 00290 // Now we've found all CallInsts reachable from each CallInst. 00291 // Find those CallInsts that are parallel with at least one other CallInst 00292 // by counting total inEdges and outEdges. 00293 unsigned long totalNumCalls = completed.size(); 00294 00295 if (totalNumCalls == 1) { 00296 // Check first for the special case of a single call instruction not 00297 // in any loop. It is not parallel, even if it has no dependences 00298 // (this is why it is a special case). 00299 // 00300 // FIXME: 00301 // THIS CASE IS NOT HANDLED RIGHT NOW, I.E., THERE IS NO 00302 // PARALLELISM FOR CALLS IN DIFFERENT ITERATIONS OF A LOOP. 00303 return; 00304 } 00305 00306 hash_map<CallInst*, unsigned long> numDeps; 00307 for (hash_map<CallInst*, DependentsSet>::iterator II = dependents.begin(), 00308 IE = dependents.end(); II != IE; ++II) { 00309 CallInst* fromCI = II->first; 00310 numDeps[fromCI] += II->second.size(); 00311 for (Dependents_iterator DI = II->second.begin(), DE = II->second.end(); 00312 DI != DE; ++DI) 00313 numDeps[*DI]++; // *DI can be reached from II->first 00314 } 00315 00316 for (hash_map<CallInst*, DependentsSet>::iterator 00317 II = dependents.begin(), IE = dependents.end(); II != IE; ++II) 00318 00319 // FIXME: Remove "- 1" when considering parallelism in loops 00320 if (numDeps[II->first] < totalNumCalls - 1) 00321 parallelCalls.push_back(II->first); 00322 } 00323 00324 00325 void FindParallelCalls::VisitOutEdges(Instruction* I, 00326 CallInst* root, 00327 DependentsSet& depsOfRoot) 00328 { 00329 assert(stmtsVisited.find(I) == stmtsVisited.end() && "Stmt visited twice?"); 00330 stmtsVisited.insert(I); 00331 00332 if (CallInst* CI = dyn_cast<CallInst>(I)) 00333 // FIXME: Ignoring parallelism in a loop. Here we're actually *ignoring* 00334 // a self-dependence in order to get the count comparison right above. 00335 // When we include loop parallelism, self-dependences should be included. 00336 if (CI != root) { 00337 // CallInst root has a path to CallInst I and any calls reachable from I 00338 depsOfRoot.insert(CI); 00339 if (completed[CI]) { 00340 // We have already visited I so we know all nodes it can reach! 00341 DependentsSet& depsOfI = dependents[CI]; 00342 depsOfRoot.insert(depsOfI.begin(), depsOfI.end()); 00343 return; 00344 } 00345 } 00346 00347 // If we reach here, we need to visit all children of I 00348 for (PgmDependenceGraph::iterator DI = depGraph.outDepBegin(*I); 00349 ! DI.fini(); ++DI) { 00350 Instruction* sink = &DI->getSink()->getInstr(); 00351 if (stmtsVisited.find(sink) == stmtsVisited.end()) 00352 VisitOutEdges(sink, root, depsOfRoot); 00353 } 00354 } 00355 00356 00357 void FindParallelCalls::visitCallInst(CallInst& CI) { 00358 if (completed[&CI]) 00359 return; 00360 stmtsVisited.clear(); // clear flags to do a fresh DFS 00361 00362 // Visit all children of CI using a recursive walk through dep graph 00363 DependentsSet& depsOfRoot = dependents[&CI]; 00364 for (PgmDependenceGraph::iterator DI = depGraph.outDepBegin(CI); 00365 ! DI.fini(); ++DI) { 00366 Instruction* sink = &DI->getSink()->getInstr(); 00367 if (stmtsVisited.find(sink) == stmtsVisited.end()) 00368 VisitOutEdges(sink, &CI, depsOfRoot); 00369 } 00370 00371 completed[&CI] = true; 00372 } 00373 00374 00375 //---------------------------------------------------------------------------- 00376 // class Parallelize 00377 // 00378 // (1) Find candidate parallel functions: any function F s.t. 00379 // there is a call C1 to the function F that is followed or preceded 00380 // by at least one other call C2 that is independent of this one 00381 // (i.e., there is no dependence path from C1 to C2 or C2 to C1) 00382 // (2) Label such a function F as a cilk function. 00383 // (3) Convert every call to F to a spawn 00384 // (4) For every function X, insert sync statements so that 00385 // every spawn is postdominated by a sync before any statements 00386 // with a data dependence to/from the call site for the spawn 00387 // 00388 //---------------------------------------------------------------------------- 00389 00390 namespace { 00391 class Parallelize : public ModulePass { 00392 public: 00393 /// Driver functions to transform a program 00394 /// 00395 bool runOnModule(Module& M); 00396 00397 /// getAnalysisUsage - Modifies extensively so preserve nothing. 00398 /// Uses the DependenceGraph and the Top-down DS Graph (only to find 00399 /// all functions called via an indirect call). 00400 /// 00401 void getAnalysisUsage(AnalysisUsage &AU) const { 00402 AU.addRequired<TDDataStructures>(); 00403 AU.addRequired<MemoryDepAnalysis>(); // force this not to be released 00404 AU.addRequired<PgmDependenceGraph>(); // because it is needed by this 00405 } 00406 }; 00407 00408 RegisterOpt<Parallelize> X("parallel", "Parallelize program using Cilk"); 00409 } 00410 00411 00412 bool Parallelize::runOnModule(Module& M) { 00413 hash_set<Function*> parallelFunctions; 00414 hash_set<Function*> safeParallelFunctions; 00415 hash_set<const GlobalValue*> indirectlyCalled; 00416 00417 // If there is no main (i.e., for an incomplete program), we can do nothing. 00418 // If there is a main, mark main as a parallel function. 00419 Function* mainFunc = M.getMainFunction(); 00420 if (!mainFunc) 00421 return false; 00422 00423 // (1) Find candidate parallel functions and mark them as Cilk functions 00424 for (Module::iterator FI = M.begin(), FE = M.end(); FI != FE; ++FI) 00425 if (! FI->isExternal()) { 00426 Function* F = FI; 00427 DSGraph& tdg = getAnalysis<TDDataStructures>().getDSGraph(*F); 00428 00429 // All the hard analysis work gets done here! 00430 FindParallelCalls finder(*F, 00431 getAnalysis<PgmDependenceGraph>().getGraph(*F)); 00432 /* getAnalysis<MemoryDepAnalysis>().getGraph(*F)); */ 00433 00434 // Now we know which call instructions are useful to parallelize. 00435 // Remember those callee functions. 00436 for (std::vector<CallInst*>::iterator 00437 CII = finder.parallelCalls.begin(), 00438 CIE = finder.parallelCalls.end(); CII != CIE; ++CII) { 00439 // Check if this is a direct call... 00440 if ((*CII)->getCalledFunction() != NULL) { 00441 // direct call: if this is to a non-external function, 00442 // mark it as a parallelizable function 00443 if (! (*CII)->getCalledFunction()->isExternal()) 00444 parallelFunctions.insert((*CII)->getCalledFunction()); 00445 } else { 00446 // Indirect call: mark all potential callees as bad 00447 std::vector<GlobalValue*> callees = 00448 tdg.getNodeForValue((*CII)->getCalledValue()) 00449 .getNode()->getGlobals(); 00450 indirectlyCalled.insert(callees.begin(), callees.end()); 00451 } 00452 } 00453 } 00454 00455 // Remove all indirectly called functions from the list of Cilk functions. 00456 for (hash_set<Function*>::iterator PFI = parallelFunctions.begin(), 00457 PFE = parallelFunctions.end(); PFI != PFE; ++PFI) 00458 if (indirectlyCalled.count(*PFI) == 0) 00459 safeParallelFunctions.insert(*PFI); 00460 00461 #undef CAN_USE_BIND1ST_ON_REFERENCE_TYPE_ARGS 00462 #ifdef CAN_USE_BIND1ST_ON_REFERENCE_TYPE_ARGS 00463 // Use this indecipherable STLese because erase invalidates iterators. 00464 // Otherwise we have to copy sets as above. 00465 hash_set<Function*>::iterator extrasBegin = 00466 std::remove_if(parallelFunctions.begin(), parallelFunctions.end(), 00467 compose1(std::bind2nd(std::greater<int>(), 0), 00468 bind_obj(&indirectlyCalled, 00469 &hash_set<const GlobalValue*>::count))); 00470 parallelFunctions.erase(extrasBegin, parallelFunctions.end()); 00471 #endif 00472 00473 // If there are no parallel functions, we can just give up. 00474 if (safeParallelFunctions.empty()) 00475 return false; 00476 00477 // Add main as a parallel function since Cilk requires this. 00478 safeParallelFunctions.insert(mainFunc); 00479 00480 // (2,3) Transform each Cilk function and all its calls simply by 00481 // adding a unique suffix to the function name. 00482 // This should identify both functions and calls to such functions 00483 // to the code generator. 00484 // (4) Also, insert calls to sync at appropriate points. 00485 Cilkifier cilkifier(M); 00486 for (hash_set<Function*>::iterator CFI = safeParallelFunctions.begin(), 00487 CFE = safeParallelFunctions.end(); CFI != CFE; ++CFI) { 00488 cilkifier.TransformFunc(*CFI, safeParallelFunctions, 00489 getAnalysis<PgmDependenceGraph>().getGraph(**CFI)); 00490 /* getAnalysis<MemoryDepAnalysis>().getGraph(**CFI)); */ 00491 } 00492 00493 return true; 00494 } 00495