diff options
Diffstat (limited to 'libgo/go/net/dnsclient_unix.go')
-rw-r--r-- | libgo/go/net/dnsclient_unix.go | 225 |
1 files changed, 125 insertions, 100 deletions
diff --git a/libgo/go/net/dnsclient_unix.go b/libgo/go/net/dnsclient_unix.go index 17188f0024c0..8f2dff46751b 100644 --- a/libgo/go/net/dnsclient_unix.go +++ b/libgo/go/net/dnsclient_unix.go @@ -16,6 +16,7 @@ package net import ( + "context" "errors" "io" "math/rand" @@ -26,10 +27,10 @@ import ( // A dnsDialer provides dialing suitable for DNS queries. type dnsDialer interface { - dialDNS(string, string) (dnsConn, error) + dialDNS(ctx context.Context, network, addr string) (dnsConn, error) } -var testHookDNSDialer = func(d time.Duration) dnsDialer { return &Dialer{Timeout: d} } +var testHookDNSDialer = func() dnsDialer { return &Dialer{} } // A dnsConn represents a DNS transport endpoint. type dnsConn interface { @@ -37,46 +38,67 @@ type dnsConn interface { SetDeadline(time.Time) error - // readDNSResponse reads a DNS response message from the DNS - // transport endpoint and returns the received DNS response - // message. - readDNSResponse() (*dnsMsg, error) + // dnsRoundTrip executes a single DNS transaction, returning a + // DNS response message for the provided DNS query message. + dnsRoundTrip(query *dnsMsg) (*dnsMsg, error) +} - // writeDNSQuery writes a DNS query message to the DNS - // connection endpoint. - writeDNSQuery(*dnsMsg) error +func (c *UDPConn) dnsRoundTrip(query *dnsMsg) (*dnsMsg, error) { + return dnsRoundTripUDP(c, query) } -func (c *UDPConn) readDNSResponse() (*dnsMsg, error) { - b := make([]byte, 512) // see RFC 1035 - n, err := c.Read(b) - if err != nil { +// dnsRoundTripUDP implements the dnsRoundTrip interface for RFC 1035's +// "UDP usage" transport mechanism. c should be a packet-oriented connection, +// such as a *UDPConn. +func dnsRoundTripUDP(c io.ReadWriter, query *dnsMsg) (*dnsMsg, error) { + b, ok := query.Pack() + if !ok { + return nil, errors.New("cannot marshal DNS message") + } + if _, err := c.Write(b); err != nil { return nil, err } - msg := &dnsMsg{} - if !msg.Unpack(b[:n]) { - return nil, errors.New("cannot unmarshal DNS message") + + b = make([]byte, 512) // see RFC 1035 + for { + n, err := c.Read(b) + if err != nil { + return nil, err + } + resp := &dnsMsg{} + if !resp.Unpack(b[:n]) || !resp.IsResponseTo(query) { + // Ignore invalid responses as they may be malicious + // forgery attempts. Instead continue waiting until + // timeout. See golang.org/issue/13281. + continue + } + return resp, nil } - return msg, nil } -func (c *UDPConn) writeDNSQuery(msg *dnsMsg) error { - b, ok := msg.Pack() +func (c *TCPConn) dnsRoundTrip(out *dnsMsg) (*dnsMsg, error) { + return dnsRoundTripTCP(c, out) +} + +// dnsRoundTripTCP implements the dnsRoundTrip interface for RFC 1035's +// "TCP usage" transport mechanism. c should be a stream-oriented connection, +// such as a *TCPConn. +func dnsRoundTripTCP(c io.ReadWriter, query *dnsMsg) (*dnsMsg, error) { + b, ok := query.Pack() if !ok { - return errors.New("cannot marshal DNS message") + return nil, errors.New("cannot marshal DNS message") } + l := len(b) + b = append([]byte{byte(l >> 8), byte(l)}, b...) if _, err := c.Write(b); err != nil { - return err + return nil, err } - return nil -} -func (c *TCPConn) readDNSResponse() (*dnsMsg, error) { - b := make([]byte, 1280) // 1280 is a reasonable initial size for IP over Ethernet, see RFC 4035 + b = make([]byte, 1280) // 1280 is a reasonable initial size for IP over Ethernet, see RFC 4035 if _, err := io.ReadFull(c, b[:2]); err != nil { return nil, err } - l := int(b[0])<<8 | int(b[1]) + l = int(b[0])<<8 | int(b[1]) if l > len(b) { b = make([]byte, l) } @@ -84,27 +106,17 @@ func (c *TCPConn) readDNSResponse() (*dnsMsg, error) { if err != nil { return nil, err } - msg := &dnsMsg{} - if !msg.Unpack(b[:n]) { + resp := &dnsMsg{} + if !resp.Unpack(b[:n]) { return nil, errors.New("cannot unmarshal DNS message") } - return msg, nil -} - -func (c *TCPConn) writeDNSQuery(msg *dnsMsg) error { - b, ok := msg.Pack() - if !ok { - return errors.New("cannot marshal DNS message") + if !resp.IsResponseTo(query) { + return nil, errors.New("invalid DNS response") } - l := uint16(len(b)) - b = append([]byte{byte(l >> 8), byte(l)}, b...) - if _, err := c.Write(b); err != nil { - return err - } - return nil + return resp, nil } -func (d *Dialer) dialDNS(network, server string) (dnsConn, error) { +func (d *Dialer) dialDNS(ctx context.Context, network, server string) (dnsConn, error) { switch network { case "tcp", "tcp4", "tcp6", "udp", "udp4", "udp6": default: @@ -115,9 +127,9 @@ func (d *Dialer) dialDNS(network, server string) (dnsConn, error) { // call back here to translate it. The DNS config parser has // already checked that all the cfg.servers[i] are IP // addresses, which Dial will use without a DNS lookup. - c, err := d.Dial(network, server) + c, err := d.DialContext(ctx, network, server) if err != nil { - return nil, err + return nil, mapErr(err) } switch network { case "tcp", "tcp4", "tcp6": @@ -129,8 +141,8 @@ func (d *Dialer) dialDNS(network, server string) (dnsConn, error) { } // exchange sends a query on the connection and hopes for a response. -func exchange(server, name string, qtype uint16, timeout time.Duration) (*dnsMsg, error) { - d := testHookDNSDialer(timeout) +func exchange(ctx context.Context, server, name string, qtype uint16) (*dnsMsg, error) { + d := testHookDNSDialer() out := dnsMsg{ dnsMsgHdr: dnsMsgHdr{ recursion_desired: true, @@ -140,24 +152,18 @@ func exchange(server, name string, qtype uint16, timeout time.Duration) (*dnsMsg }, } for _, network := range []string{"udp", "tcp"} { - c, err := d.dialDNS(network, server) + c, err := d.dialDNS(ctx, network, server) if err != nil { return nil, err } defer c.Close() - if timeout > 0 { - c.SetDeadline(time.Now().Add(timeout)) + if d, ok := ctx.Deadline(); ok && !d.IsZero() { + c.SetDeadline(d) } out.id = uint16(rand.Int()) ^ uint16(time.Now().UnixNano()) - if err := c.writeDNSQuery(&out); err != nil { - return nil, err - } - in, err := c.readDNSResponse() + in, err := c.dnsRoundTrip(&out) if err != nil { - return nil, err - } - if in.id != out.id { - return nil, errors.New("DNS message ID mismatch") + return nil, mapErr(err) } if in.truncated { // see RFC 5966 continue @@ -169,16 +175,22 @@ func exchange(server, name string, qtype uint16, timeout time.Duration) (*dnsMsg // Do a lookup for a single name, which must be rooted // (otherwise answer will not find the answers). -func tryOneName(cfg *dnsConfig, name string, qtype uint16) (string, []dnsRR, error) { +func tryOneName(ctx context.Context, cfg *dnsConfig, name string, qtype uint16) (string, []dnsRR, error) { if len(cfg.servers) == 0 { return "", nil, &DNSError{Err: "no DNS servers", Name: name} } - timeout := time.Duration(cfg.timeout) * time.Second + + deadline := time.Now().Add(cfg.timeout) + if old, ok := ctx.Deadline(); !ok || deadline.Before(old) { + var cancel context.CancelFunc + ctx, cancel = context.WithDeadline(ctx, deadline) + defer cancel() + } + var lastErr error for i := 0; i < cfg.attempts; i++ { for _, server := range cfg.servers { - server = JoinHostPort(server, "53") - msg, err := exchange(server, name, qtype, timeout) + msg, err := exchange(ctx, server, name, qtype) if err != nil { lastErr = &DNSError{ Err: err.Error(), @@ -190,6 +202,12 @@ func tryOneName(cfg *dnsConfig, name string, qtype uint16) (string, []dnsRR, err } continue } + // libresolv continues to the next server when it receives + // an invalid referral response. See golang.org/issue/15434. + if msg.rcode == dnsRcodeSuccess && !msg.authoritative && !msg.recursion_available && len(msg.answer) == 0 && len(msg.extra) == 0 { + lastErr = &DNSError{Err: "lame referral", Name: name, Server: server} + continue + } cname, rrs, err := answer(name, server, msg, qtype) // If answer errored for rcodes dnsRcodeSuccess or dnsRcodeNameError, // it means the response in msg was not useful and trying another @@ -229,7 +247,6 @@ type resolverConfig struct { // time to recheck resolv.conf. ch chan struct{} // guards lastChecked and modTime lastChecked time.Time // last time resolv.conf was checked - modTime time.Time // time of resolv.conf modification mu sync.RWMutex // protects dnsConfig dnsConfig *dnsConfig // parsed resolv.conf structure used in lookups @@ -239,16 +256,12 @@ var resolvConf resolverConfig // init initializes conf and is only called via conf.initOnce. func (conf *resolverConfig) init() { - // Set dnsConfig, modTime, and lastChecked so we don't parse + // Set dnsConfig and lastChecked so we don't parse // resolv.conf twice the first time. conf.dnsConfig = systemConf().resolv if conf.dnsConfig == nil { conf.dnsConfig = dnsReadConfig("/etc/resolv.conf") } - - if fi, err := os.Stat("/etc/resolv.conf"); err == nil { - conf.modTime = fi.ModTime() - } conf.lastChecked = time.Now() // Prepare ch so that only one update of resolverConfig may @@ -274,17 +287,12 @@ func (conf *resolverConfig) tryUpdate(name string) { } conf.lastChecked = now + var mtime time.Time if fi, err := os.Stat(name); err == nil { - if fi.ModTime().Equal(conf.modTime) { - return - } - conf.modTime = fi.ModTime() - } else { - // If modTime wasn't set prior, assume nothing has changed. - if conf.modTime.IsZero() { - return - } - conf.modTime = time.Time{} + mtime = fi.ModTime() + } + if mtime.Equal(conf.dnsConfig.mtime) { + return } dnsConf := dnsReadConfig(name) @@ -306,7 +314,7 @@ func (conf *resolverConfig) releaseSema() { <-conf.ch } -func lookup(name string, qtype uint16) (cname string, rrs []dnsRR, err error) { +func lookup(ctx context.Context, name string, qtype uint16) (cname string, rrs []dnsRR, err error) { if !isDomainName(name) { return "", nil, &DNSError{Err: "invalid domain name", Name: name} } @@ -315,7 +323,7 @@ func lookup(name string, qtype uint16) (cname string, rrs []dnsRR, err error) { conf := resolvConf.dnsConfig resolvConf.mu.RUnlock() for _, fqdn := range conf.nameList(name) { - cname, rrs, err = tryOneName(conf, fqdn, qtype) + cname, rrs, err = tryOneName(ctx, conf, fqdn, qtype) if err == nil { break } @@ -329,30 +337,47 @@ func lookup(name string, qtype uint16) (cname string, rrs []dnsRR, err error) { return } +// avoidDNS reports whether this is a hostname for which we should not +// use DNS. Currently this includes only .onion and .local names, +// per RFC 7686 and RFC 6762, respectively. See golang.org/issue/13705. +func avoidDNS(name string) bool { + if name == "" { + return true + } + if name[len(name)-1] == '.' { + name = name[:len(name)-1] + } + return stringsHasSuffixFold(name, ".onion") || stringsHasSuffixFold(name, ".local") +} + // nameList returns a list of names for sequential DNS queries. func (conf *dnsConfig) nameList(name string) []string { + if avoidDNS(name) { + return nil + } + // If name is rooted (trailing dot), try only that name. rooted := len(name) > 0 && name[len(name)-1] == '.' if rooted { return []string{name} } + + hasNdots := count(name, '.') >= conf.ndots + name += "." + // Build list of search choices. names := make([]string, 0, 1+len(conf.search)) // If name has enough dots, try unsuffixed first. - if count(name, '.') >= conf.ndots { - names = append(names, name+".") + if hasNdots { + names = append(names, name) } // Try suffixes. for _, suffix := range conf.search { - suffixed := name + "." + suffix - if suffixed[len(suffixed)-1] != '.' { - suffixed += "." - } - names = append(names, suffixed) + names = append(names, name+suffix) } // Try unsuffixed, if not tried first above. - if count(name, '.') < conf.ndots { - names = append(names, name+".") + if !hasNdots { + names = append(names, name) } return names } @@ -392,11 +417,11 @@ func (o hostLookupOrder) String() string { // Normally we let cgo use the C library resolver instead of // depending on our lookup code, so that Go and C get the same // answers. -func goLookupHost(name string) (addrs []string, err error) { - return goLookupHostOrder(name, hostLookupFilesDNS) +func goLookupHost(ctx context.Context, name string) (addrs []string, err error) { + return goLookupHostOrder(ctx, name, hostLookupFilesDNS) } -func goLookupHostOrder(name string, order hostLookupOrder) (addrs []string, err error) { +func goLookupHostOrder(ctx context.Context, name string, order hostLookupOrder) (addrs []string, err error) { if order == hostLookupFilesDNS || order == hostLookupFiles { // Use entries from /etc/hosts if they match. addrs = lookupStaticHost(name) @@ -404,7 +429,7 @@ func goLookupHostOrder(name string, order hostLookupOrder) (addrs []string, err return } } - ips, err := goLookupIPOrder(name, order) + ips, err := goLookupIPOrder(ctx, name, order) if err != nil { return } @@ -430,11 +455,11 @@ func goLookupIPFiles(name string) (addrs []IPAddr) { // goLookupIP is the native Go implementation of LookupIP. // The libc versions are in cgo_*.go. -func goLookupIP(name string) (addrs []IPAddr, err error) { - return goLookupIPOrder(name, hostLookupFilesDNS) +func goLookupIP(ctx context.Context, name string) (addrs []IPAddr, err error) { + return goLookupIPOrder(ctx, name, hostLookupFilesDNS) } -func goLookupIPOrder(name string, order hostLookupOrder) (addrs []IPAddr, err error) { +func goLookupIPOrder(ctx context.Context, name string, order hostLookupOrder) (addrs []IPAddr, err error) { if order == hostLookupFilesDNS || order == hostLookupFiles { addrs = goLookupIPFiles(name) if len(addrs) > 0 || order == hostLookupFiles { @@ -459,7 +484,7 @@ func goLookupIPOrder(name string, order hostLookupOrder) (addrs []IPAddr, err er for _, fqdn := range conf.nameList(name) { for _, qtype := range qtypes { go func(qtype uint16) { - _, rrs, err := tryOneName(conf, fqdn, qtype) + _, rrs, err := tryOneName(ctx, conf, fqdn, qtype) lane <- racer{fqdn, rrs, err} }(qtype) } @@ -502,8 +527,8 @@ func goLookupIPOrder(name string, order hostLookupOrder) (addrs []IPAddr, err er // Normally we let cgo use the C library resolver instead of // depending on our lookup code, so that Go and C get the same // answers. -func goLookupCNAME(name string) (cname string, err error) { - _, rrs, err := lookup(name, dnsTypeCNAME) +func goLookupCNAME(ctx context.Context, name string) (cname string, err error) { + _, rrs, err := lookup(ctx, name, dnsTypeCNAME) if err != nil { return } @@ -516,7 +541,7 @@ func goLookupCNAME(name string) (cname string, err error) { // only if cgoLookupPTR is the stub in cgo_stub.go). // Normally we let cgo use the C library resolver instead of depending // on our lookup code, so that Go and C get the same answers. -func goLookupPTR(addr string) ([]string, error) { +func goLookupPTR(ctx context.Context, addr string) ([]string, error) { names := lookupStaticAddr(addr) if len(names) > 0 { return names, nil @@ -525,7 +550,7 @@ func goLookupPTR(addr string) ([]string, error) { if err != nil { return nil, err } - _, rrs, err := lookup(arpa, dnsTypePTR) + _, rrs, err := lookup(ctx, arpa, dnsTypePTR) if err != nil { return nil, err } |