1 // Copyright (c) 2012 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
5 #include "net/server/web_socket.h"
9 #include "base/base64.h"
10 #include "base/rand_util.h"
11 #include "base/logging.h"
13 #include "base/sha1.h"
14 #include "base/strings/string_number_conversions.h"
15 #include "base/strings/stringprintf.h"
16 #include "base/sys_byteorder.h"
17 #include "net/server/http_connection.h"
18 #include "net/server/http_server_request_info.h"
19 #include "net/server/http_server_response_info.h"
25 static uint32
WebSocketKeyFingerprint(const std::string
& str
) {
27 const char* p_char
= str
.c_str();
28 int length
= str
.length();
30 for (int i
= 0; i
< length
; ++i
) {
31 if (p_char
[i
] >= '0' && p_char
[i
] <= '9')
32 result
.append(&p_char
[i
], 1);
33 else if (p_char
[i
] == ' ')
39 if (!base::StringToInt64(result
, &number
))
41 return base::HostToNet32(static_cast<uint32
>(number
/ spaces
));
44 class WebSocketHixie76
: public net::WebSocket
{
46 static net::WebSocket
* Create(HttpConnection
* connection
,
47 const HttpServerRequestInfo
& request
,
49 if (connection
->recv_data().length() < *pos
+ kWebSocketHandshakeBodyLen
)
51 return new WebSocketHixie76(connection
, request
, pos
);
54 virtual void Accept(const HttpServerRequestInfo
& request
) OVERRIDE
{
55 std::string key1
= request
.GetHeaderValue("sec-websocket-key1");
56 std::string key2
= request
.GetHeaderValue("sec-websocket-key2");
58 uint32 fp1
= WebSocketKeyFingerprint(key1
);
59 uint32 fp2
= WebSocketKeyFingerprint(key2
);
62 memcpy(data
, &fp1
, 4);
63 memcpy(data
+ 4, &fp2
, 4);
64 memcpy(data
+ 8, &key3_
[0], 8);
66 base::MD5Digest digest
;
67 base::MD5Sum(data
, 16, &digest
);
69 std::string origin
= request
.GetHeaderValue("origin");
70 std::string host
= request
.GetHeaderValue("host");
71 std::string location
= "ws://" + host
+ request
.path
;
72 connection_
->Send(base::StringPrintf(
73 "HTTP/1.1 101 WebSocket Protocol Handshake\r\n"
74 "Upgrade: WebSocket\r\n"
75 "Connection: Upgrade\r\n"
76 "Sec-WebSocket-Origin: %s\r\n"
77 "Sec-WebSocket-Location: %s\r\n"
81 connection_
->Send(reinterpret_cast<char*>(digest
.a
), 16);
84 virtual ParseResult
Read(std::string
* message
) OVERRIDE
{
86 const std::string
& data
= connection_
->recv_data();
90 size_t pos
= data
.find('\377', 1);
91 if (pos
== std::string::npos
)
92 return FRAME_INCOMPLETE
;
94 std::string
buffer(data
.begin() + 1, data
.begin() + pos
);
95 message
->swap(buffer
);
96 connection_
->Shift(pos
+ 1);
101 virtual void Send(const std::string
& message
) OVERRIDE
{
102 char message_start
= 0;
103 char message_end
= -1;
104 connection_
->Send(&message_start
, 1);
105 connection_
->Send(message
);
106 connection_
->Send(&message_end
, 1);
110 static const int kWebSocketHandshakeBodyLen
;
112 WebSocketHixie76(HttpConnection
* connection
,
113 const HttpServerRequestInfo
& request
,
114 size_t* pos
) : WebSocket(connection
) {
115 std::string key1
= request
.GetHeaderValue("sec-websocket-key1");
116 std::string key2
= request
.GetHeaderValue("sec-websocket-key2");
119 connection
->Send(HttpServerResponseInfo::CreateFor500(
120 "Invalid request format. Sec-WebSocket-Key1 is empty or isn't "
126 connection
->Send(HttpServerResponseInfo::CreateFor500(
127 "Invalid request format. Sec-WebSocket-Key2 is empty or isn't "
132 key3_
= connection
->recv_data().substr(
134 *pos
+ kWebSocketHandshakeBodyLen
);
135 *pos
+= kWebSocketHandshakeBodyLen
;
140 DISALLOW_COPY_AND_ASSIGN(WebSocketHixie76
);
143 const int WebSocketHixie76::kWebSocketHandshakeBodyLen
= 8;
146 // Constants for hybi-10 frame format.
150 const OpCode kOpCodeContinuation
= 0x0;
151 const OpCode kOpCodeText
= 0x1;
152 const OpCode kOpCodeBinary
= 0x2;
153 const OpCode kOpCodeClose
= 0x8;
154 const OpCode kOpCodePing
= 0x9;
155 const OpCode kOpCodePong
= 0xA;
157 const unsigned char kFinalBit
= 0x80;
158 const unsigned char kReserved1Bit
= 0x40;
159 const unsigned char kReserved2Bit
= 0x20;
160 const unsigned char kReserved3Bit
= 0x10;
161 const unsigned char kOpCodeMask
= 0xF;
162 const unsigned char kMaskBit
= 0x80;
163 const unsigned char kPayloadLengthMask
= 0x7F;
165 const size_t kMaxSingleBytePayloadLength
= 125;
166 const size_t kTwoBytePayloadLengthField
= 126;
167 const size_t kEightBytePayloadLengthField
= 127;
168 const size_t kMaskingKeyWidthInBytes
= 4;
170 class WebSocketHybi17
: public WebSocket
{
172 static WebSocket
* Create(HttpConnection
* connection
,
173 const HttpServerRequestInfo
& request
,
175 std::string version
= request
.GetHeaderValue("sec-websocket-version");
176 if (version
!= "8" && version
!= "13")
179 std::string key
= request
.GetHeaderValue("sec-websocket-key");
181 connection
->Send(HttpServerResponseInfo::CreateFor500(
182 "Invalid request format. Sec-WebSocket-Key is empty or isn't "
186 return new WebSocketHybi17(connection
, request
, pos
);
189 virtual void Accept(const HttpServerRequestInfo
& request
) OVERRIDE
{
190 static const char* const kWebSocketGuid
=
191 "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
192 std::string key
= request
.GetHeaderValue("sec-websocket-key");
193 std::string data
= base::StringPrintf("%s%s", key
.c_str(), kWebSocketGuid
);
194 std::string encoded_hash
;
195 base::Base64Encode(base::SHA1HashString(data
), &encoded_hash
);
197 std::string response
= base::StringPrintf(
198 "HTTP/1.1 101 WebSocket Protocol Handshake\r\n"
199 "Upgrade: WebSocket\r\n"
200 "Connection: Upgrade\r\n"
201 "Sec-WebSocket-Accept: %s\r\n"
203 encoded_hash
.c_str());
204 connection_
->Send(response
);
207 virtual ParseResult
Read(std::string
* message
) OVERRIDE
{
208 const std::string
& frame
= connection_
->recv_data();
209 int bytes_consumed
= 0;
212 WebSocket::DecodeFrameHybi17(frame
, true, &bytes_consumed
, message
);
213 if (result
== FRAME_OK
)
214 connection_
->Shift(bytes_consumed
);
215 if (result
== FRAME_CLOSE
)
220 virtual void Send(const std::string
& message
) OVERRIDE
{
223 std::string data
= WebSocket::EncodeFrameHybi17(message
, 0);
224 connection_
->Send(data
);
228 WebSocketHybi17(HttpConnection
* connection
,
229 const HttpServerRequestInfo
& request
,
231 : WebSocket(connection
),
250 const char* payload_
;
251 size_t payload_length_
;
252 const char* frame_end_
;
255 DISALLOW_COPY_AND_ASSIGN(WebSocketHybi17
);
258 } // anonymous namespace
260 WebSocket
* WebSocket::CreateWebSocket(HttpConnection
* connection
,
261 const HttpServerRequestInfo
& request
,
263 WebSocket
* socket
= WebSocketHybi17::Create(connection
, request
, pos
);
267 return WebSocketHixie76::Create(connection
, request
, pos
);
271 WebSocket::ParseResult
WebSocket::DecodeFrameHybi17(const std::string
& frame
,
274 std::string
* output
) {
275 size_t data_length
= frame
.length();
277 return FRAME_INCOMPLETE
;
279 const char* buffer_begin
= const_cast<char*>(frame
.data());
280 const char* p
= buffer_begin
;
281 const char* buffer_end
= p
+ data_length
;
283 unsigned char first_byte
= *p
++;
284 unsigned char second_byte
= *p
++;
286 bool final
= (first_byte
& kFinalBit
) != 0;
287 bool reserved1
= (first_byte
& kReserved1Bit
) != 0;
288 bool reserved2
= (first_byte
& kReserved2Bit
) != 0;
289 bool reserved3
= (first_byte
& kReserved3Bit
) != 0;
290 int op_code
= first_byte
& kOpCodeMask
;
291 bool masked
= (second_byte
& kMaskBit
) != 0;
292 if (!final
|| reserved1
|| reserved2
|| reserved3
)
293 return FRAME_ERROR
; // Extensions and not supported.
302 case kOpCodeBinary
: // We don't support binary frames yet.
303 case kOpCodeContinuation
: // We don't support binary frames yet.
304 case kOpCodePing
: // We don't support binary frames yet.
305 case kOpCodePong
: // We don't support binary frames yet.
310 if (client_frame
&& !masked
) // In Hybi-17 spec client MUST mask his frame.
313 uint64 payload_length64
= second_byte
& kPayloadLengthMask
;
314 if (payload_length64
> kMaxSingleBytePayloadLength
) {
315 int extended_payload_length_size
;
316 if (payload_length64
== kTwoBytePayloadLengthField
)
317 extended_payload_length_size
= 2;
319 DCHECK(payload_length64
== kEightBytePayloadLengthField
);
320 extended_payload_length_size
= 8;
322 if (buffer_end
- p
< extended_payload_length_size
)
323 return FRAME_INCOMPLETE
;
324 payload_length64
= 0;
325 for (int i
= 0; i
< extended_payload_length_size
; ++i
) {
326 payload_length64
<<= 8;
327 payload_length64
|= static_cast<unsigned char>(*p
++);
331 size_t actual_masking_key_length
= masked
? kMaskingKeyWidthInBytes
: 0;
332 static const uint64 max_payload_length
= 0x7FFFFFFFFFFFFFFFull
;
333 static size_t max_length
= std::numeric_limits
<size_t>::max();
334 if (payload_length64
> max_payload_length
||
335 payload_length64
+ actual_masking_key_length
> max_length
) {
336 // WebSocket frame length too large.
339 size_t payload_length
= static_cast<size_t>(payload_length64
);
341 size_t total_length
= actual_masking_key_length
+ payload_length
;
342 if (static_cast<size_t>(buffer_end
- p
) < total_length
)
343 return FRAME_INCOMPLETE
;
346 output
->resize(payload_length
);
347 const char* masking_key
= p
;
348 char* payload
= const_cast<char*>(p
+ kMaskingKeyWidthInBytes
);
349 for (size_t i
= 0; i
< payload_length
; ++i
) // Unmask the payload.
350 (*output
)[i
] = payload
[i
] ^ masking_key
[i
% kMaskingKeyWidthInBytes
];
352 std::string
buffer(p
, p
+ payload_length
);
353 output
->swap(buffer
);
356 size_t pos
= p
+ actual_masking_key_length
+ payload_length
- buffer_begin
;
357 *bytes_consumed
= pos
;
358 return closed
? FRAME_CLOSE
: FRAME_OK
;
362 std::string
WebSocket::EncodeFrameHybi17(const std::string
& message
,
364 std::vector
<char> frame
;
365 OpCode op_code
= kOpCodeText
;
366 size_t data_length
= message
.length();
368 frame
.push_back(kFinalBit
| op_code
);
369 char mask_key_bit
= masking_key
!= 0 ? kMaskBit
: 0;
370 if (data_length
<= kMaxSingleBytePayloadLength
)
371 frame
.push_back(data_length
| mask_key_bit
);
372 else if (data_length
<= 0xFFFF) {
373 frame
.push_back(kTwoBytePayloadLengthField
| mask_key_bit
);
374 frame
.push_back((data_length
& 0xFF00) >> 8);
375 frame
.push_back(data_length
& 0xFF);
377 frame
.push_back(kEightBytePayloadLengthField
| mask_key_bit
);
378 char extended_payload_length
[8];
379 size_t remaining
= data_length
;
380 // Fill the length into extended_payload_length in the network byte order.
381 for (int i
= 0; i
< 8; ++i
) {
382 extended_payload_length
[7 - i
] = remaining
& 0xFF;
385 frame
.insert(frame
.end(),
386 extended_payload_length
,
387 extended_payload_length
+ 8);
391 const char* data
= const_cast<char*>(message
.data());
392 if (masking_key
!= 0) {
393 const char* mask_bytes
= reinterpret_cast<char*>(&masking_key
);
394 frame
.insert(frame
.end(), mask_bytes
, mask_bytes
+ 4);
395 for (size_t i
= 0; i
< data_length
; ++i
) // Mask the payload.
396 frame
.push_back(data
[i
] ^ mask_bytes
[i
% kMaskingKeyWidthInBytes
]);
398 frame
.insert(frame
.end(), data
, data
+ data_length
);
400 return std::string(&frame
[0], frame
.size());
403 WebSocket::WebSocket(HttpConnection
* connection
) : connection_(connection
) {