Added some tests for OpenIdRelyingParty.
[dotnetoauth.git] / src / DotNetOpenAuth / Messaging / MessagingUtilities.cs
blob571d54b684ee30ef69c7cabb1a813e9e51bf7dee
1 //-----------------------------------------------------------------------
2 // <copyright file="MessagingUtilities.cs" company="Andrew Arnott">
3 // Copyright (c) Andrew Arnott. All rights reserved.
4 // </copyright>
5 //-----------------------------------------------------------------------
7 namespace DotNetOpenAuth.Messaging {
8 using System;
9 using System.Collections.Generic;
10 using System.Collections.Specialized;
11 using System.Diagnostics.CodeAnalysis;
12 using System.IO;
13 using System.Linq;
14 using System.Net;
15 using System.Security.Cryptography;
16 using System.Text;
17 using System.Web;
18 using DotNetOpenAuth.Messaging.Reflection;
20 /// <summary>
21 /// A grab-bag of utility methods useful for the channel stack of the protocol.
22 /// </summary>
23 public static class MessagingUtilities {
24 /// <summary>
25 /// The cryptographically strong random data generator used for creating secrets.
26 /// </summary>
27 /// <remarks>The random number generator is thread-safe.</remarks>
28 internal static readonly RandomNumberGenerator CryptoRandomDataGenerator = new RNGCryptoServiceProvider();
30 /// <summary>
31 /// Gets the original request URL, as seen from the browser before any URL rewrites on the server if any.
32 /// Cookieless session directory (if applicable) is also included.
33 /// </summary>
34 /// <returns>The URL in the user agent's Location bar.</returns>
35 [SuppressMessage("Microsoft.Usage", "CA2234:PassSystemUriObjectsInsteadOfStrings", Justification = "The Uri merging requires use of a string value.")]
36 [SuppressMessage("Microsoft.Design", "CA1024:UsePropertiesWhereAppropriate", Justification = "Expensive call should not be a property.")]
37 public static Uri GetRequestUrlFromContext() {
38 ErrorUtilities.VerifyHttpContext();
39 HttpContext context = HttpContext.Current;
41 // We use Request.Url for the full path to the server, and modify it
42 // with Request.RawUrl to capture both the cookieless session "directory" if it exists
43 // and the original path in case URL rewriting is going on. We don't want to be
44 // fooled by URL rewriting because we're comparing the actual URL with what's in
45 // the return_to parameter in some cases.
46 // Response.ApplyAppPathModifier(builder.Path) would have worked for the cookieless
47 // session, but not the URL rewriting problem.
48 return new Uri(context.Request.Url, context.Request.RawUrl);
51 /// <summary>
52 /// Gets the query data from the original request (before any URL rewriting has occurred.)
53 /// </summary>
54 /// <returns>A <see cref="NameValueCollection"/> containing all the parameters in the query string.</returns>
55 public static NameValueCollection GetQueryFromContextNVC() {
56 ErrorUtilities.VerifyHttpContext();
58 HttpRequest request = HttpContext.Current.Request;
60 // This request URL may have been rewritten by the host site.
61 // For openid protocol purposes, we really need to look at
62 // the original query parameters before any rewriting took place.
63 if (request.Url.PathAndQuery == request.RawUrl) {
64 // No rewriting has taken place.
65 return request.QueryString;
66 } else {
67 // Rewriting detected! Recover the original request URI.
68 return HttpUtility.ParseQueryString(GetRequestUrlFromContext().Query);
72 /// <summary>
73 /// Strips any and all URI query parameters that start with some prefix.
74 /// </summary>
75 /// <param name="uri">The URI that may have a query with parameters to remove.</param>
76 /// <param name="prefix">The prefix for parameters to remove.</param>
77 /// <returns>Either a new Uri with the parameters removed if there were any to remove, or the same Uri instance if no parameters needed to be removed.</returns>
78 public static Uri StripQueryArgumentsWithPrefix(this Uri uri, string prefix) {
79 ErrorUtilities.VerifyArgumentNotNull(uri, "uri");
81 NameValueCollection queryArgs = HttpUtility.ParseQueryString(uri.Query);
82 var matchingKeys = queryArgs.Keys.OfType<string>().Where(key => key.StartsWith(prefix, StringComparison.OrdinalIgnoreCase)).ToList();
83 if (matchingKeys.Count > 0) {
84 UriBuilder builder = new UriBuilder(uri);
85 foreach (string key in matchingKeys) {
86 queryArgs.Remove(key);
88 builder.Query = CreateQueryString(queryArgs.ToDictionary());
89 return builder.Uri;
90 } else {
91 return uri;
95 /// <summary>
96 /// Gets a cryptographically strong random sequence of values.
97 /// </summary>
98 /// <param name="length">The length of the sequence to generate.</param>
99 /// <returns>The generated values, which may contain zeros.</returns>
100 internal static byte[] GetCryptoRandomData(int length) {
101 byte[] buffer = new byte[length];
102 CryptoRandomDataGenerator.GetBytes(buffer);
103 return buffer;
106 /// <summary>
107 /// Gets a cryptographically strong random sequence of values.
108 /// </summary>
109 /// <param name="binaryLength">The length of the byte sequence to generate.</param>
110 /// <returns>A base64 encoding of the generated random data,
111 /// whose length in characters will likely be greater than <paramref name="binaryLength"/>.</returns>
112 internal static string GetCryptoRandomDataAsBase64(int binaryLength) {
113 byte[] uniq_bytes = GetCryptoRandomData(binaryLength);
114 string uniq = Convert.ToBase64String(uniq_bytes);
115 return uniq;
118 /// <summary>
119 /// Adds a set of HTTP headers to an <see cref="HttpResponse"/> instance,
120 /// taking care to set some headers to the appropriate properties of
121 /// <see cref="HttpResponse" />
122 /// </summary>
123 /// <param name="headers">The headers to add.</param>
124 /// <param name="response">The <see cref="HttpResponse"/> instance to set the appropriate values to.</param>
125 internal static void ApplyHeadersToResponse(WebHeaderCollection headers, HttpResponse response) {
126 ErrorUtilities.VerifyArgumentNotNull(headers, "headers");
127 ErrorUtilities.VerifyArgumentNotNull(response, "response");
129 foreach (string headerName in headers) {
130 switch (headerName) {
131 case "Content-Type":
132 response.ContentType = headers[HttpResponseHeader.ContentType];
133 break;
135 // Add more special cases here as necessary.
136 default:
137 response.AddHeader(headerName, headers[headerName]);
138 break;
143 /// <summary>
144 /// Copies the contents of one stream to another.
145 /// </summary>
146 /// <param name="copyFrom">The stream to copy from, at the position where copying should begin.</param>
147 /// <param name="copyTo">The stream to copy to, at the position where bytes should be written.</param>
148 /// <returns>The total number of bytes copied.</returns>
149 /// <remarks>
150 /// Copying begins at the streams' current positions.
151 /// The positions are NOT reset after copying is complete.
152 /// </remarks>
153 internal static int CopyTo(this Stream copyFrom, Stream copyTo) {
154 return CopyTo(copyFrom, copyTo, int.MaxValue);
157 /// <summary>
158 /// Copies the contents of one stream to another.
159 /// </summary>
160 /// <param name="copyFrom">The stream to copy from, at the position where copying should begin.</param>
161 /// <param name="copyTo">The stream to copy to, at the position where bytes should be written.</param>
162 /// <param name="maximumBytesToCopy">The maximum bytes to copy.</param>
163 /// <returns>The total number of bytes copied.</returns>
164 /// <remarks>
165 /// Copying begins at the streams' current positions.
166 /// The positions are NOT reset after copying is complete.
167 /// </remarks>
168 internal static int CopyTo(this Stream copyFrom, Stream copyTo, int maximumBytesToCopy) {
169 ErrorUtilities.VerifyArgumentNotNull(copyFrom, "copyFrom");
170 ErrorUtilities.VerifyArgumentNotNull(copyTo, "copyTo");
171 ErrorUtilities.VerifyArgument(copyFrom.CanRead, MessagingStrings.StreamUnreadable);
172 ErrorUtilities.VerifyArgument(copyTo.CanWrite, MessagingStrings.StreamUnwritable, "copyTo");
174 byte[] buffer = new byte[1024];
175 int readBytes;
176 int totalCopiedBytes = 0;
177 while ((readBytes = copyFrom.Read(buffer, 0, Math.Min(1024, maximumBytesToCopy))) > 0) {
178 int writeBytes = Math.Min(maximumBytesToCopy, readBytes);
179 copyTo.Write(buffer, 0, writeBytes);
180 totalCopiedBytes += writeBytes;
181 maximumBytesToCopy -= writeBytes;
184 return totalCopiedBytes;
187 /// <summary>
188 /// Creates a snapshot of some stream so it is seekable, and the original can be closed.
189 /// </summary>
190 /// <param name="copyFrom">The stream to copy bytes from.</param>
191 /// <returns>A seekable stream with the same contents as the original.</returns>
192 internal static Stream CreateSnapshot(this Stream copyFrom) {
193 ErrorUtilities.VerifyArgumentNotNull(copyFrom, "copyFrom");
195 MemoryStream copyTo = new MemoryStream(copyFrom.CanSeek ? (int)copyFrom.Length : 4 * 1024);
196 copyFrom.CopyTo(copyTo);
197 copyTo.Position = 0;
198 return copyTo;
201 /// <summary>
202 /// Clones an <see cref="HttpWebRequest"/> in order to send it again.
203 /// </summary>
204 /// <param name="request">The request to clone.</param>
205 /// <returns>The newly created instance.</returns>
206 internal static HttpWebRequest Clone(this HttpWebRequest request) {
207 ErrorUtilities.VerifyArgumentNotNull(request, "request");
208 return Clone(request, request.RequestUri);
211 /// <summary>
212 /// Clones an <see cref="HttpWebRequest"/> in order to send it again.
213 /// </summary>
214 /// <param name="request">The request to clone.</param>
215 /// <param name="newRequestUri">The new recipient of the request.</param>
216 /// <returns>The newly created instance.</returns>
217 internal static HttpWebRequest Clone(this HttpWebRequest request, Uri newRequestUri) {
218 ErrorUtilities.VerifyArgumentNotNull(request, "request");
219 ErrorUtilities.VerifyArgumentNotNull(newRequestUri, "newRequestUri");
221 var newRequest = (HttpWebRequest)WebRequest.Create(newRequestUri);
222 newRequest.Accept = request.Accept;
223 newRequest.AllowAutoRedirect = request.AllowAutoRedirect;
224 newRequest.AllowWriteStreamBuffering = request.AllowWriteStreamBuffering;
225 newRequest.AuthenticationLevel = request.AuthenticationLevel;
226 newRequest.AutomaticDecompression = request.AutomaticDecompression;
227 newRequest.CachePolicy = request.CachePolicy;
228 newRequest.ClientCertificates = request.ClientCertificates;
229 newRequest.ConnectionGroupName = request.ConnectionGroupName;
230 if (request.ContentLength >= 0) {
231 newRequest.ContentLength = request.ContentLength;
233 newRequest.ContentType = request.ContentType;
234 newRequest.ContinueDelegate = request.ContinueDelegate;
235 newRequest.CookieContainer = request.CookieContainer;
236 newRequest.Credentials = request.Credentials;
237 newRequest.Expect = request.Expect;
238 newRequest.IfModifiedSince = request.IfModifiedSince;
239 newRequest.ImpersonationLevel = request.ImpersonationLevel;
240 newRequest.KeepAlive = request.KeepAlive;
241 newRequest.MaximumAutomaticRedirections = request.MaximumAutomaticRedirections;
242 newRequest.MaximumResponseHeadersLength = request.MaximumResponseHeadersLength;
243 newRequest.MediaType = request.MediaType;
244 newRequest.Method = request.Method;
245 newRequest.Pipelined = request.Pipelined;
246 newRequest.PreAuthenticate = request.PreAuthenticate;
247 newRequest.ProtocolVersion = request.ProtocolVersion;
248 newRequest.Proxy = request.Proxy;
249 newRequest.ReadWriteTimeout = request.ReadWriteTimeout;
250 newRequest.Referer = request.Referer;
251 newRequest.SendChunked = request.SendChunked;
252 newRequest.Timeout = request.Timeout;
253 newRequest.TransferEncoding = request.TransferEncoding;
254 newRequest.UnsafeAuthenticatedConnectionSharing = request.UnsafeAuthenticatedConnectionSharing;
255 newRequest.UseDefaultCredentials = request.UseDefaultCredentials;
256 newRequest.UserAgent = request.UserAgent;
258 // We copy headers last, and only those that do not yet exist as a result
259 // of setting these properties, so as to avoid exceptions thrown because
260 // there are properties .NET wants us to use rather than direct headers.
261 foreach (string header in request.Headers) {
262 if (string.IsNullOrEmpty(newRequest.Headers[header])) {
263 newRequest.Headers.Add(header, request.Headers[header]);
267 return newRequest;
270 /// <summary>
271 /// Tests whether two arrays are equal in length and contents.
272 /// </summary>
273 /// <typeparam name="T">The type of elements in the arrays.</typeparam>
274 /// <param name="first">The first array in the comparison. May not be null.</param>
275 /// <param name="second">The second array in the comparison. May not be null.</param>
276 /// <returns>True if the arrays equal; false otherwise.</returns>
277 internal static bool AreEquivalent<T>(T[] first, T[] second) {
278 ErrorUtilities.VerifyArgumentNotNull(first, "first");
279 ErrorUtilities.VerifyArgumentNotNull(second, "second");
280 if (first.Length != second.Length) {
281 return false;
283 for (int i = 0; i < first.Length; i++) {
284 if (!first[i].Equals(second[i])) {
285 return false;
288 return true;
291 /// <summary>
292 /// Tests two sequences for same contents and ordering.
293 /// </summary>
294 /// <typeparam name="T">The type of elements in the arrays.</typeparam>
295 /// <param name="sequence1">The first sequence in the comparison. May not be null.</param>
296 /// <param name="sequence2">The second sequence in the comparison. May not be null.</param>
297 /// <returns>True if the arrays equal; false otherwise.</returns>
298 internal static bool AreEquivalent<T>(IEnumerable<T> sequence1, IEnumerable<T> sequence2) {
299 if (sequence1 == null && sequence2 == null) {
300 return true;
302 if ((sequence1 == null) ^ (sequence2 == null)) {
303 return false;
306 IEnumerator<T> iterator1 = sequence1.GetEnumerator();
307 IEnumerator<T> iterator2 = sequence2.GetEnumerator();
308 bool movenext1, movenext2;
309 while (true) {
310 movenext1 = iterator1.MoveNext();
311 movenext2 = iterator2.MoveNext();
312 if (!movenext1 || !movenext2) { // if we've reached the end of at least one sequence
313 break;
315 object obj1 = iterator1.Current;
316 object obj2 = iterator2.Current;
317 if (obj1 == null && obj2 == null) {
318 continue; // both null is ok
320 if (obj1 == null ^ obj2 == null) {
321 return false; // exactly one null is different
323 if (!obj1.Equals(obj2)) {
324 return false; // if they're not equal to each other
328 return movenext1 == movenext2; // did they both reach the end together?
331 /// <summary>
332 /// Tests whether two dictionaries are equal in length and contents.
333 /// </summary>
334 /// <typeparam name="TKey">The type of keys in the dictionaries.</typeparam>
335 /// <typeparam name="TValue">The type of values in the dictionaries.</typeparam>
336 /// <param name="first">The first dictionary in the comparison. May not be null.</param>
337 /// <param name="second">The second dictionary in the comparison. May not be null.</param>
338 /// <returns>True if the arrays equal; false otherwise.</returns>
339 internal static bool AreEquivalent<TKey, TValue>(IDictionary<TKey, TValue> first, IDictionary<TKey, TValue> second) {
340 return AreEquivalent(first.ToArray(), second.ToArray());
343 /// <summary>
344 /// Concatenates a list of name-value pairs as key=value&amp;key=value,
345 /// taking care to properly encode each key and value for URL
346 /// transmission. No ? is prefixed to the string.
347 /// </summary>
348 /// <param name="args">The dictionary of key/values to read from.</param>
349 /// <returns>The formulated querystring style string.</returns>
350 internal static string CreateQueryString(IEnumerable<KeyValuePair<string, string>> args) {
351 ErrorUtilities.VerifyArgumentNotNull(args, "args");
352 if (args.Count() == 0) {
353 return string.Empty;
355 StringBuilder sb = new StringBuilder(args.Count() * 10);
357 foreach (var p in args) {
358 ErrorUtilities.VerifyArgument(p.Key != null, MessagingStrings.UnexpectedNullKey);
359 ErrorUtilities.VerifyArgument(p.Value != null, MessagingStrings.UnexpectedNullValue, p.Key);
360 sb.Append(HttpUtility.UrlEncode(p.Key));
361 sb.Append('=');
362 sb.Append(HttpUtility.UrlEncode(p.Value));
363 sb.Append('&');
365 sb.Length--; // remove trailing &
367 return sb.ToString();
370 /// <summary>
371 /// Adds a set of name-value pairs to the end of a given URL
372 /// as part of the querystring piece. Prefixes a ? or &amp; before
373 /// first element as necessary.
374 /// </summary>
375 /// <param name="builder">The UriBuilder to add arguments to.</param>
376 /// <param name="args">
377 /// The arguments to add to the query.
378 /// If null, <paramref name="builder"/> is not changed.
379 /// </param>
380 internal static void AppendQueryArgs(this UriBuilder builder, IEnumerable<KeyValuePair<string, string>> args) {
381 if (builder == null) {
382 throw new ArgumentNullException("builder");
385 if (args != null && args.Count() > 0) {
386 StringBuilder sb = new StringBuilder(50 + (args.Count() * 10));
387 if (!string.IsNullOrEmpty(builder.Query)) {
388 sb.Append(builder.Query.Substring(1));
389 sb.Append('&');
391 sb.Append(CreateQueryString(args));
393 builder.Query = sb.ToString();
397 /// <summary>
398 /// Extracts the recipient from an HttpRequestInfo.
399 /// </summary>
400 /// <param name="request">The request to get recipient information from.</param>
401 /// <returns>The recipient.</returns>
402 internal static MessageReceivingEndpoint GetRecipient(this HttpRequestInfo request) {
403 return new MessageReceivingEndpoint(request.Url, request.HttpMethod == "GET" ? HttpDeliveryMethods.GetRequest : HttpDeliveryMethods.PostRequest);
406 /// <summary>
407 /// Copies some extra parameters into a message.
408 /// </summary>
409 /// <param name="message">The message to copy the extra data into.</param>
410 /// <param name="extraParameters">The extra data to copy into the message. May be null to do nothing.</param>
411 internal static void AddExtraParameters(this IMessage message, IDictionary<string, string> extraParameters) {
412 ErrorUtilities.VerifyArgumentNotNull(message, "message");
414 if (extraParameters != null) {
415 MessageDictionary messageDictionary = new MessageDictionary(message);
416 foreach (var pair in extraParameters) {
417 messageDictionary.Add(pair);
422 /// <summary>
423 /// Converts a <see cref="NameValueCollection"/> to an IDictionary&lt;string, string&gt;.
424 /// </summary>
425 /// <param name="nvc">The NameValueCollection to convert. May be null.</param>
426 /// <returns>The generated dictionary, or null if <paramref name="nvc"/> is null.</returns>
427 internal static Dictionary<string, string> ToDictionary(this NameValueCollection nvc) {
428 if (nvc == null) {
429 return null;
432 var dictionary = new Dictionary<string, string>();
433 foreach (string key in nvc) {
434 dictionary.Add(key, nvc[key]);
437 return dictionary;
440 /// <summary>
441 /// Sorts the elements of a sequence in ascending order by using a specified comparer.
442 /// </summary>
443 /// <typeparam name="TSource">The type of the elements of source.</typeparam>
444 /// <typeparam name="TKey">The type of the key returned by keySelector.</typeparam>
445 /// <param name="source">A sequence of values to order.</param>
446 /// <param name="keySelector">A function to extract a key from an element.</param>
447 /// <param name="comparer">A comparison function to compare keys.</param>
448 /// <returns>An System.Linq.IOrderedEnumerable&lt;TElement&gt; whose elements are sorted according to a key.</returns>
449 internal static IOrderedEnumerable<TSource> OrderBy<TSource, TKey>(this IEnumerable<TSource> source, Func<TSource, TKey> keySelector, Comparison<TKey> comparer) {
450 return System.Linq.Enumerable.OrderBy<TSource, TKey>(source, keySelector, new ComparisonHelper<TKey>(comparer));
453 /// <summary>
454 /// Determines whether the specified message is a request (indirect message or direct request).
455 /// </summary>
456 /// <param name="message">The message in question.</param>
457 /// <returns>
458 /// <c>true</c> if the specified message is a request; otherwise, <c>false</c>.
459 /// </returns>
460 /// <remarks>
461 /// Although an <see cref="IProtocolMessage"/> may implement the <see cref="IDirectedProtocolMessage"/>
462 /// interface, it may only be doing that for its derived classes. These objects are only requests
463 /// if their <see cref="IDirectedProtocolMessage.Recipient"/> property is non-null.
464 /// </remarks>
465 internal static bool IsRequest(this IDirectedProtocolMessage message) {
466 ErrorUtilities.VerifyArgumentNotNull(message, "message");
467 return message.Recipient != null;
470 /// <summary>
471 /// Determines whether the specified message is a direct response.
472 /// </summary>
473 /// <param name="message">The message in question.</param>
474 /// <returns>
475 /// <c>true</c> if the specified message is a direct response; otherwise, <c>false</c>.
476 /// </returns>
477 /// <remarks>
478 /// Although an <see cref="IProtocolMessage"/> may implement the
479 /// <see cref="IDirectResponseProtocolMessage"/> interface, it may only be doing
480 /// that for its derived classes. These objects are only requests if their
481 /// <see cref="IDirectResponseProtocolMessage.OriginatingRequest"/> property is non-null.
482 /// </remarks>
483 internal static bool IsDirectResponse(this IDirectResponseProtocolMessage message) {
484 ErrorUtilities.VerifyArgumentNotNull(message, "message");
485 return message.OriginatingRequest != null;
488 /// <summary>
489 /// A class to convert a <see cref="Comparison&lt;T&gt;"/> into an <see cref="IComparer&lt;T&gt;"/>.
490 /// </summary>
491 /// <typeparam name="T">The type of objects being compared.</typeparam>
492 private class ComparisonHelper<T> : IComparer<T> {
493 /// <summary>
494 /// The comparison method to use.
495 /// </summary>
496 private Comparison<T> comparison;
498 /// <summary>
499 /// Initializes a new instance of the ComparisonHelper class.
500 /// </summary>
501 /// <param name="comparison">The comparison method to use.</param>
502 internal ComparisonHelper(Comparison<T> comparison) {
503 if (comparison == null) {
504 throw new ArgumentNullException("comparison");
507 this.comparison = comparison;
510 #region IComparer<T> Members
512 /// <summary>
513 /// Compares two instances of <typeparamref name="T"/>.
514 /// </summary>
515 /// <param name="x">The first object to compare.</param>
516 /// <param name="y">The second object to compare.</param>
517 /// <returns>Any of -1, 0, or 1 according to standard comparison rules.</returns>
518 public int Compare(T x, T y) {
519 return this.comparison(x, y);
522 #endregion