diff options
Diffstat (limited to 'libgo/go/net')
43 files changed, 8269 insertions, 0 deletions
diff --git a/libgo/go/net/dial.go b/libgo/go/net/dial.go new file mode 100644 index 000000000..03b9d87be --- /dev/null +++ b/libgo/go/net/dial.go @@ -0,0 +1,179 @@ +// Copyright 2010 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package net + +import "os" + +// Dial connects to the remote address raddr on the network net. +// If the string laddr is not empty, it is used as the local address +// for the connection. +// +// Known networks are "tcp", "tcp4" (IPv4-only), "tcp6" (IPv6-only), +// "udp", "udp4" (IPv4-only), "udp6" (IPv6-only), "ip", "ip4" +// (IPv4-only), "ip6" (IPv6-only), "unix" and "unixgram". +// +// For IP networks, addresses have the form host:port. If host is +// a literal IPv6 address, it must be enclosed in square brackets. +// +// Examples: +// Dial("tcp", "", "12.34.56.78:80") +// Dial("tcp", "", "google.com:80") +// Dial("tcp", "", "[de:ad:be:ef::ca:fe]:80") +// Dial("tcp", "127.0.0.1:123", "127.0.0.1:88") +// +func Dial(net, laddr, raddr string) (c Conn, err os.Error) { + switch prefixBefore(net, ':') { + case "tcp", "tcp4", "tcp6": + var la, ra *TCPAddr + if laddr != "" { + if la, err = ResolveTCPAddr(laddr); err != nil { + goto Error + } + } + if raddr != "" { + if ra, err = ResolveTCPAddr(raddr); err != nil { + goto Error + } + } + c, err := DialTCP(net, la, ra) + if err != nil { + return nil, err + } + return c, nil + case "udp", "udp4", "udp6": + var la, ra *UDPAddr + if laddr != "" { + if la, err = ResolveUDPAddr(laddr); err != nil { + goto Error + } + } + if raddr != "" { + if ra, err = ResolveUDPAddr(raddr); err != nil { + goto Error + } + } + c, err := DialUDP(net, la, ra) + if err != nil { + return nil, err + } + return c, nil + case "unix", "unixgram", "unixpacket": + var la, ra *UnixAddr + if raddr != "" { + if ra, err = ResolveUnixAddr(net, raddr); err != nil { + goto Error + } + } + if laddr != "" { + if la, err = ResolveUnixAddr(net, laddr); err != nil { + goto Error + } + } + c, err = DialUnix(net, la, ra) + if err != nil { + return nil, err + } + return c, nil + case "ip", "ip4", "ip6": + var la, ra *IPAddr + if laddr != "" { + if la, err = ResolveIPAddr(laddr); err != nil { + goto Error + } + } + if raddr != "" { + if ra, err = ResolveIPAddr(raddr); err != nil { + goto Error + } + } + c, err := DialIP(net, la, ra) + if err != nil { + return nil, err + } + return c, nil + + } + err = UnknownNetworkError(net) +Error: + return nil, &OpError{"dial", net + " " + raddr, nil, err} +} + +// Listen announces on the local network address laddr. +// The network string net must be a stream-oriented +// network: "tcp", "tcp4", "tcp6", or "unix", or "unixpacket". +func Listen(net, laddr string) (l Listener, err os.Error) { + switch net { + case "tcp", "tcp4", "tcp6": + var la *TCPAddr + if laddr != "" { + if la, err = ResolveTCPAddr(laddr); err != nil { + return nil, err + } + } + l, err := ListenTCP(net, la) + if err != nil { + return nil, err + } + return l, nil + case "unix", "unixpacket": + var la *UnixAddr + if laddr != "" { + if la, err = ResolveUnixAddr(net, laddr); err != nil { + return nil, err + } + } + l, err := ListenUnix(net, la) + if err != nil { + return nil, err + } + return l, nil + } + return nil, UnknownNetworkError(net) +} + +// ListenPacket announces on the local network address laddr. +// The network string net must be a packet-oriented network: +// "udp", "udp4", "udp6", or "unixgram". +func ListenPacket(net, laddr string) (c PacketConn, err os.Error) { + switch prefixBefore(net, ':') { + case "udp", "udp4", "udp6": + var la *UDPAddr + if laddr != "" { + if la, err = ResolveUDPAddr(laddr); err != nil { + return nil, err + } + } + c, err := ListenUDP(net, la) + if err != nil { + return nil, err + } + return c, nil + case "unixgram": + var la *UnixAddr + if laddr != "" { + if la, err = ResolveUnixAddr(net, laddr); err != nil { + return nil, err + } + } + c, err := DialUnix(net, la, nil) + if err != nil { + return nil, err + } + return c, nil + case "ip", "ip4", "ip6": + var la *IPAddr + if laddr != "" { + if la, err = ResolveIPAddr(laddr); err != nil { + return nil, err + } + } + c, err := ListenIP(net, la) + if err != nil { + return nil, err + } + return c, nil + } + return nil, UnknownNetworkError(net) +} diff --git a/libgo/go/net/dialgoogle_test.go b/libgo/go/net/dialgoogle_test.go new file mode 100644 index 000000000..a432800cf --- /dev/null +++ b/libgo/go/net/dialgoogle_test.go @@ -0,0 +1,107 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package net + +import ( + "flag" + "fmt" + "io" + "strings" + "syscall" + "testing" +) + +// If an IPv6 tunnel is running, we can try dialing a real IPv6 address. +var ipv6 = flag.Bool("ipv6", false, "assume ipv6 tunnel is present") + +// fd is already connected to the destination, port 80. +// Run an HTTP request to fetch the appropriate page. +func fetchGoogle(t *testing.T, fd Conn, network, addr string) { + req := []byte("GET /intl/en/privacy/ HTTP/1.0\r\nHost: www.google.com\r\n\r\n") + n, err := fd.Write(req) + + buf := make([]byte, 1000) + n, err = io.ReadFull(fd, buf) + + if n < 1000 { + t.Errorf("fetchGoogle: short HTTP read from %s %s - %v", network, addr, err) + return + } +} + +func doDial(t *testing.T, network, addr string) { + fd, err := Dial(network, "", addr) + if err != nil { + t.Errorf("Dial(%q, %q, %q) = _, %v", network, "", addr, err) + return + } + fetchGoogle(t, fd, network, addr) + fd.Close() +} + +var googleaddrs = []string{ + "%d.%d.%d.%d:80", + "www.google.com:80", + "%d.%d.%d.%d:http", + "www.google.com:http", + "%03d.%03d.%03d.%03d:0080", + "[::ffff:%d.%d.%d.%d]:80", + "[::ffff:%02x%02x:%02x%02x]:80", + "[0:0:0:0:0000:ffff:%d.%d.%d.%d]:80", + "[0:0:0:0:000000:ffff:%d.%d.%d.%d]:80", + "[0:0:0:0:0:ffff::%d.%d.%d.%d]:80", + "[2001:4860:0:2001::68]:80", // ipv6.google.com; removed if ipv6 flag not set +} + +func TestDialGoogle(t *testing.T) { + // If no ipv6 tunnel, don't try the last address. + if !*ipv6 { + googleaddrs[len(googleaddrs)-1] = "" + } + + // Insert an actual IP address for google.com + // into the table. + + _, addrs, err := LookupHost("www.google.com") + if err != nil { + t.Fatalf("lookup www.google.com: %v", err) + } + if len(addrs) == 0 { + t.Fatalf("no addresses for www.google.com") + } + ip := ParseIP(addrs[0]).To4() + + for i, s := range googleaddrs { + if strings.Contains(s, "%") { + googleaddrs[i] = fmt.Sprintf(s, ip[0], ip[1], ip[2], ip[3]) + } + } + + for i := 0; i < len(googleaddrs); i++ { + addr := googleaddrs[i] + if addr == "" { + continue + } + t.Logf("-- %s --", addr) + doDial(t, "tcp", addr) + if addr[0] != '[' { + doDial(t, "tcp4", addr) + + if !preferIPv4 { + // make sure preferIPv4 flag works. + preferIPv4 = true + syscall.SocketDisableIPv6 = true + doDial(t, "tcp4", addr) + syscall.SocketDisableIPv6 = false + preferIPv4 = false + } + } + + // Only run tcp6 if the kernel will take it. + if kernelSupportsIPv6() { + doDial(t, "tcp6", addr) + } + } +} diff --git a/libgo/go/net/dict/dict.go b/libgo/go/net/dict/dict.go new file mode 100644 index 000000000..42f6553ad --- /dev/null +++ b/libgo/go/net/dict/dict.go @@ -0,0 +1,212 @@ +// Copyright 2010 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package dict implements the Dictionary Server Protocol +// as defined in RFC 2229. +package dict + +import ( + "container/vector" + "net/textproto" + "os" + "strconv" + "strings" +) + +// A Client represents a client connection to a dictionary server. +type Client struct { + text *textproto.Conn +} + +// Dial returns a new client connected to a dictionary server at +// addr on the given network. +func Dial(network, addr string) (*Client, os.Error) { + text, err := textproto.Dial(network, addr) + if err != nil { + return nil, err + } + _, _, err = text.ReadCodeLine(220) + if err != nil { + text.Close() + return nil, err + } + return &Client{text: text}, nil +} + +// Close closes the connection to the dictionary server. +func (c *Client) Close() os.Error { + return c.text.Close() +} + +// A Dict represents a dictionary available on the server. +type Dict struct { + Name string // short name of dictionary + Desc string // long description +} + +// Dicts returns a list of the dictionaries available on the server. +func (c *Client) Dicts() ([]Dict, os.Error) { + id, err := c.text.Cmd("SHOW DB") + if err != nil { + return nil, err + } + + c.text.StartResponse(id) + defer c.text.EndResponse(id) + + _, _, err = c.text.ReadCodeLine(110) + if err != nil { + return nil, err + } + lines, err := c.text.ReadDotLines() + if err != nil { + return nil, err + } + _, _, err = c.text.ReadCodeLine(250) + + dicts := make([]Dict, len(lines)) + for i := range dicts { + d := &dicts[i] + a, _ := fields(lines[i]) + if len(a) < 2 { + return nil, textproto.ProtocolError("invalid dictionary: " + lines[i]) + } + d.Name = a[0] + d.Desc = a[1] + } + return dicts, err +} + +// A Defn represents a definition. +type Defn struct { + Dict Dict // Dict where definition was found + Word string // Word being defined + Text []byte // Definition text, typically multiple lines +} + +// Define requests the definition of the given word. +// The argument dict names the dictionary to use, +// the Name field of a Dict returned by Dicts. +// +// The special dictionary name "*" means to look in all the +// server's dictionaries. +// The special dictionary name "!" means to look in all the +// server's dictionaries in turn, stopping after finding the word +// in one of them. +func (c *Client) Define(dict, word string) ([]*Defn, os.Error) { + id, err := c.text.Cmd("DEFINE %s %q", dict, word) + if err != nil { + return nil, err + } + + c.text.StartResponse(id) + defer c.text.EndResponse(id) + + _, line, err := c.text.ReadCodeLine(150) + if err != nil { + return nil, err + } + a, _ := fields(line) + if len(a) < 1 { + return nil, textproto.ProtocolError("malformed response: " + line) + } + n, err := strconv.Atoi(a[0]) + if err != nil { + return nil, textproto.ProtocolError("invalid definition count: " + a[0]) + } + def := make([]*Defn, n) + for i := 0; i < n; i++ { + _, line, err = c.text.ReadCodeLine(151) + if err != nil { + return nil, err + } + a, _ := fields(line) + if len(a) < 3 { + // skip it, to keep protocol in sync + i-- + n-- + def = def[0:n] + continue + } + d := &Defn{Word: a[0], Dict: Dict{a[1], a[2]}} + d.Text, err = c.text.ReadDotBytes() + if err != nil { + return nil, err + } + def[i] = d + } + _, _, err = c.text.ReadCodeLine(250) + return def, err +} + +// Fields returns the fields in s. +// Fields are space separated unquoted words +// or quoted with single or double quote. +func fields(s string) ([]string, os.Error) { + var v vector.StringVector + i := 0 + for { + for i < len(s) && (s[i] == ' ' || s[i] == '\t') { + i++ + } + if i >= len(s) { + break + } + if s[i] == '"' || s[i] == '\'' { + q := s[i] + // quoted string + var j int + for j = i + 1; ; j++ { + if j >= len(s) { + return nil, textproto.ProtocolError("malformed quoted string") + } + if s[j] == '\\' { + j++ + continue + } + if s[j] == q { + j++ + break + } + } + v.Push(unquote(s[i+1 : j-1])) + i = j + } else { + // atom + var j int + for j = i; j < len(s); j++ { + if s[j] == ' ' || s[j] == '\t' || s[j] == '\\' || s[j] == '"' || s[j] == '\'' { + break + } + } + v.Push(s[i:j]) + i = j + } + if i < len(s) { + c := s[i] + if c != ' ' && c != '\t' { + return nil, textproto.ProtocolError("quotes not on word boundaries") + } + } + } + return v, nil +} + +func unquote(s string) string { + if strings.Index(s, "\\") < 0 { + return s + } + b := []byte(s) + w := 0 + for r := 0; r < len(b); r++ { + c := b[r] + if c == '\\' { + r++ + c = b[r] + } + b[w] = c + w++ + } + return string(b[0:w]) +} diff --git a/libgo/go/net/dnsclient.go b/libgo/go/net/dnsclient.go new file mode 100644 index 000000000..87d76261f --- /dev/null +++ b/libgo/go/net/dnsclient.go @@ -0,0 +1,417 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// DNS client: see RFC 1035. +// Has to be linked into package net for Dial. + +// TODO(rsc): +// Check periodically whether /etc/resolv.conf has changed. +// Could potentially handle many outstanding lookups faster. +// Could have a small cache. +// Random UDP source port (net.Dial should do that for us). +// Random request IDs. + +package net + +import ( + "bytes" + "fmt" + "os" + "rand" + "sync" + "time" +) + +// DNSError represents a DNS lookup error. +type DNSError struct { + Error string // description of the error + Name string // name looked for + Server string // server used + IsTimeout bool +} + +func (e *DNSError) String() string { + if e == nil { + return "<nil>" + } + s := "lookup " + e.Name + if e.Server != "" { + s += " on " + e.Server + } + s += ": " + e.Error + return s +} + +func (e *DNSError) Timeout() bool { return e.IsTimeout } +func (e *DNSError) Temporary() bool { return e.IsTimeout } + +const noSuchHost = "no such host" + +// Send a request on the connection and hope for a reply. +// Up to cfg.attempts attempts. +func exchange(cfg *dnsConfig, c Conn, name string, qtype uint16) (*dnsMsg, os.Error) { + if len(name) >= 256 { + return nil, &DNSError{Error: "name too long", Name: name} + } + out := new(dnsMsg) + out.id = uint16(rand.Int()) ^ uint16(time.Nanoseconds()) + out.question = []dnsQuestion{ + {name, qtype, dnsClassINET}, + } + out.recursion_desired = true + msg, ok := out.Pack() + if !ok { + return nil, &DNSError{Error: "internal error - cannot pack message", Name: name} + } + + for attempt := 0; attempt < cfg.attempts; attempt++ { + n, err := c.Write(msg) + if err != nil { + return nil, err + } + + c.SetReadTimeout(int64(cfg.timeout) * 1e9) // nanoseconds + + buf := make([]byte, 2000) // More than enough. + n, err = c.Read(buf) + if err != nil { + if e, ok := err.(Error); ok && e.Timeout() { + continue + } + return nil, err + } + buf = buf[0:n] + in := new(dnsMsg) + if !in.Unpack(buf) || in.id != out.id { + continue + } + return in, nil + } + var server string + if a := c.RemoteAddr(); a != nil { + server = a.String() + } + return nil, &DNSError{Error: "no answer from server", Name: name, Server: server, IsTimeout: true} +} + + +// Find answer for name in dns message. +// On return, if err == nil, addrs != nil. +func answer(name, server string, dns *dnsMsg, qtype uint16) (addrs []dnsRR, err os.Error) { + addrs = make([]dnsRR, 0, len(dns.answer)) + + if dns.rcode == dnsRcodeNameError && dns.recursion_available { + return nil, &DNSError{Error: noSuchHost, Name: name} + } + if dns.rcode != dnsRcodeSuccess { + // None of the error codes make sense + // for the query we sent. If we didn't get + // a name error and we didn't get success, + // the server is behaving incorrectly. + return nil, &DNSError{Error: "server misbehaving", Name: name, Server: server} + } + + // Look for the name. + // Presotto says it's okay to assume that servers listed in + // /etc/resolv.conf are recursive resolvers. + // We asked for recursion, so it should have included + // all the answers we need in this one packet. +Cname: + for cnameloop := 0; cnameloop < 10; cnameloop++ { + addrs = addrs[0:0] + for i := 0; i < len(dns.answer); i++ { + rr := dns.answer[i] + h := rr.Header() + if h.Class == dnsClassINET && h.Name == name { + switch h.Rrtype { + case qtype: + n := len(addrs) + addrs = addrs[0 : n+1] + addrs[n] = rr + case dnsTypeCNAME: + // redirect to cname + name = rr.(*dnsRR_CNAME).Cname + continue Cname + } + } + } + if len(addrs) == 0 { + return nil, &DNSError{Error: noSuchHost, Name: name, Server: server} + } + return addrs, nil + } + + return nil, &DNSError{Error: "too many redirects", Name: name, Server: server} +} + +// Do a lookup for a single name, which must be rooted +// (otherwise answer will not find the answers). +func tryOneName(cfg *dnsConfig, name string, qtype uint16) (addrs []dnsRR, err os.Error) { + if len(cfg.servers) == 0 { + return nil, &DNSError{Error: "no DNS servers", Name: name} + } + for i := 0; i < len(cfg.servers); i++ { + // Calling Dial here is scary -- we have to be sure + // not to dial a name that will require a DNS lookup, + // or Dial will call back here to translate it. + // The DNS config parser has already checked that + // all the cfg.servers[i] are IP addresses, which + // Dial will use without a DNS lookup. + server := cfg.servers[i] + ":53" + c, cerr := Dial("udp", "", server) + if cerr != nil { + err = cerr + continue + } + msg, merr := exchange(cfg, c, name, qtype) + c.Close() + if merr != nil { + err = merr + continue + } + addrs, err = answer(name, server, msg, qtype) + if err == nil || err.(*DNSError).Error == noSuchHost { + break + } + } + return +} + +func convertRR_A(records []dnsRR) []string { + addrs := make([]string, len(records)) + for i := 0; i < len(records); i++ { + rr := records[i] + a := rr.(*dnsRR_A).A + addrs[i] = IPv4(byte(a>>24), byte(a>>16), byte(a>>8), byte(a)).String() + } + return addrs +} + +var cfg *dnsConfig +var dnserr os.Error + +func loadConfig() { cfg, dnserr = dnsReadConfig() } + +func isDomainName(s string) bool { + // See RFC 1035, RFC 3696. + if len(s) == 0 { + return false + } + if len(s) > 255 { + return false + } + if s[len(s)-1] != '.' { // simplify checking loop: make name end in dot + s += "." + } + + last := byte('.') + ok := false // ok once we've seen a letter + partlen := 0 + for i := 0; i < len(s); i++ { + c := s[i] + switch { + default: + return false + case 'a' <= c && c <= 'z' || 'A' <= c && c <= 'Z' || c == '_': + ok = true + partlen++ + case '0' <= c && c <= '9': + // fine + partlen++ + case c == '-': + // byte before dash cannot be dot + if last == '.' { + return false + } + partlen++ + case c == '.': + // byte before dot cannot be dot, dash + if last == '.' || last == '-' { + return false + } + if partlen > 63 || partlen == 0 { + return false + } + partlen = 0 + } + last = c + } + + return ok +} + +var onceLoadConfig sync.Once + +func lookup(name string, qtype uint16) (cname string, addrs []dnsRR, err os.Error) { + if !isDomainName(name) { + return name, nil, &DNSError{Error: "invalid domain name", Name: name} + } + onceLoadConfig.Do(loadConfig) + if dnserr != nil || cfg == nil { + err = dnserr + return + } + // If name is rooted (trailing dot) or has enough dots, + // try it by itself first. + rooted := len(name) > 0 && name[len(name)-1] == '.' + if rooted || count(name, '.') >= cfg.ndots { + rname := name + if !rooted { + rname += "." + } + // Can try as ordinary name. + addrs, err = tryOneName(cfg, rname, qtype) + if err == nil { + cname = rname + return + } + } + if rooted { + return + } + + // Otherwise, try suffixes. + for i := 0; i < len(cfg.search); i++ { + rname := name + "." + cfg.search[i] + if rname[len(rname)-1] != '.' { + rname += "." + } + addrs, err = tryOneName(cfg, rname, qtype) + if err == nil { + cname = rname + return + } + } + + // Last ditch effort: try unsuffixed. + rname := name + if !rooted { + rname += "." + } + addrs, err = tryOneName(cfg, rname, qtype) + if err == nil { + cname = rname + return + } + return +} + +// LookupHost looks for name using the local hosts file and DNS resolver. +// It returns the canonical name for the host and an array of that +// host's addresses. +func LookupHost(name string) (cname string, addrs []string, err os.Error) { + onceLoadConfig.Do(loadConfig) + if dnserr != nil || cfg == nil { + err = dnserr + return + } + // Use entries from /etc/hosts if they match. + addrs = lookupStaticHost(name) + if len(addrs) > 0 { + cname = name + return + } + var records []dnsRR + cname, records, err = lookup(name, dnsTypeA) + if err != nil { + return + } + addrs = convertRR_A(records) + return +} + +type SRV struct { + Target string + Port uint16 + Priority uint16 + Weight uint16 +} + +// LookupSRV tries to resolve an SRV query of the given service, +// protocol, and domain name, as specified in RFC 2782. In most cases +// the proto argument can be the same as the corresponding +// Addr.Network(). +func LookupSRV(service, proto, name string) (cname string, addrs []*SRV, err os.Error) { + target := "_" + service + "._" + proto + "." + name + var records []dnsRR + cname, records, err = lookup(target, dnsTypeSRV) + if err != nil { + return + } + addrs = make([]*SRV, len(records)) + for i := 0; i < len(records); i++ { + r := records[i].(*dnsRR_SRV) + addrs[i] = &SRV{r.Target, r.Port, r.Priority, r.Weight} + } + return +} + +type MX struct { + Host string + Pref uint16 +} + +func LookupMX(name string) (entries []*MX, err os.Error) { + var records []dnsRR + _, records, err = lookup(name, dnsTypeMX) + if err != nil { + return + } + entries = make([]*MX, len(records)) + for i := range records { + r := records[i].(*dnsRR_MX) + entries[i] = &MX{r.Mx, r.Pref} + } + return +} + +// reverseaddr returns the in-addr.arpa. or ip6.arpa. hostname of the IP +// address addr suitable for rDNS (PTR) record lookup or an error if it fails +// to parse the IP address. +func reverseaddr(addr string) (arpa string, err os.Error) { + ip := ParseIP(addr) + if ip == nil { + return "", &DNSError{Error: "unrecognized address", Name: addr} + } + if ip.To4() != nil { + return fmt.Sprintf("%d.%d.%d.%d.in-addr.arpa.", ip[15], ip[14], ip[13], ip[12]), nil + } + // Must be IPv6 + var buf bytes.Buffer + // Add it, in reverse, to the buffer + for i := len(ip) - 1; i >= 0; i-- { + s := fmt.Sprintf("%02x", ip[i]) + buf.WriteByte(s[1]) + buf.WriteByte('.') + buf.WriteByte(s[0]) + buf.WriteByte('.') + } + // Append "ip6.arpa." and return (buf already has the final .) + return buf.String() + "ip6.arpa.", nil +} + +// LookupAddr performs a reverse lookup for the given address, returning a list +// of names mapping to that address. +func LookupAddr(addr string) (name []string, err os.Error) { + name = lookupStaticAddr(addr) + if len(name) > 0 { + return + } + var arpa string + arpa, err = reverseaddr(addr) + if err != nil { + return + } + var records []dnsRR + _, records, err = lookup(arpa, dnsTypePTR) + if err != nil { + return + } + name = make([]string, len(records)) + for i := range records { + r := records[i].(*dnsRR_PTR) + name[i] = r.Ptr + } + return +} diff --git a/libgo/go/net/dnsconfig.go b/libgo/go/net/dnsconfig.go new file mode 100644 index 000000000..26f0e04e9 --- /dev/null +++ b/libgo/go/net/dnsconfig.go @@ -0,0 +1,120 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Read system DNS config from /etc/resolv.conf + +package net + +import "os" + +type dnsConfig struct { + servers []string // servers to use + search []string // suffixes to append to local name + ndots int // number of dots in name to trigger absolute lookup + timeout int // seconds before giving up on packet + attempts int // lost packets before giving up on server + rotate bool // round robin among servers +} + +var dnsconfigError os.Error + +type DNSConfigError struct { + Error os.Error +} + +func (e *DNSConfigError) String() string { + return "error reading DNS config: " + e.Error.String() +} + +func (e *DNSConfigError) Timeout() bool { return false } +func (e *DNSConfigError) Temporary() bool { return false } + + +// See resolv.conf(5) on a Linux machine. +// TODO(rsc): Supposed to call uname() and chop the beginning +// of the host name to get the default search domain. +// We assume it's in resolv.conf anyway. +func dnsReadConfig() (*dnsConfig, os.Error) { + file, err := open("/etc/resolv.conf") + if err != nil { + return nil, &DNSConfigError{err} + } + conf := new(dnsConfig) + conf.servers = make([]string, 3)[0:0] // small, but the standard limit + conf.search = make([]string, 0) + conf.ndots = 1 + conf.timeout = 5 + conf.attempts = 2 + conf.rotate = false + for line, ok := file.readLine(); ok; line, ok = file.readLine() { + f := getFields(line) + if len(f) < 1 { + continue + } + switch f[0] { + case "nameserver": // add one name server + a := conf.servers + n := len(a) + if len(f) > 1 && n < cap(a) { + // One more check: make sure server name is + // just an IP address. Otherwise we need DNS + // to look it up. + name := f[1] + switch len(ParseIP(name)) { + case 16: + name = "[" + name + "]" + fallthrough + case 4: + a = a[0 : n+1] + a[n] = name + conf.servers = a + } + } + + case "domain": // set search path to just this domain + if len(f) > 1 { + conf.search = make([]string, 1) + conf.search[0] = f[1] + } else { + conf.search = make([]string, 0) + } + + case "search": // set search path to given servers + conf.search = make([]string, len(f)-1) + for i := 0; i < len(conf.search); i++ { + conf.search[i] = f[i+1] + } + + case "options": // magic options + for i := 1; i < len(f); i++ { + s := f[i] + switch { + case len(s) >= 6 && s[0:6] == "ndots:": + n, _, _ := dtoi(s, 6) + if n < 1 { + n = 1 + } + conf.ndots = n + case len(s) >= 8 && s[0:8] == "timeout:": + n, _, _ := dtoi(s, 8) + if n < 1 { + n = 1 + } + conf.timeout = n + case len(s) >= 8 && s[0:9] == "attempts:": + n, _, _ := dtoi(s, 9) + if n < 1 { + n = 1 + } + conf.attempts = n + case s == "rotate": + conf.rotate = true + } + } + } + } + file.close() + + return conf, nil +} diff --git a/libgo/go/net/dnsmsg.go b/libgo/go/net/dnsmsg.go new file mode 100644 index 000000000..dc195caf8 --- /dev/null +++ b/libgo/go/net/dnsmsg.go @@ -0,0 +1,743 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// DNS packet assembly. See RFC 1035. +// +// This is intended to support name resolution during net.Dial. +// It doesn't have to be blazing fast. +// +// Rather than write the usual handful of routines to pack and +// unpack every message that can appear on the wire, we use +// reflection to write a generic pack/unpack for structs and then +// use it. Thus, if in the future we need to define new message +// structs, no new pack/unpack/printing code needs to be written. +// +// The first half of this file defines the DNS message formats. +// The second half implements the conversion to and from wire format. +// A few of the structure elements have string tags to aid the +// generic pack/unpack routines. +// +// TODO(rsc): There are enough names defined in this file that they're all +// prefixed with dns. Perhaps put this in its own package later. + +package net + +import ( + "fmt" + "os" + "reflect" +) + +// Packet formats + +// Wire constants. +const ( + // valid dnsRR_Header.Rrtype and dnsQuestion.qtype + dnsTypeA = 1 + dnsTypeNS = 2 + dnsTypeMD = 3 + dnsTypeMF = 4 + dnsTypeCNAME = 5 + dnsTypeSOA = 6 + dnsTypeMB = 7 + dnsTypeMG = 8 + dnsTypeMR = 9 + dnsTypeNULL = 10 + dnsTypeWKS = 11 + dnsTypePTR = 12 + dnsTypeHINFO = 13 + dnsTypeMINFO = 14 + dnsTypeMX = 15 + dnsTypeTXT = 16 + dnsTypeSRV = 33 + + // valid dnsQuestion.qtype only + dnsTypeAXFR = 252 + dnsTypeMAILB = 253 + dnsTypeMAILA = 254 + dnsTypeALL = 255 + + // valid dnsQuestion.qclass + dnsClassINET = 1 + dnsClassCSNET = 2 + dnsClassCHAOS = 3 + dnsClassHESIOD = 4 + dnsClassANY = 255 + + // dnsMsg.rcode + dnsRcodeSuccess = 0 + dnsRcodeFormatError = 1 + dnsRcodeServerFailure = 2 + dnsRcodeNameError = 3 + dnsRcodeNotImplemented = 4 + dnsRcodeRefused = 5 +) + +// The wire format for the DNS packet header. +type dnsHeader struct { + Id uint16 + Bits uint16 + Qdcount, Ancount, Nscount, Arcount uint16 +} + +const ( + // dnsHeader.Bits + _QR = 1 << 15 // query/response (response=1) + _AA = 1 << 10 // authoritative + _TC = 1 << 9 // truncated + _RD = 1 << 8 // recursion desired + _RA = 1 << 7 // recursion available +) + +// DNS queries. +type dnsQuestion struct { + Name string "domain-name" // "domain-name" specifies encoding; see packers below + Qtype uint16 + Qclass uint16 +} + +// DNS responses (resource records). +// There are many types of messages, +// but they all share the same header. +type dnsRR_Header struct { + Name string "domain-name" + Rrtype uint16 + Class uint16 + Ttl uint32 + Rdlength uint16 // length of data after header +} + +func (h *dnsRR_Header) Header() *dnsRR_Header { + return h +} + +type dnsRR interface { + Header() *dnsRR_Header +} + + +// Specific DNS RR formats for each query type. + +type dnsRR_CNAME struct { + Hdr dnsRR_Header + Cname string "domain-name" +} + +func (rr *dnsRR_CNAME) Header() *dnsRR_Header { + return &rr.Hdr +} + +type dnsRR_HINFO struct { + Hdr dnsRR_Header + Cpu string + Os string +} + +func (rr *dnsRR_HINFO) Header() *dnsRR_Header { + return &rr.Hdr +} + +type dnsRR_MB struct { + Hdr dnsRR_Header + Mb string "domain-name" +} + +func (rr *dnsRR_MB) Header() *dnsRR_Header { + return &rr.Hdr +} + +type dnsRR_MG struct { + Hdr dnsRR_Header + Mg string "domain-name" +} + +func (rr *dnsRR_MG) Header() *dnsRR_Header { + return &rr.Hdr +} + +type dnsRR_MINFO struct { + Hdr dnsRR_Header + Rmail string "domain-name" + Email string "domain-name" +} + +func (rr *dnsRR_MINFO) Header() *dnsRR_Header { + return &rr.Hdr +} + +type dnsRR_MR struct { + Hdr dnsRR_Header + Mr string "domain-name" +} + +func (rr *dnsRR_MR) Header() *dnsRR_Header { + return &rr.Hdr +} + +type dnsRR_MX struct { + Hdr dnsRR_Header + Pref uint16 + Mx string "domain-name" +} + +func (rr *dnsRR_MX) Header() *dnsRR_Header { + return &rr.Hdr +} + +type dnsRR_NS struct { + Hdr dnsRR_Header + Ns string "domain-name" +} + +func (rr *dnsRR_NS) Header() *dnsRR_Header { + return &rr.Hdr +} + +type dnsRR_PTR struct { + Hdr dnsRR_Header + Ptr string "domain-name" +} + +func (rr *dnsRR_PTR) Header() *dnsRR_Header { + return &rr.Hdr +} + +type dnsRR_SOA struct { + Hdr dnsRR_Header + Ns string "domain-name" + Mbox string "domain-name" + Serial uint32 + Refresh uint32 + Retry uint32 + Expire uint32 + Minttl uint32 +} + +func (rr *dnsRR_SOA) Header() *dnsRR_Header { + return &rr.Hdr +} + +type dnsRR_TXT struct { + Hdr dnsRR_Header + Txt string // not domain name +} + +func (rr *dnsRR_TXT) Header() *dnsRR_Header { + return &rr.Hdr +} + +type dnsRR_SRV struct { + Hdr dnsRR_Header + Priority uint16 + Weight uint16 + Port uint16 + Target string "domain-name" +} + +func (rr *dnsRR_SRV) Header() *dnsRR_Header { + return &rr.Hdr +} + +type dnsRR_A struct { + Hdr dnsRR_Header + A uint32 "ipv4" +} + +func (rr *dnsRR_A) Header() *dnsRR_Header { return &rr.Hdr } + + +// Packing and unpacking. +// +// All the packers and unpackers take a (msg []byte, off int) +// and return (off1 int, ok bool). If they return ok==false, they +// also return off1==len(msg), so that the next unpacker will +// also fail. This lets us avoid checks of ok until the end of a +// packing sequence. + +// Map of constructors for each RR wire type. +var rr_mk = map[int]func() dnsRR{ + dnsTypeCNAME: func() dnsRR { return new(dnsRR_CNAME) }, + dnsTypeHINFO: func() dnsRR { return new(dnsRR_HINFO) }, + dnsTypeMB: func() dnsRR { return new(dnsRR_MB) }, + dnsTypeMG: func() dnsRR { return new(dnsRR_MG) }, + dnsTypeMINFO: func() dnsRR { return new(dnsRR_MINFO) }, + dnsTypeMR: func() dnsRR { return new(dnsRR_MR) }, + dnsTypeMX: func() dnsRR { return new(dnsRR_MX) }, + dnsTypeNS: func() dnsRR { return new(dnsRR_NS) }, + dnsTypePTR: func() dnsRR { return new(dnsRR_PTR) }, + dnsTypeSOA: func() dnsRR { return new(dnsRR_SOA) }, + dnsTypeTXT: func() dnsRR { return new(dnsRR_TXT) }, + dnsTypeSRV: func() dnsRR { return new(dnsRR_SRV) }, + dnsTypeA: func() dnsRR { return new(dnsRR_A) }, +} + +// Pack a domain name s into msg[off:]. +// Domain names are a sequence of counted strings +// split at the dots. They end with a zero-length string. +func packDomainName(s string, msg []byte, off int) (off1 int, ok bool) { + // Add trailing dot to canonicalize name. + if n := len(s); n == 0 || s[n-1] != '.' { + s += "." + } + + // Each dot ends a segment of the name. + // We trade each dot byte for a length byte. + // There is also a trailing zero. + // Check that we have all the space we need. + tot := len(s) + 1 + if off+tot > len(msg) { + return len(msg), false + } + + // Emit sequence of counted strings, chopping at dots. + begin := 0 + for i := 0; i < len(s); i++ { + if s[i] == '.' { + if i-begin >= 1<<6 { // top two bits of length must be clear + return len(msg), false + } + msg[off] = byte(i - begin) + off++ + for j := begin; j < i; j++ { + msg[off] = s[j] + off++ + } + begin = i + 1 + } + } + msg[off] = 0 + off++ + return off, true +} + +// Unpack a domain name. +// In addition to the simple sequences of counted strings above, +// domain names are allowed to refer to strings elsewhere in the +// packet, to avoid repeating common suffixes when returning +// many entries in a single domain. The pointers are marked +// by a length byte with the top two bits set. Ignoring those +// two bits, that byte and the next give a 14 bit offset from msg[0] +// where we should pick up the trail. +// Note that if we jump elsewhere in the packet, +// we return off1 == the offset after the first pointer we found, +// which is where the next record will start. +// In theory, the pointers are only allowed to jump backward. +// We let them jump anywhere and stop jumping after a while. +func unpackDomainName(msg []byte, off int) (s string, off1 int, ok bool) { + s = "" + ptr := 0 // number of pointers followed +Loop: + for { + if off >= len(msg) { + return "", len(msg), false + } + c := int(msg[off]) + off++ + switch c & 0xC0 { + case 0x00: + if c == 0x00 { + // end of name + break Loop + } + // literal string + if off+c > len(msg) { + return "", len(msg), false + } + s += string(msg[off:off+c]) + "." + off += c + case 0xC0: + // pointer to somewhere else in msg. + // remember location after first ptr, + // since that's how many bytes we consumed. + // also, don't follow too many pointers -- + // maybe there's a loop. + if off >= len(msg) { + return "", len(msg), false + } + c1 := msg[off] + off++ + if ptr == 0 { + off1 = off + } + if ptr++; ptr > 10 { + return "", len(msg), false + } + off = (c^0xC0)<<8 | int(c1) + default: + // 0x80 and 0x40 are reserved + return "", len(msg), false + } + } + if ptr == 0 { + off1 = off + } + return s, off1, true +} + +// TODO(rsc): Move into generic library? +// Pack a reflect.StructValue into msg. Struct members can only be uint16, uint32, string, +// and other (often anonymous) structs. +func packStructValue(val *reflect.StructValue, msg []byte, off int) (off1 int, ok bool) { + for i := 0; i < val.NumField(); i++ { + f := val.Type().(*reflect.StructType).Field(i) + switch fv := val.Field(i).(type) { + default: + BadType: + fmt.Fprintf(os.Stderr, "net: dns: unknown packing type %v", f.Type) + return len(msg), false + case *reflect.StructValue: + off, ok = packStructValue(fv, msg, off) + case *reflect.UintValue: + i := fv.Get() + switch fv.Type().Kind() { + default: + goto BadType + case reflect.Uint16: + if off+2 > len(msg) { + return len(msg), false + } + msg[off] = byte(i >> 8) + msg[off+1] = byte(i) + off += 2 + case reflect.Uint32: + if off+4 > len(msg) { + return len(msg), false + } + msg[off] = byte(i >> 24) + msg[off+1] = byte(i >> 16) + msg[off+2] = byte(i >> 8) + msg[off+3] = byte(i) + off += 4 + } + case *reflect.StringValue: + // There are multiple string encodings. + // The tag distinguishes ordinary strings from domain names. + s := fv.Get() + switch f.Tag { + default: + fmt.Fprintf(os.Stderr, "net: dns: unknown string tag %v", f.Tag) + return len(msg), false + case "domain-name": + off, ok = packDomainName(s, msg, off) + if !ok { + return len(msg), false + } + case "": + // Counted string: 1 byte length. + if len(s) > 255 || off+1+len(s) > len(msg) { + return len(msg), false + } + msg[off] = byte(len(s)) + off++ + off += copy(msg[off:], s) + } + } + } + return off, true +} + +func structValue(any interface{}) *reflect.StructValue { + return reflect.NewValue(any).(*reflect.PtrValue).Elem().(*reflect.StructValue) +} + +func packStruct(any interface{}, msg []byte, off int) (off1 int, ok bool) { + off, ok = packStructValue(structValue(any), msg, off) + return off, ok +} + +// TODO(rsc): Move into generic library? +// Unpack a reflect.StructValue from msg. +// Same restrictions as packStructValue. +func unpackStructValue(val *reflect.StructValue, msg []byte, off int) (off1 int, ok bool) { + for i := 0; i < val.NumField(); i++ { + f := val.Type().(*reflect.StructType).Field(i) + switch fv := val.Field(i).(type) { + default: + BadType: + fmt.Fprintf(os.Stderr, "net: dns: unknown packing type %v", f.Type) + return len(msg), false + case *reflect.StructValue: + off, ok = unpackStructValue(fv, msg, off) + case *reflect.UintValue: + switch fv.Type().Kind() { + default: + goto BadType + case reflect.Uint16: + if off+2 > len(msg) { + return len(msg), false + } + i := uint16(msg[off])<<8 | uint16(msg[off+1]) + fv.Set(uint64(i)) + off += 2 + case reflect.Uint32: + if off+4 > len(msg) { + return len(msg), false + } + i := uint32(msg[off])<<24 | uint32(msg[off+1])<<16 | uint32(msg[off+2])<<8 | uint32(msg[off+3]) + fv.Set(uint64(i)) + off += 4 + } + case *reflect.StringValue: + var s string + switch f.Tag { + default: + fmt.Fprintf(os.Stderr, "net: dns: unknown string tag %v", f.Tag) + return len(msg), false + case "domain-name": + s, off, ok = unpackDomainName(msg, off) + if !ok { + return len(msg), false + } + case "": + if off >= len(msg) || off+1+int(msg[off]) > len(msg) { + return len(msg), false + } + n := int(msg[off]) + off++ + b := make([]byte, n) + for i := 0; i < n; i++ { + b[i] = msg[off+i] + } + off += n + s = string(b) + } + fv.Set(s) + } + } + return off, true +} + +func unpackStruct(any interface{}, msg []byte, off int) (off1 int, ok bool) { + off, ok = unpackStructValue(structValue(any), msg, off) + return off, ok +} + +// Generic struct printer. +// Doesn't care about the string tag "domain-name", +// but does look for an "ipv4" tag on uint32 variables, +// printing them as IP addresses. +func printStructValue(val *reflect.StructValue) string { + s := "{" + for i := 0; i < val.NumField(); i++ { + if i > 0 { + s += ", " + } + f := val.Type().(*reflect.StructType).Field(i) + if !f.Anonymous { + s += f.Name + "=" + } + fval := val.Field(i) + if fv, ok := fval.(*reflect.StructValue); ok { + s += printStructValue(fv) + } else if fv, ok := fval.(*reflect.UintValue); ok && f.Tag == "ipv4" { + i := fv.Get() + s += IPv4(byte(i>>24), byte(i>>16), byte(i>>8), byte(i)).String() + } else { + s += fmt.Sprint(fval.Interface()) + } + } + s += "}" + return s +} + +func printStruct(any interface{}) string { return printStructValue(structValue(any)) } + +// Resource record packer. +func packRR(rr dnsRR, msg []byte, off int) (off2 int, ok bool) { + var off1 int + // pack twice, once to find end of header + // and again to find end of packet. + // a bit inefficient but this doesn't need to be fast. + // off1 is end of header + // off2 is end of rr + off1, ok = packStruct(rr.Header(), msg, off) + off2, ok = packStruct(rr, msg, off) + if !ok { + return len(msg), false + } + // pack a third time; redo header with correct data length + rr.Header().Rdlength = uint16(off2 - off1) + packStruct(rr.Header(), msg, off) + return off2, true +} + +// Resource record unpacker. +func unpackRR(msg []byte, off int) (rr dnsRR, off1 int, ok bool) { + // unpack just the header, to find the rr type and length + var h dnsRR_Header + off0 := off + if off, ok = unpackStruct(&h, msg, off); !ok { + return nil, len(msg), false + } + end := off + int(h.Rdlength) + + // make an rr of that type and re-unpack. + // again inefficient but doesn't need to be fast. + mk, known := rr_mk[int(h.Rrtype)] + if !known { + return &h, end, true + } + rr = mk() + off, ok = unpackStruct(rr, msg, off0) + if off != end { + return &h, end, true + } + return rr, off, ok +} + +// Usable representation of a DNS packet. + +// A manually-unpacked version of (id, bits). +// This is in its own struct for easy printing. +type dnsMsgHdr struct { + id uint16 + response bool + opcode int + authoritative bool + truncated bool + recursion_desired bool + recursion_available bool + rcode int +} + +type dnsMsg struct { + dnsMsgHdr + question []dnsQuestion + answer []dnsRR + ns []dnsRR + extra []dnsRR +} + + +func (dns *dnsMsg) Pack() (msg []byte, ok bool) { + var dh dnsHeader + + // Convert convenient dnsMsg into wire-like dnsHeader. + dh.Id = dns.id + dh.Bits = uint16(dns.opcode)<<11 | uint16(dns.rcode) + if dns.recursion_available { + dh.Bits |= _RA + } + if dns.recursion_desired { + dh.Bits |= _RD + } + if dns.truncated { + dh.Bits |= _TC + } + if dns.authoritative { + dh.Bits |= _AA + } + if dns.response { + dh.Bits |= _QR + } + + // Prepare variable sized arrays. + question := dns.question + answer := dns.answer + ns := dns.ns + extra := dns.extra + + dh.Qdcount = uint16(len(question)) + dh.Ancount = uint16(len(answer)) + dh.Nscount = uint16(len(ns)) + dh.Arcount = uint16(len(extra)) + + // Could work harder to calculate message size, + // but this is far more than we need and not + // big enough to hurt the allocator. + msg = make([]byte, 2000) + + // Pack it in: header and then the pieces. + off := 0 + off, ok = packStruct(&dh, msg, off) + for i := 0; i < len(question); i++ { + off, ok = packStruct(&question[i], msg, off) + } + for i := 0; i < len(answer); i++ { + off, ok = packRR(answer[i], msg, off) + } + for i := 0; i < len(ns); i++ { + off, ok = packRR(ns[i], msg, off) + } + for i := 0; i < len(extra); i++ { + off, ok = packRR(extra[i], msg, off) + } + if !ok { + return nil, false + } + return msg[0:off], true +} + +func (dns *dnsMsg) Unpack(msg []byte) bool { + // Header. + var dh dnsHeader + off := 0 + var ok bool + if off, ok = unpackStruct(&dh, msg, off); !ok { + return false + } + dns.id = dh.Id + dns.response = (dh.Bits & _QR) != 0 + dns.opcode = int(dh.Bits>>11) & 0xF + dns.authoritative = (dh.Bits & _AA) != 0 + dns.truncated = (dh.Bits & _TC) != 0 + dns.recursion_desired = (dh.Bits & _RD) != 0 + dns.recursion_available = (dh.Bits & _RA) != 0 + dns.rcode = int(dh.Bits & 0xF) + + // Arrays. + dns.question = make([]dnsQuestion, dh.Qdcount) + dns.answer = make([]dnsRR, dh.Ancount) + dns.ns = make([]dnsRR, dh.Nscount) + dns.extra = make([]dnsRR, dh.Arcount) + + for i := 0; i < len(dns.question); i++ { + off, ok = unpackStruct(&dns.question[i], msg, off) + } + for i := 0; i < len(dns.answer); i++ { + dns.answer[i], off, ok = unpackRR(msg, off) + } + for i := 0; i < len(dns.ns); i++ { + dns.ns[i], off, ok = unpackRR(msg, off) + } + for i := 0; i < len(dns.extra); i++ { + dns.extra[i], off, ok = unpackRR(msg, off) + } + if !ok { + return false + } + // if off != len(msg) { + // println("extra bytes in dns packet", off, "<", len(msg)); + // } + return true +} + +func (dns *dnsMsg) String() string { + s := "DNS: " + printStruct(&dns.dnsMsgHdr) + "\n" + if len(dns.question) > 0 { + s += "-- Questions\n" + for i := 0; i < len(dns.question); i++ { + s += printStruct(&dns.question[i]) + "\n" + } + } + if len(dns.answer) > 0 { + s += "-- Answers\n" + for i := 0; i < len(dns.answer); i++ { + s += printStruct(dns.answer[i]) + "\n" + } + } + if len(dns.ns) > 0 { + s += "-- Name servers\n" + for i := 0; i < len(dns.ns); i++ { + s += printStruct(dns.ns[i]) + "\n" + } + } + if len(dns.extra) > 0 { + s += "-- Extra\n" + for i := 0; i < len(dns.extra); i++ { + s += printStruct(dns.extra[i]) + "\n" + } + } + return s +} diff --git a/libgo/go/net/dnsname_test.go b/libgo/go/net/dnsname_test.go new file mode 100644 index 000000000..0c1a62518 --- /dev/null +++ b/libgo/go/net/dnsname_test.go @@ -0,0 +1,69 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package net + +import ( + "testing" + "runtime" +) + +type testCase struct { + name string + result bool +} + +var tests = []testCase{ + // RFC2181, section 11. + {"_xmpp-server._tcp.google.com", true}, + {"_xmpp-server._tcp.google.com", true}, + {"foo.com", true}, + {"1foo.com", true}, + {"26.0.0.73.com", true}, + {"fo-o.com", true}, + {"fo1o.com", true}, + {"foo1.com", true}, + {"a.b..com", false}, +} + +func getTestCases(ch chan<- testCase) { + defer close(ch) + var char59 = "" + var char63 = "" + var char64 = "" + for i := 0; i < 59; i++ { + char59 += "a" + } + char63 = char59 + "aaaa" + char64 = char63 + "a" + + for _, tc := range tests { + ch <- tc + } + + ch <- testCase{char63 + ".com", true} + ch <- testCase{char64 + ".com", false} + // 255 char name is fine: + ch <- testCase{char59 + "." + char63 + "." + char63 + "." + + char63 + ".com", + true} + // 256 char name is bad: + ch <- testCase{char59 + "a." + char63 + "." + char63 + "." + + char63 + ".com", + false} +} + +func TestDNSNames(t *testing.T) { + if runtime.GOOS == "windows" { + return + } + ch := make(chan testCase) + go getTestCases(ch) + for tc := range ch { + if isDomainName(tc.name) != tc.result { + t.Errorf("isDomainName(%v) failed: Should be %v", + tc.name, tc.result) + } + } +} diff --git a/libgo/go/net/fd.go b/libgo/go/net/fd.go new file mode 100644 index 000000000..26d17d4e0 --- /dev/null +++ b/libgo/go/net/fd.go @@ -0,0 +1,612 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// TODO(rsc): All the prints in this file should go to standard error. + +package net + +import ( + "io" + "os" + "sync" + "syscall" + "time" +) + +// Network file descriptor. +type netFD struct { + // locking/lifetime of sysfd + sysmu sync.Mutex + sysref int + closing bool + + // immutable until Close + sysfd int + family int + proto int + sysfile *os.File + cr chan bool + cw chan bool + net string + laddr Addr + raddr Addr + + // owned by client + rdeadline_delta int64 + rdeadline int64 + rio sync.Mutex + wdeadline_delta int64 + wdeadline int64 + wio sync.Mutex + + // owned by fd wait server + ncr, ncw int +} + +type InvalidConnError struct{} + +func (e *InvalidConnError) String() string { return "invalid net.Conn" } +func (e *InvalidConnError) Temporary() bool { return false } +func (e *InvalidConnError) Timeout() bool { return false } + +// A pollServer helps FDs determine when to retry a non-blocking +// read or write after they get EAGAIN. When an FD needs to wait, +// send the fd on s.cr (for a read) or s.cw (for a write) to pass the +// request to the poll server. Then receive on fd.cr/fd.cw. +// When the pollServer finds that i/o on FD should be possible +// again, it will send fd on fd.cr/fd.cw to wake any waiting processes. +// This protocol is implemented as s.WaitRead() and s.WaitWrite(). +// +// There is one subtlety: when sending on s.cr/s.cw, the +// poll server is probably in a system call, waiting for an fd +// to become ready. It's not looking at the request channels. +// To resolve this, the poll server waits not just on the FDs it has +// been given but also its own pipe. After sending on the +// buffered channel s.cr/s.cw, WaitRead/WaitWrite writes a +// byte to the pipe, causing the pollServer's poll system call to +// return. In response to the pipe being readable, the pollServer +// re-polls its request channels. +// +// Note that the ordering is "send request" and then "wake up server". +// If the operations were reversed, there would be a race: the poll +// server might wake up and look at the request channel, see that it +// was empty, and go back to sleep, all before the requester managed +// to send the request. Because the send must complete before the wakeup, +// the request channel must be buffered. A buffer of size 1 is sufficient +// for any request load. If many processes are trying to submit requests, +// one will succeed, the pollServer will read the request, and then the +// channel will be empty for the next process's request. A larger buffer +// might help batch requests. +// +// To avoid races in closing, all fd operations are locked and +// refcounted. when netFD.Close() is called, it calls syscall.Shutdown +// and sets a closing flag. Only when the last reference is removed +// will the fd be closed. + +type pollServer struct { + cr, cw chan *netFD // buffered >= 1 + pr, pw *os.File + pending map[int]*netFD + poll *pollster // low-level OS hooks + deadline int64 // next deadline (nsec since 1970) +} + +func (s *pollServer) AddFD(fd *netFD, mode int) { + intfd := fd.sysfd + if intfd < 0 { + // fd closed underfoot + if mode == 'r' { + fd.cr <- true + } else { + fd.cw <- true + } + return + } + if err := s.poll.AddFD(intfd, mode, false); err != nil { + panic("pollServer AddFD " + err.String()) + return + } + + var t int64 + key := intfd << 1 + if mode == 'r' { + fd.ncr++ + t = fd.rdeadline + } else { + fd.ncw++ + key++ + t = fd.wdeadline + } + s.pending[key] = fd + if t > 0 && (s.deadline == 0 || t < s.deadline) { + s.deadline = t + } +} + +func (s *pollServer) LookupFD(fd int, mode int) *netFD { + key := fd << 1 + if mode == 'w' { + key++ + } + netfd, ok := s.pending[key] + if !ok { + return nil + } + s.pending[key] = nil, false + return netfd +} + +func (s *pollServer) WakeFD(fd *netFD, mode int) { + if mode == 'r' { + for fd.ncr > 0 { + fd.ncr-- + fd.cr <- true + } + } else { + for fd.ncw > 0 { + fd.ncw-- + fd.cw <- true + } + } +} + +func (s *pollServer) Now() int64 { + return time.Nanoseconds() +} + +func (s *pollServer) CheckDeadlines() { + now := s.Now() + // TODO(rsc): This will need to be handled more efficiently, + // probably with a heap indexed by wakeup time. + + var next_deadline int64 + for key, fd := range s.pending { + var t int64 + var mode int + if key&1 == 0 { + mode = 'r' + } else { + mode = 'w' + } + if mode == 'r' { + t = fd.rdeadline + } else { + t = fd.wdeadline + } + if t > 0 { + if t <= now { + s.pending[key] = nil, false + if mode == 'r' { + s.poll.DelFD(fd.sysfd, mode) + fd.rdeadline = -1 + } else { + s.poll.DelFD(fd.sysfd, mode) + fd.wdeadline = -1 + } + s.WakeFD(fd, mode) + } else if next_deadline == 0 || t < next_deadline { + next_deadline = t + } + } + } + s.deadline = next_deadline +} + +func (s *pollServer) Run() { + var scratch [100]byte + for { + var t = s.deadline + if t > 0 { + t = t - s.Now() + if t <= 0 { + s.CheckDeadlines() + continue + } + } + fd, mode, err := s.poll.WaitFD(t) + if err != nil { + print("pollServer WaitFD: ", err.String(), "\n") + return + } + if fd < 0 { + // Timeout happened. + s.CheckDeadlines() + continue + } + if fd == s.pr.Fd() { + // Drain our wakeup pipe. + for nn, _ := s.pr.Read(scratch[0:]); nn > 0; { + nn, _ = s.pr.Read(scratch[0:]) + } + // Read from channels + for fd, ok := <-s.cr; ok; fd, ok = <-s.cr { + s.AddFD(fd, 'r') + } + for fd, ok := <-s.cw; ok; fd, ok = <-s.cw { + s.AddFD(fd, 'w') + } + } else { + netfd := s.LookupFD(fd, mode) + if netfd == nil { + print("pollServer: unexpected wakeup for fd=", fd, " mode=", string(mode), "\n") + continue + } + s.WakeFD(netfd, mode) + } + } +} + +var wakeupbuf [1]byte + +func (s *pollServer) Wakeup() { s.pw.Write(wakeupbuf[0:]) } + +func (s *pollServer) WaitRead(fd *netFD) { + s.cr <- fd + s.Wakeup() + <-fd.cr +} + +func (s *pollServer) WaitWrite(fd *netFD) { + s.cw <- fd + s.Wakeup() + <-fd.cw +} + +// Network FD methods. +// All the network FDs use a single pollServer. + +var pollserver *pollServer +var onceStartServer sync.Once + +func startServer() { + p, err := newPollServer() + if err != nil { + print("Start pollServer: ", err.String(), "\n") + } + pollserver = p +} + +func newFD(fd, family, proto int, net string, laddr, raddr Addr) (f *netFD, err os.Error) { + onceStartServer.Do(startServer) + if e := syscall.SetNonblock(fd, true); e != 0 { + return nil, &OpError{"setnonblock", net, laddr, os.Errno(e)} + } + f = &netFD{ + sysfd: fd, + family: family, + proto: proto, + net: net, + laddr: laddr, + raddr: raddr, + } + var ls, rs string + if laddr != nil { + ls = laddr.String() + } + if raddr != nil { + rs = raddr.String() + } + f.sysfile = os.NewFile(fd, net+":"+ls+"->"+rs) + f.cr = make(chan bool, 1) + f.cw = make(chan bool, 1) + return f, nil +} + +// Add a reference to this fd. +func (fd *netFD) incref() { + fd.sysmu.Lock() + fd.sysref++ + fd.sysmu.Unlock() +} + +// Remove a reference to this FD and close if we've been asked to do so (and +// there are no references left. +func (fd *netFD) decref() { + fd.sysmu.Lock() + fd.sysref-- + if fd.closing && fd.sysref == 0 && fd.sysfd >= 0 { + // In case the user has set linger, switch to blocking mode so + // the close blocks. As long as this doesn't happen often, we + // can handle the extra OS processes. Otherwise we'll need to + // use the pollserver for Close too. Sigh. + syscall.SetNonblock(fd.sysfd, false) + fd.sysfile.Close() + fd.sysfile = nil + fd.sysfd = -1 + } + fd.sysmu.Unlock() +} + +func (fd *netFD) Close() os.Error { + if fd == nil || fd.sysfile == nil { + return os.EINVAL + } + + fd.incref() + syscall.Shutdown(fd.sysfd, syscall.SHUT_RDWR) + fd.closing = true + fd.decref() + return nil +} + +func (fd *netFD) Read(p []byte) (n int, err os.Error) { + if fd == nil { + return 0, os.EINVAL + } + fd.rio.Lock() + defer fd.rio.Unlock() + fd.incref() + defer fd.decref() + if fd.sysfile == nil { + return 0, os.EINVAL + } + if fd.rdeadline_delta > 0 { + fd.rdeadline = pollserver.Now() + fd.rdeadline_delta + } else { + fd.rdeadline = 0 + } + var oserr os.Error + for { + var errno int + n, errno = syscall.Read(fd.sysfile.Fd(), p) + if (errno == syscall.EAGAIN || errno == syscall.EINTR) && fd.rdeadline >= 0 { + pollserver.WaitRead(fd) + continue + } + if errno != 0 { + n = 0 + oserr = os.Errno(errno) + } else if n == 0 && errno == 0 && fd.proto != syscall.SOCK_DGRAM { + err = os.EOF + } + break + } + if oserr != nil { + err = &OpError{"read", fd.net, fd.raddr, oserr} + } + return +} + +func (fd *netFD) ReadFrom(p []byte) (n int, sa syscall.Sockaddr, err os.Error) { + if fd == nil || fd.sysfile == nil { + return 0, nil, os.EINVAL + } + fd.rio.Lock() + defer fd.rio.Unlock() + fd.incref() + defer fd.decref() + if fd.rdeadline_delta > 0 { + fd.rdeadline = pollserver.Now() + fd.rdeadline_delta + } else { + fd.rdeadline = 0 + } + var oserr os.Error + for { + var errno int + n, sa, errno = syscall.Recvfrom(fd.sysfd, p, 0) + if (errno == syscall.EAGAIN || errno == syscall.EINTR) && fd.rdeadline >= 0 { + pollserver.WaitRead(fd) + continue + } + if errno != 0 { + n = 0 + oserr = os.Errno(errno) + } + break + } + if oserr != nil { + err = &OpError{"read", fd.net, fd.laddr, oserr} + } + return +} + +func (fd *netFD) ReadMsg(p []byte, oob []byte) (n, oobn, flags int, sa syscall.Sockaddr, err os.Error) { + if fd == nil || fd.sysfile == nil { + return 0, 0, 0, nil, os.EINVAL + } + fd.rio.Lock() + defer fd.rio.Unlock() + fd.incref() + defer fd.decref() + if fd.rdeadline_delta > 0 { + fd.rdeadline = pollserver.Now() + fd.rdeadline_delta + } else { + fd.rdeadline = 0 + } + var oserr os.Error + for { + var errno int + n, oobn, flags, sa, errno = syscall.Recvmsg(fd.sysfd, p, oob, 0) + if (errno == syscall.EAGAIN || errno == syscall.EINTR) && fd.rdeadline >= 0 { + pollserver.WaitRead(fd) + continue + } + if errno != 0 { + oserr = os.Errno(errno) + } + if n == 0 { + oserr = os.EOF + } + break + } + if oserr != nil { + err = &OpError{"read", fd.net, fd.laddr, oserr} + return + } + return +} + +func (fd *netFD) Write(p []byte) (n int, err os.Error) { + if fd == nil { + return 0, os.EINVAL + } + fd.wio.Lock() + defer fd.wio.Unlock() + fd.incref() + defer fd.decref() + if fd.sysfile == nil { + return 0, os.EINVAL + } + if fd.wdeadline_delta > 0 { + fd.wdeadline = pollserver.Now() + fd.wdeadline_delta + } else { + fd.wdeadline = 0 + } + nn := 0 + var oserr os.Error + + for { + n, errno := syscall.Write(fd.sysfile.Fd(), p[nn:]) + if n > 0 { + nn += n + } + if nn == len(p) { + break + } + if (errno == syscall.EAGAIN || errno == syscall.EINTR) && fd.wdeadline >= 0 { + pollserver.WaitWrite(fd) + continue + } + if errno != 0 { + n = 0 + oserr = os.Errno(errno) + break + } + if n == 0 { + oserr = io.ErrUnexpectedEOF + break + } + } + if oserr != nil { + err = &OpError{"write", fd.net, fd.raddr, oserr} + } + return nn, err +} + +func (fd *netFD) WriteTo(p []byte, sa syscall.Sockaddr) (n int, err os.Error) { + if fd == nil || fd.sysfile == nil { + return 0, os.EINVAL + } + fd.wio.Lock() + defer fd.wio.Unlock() + fd.incref() + defer fd.decref() + if fd.wdeadline_delta > 0 { + fd.wdeadline = pollserver.Now() + fd.wdeadline_delta + } else { + fd.wdeadline = 0 + } + var oserr os.Error + for { + errno := syscall.Sendto(fd.sysfd, p, 0, sa) + if (errno == syscall.EAGAIN || errno == syscall.EINTR) && fd.wdeadline >= 0 { + pollserver.WaitWrite(fd) + continue + } + if errno != 0 { + oserr = os.Errno(errno) + } + break + } + if oserr == nil { + n = len(p) + } else { + err = &OpError{"write", fd.net, fd.raddr, oserr} + } + return +} + +func (fd *netFD) WriteMsg(p []byte, oob []byte, sa syscall.Sockaddr) (n int, oobn int, err os.Error) { + if fd == nil || fd.sysfile == nil { + return 0, 0, os.EINVAL + } + fd.wio.Lock() + defer fd.wio.Unlock() + fd.incref() + defer fd.decref() + if fd.wdeadline_delta > 0 { + fd.wdeadline = pollserver.Now() + fd.wdeadline_delta + } else { + fd.wdeadline = 0 + } + var oserr os.Error + for { + var errno int + errno = syscall.Sendmsg(fd.sysfd, p, oob, sa, 0) + if (errno == syscall.EAGAIN || errno == syscall.EINTR) && fd.wdeadline >= 0 { + pollserver.WaitWrite(fd) + continue + } + if errno != 0 { + oserr = os.Errno(errno) + } + break + } + if oserr == nil { + n = len(p) + oobn = len(oob) + } else { + err = &OpError{"write", fd.net, fd.raddr, oserr} + } + return +} + +func (fd *netFD) accept(toAddr func(syscall.Sockaddr) Addr) (nfd *netFD, err os.Error) { + if fd == nil || fd.sysfile == nil { + return nil, os.EINVAL + } + + fd.incref() + defer fd.decref() + + // See ../syscall/exec.go for description of ForkLock. + // It is okay to hold the lock across syscall.Accept + // because we have put fd.sysfd into non-blocking mode. + syscall.ForkLock.RLock() + var s, e int + var sa syscall.Sockaddr + for { + if fd.closing { + syscall.ForkLock.RUnlock() + return nil, os.EINVAL + } + s, sa, e = syscall.Accept(fd.sysfd) + if e != syscall.EAGAIN && e != syscall.EINTR { + break + } + syscall.ForkLock.RUnlock() + pollserver.WaitRead(fd) + syscall.ForkLock.RLock() + } + if e != 0 { + syscall.ForkLock.RUnlock() + return nil, &OpError{"accept", fd.net, fd.laddr, os.Errno(e)} + } + syscall.CloseOnExec(s) + syscall.ForkLock.RUnlock() + + if nfd, err = newFD(s, fd.family, fd.proto, fd.net, fd.laddr, toAddr(sa)); err != nil { + syscall.Close(s) + return nil, err + } + return nfd, nil +} + +func (fd *netFD) dup() (f *os.File, err os.Error) { + ns, e := syscall.Dup(fd.sysfd) + if e != 0 { + return nil, &OpError{"dup", fd.net, fd.laddr, os.Errno(e)} + } + + // We want blocking mode for the new fd, hence the double negative. + if e = syscall.SetNonblock(ns, false); e != 0 { + return nil, &OpError{"setnonblock", fd.net, fd.laddr, os.Errno(e)} + } + + return os.NewFile(ns, fd.sysfile.Name()), nil +} + +func closesocket(s int) (errno int) { + return syscall.Close(s) +} diff --git a/libgo/go/net/fd_linux.go b/libgo/go/net/fd_linux.go new file mode 100644 index 000000000..ef86cb17f --- /dev/null +++ b/libgo/go/net/fd_linux.go @@ -0,0 +1,149 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Waiting for FDs via epoll(7). + +package net + +import ( + "os" + "syscall" +) + +const ( + readFlags = syscall.EPOLLIN | syscall.EPOLLRDHUP + writeFlags = syscall.EPOLLOUT +) + +type pollster struct { + epfd int + + // Events we're already waiting for + events map[int]uint32 +} + +func newpollster() (p *pollster, err os.Error) { + p = new(pollster) + var e int + + // The arg to epoll_create is a hint to the kernel + // about the number of FDs we will care about. + // We don't know. + if p.epfd, e = syscall.EpollCreate(16); e != 0 { + return nil, os.NewSyscallError("epoll_create", e) + } + p.events = make(map[int]uint32) + return p, nil +} + +func (p *pollster) AddFD(fd int, mode int, repeat bool) os.Error { + var ev syscall.EpollEvent + var already bool + ev.Fd = int32(fd) + ev.Events, already = p.events[fd] + if !repeat { + ev.Events |= syscall.EPOLLONESHOT + } + if mode == 'r' { + ev.Events |= readFlags + } else { + ev.Events |= writeFlags + } + + var op int + if already { + op = syscall.EPOLL_CTL_MOD + } else { + op = syscall.EPOLL_CTL_ADD + } + if e := syscall.EpollCtl(p.epfd, op, fd, &ev); e != 0 { + return os.NewSyscallError("epoll_ctl", e) + } + p.events[fd] = ev.Events + return nil +} + +func (p *pollster) StopWaiting(fd int, bits uint) { + events, already := p.events[fd] + if !already { + print("Epoll unexpected fd=", fd, "\n") + return + } + + // If syscall.EPOLLONESHOT is not set, the wait + // is a repeating wait, so don't change it. + if events&syscall.EPOLLONESHOT == 0 { + return + } + + // Disable the given bits. + // If we're still waiting for other events, modify the fd + // event in the kernel. Otherwise, delete it. + events &= ^uint32(bits) + if int32(events)&^syscall.EPOLLONESHOT != 0 { + var ev syscall.EpollEvent + ev.Fd = int32(fd) + ev.Events = events + if e := syscall.EpollCtl(p.epfd, syscall.EPOLL_CTL_MOD, fd, &ev); e != 0 { + print("Epoll modify fd=", fd, ": ", os.Errno(e).String(), "\n") + } + p.events[fd] = events + } else { + if e := syscall.EpollCtl(p.epfd, syscall.EPOLL_CTL_DEL, fd, nil); e != 0 { + print("Epoll delete fd=", fd, ": ", os.Errno(e).String(), "\n") + } + p.events[fd] = 0, false + } +} + +func (p *pollster) DelFD(fd int, mode int) { + if mode == 'r' { + p.StopWaiting(fd, readFlags) + } else { + p.StopWaiting(fd, writeFlags) + } +} + +func (p *pollster) WaitFD(nsec int64) (fd int, mode int, err os.Error) { + // Get an event. + var evarray [1]syscall.EpollEvent + ev := &evarray[0] + var msec int = -1 + if nsec > 0 { + msec = int((nsec + 1e6 - 1) / 1e6) + } + n, e := syscall.EpollWait(p.epfd, evarray[0:], msec) + for e == syscall.EAGAIN || e == syscall.EINTR { + n, e = syscall.EpollWait(p.epfd, evarray[0:], msec) + } + if e != 0 { + return -1, 0, os.NewSyscallError("epoll_wait", e) + } + if n == 0 { + return -1, 0, nil + } + fd = int(ev.Fd) + + if ev.Events&writeFlags != 0 { + p.StopWaiting(fd, writeFlags) + return fd, 'w', nil + } + if ev.Events&readFlags != 0 { + p.StopWaiting(fd, readFlags) + return fd, 'r', nil + } + + // Other events are error conditions - wake whoever is waiting. + events, _ := p.events[fd] + if events&writeFlags != 0 { + p.StopWaiting(fd, writeFlags) + return fd, 'w', nil + } + p.StopWaiting(fd, readFlags) + return fd, 'r', nil +} + +func (p *pollster) Close() os.Error { + return os.NewSyscallError("close", syscall.Close(p.epfd)) +} diff --git a/libgo/go/net/fd_rtems.go b/libgo/go/net/fd_rtems.go new file mode 100644 index 000000000..61759ca6e --- /dev/null +++ b/libgo/go/net/fd_rtems.go @@ -0,0 +1,137 @@ +// Copyright 2010 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Waiting for FDs via select(2). + +package net + +import ( + "os" + "syscall" +) + +type pollster struct { + readFds, writeFds, repeatFds *syscall.FdSet_t + maxFd int + readyReadFds, readyWriteFds *syscall.FdSet_t + nReady int + lastFd int +} + +func newpollster() (p *pollster, err os.Error) { + p = new(pollster) + p.readFds = new(syscall.FdSet_t) + p.writeFds = new(syscall.FdSet_t) + p.repeatFds = new(syscall.FdSet_t) + p.readyReadFds = new(syscall.FdSet_t) + p.readyWriteFds = new(syscall.FdSet_t) + p.maxFd = -1 + p.nReady = 0 + p.lastFd = 0 + return p, nil +} + +func (p *pollster) AddFD(fd int, mode int, repeat bool) os.Error { + if mode == 'r' { + syscall.FDSet(fd, p.readFds) + } else { + syscall.FDSet(fd, p.writeFds) + } + + if repeat { + syscall.FDSet(fd, p.repeatFds) + } + + if fd > p.maxFd { + p.maxFd = fd + } + + return nil +} + +func (p *pollster) DelFD(fd int, mode int) { + if mode == 'r' { + if !syscall.FDIsSet(fd, p.readFds) { + print("Select unexpected fd=", fd, " for read\n") + return + } + syscall.FDClr(fd, p.readFds) + } else { + if !syscall.FDIsSet(fd, p.writeFds) { + print("Select unexpected fd=", fd, " for write\n") + return + } + syscall.FDClr(fd, p.writeFds) + } + + // Doesn't matter if not already present. + syscall.FDClr(fd, p.repeatFds) + + // We don't worry about maxFd here. +} + +func (p *pollster) WaitFD(nsec int64) (fd int, mode int, err os.Error) { + if p.nReady == 0 { + var timeout *syscall.Timeval + var tv syscall.Timeval + timeout = nil + if nsec > 0 { + tv = syscall.NsecToTimeval(nsec) + timeout = &tv + } + + var n, e int + var tmpReadFds, tmpWriteFds syscall.FdSet_t + for { + // Temporary syscall.FdSet_ts into which the values are copied + // because select mutates the values. + tmpReadFds = *p.readFds + tmpWriteFds = *p.writeFds + + n, e = syscall.Select(p.maxFd + 1, &tmpReadFds, &tmpWriteFds, nil, timeout) + if e != syscall.EINTR { + break + } + } + if e != 0 { + return -1, 0, os.NewSyscallError("select", e) + } + if n == 0 { + return -1, 0, nil + } + + p.nReady = n + *p.readyReadFds = tmpReadFds + *p.readyWriteFds = tmpWriteFds + p.lastFd = 0 + } + + flag := false + for i := p.lastFd; i < p.maxFd + 1; i++ { + if syscall.FDIsSet(i, p.readyReadFds) { + flag = true + mode = 'r' + syscall.FDClr(i, p.readyReadFds) + } else if syscall.FDIsSet(i, p.readyWriteFds) { + flag = true + mode = 'w' + syscall.FDClr(i, p.readyWriteFds) + } + if flag { + if !syscall.FDIsSet(i, p.repeatFds) { + p.DelFD(i, mode) + } + p.nReady-- + p.lastFd = i + return i, mode, nil + } + } + + // Will not reach here. Just to shut up the compiler. + return -1, 0, nil +} + +func (p *pollster) Close() os.Error { + return nil +} diff --git a/libgo/go/net/fd_windows.go b/libgo/go/net/fd_windows.go new file mode 100644 index 000000000..9b91eb398 --- /dev/null +++ b/libgo/go/net/fd_windows.go @@ -0,0 +1,555 @@ +// Copyright 2010 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package net + +import ( + "os" + "runtime" + "sync" + "syscall" + "time" + "unsafe" +) + +// IO completion result parameters. +type ioResult struct { + key uint32 + qty uint32 + errno int +} + +// Network file descriptor. +type netFD struct { + // locking/lifetime of sysfd + sysmu sync.Mutex + sysref int + closing bool + + // immutable until Close + sysfd int + family int + proto int + cr chan *ioResult + cw chan *ioResult + net string + laddr Addr + raddr Addr + + // owned by client + rdeadline_delta int64 + rdeadline int64 + rio sync.Mutex + wdeadline_delta int64 + wdeadline int64 + wio sync.Mutex +} + +type InvalidConnError struct{} + +func (e *InvalidConnError) String() string { return "invalid net.Conn" } +func (e *InvalidConnError) Temporary() bool { return false } +func (e *InvalidConnError) Timeout() bool { return false } + +// pollServer will run around waiting for io completion request +// to arrive. Every request received will contain channel to signal +// io owner about the completion. + +type pollServer struct { + iocp int32 +} + +func newPollServer() (s *pollServer, err os.Error) { + s = new(pollServer) + var e int + if s.iocp, e = syscall.CreateIoCompletionPort(-1, 0, 0, 1); e != 0 { + return nil, os.NewSyscallError("CreateIoCompletionPort", e) + } + go s.Run() + return s, nil +} + +type ioPacket struct { + // Used by IOCP interface, + // it must be first field of the struct, + // as our code rely on it. + o syscall.Overlapped + + // Link to the io owner. + c chan *ioResult + + w *syscall.WSABuf +} + +func (s *pollServer) getCompletedIO() (ov *syscall.Overlapped, result *ioResult, err os.Error) { + var r ioResult + var o *syscall.Overlapped + _, e := syscall.GetQueuedCompletionStatus(s.iocp, &r.qty, &r.key, &o, syscall.INFINITE) + switch { + case e == 0: + // Dequeued successfully completed io packet. + return o, &r, nil + case e == syscall.WAIT_TIMEOUT && o == nil: + // Wait has timed out (should not happen now, but might be used in the future). + return nil, &r, os.NewSyscallError("GetQueuedCompletionStatus", e) + case o == nil: + // Failed to dequeue anything -> report the error. + return nil, &r, os.NewSyscallError("GetQueuedCompletionStatus", e) + default: + // Dequeued failed io packet. + r.errno = e + return o, &r, nil + } + return +} + +func (s *pollServer) Run() { + for { + o, r, err := s.getCompletedIO() + if err != nil { + panic("Run pollServer: " + err.String() + "\n") + } + p := (*ioPacket)(unsafe.Pointer(o)) + p.c <- r + } +} + +// Network FD methods. +// All the network FDs use a single pollServer. + +var pollserver *pollServer +var onceStartServer sync.Once + +func startServer() { + p, err := newPollServer() + if err != nil { + panic("Start pollServer: " + err.String() + "\n") + } + pollserver = p + + go timeoutIO() +} + +var initErr os.Error + +func newFD(fd, family, proto int, net string, laddr, raddr Addr) (f *netFD, err os.Error) { + if initErr != nil { + return nil, initErr + } + onceStartServer.Do(startServer) + // Associate our socket with pollserver.iocp. + if _, e := syscall.CreateIoCompletionPort(int32(fd), pollserver.iocp, 0, 0); e != 0 { + return nil, &OpError{"CreateIoCompletionPort", net, laddr, os.Errno(e)} + } + f = &netFD{ + sysfd: fd, + family: family, + proto: proto, + cr: make(chan *ioResult, 1), + cw: make(chan *ioResult, 1), + net: net, + laddr: laddr, + raddr: raddr, + } + runtime.SetFinalizer(f, (*netFD).Close) + return f, nil +} + +// Add a reference to this fd. +func (fd *netFD) incref() { + fd.sysmu.Lock() + fd.sysref++ + fd.sysmu.Unlock() +} + +// Remove a reference to this FD and close if we've been asked to do so (and +// there are no references left. +func (fd *netFD) decref() { + fd.sysmu.Lock() + fd.sysref-- + if fd.closing && fd.sysref == 0 && fd.sysfd >= 0 { + // In case the user has set linger, switch to blocking mode so + // the close blocks. As long as this doesn't happen often, we + // can handle the extra OS processes. Otherwise we'll need to + // use the pollserver for Close too. Sigh. + syscall.SetNonblock(fd.sysfd, false) + closesocket(fd.sysfd) + fd.sysfd = -1 + // no need for a finalizer anymore + runtime.SetFinalizer(fd, nil) + } + fd.sysmu.Unlock() +} + +func (fd *netFD) Close() os.Error { + if fd == nil || fd.sysfd == -1 { + return os.EINVAL + } + + fd.incref() + syscall.Shutdown(fd.sysfd, syscall.SHUT_RDWR) + fd.closing = true + fd.decref() + return nil +} + +func newWSABuf(p []byte) *syscall.WSABuf { + var p0 *byte + if len(p) > 0 { + p0 = (*byte)(unsafe.Pointer(&p[0])) + } + return &syscall.WSABuf{uint32(len(p)), p0} +} + +func waitPacket(fd *netFD, pckt *ioPacket, mode int) (r *ioResult) { + var delta int64 + if mode == 'r' { + delta = fd.rdeadline_delta + } + if mode == 'w' { + delta = fd.wdeadline_delta + } + if delta <= 0 { + return <-pckt.c + } + + select { + case r = <-pckt.c: + case <-time.After(delta): + a := &arg{f: cancel, fd: fd, pckt: pckt, c: make(chan int)} + ioChan <- a + <-a.c + r = <-pckt.c + if r.errno == 995 { // IO Canceled + r.errno = syscall.EWOULDBLOCK + } + } + return r +} + +const ( + read = iota + readfrom + write + writeto + cancel +) + +type arg struct { + f int + fd *netFD + pckt *ioPacket + done *uint32 + flags *uint32 + rsa *syscall.RawSockaddrAny + size *int32 + sa *syscall.Sockaddr + c chan int +} + +var ioChan chan *arg = make(chan *arg) + +func timeoutIO() { + // CancelIO only cancels all pending input and output (I/O) operations that are + // issued by the calling thread for the specified file, does not cancel I/O + // operations that other threads issue for a file handle. So we need do all timeout + // I/O in single OS thread. + runtime.LockOSThread() + defer runtime.UnlockOSThread() + for { + o := <-ioChan + var e int + switch o.f { + case read: + e = syscall.WSARecv(uint32(o.fd.sysfd), o.pckt.w, 1, o.done, o.flags, &o.pckt.o, nil) + case readfrom: + e = syscall.WSARecvFrom(uint32(o.fd.sysfd), o.pckt.w, 1, o.done, o.flags, o.rsa, o.size, &o.pckt.o, nil) + case write: + e = syscall.WSASend(uint32(o.fd.sysfd), o.pckt.w, 1, o.done, uint32(0), &o.pckt.o, nil) + case writeto: + e = syscall.WSASendto(uint32(o.fd.sysfd), o.pckt.w, 1, o.done, 0, *o.sa, &o.pckt.o, nil) + case cancel: + _, e = syscall.CancelIo(uint32(o.fd.sysfd)) + } + o.c <- e + } +} + +func (fd *netFD) Read(p []byte) (n int, err os.Error) { + if fd == nil { + return 0, os.EINVAL + } + fd.rio.Lock() + defer fd.rio.Unlock() + fd.incref() + defer fd.decref() + if fd.sysfd == -1 { + return 0, os.EINVAL + } + // Submit receive request. + var pckt ioPacket + pckt.c = fd.cr + pckt.w = newWSABuf(p) + var done uint32 + flags := uint32(0) + var e int + if fd.rdeadline_delta > 0 { + a := &arg{f: read, fd: fd, pckt: &pckt, done: &done, flags: &flags, c: make(chan int)} + ioChan <- a + e = <-a.c + } else { + e = syscall.WSARecv(uint32(fd.sysfd), pckt.w, 1, &done, &flags, &pckt.o, nil) + } + switch e { + case 0: + // IO completed immediately, but we need to get our completion message anyway. + case syscall.ERROR_IO_PENDING: + // IO started, and we have to wait for it's completion. + default: + return 0, &OpError{"WSARecv", fd.net, fd.laddr, os.Errno(e)} + } + // Wait for our request to complete. + r := waitPacket(fd, &pckt, 'r') + if r.errno != 0 { + err = &OpError{"WSARecv", fd.net, fd.laddr, os.Errno(r.errno)} + } + n = int(r.qty) + if err == nil && n == 0 { + err = os.EOF + } + return +} + +func (fd *netFD) ReadFrom(p []byte) (n int, sa syscall.Sockaddr, err os.Error) { + if fd == nil { + return 0, nil, os.EINVAL + } + if len(p) == 0 { + return 0, nil, nil + } + fd.rio.Lock() + defer fd.rio.Unlock() + fd.incref() + defer fd.decref() + if fd.sysfd == -1 { + return 0, nil, os.EINVAL + } + // Submit receive request. + var pckt ioPacket + pckt.c = fd.cr + pckt.w = newWSABuf(p) + var done uint32 + flags := uint32(0) + var rsa syscall.RawSockaddrAny + l := int32(unsafe.Sizeof(rsa)) + var e int + if fd.rdeadline_delta > 0 { + a := &arg{f: readfrom, fd: fd, pckt: &pckt, done: &done, flags: &flags, rsa: &rsa, size: &l, c: make(chan int)} + ioChan <- a + e = <-a.c + } else { + e = syscall.WSARecvFrom(uint32(fd.sysfd), pckt.w, 1, &done, &flags, &rsa, &l, &pckt.o, nil) + } + switch e { + case 0: + // IO completed immediately, but we need to get our completion message anyway. + case syscall.ERROR_IO_PENDING: + // IO started, and we have to wait for it's completion. + default: + return 0, nil, &OpError{"WSARecvFrom", fd.net, fd.laddr, os.Errno(e)} + } + // Wait for our request to complete. + r := waitPacket(fd, &pckt, 'r') + if r.errno != 0 { + err = &OpError{"WSARecvFrom", fd.net, fd.laddr, os.Errno(r.errno)} + } + n = int(r.qty) + sa, _ = rsa.Sockaddr() + return +} + +func (fd *netFD) Write(p []byte) (n int, err os.Error) { + if fd == nil { + return 0, os.EINVAL + } + fd.wio.Lock() + defer fd.wio.Unlock() + fd.incref() + defer fd.decref() + if fd.sysfd == -1 { + return 0, os.EINVAL + } + // Submit send request. + var pckt ioPacket + pckt.c = fd.cw + pckt.w = newWSABuf(p) + var done uint32 + var e int + if fd.wdeadline_delta > 0 { + a := &arg{f: write, fd: fd, pckt: &pckt, done: &done, c: make(chan int)} + ioChan <- a + e = <-a.c + } else { + e = syscall.WSASend(uint32(fd.sysfd), pckt.w, 1, &done, uint32(0), &pckt.o, nil) + } + switch e { + case 0: + // IO completed immediately, but we need to get our completion message anyway. + case syscall.ERROR_IO_PENDING: + // IO started, and we have to wait for it's completion. + default: + return 0, &OpError{"WSASend", fd.net, fd.laddr, os.Errno(e)} + } + // Wait for our request to complete. + r := waitPacket(fd, &pckt, 'w') + if r.errno != 0 { + err = &OpError{"WSASend", fd.net, fd.laddr, os.Errno(r.errno)} + } + n = int(r.qty) + return +} + +func (fd *netFD) WriteTo(p []byte, sa syscall.Sockaddr) (n int, err os.Error) { + if fd == nil { + return 0, os.EINVAL + } + if len(p) == 0 { + return 0, nil + } + fd.wio.Lock() + defer fd.wio.Unlock() + fd.incref() + defer fd.decref() + if fd.sysfd == -1 { + return 0, os.EINVAL + } + // Submit send request. + var pckt ioPacket + pckt.c = fd.cw + pckt.w = newWSABuf(p) + var done uint32 + var e int + if fd.wdeadline_delta > 0 { + a := &arg{f: writeto, fd: fd, pckt: &pckt, done: &done, sa: &sa, c: make(chan int)} + ioChan <- a + e = <-a.c + } else { + e = syscall.WSASendto(uint32(fd.sysfd), pckt.w, 1, &done, 0, sa, &pckt.o, nil) + } + switch e { + case 0: + // IO completed immediately, but we need to get our completion message anyway. + case syscall.ERROR_IO_PENDING: + // IO started, and we have to wait for it's completion. + default: + return 0, &OpError{"WSASendTo", fd.net, fd.laddr, os.Errno(e)} + } + // Wait for our request to complete. + r := waitPacket(fd, &pckt, 'w') + if r.errno != 0 { + err = &OpError{"WSASendTo", fd.net, fd.laddr, os.Errno(r.errno)} + } + n = int(r.qty) + return +} + +func (fd *netFD) accept(toAddr func(syscall.Sockaddr) Addr) (nfd *netFD, err os.Error) { + if fd == nil || fd.sysfd == -1 { + return nil, os.EINVAL + } + fd.incref() + defer fd.decref() + + // Get new socket. + // See ../syscall/exec.go for description of ForkLock. + syscall.ForkLock.RLock() + s, e := syscall.Socket(fd.family, fd.proto, 0) + if e != 0 { + syscall.ForkLock.RUnlock() + return nil, os.Errno(e) + } + syscall.CloseOnExec(s) + syscall.ForkLock.RUnlock() + + // Associate our new socket with IOCP. + onceStartServer.Do(startServer) + if _, e = syscall.CreateIoCompletionPort(int32(s), pollserver.iocp, 0, 0); e != 0 { + return nil, &OpError{"CreateIoCompletionPort", fd.net, fd.laddr, os.Errno(e)} + } + + // Submit accept request. + // Will use new unique channel here, because, unlike Read or Write, + // Accept is expected to be executed by many goroutines simultaniously. + var pckt ioPacket + pckt.c = make(chan *ioResult) + attrs, e := syscall.AcceptIOCP(fd.sysfd, s, &pckt.o) + switch e { + case 0: + // IO completed immediately, but we need to get our completion message anyway. + case syscall.ERROR_IO_PENDING: + // IO started, and we have to wait for it's completion. + default: + closesocket(s) + return nil, &OpError{"AcceptEx", fd.net, fd.laddr, os.Errno(e)} + } + + // Wait for peer connection. + r := <-pckt.c + if r.errno != 0 { + closesocket(s) + return nil, &OpError{"AcceptEx", fd.net, fd.laddr, os.Errno(r.errno)} + } + + // Inherit properties of the listening socket. + e = syscall.SetsockoptInt(s, syscall.SOL_SOCKET, syscall.SO_UPDATE_ACCEPT_CONTEXT, fd.sysfd) + if e != 0 { + closesocket(s) + return nil, &OpError{"Setsockopt", fd.net, fd.laddr, os.Errno(r.errno)} + } + + // Get local and peer addr out of AcceptEx buffer. + lsa, rsa := syscall.GetAcceptIOCPSockaddrs(attrs) + + // Create our netFD and return it for further use. + laddr := toAddr(lsa) + raddr := toAddr(rsa) + + f := &netFD{ + sysfd: s, + family: fd.family, + proto: fd.proto, + cr: make(chan *ioResult, 1), + cw: make(chan *ioResult, 1), + net: fd.net, + laddr: laddr, + raddr: raddr, + } + runtime.SetFinalizer(f, (*netFD).Close) + return f, nil +} + +func closesocket(s int) (errno int) { + return syscall.Closesocket(int32(s)) +} + +func init() { + var d syscall.WSAData + e := syscall.WSAStartup(uint32(0x101), &d) + if e != 0 { + initErr = os.NewSyscallError("WSAStartup", e) + } +} + +func (fd *netFD) dup() (f *os.File, err os.Error) { + // TODO: Implement this + return nil, os.NewSyscallError("dup", syscall.EWINDOWS) +} + +func (fd *netFD) ReadMsg(p []byte, oob []byte) (n, oobn, flags int, sa syscall.Sockaddr, err os.Error) { + return 0, 0, 0, nil, os.EAFNOSUPPORT +} + +func (fd *netFD) WriteMsg(p []byte, oob []byte, sa syscall.Sockaddr) (n int, oobn int, err os.Error) { + return 0, 0, os.EAFNOSUPPORT +} diff --git a/libgo/go/net/hosts.go b/libgo/go/net/hosts.go new file mode 100644 index 000000000..8525f578d --- /dev/null +++ b/libgo/go/net/hosts.go @@ -0,0 +1,86 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Read static host/IP entries from /etc/hosts. + +package net + +import ( + "os" + "sync" +) + +const cacheMaxAge = int64(300) // 5 minutes. + +// hostsPath points to the file with static IP/address entries. +var hostsPath = "/etc/hosts" + +// Simple cache. +var hosts struct { + sync.Mutex + byName map[string][]string + byAddr map[string][]string + time int64 + path string +} + +func readHosts() { + now, _, _ := os.Time() + hp := hostsPath + if len(hosts.byName) == 0 || hosts.time+cacheMaxAge <= now || hosts.path != hp { + hs := make(map[string][]string) + is := make(map[string][]string) + var file *file + if file, _ = open(hp); file == nil { + return + } + for line, ok := file.readLine(); ok; line, ok = file.readLine() { + if i := byteIndex(line, '#'); i >= 0 { + // Discard comments. + line = line[0:i] + } + f := getFields(line) + if len(f) < 2 || ParseIP(f[0]) == nil { + continue + } + for i := 1; i < len(f); i++ { + h := f[i] + hs[h] = append(hs[h], f[0]) + is[f[0]] = append(is[f[0]], h) + } + } + // Update the data cache. + hosts.time, _, _ = os.Time() + hosts.path = hp + hosts.byName = hs + hosts.byAddr = is + file.close() + } +} + +// lookupStaticHosts looks up the addresses for the given host from /etc/hosts. +func lookupStaticHost(host string) []string { + hosts.Lock() + defer hosts.Unlock() + readHosts() + if len(hosts.byName) != 0 { + if ips, ok := hosts.byName[host]; ok { + return ips + } + } + return nil +} + +// rlookupStaticHosts looks up the hosts for the given address from /etc/hosts. +func lookupStaticAddr(addr string) []string { + hosts.Lock() + defer hosts.Unlock() + readHosts() + if len(hosts.byAddr) != 0 { + if hosts, ok := hosts.byAddr[addr]; ok { + return hosts + } + } + return nil +} diff --git a/libgo/go/net/hosts_test.go b/libgo/go/net/hosts_test.go new file mode 100644 index 000000000..84cd92e37 --- /dev/null +++ b/libgo/go/net/hosts_test.go @@ -0,0 +1,54 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package net + +import ( + "testing" +) + +type hostTest struct { + host string + ips []IP +} + + +var hosttests = []hostTest{ + {"odin", []IP{ + IPv4(127, 0, 0, 2), + IPv4(127, 0, 0, 3), + ParseIP("::2"), + }}, + {"thor", []IP{ + IPv4(127, 1, 1, 1), + }}, + {"loki", []IP{}}, + {"ullr", []IP{ + IPv4(127, 1, 1, 2), + }}, + {"ullrhost", []IP{ + IPv4(127, 1, 1, 2), + }}, +} + +func TestLookupStaticHost(t *testing.T) { + p := hostsPath + hostsPath = "hosts_testdata" + for i := 0; i < len(hosttests); i++ { + tt := hosttests[i] + ips := lookupStaticHost(tt.host) + if len(ips) != len(tt.ips) { + t.Errorf("# of hosts = %v; want %v", + len(ips), len(tt.ips)) + return + } + for k, v := range ips { + if tt.ips[k].String() != v { + t.Errorf("lookupStaticHost(%q) = %v; want %v", + tt.host, v, tt.ips[k]) + } + } + } + hostsPath = p +} diff --git a/libgo/go/net/hosts_testdata b/libgo/go/net/hosts_testdata new file mode 100644 index 000000000..b60176389 --- /dev/null +++ b/libgo/go/net/hosts_testdata @@ -0,0 +1,12 @@ +255.255.255.255 broadcasthost +127.0.0.2 odin +127.0.0.3 odin # inline comment +::2 odin +127.1.1.1 thor +# aliases +127.1.1.2 ullr ullrhost +# Bogus entries that must be ignored. +123.123.123 loki +321.321.321.321 +# TODO(yvesj): Should we be able to parse this? From a Darwin system. +fe80::1%lo0 localhost diff --git a/libgo/go/net/ip.go b/libgo/go/net/ip.go new file mode 100644 index 000000000..e82224a28 --- /dev/null +++ b/libgo/go/net/ip.go @@ -0,0 +1,446 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// IP address manipulations +// +// IPv4 addresses are 4 bytes; IPv6 addresses are 16 bytes. +// An IPv4 address can be converted to an IPv6 address by +// adding a canonical prefix (10 zeros, 2 0xFFs). +// This library accepts either size of byte array but always +// returns 16-byte addresses. + +package net + +// IP address lengths (bytes). +const ( + IPv4len = 4 + IPv6len = 16 +) + +// An IP is a single IP address, an array of bytes. +// Functions in this package accept either 4-byte (IP v4) +// or 16-byte (IP v6) arrays as input. Unless otherwise +// specified, functions in this package always return +// IP addresses in 16-byte form using the canonical +// embedding. +// +// Note that in this documentation, referring to an +// IP address as an IPv4 address or an IPv6 address +// is a semantic property of the address, not just the +// length of the byte array: a 16-byte array can still +// be an IPv4 address. +type IP []byte + +// An IP mask is an IP address. +type IPMask []byte + +// IPv4 returns the IP address (in 16-byte form) of the +// IPv4 address a.b.c.d. +func IPv4(a, b, c, d byte) IP { + p := make(IP, IPv6len) + for i := 0; i < 10; i++ { + p[i] = 0 + } + p[10] = 0xff + p[11] = 0xff + p[12] = a + p[13] = b + p[14] = c + p[15] = d + return p +} + +// IPv4Mask returns the IP mask (in 16-byte form) of the +// IPv4 mask a.b.c.d. +func IPv4Mask(a, b, c, d byte) IPMask { + p := make(IPMask, IPv6len) + for i := 0; i < 12; i++ { + p[i] = 0xff + } + p[12] = a + p[13] = b + p[14] = c + p[15] = d + return p +} + +// Well-known IPv4 addresses +var ( + IPv4bcast = IPv4(255, 255, 255, 255) // broadcast + IPv4allsys = IPv4(224, 0, 0, 1) // all systems + IPv4allrouter = IPv4(224, 0, 0, 2) // all routers + IPv4zero = IPv4(0, 0, 0, 0) // all zeros +) + +// Well-known IPv6 addresses +var ( + IPzero = make(IP, IPv6len) // all zeros +) + +// Is p all zeros? +func isZeros(p IP) bool { + for i := 0; i < len(p); i++ { + if p[i] != 0 { + return false + } + } + return true +} + +// To4 converts the IPv4 address ip to a 4-byte representation. +// If ip is not an IPv4 address, To4 returns nil. +func (ip IP) To4() IP { + if len(ip) == IPv4len { + return ip + } + if len(ip) == IPv6len && + isZeros(ip[0:10]) && + ip[10] == 0xff && + ip[11] == 0xff { + return ip[12:16] + } + return nil +} + +// To16 converts the IP address ip to a 16-byte representation. +// If ip is not an IP address (it is the wrong length), To16 returns nil. +func (ip IP) To16() IP { + if len(ip) == IPv4len { + return IPv4(ip[0], ip[1], ip[2], ip[3]) + } + if len(ip) == IPv6len { + return ip + } + return nil +} + +// Default route masks for IPv4. +var ( + classAMask = IPv4Mask(0xff, 0, 0, 0) + classBMask = IPv4Mask(0xff, 0xff, 0, 0) + classCMask = IPv4Mask(0xff, 0xff, 0xff, 0) +) + +// DefaultMask returns the default IP mask for the IP address ip. +// Only IPv4 addresses have default masks; DefaultMask returns +// nil if ip is not a valid IPv4 address. +func (ip IP) DefaultMask() IPMask { + if ip = ip.To4(); ip == nil { + return nil + } + switch true { + case ip[0] < 0x80: + return classAMask + case ip[0] < 0xC0: + return classBMask + default: + return classCMask + } + return nil // not reached +} + +// Mask returns the result of masking the IP address ip with mask. +func (ip IP) Mask(mask IPMask) IP { + n := len(ip) + if n != len(mask) { + return nil + } + out := make(IP, n) + for i := 0; i < n; i++ { + out[i] = ip[i] & mask[i] + } + return out +} + +// Convert i to decimal string. +func itod(i uint) string { + if i == 0 { + return "0" + } + + // Assemble decimal in reverse order. + var b [32]byte + bp := len(b) + for ; i > 0; i /= 10 { + bp-- + b[bp] = byte(i%10) + '0' + } + + return string(b[bp:]) +} + +// Convert i to hexadecimal string. +func itox(i uint) string { + if i == 0 { + return "0" + } + + // Assemble hexadecimal in reverse order. + var b [32]byte + bp := len(b) + for ; i > 0; i /= 16 { + bp-- + b[bp] = "0123456789abcdef"[byte(i%16)] + } + + return string(b[bp:]) +} + +// String returns the string form of the IP address ip. +// If the address is an IPv4 address, the string representation +// is dotted decimal ("74.125.19.99"). Otherwise the representation +// is IPv6 ("2001:4860:0:2001::68"). +func (ip IP) String() string { + p := ip + + if len(ip) == 0 { + return "" + } + + // If IPv4, use dotted notation. + if p4 := p.To4(); len(p4) == 4 { + return itod(uint(p4[0])) + "." + + itod(uint(p4[1])) + "." + + itod(uint(p4[2])) + "." + + itod(uint(p4[3])) + } + if len(p) != IPv6len { + return "?" + } + + // Find longest run of zeros. + e0 := -1 + e1 := -1 + for i := 0; i < 16; i += 2 { + j := i + for j < 16 && p[j] == 0 && p[j+1] == 0 { + j += 2 + } + if j > i && j-i > e1-e0 { + e0 = i + e1 = j + } + } + // The symbol "::" MUST NOT be used to shorten just one 16 bit 0 field. + if e1-e0 <= 2 { + e0 = -1 + e1 = -1 + } + + // Print with possible :: in place of run of zeros + var s string + for i := 0; i < 16; i += 2 { + if i == e0 { + s += "::" + i = e1 + if i >= 16 { + break + } + } else if i > 0 { + s += ":" + } + s += itox((uint(p[i]) << 8) | uint(p[i+1])) + } + return s +} + +// If mask is a sequence of 1 bits followed by 0 bits, +// return the number of 1 bits. +func simpleMaskLength(mask IPMask) int { + var n int + for i, v := range mask { + if v == 0xff { + n += 8 + continue + } + // found non-ff byte + // count 1 bits + for v&0x80 != 0 { + n++ + v <<= 1 + } + // rest must be 0 bits + if v != 0 { + return -1 + } + for i++; i < len(mask); i++ { + if mask[i] != 0 { + return -1 + } + } + break + } + return n +} + +// String returns the string representation of mask. +// If the mask is in the canonical form--ones followed by zeros--the +// string representation is just the decimal number of ones. +// If the mask is in a non-canonical form, it is formatted +// as an IP address. +func (mask IPMask) String() string { + switch len(mask) { + case 4: + n := simpleMaskLength(mask) + if n >= 0 { + return itod(uint(n + (IPv6len-IPv4len)*8)) + } + case 16: + n := simpleMaskLength(mask) + if n >= 12*8 { + return itod(uint(n - 12*8)) + } + } + return IP(mask).String() +} + +// Parse IPv4 address (d.d.d.d). +func parseIPv4(s string) IP { + var p [IPv4len]byte + i := 0 + for j := 0; j < IPv4len; j++ { + if i >= len(s) { + // Missing octets. + return nil + } + if j > 0 { + if s[i] != '.' { + return nil + } + i++ + } + var ( + n int + ok bool + ) + n, i, ok = dtoi(s, i) + if !ok || n > 0xFF { + return nil + } + p[j] = byte(n) + } + if i != len(s) { + return nil + } + return IPv4(p[0], p[1], p[2], p[3]) +} + +// Parse IPv6 address. Many forms. +// The basic form is a sequence of eight colon-separated +// 16-bit hex numbers separated by colons, +// as in 0123:4567:89ab:cdef:0123:4567:89ab:cdef. +// Two exceptions: +// * A run of zeros can be replaced with "::". +// * The last 32 bits can be in IPv4 form. +// Thus, ::ffff:1.2.3.4 is the IPv4 address 1.2.3.4. +func parseIPv6(s string) IP { + p := make(IP, 16) + ellipsis := -1 // position of ellipsis in p + i := 0 // index in string s + + // Might have leading ellipsis + if len(s) >= 2 && s[0] == ':' && s[1] == ':' { + ellipsis = 0 + i = 2 + // Might be only ellipsis + if i == len(s) { + return p + } + } + + // Loop, parsing hex numbers followed by colon. + j := 0 +L: + for j < IPv6len { + // Hex number. + n, i1, ok := xtoi(s, i) + if !ok || n > 0xFFFF { + return nil + } + + // If followed by dot, might be in trailing IPv4. + if i1 < len(s) && s[i1] == '.' { + if ellipsis < 0 && j != IPv6len-IPv4len { + // Not the right place. + return nil + } + if j+IPv4len > IPv6len { + // Not enough room. + return nil + } + p4 := parseIPv4(s[i:]) + if p4 == nil { + return nil + } + p[j] = p4[12] + p[j+1] = p4[13] + p[j+2] = p4[14] + p[j+3] = p4[15] + i = len(s) + j += 4 + break + } + + // Save this 16-bit chunk. + p[j] = byte(n >> 8) + p[j+1] = byte(n) + j += 2 + + // Stop at end of string. + i = i1 + if i == len(s) { + break + } + + // Otherwise must be followed by colon and more. + if s[i] != ':' && i+1 == len(s) { + return nil + } + i++ + + // Look for ellipsis. + if s[i] == ':' { + if ellipsis >= 0 { // already have one + return nil + } + ellipsis = j + if i++; i == len(s) { // can be at end + break + } + } + } + + // Must have used entire string. + if i != len(s) { + return nil + } + + // If didn't parse enough, expand ellipsis. + if j < IPv6len { + if ellipsis < 0 { + return nil + } + n := IPv6len - j + for k := j - 1; k >= ellipsis; k-- { + p[k+n] = p[k] + } + for k := ellipsis + n - 1; k >= ellipsis; k-- { + p[k] = 0 + } + } + return p +} + +// ParseIP parses s as an IP address, returning the result. +// The string s can be in dotted decimal ("74.125.19.99") +// or IPv6 ("2001:4860:0:2001::68") form. +// If s is not a valid textual representation of an IP address, +// ParseIP returns nil. +func ParseIP(s string) IP { + p := parseIPv4(s) + if p != nil { + return p + } + return parseIPv6(s) +} diff --git a/libgo/go/net/ip_test.go b/libgo/go/net/ip_test.go new file mode 100644 index 000000000..e29c3021d --- /dev/null +++ b/libgo/go/net/ip_test.go @@ -0,0 +1,94 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package net + +import ( + "testing" +) + +func isEqual(a, b IP) bool { + if a == nil && b == nil { + return true + } + if a == nil || b == nil || len(a) != len(b) { + return false + } + for i := 0; i < len(a); i++ { + if a[i] != b[i] { + return false + } + } + return true +} + +type parseIPTest struct { + in string + out IP +} + +var parseiptests = []parseIPTest{ + {"127.0.1.2", IPv4(127, 0, 1, 2)}, + {"127.0.0.1", IPv4(127, 0, 0, 1)}, + {"127.0.0.256", nil}, + {"abc", nil}, + {"::ffff:127.0.0.1", IPv4(127, 0, 0, 1)}, + {"2001:4860:0:2001::68", + IP{0x20, 0x01, 0x48, 0x60, 0, 0, 0x20, 0x01, + 0, 0, 0, 0, 0, 0, 0x00, 0x68, + }, + }, + {"::ffff:4a7d:1363", IPv4(74, 125, 19, 99)}, +} + +func TestParseIP(t *testing.T) { + for i := 0; i < len(parseiptests); i++ { + tt := parseiptests[i] + if out := ParseIP(tt.in); !isEqual(out, tt.out) { + t.Errorf("ParseIP(%#q) = %v, want %v", tt.in, out, tt.out) + } + } +} + +type ipStringTest struct { + in IP + out string +} + +var ipstringtests = []ipStringTest{ + // cf. RFC 5952 (A Recommendation for IPv6 Address Text Representation) + {IP{0x20, 0x1, 0xd, 0xb8, 0, 0, 0, 0, + 0, 0, 0x1, 0x23, 0, 0x12, 0, 0x1}, + "2001:db8::123:12:1"}, + {IP{0x20, 0x1, 0xd, 0xb8, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0x1}, + "2001:db8::1"}, + {IP{0x20, 0x1, 0xd, 0xb8, 0, 0, 0, 0x1, + 0, 0, 0, 0x1, 0, 0, 0, 0x1}, + "2001:db8:0:1:0:1:0:1"}, + {IP{0x20, 0x1, 0xd, 0xb8, 0, 0x1, 0, 0, + 0, 0x1, 0, 0, 0, 0x1, 0, 0}, + "2001:db8:1:0:1:0:1:0"}, + {IP{0x20, 0x1, 0, 0, 0, 0, 0, 0, + 0, 0x1, 0, 0, 0, 0, 0, 0x1}, + "2001::1:0:0:1"}, + {IP{0x20, 0x1, 0xd, 0xb8, 0, 0, 0, 0, + 0, 0x1, 0, 0, 0, 0, 0, 0}, + "2001:db8:0:0:1::"}, + {IP{0x20, 0x1, 0xd, 0xb8, 0, 0, 0, 0, + 0, 0x1, 0, 0, 0, 0, 0, 0x1}, + "2001:db8::1:0:0:1"}, + {IP{0x20, 0x1, 0xD, 0xB8, 0, 0, 0, 0, + 0, 0xA, 0, 0xB, 0, 0xC, 0, 0xD}, + "2001:db8::a:b:c:d"}, +} + +func TestIPString(t *testing.T) { + for i := 0; i < len(ipstringtests); i++ { + tt := ipstringtests[i] + if out := tt.in.String(); out != tt.out { + t.Errorf("IP.String(%v) = %#q, want %#q", tt.in, out, tt.out) + } + } +} diff --git a/libgo/go/net/ipraw_test.go b/libgo/go/net/ipraw_test.go new file mode 100644 index 000000000..562298bdf --- /dev/null +++ b/libgo/go/net/ipraw_test.go @@ -0,0 +1,117 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + + +// TODO(cw): ListenPacket test, Read() test, ipv6 test & +// Dial()/Listen() level tests + +package net + +import ( + "bytes" + "flag" + "os" + "testing" +) + +const ICMP_ECHO_REQUEST = 8 +const ICMP_ECHO_REPLY = 0 + +// returns a suitable 'ping request' packet, with id & seq and a +// payload length of pktlen +func makePingRequest(id, seq, pktlen int, filler []byte) []byte { + p := make([]byte, pktlen) + copy(p[8:], bytes.Repeat(filler, (pktlen-8)/len(filler)+1)) + + p[0] = ICMP_ECHO_REQUEST // type + p[1] = 0 // code + p[2] = 0 // cksum + p[3] = 0 // cksum + p[4] = uint8(id >> 8) // id + p[5] = uint8(id & 0xff) // id + p[6] = uint8(seq >> 8) // sequence + p[7] = uint8(seq & 0xff) // sequence + + // calculate icmp checksum + cklen := len(p) + s := uint32(0) + for i := 0; i < (cklen - 1); i += 2 { + s += uint32(p[i+1])<<8 | uint32(p[i]) + } + if cklen&1 == 1 { + s += uint32(p[cklen-1]) + } + s = (s >> 16) + (s & 0xffff) + s = s + (s >> 16) + + // place checksum back in header; using ^= avoids the + // assumption the checksum bytes are zero + p[2] ^= uint8(^s & 0xff) + p[3] ^= uint8(^s >> 8) + + return p +} + +func parsePingReply(p []byte) (id, seq int) { + id = int(p[4])<<8 | int(p[5]) + seq = int(p[6])<<8 | int(p[7]) + return +} + +var srchost = flag.String("srchost", "", "Source of the ICMP ECHO request") +var dsthost = flag.String("dsthost", "localhost", "Destination for the ICMP ECHO request") + +// test (raw) IP socket using ICMP +func TestICMP(t *testing.T) { + if os.Getuid() != 0 { + t.Logf("test disabled; must be root") + return + } + + var laddr *IPAddr + if *srchost != "" { + laddr, err := ResolveIPAddr(*srchost) + if err != nil { + t.Fatalf(`net.ResolveIPAddr("%v") = %v, %v`, *srchost, laddr, err) + } + } + + raddr, err := ResolveIPAddr(*dsthost) + if err != nil { + t.Fatalf(`net.ResolveIPAddr("%v") = %v, %v`, *dsthost, raddr, err) + } + + c, err := ListenIP("ip4:icmp", laddr) + if err != nil { + t.Fatalf(`net.ListenIP("ip4:icmp", %v) = %v, %v`, *srchost, c, err) + } + + sendid := os.Getpid() & 0xffff + const sendseq = 61455 + const pingpktlen = 128 + sendpkt := makePingRequest(sendid, sendseq, pingpktlen, []byte("Go Go Gadget Ping!!!")) + + n, err := c.WriteToIP(sendpkt, raddr) + if err != nil || n != pingpktlen { + t.Fatalf(`net.WriteToIP(..., %v) = %v, %v`, raddr, n, err) + } + + c.SetTimeout(100e6) + resp := make([]byte, 1024) + for { + n, from, err := c.ReadFrom(resp) + if err != nil { + t.Fatalf(`ReadFrom(...) = %v, %v, %v`, n, from, err) + } + if resp[0] != ICMP_ECHO_REPLY { + continue + } + rcvid, rcvseq := parsePingReply(resp) + if rcvid != sendid || rcvseq != sendseq { + t.Fatalf(`Ping reply saw id,seq=0x%x,0x%x (expected 0x%x, 0x%x)`, rcvid, rcvseq, sendid, sendseq) + } + return + } + t.Fatalf("saw no ping return") +} diff --git a/libgo/go/net/iprawsock.go b/libgo/go/net/iprawsock.go new file mode 100644 index 000000000..241be1509 --- /dev/null +++ b/libgo/go/net/iprawsock.go @@ -0,0 +1,358 @@ +// Copyright 2010 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// (Raw) IP sockets + +package net + +import ( + "os" + "sync" + "syscall" +) + +var onceReadProtocols sync.Once + +func sockaddrToIP(sa syscall.Sockaddr) Addr { + switch sa := sa.(type) { + case *syscall.SockaddrInet4: + return &IPAddr{sa.Addr[0:]} + case *syscall.SockaddrInet6: + return &IPAddr{sa.Addr[0:]} + } + return nil +} + +// IPAddr represents the address of a IP end point. +type IPAddr struct { + IP IP +} + +// Network returns the address's network name, "ip". +func (a *IPAddr) Network() string { return "ip" } + +func (a *IPAddr) String() string { + if a == nil { + return "<nil>" + } + return a.IP.String() +} + +func (a *IPAddr) family() int { + if a == nil || len(a.IP) <= 4 { + return syscall.AF_INET + } + if ip := a.IP.To4(); ip != nil { + return syscall.AF_INET + } + return syscall.AF_INET6 +} + +func (a *IPAddr) sockaddr(family int) (syscall.Sockaddr, os.Error) { + return ipToSockaddr(family, a.IP, 0) +} + +func (a *IPAddr) toAddr() sockaddr { + if a == nil { // nil *IPAddr + return nil // nil interface + } + return a +} + +// ResolveIPAddr parses addr as a IP address and resolves domain +// names to numeric addresses. A literal IPv6 host address must be +// enclosed in square brackets, as in "[::]". +func ResolveIPAddr(addr string) (*IPAddr, os.Error) { + ip, err := hostToIP(addr) + if err != nil { + return nil, err + } + return &IPAddr{ip}, nil +} + +// IPConn is the implementation of the Conn and PacketConn +// interfaces for IP network connections. +type IPConn struct { + fd *netFD +} + +func newIPConn(fd *netFD) *IPConn { return &IPConn{fd} } + +func (c *IPConn) ok() bool { return c != nil && c.fd != nil } + +// Implementation of the Conn interface - see Conn for documentation. + +// Read implements the net.Conn Read method. +func (c *IPConn) Read(b []byte) (n int, err os.Error) { + n, _, err = c.ReadFrom(b) + return +} + +// Write implements the net.Conn Write method. +func (c *IPConn) Write(b []byte) (n int, err os.Error) { + if !c.ok() { + return 0, os.EINVAL + } + return c.fd.Write(b) +} + +// Close closes the IP connection. +func (c *IPConn) Close() os.Error { + if !c.ok() { + return os.EINVAL + } + err := c.fd.Close() + c.fd = nil + return err +} + +// LocalAddr returns the local network address. +func (c *IPConn) LocalAddr() Addr { + if !c.ok() { + return nil + } + return c.fd.laddr +} + +// RemoteAddr returns the remote network address, a *IPAddr. +func (c *IPConn) RemoteAddr() Addr { + if !c.ok() { + return nil + } + return c.fd.raddr +} + +// SetTimeout implements the net.Conn SetTimeout method. +func (c *IPConn) SetTimeout(nsec int64) os.Error { + if !c.ok() { + return os.EINVAL + } + return setTimeout(c.fd, nsec) +} + +// SetReadTimeout implements the net.Conn SetReadTimeout method. +func (c *IPConn) SetReadTimeout(nsec int64) os.Error { + if !c.ok() { + return os.EINVAL + } + return setReadTimeout(c.fd, nsec) +} + +// SetWriteTimeout implements the net.Conn SetWriteTimeout method. +func (c *IPConn) SetWriteTimeout(nsec int64) os.Error { + if !c.ok() { + return os.EINVAL + } + return setWriteTimeout(c.fd, nsec) +} + +// SetReadBuffer sets the size of the operating system's +// receive buffer associated with the connection. +func (c *IPConn) SetReadBuffer(bytes int) os.Error { + if !c.ok() { + return os.EINVAL + } + return setReadBuffer(c.fd, bytes) +} + +// SetWriteBuffer sets the size of the operating system's +// transmit buffer associated with the connection. +func (c *IPConn) SetWriteBuffer(bytes int) os.Error { + if !c.ok() { + return os.EINVAL + } + return setWriteBuffer(c.fd, bytes) +} + +// IP-specific methods. + +// ReadFromIP reads a IP packet from c, copying the payload into b. +// It returns the number of bytes copied into b and the return address +// that was on the packet. +// +// ReadFromIP can be made to time out and return an error with +// Timeout() == true after a fixed time limit; see SetTimeout and +// SetReadTimeout. +func (c *IPConn) ReadFromIP(b []byte) (n int, addr *IPAddr, err os.Error) { + if !c.ok() { + return 0, nil, os.EINVAL + } + // TODO(cw,rsc): consider using readv if we know the family + // type to avoid the header trim/copy + n, sa, err := c.fd.ReadFrom(b) + switch sa := sa.(type) { + case *syscall.SockaddrInet4: + addr = &IPAddr{sa.Addr[0:]} + if len(b) >= 4 { // discard ipv4 header + hsize := (int(b[0]) & 0xf) * 4 + copy(b, b[hsize:]) + n -= hsize + } + case *syscall.SockaddrInet6: + addr = &IPAddr{sa.Addr[0:]} + } + return +} + +// ReadFrom implements the net.PacketConn ReadFrom method. +func (c *IPConn) ReadFrom(b []byte) (n int, addr Addr, err os.Error) { + if !c.ok() { + return 0, nil, os.EINVAL + } + n, uaddr, err := c.ReadFromIP(b) + return n, uaddr.toAddr(), err +} + +// WriteToIP writes a IP packet to addr via c, copying the payload from b. +// +// WriteToIP can be made to time out and return +// an error with Timeout() == true after a fixed time limit; +// see SetTimeout and SetWriteTimeout. +// On packet-oriented connections, write timeouts are rare. +func (c *IPConn) WriteToIP(b []byte, addr *IPAddr) (n int, err os.Error) { + if !c.ok() { + return 0, os.EINVAL + } + sa, err1 := addr.sockaddr(c.fd.family) + if err1 != nil { + return 0, &OpError{Op: "write", Net: "ip", Addr: addr, Error: err1} + } + return c.fd.WriteTo(b, sa) +} + +// WriteTo implements the net.PacketConn WriteTo method. +func (c *IPConn) WriteTo(b []byte, addr Addr) (n int, err os.Error) { + if !c.ok() { + return 0, os.EINVAL + } + a, ok := addr.(*IPAddr) + if !ok { + return 0, &OpError{"writeto", "ip", addr, os.EINVAL} + } + return c.WriteToIP(b, a) +} + +// Convert "host" into IP address. +func hostToIP(host string) (ip IP, err os.Error) { + var addr IP + // Try as an IP address. + addr = ParseIP(host) + if addr == nil { + // Not an IP address. Try as a DNS name. + _, addrs, err1 := LookupHost(host) + if err1 != nil { + err = err1 + goto Error + } + addr = ParseIP(addrs[0]) + if addr == nil { + // should not happen + err = &AddrError{"LookupHost returned invalid address", addrs[0]} + goto Error + } + } + + return addr, nil + +Error: + return nil, err +} + + +var protocols map[string]int + +func readProtocols() { + protocols = make(map[string]int) + if file, err := open("/etc/protocols"); err == nil { + for line, ok := file.readLine(); ok; line, ok = file.readLine() { + // tcp 6 TCP # transmission control protocol + if i := byteIndex(line, '#'); i >= 0 { + line = line[0:i] + } + f := getFields(line) + if len(f) < 2 { + continue + } + if proto, _, ok := dtoi(f[1], 0); ok { + protocols[f[0]] = proto + for _, alias := range f[2:] { + protocols[alias] = proto + } + } + } + file.close() + } +} + +func netProtoSplit(netProto string) (net string, proto int, err os.Error) { + onceReadProtocols.Do(readProtocols) + i := last(netProto, ':') + if i < 0 { // no colon + return "", 0, os.ErrorString("no IP protocol specified") + } + net = netProto[0:i] + protostr := netProto[i+1:] + proto, i, ok := dtoi(protostr, 0) + if !ok || i != len(protostr) { + // lookup by name + proto, ok = protocols[protostr] + if ok { + return + } + } + return +} + +// DialIP connects to the remote address raddr on the network net, +// which must be "ip", "ip4", or "ip6". +func DialIP(netProto string, laddr, raddr *IPAddr) (c *IPConn, err os.Error) { + net, proto, err := netProtoSplit(netProto) + if err != nil { + return + } + switch prefixBefore(net, ':') { + case "ip", "ip4", "ip6": + default: + return nil, UnknownNetworkError(net) + } + if raddr == nil { + return nil, &OpError{"dial", "ip", nil, errMissingAddress} + } + fd, e := internetSocket(net, laddr.toAddr(), raddr.toAddr(), syscall.SOCK_RAW, proto, "dial", sockaddrToIP) + if e != nil { + return nil, e + } + return newIPConn(fd), nil +} + +// ListenIP listens for incoming IP packets addressed to the +// local address laddr. The returned connection c's ReadFrom +// and WriteTo methods can be used to receive and send IP +// packets with per-packet addressing. +func ListenIP(netProto string, laddr *IPAddr) (c *IPConn, err os.Error) { + net, proto, err := netProtoSplit(netProto) + if err != nil { + return + } + switch prefixBefore(net, ':') { + case "ip", "ip4", "ip6": + default: + return nil, UnknownNetworkError(net) + } + fd, e := internetSocket(net, laddr.toAddr(), nil, syscall.SOCK_RAW, proto, "dial", sockaddrToIP) + if e != nil { + return nil, e + } + return newIPConn(fd), nil +} + +// BindToDevice binds an IPConn to a network interface. +func (c *IPConn) BindToDevice(device string) os.Error { + if !c.ok() { + return os.EINVAL + } + c.fd.incref() + defer c.fd.decref() + return os.NewSyscallError("setsockopt", syscall.BindToDevice(c.fd.sysfd, device)) +} diff --git a/libgo/go/net/ipsock.go b/libgo/go/net/ipsock.go new file mode 100644 index 000000000..4ba6a55b9 --- /dev/null +++ b/libgo/go/net/ipsock.go @@ -0,0 +1,236 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// IP sockets + +package net + +import ( + "os" + "syscall" +) + +// Should we try to use the IPv4 socket interface if we're +// only dealing with IPv4 sockets? As long as the host system +// understands IPv6, it's okay to pass IPv4 addresses to the IPv6 +// interface. That simplifies our code and is most general. +// Unfortunately, we need to run on kernels built without IPv6 support too. +// So probe the kernel to figure it out. +func kernelSupportsIPv6() bool { + // FreeBSD does not support this sort of interface. + if syscall.OS == "freebsd" { + return false + } + fd, e := syscall.Socket(syscall.AF_INET6, syscall.SOCK_STREAM, syscall.IPPROTO_TCP) + if fd >= 0 { + closesocket(fd) + } + return e == 0 +} + +var preferIPv4 = !kernelSupportsIPv6() + +// TODO(rsc): if syscall.OS == "linux", we're supposd to read +// /proc/sys/net/core/somaxconn, +// to take advantage of kernels that have raised the limit. +func listenBacklog() int { return syscall.SOMAXCONN } + +// Internet sockets (TCP, UDP) + +// A sockaddr represents a TCP or UDP network address that can +// be converted into a syscall.Sockaddr. +type sockaddr interface { + Addr + sockaddr(family int) (syscall.Sockaddr, os.Error) + family() int +} + +func internetSocket(net string, laddr, raddr sockaddr, socktype, proto int, mode string, toAddr func(syscall.Sockaddr) Addr) (fd *netFD, err os.Error) { + // Figure out IP version. + // If network has a suffix like "tcp4", obey it. + var oserr os.Error + family := syscall.AF_INET6 + switch net[len(net)-1] { + case '4': + family = syscall.AF_INET + case '6': + // nothing to do + default: + // Otherwise, guess. + // If the addresses are IPv4 and we prefer IPv4, use 4; else 6. + if preferIPv4 && + (laddr == nil || laddr.family() == syscall.AF_INET) && + (raddr == nil || raddr.family() == syscall.AF_INET) { + family = syscall.AF_INET + } + } + + var la, ra syscall.Sockaddr + if laddr != nil { + if la, oserr = laddr.sockaddr(family); oserr != nil { + goto Error + } + } + if raddr != nil { + if ra, oserr = raddr.sockaddr(family); oserr != nil { + goto Error + } + } + fd, oserr = socket(net, family, socktype, proto, la, ra, toAddr) + if oserr != nil { + goto Error + } + return fd, nil + +Error: + addr := raddr + if mode == "listen" { + addr = laddr + } + return nil, &OpError{mode, net, addr, oserr} +} + +func getip(fd int, remote bool) (ip []byte, port int, ok bool) { + // No attempt at error reporting because + // there are no possible errors, and the + // caller won't report them anyway. + var sa syscall.Sockaddr + if remote { + sa, _ = syscall.Getpeername(fd) + } else { + sa, _ = syscall.Getsockname(fd) + } + switch sa := sa.(type) { + case *syscall.SockaddrInet4: + return sa.Addr[0:], sa.Port, true + case *syscall.SockaddrInet6: + return sa.Addr[0:], sa.Port, true + } + return +} + +type InvalidAddrError string + +func (e InvalidAddrError) String() string { return string(e) } +func (e InvalidAddrError) Timeout() bool { return false } +func (e InvalidAddrError) Temporary() bool { return false } + + +func ipToSockaddr(family int, ip IP, port int) (syscall.Sockaddr, os.Error) { + switch family { + case syscall.AF_INET: + if len(ip) == 0 { + ip = IPv4zero + } + if ip = ip.To4(); ip == nil { + return nil, InvalidAddrError("non-IPv4 address") + } + s := new(syscall.SockaddrInet4) + for i := 0; i < IPv4len; i++ { + s.Addr[i] = ip[i] + } + s.Port = port + return s, nil + case syscall.AF_INET6: + if len(ip) == 0 { + ip = IPzero + } + // IPv4 callers use 0.0.0.0 to mean "announce on any available address". + // In IPv6 mode, Linux treats that as meaning "announce on 0.0.0.0", + // which it refuses to do. Rewrite to the IPv6 all zeros. + if p4 := ip.To4(); p4 != nil && p4[0] == 0 && p4[1] == 0 && p4[2] == 0 && p4[3] == 0 { + ip = IPzero + } + if ip = ip.To16(); ip == nil { + return nil, InvalidAddrError("non-IPv6 address") + } + s := new(syscall.SockaddrInet6) + for i := 0; i < IPv6len; i++ { + s.Addr[i] = ip[i] + } + s.Port = port + return s, nil + } + return nil, InvalidAddrError("unexpected socket family") +} + +// Split "host:port" into "host" and "port". +// Host cannot contain colons unless it is bracketed. +func splitHostPort(hostport string) (host, port string, err os.Error) { + // The port starts after the last colon. + i := last(hostport, ':') + if i < 0 { + err = &AddrError{"missing port in address", hostport} + return + } + + host, port = hostport[0:i], hostport[i+1:] + + // Can put brackets around host ... + if len(host) > 0 && host[0] == '[' && host[len(host)-1] == ']' { + host = host[1 : len(host)-1] + } else { + // ... but if there are no brackets, no colons. + if byteIndex(host, ':') >= 0 { + err = &AddrError{"too many colons in address", hostport} + return + } + } + return +} + +// Join "host" and "port" into "host:port". +// If host contains colons, will join into "[host]:port". +func joinHostPort(host, port string) string { + // If host has colons, have to bracket it. + if byteIndex(host, ':') >= 0 { + return "[" + host + "]:" + port + } + return host + ":" + port +} + +// Convert "host:port" into IP address and port. +func hostPortToIP(net, hostport string) (ip IP, iport int, err os.Error) { + host, port, err := splitHostPort(hostport) + if err != nil { + goto Error + } + + var addr IP + if host != "" { + // Try as an IP address. + addr = ParseIP(host) + if addr == nil { + // Not an IP address. Try as a DNS name. + _, addrs, err1 := LookupHost(host) + if err1 != nil { + err = err1 + goto Error + } + addr = ParseIP(addrs[0]) + if addr == nil { + // should not happen + err = &AddrError{"LookupHost returned invalid address", addrs[0]} + goto Error + } + } + } + + p, i, ok := dtoi(port, 0) + if !ok || i != len(port) { + p, err = LookupPort(net, port) + if err != nil { + goto Error + } + } + if p < 0 || p > 0xFFFF { + err = &AddrError{"invalid port", port} + goto Error + } + + return addr, p, nil + +Error: + return nil, 0, err +} diff --git a/libgo/go/net/net.go b/libgo/go/net/net.go new file mode 100644 index 000000000..c0c1c3b8a --- /dev/null +++ b/libgo/go/net/net.go @@ -0,0 +1,192 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// The net package provides a portable interface to Unix +// networks sockets, including TCP/IP, UDP, domain name +// resolution, and Unix domain sockets. +package net + +// TODO(rsc): +// support for raw ethernet sockets + +import "os" + +// Addr represents a network end point address. +type Addr interface { + Network() string // name of the network + String() string // string form of address +} + +// Conn is a generic stream-oriented network connection. +type Conn interface { + // Read reads data from the connection. + // Read can be made to time out and return a net.Error with Timeout() == true + // after a fixed time limit; see SetTimeout and SetReadTimeout. + Read(b []byte) (n int, err os.Error) + + // Write writes data to the connection. + // Write can be made to time out and return a net.Error with Timeout() == true + // after a fixed time limit; see SetTimeout and SetWriteTimeout. + Write(b []byte) (n int, err os.Error) + + // Close closes the connection. + // The error returned is an os.Error to satisfy io.Closer; + Close() os.Error + + // LocalAddr returns the local network address. + LocalAddr() Addr + + // RemoteAddr returns the remote network address. + RemoteAddr() Addr + + // SetTimeout sets the read and write deadlines associated + // with the connection. + SetTimeout(nsec int64) os.Error + + // SetReadTimeout sets the time (in nanoseconds) that + // Read will wait for data before returning an error with Timeout() == true. + // Setting nsec == 0 (the default) disables the deadline. + SetReadTimeout(nsec int64) os.Error + + // SetWriteTimeout sets the time (in nanoseconds) that + // Write will wait to send its data before returning an error with Timeout() == true. + // Setting nsec == 0 (the default) disables the deadline. + // Even if write times out, it may return n > 0, indicating that + // some of the data was successfully written. + SetWriteTimeout(nsec int64) os.Error +} + +// An Error represents a network error. +type Error interface { + os.Error + Timeout() bool // Is the error a timeout? + Temporary() bool // Is the error temporary? +} + +// PacketConn is a generic packet-oriented network connection. +type PacketConn interface { + // ReadFrom reads a packet from the connection, + // copying the payload into b. It returns the number of + // bytes copied into b and the return address that + // was on the packet. + // ReadFrom can be made to time out and return + // an error with Timeout() == true after a fixed time limit; + // see SetTimeout and SetReadTimeout. + ReadFrom(b []byte) (n int, addr Addr, err os.Error) + + // WriteTo writes a packet with payload b to addr. + // WriteTo can be made to time out and return + // an error with Timeout() == true after a fixed time limit; + // see SetTimeout and SetWriteTimeout. + // On packet-oriented connections, write timeouts are rare. + WriteTo(b []byte, addr Addr) (n int, err os.Error) + + // Close closes the connection. + // The error returned is an os.Error to satisfy io.Closer; + Close() os.Error + + // LocalAddr returns the local network address. + LocalAddr() Addr + + // SetTimeout sets the read and write deadlines associated + // with the connection. + SetTimeout(nsec int64) os.Error + + // SetReadTimeout sets the time (in nanoseconds) that + // Read will wait for data before returning an error with Timeout() == true. + // Setting nsec == 0 (the default) disables the deadline. + SetReadTimeout(nsec int64) os.Error + + // SetWriteTimeout sets the time (in nanoseconds) that + // Write will wait to send its data before returning an error with Timeout() == true. + // Setting nsec == 0 (the default) disables the deadline. + // Even if write times out, it may return n > 0, indicating that + // some of the data was successfully written. + SetWriteTimeout(nsec int64) os.Error +} + +// A Listener is a generic network listener for stream-oriented protocols. +type Listener interface { + // Accept waits for and returns the next connection to the listener. + Accept() (c Conn, err os.Error) + + // Close closes the listener. + // The error returned is an os.Error to satisfy io.Closer; + Close() os.Error + + // Addr returns the listener's network address. + Addr() Addr +} + +var errMissingAddress = os.ErrorString("missing address") + +type OpError struct { + Op string + Net string + Addr Addr + Error os.Error +} + +func (e *OpError) String() string { + if e == nil { + return "<nil>" + } + s := e.Op + if e.Net != "" { + s += " " + e.Net + } + if e.Addr != nil { + s += " " + e.Addr.String() + } + s += ": " + e.Error.String() + return s +} + +type temporary interface { + Temporary() bool +} + +func (e *OpError) Temporary() bool { + t, ok := e.Error.(temporary) + return ok && t.Temporary() +} + +type timeout interface { + Timeout() bool +} + +func (e *OpError) Timeout() bool { + t, ok := e.Error.(timeout) + return ok && t.Timeout() +} + +type AddrError struct { + Error string + Addr string +} + +func (e *AddrError) String() string { + if e == nil { + return "<nil>" + } + s := e.Error + if e.Addr != "" { + s += " " + e.Addr + } + return s +} + +func (e *AddrError) Temporary() bool { + return false +} + +func (e *AddrError) Timeout() bool { + return false +} + +type UnknownNetworkError string + +func (e UnknownNetworkError) String() string { return "unknown network " + string(e) } +func (e UnknownNetworkError) Temporary() bool { return false } +func (e UnknownNetworkError) Timeout() bool { return false } diff --git a/libgo/go/net/net_test.go b/libgo/go/net/net_test.go new file mode 100644 index 000000000..275b31c0e --- /dev/null +++ b/libgo/go/net/net_test.go @@ -0,0 +1,126 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package net + +import ( + "flag" + "regexp" + "runtime" + "testing" +) + +var runErrorTest = flag.Bool("run_error_test", false, "let TestDialError check for dns errors") + +type DialErrorTest struct { + Net string + Laddr string + Raddr string + Pattern string +} + +var dialErrorTests = []DialErrorTest{ + { + "datakit", "", "mh/astro/r70", + "dial datakit mh/astro/r70: unknown network datakit", + }, + { + "tcp", "", "127.0.0.1:☺", + "dial tcp 127.0.0.1:☺: unknown port tcp/☺", + }, + { + "tcp", "", "no-such-name.google.com.:80", + "dial tcp no-such-name.google.com.:80: lookup no-such-name.google.com.( on .*)?: no (.*)", + }, + { + "tcp", "", "no-such-name.no-such-top-level-domain.:80", + "dial tcp no-such-name.no-such-top-level-domain.:80: lookup no-such-name.no-such-top-level-domain.( on .*)?: no (.*)", + }, + { + "tcp", "", "no-such-name:80", + `dial tcp no-such-name:80: lookup no-such-name\.(.*\.)?( on .*)?: no (.*)`, + }, + { + "tcp", "", "mh/astro/r70:http", + "dial tcp mh/astro/r70:http: lookup mh/astro/r70: invalid domain name", + }, + { + "unix", "", "/etc/file-not-found", + "dial unix /etc/file-not-found: [nN]o such file or directory", + }, + { + "unix", "", "/etc/", + "dial unix /etc/: ([pP]ermission denied|[sS]ocket operation on non-socket|[cC]onnection refused)", + }, + { + "unixpacket", "", "/etc/file-not-found", + "dial unixpacket /etc/file-not-found: no such file or directory", + }, + { + "unixpacket", "", "/etc/", + "dial unixpacket /etc/: (permission denied|socket operation on non-socket|connection refused)", + }, +} + +func TestDialError(t *testing.T) { + if !*runErrorTest { + t.Logf("test disabled; use --run_error_test to enable") + return + } + for i, tt := range dialErrorTests { + c, e := Dial(tt.Net, tt.Laddr, tt.Raddr) + if c != nil { + c.Close() + } + if e == nil { + t.Errorf("#%d: nil error, want match for %#q", i, tt.Pattern) + continue + } + s := e.String() + match, _ := regexp.MatchString(tt.Pattern, s) + if !match { + t.Errorf("#%d: %q, want match for %#q", i, s, tt.Pattern) + } + } +} + +var revAddrTests = []struct { + Addr string + Reverse string + ErrPrefix string +}{ + {"1.2.3.4", "4.3.2.1.in-addr.arpa.", ""}, + {"245.110.36.114", "114.36.110.245.in-addr.arpa.", ""}, + {"::ffff:12.34.56.78", "78.56.34.12.in-addr.arpa.", ""}, + {"::1", "1.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.ip6.arpa.", ""}, + {"1::", "0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.1.0.0.0.ip6.arpa.", ""}, + {"1234:567::89a:bcde", "e.d.c.b.a.9.8.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.7.6.5.0.4.3.2.1.ip6.arpa.", ""}, + {"1234:567:fefe:bcbc:adad:9e4a:89a:bcde", "e.d.c.b.a.9.8.0.a.4.e.9.d.a.d.a.c.b.c.b.e.f.e.f.7.6.5.0.4.3.2.1.ip6.arpa.", ""}, + {"1.2.3", "", "unrecognized address"}, + {"1.2.3.4.5", "", "unrecognized address"}, + {"1234:567:bcbca::89a:bcde", "", "unrecognized address"}, + {"1234:567::bcbc:adad::89a:bcde", "", "unrecognized address"}, +} + +func TestReverseAddress(t *testing.T) { + if runtime.GOOS == "windows" { + return + } + for i, tt := range revAddrTests { + a, e := reverseaddr(tt.Addr) + if len(tt.ErrPrefix) > 0 && e == nil { + t.Errorf("#%d: expected %q, got <nil> (error)", i, tt.ErrPrefix) + continue + } + if len(tt.ErrPrefix) == 0 && e != nil { + t.Errorf("#%d: expected <nil>, got %q (error)", i, e) + } + if e != nil && e.(*DNSError).Error != tt.ErrPrefix { + t.Errorf("#%d: expected %q, got %q (mismatched error)", i, tt.ErrPrefix, e.(*DNSError).Error) + } + if a != tt.Reverse { + t.Errorf("#%d: expected %q, got %q (reverse address)", i, tt.Reverse, a) + } + } +} diff --git a/libgo/go/net/newpollserver.go b/libgo/go/net/newpollserver.go new file mode 100644 index 000000000..820e70b46 --- /dev/null +++ b/libgo/go/net/newpollserver.go @@ -0,0 +1,41 @@ +// Copyright 2010 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package net + +import ( + "os" + "syscall" +) + +func newPollServer() (s *pollServer, err os.Error) { + s = new(pollServer) + s.cr = make(chan *netFD, 1) + s.cw = make(chan *netFD, 1) + if s.pr, s.pw, err = os.Pipe(); err != nil { + return nil, err + } + var e int + if e = syscall.SetNonblock(s.pr.Fd(), true); e != 0 { + Errno: + err = &os.PathError{"setnonblock", s.pr.Name(), os.Errno(e)} + Error: + s.pr.Close() + s.pw.Close() + return nil, err + } + if e = syscall.SetNonblock(s.pw.Fd(), true); e != 0 { + goto Errno + } + if s.poll, err = newpollster(); err != nil { + goto Error + } + if err = s.poll.AddFD(s.pr.Fd(), 'r', true); err != nil { + s.poll.Close() + goto Error + } + s.pending = make(map[int]*netFD) + go s.Run() + return s, nil +} diff --git a/libgo/go/net/newpollserver_rtems.go b/libgo/go/net/newpollserver_rtems.go new file mode 100644 index 000000000..05cb71a54 --- /dev/null +++ b/libgo/go/net/newpollserver_rtems.go @@ -0,0 +1,78 @@ +// Copyright 2010 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package net + +import ( + "os" + "syscall" +) + +func selfConnectedTCPSocket() (pr, pw *os.File, err os.Error) { + // See ../syscall/exec.go for description of ForkLock. + syscall.ForkLock.RLock() + sockfd, e := syscall.Socket(syscall.AF_INET, syscall.SOCK_STREAM, 0) + if e != 0 { + syscall.ForkLock.RUnlock() + return nil, nil, os.Errno(e) + } + syscall.CloseOnExec(sockfd) + syscall.ForkLock.RUnlock() + + // Allow reuse of recently-used addresses. + syscall.SetsockoptInt(sockfd, syscall.SOL_SOCKET, syscall.SO_REUSEADDR, 1) + + var laTCP *TCPAddr + var la syscall.Sockaddr + if laTCP, err = ResolveTCPAddr("127.0.0.1:0"); err != nil { + Error: + return nil, nil, err + } + if la, err = laTCP.sockaddr(syscall.AF_INET); err != nil { + goto Error + } + e = syscall.Bind(sockfd, la) + if e != 0 { + Errno: + syscall.Close(sockfd) + return nil, nil, os.Errno(e) + } + + laddr, _ := syscall.Getsockname(sockfd) + e = syscall.Connect(sockfd, laddr) + if e != 0 { + goto Errno + } + + fd := os.NewFile(sockfd, "wakeupSocket") + return fd, fd, nil +} + +func newPollServer() (s *pollServer, err os.Error) { + s = new(pollServer) + s.cr = make(chan *netFD, 1) + s.cw = make(chan *netFD, 1) + // s.pr and s.pw are indistinguishable. + if s.pr, s.pw, err = selfConnectedTCPSocket(); err != nil { + return nil, err + } + var e int + if e = syscall.SetNonblock(s.pr.Fd(), true); e != 0 { + Errno: + err = &os.PathError{"setnonblock", s.pr.Name(), os.Errno(e)} + Error: + s.pr.Close() + return nil, err + } + if s.poll, err = newpollster(); err != nil { + goto Error + } + if err = s.poll.AddFD(s.pr.Fd(), 'r', true); err != nil { + s.poll.Close() + goto Error + } + s.pending = make(map[int]*netFD) + go s.Run() + return s, nil +} diff --git a/libgo/go/net/parse.go b/libgo/go/net/parse.go new file mode 100644 index 000000000..605f3110b --- /dev/null +++ b/libgo/go/net/parse.go @@ -0,0 +1,214 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Simple file i/o and string manipulation, to avoid +// depending on strconv and bufio and strings. + +package net + +import ( + "io" + "os" +) + +type file struct { + file *os.File + data []byte + atEOF bool +} + +func (f *file) close() { f.file.Close() } + +func (f *file) getLineFromData() (s string, ok bool) { + data := f.data + i := 0 + for i = 0; i < len(data); i++ { + if data[i] == '\n' { + s = string(data[0:i]) + ok = true + // move data + i++ + n := len(data) - i + copy(data[0:], data[i:]) + f.data = data[0:n] + return + } + } + if f.atEOF && len(f.data) > 0 { + // EOF, return all we have + s = string(data) + f.data = f.data[0:0] + ok = true + } + return +} + +func (f *file) readLine() (s string, ok bool) { + if s, ok = f.getLineFromData(); ok { + return + } + if len(f.data) < cap(f.data) { + ln := len(f.data) + n, err := io.ReadFull(f.file, f.data[ln:cap(f.data)]) + if n >= 0 { + f.data = f.data[0 : ln+n] + } + if err == os.EOF { + f.atEOF = true + } + } + s, ok = f.getLineFromData() + return +} + +func open(name string) (*file, os.Error) { + fd, err := os.Open(name, os.O_RDONLY, 0) + if err != nil { + return nil, err + } + return &file{fd, make([]byte, 1024)[0:0], false}, nil +} + +func byteIndex(s string, c byte) int { + for i := 0; i < len(s); i++ { + if s[i] == c { + return i + } + } + return -1 +} + +// Count occurrences in s of any bytes in t. +func countAnyByte(s string, t string) int { + n := 0 + for i := 0; i < len(s); i++ { + if byteIndex(t, s[i]) >= 0 { + n++ + } + } + return n +} + +// Split s at any bytes in t. +func splitAtBytes(s string, t string) []string { + a := make([]string, 1+countAnyByte(s, t)) + n := 0 + last := 0 + for i := 0; i < len(s); i++ { + if byteIndex(t, s[i]) >= 0 { + if last < i { + a[n] = string(s[last:i]) + n++ + } + last = i + 1 + } + } + if last < len(s) { + a[n] = string(s[last:]) + n++ + } + return a[0:n] +} + +func getFields(s string) []string { return splitAtBytes(s, " \r\t\n") } + +// Bigger than we need, not too big to worry about overflow +const big = 0xFFFFFF + +// Decimal to integer starting at &s[i0]. +// Returns number, new offset, success. +func dtoi(s string, i0 int) (n int, i int, ok bool) { + n = 0 + for i = i0; i < len(s) && '0' <= s[i] && s[i] <= '9'; i++ { + n = n*10 + int(s[i]-'0') + if n >= big { + return 0, i, false + } + } + if i == i0 { + return 0, i, false + } + return n, i, true +} + +// Hexadecimal to integer starting at &s[i0]. +// Returns number, new offset, success. +func xtoi(s string, i0 int) (n int, i int, ok bool) { + n = 0 + for i = i0; i < len(s); i++ { + if '0' <= s[i] && s[i] <= '9' { + n *= 16 + n += int(s[i] - '0') + } else if 'a' <= s[i] && s[i] <= 'f' { + n *= 16 + n += int(s[i]-'a') + 10 + } else if 'A' <= s[i] && s[i] <= 'F' { + n *= 16 + n += int(s[i]-'A') + 10 + } else { + break + } + if n >= big { + return 0, i, false + } + } + if i == i0 { + return 0, i, false + } + return n, i, true +} + +// Integer to decimal. +func itoa(i int) string { + var buf [30]byte + n := len(buf) + neg := false + if i < 0 { + i = -i + neg = true + } + ui := uint(i) + for ui > 0 || n == len(buf) { + n-- + buf[n] = byte('0' + ui%10) + ui /= 10 + } + if neg { + n-- + buf[n] = '-' + } + return string(buf[n:]) +} + +// Number of occurrences of b in s. +func count(s string, b byte) int { + n := 0 + for i := 0; i < len(s); i++ { + if s[i] == b { + n++ + } + } + return n +} + +// Returns the prefix of s up to but not including the character c +func prefixBefore(s string, c byte) string { + for i, v := range s { + if v == int(c) { + return s[0:i] + } + } + return s +} + +// Index of rightmost occurrence of b in s. +func last(s string, b byte) int { + i := len(s) + for i--; i >= 0; i-- { + if s[i] == b { + break + } + } + return i +} diff --git a/libgo/go/net/parse_test.go b/libgo/go/net/parse_test.go new file mode 100644 index 000000000..2b7784eee --- /dev/null +++ b/libgo/go/net/parse_test.go @@ -0,0 +1,50 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package net + +import ( + "bufio" + "os" + "testing" + "runtime" +) + +func TestReadLine(t *testing.T) { + // /etc/services file does not exist on windows. + if runtime.GOOS == "windows" { + return + } + filename := "/etc/services" // a nice big file + + fd, err := os.Open(filename, os.O_RDONLY, 0) + if err != nil { + t.Fatalf("open %s: %v", filename, err) + } + br := bufio.NewReader(fd) + + file, err := open(filename) + if file == nil { + t.Fatalf("net.open(%s) = nil", filename) + } + + lineno := 1 + byteno := 0 + for { + bline, berr := br.ReadString('\n') + if n := len(bline); n > 0 { + bline = bline[0 : n-1] + } + line, ok := file.readLine() + if (berr != nil) != !ok || bline != line { + t.Fatalf("%s:%d (#%d)\nbufio => %q, %v\nnet => %q, %v", + filename, lineno, byteno, bline, berr, line, ok) + } + if !ok { + break + } + lineno++ + byteno += len(line) + 1 + } +} diff --git a/libgo/go/net/pipe.go b/libgo/go/net/pipe.go new file mode 100644 index 000000000..c0bbd356b --- /dev/null +++ b/libgo/go/net/pipe.go @@ -0,0 +1,62 @@ +package net + +import ( + "io" + "os" +) + +// Pipe creates a synchronous, in-memory, full duplex +// network connection; both ends implement the Conn interface. +// Reads on one end are matched with writes on the other, +// copying data directly between the two; there is no internal +// buffering. +func Pipe() (Conn, Conn) { + r1, w1 := io.Pipe() + r2, w2 := io.Pipe() + + return &pipe{r1, w2}, &pipe{r2, w1} +} + +type pipe struct { + *io.PipeReader + *io.PipeWriter +} + +type pipeAddr int + +func (pipeAddr) Network() string { + return "pipe" +} + +func (pipeAddr) String() string { + return "pipe" +} + +func (p *pipe) Close() os.Error { + err := p.PipeReader.Close() + err1 := p.PipeWriter.Close() + if err == nil { + err = err1 + } + return err +} + +func (p *pipe) LocalAddr() Addr { + return pipeAddr(0) +} + +func (p *pipe) RemoteAddr() Addr { + return pipeAddr(0) +} + +func (p *pipe) SetTimeout(nsec int64) os.Error { + return os.NewError("net.Pipe does not support timeouts") +} + +func (p *pipe) SetReadTimeout(nsec int64) os.Error { + return os.NewError("net.Pipe does not support timeouts") +} + +func (p *pipe) SetWriteTimeout(nsec int64) os.Error { + return os.NewError("net.Pipe does not support timeouts") +} diff --git a/libgo/go/net/pipe_test.go b/libgo/go/net/pipe_test.go new file mode 100644 index 000000000..7e4c6db44 --- /dev/null +++ b/libgo/go/net/pipe_test.go @@ -0,0 +1,57 @@ +// Copyright 2010 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package net + +import ( + "bytes" + "io" + "os" + "testing" +) + +func checkWrite(t *testing.T, w io.Writer, data []byte, c chan int) { + n, err := w.Write(data) + if err != nil { + t.Errorf("write: %v", err) + } + if n != len(data) { + t.Errorf("short write: %d != %d", n, len(data)) + } + c <- 0 +} + +func checkRead(t *testing.T, r io.Reader, data []byte, wantErr os.Error) { + buf := make([]byte, len(data)+10) + n, err := r.Read(buf) + if err != wantErr { + t.Errorf("read: %v", err) + return + } + if n != len(data) || !bytes.Equal(buf[0:n], data) { + t.Errorf("bad read: got %q", buf[0:n]) + return + } +} + +// Test a simple read/write/close sequence. +// Assumes that the underlying io.Pipe implementation +// is solid and we're just testing the net wrapping. + +func TestPipe(t *testing.T) { + c := make(chan int) + cli, srv := Pipe() + go checkWrite(t, cli, []byte("hello, world"), c) + checkRead(t, srv, []byte("hello, world"), nil) + <-c + go checkWrite(t, srv, []byte("line 2"), c) + checkRead(t, cli, []byte("line 2"), nil) + <-c + go checkWrite(t, cli, []byte("a third line"), c) + checkRead(t, srv, []byte("a third line"), nil) + <-c + go srv.Close() + checkRead(t, cli, nil, os.EOF) + cli.Close() +} diff --git a/libgo/go/net/port.go b/libgo/go/net/port.go new file mode 100644 index 000000000..7d25058b2 --- /dev/null +++ b/libgo/go/net/port.go @@ -0,0 +1,70 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Read system port mappings from /etc/services + +package net + +import ( + "os" + "sync" +) + +var services map[string]map[string]int +var servicesError os.Error +var onceReadServices sync.Once + +func readServices() { + services = make(map[string]map[string]int) + var file *file + if file, servicesError = open("/etc/services"); servicesError != nil { + return + } + for line, ok := file.readLine(); ok; line, ok = file.readLine() { + // "http 80/tcp www www-http # World Wide Web HTTP" + if i := byteIndex(line, '#'); i >= 0 { + line = line[0:i] + } + f := getFields(line) + if len(f) < 2 { + continue + } + portnet := f[1] // "tcp/80" + port, j, ok := dtoi(portnet, 0) + if !ok || port <= 0 || j >= len(portnet) || portnet[j] != '/' { + continue + } + netw := portnet[j+1:] // "tcp" + m, ok1 := services[netw] + if !ok1 { + m = make(map[string]int) + services[netw] = m + } + for i := 0; i < len(f); i++ { + if i != 1 { // f[1] was port/net + m[f[i]] = port + } + } + } + file.close() +} + +// LookupPort looks up the port for the given network and service. +func LookupPort(network, service string) (port int, err os.Error) { + onceReadServices.Do(readServices) + + switch network { + case "tcp4", "tcp6": + network = "tcp" + case "udp4", "udp6": + network = "udp" + } + + if m, ok := services[network]; ok { + if port, ok = m[service]; ok { + return + } + } + return 0, &AddrError{"unknown port", network + "/" + service} +} diff --git a/libgo/go/net/port_test.go b/libgo/go/net/port_test.go new file mode 100644 index 000000000..329b169f3 --- /dev/null +++ b/libgo/go/net/port_test.go @@ -0,0 +1,53 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package net + +import ( + "testing" +) + +type portTest struct { + netw string + name string + port int + ok bool +} + +var porttests = []portTest{ + {"tcp", "echo", 7, true}, + {"tcp", "discard", 9, true}, + {"tcp", "systat", 11, true}, + {"tcp", "daytime", 13, true}, + {"tcp", "chargen", 19, true}, + {"tcp", "ftp-data", 20, true}, + {"tcp", "ftp", 21, true}, + {"tcp", "telnet", 23, true}, + {"tcp", "smtp", 25, true}, + {"tcp", "time", 37, true}, + {"tcp", "domain", 53, true}, + {"tcp", "finger", 79, true}, + + {"udp", "echo", 7, true}, + {"udp", "tftp", 69, true}, + {"udp", "bootpc", 68, true}, + {"udp", "bootps", 67, true}, + {"udp", "domain", 53, true}, + {"udp", "ntp", 123, true}, + {"udp", "snmp", 161, true}, + {"udp", "syslog", 514, true}, + + {"--badnet--", "zzz", 0, false}, + {"tcp", "--badport--", 0, false}, +} + +func TestLookupPort(t *testing.T) { + for i := 0; i < len(porttests); i++ { + tt := porttests[i] + if port, err := LookupPort(tt.netw, tt.name); port != tt.port || (err == nil) != tt.ok { + t.Errorf("LookupPort(%q, %q) = %v, %s; want %v", + tt.netw, tt.name, port, err, tt.port) + } + } +} diff --git a/libgo/go/net/resolv_windows.go b/libgo/go/net/resolv_windows.go new file mode 100644 index 000000000..f3d854ff2 --- /dev/null +++ b/libgo/go/net/resolv_windows.go @@ -0,0 +1,112 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package net + +import ( + "syscall" + "unsafe" + "os" + "sync" +) + +var hostentLock sync.Mutex +var serventLock sync.Mutex + +func LookupHost(name string) (cname string, addrs []string, err os.Error) { + hostentLock.Lock() + defer hostentLock.Unlock() + h, e := syscall.GetHostByName(name) + if e != 0 { + return "", nil, os.NewSyscallError("GetHostByName", e) + } + cname = name + switch h.AddrType { + case syscall.AF_INET: + i := 0 + addrs = make([]string, 100) // plenty of room to grow + for p := (*[100](*[4]byte))(unsafe.Pointer(h.AddrList)); i < cap(addrs) && p[i] != nil; i++ { + addrs[i] = IPv4(p[i][0], p[i][1], p[i][2], p[i][3]).String() + } + addrs = addrs[0:i] + default: // TODO(vcc): Implement non IPv4 address lookups. + return "", nil, os.NewSyscallError("LookupHost", syscall.EWINDOWS) + } + return cname, addrs, nil +} + +type SRV struct { + Target string + Port uint16 + Priority uint16 + Weight uint16 +} + +func LookupSRV(service, proto, name string) (cname string, addrs []*SRV, err os.Error) { + var r *syscall.DNSRecord + target := "_" + service + "._" + proto + "." + name + e := syscall.DnsQuery(target, syscall.DNS_TYPE_SRV, 0, nil, &r, nil) + if int(e) != 0 { + return "", nil, os.NewSyscallError("LookupSRV", int(e)) + } + defer syscall.DnsRecordListFree(r, 1) + addrs = make([]*SRV, 100) + i := 0 + for p := r; p != nil && p.Type == syscall.DNS_TYPE_SRV; p = p.Next { + v := (*syscall.DNSSRVData)(unsafe.Pointer(&p.Data[0])) + addrs[i] = &SRV{syscall.UTF16ToString((*[256]uint16)(unsafe.Pointer(v.Target))[:]), v.Port, v.Priority, v.Weight} + i++ + } + addrs = addrs[0:i] + return name, addrs, nil +} + +func LookupPort(network, service string) (port int, err os.Error) { + switch network { + case "tcp4", "tcp6": + network = "tcp" + case "udp4", "udp6": + network = "udp" + } + serventLock.Lock() + defer serventLock.Unlock() + s, e := syscall.GetServByName(service, network) + if e != 0 { + return 0, os.NewSyscallError("GetServByName", e) + } + return int(syscall.Ntohs(s.Port)), nil +} + +// TODO(brainman): Following code is only to get tests running. + +func isDomainName(s string) bool { + panic("unimplemented") +} + +func reverseaddr(addr string) (arpa string, err os.Error) { + panic("unimplemented") +} + +// DNSError represents a DNS lookup error. +type DNSError struct { + Error string // description of the error + Name string // name looked for + Server string // server used + IsTimeout bool +} + +func (e *DNSError) String() string { + if e == nil { + return "<nil>" + } + s := "lookup " + e.Name + if e.Server != "" { + s += " on " + e.Server + } + s += ": " + e.Error + return s +} + +func (e *DNSError) Timeout() bool { return e.IsTimeout } +func (e *DNSError) Temporary() bool { return e.IsTimeout } diff --git a/libgo/go/net/server_test.go b/libgo/go/net/server_test.go new file mode 100644 index 000000000..3f2442a46 --- /dev/null +++ b/libgo/go/net/server_test.go @@ -0,0 +1,203 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package net + +import ( + "flag" + "io" + "os" + "strings" + "syscall" + "testing" + "runtime" +) + +// Do not test empty datagrams by default. +// It causes unexplained timeouts on some systems, +// including Snow Leopard. I think that the kernel +// doesn't quite expect them. +var testUDP = flag.Bool("udp", false, "whether to test UDP datagrams") + +func runEcho(fd io.ReadWriter, done chan<- int) { + var buf [1024]byte + + for { + n, err := fd.Read(buf[0:]) + if err != nil || n == 0 { + break + } + fd.Write(buf[0:n]) + } + done <- 1 +} + +func runServe(t *testing.T, network, addr string, listening chan<- string, done chan<- int) { + l, err := Listen(network, addr) + if err != nil { + t.Fatalf("net.Listen(%q, %q) = _, %v", network, addr, err) + } + listening <- l.Addr().String() + + for { + fd, err := l.Accept() + if err != nil { + break + } + echodone := make(chan int) + go runEcho(fd, echodone) + <-echodone // make sure Echo stops + l.Close() + } + done <- 1 +} + +func connect(t *testing.T, network, addr string, isEmpty bool) { + var laddr string + if network == "unixgram" { + laddr = addr + ".local" + } + fd, err := Dial(network, laddr, addr) + if err != nil { + t.Fatalf("net.Dial(%q, %q, %q) = _, %v", network, laddr, addr, err) + } + fd.SetReadTimeout(1e9) // 1s + + var b []byte + if !isEmpty { + b = []byte("hello, world\n") + } + var b1 [100]byte + + n, err1 := fd.Write(b) + if n != len(b) { + t.Fatalf("fd.Write(%q) = %d, %v", b, n, err1) + } + + n, err1 = fd.Read(b1[0:]) + if n != len(b) || err1 != nil { + t.Fatalf("fd.Read() = %d, %v (want %d, nil)", n, err1, len(b)) + } + fd.Close() +} + +func doTest(t *testing.T, network, listenaddr, dialaddr string) { + t.Logf("Test %s %s %s\n", network, listenaddr, dialaddr) + listening := make(chan string) + done := make(chan int) + if network == "tcp" { + listenaddr += ":0" // any available port + } + go runServe(t, network, listenaddr, listening, done) + addr := <-listening // wait for server to start + if network == "tcp" { + dialaddr += addr[strings.LastIndex(addr, ":"):] + } + connect(t, network, dialaddr, false) + <-done // make sure server stopped +} + +func TestTCPServer(t *testing.T) { + doTest(t, "tcp", "0.0.0.0", "127.0.0.1") + doTest(t, "tcp", "", "127.0.0.1") + if kernelSupportsIPv6() { + doTest(t, "tcp", "[::]", "[::ffff:127.0.0.1]") + doTest(t, "tcp", "[::]", "127.0.0.1") + doTest(t, "tcp", "0.0.0.0", "[::ffff:127.0.0.1]") + } +} + +func TestUnixServer(t *testing.T) { + // "unix" sockets are not supported on windows. + if runtime.GOOS == "windows" { + return + } + os.Remove("/tmp/gotest.net") + doTest(t, "unix", "/tmp/gotest.net", "/tmp/gotest.net") + os.Remove("/tmp/gotest.net") + if syscall.OS == "linux" { + doTest(t, "unixpacket", "/tmp/gotest.net", "/tmp/gotest.net") + os.Remove("/tmp/gotest.net") + // Test abstract unix domain socket, a Linux-ism + doTest(t, "unix", "@gotest/net", "@gotest/net") + doTest(t, "unixpacket", "@gotest/net", "@gotest/net") + } +} + +func runPacket(t *testing.T, network, addr string, listening chan<- string, done chan<- int) { + c, err := ListenPacket(network, addr) + if err != nil { + t.Fatalf("net.ListenPacket(%q, %q) = _, %v", network, addr, err) + } + listening <- c.LocalAddr().String() + c.SetReadTimeout(10e6) // 10ms + var buf [1000]byte + for { + n, addr, err := c.ReadFrom(buf[0:]) + if e, ok := err.(Error); ok && e.Timeout() { + if done <- 1 { + break + } + continue + } + if err != nil { + break + } + if _, err = c.WriteTo(buf[0:n], addr); err != nil { + t.Fatalf("WriteTo %v: %v", addr, err) + } + } + c.Close() + done <- 1 +} + +func doTestPacket(t *testing.T, network, listenaddr, dialaddr string, isEmpty bool) { + t.Logf("TestPacket %s %s %s\n", network, listenaddr, dialaddr) + listening := make(chan string) + done := make(chan int) + if network == "udp" { + listenaddr += ":0" // any available port + } + go runPacket(t, network, listenaddr, listening, done) + addr := <-listening // wait for server to start + if network == "udp" { + dialaddr += addr[strings.LastIndex(addr, ":"):] + } + connect(t, network, dialaddr, isEmpty) + <-done // tell server to stop + <-done // wait for stop +} + +func TestUDPServer(t *testing.T) { + if !*testUDP { + return + } + for _, isEmpty := range []bool{false, true} { + doTestPacket(t, "udp", "0.0.0.0", "127.0.0.1", isEmpty) + doTestPacket(t, "udp", "", "127.0.0.1", isEmpty) + if kernelSupportsIPv6() { + doTestPacket(t, "udp", "[::]", "[::ffff:127.0.0.1]", isEmpty) + doTestPacket(t, "udp", "[::]", "127.0.0.1", isEmpty) + doTestPacket(t, "udp", "0.0.0.0", "[::ffff:127.0.0.1]", isEmpty) + } + } +} + +func TestUnixDatagramServer(t *testing.T) { + // "unix" sockets are not supported on windows. + if runtime.GOOS == "windows" { + return + } + for _, isEmpty := range []bool{false} { + os.Remove("/tmp/gotest1.net") + os.Remove("/tmp/gotest1.net.local") + doTestPacket(t, "unixgram", "/tmp/gotest1.net", "/tmp/gotest1.net", isEmpty) + os.Remove("/tmp/gotest1.net") + os.Remove("/tmp/gotest1.net.local") + if syscall.OS == "linux" { + // Test abstract unix domain socket, a Linux-ism + doTestPacket(t, "unixgram", "@gotest1/net", "@gotest1/net", isEmpty) + } + } +} diff --git a/libgo/go/net/sock.go b/libgo/go/net/sock.go new file mode 100644 index 000000000..5a88ddcbc --- /dev/null +++ b/libgo/go/net/sock.go @@ -0,0 +1,181 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Sockets + +package net + +import ( + "os" + "reflect" + "syscall" +) + +// Boolean to int. +func boolint(b bool) int { + if b { + return 1 + } + return 0 +} + +// Generic socket creation. +func socket(net string, f, p, t int, la, ra syscall.Sockaddr, toAddr func(syscall.Sockaddr) Addr) (fd *netFD, err os.Error) { + // See ../syscall/exec.go for description of ForkLock. + syscall.ForkLock.RLock() + s, e := syscall.Socket(f, p, t) + if e != 0 { + syscall.ForkLock.RUnlock() + return nil, os.Errno(e) + } + syscall.CloseOnExec(s) + syscall.ForkLock.RUnlock() + + // Allow reuse of recently-used addresses. + syscall.SetsockoptInt(s, syscall.SOL_SOCKET, syscall.SO_REUSEADDR, 1) + + // Allow broadcast. + syscall.SetsockoptInt(s, syscall.SOL_SOCKET, syscall.SO_BROADCAST, 1) + + if f == syscall.AF_INET6 { + // using ip, tcp, udp, etc. + // allow both protocols even if the OS default is otherwise. + syscall.SetsockoptInt(s, syscall.IPPROTO_IPV6, syscall.IPV6_V6ONLY, 0) + } + + if la != nil { + e = syscall.Bind(s, la) + if e != 0 { + closesocket(s) + return nil, os.Errno(e) + } + } + + if ra != nil { + e = syscall.Connect(s, ra) + for e == syscall.EINTR { + e = syscall.Connect(s, ra) + } + if e != 0 { + closesocket(s) + return nil, os.Errno(e) + } + } + + sa, _ := syscall.Getsockname(s) + laddr := toAddr(sa) + sa, _ = syscall.Getpeername(s) + raddr := toAddr(sa) + + fd, err = newFD(s, f, p, net, laddr, raddr) + if err != nil { + closesocket(s) + return nil, err + } + + return fd, nil +} + +func setsockoptInt(fd, level, opt int, value int) os.Error { + return os.NewSyscallError("setsockopt", syscall.SetsockoptInt(fd, level, opt, value)) +} + +func setsockoptNsec(fd, level, opt int, nsec int64) os.Error { + var tv = syscall.NsecToTimeval(nsec) + return os.NewSyscallError("setsockopt", syscall.SetsockoptTimeval(fd, level, opt, &tv)) +} + +func setReadBuffer(fd *netFD, bytes int) os.Error { + fd.incref() + defer fd.decref() + return setsockoptInt(fd.sysfd, syscall.SOL_SOCKET, syscall.SO_RCVBUF, bytes) +} + +func setWriteBuffer(fd *netFD, bytes int) os.Error { + fd.incref() + defer fd.decref() + return setsockoptInt(fd.sysfd, syscall.SOL_SOCKET, syscall.SO_SNDBUF, bytes) +} + +func setReadTimeout(fd *netFD, nsec int64) os.Error { + fd.rdeadline_delta = nsec + return nil +} + +func setWriteTimeout(fd *netFD, nsec int64) os.Error { + fd.wdeadline_delta = nsec + return nil +} + +func setTimeout(fd *netFD, nsec int64) os.Error { + if e := setReadTimeout(fd, nsec); e != nil { + return e + } + return setWriteTimeout(fd, nsec) +} + +func setReuseAddr(fd *netFD, reuse bool) os.Error { + fd.incref() + defer fd.decref() + return setsockoptInt(fd.sysfd, syscall.SOL_SOCKET, syscall.SO_REUSEADDR, boolint(reuse)) +} + +func bindToDevice(fd *netFD, dev string) os.Error { + // TODO(rsc): call setsockopt with null-terminated string pointer + return os.EINVAL +} + +func setDontRoute(fd *netFD, dontroute bool) os.Error { + fd.incref() + defer fd.decref() + return setsockoptInt(fd.sysfd, syscall.SOL_SOCKET, syscall.SO_DONTROUTE, boolint(dontroute)) +} + +func setKeepAlive(fd *netFD, keepalive bool) os.Error { + fd.incref() + defer fd.decref() + return setsockoptInt(fd.sysfd, syscall.SOL_SOCKET, syscall.SO_KEEPALIVE, boolint(keepalive)) +} + +func setNoDelay(fd *netFD, noDelay bool) os.Error { + fd.incref() + defer fd.decref() + return setsockoptInt(fd.sysfd, syscall.IPPROTO_TCP, syscall.TCP_NODELAY, boolint(noDelay)) +} + +func setLinger(fd *netFD, sec int) os.Error { + var l syscall.Linger + if sec >= 0 { + l.Onoff = 1 + l.Linger = int32(sec) + } else { + l.Onoff = 0 + l.Linger = 0 + } + fd.incref() + defer fd.decref() + e := syscall.SetsockoptLinger(fd.sysfd, syscall.SOL_SOCKET, syscall.SO_LINGER, &l) + return os.NewSyscallError("setsockopt", e) +} + +type UnknownSocketError struct { + sa syscall.Sockaddr +} + +func (e *UnknownSocketError) String() string { + return "unknown socket address type " + reflect.Typeof(e.sa).String() +} + +func sockaddrToString(sa syscall.Sockaddr) (name string, err os.Error) { + switch a := sa.(type) { + case *syscall.SockaddrInet4: + return joinHostPort(IP(a.Addr[0:]).String(), itoa(a.Port)), nil + case *syscall.SockaddrInet6: + return joinHostPort(IP(a.Addr[0:]).String(), itoa(a.Port)), nil + case *syscall.SockaddrUnix: + return a.Name, nil + } + + return "", &UnknownSocketError{sa} +} diff --git a/libgo/go/net/srv_test.go b/libgo/go/net/srv_test.go new file mode 100644 index 000000000..4dd6089cd --- /dev/null +++ b/libgo/go/net/srv_test.go @@ -0,0 +1,22 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// TODO It would be nice to use a mock DNS server, to eliminate +// external dependencies. + +package net + +import ( + "testing" +) + +func TestGoogleSRV(t *testing.T) { + _, addrs, err := LookupSRV("xmpp-server", "tcp", "google.com") + if err != nil { + t.Errorf("failed: %s", err) + } + if len(addrs) == 0 { + t.Errorf("no results") + } +} diff --git a/libgo/go/net/tcpsock.go b/libgo/go/net/tcpsock.go new file mode 100644 index 000000000..a4bca11bb --- /dev/null +++ b/libgo/go/net/tcpsock.go @@ -0,0 +1,293 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// TCP sockets + +package net + +import ( + "os" + "syscall" +) + +func sockaddrToTCP(sa syscall.Sockaddr) Addr { + switch sa := sa.(type) { + case *syscall.SockaddrInet4: + return &TCPAddr{sa.Addr[0:], sa.Port} + case *syscall.SockaddrInet6: + return &TCPAddr{sa.Addr[0:], sa.Port} + } + return nil +} + +// TCPAddr represents the address of a TCP end point. +type TCPAddr struct { + IP IP + Port int +} + +// Network returns the address's network name, "tcp". +func (a *TCPAddr) Network() string { return "tcp" } + +func (a *TCPAddr) String() string { + if a == nil { + return "<nil>" + } + return joinHostPort(a.IP.String(), itoa(a.Port)) +} + +func (a *TCPAddr) family() int { + if a == nil || len(a.IP) <= 4 { + return syscall.AF_INET + } + if ip := a.IP.To4(); ip != nil { + return syscall.AF_INET + } + return syscall.AF_INET6 +} + +func (a *TCPAddr) sockaddr(family int) (syscall.Sockaddr, os.Error) { + return ipToSockaddr(family, a.IP, a.Port) +} + +func (a *TCPAddr) toAddr() sockaddr { + if a == nil { // nil *TCPAddr + return nil // nil interface + } + return a +} + +// ResolveTCPAddr parses addr as a TCP address of the form +// host:port and resolves domain names or port names to +// numeric addresses. A literal IPv6 host address must be +// enclosed in square brackets, as in "[::]:80". +func ResolveTCPAddr(addr string) (*TCPAddr, os.Error) { + ip, port, err := hostPortToIP("tcp", addr) + if err != nil { + return nil, err + } + return &TCPAddr{ip, port}, nil +} + +// TCPConn is an implementation of the Conn interface +// for TCP network connections. +type TCPConn struct { + fd *netFD +} + +func newTCPConn(fd *netFD) *TCPConn { + c := &TCPConn{fd} + c.SetNoDelay(true) + return c +} + +func (c *TCPConn) ok() bool { return c != nil && c.fd != nil } + +// Implementation of the Conn interface - see Conn for documentation. + +// Read implements the net.Conn Read method. +func (c *TCPConn) Read(b []byte) (n int, err os.Error) { + if !c.ok() { + return 0, os.EINVAL + } + return c.fd.Read(b) +} + +// Write implements the net.Conn Write method. +func (c *TCPConn) Write(b []byte) (n int, err os.Error) { + if !c.ok() { + return 0, os.EINVAL + } + return c.fd.Write(b) +} + +// Close closes the TCP connection. +func (c *TCPConn) Close() os.Error { + if !c.ok() { + return os.EINVAL + } + err := c.fd.Close() + c.fd = nil + return err +} + +// LocalAddr returns the local network address, a *TCPAddr. +func (c *TCPConn) LocalAddr() Addr { + if !c.ok() { + return nil + } + return c.fd.laddr +} + +// RemoteAddr returns the remote network address, a *TCPAddr. +func (c *TCPConn) RemoteAddr() Addr { + if !c.ok() { + return nil + } + return c.fd.raddr +} + +// SetTimeout implements the net.Conn SetTimeout method. +func (c *TCPConn) SetTimeout(nsec int64) os.Error { + if !c.ok() { + return os.EINVAL + } + return setTimeout(c.fd, nsec) +} + +// SetReadTimeout implements the net.Conn SetReadTimeout method. +func (c *TCPConn) SetReadTimeout(nsec int64) os.Error { + if !c.ok() { + return os.EINVAL + } + return setReadTimeout(c.fd, nsec) +} + +// SetWriteTimeout implements the net.Conn SetWriteTimeout method. +func (c *TCPConn) SetWriteTimeout(nsec int64) os.Error { + if !c.ok() { + return os.EINVAL + } + return setWriteTimeout(c.fd, nsec) +} + +// SetReadBuffer sets the size of the operating system's +// receive buffer associated with the connection. +func (c *TCPConn) SetReadBuffer(bytes int) os.Error { + if !c.ok() { + return os.EINVAL + } + return setReadBuffer(c.fd, bytes) +} + +// SetWriteBuffer sets the size of the operating system's +// transmit buffer associated with the connection. +func (c *TCPConn) SetWriteBuffer(bytes int) os.Error { + if !c.ok() { + return os.EINVAL + } + return setWriteBuffer(c.fd, bytes) +} + +// SetLinger sets the behavior of Close() on a connection +// which still has data waiting to be sent or to be acknowledged. +// +// If sec < 0 (the default), Close returns immediately and +// the operating system finishes sending the data in the background. +// +// If sec == 0, Close returns immediately and the operating system +// discards any unsent or unacknowledged data. +// +// If sec > 0, Close blocks for at most sec seconds waiting for +// data to be sent and acknowledged. +func (c *TCPConn) SetLinger(sec int) os.Error { + if !c.ok() { + return os.EINVAL + } + return setLinger(c.fd, sec) +} + +// SetKeepAlive sets whether the operating system should send +// keepalive messages on the connection. +func (c *TCPConn) SetKeepAlive(keepalive bool) os.Error { + if !c.ok() { + return os.EINVAL + } + return setKeepAlive(c.fd, keepalive) +} + +// SetNoDelay controls whether the operating system should delay +// packet transmission in hopes of sending fewer packets +// (Nagle's algorithm). The default is true (no delay), meaning +// that data is sent as soon as possible after a Write. +func (c *TCPConn) SetNoDelay(noDelay bool) os.Error { + if !c.ok() { + return os.EINVAL + } + return setNoDelay(c.fd, noDelay) +} + +// File returns a copy of the underlying os.File, set to blocking mode. +// It is the caller's responsibility to close f when finished. +// Closing c does not affect f, and closing f does not affect c. +func (c *TCPConn) File() (f *os.File, err os.Error) { return c.fd.dup() } + +// DialTCP is like Dial but can only connect to TCP networks +// and returns a TCPConn structure. +func DialTCP(net string, laddr, raddr *TCPAddr) (c *TCPConn, err os.Error) { + if raddr == nil { + return nil, &OpError{"dial", "tcp", nil, errMissingAddress} + } + fd, e := internetSocket(net, laddr.toAddr(), raddr.toAddr(), syscall.SOCK_STREAM, 0, "dial", sockaddrToTCP) + if e != nil { + return nil, e + } + return newTCPConn(fd), nil +} + +// TCPListener is a TCP network listener. +// Clients should typically use variables of type Listener +// instead of assuming TCP. +type TCPListener struct { + fd *netFD +} + +// ListenTCP announces on the TCP address laddr and returns a TCP listener. +// Net must be "tcp", "tcp4", or "tcp6". +// If laddr has a port of 0, it means to listen on some available port. +// The caller can use l.Addr() to retrieve the chosen address. +func ListenTCP(net string, laddr *TCPAddr) (l *TCPListener, err os.Error) { + fd, err := internetSocket(net, laddr.toAddr(), nil, syscall.SOCK_STREAM, 0, "listen", sockaddrToTCP) + if err != nil { + return nil, err + } + errno := syscall.Listen(fd.sysfd, listenBacklog()) + if errno != 0 { + closesocket(fd.sysfd) + return nil, &OpError{"listen", "tcp", laddr, os.Errno(errno)} + } + l = new(TCPListener) + l.fd = fd + return l, nil +} + +// AcceptTCP accepts the next incoming call and returns the new connection +// and the remote address. +func (l *TCPListener) AcceptTCP() (c *TCPConn, err os.Error) { + if l == nil || l.fd == nil || l.fd.sysfd < 0 { + return nil, os.EINVAL + } + fd, err := l.fd.accept(sockaddrToTCP) + if err != nil { + return nil, err + } + return newTCPConn(fd), nil +} + +// Accept implements the Accept method in the Listener interface; +// it waits for the next call and returns a generic Conn. +func (l *TCPListener) Accept() (c Conn, err os.Error) { + c1, err := l.AcceptTCP() + if err != nil { + return nil, err + } + return c1, nil +} + +// Close stops listening on the TCP address. +// Already Accepted connections are not closed. +func (l *TCPListener) Close() os.Error { + if l == nil || l.fd == nil { + return os.EINVAL + } + return l.fd.Close() +} + +// Addr returns the listener's network address, a *TCPAddr. +func (l *TCPListener) Addr() Addr { return l.fd.laddr } + +// File returns a copy of the underlying os.File, set to blocking mode. +// It is the caller's responsibility to close f when finished. +// Closing c does not affect f, and closing f does not affect c. +func (l *TCPListener) File() (f *os.File, err os.Error) { return l.fd.dup() } diff --git a/libgo/go/net/textproto/pipeline.go b/libgo/go/net/textproto/pipeline.go new file mode 100644 index 000000000..8c25884b3 --- /dev/null +++ b/libgo/go/net/textproto/pipeline.go @@ -0,0 +1,117 @@ +// Copyright 2010 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package textproto + +import ( + "sync" +) + +// A Pipeline manages a pipelined in-order request/response sequence. +// +// To use a Pipeline p to manage multiple clients on a connection, +// each client should run: +// +// id := p.Next() // take a number +// +// p.StartRequest(id) // wait for turn to send request +// «send request» +// p.EndRequest(id) // notify Pipeline that request is sent +// +// p.StartResponse(id) // wait for turn to read response +// «read response» +// p.EndResponse(id) // notify Pipeline that response is read +// +// A pipelined server can use the same calls to ensure that +// responses computed in parallel are written in the correct order. +type Pipeline struct { + mu sync.Mutex + id uint + request sequencer + response sequencer +} + +// Next returns the next id for a request/response pair. +func (p *Pipeline) Next() uint { + p.mu.Lock() + id := p.id + p.id++ + p.mu.Unlock() + return id +} + +// StartRequest blocks until it is time to send (or, if this is a server, receive) +// the request with the given id. +func (p *Pipeline) StartRequest(id uint) { + p.request.Start(id) +} + +// EndRequest notifies p that the request with the given id has been sent +// (or, if this is a server, received). +func (p *Pipeline) EndRequest(id uint) { + p.request.End(id) +} + +// StartResponse blocks until it is time to receive (or, if this is a server, send) +// the request with the given id. +func (p *Pipeline) StartResponse(id uint) { + p.response.Start(id) +} + +// EndResponse notifies p that the response with the given id has been received +// (or, if this is a server, sent). +func (p *Pipeline) EndResponse(id uint) { + p.response.End(id) +} + +// A sequencer schedules a sequence of numbered events that must +// happen in order, one after the other. The event numbering must start +// at 0 and increment without skipping. The event number wraps around +// safely as long as there are not 2^32 simultaneous events pending. +type sequencer struct { + mu sync.Mutex + id uint + wait map[uint]chan uint +} + +// Start waits until it is time for the event numbered id to begin. +// That is, except for the first event, it waits until End(id-1) has +// been called. +func (s *sequencer) Start(id uint) { + s.mu.Lock() + if s.id == id { + s.mu.Unlock() + return + } + c := make(chan uint) + if s.wait == nil { + s.wait = make(map[uint]chan uint) + } + s.wait[id] = c + s.mu.Unlock() + <-c +} + +// End notifies the sequencer that the event numbered id has completed, +// allowing it to schedule the event numbered id+1. It is a run-time error +// to call End with an id that is not the number of the active event. +func (s *sequencer) End(id uint) { + s.mu.Lock() + if s.id != id { + panic("out of sync") + } + id++ + s.id = id + if s.wait == nil { + s.wait = make(map[uint]chan uint) + } + c, ok := s.wait[id] + if ok { + s.wait[id] = nil, false + } + s.mu.Unlock() + if ok { + c <- 1 + } +} diff --git a/libgo/go/net/textproto/reader.go b/libgo/go/net/textproto/reader.go new file mode 100644 index 000000000..c8e34b758 --- /dev/null +++ b/libgo/go/net/textproto/reader.go @@ -0,0 +1,492 @@ +// Copyright 2010 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package textproto + +import ( + "bufio" + "bytes" + "container/vector" + "io" + "io/ioutil" + "os" + "strconv" +) + +// BUG(rsc): To let callers manage exposure to denial of service +// attacks, Reader should allow them to set and reset a limit on +// the number of bytes read from the connection. + +// A Reader implements convenience methods for reading requests +// or responses from a text protocol network connection. +type Reader struct { + R *bufio.Reader + dot *dotReader +} + +// NewReader returns a new Reader reading from r. +func NewReader(r *bufio.Reader) *Reader { + return &Reader{R: r} +} + +// ReadLine reads a single line from r, +// eliding the final \n or \r\n from the returned string. +func (r *Reader) ReadLine() (string, os.Error) { + line, err := r.ReadLineBytes() + return string(line), err +} + +// ReadLineBytes is like ReadLine but returns a []byte instead of a string. +func (r *Reader) ReadLineBytes() ([]byte, os.Error) { + r.closeDot() + line, err := r.R.ReadBytes('\n') + n := len(line) + if n > 0 && line[n-1] == '\n' { + n-- + if n > 0 && line[n-1] == '\r' { + n-- + } + } + return line[0:n], err +} + +// ReadContinuedLine reads a possibly continued line from r, +// eliding the final trailing ASCII white space. +// Lines after the first are considered continuations if they +// begin with a space or tab character. In the returned data, +// continuation lines are separated from the previous line +// only by a single space: the newline and leading white space +// are removed. +// +// For example, consider this input: +// +// Line 1 +// continued... +// Line 2 +// +// The first call to ReadContinuedLine will return "Line 1 continued..." +// and the second will return "Line 2". +// +// A line consisting of only white space is never continued. +// +func (r *Reader) ReadContinuedLine() (string, os.Error) { + line, err := r.ReadContinuedLineBytes() + return string(line), err +} + +// trim returns s with leading and trailing spaces and tabs removed. +// It does not assume Unicode or UTF-8. +func trim(s []byte) []byte { + i := 0 + for i < len(s) && (s[i] == ' ' || s[i] == '\t') { + i++ + } + n := len(s) + for n > i && (s[n-1] == ' ' || s[n-1] == '\t') { + n-- + } + return s[i:n] +} + +// ReadContinuedLineBytes is like ReadContinuedLine but +// returns a []byte instead of a string. +func (r *Reader) ReadContinuedLineBytes() ([]byte, os.Error) { + // Read the first line. + line, err := r.ReadLineBytes() + if err != nil { + return line, err + } + if len(line) == 0 { // blank line - no continuation + return line, nil + } + line = trim(line) + + // Look for a continuation line. + c, err := r.R.ReadByte() + if err != nil { + // Delay err until we read the byte next time. + return line, nil + } + if c != ' ' && c != '\t' { + // Not a continuation. + r.R.UnreadByte() + return line, nil + } + + // Read continuation lines. + for { + // Consume leading spaces; one already gone. + for { + c, err = r.R.ReadByte() + if err != nil { + break + } + if c != ' ' && c != '\t' { + r.R.UnreadByte() + break + } + } + var cont []byte + cont, err = r.ReadLineBytes() + cont = trim(cont) + line = append(line, ' ') + line = append(line, cont...) + if err != nil { + break + } + + // Check for leading space on next line. + if c, err = r.R.ReadByte(); err != nil { + break + } + if c != ' ' && c != '\t' { + r.R.UnreadByte() + break + } + } + + // Delay error until next call. + if len(line) > 0 { + err = nil + } + return line, err +} + +func (r *Reader) readCodeLine(expectCode int) (code int, continued bool, message string, err os.Error) { + line, err := r.ReadLine() + if err != nil { + return + } + if len(line) < 4 || line[3] != ' ' && line[3] != '-' { + err = ProtocolError("short response: " + line) + return + } + continued = line[3] == '-' + code, err = strconv.Atoi(line[0:3]) + if err != nil || code < 100 { + err = ProtocolError("invalid response code: " + line) + return + } + message = line[4:] + if 1 <= expectCode && expectCode < 10 && code/100 != expectCode || + 10 <= expectCode && expectCode < 100 && code/10 != expectCode || + 100 <= expectCode && expectCode < 1000 && code != expectCode { + err = &Error{code, message} + } + return +} + +// ReadCodeLine reads a response code line of the form +// code message +// where code is a 3-digit status code and the message +// extends to the rest of the line. An example of such a line is: +// 220 plan9.bell-labs.com ESMTP +// +// If the prefix of the status does not match the digits in expectCode, +// ReadCodeLine returns with err set to &Error{code, message}. +// For example, if expectCode is 31, an error will be returned if +// the status is not in the range [310,319]. +// +// If the response is multi-line, ReadCodeLine returns an error. +// +// An expectCode <= 0 disables the check of the status code. +// +func (r *Reader) ReadCodeLine(expectCode int) (code int, message string, err os.Error) { + code, continued, message, err := r.readCodeLine(expectCode) + if err == nil && continued { + err = ProtocolError("unexpected multi-line response: " + message) + } + return +} + +// ReadResponse reads a multi-line response of the form +// code-message line 1 +// code-message line 2 +// ... +// code message line n +// where code is a 3-digit status code. Each line should have the same code. +// The response is terminated by a line that uses a space between the code and +// the message line rather than a dash. Each line in message is separated by +// a newline (\n). +// +// If the prefix of the status does not match the digits in expectCode, +// ReadResponse returns with err set to &Error{code, message}. +// For example, if expectCode is 31, an error will be returned if +// the status is not in the range [310,319]. +// +// An expectCode <= 0 disables the check of the status code. +// +func (r *Reader) ReadResponse(expectCode int) (code int, message string, err os.Error) { + code, continued, message, err := r.readCodeLine(expectCode) + for err == nil && continued { + var code2 int + var moreMessage string + code2, continued, moreMessage, err = r.readCodeLine(expectCode) + if code != code2 { + err = ProtocolError("status code mismatch: " + strconv.Itoa(code) + ", " + strconv.Itoa(code2)) + } + message += "\n" + moreMessage + } + return +} + +// DotReader returns a new Reader that satisfies Reads using the +// decoded text of a dot-encoded block read from r. +// The returned Reader is only valid until the next call +// to a method on r. +// +// Dot encoding is a common framing used for data blocks +// in text protcols like SMTP. The data consists of a sequence +// of lines, each of which ends in "\r\n". The sequence itself +// ends at a line containing just a dot: ".\r\n". Lines beginning +// with a dot are escaped with an additional dot to avoid +// looking like the end of the sequence. +// +// The decoded form returned by the Reader's Read method +// rewrites the "\r\n" line endings into the simpler "\n", +// removes leading dot escapes if present, and stops with error os.EOF +// after consuming (and discarding) the end-of-sequence line. +func (r *Reader) DotReader() io.Reader { + r.closeDot() + r.dot = &dotReader{r: r} + return r.dot +} + +type dotReader struct { + r *Reader + state int +} + +// Read satisfies reads by decoding dot-encoded data read from d.r. +func (d *dotReader) Read(b []byte) (n int, err os.Error) { + // Run data through a simple state machine to + // elide leading dots, rewrite trailing \r\n into \n, + // and detect ending .\r\n line. + const ( + stateBeginLine = iota // beginning of line; initial state; must be zero + stateDot // read . at beginning of line + stateDotCR // read .\r at beginning of line + stateCR // read \r (possibly at end of line) + stateData // reading data in middle of line + stateEOF // reached .\r\n end marker line + ) + br := d.r.R + for n < len(b) && d.state != stateEOF { + var c byte + c, err = br.ReadByte() + if err != nil { + if err == os.EOF { + err = io.ErrUnexpectedEOF + } + break + } + switch d.state { + case stateBeginLine: + if c == '.' { + d.state = stateDot + continue + } + if c == '\r' { + d.state = stateCR + continue + } + d.state = stateData + + case stateDot: + if c == '\r' { + d.state = stateDotCR + continue + } + if c == '\n' { + d.state = stateEOF + continue + } + d.state = stateData + + case stateDotCR: + if c == '\n' { + d.state = stateEOF + continue + } + // Not part of .\r\n. + // Consume leading dot and emit saved \r. + br.UnreadByte() + c = '\r' + d.state = stateData + + case stateCR: + if c == '\n' { + d.state = stateBeginLine + break + } + // Not part of \r\n. Emit saved \r + br.UnreadByte() + c = '\r' + d.state = stateData + + case stateData: + if c == '\r' { + d.state = stateCR + continue + } + if c == '\n' { + d.state = stateBeginLine + } + } + b[n] = c + n++ + } + if err == nil && d.state == stateEOF { + err = os.EOF + } + if err != nil && d.r.dot == d { + d.r.dot = nil + } + return +} + +// closeDot drains the current DotReader if any, +// making sure that it reads until the ending dot line. +func (r *Reader) closeDot() { + if r.dot == nil { + return + } + buf := make([]byte, 128) + for r.dot != nil { + // When Read reaches EOF or an error, + // it will set r.dot == nil. + r.dot.Read(buf) + } +} + +// ReadDotBytes reads a dot-encoding and returns the decoded data. +// +// See the documentation for the DotReader method for details about dot-encoding. +func (r *Reader) ReadDotBytes() ([]byte, os.Error) { + return ioutil.ReadAll(r.DotReader()) +} + +// ReadDotLines reads a dot-encoding and returns a slice +// containing the decoded lines, with the final \r\n or \n elided from each. +// +// See the documentation for the DotReader method for details about dot-encoding. +func (r *Reader) ReadDotLines() ([]string, os.Error) { + // We could use ReadDotBytes and then Split it, + // but reading a line at a time avoids needing a + // large contiguous block of memory and is simpler. + var v vector.StringVector + var err os.Error + for { + var line string + line, err = r.ReadLine() + if err != nil { + if err == os.EOF { + err = io.ErrUnexpectedEOF + } + break + } + + // Dot by itself marks end; otherwise cut one dot. + if len(line) > 0 && line[0] == '.' { + if len(line) == 1 { + break + } + line = line[1:] + } + v.Push(line) + } + return v, err +} + +// ReadMIMEHeader reads a MIME-style header from r. +// The header is a sequence of possibly continued Key: Value lines +// ending in a blank line. +// The returned map m maps CanonicalHeaderKey(key) to a +// sequence of values in the same order encountered in the input. +// +// For example, consider this input: +// +// My-Key: Value 1 +// Long-Key: Even +// Longer Value +// My-Key: Value 2 +// +// Given that input, ReadMIMEHeader returns the map: +// +// map[string][]string{ +// "My-Key": []string{"Value 1", "Value 2"}, +// "Long-Key": []string{"Even Longer Value"}, +// } +// +func (r *Reader) ReadMIMEHeader() (map[string][]string, os.Error) { + m := make(map[string][]string) + for { + kv, err := r.ReadContinuedLineBytes() + if len(kv) == 0 { + return m, err + } + + // Key ends at first colon; must not have spaces. + i := bytes.IndexByte(kv, ':') + if i < 0 || bytes.IndexByte(kv[0:i], ' ') >= 0 { + return m, ProtocolError("malformed MIME header line: " + string(kv)) + } + key := CanonicalHeaderKey(string(kv[0:i])) + + // Skip initial spaces in value. + i++ // skip colon + for i < len(kv) && (kv[i] == ' ' || kv[i] == '\t') { + i++ + } + value := string(kv[i:]) + + v := vector.StringVector(m[key]) + v.Push(value) + m[key] = v + + if err != nil { + return m, err + } + } + panic("unreachable") +} + +// CanonicalHeaderKey returns the canonical format of the +// MIME header key s. The canonicalization converts the first +// letter and any letter following a hyphen to upper case; +// the rest are converted to lowercase. For example, the +// canonical key for "accept-encoding" is "Accept-Encoding". +func CanonicalHeaderKey(s string) string { + // Quick check for canonical encoding. + needUpper := true + for i := 0; i < len(s); i++ { + c := s[i] + if needUpper && 'a' <= c && c <= 'z' { + goto MustRewrite + } + if !needUpper && 'A' <= c && c <= 'Z' { + goto MustRewrite + } + needUpper = c == '-' + } + return s + +MustRewrite: + // Canonicalize: first letter upper case + // and upper case after each dash. + // (Host, User-Agent, If-Modified-Since). + // MIME headers are ASCII only, so no Unicode issues. + a := []byte(s) + upper := true + for i, v := range a { + if upper && 'a' <= v && v <= 'z' { + a[i] = v + 'A' - 'a' + } + if !upper && 'A' <= v && v <= 'Z' { + a[i] = v + 'a' - 'A' + } + upper = v == '-' + } + return string(a) +} diff --git a/libgo/go/net/textproto/reader_test.go b/libgo/go/net/textproto/reader_test.go new file mode 100644 index 000000000..2cecbc75f --- /dev/null +++ b/libgo/go/net/textproto/reader_test.go @@ -0,0 +1,140 @@ +// Copyright 2010 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package textproto + +import ( + "bufio" + "io" + "os" + "reflect" + "strings" + "testing" +) + +type canonicalHeaderKeyTest struct { + in, out string +} + +var canonicalHeaderKeyTests = []canonicalHeaderKeyTest{ + {"a-b-c", "A-B-C"}, + {"a-1-c", "A-1-C"}, + {"User-Agent", "User-Agent"}, + {"uSER-aGENT", "User-Agent"}, + {"user-agent", "User-Agent"}, + {"USER-AGENT", "User-Agent"}, +} + +func TestCanonicalHeaderKey(t *testing.T) { + for _, tt := range canonicalHeaderKeyTests { + if s := CanonicalHeaderKey(tt.in); s != tt.out { + t.Errorf("CanonicalHeaderKey(%q) = %q, want %q", tt.in, s, tt.out) + } + } +} + +func reader(s string) *Reader { + return NewReader(bufio.NewReader(strings.NewReader(s))) +} + +func TestReadLine(t *testing.T) { + r := reader("line1\nline2\n") + s, err := r.ReadLine() + if s != "line1" || err != nil { + t.Fatalf("Line 1: %s, %v", s, err) + } + s, err = r.ReadLine() + if s != "line2" || err != nil { + t.Fatalf("Line 2: %s, %v", s, err) + } + s, err = r.ReadLine() + if s != "" || err != os.EOF { + t.Fatalf("EOF: %s, %v", s, err) + } +} + +func TestReadContinuedLine(t *testing.T) { + r := reader("line1\nline\n 2\nline3\n") + s, err := r.ReadContinuedLine() + if s != "line1" || err != nil { + t.Fatalf("Line 1: %s, %v", s, err) + } + s, err = r.ReadContinuedLine() + if s != "line 2" || err != nil { + t.Fatalf("Line 2: %s, %v", s, err) + } + s, err = r.ReadContinuedLine() + if s != "line3" || err != nil { + t.Fatalf("Line 3: %s, %v", s, err) + } + s, err = r.ReadContinuedLine() + if s != "" || err != os.EOF { + t.Fatalf("EOF: %s, %v", s, err) + } +} + +func TestReadCodeLine(t *testing.T) { + r := reader("123 hi\n234 bye\n345 no way\n") + code, msg, err := r.ReadCodeLine(0) + if code != 123 || msg != "hi" || err != nil { + t.Fatalf("Line 1: %d, %s, %v", code, msg, err) + } + code, msg, err = r.ReadCodeLine(23) + if code != 234 || msg != "bye" || err != nil { + t.Fatalf("Line 2: %d, %s, %v", code, msg, err) + } + code, msg, err = r.ReadCodeLine(346) + if code != 345 || msg != "no way" || err == nil { + t.Fatalf("Line 3: %d, %s, %v", code, msg, err) + } + if e, ok := err.(*Error); !ok || e.Code != code || e.Msg != msg { + t.Fatalf("Line 3: wrong error %v\n", err) + } + code, msg, err = r.ReadCodeLine(1) + if code != 0 || msg != "" || err != os.EOF { + t.Fatalf("EOF: %d, %s, %v", code, msg, err) + } +} + +func TestReadDotLines(t *testing.T) { + r := reader("dotlines\r\n.foo\r\n..bar\n...baz\nquux\r\n\r\n.\r\nanother\n") + s, err := r.ReadDotLines() + want := []string{"dotlines", "foo", ".bar", "..baz", "quux", ""} + if !reflect.DeepEqual(s, want) || err != nil { + t.Fatalf("ReadDotLines: %v, %v", s, err) + } + + s, err = r.ReadDotLines() + want = []string{"another"} + if !reflect.DeepEqual(s, want) || err != io.ErrUnexpectedEOF { + t.Fatalf("ReadDotLines2: %v, %v", s, err) + } +} + +func TestReadDotBytes(t *testing.T) { + r := reader("dotlines\r\n.foo\r\n..bar\n...baz\nquux\r\n\r\n.\r\nanot.her\r\n") + b, err := r.ReadDotBytes() + want := []byte("dotlines\nfoo\n.bar\n..baz\nquux\n\n") + if !reflect.DeepEqual(b, want) || err != nil { + t.Fatalf("ReadDotBytes: %q, %v", b, err) + } + + b, err = r.ReadDotBytes() + want = []byte("anot.her\n") + if !reflect.DeepEqual(b, want) || err != io.ErrUnexpectedEOF { + t.Fatalf("ReadDotBytes2: %q, %v", b, err) + } +} + +func TestReadMIMEHeader(t *testing.T) { + r := reader("my-key: Value 1 \r\nLong-key: Even \n Longer Value\r\nmy-Key: Value 2\r\n\n") + m, err := r.ReadMIMEHeader() + want := map[string][]string{ + "My-Key": {"Value 1", "Value 2"}, + "Long-Key": {"Even Longer Value"}, + } + if !reflect.DeepEqual(m, want) || err != nil { + t.Fatalf("ReadMIMEHeader: %v, %v; want %v", m, err, want) + } +} diff --git a/libgo/go/net/textproto/textproto.go b/libgo/go/net/textproto/textproto.go new file mode 100644 index 000000000..f62009c52 --- /dev/null +++ b/libgo/go/net/textproto/textproto.go @@ -0,0 +1,122 @@ +// Copyright 2010 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// The textproto package implements generic support for +// text-based request/response protocols in the style of +// HTTP, NNTP, and SMTP. +// +// The package provides: +// +// Error, which represents a numeric error response from +// a server. +// +// Pipeline, to manage pipelined requests and responses +// in a client. +// +// Reader, to read numeric response code lines, +// key: value headers, lines wrapped with leading spaces +// on continuation lines, and whole text blocks ending +// with a dot on a line by itself. +// +// Writer, to write dot-encoded text blocks. +// +package textproto + +import ( + "bufio" + "fmt" + "io" + "net" + "os" +) + +// An Error represents a numeric error response from a server. +type Error struct { + Code int + Msg string +} + +func (e *Error) String() string { + return fmt.Sprintf("%03d %s", e.Code, e.Msg) +} + +// A ProtocolError describes a protocol violation such +// as an invalid response or a hung-up connection. +type ProtocolError string + +func (p ProtocolError) String() string { + return string(p) +} + +// A Conn represents a textual network protocol connection. +// It consists of a Reader and Writer to manage I/O +// and a Pipeline to sequence concurrent requests on the connection. +// These embedded types carry methods with them; +// see the documentation of those types for details. +type Conn struct { + Reader + Writer + Pipeline + conn io.ReadWriteCloser +} + +// NewConn returns a new Conn using conn for I/O. +func NewConn(conn io.ReadWriteCloser) *Conn { + return &Conn{ + Reader: Reader{R: bufio.NewReader(conn)}, + Writer: Writer{W: bufio.NewWriter(conn)}, + conn: conn, + } +} + +// Close closes the connection. +func (c *Conn) Close() os.Error { + return c.conn.Close() +} + +// Dial connects to the given address on the given network using net.Dial +// and then returns a new Conn for the connection. +func Dial(network, addr string) (*Conn, os.Error) { + c, err := net.Dial(network, "", addr) + if err != nil { + return nil, err + } + return NewConn(c), nil +} + +// Cmd is a convenience method that sends a command after +// waiting its turn in the pipeline. The command text is the +// result of formatting format with args and appending \r\n. +// Cmd returns the id of the command, for use with StartResponse and EndResponse. +// +// For example, a client might run a HELP command that returns a dot-body +// by using: +// +// id, err := c.Cmd("HELP") +// if err != nil { +// return nil, err +// } +// +// c.StartResponse(id) +// defer c.EndResponse(id) +// +// if _, _, err = c.ReadCodeLine(110); err != nil { +// return nil, err +// } +// text, err := c.ReadDotAll() +// if err != nil { +// return nil, err +// } +// return c.ReadCodeLine(250) +// +func (c *Conn) Cmd(format string, args ...interface{}) (id uint, err os.Error) { + id = c.Next() + c.StartRequest(id) + err = c.PrintfLine(format, args...) + c.EndRequest(id) + if err != nil { + return 0, err + } + return id, nil +} diff --git a/libgo/go/net/textproto/writer.go b/libgo/go/net/textproto/writer.go new file mode 100644 index 000000000..4e705f6c3 --- /dev/null +++ b/libgo/go/net/textproto/writer.go @@ -0,0 +1,119 @@ +// Copyright 2010 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package textproto + +import ( + "bufio" + "fmt" + "io" + "os" +) + +// A Writer implements convenience methods for writing +// requests or responses to a text protocol network connection. +type Writer struct { + W *bufio.Writer + dot *dotWriter +} + +// NewWriter returns a new Writer writing to w. +func NewWriter(w *bufio.Writer) *Writer { + return &Writer{W: w} +} + +var crnl = []byte{'\r', '\n'} +var dotcrnl = []byte{'.', '\r', '\n'} + +// PrintfLine writes the formatted output followed by \r\n. +func (w *Writer) PrintfLine(format string, args ...interface{}) os.Error { + w.closeDot() + fmt.Fprintf(w.W, format, args...) + w.W.Write(crnl) + return w.W.Flush() +} + +// DotWriter returns a writer that can be used to write a dot-encoding to w. +// It takes care of inserting leading dots when necessary, +// translating line-ending \n into \r\n, and adding the final .\r\n line +// when the DotWriter is closed. The caller should close the +// DotWriter before the next call to a method on w. +// +// See the documentation for Reader's DotReader method for details about dot-encoding. +func (w *Writer) DotWriter() io.WriteCloser { + w.closeDot() + w.dot = &dotWriter{w: w} + return w.dot +} + +func (w *Writer) closeDot() { + if w.dot != nil { + w.dot.Close() // sets w.dot = nil + } +} + +type dotWriter struct { + w *Writer + state int +} + +const ( + wstateBeginLine = iota // beginning of line; initial state; must be zero + wstateCR // wrote \r (possibly at end of line) + wstateData // writing data in middle of line +) + +func (d *dotWriter) Write(b []byte) (n int, err os.Error) { + bw := d.w.W + for n < len(b) { + c := b[n] + switch d.state { + case wstateBeginLine: + d.state = wstateData + if c == '.' { + // escape leading dot + bw.WriteByte('.') + } + fallthrough + + case wstateData: + if c == '\r' { + d.state = wstateCR + } + if c == '\n' { + bw.WriteByte('\r') + d.state = wstateBeginLine + } + + case wstateCR: + d.state = wstateData + if c == '\n' { + d.state = wstateBeginLine + } + } + if err = bw.WriteByte(c); err != nil { + break + } + n++ + } + return +} + +func (d *dotWriter) Close() os.Error { + if d.w.dot == d { + d.w.dot = nil + } + bw := d.w.W + switch d.state { + default: + bw.WriteByte('\r') + fallthrough + case wstateCR: + bw.WriteByte('\n') + fallthrough + case wstateBeginLine: + bw.Write(dotcrnl) + } + return bw.Flush() +} diff --git a/libgo/go/net/textproto/writer_test.go b/libgo/go/net/textproto/writer_test.go new file mode 100644 index 000000000..e03ab5e15 --- /dev/null +++ b/libgo/go/net/textproto/writer_test.go @@ -0,0 +1,35 @@ +// Copyright 2010 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package textproto + +import ( + "bufio" + "bytes" + "testing" +) + +func TestPrintfLine(t *testing.T) { + var buf bytes.Buffer + w := NewWriter(bufio.NewWriter(&buf)) + err := w.PrintfLine("foo %d", 123) + if s := buf.String(); s != "foo 123\r\n" || err != nil { + t.Fatalf("s=%q; err=%s", s, err) + } +} + +func TestDotWriter(t *testing.T) { + var buf bytes.Buffer + w := NewWriter(bufio.NewWriter(&buf)) + d := w.DotWriter() + n, err := d.Write([]byte("abc\n.def\n..ghi\n.jkl\n.")) + if n != 21 || err != nil { + t.Fatalf("Write: %d, %s", n, err) + } + d.Close() + want := "abc\r\n..def\r\n...ghi\r\n..jkl\r\n..\r\n.\r\n" + if s := buf.String(); s != want { + t.Fatalf("wrote %q", s) + } +} diff --git a/libgo/go/net/timeout_test.go b/libgo/go/net/timeout_test.go new file mode 100644 index 000000000..09a257dc8 --- /dev/null +++ b/libgo/go/net/timeout_test.go @@ -0,0 +1,57 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package net + +import ( + "os" + "testing" + "time" +) + +func testTimeout(t *testing.T, network, addr string, readFrom bool) { + fd, err := Dial(network, "", addr) + if err != nil { + t.Errorf("dial %s %s failed: %v", network, addr, err) + return + } + defer fd.Close() + t0 := time.Nanoseconds() + fd.SetReadTimeout(1e8) // 100ms + var b [100]byte + var n int + var err1 os.Error + if readFrom { + n, _, err1 = fd.(PacketConn).ReadFrom(b[0:]) + } else { + n, err1 = fd.Read(b[0:]) + } + t1 := time.Nanoseconds() + what := "Read" + if readFrom { + what = "ReadFrom" + } + if n != 0 || err1 == nil || !err1.(Error).Timeout() { + t.Errorf("fd.%s on %s %s did not return 0, timeout: %v, %v", what, network, addr, n, err1) + } + if t1-t0 < 0.5e8 || t1-t0 > 1.5e8 { + t.Errorf("fd.%s on %s %s took %f seconds, expected 0.1", what, network, addr, float64(t1-t0)/1e9) + } +} + +func TestTimeoutUDP(t *testing.T) { + testTimeout(t, "udp", "127.0.0.1:53", false) + testTimeout(t, "udp", "127.0.0.1:53", true) +} + +func TestTimeoutTCP(t *testing.T) { + // set up a listener that won't talk back + listening := make(chan string) + done := make(chan int) + go runServe(t, "tcp", "127.0.0.1:0", listening, done) + addr := <-listening + + testTimeout(t, "tcp", addr, false) + <-done +} diff --git a/libgo/go/net/udpsock.go b/libgo/go/net/udpsock.go new file mode 100644 index 000000000..0270954c1 --- /dev/null +++ b/libgo/go/net/udpsock.go @@ -0,0 +1,281 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// UDP sockets + +package net + +import ( + "os" + "syscall" +) + +func sockaddrToUDP(sa syscall.Sockaddr) Addr { + switch sa := sa.(type) { + case *syscall.SockaddrInet4: + return &UDPAddr{sa.Addr[0:], sa.Port} + case *syscall.SockaddrInet6: + return &UDPAddr{sa.Addr[0:], sa.Port} + } + return nil +} + +// UDPAddr represents the address of a UDP end point. +type UDPAddr struct { + IP IP + Port int +} + +// Network returns the address's network name, "udp". +func (a *UDPAddr) Network() string { return "udp" } + +func (a *UDPAddr) String() string { + if a == nil { + return "<nil>" + } + return joinHostPort(a.IP.String(), itoa(a.Port)) +} + +func (a *UDPAddr) family() int { + if a == nil || len(a.IP) <= 4 { + return syscall.AF_INET + } + if ip := a.IP.To4(); ip != nil { + return syscall.AF_INET + } + return syscall.AF_INET6 +} + +func (a *UDPAddr) sockaddr(family int) (syscall.Sockaddr, os.Error) { + return ipToSockaddr(family, a.IP, a.Port) +} + +func (a *UDPAddr) toAddr() sockaddr { + if a == nil { // nil *UDPAddr + return nil // nil interface + } + return a +} + +// ResolveUDPAddr parses addr as a UDP address of the form +// host:port and resolves domain names or port names to +// numeric addresses. A literal IPv6 host address must be +// enclosed in square brackets, as in "[::]:80". +func ResolveUDPAddr(addr string) (*UDPAddr, os.Error) { + ip, port, err := hostPortToIP("udp", addr) + if err != nil { + return nil, err + } + return &UDPAddr{ip, port}, nil +} + +// UDPConn is the implementation of the Conn and PacketConn +// interfaces for UDP network connections. +type UDPConn struct { + fd *netFD +} + +func newUDPConn(fd *netFD) *UDPConn { return &UDPConn{fd} } + +func (c *UDPConn) ok() bool { return c != nil && c.fd != nil } + +// Implementation of the Conn interface - see Conn for documentation. + +// Read implements the net.Conn Read method. +func (c *UDPConn) Read(b []byte) (n int, err os.Error) { + if !c.ok() { + return 0, os.EINVAL + } + return c.fd.Read(b) +} + +// Write implements the net.Conn Write method. +func (c *UDPConn) Write(b []byte) (n int, err os.Error) { + if !c.ok() { + return 0, os.EINVAL + } + return c.fd.Write(b) +} + +// Close closes the UDP connection. +func (c *UDPConn) Close() os.Error { + if !c.ok() { + return os.EINVAL + } + err := c.fd.Close() + c.fd = nil + return err +} + +// LocalAddr returns the local network address. +func (c *UDPConn) LocalAddr() Addr { + if !c.ok() { + return nil + } + return c.fd.laddr +} + +// RemoteAddr returns the remote network address, a *UDPAddr. +func (c *UDPConn) RemoteAddr() Addr { + if !c.ok() { + return nil + } + return c.fd.raddr +} + +// SetTimeout implements the net.Conn SetTimeout method. +func (c *UDPConn) SetTimeout(nsec int64) os.Error { + if !c.ok() { + return os.EINVAL + } + return setTimeout(c.fd, nsec) +} + +// SetReadTimeout implements the net.Conn SetReadTimeout method. +func (c *UDPConn) SetReadTimeout(nsec int64) os.Error { + if !c.ok() { + return os.EINVAL + } + return setReadTimeout(c.fd, nsec) +} + +// SetWriteTimeout implements the net.Conn SetWriteTimeout method. +func (c *UDPConn) SetWriteTimeout(nsec int64) os.Error { + if !c.ok() { + return os.EINVAL + } + return setWriteTimeout(c.fd, nsec) +} + +// SetReadBuffer sets the size of the operating system's +// receive buffer associated with the connection. +func (c *UDPConn) SetReadBuffer(bytes int) os.Error { + if !c.ok() { + return os.EINVAL + } + return setReadBuffer(c.fd, bytes) +} + +// SetWriteBuffer sets the size of the operating system's +// transmit buffer associated with the connection. +func (c *UDPConn) SetWriteBuffer(bytes int) os.Error { + if !c.ok() { + return os.EINVAL + } + return setWriteBuffer(c.fd, bytes) +} + +// UDP-specific methods. + +// ReadFromUDP reads a UDP packet from c, copying the payload into b. +// It returns the number of bytes copied into b and the return address +// that was on the packet. +// +// ReadFromUDP can be made to time out and return an error with Timeout() == true +// after a fixed time limit; see SetTimeout and SetReadTimeout. +func (c *UDPConn) ReadFromUDP(b []byte) (n int, addr *UDPAddr, err os.Error) { + if !c.ok() { + return 0, nil, os.EINVAL + } + n, sa, err := c.fd.ReadFrom(b) + switch sa := sa.(type) { + case *syscall.SockaddrInet4: + addr = &UDPAddr{sa.Addr[0:], sa.Port} + case *syscall.SockaddrInet6: + addr = &UDPAddr{sa.Addr[0:], sa.Port} + } + return +} + +// ReadFrom implements the net.PacketConn ReadFrom method. +func (c *UDPConn) ReadFrom(b []byte) (n int, addr Addr, err os.Error) { + if !c.ok() { + return 0, nil, os.EINVAL + } + n, uaddr, err := c.ReadFromUDP(b) + return n, uaddr.toAddr(), err +} + +// WriteToUDP writes a UDP packet to addr via c, copying the payload from b. +// +// WriteToUDP can be made to time out and return +// an error with Timeout() == true after a fixed time limit; +// see SetTimeout and SetWriteTimeout. +// On packet-oriented connections, write timeouts are rare. +func (c *UDPConn) WriteToUDP(b []byte, addr *UDPAddr) (n int, err os.Error) { + if !c.ok() { + return 0, os.EINVAL + } + sa, err1 := addr.sockaddr(c.fd.family) + if err1 != nil { + return 0, &OpError{Op: "write", Net: "udp", Addr: addr, Error: err1} + } + return c.fd.WriteTo(b, sa) +} + +// WriteTo implements the net.PacketConn WriteTo method. +func (c *UDPConn) WriteTo(b []byte, addr Addr) (n int, err os.Error) { + if !c.ok() { + return 0, os.EINVAL + } + a, ok := addr.(*UDPAddr) + if !ok { + return 0, &OpError{"writeto", "udp", addr, os.EINVAL} + } + return c.WriteToUDP(b, a) +} + +// DialUDP connects to the remote address raddr on the network net, +// which must be "udp", "udp4", or "udp6". If laddr is not nil, it is used +// as the local address for the connection. +func DialUDP(net string, laddr, raddr *UDPAddr) (c *UDPConn, err os.Error) { + switch net { + case "udp", "udp4", "udp6": + default: + return nil, UnknownNetworkError(net) + } + if raddr == nil { + return nil, &OpError{"dial", "udp", nil, errMissingAddress} + } + fd, e := internetSocket(net, laddr.toAddr(), raddr.toAddr(), syscall.SOCK_DGRAM, 0, "dial", sockaddrToUDP) + if e != nil { + return nil, e + } + return newUDPConn(fd), nil +} + +// ListenUDP listens for incoming UDP packets addressed to the +// local address laddr. The returned connection c's ReadFrom +// and WriteTo methods can be used to receive and send UDP +// packets with per-packet addressing. +func ListenUDP(net string, laddr *UDPAddr) (c *UDPConn, err os.Error) { + switch net { + case "udp", "udp4", "udp6": + default: + return nil, UnknownNetworkError(net) + } + if laddr == nil { + return nil, &OpError{"listen", "udp", nil, errMissingAddress} + } + fd, e := internetSocket(net, laddr.toAddr(), nil, syscall.SOCK_DGRAM, 0, "dial", sockaddrToUDP) + if e != nil { + return nil, e + } + return newUDPConn(fd), nil +} + +// BindToDevice binds a UDPConn to a network interface. +func (c *UDPConn) BindToDevice(device string) os.Error { + if !c.ok() { + return os.EINVAL + } + c.fd.incref() + defer c.fd.decref() + return os.NewSyscallError("setsockopt", syscall.BindToDevice(c.fd.sysfd, device)) +} + +// File returns a copy of the underlying os.File, set to blocking mode. +// It is the caller's responsibility to close f when finished. +// Closing c does not affect f, and closing f does not affect c. +func (c *UDPConn) File() (f *os.File, err os.Error) { return c.fd.dup() } diff --git a/libgo/go/net/unixsock.go b/libgo/go/net/unixsock.go new file mode 100644 index 000000000..8c26a7baf --- /dev/null +++ b/libgo/go/net/unixsock.go @@ -0,0 +1,449 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Unix domain sockets + +package net + +import ( + "os" + "syscall" +) + +func unixSocket(net string, laddr, raddr *UnixAddr, mode string) (fd *netFD, err os.Error) { + var proto int + switch net { + default: + return nil, UnknownNetworkError(net) + case "unix": + proto = syscall.SOCK_STREAM + case "unixgram": + proto = syscall.SOCK_DGRAM + case "unixpacket": + proto = syscall.SOCK_SEQPACKET + } + + var la, ra syscall.Sockaddr + switch mode { + default: + panic("unixSocket mode " + mode) + + case "dial": + if laddr != nil { + la = &syscall.SockaddrUnix{Name: laddr.Name} + } + if raddr != nil { + ra = &syscall.SockaddrUnix{Name: raddr.Name} + } else if proto != syscall.SOCK_DGRAM || laddr == nil { + return nil, &OpError{Op: mode, Net: net, Error: errMissingAddress} + } + + case "listen": + if laddr == nil { + return nil, &OpError{mode, net, nil, errMissingAddress} + } + la = &syscall.SockaddrUnix{Name: laddr.Name} + if raddr != nil { + return nil, &OpError{Op: mode, Net: net, Addr: raddr, Error: &AddrError{Error: "unexpected remote address", Addr: raddr.String()}} + } + } + + f := sockaddrToUnix + if proto == syscall.SOCK_DGRAM { + f = sockaddrToUnixgram + } else if proto == syscall.SOCK_SEQPACKET { + f = sockaddrToUnixpacket + } + + fd, oserr := socket(net, syscall.AF_UNIX, proto, 0, la, ra, f) + if oserr != nil { + goto Error + } + return fd, nil + +Error: + addr := raddr + if mode == "listen" { + addr = laddr + } + return nil, &OpError{Op: mode, Net: net, Addr: addr, Error: oserr} +} + +// UnixAddr represents the address of a Unix domain socket end point. +type UnixAddr struct { + Name string + Net string +} + +func sockaddrToUnix(sa syscall.Sockaddr) Addr { + if s, ok := sa.(*syscall.SockaddrUnix); ok { + return &UnixAddr{s.Name, "unix"} + } + return nil +} + +func sockaddrToUnixgram(sa syscall.Sockaddr) Addr { + if s, ok := sa.(*syscall.SockaddrUnix); ok { + return &UnixAddr{s.Name, "unixgram"} + } + return nil +} + +func sockaddrToUnixpacket(sa syscall.Sockaddr) Addr { + if s, ok := sa.(*syscall.SockaddrUnix); ok { + return &UnixAddr{s.Name, "unixpacket"} + } + return nil +} + +func protoToNet(proto int) string { + switch proto { + case syscall.SOCK_STREAM: + return "unix" + case syscall.SOCK_SEQPACKET: + return "unixpacket" + case syscall.SOCK_DGRAM: + return "unixgram" + default: + panic("protoToNet unknown protocol") + } + return "" +} + +// Network returns the address's network name, "unix" or "unixgram". +func (a *UnixAddr) Network() string { + return a.Net +} + +func (a *UnixAddr) String() string { + if a == nil { + return "<nil>" + } + return a.Name +} + +func (a *UnixAddr) toAddr() Addr { + if a == nil { // nil *UnixAddr + return nil // nil interface + } + return a +} + +// ResolveUnixAddr parses addr as a Unix domain socket address. +// The string net gives the network name, "unix", "unixgram" or +// "unixpacket". +func ResolveUnixAddr(net, addr string) (*UnixAddr, os.Error) { + switch net { + case "unix": + case "unixpacket": + case "unixgram": + default: + return nil, UnknownNetworkError(net) + } + return &UnixAddr{addr, net}, nil +} + +// UnixConn is an implementation of the Conn interface +// for connections to Unix domain sockets. +type UnixConn struct { + fd *netFD +} + +func newUnixConn(fd *netFD) *UnixConn { return &UnixConn{fd} } + +func (c *UnixConn) ok() bool { return c != nil && c.fd != nil } + +// Implementation of the Conn interface - see Conn for documentation. + +// Read implements the net.Conn Read method. +func (c *UnixConn) Read(b []byte) (n int, err os.Error) { + if !c.ok() { + return 0, os.EINVAL + } + return c.fd.Read(b) +} + +// Write implements the net.Conn Write method. +func (c *UnixConn) Write(b []byte) (n int, err os.Error) { + if !c.ok() { + return 0, os.EINVAL + } + return c.fd.Write(b) +} + +// Close closes the Unix domain connection. +func (c *UnixConn) Close() os.Error { + if !c.ok() { + return os.EINVAL + } + err := c.fd.Close() + c.fd = nil + return err +} + +// LocalAddr returns the local network address, a *UnixAddr. +// Unlike in other protocols, LocalAddr is usually nil for dialed connections. +func (c *UnixConn) LocalAddr() Addr { + if !c.ok() { + return nil + } + return c.fd.laddr +} + +// RemoteAddr returns the remote network address, a *UnixAddr. +// Unlike in other protocols, RemoteAddr is usually nil for connections +// accepted by a listener. +func (c *UnixConn) RemoteAddr() Addr { + if !c.ok() { + return nil + } + return c.fd.raddr +} + +// SetTimeout implements the net.Conn SetTimeout method. +func (c *UnixConn) SetTimeout(nsec int64) os.Error { + if !c.ok() { + return os.EINVAL + } + return setTimeout(c.fd, nsec) +} + +// SetReadTimeout implements the net.Conn SetReadTimeout method. +func (c *UnixConn) SetReadTimeout(nsec int64) os.Error { + if !c.ok() { + return os.EINVAL + } + return setReadTimeout(c.fd, nsec) +} + +// SetWriteTimeout implements the net.Conn SetWriteTimeout method. +func (c *UnixConn) SetWriteTimeout(nsec int64) os.Error { + if !c.ok() { + return os.EINVAL + } + return setWriteTimeout(c.fd, nsec) +} + +// SetReadBuffer sets the size of the operating system's +// receive buffer associated with the connection. +func (c *UnixConn) SetReadBuffer(bytes int) os.Error { + if !c.ok() { + return os.EINVAL + } + return setReadBuffer(c.fd, bytes) +} + +// SetWriteBuffer sets the size of the operating system's +// transmit buffer associated with the connection. +func (c *UnixConn) SetWriteBuffer(bytes int) os.Error { + if !c.ok() { + return os.EINVAL + } + return setWriteBuffer(c.fd, bytes) +} + +// ReadFromUnix reads a packet from c, copying the payload into b. +// It returns the number of bytes copied into b and the return address +// that was on the packet. +// +// ReadFromUnix can be made to time out and return +// an error with Timeout() == true after a fixed time limit; +// see SetTimeout and SetReadTimeout. +func (c *UnixConn) ReadFromUnix(b []byte) (n int, addr *UnixAddr, err os.Error) { + if !c.ok() { + return 0, nil, os.EINVAL + } + n, sa, err := c.fd.ReadFrom(b) + switch sa := sa.(type) { + case *syscall.SockaddrUnix: + addr = &UnixAddr{sa.Name, protoToNet(c.fd.proto)} + } + return +} + +// ReadFrom implements the net.PacketConn ReadFrom method. +func (c *UnixConn) ReadFrom(b []byte) (n int, addr Addr, err os.Error) { + if !c.ok() { + return 0, nil, os.EINVAL + } + n, uaddr, err := c.ReadFromUnix(b) + return n, uaddr.toAddr(), err +} + +// WriteToUnix writes a packet to addr via c, copying the payload from b. +// +// WriteToUnix can be made to time out and return +// an error with Timeout() == true after a fixed time limit; +// see SetTimeout and SetWriteTimeout. +// On packet-oriented connections, write timeouts are rare. +func (c *UnixConn) WriteToUnix(b []byte, addr *UnixAddr) (n int, err os.Error) { + if !c.ok() { + return 0, os.EINVAL + } + if addr.Net != protoToNet(c.fd.proto) { + return 0, os.EAFNOSUPPORT + } + sa := &syscall.SockaddrUnix{Name: addr.Name} + return c.fd.WriteTo(b, sa) +} + +// WriteTo implements the net.PacketConn WriteTo method. +func (c *UnixConn) WriteTo(b []byte, addr Addr) (n int, err os.Error) { + if !c.ok() { + return 0, os.EINVAL + } + a, ok := addr.(*UnixAddr) + if !ok { + return 0, &OpError{"writeto", "unix", addr, os.EINVAL} + } + return c.WriteToUnix(b, a) +} + +func (c *UnixConn) ReadMsgUnix(b, oob []byte) (n, oobn, flags int, addr *UnixAddr, err os.Error) { + if !c.ok() { + return 0, 0, 0, nil, os.EINVAL + } + n, oobn, flags, sa, err := c.fd.ReadMsg(b, oob) + switch sa := sa.(type) { + case *syscall.SockaddrUnix: + addr = &UnixAddr{sa.Name, protoToNet(c.fd.proto)} + } + return +} + +func (c *UnixConn) WriteMsgUnix(b, oob []byte, addr *UnixAddr) (n, oobn int, err os.Error) { + if !c.ok() { + return 0, 0, os.EINVAL + } + if addr != nil { + if addr.Net != protoToNet(c.fd.proto) { + return 0, 0, os.EAFNOSUPPORT + } + sa := &syscall.SockaddrUnix{Name: addr.Name} + return c.fd.WriteMsg(b, oob, sa) + } + return c.fd.WriteMsg(b, oob, nil) +} + +// File returns a copy of the underlying os.File, set to blocking mode. +// It is the caller's responsibility to close f when finished. +// Closing c does not affect f, and closing f does not affect c. +func (c *UnixConn) File() (f *os.File, err os.Error) { return c.fd.dup() } + +// DialUnix connects to the remote address raddr on the network net, +// which must be "unix" or "unixgram". If laddr is not nil, it is used +// as the local address for the connection. +func DialUnix(net string, laddr, raddr *UnixAddr) (c *UnixConn, err os.Error) { + fd, e := unixSocket(net, laddr, raddr, "dial") + if e != nil { + return nil, e + } + return newUnixConn(fd), nil +} + +// UnixListener is a Unix domain socket listener. +// Clients should typically use variables of type Listener +// instead of assuming Unix domain sockets. +type UnixListener struct { + fd *netFD + path string +} + +// ListenUnix announces on the Unix domain socket laddr and returns a Unix listener. +// Net must be "unix" (stream sockets). +func ListenUnix(net string, laddr *UnixAddr) (l *UnixListener, err os.Error) { + if net != "unix" && net != "unixgram" && net != "unixpacket" { + return nil, UnknownNetworkError(net) + } + if laddr != nil { + laddr = &UnixAddr{laddr.Name, net} // make our own copy + } + fd, err := unixSocket(net, laddr, nil, "listen") + if err != nil { + return nil, err + } + e1 := syscall.Listen(fd.sysfd, 8) // listenBacklog()); + if e1 != 0 { + closesocket(fd.sysfd) + return nil, &OpError{Op: "listen", Net: "unix", Addr: laddr, Error: os.Errno(e1)} + } + return &UnixListener{fd, laddr.Name}, nil +} + +// AcceptUnix accepts the next incoming call and returns the new connection +// and the remote address. +func (l *UnixListener) AcceptUnix() (c *UnixConn, err os.Error) { + if l == nil || l.fd == nil { + return nil, os.EINVAL + } + fd, e := l.fd.accept(sockaddrToUnix) + if e != nil { + return nil, e + } + c = newUnixConn(fd) + return c, nil +} + +// Accept implements the Accept method in the Listener interface; +// it waits for the next call and returns a generic Conn. +func (l *UnixListener) Accept() (c Conn, err os.Error) { + c1, err := l.AcceptUnix() + if err != nil { + return nil, err + } + return c1, nil +} + +// Close stops listening on the Unix address. +// Already accepted connections are not closed. +func (l *UnixListener) Close() os.Error { + if l == nil || l.fd == nil { + return os.EINVAL + } + + // The operating system doesn't clean up + // the file that announcing created, so + // we have to clean it up ourselves. + // There's a race here--we can't know for + // sure whether someone else has come along + // and replaced our socket name already-- + // but this sequence (remove then close) + // is at least compatible with the auto-remove + // sequence in ListenUnix. It's only non-Go + // programs that can mess us up. + if l.path[0] != '@' { + syscall.Unlink(l.path) + } + err := l.fd.Close() + l.fd = nil + return err +} + +// Addr returns the listener's network address. +func (l *UnixListener) Addr() Addr { return l.fd.laddr } + +// File returns a copy of the underlying os.File, set to blocking mode. +// It is the caller's responsibility to close f when finished. +// Closing c does not affect f, and closing f does not affect c. +func (l *UnixListener) File() (f *os.File, err os.Error) { return l.fd.dup() } + +// ListenUnixgram listens for incoming Unix datagram packets addressed to the +// local address laddr. The returned connection c's ReadFrom +// and WriteTo methods can be used to receive and send UDP +// packets with per-packet addressing. The network net must be "unixgram". +func ListenUnixgram(net string, laddr *UnixAddr) (c *UDPConn, err os.Error) { + switch net { + case "unixgram": + default: + return nil, UnknownNetworkError(net) + } + if laddr == nil { + return nil, &OpError{"listen", "unixgram", nil, errMissingAddress} + } + fd, e := unixSocket(net, laddr, nil, "listen") + if e != nil { + return nil, e + } + return newUDPConn(fd), nil +} |