diff --git a/pkg/netutil/ident.go b/pkg/netutil/ident.go index 1776b25d8..ba3f5786c 100644 --- a/pkg/netutil/ident.go +++ b/pkg/netutil/ident.go @@ -273,16 +273,3 @@ func uidFromProcReader(lip net.IP, lport int, rip net.IP, rport int, r io.Reader } panic("unreachable") } - -// Localhost returns the first address found when -// doing a lookup of "localhost". -func Localhost() (net.IP, error) { - ips, err := net.LookupIP("localhost") - if err != nil { - return nil, err - } - if len(ips) < 1 { - return nil, errors.New("IP lookup for localhost returned no result") - } - return ips[0], nil -} diff --git a/pkg/netutil/netutil.go b/pkg/netutil/netutil.go index 5d517e078..2a341c270 100644 --- a/pkg/netutil/netutil.go +++ b/pkg/netutil/netutil.go @@ -17,6 +17,7 @@ limitations under the License. package netutil import ( + "errors" "fmt" "net" "net/url" @@ -67,3 +68,62 @@ func HostPort(urlStr string) (string, error) { } return hostPort, nil } + +// ListenOnLocalRandomPort returns a tcp listener on a local (see LoopbackIP) random port. +func ListenOnLocalRandomPort() (net.Listener, error) { + ip, err := Localhost() + if err != nil { + return nil, err + } + l, err := net.ListenTCP("tcp", &net.TCPAddr{IP: ip, Port: 0}) + if err != nil { + return nil, err + } + return l, nil +} + +// Localhost returns the first address found when +// doing a lookup of "localhost". If not successful, +// it looks for an ip on the loopback interfaces. +func Localhost() (net.IP, error) { + if ip := localhostLookup(); ip != nil { + return ip, nil + } + if ip := loopbackIP(); ip != nil { + return ip, nil + } + return nil, errors.New("No loopback ip found.") +} + +// localhostLookup looks for a loopback IP by resolving localhost. +func localhostLookup() net.IP { + if ips, err := net.LookupIP("localhost"); err == nil && len(ips) > 0 { + return ips[0] + } + return nil +} + +const flagUpLoopback = net.FlagUp | net.FlagLoopback + +// loopbackIP finds the first loopback IP address sniffing network interfaces. +func loopbackIP() net.IP { + interfaces, err := net.Interfaces() + if err != nil { + return nil + } + for _, inf := range interfaces { + if inf.Flags&flagUpLoopback == flagUpLoopback { + addrs, err := inf.Addrs() + if err != nil { + continue + } + for _, addr := range addrs { + ip, _, err := net.ParseCIDR(addr.String()) + if err == nil && ip.IsLoopback() { + return ip + } + } + } + } + return nil +} diff --git a/pkg/netutil/netutil_test.go b/pkg/netutil/netutil_test.go index a5cd5bd3c..644d7eb6b 100644 --- a/pkg/netutil/netutil_test.go +++ b/pkg/netutil/netutil_test.go @@ -17,6 +17,8 @@ limitations under the License. package netutil import ( + "net" + "strconv" "testing" ) @@ -120,3 +122,60 @@ func TestHostPort(t *testing.T) { } } } + +func testLocalhostResolver(t *testing.T, resolve func() net.IP) { + ip := resolve() + if ip == nil { + t.Fatal("no ip found.") + } + if !ip.IsLoopback() { + t.Errorf("expected a loopback address: %s", ip) + } +} + +func testLocalhost(t *testing.T) { + testLocalhostResolver(t, localhostLookup) +} + +func testLoopbackIp(t *testing.T) { + testLocalhostResolver(t, loopbackIP) +} + +func TestLocalhost(t *testing.T) { + _, err := Localhost() + if err != nil { + t.Fatal(err) + } +} + +func TestListenOnLocalRandomPort(t *testing.T) { + l, err := ListenOnLocalRandomPort() + if err != nil { + t.Fatalf("unexpected error %v", err) + } + defer l.Close() + + _, port, err := net.SplitHostPort(l.Addr().String()) + if err != nil { + t.Fatal(err) + } + if p, _ := strconv.Atoi(port); p < 1 { + t.Fatalf("expected port(%d) to be > 0", p) + } +} + +func BenchmarkLocalhostLookup(b *testing.B) { + for i := 0; i < b.N; i++ { + if ip := localhostLookup(); ip == nil { + b.Fatal("no ip found.") + } + } +} + +func BenchmarkLoopbackIP(b *testing.B) { + for i := 0; i < b.N; i++ { + if ip := loopbackIP(); ip == nil { + b.Fatal("no ip found.") + } + } +}