1 //===-- StructRetPromotion.cpp - Promote sret arguments ------------------===//
3 // The LLVM Compiler Infrastructure
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
8 //===----------------------------------------------------------------------===//
10 // This pass finds functions that return a struct (using a pointer to the struct
11 // as the first argument of the function, marked with the 'sret' attribute) and
12 // replaces them with a new function that simply returns each of the elements of
13 // that struct (using multiple return values).
15 // This pass works under a number of conditions:
16 // 1. The returned struct must not contain other structs
17 // 2. The returned struct must only be used to load values from
18 // 3. The placeholder struct passed in is the result of an alloca
20 //===----------------------------------------------------------------------===//
22 #define DEBUG_TYPE "sretpromotion"
23 #include "llvm/Transforms/IPO.h"
24 #include "llvm/Constants.h"
25 #include "llvm/DerivedTypes.h"
26 #include "llvm/LLVMContext.h"
27 #include "llvm/Module.h"
28 #include "llvm/CallGraphSCCPass.h"
29 #include "llvm/Instructions.h"
30 #include "llvm/Analysis/CallGraph.h"
31 #include "llvm/Support/CallSite.h"
32 #include "llvm/Support/CFG.h"
33 #include "llvm/Support/Debug.h"
34 #include "llvm/ADT/Statistic.h"
35 #include "llvm/ADT/SmallVector.h"
36 #include "llvm/ADT/Statistic.h"
37 #include "llvm/Support/raw_ostream.h"
40 STATISTIC(NumRejectedSRETUses
, "Number of sret rejected due to unexpected uses");
41 STATISTIC(NumSRET
, "Number of sret promoted");
43 /// SRETPromotion - This pass removes sret parameter and updates
44 /// function to use multiple return value.
46 struct SRETPromotion
: public CallGraphSCCPass
{
47 virtual void getAnalysisUsage(AnalysisUsage
&AU
) const {
48 CallGraphSCCPass::getAnalysisUsage(AU
);
51 virtual bool runOnSCC(std::vector
<CallGraphNode
*> &SCC
);
52 static char ID
; // Pass identification, replacement for typeid
53 SRETPromotion() : CallGraphSCCPass(&ID
) {}
56 CallGraphNode
*PromoteReturn(CallGraphNode
*CGN
);
57 bool isSafeToUpdateAllCallers(Function
*F
);
58 Function
*cloneFunctionBody(Function
*F
, const StructType
*STy
);
59 CallGraphNode
*updateCallSites(Function
*F
, Function
*NF
);
60 bool nestedStructType(const StructType
*STy
);
64 char SRETPromotion::ID
= 0;
65 static RegisterPass
<SRETPromotion
>
66 X("sretpromotion", "Promote sret arguments to multiple ret values");
68 Pass
*llvm::createStructRetPromotionPass() {
69 return new SRETPromotion();
72 bool SRETPromotion::runOnSCC(std::vector
<CallGraphNode
*> &SCC
) {
75 for (unsigned i
= 0, e
= SCC
.size(); i
!= e
; ++i
)
76 if (CallGraphNode
*NewNode
= PromoteReturn(SCC
[i
])) {
84 /// PromoteReturn - This method promotes function that uses StructRet paramater
85 /// into a function that uses multiple return values.
86 CallGraphNode
*SRETPromotion::PromoteReturn(CallGraphNode
*CGN
) {
87 Function
*F
= CGN
->getFunction();
89 if (!F
|| F
->isDeclaration() || !F
->hasLocalLinkage())
92 // Make sure that function returns struct.
93 if (F
->arg_size() == 0 || !F
->hasStructRetAttr() || F
->doesNotReturn())
96 DEBUG(errs() << "SretPromotion: Looking at sret function "
97 << F
->getName() << "\n");
99 assert(F
->getReturnType() == Type::getVoidTy(F
->getContext()) &&
100 "Invalid function return type");
101 Function::arg_iterator AI
= F
->arg_begin();
102 const llvm::PointerType
*FArgType
= dyn_cast
<PointerType
>(AI
->getType());
103 assert(FArgType
&& "Invalid sret parameter type");
104 const llvm::StructType
*STy
=
105 dyn_cast
<StructType
>(FArgType
->getElementType());
106 assert(STy
&& "Invalid sret parameter element type");
108 // Check if it is ok to perform this promotion.
109 if (isSafeToUpdateAllCallers(F
) == false) {
110 DEBUG(errs() << "SretPromotion: Not all callers can be updated\n");
111 NumRejectedSRETUses
++;
115 DEBUG(errs() << "SretPromotion: sret argument will be promoted\n");
117 // [1] Replace use of sret parameter
118 AllocaInst
*TheAlloca
= new AllocaInst(STy
, NULL
, "mrv",
119 F
->getEntryBlock().begin());
120 Value
*NFirstArg
= F
->arg_begin();
121 NFirstArg
->replaceAllUsesWith(TheAlloca
);
123 // [2] Find and replace ret instructions
124 for (Function::iterator FI
= F
->begin(), FE
= F
->end(); FI
!= FE
; ++FI
)
125 for(BasicBlock::iterator BI
= FI
->begin(), BE
= FI
->end(); BI
!= BE
; ) {
128 if (isa
<ReturnInst
>(I
)) {
129 Value
*NV
= new LoadInst(TheAlloca
, "mrv.ld", I
);
130 ReturnInst
*NR
= ReturnInst::Create(F
->getContext(), NV
, I
);
131 I
->replaceAllUsesWith(NR
);
132 I
->eraseFromParent();
136 // [3] Create the new function body and insert it into the module.
137 Function
*NF
= cloneFunctionBody(F
, STy
);
139 // [4] Update all call sites to use new function
140 CallGraphNode
*NF_CFN
= updateCallSites(F
, NF
);
142 CallGraph
&CG
= getAnalysis
<CallGraph
>();
143 NF_CFN
->stealCalledFunctionsFrom(CG
[F
]);
145 delete CG
.removeFunctionFromModule(F
);
149 // Check if it is ok to perform this promotion.
150 bool SRETPromotion::isSafeToUpdateAllCallers(Function
*F
) {
153 // No users. OK to modify signature.
156 for (Value::use_iterator FnUseI
= F
->use_begin(), FnUseE
= F
->use_end();
157 FnUseI
!= FnUseE
; ++FnUseI
) {
158 // The function is passed in as an argument to (possibly) another function,
159 // we can't change it!
160 CallSite CS
= CallSite::get(*FnUseI
);
161 Instruction
*Call
= CS
.getInstruction();
162 // The function is used by something else than a call or invoke instruction,
163 // we can't change it!
164 if (!Call
|| !CS
.isCallee(FnUseI
))
166 CallSite::arg_iterator AI
= CS
.arg_begin();
167 Value
*FirstArg
= *AI
;
169 if (!isa
<AllocaInst
>(FirstArg
))
172 // Check FirstArg's users.
173 for (Value::use_iterator ArgI
= FirstArg
->use_begin(),
174 ArgE
= FirstArg
->use_end(); ArgI
!= ArgE
; ++ArgI
) {
176 // If FirstArg user is a CallInst that does not correspond to current
177 // call site then this function F is not suitable for sret promotion.
178 if (CallInst
*CI
= dyn_cast
<CallInst
>(ArgI
)) {
182 // If FirstArg user is a GEP whose all users are not LoadInst then
183 // this function F is not suitable for sret promotion.
184 else if (GetElementPtrInst
*GEP
= dyn_cast
<GetElementPtrInst
>(ArgI
)) {
185 // TODO : Use dom info and insert PHINodes to collect get results
186 // from multiple call sites for this GEP.
187 if (GEP
->getParent() != Call
->getParent())
189 for (Value::use_iterator GEPI
= GEP
->use_begin(), GEPE
= GEP
->use_end();
190 GEPI
!= GEPE
; ++GEPI
)
191 if (!isa
<LoadInst
>(GEPI
))
194 // Any other FirstArg users make this function unsuitable for sret
204 /// cloneFunctionBody - Create a new function based on F and
205 /// insert it into module. Remove first argument. Use STy as
206 /// the return type for new function.
207 Function
*SRETPromotion::cloneFunctionBody(Function
*F
,
208 const StructType
*STy
) {
210 const FunctionType
*FTy
= F
->getFunctionType();
211 std::vector
<const Type
*> Params
;
213 // Attributes - Keep track of the parameter attributes for the arguments.
214 SmallVector
<AttributeWithIndex
, 8> AttributesVec
;
215 const AttrListPtr
&PAL
= F
->getAttributes();
217 // Add any return attributes.
218 if (Attributes attrs
= PAL
.getRetAttributes())
219 AttributesVec
.push_back(AttributeWithIndex::get(0, attrs
));
221 // Skip first argument.
222 Function::arg_iterator I
= F
->arg_begin(), E
= F
->arg_end();
224 // 0th parameter attribute is reserved for return type.
225 // 1th parameter attribute is for first 1st sret argument.
226 unsigned ParamIndex
= 2;
228 Params
.push_back(I
->getType());
229 if (Attributes Attrs
= PAL
.getParamAttributes(ParamIndex
))
230 AttributesVec
.push_back(AttributeWithIndex::get(ParamIndex
- 1, Attrs
));
235 // Add any fn attributes.
236 if (Attributes attrs
= PAL
.getFnAttributes())
237 AttributesVec
.push_back(AttributeWithIndex::get(~0, attrs
));
240 FunctionType
*NFTy
= FunctionType::get(STy
, Params
, FTy
->isVarArg());
241 Function
*NF
= Function::Create(NFTy
, F
->getLinkage());
243 NF
->copyAttributesFrom(F
);
244 NF
->setAttributes(AttrListPtr::get(AttributesVec
.begin(), AttributesVec
.end()));
245 F
->getParent()->getFunctionList().insert(F
, NF
);
246 NF
->getBasicBlockList().splice(NF
->begin(), F
->getBasicBlockList());
251 Function::arg_iterator NI
= NF
->arg_begin();
254 I
->replaceAllUsesWith(NI
);
263 /// updateCallSites - Update all sites that call F to use NF.
264 CallGraphNode
*SRETPromotion::updateCallSites(Function
*F
, Function
*NF
) {
265 CallGraph
&CG
= getAnalysis
<CallGraph
>();
266 SmallVector
<Value
*, 16> Args
;
268 // Attributes - Keep track of the parameter attributes for the arguments.
269 SmallVector
<AttributeWithIndex
, 8> ArgAttrsVec
;
271 // Get a new callgraph node for NF.
272 CallGraphNode
*NF_CGN
= CG
.getOrInsertFunction(NF
);
274 while (!F
->use_empty()) {
275 CallSite CS
= CallSite::get(*F
->use_begin());
276 Instruction
*Call
= CS
.getInstruction();
278 const AttrListPtr
&PAL
= F
->getAttributes();
279 // Add any return attributes.
280 if (Attributes attrs
= PAL
.getRetAttributes())
281 ArgAttrsVec
.push_back(AttributeWithIndex::get(0, attrs
));
283 // Copy arguments, however skip first one.
284 CallSite::arg_iterator AI
= CS
.arg_begin(), AE
= CS
.arg_end();
285 Value
*FirstCArg
= *AI
;
287 // 0th parameter attribute is reserved for return type.
288 // 1th parameter attribute is for first 1st sret argument.
289 unsigned ParamIndex
= 2;
292 if (Attributes Attrs
= PAL
.getParamAttributes(ParamIndex
))
293 ArgAttrsVec
.push_back(AttributeWithIndex::get(ParamIndex
- 1, Attrs
));
298 // Add any function attributes.
299 if (Attributes attrs
= PAL
.getFnAttributes())
300 ArgAttrsVec
.push_back(AttributeWithIndex::get(~0, attrs
));
302 AttrListPtr NewPAL
= AttrListPtr::get(ArgAttrsVec
.begin(), ArgAttrsVec
.end());
304 // Build new call instruction.
306 if (InvokeInst
*II
= dyn_cast
<InvokeInst
>(Call
)) {
307 New
= InvokeInst::Create(NF
, II
->getNormalDest(), II
->getUnwindDest(),
308 Args
.begin(), Args
.end(), "", Call
);
309 cast
<InvokeInst
>(New
)->setCallingConv(CS
.getCallingConv());
310 cast
<InvokeInst
>(New
)->setAttributes(NewPAL
);
312 New
= CallInst::Create(NF
, Args
.begin(), Args
.end(), "", Call
);
313 cast
<CallInst
>(New
)->setCallingConv(CS
.getCallingConv());
314 cast
<CallInst
>(New
)->setAttributes(NewPAL
);
315 if (cast
<CallInst
>(Call
)->isTailCall())
316 cast
<CallInst
>(New
)->setTailCall();
322 // Update the callgraph to know that the callsite has been transformed.
323 CallGraphNode
*CalleeNode
= CG
[Call
->getParent()->getParent()];
324 CalleeNode
->removeCallEdgeFor(Call
);
325 CalleeNode
->addCalledFunction(New
, NF_CGN
);
327 // Update all users of sret parameter to extract value using extractvalue.
328 for (Value::use_iterator UI
= FirstCArg
->use_begin(),
329 UE
= FirstCArg
->use_end(); UI
!= UE
; ) {
331 CallInst
*C2
= dyn_cast
<CallInst
>(U2
);
332 if (C2
&& (C2
== Call
))
335 GetElementPtrInst
*UGEP
= cast
<GetElementPtrInst
>(U2
);
336 ConstantInt
*Idx
= cast
<ConstantInt
>(UGEP
->getOperand(2));
337 Value
*GR
= ExtractValueInst::Create(New
, Idx
->getZExtValue(),
339 while(!UGEP
->use_empty()) {
340 // isSafeToUpdateAllCallers has checked that all GEP uses are
342 LoadInst
*L
= cast
<LoadInst
>(*UGEP
->use_begin());
343 L
->replaceAllUsesWith(GR
);
344 L
->eraseFromParent();
346 UGEP
->eraseFromParent();
349 Call
->eraseFromParent();
355 /// nestedStructType - Return true if STy includes any
356 /// other aggregate types
357 bool SRETPromotion::nestedStructType(const StructType
*STy
) {
358 unsigned Num
= STy
->getNumElements();
359 for (unsigned i
= 0; i
< Num
; i
++) {
360 const Type
*Ty
= STy
->getElementType(i
);
361 if (!Ty
->isSingleValueType() && Ty
!= Type::getVoidTy(STy
->getContext()))