diff --git a/pkg/auth/auth.go b/pkg/auth/auth.go index ea576ecf7..462091bd1 100644 --- a/pkg/auth/auth.go +++ b/pkg/auth/auth.go @@ -281,11 +281,11 @@ func (da *DevAuth) AddAuthHeader(req *http.Request) { func localhostAuthorized(req *http.Request) bool { uid := os.Getuid() - from, err := netutil.HostPortToIP(req.RemoteAddr) + from, err := netutil.HostPortToIP(req.RemoteAddr, nil) if err != nil { return false } - to, err := netutil.HostPortToIP(req.Host) + to, err := netutil.HostPortToIP(req.Host, from) if err != nil { return false } diff --git a/pkg/auth/auth_test.go b/pkg/auth/auth_test.go index 994b8f1a8..80ffeabe4 100644 --- a/pkg/auth/auth_test.go +++ b/pkg/auth/auth_test.go @@ -18,6 +18,10 @@ package auth import ( "fmt" + "io/ioutil" + "net" + "net/http" + "net/http/httptest" "reflect" "testing" ) @@ -51,3 +55,78 @@ func TestFromConfig(t *testing.T) { } } } + +func testServer(t *testing.T, l net.Listener) *httptest.Server { + ts := &httptest.Server{ + Listener: l, + Config: &http.Server{ + Handler: http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + if localhostAuthorized(r) { + fmt.Fprintf(rw, "authorized") + return + } + fmt.Fprintf(rw, "unauthorized") + }), + }, + } + ts.Start() + + return ts +} + +func TestLocalhostAuthIPv6(t *testing.T) { + l, err := net.Listen("tcp", "[::1]:0") + if err != nil { + t.Skip("skipping IPv6 test; can't listen on [::1]:0") + } + _, port, err := net.SplitHostPort(l.Addr().String()) + if err != nil { + t.Fatal(err) + } + + ts := testServer(t, l) + defer ts.Close() + + // Use an explicit transport to force IPv6 (http.Get resolves localhost in IPv4 otherwise) + trans := &http.Transport{ + Dial: func(network, addr string) (net.Conn, error) { + c, err := net.Dial("tcp6", addr) + return c, err + }, + } + + testLoginRequest(t, &http.Client{Transport: trans}, "http://[::1]:"+port) + testLoginRequest(t, &http.Client{Transport: trans}, "http://localhost:"+port) +} + +func TestLocalhostAuthIPv4(t *testing.T) { + l, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Skip("skipping IPv4 test; can't listen on 127.0.0.1:0") + } + _, port, err := net.SplitHostPort(l.Addr().String()) + if err != nil { + t.Fatal(err) + } + + ts := testServer(t, l) + defer ts.Close() + + testLoginRequest(t, &http.Client{}, "http://127.0.0.1:"+port) + testLoginRequest(t, &http.Client{}, "http://localhost:"+port) +} + +func testLoginRequest(t *testing.T, client *http.Client, URL string) { + res, err := client.Get(URL) + if err != nil { + t.Fatal(err) + } + body, err := ioutil.ReadAll(res.Body) + if err != nil { + t.Fatal(err) + } + const exp = "authorized" + if string(body) != exp { + t.Errorf("got %q (instead of %v)", string(body), exp) + } +} diff --git a/pkg/netutil/ident.go b/pkg/netutil/ident.go index f55fffeba..d144bd4d8 100644 --- a/pkg/netutil/ident.go +++ b/pkg/netutil/ident.go @@ -46,9 +46,9 @@ func ConnUserid(conn net.Conn) (uid int, err error) { return AddrPairUserid(conn.LocalAddr(), conn.RemoteAddr()) } -// HostPortToIP parses a host:port to a TCPAddr without resolving names -// other than localhost. It will return an error instead of resolving. -func HostPortToIP(hostport string) (hostaddr *net.TCPAddr, err error) { +// HostPortToIP parses a host:port to a TCPAddr without resolving names. +// If given a context IP, it will resolve localhost to match the context's IP family. +func HostPortToIP(hostport string, ctx *net.TCPAddr) (hostaddr *net.TCPAddr, err error) { host, port, err := net.SplitHostPort(hostport) if err != nil { return nil, err @@ -58,8 +58,12 @@ func HostPortToIP(hostport string) (hostaddr *net.TCPAddr, err error) { return nil, fmt.Errorf("invalid port %s", iport) } var addr net.IP - if host == "localhost" { - addr = net.IPv4(127, 0, 0, 1) + if ctx != nil && host == "localhost" { + if ctx.IP.To4() != nil { + addr = net.IPv4(127, 0, 0, 1) + } else { + addr = net.IP{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1} + } } else if addr = net.ParseIP(host); addr == nil { return nil, fmt.Errorf("could not parse IP %s", host) } diff --git a/pkg/netutil/ident_test.go b/pkg/netutil/ident_test.go index da1cfb4b1..a28f210bc 100644 --- a/pkg/netutil/ident_test.go +++ b/pkg/netutil/ident_test.go @@ -98,7 +98,7 @@ func testLocalListener(t *testing.T, ln net.Listener) { func TestHTTPAuth(t *testing.T) { var ts *httptest.Server ts = httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { - from, err := HostPortToIP(r.RemoteAddr) + from, err := HostPortToIP(r.RemoteAddr, nil) if err != nil { t.Fatal(err) }