StyleCop fixes
[dotnetoauth.git] / src / DotNetOpenAuth / OpenId / ChannelElements / SigningBindingElement.cs
blobbe768329b72eca8d77c2ff05ab765668fc7a4171
1 //-----------------------------------------------------------------------
2 // <copyright file="SigningBindingElement.cs" company="Andrew Arnott">
3 // Copyright (c) Andrew Arnott. All rights reserved.
4 // </copyright>
5 //-----------------------------------------------------------------------
7 namespace DotNetOpenAuth.OpenId.ChannelElements {
8 using System;
9 using System.Collections.Generic;
10 using System.Diagnostics;
11 using System.Linq;
12 using System.Net.Security;
13 using DotNetOpenAuth.Loggers;
14 using DotNetOpenAuth.Messaging;
15 using DotNetOpenAuth.Messaging.Bindings;
16 using DotNetOpenAuth.Messaging.Reflection;
17 using DotNetOpenAuth.OpenId.Messages;
18 using DotNetOpenAuth.OpenId.Provider;
19 using DotNetOpenAuth.OpenId.RelyingParty;
21 /// <summary>
22 /// Signs and verifies authentication assertions.
23 /// </summary>
24 internal class SigningBindingElement : IChannelBindingElement {
25 /// <summary>
26 /// The association store used by Relying Parties to look up the secrets needed for signing.
27 /// </summary>
28 private readonly IAssociationStore<Uri> rpAssociations;
30 /// <summary>
31 /// The association store used by Providers to look up the secrets needed for signing.
32 /// </summary>
33 private readonly IAssociationStore<AssociationRelyingPartyType> opAssociations;
35 /// <summary>
36 /// The security settings at the Provider.
37 /// Only defined when this element is instantiated to service a Provider.
38 /// </summary>
39 private readonly ProviderSecuritySettings opSecuritySettings;
41 /// <summary>
42 /// A logger specifically used for logging verbose text on everything about the signing process.
43 /// </summary>
44 private static ILog signingLogger = Logger.Create(typeof(SigningBindingElement));
46 /// <summary>
47 /// Initializes a new instance of the SigningBindingElement class for use by a Relying Party.
48 /// </summary>
49 /// <param name="associationStore">The association store used to look up the secrets needed for signing. May be null for dumb Relying Parties.</param>
50 internal SigningBindingElement(IAssociationStore<Uri> associationStore) {
51 this.rpAssociations = associationStore;
54 /// <summary>
55 /// Initializes a new instance of the SigningBindingElement class for use by a Provider.
56 /// </summary>
57 /// <param name="associationStore">The association store used to look up the secrets needed for signing.</param>
58 /// <param name="securitySettings">The security settings.</param>
59 internal SigningBindingElement(IAssociationStore<AssociationRelyingPartyType> associationStore, ProviderSecuritySettings securitySettings) {
60 ErrorUtilities.VerifyArgumentNotNull(associationStore, "associationStore");
61 ErrorUtilities.VerifyArgumentNotNull(securitySettings, "securitySettings");
63 this.opAssociations = associationStore;
64 this.opSecuritySettings = securitySettings;
67 #region IChannelBindingElement Properties
69 /// <summary>
70 /// Gets the protection offered (if any) by this binding element.
71 /// </summary>
72 /// <value><see cref="MessageProtections.TamperProtection"/></value>
73 public MessageProtections Protection {
74 get { return MessageProtections.TamperProtection; }
77 /// <summary>
78 /// Gets or sets the channel that this binding element belongs to.
79 /// </summary>
80 public Channel Channel { get; set; }
82 #endregion
84 /// <summary>
85 /// Gets a value indicating whether this binding element is on a Provider channel.
86 /// </summary>
87 private bool IsOnProvider {
88 get { return this.opAssociations != null; }
91 #region IChannelBindingElement Methods
93 /// <summary>
94 /// Prepares a message for sending based on the rules of this channel binding element.
95 /// </summary>
96 /// <param name="message">The message to prepare for sending.</param>
97 /// <returns>
98 /// True if the <paramref name="message"/> applied to this binding element
99 /// and the operation was successful. False otherwise.
100 /// </returns>
101 public bool PrepareMessageForSending(IProtocolMessage message) {
102 var signedMessage = message as ITamperResistantOpenIdMessage;
103 if (signedMessage != null) {
104 Logger.DebugFormat("Signing {0} message.", message.GetType().Name);
105 Association association = this.GetAssociation(signedMessage);
106 signedMessage.AssociationHandle = association.Handle;
107 signedMessage.SignedParameterOrder = this.GetSignedParameterOrder(signedMessage);
108 signedMessage.Signature = this.GetSignature(signedMessage, association);
109 return true;
112 return false;
115 /// <summary>
116 /// Performs any transformation on an incoming message that may be necessary and/or
117 /// validates an incoming message based on the rules of this channel binding element.
118 /// </summary>
119 /// <param name="message">The incoming message to process.</param>
120 /// <returns>
121 /// True if the <paramref name="message"/> applied to this binding element
122 /// and the operation was successful. False if the operation did not apply to this message.
123 /// </returns>
124 /// <exception cref="ProtocolException">
125 /// Thrown when the binding element rules indicate that this message is invalid and should
126 /// NOT be processed.
127 /// </exception>
128 public bool PrepareMessageForReceiving(IProtocolMessage message) {
129 var signedMessage = message as ITamperResistantOpenIdMessage;
130 if (signedMessage != null) {
131 Logger.DebugFormat("Verifying incoming {0} message signature of: {1}", message.GetType().Name, signedMessage.Signature);
133 EnsureParametersRequiringSignatureAreSigned(signedMessage);
135 Association association = this.GetSpecificAssociation(signedMessage);
136 if (association != null) {
137 string signature = this.GetSignature(signedMessage, association);
138 if (!string.Equals(signedMessage.Signature, signature, StringComparison.Ordinal)) {
139 Logger.Error("Signature verification failed.");
140 throw new InvalidSignatureException(message);
142 } else {
143 ErrorUtilities.VerifyInternal(this.Channel != null, "Cannot verify private association signature because we don't have a channel.");
145 // We did not recognize the association the provider used to sign the message.
146 // Ask the provider to check the signature then.
147 var indirectSignedResponse = (IndirectSignedResponse)signedMessage;
148 var checkSignatureRequest = new CheckAuthenticationRequest(indirectSignedResponse);
149 var checkSignatureResponse = this.Channel.Request<CheckAuthenticationResponse>(checkSignatureRequest);
150 if (!checkSignatureResponse.IsValid) {
151 Logger.Error("Provider reports signature verification failed.");
152 throw new InvalidSignatureException(message);
155 // If the OP confirms that a handle should be invalidated as well, do that.
156 if (!string.IsNullOrEmpty(checkSignatureResponse.InvalidateHandle)) {
157 if (this.rpAssociations != null) {
158 this.rpAssociations.RemoveAssociation(indirectSignedResponse.ProviderEndpoint, checkSignatureResponse.InvalidateHandle);
163 return true;
166 return false;
169 #endregion
171 /// <summary>
172 /// Ensures that all message parameters that must be signed are in fact included
173 /// in the signature.
174 /// </summary>
175 /// <param name="signedMessage">The signed message.</param>
176 private static void EnsureParametersRequiringSignatureAreSigned(ITamperResistantOpenIdMessage signedMessage) {
177 // Verify that the signed parameter order includes the mandated fields.
178 // We do this in such a way that derived classes that add mandated fields automatically
179 // get included in the list of checked parameters.
180 Protocol protocol = Protocol.Lookup(signedMessage.Version);
181 var partsRequiringProtection = from part in MessageDescription.Get(signedMessage.GetType(), signedMessage.Version).Mapping.Values
182 where part.RequiredProtection != ProtectionLevel.None
183 select part.Name;
184 ErrorUtilities.VerifyInternal(partsRequiringProtection.All(name => name.StartsWith(protocol.openid.Prefix, StringComparison.Ordinal)), "Signing only works when the parameters start with the 'openid.' prefix.");
185 string[] signedParts = signedMessage.SignedParameterOrder.Split(',');
186 var unsignedParts = from partName in partsRequiringProtection
187 where !signedParts.Contains(partName.Substring(protocol.openid.Prefix.Length))
188 select partName;
189 ErrorUtilities.VerifyProtocol(!unsignedParts.Any(), OpenIdStrings.SignatureDoesNotIncludeMandatoryParts, string.Join(", ", unsignedParts.ToArray()));
192 /// <summary>
193 /// Gets the value to use for the openid.signed parameter.
194 /// </summary>
195 /// <param name="signedMessage">The signable message.</param>
196 /// <returns>
197 /// A comma-delimited list of parameter names, omitting the 'openid.' prefix, that determines
198 /// the inclusion and order of message parts that will be signed.
199 /// </returns>
200 private string GetSignedParameterOrder(ITamperResistantOpenIdMessage signedMessage) {
201 ErrorUtilities.VerifyArgumentNotNull(signedMessage, "signedMessage");
203 Protocol protocol = Protocol.Lookup(signedMessage.Version);
205 MessageDescription description = MessageDescription.Get(signedMessage.GetType(), signedMessage.Version);
206 var signedParts = from part in description.Mapping.Values
207 where (part.RequiredProtection & System.Net.Security.ProtectionLevel.Sign) != 0
208 && part.GetValue(signedMessage) != null
209 select part.Name;
210 string prefix = Protocol.V20.openid.Prefix;
211 ErrorUtilities.VerifyInternal(signedParts.All(name => name.StartsWith(prefix, StringComparison.Ordinal)), "All signed message parts must start with 'openid.'.");
213 if (this.opSecuritySettings.SignOutgoingExtensions) {
214 // Tack on any ExtraData parameters that start with 'openid.'.
215 List<string> extraSignedParameters = new List<string>(signedMessage.ExtraData.Count);
216 foreach (string key in signedMessage.ExtraData.Keys) {
217 if (key.StartsWith(protocol.openid.Prefix, StringComparison.Ordinal)) {
218 extraSignedParameters.Add(key);
219 } else {
220 Logger.DebugFormat("The extra parameter '{0}' will not be signed because it does not start with 'openid.'.", key);
223 signedParts = signedParts.Concat(extraSignedParameters);
226 int skipLength = prefix.Length;
227 string signedFields = string.Join(",", signedParts.Select(name => name.Substring(skipLength)).ToArray());
228 return signedFields;
231 /// <summary>
232 /// Calculates the signature for a given message.
233 /// </summary>
234 /// <param name="signedMessage">The message to sign or verify.</param>
235 /// <param name="association">The association to use to sign the message.</param>
236 /// <returns>The calculated signature of the method.</returns>
237 private string GetSignature(ITamperResistantOpenIdMessage signedMessage, Association association) {
238 ErrorUtilities.VerifyArgumentNotNull(signedMessage, "signedMessage");
239 ErrorUtilities.VerifyNonZeroLength(signedMessage.SignedParameterOrder, "signedMessage.SignedParameterOrder");
240 ErrorUtilities.VerifyArgumentNotNull(association, "association");
242 // Prepare the parts to sign, taking care to replace an openid.mode value
243 // of check_authentication with its original id_res so the signature matches.
244 Protocol protocol = Protocol.Lookup(signedMessage.Version);
245 MessageDictionary dictionary = new MessageDictionary(signedMessage);
246 var parametersToSign = from name in signedMessage.SignedParameterOrder.Split(',')
247 let prefixedName = Protocol.V20.openid.Prefix + name
248 select new KeyValuePair<string, string>(prefixedName, dictionary[prefixedName]);
250 byte[] dataToSign = KeyValueFormEncoding.GetBytes(parametersToSign);
251 string signature = Convert.ToBase64String(association.Sign(dataToSign));
253 if (signingLogger.IsDebugEnabled) {
254 signingLogger.DebugFormat(
255 "Signing these message parts: {0}{1}{0}Base64 representation of signed data: {2}{0}Signature: {3}",
256 Environment.NewLine,
257 parametersToSign.ToStringDeferred(),
258 Convert.ToBase64String(dataToSign),
259 signature);
262 return signature;
265 /// <summary>
266 /// Gets the association to use to sign or verify a message.
267 /// </summary>
268 /// <param name="signedMessage">The message to sign or verify.</param>
269 /// <returns>The association to use to sign or verify the message.</returns>
270 private Association GetAssociation(ITamperResistantOpenIdMessage signedMessage) {
271 if (this.IsOnProvider) {
272 // We're on a Provider to either sign (smart/dumb) or verify a dumb signature.
273 return this.GetSpecificAssociation(signedMessage) ?? this.GetDumbAssociationForSigning();
274 } else {
275 // We're on a Relying Party verifying a signature.
276 IDirectedProtocolMessage directedMessage = (IDirectedProtocolMessage)signedMessage;
277 if (this.rpAssociations != null) {
278 return this.rpAssociations.GetAssociation(directedMessage.Recipient, signedMessage.AssociationHandle);
279 } else {
280 return null;
285 /// <summary>
286 /// Gets a specific association referenced in a given message's association handle.
287 /// </summary>
288 /// <param name="signedMessage">The signed message whose association handle should be used to lookup the association to return.</param>
289 /// <returns>The referenced association; or <c>null</c> if such an association cannot be found.</returns>
290 /// <remarks>
291 /// If the association handle set in the message does not match any valid association,
292 /// the association handle property is cleared, and the
293 /// <see cref="ITamperResistantOpenIdMessage.InvalidateHandle"/> property is set to the
294 /// handle that could not be found.
295 /// </remarks>
296 private Association GetSpecificAssociation(ITamperResistantOpenIdMessage signedMessage) {
297 Association association = null;
299 if (!string.IsNullOrEmpty(signedMessage.AssociationHandle)) {
300 if (this.IsOnProvider) {
301 // Since we have an association handle, we're either signing with a smart association,
302 // or verifying a dumb one.
303 bool signing = string.IsNullOrEmpty(signedMessage.Signature);
304 ErrorUtilities.VerifyInternal(signing == (signedMessage is PositiveAssertionResponse), "Ooops... somehow we think we're signing a message that isn't a positive assertion!");
305 AssociationRelyingPartyType type = signing ? AssociationRelyingPartyType.Smart : AssociationRelyingPartyType.Dumb;
306 association = this.opAssociations.GetAssociation(type, signedMessage.AssociationHandle);
307 if (association == null) {
308 // There was no valid association with the requested handle.
309 // Let's tell the RP to forget about that association.
310 signedMessage.InvalidateHandle = signedMessage.AssociationHandle;
311 signedMessage.AssociationHandle = null;
313 } else if (this.rpAssociations != null) { // if on a smart RP
314 Uri providerEndpoint = ((IndirectSignedResponse)signedMessage).ProviderEndpoint;
315 association = this.rpAssociations.GetAssociation(providerEndpoint, signedMessage.AssociationHandle);
319 return association;
322 /// <summary>
323 /// Gets a private Provider association used for signing messages in "dumb" mode.
324 /// </summary>
325 /// <returns>An existing or newly created association.</returns>
326 private Association GetDumbAssociationForSigning() {
327 // If no assoc_handle was given or it was invalid, the only thing
328 // left to do is sign a message using a 'dumb' mode association.
329 Protocol protocol = Protocol.Default;
330 Association association = this.opAssociations.GetAssociation(AssociationRelyingPartyType.Dumb);
331 if (association == null) {
332 association = HmacShaAssociation.Create(protocol, protocol.Args.SignatureAlgorithm.HMAC_SHA256, AssociationRelyingPartyType.Dumb, this.opSecuritySettings);
333 this.opAssociations.StoreAssociation(AssociationRelyingPartyType.Dumb, association);
336 return association;