localhost auth: resolve localhost to [::1] if using ipv6

http://camlistore.org/issue/238

Change-Id: Icab7d87fe651365fb44db4c2874d4976fa631ad6
This commit is contained in:
Salman Aljammaz 2013-10-19 15:09:48 +01:00
parent dfed62b76a
commit 10d67c6d20
4 changed files with 91 additions and 8 deletions

View File

@ -281,11 +281,11 @@ func (da *DevAuth) AddAuthHeader(req *http.Request) {
func localhostAuthorized(req *http.Request) bool {
uid := os.Getuid()
from, err := netutil.HostPortToIP(req.RemoteAddr)
from, err := netutil.HostPortToIP(req.RemoteAddr, nil)
if err != nil {
return false
}
to, err := netutil.HostPortToIP(req.Host)
to, err := netutil.HostPortToIP(req.Host, from)
if err != nil {
return false
}

View File

@ -18,6 +18,10 @@ package auth
import (
"fmt"
"io/ioutil"
"net"
"net/http"
"net/http/httptest"
"reflect"
"testing"
)
@ -51,3 +55,78 @@ func TestFromConfig(t *testing.T) {
}
}
}
func testServer(t *testing.T, l net.Listener) *httptest.Server {
ts := &httptest.Server{
Listener: l,
Config: &http.Server{
Handler: http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
if localhostAuthorized(r) {
fmt.Fprintf(rw, "authorized")
return
}
fmt.Fprintf(rw, "unauthorized")
}),
},
}
ts.Start()
return ts
}
func TestLocalhostAuthIPv6(t *testing.T) {
l, err := net.Listen("tcp", "[::1]:0")
if err != nil {
t.Skip("skipping IPv6 test; can't listen on [::1]:0")
}
_, port, err := net.SplitHostPort(l.Addr().String())
if err != nil {
t.Fatal(err)
}
ts := testServer(t, l)
defer ts.Close()
// Use an explicit transport to force IPv6 (http.Get resolves localhost in IPv4 otherwise)
trans := &http.Transport{
Dial: func(network, addr string) (net.Conn, error) {
c, err := net.Dial("tcp6", addr)
return c, err
},
}
testLoginRequest(t, &http.Client{Transport: trans}, "http://[::1]:"+port)
testLoginRequest(t, &http.Client{Transport: trans}, "http://localhost:"+port)
}
func TestLocalhostAuthIPv4(t *testing.T) {
l, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Skip("skipping IPv4 test; can't listen on 127.0.0.1:0")
}
_, port, err := net.SplitHostPort(l.Addr().String())
if err != nil {
t.Fatal(err)
}
ts := testServer(t, l)
defer ts.Close()
testLoginRequest(t, &http.Client{}, "http://127.0.0.1:"+port)
testLoginRequest(t, &http.Client{}, "http://localhost:"+port)
}
func testLoginRequest(t *testing.T, client *http.Client, URL string) {
res, err := client.Get(URL)
if err != nil {
t.Fatal(err)
}
body, err := ioutil.ReadAll(res.Body)
if err != nil {
t.Fatal(err)
}
const exp = "authorized"
if string(body) != exp {
t.Errorf("got %q (instead of %v)", string(body), exp)
}
}

View File

@ -46,9 +46,9 @@ func ConnUserid(conn net.Conn) (uid int, err error) {
return AddrPairUserid(conn.LocalAddr(), conn.RemoteAddr())
}
// HostPortToIP parses a host:port to a TCPAddr without resolving names
// other than localhost. It will return an error instead of resolving.
func HostPortToIP(hostport string) (hostaddr *net.TCPAddr, err error) {
// HostPortToIP parses a host:port to a TCPAddr without resolving names.
// If given a context IP, it will resolve localhost to match the context's IP family.
func HostPortToIP(hostport string, ctx *net.TCPAddr) (hostaddr *net.TCPAddr, err error) {
host, port, err := net.SplitHostPort(hostport)
if err != nil {
return nil, err
@ -58,8 +58,12 @@ func HostPortToIP(hostport string) (hostaddr *net.TCPAddr, err error) {
return nil, fmt.Errorf("invalid port %s", iport)
}
var addr net.IP
if host == "localhost" {
addr = net.IPv4(127, 0, 0, 1)
if ctx != nil && host == "localhost" {
if ctx.IP.To4() != nil {
addr = net.IPv4(127, 0, 0, 1)
} else {
addr = net.IP{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1}
}
} else if addr = net.ParseIP(host); addr == nil {
return nil, fmt.Errorf("could not parse IP %s", host)
}

View File

@ -98,7 +98,7 @@ func testLocalListener(t *testing.T, ln net.Listener) {
func TestHTTPAuth(t *testing.T) {
var ts *httptest.Server
ts = httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
from, err := HostPortToIP(r.RemoteAddr)
from, err := HostPortToIP(r.RemoteAddr, nil)
if err != nil {
t.Fatal(err)
}