1 //===- MergeFunctions.cpp - Merge identical functions ---------------------===//
3 // The LLVM Compiler Infrastructure
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
8 //===----------------------------------------------------------------------===//
10 // This pass looks for equivalent functions that are mergable and folds them.
12 // A hash is computed from the function, based on its type and number of
15 // Once all hashes are computed, we perform an expensive equality comparison
16 // on each function pair. This takes n^2/2 comparisons per bucket, so it's
17 // important that the hash function be high quality. The equality comparison
18 // iterates through each instruction in each basic block.
20 // When a match is found, the functions are folded. We can only fold two
21 // functions when we know that the definition of one of them is not
24 //===----------------------------------------------------------------------===//
28 // * fold vector<T*>::push_back and vector<S*>::push_back.
30 // These two functions have different types, but in a way that doesn't matter
31 // to us. As long as we never see an S or T itself, using S* and S** is the
32 // same as using a T* and T**.
34 // * virtual functions.
36 // Many functions have their address taken by the virtual function table for
37 // the object they belong to. However, as long as it's only used for a lookup
38 // and call, this is irrelevant, and we'd like to fold such implementations.
40 //===----------------------------------------------------------------------===//
42 #define DEBUG_TYPE "mergefunc"
43 #include "llvm/Transforms/IPO.h"
44 #include "llvm/ADT/DenseMap.h"
45 #include "llvm/ADT/FoldingSet.h"
46 #include "llvm/ADT/Statistic.h"
47 #include "llvm/Constants.h"
48 #include "llvm/InlineAsm.h"
49 #include "llvm/Instructions.h"
50 #include "llvm/LLVMContext.h"
51 #include "llvm/Module.h"
52 #include "llvm/Pass.h"
53 #include "llvm/Support/CallSite.h"
54 #include "llvm/Support/Debug.h"
55 #include "llvm/Support/ErrorHandling.h"
56 #include "llvm/Support/raw_ostream.h"
61 STATISTIC(NumFunctionsMerged
, "Number of functions merged");
64 struct MergeFunctions
: public ModulePass
{
65 static char ID
; // Pass identification, replacement for typeid
66 MergeFunctions() : ModulePass(&ID
) {}
68 bool runOnModule(Module
&M
);
72 char MergeFunctions::ID
= 0;
73 static RegisterPass
<MergeFunctions
>
74 X("mergefunc", "Merge Functions");
76 ModulePass
*llvm::createMergeFunctionsPass() {
77 return new MergeFunctions();
80 // ===----------------------------------------------------------------------===
81 // Comparison of functions
82 // ===----------------------------------------------------------------------===
84 static unsigned long hash(const Function
*F
) {
85 const FunctionType
*FTy
= F
->getFunctionType();
88 ID
.AddInteger(F
->size());
89 ID
.AddInteger(F
->getCallingConv());
90 ID
.AddBoolean(F
->hasGC());
91 ID
.AddBoolean(FTy
->isVarArg());
92 ID
.AddInteger(FTy
->getReturnType()->getTypeID());
93 for (unsigned i
= 0, e
= FTy
->getNumParams(); i
!= e
; ++i
)
94 ID
.AddInteger(FTy
->getParamType(i
)->getTypeID());
95 return ID
.ComputeHash();
98 /// IgnoreBitcasts - given a bitcast, returns the first non-bitcast found by
99 /// walking the chain of cast operands. Otherwise, returns the argument.
100 static Value
* IgnoreBitcasts(Value
*V
) {
101 while (BitCastInst
*BC
= dyn_cast
<BitCastInst
>(V
))
102 V
= BC
->getOperand(0);
107 /// isEquivalentType - any two pointers are equivalent. Otherwise, standard
108 /// type equivalence rules apply.
109 static bool isEquivalentType(const Type
*Ty1
, const Type
*Ty2
) {
112 if (Ty1
->getTypeID() != Ty2
->getTypeID())
115 switch(Ty1
->getTypeID()) {
117 case Type::FloatTyID
:
118 case Type::DoubleTyID
:
119 case Type::X86_FP80TyID
:
120 case Type::FP128TyID
:
121 case Type::PPC_FP128TyID
:
122 case Type::LabelTyID
:
123 case Type::MetadataTyID
:
126 case Type::IntegerTyID
:
127 case Type::OpaqueTyID
:
128 // Ty1 == Ty2 would have returned true earlier.
132 llvm_unreachable("Unknown type!");
135 case Type::PointerTyID
: {
136 const PointerType
*PTy1
= cast
<PointerType
>(Ty1
);
137 const PointerType
*PTy2
= cast
<PointerType
>(Ty2
);
138 return PTy1
->getAddressSpace() == PTy2
->getAddressSpace();
141 case Type::StructTyID
: {
142 const StructType
*STy1
= cast
<StructType
>(Ty1
);
143 const StructType
*STy2
= cast
<StructType
>(Ty2
);
144 if (STy1
->getNumElements() != STy2
->getNumElements())
147 if (STy1
->isPacked() != STy2
->isPacked())
150 for (unsigned i
= 0, e
= STy1
->getNumElements(); i
!= e
; ++i
) {
151 if (!isEquivalentType(STy1
->getElementType(i
), STy2
->getElementType(i
)))
157 case Type::FunctionTyID
: {
158 const FunctionType
*FTy1
= cast
<FunctionType
>(Ty1
);
159 const FunctionType
*FTy2
= cast
<FunctionType
>(Ty2
);
160 if (FTy1
->getNumParams() != FTy2
->getNumParams() ||
161 FTy1
->isVarArg() != FTy2
->isVarArg())
164 if (!isEquivalentType(FTy1
->getReturnType(), FTy2
->getReturnType()))
167 for (unsigned i
= 0, e
= FTy1
->getNumParams(); i
!= e
; ++i
) {
168 if (!isEquivalentType(FTy1
->getParamType(i
), FTy2
->getParamType(i
)))
174 case Type::ArrayTyID
:
175 case Type::VectorTyID
: {
176 const SequentialType
*STy1
= cast
<SequentialType
>(Ty1
);
177 const SequentialType
*STy2
= cast
<SequentialType
>(Ty2
);
178 return isEquivalentType(STy1
->getElementType(), STy2
->getElementType());
183 /// isEquivalentOperation - determine whether the two operations are the same
184 /// except that pointer-to-A and pointer-to-B are equivalent. This should be
185 /// kept in sync with Instruction::isSameOperationAs.
187 isEquivalentOperation(const Instruction
*I1
, const Instruction
*I2
) {
188 if (I1
->getOpcode() != I2
->getOpcode() ||
189 I1
->getNumOperands() != I2
->getNumOperands() ||
190 !isEquivalentType(I1
->getType(), I2
->getType()) ||
191 !I1
->hasSameSubclassOptionalData(I2
))
194 // We have two instructions of identical opcode and #operands. Check to see
195 // if all operands are the same type
196 for (unsigned i
= 0, e
= I1
->getNumOperands(); i
!= e
; ++i
)
197 if (!isEquivalentType(I1
->getOperand(i
)->getType(),
198 I2
->getOperand(i
)->getType()))
201 // Check special state that is a part of some instructions.
202 if (const LoadInst
*LI
= dyn_cast
<LoadInst
>(I1
))
203 return LI
->isVolatile() == cast
<LoadInst
>(I2
)->isVolatile() &&
204 LI
->getAlignment() == cast
<LoadInst
>(I2
)->getAlignment();
205 if (const StoreInst
*SI
= dyn_cast
<StoreInst
>(I1
))
206 return SI
->isVolatile() == cast
<StoreInst
>(I2
)->isVolatile() &&
207 SI
->getAlignment() == cast
<StoreInst
>(I2
)->getAlignment();
208 if (const CmpInst
*CI
= dyn_cast
<CmpInst
>(I1
))
209 return CI
->getPredicate() == cast
<CmpInst
>(I2
)->getPredicate();
210 if (const CallInst
*CI
= dyn_cast
<CallInst
>(I1
))
211 return CI
->isTailCall() == cast
<CallInst
>(I2
)->isTailCall() &&
212 CI
->getCallingConv() == cast
<CallInst
>(I2
)->getCallingConv() &&
213 CI
->getAttributes().getRawPointer() ==
214 cast
<CallInst
>(I2
)->getAttributes().getRawPointer();
215 if (const InvokeInst
*CI
= dyn_cast
<InvokeInst
>(I1
))
216 return CI
->getCallingConv() == cast
<InvokeInst
>(I2
)->getCallingConv() &&
217 CI
->getAttributes().getRawPointer() ==
218 cast
<InvokeInst
>(I2
)->getAttributes().getRawPointer();
219 if (const InsertValueInst
*IVI
= dyn_cast
<InsertValueInst
>(I1
)) {
220 if (IVI
->getNumIndices() != cast
<InsertValueInst
>(I2
)->getNumIndices())
222 for (unsigned i
= 0, e
= IVI
->getNumIndices(); i
!= e
; ++i
)
223 if (IVI
->idx_begin()[i
] != cast
<InsertValueInst
>(I2
)->idx_begin()[i
])
227 if (const ExtractValueInst
*EVI
= dyn_cast
<ExtractValueInst
>(I1
)) {
228 if (EVI
->getNumIndices() != cast
<ExtractValueInst
>(I2
)->getNumIndices())
230 for (unsigned i
= 0, e
= EVI
->getNumIndices(); i
!= e
; ++i
)
231 if (EVI
->idx_begin()[i
] != cast
<ExtractValueInst
>(I2
)->idx_begin()[i
])
239 static bool compare(const Value
*V
, const Value
*U
) {
240 assert(!isa
<BasicBlock
>(V
) && !isa
<BasicBlock
>(U
) &&
241 "Must not compare basic blocks.");
243 assert(isEquivalentType(V
->getType(), U
->getType()) &&
244 "Two of the same operation have operands of different type.");
246 // TODO: If the constant is an expression of F, we should accept that it's
247 // equal to the same expression in terms of G.
248 if (isa
<Constant
>(V
))
251 // The caller has ensured that ValueMap[V] != U. Since Arguments are
252 // pre-loaded into the ValueMap, and Instructions are added as we go, we know
253 // that this can only be a mis-match.
254 if (isa
<Instruction
>(V
) || isa
<Argument
>(V
))
257 if (isa
<InlineAsm
>(V
) && isa
<InlineAsm
>(U
)) {
258 const InlineAsm
*IAF
= cast
<InlineAsm
>(V
);
259 const InlineAsm
*IAG
= cast
<InlineAsm
>(U
);
260 return IAF
->getAsmString() == IAG
->getAsmString() &&
261 IAF
->getConstraintString() == IAG
->getConstraintString();
267 static bool equals(const BasicBlock
*BB1
, const BasicBlock
*BB2
,
268 DenseMap
<const Value
*, const Value
*> &ValueMap
,
269 DenseMap
<const Value
*, const Value
*> &SpeculationMap
) {
270 // Speculatively add it anyways. If it's false, we'll notice a difference
271 // later, and this won't matter.
274 BasicBlock::const_iterator FI
= BB1
->begin(), FE
= BB1
->end();
275 BasicBlock::const_iterator GI
= BB2
->begin(), GE
= BB2
->end();
278 if (isa
<BitCastInst
>(FI
)) {
282 if (isa
<BitCastInst
>(GI
)) {
287 if (!isEquivalentOperation(FI
, GI
))
290 if (isa
<GetElementPtrInst
>(FI
)) {
291 const GetElementPtrInst
*GEPF
= cast
<GetElementPtrInst
>(FI
);
292 const GetElementPtrInst
*GEPG
= cast
<GetElementPtrInst
>(GI
);
293 if (GEPF
->hasAllZeroIndices() && GEPG
->hasAllZeroIndices()) {
294 // It's effectively a bitcast.
299 // TODO: we only really care about the elements before the index
300 if (FI
->getOperand(0)->getType() != GI
->getOperand(0)->getType())
304 if (ValueMap
[FI
] == GI
) {
309 if (ValueMap
[FI
] != NULL
)
312 for (unsigned i
= 0, e
= FI
->getNumOperands(); i
!= e
; ++i
) {
313 Value
*OpF
= IgnoreBitcasts(FI
->getOperand(i
));
314 Value
*OpG
= IgnoreBitcasts(GI
->getOperand(i
));
316 if (ValueMap
[OpF
] == OpG
)
319 if (ValueMap
[OpF
] != NULL
)
322 if (OpF
->getValueID() != OpG
->getValueID() ||
323 !isEquivalentType(OpF
->getType(), OpG
->getType()))
326 if (isa
<PHINode
>(FI
)) {
327 if (SpeculationMap
[OpF
] == NULL
)
328 SpeculationMap
[OpF
] = OpG
;
329 else if (SpeculationMap
[OpF
] != OpG
)
332 } else if (isa
<BasicBlock
>(OpF
)) {
333 assert(isa
<TerminatorInst
>(FI
) &&
334 "BasicBlock referenced by non-Terminator non-PHI");
335 // This call changes the ValueMap, hence we can't use
336 // Value *& = ValueMap[...]
337 if (!equals(cast
<BasicBlock
>(OpF
), cast
<BasicBlock
>(OpG
), ValueMap
,
341 if (!compare(OpF
, OpG
))
350 } while (FI
!= FE
&& GI
!= GE
);
352 return FI
== FE
&& GI
== GE
;
355 static bool equals(const Function
*F
, const Function
*G
) {
356 // We need to recheck everything, but check the things that weren't included
357 // in the hash first.
359 if (F
->getAttributes() != G
->getAttributes())
362 if (F
->hasGC() != G
->hasGC())
365 if (F
->hasGC() && F
->getGC() != G
->getGC())
368 if (F
->hasSection() != G
->hasSection())
371 if (F
->hasSection() && F
->getSection() != G
->getSection())
374 if (F
->isVarArg() != G
->isVarArg())
377 // TODO: if it's internal and only used in direct calls, we could handle this
379 if (F
->getCallingConv() != G
->getCallingConv())
382 if (!isEquivalentType(F
->getFunctionType(), G
->getFunctionType()))
385 DenseMap
<const Value
*, const Value
*> ValueMap
;
386 DenseMap
<const Value
*, const Value
*> SpeculationMap
;
389 assert(F
->arg_size() == G
->arg_size() &&
390 "Identical functions have a different number of args.");
392 for (Function::const_arg_iterator fi
= F
->arg_begin(), gi
= G
->arg_begin(),
393 fe
= F
->arg_end(); fi
!= fe
; ++fi
, ++gi
)
396 if (!equals(&F
->getEntryBlock(), &G
->getEntryBlock(), ValueMap
,
400 for (DenseMap
<const Value
*, const Value
*>::iterator
401 I
= SpeculationMap
.begin(), E
= SpeculationMap
.end(); I
!= E
; ++I
) {
402 if (ValueMap
[I
->first
] != I
->second
)
409 // ===----------------------------------------------------------------------===
410 // Folding of functions
411 // ===----------------------------------------------------------------------===
414 // * F is external strong, G is external strong:
415 // turn G into a thunk to F (1)
416 // * F is external strong, G is external weak:
417 // turn G into a thunk to F (1)
418 // * F is external weak, G is external weak:
420 // * F is external strong, G is internal:
421 // address of G taken:
422 // turn G into a thunk to F (1)
423 // address of G not taken:
424 // make G an alias to F (2)
425 // * F is internal, G is external weak
426 // address of F is taken:
427 // turn G into a thunk to F (1)
428 // address of F is not taken:
429 // make G an alias of F (2)
430 // * F is internal, G is internal:
431 // address of F and G are taken:
432 // turn G into a thunk to F (1)
433 // address of G is not taken:
434 // make G an alias to F (2)
436 // alias requires linkage == (external,local,weak) fallback to creating a thunk
437 // external means 'externally visible' linkage != (internal,private)
438 // internal means linkage == (internal,private)
439 // weak means linkage mayBeOverridable
440 // being external implies that the address is taken
442 // 1. turn G into a thunk to F
443 // 2. make G an alias to F
445 enum LinkageCategory
{
451 static LinkageCategory
categorize(const Function
*F
) {
452 switch (F
->getLinkage()) {
453 case GlobalValue::InternalLinkage
:
454 case GlobalValue::PrivateLinkage
:
455 case GlobalValue::LinkerPrivateLinkage
:
458 case GlobalValue::WeakAnyLinkage
:
459 case GlobalValue::WeakODRLinkage
:
460 case GlobalValue::ExternalWeakLinkage
:
463 case GlobalValue::ExternalLinkage
:
464 case GlobalValue::AvailableExternallyLinkage
:
465 case GlobalValue::LinkOnceAnyLinkage
:
466 case GlobalValue::LinkOnceODRLinkage
:
467 case GlobalValue::AppendingLinkage
:
468 case GlobalValue::DLLImportLinkage
:
469 case GlobalValue::DLLExportLinkage
:
470 case GlobalValue::CommonLinkage
:
471 return ExternalStrong
;
474 llvm_unreachable("Unknown LinkageType.");
478 static void ThunkGToF(Function
*F
, Function
*G
) {
479 Function
*NewG
= Function::Create(G
->getFunctionType(), G
->getLinkage(), "",
481 BasicBlock
*BB
= BasicBlock::Create(F
->getContext(), "", NewG
);
483 std::vector
<Value
*> Args
;
485 const FunctionType
*FFTy
= F
->getFunctionType();
486 for (Function::arg_iterator AI
= NewG
->arg_begin(), AE
= NewG
->arg_end();
488 if (FFTy
->getParamType(i
) == AI
->getType())
491 Value
*BCI
= new BitCastInst(AI
, FFTy
->getParamType(i
), "", BB
);
497 CallInst
*CI
= CallInst::Create(F
, Args
.begin(), Args
.end(), "", BB
);
499 CI
->setCallingConv(F
->getCallingConv());
500 if (NewG
->getReturnType()->isVoidTy()) {
501 ReturnInst::Create(F
->getContext(), BB
);
502 } else if (CI
->getType() != NewG
->getReturnType()) {
503 Value
*BCI
= new BitCastInst(CI
, NewG
->getReturnType(), "", BB
);
504 ReturnInst::Create(F
->getContext(), BCI
, BB
);
506 ReturnInst::Create(F
->getContext(), CI
, BB
);
509 NewG
->copyAttributesFrom(G
);
511 G
->replaceAllUsesWith(NewG
);
512 G
->eraseFromParent();
514 // TODO: look at direct callers to G and make them all direct callers to F.
517 static void AliasGToF(Function
*F
, Function
*G
) {
518 if (!G
->hasExternalLinkage() && !G
->hasLocalLinkage() && !G
->hasWeakLinkage())
519 return ThunkGToF(F
, G
);
521 GlobalAlias
*GA
= new GlobalAlias(
522 G
->getType(), G
->getLinkage(), "",
523 ConstantExpr::getBitCast(F
, G
->getType()), G
->getParent());
524 F
->setAlignment(std::max(F
->getAlignment(), G
->getAlignment()));
526 GA
->setVisibility(G
->getVisibility());
527 G
->replaceAllUsesWith(GA
);
528 G
->eraseFromParent();
531 static bool fold(std::vector
<Function
*> &FnVec
, unsigned i
, unsigned j
) {
532 Function
*F
= FnVec
[i
];
533 Function
*G
= FnVec
[j
];
535 LinkageCategory catF
= categorize(F
);
536 LinkageCategory catG
= categorize(G
);
538 if (catF
== ExternalWeak
|| (catF
== Internal
&& catG
== ExternalStrong
)) {
539 std::swap(FnVec
[i
], FnVec
[j
]);
541 std::swap(catF
, catG
);
552 if (G
->hasAddressTaken())
561 assert(catG
== ExternalWeak
);
563 // Make them both thunks to the same internal function.
564 F
->setAlignment(std::max(F
->getAlignment(), G
->getAlignment()));
565 Function
*H
= Function::Create(F
->getFunctionType(), F
->getLinkage(), "",
567 H
->copyAttributesFrom(F
);
569 F
->replaceAllUsesWith(H
);
574 F
->setLinkage(GlobalValue::InternalLinkage
);
583 if (F
->hasAddressTaken())
589 bool addrTakenF
= F
->hasAddressTaken();
590 bool addrTakenG
= G
->hasAddressTaken();
591 if (!addrTakenF
&& addrTakenG
) {
592 std::swap(FnVec
[i
], FnVec
[j
]);
594 std::swap(addrTakenF
, addrTakenG
);
597 if (addrTakenF
&& addrTakenG
) {
608 ++NumFunctionsMerged
;
612 // ===----------------------------------------------------------------------===
614 // ===----------------------------------------------------------------------===
616 bool MergeFunctions::runOnModule(Module
&M
) {
617 bool Changed
= false;
619 std::map
<unsigned long, std::vector
<Function
*> > FnMap
;
621 for (Module::iterator F
= M
.begin(), E
= M
.end(); F
!= E
; ++F
) {
622 if (F
->isDeclaration() || F
->isIntrinsic())
625 FnMap
[hash(F
)].push_back(F
);
628 // TODO: instead of running in a loop, we could also fold functions in
629 // callgraph order. Constructing the CFG probably isn't cheaper than just
630 // running in a loop, unless it happened to already be available.
634 LocalChanged
= false;
635 DEBUG(dbgs() << "size: " << FnMap
.size() << "\n");
636 for (std::map
<unsigned long, std::vector
<Function
*> >::iterator
637 I
= FnMap
.begin(), E
= FnMap
.end(); I
!= E
; ++I
) {
638 std::vector
<Function
*> &FnVec
= I
->second
;
639 DEBUG(dbgs() << "hash (" << I
->first
<< "): " << FnVec
.size() << "\n");
641 for (int i
= 0, e
= FnVec
.size(); i
!= e
; ++i
) {
642 for (int j
= i
+ 1; j
!= e
; ++j
) {
643 bool isEqual
= equals(FnVec
[i
], FnVec
[j
]);
645 DEBUG(dbgs() << " " << FnVec
[i
]->getName()
646 << (isEqual
? " == " : " != ")
647 << FnVec
[j
]->getName() << "\n");
650 if (fold(FnVec
, i
, j
)) {
652 FnVec
.erase(FnVec
.begin() + j
);
660 Changed
|= LocalChanged
;
661 } while (LocalChanged
);