Add some handle read/write helpers to mojo/common/test/test_utils.h
[chromium-blink-merge.git] / mojo / system / raw_channel_posix_unittest.cc
blob5af66c381763e19381fee247c1a0576976263fa0
1 // Copyright 2013 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
5 // TODO(vtl): Factor out the remaining POSIX-specific bits of this test (once we
6 // have a non-POSIX implementation).
8 #include "mojo/system/raw_channel.h"
10 #include <sys/socket.h>
12 #include <vector>
14 #include "base/basictypes.h"
15 #include "base/bind.h"
16 #include "base/compiler_specific.h"
17 #include "base/location.h"
18 #include "base/logging.h"
19 #include "base/memory/scoped_ptr.h"
20 #include "base/memory/scoped_vector.h"
21 #include "base/message_loop/message_loop.h"
22 #include "base/rand_util.h"
23 #include "base/synchronization/lock.h"
24 #include "base/synchronization/waitable_event.h"
25 #include "base/threading/platform_thread.h" // For |Sleep()|.
26 #include "base/threading/simple_thread.h"
27 #include "base/time/time.h"
28 #include "mojo/common/test/test_utils.h"
29 #include "mojo/system/embedder/platform_channel_pair.h"
30 #include "mojo/system/embedder/platform_handle.h"
31 #include "mojo/system/embedder/scoped_platform_handle.h"
32 #include "mojo/system/message_in_transit.h"
33 #include "mojo/system/test_utils.h"
35 namespace mojo {
36 namespace system {
37 namespace {
39 scoped_ptr<MessageInTransit> MakeTestMessage(uint32_t num_bytes) {
40 std::vector<unsigned char> bytes(num_bytes, 0);
41 for (size_t i = 0; i < num_bytes; i++)
42 bytes[i] = static_cast<unsigned char>(i + num_bytes);
43 return make_scoped_ptr(
44 new MessageInTransit(MessageInTransit::OWNED_BUFFER,
45 MessageInTransit::kTypeMessagePipeEndpoint,
46 MessageInTransit::kSubtypeMessagePipeEndpointData,
47 num_bytes, 0, bytes.data()));
50 bool CheckMessageData(const void* bytes, uint32_t num_bytes) {
51 const unsigned char* b = static_cast<const unsigned char*>(bytes);
52 for (uint32_t i = 0; i < num_bytes; i++) {
53 if (b[i] != static_cast<unsigned char>(i + num_bytes))
54 return false;
56 return true;
59 void InitOnIOThread(RawChannel* raw_channel) {
60 CHECK(raw_channel->Init());
63 bool WriteTestMessageToHandle(const embedder::PlatformHandle& handle,
64 uint32_t num_bytes) {
65 scoped_ptr<MessageInTransit> message(MakeTestMessage(num_bytes));
67 size_t write_size = 0;
68 mojo::test::BlockingWrite(
69 handle, message->main_buffer(), message->main_buffer_size(), &write_size);
70 return write_size == message->main_buffer_size();
73 // -----------------------------------------------------------------------------
75 class RawChannelPosixTest : public test::TestWithIOThreadBase {
76 public:
77 RawChannelPosixTest() {}
78 virtual ~RawChannelPosixTest() {}
80 virtual void SetUp() OVERRIDE {
81 test::TestWithIOThreadBase::SetUp();
83 embedder::PlatformChannelPair channel_pair;
84 handles[0] = channel_pair.PassServerHandle();
85 handles[1] = channel_pair.PassClientHandle();
88 virtual void TearDown() OVERRIDE {
89 handles[0].reset();
90 handles[1].reset();
92 test::TestWithIOThreadBase::TearDown();
95 protected:
96 embedder::ScopedPlatformHandle handles[2];
98 private:
99 DISALLOW_COPY_AND_ASSIGN(RawChannelPosixTest);
102 // RawChannelPosixTest.WriteMessage --------------------------------------------
104 class WriteOnlyRawChannelDelegate : public RawChannel::Delegate {
105 public:
106 WriteOnlyRawChannelDelegate() {}
107 virtual ~WriteOnlyRawChannelDelegate() {}
109 // |RawChannel::Delegate| implementation:
110 virtual void OnReadMessage(const MessageInTransit& /*message*/) OVERRIDE {
111 NOTREACHED();
113 virtual void OnFatalError(FatalError /*fatal_error*/) OVERRIDE {
114 NOTREACHED();
117 private:
118 DISALLOW_COPY_AND_ASSIGN(WriteOnlyRawChannelDelegate);
121 static const int64_t kMessageReaderSleepMs = 1;
122 static const size_t kMessageReaderMaxPollIterations = 3000;
124 class TestMessageReaderAndChecker {
125 public:
126 explicit TestMessageReaderAndChecker(embedder::PlatformHandle handle)
127 : handle_(handle) {}
128 ~TestMessageReaderAndChecker() { CHECK(bytes_.empty()); }
130 bool ReadAndCheckNextMessage(uint32_t expected_size) {
131 unsigned char buffer[4096];
133 for (size_t i = 0; i < kMessageReaderMaxPollIterations;) {
134 size_t read_size = 0;
135 CHECK(mojo::test::NonBlockingRead(handle_, buffer, sizeof(buffer),
136 &read_size));
138 // Append newly-read data to |bytes_|.
139 bytes_.insert(bytes_.end(), buffer, buffer + read_size);
141 // If we have the header....
142 size_t message_size;
143 if (MessageInTransit::GetNextMessageSize(bytes_.data(), bytes_.size(),
144 &message_size)) {
145 // If we've read the whole message....
146 if (bytes_.size() >= message_size) {
147 bool rv = true;
148 MessageInTransit message(MessageInTransit::UNOWNED_BUFFER,
149 message_size, bytes_.data());
150 CHECK_EQ(message.main_buffer_size(), message_size);
152 if (message.num_bytes() != expected_size) {
153 LOG(ERROR) << "Wrong size: " << message_size << " instead of "
154 << expected_size << " bytes.";
155 rv = false;
156 } else if (!CheckMessageData(message.bytes(),
157 message.num_bytes())) {
158 LOG(ERROR) << "Incorrect message bytes.";
159 rv = false;
162 // Erase message data.
163 bytes_.erase(bytes_.begin(),
164 bytes_.begin() +
165 message.main_buffer_size());
166 return rv;
170 if (static_cast<size_t>(read_size) < sizeof(buffer)) {
171 i++;
172 base::PlatformThread::Sleep(
173 base::TimeDelta::FromMilliseconds(kMessageReaderSleepMs));
177 LOG(ERROR) << "Too many iterations.";
178 return false;
181 private:
182 const embedder::PlatformHandle handle_;
184 // The start of the received data should always be on a message boundary.
185 std::vector<unsigned char> bytes_;
187 DISALLOW_COPY_AND_ASSIGN(TestMessageReaderAndChecker);
190 // Tests writing (and verifies reading using our own custom reader).
191 TEST_F(RawChannelPosixTest, WriteMessage) {
192 WriteOnlyRawChannelDelegate delegate;
193 scoped_ptr<RawChannel> rc(RawChannel::Create(handles[0].Pass(),
194 &delegate,
195 io_thread_message_loop()));
197 TestMessageReaderAndChecker checker(handles[1].get());
199 test::PostTaskAndWait(io_thread_task_runner(),
200 FROM_HERE,
201 base::Bind(&InitOnIOThread, rc.get()));
203 // Write and read, for a variety of sizes.
204 for (uint32_t size = 1; size < 5 * 1000 * 1000; size += size / 2 + 1) {
205 EXPECT_TRUE(rc->WriteMessage(MakeTestMessage(size)));
206 EXPECT_TRUE(checker.ReadAndCheckNextMessage(size)) << size;
209 // Write/queue and read afterwards, for a variety of sizes.
210 for (uint32_t size = 1; size < 5 * 1000 * 1000; size += size / 2 + 1)
211 EXPECT_TRUE(rc->WriteMessage(MakeTestMessage(size)));
212 for (uint32_t size = 1; size < 5 * 1000 * 1000; size += size / 2 + 1)
213 EXPECT_TRUE(checker.ReadAndCheckNextMessage(size)) << size;
215 test::PostTaskAndWait(io_thread_task_runner(),
216 FROM_HERE,
217 base::Bind(&RawChannel::Shutdown,
218 base::Unretained(rc.get())));
221 // RawChannelPosixTest.OnReadMessage -------------------------------------------
223 class ReadCheckerRawChannelDelegate : public RawChannel::Delegate {
224 public:
225 ReadCheckerRawChannelDelegate()
226 : done_event_(false, false),
227 position_(0) {}
228 virtual ~ReadCheckerRawChannelDelegate() {}
230 // |RawChannel::Delegate| implementation (called on the I/O thread):
231 virtual void OnReadMessage(const MessageInTransit& message) OVERRIDE {
232 size_t position;
233 size_t expected_size;
234 bool should_signal = false;
236 base::AutoLock locker(lock_);
237 CHECK_LT(position_, expected_sizes_.size());
238 position = position_;
239 expected_size = expected_sizes_[position];
240 position_++;
241 if (position_ >= expected_sizes_.size())
242 should_signal = true;
245 EXPECT_EQ(expected_size, message.num_bytes()) << position;
246 if (message.num_bytes() == expected_size) {
247 EXPECT_TRUE(CheckMessageData(message.bytes(), message.num_bytes()))
248 << position;
251 if (should_signal)
252 done_event_.Signal();
254 virtual void OnFatalError(FatalError /*fatal_error*/) OVERRIDE {
255 NOTREACHED();
258 // Wait for all the messages (of sizes |expected_sizes_|) to be seen.
259 void Wait() {
260 done_event_.Wait();
263 void SetExpectedSizes(const std::vector<uint32_t>& expected_sizes) {
264 base::AutoLock locker(lock_);
265 CHECK_EQ(position_, expected_sizes_.size());
266 expected_sizes_ = expected_sizes;
267 position_ = 0;
270 private:
271 base::WaitableEvent done_event_;
273 base::Lock lock_; // Protects the following members.
274 std::vector<uint32_t> expected_sizes_;
275 size_t position_;
277 DISALLOW_COPY_AND_ASSIGN(ReadCheckerRawChannelDelegate);
280 // Tests reading (writing using our own custom writer).
281 TEST_F(RawChannelPosixTest, OnReadMessage) {
282 ReadCheckerRawChannelDelegate delegate;
283 scoped_ptr<RawChannel> rc(RawChannel::Create(handles[0].Pass(),
284 &delegate,
285 io_thread_message_loop()));
287 test::PostTaskAndWait(io_thread_task_runner(),
288 FROM_HERE,
289 base::Bind(&InitOnIOThread, rc.get()));
291 // Write and read, for a variety of sizes.
292 for (uint32_t size = 1; size < 5 * 1000 * 1000; size += size / 2 + 1) {
293 delegate.SetExpectedSizes(std::vector<uint32_t>(1, size));
295 EXPECT_TRUE(WriteTestMessageToHandle(handles[1].get(), size));
297 delegate.Wait();
300 // Set up reader and write as fast as we can.
301 // Write/queue and read afterwards, for a variety of sizes.
302 std::vector<uint32_t> expected_sizes;
303 for (uint32_t size = 1; size < 5 * 1000 * 1000; size += size / 2 + 1)
304 expected_sizes.push_back(size);
305 delegate.SetExpectedSizes(expected_sizes);
306 for (uint32_t size = 1; size < 5 * 1000 * 1000; size += size / 2 + 1)
307 EXPECT_TRUE(WriteTestMessageToHandle(handles[1].get(), size));
308 delegate.Wait();
310 test::PostTaskAndWait(io_thread_task_runner(),
311 FROM_HERE,
312 base::Bind(&RawChannel::Shutdown,
313 base::Unretained(rc.get())));
316 // RawChannelPosixTest.WriteMessageAndOnReadMessage ----------------------------
318 class RawChannelWriterThread : public base::SimpleThread {
319 public:
320 RawChannelWriterThread(RawChannel* raw_channel, size_t write_count)
321 : base::SimpleThread("raw_channel_writer_thread"),
322 raw_channel_(raw_channel),
323 left_to_write_(write_count) {
326 virtual ~RawChannelWriterThread() {
327 Join();
330 private:
331 virtual void Run() OVERRIDE {
332 static const int kMaxRandomMessageSize = 25000;
334 while (left_to_write_-- > 0) {
335 EXPECT_TRUE(raw_channel_->WriteMessage(MakeTestMessage(
336 static_cast<uint32_t>(base::RandInt(1, kMaxRandomMessageSize)))));
340 RawChannel* const raw_channel_;
341 size_t left_to_write_;
343 DISALLOW_COPY_AND_ASSIGN(RawChannelWriterThread);
346 class ReadCountdownRawChannelDelegate : public RawChannel::Delegate {
347 public:
348 explicit ReadCountdownRawChannelDelegate(size_t expected_count)
349 : done_event_(false, false),
350 expected_count_(expected_count),
351 count_(0) {}
352 virtual ~ReadCountdownRawChannelDelegate() {}
354 // |RawChannel::Delegate| implementation (called on the I/O thread):
355 virtual void OnReadMessage(const MessageInTransit& message) OVERRIDE {
356 EXPECT_LT(count_, expected_count_);
357 count_++;
359 EXPECT_TRUE(CheckMessageData(message.bytes(), message.num_bytes()));
361 if (count_ >= expected_count_)
362 done_event_.Signal();
364 virtual void OnFatalError(FatalError /*fatal_error*/) OVERRIDE {
365 NOTREACHED();
368 // Wait for all the messages to have been seen.
369 void Wait() {
370 done_event_.Wait();
373 private:
374 base::WaitableEvent done_event_;
375 size_t expected_count_;
376 size_t count_;
378 DISALLOW_COPY_AND_ASSIGN(ReadCountdownRawChannelDelegate);
381 TEST_F(RawChannelPosixTest, WriteMessageAndOnReadMessage) {
382 static const size_t kNumWriterThreads = 10;
383 static const size_t kNumWriteMessagesPerThread = 4000;
385 WriteOnlyRawChannelDelegate writer_delegate;
386 scoped_ptr<RawChannel> writer_rc(
387 RawChannel::Create(handles[0].Pass(),
388 &writer_delegate,
389 io_thread_message_loop()));
391 test::PostTaskAndWait(io_thread_task_runner(),
392 FROM_HERE,
393 base::Bind(&InitOnIOThread, writer_rc.get()));
395 ReadCountdownRawChannelDelegate reader_delegate(
396 kNumWriterThreads * kNumWriteMessagesPerThread);
397 scoped_ptr<RawChannel> reader_rc(
398 RawChannel::Create(handles[1].Pass(),
399 &reader_delegate,
400 io_thread_message_loop()));
402 test::PostTaskAndWait(io_thread_task_runner(),
403 FROM_HERE,
404 base::Bind(&InitOnIOThread, reader_rc.get()));
407 ScopedVector<RawChannelWriterThread> writer_threads;
408 for (size_t i = 0; i < kNumWriterThreads; i++) {
409 writer_threads.push_back(new RawChannelWriterThread(
410 writer_rc.get(), kNumWriteMessagesPerThread));
412 for (size_t i = 0; i < writer_threads.size(); i++)
413 writer_threads[i]->Start();
414 } // Joins all the writer threads.
416 // Sleep a bit, to let any extraneous reads be processed. (There shouldn't be
417 // any, but we want to know about them.)
418 base::PlatformThread::Sleep(base::TimeDelta::FromMilliseconds(100));
420 // Wait for reading to finish.
421 reader_delegate.Wait();
423 test::PostTaskAndWait(io_thread_task_runner(),
424 FROM_HERE,
425 base::Bind(&RawChannel::Shutdown,
426 base::Unretained(reader_rc.get())));
428 test::PostTaskAndWait(io_thread_task_runner(),
429 FROM_HERE,
430 base::Bind(&RawChannel::Shutdown,
431 base::Unretained(writer_rc.get())));
434 // RawChannelPosixTest.OnFatalError --------------------------------------------
436 class FatalErrorRecordingRawChannelDelegate
437 : public ReadCountdownRawChannelDelegate {
438 public:
439 FatalErrorRecordingRawChannelDelegate(size_t expected_read_count_)
440 : ReadCountdownRawChannelDelegate(expected_read_count_),
441 got_fatal_error_event_(false, false),
442 on_fatal_error_call_count_(0),
443 last_fatal_error_(FATAL_ERROR_UNKNOWN) {}
445 virtual ~FatalErrorRecordingRawChannelDelegate() {}
447 virtual void OnFatalError(FatalError fatal_error) OVERRIDE {
448 CHECK_EQ(on_fatal_error_call_count_, 0);
449 on_fatal_error_call_count_++;
450 last_fatal_error_ = fatal_error;
451 got_fatal_error_event_.Signal();
454 FatalError WaitForFatalError() {
455 got_fatal_error_event_.Wait();
456 CHECK_EQ(on_fatal_error_call_count_, 1);
457 return last_fatal_error_;
460 private:
461 base::WaitableEvent got_fatal_error_event_;
463 int on_fatal_error_call_count_;
464 FatalError last_fatal_error_;
466 DISALLOW_COPY_AND_ASSIGN(FatalErrorRecordingRawChannelDelegate);
469 // Tests fatal errors.
470 // TODO(vtl): Figure out how to make reading fail reliably. (I'm not convinced
471 // that it does.)
472 TEST_F(RawChannelPosixTest, OnFatalError) {
473 const size_t kMessageCount = 5;
475 FatalErrorRecordingRawChannelDelegate delegate(2 * kMessageCount);
476 scoped_ptr<RawChannel> rc(RawChannel::Create(handles[0].Pass(),
477 &delegate,
478 io_thread_message_loop()));
480 test::PostTaskAndWait(io_thread_task_runner(),
481 FROM_HERE,
482 base::Bind(&InitOnIOThread, rc.get()));
484 // Write into the other end a few messages.
485 uint32_t message_size = 1;
486 for (size_t count = 0; count < kMessageCount;
487 ++count, message_size += message_size / 2 + 1) {
488 EXPECT_TRUE(WriteTestMessageToHandle(handles[1].get(), message_size));
491 // Shut down read at the other end, which should make writing fail.
492 EXPECT_EQ(0, shutdown(handles[1].get().fd, SHUT_RD));
494 EXPECT_FALSE(rc->WriteMessage(MakeTestMessage(1)));
496 // TODO(vtl): In theory, it's conceivable that closing the other end might
497 // lead to read failing. In practice, it doesn't seem to.
498 EXPECT_EQ(RawChannel::Delegate::FATAL_ERROR_FAILED_WRITE,
499 delegate.WaitForFatalError());
501 EXPECT_FALSE(rc->WriteMessage(MakeTestMessage(2)));
503 // Sleep a bit, to make sure we don't get another |OnFatalError()|
504 // notification. (If we actually get another one, |OnFatalError()| crashes.)
505 base::PlatformThread::Sleep(base::TimeDelta::FromMilliseconds(100));
507 // Write into the other end a few more messages.
508 for (size_t count = 0; count < kMessageCount;
509 ++count, message_size += message_size / 2 + 1) {
510 EXPECT_TRUE(WriteTestMessageToHandle(handles[1].get(), message_size));
512 // Wait for reading to finish. A writing failure shouldn't affect reading.
513 delegate.Wait();
515 test::PostTaskAndWait(io_thread_task_runner(),
516 FROM_HERE,
517 base::Bind(&RawChannel::Shutdown,
518 base::Unretained(rc.get())));
521 // RawChannelPosixTest.WriteMessageAfterShutdown -------------------------------
523 // Makes sure that calling |WriteMessage()| after |Shutdown()| behaves
524 // correctly.
525 TEST_F(RawChannelPosixTest, WriteMessageAfterShutdown) {
526 WriteOnlyRawChannelDelegate delegate;
527 scoped_ptr<RawChannel> rc(RawChannel::Create(handles[0].Pass(),
528 &delegate,
529 io_thread_message_loop()));
531 test::PostTaskAndWait(io_thread_task_runner(),
532 FROM_HERE,
533 base::Bind(&InitOnIOThread, rc.get()));
534 test::PostTaskAndWait(io_thread_task_runner(),
535 FROM_HERE,
536 base::Bind(&RawChannel::Shutdown,
537 base::Unretained(rc.get())));
539 EXPECT_FALSE(rc->WriteMessage(MakeTestMessage(1)));
542 } // namespace
543 } // namespace system
544 } // namespace mojo