diff options
Diffstat (limited to 'libgo/go/rpc')
-rw-r--r-- | libgo/go/rpc/client.go | 250 | ||||
-rw-r--r-- | libgo/go/rpc/debug.go | 90 | ||||
-rw-r--r-- | libgo/go/rpc/jsonrpc/all_test.go | 156 | ||||
-rw-r--r-- | libgo/go/rpc/jsonrpc/client.go | 121 | ||||
-rw-r--r-- | libgo/go/rpc/jsonrpc/server.go | 133 | ||||
-rw-r--r-- | libgo/go/rpc/server.go | 530 | ||||
-rw-r--r-- | libgo/go/rpc/server_test.go | 384 |
7 files changed, 1664 insertions, 0 deletions
diff --git a/libgo/go/rpc/client.go b/libgo/go/rpc/client.go new file mode 100644 index 000000000..601c49715 --- /dev/null +++ b/libgo/go/rpc/client.go @@ -0,0 +1,250 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package rpc + +import ( + "bufio" + "gob" + "http" + "io" + "log" + "net" + "os" + "sync" +) + +// Call represents an active RPC. +type Call struct { + ServiceMethod string // The name of the service and method to call. + Args interface{} // The argument to the function (*struct). + Reply interface{} // The reply from the function (*struct). + Error os.Error // After completion, the error status. + Done chan *Call // Strobes when call is complete; value is the error status. + seq uint64 +} + +// Client represents an RPC Client. +// There may be multiple outstanding Calls associated +// with a single Client. +type Client struct { + mutex sync.Mutex // protects pending, seq + shutdown os.Error // non-nil if the client is shut down + sending sync.Mutex + seq uint64 + codec ClientCodec + pending map[uint64]*Call + closing bool +} + +// A ClientCodec implements writing of RPC requests and +// reading of RPC responses for the client side of an RPC session. +// The client calls WriteRequest to write a request to the connection +// and calls ReadResponseHeader and ReadResponseBody in pairs +// to read responses. The client calls Close when finished with the +// connection. +type ClientCodec interface { + WriteRequest(*Request, interface{}) os.Error + ReadResponseHeader(*Response) os.Error + ReadResponseBody(interface{}) os.Error + + Close() os.Error +} + +func (client *Client) send(c *Call) { + // Register this call. + client.mutex.Lock() + if client.shutdown != nil { + c.Error = client.shutdown + client.mutex.Unlock() + _ = c.Done <- c // do not block + return + } + c.seq = client.seq + client.seq++ + client.pending[c.seq] = c + client.mutex.Unlock() + + // Encode and send the request. + request := new(Request) + client.sending.Lock() + defer client.sending.Unlock() + request.Seq = c.seq + request.ServiceMethod = c.ServiceMethod + if err := client.codec.WriteRequest(request, c.Args); err != nil { + panic("rpc: client encode error: " + err.String()) + } +} + +func (client *Client) input() { + var err os.Error + for err == nil { + response := new(Response) + err = client.codec.ReadResponseHeader(response) + if err != nil { + if err == os.EOF && !client.closing { + err = io.ErrUnexpectedEOF + } + break + } + seq := response.Seq + client.mutex.Lock() + c := client.pending[seq] + client.pending[seq] = c, false + client.mutex.Unlock() + err = client.codec.ReadResponseBody(c.Reply) + if response.Error != "" { + c.Error = os.ErrorString(response.Error) + } else if err != nil { + c.Error = err + } else { + // Empty strings should turn into nil os.Errors + c.Error = nil + } + // We don't want to block here. It is the caller's responsibility to make + // sure the channel has enough buffer space. See comment in Go(). + _ = c.Done <- c // do not block + } + // Terminate pending calls. + client.mutex.Lock() + client.shutdown = err + for _, call := range client.pending { + call.Error = err + _ = call.Done <- call // do not block + } + client.mutex.Unlock() + if err != os.EOF || !client.closing { + log.Println("rpc: client protocol error:", err) + } +} + +// NewClient returns a new Client to handle requests to the +// set of services at the other end of the connection. +func NewClient(conn io.ReadWriteCloser) *Client { + return NewClientWithCodec(&gobClientCodec{conn, gob.NewDecoder(conn), gob.NewEncoder(conn)}) +} + +// NewClientWithCodec is like NewClient but uses the specified +// codec to encode requests and decode responses. +func NewClientWithCodec(codec ClientCodec) *Client { + client := &Client{ + codec: codec, + pending: make(map[uint64]*Call), + } + go client.input() + return client +} + +type gobClientCodec struct { + rwc io.ReadWriteCloser + dec *gob.Decoder + enc *gob.Encoder +} + +func (c *gobClientCodec) WriteRequest(r *Request, body interface{}) os.Error { + if err := c.enc.Encode(r); err != nil { + return err + } + return c.enc.Encode(body) +} + +func (c *gobClientCodec) ReadResponseHeader(r *Response) os.Error { + return c.dec.Decode(r) +} + +func (c *gobClientCodec) ReadResponseBody(body interface{}) os.Error { + return c.dec.Decode(body) +} + +func (c *gobClientCodec) Close() os.Error { + return c.rwc.Close() +} + + +// DialHTTP connects to an HTTP RPC server at the specified network address +// listening on the default HTTP RPC path. +func DialHTTP(network, address string) (*Client, os.Error) { + return DialHTTPPath(network, address, DefaultRPCPath) +} + +// DialHTTPPath connects to an HTTP RPC server +// at the specified network address and path. +func DialHTTPPath(network, address, path string) (*Client, os.Error) { + var err os.Error + conn, err := net.Dial(network, "", address) + if err != nil { + return nil, err + } + io.WriteString(conn, "CONNECT "+path+" HTTP/1.0\n\n") + + // Require successful HTTP response + // before switching to RPC protocol. + resp, err := http.ReadResponse(bufio.NewReader(conn), "CONNECT") + if err == nil && resp.Status == connected { + return NewClient(conn), nil + } + if err == nil { + err = os.ErrorString("unexpected HTTP response: " + resp.Status) + } + conn.Close() + return nil, &net.OpError{"dial-http", network + " " + address, nil, err} +} + +// Dial connects to an RPC server at the specified network address. +func Dial(network, address string) (*Client, os.Error) { + conn, err := net.Dial(network, "", address) + if err != nil { + return nil, err + } + return NewClient(conn), nil +} + +func (client *Client) Close() os.Error { + if client.shutdown != nil || client.closing { + return os.ErrorString("rpc: already closed") + } + client.mutex.Lock() + client.closing = true + client.mutex.Unlock() + return client.codec.Close() +} + +// Go invokes the function asynchronously. It returns the Call structure representing +// the invocation. The done channel will signal when the call is complete by returning +// the same Call object. If done is nil, Go will allocate a new channel. +// If non-nil, done must be buffered or Go will deliberately crash. +func (client *Client) Go(serviceMethod string, args interface{}, reply interface{}, done chan *Call) *Call { + c := new(Call) + c.ServiceMethod = serviceMethod + c.Args = args + c.Reply = reply + if done == nil { + done = make(chan *Call, 10) // buffered. + } else { + // If caller passes done != nil, it must arrange that + // done has enough buffer for the number of simultaneous + // RPCs that will be using that channel. If the channel + // is totally unbuffered, it's best not to run at all. + if cap(done) == 0 { + log.Panic("rpc: done channel is unbuffered") + } + } + c.Done = done + if client.shutdown != nil { + c.Error = client.shutdown + _ = c.Done <- c // do not block + return c + } + client.send(c) + return c +} + +// Call invokes the named function, waits for it to complete, and returns its error status. +func (client *Client) Call(serviceMethod string, args interface{}, reply interface{}) os.Error { + if client.shutdown != nil { + return client.shutdown + } + call := <-client.Go(serviceMethod, args, reply, nil).Done + return call.Error +} diff --git a/libgo/go/rpc/debug.go b/libgo/go/rpc/debug.go new file mode 100644 index 000000000..44b32e04b --- /dev/null +++ b/libgo/go/rpc/debug.go @@ -0,0 +1,90 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package rpc + +/* + Some HTML presented at http://machine:port/debug/rpc + Lists services, their methods, and some statistics, still rudimentary. +*/ + +import ( + "fmt" + "http" + "sort" + "template" +) + +const debugText = `<html> + <body> + <title>Services</title> + {.repeated section @} + <hr> + Service {Name} + <hr> + <table> + <th align=center>Method</th><th align=center>Calls</th> + {.repeated section Method} + <tr> + <td align=left font=fixed>{Name}({Type.ArgType}, {Type.ReplyType}) os.Error</td> + <td align=center>{Type.NumCalls}</td> + </tr> + {.end} + </table> + {.end} + </body> + </html>` + +var debug = template.MustParse(debugText, nil) + +type debugMethod struct { + Type *methodType + Name string +} + +type methodArray []debugMethod + +type debugService struct { + Service *service + Name string + Method methodArray +} + +type serviceArray []debugService + +func (s serviceArray) Len() int { return len(s) } +func (s serviceArray) Less(i, j int) bool { return s[i].Name < s[j].Name } +func (s serviceArray) Swap(i, j int) { s[i], s[j] = s[j], s[i] } + +func (m methodArray) Len() int { return len(m) } +func (m methodArray) Less(i, j int) bool { return m[i].Name < m[j].Name } +func (m methodArray) Swap(i, j int) { m[i], m[j] = m[j], m[i] } + +type debugHTTP struct { + *Server +} + +// Runs at /debug/rpc +func (server debugHTTP) ServeHTTP(w http.ResponseWriter, req *http.Request) { + // Build a sorted version of the data. + var services = make(serviceArray, len(server.serviceMap)) + i := 0 + server.Lock() + for sname, service := range server.serviceMap { + services[i] = debugService{service, sname, make(methodArray, len(service.method))} + j := 0 + for mname, method := range service.method { + services[i].Method[j] = debugMethod{method, mname} + j++ + } + sort.Sort(services[i].Method) + i++ + } + server.Unlock() + sort.Sort(services) + err := debug.Execute(services, w) + if err != nil { + fmt.Fprintln(w, "rpc: error executing template:", err.String()) + } +} diff --git a/libgo/go/rpc/jsonrpc/all_test.go b/libgo/go/rpc/jsonrpc/all_test.go new file mode 100644 index 000000000..764ee7ff3 --- /dev/null +++ b/libgo/go/rpc/jsonrpc/all_test.go @@ -0,0 +1,156 @@ +// Copyright 2010 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package jsonrpc + +import ( + "fmt" + "json" + "net" + "os" + "rpc" + "testing" +) + +type Args struct { + A, B int +} + +type Reply struct { + C int +} + +type Arith int + +func (t *Arith) Add(args *Args, reply *Reply) os.Error { + reply.C = args.A + args.B + return nil +} + +func (t *Arith) Mul(args *Args, reply *Reply) os.Error { + reply.C = args.A * args.B + return nil +} + +func (t *Arith) Div(args *Args, reply *Reply) os.Error { + if args.B == 0 { + return os.ErrorString("divide by zero") + } + reply.C = args.A / args.B + return nil +} + +func (t *Arith) Error(args *Args, reply *Reply) os.Error { + panic("ERROR") +} + +func init() { + rpc.Register(new(Arith)) +} + +func TestServer(t *testing.T) { + type addResp struct { + Id interface{} "id" + Result Reply "result" + Error interface{} "error" + } + + cli, srv := net.Pipe() + defer cli.Close() + go ServeConn(srv) + dec := json.NewDecoder(cli) + + // Send hand-coded requests to server, parse responses. + for i := 0; i < 10; i++ { + fmt.Fprintf(cli, `{"method": "Arith.Add", "id": "\u%04d", "params": [{"A": %d, "B": %d}]}`, i, i, i+1) + var resp addResp + err := dec.Decode(&resp) + if err != nil { + t.Fatalf("Decode: %s", err) + } + if resp.Error != nil { + t.Fatalf("resp.Error: %s", resp.Error) + } + if resp.Id.(string) != string(i) { + t.Fatalf("resp: bad id %q want %q", resp.Id.(string), string(i)) + } + if resp.Result.C != 2*i+1 { + t.Fatalf("resp: bad result: %d+%d=%d", i, i+1, resp.Result.C) + } + } + + fmt.Fprintf(cli, "{}\n") + var resp addResp + if err := dec.Decode(&resp); err != nil { + t.Fatalf("Decode after empty: %s", err) + } + if resp.Error == nil { + t.Fatalf("Expected error, got nil") + } +} + +func TestClient(t *testing.T) { + // Assume server is okay (TestServer is above). + // Test client against server. + cli, srv := net.Pipe() + go ServeConn(srv) + + client := NewClient(cli) + defer client.Close() + + // Synchronous calls + args := &Args{7, 8} + reply := new(Reply) + err := client.Call("Arith.Add", args, reply) + if err != nil { + t.Errorf("Add: expected no error but got string %q", err.String()) + } + if reply.C != args.A+args.B { + t.Errorf("Add: expected %d got %d", reply.C, args.A+args.B) + } + + args = &Args{7, 8} + reply = new(Reply) + err = client.Call("Arith.Mul", args, reply) + if err != nil { + t.Errorf("Mul: expected no error but got string %q", err.String()) + } + if reply.C != args.A*args.B { + t.Errorf("Mul: expected %d got %d", reply.C, args.A*args.B) + } + + // Out of order. + args = &Args{7, 8} + mulReply := new(Reply) + mulCall := client.Go("Arith.Mul", args, mulReply, nil) + addReply := new(Reply) + addCall := client.Go("Arith.Add", args, addReply, nil) + + addCall = <-addCall.Done + if addCall.Error != nil { + t.Errorf("Add: expected no error but got string %q", addCall.Error.String()) + } + if addReply.C != args.A+args.B { + t.Errorf("Add: expected %d got %d", addReply.C, args.A+args.B) + } + + mulCall = <-mulCall.Done + if mulCall.Error != nil { + t.Errorf("Mul: expected no error but got string %q", mulCall.Error.String()) + } + if mulReply.C != args.A*args.B { + t.Errorf("Mul: expected %d got %d", mulReply.C, args.A*args.B) + } + + // Error test + args = &Args{7, 0} + reply = new(Reply) + err = client.Call("Arith.Div", args, reply) + // expect an error: zero divide + if err == nil { + t.Error("Div: expected error") + } else if err.String() != "divide by zero" { + t.Error("Div: expected divide by zero error; got", err) + } +} diff --git a/libgo/go/rpc/jsonrpc/client.go b/libgo/go/rpc/jsonrpc/client.go new file mode 100644 index 000000000..dcaa69f9d --- /dev/null +++ b/libgo/go/rpc/jsonrpc/client.go @@ -0,0 +1,121 @@ +// Copyright 2010 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package jsonrpc implements a JSON-RPC ClientCodec and ServerCodec +// for the rpc package. +package jsonrpc + +import ( + "fmt" + "io" + "json" + "net" + "os" + "rpc" + "sync" +) + +type clientCodec struct { + dec *json.Decoder // for reading JSON values + enc *json.Encoder // for writing JSON values + c io.Closer + + // temporary work space + req clientRequest + resp clientResponse + + // JSON-RPC responses include the request id but not the request method. + // Package rpc expects both. + // We save the request method in pending when sending a request + // and then look it up by request ID when filling out the rpc Response. + mutex sync.Mutex // protects pending + pending map[uint64]string // map request id to method name +} + +// NewClientCodec returns a new rpc.ClientCodec using JSON-RPC on conn. +func NewClientCodec(conn io.ReadWriteCloser) rpc.ClientCodec { + return &clientCodec{ + dec: json.NewDecoder(conn), + enc: json.NewEncoder(conn), + c: conn, + pending: make(map[uint64]string), + } +} + +type clientRequest struct { + Method string "method" + Params [1]interface{} "params" + Id uint64 "id" +} + +func (c *clientCodec) WriteRequest(r *rpc.Request, param interface{}) os.Error { + c.mutex.Lock() + c.pending[r.Seq] = r.ServiceMethod + c.mutex.Unlock() + c.req.Method = r.ServiceMethod + c.req.Params[0] = param + c.req.Id = r.Seq + return c.enc.Encode(&c.req) +} + +type clientResponse struct { + Id uint64 "id" + Result *json.RawMessage "result" + Error interface{} "error" +} + +func (r *clientResponse) reset() { + r.Id = 0 + r.Result = nil + r.Error = nil +} + +func (c *clientCodec) ReadResponseHeader(r *rpc.Response) os.Error { + c.resp.reset() + if err := c.dec.Decode(&c.resp); err != nil { + return err + } + + c.mutex.Lock() + r.ServiceMethod = c.pending[c.resp.Id] + c.pending[c.resp.Id] = "", false + c.mutex.Unlock() + + r.Error = "" + r.Seq = c.resp.Id + if c.resp.Error != nil { + x, ok := c.resp.Error.(string) + if !ok { + return fmt.Errorf("invalid error %v", c.resp.Error) + } + if x == "" { + x = "unspecified error" + } + r.Error = x + } + return nil +} + +func (c *clientCodec) ReadResponseBody(x interface{}) os.Error { + return json.Unmarshal(*c.resp.Result, x) +} + +func (c *clientCodec) Close() os.Error { + return c.c.Close() +} + +// NewClient returns a new rpc.Client to handle requests to the +// set of services at the other end of the connection. +func NewClient(conn io.ReadWriteCloser) *rpc.Client { + return rpc.NewClientWithCodec(NewClientCodec(conn)) +} + +// Dial connects to a JSON-RPC server at the specified network address. +func Dial(network, address string) (*rpc.Client, os.Error) { + conn, err := net.Dial(network, "", address) + if err != nil { + return nil, err + } + return NewClient(conn), err +} diff --git a/libgo/go/rpc/jsonrpc/server.go b/libgo/go/rpc/jsonrpc/server.go new file mode 100644 index 000000000..bf53bda8d --- /dev/null +++ b/libgo/go/rpc/jsonrpc/server.go @@ -0,0 +1,133 @@ +// Copyright 2010 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package jsonrpc + +import ( + "io" + "json" + "os" + "rpc" + "sync" +) + +type serverCodec struct { + dec *json.Decoder // for reading JSON values + enc *json.Encoder // for writing JSON values + c io.Closer + + // temporary work space + req serverRequest + resp serverResponse + + // JSON-RPC clients can use arbitrary json values as request IDs. + // Package rpc expects uint64 request IDs. + // We assign uint64 sequence numbers to incoming requests + // but save the original request ID in the pending map. + // When rpc responds, we use the sequence number in + // the response to find the original request ID. + mutex sync.Mutex // protects seq, pending + seq uint64 + pending map[uint64]*json.RawMessage +} + +// NewServerCodec returns a new rpc.ServerCodec using JSON-RPC on conn. +func NewServerCodec(conn io.ReadWriteCloser) rpc.ServerCodec { + return &serverCodec{ + dec: json.NewDecoder(conn), + enc: json.NewEncoder(conn), + c: conn, + pending: make(map[uint64]*json.RawMessage), + } +} + +type serverRequest struct { + Method string "method" + Params *json.RawMessage "params" + Id *json.RawMessage "id" +} + +func (r *serverRequest) reset() { + r.Method = "" + if r.Params != nil { + *r.Params = (*r.Params)[0:0] + } + if r.Id != nil { + *r.Id = (*r.Id)[0:0] + } +} + +type serverResponse struct { + Id *json.RawMessage "id" + Result interface{} "result" + Error interface{} "error" +} + +func (c *serverCodec) ReadRequestHeader(r *rpc.Request) os.Error { + c.req.reset() + if err := c.dec.Decode(&c.req); err != nil { + return err + } + r.ServiceMethod = c.req.Method + + // JSON request id can be any JSON value; + // RPC package expects uint64. Translate to + // internal uint64 and save JSON on the side. + c.mutex.Lock() + c.seq++ + c.pending[c.seq] = c.req.Id + c.req.Id = nil + r.Seq = c.seq + c.mutex.Unlock() + + return nil +} + +func (c *serverCodec) ReadRequestBody(x interface{}) os.Error { + // JSON params is array value. + // RPC params is struct. + // Unmarshal into array containing struct for now. + // Should think about making RPC more general. + var params [1]interface{} + params[0] = x + return json.Unmarshal(*c.req.Params, ¶ms) +} + +var null = json.RawMessage([]byte("null")) + +func (c *serverCodec) WriteResponse(r *rpc.Response, x interface{}) os.Error { + var resp serverResponse + c.mutex.Lock() + b, ok := c.pending[r.Seq] + if !ok { + c.mutex.Unlock() + return os.NewError("invalid sequence number in response") + } + c.pending[r.Seq] = nil, false + c.mutex.Unlock() + + if b == nil { + // Invalid request so no id. Use JSON null. + b = &null + } + resp.Id = b + resp.Result = x + if r.Error == "" { + resp.Error = nil + } else { + resp.Error = r.Error + } + return c.enc.Encode(resp) +} + +func (c *serverCodec) Close() os.Error { + return c.c.Close() +} + +// ServeConn runs the JSON-RPC server on a single connection. +// ServeConn blocks, serving the connection until the client hangs up. +// The caller typically invokes ServeConn in a go statement. +func ServeConn(conn io.ReadWriteCloser) { + rpc.ServeCodec(NewServerCodec(conn)) +} diff --git a/libgo/go/rpc/server.go b/libgo/go/rpc/server.go new file mode 100644 index 000000000..5c50bcc3a --- /dev/null +++ b/libgo/go/rpc/server.go @@ -0,0 +1,530 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +/* + The rpc package provides access to the exported methods of an object across a + network or other I/O connection. A server registers an object, making it visible + as a service with the name of the type of the object. After registration, exported + methods of the object will be accessible remotely. A server may register multiple + objects (services) of different types but it is an error to register multiple + objects of the same type. + + Only methods that satisfy these criteria will be made available for remote access; + other methods will be ignored: + + - the method receiver and name are exported, that is, begin with an upper case letter. + - the method has two arguments, both pointers to exported types. + - the method has return type os.Error. + + The method's first argument represents the arguments provided by the caller; the + second argument represents the result parameters to be returned to the caller. + The method's return value, if non-nil, is passed back as a string that the client + sees as an os.ErrorString. + + The server may handle requests on a single connection by calling ServeConn. More + typically it will create a network listener and call Accept or, for an HTTP + listener, HandleHTTP and http.Serve. + + A client wishing to use the service establishes a connection and then invokes + NewClient on the connection. The convenience function Dial (DialHTTP) performs + both steps for a raw network connection (an HTTP connection). The resulting + Client object has two methods, Call and Go, that specify the service and method to + call, a pointer containing the arguments, and a pointer to receive the result + parameters. + + Call waits for the remote call to complete; Go launches the call asynchronously + and returns a channel that will signal completion. + + Package "gob" is used to transport the data. + + Here is a simple example. A server wishes to export an object of type Arith: + + package server + + type Args struct { + A, B int + } + + type Quotient struct { + Quo, Rem int + } + + type Arith int + + func (t *Arith) Multiply(args *Args, reply *int) os.Error { + *reply = args.A * args.B + return nil + } + + func (t *Arith) Divide(args *Args, quo *Quotient) os.Error { + if args.B == 0 { + return os.ErrorString("divide by zero") + } + quo.Quo = args.A / args.B + quo.Rem = args.A % args.B + return nil + } + + The server calls (for HTTP service): + + arith := new(Arith) + rpc.Register(arith) + rpc.HandleHTTP() + l, e := net.Listen("tcp", ":1234") + if e != nil { + log.Exit("listen error:", e) + } + go http.Serve(l, nil) + + At this point, clients can see a service "Arith" with methods "Arith.Multiply" and + "Arith.Divide". To invoke one, a client first dials the server: + + client, err := rpc.DialHTTP("tcp", serverAddress + ":1234") + if err != nil { + log.Exit("dialing:", err) + } + + Then it can make a remote call: + + // Synchronous call + args := &server.Args{7,8} + var reply int + err = client.Call("Arith.Multiply", args, &reply) + if err != nil { + log.Exit("arith error:", err) + } + fmt.Printf("Arith: %d*%d=%d", args.A, args.B, *reply) + + or + + // Asynchronous call + quotient := new(Quotient) + divCall := client.Go("Arith.Divide", args, "ient, nil) + replyCall := <-divCall.Done // will be equal to divCall + // check errors, print, etc. + + A server implementation will often provide a simple, type-safe wrapper for the + client. +*/ +package rpc + +import ( + "gob" + "http" + "log" + "io" + "net" + "os" + "reflect" + "strings" + "sync" + "unicode" + "utf8" +) + +const ( + // Defaults used by HandleHTTP + DefaultRPCPath = "/_goRPC_" + DefaultDebugPath = "/debug/rpc" +) + +// Precompute the reflect type for os.Error. Can't use os.Error directly +// because Typeof takes an empty interface value. This is annoying. +var unusedError *os.Error +var typeOfOsError = reflect.Typeof(unusedError).(*reflect.PtrType).Elem() + +type methodType struct { + sync.Mutex // protects counters + method reflect.Method + ArgType *reflect.PtrType + ReplyType *reflect.PtrType + numCalls uint +} + +type service struct { + name string // name of service + rcvr reflect.Value // receiver of methods for the service + typ reflect.Type // type of the receiver + method map[string]*methodType // registered methods +} + +// Request is a header written before every RPC call. It is used internally +// but documented here as an aid to debugging, such as when analyzing +// network traffic. +type Request struct { + ServiceMethod string // format: "Service.Method" + Seq uint64 // sequence number chosen by client +} + +// Response is a header written before every RPC return. It is used internally +// but documented here as an aid to debugging, such as when analyzing +// network traffic. +type Response struct { + ServiceMethod string // echoes that of the Request + Seq uint64 // echoes that of the request + Error string // error, if any. +} + +// ClientInfo records information about an RPC client connection. +type ClientInfo struct { + LocalAddr string + RemoteAddr string +} + +// Server represents an RPC Server. +type Server struct { + sync.Mutex // protects the serviceMap + serviceMap map[string]*service +} + +// NewServer returns a new Server. +func NewServer() *Server { + return &Server{serviceMap: make(map[string]*service)} +} + +// DefaultServer is the default instance of *Server. +var DefaultServer = NewServer() + +// Is this an exported - upper case - name? +func isExported(name string) bool { + rune, _ := utf8.DecodeRuneInString(name) + return unicode.IsUpper(rune) +} + +// Register publishes in the server the set of methods of the +// receiver value that satisfy the following conditions: +// - exported method +// - two arguments, both pointers to exported structs +// - one return value, of type os.Error +// It returns an error if the receiver is not an exported type or has no +// suitable methods. +// The client accesses each method using a string of the form "Type.Method", +// where Type is the receiver's concrete type. +func (server *Server) Register(rcvr interface{}) os.Error { + return server.register(rcvr, "", false) +} + +// RegisterName is like Register but uses the provided name for the type +// instead of the receiver's concrete type. +func (server *Server) RegisterName(name string, rcvr interface{}) os.Error { + return server.register(rcvr, name, true) +} + +func (server *Server) register(rcvr interface{}, name string, useName bool) os.Error { + server.Lock() + defer server.Unlock() + if server.serviceMap == nil { + server.serviceMap = make(map[string]*service) + } + s := new(service) + s.typ = reflect.Typeof(rcvr) + s.rcvr = reflect.NewValue(rcvr) + sname := reflect.Indirect(s.rcvr).Type().Name() + if useName { + sname = name + } + if sname == "" { + log.Exit("rpc: no service name for type", s.typ.String()) + } + if s.typ.PkgPath() != "" && !isExported(sname) && !useName { + s := "rpc Register: type " + sname + " is not exported" + log.Print(s) + return os.ErrorString(s) + } + if _, present := server.serviceMap[sname]; present { + return os.ErrorString("rpc: service already defined: " + sname) + } + s.name = sname + s.method = make(map[string]*methodType) + + // Install the methods + for m := 0; m < s.typ.NumMethod(); m++ { + method := s.typ.Method(m) + mtype := method.Type + mname := method.Name + if mtype.PkgPath() != "" || !isExported(mname) { + continue + } + // Method needs three ins: receiver, *args, *reply. + if mtype.NumIn() != 3 { + log.Println("method", mname, "has wrong number of ins:", mtype.NumIn()) + continue + } + argType, ok := mtype.In(1).(*reflect.PtrType) + if !ok { + log.Println(mname, "arg type not a pointer:", mtype.In(1)) + continue + } + replyType, ok := mtype.In(2).(*reflect.PtrType) + if !ok { + log.Println(mname, "reply type not a pointer:", mtype.In(2)) + continue + } + if argType.Elem().PkgPath() != "" && !isExported(argType.Elem().Name()) { + log.Println(mname, "argument type not exported:", argType) + continue + } + if replyType.Elem().PkgPath() != "" && !isExported(replyType.Elem().Name()) { + log.Println(mname, "reply type not exported:", replyType) + continue + } + if mtype.NumIn() == 4 { + t := mtype.In(3) + if t != reflect.Typeof((*ClientInfo)(nil)) { + log.Println(mname, "last argument not *ClientInfo") + continue + } + } + // Method needs one out: os.Error. + if mtype.NumOut() != 1 { + log.Println("method", mname, "has wrong number of outs:", mtype.NumOut()) + continue + } + if returnType := mtype.Out(0); returnType != typeOfOsError { + log.Println("method", mname, "returns", returnType.String(), "not os.Error") + continue + } + s.method[mname] = &methodType{method: method, ArgType: argType, ReplyType: replyType} + } + + if len(s.method) == 0 { + s := "rpc Register: type " + sname + " has no exported methods of suitable type" + log.Print(s) + return os.ErrorString(s) + } + server.serviceMap[s.name] = s + return nil +} + +// A value sent as a placeholder for the response when the server receives an invalid request. +type InvalidRequest struct { + marker int +} + +var invalidRequest = InvalidRequest{1} + +func _new(t *reflect.PtrType) *reflect.PtrValue { + v := reflect.MakeZero(t).(*reflect.PtrValue) + v.PointTo(reflect.MakeZero(t.Elem())) + return v +} + +func sendResponse(sending *sync.Mutex, req *Request, reply interface{}, codec ServerCodec, errmsg string) { + resp := new(Response) + // Encode the response header + resp.ServiceMethod = req.ServiceMethod + if errmsg != "" { + resp.Error = errmsg + } + resp.Seq = req.Seq + sending.Lock() + err := codec.WriteResponse(resp, reply) + if err != nil { + log.Println("rpc: writing response:", err) + } + sending.Unlock() +} + +func (m *methodType) NumCalls() (n uint) { + m.Lock() + n = m.numCalls + m.Unlock() + return n +} + +func (s *service) call(sending *sync.Mutex, mtype *methodType, req *Request, argv, replyv reflect.Value, codec ServerCodec) { + mtype.Lock() + mtype.numCalls++ + mtype.Unlock() + function := mtype.method.Func + // Invoke the method, providing a new value for the reply. + returnValues := function.Call([]reflect.Value{s.rcvr, argv, replyv}) + // The return value for the method is an os.Error. + errInter := returnValues[0].Interface() + errmsg := "" + if errInter != nil { + errmsg = errInter.(os.Error).String() + } + sendResponse(sending, req, replyv.Interface(), codec, errmsg) +} + +type gobServerCodec struct { + rwc io.ReadWriteCloser + dec *gob.Decoder + enc *gob.Encoder +} + +func (c *gobServerCodec) ReadRequestHeader(r *Request) os.Error { + return c.dec.Decode(r) +} + +func (c *gobServerCodec) ReadRequestBody(body interface{}) os.Error { + return c.dec.Decode(body) +} + +func (c *gobServerCodec) WriteResponse(r *Response, body interface{}) os.Error { + if err := c.enc.Encode(r); err != nil { + return err + } + return c.enc.Encode(body) +} + +func (c *gobServerCodec) Close() os.Error { + return c.rwc.Close() +} + + +// ServeConn runs the server on a single connection. +// ServeConn blocks, serving the connection until the client hangs up. +// The caller typically invokes ServeConn in a go statement. +// ServeConn uses the gob wire format (see package gob) on the +// connection. To use an alternate codec, use ServeCodec. +func (server *Server) ServeConn(conn io.ReadWriteCloser) { + server.ServeCodec(&gobServerCodec{conn, gob.NewDecoder(conn), gob.NewEncoder(conn)}) +} + +// ServeCodec is like ServeConn but uses the specified codec to +// decode requests and encode responses. +func (server *Server) ServeCodec(codec ServerCodec) { + sending := new(sync.Mutex) + for { + // Grab the request header. + req := new(Request) + err := codec.ReadRequestHeader(req) + if err != nil { + if err == os.EOF || err == io.ErrUnexpectedEOF { + if err == io.ErrUnexpectedEOF { + log.Println("rpc:", err) + } + break + } + s := "rpc: server cannot decode request: " + err.String() + sendResponse(sending, req, invalidRequest, codec, s) + break + } + serviceMethod := strings.Split(req.ServiceMethod, ".", -1) + if len(serviceMethod) != 2 { + s := "rpc: service/method request ill-formed: " + req.ServiceMethod + sendResponse(sending, req, invalidRequest, codec, s) + continue + } + // Look up the request. + server.Lock() + service, ok := server.serviceMap[serviceMethod[0]] + server.Unlock() + if !ok { + s := "rpc: can't find service " + req.ServiceMethod + sendResponse(sending, req, invalidRequest, codec, s) + continue + } + mtype, ok := service.method[serviceMethod[1]] + if !ok { + s := "rpc: can't find method " + req.ServiceMethod + sendResponse(sending, req, invalidRequest, codec, s) + continue + } + // Decode the argument value. + argv := _new(mtype.ArgType) + replyv := _new(mtype.ReplyType) + err = codec.ReadRequestBody(argv.Interface()) + if err != nil { + log.Println("rpc: tearing down", serviceMethod[0], "connection:", err) + sendResponse(sending, req, replyv.Interface(), codec, err.String()) + break + } + go service.call(sending, mtype, req, argv, replyv, codec) + } + codec.Close() +} + +// Accept accepts connections on the listener and serves requests +// for each incoming connection. Accept blocks; the caller typically +// invokes it in a go statement. +func (server *Server) Accept(lis net.Listener) { + for { + conn, err := lis.Accept() + if err != nil { + log.Exit("rpc.Serve: accept:", err.String()) // TODO(r): exit? + } + go server.ServeConn(conn) + } +} + +// Register publishes the receiver's methods in the DefaultServer. +func Register(rcvr interface{}) os.Error { return DefaultServer.Register(rcvr) } + +// RegisterName is like Register but uses the provided name for the type +// instead of the receiver's concrete type. +func RegisterName(name string, rcvr interface{}) os.Error { + return DefaultServer.RegisterName(name, rcvr) +} + +// A ServerCodec implements reading of RPC requests and writing of +// RPC responses for the server side of an RPC session. +// The server calls ReadRequestHeader and ReadRequestBody in pairs +// to read requests from the connection, and it calls WriteResponse to +// write a response back. The server calls Close when finished with the +// connection. +type ServerCodec interface { + ReadRequestHeader(*Request) os.Error + ReadRequestBody(interface{}) os.Error + WriteResponse(*Response, interface{}) os.Error + + Close() os.Error +} + +// ServeConn runs the DefaultServer on a single connection. +// ServeConn blocks, serving the connection until the client hangs up. +// The caller typically invokes ServeConn in a go statement. +// ServeConn uses the gob wire format (see package gob) on the +// connection. To use an alternate codec, use ServeCodec. +func ServeConn(conn io.ReadWriteCloser) { + DefaultServer.ServeConn(conn) +} + +// ServeCodec is like ServeConn but uses the specified codec to +// decode requests and encode responses. +func ServeCodec(codec ServerCodec) { + DefaultServer.ServeCodec(codec) +} + +// Accept accepts connections on the listener and serves requests +// to DefaultServer for each incoming connection. +// Accept blocks; the caller typically invokes it in a go statement. +func Accept(lis net.Listener) { DefaultServer.Accept(lis) } + +// Can connect to RPC service using HTTP CONNECT to rpcPath. +var connected = "200 Connected to Go RPC" + +// ServeHTTP implements an http.Handler that answers RPC requests. +func (server *Server) ServeHTTP(w http.ResponseWriter, req *http.Request) { + if req.Method != "CONNECT" { + w.SetHeader("Content-Type", "text/plain; charset=utf-8") + w.WriteHeader(http.StatusMethodNotAllowed) + io.WriteString(w, "405 must CONNECT\n") + return + } + conn, _, err := w.Hijack() + if err != nil { + log.Print("rpc hijacking ", w.RemoteAddr(), ": ", err.String()) + return + } + io.WriteString(conn, "HTTP/1.0 "+connected+"\n\n") + server.ServeConn(conn) +} + +// HandleHTTP registers an HTTP handler for RPC messages on rpcPath, +// and a debugging handler on debugPath. +// It is still necessary to invoke http.Serve(), typically in a go statement. +func (server *Server) HandleHTTP(rpcPath, debugPath string) { + http.Handle(rpcPath, server) + http.Handle(debugPath, debugHTTP{server}) +} + +// HandleHTTP registers an HTTP handler for RPC messages to DefaultServer +// on DefaultRPCPath and a debugging handler on DefaultDebugPath. +// It is still necessary to invoke http.Serve(), typically in a go statement. +func HandleHTTP() { + DefaultServer.HandleHTTP(DefaultRPCPath, DefaultDebugPath) +} diff --git a/libgo/go/rpc/server_test.go b/libgo/go/rpc/server_test.go new file mode 100644 index 000000000..355d51ce4 --- /dev/null +++ b/libgo/go/rpc/server_test.go @@ -0,0 +1,384 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package rpc + +import ( + "fmt" + "http" + "log" + "net" + "os" + "strings" + "sync" + "testing" + "time" +) + +var ( + serverAddr, newServerAddr string + httpServerAddr string + once, newOnce, httpOnce sync.Once +) + +const ( + second = 1e9 + newHttpPath = "/foo" +) + +type Args struct { + A, B int +} + +type Reply struct { + C int +} + +type Arith int + +func (t *Arith) Add(args *Args, reply *Reply) os.Error { + reply.C = args.A + args.B + return nil +} + +func (t *Arith) Mul(args *Args, reply *Reply) os.Error { + reply.C = args.A * args.B + return nil +} + +func (t *Arith) Div(args *Args, reply *Reply) os.Error { + if args.B == 0 { + return os.ErrorString("divide by zero") + } + reply.C = args.A / args.B + return nil +} + +func (t *Arith) String(args *Args, reply *string) os.Error { + *reply = fmt.Sprintf("%d+%d=%d", args.A, args.B, args.A+args.B) + return nil +} + +func (t *Arith) Scan(args *string, reply *Reply) (err os.Error) { + _, err = fmt.Sscan(*args, &reply.C) + return +} + +func (t *Arith) Error(args *Args, reply *Reply) os.Error { + panic("ERROR") +} + +func listenTCP() (net.Listener, string) { + l, e := net.Listen("tcp", "127.0.0.1:0") // any available address + if e != nil { + log.Exitf("net.Listen tcp :0: %v", e) + } + return l, l.Addr().String() +} + +func startServer() { + Register(new(Arith)) + + var l net.Listener + l, serverAddr = listenTCP() + log.Println("Test RPC server listening on", serverAddr) + go Accept(l) + + HandleHTTP() + httpOnce.Do(startHttpServer) +} + +func startNewServer() { + s := NewServer() + s.Register(new(Arith)) + + var l net.Listener + l, newServerAddr = listenTCP() + log.Println("NewServer test RPC server listening on", newServerAddr) + go Accept(l) + + s.HandleHTTP(newHttpPath, "/bar") + httpOnce.Do(startHttpServer) +} + +func startHttpServer() { + var l net.Listener + l, httpServerAddr = listenTCP() + httpServerAddr = l.Addr().String() + log.Println("Test HTTP RPC server listening on", httpServerAddr) + go http.Serve(l, nil) +} + +func TestRPC(t *testing.T) { + once.Do(startServer) + testRPC(t, serverAddr) + newOnce.Do(startNewServer) + testRPC(t, newServerAddr) +} + +func testRPC(t *testing.T, addr string) { + client, err := Dial("tcp", addr) + if err != nil { + t.Fatal("dialing", err) + } + + // Synchronous calls + args := &Args{7, 8} + reply := new(Reply) + err = client.Call("Arith.Add", args, reply) + if err != nil { + t.Errorf("Add: expected no error but got string %q", err.String()) + } + if reply.C != args.A+args.B { + t.Errorf("Add: expected %d got %d", reply.C, args.A+args.B) + } + + args = &Args{7, 8} + reply = new(Reply) + err = client.Call("Arith.Mul", args, reply) + if err != nil { + t.Errorf("Mul: expected no error but got string %q", err.String()) + } + if reply.C != args.A*args.B { + t.Errorf("Mul: expected %d got %d", reply.C, args.A*args.B) + } + + // Out of order. + args = &Args{7, 8} + mulReply := new(Reply) + mulCall := client.Go("Arith.Mul", args, mulReply, nil) + addReply := new(Reply) + addCall := client.Go("Arith.Add", args, addReply, nil) + + addCall = <-addCall.Done + if addCall.Error != nil { + t.Errorf("Add: expected no error but got string %q", addCall.Error.String()) + } + if addReply.C != args.A+args.B { + t.Errorf("Add: expected %d got %d", addReply.C, args.A+args.B) + } + + mulCall = <-mulCall.Done + if mulCall.Error != nil { + t.Errorf("Mul: expected no error but got string %q", mulCall.Error.String()) + } + if mulReply.C != args.A*args.B { + t.Errorf("Mul: expected %d got %d", mulReply.C, args.A*args.B) + } + + // Error test + args = &Args{7, 0} + reply = new(Reply) + err = client.Call("Arith.Div", args, reply) + // expect an error: zero divide + if err == nil { + t.Error("Div: expected error") + } else if err.String() != "divide by zero" { + t.Error("Div: expected divide by zero error; got", err) + } + + // Non-struct argument + const Val = 12345 + str := fmt.Sprint(Val) + reply = new(Reply) + err = client.Call("Arith.Scan", &str, reply) + if err != nil { + t.Errorf("Scan: expected no error but got string %q", err.String()) + } else if reply.C != Val { + t.Errorf("Scan: expected %d got %d", Val, reply.C) + } + + // Non-struct reply + args = &Args{27, 35} + str = "" + err = client.Call("Arith.String", args, &str) + if err != nil { + t.Errorf("String: expected no error but got string %q", err.String()) + } + expect := fmt.Sprintf("%d+%d=%d", args.A, args.B, args.A+args.B) + if str != expect { + t.Errorf("String: expected %s got %s", expect, str) + } +} + +func TestHTTPRPC(t *testing.T) { + once.Do(startServer) + testHTTPRPC(t, "") + newOnce.Do(startNewServer) + testHTTPRPC(t, newHttpPath) +} + +func testHTTPRPC(t *testing.T, path string) { + var client *Client + var err os.Error + if path == "" { + client, err = DialHTTP("tcp", httpServerAddr) + } else { + client, err = DialHTTPPath("tcp", httpServerAddr, path) + } + if err != nil { + t.Fatal("dialing", err) + } + + // Synchronous calls + args := &Args{7, 8} + reply := new(Reply) + err = client.Call("Arith.Add", args, reply) + if err != nil { + t.Errorf("Add: expected no error but got string %q", err.String()) + } + if reply.C != args.A+args.B { + t.Errorf("Add: expected %d got %d", reply.C, args.A+args.B) + } +} + +func TestCheckUnknownService(t *testing.T) { + once.Do(startServer) + + conn, err := net.Dial("tcp", "", serverAddr) + if err != nil { + t.Fatal("dialing:", err) + } + + client := NewClient(conn) + + args := &Args{7, 8} + reply := new(Reply) + err = client.Call("Unknown.Add", args, reply) + if err == nil { + t.Error("expected error calling unknown service") + } else if strings.Index(err.String(), "service") < 0 { + t.Error("expected error about service; got", err) + } +} + +func TestCheckUnknownMethod(t *testing.T) { + once.Do(startServer) + + conn, err := net.Dial("tcp", "", serverAddr) + if err != nil { + t.Fatal("dialing:", err) + } + + client := NewClient(conn) + + args := &Args{7, 8} + reply := new(Reply) + err = client.Call("Arith.Unknown", args, reply) + if err == nil { + t.Error("expected error calling unknown service") + } else if strings.Index(err.String(), "method") < 0 { + t.Error("expected error about method; got", err) + } +} + +func TestCheckBadType(t *testing.T) { + once.Do(startServer) + + conn, err := net.Dial("tcp", "", serverAddr) + if err != nil { + t.Fatal("dialing:", err) + } + + client := NewClient(conn) + + reply := new(Reply) + err = client.Call("Arith.Add", reply, reply) // args, reply would be the correct thing to use + if err == nil { + t.Error("expected error calling Arith.Add with wrong arg type") + } else if strings.Index(err.String(), "type") < 0 { + t.Error("expected error about type; got", err) + } +} + +type ArgNotPointer int +type ReplyNotPointer int +type ArgNotPublic int +type ReplyNotPublic int +type local struct{} + +func (t *ArgNotPointer) ArgNotPointer(args Args, reply *Reply) os.Error { + return nil +} + +func (t *ReplyNotPointer) ReplyNotPointer(args *Args, reply Reply) os.Error { + return nil +} + +func (t *ArgNotPublic) ArgNotPublic(args *local, reply *Reply) os.Error { + return nil +} + +func (t *ReplyNotPublic) ReplyNotPublic(args *Args, reply *local) os.Error { + return nil +} + +// Check that registration handles lots of bad methods and a type with no suitable methods. +func TestRegistrationError(t *testing.T) { + err := Register(new(ArgNotPointer)) + if err == nil { + t.Errorf("expected error registering ArgNotPointer") + } + err = Register(new(ReplyNotPointer)) + if err == nil { + t.Errorf("expected error registering ReplyNotPointer") + } + err = Register(new(ArgNotPublic)) + if err == nil { + t.Errorf("expected error registering ArgNotPublic") + } + err = Register(new(ReplyNotPublic)) + if err == nil { + t.Errorf("expected error registering ReplyNotPublic") + } +} + +type WriteFailCodec int + +func (WriteFailCodec) WriteRequest(*Request, interface{}) os.Error { + // the panic caused by this error used to not unlock a lock. + return os.NewError("fail") +} + +func (WriteFailCodec) ReadResponseHeader(*Response) os.Error { + time.Sleep(60e9) + panic("unreachable") +} + +func (WriteFailCodec) ReadResponseBody(interface{}) os.Error { + time.Sleep(60e9) + panic("unreachable") +} + +func (WriteFailCodec) Close() os.Error { + return nil +} + +func TestSendDeadlock(t *testing.T) { + client := NewClientWithCodec(WriteFailCodec(0)) + + done := make(chan bool) + go func() { + testSendDeadlock(client) + testSendDeadlock(client) + done <- true + }() + for i := 0; i < 50; i++ { + time.Sleep(100 * 1e6) + _, ok := <-done + if ok { + return + } + } + t.Fatal("deadlock") +} + +func testSendDeadlock(client *Client) { + defer func() { + recover() + }() + args := &Args{7, 8} + reply := new(Reply) + client.Call("Arith.Add", args, reply) +} |