1
//---------------------------------------------------------------------
2 // <copyright file="LinqExpressionNormalizer.cs" company="Microsoft">
3 // Copyright (c) Microsoft Corporation. All rights reserved.
6 // @owner Microsoft, Microsoft
7 //---------------------------------------------------------------------
9 using System
.Linq
.Expressions
;
10 using System
.Diagnostics
;
11 using System
.Collections
.Generic
;
12 using System
.Reflection
;
13 namespace System
.Data
.Objects
.ELinq
16 /// Replaces expression patterns produced by the compiler with approximations
17 /// used in query translation. For instance, the following VB code:
21 /// becomes the expression
23 /// Equal(MethodCallExpression(Microsoft.VisualBasic.CompilerServices.Operators.CompareString(x, y, False), 0)
25 /// which is normalized to
29 /// Comment convention:
31 /// CODE(Lang): _VB or C# coding pattern being simplified_
32 /// ORIGINAL: _original LINQ expression_
33 /// NORMALIZED: _normalized LINQ expression_
35 internal class LinqExpressionNormalizer
: EntityExpressionVisitor
38 /// If we encounter a MethodCallExpression, we never need to lift to lift to null. This capability
39 /// exists to translate certain patterns in the language. In this case, the user (or compiler)
40 /// has explicitly asked for a method invocation (at which point, lifting can no longer occur).
42 private const bool LiftToNull
= false;
45 /// Gets a dictionary mapping from LINQ expressions to matched by those expressions. Used
46 /// to identify composite expression patterns.
48 private readonly Dictionary
<Expression
, Pattern
> _patterns
= new Dictionary
<Expression
, Pattern
>();
51 /// Handle binary patterns:
53 /// - VB 'Is' operator
54 /// - Compare patterns
56 internal override Expression
VisitBinary(BinaryExpression b
)
58 b
= (BinaryExpression
)base.VisitBinary(b
);
61 // ORIGINAL: Equal(Convert(x, typeof(object)), Convert(y, typeof(object))
62 // NORMALIZED: Equal(x, y)
63 if (b
.NodeType
== ExpressionType
.Equal
)
65 Expression normalizedLeft
= UnwrapObjectConvert(b
.Left
);
66 Expression normalizedRight
= UnwrapObjectConvert(b
.Right
);
67 if (normalizedLeft
!= b
.Left
|| normalizedRight
!= b
.Right
)
69 b
= CreateRelationalOperator(ExpressionType
.Equal
, normalizedLeft
, normalizedRight
);
74 // ORIGINAL: Equal(Microsoft.VisualBasic.CompilerServices.Operators.CompareString(x, y, False), 0)
75 // NORMALIZED: Equal(x, y)
77 if (_patterns
.TryGetValue(b
.Left
, out pattern
) && pattern
.Kind
== PatternKind
.Compare
&& IsConstantZero(b
.Right
))
79 ComparePattern comparePattern
= (ComparePattern
)pattern
;
80 // handle relational operators
81 BinaryExpression relationalExpression
;
82 if (TryCreateRelationalOperator(b
.NodeType
, comparePattern
.Left
, comparePattern
.Right
, out relationalExpression
))
84 b
= relationalExpression
;
93 /// ORIGINAL: Convert(x, typeof(object))
94 /// ORIGINAL(Funcletized): Constant(x, typeof(object))
97 private static Expression
UnwrapObjectConvert(Expression input
)
99 // recognize funcletized (already evaluated) Converts
100 if (input
.NodeType
== ExpressionType
.Constant
&&
101 input
.Type
== typeof(object))
103 ConstantExpression constant
= (ConstantExpression
)input
;
105 // we will handle nulls later, so just bypass those
106 if (constant
.Value
!= null &&
107 constant
.Value
.GetType() != typeof(object))
109 return Expression
.Constant(constant
.Value
, constant
.Value
.GetType());
113 // unwrap object converts
114 while (ExpressionType
.Convert
== input
.NodeType
&& typeof(object) == input
.Type
)
116 input
= ((UnaryExpression
)input
).Operand
;
122 /// Returns true if the given expression is a constant '0'.
124 private bool IsConstantZero(Expression expression
)
126 return expression
.NodeType
== ExpressionType
.Constant
&&
127 ((ConstantExpression
)expression
).Value
.Equals(0);
131 /// Handles MethodCall patterns:
133 /// - Operator overloads
136 internal override Expression
VisitMethodCall(MethodCallExpression m
)
138 m
= (MethodCallExpression
)base.VisitMethodCall(m
);
140 if (m
.Method
.IsStatic
)
142 // handle operator overloads
143 if (m
.Method
.Name
.StartsWith("op_", StringComparison
.Ordinal
))
145 // handle binary operator overloads
146 if (m
.Arguments
.Count
== 2)
149 // ORIGINAL: MethodCallExpression(<op_Equality>, x, y)
150 // NORMALIZED: Equal(x, y)
151 switch (m
.Method
.Name
)
154 return Expression
.Equal(m
.Arguments
[0], m
.Arguments
[1], LiftToNull
, m
.Method
);
156 case "op_Inequality":
157 return Expression
.NotEqual(m
.Arguments
[0], m
.Arguments
[1], LiftToNull
, m
.Method
);
159 case "op_GreaterThan":
160 return Expression
.GreaterThan(m
.Arguments
[0], m
.Arguments
[1], LiftToNull
, m
.Method
);
162 case "op_GreaterThanOrEqual":
163 return Expression
.GreaterThanOrEqual(m
.Arguments
[0], m
.Arguments
[1], LiftToNull
, m
.Method
);
166 return Expression
.LessThan(m
.Arguments
[0], m
.Arguments
[1], LiftToNull
, m
.Method
);
168 case "op_LessThanOrEqual":
169 return Expression
.LessThanOrEqual(m
.Arguments
[0], m
.Arguments
[1], LiftToNull
, m
.Method
);
172 return Expression
.Multiply(m
.Arguments
[0], m
.Arguments
[1], m
.Method
);
174 case "op_Subtraction":
175 return Expression
.Subtract(m
.Arguments
[0], m
.Arguments
[1], m
.Method
);
178 return Expression
.Add(m
.Arguments
[0], m
.Arguments
[1], m
.Method
);
181 return Expression
.Divide(m
.Arguments
[0], m
.Arguments
[1], m
.Method
);
184 return Expression
.Modulo(m
.Arguments
[0], m
.Arguments
[1], m
.Method
);
186 case "op_BitwiseAnd":
187 return Expression
.And(m
.Arguments
[0], m
.Arguments
[1], m
.Method
);
190 return Expression
.Or(m
.Arguments
[0], m
.Arguments
[1], m
.Method
);
192 case "op_ExclusiveOr":
193 return Expression
.ExclusiveOr(m
.Arguments
[0], m
.Arguments
[1], m
.Method
);
200 // handle unary operator overloads
201 if (m
.Arguments
.Count
== 1)
204 // ORIGINAL: MethodCallExpression(<op_UnaryPlus>, x)
205 // NORMALIZED: UnaryPlus(x)
206 switch (m
.Method
.Name
)
208 case "op_UnaryNegation":
209 return Expression
.Negate(m
.Arguments
[0], m
.Method
);
212 return Expression
.UnaryPlus(m
.Arguments
[0], m
.Method
);
216 return Expression
.Convert(m
.Arguments
[0], m
.Type
, m
.Method
);
218 case "op_OnesComplement":
220 return Expression
.Not(m
.Arguments
[0], m
.Method
);
228 // check for static Equals method
229 if (m
.Method
.Name
== "Equals" && m
.Arguments
.Count
> 1)
231 // CODE(C#): Object.Equals(x, y)
232 // ORIGINAL: MethodCallExpression(<object.Equals>, x, y)
233 // NORMALIZED: Equal(x, y)
234 return Expression
.Equal(m
.Arguments
[0], m
.Arguments
[1], false, m
.Method
);
237 // check for Microsoft.VisualBasic.CompilerServices.Operators.CompareString method
238 if (m
.Method
.Name
== "CompareString" && m
.Method
.DeclaringType
.FullName
== "Microsoft.VisualBasic.CompilerServices.Operators")
240 // CODE(VB): x = y; where x and y are strings, a part of the expression looks like:
241 // ORIGINAL: MethodCallExpression(Microsoft.VisualBasic.CompilerServices.Operators.CompareString(x, y, False)
242 // NORMALIZED: see CreateCompareExpression method
243 return CreateCompareExpression(m
.Arguments
[0], m
.Arguments
[1]);
246 // check for static Compare method
247 if (m
.Method
.Name
== "Compare" && m
.Arguments
.Count
> 1 && m
.Method
.ReturnType
== typeof(int))
249 // CODE(C#): Class.Compare(x, y)
250 // ORIGINAL: MethodCallExpression(<Compare>, x, y)
251 // NORMALIZED: see CreateCompareExpression method
252 return CreateCompareExpression(m
.Arguments
[0], m
.Arguments
[1]);
257 // check for instance Equals method
258 if (m
.Method
.Name
== "Equals" && m
.Arguments
.Count
> 0)
260 // type-specific Equals method on spatial types becomes a call to the 'STEquals' spatial canonical function, so should remain in the expression tree.
261 Type parameterType
= m
.Method
.GetParameters()[0].ParameterType
;
262 if (parameterType
!= typeof(System
.Data
.Spatial
.DbGeography
) && parameterType
!= typeof(System
.Data
.Spatial
.DbGeometry
))
264 // CODE(C#): x.Equals(y)
265 // ORIGINAL: MethodCallExpression(x, <Equals>, y)
266 // NORMALIZED: Equal(x, y)
267 return CreateRelationalOperator(ExpressionType
.Equal
, m
.Object
, m
.Arguments
[0]);
271 // check for instance CompareTo method
272 if (m
.Method
.Name
== "CompareTo" && m
.Arguments
.Count
== 1 && m
.Method
.ReturnType
== typeof(int))
274 // CODE(C#): x.CompareTo(y)
275 // ORIGINAL: MethodCallExpression(x.CompareTo(y))
276 // NORMALIZED: see CreateCompareExpression method
277 return CreateCompareExpression(m
.Object
, m
.Arguments
[0]);
280 // check for List<> instance Contains method
281 if (m
.Method
.Name
== "Contains" && m
.Arguments
.Count
== 1) {
282 Type declaringType
= m
.Method
.DeclaringType
;
283 if (declaringType
.IsGenericType
&& declaringType
.GetGenericTypeDefinition() == typeof(List
<>))
285 // CODE(C#): List<T> x.Contains(y)
286 // ORIGINAL: MethodCallExpression(x.Contains(y))
287 // NORMALIZED: IEnumerable<T>.Contains(x, y)
289 MethodInfo containsMethod
;
290 if (ReflectionUtil
.TryLookupMethod(SequenceMethod
.Contains
, out containsMethod
))
292 MethodInfo enumerableContainsMethod
= containsMethod
.MakeGenericMethod(declaringType
.GetGenericArguments());
293 return Expression
.Call(enumerableContainsMethod
, m
.Object
, m
.Arguments
[0]);
299 // check for coalesce operators added by the VB compiler to predicate arguments
300 return NormalizePredicateArgument(m
);
306 /// Identifies and normalizes any predicate argument in the given call expression. If no changes
307 /// are needed, returns the existing expression. Otherwise, returns a new call expression
308 /// with a normalized predicate argument.
310 private static MethodCallExpression
NormalizePredicateArgument(MethodCallExpression callExpression
)
312 MethodCallExpression result
;
315 Expression normalizedArgument
;
316 if (HasPredicateArgument(callExpression
, out argumentOrdinal
) &&
317 TryMatchCoalescePattern(callExpression
.Arguments
[argumentOrdinal
], out normalizedArgument
))
319 List
<Expression
> normalizedArguments
= new List
<Expression
>(callExpression
.Arguments
);
321 // replace the predicate argument with the normalized version
322 normalizedArguments
[argumentOrdinal
] = normalizedArgument
;
324 result
= Expression
.Call(callExpression
.Object
, callExpression
.Method
, normalizedArguments
);
328 // nothing has changed
329 result
= callExpression
;
336 /// Determines whether the given call expression has a 'predicate' argument (e.g. Where(source, predicate))
337 /// and returns the ordinal for the predicate.
340 /// Obviously this method will need to be replaced if we ever encounter a method with multiple predicates.
342 private static bool HasPredicateArgument(MethodCallExpression callExpression
, out int argumentOrdinal
)
344 argumentOrdinal
= default(int);
347 // It turns out all supported methods taking a predicate argument have it as the second
348 // argument. As a result, we always set argumentOrdinal to 1 when there is a match and
349 // we can safely ignore all methods taking fewer than 2 arguments
350 SequenceMethod sequenceMethod
;
351 if (2 <= callExpression
.Arguments
.Count
&&
352 ReflectionUtil
.TryIdentifySequenceMethod(callExpression
.Method
, out sequenceMethod
))
354 switch (sequenceMethod
)
356 case SequenceMethod
.FirstPredicate
:
357 case SequenceMethod
.FirstOrDefaultPredicate
:
358 case SequenceMethod
.SinglePredicate
:
359 case SequenceMethod
.SingleOrDefaultPredicate
:
360 case SequenceMethod
.LastPredicate
:
361 case SequenceMethod
.LastOrDefaultPredicate
:
362 case SequenceMethod
.Where
:
363 case SequenceMethod
.WhereOrdinal
:
364 case SequenceMethod
.CountPredicate
:
365 case SequenceMethod
.LongCountPredicate
:
366 case SequenceMethod
.AnyPredicate
:
367 case SequenceMethod
.All
:
368 case SequenceMethod
.SkipWhile
:
369 case SequenceMethod
.SkipWhileOrdinal
:
370 case SequenceMethod
.TakeWhile
:
371 case SequenceMethod
.TakeWhileOrdinal
:
372 argumentOrdinal
= 1; // the second argument is always the one
382 /// Determines whether the given expression of the form Lambda(Coalesce(left, Constant(false)), ...), a pattern
383 /// introduced by the VB compiler for predicate arguments. Returns the 'normalized' version of the expression
384 /// Lambda((bool)left, ...)
386 private static bool TryMatchCoalescePattern(Expression expression
, out Expression normalized
)
391 if (expression
.NodeType
== ExpressionType
.Quote
)
393 // try to normalize the quoted expression
394 UnaryExpression quote
= (UnaryExpression
)expression
;
395 if (TryMatchCoalescePattern(quote
.Operand
, out normalized
))
398 normalized
= Expression
.Quote(normalized
);
401 else if (expression
.NodeType
== ExpressionType
.Lambda
)
403 LambdaExpression lambda
= (LambdaExpression
)expression
;
405 // collapse coalesce lambda expressions
406 // CODE(VB): where a.NullableInt = 1
407 // ORIGINAL: Lambda(Coalesce(expr, Constant(false)), a)
408 // NORMALIZED: Lambda(expr, a)
409 if (lambda
.Body
.NodeType
== ExpressionType
.Coalesce
&& lambda
.Body
.Type
== typeof(bool))
411 BinaryExpression coalesce
= (BinaryExpression
)lambda
.Body
;
412 if (coalesce
.Right
.NodeType
== ExpressionType
.Constant
&& false.Equals(((ConstantExpression
)coalesce
.Right
).Value
))
414 normalized
= Expression
.Lambda(lambda
.Type
, Expression
.Convert(coalesce
.Left
, typeof(bool)), lambda
.Parameters
);
423 private static readonly MethodInfo s_relationalOperatorPlaceholderMethod
= typeof(LinqExpressionNormalizer
).GetMethod("RelationalOperatorPlaceholder", BindingFlags
.Static
| BindingFlags
.NonPublic
);
425 /// This method exists solely to support creation of valid relational operator LINQ expressions that are not natively supported
426 /// by the CLR (e.g. String > String). This method must not be invoked.
428 private static bool RelationalOperatorPlaceholder
<TLeft
, TRight
>(TLeft left
, TRight right
)
430 Debug
.Fail("This method should never be called. It exists merely to support creation of relational LINQ expressions.");
431 return object.ReferenceEquals(left
, right
);
435 /// Create an operator relating 'left' and 'right' given a relational operator.
437 private static BinaryExpression
CreateRelationalOperator(ExpressionType op
, Expression left
, Expression right
)
439 BinaryExpression result
;
440 if (!TryCreateRelationalOperator(op
, left
, right
, out result
))
442 Debug
.Fail("CreateRelationalOperator has unknown op " + op
);
448 /// Try to create an operator relating 'left' and 'right' using the given operator. If the given operator
449 /// does not define a known relation, returns false.
451 private static bool TryCreateRelationalOperator(ExpressionType op
, Expression left
, Expression right
, out BinaryExpression result
)
453 MethodInfo relationalOperatorPlaceholderMethod
= s_relationalOperatorPlaceholderMethod
.MakeGenericMethod(left
.Type
, right
.Type
);
457 case ExpressionType
.Equal
:
458 result
= Expression
.Equal(left
, right
, LiftToNull
, relationalOperatorPlaceholderMethod
);
461 case ExpressionType
.NotEqual
:
462 result
= Expression
.NotEqual(left
, right
, LiftToNull
, relationalOperatorPlaceholderMethod
);
465 case ExpressionType
.LessThan
:
466 result
= Expression
.LessThan(left
, right
, LiftToNull
, relationalOperatorPlaceholderMethod
);
469 case ExpressionType
.LessThanOrEqual
:
470 result
= Expression
.LessThanOrEqual(left
, right
, LiftToNull
, relationalOperatorPlaceholderMethod
);
473 case ExpressionType
.GreaterThan
:
474 result
= Expression
.GreaterThan(left
, right
, LiftToNull
, relationalOperatorPlaceholderMethod
);
477 case ExpressionType
.GreaterThanOrEqual
:
478 result
= Expression
.GreaterThanOrEqual(left
, right
, LiftToNull
, relationalOperatorPlaceholderMethod
);
488 /// CODE(C#): Class.Compare(left, right)
489 /// ORIGINAL: MethodCallExpression(Compare, left, right)
490 /// NORMALIZED: Condition(Equal(left, right), 0, Condition(left > right, 1, -1))
492 /// Why is this an improvement? We know how to evaluate Condition in the store, but we don't
493 /// know how to evaluate MethodCallExpression... Where the CompareTo appears within a larger expression,
494 /// e.g. left.CompareTo(right) > 0, we can further simplify to left > right (we register the "ComparePattern"
495 /// to make this possible).
497 private Expression
CreateCompareExpression(Expression left
, Expression right
)
499 Expression result
= Expression
.Condition(
500 CreateRelationalOperator(ExpressionType
.Equal
, left
, right
),
501 Expression
.Constant(0),
502 Expression
.Condition(
503 CreateRelationalOperator(ExpressionType
.GreaterThan
, left
, right
),
504 Expression
.Constant(1),
505 Expression
.Constant(-1)));
507 // Remember that this node matches the pattern
508 _patterns
[result
] = new ComparePattern(left
, right
);
514 /// Encapsulates an expression matching some pattern.
516 private abstract class Pattern
519 /// Gets pattern kind.
521 internal abstract PatternKind Kind { get; }
525 /// Gets pattern kind.
527 private enum PatternKind
533 /// Matches expression of the form x.CompareTo(y) or Class.CompareTo(x, y)
535 private sealed class ComparePattern
: Pattern
537 internal ComparePattern(Expression left
, Expression right
)
544 /// Gets left-hand argument to Compare operation.
546 internal readonly Expression Left
;
549 /// Gets right-hand argument to Compare operation.
551 internal readonly Expression Right
;
554 internal override PatternKind Kind
556 get { return PatternKind.Compare; }