1 // Copyright 2013 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
5 #include <sys/socket.h>
8 #include "base/files/file_path.h"
9 #include "base/location.h"
10 #include "base/path_service.h"
11 #include "base/posix/eintr_wrapper.h"
12 #include "base/single_thread_task_runner.h"
13 #include "base/synchronization/waitable_event.h"
14 #include "base/threading/thread.h"
15 #include "base/threading/thread_restrictions.h"
16 #include "ipc/unix_domain_socket_util.h"
17 #include "testing/gtest/include/gtest/gtest.h"
21 class SocketAcceptor
: public base::MessageLoopForIO::Watcher
{
23 SocketAcceptor(int fd
, base::SingleThreadTaskRunner
* target_thread
)
25 target_thread_(target_thread
),
26 started_watching_event_(false, false),
27 accepted_event_(false, false) {
28 target_thread
->PostTask(FROM_HERE
,
29 base::Bind(&SocketAcceptor::StartWatching
, base::Unretained(this), fd
));
32 ~SocketAcceptor() override
{
36 int server_fd() const { return server_fd_
; }
38 void WaitUntilReady() {
39 started_watching_event_
.Wait();
42 void WaitForAccept() {
43 accepted_event_
.Wait();
48 target_thread_
->PostTask(FROM_HERE
,
49 base::Bind(&SocketAcceptor::StopWatching
, base::Unretained(this),
55 void StartWatching(int fd
) {
56 watcher_
.reset(new base::MessageLoopForIO::FileDescriptorWatcher
);
57 base::MessageLoopForIO::current()->WatchFileDescriptor(
58 fd
, true, base::MessageLoopForIO::WATCH_READ
, watcher_
.get(), this);
59 started_watching_event_
.Signal();
61 void StopWatching(base::MessageLoopForIO::FileDescriptorWatcher
* watcher
) {
62 watcher
->StopWatchingFileDescriptor();
65 void OnFileCanReadWithoutBlocking(int fd
) override
{
66 ASSERT_EQ(-1, server_fd_
);
67 IPC::ServerAcceptConnection(fd
, &server_fd_
);
68 watcher_
->StopWatchingFileDescriptor();
69 accepted_event_
.Signal();
71 void OnFileCanWriteWithoutBlocking(int fd
) override
{}
74 base::SingleThreadTaskRunner
* target_thread_
;
75 scoped_ptr
<base::MessageLoopForIO::FileDescriptorWatcher
> watcher_
;
76 base::WaitableEvent started_watching_event_
;
77 base::WaitableEvent accepted_event_
;
79 DISALLOW_COPY_AND_ASSIGN(SocketAcceptor
);
82 const base::FilePath
GetChannelDir() {
83 base::FilePath tmp_dir
;
84 PathService::Get(base::DIR_TEMP
, &tmp_dir
);
88 class TestUnixSocketConnection
{
90 TestUnixSocketConnection()
91 : worker_("WorkerThread"),
92 server_listen_fd_(-1),
95 socket_name_
= GetChannelDir().Append("TestSocket");
96 base::Thread::Options options
;
97 options
.message_loop_type
= base::MessageLoop::TYPE_IO
;
98 worker_
.StartWithOptions(options
);
101 bool CreateServerSocket() {
102 IPC::CreateServerUnixDomainSocket(socket_name_
, &server_listen_fd_
);
103 if (server_listen_fd_
< 0)
105 struct stat socket_stat
;
106 stat(socket_name_
.value().c_str(), &socket_stat
);
107 EXPECT_TRUE(S_ISSOCK(socket_stat
.st_mode
));
109 new SocketAcceptor(server_listen_fd_
, worker_
.task_runner().get()));
110 acceptor_
->WaitUntilReady();
114 bool CreateClientSocket() {
115 DCHECK(server_listen_fd_
>= 0);
116 IPC::CreateClientUnixDomainSocket(socket_name_
, &client_fd_
);
119 acceptor_
->WaitForAccept();
120 server_fd_
= acceptor_
->server_fd();
121 return server_fd_
>= 0;
124 virtual ~TestUnixSocketConnection() {
129 if (server_listen_fd_
>= 0) {
130 close(server_listen_fd_
);
131 unlink(socket_name_
.value().c_str());
135 int client_fd() const { return client_fd_
; }
136 int server_fd() const { return server_fd_
; }
139 base::Thread worker_
;
140 base::FilePath socket_name_
;
141 int server_listen_fd_
;
144 scoped_ptr
<SocketAcceptor
> acceptor_
;
147 // Ensure that IPC::CreateServerUnixDomainSocket creates a socket that
148 // IPC::CreateClientUnixDomainSocket can successfully connect to.
149 TEST(UnixDomainSocketUtil
, Connect
) {
150 TestUnixSocketConnection connection
;
151 ASSERT_TRUE(connection
.CreateServerSocket());
152 ASSERT_TRUE(connection
.CreateClientSocket());
155 // Ensure that messages can be sent across the resulting socket.
156 TEST(UnixDomainSocketUtil
, SendReceive
) {
157 TestUnixSocketConnection connection
;
158 ASSERT_TRUE(connection
.CreateServerSocket());
159 ASSERT_TRUE(connection
.CreateClientSocket());
161 const char buffer
[] = "Hello, server!";
162 size_t buf_len
= sizeof(buffer
);
164 HANDLE_EINTR(send(connection
.client_fd(), buffer
, buf_len
, 0));
165 ASSERT_EQ(buf_len
, sent_bytes
);
166 char recv_buf
[sizeof(buffer
)];
167 size_t received_bytes
=
168 HANDLE_EINTR(recv(connection
.server_fd(), recv_buf
, buf_len
, 0));
169 ASSERT_EQ(buf_len
, received_bytes
);
170 ASSERT_EQ(0, memcmp(recv_buf
, buffer
, buf_len
));