[ManagedMemoryRewrite] Iterate over operands of the expanded instruction, not the...
[polly-mirror.git] / lib / CodeGen / ManagedMemoryRewrite.cpp
blob95f9d7dceed2dc175d918d82f53b80f0f0f007aa
1 //===---- ManagedMemoryRewrite.cpp - Rewrite global & malloc'd memory -----===//
2 //
3 // The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 // Take a module and rewrite:
11 // 1. `malloc` -> `polly_mallocManaged`
12 // 2. `free` -> `polly_freeManaged`
13 // 3. global arrays with initializers -> global arrays that are initialized
14 // with a constructor call to
15 // `polly_mallocManaged`.
17 //===----------------------------------------------------------------------===//
19 #include "polly/CodeGen/CodeGeneration.h"
20 #include "polly/CodeGen/IslAst.h"
21 #include "polly/CodeGen/IslNodeBuilder.h"
22 #include "polly/CodeGen/PPCGCodeGeneration.h"
23 #include "polly/CodeGen/Utils.h"
24 #include "polly/DependenceInfo.h"
25 #include "polly/LinkAllPasses.h"
26 #include "polly/Options.h"
27 #include "polly/ScopDetection.h"
28 #include "polly/ScopInfo.h"
29 #include "polly/Support/SCEVValidator.h"
30 #include "llvm/Analysis/AliasAnalysis.h"
31 #include "llvm/Analysis/BasicAliasAnalysis.h"
32 #include "llvm/Analysis/CaptureTracking.h"
33 #include "llvm/Analysis/GlobalsModRef.h"
34 #include "llvm/Analysis/ScalarEvolutionAliasAnalysis.h"
35 #include "llvm/Analysis/TargetLibraryInfo.h"
36 #include "llvm/Analysis/TargetTransformInfo.h"
37 #include "llvm/IR/LegacyPassManager.h"
38 #include "llvm/IR/Verifier.h"
39 #include "llvm/IRReader/IRReader.h"
40 #include "llvm/Linker/Linker.h"
41 #include "llvm/Support/TargetRegistry.h"
42 #include "llvm/Support/TargetSelect.h"
43 #include "llvm/Target/TargetMachine.h"
44 #include "llvm/Transforms/IPO/PassManagerBuilder.h"
45 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
46 #include "llvm/Transforms/Utils/ModuleUtils.h"
48 static cl::opt<bool> RewriteAllocas(
49 "polly-acc-rewrite-allocas",
50 cl::desc(
51 "Ask the managed memory rewriter to also rewrite alloca instructions"),
52 cl::Hidden, cl::init(false), cl::ZeroOrMore, cl::cat(PollyCategory));
54 static cl::opt<bool> IgnoreLinkageForGlobals(
55 "polly-acc-rewrite-ignore-linkage-for-globals",
56 cl::desc(
57 "By default, we only rewrite globals with internal linkage. This flag "
58 "enables rewriting of globals regardless of linkage"),
59 cl::Hidden, cl::init(false), cl::ZeroOrMore, cl::cat(PollyCategory));
61 #define DEBUG_TYPE "polly-acc-rewrite-managed-memory"
62 namespace {
64 static llvm::Function *getOrCreatePollyMallocManaged(Module &M) {
65 const char *Name = "polly_mallocManaged";
66 Function *F = M.getFunction(Name);
68 // If F is not available, declare it.
69 if (!F) {
70 GlobalValue::LinkageTypes Linkage = Function::ExternalLinkage;
71 PollyIRBuilder Builder(M.getContext());
72 // TODO: How do I get `size_t`? I assume from DataLayout?
73 FunctionType *Ty = FunctionType::get(Builder.getInt8PtrTy(),
74 {Builder.getInt64Ty()}, false);
75 F = Function::Create(Ty, Linkage, Name, &M);
78 return F;
81 static llvm::Function *getOrCreatePollyFreeManaged(Module &M) {
82 const char *Name = "polly_freeManaged";
83 Function *F = M.getFunction(Name);
85 // If F is not available, declare it.
86 if (!F) {
87 GlobalValue::LinkageTypes Linkage = Function::ExternalLinkage;
88 PollyIRBuilder Builder(M.getContext());
89 // TODO: How do I get `size_t`? I assume from DataLayout?
90 FunctionType *Ty =
91 FunctionType::get(Builder.getVoidTy(), {Builder.getInt8PtrTy()}, false);
92 F = Function::Create(Ty, Linkage, Name, &M);
95 return F;
98 // Expand a constant expression `Cur`, which is used at instruction `Parent`
99 // at index `index`.
100 // Since a constant expression can expand to multiple instructions, store all
101 // the expands into a set called `Expands`.
102 // Note that this goes inorder on the constant expression tree.
103 // A * ((B * D) + C)
104 // will be processed with first A, then B * D, then B, then D, and then C.
105 // Though ConstantExprs are not treated as "trees" but as DAGs, since you can
106 // have something like this:
107 // *
108 // / \
109 // \ /
110 // (D)
112 // For the purposes of this expansion, we expand the two occurences of D
113 // separately. Therefore, we expand the DAG into the tree:
114 // *
115 // / \
116 // D D
117 // TODO: We don't _have_to do this, but this is the simplest solution.
118 // We can write a solution that keeps track of which constants have been
119 // already expanded.
120 static void expandConstantExpr(ConstantExpr *Cur, PollyIRBuilder &Builder,
121 Instruction *Parent, int index,
122 SmallPtrSet<Instruction *, 4> &Expands) {
123 assert(Cur && "invalid constant expression passed");
124 Instruction *I = Cur->getAsInstruction();
125 assert(I && "unable to convert ConstantExpr to Instruction");
127 DEBUG(dbgs() << "Expanding ConstantExpression: " << *Cur
128 << " | in Instruction: " << *I << "\n";);
130 // Invalidate `Cur` so that no one after this point uses `Cur`. Rather,
131 // they should mutate `I`.
132 Cur = nullptr;
134 Expands.insert(I);
135 Parent->setOperand(index, I);
137 // The things that `Parent` uses (its operands) should be created
138 // before `Parent`.
139 Builder.SetInsertPoint(Parent);
140 Builder.Insert(I);
142 for (unsigned i = 0; i < I->getNumOperands(); i++) {
143 Value *Op = I->getOperand(i);
144 assert(isa<Constant>(Op) && "constant must have a constant operand");
146 if (ConstantExpr *CExprOp = dyn_cast<ConstantExpr>(Op))
147 expandConstantExpr(CExprOp, Builder, I, i, Expands);
151 // Edit all uses of `OldVal` to NewVal` in `Inst`. This will rewrite
152 // `ConstantExpr`s that are used in the `Inst`.
153 // Note that `replaceAllUsesWith` is insufficient for this purpose because it
154 // does not rewrite values in `ConstantExpr`s.
155 static void rewriteOldValToNew(Instruction *Inst, Value *OldVal, Value *NewVal,
156 PollyIRBuilder &Builder) {
158 // This contains a set of instructions in which OldVal must be replaced.
159 // We start with `Inst`, and we fill it up with the expanded `ConstantExpr`s
160 // from `Inst`s arguments.
161 // We need to go through this process because `replaceAllUsesWith` does not
162 // actually edit `ConstantExpr`s.
163 SmallPtrSet<Instruction *, 4> InstsToVisit = {Inst};
165 // Expand all `ConstantExpr`s and place it in `InstsToVisit`.
166 for (unsigned i = 0; i < Inst->getNumOperands(); i++) {
167 Value *Operand = Inst->getOperand(i);
168 if (ConstantExpr *ValueConstExpr = dyn_cast<ConstantExpr>(Operand))
169 expandConstantExpr(ValueConstExpr, Builder, Inst, i, InstsToVisit);
172 // Now visit each instruction and use `replaceUsesOfWith`. We know that
173 // will work because `I` cannot have any `ConstantExpr` within it.
174 for (Instruction *I : InstsToVisit)
175 I->replaceUsesOfWith(OldVal, NewVal);
178 // Given a value `Current`, return all Instructions that may contain `Current`
179 // in an expression.
180 // We need this auxiliary function, because if we have a
181 // `Constant` that is a user of `V`, we need to recurse into the
182 // `Constant`s uses to gather the root instruciton.
183 static void getInstructionUsersOfValue(Value *V,
184 SmallVector<Instruction *, 4> &Owners) {
185 if (auto *I = dyn_cast<Instruction>(V)) {
186 Owners.push_back(I);
187 } else {
188 // Anything that is a `User` must be a constant or an instruction.
189 auto *C = cast<Constant>(V);
190 for (Use &CUse : C->uses())
191 getInstructionUsersOfValue(CUse.getUser(), Owners);
195 static void
196 replaceGlobalArray(Module &M, const DataLayout &DL, GlobalVariable &Array,
197 SmallPtrSet<GlobalVariable *, 4> &ReplacedGlobals) {
198 // We only want arrays.
199 ArrayType *ArrayTy = dyn_cast<ArrayType>(Array.getType()->getElementType());
200 if (!ArrayTy)
201 return;
202 Type *ElemTy = ArrayTy->getElementType();
203 PointerType *ElemPtrTy = ElemTy->getPointerTo();
205 // We only wish to replace arrays that are visible in the module they
206 // inhabit. Otherwise, our type edit from [T] to T* would be illegal across
207 // modules.
208 const bool OnlyVisibleInsideModule = Array.hasPrivateLinkage() ||
209 Array.hasInternalLinkage() ||
210 IgnoreLinkageForGlobals;
211 if (!OnlyVisibleInsideModule)
212 return;
214 if (!Array.hasInitializer() ||
215 !isa<ConstantAggregateZero>(Array.getInitializer()))
216 return;
218 // At this point, we have committed to replacing this array.
219 ReplacedGlobals.insert(&Array);
221 std::string NewName = (Array.getName() + Twine(".toptr")).str();
222 GlobalVariable *ReplacementToArr =
223 cast<GlobalVariable>(M.getOrInsertGlobal(NewName, ElemPtrTy));
224 ReplacementToArr->setInitializer(ConstantPointerNull::get(ElemPtrTy));
226 Function *PollyMallocManaged = getOrCreatePollyMallocManaged(M);
227 Twine FnName = Array.getName() + ".constructor";
228 PollyIRBuilder Builder(M.getContext());
229 FunctionType *Ty = FunctionType::get(Builder.getVoidTy(), false);
230 const GlobalValue::LinkageTypes Linkage = Function::ExternalLinkage;
231 Function *F = Function::Create(Ty, Linkage, FnName, &M);
232 BasicBlock *Start = BasicBlock::Create(M.getContext(), "entry", F);
233 Builder.SetInsertPoint(Start);
235 int ArraySizeInt = DL.getTypeAllocSizeInBits(ArrayTy) / 8;
236 Value *ArraySize = Builder.getInt64(ArraySizeInt);
237 ArraySize->setName("array.size");
239 Value *AllocatedMemRaw =
240 Builder.CreateCall(PollyMallocManaged, {ArraySize}, "mem.raw");
241 Value *AllocatedMemTyped =
242 Builder.CreatePointerCast(AllocatedMemRaw, ElemPtrTy, "mem.typed");
243 Builder.CreateStore(AllocatedMemTyped, ReplacementToArr);
244 Builder.CreateRetVoid();
246 const int Priority = 0;
247 appendToGlobalCtors(M, F, Priority, ReplacementToArr);
249 SmallVector<Instruction *, 4> ArrayUserInstructions;
250 // Get all instructions that use array. We need to do this weird thing
251 // because `Constant`s that contain this array neeed to be expanded into
252 // instructions so that we can replace their parameters. `Constant`s cannot
253 // be edited easily, so we choose to convert all `Constant`s to
254 // `Instruction`s and handle all of the uses of `Array` uniformly.
255 for (Use &ArrayUse : Array.uses())
256 getInstructionUsersOfValue(ArrayUse.getUser(), ArrayUserInstructions);
258 for (Instruction *UserOfArrayInst : ArrayUserInstructions) {
260 Builder.SetInsertPoint(UserOfArrayInst);
261 // <ty>** -> <ty>*
262 Value *ArrPtrLoaded = Builder.CreateLoad(ReplacementToArr, "arrptr.load");
263 // <ty>* -> [ty]*
264 Value *ArrPtrLoadedBitcasted = Builder.CreateBitCast(
265 ArrPtrLoaded, ArrayTy->getPointerTo(), "arrptr.bitcast");
266 rewriteOldValToNew(UserOfArrayInst, &Array, ArrPtrLoadedBitcasted, Builder);
270 // We return all `allocas` that may need to be converted to a call to
271 // cudaMallocManaged.
272 static void getAllocasToBeManaged(Function &F,
273 SmallSet<AllocaInst *, 4> &Allocas) {
274 for (BasicBlock &BB : F) {
275 for (Instruction &I : BB) {
276 auto *Alloca = dyn_cast<AllocaInst>(&I);
277 if (!Alloca)
278 continue;
279 dbgs() << "Checking if " << *Alloca << "may be captured: ";
281 if (PointerMayBeCaptured(Alloca, /* ReturnCaptures */ false,
282 /* StoreCaptures */ true)) {
283 Allocas.insert(Alloca);
284 DEBUG(dbgs() << "YES (captured)\n");
285 } else {
286 DEBUG(dbgs() << "NO (not captured)\n");
292 static void rewriteAllocaAsManagedMemory(AllocaInst *Alloca,
293 const DataLayout &DL) {
294 DEBUG(dbgs() << "rewriting: " << *Alloca << " to managed mem.\n");
295 Module *M = Alloca->getModule();
296 assert(M && "Alloca does not have a module");
298 PollyIRBuilder Builder(M->getContext());
299 Builder.SetInsertPoint(Alloca);
301 Value *MallocManagedFn = getOrCreatePollyMallocManaged(*Alloca->getModule());
302 const int Size = DL.getTypeAllocSize(Alloca->getType()->getElementType());
303 Value *SizeVal = Builder.getInt64(Size);
304 Value *RawManagedMem = Builder.CreateCall(MallocManagedFn, {SizeVal});
305 Value *Bitcasted = Builder.CreateBitCast(RawManagedMem, Alloca->getType());
307 Function *F = Alloca->getFunction();
308 assert(F && "Alloca has invalid function");
310 Bitcasted->takeName(Alloca);
311 Alloca->replaceAllUsesWith(Bitcasted);
312 Alloca->eraseFromParent();
314 for (BasicBlock &BB : *F) {
315 ReturnInst *Return = dyn_cast<ReturnInst>(BB.getTerminator());
316 if (!Return)
317 continue;
318 Builder.SetInsertPoint(Return);
320 Value *FreeManagedFn = getOrCreatePollyFreeManaged(*M);
321 Builder.CreateCall(FreeManagedFn, {RawManagedMem});
325 // Replace all uses of `Old` with `New`, even inside `ConstantExpr`.
327 // `replaceAllUsesWith` does replace values in `ConstantExpr`. This function
328 // actually does replace it in `ConstantExpr`. The caveat is that if there is
329 // a use that is *outside* a function (say, at global declarations), we fail.
330 // So, this is meant to be used on values which we know will only be used
331 // within functions.
333 // This process works by looking through the uses of `Old`. If it finds a
334 // `ConstantExpr`, it recursively looks for the owning instruction.
335 // Then, it expands all the `ConstantExpr` to instructions and replaces
336 // `Old` with `New` in the expanded instructions.
337 static void replaceAllUsesAndConstantUses(Value *Old, Value *New,
338 PollyIRBuilder &Builder) {
339 SmallVector<Instruction *, 4> UserInstructions;
340 // Get all instructions that use array. We need to do this weird thing
341 // because `Constant`s that contain this array neeed to be expanded into
342 // instructions so that we can replace their parameters. `Constant`s cannot
343 // be edited easily, so we choose to convert all `Constant`s to
344 // `Instruction`s and handle all of the uses of `Array` uniformly.
345 for (Use &ArrayUse : Old->uses())
346 getInstructionUsersOfValue(ArrayUse.getUser(), UserInstructions);
348 for (Instruction *I : UserInstructions)
349 rewriteOldValToNew(I, Old, New, Builder);
352 class ManagedMemoryRewritePass : public ModulePass {
353 public:
354 static char ID;
355 GPUArch Architecture;
356 GPURuntime Runtime;
358 ManagedMemoryRewritePass() : ModulePass(ID) {}
359 virtual bool runOnModule(Module &M) {
360 const DataLayout &DL = M.getDataLayout();
362 Function *Malloc = M.getFunction("malloc");
364 if (Malloc) {
365 PollyIRBuilder Builder(M.getContext());
366 Function *PollyMallocManaged = getOrCreatePollyMallocManaged(M);
367 assert(PollyMallocManaged && "unable to create polly_mallocManaged");
369 replaceAllUsesAndConstantUses(Malloc, PollyMallocManaged, Builder);
370 Malloc->eraseFromParent();
373 Function *Free = M.getFunction("free");
375 if (Free) {
376 PollyIRBuilder Builder(M.getContext());
377 Function *PollyFreeManaged = getOrCreatePollyFreeManaged(M);
378 assert(PollyFreeManaged && "unable to create polly_freeManaged");
380 replaceAllUsesAndConstantUses(Free, PollyFreeManaged, Builder);
381 Free->eraseFromParent();
384 SmallPtrSet<GlobalVariable *, 4> GlobalsToErase;
385 for (GlobalVariable &Global : M.globals())
386 replaceGlobalArray(M, DL, Global, GlobalsToErase);
387 for (GlobalVariable *G : GlobalsToErase)
388 G->eraseFromParent();
390 // Rewrite allocas to cudaMallocs if we are asked to do so.
391 if (RewriteAllocas) {
392 SmallSet<AllocaInst *, 4> AllocasToBeManaged;
393 for (Function &F : M.functions())
394 getAllocasToBeManaged(F, AllocasToBeManaged);
396 for (AllocaInst *Alloca : AllocasToBeManaged)
397 rewriteAllocaAsManagedMemory(Alloca, DL);
400 return true;
404 } // namespace
405 char ManagedMemoryRewritePass::ID = 42;
407 Pass *polly::createManagedMemoryRewritePassPass(GPUArch Arch,
408 GPURuntime Runtime) {
409 ManagedMemoryRewritePass *pass = new ManagedMemoryRewritePass();
410 pass->Runtime = Runtime;
411 pass->Architecture = Arch;
412 return pass;
415 INITIALIZE_PASS_BEGIN(
416 ManagedMemoryRewritePass, "polly-acc-rewrite-managed-memory",
417 "Polly - Rewrite all allocations in heap & data section to managed memory",
418 false, false)
419 INITIALIZE_PASS_DEPENDENCY(PPCGCodeGeneration);
420 INITIALIZE_PASS_DEPENDENCY(DependenceInfo);
421 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass);
422 INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass);
423 INITIALIZE_PASS_DEPENDENCY(RegionInfoPass);
424 INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass);
425 INITIALIZE_PASS_DEPENDENCY(ScopDetectionWrapperPass);
426 INITIALIZE_PASS_END(
427 ManagedMemoryRewritePass, "polly-acc-rewrite-managed-memory",
428 "Polly - Rewrite all allocations in heap & data section to managed memory",
429 false, false)