2010-04-06 Jb Evain <jbevain@novell.com>
[mcs.git] / class / System.Core / System.Linq.Expressions / ExpressionTransformer.cs
blob5e05de795dad0c5bc952826bfcf52bfab8a30b67
1 //
2 // ExpressionTransformer.cs
3 //
4 // Authors:
5 // Roei Erez (roeie@mainsoft.com)
6 // Jb Evain (jbevain@novell.com)
7 //
8 // Copyright (C) 2007 Novell, Inc (http://www.novell.com)
9 //
10 // Permission is hereby granted, free of charge, to any person obtaining
11 // a copy of this software and associated documentation files (the
12 // "Software"), to deal in the Software without restriction, including
13 // without limitation the rights to use, copy, modify, merge, publish,
14 // distribute, sublicense, and/or sell copies of the Software, and to
15 // permit persons to whom the Software is furnished to do so, subject to
16 // the following conditions:
18 // The above copyright notice and this permission notice shall be
19 // included in all copies or substantial portions of the Software.
21 // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
22 // EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
23 // MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
24 // NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
25 // LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
26 // OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
27 // WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
30 using System;
31 using System.Collections.ObjectModel;
32 using System.Collections.Generic;
33 using System.Linq.Expressions;
35 namespace System.Linq.Expressions {
37 abstract class ExpressionTransformer {
39 public Expression Transform (Expression expression)
41 return Visit (expression);
44 protected virtual Expression Visit (Expression exp)
46 if (exp == null) return exp;
48 switch (exp.NodeType) {
49 case ExpressionType.Negate:
50 case ExpressionType.NegateChecked:
51 case ExpressionType.Not:
52 case ExpressionType.Convert:
53 case ExpressionType.ConvertChecked:
54 case ExpressionType.ArrayLength:
55 case ExpressionType.Quote:
56 case ExpressionType.TypeAs:
57 case ExpressionType.UnaryPlus:
58 return this.VisitUnary ((UnaryExpression) exp);
59 case ExpressionType.Add:
60 case ExpressionType.AddChecked:
61 case ExpressionType.Subtract:
62 case ExpressionType.SubtractChecked:
63 case ExpressionType.Multiply:
64 case ExpressionType.MultiplyChecked:
65 case ExpressionType.Divide:
66 case ExpressionType.Power:
67 case ExpressionType.Modulo:
68 case ExpressionType.And:
69 case ExpressionType.AndAlso:
70 case ExpressionType.Or:
71 case ExpressionType.OrElse:
72 case ExpressionType.LessThan:
73 case ExpressionType.LessThanOrEqual:
74 case ExpressionType.GreaterThan:
75 case ExpressionType.GreaterThanOrEqual:
76 case ExpressionType.Equal:
77 case ExpressionType.NotEqual:
78 case ExpressionType.Coalesce:
79 case ExpressionType.ArrayIndex:
80 case ExpressionType.RightShift:
81 case ExpressionType.LeftShift:
82 case ExpressionType.ExclusiveOr:
83 return this.VisitBinary ((BinaryExpression) exp);
84 case ExpressionType.TypeIs:
85 return this.VisitTypeIs ((TypeBinaryExpression) exp);
86 case ExpressionType.Conditional:
87 return this.VisitConditional ((ConditionalExpression) exp);
88 case ExpressionType.Constant:
89 return this.VisitConstant ((ConstantExpression) exp);
90 case ExpressionType.Parameter:
91 return this.VisitParameter ((ParameterExpression) exp);
92 case ExpressionType.MemberAccess:
93 return this.VisitMemberAccess ((MemberExpression) exp);
94 case ExpressionType.Call:
95 return this.VisitMethodCall ((MethodCallExpression) exp);
96 case ExpressionType.Lambda:
97 return this.VisitLambda ((LambdaExpression) exp);
98 case ExpressionType.New:
99 return this.VisitNew ((NewExpression) exp);
100 case ExpressionType.NewArrayInit:
101 case ExpressionType.NewArrayBounds:
102 return this.VisitNewArray ((NewArrayExpression) exp);
103 case ExpressionType.Invoke:
104 return this.VisitInvocation ((InvocationExpression) exp);
105 case ExpressionType.MemberInit:
106 return this.VisitMemberInit ((MemberInitExpression) exp);
107 case ExpressionType.ListInit:
108 return this.VisitListInit ((ListInitExpression) exp);
109 default:
110 throw new Exception (string.Format ("Unhandled expression type: '{0}'", exp.NodeType));
114 protected virtual MemberBinding VisitBinding (MemberBinding binding)
116 switch (binding.BindingType) {
117 case MemberBindingType.Assignment:
118 return this.VisitMemberAssignment ((MemberAssignment) binding);
119 case MemberBindingType.MemberBinding:
120 return this.VisitMemberMemberBinding ((MemberMemberBinding) binding);
121 case MemberBindingType.ListBinding:
122 return this.VisitMemberListBinding ((MemberListBinding) binding);
123 default:
124 throw new Exception (string.Format ("Unhandled binding type '{0}'", binding.BindingType));
128 protected virtual ElementInit VisitElementInitializer (ElementInit initializer)
130 ReadOnlyCollection<Expression> arguments = this.VisitExpressionList (initializer.Arguments);
131 if (arguments != initializer.Arguments) return Expression.ElementInit (initializer.AddMethod, arguments);
132 return initializer;
135 protected virtual Expression VisitUnary (UnaryExpression u)
137 Expression operand = this.Visit (u.Operand);
138 if (operand != u.Operand) return Expression.MakeUnary (u.NodeType, operand, u.Type, u.Method);
139 return u;
142 protected virtual Expression VisitBinary (BinaryExpression b)
144 Expression left = this.Visit (b.Left);
145 Expression right = this.Visit (b.Right);
146 Expression conversion = this.Visit (b.Conversion);
147 if (left != b.Left || right != b.Right || conversion != b.Conversion) {
148 if (b.NodeType == ExpressionType.Coalesce && b.Conversion != null) {
149 return Expression.Coalesce (left, right, conversion as LambdaExpression);
150 } else {
151 return Expression.MakeBinary (b.NodeType, left, right, b.IsLiftedToNull, b.Method);
154 return b;
157 protected virtual Expression VisitTypeIs (TypeBinaryExpression b)
159 Expression expr = this.Visit (b.Expression);
160 if (expr != b.Expression) {
161 return Expression.TypeIs (expr, b.TypeOperand);
163 return b;
166 protected virtual Expression VisitConstant (ConstantExpression c)
168 return c;
171 protected virtual Expression VisitConditional (ConditionalExpression c)
173 Expression test = this.Visit (c.Test);
174 Expression ifTrue = this.Visit (c.IfTrue);
175 Expression ifFalse = this.Visit (c.IfFalse);
176 if (test != c.Test || ifTrue != c.IfTrue || ifFalse != c.IfFalse) {
177 return Expression.Condition (test, ifTrue, ifFalse);
179 return c;
182 protected virtual Expression VisitParameter (ParameterExpression p)
184 return p;
187 protected virtual Expression VisitMemberAccess (MemberExpression m)
189 Expression exp = this.Visit (m.Expression);
190 if (exp != m.Expression) {
191 return Expression.MakeMemberAccess (exp, m.Member);
193 return m;
196 protected virtual Expression VisitMethodCall (MethodCallExpression m)
198 Expression obj = this.Visit (m.Object);
199 IEnumerable<Expression> args = this.VisitExpressionList (m.Arguments);
200 if (obj != m.Object || args != m.Arguments) {
201 return Expression.Call (obj, m.Method, args);
203 return m;
206 protected virtual ReadOnlyCollection<Expression> VisitExpressionList (ReadOnlyCollection<Expression> original)
208 var list = VisitList (original, Visit);
209 if (list == null) return original;
211 return new ReadOnlyCollection<Expression> (list);
214 protected virtual MemberAssignment VisitMemberAssignment (MemberAssignment assignment)
216 Expression e = this.Visit (assignment.Expression);
217 if (e != assignment.Expression) return Expression.Bind (assignment.Member, e);
218 return assignment;
221 protected virtual MemberMemberBinding VisitMemberMemberBinding (MemberMemberBinding binding)
223 IEnumerable<MemberBinding> bindings = this.VisitBindingList (binding.Bindings);
224 if (bindings != binding.Bindings) return Expression.MemberBind (binding.Member, bindings);
225 return binding;
228 protected virtual MemberListBinding VisitMemberListBinding (MemberListBinding binding)
230 IEnumerable<ElementInit> initializers = this.VisitElementInitializerList (binding.Initializers);
231 if (initializers != binding.Initializers) return Expression.ListBind (binding.Member, initializers);
232 return binding;
235 protected virtual IEnumerable<MemberBinding> VisitBindingList (ReadOnlyCollection<MemberBinding> original)
237 return VisitList (original, VisitBinding);
240 protected virtual IEnumerable<ElementInit> VisitElementInitializerList (ReadOnlyCollection<ElementInit> original)
242 return VisitList (original, VisitElementInitializer);
245 private IList<TElement> VisitList<TElement> (ReadOnlyCollection<TElement> original, Func<TElement, TElement> visit)
247 List<TElement> list = null;
248 for (int i = 0, n = original.Count; i < n; i++) {
249 TElement element = visit (original [i]);
250 if (list != null) {
251 list.Add (element);
252 } else if (!EqualityComparer<TElement>.Default.Equals (element, original [i])) {
253 list = new List<TElement> (n);
254 for (int j = 0; j < i; j++) {
255 list.Add (original [j]);
257 list.Add (element);
260 if (list != null)
261 return list;
263 return original;
266 protected virtual Expression VisitLambda (LambdaExpression lambda)
268 Expression body = this.Visit (lambda.Body);
269 if (body != lambda.Body) return Expression.Lambda (lambda.Type, body, lambda.Parameters);
270 return lambda;
273 protected virtual NewExpression VisitNew (NewExpression nex)
275 IEnumerable<Expression> args = this.VisitExpressionList (nex.Arguments);
276 if (args != nex.Arguments) {
277 if (nex.Members != null)
278 return Expression.New (nex.Constructor, args, nex.Members);
279 else
280 return Expression.New (nex.Constructor, args);
282 return nex;
285 protected virtual Expression VisitMemberInit (MemberInitExpression init)
287 NewExpression n = this.VisitNew (init.NewExpression);
288 IEnumerable<MemberBinding> bindings = this.VisitBindingList (init.Bindings);
289 if (n != init.NewExpression || bindings != init.Bindings) return Expression.MemberInit (n, bindings);
290 return init;
293 protected virtual Expression VisitListInit (ListInitExpression init)
295 NewExpression n = this.VisitNew (init.NewExpression);
296 IEnumerable<ElementInit> initializers = this.VisitElementInitializerList (init.Initializers);
297 if (n != init.NewExpression || initializers != init.Initializers) return Expression.ListInit (n, initializers);
298 return init;
301 protected virtual Expression VisitNewArray (NewArrayExpression na)
303 IEnumerable<Expression> exprs = this.VisitExpressionList (na.Expressions);
304 if (exprs != na.Expressions) {
305 if (na.NodeType == ExpressionType.NewArrayInit) {
306 return Expression.NewArrayInit (na.Type.GetElementType (), exprs);
307 } else {
308 return Expression.NewArrayBounds (na.Type.GetElementType (), exprs);
311 return na;
314 protected virtual Expression VisitInvocation (InvocationExpression iv)
316 IEnumerable<Expression> args = this.VisitExpressionList (iv.Arguments);
317 Expression expr = this.Visit (iv.Expression);
318 if (args != iv.Arguments || expr != iv.Expression) return Expression.Invoke (expr, args);
319 return iv;