diff --git a/pkg/netutil/ident.go b/pkg/netutil/ident.go index 4b3565d9f..da1420127 100644 --- a/pkg/netutil/ident.go +++ b/pkg/netutil/ident.go @@ -19,6 +19,7 @@ package netutil import ( "bufio" "bytes" + "encoding/binary" "fmt" "io" "net" @@ -91,18 +92,32 @@ func AddrPairUserid(lipport, ripport string) (uid int, err error) { return uidFromReader(lip, lport, rip, rport, f) } -func reverseIPBytes(b []byte) []byte { - rb := make([]byte, len(b)) - for i, v := range b { - rb[len(b)-i-1] = v +func toLinuxIPv4Order(b []byte) []byte { + binary.BigEndian.PutUint32(b, binary.LittleEndian.Uint32(b)) + return b +} + +func toLinuxIPv6Order(b []byte) []byte { + for i := 0; i < 16; i += 4 { + sb := b[i : i+4] + binary.BigEndian.PutUint32(sb, binary.LittleEndian.Uint32(sb)) } - return rb + return b +} + +type maybeBrackets net.IP + +func (p maybeBrackets) String() string { + s := net.IP(p).String() + if strings.Contains(s, ":") { + return "[" + s + "]" + } + return s } func uidFromDarwinLsof(lip net.IP, lport int, rip net.IP, rport int) (uid int, err error) { - seek := fmt.Sprintf("%s:%d->%s:%d", lip, lport, rip, rport) + seek := fmt.Sprintf("%s:%d->%s:%d", maybeBrackets(lip), lport, maybeBrackets(rip), rport) seekb := []byte(seek) - cmd := exec.Command("lsof", "-a", "-n", "-i", "-P") stdout, err := cmd.StdoutPipe() if err != nil { @@ -152,18 +167,19 @@ func uidFromReader(lip net.IP, lport int, rip net.IP, rport int, r io.Reader) (u localHex := "" remoteHex := "" - if lip.To4() != nil { + ipv4 := lip.To4() != nil + if ipv4 { // In the kernel, the port is run through ntohs(), and // the inet_request_socket in // include/net/inet_socket.h says the "loc_addr" and // "rmt_addr" fields are __be32, but get_openreq4's // printf of them is raw, without byte order // converstion. - localHex = fmt.Sprintf("%08X:%04X", reverseIPBytes([]byte(lip.To4())), lport) - remoteHex = fmt.Sprintf("%08X:%04X", reverseIPBytes([]byte(rip.To4())), rport) + localHex = fmt.Sprintf("%08X:%04X", toLinuxIPv4Order([]byte(lip.To4())), lport) + remoteHex = fmt.Sprintf("%08X:%04X", toLinuxIPv4Order([]byte(rip.To4())), rport) } else { - localHex = fmt.Sprintf("%032X:%04X", []byte(lip.To16()), lport) - remoteHex = fmt.Sprintf("%032X:%04X", []byte(rip.To16()), rport) + localHex = fmt.Sprintf("%032X:%04X", toLinuxIPv6Order([]byte(lip.To16())), lport) + remoteHex = fmt.Sprintf("%032X:%04X", toLinuxIPv6Order([]byte(rip.To16())), rport) } for { diff --git a/pkg/netutil/ident_test.go b/pkg/netutil/ident_test.go index 109582c62..6a8fdb028 100644 --- a/pkg/netutil/ident_test.go +++ b/pkg/netutil/ident_test.go @@ -17,8 +17,12 @@ limitations under the License. package netutil import ( + "fmt" + "io/ioutil" "log" "net" + "net/http" + "net/http/httptest" "os" "strings" "testing" @@ -33,6 +37,19 @@ func TestLocalIPv4(t *testing.T) { if err != nil { t.Fatal(err) } + testLocalListener(t, ln) +} + +func TestLocalIPv6(t *testing.T) { + ln, err := net.Listen("tcp", "[::1]:0") + if err != nil { + t.Logf("skipping IPv6 test; not supported on host machine?") + return + } + testLocalListener(t, ln) +} + +func testLocalListener(t *testing.T, ln net.Listener) { defer ln.Close() // Accept a connection, run ConnUserId (what we're testing), and @@ -77,7 +94,31 @@ func TestLocalIPv4(t *testing.T) { } } -// TODO: test IPv6. probably not working. +func TestHTTPAuth(t *testing.T) { + var ts *httptest.Server + ts = httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + from := r.RemoteAddr + to := ts.Listener.Addr().String() + uid, err := AddrPairUserid(from, to) + if err != nil { + fmt.Fprintf(rw, "ERR: %v", err) + return + } + fmt.Fprintf(rw, "uid=%d", uid) + })) + defer ts.Close() + res, err := http.Get(ts.URL) + if err != nil { + t.Fatal(err) + } + body, err := ioutil.ReadAll(res.Body) + if err != nil { + t.Fatal(err) + } + if g, e := string(body), fmt.Sprintf("uid=%d", os.Getuid()); g != e { + t.Errorf("got body %q; want %q", g, e) + } +} func TestParseLinuxTCPStat4(t *testing.T) { lip, lport := net.ParseIP("67.218.110.129"), 43436