diff --git a/pkg/webserver/webserver.go b/pkg/webserver/webserver.go index 3fd60bc8e..1d183648e 100644 --- a/pkg/webserver/webserver.go +++ b/pkg/webserver/webserver.go @@ -29,6 +29,8 @@ import ( "os" "strings" "time" + + "camlistore.org/third_party/github.com/bradfitz/runsit/listen" ) var Listen = flag.String("listen", "", "host:port to listen on, or :0 to auto-select") @@ -99,23 +101,27 @@ func (s *Server) ServeHTTP(rw http.ResponseWriter, req *http.Request) { // Listen starts listening on the given host:port string. // If listen is empty the *Listen flag will be used instead. -func (s *Server) Listen(listen string) error { +func (s *Server) Listen(addr string) error { if s.listener != nil { return nil } doLog := os.Getenv("TESTING_PORT_WRITE_FD") == "" // Don't make noise during unit tests - if listen == "" { + if addr == "" { if *Listen == "" { return fmt.Errorf("Cannot start listening: host:port needs to be provided with the -listen flag") } - listen = *Listen + addr = *Listen } - var err error - s.listener, err = net.Listen("tcp", listen) + var laddr listen.Addr + err := laddr.Set(addr) if err != nil { - log.Fatalf("Failed to listen on %s: %v", listen, err) + return err + } + s.listener, err = laddr.Listen() + if err != nil { + log.Fatalf("Failed to listen on %s: %v", addr, err) } base := s.BaseURL() if doLog { diff --git a/third_party/github.com/bradfitz/runsit/listen/listen.go b/third_party/github.com/bradfitz/runsit/listen/listen.go new file mode 100644 index 000000000..c7547b08f --- /dev/null +++ b/third_party/github.com/bradfitz/runsit/listen/listen.go @@ -0,0 +1,125 @@ +package listen + +import ( + "errors" + "flag" + "fmt" + "net" + "os" + "strconv" + "strings" +) + +func NewFlag(flagName, defaultValue string, serverType string) *Addr { + addr := &Addr{ + s: defaultValue, + } + flag.Var(addr, flagName, Usage(serverType)) + return addr +} + +// Usage returns a descriptive usage message for a flag given the name +// of thing being addressed. +func Usage(name string) string { + if name == "" { + name = "Listen address" + } + if !strings.HasSuffix(name, " address") { + name += " address" + } + return name + "; may be port, :port, ip:port, FD:, or ADDR: to use named runsit ports" +} + +// Addr is a flag variable. Use like: +// +// var webPort listen.Addr +// flag.Var(&webPort, "web_addr", listen.Usage("Web server address")) +// flag.Parse() +// webListener, err := webPort.Listen() +type Addr struct { + s string + ln net.Listener + err error +} + +func (a *Addr) String() string { + return a.s +} + +// Set implements the flag.Value interface. +func (a *Addr) Set(v string) error { + a.s = v + + // Try the requested port by runsit port name first. + fd, ok, err := namedPort(v) + if err != nil { + return err + } + if ok { + return a.listenOnFD(fd) + } + + if strings.HasPrefix(v, "FD:") { + fdStr := v[len("FD:"):] + fd, err := strconv.ParseUint(fdStr, 10, 32) + if err != nil { + return fmt.Errorf("invalid file descriptor %q: %v", fdStr, err) + } + return a.listenOnFD(uintptr(fd)) + } + + ipPort := v + if isPort(v) { + ipPort = ":" + v + } + + _, _, err = net.SplitHostPort(ipPort) + if err != nil { + return fmt.Errorf("invalid PORT or IP:PORT %q: %v", v, err) + } + a.ln, err = net.Listen("tcp", ipPort) + return err +} + +func isPort(s string) bool { + _, err := strconv.ParseUint(s, 10, 16) + return err == nil +} + +func (a *Addr) listenOnFD(fd uintptr) (err error) { + f := os.NewFile(fd, fmt.Sprintf("fd #%d from runsit parent", fd)) + a.ln, err = net.FileListener(f) + return +} + +func namedPort(name string) (fd uintptr, ok bool, err error) { + s := os.Getenv("RUNSIT_PORTFD_" + name) + if s == "" { + return + } + u64, err := strconv.ParseUint(s, 10, 32) + if err != nil { + return + } + return uintptr(u64), true, nil +} + +var _ flag.Value = (*Addr)(nil) + +// Listen returns the address's TCP listener. +func (a *Addr) Listen() (net.Listener, error) { + // Start the listener now, if there's a default + // and nothing's called Set yet. + if a.err == nil && a.ln == nil && a.s != "" { + if err := a.Set(a.s); err != nil { + return nil, err + } + } + if a.err != nil { + return nil, a.err + } + if a.ln != nil { + return a.ln, nil + } + return nil, errors.New("listen: no error or listener") +}