1
//-----------------------------------------------------------------------
2 // <copyright file="CoordinatingChannel.cs" company="Andrew Arnott">
3 // Copyright (c) Andrew Arnott. All rights reserved.
5 //-----------------------------------------------------------------------
7 namespace DotNetOpenAuth
.Test
.Mocks
{
9 using System
.Collections
.Generic
;
12 using System
.Threading
;
13 using DotNetOpenAuth
.Messaging
;
15 internal class CoordinatingChannel
: Channel
{
16 private Channel wrappedChannel
;
17 private EventWaitHandle incomingMessageSignal
= new AutoResetEvent(false);
18 private IProtocolMessage incomingMessage
;
19 private Action
<IProtocolMessage
> incomingMessageFilter
;
20 private Action
<IProtocolMessage
> outgoingMessageFilter
;
22 internal CoordinatingChannel(Channel wrappedChannel
, Action
<IProtocolMessage
> incomingMessageFilter
, Action
<IProtocolMessage
> outgoingMessageFilter
)
23 : base(GetMessageFactory(wrappedChannel
), wrappedChannel
.BindingElements
.ToArray()) {
24 ErrorUtilities
.VerifyArgumentNotNull(wrappedChannel
, "wrappedChannel");
26 this.wrappedChannel
= wrappedChannel
;
27 this.incomingMessageFilter
= incomingMessageFilter
;
28 this.outgoingMessageFilter
= outgoingMessageFilter
;
30 // Preserve any customized binding element ordering.
31 this.CustomizeBindingElementOrder(this.wrappedChannel
.OutgoingBindingElements
, this.wrappedChannel
.IncomingBindingElements
);
35 /// Gets or sets the coordinating channel used by the other party.
37 internal CoordinatingChannel RemoteChannel { get; set; }
40 /// Replays the specified message as if it were received again.
42 /// <param name="message">The message to replay.</param>
43 internal void Replay(IProtocolMessage message
) {
44 this.VerifyMessageAfterReceiving(CloneSerializedParts(message
));
47 protected internal override HttpRequestInfo
GetRequestFromContext() {
48 return new HttpRequestInfo((IDirectedProtocolMessage
)this.AwaitIncomingMessage());
51 protected override IProtocolMessage
RequestInternal(IDirectedProtocolMessage request
) {
52 this.ProcessMessageFilter(request
, true);
53 HttpRequestInfo requestInfo
= this.SpoofHttpMethod(request
);
54 // Drop the outgoing message in the other channel's in-slot and let them know it's there.
55 this.RemoteChannel
.incomingMessage
= requestInfo
.Message
;
56 this.RemoteChannel
.incomingMessageSignal
.Set();
57 // Now wait for a response...
58 IProtocolMessage response
= this.AwaitIncomingMessage();
59 this.ProcessMessageFilter(response
, false);
63 protected override UserAgentResponse
SendDirectMessageResponse(IProtocolMessage response
) {
64 this.ProcessMessageFilter(response
, true);
65 this.RemoteChannel
.incomingMessage
= CloneSerializedParts(response
);
66 this.RemoteChannel
.incomingMessageSignal
.Set();
70 protected override UserAgentResponse
SendIndirectMessage(IDirectedProtocolMessage message
) {
71 this.ProcessMessageFilter(message
, true);
72 // In this mock transport, direct and indirect messages are the same.
73 return this.SendDirectMessageResponse(message
);
76 protected override IDirectedProtocolMessage
ReadFromRequestInternal(HttpRequestInfo request
) {
77 this.ProcessMessageFilter(request
.Message
, false);
78 return request
.Message
;
81 protected override IDictionary
<string, string> ReadFromResponseInternal(DirectWebResponse response
) {
82 Channel_Accessor accessor
= Channel_Accessor
.AttachShadow(this.wrappedChannel
);
83 return accessor
.ReadFromResponseInternal(response
);
86 protected override void VerifyMessageAfterReceiving(IProtocolMessage message
) {
87 Channel_Accessor accessor
= Channel_Accessor
.AttachShadow(this.wrappedChannel
);
88 accessor
.VerifyMessageAfterReceiving(message
);
92 /// Spoof HTTP request information for signing/verification purposes.
94 /// <param name="message">The message to add a pretend HTTP method to.</param>
95 /// <returns>A spoofed HttpRequestInfo that wraps the new message.</returns>
96 protected virtual HttpRequestInfo
SpoofHttpMethod(IDirectedProtocolMessage message
) {
97 HttpRequestInfo requestInfo
= new HttpRequestInfo(message
);
99 requestInfo
.Message
= this.CloneSerializedParts(message
);
104 protected virtual T CloneSerializedParts
<T
>(T message
) where T
: class, IProtocolMessage
{
105 ErrorUtilities
.VerifyArgumentNotNull(message
, "message");
107 IProtocolMessage clonedMessage
;
108 MessageSerializer serializer
= MessageSerializer
.Get(message
.GetType());
109 var fields
= serializer
.Serialize(message
);
111 MessageReceivingEndpoint recipient
= null;
112 var directedMessage
= message
as IDirectedProtocolMessage
;
113 var directResponse
= message
as IDirectResponseProtocolMessage
;
114 if (directedMessage
!= null && directedMessage
.IsRequest()) {
115 if (directedMessage
.Recipient
!= null) {
116 recipient
= new MessageReceivingEndpoint(directedMessage
.Recipient
, directedMessage
.HttpMethods
);
119 clonedMessage
= this.RemoteChannel
.MessageFactory
.GetNewRequestMessage(recipient
, fields
);
120 } else if (directResponse
!= null && directResponse
.IsDirectResponse()) {
121 clonedMessage
= this.RemoteChannel
.MessageFactory
.GetNewResponseMessage(directResponse
.OriginatingRequest
, fields
);
123 throw new InvalidOperationException("Totally expected a message to implement one of the two derived interface types.");
126 ErrorUtilities
.VerifyInternal(clonedMessage
!= null, "Message factory did not generate a message instance for " + message
.GetType().Name
);
128 // Fill the cloned message with data.
129 serializer
.Deserialize(fields
, clonedMessage
);
131 return (T
)clonedMessage
;
134 private static IMessageFactory
GetMessageFactory(Channel channel
) {
135 ErrorUtilities
.VerifyArgumentNotNull(channel
, "channel");
137 Channel_Accessor accessor
= Channel_Accessor
.AttachShadow(channel
);
138 return accessor
.MessageFactory
;
141 private IProtocolMessage
AwaitIncomingMessage() {
142 this.incomingMessageSignal
.WaitOne();
143 IProtocolMessage response
= this.incomingMessage
;
144 this.incomingMessage
= null;
148 private void ProcessMessageFilter(IProtocolMessage message
, bool outgoing
) {
150 if (this.outgoingMessageFilter
!= null) {
151 this.outgoingMessageFilter(message
);
154 if (this.incomingMessageFilter
!= null) {
155 this.incomingMessageFilter(message
);