1 // Copyright 2013 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
5 #include <sys/socket.h>
8 #include "base/files/file_path.h"
9 #include "base/path_service.h"
10 #include "base/posix/eintr_wrapper.h"
11 #include "base/synchronization/waitable_event.h"
12 #include "base/threading/thread.h"
13 #include "base/threading/thread_restrictions.h"
14 #include "ipc/unix_domain_socket_util.h"
15 #include "testing/gtest/include/gtest/gtest.h"
19 class SocketAcceptor
: public base::MessageLoopForIO::Watcher
{
21 SocketAcceptor(int fd
, base::MessageLoopProxy
* target_thread
)
23 target_thread_(target_thread
),
24 started_watching_event_(false, false),
25 accepted_event_(false, false) {
26 target_thread
->PostTask(FROM_HERE
,
27 base::Bind(&SocketAcceptor::StartWatching
, base::Unretained(this), fd
));
30 ~SocketAcceptor() override
{
34 int server_fd() const { return server_fd_
; }
36 void WaitUntilReady() {
37 started_watching_event_
.Wait();
40 void WaitForAccept() {
41 accepted_event_
.Wait();
46 target_thread_
->PostTask(FROM_HERE
,
47 base::Bind(&SocketAcceptor::StopWatching
, base::Unretained(this),
53 void StartWatching(int fd
) {
54 watcher_
.reset(new base::MessageLoopForIO::FileDescriptorWatcher
);
55 base::MessageLoopForIO::current()->WatchFileDescriptor(
56 fd
, true, base::MessageLoopForIO::WATCH_READ
, watcher_
.get(), this);
57 started_watching_event_
.Signal();
59 void StopWatching(base::MessageLoopForIO::FileDescriptorWatcher
* watcher
) {
60 watcher
->StopWatchingFileDescriptor();
63 void OnFileCanReadWithoutBlocking(int fd
) override
{
64 ASSERT_EQ(-1, server_fd_
);
65 IPC::ServerAcceptConnection(fd
, &server_fd_
);
66 watcher_
->StopWatchingFileDescriptor();
67 accepted_event_
.Signal();
69 void OnFileCanWriteWithoutBlocking(int fd
) override
{}
72 base::MessageLoopProxy
* target_thread_
;
73 scoped_ptr
<base::MessageLoopForIO::FileDescriptorWatcher
> watcher_
;
74 base::WaitableEvent started_watching_event_
;
75 base::WaitableEvent accepted_event_
;
77 DISALLOW_COPY_AND_ASSIGN(SocketAcceptor
);
80 const base::FilePath
GetChannelDir() {
81 base::FilePath tmp_dir
;
82 PathService::Get(base::DIR_TEMP
, &tmp_dir
);
86 class TestUnixSocketConnection
{
88 TestUnixSocketConnection()
89 : worker_("WorkerThread"),
90 server_listen_fd_(-1),
93 socket_name_
= GetChannelDir().Append("TestSocket");
94 base::Thread::Options options
;
95 options
.message_loop_type
= base::MessageLoop::TYPE_IO
;
96 worker_
.StartWithOptions(options
);
99 bool CreateServerSocket() {
100 IPC::CreateServerUnixDomainSocket(socket_name_
, &server_listen_fd_
);
101 if (server_listen_fd_
< 0)
103 struct stat socket_stat
;
104 stat(socket_name_
.value().c_str(), &socket_stat
);
105 EXPECT_TRUE(S_ISSOCK(socket_stat
.st_mode
));
106 acceptor_
.reset(new SocketAcceptor(server_listen_fd_
,
107 worker_
.message_loop_proxy().get()));
108 acceptor_
->WaitUntilReady();
112 bool CreateClientSocket() {
113 DCHECK(server_listen_fd_
>= 0);
114 IPC::CreateClientUnixDomainSocket(socket_name_
, &client_fd_
);
117 acceptor_
->WaitForAccept();
118 server_fd_
= acceptor_
->server_fd();
119 return server_fd_
>= 0;
122 virtual ~TestUnixSocketConnection() {
127 if (server_listen_fd_
>= 0) {
128 close(server_listen_fd_
);
129 unlink(socket_name_
.value().c_str());
133 int client_fd() const { return client_fd_
; }
134 int server_fd() const { return server_fd_
; }
137 base::Thread worker_
;
138 base::FilePath socket_name_
;
139 int server_listen_fd_
;
142 scoped_ptr
<SocketAcceptor
> acceptor_
;
145 // Ensure that IPC::CreateServerUnixDomainSocket creates a socket that
146 // IPC::CreateClientUnixDomainSocket can successfully connect to.
147 TEST(UnixDomainSocketUtil
, Connect
) {
148 TestUnixSocketConnection connection
;
149 ASSERT_TRUE(connection
.CreateServerSocket());
150 ASSERT_TRUE(connection
.CreateClientSocket());
153 // Ensure that messages can be sent across the resulting socket.
154 TEST(UnixDomainSocketUtil
, SendReceive
) {
155 TestUnixSocketConnection connection
;
156 ASSERT_TRUE(connection
.CreateServerSocket());
157 ASSERT_TRUE(connection
.CreateClientSocket());
159 const char buffer
[] = "Hello, server!";
160 size_t buf_len
= sizeof(buffer
);
162 HANDLE_EINTR(send(connection
.client_fd(), buffer
, buf_len
, 0));
163 ASSERT_EQ(buf_len
, sent_bytes
);
164 char recv_buf
[sizeof(buffer
)];
165 size_t received_bytes
=
166 HANDLE_EINTR(recv(connection
.server_fd(), recv_buf
, buf_len
, 0));
167 ASSERT_EQ(buf_len
, received_bytes
);
168 ASSERT_EQ(0, memcmp(recv_buf
, buffer
, buf_len
));