1 // Copyright 2013 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
5 #include <sys/socket.h>
8 #include "base/files/file_path.h"
9 #include "base/path_service.h"
10 #include "base/posix/eintr_wrapper.h"
11 #include "base/synchronization/waitable_event.h"
12 #include "base/threading/thread.h"
13 #include "base/threading/thread_restrictions.h"
14 #include "ipc/unix_domain_socket_util.h"
15 #include "testing/gtest/include/gtest/gtest.h"
19 class SocketAcceptor
: public base::MessageLoopForIO::Watcher
{
21 SocketAcceptor(int fd
, base::MessageLoopProxy
* target_thread
)
23 target_thread_(target_thread
),
24 started_watching_event_(false, false),
25 accepted_event_(false, false) {
26 target_thread
->PostTask(FROM_HERE
,
27 base::Bind(&SocketAcceptor::StartWatching
, base::Unretained(this), fd
));
30 virtual ~SocketAcceptor() {
34 int server_fd() const { return server_fd_
; }
36 void WaitUntilReady() {
37 started_watching_event_
.Wait();
40 void WaitForAccept() {
41 accepted_event_
.Wait();
46 target_thread_
->PostTask(FROM_HERE
,
47 base::Bind(&SocketAcceptor::StopWatching
, base::Unretained(this),
53 void StartWatching(int fd
) {
54 watcher_
.reset(new base::MessageLoopForIO::FileDescriptorWatcher
);
55 base::MessageLoopForIO::current()->WatchFileDescriptor(
56 fd
, true, base::MessageLoopForIO::WATCH_READ
, watcher_
.get(), this);
57 started_watching_event_
.Signal();
59 void StopWatching(base::MessageLoopForIO::FileDescriptorWatcher
* watcher
) {
60 watcher
->StopWatchingFileDescriptor();
63 virtual void OnFileCanReadWithoutBlocking(int fd
) OVERRIDE
{
64 ASSERT_EQ(-1, server_fd_
);
65 IPC::ServerAcceptConnection(fd
, &server_fd_
);
66 watcher_
->StopWatchingFileDescriptor();
67 accepted_event_
.Signal();
69 virtual void OnFileCanWriteWithoutBlocking(int fd
) OVERRIDE
{}
72 base::MessageLoopProxy
* target_thread_
;
73 scoped_ptr
<base::MessageLoopForIO::FileDescriptorWatcher
> watcher_
;
74 base::WaitableEvent started_watching_event_
;
75 base::WaitableEvent accepted_event_
;
77 DISALLOW_COPY_AND_ASSIGN(SocketAcceptor
);
80 const base::FilePath
GetChannelDir() {
81 #if defined(OS_ANDROID)
82 base::FilePath tmp_dir
;
83 PathService::Get(base::DIR_CACHE
, &tmp_dir
);
86 return base::FilePath("/var/tmp");
90 class TestUnixSocketConnection
{
92 TestUnixSocketConnection()
93 : worker_("WorkerThread"),
94 server_listen_fd_(-1),
97 socket_name_
= GetChannelDir().Append("TestSocket");
98 base::Thread::Options options
;
99 options
.message_loop_type
= base::MessageLoop::TYPE_IO
;
100 worker_
.StartWithOptions(options
);
103 bool CreateServerSocket() {
104 IPC::CreateServerUnixDomainSocket(socket_name_
, &server_listen_fd_
);
105 if (server_listen_fd_
< 0)
107 struct stat socket_stat
;
108 stat(socket_name_
.value().c_str(), &socket_stat
);
109 EXPECT_TRUE(S_ISSOCK(socket_stat
.st_mode
));
110 acceptor_
.reset(new SocketAcceptor(server_listen_fd_
,
111 worker_
.message_loop_proxy().get()));
112 acceptor_
->WaitUntilReady();
116 bool CreateClientSocket() {
117 DCHECK(server_listen_fd_
>= 0);
118 IPC::CreateClientUnixDomainSocket(socket_name_
, &client_fd_
);
121 acceptor_
->WaitForAccept();
122 server_fd_
= acceptor_
->server_fd();
123 return server_fd_
>= 0;
126 virtual ~TestUnixSocketConnection() {
131 if (server_listen_fd_
>= 0) {
132 close(server_listen_fd_
);
133 unlink(socket_name_
.value().c_str());
137 int client_fd() const { return client_fd_
; }
138 int server_fd() const { return server_fd_
; }
141 base::Thread worker_
;
142 base::FilePath socket_name_
;
143 int server_listen_fd_
;
146 scoped_ptr
<SocketAcceptor
> acceptor_
;
149 // Ensure that IPC::CreateServerUnixDomainSocket creates a socket that
150 // IPC::CreateClientUnixDomainSocket can successfully connect to.
151 TEST(UnixDomainSocketUtil
, Connect
) {
152 TestUnixSocketConnection connection
;
153 ASSERT_TRUE(connection
.CreateServerSocket());
154 ASSERT_TRUE(connection
.CreateClientSocket());
157 // Ensure that messages can be sent across the resulting socket.
158 TEST(UnixDomainSocketUtil
, SendReceive
) {
159 TestUnixSocketConnection connection
;
160 ASSERT_TRUE(connection
.CreateServerSocket());
161 ASSERT_TRUE(connection
.CreateClientSocket());
163 const char buffer
[] = "Hello, server!";
164 size_t buf_len
= sizeof(buffer
);
166 HANDLE_EINTR(send(connection
.client_fd(), buffer
, buf_len
, 0));
167 ASSERT_EQ(buf_len
, sent_bytes
);
168 char recv_buf
[sizeof(buffer
)];
169 size_t received_bytes
=
170 HANDLE_EINTR(recv(connection
.server_fd(), recv_buf
, buf_len
, 0));
171 ASSERT_EQ(buf_len
, received_bytes
);
172 ASSERT_EQ(0, memcmp(recv_buf
, buffer
, buf_len
));