Merge Chromium + Blink git repositories
[chromium-blink-merge.git] / net / websockets / websocket_deflate_stream.cc
blobcf1690356537137d6353a4f214f5cb31c5073b07
1 // Copyright 2013 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
5 #include "net/websockets/websocket_deflate_stream.h"
7 #include <algorithm>
8 #include <string>
10 #include "base/bind.h"
11 #include "base/logging.h"
12 #include "base/memory/ref_counted.h"
13 #include "base/memory/scoped_ptr.h"
14 #include "base/memory/scoped_vector.h"
15 #include "net/base/completion_callback.h"
16 #include "net/base/io_buffer.h"
17 #include "net/base/net_errors.h"
18 #include "net/websockets/websocket_deflate_parameters.h"
19 #include "net/websockets/websocket_deflate_predictor.h"
20 #include "net/websockets/websocket_deflater.h"
21 #include "net/websockets/websocket_errors.h"
22 #include "net/websockets/websocket_frame.h"
23 #include "net/websockets/websocket_inflater.h"
24 #include "net/websockets/websocket_stream.h"
26 class GURL;
28 namespace net {
30 namespace {
32 const int kWindowBits = 15;
33 const size_t kChunkSize = 4 * 1024;
35 } // namespace
37 WebSocketDeflateStream::WebSocketDeflateStream(
38 scoped_ptr<WebSocketStream> stream,
39 const WebSocketDeflateParameters& params,
40 scoped_ptr<WebSocketDeflatePredictor> predictor)
41 : stream_(stream.Pass()),
42 deflater_(params.client_context_take_over_mode()),
43 inflater_(kChunkSize, kChunkSize),
44 reading_state_(NOT_READING),
45 writing_state_(NOT_WRITING),
46 current_reading_opcode_(WebSocketFrameHeader::kOpCodeText),
47 current_writing_opcode_(WebSocketFrameHeader::kOpCodeText),
48 predictor_(predictor.Pass()) {
49 DCHECK(stream_);
50 DCHECK(params.IsValidAsResponse());
51 int client_max_window_bits = 15;
52 if (params.is_client_max_window_bits_specified()) {
53 DCHECK(params.has_client_max_window_bits_value());
54 client_max_window_bits = params.client_max_window_bits();
56 deflater_.Initialize(client_max_window_bits);
57 inflater_.Initialize(kWindowBits);
60 WebSocketDeflateStream::~WebSocketDeflateStream() {}
62 int WebSocketDeflateStream::ReadFrames(ScopedVector<WebSocketFrame>* frames,
63 const CompletionCallback& callback) {
64 int result = stream_->ReadFrames(
65 frames,
66 base::Bind(&WebSocketDeflateStream::OnReadComplete,
67 base::Unretained(this),
68 base::Unretained(frames),
69 callback));
70 if (result < 0)
71 return result;
72 DCHECK_EQ(OK, result);
73 DCHECK(!frames->empty());
75 return InflateAndReadIfNecessary(frames, callback);
78 int WebSocketDeflateStream::WriteFrames(ScopedVector<WebSocketFrame>* frames,
79 const CompletionCallback& callback) {
80 int result = Deflate(frames);
81 if (result != OK)
82 return result;
83 if (frames->empty())
84 return OK;
85 return stream_->WriteFrames(frames, callback);
88 void WebSocketDeflateStream::Close() { stream_->Close(); }
90 std::string WebSocketDeflateStream::GetSubProtocol() const {
91 return stream_->GetSubProtocol();
94 std::string WebSocketDeflateStream::GetExtensions() const {
95 return stream_->GetExtensions();
98 void WebSocketDeflateStream::OnReadComplete(
99 ScopedVector<WebSocketFrame>* frames,
100 const CompletionCallback& callback,
101 int result) {
102 if (result != OK) {
103 frames->clear();
104 callback.Run(result);
105 return;
108 int r = InflateAndReadIfNecessary(frames, callback);
109 if (r != ERR_IO_PENDING)
110 callback.Run(r);
113 int WebSocketDeflateStream::Deflate(ScopedVector<WebSocketFrame>* frames) {
114 ScopedVector<WebSocketFrame> frames_to_write;
115 // Store frames of the currently processed message if writing_state_ equals to
116 // WRITING_POSSIBLY_COMPRESSED_MESSAGE.
117 ScopedVector<WebSocketFrame> frames_of_message;
118 for (size_t i = 0; i < frames->size(); ++i) {
119 DCHECK(!(*frames)[i]->header.reserved1);
120 if (!WebSocketFrameHeader::IsKnownDataOpCode((*frames)[i]->header.opcode)) {
121 frames_to_write.push_back((*frames)[i]);
122 (*frames)[i] = NULL;
123 continue;
125 if (writing_state_ == NOT_WRITING)
126 OnMessageStart(*frames, i);
128 scoped_ptr<WebSocketFrame> frame((*frames)[i]);
129 (*frames)[i] = NULL;
130 predictor_->RecordInputDataFrame(frame.get());
132 if (writing_state_ == WRITING_UNCOMPRESSED_MESSAGE) {
133 if (frame->header.final)
134 writing_state_ = NOT_WRITING;
135 predictor_->RecordWrittenDataFrame(frame.get());
136 frames_to_write.push_back(frame.Pass());
137 current_writing_opcode_ = WebSocketFrameHeader::kOpCodeContinuation;
138 } else {
139 if (frame->data.get() &&
140 !deflater_.AddBytes(
141 frame->data->data(),
142 static_cast<size_t>(frame->header.payload_length))) {
143 DVLOG(1) << "WebSocket protocol error. "
144 << "deflater_.AddBytes() returns an error.";
145 return ERR_WS_PROTOCOL_ERROR;
147 if (frame->header.final && !deflater_.Finish()) {
148 DVLOG(1) << "WebSocket protocol error. "
149 << "deflater_.Finish() returns an error.";
150 return ERR_WS_PROTOCOL_ERROR;
153 if (writing_state_ == WRITING_COMPRESSED_MESSAGE) {
154 if (deflater_.CurrentOutputSize() >= kChunkSize ||
155 frame->header.final) {
156 int result = AppendCompressedFrame(frame->header, &frames_to_write);
157 if (result != OK)
158 return result;
160 if (frame->header.final)
161 writing_state_ = NOT_WRITING;
162 } else {
163 DCHECK_EQ(WRITING_POSSIBLY_COMPRESSED_MESSAGE, writing_state_);
164 bool final = frame->header.final;
165 frames_of_message.push_back(frame.Pass());
166 if (final) {
167 int result = AppendPossiblyCompressedMessage(&frames_of_message,
168 &frames_to_write);
169 if (result != OK)
170 return result;
171 frames_of_message.clear();
172 writing_state_ = NOT_WRITING;
177 DCHECK_NE(WRITING_POSSIBLY_COMPRESSED_MESSAGE, writing_state_);
178 frames->swap(frames_to_write);
179 return OK;
182 void WebSocketDeflateStream::OnMessageStart(
183 const ScopedVector<WebSocketFrame>& frames, size_t index) {
184 WebSocketFrame* frame = frames[index];
185 current_writing_opcode_ = frame->header.opcode;
186 DCHECK(current_writing_opcode_ == WebSocketFrameHeader::kOpCodeText ||
187 current_writing_opcode_ == WebSocketFrameHeader::kOpCodeBinary);
188 WebSocketDeflatePredictor::Result prediction =
189 predictor_->Predict(frames, index);
191 switch (prediction) {
192 case WebSocketDeflatePredictor::DEFLATE:
193 writing_state_ = WRITING_COMPRESSED_MESSAGE;
194 return;
195 case WebSocketDeflatePredictor::DO_NOT_DEFLATE:
196 writing_state_ = WRITING_UNCOMPRESSED_MESSAGE;
197 return;
198 case WebSocketDeflatePredictor::TRY_DEFLATE:
199 writing_state_ = WRITING_POSSIBLY_COMPRESSED_MESSAGE;
200 return;
202 NOTREACHED();
205 int WebSocketDeflateStream::AppendCompressedFrame(
206 const WebSocketFrameHeader& header,
207 ScopedVector<WebSocketFrame>* frames_to_write) {
208 const WebSocketFrameHeader::OpCode opcode = current_writing_opcode_;
209 scoped_refptr<IOBufferWithSize> compressed_payload =
210 deflater_.GetOutput(deflater_.CurrentOutputSize());
211 if (!compressed_payload.get()) {
212 DVLOG(1) << "WebSocket protocol error. "
213 << "deflater_.GetOutput() returns an error.";
214 return ERR_WS_PROTOCOL_ERROR;
216 scoped_ptr<WebSocketFrame> compressed(new WebSocketFrame(opcode));
217 compressed->header.CopyFrom(header);
218 compressed->header.opcode = opcode;
219 compressed->header.final = header.final;
220 compressed->header.reserved1 =
221 (opcode != WebSocketFrameHeader::kOpCodeContinuation);
222 compressed->data = compressed_payload;
223 compressed->header.payload_length = compressed_payload->size();
225 current_writing_opcode_ = WebSocketFrameHeader::kOpCodeContinuation;
226 predictor_->RecordWrittenDataFrame(compressed.get());
227 frames_to_write->push_back(compressed.Pass());
228 return OK;
231 int WebSocketDeflateStream::AppendPossiblyCompressedMessage(
232 ScopedVector<WebSocketFrame>* frames,
233 ScopedVector<WebSocketFrame>* frames_to_write) {
234 DCHECK(!frames->empty());
236 const WebSocketFrameHeader::OpCode opcode = current_writing_opcode_;
237 scoped_refptr<IOBufferWithSize> compressed_payload =
238 deflater_.GetOutput(deflater_.CurrentOutputSize());
239 if (!compressed_payload.get()) {
240 DVLOG(1) << "WebSocket protocol error. "
241 << "deflater_.GetOutput() returns an error.";
242 return ERR_WS_PROTOCOL_ERROR;
245 uint64 original_payload_length = 0;
246 for (size_t i = 0; i < frames->size(); ++i) {
247 WebSocketFrame* frame = (*frames)[i];
248 // Asserts checking that frames represent one whole data message.
249 DCHECK(WebSocketFrameHeader::IsKnownDataOpCode(frame->header.opcode));
250 DCHECK_EQ(i == 0,
251 WebSocketFrameHeader::kOpCodeContinuation !=
252 frame->header.opcode);
253 DCHECK_EQ(i == frames->size() - 1, frame->header.final);
254 original_payload_length += frame->header.payload_length;
256 if (original_payload_length <=
257 static_cast<uint64>(compressed_payload->size())) {
258 // Compression is not effective. Use the original frames.
259 for (size_t i = 0; i < frames->size(); ++i) {
260 WebSocketFrame* frame = (*frames)[i];
261 frames_to_write->push_back(frame);
262 predictor_->RecordWrittenDataFrame(frame);
263 (*frames)[i] = NULL;
265 frames->weak_clear();
266 return OK;
268 scoped_ptr<WebSocketFrame> compressed(new WebSocketFrame(opcode));
269 compressed->header.CopyFrom((*frames)[0]->header);
270 compressed->header.opcode = opcode;
271 compressed->header.final = true;
272 compressed->header.reserved1 = true;
273 compressed->data = compressed_payload;
274 compressed->header.payload_length = compressed_payload->size();
276 predictor_->RecordWrittenDataFrame(compressed.get());
277 frames_to_write->push_back(compressed.Pass());
278 return OK;
281 int WebSocketDeflateStream::Inflate(ScopedVector<WebSocketFrame>* frames) {
282 ScopedVector<WebSocketFrame> frames_to_output;
283 ScopedVector<WebSocketFrame> frames_passed;
284 frames->swap(frames_passed);
285 for (size_t i = 0; i < frames_passed.size(); ++i) {
286 scoped_ptr<WebSocketFrame> frame(frames_passed[i]);
287 frames_passed[i] = NULL;
288 DVLOG(3) << "Input frame: opcode=" << frame->header.opcode
289 << " final=" << frame->header.final
290 << " reserved1=" << frame->header.reserved1
291 << " payload_length=" << frame->header.payload_length;
293 if (!WebSocketFrameHeader::IsKnownDataOpCode(frame->header.opcode)) {
294 frames_to_output.push_back(frame.Pass());
295 continue;
298 if (reading_state_ == NOT_READING) {
299 if (frame->header.reserved1)
300 reading_state_ = READING_COMPRESSED_MESSAGE;
301 else
302 reading_state_ = READING_UNCOMPRESSED_MESSAGE;
303 current_reading_opcode_ = frame->header.opcode;
304 } else {
305 if (frame->header.reserved1) {
306 DVLOG(1) << "WebSocket protocol error. "
307 << "Receiving a non-first frame with RSV1 flag set.";
308 return ERR_WS_PROTOCOL_ERROR;
312 if (reading_state_ == READING_UNCOMPRESSED_MESSAGE) {
313 if (frame->header.final)
314 reading_state_ = NOT_READING;
315 current_reading_opcode_ = WebSocketFrameHeader::kOpCodeContinuation;
316 frames_to_output.push_back(frame.Pass());
317 } else {
318 DCHECK_EQ(reading_state_, READING_COMPRESSED_MESSAGE);
319 if (frame->data.get() &&
320 !inflater_.AddBytes(
321 frame->data->data(),
322 static_cast<size_t>(frame->header.payload_length))) {
323 DVLOG(1) << "WebSocket protocol error. "
324 << "inflater_.AddBytes() returns an error.";
325 return ERR_WS_PROTOCOL_ERROR;
327 if (frame->header.final) {
328 if (!inflater_.Finish()) {
329 DVLOG(1) << "WebSocket protocol error. "
330 << "inflater_.Finish() returns an error.";
331 return ERR_WS_PROTOCOL_ERROR;
334 // TODO(yhirano): Many frames can be generated by the inflater and
335 // memory consumption can grow.
336 // We could avoid it, but avoiding it makes this class much more
337 // complicated.
338 while (inflater_.CurrentOutputSize() >= kChunkSize ||
339 frame->header.final) {
340 size_t size = std::min(kChunkSize, inflater_.CurrentOutputSize());
341 scoped_ptr<WebSocketFrame> inflated(
342 new WebSocketFrame(WebSocketFrameHeader::kOpCodeText));
343 scoped_refptr<IOBufferWithSize> data = inflater_.GetOutput(size);
344 bool is_final = !inflater_.CurrentOutputSize() && frame->header.final;
345 if (!data.get()) {
346 DVLOG(1) << "WebSocket protocol error. "
347 << "inflater_.GetOutput() returns an error.";
348 return ERR_WS_PROTOCOL_ERROR;
350 inflated->header.CopyFrom(frame->header);
351 inflated->header.opcode = current_reading_opcode_;
352 inflated->header.final = is_final;
353 inflated->header.reserved1 = false;
354 inflated->data = data;
355 inflated->header.payload_length = data->size();
356 DVLOG(3) << "Inflated frame: opcode=" << inflated->header.opcode
357 << " final=" << inflated->header.final
358 << " reserved1=" << inflated->header.reserved1
359 << " payload_length=" << inflated->header.payload_length;
360 frames_to_output.push_back(inflated.Pass());
361 current_reading_opcode_ = WebSocketFrameHeader::kOpCodeContinuation;
362 if (is_final)
363 break;
365 if (frame->header.final)
366 reading_state_ = NOT_READING;
369 frames->swap(frames_to_output);
370 return frames->empty() ? ERR_IO_PENDING : OK;
373 int WebSocketDeflateStream::InflateAndReadIfNecessary(
374 ScopedVector<WebSocketFrame>* frames,
375 const CompletionCallback& callback) {
376 int result = Inflate(frames);
377 while (result == ERR_IO_PENDING) {
378 DCHECK(frames->empty());
380 result = stream_->ReadFrames(
381 frames,
382 base::Bind(&WebSocketDeflateStream::OnReadComplete,
383 base::Unretained(this),
384 base::Unretained(frames),
385 callback));
386 if (result < 0)
387 break;
388 DCHECK_EQ(OK, result);
389 DCHECK(!frames->empty());
391 result = Inflate(frames);
393 if (result < 0)
394 frames->clear();
395 return result;
398 } // namespace net