mirror of https://github.com/perkeep/perkeep.git
localhost auth: resolve localhost to [::1] if using ipv6
http://camlistore.org/issue/238 Change-Id: Icab7d87fe651365fb44db4c2874d4976fa631ad6
This commit is contained in:
parent
dfed62b76a
commit
10d67c6d20
|
@ -281,11 +281,11 @@ func (da *DevAuth) AddAuthHeader(req *http.Request) {
|
|||
|
||||
func localhostAuthorized(req *http.Request) bool {
|
||||
uid := os.Getuid()
|
||||
from, err := netutil.HostPortToIP(req.RemoteAddr)
|
||||
from, err := netutil.HostPortToIP(req.RemoteAddr, nil)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
to, err := netutil.HostPortToIP(req.Host)
|
||||
to, err := netutil.HostPortToIP(req.Host, from)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
|
|
@ -18,6 +18,10 @@ package auth
|
|||
|
||||
import (
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
@ -51,3 +55,78 @@ func TestFromConfig(t *testing.T) {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
func testServer(t *testing.T, l net.Listener) *httptest.Server {
|
||||
ts := &httptest.Server{
|
||||
Listener: l,
|
||||
Config: &http.Server{
|
||||
Handler: http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
||||
if localhostAuthorized(r) {
|
||||
fmt.Fprintf(rw, "authorized")
|
||||
return
|
||||
}
|
||||
fmt.Fprintf(rw, "unauthorized")
|
||||
}),
|
||||
},
|
||||
}
|
||||
ts.Start()
|
||||
|
||||
return ts
|
||||
}
|
||||
|
||||
func TestLocalhostAuthIPv6(t *testing.T) {
|
||||
l, err := net.Listen("tcp", "[::1]:0")
|
||||
if err != nil {
|
||||
t.Skip("skipping IPv6 test; can't listen on [::1]:0")
|
||||
}
|
||||
_, port, err := net.SplitHostPort(l.Addr().String())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
ts := testServer(t, l)
|
||||
defer ts.Close()
|
||||
|
||||
// Use an explicit transport to force IPv6 (http.Get resolves localhost in IPv4 otherwise)
|
||||
trans := &http.Transport{
|
||||
Dial: func(network, addr string) (net.Conn, error) {
|
||||
c, err := net.Dial("tcp6", addr)
|
||||
return c, err
|
||||
},
|
||||
}
|
||||
|
||||
testLoginRequest(t, &http.Client{Transport: trans}, "http://[::1]:"+port)
|
||||
testLoginRequest(t, &http.Client{Transport: trans}, "http://localhost:"+port)
|
||||
}
|
||||
|
||||
func TestLocalhostAuthIPv4(t *testing.T) {
|
||||
l, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Skip("skipping IPv4 test; can't listen on 127.0.0.1:0")
|
||||
}
|
||||
_, port, err := net.SplitHostPort(l.Addr().String())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
ts := testServer(t, l)
|
||||
defer ts.Close()
|
||||
|
||||
testLoginRequest(t, &http.Client{}, "http://127.0.0.1:"+port)
|
||||
testLoginRequest(t, &http.Client{}, "http://localhost:"+port)
|
||||
}
|
||||
|
||||
func testLoginRequest(t *testing.T, client *http.Client, URL string) {
|
||||
res, err := client.Get(URL)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
body, err := ioutil.ReadAll(res.Body)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
const exp = "authorized"
|
||||
if string(body) != exp {
|
||||
t.Errorf("got %q (instead of %v)", string(body), exp)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -46,9 +46,9 @@ func ConnUserid(conn net.Conn) (uid int, err error) {
|
|||
return AddrPairUserid(conn.LocalAddr(), conn.RemoteAddr())
|
||||
}
|
||||
|
||||
// HostPortToIP parses a host:port to a TCPAddr without resolving names
|
||||
// other than localhost. It will return an error instead of resolving.
|
||||
func HostPortToIP(hostport string) (hostaddr *net.TCPAddr, err error) {
|
||||
// HostPortToIP parses a host:port to a TCPAddr without resolving names.
|
||||
// If given a context IP, it will resolve localhost to match the context's IP family.
|
||||
func HostPortToIP(hostport string, ctx *net.TCPAddr) (hostaddr *net.TCPAddr, err error) {
|
||||
host, port, err := net.SplitHostPort(hostport)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -58,8 +58,12 @@ func HostPortToIP(hostport string) (hostaddr *net.TCPAddr, err error) {
|
|||
return nil, fmt.Errorf("invalid port %s", iport)
|
||||
}
|
||||
var addr net.IP
|
||||
if host == "localhost" {
|
||||
addr = net.IPv4(127, 0, 0, 1)
|
||||
if ctx != nil && host == "localhost" {
|
||||
if ctx.IP.To4() != nil {
|
||||
addr = net.IPv4(127, 0, 0, 1)
|
||||
} else {
|
||||
addr = net.IP{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1}
|
||||
}
|
||||
} else if addr = net.ParseIP(host); addr == nil {
|
||||
return nil, fmt.Errorf("could not parse IP %s", host)
|
||||
}
|
||||
|
|
|
@ -98,7 +98,7 @@ func testLocalListener(t *testing.T, ln net.Listener) {
|
|||
func TestHTTPAuth(t *testing.T) {
|
||||
var ts *httptest.Server
|
||||
ts = httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
||||
from, err := HostPortToIP(r.RemoteAddr)
|
||||
from, err := HostPortToIP(r.RemoteAddr, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue