Added replay detection tests.
[dotnetoauth.git] / src / DotNetOpenAuth.Test / Mocks / CoordinatingChannel.cs
blob711f92466877aea0f1c687cd7501b9afc97b7ba4
1 //-----------------------------------------------------------------------
2 // <copyright file="CoordinatingChannel.cs" company="Andrew Arnott">
3 // Copyright (c) Andrew Arnott. All rights reserved.
4 // </copyright>
5 //-----------------------------------------------------------------------
7 namespace DotNetOpenAuth.Test.Mocks {
8 using System;
9 using System.Collections.Generic;
10 using System.Linq;
11 using System.Text;
12 using System.Threading;
13 using DotNetOpenAuth.Messaging;
15 internal class CoordinatingChannel : Channel {
16 private Channel wrappedChannel;
17 private EventWaitHandle incomingMessageSignal = new AutoResetEvent(false);
18 private IProtocolMessage incomingMessage;
19 private Action<IProtocolMessage> incomingMessageFilter;
20 private Action<IProtocolMessage> outgoingMessageFilter;
22 internal CoordinatingChannel(Channel wrappedChannel, Action<IProtocolMessage> incomingMessageFilter, Action<IProtocolMessage> outgoingMessageFilter)
23 : base(GetMessageFactory(wrappedChannel), wrappedChannel.BindingElements.ToArray()) {
24 ErrorUtilities.VerifyArgumentNotNull(wrappedChannel, "wrappedChannel");
26 this.wrappedChannel = wrappedChannel;
27 this.incomingMessageFilter = incomingMessageFilter;
28 this.outgoingMessageFilter = outgoingMessageFilter;
30 // Preserve any customized binding element ordering.
31 this.CustomizeBindingElementOrder(this.wrappedChannel.OutgoingBindingElements, this.wrappedChannel.IncomingBindingElements);
34 /// <summary>
35 /// Gets or sets the coordinating channel used by the other party.
36 /// </summary>
37 internal CoordinatingChannel RemoteChannel { get; set; }
39 /// <summary>
40 /// Replays the specified message as if it were received again.
41 /// </summary>
42 /// <param name="message">The message to replay.</param>
43 internal void Replay(IProtocolMessage message) {
44 this.VerifyMessageAfterReceiving(CloneSerializedParts(message));
47 protected internal override HttpRequestInfo GetRequestFromContext() {
48 return new HttpRequestInfo((IDirectedProtocolMessage)this.AwaitIncomingMessage());
51 protected override IProtocolMessage RequestInternal(IDirectedProtocolMessage request) {
52 this.ProcessMessageFilter(request, true);
53 HttpRequestInfo requestInfo = this.SpoofHttpMethod(request);
54 // Drop the outgoing message in the other channel's in-slot and let them know it's there.
55 this.RemoteChannel.incomingMessage = requestInfo.Message;
56 this.RemoteChannel.incomingMessageSignal.Set();
57 // Now wait for a response...
58 IProtocolMessage response = this.AwaitIncomingMessage();
59 this.ProcessMessageFilter(response, false);
60 return response;
63 protected override UserAgentResponse SendDirectMessageResponse(IProtocolMessage response) {
64 this.ProcessMessageFilter(response, true);
65 this.RemoteChannel.incomingMessage = CloneSerializedParts(response);
66 this.RemoteChannel.incomingMessageSignal.Set();
67 return null;
70 protected override UserAgentResponse SendIndirectMessage(IDirectedProtocolMessage message) {
71 this.ProcessMessageFilter(message, true);
72 // In this mock transport, direct and indirect messages are the same.
73 return this.SendDirectMessageResponse(message);
76 protected override IDirectedProtocolMessage ReadFromRequestInternal(HttpRequestInfo request) {
77 this.ProcessMessageFilter(request.Message, false);
78 return request.Message;
81 protected override IDictionary<string, string> ReadFromResponseInternal(DirectWebResponse response) {
82 Channel_Accessor accessor = Channel_Accessor.AttachShadow(this.wrappedChannel);
83 return accessor.ReadFromResponseInternal(response);
86 protected override void VerifyMessageAfterReceiving(IProtocolMessage message) {
87 Channel_Accessor accessor = Channel_Accessor.AttachShadow(this.wrappedChannel);
88 accessor.VerifyMessageAfterReceiving(message);
91 /// <summary>
92 /// Spoof HTTP request information for signing/verification purposes.
93 /// </summary>
94 /// <param name="message">The message to add a pretend HTTP method to.</param>
95 /// <returns>A spoofed HttpRequestInfo that wraps the new message.</returns>
96 protected virtual HttpRequestInfo SpoofHttpMethod(IDirectedProtocolMessage message) {
97 HttpRequestInfo requestInfo = new HttpRequestInfo(message);
99 requestInfo.Message = this.CloneSerializedParts(message);
101 return requestInfo;
104 protected virtual T CloneSerializedParts<T>(T message) where T : class, IProtocolMessage {
105 ErrorUtilities.VerifyArgumentNotNull(message, "message");
107 IProtocolMessage clonedMessage;
108 MessageSerializer serializer = MessageSerializer.Get(message.GetType());
109 var fields = serializer.Serialize(message);
111 MessageReceivingEndpoint recipient = null;
112 var directedMessage = message as IDirectedProtocolMessage;
113 var directResponse = message as IDirectResponseProtocolMessage;
114 if (directedMessage != null && directedMessage.IsRequest()) {
115 if (directedMessage.Recipient != null) {
116 recipient = new MessageReceivingEndpoint(directedMessage.Recipient, directedMessage.HttpMethods);
119 clonedMessage = this.RemoteChannel.MessageFactory.GetNewRequestMessage(recipient, fields);
120 } else if (directResponse != null && directResponse.IsDirectResponse()) {
121 clonedMessage = this.RemoteChannel.MessageFactory.GetNewResponseMessage(directResponse.OriginatingRequest, fields);
122 } else {
123 throw new InvalidOperationException("Totally expected a message to implement one of the two derived interface types.");
126 ErrorUtilities.VerifyInternal(clonedMessage != null, "Message factory did not generate a message instance for " + message.GetType().Name);
128 // Fill the cloned message with data.
129 serializer.Deserialize(fields, clonedMessage);
131 return (T)clonedMessage;
134 private static IMessageFactory GetMessageFactory(Channel channel) {
135 ErrorUtilities.VerifyArgumentNotNull(channel, "channel");
137 Channel_Accessor accessor = Channel_Accessor.AttachShadow(channel);
138 return accessor.MessageFactory;
141 private IProtocolMessage AwaitIncomingMessage() {
142 this.incomingMessageSignal.WaitOne();
143 IProtocolMessage response = this.incomingMessage;
144 this.incomingMessage = null;
145 return response;
148 private void ProcessMessageFilter(IProtocolMessage message, bool outgoing) {
149 if (outgoing) {
150 if (this.outgoingMessageFilter != null) {
151 this.outgoingMessageFilter(message);
153 } else {
154 if (this.incomingMessageFilter != null) {
155 this.incomingMessageFilter(message);