summaryrefslogtreecommitdiff
path: root/libgo/go/rpc
diff options
context:
space:
mode:
Diffstat (limited to 'libgo/go/rpc')
-rw-r--r--libgo/go/rpc/client.go250
-rw-r--r--libgo/go/rpc/debug.go90
-rw-r--r--libgo/go/rpc/jsonrpc/all_test.go156
-rw-r--r--libgo/go/rpc/jsonrpc/client.go121
-rw-r--r--libgo/go/rpc/jsonrpc/server.go133
-rw-r--r--libgo/go/rpc/server.go530
-rw-r--r--libgo/go/rpc/server_test.go384
7 files changed, 1664 insertions, 0 deletions
diff --git a/libgo/go/rpc/client.go b/libgo/go/rpc/client.go
new file mode 100644
index 000000000..601c49715
--- /dev/null
+++ b/libgo/go/rpc/client.go
@@ -0,0 +1,250 @@
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package rpc
+
+import (
+ "bufio"
+ "gob"
+ "http"
+ "io"
+ "log"
+ "net"
+ "os"
+ "sync"
+)
+
+// Call represents an active RPC.
+type Call struct {
+ ServiceMethod string // The name of the service and method to call.
+ Args interface{} // The argument to the function (*struct).
+ Reply interface{} // The reply from the function (*struct).
+ Error os.Error // After completion, the error status.
+ Done chan *Call // Strobes when call is complete; value is the error status.
+ seq uint64
+}
+
+// Client represents an RPC Client.
+// There may be multiple outstanding Calls associated
+// with a single Client.
+type Client struct {
+ mutex sync.Mutex // protects pending, seq
+ shutdown os.Error // non-nil if the client is shut down
+ sending sync.Mutex
+ seq uint64
+ codec ClientCodec
+ pending map[uint64]*Call
+ closing bool
+}
+
+// A ClientCodec implements writing of RPC requests and
+// reading of RPC responses for the client side of an RPC session.
+// The client calls WriteRequest to write a request to the connection
+// and calls ReadResponseHeader and ReadResponseBody in pairs
+// to read responses. The client calls Close when finished with the
+// connection.
+type ClientCodec interface {
+ WriteRequest(*Request, interface{}) os.Error
+ ReadResponseHeader(*Response) os.Error
+ ReadResponseBody(interface{}) os.Error
+
+ Close() os.Error
+}
+
+func (client *Client) send(c *Call) {
+ // Register this call.
+ client.mutex.Lock()
+ if client.shutdown != nil {
+ c.Error = client.shutdown
+ client.mutex.Unlock()
+ _ = c.Done <- c // do not block
+ return
+ }
+ c.seq = client.seq
+ client.seq++
+ client.pending[c.seq] = c
+ client.mutex.Unlock()
+
+ // Encode and send the request.
+ request := new(Request)
+ client.sending.Lock()
+ defer client.sending.Unlock()
+ request.Seq = c.seq
+ request.ServiceMethod = c.ServiceMethod
+ if err := client.codec.WriteRequest(request, c.Args); err != nil {
+ panic("rpc: client encode error: " + err.String())
+ }
+}
+
+func (client *Client) input() {
+ var err os.Error
+ for err == nil {
+ response := new(Response)
+ err = client.codec.ReadResponseHeader(response)
+ if err != nil {
+ if err == os.EOF && !client.closing {
+ err = io.ErrUnexpectedEOF
+ }
+ break
+ }
+ seq := response.Seq
+ client.mutex.Lock()
+ c := client.pending[seq]
+ client.pending[seq] = c, false
+ client.mutex.Unlock()
+ err = client.codec.ReadResponseBody(c.Reply)
+ if response.Error != "" {
+ c.Error = os.ErrorString(response.Error)
+ } else if err != nil {
+ c.Error = err
+ } else {
+ // Empty strings should turn into nil os.Errors
+ c.Error = nil
+ }
+ // We don't want to block here. It is the caller's responsibility to make
+ // sure the channel has enough buffer space. See comment in Go().
+ _ = c.Done <- c // do not block
+ }
+ // Terminate pending calls.
+ client.mutex.Lock()
+ client.shutdown = err
+ for _, call := range client.pending {
+ call.Error = err
+ _ = call.Done <- call // do not block
+ }
+ client.mutex.Unlock()
+ if err != os.EOF || !client.closing {
+ log.Println("rpc: client protocol error:", err)
+ }
+}
+
+// NewClient returns a new Client to handle requests to the
+// set of services at the other end of the connection.
+func NewClient(conn io.ReadWriteCloser) *Client {
+ return NewClientWithCodec(&gobClientCodec{conn, gob.NewDecoder(conn), gob.NewEncoder(conn)})
+}
+
+// NewClientWithCodec is like NewClient but uses the specified
+// codec to encode requests and decode responses.
+func NewClientWithCodec(codec ClientCodec) *Client {
+ client := &Client{
+ codec: codec,
+ pending: make(map[uint64]*Call),
+ }
+ go client.input()
+ return client
+}
+
+type gobClientCodec struct {
+ rwc io.ReadWriteCloser
+ dec *gob.Decoder
+ enc *gob.Encoder
+}
+
+func (c *gobClientCodec) WriteRequest(r *Request, body interface{}) os.Error {
+ if err := c.enc.Encode(r); err != nil {
+ return err
+ }
+ return c.enc.Encode(body)
+}
+
+func (c *gobClientCodec) ReadResponseHeader(r *Response) os.Error {
+ return c.dec.Decode(r)
+}
+
+func (c *gobClientCodec) ReadResponseBody(body interface{}) os.Error {
+ return c.dec.Decode(body)
+}
+
+func (c *gobClientCodec) Close() os.Error {
+ return c.rwc.Close()
+}
+
+
+// DialHTTP connects to an HTTP RPC server at the specified network address
+// listening on the default HTTP RPC path.
+func DialHTTP(network, address string) (*Client, os.Error) {
+ return DialHTTPPath(network, address, DefaultRPCPath)
+}
+
+// DialHTTPPath connects to an HTTP RPC server
+// at the specified network address and path.
+func DialHTTPPath(network, address, path string) (*Client, os.Error) {
+ var err os.Error
+ conn, err := net.Dial(network, "", address)
+ if err != nil {
+ return nil, err
+ }
+ io.WriteString(conn, "CONNECT "+path+" HTTP/1.0\n\n")
+
+ // Require successful HTTP response
+ // before switching to RPC protocol.
+ resp, err := http.ReadResponse(bufio.NewReader(conn), "CONNECT")
+ if err == nil && resp.Status == connected {
+ return NewClient(conn), nil
+ }
+ if err == nil {
+ err = os.ErrorString("unexpected HTTP response: " + resp.Status)
+ }
+ conn.Close()
+ return nil, &net.OpError{"dial-http", network + " " + address, nil, err}
+}
+
+// Dial connects to an RPC server at the specified network address.
+func Dial(network, address string) (*Client, os.Error) {
+ conn, err := net.Dial(network, "", address)
+ if err != nil {
+ return nil, err
+ }
+ return NewClient(conn), nil
+}
+
+func (client *Client) Close() os.Error {
+ if client.shutdown != nil || client.closing {
+ return os.ErrorString("rpc: already closed")
+ }
+ client.mutex.Lock()
+ client.closing = true
+ client.mutex.Unlock()
+ return client.codec.Close()
+}
+
+// Go invokes the function asynchronously. It returns the Call structure representing
+// the invocation. The done channel will signal when the call is complete by returning
+// the same Call object. If done is nil, Go will allocate a new channel.
+// If non-nil, done must be buffered or Go will deliberately crash.
+func (client *Client) Go(serviceMethod string, args interface{}, reply interface{}, done chan *Call) *Call {
+ c := new(Call)
+ c.ServiceMethod = serviceMethod
+ c.Args = args
+ c.Reply = reply
+ if done == nil {
+ done = make(chan *Call, 10) // buffered.
+ } else {
+ // If caller passes done != nil, it must arrange that
+ // done has enough buffer for the number of simultaneous
+ // RPCs that will be using that channel. If the channel
+ // is totally unbuffered, it's best not to run at all.
+ if cap(done) == 0 {
+ log.Panic("rpc: done channel is unbuffered")
+ }
+ }
+ c.Done = done
+ if client.shutdown != nil {
+ c.Error = client.shutdown
+ _ = c.Done <- c // do not block
+ return c
+ }
+ client.send(c)
+ return c
+}
+
+// Call invokes the named function, waits for it to complete, and returns its error status.
+func (client *Client) Call(serviceMethod string, args interface{}, reply interface{}) os.Error {
+ if client.shutdown != nil {
+ return client.shutdown
+ }
+ call := <-client.Go(serviceMethod, args, reply, nil).Done
+ return call.Error
+}
diff --git a/libgo/go/rpc/debug.go b/libgo/go/rpc/debug.go
new file mode 100644
index 000000000..44b32e04b
--- /dev/null
+++ b/libgo/go/rpc/debug.go
@@ -0,0 +1,90 @@
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package rpc
+
+/*
+ Some HTML presented at http://machine:port/debug/rpc
+ Lists services, their methods, and some statistics, still rudimentary.
+*/
+
+import (
+ "fmt"
+ "http"
+ "sort"
+ "template"
+)
+
+const debugText = `<html>
+ <body>
+ <title>Services</title>
+ {.repeated section @}
+ <hr>
+ Service {Name}
+ <hr>
+ <table>
+ <th align=center>Method</th><th align=center>Calls</th>
+ {.repeated section Method}
+ <tr>
+ <td align=left font=fixed>{Name}({Type.ArgType}, {Type.ReplyType}) os.Error</td>
+ <td align=center>{Type.NumCalls}</td>
+ </tr>
+ {.end}
+ </table>
+ {.end}
+ </body>
+ </html>`
+
+var debug = template.MustParse(debugText, nil)
+
+type debugMethod struct {
+ Type *methodType
+ Name string
+}
+
+type methodArray []debugMethod
+
+type debugService struct {
+ Service *service
+ Name string
+ Method methodArray
+}
+
+type serviceArray []debugService
+
+func (s serviceArray) Len() int { return len(s) }
+func (s serviceArray) Less(i, j int) bool { return s[i].Name < s[j].Name }
+func (s serviceArray) Swap(i, j int) { s[i], s[j] = s[j], s[i] }
+
+func (m methodArray) Len() int { return len(m) }
+func (m methodArray) Less(i, j int) bool { return m[i].Name < m[j].Name }
+func (m methodArray) Swap(i, j int) { m[i], m[j] = m[j], m[i] }
+
+type debugHTTP struct {
+ *Server
+}
+
+// Runs at /debug/rpc
+func (server debugHTTP) ServeHTTP(w http.ResponseWriter, req *http.Request) {
+ // Build a sorted version of the data.
+ var services = make(serviceArray, len(server.serviceMap))
+ i := 0
+ server.Lock()
+ for sname, service := range server.serviceMap {
+ services[i] = debugService{service, sname, make(methodArray, len(service.method))}
+ j := 0
+ for mname, method := range service.method {
+ services[i].Method[j] = debugMethod{method, mname}
+ j++
+ }
+ sort.Sort(services[i].Method)
+ i++
+ }
+ server.Unlock()
+ sort.Sort(services)
+ err := debug.Execute(services, w)
+ if err != nil {
+ fmt.Fprintln(w, "rpc: error executing template:", err.String())
+ }
+}
diff --git a/libgo/go/rpc/jsonrpc/all_test.go b/libgo/go/rpc/jsonrpc/all_test.go
new file mode 100644
index 000000000..764ee7ff3
--- /dev/null
+++ b/libgo/go/rpc/jsonrpc/all_test.go
@@ -0,0 +1,156 @@
+// Copyright 2010 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package jsonrpc
+
+import (
+ "fmt"
+ "json"
+ "net"
+ "os"
+ "rpc"
+ "testing"
+)
+
+type Args struct {
+ A, B int
+}
+
+type Reply struct {
+ C int
+}
+
+type Arith int
+
+func (t *Arith) Add(args *Args, reply *Reply) os.Error {
+ reply.C = args.A + args.B
+ return nil
+}
+
+func (t *Arith) Mul(args *Args, reply *Reply) os.Error {
+ reply.C = args.A * args.B
+ return nil
+}
+
+func (t *Arith) Div(args *Args, reply *Reply) os.Error {
+ if args.B == 0 {
+ return os.ErrorString("divide by zero")
+ }
+ reply.C = args.A / args.B
+ return nil
+}
+
+func (t *Arith) Error(args *Args, reply *Reply) os.Error {
+ panic("ERROR")
+}
+
+func init() {
+ rpc.Register(new(Arith))
+}
+
+func TestServer(t *testing.T) {
+ type addResp struct {
+ Id interface{} "id"
+ Result Reply "result"
+ Error interface{} "error"
+ }
+
+ cli, srv := net.Pipe()
+ defer cli.Close()
+ go ServeConn(srv)
+ dec := json.NewDecoder(cli)
+
+ // Send hand-coded requests to server, parse responses.
+ for i := 0; i < 10; i++ {
+ fmt.Fprintf(cli, `{"method": "Arith.Add", "id": "\u%04d", "params": [{"A": %d, "B": %d}]}`, i, i, i+1)
+ var resp addResp
+ err := dec.Decode(&resp)
+ if err != nil {
+ t.Fatalf("Decode: %s", err)
+ }
+ if resp.Error != nil {
+ t.Fatalf("resp.Error: %s", resp.Error)
+ }
+ if resp.Id.(string) != string(i) {
+ t.Fatalf("resp: bad id %q want %q", resp.Id.(string), string(i))
+ }
+ if resp.Result.C != 2*i+1 {
+ t.Fatalf("resp: bad result: %d+%d=%d", i, i+1, resp.Result.C)
+ }
+ }
+
+ fmt.Fprintf(cli, "{}\n")
+ var resp addResp
+ if err := dec.Decode(&resp); err != nil {
+ t.Fatalf("Decode after empty: %s", err)
+ }
+ if resp.Error == nil {
+ t.Fatalf("Expected error, got nil")
+ }
+}
+
+func TestClient(t *testing.T) {
+ // Assume server is okay (TestServer is above).
+ // Test client against server.
+ cli, srv := net.Pipe()
+ go ServeConn(srv)
+
+ client := NewClient(cli)
+ defer client.Close()
+
+ // Synchronous calls
+ args := &Args{7, 8}
+ reply := new(Reply)
+ err := client.Call("Arith.Add", args, reply)
+ if err != nil {
+ t.Errorf("Add: expected no error but got string %q", err.String())
+ }
+ if reply.C != args.A+args.B {
+ t.Errorf("Add: expected %d got %d", reply.C, args.A+args.B)
+ }
+
+ args = &Args{7, 8}
+ reply = new(Reply)
+ err = client.Call("Arith.Mul", args, reply)
+ if err != nil {
+ t.Errorf("Mul: expected no error but got string %q", err.String())
+ }
+ if reply.C != args.A*args.B {
+ t.Errorf("Mul: expected %d got %d", reply.C, args.A*args.B)
+ }
+
+ // Out of order.
+ args = &Args{7, 8}
+ mulReply := new(Reply)
+ mulCall := client.Go("Arith.Mul", args, mulReply, nil)
+ addReply := new(Reply)
+ addCall := client.Go("Arith.Add", args, addReply, nil)
+
+ addCall = <-addCall.Done
+ if addCall.Error != nil {
+ t.Errorf("Add: expected no error but got string %q", addCall.Error.String())
+ }
+ if addReply.C != args.A+args.B {
+ t.Errorf("Add: expected %d got %d", addReply.C, args.A+args.B)
+ }
+
+ mulCall = <-mulCall.Done
+ if mulCall.Error != nil {
+ t.Errorf("Mul: expected no error but got string %q", mulCall.Error.String())
+ }
+ if mulReply.C != args.A*args.B {
+ t.Errorf("Mul: expected %d got %d", mulReply.C, args.A*args.B)
+ }
+
+ // Error test
+ args = &Args{7, 0}
+ reply = new(Reply)
+ err = client.Call("Arith.Div", args, reply)
+ // expect an error: zero divide
+ if err == nil {
+ t.Error("Div: expected error")
+ } else if err.String() != "divide by zero" {
+ t.Error("Div: expected divide by zero error; got", err)
+ }
+}
diff --git a/libgo/go/rpc/jsonrpc/client.go b/libgo/go/rpc/jsonrpc/client.go
new file mode 100644
index 000000000..dcaa69f9d
--- /dev/null
+++ b/libgo/go/rpc/jsonrpc/client.go
@@ -0,0 +1,121 @@
+// Copyright 2010 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Package jsonrpc implements a JSON-RPC ClientCodec and ServerCodec
+// for the rpc package.
+package jsonrpc
+
+import (
+ "fmt"
+ "io"
+ "json"
+ "net"
+ "os"
+ "rpc"
+ "sync"
+)
+
+type clientCodec struct {
+ dec *json.Decoder // for reading JSON values
+ enc *json.Encoder // for writing JSON values
+ c io.Closer
+
+ // temporary work space
+ req clientRequest
+ resp clientResponse
+
+ // JSON-RPC responses include the request id but not the request method.
+ // Package rpc expects both.
+ // We save the request method in pending when sending a request
+ // and then look it up by request ID when filling out the rpc Response.
+ mutex sync.Mutex // protects pending
+ pending map[uint64]string // map request id to method name
+}
+
+// NewClientCodec returns a new rpc.ClientCodec using JSON-RPC on conn.
+func NewClientCodec(conn io.ReadWriteCloser) rpc.ClientCodec {
+ return &clientCodec{
+ dec: json.NewDecoder(conn),
+ enc: json.NewEncoder(conn),
+ c: conn,
+ pending: make(map[uint64]string),
+ }
+}
+
+type clientRequest struct {
+ Method string "method"
+ Params [1]interface{} "params"
+ Id uint64 "id"
+}
+
+func (c *clientCodec) WriteRequest(r *rpc.Request, param interface{}) os.Error {
+ c.mutex.Lock()
+ c.pending[r.Seq] = r.ServiceMethod
+ c.mutex.Unlock()
+ c.req.Method = r.ServiceMethod
+ c.req.Params[0] = param
+ c.req.Id = r.Seq
+ return c.enc.Encode(&c.req)
+}
+
+type clientResponse struct {
+ Id uint64 "id"
+ Result *json.RawMessage "result"
+ Error interface{} "error"
+}
+
+func (r *clientResponse) reset() {
+ r.Id = 0
+ r.Result = nil
+ r.Error = nil
+}
+
+func (c *clientCodec) ReadResponseHeader(r *rpc.Response) os.Error {
+ c.resp.reset()
+ if err := c.dec.Decode(&c.resp); err != nil {
+ return err
+ }
+
+ c.mutex.Lock()
+ r.ServiceMethod = c.pending[c.resp.Id]
+ c.pending[c.resp.Id] = "", false
+ c.mutex.Unlock()
+
+ r.Error = ""
+ r.Seq = c.resp.Id
+ if c.resp.Error != nil {
+ x, ok := c.resp.Error.(string)
+ if !ok {
+ return fmt.Errorf("invalid error %v", c.resp.Error)
+ }
+ if x == "" {
+ x = "unspecified error"
+ }
+ r.Error = x
+ }
+ return nil
+}
+
+func (c *clientCodec) ReadResponseBody(x interface{}) os.Error {
+ return json.Unmarshal(*c.resp.Result, x)
+}
+
+func (c *clientCodec) Close() os.Error {
+ return c.c.Close()
+}
+
+// NewClient returns a new rpc.Client to handle requests to the
+// set of services at the other end of the connection.
+func NewClient(conn io.ReadWriteCloser) *rpc.Client {
+ return rpc.NewClientWithCodec(NewClientCodec(conn))
+}
+
+// Dial connects to a JSON-RPC server at the specified network address.
+func Dial(network, address string) (*rpc.Client, os.Error) {
+ conn, err := net.Dial(network, "", address)
+ if err != nil {
+ return nil, err
+ }
+ return NewClient(conn), err
+}
diff --git a/libgo/go/rpc/jsonrpc/server.go b/libgo/go/rpc/jsonrpc/server.go
new file mode 100644
index 000000000..bf53bda8d
--- /dev/null
+++ b/libgo/go/rpc/jsonrpc/server.go
@@ -0,0 +1,133 @@
+// Copyright 2010 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package jsonrpc
+
+import (
+ "io"
+ "json"
+ "os"
+ "rpc"
+ "sync"
+)
+
+type serverCodec struct {
+ dec *json.Decoder // for reading JSON values
+ enc *json.Encoder // for writing JSON values
+ c io.Closer
+
+ // temporary work space
+ req serverRequest
+ resp serverResponse
+
+ // JSON-RPC clients can use arbitrary json values as request IDs.
+ // Package rpc expects uint64 request IDs.
+ // We assign uint64 sequence numbers to incoming requests
+ // but save the original request ID in the pending map.
+ // When rpc responds, we use the sequence number in
+ // the response to find the original request ID.
+ mutex sync.Mutex // protects seq, pending
+ seq uint64
+ pending map[uint64]*json.RawMessage
+}
+
+// NewServerCodec returns a new rpc.ServerCodec using JSON-RPC on conn.
+func NewServerCodec(conn io.ReadWriteCloser) rpc.ServerCodec {
+ return &serverCodec{
+ dec: json.NewDecoder(conn),
+ enc: json.NewEncoder(conn),
+ c: conn,
+ pending: make(map[uint64]*json.RawMessage),
+ }
+}
+
+type serverRequest struct {
+ Method string "method"
+ Params *json.RawMessage "params"
+ Id *json.RawMessage "id"
+}
+
+func (r *serverRequest) reset() {
+ r.Method = ""
+ if r.Params != nil {
+ *r.Params = (*r.Params)[0:0]
+ }
+ if r.Id != nil {
+ *r.Id = (*r.Id)[0:0]
+ }
+}
+
+type serverResponse struct {
+ Id *json.RawMessage "id"
+ Result interface{} "result"
+ Error interface{} "error"
+}
+
+func (c *serverCodec) ReadRequestHeader(r *rpc.Request) os.Error {
+ c.req.reset()
+ if err := c.dec.Decode(&c.req); err != nil {
+ return err
+ }
+ r.ServiceMethod = c.req.Method
+
+ // JSON request id can be any JSON value;
+ // RPC package expects uint64. Translate to
+ // internal uint64 and save JSON on the side.
+ c.mutex.Lock()
+ c.seq++
+ c.pending[c.seq] = c.req.Id
+ c.req.Id = nil
+ r.Seq = c.seq
+ c.mutex.Unlock()
+
+ return nil
+}
+
+func (c *serverCodec) ReadRequestBody(x interface{}) os.Error {
+ // JSON params is array value.
+ // RPC params is struct.
+ // Unmarshal into array containing struct for now.
+ // Should think about making RPC more general.
+ var params [1]interface{}
+ params[0] = x
+ return json.Unmarshal(*c.req.Params, &params)
+}
+
+var null = json.RawMessage([]byte("null"))
+
+func (c *serverCodec) WriteResponse(r *rpc.Response, x interface{}) os.Error {
+ var resp serverResponse
+ c.mutex.Lock()
+ b, ok := c.pending[r.Seq]
+ if !ok {
+ c.mutex.Unlock()
+ return os.NewError("invalid sequence number in response")
+ }
+ c.pending[r.Seq] = nil, false
+ c.mutex.Unlock()
+
+ if b == nil {
+ // Invalid request so no id. Use JSON null.
+ b = &null
+ }
+ resp.Id = b
+ resp.Result = x
+ if r.Error == "" {
+ resp.Error = nil
+ } else {
+ resp.Error = r.Error
+ }
+ return c.enc.Encode(resp)
+}
+
+func (c *serverCodec) Close() os.Error {
+ return c.c.Close()
+}
+
+// ServeConn runs the JSON-RPC server on a single connection.
+// ServeConn blocks, serving the connection until the client hangs up.
+// The caller typically invokes ServeConn in a go statement.
+func ServeConn(conn io.ReadWriteCloser) {
+ rpc.ServeCodec(NewServerCodec(conn))
+}
diff --git a/libgo/go/rpc/server.go b/libgo/go/rpc/server.go
new file mode 100644
index 000000000..5c50bcc3a
--- /dev/null
+++ b/libgo/go/rpc/server.go
@@ -0,0 +1,530 @@
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+/*
+ The rpc package provides access to the exported methods of an object across a
+ network or other I/O connection. A server registers an object, making it visible
+ as a service with the name of the type of the object. After registration, exported
+ methods of the object will be accessible remotely. A server may register multiple
+ objects (services) of different types but it is an error to register multiple
+ objects of the same type.
+
+ Only methods that satisfy these criteria will be made available for remote access;
+ other methods will be ignored:
+
+ - the method receiver and name are exported, that is, begin with an upper case letter.
+ - the method has two arguments, both pointers to exported types.
+ - the method has return type os.Error.
+
+ The method's first argument represents the arguments provided by the caller; the
+ second argument represents the result parameters to be returned to the caller.
+ The method's return value, if non-nil, is passed back as a string that the client
+ sees as an os.ErrorString.
+
+ The server may handle requests on a single connection by calling ServeConn. More
+ typically it will create a network listener and call Accept or, for an HTTP
+ listener, HandleHTTP and http.Serve.
+
+ A client wishing to use the service establishes a connection and then invokes
+ NewClient on the connection. The convenience function Dial (DialHTTP) performs
+ both steps for a raw network connection (an HTTP connection). The resulting
+ Client object has two methods, Call and Go, that specify the service and method to
+ call, a pointer containing the arguments, and a pointer to receive the result
+ parameters.
+
+ Call waits for the remote call to complete; Go launches the call asynchronously
+ and returns a channel that will signal completion.
+
+ Package "gob" is used to transport the data.
+
+ Here is a simple example. A server wishes to export an object of type Arith:
+
+ package server
+
+ type Args struct {
+ A, B int
+ }
+
+ type Quotient struct {
+ Quo, Rem int
+ }
+
+ type Arith int
+
+ func (t *Arith) Multiply(args *Args, reply *int) os.Error {
+ *reply = args.A * args.B
+ return nil
+ }
+
+ func (t *Arith) Divide(args *Args, quo *Quotient) os.Error {
+ if args.B == 0 {
+ return os.ErrorString("divide by zero")
+ }
+ quo.Quo = args.A / args.B
+ quo.Rem = args.A % args.B
+ return nil
+ }
+
+ The server calls (for HTTP service):
+
+ arith := new(Arith)
+ rpc.Register(arith)
+ rpc.HandleHTTP()
+ l, e := net.Listen("tcp", ":1234")
+ if e != nil {
+ log.Exit("listen error:", e)
+ }
+ go http.Serve(l, nil)
+
+ At this point, clients can see a service "Arith" with methods "Arith.Multiply" and
+ "Arith.Divide". To invoke one, a client first dials the server:
+
+ client, err := rpc.DialHTTP("tcp", serverAddress + ":1234")
+ if err != nil {
+ log.Exit("dialing:", err)
+ }
+
+ Then it can make a remote call:
+
+ // Synchronous call
+ args := &server.Args{7,8}
+ var reply int
+ err = client.Call("Arith.Multiply", args, &reply)
+ if err != nil {
+ log.Exit("arith error:", err)
+ }
+ fmt.Printf("Arith: %d*%d=%d", args.A, args.B, *reply)
+
+ or
+
+ // Asynchronous call
+ quotient := new(Quotient)
+ divCall := client.Go("Arith.Divide", args, &quotient, nil)
+ replyCall := <-divCall.Done // will be equal to divCall
+ // check errors, print, etc.
+
+ A server implementation will often provide a simple, type-safe wrapper for the
+ client.
+*/
+package rpc
+
+import (
+ "gob"
+ "http"
+ "log"
+ "io"
+ "net"
+ "os"
+ "reflect"
+ "strings"
+ "sync"
+ "unicode"
+ "utf8"
+)
+
+const (
+ // Defaults used by HandleHTTP
+ DefaultRPCPath = "/_goRPC_"
+ DefaultDebugPath = "/debug/rpc"
+)
+
+// Precompute the reflect type for os.Error. Can't use os.Error directly
+// because Typeof takes an empty interface value. This is annoying.
+var unusedError *os.Error
+var typeOfOsError = reflect.Typeof(unusedError).(*reflect.PtrType).Elem()
+
+type methodType struct {
+ sync.Mutex // protects counters
+ method reflect.Method
+ ArgType *reflect.PtrType
+ ReplyType *reflect.PtrType
+ numCalls uint
+}
+
+type service struct {
+ name string // name of service
+ rcvr reflect.Value // receiver of methods for the service
+ typ reflect.Type // type of the receiver
+ method map[string]*methodType // registered methods
+}
+
+// Request is a header written before every RPC call. It is used internally
+// but documented here as an aid to debugging, such as when analyzing
+// network traffic.
+type Request struct {
+ ServiceMethod string // format: "Service.Method"
+ Seq uint64 // sequence number chosen by client
+}
+
+// Response is a header written before every RPC return. It is used internally
+// but documented here as an aid to debugging, such as when analyzing
+// network traffic.
+type Response struct {
+ ServiceMethod string // echoes that of the Request
+ Seq uint64 // echoes that of the request
+ Error string // error, if any.
+}
+
+// ClientInfo records information about an RPC client connection.
+type ClientInfo struct {
+ LocalAddr string
+ RemoteAddr string
+}
+
+// Server represents an RPC Server.
+type Server struct {
+ sync.Mutex // protects the serviceMap
+ serviceMap map[string]*service
+}
+
+// NewServer returns a new Server.
+func NewServer() *Server {
+ return &Server{serviceMap: make(map[string]*service)}
+}
+
+// DefaultServer is the default instance of *Server.
+var DefaultServer = NewServer()
+
+// Is this an exported - upper case - name?
+func isExported(name string) bool {
+ rune, _ := utf8.DecodeRuneInString(name)
+ return unicode.IsUpper(rune)
+}
+
+// Register publishes in the server the set of methods of the
+// receiver value that satisfy the following conditions:
+// - exported method
+// - two arguments, both pointers to exported structs
+// - one return value, of type os.Error
+// It returns an error if the receiver is not an exported type or has no
+// suitable methods.
+// The client accesses each method using a string of the form "Type.Method",
+// where Type is the receiver's concrete type.
+func (server *Server) Register(rcvr interface{}) os.Error {
+ return server.register(rcvr, "", false)
+}
+
+// RegisterName is like Register but uses the provided name for the type
+// instead of the receiver's concrete type.
+func (server *Server) RegisterName(name string, rcvr interface{}) os.Error {
+ return server.register(rcvr, name, true)
+}
+
+func (server *Server) register(rcvr interface{}, name string, useName bool) os.Error {
+ server.Lock()
+ defer server.Unlock()
+ if server.serviceMap == nil {
+ server.serviceMap = make(map[string]*service)
+ }
+ s := new(service)
+ s.typ = reflect.Typeof(rcvr)
+ s.rcvr = reflect.NewValue(rcvr)
+ sname := reflect.Indirect(s.rcvr).Type().Name()
+ if useName {
+ sname = name
+ }
+ if sname == "" {
+ log.Exit("rpc: no service name for type", s.typ.String())
+ }
+ if s.typ.PkgPath() != "" && !isExported(sname) && !useName {
+ s := "rpc Register: type " + sname + " is not exported"
+ log.Print(s)
+ return os.ErrorString(s)
+ }
+ if _, present := server.serviceMap[sname]; present {
+ return os.ErrorString("rpc: service already defined: " + sname)
+ }
+ s.name = sname
+ s.method = make(map[string]*methodType)
+
+ // Install the methods
+ for m := 0; m < s.typ.NumMethod(); m++ {
+ method := s.typ.Method(m)
+ mtype := method.Type
+ mname := method.Name
+ if mtype.PkgPath() != "" || !isExported(mname) {
+ continue
+ }
+ // Method needs three ins: receiver, *args, *reply.
+ if mtype.NumIn() != 3 {
+ log.Println("method", mname, "has wrong number of ins:", mtype.NumIn())
+ continue
+ }
+ argType, ok := mtype.In(1).(*reflect.PtrType)
+ if !ok {
+ log.Println(mname, "arg type not a pointer:", mtype.In(1))
+ continue
+ }
+ replyType, ok := mtype.In(2).(*reflect.PtrType)
+ if !ok {
+ log.Println(mname, "reply type not a pointer:", mtype.In(2))
+ continue
+ }
+ if argType.Elem().PkgPath() != "" && !isExported(argType.Elem().Name()) {
+ log.Println(mname, "argument type not exported:", argType)
+ continue
+ }
+ if replyType.Elem().PkgPath() != "" && !isExported(replyType.Elem().Name()) {
+ log.Println(mname, "reply type not exported:", replyType)
+ continue
+ }
+ if mtype.NumIn() == 4 {
+ t := mtype.In(3)
+ if t != reflect.Typeof((*ClientInfo)(nil)) {
+ log.Println(mname, "last argument not *ClientInfo")
+ continue
+ }
+ }
+ // Method needs one out: os.Error.
+ if mtype.NumOut() != 1 {
+ log.Println("method", mname, "has wrong number of outs:", mtype.NumOut())
+ continue
+ }
+ if returnType := mtype.Out(0); returnType != typeOfOsError {
+ log.Println("method", mname, "returns", returnType.String(), "not os.Error")
+ continue
+ }
+ s.method[mname] = &methodType{method: method, ArgType: argType, ReplyType: replyType}
+ }
+
+ if len(s.method) == 0 {
+ s := "rpc Register: type " + sname + " has no exported methods of suitable type"
+ log.Print(s)
+ return os.ErrorString(s)
+ }
+ server.serviceMap[s.name] = s
+ return nil
+}
+
+// A value sent as a placeholder for the response when the server receives an invalid request.
+type InvalidRequest struct {
+ marker int
+}
+
+var invalidRequest = InvalidRequest{1}
+
+func _new(t *reflect.PtrType) *reflect.PtrValue {
+ v := reflect.MakeZero(t).(*reflect.PtrValue)
+ v.PointTo(reflect.MakeZero(t.Elem()))
+ return v
+}
+
+func sendResponse(sending *sync.Mutex, req *Request, reply interface{}, codec ServerCodec, errmsg string) {
+ resp := new(Response)
+ // Encode the response header
+ resp.ServiceMethod = req.ServiceMethod
+ if errmsg != "" {
+ resp.Error = errmsg
+ }
+ resp.Seq = req.Seq
+ sending.Lock()
+ err := codec.WriteResponse(resp, reply)
+ if err != nil {
+ log.Println("rpc: writing response:", err)
+ }
+ sending.Unlock()
+}
+
+func (m *methodType) NumCalls() (n uint) {
+ m.Lock()
+ n = m.numCalls
+ m.Unlock()
+ return n
+}
+
+func (s *service) call(sending *sync.Mutex, mtype *methodType, req *Request, argv, replyv reflect.Value, codec ServerCodec) {
+ mtype.Lock()
+ mtype.numCalls++
+ mtype.Unlock()
+ function := mtype.method.Func
+ // Invoke the method, providing a new value for the reply.
+ returnValues := function.Call([]reflect.Value{s.rcvr, argv, replyv})
+ // The return value for the method is an os.Error.
+ errInter := returnValues[0].Interface()
+ errmsg := ""
+ if errInter != nil {
+ errmsg = errInter.(os.Error).String()
+ }
+ sendResponse(sending, req, replyv.Interface(), codec, errmsg)
+}
+
+type gobServerCodec struct {
+ rwc io.ReadWriteCloser
+ dec *gob.Decoder
+ enc *gob.Encoder
+}
+
+func (c *gobServerCodec) ReadRequestHeader(r *Request) os.Error {
+ return c.dec.Decode(r)
+}
+
+func (c *gobServerCodec) ReadRequestBody(body interface{}) os.Error {
+ return c.dec.Decode(body)
+}
+
+func (c *gobServerCodec) WriteResponse(r *Response, body interface{}) os.Error {
+ if err := c.enc.Encode(r); err != nil {
+ return err
+ }
+ return c.enc.Encode(body)
+}
+
+func (c *gobServerCodec) Close() os.Error {
+ return c.rwc.Close()
+}
+
+
+// ServeConn runs the server on a single connection.
+// ServeConn blocks, serving the connection until the client hangs up.
+// The caller typically invokes ServeConn in a go statement.
+// ServeConn uses the gob wire format (see package gob) on the
+// connection. To use an alternate codec, use ServeCodec.
+func (server *Server) ServeConn(conn io.ReadWriteCloser) {
+ server.ServeCodec(&gobServerCodec{conn, gob.NewDecoder(conn), gob.NewEncoder(conn)})
+}
+
+// ServeCodec is like ServeConn but uses the specified codec to
+// decode requests and encode responses.
+func (server *Server) ServeCodec(codec ServerCodec) {
+ sending := new(sync.Mutex)
+ for {
+ // Grab the request header.
+ req := new(Request)
+ err := codec.ReadRequestHeader(req)
+ if err != nil {
+ if err == os.EOF || err == io.ErrUnexpectedEOF {
+ if err == io.ErrUnexpectedEOF {
+ log.Println("rpc:", err)
+ }
+ break
+ }
+ s := "rpc: server cannot decode request: " + err.String()
+ sendResponse(sending, req, invalidRequest, codec, s)
+ break
+ }
+ serviceMethod := strings.Split(req.ServiceMethod, ".", -1)
+ if len(serviceMethod) != 2 {
+ s := "rpc: service/method request ill-formed: " + req.ServiceMethod
+ sendResponse(sending, req, invalidRequest, codec, s)
+ continue
+ }
+ // Look up the request.
+ server.Lock()
+ service, ok := server.serviceMap[serviceMethod[0]]
+ server.Unlock()
+ if !ok {
+ s := "rpc: can't find service " + req.ServiceMethod
+ sendResponse(sending, req, invalidRequest, codec, s)
+ continue
+ }
+ mtype, ok := service.method[serviceMethod[1]]
+ if !ok {
+ s := "rpc: can't find method " + req.ServiceMethod
+ sendResponse(sending, req, invalidRequest, codec, s)
+ continue
+ }
+ // Decode the argument value.
+ argv := _new(mtype.ArgType)
+ replyv := _new(mtype.ReplyType)
+ err = codec.ReadRequestBody(argv.Interface())
+ if err != nil {
+ log.Println("rpc: tearing down", serviceMethod[0], "connection:", err)
+ sendResponse(sending, req, replyv.Interface(), codec, err.String())
+ break
+ }
+ go service.call(sending, mtype, req, argv, replyv, codec)
+ }
+ codec.Close()
+}
+
+// Accept accepts connections on the listener and serves requests
+// for each incoming connection. Accept blocks; the caller typically
+// invokes it in a go statement.
+func (server *Server) Accept(lis net.Listener) {
+ for {
+ conn, err := lis.Accept()
+ if err != nil {
+ log.Exit("rpc.Serve: accept:", err.String()) // TODO(r): exit?
+ }
+ go server.ServeConn(conn)
+ }
+}
+
+// Register publishes the receiver's methods in the DefaultServer.
+func Register(rcvr interface{}) os.Error { return DefaultServer.Register(rcvr) }
+
+// RegisterName is like Register but uses the provided name for the type
+// instead of the receiver's concrete type.
+func RegisterName(name string, rcvr interface{}) os.Error {
+ return DefaultServer.RegisterName(name, rcvr)
+}
+
+// A ServerCodec implements reading of RPC requests and writing of
+// RPC responses for the server side of an RPC session.
+// The server calls ReadRequestHeader and ReadRequestBody in pairs
+// to read requests from the connection, and it calls WriteResponse to
+// write a response back. The server calls Close when finished with the
+// connection.
+type ServerCodec interface {
+ ReadRequestHeader(*Request) os.Error
+ ReadRequestBody(interface{}) os.Error
+ WriteResponse(*Response, interface{}) os.Error
+
+ Close() os.Error
+}
+
+// ServeConn runs the DefaultServer on a single connection.
+// ServeConn blocks, serving the connection until the client hangs up.
+// The caller typically invokes ServeConn in a go statement.
+// ServeConn uses the gob wire format (see package gob) on the
+// connection. To use an alternate codec, use ServeCodec.
+func ServeConn(conn io.ReadWriteCloser) {
+ DefaultServer.ServeConn(conn)
+}
+
+// ServeCodec is like ServeConn but uses the specified codec to
+// decode requests and encode responses.
+func ServeCodec(codec ServerCodec) {
+ DefaultServer.ServeCodec(codec)
+}
+
+// Accept accepts connections on the listener and serves requests
+// to DefaultServer for each incoming connection.
+// Accept blocks; the caller typically invokes it in a go statement.
+func Accept(lis net.Listener) { DefaultServer.Accept(lis) }
+
+// Can connect to RPC service using HTTP CONNECT to rpcPath.
+var connected = "200 Connected to Go RPC"
+
+// ServeHTTP implements an http.Handler that answers RPC requests.
+func (server *Server) ServeHTTP(w http.ResponseWriter, req *http.Request) {
+ if req.Method != "CONNECT" {
+ w.SetHeader("Content-Type", "text/plain; charset=utf-8")
+ w.WriteHeader(http.StatusMethodNotAllowed)
+ io.WriteString(w, "405 must CONNECT\n")
+ return
+ }
+ conn, _, err := w.Hijack()
+ if err != nil {
+ log.Print("rpc hijacking ", w.RemoteAddr(), ": ", err.String())
+ return
+ }
+ io.WriteString(conn, "HTTP/1.0 "+connected+"\n\n")
+ server.ServeConn(conn)
+}
+
+// HandleHTTP registers an HTTP handler for RPC messages on rpcPath,
+// and a debugging handler on debugPath.
+// It is still necessary to invoke http.Serve(), typically in a go statement.
+func (server *Server) HandleHTTP(rpcPath, debugPath string) {
+ http.Handle(rpcPath, server)
+ http.Handle(debugPath, debugHTTP{server})
+}
+
+// HandleHTTP registers an HTTP handler for RPC messages to DefaultServer
+// on DefaultRPCPath and a debugging handler on DefaultDebugPath.
+// It is still necessary to invoke http.Serve(), typically in a go statement.
+func HandleHTTP() {
+ DefaultServer.HandleHTTP(DefaultRPCPath, DefaultDebugPath)
+}
diff --git a/libgo/go/rpc/server_test.go b/libgo/go/rpc/server_test.go
new file mode 100644
index 000000000..355d51ce4
--- /dev/null
+++ b/libgo/go/rpc/server_test.go
@@ -0,0 +1,384 @@
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package rpc
+
+import (
+ "fmt"
+ "http"
+ "log"
+ "net"
+ "os"
+ "strings"
+ "sync"
+ "testing"
+ "time"
+)
+
+var (
+ serverAddr, newServerAddr string
+ httpServerAddr string
+ once, newOnce, httpOnce sync.Once
+)
+
+const (
+ second = 1e9
+ newHttpPath = "/foo"
+)
+
+type Args struct {
+ A, B int
+}
+
+type Reply struct {
+ C int
+}
+
+type Arith int
+
+func (t *Arith) Add(args *Args, reply *Reply) os.Error {
+ reply.C = args.A + args.B
+ return nil
+}
+
+func (t *Arith) Mul(args *Args, reply *Reply) os.Error {
+ reply.C = args.A * args.B
+ return nil
+}
+
+func (t *Arith) Div(args *Args, reply *Reply) os.Error {
+ if args.B == 0 {
+ return os.ErrorString("divide by zero")
+ }
+ reply.C = args.A / args.B
+ return nil
+}
+
+func (t *Arith) String(args *Args, reply *string) os.Error {
+ *reply = fmt.Sprintf("%d+%d=%d", args.A, args.B, args.A+args.B)
+ return nil
+}
+
+func (t *Arith) Scan(args *string, reply *Reply) (err os.Error) {
+ _, err = fmt.Sscan(*args, &reply.C)
+ return
+}
+
+func (t *Arith) Error(args *Args, reply *Reply) os.Error {
+ panic("ERROR")
+}
+
+func listenTCP() (net.Listener, string) {
+ l, e := net.Listen("tcp", "127.0.0.1:0") // any available address
+ if e != nil {
+ log.Exitf("net.Listen tcp :0: %v", e)
+ }
+ return l, l.Addr().String()
+}
+
+func startServer() {
+ Register(new(Arith))
+
+ var l net.Listener
+ l, serverAddr = listenTCP()
+ log.Println("Test RPC server listening on", serverAddr)
+ go Accept(l)
+
+ HandleHTTP()
+ httpOnce.Do(startHttpServer)
+}
+
+func startNewServer() {
+ s := NewServer()
+ s.Register(new(Arith))
+
+ var l net.Listener
+ l, newServerAddr = listenTCP()
+ log.Println("NewServer test RPC server listening on", newServerAddr)
+ go Accept(l)
+
+ s.HandleHTTP(newHttpPath, "/bar")
+ httpOnce.Do(startHttpServer)
+}
+
+func startHttpServer() {
+ var l net.Listener
+ l, httpServerAddr = listenTCP()
+ httpServerAddr = l.Addr().String()
+ log.Println("Test HTTP RPC server listening on", httpServerAddr)
+ go http.Serve(l, nil)
+}
+
+func TestRPC(t *testing.T) {
+ once.Do(startServer)
+ testRPC(t, serverAddr)
+ newOnce.Do(startNewServer)
+ testRPC(t, newServerAddr)
+}
+
+func testRPC(t *testing.T, addr string) {
+ client, err := Dial("tcp", addr)
+ if err != nil {
+ t.Fatal("dialing", err)
+ }
+
+ // Synchronous calls
+ args := &Args{7, 8}
+ reply := new(Reply)
+ err = client.Call("Arith.Add", args, reply)
+ if err != nil {
+ t.Errorf("Add: expected no error but got string %q", err.String())
+ }
+ if reply.C != args.A+args.B {
+ t.Errorf("Add: expected %d got %d", reply.C, args.A+args.B)
+ }
+
+ args = &Args{7, 8}
+ reply = new(Reply)
+ err = client.Call("Arith.Mul", args, reply)
+ if err != nil {
+ t.Errorf("Mul: expected no error but got string %q", err.String())
+ }
+ if reply.C != args.A*args.B {
+ t.Errorf("Mul: expected %d got %d", reply.C, args.A*args.B)
+ }
+
+ // Out of order.
+ args = &Args{7, 8}
+ mulReply := new(Reply)
+ mulCall := client.Go("Arith.Mul", args, mulReply, nil)
+ addReply := new(Reply)
+ addCall := client.Go("Arith.Add", args, addReply, nil)
+
+ addCall = <-addCall.Done
+ if addCall.Error != nil {
+ t.Errorf("Add: expected no error but got string %q", addCall.Error.String())
+ }
+ if addReply.C != args.A+args.B {
+ t.Errorf("Add: expected %d got %d", addReply.C, args.A+args.B)
+ }
+
+ mulCall = <-mulCall.Done
+ if mulCall.Error != nil {
+ t.Errorf("Mul: expected no error but got string %q", mulCall.Error.String())
+ }
+ if mulReply.C != args.A*args.B {
+ t.Errorf("Mul: expected %d got %d", mulReply.C, args.A*args.B)
+ }
+
+ // Error test
+ args = &Args{7, 0}
+ reply = new(Reply)
+ err = client.Call("Arith.Div", args, reply)
+ // expect an error: zero divide
+ if err == nil {
+ t.Error("Div: expected error")
+ } else if err.String() != "divide by zero" {
+ t.Error("Div: expected divide by zero error; got", err)
+ }
+
+ // Non-struct argument
+ const Val = 12345
+ str := fmt.Sprint(Val)
+ reply = new(Reply)
+ err = client.Call("Arith.Scan", &str, reply)
+ if err != nil {
+ t.Errorf("Scan: expected no error but got string %q", err.String())
+ } else if reply.C != Val {
+ t.Errorf("Scan: expected %d got %d", Val, reply.C)
+ }
+
+ // Non-struct reply
+ args = &Args{27, 35}
+ str = ""
+ err = client.Call("Arith.String", args, &str)
+ if err != nil {
+ t.Errorf("String: expected no error but got string %q", err.String())
+ }
+ expect := fmt.Sprintf("%d+%d=%d", args.A, args.B, args.A+args.B)
+ if str != expect {
+ t.Errorf("String: expected %s got %s", expect, str)
+ }
+}
+
+func TestHTTPRPC(t *testing.T) {
+ once.Do(startServer)
+ testHTTPRPC(t, "")
+ newOnce.Do(startNewServer)
+ testHTTPRPC(t, newHttpPath)
+}
+
+func testHTTPRPC(t *testing.T, path string) {
+ var client *Client
+ var err os.Error
+ if path == "" {
+ client, err = DialHTTP("tcp", httpServerAddr)
+ } else {
+ client, err = DialHTTPPath("tcp", httpServerAddr, path)
+ }
+ if err != nil {
+ t.Fatal("dialing", err)
+ }
+
+ // Synchronous calls
+ args := &Args{7, 8}
+ reply := new(Reply)
+ err = client.Call("Arith.Add", args, reply)
+ if err != nil {
+ t.Errorf("Add: expected no error but got string %q", err.String())
+ }
+ if reply.C != args.A+args.B {
+ t.Errorf("Add: expected %d got %d", reply.C, args.A+args.B)
+ }
+}
+
+func TestCheckUnknownService(t *testing.T) {
+ once.Do(startServer)
+
+ conn, err := net.Dial("tcp", "", serverAddr)
+ if err != nil {
+ t.Fatal("dialing:", err)
+ }
+
+ client := NewClient(conn)
+
+ args := &Args{7, 8}
+ reply := new(Reply)
+ err = client.Call("Unknown.Add", args, reply)
+ if err == nil {
+ t.Error("expected error calling unknown service")
+ } else if strings.Index(err.String(), "service") < 0 {
+ t.Error("expected error about service; got", err)
+ }
+}
+
+func TestCheckUnknownMethod(t *testing.T) {
+ once.Do(startServer)
+
+ conn, err := net.Dial("tcp", "", serverAddr)
+ if err != nil {
+ t.Fatal("dialing:", err)
+ }
+
+ client := NewClient(conn)
+
+ args := &Args{7, 8}
+ reply := new(Reply)
+ err = client.Call("Arith.Unknown", args, reply)
+ if err == nil {
+ t.Error("expected error calling unknown service")
+ } else if strings.Index(err.String(), "method") < 0 {
+ t.Error("expected error about method; got", err)
+ }
+}
+
+func TestCheckBadType(t *testing.T) {
+ once.Do(startServer)
+
+ conn, err := net.Dial("tcp", "", serverAddr)
+ if err != nil {
+ t.Fatal("dialing:", err)
+ }
+
+ client := NewClient(conn)
+
+ reply := new(Reply)
+ err = client.Call("Arith.Add", reply, reply) // args, reply would be the correct thing to use
+ if err == nil {
+ t.Error("expected error calling Arith.Add with wrong arg type")
+ } else if strings.Index(err.String(), "type") < 0 {
+ t.Error("expected error about type; got", err)
+ }
+}
+
+type ArgNotPointer int
+type ReplyNotPointer int
+type ArgNotPublic int
+type ReplyNotPublic int
+type local struct{}
+
+func (t *ArgNotPointer) ArgNotPointer(args Args, reply *Reply) os.Error {
+ return nil
+}
+
+func (t *ReplyNotPointer) ReplyNotPointer(args *Args, reply Reply) os.Error {
+ return nil
+}
+
+func (t *ArgNotPublic) ArgNotPublic(args *local, reply *Reply) os.Error {
+ return nil
+}
+
+func (t *ReplyNotPublic) ReplyNotPublic(args *Args, reply *local) os.Error {
+ return nil
+}
+
+// Check that registration handles lots of bad methods and a type with no suitable methods.
+func TestRegistrationError(t *testing.T) {
+ err := Register(new(ArgNotPointer))
+ if err == nil {
+ t.Errorf("expected error registering ArgNotPointer")
+ }
+ err = Register(new(ReplyNotPointer))
+ if err == nil {
+ t.Errorf("expected error registering ReplyNotPointer")
+ }
+ err = Register(new(ArgNotPublic))
+ if err == nil {
+ t.Errorf("expected error registering ArgNotPublic")
+ }
+ err = Register(new(ReplyNotPublic))
+ if err == nil {
+ t.Errorf("expected error registering ReplyNotPublic")
+ }
+}
+
+type WriteFailCodec int
+
+func (WriteFailCodec) WriteRequest(*Request, interface{}) os.Error {
+ // the panic caused by this error used to not unlock a lock.
+ return os.NewError("fail")
+}
+
+func (WriteFailCodec) ReadResponseHeader(*Response) os.Error {
+ time.Sleep(60e9)
+ panic("unreachable")
+}
+
+func (WriteFailCodec) ReadResponseBody(interface{}) os.Error {
+ time.Sleep(60e9)
+ panic("unreachable")
+}
+
+func (WriteFailCodec) Close() os.Error {
+ return nil
+}
+
+func TestSendDeadlock(t *testing.T) {
+ client := NewClientWithCodec(WriteFailCodec(0))
+
+ done := make(chan bool)
+ go func() {
+ testSendDeadlock(client)
+ testSendDeadlock(client)
+ done <- true
+ }()
+ for i := 0; i < 50; i++ {
+ time.Sleep(100 * 1e6)
+ _, ok := <-done
+ if ok {
+ return
+ }
+ }
+ t.Fatal("deadlock")
+}
+
+func testSendDeadlock(client *Client) {
+ defer func() {
+ recover()
+ }()
+ args := &Args{7, 8}
+ reply := new(Reply)
+ client.Call("Arith.Add", args, reply)
+}