1 // Copyright 2013 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
5 #include "google_apis/gcm/engine/connection_handler_impl.h"
7 #include "base/message_loop/message_loop.h"
8 #include "google/protobuf/io/coded_stream.h"
9 #include "google_apis/gcm/base/mcs_util.h"
10 #include "google_apis/gcm/base/socket_stream.h"
11 #include "google_apis/gcm/protocol/mcs.pb.h"
12 #include "net/base/net_errors.h"
13 #include "net/socket/stream_socket.h"
15 using namespace google::protobuf::io
;
21 // # of bytes a MCS version packet consumes.
22 const int kVersionPacketLen
= 1;
23 // # of bytes a tag packet consumes.
24 const int kTagPacketLen
= 1;
25 // Max # of bytes a length packet consumes.
26 const int kSizePacketLenMin
= 1;
27 const int kSizePacketLenMax
= 2;
29 // The current MCS protocol version.
30 // TODO(zea): bump to 41 once the server supports it.
31 const int kMCSVersion
= 38;
35 ConnectionHandlerImpl::ConnectionHandlerImpl(
36 base::TimeDelta read_timeout
,
37 const ProtoReceivedCallback
& read_callback
,
38 const ProtoSentCallback
& write_callback
,
39 const ConnectionChangedCallback
& connection_callback
)
40 : read_timeout_(read_timeout
),
41 handshake_complete_(false),
44 read_callback_(read_callback
),
45 write_callback_(write_callback
),
46 connection_callback_(connection_callback
),
47 weak_ptr_factory_(this) {
50 ConnectionHandlerImpl::~ConnectionHandlerImpl() {
53 void ConnectionHandlerImpl::Init(
54 const mcs_proto::LoginRequest
& login_request
,
55 net::StreamSocket
* socket
) {
56 DCHECK(!read_callback_
.is_null());
57 DCHECK(!write_callback_
.is_null());
58 DCHECK(!connection_callback_
.is_null());
60 // Invalidate any previously outstanding reads.
61 weak_ptr_factory_
.InvalidateWeakPtrs();
63 handshake_complete_
= false;
67 input_stream_
.reset(new SocketInputStream(socket_
));
68 output_stream_
.reset(new SocketOutputStream(socket_
));
73 bool ConnectionHandlerImpl::CanSendMessage() const {
74 return handshake_complete_
&& output_stream_
.get() &&
75 output_stream_
->GetState() == SocketOutputStream::EMPTY
;
78 void ConnectionHandlerImpl::SendMessage(
79 const google::protobuf::MessageLite
& message
) {
80 DCHECK_EQ(output_stream_
->GetState(), SocketOutputStream::EMPTY
);
81 DCHECK(handshake_complete_
);
84 CodedOutputStream
coded_output_stream(output_stream_
.get());
85 DVLOG(1) << "Writing proto of size " << message
.ByteSize();
86 int tag
= GetMCSProtoTag(message
);
88 coded_output_stream
.WriteRaw(&tag
, 1);
89 coded_output_stream
.WriteVarint32(message
.ByteSize());
90 message
.SerializeToCodedStream(&coded_output_stream
);
93 if (output_stream_
->Flush(
94 base::Bind(&ConnectionHandlerImpl::OnMessageSent
,
95 weak_ptr_factory_
.GetWeakPtr())) != net::ERR_IO_PENDING
) {
100 void ConnectionHandlerImpl::Login(
101 const google::protobuf::MessageLite
& login_request
) {
102 DCHECK_EQ(output_stream_
->GetState(), SocketOutputStream::EMPTY
);
104 const char version_byte
[1] = {kMCSVersion
};
105 const char login_request_tag
[1] = {kLoginRequestTag
};
107 CodedOutputStream
coded_output_stream(output_stream_
.get());
108 coded_output_stream
.WriteRaw(version_byte
, 1);
109 coded_output_stream
.WriteRaw(login_request_tag
, 1);
110 coded_output_stream
.WriteVarint32(login_request
.ByteSize());
111 login_request
.SerializeToCodedStream(&coded_output_stream
);
114 if (output_stream_
->Flush(
115 base::Bind(&ConnectionHandlerImpl::OnMessageSent
,
116 weak_ptr_factory_
.GetWeakPtr())) != net::ERR_IO_PENDING
) {
117 base::MessageLoop::current()->PostTask(
119 base::Bind(&ConnectionHandlerImpl::OnMessageSent
,
120 weak_ptr_factory_
.GetWeakPtr()));
123 read_timeout_timer_
.Start(FROM_HERE
,
125 base::Bind(&ConnectionHandlerImpl::OnTimeout
,
126 weak_ptr_factory_
.GetWeakPtr()));
127 WaitForData(MCS_VERSION_TAG_AND_SIZE
);
130 void ConnectionHandlerImpl::OnMessageSent() {
131 if (!output_stream_
.get()) {
132 // The connection has already been closed. Just return.
133 DCHECK(!input_stream_
.get());
134 DCHECK(!read_timeout_timer_
.IsRunning());
138 if (output_stream_
->GetState() != SocketOutputStream::EMPTY
) {
139 int last_error
= output_stream_
->last_error();
141 // If the socket stream had an error, plumb it up, else plumb up FAILED.
142 if (last_error
== net::OK
)
143 last_error
= net::ERR_FAILED
;
144 connection_callback_
.Run(last_error
);
148 write_callback_
.Run();
151 void ConnectionHandlerImpl::GetNextMessage() {
152 DCHECK(SocketInputStream::EMPTY
== input_stream_
->GetState() ||
153 SocketInputStream::READY
== input_stream_
->GetState());
157 WaitForData(MCS_TAG_AND_SIZE
);
160 void ConnectionHandlerImpl::WaitForData(ProcessingState state
) {
161 DVLOG(1) << "Waiting for MCS data: state == " << state
;
163 if (!input_stream_
) {
164 // The connection has already been closed. Just return.
165 DCHECK(!output_stream_
.get());
166 DCHECK(!read_timeout_timer_
.IsRunning());
170 if (input_stream_
->GetState() != SocketInputStream::EMPTY
&&
171 input_stream_
->GetState() != SocketInputStream::READY
) {
172 // An error occurred.
173 int last_error
= output_stream_
->last_error();
175 // If the socket stream had an error, plumb it up, else plumb up FAILED.
176 if (last_error
== net::OK
)
177 last_error
= net::ERR_FAILED
;
178 connection_callback_
.Run(last_error
);
182 // Used to determine whether a Socket::Read is necessary.
183 int min_bytes_needed
= 0;
184 // Used to limit the size of the Socket::Read.
185 int max_bytes_needed
= 0;
188 case MCS_VERSION_TAG_AND_SIZE
:
189 min_bytes_needed
= kVersionPacketLen
+ kTagPacketLen
+ kSizePacketLenMin
;
190 max_bytes_needed
= kVersionPacketLen
+ kTagPacketLen
+ kSizePacketLenMax
;
192 case MCS_TAG_AND_SIZE
:
193 min_bytes_needed
= kTagPacketLen
+ kSizePacketLenMin
;
194 max_bytes_needed
= kTagPacketLen
+ kSizePacketLenMax
;
197 // If in this state, the minimum size packet length must already have been
198 // insufficient, so set both to the max length.
199 min_bytes_needed
= kSizePacketLenMax
;
200 max_bytes_needed
= kSizePacketLenMax
;
202 case MCS_PROTO_BYTES
:
203 read_timeout_timer_
.Reset();
204 // No variability in the message size, set both to the same.
205 min_bytes_needed
= message_size_
;
206 max_bytes_needed
= message_size_
;
211 DCHECK_GE(max_bytes_needed
, min_bytes_needed
);
213 int byte_count
= input_stream_
->UnreadByteCount();
214 if (min_bytes_needed
- byte_count
> 0 &&
215 input_stream_
->Refresh(
216 base::Bind(&ConnectionHandlerImpl::WaitForData
,
217 weak_ptr_factory_
.GetWeakPtr(),
219 max_bytes_needed
- byte_count
) == net::ERR_IO_PENDING
) {
223 // Check for refresh errors.
224 if (input_stream_
->GetState() != SocketInputStream::READY
) {
225 // An error occurred.
226 int last_error
= output_stream_
->last_error();
228 // If the socket stream had an error, plumb it up, else plumb up FAILED.
229 if (last_error
== net::OK
)
230 last_error
= net::ERR_FAILED
;
231 connection_callback_
.Run(last_error
);
235 // Received enough bytes, process them.
236 DVLOG(1) << "Processing MCS data: state == " << state
;
238 case MCS_VERSION_TAG_AND_SIZE
:
241 case MCS_TAG_AND_SIZE
:
247 case MCS_PROTO_BYTES
:
255 void ConnectionHandlerImpl::OnGotVersion() {
258 CodedInputStream
coded_input_stream(input_stream_
.get());
259 coded_input_stream
.ReadRaw(&version
, 1);
261 if (version
< kMCSVersion
) {
262 LOG(ERROR
) << "Invalid GCM version response: " << static_cast<int>(version
);
263 connection_callback_
.Run(net::ERR_FAILED
);
267 input_stream_
->RebuildBuffer();
269 // Process the LoginResponse message tag.
273 void ConnectionHandlerImpl::OnGotMessageTag() {
274 if (input_stream_
->GetState() != SocketInputStream::READY
) {
275 LOG(ERROR
) << "Failed to receive protobuf tag.";
276 read_callback_
.Run(scoped_ptr
<google::protobuf::MessageLite
>());
281 CodedInputStream
coded_input_stream(input_stream_
.get());
282 coded_input_stream
.ReadRaw(&message_tag_
, 1);
285 DVLOG(1) << "Received proto of type "
286 << static_cast<unsigned int>(message_tag_
);
288 if (!read_timeout_timer_
.IsRunning()) {
289 read_timeout_timer_
.Start(FROM_HERE
,
291 base::Bind(&ConnectionHandlerImpl::OnTimeout
,
292 weak_ptr_factory_
.GetWeakPtr()));
297 void ConnectionHandlerImpl::OnGotMessageSize() {
298 if (input_stream_
->GetState() != SocketInputStream::READY
) {
299 LOG(ERROR
) << "Failed to receive message size.";
300 read_callback_
.Run(scoped_ptr
<google::protobuf::MessageLite
>());
304 bool need_another_byte
= false;
305 int prev_byte_count
= input_stream_
->ByteCount();
307 CodedInputStream
coded_input_stream(input_stream_
.get());
308 if (!coded_input_stream
.ReadVarint32(&message_size_
))
309 need_another_byte
= true;
312 if (need_another_byte
) {
313 DVLOG(1) << "Expecting another message size byte.";
314 if (prev_byte_count
>= kSizePacketLenMax
) {
315 // Already had enough bytes, something else went wrong.
316 LOG(ERROR
) << "Failed to process message size.";
317 read_callback_
.Run(scoped_ptr
<google::protobuf::MessageLite
>());
320 // Back up by the amount read (should always be 1 byte).
321 int bytes_read
= prev_byte_count
- input_stream_
->ByteCount();
322 DCHECK_EQ(bytes_read
, 1);
323 input_stream_
->BackUp(bytes_read
);
324 WaitForData(MCS_FULL_SIZE
);
328 DVLOG(1) << "Proto size: " << message_size_
;
330 if (message_size_
> 0)
331 WaitForData(MCS_PROTO_BYTES
);
336 void ConnectionHandlerImpl::OnGotMessageBytes() {
337 read_timeout_timer_
.Stop();
338 scoped_ptr
<google::protobuf::MessageLite
> protobuf(
339 BuildProtobufFromTag(message_tag_
));
340 // Messages with no content are valid; just use the default protobuf for
342 if (protobuf
.get() && message_size_
== 0) {
343 base::MessageLoop::current()->PostTask(
345 base::Bind(&ConnectionHandlerImpl::GetNextMessage
,
346 weak_ptr_factory_
.GetWeakPtr()));
347 read_callback_
.Run(protobuf
.Pass());
351 if (!protobuf
.get() ||
352 input_stream_
->GetState() != SocketInputStream::READY
) {
353 LOG(ERROR
) << "Failed to extract protobuf bytes of type "
354 << static_cast<unsigned int>(message_tag_
);
355 protobuf
.reset(); // Return a null pointer to denote an error.
356 read_callback_
.Run(protobuf
.Pass());
361 CodedInputStream
coded_input_stream(input_stream_
.get());
362 if (!protobuf
->ParsePartialFromCodedStream(&coded_input_stream
)) {
363 NOTREACHED() << "Unable to parse GCM message of type "
364 << static_cast<unsigned int>(message_tag_
);
365 protobuf
.reset(); // Return a null pointer to denote an error.
366 read_callback_
.Run(protobuf
.Pass());
371 input_stream_
->RebuildBuffer();
372 base::MessageLoop::current()->PostTask(
374 base::Bind(&ConnectionHandlerImpl::GetNextMessage
,
375 weak_ptr_factory_
.GetWeakPtr()));
376 if (message_tag_
== kLoginResponseTag
) {
377 if (handshake_complete_
) {
378 LOG(ERROR
) << "Unexpected login response.";
380 handshake_complete_
= true;
381 DVLOG(1) << "GCM Handshake complete.";
382 connection_callback_
.Run(net::OK
);
385 read_callback_
.Run(protobuf
.Pass());
388 void ConnectionHandlerImpl::OnTimeout() {
389 LOG(ERROR
) << "Timed out waiting for GCM Protocol buffer.";
391 connection_callback_
.Run(net::ERR_TIMED_OUT
);
394 void ConnectionHandlerImpl::CloseConnection() {
395 DVLOG(1) << "Closing connection.";
396 read_timeout_timer_
.Stop();
397 socket_
->Disconnect();
398 input_stream_
.reset();
399 output_stream_
.reset();
400 weak_ptr_factory_
.InvalidateWeakPtrs();