Updates referencesource to .NET 4.7
[mono-project.git] / mcs / class / referencesource / System.Data.Entity / System / Data / Objects / ELinq / LinqExpressionNormalizer.cs
blobe6042b498d49cc00c3380b8b829fcf227d3eec84
1 //---------------------------------------------------------------------
2 // <copyright file="LinqExpressionNormalizer.cs" company="Microsoft">
3 // Copyright (c) Microsoft Corporation. All rights reserved.
4 // </copyright>
5 //
6 // @owner Microsoft, Microsoft
7 //---------------------------------------------------------------------
9 using System.Linq.Expressions;
10 using System.Diagnostics;
11 using System.Collections.Generic;
12 using System.Reflection;
13 namespace System.Data.Objects.ELinq
15 /// <summary>
16 /// Replaces expression patterns produced by the compiler with approximations
17 /// used in query translation. For instance, the following VB code:
18 ///
19 /// x = y
20 ///
21 /// becomes the expression
22 ///
23 /// Equal(MethodCallExpression(Microsoft.VisualBasic.CompilerServices.Operators.CompareString(x, y, False), 0)
24 ///
25 /// which is normalized to
26 ///
27 /// Equal(x, y)
28 ///
29 /// Comment convention:
30 ///
31 /// CODE(Lang): _VB or C# coding pattern being simplified_
32 /// ORIGINAL: _original LINQ expression_
33 /// NORMALIZED: _normalized LINQ expression_
34 /// </summary>
35 internal class LinqExpressionNormalizer : EntityExpressionVisitor
37 /// <summary>
38 /// If we encounter a MethodCallExpression, we never need to lift to lift to null. This capability
39 /// exists to translate certain patterns in the language. In this case, the user (or compiler)
40 /// has explicitly asked for a method invocation (at which point, lifting can no longer occur).
41 /// </summary>
42 private const bool LiftToNull = false;
44 /// <summary>
45 /// Gets a dictionary mapping from LINQ expressions to matched by those expressions. Used
46 /// to identify composite expression patterns.
47 /// </summary>
48 private readonly Dictionary<Expression, Pattern> _patterns = new Dictionary<Expression, Pattern>();
50 /// <summary>
51 /// Handle binary patterns:
52 ///
53 /// - VB 'Is' operator
54 /// - Compare patterns
55 /// </summary>
56 internal override Expression VisitBinary(BinaryExpression b)
58 b = (BinaryExpression)base.VisitBinary(b);
60 // CODE(VB): x Is y
61 // ORIGINAL: Equal(Convert(x, typeof(object)), Convert(y, typeof(object))
62 // NORMALIZED: Equal(x, y)
63 if (b.NodeType == ExpressionType.Equal)
65 Expression normalizedLeft = UnwrapObjectConvert(b.Left);
66 Expression normalizedRight = UnwrapObjectConvert(b.Right);
67 if (normalizedLeft != b.Left || normalizedRight != b.Right)
69 b = CreateRelationalOperator(ExpressionType.Equal, normalizedLeft, normalizedRight);
73 // CODE(VB): x = y
74 // ORIGINAL: Equal(Microsoft.VisualBasic.CompilerServices.Operators.CompareString(x, y, False), 0)
75 // NORMALIZED: Equal(x, y)
76 Pattern pattern;
77 if (_patterns.TryGetValue(b.Left, out pattern) && pattern.Kind == PatternKind.Compare && IsConstantZero(b.Right))
79 ComparePattern comparePattern = (ComparePattern)pattern;
80 // handle relational operators
81 BinaryExpression relationalExpression;
82 if (TryCreateRelationalOperator(b.NodeType, comparePattern.Left, comparePattern.Right, out relationalExpression))
84 b = relationalExpression;
88 return b;
91 /// <summary>
92 /// CODE: x
93 /// ORIGINAL: Convert(x, typeof(object))
94 /// ORIGINAL(Funcletized): Constant(x, typeof(object))
95 /// NORMALIZED: x
96 /// </summary>
97 private static Expression UnwrapObjectConvert(Expression input)
99 // recognize funcletized (already evaluated) Converts
100 if (input.NodeType == ExpressionType.Constant &&
101 input.Type == typeof(object))
103 ConstantExpression constant = (ConstantExpression)input;
105 // we will handle nulls later, so just bypass those
106 if (constant.Value != null &&
107 constant.Value.GetType() != typeof(object))
109 return Expression.Constant(constant.Value, constant.Value.GetType());
113 // unwrap object converts
114 while (ExpressionType.Convert == input.NodeType && typeof(object) == input.Type)
116 input = ((UnaryExpression)input).Operand;
118 return input;
121 /// <summary>
122 /// Returns true if the given expression is a constant '0'.
123 /// </summary>
124 private bool IsConstantZero(Expression expression)
126 return expression.NodeType == ExpressionType.Constant &&
127 ((ConstantExpression)expression).Value.Equals(0);
130 /// <summary>
131 /// Handles MethodCall patterns:
132 ///
133 /// - Operator overloads
134 /// - VB operators
135 /// </summary>
136 internal override Expression VisitMethodCall(MethodCallExpression m)
138 m = (MethodCallExpression)base.VisitMethodCall(m);
140 if (m.Method.IsStatic)
142 // handle operator overloads
143 if (m.Method.Name.StartsWith("op_", StringComparison.Ordinal))
145 // handle binary operator overloads
146 if (m.Arguments.Count == 2)
148 // CODE(C#): x == y
149 // ORIGINAL: MethodCallExpression(<op_Equality>, x, y)
150 // NORMALIZED: Equal(x, y)
151 switch (m.Method.Name)
153 case "op_Equality":
154 return Expression.Equal(m.Arguments[0], m.Arguments[1], LiftToNull, m.Method);
156 case "op_Inequality":
157 return Expression.NotEqual(m.Arguments[0], m.Arguments[1], LiftToNull, m.Method);
159 case "op_GreaterThan":
160 return Expression.GreaterThan(m.Arguments[0], m.Arguments[1], LiftToNull, m.Method);
162 case "op_GreaterThanOrEqual":
163 return Expression.GreaterThanOrEqual(m.Arguments[0], m.Arguments[1], LiftToNull, m.Method);
165 case "op_LessThan":
166 return Expression.LessThan(m.Arguments[0], m.Arguments[1], LiftToNull, m.Method);
168 case "op_LessThanOrEqual":
169 return Expression.LessThanOrEqual(m.Arguments[0], m.Arguments[1], LiftToNull, m.Method);
171 case "op_Multiply":
172 return Expression.Multiply(m.Arguments[0], m.Arguments[1], m.Method);
174 case "op_Subtraction":
175 return Expression.Subtract(m.Arguments[0], m.Arguments[1], m.Method);
177 case "op_Addition":
178 return Expression.Add(m.Arguments[0], m.Arguments[1], m.Method);
180 case "op_Division":
181 return Expression.Divide(m.Arguments[0], m.Arguments[1], m.Method);
183 case "op_Modulus":
184 return Expression.Modulo(m.Arguments[0], m.Arguments[1], m.Method);
186 case "op_BitwiseAnd":
187 return Expression.And(m.Arguments[0], m.Arguments[1], m.Method);
189 case "op_BitwiseOr":
190 return Expression.Or(m.Arguments[0], m.Arguments[1], m.Method);
192 case "op_ExclusiveOr":
193 return Expression.ExclusiveOr(m.Arguments[0], m.Arguments[1], m.Method);
195 default:
196 break;
200 // handle unary operator overloads
201 if (m.Arguments.Count == 1)
203 // CODE(C#): +x
204 // ORIGINAL: MethodCallExpression(<op_UnaryPlus>, x)
205 // NORMALIZED: UnaryPlus(x)
206 switch (m.Method.Name)
208 case "op_UnaryNegation":
209 return Expression.Negate(m.Arguments[0], m.Method);
211 case "op_UnaryPlus":
212 return Expression.UnaryPlus(m.Arguments[0], m.Method);
214 case "op_Explicit":
215 case "op_Implicit":
216 return Expression.Convert(m.Arguments[0], m.Type, m.Method);
218 case "op_OnesComplement":
219 case "op_False":
220 return Expression.Not(m.Arguments[0], m.Method);
222 default:
223 break;
228 // check for static Equals method
229 if (m.Method.Name == "Equals" && m.Arguments.Count > 1)
231 // CODE(C#): Object.Equals(x, y)
232 // ORIGINAL: MethodCallExpression(<object.Equals>, x, y)
233 // NORMALIZED: Equal(x, y)
234 return Expression.Equal(m.Arguments[0], m.Arguments[1], false, m.Method);
237 // check for Microsoft.VisualBasic.CompilerServices.Operators.CompareString method
238 if (m.Method.Name == "CompareString" && m.Method.DeclaringType.FullName == "Microsoft.VisualBasic.CompilerServices.Operators")
240 // CODE(VB): x = y; where x and y are strings, a part of the expression looks like:
241 // ORIGINAL: MethodCallExpression(Microsoft.VisualBasic.CompilerServices.Operators.CompareString(x, y, False)
242 // NORMALIZED: see CreateCompareExpression method
243 return CreateCompareExpression(m.Arguments[0], m.Arguments[1]);
246 // check for static Compare method
247 if (m.Method.Name == "Compare" && m.Arguments.Count > 1 && m.Method.ReturnType == typeof(int))
249 // CODE(C#): Class.Compare(x, y)
250 // ORIGINAL: MethodCallExpression(<Compare>, x, y)
251 // NORMALIZED: see CreateCompareExpression method
252 return CreateCompareExpression(m.Arguments[0], m.Arguments[1]);
255 else
257 // check for instance Equals method
258 if (m.Method.Name == "Equals" && m.Arguments.Count > 0)
260 // type-specific Equals method on spatial types becomes a call to the 'STEquals' spatial canonical function, so should remain in the expression tree.
261 Type parameterType = m.Method.GetParameters()[0].ParameterType;
262 if (parameterType != typeof(System.Data.Spatial.DbGeography) && parameterType != typeof(System.Data.Spatial.DbGeometry))
264 // CODE(C#): x.Equals(y)
265 // ORIGINAL: MethodCallExpression(x, <Equals>, y)
266 // NORMALIZED: Equal(x, y)
267 return CreateRelationalOperator(ExpressionType.Equal, m.Object, m.Arguments[0]);
271 // check for instance CompareTo method
272 if (m.Method.Name == "CompareTo" && m.Arguments.Count == 1 && m.Method.ReturnType == typeof(int))
274 // CODE(C#): x.CompareTo(y)
275 // ORIGINAL: MethodCallExpression(x.CompareTo(y))
276 // NORMALIZED: see CreateCompareExpression method
277 return CreateCompareExpression(m.Object, m.Arguments[0]);
280 // check for List<> instance Contains method
281 if (m.Method.Name == "Contains" && m.Arguments.Count == 1) {
282 Type declaringType = m.Method.DeclaringType;
283 if (declaringType.IsGenericType && declaringType.GetGenericTypeDefinition() == typeof(List<>))
285 // CODE(C#): List<T> x.Contains(y)
286 // ORIGINAL: MethodCallExpression(x.Contains(y))
287 // NORMALIZED: IEnumerable<T>.Contains(x, y)
289 MethodInfo containsMethod;
290 if (ReflectionUtil.TryLookupMethod(SequenceMethod.Contains, out containsMethod))
292 MethodInfo enumerableContainsMethod = containsMethod.MakeGenericMethod(declaringType.GetGenericArguments());
293 return Expression.Call(enumerableContainsMethod, m.Object, m.Arguments[0]);
299 // check for coalesce operators added by the VB compiler to predicate arguments
300 return NormalizePredicateArgument(m);
305 /// <summary>
306 /// Identifies and normalizes any predicate argument in the given call expression. If no changes
307 /// are needed, returns the existing expression. Otherwise, returns a new call expression
308 /// with a normalized predicate argument.
309 /// </summary>
310 private static MethodCallExpression NormalizePredicateArgument(MethodCallExpression callExpression)
312 MethodCallExpression result;
314 int argumentOrdinal;
315 Expression normalizedArgument;
316 if (HasPredicateArgument(callExpression, out argumentOrdinal) &&
317 TryMatchCoalescePattern(callExpression.Arguments[argumentOrdinal], out normalizedArgument))
319 List<Expression> normalizedArguments = new List<Expression>(callExpression.Arguments);
321 // replace the predicate argument with the normalized version
322 normalizedArguments[argumentOrdinal] = normalizedArgument;
324 result = Expression.Call(callExpression.Object, callExpression.Method, normalizedArguments);
326 else
328 // nothing has changed
329 result = callExpression;
332 return result;
335 /// <summary>
336 /// Determines whether the given call expression has a 'predicate' argument (e.g. Where(source, predicate))
337 /// and returns the ordinal for the predicate.
338 /// </summary>
339 /// <remarks>
340 /// Obviously this method will need to be replaced if we ever encounter a method with multiple predicates.
341 /// </remarks>
342 private static bool HasPredicateArgument(MethodCallExpression callExpression, out int argumentOrdinal)
344 argumentOrdinal = default(int);
345 bool result = false;
347 // It turns out all supported methods taking a predicate argument have it as the second
348 // argument. As a result, we always set argumentOrdinal to 1 when there is a match and
349 // we can safely ignore all methods taking fewer than 2 arguments
350 SequenceMethod sequenceMethod;
351 if (2 <= callExpression.Arguments.Count &&
352 ReflectionUtil.TryIdentifySequenceMethod(callExpression.Method, out sequenceMethod))
354 switch (sequenceMethod)
356 case SequenceMethod.FirstPredicate:
357 case SequenceMethod.FirstOrDefaultPredicate:
358 case SequenceMethod.SinglePredicate:
359 case SequenceMethod.SingleOrDefaultPredicate:
360 case SequenceMethod.LastPredicate:
361 case SequenceMethod.LastOrDefaultPredicate:
362 case SequenceMethod.Where:
363 case SequenceMethod.WhereOrdinal:
364 case SequenceMethod.CountPredicate:
365 case SequenceMethod.LongCountPredicate:
366 case SequenceMethod.AnyPredicate:
367 case SequenceMethod.All:
368 case SequenceMethod.SkipWhile:
369 case SequenceMethod.SkipWhileOrdinal:
370 case SequenceMethod.TakeWhile:
371 case SequenceMethod.TakeWhileOrdinal:
372 argumentOrdinal = 1; // the second argument is always the one
373 result = true;
374 break;
378 return result;
381 /// <summary>
382 /// Determines whether the given expression of the form Lambda(Coalesce(left, Constant(false)), ...), a pattern
383 /// introduced by the VB compiler for predicate arguments. Returns the 'normalized' version of the expression
384 /// Lambda((bool)left, ...)
385 /// </summary>
386 private static bool TryMatchCoalescePattern(Expression expression, out Expression normalized)
388 normalized = null;
389 bool result = false;
391 if (expression.NodeType == ExpressionType.Quote)
393 // try to normalize the quoted expression
394 UnaryExpression quote = (UnaryExpression)expression;
395 if (TryMatchCoalescePattern(quote.Operand, out normalized))
397 result = true;
398 normalized = Expression.Quote(normalized);
401 else if (expression.NodeType == ExpressionType.Lambda)
403 LambdaExpression lambda = (LambdaExpression)expression;
405 // collapse coalesce lambda expressions
406 // CODE(VB): where a.NullableInt = 1
407 // ORIGINAL: Lambda(Coalesce(expr, Constant(false)), a)
408 // NORMALIZED: Lambda(expr, a)
409 if (lambda.Body.NodeType == ExpressionType.Coalesce && lambda.Body.Type == typeof(bool))
411 BinaryExpression coalesce = (BinaryExpression)lambda.Body;
412 if (coalesce.Right.NodeType == ExpressionType.Constant && false.Equals(((ConstantExpression)coalesce.Right).Value))
414 normalized = Expression.Lambda(lambda.Type, Expression.Convert(coalesce.Left, typeof(bool)), lambda.Parameters);
415 result = true;
420 return result;
423 private static readonly MethodInfo s_relationalOperatorPlaceholderMethod = typeof(LinqExpressionNormalizer).GetMethod("RelationalOperatorPlaceholder", BindingFlags.Static | BindingFlags.NonPublic);
424 /// <summary>
425 /// This method exists solely to support creation of valid relational operator LINQ expressions that are not natively supported
426 /// by the CLR (e.g. String > String). This method must not be invoked.
427 /// </summary>
428 private static bool RelationalOperatorPlaceholder<TLeft, TRight>(TLeft left, TRight right)
430 Debug.Fail("This method should never be called. It exists merely to support creation of relational LINQ expressions.");
431 return object.ReferenceEquals(left, right);
434 /// <summary>
435 /// Create an operator relating 'left' and 'right' given a relational operator.
436 /// </summary>
437 private static BinaryExpression CreateRelationalOperator(ExpressionType op, Expression left, Expression right)
439 BinaryExpression result;
440 if (!TryCreateRelationalOperator(op, left, right, out result))
442 Debug.Fail("CreateRelationalOperator has unknown op " + op);
444 return result;
447 /// <summary>
448 /// Try to create an operator relating 'left' and 'right' using the given operator. If the given operator
449 /// does not define a known relation, returns false.
450 /// </summary>
451 private static bool TryCreateRelationalOperator(ExpressionType op, Expression left, Expression right, out BinaryExpression result)
453 MethodInfo relationalOperatorPlaceholderMethod = s_relationalOperatorPlaceholderMethod.MakeGenericMethod(left.Type, right.Type);
455 switch (op)
457 case ExpressionType.Equal:
458 result = Expression.Equal(left, right, LiftToNull, relationalOperatorPlaceholderMethod);
459 return true;
461 case ExpressionType.NotEqual:
462 result = Expression.NotEqual(left, right, LiftToNull, relationalOperatorPlaceholderMethod);
463 return true;
465 case ExpressionType.LessThan:
466 result = Expression.LessThan(left, right, LiftToNull, relationalOperatorPlaceholderMethod);
467 return true;
469 case ExpressionType.LessThanOrEqual:
470 result = Expression.LessThanOrEqual(left, right, LiftToNull, relationalOperatorPlaceholderMethod);
471 return true;
473 case ExpressionType.GreaterThan:
474 result = Expression.GreaterThan(left, right, LiftToNull, relationalOperatorPlaceholderMethod);
475 return true;
477 case ExpressionType.GreaterThanOrEqual:
478 result = Expression.GreaterThanOrEqual(left, right, LiftToNull, relationalOperatorPlaceholderMethod);
479 return true;
481 default:
482 result = null;
483 return false;
487 /// <summary>
488 /// CODE(C#): Class.Compare(left, right)
489 /// ORIGINAL: MethodCallExpression(Compare, left, right)
490 /// NORMALIZED: Condition(Equal(left, right), 0, Condition(left > right, 1, -1))
491 ///
492 /// Why is this an improvement? We know how to evaluate Condition in the store, but we don't
493 /// know how to evaluate MethodCallExpression... Where the CompareTo appears within a larger expression,
494 /// e.g. left.CompareTo(right) > 0, we can further simplify to left > right (we register the "ComparePattern"
495 /// to make this possible).
496 /// </summary>
497 private Expression CreateCompareExpression(Expression left, Expression right)
499 Expression result = Expression.Condition(
500 CreateRelationalOperator(ExpressionType.Equal, left, right),
501 Expression.Constant(0),
502 Expression.Condition(
503 CreateRelationalOperator(ExpressionType.GreaterThan, left, right),
504 Expression.Constant(1),
505 Expression.Constant(-1)));
507 // Remember that this node matches the pattern
508 _patterns[result] = new ComparePattern(left, right);
510 return result;
513 /// <summary>
514 /// Encapsulates an expression matching some pattern.
515 /// </summary>
516 private abstract class Pattern
518 /// <summary>
519 /// Gets pattern kind.
520 /// </summary>
521 internal abstract PatternKind Kind { get; }
524 /// <summary>
525 /// Gets pattern kind.
526 /// </summary>
527 private enum PatternKind
529 Compare,
532 /// <summary>
533 /// Matches expression of the form x.CompareTo(y) or Class.CompareTo(x, y)
534 /// </summary>
535 private sealed class ComparePattern : Pattern
537 internal ComparePattern(Expression left, Expression right)
539 this.Left = left;
540 this.Right = right;
543 /// <summary>
544 /// Gets left-hand argument to Compare operation.
545 /// </summary>
546 internal readonly Expression Left;
548 /// <summary>
549 /// Gets right-hand argument to Compare operation.
550 /// </summary>
551 internal readonly Expression Right;
554 internal override PatternKind Kind
556 get { return PatternKind.Compare; }