libgo: update to Go 1.11
[official-gcc.git] / libgo / go / net / splice_test.go
blob44a5c00ba872e825c22df97c801a5c9dd875be56
1 // Copyright 2018 The Go Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style
3 // license that can be found in the LICENSE file.
5 // +build linux
7 package net
9 import (
10 "bytes"
11 "fmt"
12 "io"
13 "io/ioutil"
14 "sync"
15 "testing"
18 func TestSplice(t *testing.T) {
19 t.Run("simple", testSpliceSimple)
20 t.Run("multipleWrite", testSpliceMultipleWrite)
21 t.Run("big", testSpliceBig)
22 t.Run("honorsLimitedReader", testSpliceHonorsLimitedReader)
23 t.Run("readerAtEOF", testSpliceReaderAtEOF)
24 t.Run("issue25985", testSpliceIssue25985)
27 func testSpliceSimple(t *testing.T) {
28 srv, err := newSpliceTestServer()
29 if err != nil {
30 t.Fatal(err)
32 defer srv.Close()
33 copyDone := srv.Copy()
34 msg := []byte("splice test")
35 if _, err := srv.Write(msg); err != nil {
36 t.Fatal(err)
38 got := make([]byte, len(msg))
39 if _, err := io.ReadFull(srv, got); err != nil {
40 t.Fatal(err)
42 if !bytes.Equal(got, msg) {
43 t.Errorf("got %q, wrote %q", got, msg)
45 srv.CloseWrite()
46 srv.CloseRead()
47 if err := <-copyDone; err != nil {
48 t.Errorf("splice: %v", err)
52 func testSpliceMultipleWrite(t *testing.T) {
53 srv, err := newSpliceTestServer()
54 if err != nil {
55 t.Fatal(err)
57 defer srv.Close()
58 copyDone := srv.Copy()
59 msg1 := []byte("splice test part 1 ")
60 msg2 := []byte(" splice test part 2")
61 if _, err := srv.Write(msg1); err != nil {
62 t.Fatalf("Write: %v", err)
64 if _, err := srv.Write(msg2); err != nil {
65 t.Fatal(err)
67 got := make([]byte, len(msg1)+len(msg2))
68 if _, err := io.ReadFull(srv, got); err != nil {
69 t.Fatal(err)
71 want := append(msg1, msg2...)
72 if !bytes.Equal(got, want) {
73 t.Errorf("got %q, wrote %q", got, want)
75 srv.CloseWrite()
76 srv.CloseRead()
77 if err := <-copyDone; err != nil {
78 t.Errorf("splice: %v", err)
82 func testSpliceBig(t *testing.T) {
83 // The maximum amount of data that internal/poll.Splice will use in a
84 // splice(2) call is 4 << 20. Use a bigger size here so that we test an
85 // amount that doesn't fit in a single call.
86 size := 5 << 20
87 srv, err := newSpliceTestServer()
88 if err != nil {
89 t.Fatal(err)
91 defer srv.Close()
92 big := make([]byte, size)
93 copyDone := srv.Copy()
94 type readResult struct {
95 b []byte
96 err error
98 readDone := make(chan readResult)
99 go func() {
100 got := make([]byte, len(big))
101 _, err := io.ReadFull(srv, got)
102 readDone <- readResult{got, err}
104 if _, err := srv.Write(big); err != nil {
105 t.Fatal(err)
107 res := <-readDone
108 if res.err != nil {
109 t.Fatal(res.err)
111 got := res.b
112 if !bytes.Equal(got, big) {
113 t.Errorf("input and output differ")
115 srv.CloseWrite()
116 srv.CloseRead()
117 if err := <-copyDone; err != nil {
118 t.Errorf("splice: %v", err)
122 func testSpliceHonorsLimitedReader(t *testing.T) {
123 t.Run("stopsAfterN", testSpliceStopsAfterN)
124 t.Run("updatesN", testSpliceUpdatesN)
127 func testSpliceStopsAfterN(t *testing.T) {
128 clientUp, serverUp, err := spliceTestSocketPair("tcp")
129 if err != nil {
130 t.Fatal(err)
132 defer clientUp.Close()
133 defer serverUp.Close()
134 clientDown, serverDown, err := spliceTestSocketPair("tcp")
135 if err != nil {
136 t.Fatal(err)
138 defer clientDown.Close()
139 defer serverDown.Close()
140 count := 128
141 copyDone := make(chan error)
142 lr := &io.LimitedReader{
143 N: int64(count),
144 R: serverUp,
146 go func() {
147 _, err := io.Copy(serverDown, lr)
148 serverDown.Close()
149 copyDone <- err
151 msg := make([]byte, 2*count)
152 if _, err := clientUp.Write(msg); err != nil {
153 t.Fatal(err)
155 clientUp.Close()
156 var buf bytes.Buffer
157 if _, err := io.Copy(&buf, clientDown); err != nil {
158 t.Fatal(err)
160 if buf.Len() != count {
161 t.Errorf("splice transferred %d bytes, want to stop after %d", buf.Len(), count)
163 clientDown.Close()
164 if err := <-copyDone; err != nil {
165 t.Errorf("splice: %v", err)
169 func testSpliceUpdatesN(t *testing.T) {
170 clientUp, serverUp, err := spliceTestSocketPair("tcp")
171 if err != nil {
172 t.Fatal(err)
174 defer clientUp.Close()
175 defer serverUp.Close()
176 clientDown, serverDown, err := spliceTestSocketPair("tcp")
177 if err != nil {
178 t.Fatal(err)
180 defer clientDown.Close()
181 defer serverDown.Close()
182 count := 128
183 copyDone := make(chan error)
184 lr := &io.LimitedReader{
185 N: int64(100 + count),
186 R: serverUp,
188 go func() {
189 _, err := io.Copy(serverDown, lr)
190 copyDone <- err
192 msg := make([]byte, count)
193 if _, err := clientUp.Write(msg); err != nil {
194 t.Fatal(err)
196 clientUp.Close()
197 got := make([]byte, count)
198 if _, err := io.ReadFull(clientDown, got); err != nil {
199 t.Fatal(err)
201 clientDown.Close()
202 if err := <-copyDone; err != nil {
203 t.Errorf("splice: %v", err)
205 wantN := int64(100)
206 if lr.N != wantN {
207 t.Errorf("lr.N = %d, want %d", lr.N, wantN)
211 func testSpliceReaderAtEOF(t *testing.T) {
212 clientUp, serverUp, err := spliceTestSocketPair("tcp")
213 if err != nil {
214 t.Fatal(err)
216 defer clientUp.Close()
217 defer serverUp.Close()
218 clientDown, serverDown, err := spliceTestSocketPair("tcp")
219 if err != nil {
220 t.Fatal(err)
222 defer clientDown.Close()
223 defer serverDown.Close()
225 serverUp.Close()
226 _, err, handled := splice(serverDown.(*TCPConn).fd, serverUp)
227 if !handled {
228 t.Errorf("closed connection: got err = %v, handled = %t, want handled = true", err, handled)
230 lr := &io.LimitedReader{
231 N: 0,
232 R: serverUp,
234 _, err, handled = splice(serverDown.(*TCPConn).fd, lr)
235 if !handled {
236 t.Errorf("exhausted LimitedReader: got err = %v, handled = %t, want handled = true", err, handled)
240 func testSpliceIssue25985(t *testing.T) {
241 front, err := newLocalListener("tcp")
242 if err != nil {
243 t.Fatal(err)
245 defer front.Close()
246 back, err := newLocalListener("tcp")
247 if err != nil {
248 t.Fatal(err)
250 defer back.Close()
252 var wg sync.WaitGroup
253 wg.Add(2)
255 proxy := func() {
256 src, err := front.Accept()
257 if err != nil {
258 return
260 dst, err := Dial("tcp", back.Addr().String())
261 if err != nil {
262 return
264 defer dst.Close()
265 defer src.Close()
266 go func() {
267 io.Copy(src, dst)
268 wg.Done()
270 go func() {
271 io.Copy(dst, src)
272 wg.Done()
276 go proxy()
278 toFront, err := Dial("tcp", front.Addr().String())
279 if err != nil {
280 t.Fatal(err)
283 io.WriteString(toFront, "foo")
284 toFront.Close()
286 fromProxy, err := back.Accept()
287 if err != nil {
288 t.Fatal(err)
290 defer fromProxy.Close()
292 _, err = ioutil.ReadAll(fromProxy)
293 if err != nil {
294 t.Fatal(err)
297 wg.Wait()
300 func BenchmarkTCPReadFrom(b *testing.B) {
301 testHookUninstaller.Do(uninstallTestHooks)
303 var chunkSizes []int
304 for i := uint(10); i <= 20; i++ {
305 chunkSizes = append(chunkSizes, 1<<i)
307 // To benchmark the genericReadFrom code path, set this to false.
308 useSplice := true
309 for _, chunkSize := range chunkSizes {
310 b.Run(fmt.Sprint(chunkSize), func(b *testing.B) {
311 benchmarkSplice(b, chunkSize, useSplice)
316 func benchmarkSplice(b *testing.B, chunkSize int, useSplice bool) {
317 srv, err := newSpliceTestServer()
318 if err != nil {
319 b.Fatal(err)
321 defer srv.Close()
322 var copyDone <-chan error
323 if useSplice {
324 copyDone = srv.Copy()
325 } else {
326 copyDone = srv.CopyNoSplice()
328 chunk := make([]byte, chunkSize)
329 discardDone := make(chan struct{})
330 go func() {
331 for {
332 buf := make([]byte, chunkSize)
333 _, err := srv.Read(buf)
334 if err != nil {
335 break
338 discardDone <- struct{}{}
340 b.SetBytes(int64(chunkSize))
341 b.ResetTimer()
342 for i := 0; i < b.N; i++ {
343 srv.Write(chunk)
345 srv.CloseWrite()
346 <-copyDone
347 srv.CloseRead()
348 <-discardDone
351 type spliceTestServer struct {
352 clientUp io.WriteCloser
353 clientDown io.ReadCloser
354 serverUp io.ReadCloser
355 serverDown io.WriteCloser
358 func newSpliceTestServer() (*spliceTestServer, error) {
359 // For now, both networks are hard-coded to TCP.
360 // If splice is enabled for non-tcp upstream connections,
361 // newSpliceTestServer will need to take a network parameter.
362 clientUp, serverUp, err := spliceTestSocketPair("tcp")
363 if err != nil {
364 return nil, err
366 clientDown, serverDown, err := spliceTestSocketPair("tcp")
367 if err != nil {
368 clientUp.Close()
369 serverUp.Close()
370 return nil, err
372 return &spliceTestServer{clientUp, clientDown, serverUp, serverDown}, nil
375 // Read reads from the downstream connection.
376 func (srv *spliceTestServer) Read(b []byte) (int, error) {
377 return srv.clientDown.Read(b)
380 // Write writes to the upstream connection.
381 func (srv *spliceTestServer) Write(b []byte) (int, error) {
382 return srv.clientUp.Write(b)
385 // Close closes the server.
386 func (srv *spliceTestServer) Close() error {
387 err := srv.closeUp()
388 err1 := srv.closeDown()
389 if err == nil {
390 return err1
392 return err
395 // CloseWrite closes the client side of the upstream connection.
396 func (srv *spliceTestServer) CloseWrite() error {
397 return srv.clientUp.Close()
400 // CloseRead closes the client side of the downstream connection.
401 func (srv *spliceTestServer) CloseRead() error {
402 return srv.clientDown.Close()
405 // Copy copies from the server side of the upstream connection
406 // to the server side of the downstream connection, in a separate
407 // goroutine. Copy is done when the first send on the returned
408 // channel succeeds.
409 func (srv *spliceTestServer) Copy() <-chan error {
410 ch := make(chan error)
411 go func() {
412 _, err := io.Copy(srv.serverDown, srv.serverUp)
413 ch <- err
414 close(ch)
416 return ch
419 // CopyNoSplice is like Copy, but ensures that the splice code path
420 // is not reached.
421 func (srv *spliceTestServer) CopyNoSplice() <-chan error {
422 type onlyReader struct {
423 io.Reader
425 ch := make(chan error)
426 go func() {
427 _, err := io.Copy(srv.serverDown, onlyReader{srv.serverUp})
428 ch <- err
429 close(ch)
431 return ch
434 func (srv *spliceTestServer) closeUp() error {
435 var err, err1 error
436 if srv.serverUp != nil {
437 err = srv.serverUp.Close()
439 if srv.clientUp != nil {
440 err1 = srv.clientUp.Close()
442 if err == nil {
443 return err1
445 return err
448 func (srv *spliceTestServer) closeDown() error {
449 var err, err1 error
450 if srv.serverDown != nil {
451 err = srv.serverDown.Close()
453 if srv.clientDown != nil {
454 err1 = srv.clientDown.Close()
456 if err == nil {
457 return err1
459 return err
462 func spliceTestSocketPair(net string) (client, server Conn, err error) {
463 ln, err := newLocalListener(net)
464 if err != nil {
465 return nil, nil, err
467 defer ln.Close()
468 var cerr, serr error
469 acceptDone := make(chan struct{})
470 go func() {
471 server, serr = ln.Accept()
472 acceptDone <- struct{}{}
474 client, cerr = Dial(ln.Addr().Network(), ln.Addr().String())
475 <-acceptDone
476 if cerr != nil {
477 if server != nil {
478 server.Close()
480 return nil, nil, cerr
482 if serr != nil {
483 if client != nil {
484 client.Close()
486 return nil, nil, serr
488 return client, server, nil