WebKit roll 90808:90810
[chromium-blink-merge.git] / net / websockets / websocket_handshake_handler.cc
blob4a424ca91c6eab3ba9eb24c6e64a9b0d05d77e36
1 // Copyright (c) 2011 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
5 #include "net/websockets/websocket_handshake_handler.h"
7 #include "base/base64.h"
8 #include "base/md5.h"
9 #include "base/sha1.h"
10 #include "base/string_number_conversions.h"
11 #include "base/string_piece.h"
12 #include "base/string_util.h"
13 #include "googleurl/src/gurl.h"
14 #include "net/http/http_response_headers.h"
15 #include "net/http/http_util.h"
17 namespace {
19 const size_t kRequestKey3Size = 8U;
20 const size_t kResponseKeySize = 16U;
22 // First version that introduced new WebSocket handshake which does not
23 // require sending "key3" or "response key" data after headers.
24 const int kMinVersionOfHybiNewHandshake = 4;
26 // Used when we calculate the value of Sec-WebSocket-Accept.
27 const char* const kWebSocketGuid = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
29 void ParseHandshakeHeader(
30 const char* handshake_message, int len,
31 std::string* status_line,
32 std::string* headers) {
33 size_t i = base::StringPiece(handshake_message, len).find_first_of("\r\n");
34 if (i == base::StringPiece::npos) {
35 *status_line = std::string(handshake_message, len);
36 *headers = "";
37 return;
39 // |status_line| includes \r\n.
40 *status_line = std::string(handshake_message, i + 2);
42 int header_len = len - (i + 2) - 2;
43 if (header_len > 0) {
44 // |handshake_message| includes tailing \r\n\r\n.
45 // |headers| doesn't include 2nd \r\n.
46 *headers = std::string(handshake_message + i + 2, header_len);
47 } else {
48 *headers = "";
52 void FetchHeaders(const std::string& headers,
53 const char* const headers_to_get[],
54 size_t headers_to_get_len,
55 std::vector<std::string>* values) {
56 net::HttpUtil::HeadersIterator iter(headers.begin(), headers.end(), "\r\n");
57 while (iter.GetNext()) {
58 for (size_t i = 0; i < headers_to_get_len; i++) {
59 if (LowerCaseEqualsASCII(iter.name_begin(), iter.name_end(),
60 headers_to_get[i])) {
61 values->push_back(iter.values());
67 bool GetHeaderName(std::string::const_iterator line_begin,
68 std::string::const_iterator line_end,
69 std::string::const_iterator* name_begin,
70 std::string::const_iterator* name_end) {
71 std::string::const_iterator colon = std::find(line_begin, line_end, ':');
72 if (colon == line_end) {
73 return false;
75 *name_begin = line_begin;
76 *name_end = colon;
77 if (*name_begin == *name_end || net::HttpUtil::IsLWS(**name_begin))
78 return false;
79 net::HttpUtil::TrimLWS(name_begin, name_end);
80 return true;
83 // Similar to HttpUtil::StripHeaders, but it preserves malformed headers, that
84 // is, lines that are not formatted as "<name>: <value>\r\n".
85 std::string FilterHeaders(
86 const std::string& headers,
87 const char* const headers_to_remove[],
88 size_t headers_to_remove_len) {
89 std::string filtered_headers;
91 StringTokenizer lines(headers.begin(), headers.end(), "\r\n");
92 while (lines.GetNext()) {
93 std::string::const_iterator line_begin = lines.token_begin();
94 std::string::const_iterator line_end = lines.token_end();
95 std::string::const_iterator name_begin;
96 std::string::const_iterator name_end;
97 bool should_remove = false;
98 if (GetHeaderName(line_begin, line_end, &name_begin, &name_end)) {
99 for (size_t i = 0; i < headers_to_remove_len; ++i) {
100 if (LowerCaseEqualsASCII(name_begin, name_end, headers_to_remove[i])) {
101 should_remove = true;
102 break;
106 if (!should_remove) {
107 filtered_headers.append(line_begin, line_end);
108 filtered_headers.append("\r\n");
111 return filtered_headers;
114 // Gets a key number from |key| and appends the number to |challenge|.
115 // The key number (/part_N/) is extracted as step 4.-8. in
116 // 5.2. Sending the server's opening handshake of
117 // http://www.ietf.org/id/draft-ietf-hybi-thewebsocketprotocol-00.txt
118 void GetKeyNumber(const std::string& key, std::string* challenge) {
119 uint32 key_number = 0;
120 uint32 spaces = 0;
121 for (size_t i = 0; i < key.size(); ++i) {
122 if (isdigit(key[i])) {
123 // key_number should not overflow. (it comes from
124 // WebCore/websockets/WebSocketHandshake.cpp).
125 key_number = key_number * 10 + key[i] - '0';
126 } else if (key[i] == ' ') {
127 ++spaces;
130 // spaces should not be zero in valid handshake request.
131 if (spaces == 0)
132 return;
133 key_number /= spaces;
135 char part[4];
136 for (int i = 0; i < 4; i++) {
137 part[3 - i] = key_number & 0xFF;
138 key_number >>= 8;
140 challenge->append(part, 4);
143 int GetVersionFromRequest(const std::string& request_headers) {
144 std::vector<std::string> values;
145 const char* const headers_to_get[2] = { "sec-websocket-version",
146 "sec-websocket-draft" };
147 FetchHeaders(request_headers, headers_to_get, 2, &values);
148 DCHECK_LE(values.size(), 1U);
149 if (values.empty())
150 return 0;
151 int version;
152 bool conversion_success = base::StringToInt(values[0], &version);
153 DCHECK(conversion_success);
154 DCHECK_GE(version, 1);
155 return version;
158 } // anonymous namespace
160 namespace net {
162 WebSocketHandshakeRequestHandler::WebSocketHandshakeRequestHandler()
163 : original_length_(0),
164 raw_length_(0),
165 protocol_version_(-1) {}
167 bool WebSocketHandshakeRequestHandler::ParseRequest(
168 const char* data, int length) {
169 DCHECK_GT(length, 0);
170 std::string input(data, length);
171 int input_header_length =
172 HttpUtil::LocateEndOfHeaders(input.data(), input.size(), 0);
173 if (input_header_length <= 0)
174 return false;
176 ParseHandshakeHeader(input.data(),
177 input_header_length,
178 &status_line_,
179 &headers_);
181 // WebSocket protocol drafts hixie-76 (hybi-00), hybi-01, 02 and 03 require
182 // the clients to send key3 after the handshake request header fields.
183 // Hybi-04 and later drafts, on the other hand, no longer have key3
184 // in the handshake format.
185 protocol_version_ = GetVersionFromRequest(headers_);
186 DCHECK_GE(protocol_version_, 0);
187 if (protocol_version_ >= kMinVersionOfHybiNewHandshake) {
188 key3_ = "";
189 original_length_ = input_header_length;
190 return true;
193 if (input_header_length + kRequestKey3Size > input.size())
194 return false;
196 // Assumes WebKit doesn't send any data after handshake request message
197 // until handshake is finished.
198 // Thus, |key3_| is part of handshake message, and not in part
199 // of WebSocket frame stream.
200 DCHECK_EQ(kRequestKey3Size, input.size() - input_header_length);
201 key3_ = std::string(input.data() + input_header_length,
202 input.size() - input_header_length);
203 original_length_ = input.size();
204 return true;
207 size_t WebSocketHandshakeRequestHandler::original_length() const {
208 return original_length_;
211 void WebSocketHandshakeRequestHandler::AppendHeaderIfMissing(
212 const std::string& name, const std::string& value) {
213 DCHECK(!headers_.empty());
214 HttpUtil::AppendHeaderIfMissing(name.c_str(), value, &headers_);
217 void WebSocketHandshakeRequestHandler::RemoveHeaders(
218 const char* const headers_to_remove[],
219 size_t headers_to_remove_len) {
220 DCHECK(!headers_.empty());
221 headers_ = FilterHeaders(
222 headers_, headers_to_remove, headers_to_remove_len);
225 HttpRequestInfo WebSocketHandshakeRequestHandler::GetRequestInfo(
226 const GURL& url, std::string* challenge) {
227 HttpRequestInfo request_info;
228 request_info.url = url;
229 size_t method_end = base::StringPiece(status_line_).find_first_of(" ");
230 if (method_end != base::StringPiece::npos)
231 request_info.method = std::string(status_line_.data(), method_end);
233 request_info.extra_headers.Clear();
234 request_info.extra_headers.AddHeadersFromString(headers_);
236 request_info.extra_headers.RemoveHeader("Upgrade");
237 request_info.extra_headers.RemoveHeader("Connection");
239 if (protocol_version_ >= kMinVersionOfHybiNewHandshake) {
240 std::string key;
241 bool header_present =
242 request_info.extra_headers.GetHeader("Sec-WebSocket-Key", &key);
243 DCHECK(header_present);
244 request_info.extra_headers.RemoveHeader("Sec-WebSocket-Key");
245 *challenge = key;
246 } else {
247 challenge->clear();
248 std::string key;
249 bool header_present =
250 request_info.extra_headers.GetHeader("Sec-WebSocket-Key1", &key);
251 DCHECK(header_present);
252 request_info.extra_headers.RemoveHeader("Sec-WebSocket-Key1");
253 GetKeyNumber(key, challenge);
255 header_present =
256 request_info.extra_headers.GetHeader("Sec-WebSocket-Key2", &key);
257 DCHECK(header_present);
258 request_info.extra_headers.RemoveHeader("Sec-WebSocket-Key2");
259 GetKeyNumber(key, challenge);
261 challenge->append(key3_);
264 return request_info;
267 bool WebSocketHandshakeRequestHandler::GetRequestHeaderBlock(
268 const GURL& url, spdy::SpdyHeaderBlock* headers, std::string* challenge) {
269 // We don't set "method" and "version". These are fixed value in WebSocket
270 // protocol.
271 (*headers)["url"] = url.spec();
273 std::string new_key; // For protocols hybi-04 and newer.
274 std::string old_key1; // For protocols hybi-03 and older.
275 std::string old_key2; // Ditto.
276 HttpUtil::HeadersIterator iter(headers_.begin(), headers_.end(), "\r\n");
277 while (iter.GetNext()) {
278 if (LowerCaseEqualsASCII(iter.name_begin(), iter.name_end(),
279 "connection")) {
280 // Ignore "Connection" header.
281 continue;
282 } else if (LowerCaseEqualsASCII(iter.name_begin(), iter.name_end(),
283 "upgrade")) {
284 // Ignore "Upgrade" header.
285 continue;
286 } else if (LowerCaseEqualsASCII(iter.name_begin(), iter.name_end(),
287 "sec-websocket-key1")) {
288 // Only used for generating challenge.
289 old_key1 = iter.values();
290 continue;
291 } else if (LowerCaseEqualsASCII(iter.name_begin(), iter.name_end(),
292 "sec-websocket-key2")) {
293 // Only used for generating challenge.
294 old_key2 = iter.values();
295 continue;
296 } else if (LowerCaseEqualsASCII(iter.name_begin(), iter.name_end(),
297 "sec-websocket-key")) {
298 // Only used for generating challenge.
299 new_key = iter.values();
300 continue;
302 // Others should be sent out to |headers|.
303 std::string name = StringToLowerASCII(iter.name());
304 spdy::SpdyHeaderBlock::iterator found = headers->find(name);
305 if (found == headers->end()) {
306 (*headers)[name] = iter.values();
307 } else {
308 // For now, websocket doesn't use multiple headers, but follows to http.
309 found->second.append(1, '\0'); // +=() doesn't append 0's
310 found->second.append(iter.values());
314 if (protocol_version_ >= kMinVersionOfHybiNewHandshake) {
315 DVLOG_IF(1, !old_key1.empty())
316 << "Server sent unexpected Sec-WebSocket-Key1 header.";
317 DVLOG_IF(1, !old_key2.empty())
318 << "Server sent unexpected Sec-WebSocket-Key2 header.";
319 *challenge = new_key;
320 } else {
321 DVLOG_IF(1, !new_key.empty())
322 << "Server sent unexpected Sec-WebSocket-Key header.";
323 challenge->clear();
324 GetKeyNumber(old_key1, challenge);
325 GetKeyNumber(old_key2, challenge);
326 challenge->append(key3_);
329 return true;
332 std::string WebSocketHandshakeRequestHandler::GetRawRequest() {
333 DCHECK(!status_line_.empty());
334 DCHECK(!headers_.empty());
335 // The following works on both hybi-04 and older handshake,
336 // because |key3_| is guaranteed to be empty if the handshake was hybi-04's.
337 std::string raw_request = status_line_ + headers_ + "\r\n" + key3_;
338 raw_length_ = raw_request.size();
339 return raw_request;
342 size_t WebSocketHandshakeRequestHandler::raw_length() const {
343 DCHECK_GT(raw_length_, 0);
344 return raw_length_;
347 int WebSocketHandshakeRequestHandler::protocol_version() const {
348 DCHECK_GE(protocol_version_, 0);
349 return protocol_version_;
352 WebSocketHandshakeResponseHandler::WebSocketHandshakeResponseHandler()
353 : original_header_length_(0),
354 protocol_version_(0) {}
356 WebSocketHandshakeResponseHandler::~WebSocketHandshakeResponseHandler() {}
358 int WebSocketHandshakeResponseHandler::protocol_version() const {
359 DCHECK_GE(protocol_version_, 0);
360 return protocol_version_;
363 void WebSocketHandshakeResponseHandler::set_protocol_version(
364 int protocol_version) {
365 DCHECK_GE(protocol_version, 0);
366 protocol_version_ = protocol_version;
369 size_t WebSocketHandshakeResponseHandler::ParseRawResponse(
370 const char* data, int length) {
371 DCHECK_GT(length, 0);
372 if (HasResponse()) {
373 DCHECK(!status_line_.empty());
374 DCHECK(!headers_.empty());
375 DCHECK_EQ(GetResponseKeySize(), key_.size());
376 return 0;
379 size_t old_original_length = original_.size();
381 original_.append(data, length);
382 // TODO(ukai): fail fast when response gives wrong status code.
383 original_header_length_ = HttpUtil::LocateEndOfHeaders(
384 original_.data(), original_.size(), 0);
385 if (!HasResponse())
386 return length;
388 ParseHandshakeHeader(original_.data(),
389 original_header_length_,
390 &status_line_,
391 &headers_);
392 int header_size = status_line_.size() + headers_.size();
393 DCHECK_GE(original_header_length_, header_size);
394 header_separator_ = std::string(original_.data() + header_size,
395 original_header_length_ - header_size);
396 key_ = std::string(original_.data() + original_header_length_,
397 GetResponseKeySize());
398 return original_header_length_ + GetResponseKeySize() - old_original_length;
401 bool WebSocketHandshakeResponseHandler::HasResponse() const {
402 return original_header_length_ > 0 &&
403 original_header_length_ + GetResponseKeySize() <= original_.size();
406 bool WebSocketHandshakeResponseHandler::ParseResponseInfo(
407 const HttpResponseInfo& response_info,
408 const std::string& challenge) {
409 if (!response_info.headers.get())
410 return false;
412 std::string response_message;
413 response_message = response_info.headers->GetStatusLine();
414 response_message += "\r\n";
415 if (protocol_version_ >= kMinVersionOfHybiNewHandshake)
416 response_message += "Upgrade: websocket\r\n";
417 else
418 response_message += "Upgrade: WebSocket\r\n";
419 response_message += "Connection: Upgrade\r\n";
421 if (protocol_version_ >= kMinVersionOfHybiNewHandshake) {
422 std::string hash = base::SHA1HashString(challenge + kWebSocketGuid);
423 std::string websocket_accept;
424 bool encode_success = base::Base64Encode(hash, &websocket_accept);
425 DCHECK(encode_success);
426 response_message += "Sec-WebSocket-Accept: " + websocket_accept + "\r\n";
429 void* iter = NULL;
430 std::string name;
431 std::string value;
432 while (response_info.headers->EnumerateHeaderLines(&iter, &name, &value)) {
433 response_message += name + ": " + value + "\r\n";
435 response_message += "\r\n";
437 if (protocol_version_ < kMinVersionOfHybiNewHandshake) {
438 MD5Digest digest;
439 MD5Sum(challenge.data(), challenge.size(), &digest);
441 const char* digest_data = reinterpret_cast<char*>(digest.a);
442 response_message.append(digest_data, sizeof(digest.a));
445 return ParseRawResponse(response_message.data(),
446 response_message.size()) == response_message.size();
449 bool WebSocketHandshakeResponseHandler::ParseResponseHeaderBlock(
450 const spdy::SpdyHeaderBlock& headers,
451 const std::string& challenge) {
452 std::string response_message;
453 if (protocol_version_ >= kMinVersionOfHybiNewHandshake) {
454 response_message = "HTTP/1.1 101 Switching Protocols\r\n";
455 response_message += "Upgrade: websocket\r\n";
456 } else {
457 response_message = "HTTP/1.1 101 WebSocket Protocol Handshake\r\n";
458 response_message += "Upgrade: WebSocket\r\n";
460 response_message += "Connection: Upgrade\r\n";
462 if (protocol_version_ >= kMinVersionOfHybiNewHandshake) {
463 std::string hash = base::SHA1HashString(challenge + kWebSocketGuid);
464 std::string websocket_accept;
465 bool encode_success = base::Base64Encode(hash, &websocket_accept);
466 DCHECK(encode_success);
467 response_message += "Sec-WebSocket-Accept: " + websocket_accept + "\r\n";
470 for (spdy::SpdyHeaderBlock::const_iterator iter = headers.begin();
471 iter != headers.end();
472 ++iter) {
473 // For each value, if the server sends a NUL-separated list of values,
474 // we separate that back out into individual headers for each value
475 // in the list.
476 const std::string& value = iter->second;
477 size_t start = 0;
478 size_t end = 0;
479 do {
480 end = value.find('\0', start);
481 std::string tval;
482 if (end != std::string::npos)
483 tval = value.substr(start, (end - start));
484 else
485 tval = value.substr(start);
486 response_message += iter->first + ": " + tval + "\r\n";
487 start = end + 1;
488 } while (end != std::string::npos);
490 response_message += "\r\n";
492 if (protocol_version_ < kMinVersionOfHybiNewHandshake) {
493 MD5Digest digest;
494 MD5Sum(challenge.data(), challenge.size(), &digest);
496 const char* digest_data = reinterpret_cast<char*>(digest.a);
497 response_message.append(digest_data, sizeof(digest.a));
500 return ParseRawResponse(response_message.data(),
501 response_message.size()) == response_message.size();
504 void WebSocketHandshakeResponseHandler::GetHeaders(
505 const char* const headers_to_get[],
506 size_t headers_to_get_len,
507 std::vector<std::string>* values) {
508 DCHECK(HasResponse());
509 DCHECK(!status_line_.empty());
510 DCHECK(!headers_.empty());
511 DCHECK_EQ(GetResponseKeySize(), key_.size());
513 FetchHeaders(headers_, headers_to_get, headers_to_get_len, values);
516 void WebSocketHandshakeResponseHandler::RemoveHeaders(
517 const char* const headers_to_remove[],
518 size_t headers_to_remove_len) {
519 DCHECK(HasResponse());
520 DCHECK(!status_line_.empty());
521 DCHECK(!headers_.empty());
522 DCHECK_EQ(GetResponseKeySize(), key_.size());
524 headers_ = FilterHeaders(headers_, headers_to_remove, headers_to_remove_len);
527 std::string WebSocketHandshakeResponseHandler::GetRawResponse() const {
528 DCHECK(HasResponse());
529 return std::string(original_.data(),
530 original_header_length_ + GetResponseKeySize());
533 std::string WebSocketHandshakeResponseHandler::GetResponse() {
534 DCHECK(HasResponse());
535 DCHECK(!status_line_.empty());
536 // headers_ might be empty for wrong response from server.
537 DCHECK_EQ(GetResponseKeySize(), key_.size());
539 return status_line_ + headers_ + header_separator_ + key_;
542 size_t WebSocketHandshakeResponseHandler::GetResponseKeySize() const {
543 if (protocol_version_ >= kMinVersionOfHybiNewHandshake)
544 return 0;
545 return kResponseKeySize;
548 } // namespace net