1
//-----------------------------------------------------------------------
2 // <copyright file="MessagingUtilities.cs" company="Andrew Arnott">
3 // Copyright (c) Andrew Arnott. All rights reserved.
5 //-----------------------------------------------------------------------
7 namespace DotNetOpenAuth
.Messaging
{
9 using System
.Collections
.Generic
;
10 using System
.Collections
.Specialized
;
11 using System
.Diagnostics
.CodeAnalysis
;
15 using System
.Security
.Cryptography
;
18 using DotNetOpenAuth
.Messaging
.Reflection
;
21 /// A grab-bag of utility methods useful for the channel stack of the protocol.
23 public static class MessagingUtilities
{
25 /// The cryptographically strong random data generator used for creating secrets.
27 /// <remarks>The random number generator is thread-safe.</remarks>
28 internal static readonly RandomNumberGenerator CryptoRandomDataGenerator
= new RNGCryptoServiceProvider();
31 /// Gets the original request URL, as seen from the browser before any URL rewrites on the server if any.
32 /// Cookieless session directory (if applicable) is also included.
34 /// <returns>The URL in the user agent's Location bar.</returns>
35 [SuppressMessage("Microsoft.Usage", "CA2234:PassSystemUriObjectsInsteadOfStrings", Justification
= "The Uri merging requires use of a string value.")]
36 [SuppressMessage("Microsoft.Design", "CA1024:UsePropertiesWhereAppropriate", Justification
= "Expensive call should not be a property.")]
37 public static Uri
GetRequestUrlFromContext() {
38 HttpContext context
= HttpContext
.Current
;
39 if (context
== null) {
40 throw new InvalidOperationException(MessagingStrings
.CurrentHttpContextRequired
);
43 // We use Request.Url for the full path to the server, and modify it
44 // with Request.RawUrl to capture both the cookieless session "directory" if it exists
45 // and the original path in case URL rewriting is going on. We don't want to be
46 // fooled by URL rewriting because we're comparing the actual URL with what's in
47 // the return_to parameter in some cases.
48 // Response.ApplyAppPathModifier(builder.Path) would have worked for the cookieless
49 // session, but not the URL rewriting problem.
50 return new Uri(context
.Request
.Url
, context
.Request
.RawUrl
);
54 /// Strips any and all URI query parameters that start with some prefix.
56 /// <param name="uri">The URI that may have a query with parameters to remove.</param>
57 /// <param name="prefix">The prefix for parameters to remove.</param>
58 /// <returns>Either a new Uri with the parameters removed if there were any to remove, or the same Uri instance if no parameters needed to be removed.</returns>
59 public static Uri
StripQueryArgumentsWithPrefix(this Uri uri
, string prefix
) {
60 NameValueCollection queryArgs
= HttpUtility
.ParseQueryString(uri
.Query
);
61 var matchingKeys
= queryArgs
.Keys
.OfType
<string>().Where(key
=> key
.StartsWith(prefix
, StringComparison
.OrdinalIgnoreCase
)).ToList();
62 if (matchingKeys
.Count
> 0) {
63 UriBuilder builder
= new UriBuilder(uri
);
64 foreach (string key
in matchingKeys
) {
65 queryArgs
.Remove(key
);
67 builder
.Query
= CreateQueryString(queryArgs
.ToDictionary());
75 /// Gets a cryptographically strong random sequence of values.
77 /// <param name="length">The length of the sequence to generate.</param>
78 /// <returns>The generated values, which may contain zeros.</returns>
79 internal static byte[] GetCryptoRandomData(int length
) {
80 byte[] buffer
= new byte[length
];
81 CryptoRandomDataGenerator
.GetBytes(buffer
);
86 /// Adds a set of HTTP headers to an <see cref="HttpResponse"/> instance,
87 /// taking care to set some headers to the appropriate properties of
88 /// <see cref="HttpResponse" />
90 /// <param name="headers">The headers to add.</param>
91 /// <param name="response">The <see cref="HttpResponse"/> instance to set the appropriate values to.</param>
92 internal static void ApplyHeadersToResponse(WebHeaderCollection headers
, HttpResponse response
) {
93 if (headers
== null) {
94 throw new ArgumentNullException("headers");
96 if (response
== null) {
97 throw new ArgumentNullException("response");
99 foreach (string headerName
in headers
) {
100 switch (headerName
) {
102 response
.ContentType
= headers
[HttpResponseHeader
.ContentType
];
105 // Add more special cases here as necessary.
107 response
.AddHeader(headerName
, headers
[headerName
]);
114 /// Copies the contents of one stream to another.
116 /// <param name="copyFrom">The stream to copy from, at the position where copying should begin.</param>
117 /// <param name="copyTo">The stream to copy to, at the position where bytes should be written.</param>
118 /// <returns>The total number of bytes copied.</returns>
120 /// Copying begins at the streams' current positions.
121 /// The positions are NOT reset after copying is complete.
123 internal static int CopyTo(this Stream copyFrom
, Stream copyTo
) {
124 return CopyTo(copyFrom
, copyTo
, int.MaxValue
);
128 /// Copies the contents of one stream to another.
130 /// <param name="copyFrom">The stream to copy from, at the position where copying should begin.</param>
131 /// <param name="copyTo">The stream to copy to, at the position where bytes should be written.</param>
132 /// <param name="maximumBytesToCopy">The maximum bytes to copy.</param>
133 /// <returns>The total number of bytes copied.</returns>
135 /// Copying begins at the streams' current positions.
136 /// The positions are NOT reset after copying is complete.
138 internal static int CopyTo(this Stream copyFrom
, Stream copyTo
, int maximumBytesToCopy
) {
139 ErrorUtilities
.VerifyArgumentNotNull(copyFrom
, "copyFrom");
140 ErrorUtilities
.VerifyArgumentNotNull(copyTo
, "copyTo");
141 ErrorUtilities
.VerifyArgument(copyFrom
.CanRead
, MessagingStrings
.StreamUnreadable
);
142 ErrorUtilities
.VerifyArgument(copyTo
.CanWrite
, MessagingStrings
.StreamUnwritable
, "copyTo");
144 byte[] buffer
= new byte[1024];
146 int totalCopiedBytes
= 0;
147 while ((readBytes
= copyFrom
.Read(buffer
, 0, Math
.Min(1024, maximumBytesToCopy
))) > 0) {
148 int writeBytes
= Math
.Min(maximumBytesToCopy
, readBytes
);
149 copyTo
.Write(buffer
, 0, writeBytes
);
150 totalCopiedBytes
+= writeBytes
;
151 maximumBytesToCopy
-= writeBytes
;
154 return totalCopiedBytes
;
158 /// Creates a snapshot of some stream so it is seekable, and the original can be closed.
160 /// <param name="copyFrom">The stream to copy bytes from.</param>
161 /// <returns>A seekable stream with the same contents as the original.</returns>
162 internal static Stream
CreateSnapshot(this Stream copyFrom
) {
163 ErrorUtilities
.VerifyArgumentNotNull(copyFrom
, "copyFrom");
165 MemoryStream copyTo
= new MemoryStream(copyFrom
.CanSeek
? (int)copyFrom
.Length
: 4 * 1024);
166 copyFrom
.CopyTo(copyTo
);
172 /// Tests whether two arrays are equal in length and contents.
174 /// <typeparam name="T">The type of elements in the arrays.</typeparam>
175 /// <param name="first">The first array in the comparison. May not be null.</param>
176 /// <param name="second">The second array in the comparison. May not be null.</param>
177 /// <returns>True if the arrays equal; false otherwise.</returns>
178 internal static bool AreEquivalent
<T
>(T
[] first
, T
[] second
) {
180 throw new ArgumentNullException("first");
182 if (second
== null) {
183 throw new ArgumentNullException("second");
185 if (first
.Length
!= second
.Length
) {
188 for (int i
= 0; i
< first
.Length
; i
++) {
189 if (!first
[i
].Equals(second
[i
])) {
197 /// Tests whether two dictionaries are equal in length and contents.
199 /// <typeparam name="TKey">The type of keys in the dictionaries.</typeparam>
200 /// <typeparam name="TValue">The type of values in the dictionaries.</typeparam>
201 /// <param name="first">The first dictionary in the comparison. May not be null.</param>
202 /// <param name="second">The second dictionary in the comparison. May not be null.</param>
203 /// <returns>True if the arrays equal; false otherwise.</returns>
204 internal static bool AreEquivalent
<TKey
, TValue
>(IDictionary
<TKey
, TValue
> first
, IDictionary
<TKey
, TValue
> second
) {
205 return AreEquivalent(first
.ToArray(), second
.ToArray());
209 /// Concatenates a list of name-value pairs as key=value&key=value,
210 /// taking care to properly encode each key and value for URL
211 /// transmission. No ? is prefixed to the string.
213 /// <param name="args">The dictionary of key/values to read from.</param>
214 /// <returns>The formulated querystring style string.</returns>
215 internal static string CreateQueryString(IDictionary
<string, string> args
) {
217 throw new ArgumentNullException("args");
219 if (args
.Count
== 0) {
222 StringBuilder sb
= new StringBuilder(args
.Count
* 10);
224 foreach (var p
in args
) {
225 sb
.Append(HttpUtility
.UrlEncode(p
.Key
));
227 sb
.Append(HttpUtility
.UrlEncode(p
.Value
));
230 sb
.Length
--; // remove trailing &
232 return sb
.ToString();
236 /// Adds a set of name-value pairs to the end of a given URL
237 /// as part of the querystring piece. Prefixes a ? or & before
238 /// first element as necessary.
240 /// <param name="builder">The UriBuilder to add arguments to.</param>
241 /// <param name="args">
242 /// The arguments to add to the query.
243 /// If null, <paramref name="builder"/> is not changed.
245 internal static void AppendQueryArgs(UriBuilder builder
, IDictionary
<string, string> args
) {
246 if (builder
== null) {
247 throw new ArgumentNullException("builder");
250 if (args
!= null && args
.Count
> 0) {
251 StringBuilder sb
= new StringBuilder(50 + (args
.Count
* 10));
252 if (!string.IsNullOrEmpty(builder
.Query
)) {
253 sb
.Append(builder
.Query
.Substring(1));
256 sb
.Append(CreateQueryString(args
));
258 builder
.Query
= sb
.ToString();
263 /// Extracts the recipient from an HttpRequestInfo.
265 /// <param name="request">The request to get recipient information from.</param>
266 /// <returns>The recipient.</returns>
267 internal static MessageReceivingEndpoint
GetRecipient(this HttpRequestInfo request
) {
268 return new MessageReceivingEndpoint(request
.Url
, request
.HttpMethod
== "GET" ? HttpDeliveryMethods
.GetRequest
: HttpDeliveryMethods
.PostRequest
);
272 /// Copies some extra parameters into a message.
274 /// <param name="message">The message to copy the extra data into.</param>
275 /// <param name="extraParameters">The extra data to copy into the message. May be null to do nothing.</param>
276 internal static void AddNonOAuthParameters(this IProtocolMessage message
, IDictionary
<string, string> extraParameters
) {
277 if (message
== null) {
278 throw new ArgumentNullException("message");
281 if (extraParameters
!= null) {
282 MessageDictionary messageDictionary
= new MessageDictionary(message
);
283 foreach (var pair
in extraParameters
) {
284 messageDictionary
.Add(pair
);
290 /// Converts a <see cref="NameValueCollection"/> to an IDictionary<string, string>.
292 /// <param name="nvc">The NameValueCollection to convert. May be null.</param>
293 /// <returns>The generated dictionary, or null if <paramref name="nvc"/> is null.</returns>
294 internal static Dictionary
<string, string> ToDictionary(this NameValueCollection nvc
) {
299 var dictionary
= new Dictionary
<string, string>();
300 foreach (string key
in nvc
) {
301 dictionary
.Add(key
, nvc
[key
]);
308 /// Sorts the elements of a sequence in ascending order by using a specified comparer.
310 /// <typeparam name="TSource">The type of the elements of source.</typeparam>
311 /// <typeparam name="TKey">The type of the key returned by keySelector.</typeparam>
312 /// <param name="source">A sequence of values to order.</param>
313 /// <param name="keySelector">A function to extract a key from an element.</param>
314 /// <param name="comparer">A comparison function to compare keys.</param>
315 /// <returns>An System.Linq.IOrderedEnumerable<TElement> whose elements are sorted according to a key.</returns>
316 internal static IOrderedEnumerable
<TSource
> OrderBy
<TSource
, TKey
>(this IEnumerable
<TSource
> source
, Func
<TSource
, TKey
> keySelector
, Comparison
<TKey
> comparer
) {
317 return System
.Linq
.Enumerable
.OrderBy
<TSource
, TKey
>(source
, keySelector
, new ComparisonHelper
<TKey
>(comparer
));
321 /// Determines whether the specified message is a request (indirect message or direct request).
323 /// <param name="message">The message in question.</param>
325 /// <c>true</c> if the specified message is a request; otherwise, <c>false</c>.
328 /// Although an <see cref="IProtocolMessage"/> may implement the <see cref="IDirectedProtocolMessage"/>
329 /// interface, it may only be doing that for its derived classes. These objects are only requests
330 /// if their <see cref="IDirectedProtocolMessage.Recipient"/> property is non-null.
332 internal static bool IsRequest(this IDirectedProtocolMessage message
) {
333 ErrorUtilities
.VerifyArgumentNotNull(message
, "message");
334 return message
.Recipient
!= null;
338 /// Determines whether the specified message is a direct response.
340 /// <param name="message">The message in question.</param>
342 /// <c>true</c> if the specified message is a direct response; otherwise, <c>false</c>.
345 /// Although an <see cref="IProtocolMessage"/> may implement the
346 /// <see cref="IDirectResponseProtocolMessage"/> interface, it may only be doing
347 /// that for its derived classes. These objects are only requests if their
348 /// <see cref="IDirectResponseProtocolMessage.OriginatingRequest"/> property is non-null.
350 internal static bool IsDirectResponse(this IDirectResponseProtocolMessage message
) {
351 ErrorUtilities
.VerifyArgumentNotNull(message
, "message");
352 return message
.OriginatingRequest
!= null;
356 /// A class to convert a <see cref="Comparison<T>"/> into an <see cref="IComparer<T>"/>.
358 /// <typeparam name="T">The type of objects being compared.</typeparam>
359 private class ComparisonHelper
<T
> : IComparer
<T
> {
361 /// The comparison method to use.
363 private Comparison
<T
> comparison
;
366 /// Initializes a new instance of the ComparisonHelper class.
368 /// <param name="comparison">The comparison method to use.</param>
369 internal ComparisonHelper(Comparison
<T
> comparison
) {
370 if (comparison
== null) {
371 throw new ArgumentNullException("comparison");
374 this.comparison
= comparison
;
377 #region IComparer<T> Members
380 /// Compares two instances of <typeparamref name="T"/>.
382 /// <param name="x">The first object to compare.</param>
383 /// <param name="y">The second object to compare.</param>
384 /// <returns>Any of -1, 0, or 1 according to standard comparison rules.</returns>
385 public int Compare(T x
, T y
) {
386 return this.comparison(x
, y
);