1 // Copyright 2009 The Go Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style
3 // license that can be found in the LICENSE file.
25 serverAddr
, newServerAddr
string
27 once
, newOnce
, httpOnce sync
.Once
44 // Some of Arith's methods have value args, some have pointer args. That's deliberate.
46 func (t
*Arith
) Add(args Args
, reply
*Reply
) error
{
47 reply
.C
= args
.A
+ args
.B
51 func (t
*Arith
) Mul(args
*Args
, reply
*Reply
) error
{
52 reply
.C
= args
.A
* args
.B
56 func (t
*Arith
) Div(args Args
, reply
*Reply
) error
{
58 return errors
.New("divide by zero")
60 reply
.C
= args
.A
/ args
.B
64 func (t
*Arith
) String(args
*Args
, reply
*string) error
{
65 *reply
= fmt
.Sprintf("%d+%d=%d", args
.A
, args
.B
, args
.A
+args
.B
)
69 func (t
*Arith
) Scan(args
string, reply
*Reply
) (err error
) {
70 _
, err
= fmt
.Sscan(args
, &reply
.C
)
74 func (t
*Arith
) Error(args
*Args
, reply
*Reply
) error
{
78 func (t
*Arith
) SleepMilli(args
*Args
, reply
*Reply
) error
{
79 time
.Sleep(time
.Duration(args
.A
) * time
.Millisecond
)
85 func (t
*hidden
) Exported(args Args
, reply
*Reply
) error
{
86 reply
.C
= args
.A
+ args
.B
94 type BuiltinTypes
struct{}
96 func (BuiltinTypes
) Map(args
*Args
, reply
*map[int]int) error
{
97 (*reply
)[args
.A
] = args
.B
101 func (BuiltinTypes
) Slice(args
*Args
, reply
*[]int) error
{
102 *reply
= append(*reply
, args
.A
, args
.B
)
106 func (BuiltinTypes
) Array(args
*Args
, reply
*[2]int) error
{
112 func listenTCP() (net
.Listener
, string) {
113 l
, e
:= net
.Listen("tcp", "127.0.0.1:0") // any available address
115 log
.Fatalf("net.Listen tcp :0: %v", e
)
117 return l
, l
.Addr().String()
123 RegisterName("net.rpc.Arith", new(Arith
))
124 Register(BuiltinTypes
{})
127 l
, serverAddr
= listenTCP()
128 log
.Println("Test RPC server listening on", serverAddr
)
132 httpOnce
.Do(startHttpServer
)
135 func startNewServer() {
136 newServer
= NewServer()
137 newServer
.Register(new(Arith
))
138 newServer
.Register(new(Embed
))
139 newServer
.RegisterName("net.rpc.Arith", new(Arith
))
140 newServer
.RegisterName("newServer.Arith", new(Arith
))
143 l
, newServerAddr
= listenTCP()
144 log
.Println("NewServer test RPC server listening on", newServerAddr
)
145 go newServer
.Accept(l
)
147 newServer
.HandleHTTP(newHttpPath
, "/bar")
148 httpOnce
.Do(startHttpServer
)
151 func startHttpServer() {
152 server
:= httptest
.NewServer(nil)
153 httpServerAddr
= server
.Listener
.Addr().String()
154 log
.Println("Test HTTP RPC server listening on", httpServerAddr
)
157 func TestRPC(t
*testing
.T
) {
159 testRPC(t
, serverAddr
)
160 newOnce
.Do(startNewServer
)
161 testRPC(t
, newServerAddr
)
162 testNewServerRPC(t
, newServerAddr
)
165 func testRPC(t
*testing
.T
, addr
string) {
166 client
, err
:= Dial("tcp", addr
)
168 t
.Fatal("dialing", err
)
175 err
= client
.Call("Arith.Add", args
, reply
)
177 t
.Errorf("Add: expected no error but got string %q", err
.Error())
179 if reply
.C
!= args
.A
+args
.B
{
180 t
.Errorf("Add: expected %d got %d", reply
.C
, args
.A
+args
.B
)
183 // Methods exported from unexported embedded structs
186 err
= client
.Call("Embed.Exported", args
, reply
)
188 t
.Errorf("Add: expected no error but got string %q", err
.Error())
190 if reply
.C
!= args
.A
+args
.B
{
191 t
.Errorf("Add: expected %d got %d", reply
.C
, args
.A
+args
.B
)
194 // Nonexistent method
197 err
= client
.Call("Arith.BadOperation", args
, reply
)
200 t
.Error("BadOperation: expected error")
201 } else if !strings
.HasPrefix(err
.Error(), "rpc: can't find method ") {
202 t
.Errorf("BadOperation: expected can't find method error; got %q", err
)
208 err
= client
.Call("Arith.Unknown", args
, reply
)
210 t
.Error("expected error calling unknown service")
211 } else if !strings
.Contains(err
.Error(), "method") {
212 t
.Error("expected error about method; got", err
)
217 mulReply
:= new(Reply
)
218 mulCall
:= client
.Go("Arith.Mul", args
, mulReply
, nil)
219 addReply
:= new(Reply
)
220 addCall
:= client
.Go("Arith.Add", args
, addReply
, nil)
222 addCall
= <-addCall
.Done
223 if addCall
.Error
!= nil {
224 t
.Errorf("Add: expected no error but got string %q", addCall
.Error
.Error())
226 if addReply
.C
!= args
.A
+args
.B
{
227 t
.Errorf("Add: expected %d got %d", addReply
.C
, args
.A
+args
.B
)
230 mulCall
= <-mulCall
.Done
231 if mulCall
.Error
!= nil {
232 t
.Errorf("Mul: expected no error but got string %q", mulCall
.Error
.Error())
234 if mulReply
.C
!= args
.A
*args
.B
{
235 t
.Errorf("Mul: expected %d got %d", mulReply
.C
, args
.A
*args
.B
)
241 err
= client
.Call("Arith.Div", args
, reply
)
242 // expect an error: zero divide
244 t
.Error("Div: expected error")
245 } else if err
.Error() != "divide by zero" {
246 t
.Error("Div: expected divide by zero error; got", err
)
251 err
= client
.Call("Arith.Add", reply
, reply
) // args, reply would be the correct thing to use
253 t
.Error("expected error calling Arith.Add with wrong arg type")
254 } else if !strings
.Contains(err
.Error(), "type") {
255 t
.Error("expected error about type; got", err
)
258 // Non-struct argument
260 str
:= fmt
.Sprint(Val
)
262 err
= client
.Call("Arith.Scan", &str
, reply
)
264 t
.Errorf("Scan: expected no error but got string %q", err
.Error())
265 } else if reply
.C
!= Val
{
266 t
.Errorf("Scan: expected %d got %d", Val
, reply
.C
)
272 err
= client
.Call("Arith.String", args
, &str
)
274 t
.Errorf("String: expected no error but got string %q", err
.Error())
276 expect
:= fmt
.Sprintf("%d+%d=%d", args
.A
, args
.B
, args
.A
+args
.B
)
278 t
.Errorf("String: expected %s got %s", expect
, str
)
283 err
= client
.Call("Arith.Mul", args
, reply
)
285 t
.Errorf("Mul: expected no error but got string %q", err
.Error())
287 if reply
.C
!= args
.A
*args
.B
{
288 t
.Errorf("Mul: expected %d got %d", reply
.C
, args
.A
*args
.B
)
291 // ServiceName contain "." character
294 err
= client
.Call("net.rpc.Arith.Add", args
, reply
)
296 t
.Errorf("Add: expected no error but got string %q", err
.Error())
298 if reply
.C
!= args
.A
+args
.B
{
299 t
.Errorf("Add: expected %d got %d", reply
.C
, args
.A
+args
.B
)
303 func testNewServerRPC(t
*testing
.T
, addr
string) {
304 client
, err
:= Dial("tcp", addr
)
306 t
.Fatal("dialing", err
)
313 err
= client
.Call("newServer.Arith.Add", args
, reply
)
315 t
.Errorf("Add: expected no error but got string %q", err
.Error())
317 if reply
.C
!= args
.A
+args
.B
{
318 t
.Errorf("Add: expected %d got %d", reply
.C
, args
.A
+args
.B
)
322 func TestHTTP(t
*testing
.T
) {
325 newOnce
.Do(startNewServer
)
326 testHTTPRPC(t
, newHttpPath
)
329 func testHTTPRPC(t
*testing
.T
, path
string) {
333 client
, err
= DialHTTP("tcp", httpServerAddr
)
335 client
, err
= DialHTTPPath("tcp", httpServerAddr
, path
)
338 t
.Fatal("dialing", err
)
345 err
= client
.Call("Arith.Add", args
, reply
)
347 t
.Errorf("Add: expected no error but got string %q", err
.Error())
349 if reply
.C
!= args
.A
+args
.B
{
350 t
.Errorf("Add: expected %d got %d", reply
.C
, args
.A
+args
.B
)
354 func TestBuiltinTypes(t
*testing
.T
) {
357 client
, err
:= DialHTTP("tcp", httpServerAddr
)
359 t
.Fatal("dialing", err
)
365 replyMap
:= map[int]int{}
366 err
= client
.Call("BuiltinTypes.Map", args
, &replyMap
)
368 t
.Errorf("Map: expected no error but got string %q", err
.Error())
370 if replyMap
[args
.A
] != args
.B
{
371 t
.Errorf("Map: expected %d got %d", args
.B
, replyMap
[args
.A
])
376 replySlice
:= []int{}
377 err
= client
.Call("BuiltinTypes.Slice", args
, &replySlice
)
379 t
.Errorf("Slice: expected no error but got string %q", err
.Error())
381 if e
:= []int{args
.A
, args
.B
}; !reflect
.DeepEqual(replySlice
, e
) {
382 t
.Errorf("Slice: expected %v got %v", e
, replySlice
)
387 replyArray
:= [2]int{}
388 err
= client
.Call("BuiltinTypes.Array", args
, &replyArray
)
390 t
.Errorf("Array: expected no error but got string %q", err
.Error())
392 if e
:= [2]int{args
.A
, args
.B
}; !reflect
.DeepEqual(replyArray
, e
) {
393 t
.Errorf("Array: expected %v got %v", e
, replyArray
)
397 // CodecEmulator provides a client-like api and a ServerCodec interface.
398 // Can be used to test ServeRequest.
399 type CodecEmulator
struct {
407 func (codec
*CodecEmulator
) Call(serviceMethod
string, args
*Args
, reply
*Reply
) error
{
408 codec
.serviceMethod
= serviceMethod
412 var serverError error
413 if codec
.server
== nil {
414 serverError
= ServeRequest(codec
)
416 serverError
= codec
.server
.ServeRequest(codec
)
418 if codec
.err
== nil && serverError
!= nil {
419 codec
.err
= serverError
424 func (codec
*CodecEmulator
) ReadRequestHeader(req
*Request
) error
{
425 req
.ServiceMethod
= codec
.serviceMethod
430 func (codec
*CodecEmulator
) ReadRequestBody(argv
interface{}) error
{
431 if codec
.args
== nil {
432 return io
.ErrUnexpectedEOF
434 *(argv
.(*Args
)) = *codec
.args
438 func (codec
*CodecEmulator
) WriteResponse(resp
*Response
, reply
interface{}) error
{
439 if resp
.Error
!= "" {
440 codec
.err
= errors
.New(resp
.Error
)
442 *codec
.reply
= *(reply
.(*Reply
))
447 func (codec
*CodecEmulator
) Close() error
{
451 func TestServeRequest(t
*testing
.T
) {
453 testServeRequest(t
, nil)
454 newOnce
.Do(startNewServer
)
455 testServeRequest(t
, newServer
)
458 func testServeRequest(t
*testing
.T
, server
*Server
) {
459 client
:= CodecEmulator
{server
: server
}
464 err
:= client
.Call("Arith.Add", args
, reply
)
466 t
.Errorf("Add: expected no error but got string %q", err
.Error())
468 if reply
.C
!= args
.A
+args
.B
{
469 t
.Errorf("Add: expected %d got %d", reply
.C
, args
.A
+args
.B
)
472 err
= client
.Call("Arith.Add", nil, reply
)
474 t
.Errorf("expected error calling Arith.Add with nil arg")
478 type ReplyNotPointer
int
479 type ArgNotPublic
int
480 type ReplyNotPublic
int
481 type NeedsPtrType
int
484 func (t
*ReplyNotPointer
) ReplyNotPointer(args
*Args
, reply Reply
) error
{
488 func (t
*ArgNotPublic
) ArgNotPublic(args
*local
, reply
*Reply
) error
{
492 func (t
*ReplyNotPublic
) ReplyNotPublic(args
*Args
, reply
*local
) error
{
496 func (t
*NeedsPtrType
) NeedsPtrType(args
*Args
, reply
*Reply
) error
{
500 // Check that registration handles lots of bad methods and a type with no suitable methods.
501 func TestRegistrationError(t
*testing
.T
) {
502 err
:= Register(new(ReplyNotPointer
))
504 t
.Error("expected error registering ReplyNotPointer")
506 err
= Register(new(ArgNotPublic
))
508 t
.Error("expected error registering ArgNotPublic")
510 err
= Register(new(ReplyNotPublic
))
512 t
.Error("expected error registering ReplyNotPublic")
514 err
= Register(NeedsPtrType(0))
516 t
.Error("expected error registering NeedsPtrType")
517 } else if !strings
.Contains(err
.Error(), "pointer") {
518 t
.Error("expected hint when registering NeedsPtrType")
522 type WriteFailCodec
int
524 func (WriteFailCodec
) WriteRequest(*Request
, interface{}) error
{
525 // the panic caused by this error used to not unlock a lock.
526 return errors
.New("fail")
529 func (WriteFailCodec
) ReadResponseHeader(*Response
) error
{
533 func (WriteFailCodec
) ReadResponseBody(interface{}) error
{
537 func (WriteFailCodec
) Close() error
{
541 func TestSendDeadlock(t
*testing
.T
) {
542 client
:= NewClientWithCodec(WriteFailCodec(0))
545 done
:= make(chan bool)
547 testSendDeadlock(client
)
548 testSendDeadlock(client
)
554 case <-time
.After(5 * time
.Second
):
559 func testSendDeadlock(client
*Client
) {
565 client
.Call("Arith.Add", args
, reply
)
568 func dialDirect() (*Client
, error
) {
569 return Dial("tcp", serverAddr
)
572 func dialHTTP() (*Client
, error
) {
573 return DialHTTP("tcp", httpServerAddr
)
576 func countMallocs(dial
func() (*Client
, error
), t
*testing
.T
) float64 {
578 client
, err
:= dial()
580 t
.Fatal("error dialing", err
)
586 return testing
.AllocsPerRun(100, func() {
587 err
:= client
.Call("Arith.Add", args
, reply
)
589 t
.Errorf("Add: expected no error but got string %q", err
.Error())
591 if reply
.C
!= args
.A
+args
.B
{
592 t
.Errorf("Add: expected %d got %d", reply
.C
, args
.A
+args
.B
)
597 func TestCountMallocs(t
*testing
.T
) {
599 t
.Skip("skipping malloc count in short mode")
601 if runtime
.GOMAXPROCS(0) > 1 {
602 t
.Skip("skipping; GOMAXPROCS>1")
604 fmt
.Printf("mallocs per rpc round trip: %v\n", countMallocs(dialDirect
, t
))
607 func TestCountMallocsOverHTTP(t
*testing
.T
) {
609 t
.Skip("skipping malloc count in short mode")
611 if runtime
.GOMAXPROCS(0) > 1 {
612 t
.Skip("skipping; GOMAXPROCS>1")
614 fmt
.Printf("mallocs per HTTP rpc round trip: %v\n", countMallocs(dialHTTP
, t
))
617 type writeCrasher
struct {
621 func (writeCrasher
) Close() error
{
625 func (w
*writeCrasher
) Read(p
[]byte) (int, error
) {
630 func (writeCrasher
) Write(p
[]byte) (int, error
) {
631 return 0, errors
.New("fake write failure")
634 func TestClientWriteError(t
*testing
.T
) {
635 w
:= &writeCrasher
{done
: make(chan bool)}
640 err
:= c
.Call("foo", 1, &res
)
642 t
.Fatal("expected error")
644 if err
.Error() != "fake write failure" {
645 t
.Error("unexpected value of error:", err
)
650 func TestTCPClose(t
*testing
.T
) {
653 client
, err
:= dialHTTP()
655 t
.Fatalf("dialing: %v", err
)
661 err
= client
.Call("Arith.Mul", args
, &reply
)
663 t
.Fatal("arith error:", err
)
665 t
.Logf("Arith: %d*%d=%d\n", args
.A
, args
.B
, reply
)
666 if reply
.C
!= args
.A
*args
.B
{
667 t
.Errorf("Add: expected %d got %d", reply
.C
, args
.A
*args
.B
)
671 func TestErrorAfterClientClose(t
*testing
.T
) {
674 client
, err
:= dialHTTP()
676 t
.Fatalf("dialing: %v", err
)
680 t
.Fatal("close error:", err
)
682 err
= client
.Call("Arith.Add", &Args
{7, 9}, new(Reply
))
683 if err
!= ErrShutdown
{
684 t
.Errorf("Forever: expected ErrShutdown got %v", err
)
688 // Tests the fix to issue 11221. Without the fix, this loops forever or crashes.
689 func TestAcceptExitAfterListenerClose(t
*testing
.T
) {
690 newServer
:= NewServer()
691 newServer
.Register(new(Arith
))
692 newServer
.RegisterName("net.rpc.Arith", new(Arith
))
693 newServer
.RegisterName("newServer.Arith", new(Arith
))
701 func TestShutdown(t
*testing
.T
) {
704 ch
:= make(chan net
.Conn
, 1)
713 c
, err
:= net
.Dial("tcp", l
.Addr().String())
722 newServer
:= NewServer()
723 newServer
.Register(new(Arith
))
724 go newServer
.ServeConn(c1
)
728 client
:= NewClient(c
)
729 err
= client
.Call("Arith.Add", args
, reply
)
734 // On an unloaded system 10ms is usually enough to fail 100% of the time
735 // with a broken server. On a loaded system, a broken server might incorrectly
736 // be reported as passing, but we're OK with that kind of flakiness.
737 // If the code is correct, this test will never fail, regardless of timeout.
739 done
:= make(chan *Call
, 1)
740 call
:= client
.Go("Arith.SleepMilli", args
, reply
, done
)
741 c
.(*net
.TCPConn
).CloseWrite()
743 if call
.Error
!= nil {
748 func benchmarkEndToEnd(dial
func() (*Client
, error
), b
*testing
.B
) {
750 client
, err
:= dial()
752 b
.Fatal("error dialing:", err
)
760 b
.RunParallel(func(pb
*testing
.PB
) {
763 err
:= client
.Call("Arith.Add", args
, reply
)
765 b
.Fatalf("rpc error: Add: expected no error but got string %q", err
.Error())
767 if reply
.C
!= args
.A
+args
.B
{
768 b
.Fatalf("rpc error: Add: expected %d got %d", reply
.C
, args
.A
+args
.B
)
774 func benchmarkEndToEndAsync(dial
func() (*Client
, error
), b
*testing
.B
) {
778 const MaxConcurrentCalls
= 100
780 client
, err
:= dial()
782 b
.Fatal("error dialing:", err
)
786 // Asynchronous calls
788 procs
:= 4 * runtime
.GOMAXPROCS(-1)
791 var wg sync
.WaitGroup
793 gate
:= make(chan bool, MaxConcurrentCalls
)
794 res
:= make(chan *Call
, MaxConcurrentCalls
)
797 for p
:= 0; p
< procs
; p
++ {
799 for atomic
.AddInt32(&send
, -1) >= 0 {
802 client
.Go("Arith.Add", args
, reply
, res
)
806 for call
:= range res
{
807 A
:= call
.Args
.(*Args
).A
808 B
:= call
.Args
.(*Args
).B
809 C
:= call
.Reply
.(*Reply
).C
811 b
.Errorf("incorrect reply: Add: expected %d got %d", A
+B
, C
)
815 if atomic
.AddInt32(&recv
, -1) == 0 {
825 func BenchmarkEndToEnd(b
*testing
.B
) {
826 benchmarkEndToEnd(dialDirect
, b
)
829 func BenchmarkEndToEndHTTP(b
*testing
.B
) {
830 benchmarkEndToEnd(dialHTTP
, b
)
833 func BenchmarkEndToEndAsync(b
*testing
.B
) {
834 benchmarkEndToEndAsync(dialDirect
, b
)
837 func BenchmarkEndToEndAsyncHTTP(b
*testing
.B
) {
838 benchmarkEndToEndAsync(dialHTTP
, b
)