vendor: Update github.com/gorilla/websocket to v1.2.0

Maps to rev: gorilla/websocket@ea4d1f681b

Also adjust the search websocket code to use the new API
which does proper origin checks.

Changelog: https://github.com/gorilla/websocket/releases/tag/v1.2.0

Change-Id: Ie25490b6319a98345aa6272e9eabc35c298bc06d
This commit is contained in:
Paul Lindner 2018-01-11 09:12:32 -08:00
parent 020c6923ce
commit 837fe8ac46
31 changed files with 3471 additions and 516 deletions

5
Gopkg.lock generated
View File

@ -117,7 +117,8 @@
[[projects]]
name = "github.com/gorilla/websocket"
packages = ["."]
revision = "2119675aadf"
revision = "ea4d1f681babbce9545c9c5f3d5194a789c89f5b"
version = "v1.2.0"
[[projects]]
name = "github.com/hjfreyer/taglib-go"
@ -345,6 +346,6 @@
[solve-meta]
analyzer-name = "dep"
analyzer-version = 1
inputs-digest = "df510f0b314e2cffb83b9b69ad0cd2cfa06b3cf97f575ba11f0a35b1152be7ca"
inputs-digest = "36cd029d45eecfdd9b745610c39865be9efc9aa37f6331570b798e9c8ed649d9"
solver-name = "gps-cdcl"
solver-version = 1

View File

@ -131,7 +131,7 @@ required = [
[[constraint]]
name = "github.com/gorilla/websocket"
revision = "2119675aadf"
version = "1.2.0"
[[constraint]]
name = "github.com/hjfreyer/taglib-go"

View File

@ -291,8 +291,15 @@ func (c *wsConn) writePump() {
}
}
// upgrader is used in serveWebSocket to construct websocket connections.
var upgrader = websocket.Upgrader{
ReadBufferSize: 1024,
WriteBufferSize: 1024,
// uses a default origin check policy
}
func (sh *Handler) serveWebSocket(rw http.ResponseWriter, req *http.Request) {
ws, err := websocket.Upgrade(rw, req, nil, 1024, 1024)
ws, err := upgrader.Upgrade(rw, req, nil)
if _, ok := err.(websocket.HandshakeError); ok {
http.Error(rw, "Not a websocket handshake", 400)
return

8
vendor/github.com/gorilla/websocket/AUTHORS generated vendored Normal file
View File

@ -0,0 +1,8 @@
# This is the official list of Gorilla WebSocket authors for copyright
# purposes.
#
# Please keep the list sorted.
Gary Burd <gary@beagledreams.com>
Joachim Bauch <mail@joachim-bauch.de>

View File

@ -1,23 +1,22 @@
Copyright (c) 2013, Gorilla web toolkit
All rights reserved.
Copyright (c) 2013 The Gorilla WebSocket Authors. All rights reserved.
Redistribution and use in source and binary forms, with or without modification,
are permitted provided that the following conditions are met:
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
Redistributions in binary form must reproduce the above copyright notice, this
list of conditions and the following disclaimer in the documentation and/or
other materials provided with the distribution.
Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR
ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

View File

@ -1,18 +1,24 @@
# Gorilla WebSocket
# Gorilla WebSocket
Gorilla WebSocket is a [Go](http://golang.org/) implementation of the
[WebSocket](http://www.rfc-editor.org/rfc/rfc6455.txt) protocol.
[![Build Status](https://travis-ci.org/gorilla/websocket.svg?branch=master)](https://travis-ci.org/gorilla/websocket)
[![GoDoc](https://godoc.org/github.com/gorilla/websocket?status.svg)](https://godoc.org/github.com/gorilla/websocket)
### Documentation
* [Reference](http://godoc.org/github.com/gorilla/websocket)
* [API Reference](http://godoc.org/github.com/gorilla/websocket)
* [Chat example](https://github.com/gorilla/websocket/tree/master/examples/chat)
* [Command example](https://github.com/gorilla/websocket/tree/master/examples/command)
* [Client and server example](https://github.com/gorilla/websocket/tree/master/examples/echo)
* [File watch example](https://github.com/gorilla/websocket/tree/master/examples/filewatch)
### Status
The Gorilla WebSocket package provides a complete and tested implementation of
the [WebSocket](http://www.rfc-editor.org/rfc/rfc6455.txt) protocol. The
package API is stable.
package API is stable.
### Installation
@ -20,7 +26,7 @@ package API is stable.
### Protocol Compliance
The Gorilla WebSocket package passes the server tests in the [Autobahn WebSockets Test
The Gorilla WebSocket package passes the server tests in the [Autobahn Test
Suite](http://autobahn.ws/testsuite) using the application in the [examples/autobahn
subdirectory](https://github.com/gorilla/websocket/tree/master/examples/autobahn).
@ -29,21 +35,30 @@ subdirectory](https://github.com/gorilla/websocket/tree/master/examples/autobahn
<table>
<tr>
<th></th>
<th><a href="http://godoc.org/github.com/gorilla/websocket">gorilla</a></th>
<th><a href="http://godoc.org/code.google.com/p/go.net/websocket">go.net</a></th>
<th><a href="http://godoc.org/github.com/gorilla/websocket">github.com/gorilla</a></th>
<th><a href="http://godoc.org/golang.org/x/net/websocket">golang.org/x/net</a></th>
</tr>
<tr>
<tr><td>Protocol support</td><td>RFC 6455</td><td>RFC 6455</td></tr>
<tr><td>Limit size of received message</td><td>Yes</td><td>No</td></tr>
<tr><td>Send pings and receive pongs</td><td>Yes</td><td>No</td></tr>
<tr><td>Send close message</td><td>Yes</td><td>No</td></tr>
<tr><td>Read message using io.Reader</td><td>Yes</td><td>No, see note</td></tr>
<tr><td>Write message using io.WriteCloser</td><td>Yes</td><td>No, see note</td></tr>
<tr><td>Encode, decode JSON message</td><td>Yes</td><td>Yes</td></tr>
<tr><td colspan="3"><a href="http://tools.ietf.org/html/rfc6455">RFC 6455</a> Features</td></tr>
<tr><td>Passes <a href="http://autobahn.ws/testsuite/">Autobahn Test Suite</a></td><td><a href="https://github.com/gorilla/websocket/tree/master/examples/autobahn">Yes</a></td><td>No</td></tr>
<tr><td>Receive <a href="https://tools.ietf.org/html/rfc6455#section-5.4">fragmented</a> message<td>Yes</td><td><a href="https://code.google.com/p/go/issues/detail?id=7632">No</a>, see note 1</td></tr>
<tr><td>Send <a href="https://tools.ietf.org/html/rfc6455#section-5.5.1">close</a> message</td><td><a href="http://godoc.org/github.com/gorilla/websocket#hdr-Control_Messages">Yes</a></td><td><a href="https://code.google.com/p/go/issues/detail?id=4588">No</a></td></tr>
<tr><td>Send <a href="https://tools.ietf.org/html/rfc6455#section-5.5.2">pings</a> and receive <a href="https://tools.ietf.org/html/rfc6455#section-5.5.3">pongs</a></td><td><a href="http://godoc.org/github.com/gorilla/websocket#hdr-Control_Messages">Yes</a></td><td>No</td></tr>
<tr><td>Get the <a href="https://tools.ietf.org/html/rfc6455#section-5.6">type</a> of a received data message</td><td>Yes</td><td>Yes, see note 2</td></tr>
<tr><td colspan="3">Other Features</tr></td>
<tr><td><a href="https://tools.ietf.org/html/rfc7692">Compression Extensions</a></td><td>Experimental</td><td>No</td></tr>
<tr><td>Read message using io.Reader</td><td><a href="http://godoc.org/github.com/gorilla/websocket#Conn.NextReader">Yes</a></td><td>No, see note 3</td></tr>
<tr><td>Write message using io.WriteCloser</td><td><a href="http://godoc.org/github.com/gorilla/websocket#Conn.NextWriter">Yes</a></td><td>No, see note 3</td></tr>
</table>
Note: The go.net io.Reader and io.Writer operate across WebSocket message
boundaries. Read returns when the input buffer is full or a message boundary is
encountered, Each call to Write sends a message. The Gorilla io.Reader and
io.WriteCloser operate on a single WebSocket message.
Notes:
1. Large messages are fragmented in [Chrome's new WebSocket implementation](http://www.ietf.org/mail-archive/web/hybi/current/msg10503.html).
2. The application can get the type of a received data message by implementing
a [Codec marshal](http://godoc.org/golang.org/x/net/websocket#Codec.Marshal)
function.
3. The go.net io.Reader and io.Writer operate across WebSocket frame boundaries.
Read returns when the input buffer is full or a frame boundary is
encountered. Each call to Write sends a single frame message. The Gorilla
io.Reader and io.WriteCloser operate on a single WebSocket message.

View File

@ -1,69 +1,392 @@
// Copyright 2013 Gary Burd. All rights reserved.
// Copyright 2013 The Gorilla WebSocket Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package websocket
import (
"bufio"
"bytes"
"crypto/tls"
"encoding/base64"
"errors"
"io"
"io/ioutil"
"net"
"net/http"
"net/url"
"strings"
"time"
)
// ErrBadHandshake is returned when the server response to opening handshake is
// invalid.
var ErrBadHandshake = errors.New("websocket: bad handshake")
var errInvalidCompression = errors.New("websocket: invalid compression negotiation")
// NewClient creates a new client connection using the given net connection.
// The URL u specifies the host and request URI. Use requestHeader to specify
// the origin (Origin), subprotocols (Set-WebSocket-Protocol) and cookies
// the origin (Origin), subprotocols (Sec-WebSocket-Protocol) and cookies
// (Cookie). Use the response.Header to get the selected subprotocol
// (Sec-WebSocket-Protocol) and cookies (Set-Cookie).
//
// If the WebSocket handshake fails, ErrBadHandshake is returned along with a
// non-nil *http.Response so that callers can handle redirects, authentication,
// etc.
//
// Deprecated: Use Dialer instead.
func NewClient(netConn net.Conn, u *url.URL, requestHeader http.Header, readBufSize, writeBufSize int) (c *Conn, response *http.Response, err error) {
d := Dialer{
ReadBufferSize: readBufSize,
WriteBufferSize: writeBufSize,
NetDial: func(net, addr string) (net.Conn, error) {
return netConn, nil
},
}
return d.Dial(u.String(), requestHeader)
}
// A Dialer contains options for connecting to WebSocket server.
type Dialer struct {
// NetDial specifies the dial function for creating TCP connections. If
// NetDial is nil, net.Dial is used.
NetDial func(network, addr string) (net.Conn, error)
// Proxy specifies a function to return a proxy for a given
// Request. If the function returns a non-nil error, the
// request is aborted with the provided error.
// If Proxy is nil or returns a nil *URL, no proxy is used.
Proxy func(*http.Request) (*url.URL, error)
// TLSClientConfig specifies the TLS configuration to use with tls.Client.
// If nil, the default configuration is used.
TLSClientConfig *tls.Config
// HandshakeTimeout specifies the duration for the handshake to complete.
HandshakeTimeout time.Duration
// ReadBufferSize and WriteBufferSize specify I/O buffer sizes. If a buffer
// size is zero, then a useful default size is used. The I/O buffer sizes
// do not limit the size of the messages that can be sent or received.
ReadBufferSize, WriteBufferSize int
// Subprotocols specifies the client's requested subprotocols.
Subprotocols []string
// EnableCompression specifies if the client should attempt to negotiate
// per message compression (RFC 7692). Setting this value to true does not
// guarantee that compression will be supported. Currently only "no context
// takeover" modes are supported.
EnableCompression bool
// Jar specifies the cookie jar.
// If Jar is nil, cookies are not sent in requests and ignored
// in responses.
Jar http.CookieJar
}
var errMalformedURL = errors.New("malformed ws or wss URL")
// parseURL parses the URL.
//
// This function is a replacement for the standard library url.Parse function.
// In Go 1.4 and earlier, url.Parse loses information from the path.
func parseURL(s string) (*url.URL, error) {
// From the RFC:
//
// ws-URI = "ws:" "//" host [ ":" port ] path [ "?" query ]
// wss-URI = "wss:" "//" host [ ":" port ] path [ "?" query ]
var u url.URL
switch {
case strings.HasPrefix(s, "ws://"):
u.Scheme = "ws"
s = s[len("ws://"):]
case strings.HasPrefix(s, "wss://"):
u.Scheme = "wss"
s = s[len("wss://"):]
default:
return nil, errMalformedURL
}
if i := strings.Index(s, "?"); i >= 0 {
u.RawQuery = s[i+1:]
s = s[:i]
}
if i := strings.Index(s, "/"); i >= 0 {
u.Opaque = s[i:]
s = s[:i]
} else {
u.Opaque = "/"
}
u.Host = s
if strings.Contains(u.Host, "@") {
// Don't bother parsing user information because user information is
// not allowed in websocket URIs.
return nil, errMalformedURL
}
return &u, nil
}
func hostPortNoPort(u *url.URL) (hostPort, hostNoPort string) {
hostPort = u.Host
hostNoPort = u.Host
if i := strings.LastIndex(u.Host, ":"); i > strings.LastIndex(u.Host, "]") {
hostNoPort = hostNoPort[:i]
} else {
switch u.Scheme {
case "wss":
hostPort += ":443"
case "https":
hostPort += ":443"
default:
hostPort += ":80"
}
}
return hostPort, hostNoPort
}
// DefaultDialer is a dialer with all fields set to the default zero values.
var DefaultDialer = &Dialer{
Proxy: http.ProxyFromEnvironment,
}
// Dial creates a new client connection. Use requestHeader to specify the
// origin (Origin), subprotocols (Sec-WebSocket-Protocol) and cookies (Cookie).
// Use the response.Header to get the selected subprotocol
// (Sec-WebSocket-Protocol) and cookies (Set-Cookie).
//
// If the WebSocket handshake fails, ErrBadHandshake is returned along with a
// non-nil *http.Response so that callers can handle redirects, authentication,
// etcetera. The response body may not contain the entire response and does not
// need to be closed by the application.
func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Response, error) {
if d == nil {
d = &Dialer{
Proxy: http.ProxyFromEnvironment,
}
}
challengeKey, err := generateChallengeKey()
if err != nil {
return nil, nil, err
}
acceptKey := computeAcceptKey(challengeKey)
c = newConn(netConn, false, readBufSize, writeBufSize)
p := c.writeBuf[:0]
p = append(p, "GET "...)
p = append(p, u.RequestURI()...)
p = append(p, " HTTP/1.1\r\nHost: "...)
p = append(p, u.Host...)
p = append(p, "\r\nUpgrade: websocket\r\nConnection: upgrade\r\nSec-WebSocket-Version: 13\r\nSec-WebSocket-Key: "...)
p = append(p, challengeKey...)
p = append(p, "\r\n"...)
for k, vs := range requestHeader {
for _, v := range vs {
p = append(p, k...)
p = append(p, ": "...)
p = append(p, v...)
p = append(p, "\r\n"...)
}
}
p = append(p, "\r\n"...)
if _, err := netConn.Write(p); err != nil {
return nil, nil, err
}
resp, err := http.ReadResponse(c.br, &http.Request{Method: "GET", URL: u})
u, err := parseURL(urlStr)
if err != nil {
return nil, nil, err
}
switch u.Scheme {
case "ws":
u.Scheme = "http"
case "wss":
u.Scheme = "https"
default:
return nil, nil, errMalformedURL
}
if u.User != nil {
// User name and password are not allowed in websocket URIs.
return nil, nil, errMalformedURL
}
req := &http.Request{
Method: "GET",
URL: u,
Proto: "HTTP/1.1",
ProtoMajor: 1,
ProtoMinor: 1,
Header: make(http.Header),
Host: u.Host,
}
// Set the cookies present in the cookie jar of the dialer
if d.Jar != nil {
for _, cookie := range d.Jar.Cookies(u) {
req.AddCookie(cookie)
}
}
// Set the request headers using the capitalization for names and values in
// RFC examples. Although the capitalization shouldn't matter, there are
// servers that depend on it. The Header.Set method is not used because the
// method canonicalizes the header names.
req.Header["Upgrade"] = []string{"websocket"}
req.Header["Connection"] = []string{"Upgrade"}
req.Header["Sec-WebSocket-Key"] = []string{challengeKey}
req.Header["Sec-WebSocket-Version"] = []string{"13"}
if len(d.Subprotocols) > 0 {
req.Header["Sec-WebSocket-Protocol"] = []string{strings.Join(d.Subprotocols, ", ")}
}
for k, vs := range requestHeader {
switch {
case k == "Host":
if len(vs) > 0 {
req.Host = vs[0]
}
case k == "Upgrade" ||
k == "Connection" ||
k == "Sec-Websocket-Key" ||
k == "Sec-Websocket-Version" ||
k == "Sec-Websocket-Extensions" ||
(k == "Sec-Websocket-Protocol" && len(d.Subprotocols) > 0):
return nil, nil, errors.New("websocket: duplicate header not allowed: " + k)
default:
req.Header[k] = vs
}
}
if d.EnableCompression {
req.Header.Set("Sec-Websocket-Extensions", "permessage-deflate; server_no_context_takeover; client_no_context_takeover")
}
hostPort, hostNoPort := hostPortNoPort(u)
var proxyURL *url.URL
// Check wether the proxy method has been configured
if d.Proxy != nil {
proxyURL, err = d.Proxy(req)
}
if err != nil {
return nil, nil, err
}
var targetHostPort string
if proxyURL != nil {
targetHostPort, _ = hostPortNoPort(proxyURL)
} else {
targetHostPort = hostPort
}
var deadline time.Time
if d.HandshakeTimeout != 0 {
deadline = time.Now().Add(d.HandshakeTimeout)
}
netDial := d.NetDial
if netDial == nil {
netDialer := &net.Dialer{Deadline: deadline}
netDial = netDialer.Dial
}
netConn, err := netDial("tcp", targetHostPort)
if err != nil {
return nil, nil, err
}
defer func() {
if netConn != nil {
netConn.Close()
}
}()
if err := netConn.SetDeadline(deadline); err != nil {
return nil, nil, err
}
if proxyURL != nil {
connectHeader := make(http.Header)
if user := proxyURL.User; user != nil {
proxyUser := user.Username()
if proxyPassword, passwordSet := user.Password(); passwordSet {
credential := base64.StdEncoding.EncodeToString([]byte(proxyUser + ":" + proxyPassword))
connectHeader.Set("Proxy-Authorization", "Basic "+credential)
}
}
connectReq := &http.Request{
Method: "CONNECT",
URL: &url.URL{Opaque: hostPort},
Host: hostPort,
Header: connectHeader,
}
connectReq.Write(netConn)
// Read response.
// Okay to use and discard buffered reader here, because
// TLS server will not speak until spoken to.
br := bufio.NewReader(netConn)
resp, err := http.ReadResponse(br, connectReq)
if err != nil {
return nil, nil, err
}
if resp.StatusCode != 200 {
f := strings.SplitN(resp.Status, " ", 2)
return nil, nil, errors.New(f[1])
}
}
if u.Scheme == "https" {
cfg := cloneTLSConfig(d.TLSClientConfig)
if cfg.ServerName == "" {
cfg.ServerName = hostNoPort
}
tlsConn := tls.Client(netConn, cfg)
netConn = tlsConn
if err := tlsConn.Handshake(); err != nil {
return nil, nil, err
}
if !cfg.InsecureSkipVerify {
if err := tlsConn.VerifyHostname(cfg.ServerName); err != nil {
return nil, nil, err
}
}
}
conn := newConn(netConn, false, d.ReadBufferSize, d.WriteBufferSize)
if err := req.Write(netConn); err != nil {
return nil, nil, err
}
resp, err := http.ReadResponse(conn.br, req)
if err != nil {
return nil, nil, err
}
if d.Jar != nil {
if rc := resp.Cookies(); len(rc) > 0 {
d.Jar.SetCookies(u, rc)
}
}
if resp.StatusCode != 101 ||
!strings.EqualFold(resp.Header.Get("Upgrade"), "websocket") ||
!strings.EqualFold(resp.Header.Get("Connection"), "upgrade") ||
resp.Header.Get("Sec-Websocket-Accept") != acceptKey {
resp.Header.Get("Sec-Websocket-Accept") != computeAcceptKey(challengeKey) {
// Before closing the network connection on return from this
// function, slurp up some of the response to aid application
// debugging.
buf := make([]byte, 1024)
n, _ := io.ReadFull(resp.Body, buf)
resp.Body = ioutil.NopCloser(bytes.NewReader(buf[:n]))
return nil, resp, ErrBadHandshake
}
return c, resp, nil
for _, ext := range parseExtensions(resp.Header) {
if ext[""] != "permessage-deflate" {
continue
}
_, snct := ext["server_no_context_takeover"]
_, cnct := ext["client_no_context_takeover"]
if !snct || !cnct {
return nil, resp, errInvalidCompression
}
conn.newCompressionWriter = compressNoContextTakeover
conn.newDecompressionReader = decompressNoContextTakeover
break
}
resp.Body = ioutil.NopCloser(bytes.NewReader([]byte{}))
conn.subprotocol = resp.Header.Get("Sec-Websocket-Protocol")
netConn.SetDeadline(time.Time{})
netConn = nil // to avoid close in defer.
return conn, resp, nil
}

16
vendor/github.com/gorilla/websocket/client_clone.go generated vendored Normal file
View File

@ -0,0 +1,16 @@
// Copyright 2013 The Gorilla WebSocket Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build go1.8
package websocket
import "crypto/tls"
func cloneTLSConfig(cfg *tls.Config) *tls.Config {
if cfg == nil {
return &tls.Config{}
}
return cfg.Clone()
}

View File

@ -0,0 +1,38 @@
// Copyright 2013 The Gorilla WebSocket Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build !go1.8
package websocket
import "crypto/tls"
// cloneTLSConfig clones all public fields except the fields
// SessionTicketsDisabled and SessionTicketKey. This avoids copying the
// sync.Mutex in the sync.Once and makes it safe to call cloneTLSConfig on a
// config in active use.
func cloneTLSConfig(cfg *tls.Config) *tls.Config {
if cfg == nil {
return &tls.Config{}
}
return &tls.Config{
Rand: cfg.Rand,
Time: cfg.Time,
Certificates: cfg.Certificates,
NameToCertificate: cfg.NameToCertificate,
GetCertificate: cfg.GetCertificate,
RootCAs: cfg.RootCAs,
NextProtos: cfg.NextProtos,
ServerName: cfg.ServerName,
ClientAuth: cfg.ClientAuth,
ClientCAs: cfg.ClientCAs,
InsecureSkipVerify: cfg.InsecureSkipVerify,
CipherSuites: cfg.CipherSuites,
PreferServerCipherSuites: cfg.PreferServerCipherSuites,
ClientSessionCache: cfg.ClientSessionCache,
MinVersion: cfg.MinVersion,
MaxVersion: cfg.MaxVersion,
CurvePreferences: cfg.CurvePreferences,
}
}

View File

@ -1,86 +1,421 @@
// Copyright 2013 Gary Burd. All rights reserved.
// Copyright 2013 The Gorilla WebSocket Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package websocket_test
package websocket
import (
"crypto/tls"
"crypto/x509"
"encoding/base64"
"io"
"io/ioutil"
"net"
"net/http"
"net/http/cookiejar"
"net/http/httptest"
"net/url"
"reflect"
"strings"
"testing"
"time"
"github.com/gorilla/websocket"
)
type wsHandler struct {
*testing.T
var cstUpgrader = Upgrader{
Subprotocols: []string{"p0", "p1"},
ReadBufferSize: 1024,
WriteBufferSize: 1024,
EnableCompression: true,
Error: func(w http.ResponseWriter, r *http.Request, status int, reason error) {
http.Error(w, reason.Error(), status)
},
}
func (t wsHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if r.Method != "GET" {
http.Error(w, "Method not allowed", 405)
t.Logf("bad method: %s", r.Method)
var cstDialer = Dialer{
Subprotocols: []string{"p1", "p2"},
ReadBufferSize: 1024,
WriteBufferSize: 1024,
}
type cstHandler struct{ *testing.T }
type cstServer struct {
*httptest.Server
URL string
}
const (
cstPath = "/a/b"
cstRawQuery = "x=y"
cstRequestURI = cstPath + "?" + cstRawQuery
)
func newServer(t *testing.T) *cstServer {
var s cstServer
s.Server = httptest.NewServer(cstHandler{t})
s.Server.URL += cstRequestURI
s.URL = makeWsProto(s.Server.URL)
return &s
}
func newTLSServer(t *testing.T) *cstServer {
var s cstServer
s.Server = httptest.NewTLSServer(cstHandler{t})
s.Server.URL += cstRequestURI
s.URL = makeWsProto(s.Server.URL)
return &s
}
func (t cstHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != cstPath {
t.Logf("path=%v, want %v", r.URL.Path, cstPath)
http.Error(w, "bad path", 400)
return
}
if r.Header.Get("Origin") != "http://"+r.Host {
http.Error(w, "Origin not allowed", 403)
t.Logf("bad origin: %s", r.Header.Get("Origin"))
if r.URL.RawQuery != cstRawQuery {
t.Logf("query=%v, want %v", r.URL.RawQuery, cstRawQuery)
http.Error(w, "bad path", 400)
return
}
ws, err := websocket.Upgrade(w, r, http.Header{"Set-Cookie": {"sessionID=1234"}}, 1024, 1024)
if _, ok := err.(websocket.HandshakeError); ok {
t.Logf("bad handshake: %v", err)
http.Error(w, "Not a websocket handshake", 400)
subprotos := Subprotocols(r)
if !reflect.DeepEqual(subprotos, cstDialer.Subprotocols) {
t.Logf("subprotols=%v, want %v", subprotos, cstDialer.Subprotocols)
http.Error(w, "bad protocol", 400)
return
} else if err != nil {
t.Logf("upgrade error: %v", err)
}
ws, err := cstUpgrader.Upgrade(w, r, http.Header{"Set-Cookie": {"sessionID=1234"}})
if err != nil {
t.Logf("Upgrade: %v", err)
return
}
defer ws.Close()
for {
op, r, err := ws.NextReader()
if err != nil {
if err != io.EOF {
t.Logf("NextReader: %v", err)
if ws.Subprotocol() != "p1" {
t.Logf("Subprotocol() = %s, want p1", ws.Subprotocol())
ws.Close()
return
}
op, rd, err := ws.NextReader()
if err != nil {
t.Logf("NextReader: %v", err)
return
}
wr, err := ws.NextWriter(op)
if err != nil {
t.Logf("NextWriter: %v", err)
return
}
if _, err = io.Copy(wr, rd); err != nil {
t.Logf("NextWriter: %v", err)
return
}
if err := wr.Close(); err != nil {
t.Logf("Close: %v", err)
return
}
}
func makeWsProto(s string) string {
return "ws" + strings.TrimPrefix(s, "http")
}
func sendRecv(t *testing.T, ws *Conn) {
const message = "Hello World!"
if err := ws.SetWriteDeadline(time.Now().Add(time.Second)); err != nil {
t.Fatalf("SetWriteDeadline: %v", err)
}
if err := ws.WriteMessage(TextMessage, []byte(message)); err != nil {
t.Fatalf("WriteMessage: %v", err)
}
if err := ws.SetReadDeadline(time.Now().Add(time.Second)); err != nil {
t.Fatalf("SetReadDeadline: %v", err)
}
_, p, err := ws.ReadMessage()
if err != nil {
t.Fatalf("ReadMessage: %v", err)
}
if string(p) != message {
t.Fatalf("message=%s, want %s", p, message)
}
}
func TestProxyDial(t *testing.T) {
s := newServer(t)
defer s.Close()
surl, _ := url.Parse(s.URL)
cstDialer.Proxy = http.ProxyURL(surl)
connect := false
origHandler := s.Server.Config.Handler
// Capture the request Host header.
s.Server.Config.Handler = http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) {
if r.Method == "CONNECT" {
connect = true
w.WriteHeader(200)
return
}
return
if !connect {
t.Log("connect not recieved")
http.Error(w, "connect not recieved", 405)
return
}
origHandler.ServeHTTP(w, r)
})
ws, _, err := cstDialer.Dial(s.URL, nil)
if err != nil {
t.Fatalf("Dial: %v", err)
}
defer ws.Close()
sendRecv(t, ws)
cstDialer.Proxy = http.ProxyFromEnvironment
}
func TestProxyAuthorizationDial(t *testing.T) {
s := newServer(t)
defer s.Close()
surl, _ := url.Parse(s.URL)
surl.User = url.UserPassword("username", "password")
cstDialer.Proxy = http.ProxyURL(surl)
connect := false
origHandler := s.Server.Config.Handler
// Capture the request Host header.
s.Server.Config.Handler = http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) {
proxyAuth := r.Header.Get("Proxy-Authorization")
expectedProxyAuth := "Basic " + base64.StdEncoding.EncodeToString([]byte("username:password"))
if r.Method == "CONNECT" && proxyAuth == expectedProxyAuth {
connect = true
w.WriteHeader(200)
return
}
if !connect {
t.Log("connect with proxy authorization not recieved")
http.Error(w, "connect with proxy authorization not recieved", 405)
return
}
origHandler.ServeHTTP(w, r)
})
ws, _, err := cstDialer.Dial(s.URL, nil)
if err != nil {
t.Fatalf("Dial: %v", err)
}
defer ws.Close()
sendRecv(t, ws)
cstDialer.Proxy = http.ProxyFromEnvironment
}
func TestDial(t *testing.T) {
s := newServer(t)
defer s.Close()
ws, _, err := cstDialer.Dial(s.URL, nil)
if err != nil {
t.Fatalf("Dial: %v", err)
}
defer ws.Close()
sendRecv(t, ws)
}
func TestDialCookieJar(t *testing.T) {
s := newServer(t)
defer s.Close()
jar, _ := cookiejar.New(nil)
d := cstDialer
d.Jar = jar
u, _ := parseURL(s.URL)
switch u.Scheme {
case "ws":
u.Scheme = "http"
case "wss":
u.Scheme = "https"
}
cookies := []*http.Cookie{&http.Cookie{Name: "gorilla", Value: "ws", Path: "/"}}
d.Jar.SetCookies(u, cookies)
ws, _, err := d.Dial(s.URL, nil)
if err != nil {
t.Fatalf("Dial: %v", err)
}
defer ws.Close()
var gorilla string
var sessionID string
for _, c := range d.Jar.Cookies(u) {
if c.Name == "gorilla" {
gorilla = c.Value
}
if op == websocket.PongMessage {
continue
if c.Name == "sessionID" {
sessionID = c.Value
}
w, err := ws.NextWriter(op)
}
if gorilla != "ws" {
t.Error("Cookie not present in jar.")
}
if sessionID != "1234" {
t.Error("Set-Cookie not received from the server.")
}
sendRecv(t, ws)
}
func TestDialTLS(t *testing.T) {
s := newTLSServer(t)
defer s.Close()
certs := x509.NewCertPool()
for _, c := range s.TLS.Certificates {
roots, err := x509.ParseCertificates(c.Certificate[len(c.Certificate)-1])
if err != nil {
t.Logf("NextWriter: %v", err)
return
t.Fatalf("error parsing server's root cert: %v", err)
}
if _, err = io.Copy(w, r); err != nil {
t.Logf("Copy: %v", err)
return
for _, root := range roots {
certs.AddCert(root)
}
if err := w.Close(); err != nil {
t.Logf("Close: %v", err)
return
}
d := cstDialer
d.TLSClientConfig = &tls.Config{RootCAs: certs}
ws, _, err := d.Dial(s.URL, nil)
if err != nil {
t.Fatalf("Dial: %v", err)
}
defer ws.Close()
sendRecv(t, ws)
}
func xTestDialTLSBadCert(t *testing.T) {
// This test is deactivated because of noisy logging from the net/http package.
s := newTLSServer(t)
defer s.Close()
ws, _, err := cstDialer.Dial(s.URL, nil)
if err == nil {
ws.Close()
t.Fatalf("Dial: nil")
}
}
func TestDialTLSNoVerify(t *testing.T) {
s := newTLSServer(t)
defer s.Close()
d := cstDialer
d.TLSClientConfig = &tls.Config{InsecureSkipVerify: true}
ws, _, err := d.Dial(s.URL, nil)
if err != nil {
t.Fatalf("Dial: %v", err)
}
defer ws.Close()
sendRecv(t, ws)
}
func TestDialTimeout(t *testing.T) {
s := newServer(t)
defer s.Close()
d := cstDialer
d.HandshakeTimeout = -1
ws, _, err := d.Dial(s.URL, nil)
if err == nil {
ws.Close()
t.Fatalf("Dial: nil")
}
}
func TestDialBadScheme(t *testing.T) {
s := newServer(t)
defer s.Close()
ws, _, err := cstDialer.Dial(s.Server.URL, nil)
if err == nil {
ws.Close()
t.Fatalf("Dial: nil")
}
}
func TestDialBadOrigin(t *testing.T) {
s := newServer(t)
defer s.Close()
ws, resp, err := cstDialer.Dial(s.URL, http.Header{"Origin": {"bad"}})
if err == nil {
ws.Close()
t.Fatalf("Dial: nil")
}
if resp == nil {
t.Fatalf("resp=nil, err=%v", err)
}
if resp.StatusCode != http.StatusForbidden {
t.Fatalf("status=%d, want %d", resp.StatusCode, http.StatusForbidden)
}
}
func TestDialBadHeader(t *testing.T) {
s := newServer(t)
defer s.Close()
for _, k := range []string{"Upgrade",
"Connection",
"Sec-Websocket-Key",
"Sec-Websocket-Version",
"Sec-Websocket-Protocol"} {
h := http.Header{}
h.Set(k, "bad")
ws, _, err := cstDialer.Dial(s.URL, http.Header{"Origin": {"bad"}})
if err == nil {
ws.Close()
t.Errorf("Dial with header %s returned nil", k)
}
}
}
func TestClientServer(t *testing.T) {
s := httptest.NewServer(wsHandler{t})
func TestBadMethod(t *testing.T) {
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ws, err := cstUpgrader.Upgrade(w, r, nil)
if err == nil {
t.Errorf("handshake succeeded, expect fail")
ws.Close()
}
}))
defer s.Close()
u, _ := url.Parse(s.URL)
c, err := net.Dial("tcp", u.Host)
resp, err := http.PostForm(s.URL, url.Values{})
if err != nil {
t.Fatalf("PostForm returned error %v", err)
}
resp.Body.Close()
if resp.StatusCode != http.StatusMethodNotAllowed {
t.Errorf("Status = %d, want %d", resp.StatusCode, http.StatusMethodNotAllowed)
}
}
func TestHandshake(t *testing.T) {
s := newServer(t)
defer s.Close()
ws, resp, err := cstDialer.Dial(s.URL, http.Header{"Origin": {s.URL}})
if err != nil {
t.Fatalf("Dial: %v", err)
}
ws, resp, err := websocket.NewClient(c, u, http.Header{"Origin": {s.URL}}, 1024, 1024)
if err != nil {
t.Fatalf("NewClient: %v", err)
}
defer ws.Close()
var sessionID string
@ -93,22 +428,85 @@ func TestClientServer(t *testing.T) {
t.Error("Set-Cookie not received from the server.")
}
w, _ := ws.NextWriter(websocket.TextMessage)
io.WriteString(w, "HELLO")
w.Close()
ws.SetReadDeadline(time.Now().Add(1 * time.Second))
op, r, err := ws.NextReader()
if ws.Subprotocol() != "p1" {
t.Errorf("ws.Subprotocol() = %s, want p1", ws.Subprotocol())
}
sendRecv(t, ws)
}
func TestRespOnBadHandshake(t *testing.T) {
const expectedStatus = http.StatusGone
const expectedBody = "This is the response body."
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(expectedStatus)
io.WriteString(w, expectedBody)
}))
defer s.Close()
ws, resp, err := cstDialer.Dial(makeWsProto(s.URL), nil)
if err == nil {
ws.Close()
t.Fatalf("Dial: nil")
}
if resp == nil {
t.Fatalf("resp=nil, err=%v", err)
}
if resp.StatusCode != expectedStatus {
t.Errorf("resp.StatusCode=%d, want %d", resp.StatusCode, expectedStatus)
}
p, err := ioutil.ReadAll(resp.Body)
if err != nil {
t.Fatalf("NextReader: %v", err)
t.Fatalf("ReadFull(resp.Body) returned error %v", err)
}
if op != websocket.TextMessage {
t.Fatalf("op=%d, want %d", op, websocket.TextMessage)
}
b, err := ioutil.ReadAll(r)
if err != nil {
t.Fatalf("ReadAll: %v", err)
}
if string(b) != "HELLO" {
t.Fatalf("message=%s, want %s", b, "HELLO")
if string(p) != expectedBody {
t.Errorf("resp.Body=%s, want %s", p, expectedBody)
}
}
// TestHostHeader confirms that the host header provided in the call to Dial is
// sent to the server.
func TestHostHeader(t *testing.T) {
s := newServer(t)
defer s.Close()
specifiedHost := make(chan string, 1)
origHandler := s.Server.Config.Handler
// Capture the request Host header.
s.Server.Config.Handler = http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) {
specifiedHost <- r.Host
origHandler.ServeHTTP(w, r)
})
ws, _, err := cstDialer.Dial(s.URL, http.Header{"Host": {"testhost"}})
if err != nil {
t.Fatalf("Dial: %v", err)
}
defer ws.Close()
if gotHost := <-specifiedHost; gotHost != "testhost" {
t.Fatalf("gotHost = %q, want \"testhost\"", gotHost)
}
sendRecv(t, ws)
}
func TestDialCompression(t *testing.T) {
s := newServer(t)
defer s.Close()
dialer := cstDialer
dialer.EnableCompression = true
ws, _, err := dialer.Dial(s.URL, nil)
if err != nil {
t.Fatalf("Dial: %v", err)
}
defer ws.Close()
sendRecv(t, ws)
}

72
vendor/github.com/gorilla/websocket/client_test.go generated vendored Normal file
View File

@ -0,0 +1,72 @@
// Copyright 2014 The Gorilla WebSocket Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package websocket
import (
"net/url"
"reflect"
"testing"
)
var parseURLTests = []struct {
s string
u *url.URL
rui string
}{
{"ws://example.com/", &url.URL{Scheme: "ws", Host: "example.com", Opaque: "/"}, "/"},
{"ws://example.com", &url.URL{Scheme: "ws", Host: "example.com", Opaque: "/"}, "/"},
{"ws://example.com:7777/", &url.URL{Scheme: "ws", Host: "example.com:7777", Opaque: "/"}, "/"},
{"wss://example.com/", &url.URL{Scheme: "wss", Host: "example.com", Opaque: "/"}, "/"},
{"wss://example.com/a/b", &url.URL{Scheme: "wss", Host: "example.com", Opaque: "/a/b"}, "/a/b"},
{"ss://example.com/a/b", nil, ""},
{"ws://webmaster@example.com/", nil, ""},
{"wss://example.com/a/b?x=y", &url.URL{Scheme: "wss", Host: "example.com", Opaque: "/a/b", RawQuery: "x=y"}, "/a/b?x=y"},
{"wss://example.com?x=y", &url.URL{Scheme: "wss", Host: "example.com", Opaque: "/", RawQuery: "x=y"}, "/?x=y"},
}
func TestParseURL(t *testing.T) {
for _, tt := range parseURLTests {
u, err := parseURL(tt.s)
if tt.u != nil && err != nil {
t.Errorf("parseURL(%q) returned error %v", tt.s, err)
continue
}
if tt.u == nil {
if err == nil {
t.Errorf("parseURL(%q) did not return error", tt.s)
}
continue
}
if !reflect.DeepEqual(u, tt.u) {
t.Errorf("parseURL(%q) = %v, want %v", tt.s, u, tt.u)
continue
}
if u.RequestURI() != tt.rui {
t.Errorf("parseURL(%q).RequestURI() = %v, want %v", tt.s, u.RequestURI(), tt.rui)
}
}
}
var hostPortNoPortTests = []struct {
u *url.URL
hostPort, hostNoPort string
}{
{&url.URL{Scheme: "ws", Host: "example.com"}, "example.com:80", "example.com"},
{&url.URL{Scheme: "wss", Host: "example.com"}, "example.com:443", "example.com"},
{&url.URL{Scheme: "ws", Host: "example.com:7777"}, "example.com:7777", "example.com"},
{&url.URL{Scheme: "wss", Host: "example.com:7777"}, "example.com:7777", "example.com"},
}
func TestHostPortNoPort(t *testing.T) {
for _, tt := range hostPortNoPortTests {
hostPort, hostNoPort := hostPortNoPort(tt.u)
if hostPort != tt.hostPort {
t.Errorf("hostPortNoPort(%v) returned hostPort %q, want %q", tt.u, hostPort, tt.hostPort)
}
if hostNoPort != tt.hostNoPort {
t.Errorf("hostPortNoPort(%v) returned hostNoPort %q, want %q", tt.u, hostNoPort, tt.hostNoPort)
}
}
}

148
vendor/github.com/gorilla/websocket/compression.go generated vendored Normal file
View File

@ -0,0 +1,148 @@
// Copyright 2017 The Gorilla WebSocket Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package websocket
import (
"compress/flate"
"errors"
"io"
"strings"
"sync"
)
const (
minCompressionLevel = -2 // flate.HuffmanOnly not defined in Go < 1.6
maxCompressionLevel = flate.BestCompression
defaultCompressionLevel = 1
)
var (
flateWriterPools [maxCompressionLevel - minCompressionLevel + 1]sync.Pool
flateReaderPool = sync.Pool{New: func() interface{} {
return flate.NewReader(nil)
}}
)
func decompressNoContextTakeover(r io.Reader) io.ReadCloser {
const tail =
// Add four bytes as specified in RFC
"\x00\x00\xff\xff" +
// Add final block to squelch unexpected EOF error from flate reader.
"\x01\x00\x00\xff\xff"
fr, _ := flateReaderPool.Get().(io.ReadCloser)
fr.(flate.Resetter).Reset(io.MultiReader(r, strings.NewReader(tail)), nil)
return &flateReadWrapper{fr}
}
func isValidCompressionLevel(level int) bool {
return minCompressionLevel <= level && level <= maxCompressionLevel
}
func compressNoContextTakeover(w io.WriteCloser, level int) io.WriteCloser {
p := &flateWriterPools[level-minCompressionLevel]
tw := &truncWriter{w: w}
fw, _ := p.Get().(*flate.Writer)
if fw == nil {
fw, _ = flate.NewWriter(tw, level)
} else {
fw.Reset(tw)
}
return &flateWriteWrapper{fw: fw, tw: tw, p: p}
}
// truncWriter is an io.Writer that writes all but the last four bytes of the
// stream to another io.Writer.
type truncWriter struct {
w io.WriteCloser
n int
p [4]byte
}
func (w *truncWriter) Write(p []byte) (int, error) {
n := 0
// fill buffer first for simplicity.
if w.n < len(w.p) {
n = copy(w.p[w.n:], p)
p = p[n:]
w.n += n
if len(p) == 0 {
return n, nil
}
}
m := len(p)
if m > len(w.p) {
m = len(w.p)
}
if nn, err := w.w.Write(w.p[:m]); err != nil {
return n + nn, err
}
copy(w.p[:], w.p[m:])
copy(w.p[len(w.p)-m:], p[len(p)-m:])
nn, err := w.w.Write(p[:len(p)-m])
return n + nn, err
}
type flateWriteWrapper struct {
fw *flate.Writer
tw *truncWriter
p *sync.Pool
}
func (w *flateWriteWrapper) Write(p []byte) (int, error) {
if w.fw == nil {
return 0, errWriteClosed
}
return w.fw.Write(p)
}
func (w *flateWriteWrapper) Close() error {
if w.fw == nil {
return errWriteClosed
}
err1 := w.fw.Flush()
w.p.Put(w.fw)
w.fw = nil
if w.tw.p != [4]byte{0, 0, 0xff, 0xff} {
return errors.New("websocket: internal error, unexpected bytes at end of flate stream")
}
err2 := w.tw.w.Close()
if err1 != nil {
return err1
}
return err2
}
type flateReadWrapper struct {
fr io.ReadCloser
}
func (r *flateReadWrapper) Read(p []byte) (int, error) {
if r.fr == nil {
return 0, io.ErrClosedPipe
}
n, err := r.fr.Read(p)
if err == io.EOF {
// Preemptively place the reader back in the pool. This helps with
// scenarios where the application does not call NextReader() soon after
// this final read.
r.Close()
}
return n, err
}
func (r *flateReadWrapper) Close() error {
if r.fr == nil {
return io.ErrClosedPipe
}
err := r.fr.Close()
flateReaderPool.Put(r.fr)
r.fr = nil
return err
}

View File

@ -0,0 +1,80 @@
package websocket
import (
"bytes"
"fmt"
"io"
"io/ioutil"
"testing"
)
type nopCloser struct{ io.Writer }
func (nopCloser) Close() error { return nil }
func TestTruncWriter(t *testing.T) {
const data = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijlkmnopqrstuvwxyz987654321"
for n := 1; n <= 10; n++ {
var b bytes.Buffer
w := &truncWriter{w: nopCloser{&b}}
p := []byte(data)
for len(p) > 0 {
m := len(p)
if m > n {
m = n
}
w.Write(p[:m])
p = p[m:]
}
if b.String() != data[:len(data)-len(w.p)] {
t.Errorf("%d: %q", n, b.String())
}
}
}
func textMessages(num int) [][]byte {
messages := make([][]byte, num)
for i := 0; i < num; i++ {
msg := fmt.Sprintf("planet: %d, country: %d, city: %d, street: %d", i, i, i, i)
messages[i] = []byte(msg)
}
return messages
}
func BenchmarkWriteNoCompression(b *testing.B) {
w := ioutil.Discard
c := newConn(fakeNetConn{Reader: nil, Writer: w}, false, 1024, 1024)
messages := textMessages(100)
b.ResetTimer()
for i := 0; i < b.N; i++ {
c.WriteMessage(TextMessage, messages[i%len(messages)])
}
b.ReportAllocs()
}
func BenchmarkWriteWithCompression(b *testing.B) {
w := ioutil.Discard
c := newConn(fakeNetConn{Reader: nil, Writer: w}, false, 1024, 1024)
messages := textMessages(100)
c.enableWriteCompression = true
c.newCompressionWriter = compressNoContextTakeover
b.ResetTimer()
for i := 0; i < b.N; i++ {
c.WriteMessage(TextMessage, messages[i%len(messages)])
}
b.ReportAllocs()
}
func TestValidCompressionLevel(t *testing.T) {
c := newConn(fakeNetConn{}, false, 1024, 1024)
for _, level := range []int{minCompressionLevel - 1, maxCompressionLevel + 1} {
if err := c.SetCompressionLevel(level); err == nil {
t.Errorf("no error for level %d", level)
}
}
for _, level := range []int{minCompressionLevel, maxCompressionLevel} {
if err := c.SetCompressionLevel(level); err != nil {
t.Errorf("error for level %d", level)
}
}
}

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,134 @@
// Copyright 2017 The Gorilla WebSocket Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build go1.7
package websocket
import (
"io"
"io/ioutil"
"sync/atomic"
"testing"
)
// broadcastBench allows to run broadcast benchmarks.
// In every broadcast benchmark we create many connections, then send the same
// message into every connection and wait for all writes complete. This emulates
// an application where many connections listen to the same data - i.e. PUB/SUB
// scenarios with many subscribers in one channel.
type broadcastBench struct {
w io.Writer
message *broadcastMessage
closeCh chan struct{}
doneCh chan struct{}
count int32
conns []*broadcastConn
compression bool
usePrepared bool
}
type broadcastMessage struct {
payload []byte
prepared *PreparedMessage
}
type broadcastConn struct {
conn *Conn
msgCh chan *broadcastMessage
}
func newBroadcastConn(c *Conn) *broadcastConn {
return &broadcastConn{
conn: c,
msgCh: make(chan *broadcastMessage, 1),
}
}
func newBroadcastBench(usePrepared, compression bool) *broadcastBench {
bench := &broadcastBench{
w: ioutil.Discard,
doneCh: make(chan struct{}),
closeCh: make(chan struct{}),
usePrepared: usePrepared,
compression: compression,
}
msg := &broadcastMessage{
payload: textMessages(1)[0],
}
if usePrepared {
pm, _ := NewPreparedMessage(TextMessage, msg.payload)
msg.prepared = pm
}
bench.message = msg
bench.makeConns(10000)
return bench
}
func (b *broadcastBench) makeConns(numConns int) {
conns := make([]*broadcastConn, numConns)
for i := 0; i < numConns; i++ {
c := newConn(fakeNetConn{Reader: nil, Writer: b.w}, true, 1024, 1024)
if b.compression {
c.enableWriteCompression = true
c.newCompressionWriter = compressNoContextTakeover
}
conns[i] = newBroadcastConn(c)
go func(c *broadcastConn) {
for {
select {
case msg := <-c.msgCh:
if b.usePrepared {
c.conn.WritePreparedMessage(msg.prepared)
} else {
c.conn.WriteMessage(TextMessage, msg.payload)
}
val := atomic.AddInt32(&b.count, 1)
if val%int32(numConns) == 0 {
b.doneCh <- struct{}{}
}
case <-b.closeCh:
return
}
}
}(conns[i])
}
b.conns = conns
}
func (b *broadcastBench) close() {
close(b.closeCh)
}
func (b *broadcastBench) runOnce() {
for _, c := range b.conns {
c.msgCh <- b.message
}
<-b.doneCh
}
func BenchmarkBroadcast(b *testing.B) {
benchmarks := []struct {
name string
usePrepared bool
compression bool
}{
{"NoCompression", false, false},
{"WithCompression", false, true},
{"NoCompressionPrepared", true, false},
{"WithCompressionPrepared", true, true},
}
for _, bm := range benchmarks {
b.Run(bm.name, func(b *testing.B) {
bench := newBroadcastBench(bm.usePrepared, bm.compression)
defer bench.close()
b.ResetTimer()
for i := 0; i < b.N; i++ {
bench.runOnce()
}
b.ReportAllocs()
})
}
}

18
vendor/github.com/gorilla/websocket/conn_read.go generated vendored Normal file
View File

@ -0,0 +1,18 @@
// Copyright 2016 The Gorilla WebSocket Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build go1.5
package websocket
import "io"
func (c *Conn) read(n int) ([]byte, error) {
p, err := c.br.Peek(n)
if err == io.EOF {
err = errUnexpectedEOF
}
c.br.Discard(len(p))
return p, err
}

View File

@ -0,0 +1,21 @@
// Copyright 2016 The Gorilla WebSocket Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build !go1.5
package websocket
import "io"
func (c *Conn) read(n int) ([]byte, error) {
p, err := c.br.Peek(n)
if err == io.EOF {
err = errUnexpectedEOF
}
if len(p) > 0 {
// advance over the bytes just read
io.ReadFull(c.br, p)
}
return p, err
}

View File

@ -1,32 +1,52 @@
// Copyright 2013 Gary Burd. All rights reserved.
// Copyright 2013 The Gorilla WebSocket Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package websocket
import (
"bufio"
"bytes"
"errors"
"fmt"
"io"
"io/ioutil"
"net"
"reflect"
"testing"
"testing/iotest"
"time"
)
var _ net.Error = errWriteTimeout
type fakeNetConn struct {
io.Reader
io.Writer
}
func (c fakeNetConn) Close() error { return nil }
func (c fakeNetConn) LocalAddr() net.Addr { return nil }
func (c fakeNetConn) RemoteAddr() net.Addr { return nil }
func (c fakeNetConn) LocalAddr() net.Addr { return localAddr }
func (c fakeNetConn) RemoteAddr() net.Addr { return remoteAddr }
func (c fakeNetConn) SetDeadline(t time.Time) error { return nil }
func (c fakeNetConn) SetReadDeadline(t time.Time) error { return nil }
func (c fakeNetConn) SetWriteDeadline(t time.Time) error { return nil }
type fakeAddr int
var (
localAddr = fakeAddr(1)
remoteAddr = fakeAddr(2)
)
func (a fakeAddr) Network() string {
return "net"
}
func (a fakeAddr) String() string {
return "str"
}
func TestFraming(t *testing.T) {
frameSizes := []int{0, 1, 2, 124, 125, 126, 127, 128, 129, 65534, 65535, 65536, 65537}
var readChunkers = []struct {
@ -37,66 +57,78 @@ func TestFraming(t *testing.T) {
{"one", iotest.OneByteReader},
{"asis", func(r io.Reader) io.Reader { return r }},
}
writeBuf := make([]byte, 65537)
for i := range writeBuf {
writeBuf[i] = byte(i)
}
var writers = []struct {
name string
f func(w io.Writer, n int) (int, error)
}{
{"iocopy", func(w io.Writer, n int) (int, error) {
nn, err := io.Copy(w, bytes.NewReader(writeBuf[:n]))
return int(nn), err
}},
{"write", func(w io.Writer, n int) (int, error) {
return w.Write(writeBuf[:n])
}},
{"string", func(w io.Writer, n int) (int, error) {
return io.WriteString(w, string(writeBuf[:n]))
}},
}
for _, isServer := range []bool{true, false} {
for _, chunker := range readChunkers {
for _, compress := range []bool{false, true} {
for _, isServer := range []bool{true, false} {
for _, chunker := range readChunkers {
var connBuf bytes.Buffer
wc := newConn(fakeNetConn{Reader: nil, Writer: &connBuf}, isServer, 1024, 1024)
rc := newConn(fakeNetConn{Reader: chunker.f(&connBuf), Writer: nil}, !isServer, 1024, 1024)
var connBuf bytes.Buffer
wc := newConn(fakeNetConn{Reader: nil, Writer: &connBuf}, isServer, 1024, 1024)
rc := newConn(fakeNetConn{Reader: chunker.f(&connBuf), Writer: nil}, !isServer, 1024, 1024)
if compress {
wc.newCompressionWriter = compressNoContextTakeover
rc.newDecompressionReader = decompressNoContextTakeover
}
for _, n := range frameSizes {
for _, writer := range writers {
name := fmt.Sprintf("z:%v, s:%v, r:%s, n:%d w:%s", compress, isServer, chunker.name, n, writer.name)
for _, n := range frameSizes {
for _, iocopy := range []bool{true, false} {
name := fmt.Sprintf("s:%v, r:%s, n:%d c:%v", isServer, chunker.name, n, iocopy)
w, err := wc.NextWriter(TextMessage)
if err != nil {
t.Errorf("%s: wc.NextWriter() returned %v", name, err)
continue
}
nn, err := writer.f(w, n)
if err != nil || nn != n {
t.Errorf("%s: w.Write(writeBuf[:n]) returned %d, %v", name, nn, err)
continue
}
err = w.Close()
if err != nil {
t.Errorf("%s: w.Close() returned %v", name, err)
continue
}
w, err := wc.NextWriter(TextMessage)
if err != nil {
t.Errorf("%s: wc.NextWriter() returned %v", name, err)
continue
}
var nn int
if iocopy {
var n64 int64
n64, err = io.Copy(w, bytes.NewReader(writeBuf[:n]))
nn = int(n64)
} else {
nn, err = w.Write(writeBuf[:n])
}
if err != nil || nn != n {
t.Errorf("%s: w.Write(writeBuf[:n]) returned %d, %v", name, nn, err)
continue
}
err = w.Close()
if err != nil {
t.Errorf("%s: w.Close() returned %v", name, err)
continue
}
opCode, r, err := rc.NextReader()
if err != nil || opCode != TextMessage {
t.Errorf("%s: NextReader() returned %d, r, %v", name, opCode, err)
continue
}
rbuf, err := ioutil.ReadAll(r)
if err != nil {
t.Errorf("%s: ReadFull() returned rbuf, %v", name, err)
continue
}
opCode, r, err := rc.NextReader()
if err != nil || opCode != TextMessage {
t.Errorf("%s: NextReader() returned %d, r, %v", name, opCode, err)
continue
}
rbuf, err := ioutil.ReadAll(r)
if err != nil {
t.Errorf("%s: ReadFull() returned rbuf, %v", name, err)
continue
}
if len(rbuf) != n {
t.Errorf("%s: len(rbuf) is %d, want %d", name, len(rbuf), n)
continue
}
if len(rbuf) != n {
t.Errorf("%s: len(rbuf) is %d, want %d", name, len(rbuf), n)
continue
}
for i, b := range rbuf {
if byte(i) != b {
t.Errorf("%s: bad byte at offset %d", name, i)
break
for i, b := range rbuf {
if byte(i) != b {
t.Errorf("%s: bad byte at offset %d", name, i)
break
}
}
}
}
@ -105,6 +137,155 @@ func TestFraming(t *testing.T) {
}
}
func TestControl(t *testing.T) {
const message = "this is a ping/pong messsage"
for _, isServer := range []bool{true, false} {
for _, isWriteControl := range []bool{true, false} {
name := fmt.Sprintf("s:%v, wc:%v", isServer, isWriteControl)
var connBuf bytes.Buffer
wc := newConn(fakeNetConn{Reader: nil, Writer: &connBuf}, isServer, 1024, 1024)
rc := newConn(fakeNetConn{Reader: &connBuf, Writer: nil}, !isServer, 1024, 1024)
if isWriteControl {
wc.WriteControl(PongMessage, []byte(message), time.Now().Add(time.Second))
} else {
w, err := wc.NextWriter(PongMessage)
if err != nil {
t.Errorf("%s: wc.NextWriter() returned %v", name, err)
continue
}
if _, err := w.Write([]byte(message)); err != nil {
t.Errorf("%s: w.Write() returned %v", name, err)
continue
}
if err := w.Close(); err != nil {
t.Errorf("%s: w.Close() returned %v", name, err)
continue
}
var actualMessage string
rc.SetPongHandler(func(s string) error { actualMessage = s; return nil })
rc.NextReader()
if actualMessage != message {
t.Errorf("%s: pong=%q, want %q", name, actualMessage, message)
continue
}
}
}
}
}
func TestCloseFrameBeforeFinalMessageFrame(t *testing.T) {
const bufSize = 512
expectedErr := &CloseError{Code: CloseNormalClosure, Text: "hello"}
var b1, b2 bytes.Buffer
wc := newConn(fakeNetConn{Reader: nil, Writer: &b1}, false, 1024, bufSize)
rc := newConn(fakeNetConn{Reader: &b1, Writer: &b2}, true, 1024, 1024)
w, _ := wc.NextWriter(BinaryMessage)
w.Write(make([]byte, bufSize+bufSize/2))
wc.WriteControl(CloseMessage, FormatCloseMessage(expectedErr.Code, expectedErr.Text), time.Now().Add(10*time.Second))
w.Close()
op, r, err := rc.NextReader()
if op != BinaryMessage || err != nil {
t.Fatalf("NextReader() returned %d, %v", op, err)
}
_, err = io.Copy(ioutil.Discard, r)
if !reflect.DeepEqual(err, expectedErr) {
t.Fatalf("io.Copy() returned %v, want %v", err, expectedErr)
}
_, _, err = rc.NextReader()
if !reflect.DeepEqual(err, expectedErr) {
t.Fatalf("NextReader() returned %v, want %v", err, expectedErr)
}
}
func TestEOFWithinFrame(t *testing.T) {
const bufSize = 64
for n := 0; ; n++ {
var b bytes.Buffer
wc := newConn(fakeNetConn{Reader: nil, Writer: &b}, false, 1024, 1024)
rc := newConn(fakeNetConn{Reader: &b, Writer: nil}, true, 1024, 1024)
w, _ := wc.NextWriter(BinaryMessage)
w.Write(make([]byte, bufSize))
w.Close()
if n >= b.Len() {
break
}
b.Truncate(n)
op, r, err := rc.NextReader()
if err == errUnexpectedEOF {
continue
}
if op != BinaryMessage || err != nil {
t.Fatalf("%d: NextReader() returned %d, %v", n, op, err)
}
_, err = io.Copy(ioutil.Discard, r)
if err != errUnexpectedEOF {
t.Fatalf("%d: io.Copy() returned %v, want %v", n, err, errUnexpectedEOF)
}
_, _, err = rc.NextReader()
if err != errUnexpectedEOF {
t.Fatalf("%d: NextReader() returned %v, want %v", n, err, errUnexpectedEOF)
}
}
}
func TestEOFBeforeFinalFrame(t *testing.T) {
const bufSize = 512
var b1, b2 bytes.Buffer
wc := newConn(fakeNetConn{Reader: nil, Writer: &b1}, false, 1024, bufSize)
rc := newConn(fakeNetConn{Reader: &b1, Writer: &b2}, true, 1024, 1024)
w, _ := wc.NextWriter(BinaryMessage)
w.Write(make([]byte, bufSize+bufSize/2))
op, r, err := rc.NextReader()
if op != BinaryMessage || err != nil {
t.Fatalf("NextReader() returned %d, %v", op, err)
}
_, err = io.Copy(ioutil.Discard, r)
if err != errUnexpectedEOF {
t.Fatalf("io.Copy() returned %v, want %v", err, errUnexpectedEOF)
}
_, _, err = rc.NextReader()
if err != errUnexpectedEOF {
t.Fatalf("NextReader() returned %v, want %v", err, errUnexpectedEOF)
}
}
func TestWriteAfterMessageWriterClose(t *testing.T) {
wc := newConn(fakeNetConn{Reader: nil, Writer: &bytes.Buffer{}}, false, 1024, 1024)
w, _ := wc.NextWriter(BinaryMessage)
io.WriteString(w, "hello")
if err := w.Close(); err != nil {
t.Fatalf("unxpected error closing message writer, %v", err)
}
if _, err := io.WriteString(w, "world"); err == nil {
t.Fatalf("no error writing after close")
}
w, _ = wc.NextWriter(BinaryMessage)
io.WriteString(w, "hello")
// close w by getting next writer
_, err := wc.NextWriter(BinaryMessage)
if err != nil {
t.Fatalf("unexpected error getting next writer, %v", err)
}
if _, err := io.WriteString(w, "world"); err == nil {
t.Fatalf("no error writing after close")
}
}
func TestReadLimit(t *testing.T) {
const readLimit = 512
@ -138,3 +319,179 @@ func TestReadLimit(t *testing.T) {
t.Fatalf("io.Copy() returned %v", err)
}
}
func TestAddrs(t *testing.T) {
c := newConn(&fakeNetConn{}, true, 1024, 1024)
if c.LocalAddr() != localAddr {
t.Errorf("LocalAddr = %v, want %v", c.LocalAddr(), localAddr)
}
if c.RemoteAddr() != remoteAddr {
t.Errorf("RemoteAddr = %v, want %v", c.RemoteAddr(), remoteAddr)
}
}
func TestUnderlyingConn(t *testing.T) {
var b1, b2 bytes.Buffer
fc := fakeNetConn{Reader: &b1, Writer: &b2}
c := newConn(fc, true, 1024, 1024)
ul := c.UnderlyingConn()
if ul != fc {
t.Fatalf("Underlying conn is not what it should be.")
}
}
func TestBufioReadBytes(t *testing.T) {
// Test calling bufio.ReadBytes for value longer than read buffer size.
m := make([]byte, 512)
m[len(m)-1] = '\n'
var b1, b2 bytes.Buffer
wc := newConn(fakeNetConn{Reader: nil, Writer: &b1}, false, len(m)+64, len(m)+64)
rc := newConn(fakeNetConn{Reader: &b1, Writer: &b2}, true, len(m)-64, len(m)-64)
w, _ := wc.NextWriter(BinaryMessage)
w.Write(m)
w.Close()
op, r, err := rc.NextReader()
if op != BinaryMessage || err != nil {
t.Fatalf("NextReader() returned %d, %v", op, err)
}
br := bufio.NewReader(r)
p, err := br.ReadBytes('\n')
if err != nil {
t.Fatalf("ReadBytes() returned %v", err)
}
if len(p) != len(m) {
t.Fatalf("read returnd %d bytes, want %d bytes", len(p), len(m))
}
}
var closeErrorTests = []struct {
err error
codes []int
ok bool
}{
{&CloseError{Code: CloseNormalClosure}, []int{CloseNormalClosure}, true},
{&CloseError{Code: CloseNormalClosure}, []int{CloseNoStatusReceived}, false},
{&CloseError{Code: CloseNormalClosure}, []int{CloseNoStatusReceived, CloseNormalClosure}, true},
{errors.New("hello"), []int{CloseNormalClosure}, false},
}
func TestCloseError(t *testing.T) {
for _, tt := range closeErrorTests {
ok := IsCloseError(tt.err, tt.codes...)
if ok != tt.ok {
t.Errorf("IsCloseError(%#v, %#v) returned %v, want %v", tt.err, tt.codes, ok, tt.ok)
}
}
}
var unexpectedCloseErrorTests = []struct {
err error
codes []int
ok bool
}{
{&CloseError{Code: CloseNormalClosure}, []int{CloseNormalClosure}, false},
{&CloseError{Code: CloseNormalClosure}, []int{CloseNoStatusReceived}, true},
{&CloseError{Code: CloseNormalClosure}, []int{CloseNoStatusReceived, CloseNormalClosure}, false},
{errors.New("hello"), []int{CloseNormalClosure}, false},
}
func TestUnexpectedCloseErrors(t *testing.T) {
for _, tt := range unexpectedCloseErrorTests {
ok := IsUnexpectedCloseError(tt.err, tt.codes...)
if ok != tt.ok {
t.Errorf("IsUnexpectedCloseError(%#v, %#v) returned %v, want %v", tt.err, tt.codes, ok, tt.ok)
}
}
}
type blockingWriter struct {
c1, c2 chan struct{}
}
func (w blockingWriter) Write(p []byte) (int, error) {
// Allow main to continue
close(w.c1)
// Wait for panic in main
<-w.c2
return len(p), nil
}
func TestConcurrentWritePanic(t *testing.T) {
w := blockingWriter{make(chan struct{}), make(chan struct{})}
c := newConn(fakeNetConn{Reader: nil, Writer: w}, false, 1024, 1024)
go func() {
c.WriteMessage(TextMessage, []byte{})
}()
// wait for goroutine to block in write.
<-w.c1
defer func() {
close(w.c2)
if v := recover(); v != nil {
return
}
}()
c.WriteMessage(TextMessage, []byte{})
t.Fatal("should not get here")
}
type failingReader struct{}
func (r failingReader) Read(p []byte) (int, error) {
return 0, io.EOF
}
func TestFailedConnectionReadPanic(t *testing.T) {
c := newConn(fakeNetConn{Reader: failingReader{}, Writer: nil}, false, 1024, 1024)
defer func() {
if v := recover(); v != nil {
return
}
}()
for i := 0; i < 20000; i++ {
c.ReadMessage()
}
t.Fatal("should not get here")
}
func TestBufioReuse(t *testing.T) {
brw := bufio.NewReadWriter(bufio.NewReader(nil), bufio.NewWriter(nil))
c := newConnBRW(nil, false, 0, 0, brw)
if c.br != brw.Reader {
t.Error("connection did not reuse bufio.Reader")
}
var wh writeHook
brw.Writer.Reset(&wh)
brw.WriteByte(0)
brw.Flush()
if &c.writeBuf[0] != &wh.p[0] {
t.Error("connection did not reuse bufio.Writer")
}
brw = bufio.NewReadWriter(bufio.NewReaderSize(nil, 0), bufio.NewWriterSize(nil, 0))
c = newConnBRW(nil, false, 0, 0, brw)
if c.br == brw.Reader {
t.Error("connection used bufio.Reader with small size")
}
brw.Writer.Reset(&wh)
brw.WriteByte(0)
brw.Flush()
if &c.writeBuf[0] != &wh.p[0] {
t.Error("connection used bufio.Writer with small size")
}
}

View File

@ -1,4 +1,4 @@
// Copyright 2013 Gary Burd. All rights reserved.
// Copyright 2013 The Gorilla WebSocket Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
@ -6,23 +6,25 @@
//
// Overview
//
// The Conn type represents a WebSocket connection. A server application calls
// the Upgrade function from an HTTP request handler to get a pointer to a
// Conn:
// The Conn type represents a WebSocket connection. A server application uses
// the Upgrade function from an Upgrader object with a HTTP request handler
// to get a pointer to a Conn:
//
// var upgrader = websocket.Upgrader{
// ReadBufferSize: 1024,
// WriteBufferSize: 1024,
// }
//
// func handler(w http.ResponseWriter, r *http.Request) {
// conn, err := websocket.Upgrade(w, r.Header, nil, 1024, 1024)
// if _, ok := err.(websocket.HandshakeError); ok {
// http.Error(w, "Not a websocket handshake", 400)
// return
// } else if err != nil {
// conn, err := upgrader.Upgrade(w, r, nil)
// if err != nil {
// log.Println(err)
// return
// }
// ... Use conn to send and receive messages.
// }
//
// Call the connection WriteMessage and ReadMessages methods to send and
// Call the connection's WriteMessage and ReadMessage methods to send and
// receive messages as a slice of bytes. This snippet of code shows how to echo
// messages using these methods:
//
@ -31,7 +33,7 @@
// if err != nil {
// return
// }
// if _, err := conn.WriteMessaage(messageType, p); err != nil {
// if err = conn.WriteMessage(messageType, p); err != nil {
// return err
// }
// }
@ -44,8 +46,7 @@
// method to get an io.WriteCloser, write the message to the writer and close
// the writer when done. To receive a message, call the connection NextReader
// method to get an io.Reader and read until io.EOF is returned. This snippet
// snippet shows how to echo messages using the NextWriter and NextReader
// methods:
// shows how to echo messages using the NextWriter and NextReader methods:
//
// for {
// messageType, r, err := conn.NextReader()
@ -84,20 +85,96 @@
// and pong. Call the connection WriteControl, WriteMessage or NextWriter
// methods to send a control message to the peer.
//
// Connections handle received ping and pong messages by invoking a callback
// function set with SetPingHandler and SetPongHandler methods. These callback
// functions can be invoked from the ReadMessage method, the NextReader method
// or from a call to the data message reader returned from NextReader.
// Connections handle received close messages by sending a close message to the
// peer and returning a *CloseError from the the NextReader, ReadMessage or the
// message Read method.
//
// Connections handle received close messages by returning an error from the
// ReadMessage method, the NextReader method or from a call to the data message
// reader returned from NextReader.
// Connections handle received ping and pong messages by invoking callback
// functions set with SetPingHandler and SetPongHandler methods. The callback
// functions are called from the NextReader, ReadMessage and the message Read
// methods.
//
// The default ping handler sends a pong to the peer. The application's reading
// goroutine can block for a short time while the handler writes the pong data
// to the connection.
//
// The application must read the connection to process ping, pong and close
// messages sent from the peer. If the application is not otherwise interested
// in messages from the peer, then the application should start a goroutine to
// read and discard messages from the peer. A simple example is:
//
// func readLoop(c *websocket.Conn) {
// for {
// if _, _, err := c.NextReader(); err != nil {
// c.Close()
// break
// }
// }
// }
//
// Concurrency
//
// A Conn supports a single concurrent caller to the write methods (NextWriter,
// SetWriteDeadline, WriteMessage) and a single concurrent caller to the read
// methods (NextReader, SetReadDeadline, ReadMessage). The Close and
// WriteControl methods can be called concurrently with all other methods.
// Connections support one concurrent reader and one concurrent writer.
//
// Applications are responsible for ensuring that no more than one goroutine
// calls the write methods (NextWriter, SetWriteDeadline, WriteMessage,
// WriteJSON, EnableWriteCompression, SetCompressionLevel) concurrently and
// that no more than one goroutine calls the read methods (NextReader,
// SetReadDeadline, ReadMessage, ReadJSON, SetPongHandler, SetPingHandler)
// concurrently.
//
// The Close and WriteControl methods can be called concurrently with all other
// methods.
//
// Origin Considerations
//
// Web browsers allow Javascript applications to open a WebSocket connection to
// any host. It's up to the server to enforce an origin policy using the Origin
// request header sent by the browser.
//
// The Upgrader calls the function specified in the CheckOrigin field to check
// the origin. If the CheckOrigin function returns false, then the Upgrade
// method fails the WebSocket handshake with HTTP status 403.
//
// If the CheckOrigin field is nil, then the Upgrader uses a safe default: fail
// the handshake if the Origin request header is present and not equal to the
// Host request header.
//
// An application can allow connections from any origin by specifying a
// function that always returns true:
//
// var upgrader = websocket.Upgrader{
// CheckOrigin: func(r *http.Request) bool { return true },
// }
//
// The deprecated Upgrade function does not enforce an origin policy. It's the
// application's responsibility to check the Origin header before calling
// Upgrade.
//
// Compression EXPERIMENTAL
//
// Per message compression extensions (RFC 7692) are experimentally supported
// by this package in a limited capacity. Setting the EnableCompression option
// to true in Dialer or Upgrader will attempt to negotiate per message deflate
// support.
//
// var upgrader = websocket.Upgrader{
// EnableCompression: true,
// }
//
// If compression was successfully negotiated with the connection's peer, any
// message received in compressed form will be automatically decompressed.
// All Read methods will return uncompressed bytes.
//
// Per message compression of messages written to a connection can be enabled
// or disabled by calling the corresponding Conn method:
//
// conn.EnableWriteCompression(false)
//
// Currently this package does not support compression with "context takeover".
// This means that messages must be compressed and decompressed in isolation,
// without retaining sliding window or dictionary state across messages. For
// more details refer to RFC 7692.
//
// Use of compression is experimental and may result in decreased performance.
package websocket

46
vendor/github.com/gorilla/websocket/example_test.go generated vendored Normal file
View File

@ -0,0 +1,46 @@
// Copyright 2015 The Gorilla WebSocket Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package websocket_test
import (
"log"
"net/http"
"testing"
"github.com/gorilla/websocket"
)
var (
c *websocket.Conn
req *http.Request
)
// The websocket.IsUnexpectedCloseError function is useful for identifying
// application and protocol errors.
//
// This server application works with a client application running in the
// browser. The client application does not explicitly close the websocket. The
// only expected close message from the client has the code
// websocket.CloseGoingAway. All other other close messages are likely the
// result of an application or protocol error and are logged to aid debugging.
func ExampleIsUnexpectedCloseError() {
for {
messageType, p, err := c.ReadMessage()
if err != nil {
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway) {
log.Printf("error: %v, user-agent: %v", err, req.Header.Get("User-Agent"))
}
return
}
processMesage(messageType, p)
}
}
func processMesage(mt int, p []byte) {}
// TestX prevents godoc from showing this entire file in the example. Remove
// this function when a second example is added.
func TestX(t *testing.T) {}

View File

@ -1,4 +1,4 @@
// Copyright 2013 Gary Burd. All rights reserved.
// Copyright 2013 The Gorilla WebSocket Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
@ -6,9 +6,10 @@ package websocket
import (
"encoding/json"
"io"
)
// DEPRECATED: use c.WriteJSON instead.
// WriteJSON is deprecated, use c.WriteJSON instead.
func WriteJSON(c *Conn, v interface{}) error {
return c.WriteJSON(v)
}
@ -30,7 +31,7 @@ func (c *Conn) WriteJSON(v interface{}) error {
return err2
}
// DEPRECATED: use c.WriteJSON instead.
// ReadJSON is deprecated, use c.ReadJSON instead.
func ReadJSON(c *Conn, v interface{}) error {
return c.ReadJSON(v)
}
@ -38,12 +39,17 @@ func ReadJSON(c *Conn, v interface{}) error {
// ReadJSON reads the next JSON-encoded message from the connection and stores
// it in the value pointed to by v.
//
// See the documentation for the encoding/json Marshal function for details
// See the documentation for the encoding/json Unmarshal function for details
// about the conversion of JSON to a Go value.
func (c *Conn) ReadJSON(v interface{}) error {
_, r, err := c.NextReader()
if err != nil {
return err
}
return json.NewDecoder(r).Decode(v)
err = json.NewDecoder(r).Decode(v)
if err == io.EOF {
// One value is expected in the message.
err = io.ErrUnexpectedEOF
}
return err
}

View File

@ -1,4 +1,4 @@
// Copyright 2013 Gary Burd. All rights reserved.
// Copyright 2013 The Gorilla WebSocket Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
@ -6,6 +6,8 @@ package websocket
import (
"bytes"
"encoding/json"
"io"
"reflect"
"testing"
)
@ -36,6 +38,60 @@ func TestJSON(t *testing.T) {
}
}
func TestPartialJSONRead(t *testing.T) {
var buf bytes.Buffer
c := fakeNetConn{&buf, &buf}
wc := newConn(c, true, 1024, 1024)
rc := newConn(c, false, 1024, 1024)
var v struct {
A int
B string
}
v.A = 1
v.B = "hello"
messageCount := 0
// Partial JSON values.
data, err := json.Marshal(v)
if err != nil {
t.Fatal(err)
}
for i := len(data) - 1; i >= 0; i-- {
if err := wc.WriteMessage(TextMessage, data[:i]); err != nil {
t.Fatal(err)
}
messageCount++
}
// Whitespace.
if err := wc.WriteMessage(TextMessage, []byte(" ")); err != nil {
t.Fatal(err)
}
messageCount++
// Close.
if err := wc.WriteMessage(CloseMessage, FormatCloseMessage(CloseNormalClosure, "")); err != nil {
t.Fatal(err)
}
for i := 0; i < messageCount; i++ {
err := rc.ReadJSON(&v)
if err != io.ErrUnexpectedEOF {
t.Error("read", i, err)
}
}
err = rc.ReadJSON(&v)
if _, ok := err.(*CloseError); !ok {
t.Error("final", err)
}
}
func TestDeprecatedJSON(t *testing.T) {
var buf bytes.Buffer
c := fakeNetConn{&buf, &buf}

55
vendor/github.com/gorilla/websocket/mask.go generated vendored Normal file
View File

@ -0,0 +1,55 @@
// Copyright 2016 The Gorilla WebSocket Authors. All rights reserved. Use of
// this source code is governed by a BSD-style license that can be found in the
// LICENSE file.
// +build !appengine
package websocket
import "unsafe"
const wordSize = int(unsafe.Sizeof(uintptr(0)))
func maskBytes(key [4]byte, pos int, b []byte) int {
// Mask one byte at a time for small buffers.
if len(b) < 2*wordSize {
for i := range b {
b[i] ^= key[pos&3]
pos++
}
return pos & 3
}
// Mask one byte at a time to word boundary.
if n := int(uintptr(unsafe.Pointer(&b[0]))) % wordSize; n != 0 {
n = wordSize - n
for i := range b[:n] {
b[i] ^= key[pos&3]
pos++
}
b = b[n:]
}
// Create aligned word size key.
var k [wordSize]byte
for i := range k {
k[i] = key[(pos+i)&3]
}
kw := *(*uintptr)(unsafe.Pointer(&k))
// Mask one word at a time.
n := (len(b) / wordSize) * wordSize
for i := 0; i < n; i += wordSize {
*(*uintptr)(unsafe.Pointer(uintptr(unsafe.Pointer(&b[0])) + uintptr(i))) ^= kw
}
// Mask one byte at a time for remaining bytes.
b = b[n:]
for i := range b {
b[i] ^= key[pos&3]
pos++
}
return pos & 3
}

15
vendor/github.com/gorilla/websocket/mask_safe.go generated vendored Normal file
View File

@ -0,0 +1,15 @@
// Copyright 2016 The Gorilla WebSocket Authors. All rights reserved. Use of
// this source code is governed by a BSD-style license that can be found in the
// LICENSE file.
// +build appengine
package websocket
func maskBytes(key [4]byte, pos int, b []byte) int {
for i := range b {
b[i] ^= key[pos&3]
pos++
}
return pos & 3
}

73
vendor/github.com/gorilla/websocket/mask_test.go generated vendored Normal file
View File

@ -0,0 +1,73 @@
// Copyright 2016 The Gorilla WebSocket Authors. All rights reserved. Use of
// this source code is governed by a BSD-style license that can be found in the
// LICENSE file.
// Require 1.7 for sub-bencmarks
// +build go1.7,!appengine
package websocket
import (
"fmt"
"testing"
)
func maskBytesByByte(key [4]byte, pos int, b []byte) int {
for i := range b {
b[i] ^= key[pos&3]
pos++
}
return pos & 3
}
func notzero(b []byte) int {
for i := range b {
if b[i] != 0 {
return i
}
}
return -1
}
func TestMaskBytes(t *testing.T) {
key := [4]byte{1, 2, 3, 4}
for size := 1; size <= 1024; size++ {
for align := 0; align < wordSize; align++ {
for pos := 0; pos < 4; pos++ {
b := make([]byte, size+align)[align:]
maskBytes(key, pos, b)
maskBytesByByte(key, pos, b)
if i := notzero(b); i >= 0 {
t.Errorf("size:%d, align:%d, pos:%d, offset:%d", size, align, pos, i)
}
}
}
}
}
func BenchmarkMaskBytes(b *testing.B) {
for _, size := range []int{2, 4, 8, 16, 32, 512, 1024} {
b.Run(fmt.Sprintf("size-%d", size), func(b *testing.B) {
for _, align := range []int{wordSize / 2} {
b.Run(fmt.Sprintf("align-%d", align), func(b *testing.B) {
for _, fn := range []struct {
name string
fn func(key [4]byte, pos int, b []byte) int
}{
{"byte", maskBytesByByte},
{"word", maskBytes},
} {
b.Run(fn.name, func(b *testing.B) {
key := newMaskKey()
data := make([]byte, size+align)[align:]
for i := 0; i < b.N; i++ {
fn.fn(key, 0, data)
}
b.SetBytes(int64(len(data)))
})
}
})
}
})
}
}

103
vendor/github.com/gorilla/websocket/prepared.go generated vendored Normal file
View File

@ -0,0 +1,103 @@
// Copyright 2017 The Gorilla WebSocket Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package websocket
import (
"bytes"
"net"
"sync"
"time"
)
// PreparedMessage caches on the wire representations of a message payload.
// Use PreparedMessage to efficiently send a message payload to multiple
// connections. PreparedMessage is especially useful when compression is used
// because the CPU and memory expensive compression operation can be executed
// once for a given set of compression options.
type PreparedMessage struct {
messageType int
data []byte
err error
mu sync.Mutex
frames map[prepareKey]*preparedFrame
}
// prepareKey defines a unique set of options to cache prepared frames in PreparedMessage.
type prepareKey struct {
isServer bool
compress bool
compressionLevel int
}
// preparedFrame contains data in wire representation.
type preparedFrame struct {
once sync.Once
data []byte
}
// NewPreparedMessage returns an initialized PreparedMessage. You can then send
// it to connection using WritePreparedMessage method. Valid wire
// representation will be calculated lazily only once for a set of current
// connection options.
func NewPreparedMessage(messageType int, data []byte) (*PreparedMessage, error) {
pm := &PreparedMessage{
messageType: messageType,
frames: make(map[prepareKey]*preparedFrame),
data: data,
}
// Prepare a plain server frame.
_, frameData, err := pm.frame(prepareKey{isServer: true, compress: false})
if err != nil {
return nil, err
}
// To protect against caller modifying the data argument, remember the data
// copied to the plain server frame.
pm.data = frameData[len(frameData)-len(data):]
return pm, nil
}
func (pm *PreparedMessage) frame(key prepareKey) (int, []byte, error) {
pm.mu.Lock()
frame, ok := pm.frames[key]
if !ok {
frame = &preparedFrame{}
pm.frames[key] = frame
}
pm.mu.Unlock()
var err error
frame.once.Do(func() {
// Prepare a frame using a 'fake' connection.
// TODO: Refactor code in conn.go to allow more direct construction of
// the frame.
mu := make(chan bool, 1)
mu <- true
var nc prepareConn
c := &Conn{
conn: &nc,
mu: mu,
isServer: key.isServer,
compressionLevel: key.compressionLevel,
enableWriteCompression: true,
writeBuf: make([]byte, defaultWriteBufferSize+maxFrameHeaderSize),
}
if key.compress {
c.newCompressionWriter = compressNoContextTakeover
}
err = c.WriteMessage(pm.messageType, pm.data)
frame.data = nc.buf.Bytes()
})
return pm.messageType, frame.data, err
}
type prepareConn struct {
buf bytes.Buffer
net.Conn
}
func (pc *prepareConn) Write(p []byte) (int, error) { return pc.buf.Write(p) }
func (pc *prepareConn) SetWriteDeadline(t time.Time) error { return nil }

74
vendor/github.com/gorilla/websocket/prepared_test.go generated vendored Normal file
View File

@ -0,0 +1,74 @@
// Copyright 2017 The Gorilla WebSocket Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package websocket
import (
"bytes"
"compress/flate"
"math/rand"
"testing"
)
var preparedMessageTests = []struct {
messageType int
isServer bool
enableWriteCompression bool
compressionLevel int
}{
// Server
{TextMessage, true, false, flate.BestSpeed},
{TextMessage, true, true, flate.BestSpeed},
{TextMessage, true, true, flate.BestCompression},
{PingMessage, true, false, flate.BestSpeed},
{PingMessage, true, true, flate.BestSpeed},
// Client
{TextMessage, false, false, flate.BestSpeed},
{TextMessage, false, true, flate.BestSpeed},
{TextMessage, false, true, flate.BestCompression},
{PingMessage, false, false, flate.BestSpeed},
{PingMessage, false, true, flate.BestSpeed},
}
func TestPreparedMessage(t *testing.T) {
for _, tt := range preparedMessageTests {
var data = []byte("this is a test")
var buf bytes.Buffer
c := newConn(fakeNetConn{Reader: nil, Writer: &buf}, tt.isServer, 1024, 1024)
if tt.enableWriteCompression {
c.newCompressionWriter = compressNoContextTakeover
}
c.SetCompressionLevel(tt.compressionLevel)
// Seed random number generator for consistent frame mask.
rand.Seed(1234)
if err := c.WriteMessage(tt.messageType, data); err != nil {
t.Fatal(err)
}
want := buf.String()
pm, err := NewPreparedMessage(tt.messageType, data)
if err != nil {
t.Fatal(err)
}
// Scribble on data to ensure that NewPreparedMessage takes a snapshot.
copy(data, "hello world")
// Seed random number generator for consistent frame mask.
rand.Seed(1234)
buf.Reset()
if err := c.WritePreparedMessage(pm); err != nil {
t.Fatal(err)
}
got := buf.String()
if got != want {
t.Errorf("write message != prepared message for %+v", tt)
}
}
}

View File

@ -1,4 +1,4 @@
// Copyright 2013 Gary Burd. All rights reserved.
// Copyright 2013 The Gorilla WebSocket Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
@ -9,7 +9,9 @@ import (
"errors"
"net"
"net/http"
"net/url"
"strings"
"time"
)
// HandshakeError describes an error with the handshake from the peer.
@ -19,76 +21,180 @@ type HandshakeError struct {
func (e HandshakeError) Error() string { return e.message }
// Upgrader specifies parameters for upgrading an HTTP connection to a
// WebSocket connection.
type Upgrader struct {
// HandshakeTimeout specifies the duration for the handshake to complete.
HandshakeTimeout time.Duration
// ReadBufferSize and WriteBufferSize specify I/O buffer sizes. If a buffer
// size is zero, then buffers allocated by the HTTP server are used. The
// I/O buffer sizes do not limit the size of the messages that can be sent
// or received.
ReadBufferSize, WriteBufferSize int
// Subprotocols specifies the server's supported protocols in order of
// preference. If this field is set, then the Upgrade method negotiates a
// subprotocol by selecting the first match in this list with a protocol
// requested by the client.
Subprotocols []string
// Error specifies the function for generating HTTP error responses. If Error
// is nil, then http.Error is used to generate the HTTP response.
Error func(w http.ResponseWriter, r *http.Request, status int, reason error)
// CheckOrigin returns true if the request Origin header is acceptable. If
// CheckOrigin is nil, the host in the Origin header must not be set or
// must match the host of the request.
CheckOrigin func(r *http.Request) bool
// EnableCompression specify if the server should attempt to negotiate per
// message compression (RFC 7692). Setting this value to true does not
// guarantee that compression will be supported. Currently only "no context
// takeover" modes are supported.
EnableCompression bool
}
func (u *Upgrader) returnError(w http.ResponseWriter, r *http.Request, status int, reason string) (*Conn, error) {
err := HandshakeError{reason}
if u.Error != nil {
u.Error(w, r, status, err)
} else {
w.Header().Set("Sec-Websocket-Version", "13")
http.Error(w, http.StatusText(status), status)
}
return nil, err
}
// checkSameOrigin returns true if the origin is not set or is equal to the request host.
func checkSameOrigin(r *http.Request) bool {
origin := r.Header["Origin"]
if len(origin) == 0 {
return true
}
u, err := url.Parse(origin[0])
if err != nil {
return false
}
return u.Host == r.Host
}
func (u *Upgrader) selectSubprotocol(r *http.Request, responseHeader http.Header) string {
if u.Subprotocols != nil {
clientProtocols := Subprotocols(r)
for _, serverProtocol := range u.Subprotocols {
for _, clientProtocol := range clientProtocols {
if clientProtocol == serverProtocol {
return clientProtocol
}
}
}
} else if responseHeader != nil {
return responseHeader.Get("Sec-Websocket-Protocol")
}
return ""
}
// Upgrade upgrades the HTTP server connection to the WebSocket protocol.
//
// The application is responsible for checking the request origin before
// calling Upgrade. An example implementation of the same origin policy is:
// The responseHeader is included in the response to the client's upgrade
// request. Use the responseHeader to specify cookies (Set-Cookie) and the
// application negotiated subprotocol (Sec-Websocket-Protocol).
//
// if req.Header.Get("Origin") != "http://"+req.Host {
// http.Error(w, "Origin not allowed", 403)
// return
// }
//
// If the endpoint supports WebSocket subprotocols, then the application is
// responsible for selecting a subprotocol that is acceptable to the client and
// echoing that value back to the client. Use the Subprotocols function to get
// the list of protocols specified by the client. Use the
// Sec-Websocket-Protocol response header to echo the selected protocol back to
// the client.
//
// Appilcations can set cookies by adding a Set-Cookie header to the response
// header.
//
// If the request is not a valid WebSocket handshake, then Upgrade returns an
// error of type HandshakeError. Applications should handle this error by
// replying to the client with an HTTP error response.
func Upgrade(w http.ResponseWriter, r *http.Request, responseHeader http.Header, readBufSize, writeBufSize int) (*Conn, error) {
// If the upgrade fails, then Upgrade replies to the client with an HTTP error
// response.
func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeader http.Header) (*Conn, error) {
if r.Method != "GET" {
return u.returnError(w, r, http.StatusMethodNotAllowed, "websocket: not a websocket handshake: request method is not GET")
}
if values := r.Header["Sec-Websocket-Version"]; len(values) == 0 || values[0] != "13" {
return nil, HandshakeError{"websocket: version != 13"}
if _, ok := responseHeader["Sec-Websocket-Extensions"]; ok {
return u.returnError(w, r, http.StatusInternalServerError, "websocket: application specific 'Sec-Websocket-Extensions' headers are unsupported")
}
if !tokenListContainsValue(r.Header, "Connection", "upgrade") {
return nil, HandshakeError{"websocket: connection header != upgrade"}
return u.returnError(w, r, http.StatusBadRequest, "websocket: not a websocket handshake: 'upgrade' token not found in 'Connection' header")
}
if !tokenListContainsValue(r.Header, "Upgrade", "websocket") {
return nil, HandshakeError{"websocket: upgrade != websocket"}
return u.returnError(w, r, http.StatusBadRequest, "websocket: not a websocket handshake: 'websocket' token not found in 'Upgrade' header")
}
var challengeKey string
values := r.Header["Sec-Websocket-Key"]
if len(values) == 0 || values[0] == "" {
return nil, HandshakeError{"websocket: key missing or blank"}
if !tokenListContainsValue(r.Header, "Sec-Websocket-Version", "13") {
return u.returnError(w, r, http.StatusBadRequest, "websocket: unsupported version: 13 not found in 'Sec-Websocket-Version' header")
}
checkOrigin := u.CheckOrigin
if checkOrigin == nil {
checkOrigin = checkSameOrigin
}
if !checkOrigin(r) {
return u.returnError(w, r, http.StatusForbidden, "websocket: 'Origin' header value not allowed")
}
challengeKey := r.Header.Get("Sec-Websocket-Key")
if challengeKey == "" {
return u.returnError(w, r, http.StatusBadRequest, "websocket: not a websocket handshake: `Sec-Websocket-Key' header is missing or blank")
}
subprotocol := u.selectSubprotocol(r, responseHeader)
// Negotiate PMCE
var compress bool
if u.EnableCompression {
for _, ext := range parseExtensions(r.Header) {
if ext[""] != "permessage-deflate" {
continue
}
compress = true
break
}
}
challengeKey = values[0]
var (
netConn net.Conn
br *bufio.Reader
err error
)
h, ok := w.(http.Hijacker)
if !ok {
return nil, errors.New("websocket: response does not implement http.Hijacker")
return u.returnError(w, r, http.StatusInternalServerError, "websocket: response does not implement http.Hijacker")
}
var brw *bufio.ReadWriter
netConn, brw, err = h.Hijack()
if err != nil {
return u.returnError(w, r, http.StatusInternalServerError, err.Error())
}
var rw *bufio.ReadWriter
netConn, rw, err = h.Hijack()
br = rw.Reader
if br.Buffered() > 0 {
if brw.Reader.Buffered() > 0 {
netConn.Close()
return nil, errors.New("websocket: client sent data before handshake is complete")
}
c := newConn(netConn, true, readBufSize, writeBufSize)
c := newConnBRW(netConn, true, u.ReadBufferSize, u.WriteBufferSize, brw)
c.subprotocol = subprotocol
if compress {
c.newCompressionWriter = compressNoContextTakeover
c.newDecompressionReader = decompressNoContextTakeover
}
p := c.writeBuf[:0]
p = append(p, "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: "...)
p = append(p, computeAcceptKey(challengeKey)...)
p = append(p, "\r\n"...)
if c.subprotocol != "" {
p = append(p, "Sec-Websocket-Protocol: "...)
p = append(p, c.subprotocol...)
p = append(p, "\r\n"...)
}
if compress {
p = append(p, "Sec-Websocket-Extensions: permessage-deflate; server_no_context_takeover; client_no_context_takeover\r\n"...)
}
for k, vs := range responseHeader {
if k == "Sec-Websocket-Protocol" {
continue
}
for _, v := range vs {
p = append(p, k...)
p = append(p, ": "...)
@ -105,14 +211,64 @@ func Upgrade(w http.ResponseWriter, r *http.Request, responseHeader http.Header,
}
p = append(p, "\r\n"...)
// Clear deadlines set by HTTP server.
netConn.SetDeadline(time.Time{})
if u.HandshakeTimeout > 0 {
netConn.SetWriteDeadline(time.Now().Add(u.HandshakeTimeout))
}
if _, err = netConn.Write(p); err != nil {
netConn.Close()
return nil, err
}
if u.HandshakeTimeout > 0 {
netConn.SetWriteDeadline(time.Time{})
}
return c, nil
}
// Upgrade upgrades the HTTP server connection to the WebSocket protocol.
//
// This function is deprecated, use websocket.Upgrader instead.
//
// The application is responsible for checking the request origin before
// calling Upgrade. An example implementation of the same origin policy is:
//
// if req.Header.Get("Origin") != "http://"+req.Host {
// http.Error(w, "Origin not allowed", 403)
// return
// }
//
// If the endpoint supports subprotocols, then the application is responsible
// for negotiating the protocol used on the connection. Use the Subprotocols()
// function to get the subprotocols requested by the client. Use the
// Sec-Websocket-Protocol response header to specify the subprotocol selected
// by the application.
//
// The responseHeader is included in the response to the client's upgrade
// request. Use the responseHeader to specify cookies (Set-Cookie) and the
// negotiated subprotocol (Sec-Websocket-Protocol).
//
// The connection buffers IO to the underlying network connection. The
// readBufSize and writeBufSize parameters specify the size of the buffers to
// use. Messages can be larger than the buffers.
//
// If the request is not a valid WebSocket handshake, then Upgrade returns an
// error of type HandshakeError. Applications should handle this error by
// replying to the client with an HTTP error response.
func Upgrade(w http.ResponseWriter, r *http.Request, responseHeader http.Header, readBufSize, writeBufSize int) (*Conn, error) {
u := Upgrader{ReadBufferSize: readBufSize, WriteBufferSize: writeBufSize}
u.Error = func(w http.ResponseWriter, r *http.Request, status int, reason error) {
// don't return errors to maintain backwards compatibility
}
u.CheckOrigin = func(r *http.Request) bool {
// allow all connections by default
return true
}
return u.Upgrade(w, r, responseHeader)
}
// Subprotocols returns the subprotocols requested by the client in the
// Sec-Websocket-Protocol header.
func Subprotocols(r *http.Request) []string {
@ -126,3 +282,10 @@ func Subprotocols(r *http.Request) []string {
}
return protocols
}
// IsWebSocketUpgrade returns true if the client requested upgrade to the
// WebSocket protocol.
func IsWebSocketUpgrade(r *http.Request) bool {
return tokenListContainsValue(r.Header, "Connection", "upgrade") &&
tokenListContainsValue(r.Header, "Upgrade", "websocket")
}

View File

@ -1,4 +1,4 @@
// Copyright 2013 Gary Burd. All rights reserved.
// Copyright 2013 The Gorilla WebSocket Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
@ -31,3 +31,21 @@ func TestSubprotocols(t *testing.T) {
}
}
}
var isWebSocketUpgradeTests = []struct {
ok bool
h http.Header
}{
{false, http.Header{"Upgrade": {"websocket"}}},
{false, http.Header{"Connection": {"upgrade"}}},
{true, http.Header{"Connection": {"upgRade"}, "Upgrade": {"WebSocket"}}},
}
func TestIsWebSocketUpgrade(t *testing.T) {
for _, tt := range isWebSocketUpgradeTests {
ok := IsWebSocketUpgrade(&http.Request{Header: tt.h})
if tt.ok != ok {
t.Errorf("IsWebSocketUpgrade(%v) returned %v, want %v", tt.h, ok, tt.ok)
}
}
}

View File

@ -1,4 +1,4 @@
// Copyright 2013 Gary Burd. All rights reserved.
// Copyright 2013 The Gorilla WebSocket Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
@ -13,19 +13,6 @@ import (
"strings"
)
// tokenListContainsValue returns true if the 1#token header with the given
// name contains token.
func tokenListContainsValue(header http.Header, name string, value string) bool {
for _, v := range header[name] {
for _, s := range strings.Split(v, ",") {
if strings.EqualFold(value, strings.TrimSpace(s)) {
return true
}
}
}
return false
}
var keyGUID = []byte("258EAFA5-E914-47DA-95CA-C5AB0DC85B11")
func computeAcceptKey(challengeKey string) string {
@ -42,3 +29,186 @@ func generateChallengeKey() (string, error) {
}
return base64.StdEncoding.EncodeToString(p), nil
}
// Octet types from RFC 2616.
var octetTypes [256]byte
const (
isTokenOctet = 1 << iota
isSpaceOctet
)
func init() {
// From RFC 2616
//
// OCTET = <any 8-bit sequence of data>
// CHAR = <any US-ASCII character (octets 0 - 127)>
// CTL = <any US-ASCII control character (octets 0 - 31) and DEL (127)>
// CR = <US-ASCII CR, carriage return (13)>
// LF = <US-ASCII LF, linefeed (10)>
// SP = <US-ASCII SP, space (32)>
// HT = <US-ASCII HT, horizontal-tab (9)>
// <"> = <US-ASCII double-quote mark (34)>
// CRLF = CR LF
// LWS = [CRLF] 1*( SP | HT )
// TEXT = <any OCTET except CTLs, but including LWS>
// separators = "(" | ")" | "<" | ">" | "@" | "," | ";" | ":" | "\" | <">
// | "/" | "[" | "]" | "?" | "=" | "{" | "}" | SP | HT
// token = 1*<any CHAR except CTLs or separators>
// qdtext = <any TEXT except <">>
for c := 0; c < 256; c++ {
var t byte
isCtl := c <= 31 || c == 127
isChar := 0 <= c && c <= 127
isSeparator := strings.IndexRune(" \t\"(),/:;<=>?@[]\\{}", rune(c)) >= 0
if strings.IndexRune(" \t\r\n", rune(c)) >= 0 {
t |= isSpaceOctet
}
if isChar && !isCtl && !isSeparator {
t |= isTokenOctet
}
octetTypes[c] = t
}
}
func skipSpace(s string) (rest string) {
i := 0
for ; i < len(s); i++ {
if octetTypes[s[i]]&isSpaceOctet == 0 {
break
}
}
return s[i:]
}
func nextToken(s string) (token, rest string) {
i := 0
for ; i < len(s); i++ {
if octetTypes[s[i]]&isTokenOctet == 0 {
break
}
}
return s[:i], s[i:]
}
func nextTokenOrQuoted(s string) (value string, rest string) {
if !strings.HasPrefix(s, "\"") {
return nextToken(s)
}
s = s[1:]
for i := 0; i < len(s); i++ {
switch s[i] {
case '"':
return s[:i], s[i+1:]
case '\\':
p := make([]byte, len(s)-1)
j := copy(p, s[:i])
escape := true
for i = i + 1; i < len(s); i++ {
b := s[i]
switch {
case escape:
escape = false
p[j] = b
j += 1
case b == '\\':
escape = true
case b == '"':
return string(p[:j]), s[i+1:]
default:
p[j] = b
j += 1
}
}
return "", ""
}
}
return "", ""
}
// tokenListContainsValue returns true if the 1#token header with the given
// name contains token.
func tokenListContainsValue(header http.Header, name string, value string) bool {
headers:
for _, s := range header[name] {
for {
var t string
t, s = nextToken(skipSpace(s))
if t == "" {
continue headers
}
s = skipSpace(s)
if s != "" && s[0] != ',' {
continue headers
}
if strings.EqualFold(t, value) {
return true
}
if s == "" {
continue headers
}
s = s[1:]
}
}
return false
}
// parseExtensiosn parses WebSocket extensions from a header.
func parseExtensions(header http.Header) []map[string]string {
// From RFC 6455:
//
// Sec-WebSocket-Extensions = extension-list
// extension-list = 1#extension
// extension = extension-token *( ";" extension-param )
// extension-token = registered-token
// registered-token = token
// extension-param = token [ "=" (token | quoted-string) ]
// ;When using the quoted-string syntax variant, the value
// ;after quoted-string unescaping MUST conform to the
// ;'token' ABNF.
var result []map[string]string
headers:
for _, s := range header["Sec-Websocket-Extensions"] {
for {
var t string
t, s = nextToken(skipSpace(s))
if t == "" {
continue headers
}
ext := map[string]string{"": t}
for {
s = skipSpace(s)
if !strings.HasPrefix(s, ";") {
break
}
var k string
k, s = nextToken(skipSpace(s[1:]))
if k == "" {
continue headers
}
s = skipSpace(s)
var v string
if strings.HasPrefix(s, "=") {
v, s = nextTokenOrQuoted(skipSpace(s[1:]))
s = skipSpace(s)
}
if s != "" && s[0] != ',' && s[0] != ';' {
continue headers
}
ext[k] = v
}
if s != "" && s[0] != ',' {
continue headers
}
result = append(result, ext)
if s == "" {
continue headers
}
s = s[1:]
}
}
return result
}

74
vendor/github.com/gorilla/websocket/util_test.go generated vendored Normal file
View File

@ -0,0 +1,74 @@
// Copyright 2014 The Gorilla WebSocket Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package websocket
import (
"net/http"
"reflect"
"testing"
)
var tokenListContainsValueTests = []struct {
value string
ok bool
}{
{"WebSocket", true},
{"WEBSOCKET", true},
{"websocket", true},
{"websockets", false},
{"x websocket", false},
{"websocket x", false},
{"other,websocket,more", true},
{"other, websocket, more", true},
}
func TestTokenListContainsValue(t *testing.T) {
for _, tt := range tokenListContainsValueTests {
h := http.Header{"Upgrade": {tt.value}}
ok := tokenListContainsValue(h, "Upgrade", "websocket")
if ok != tt.ok {
t.Errorf("tokenListContainsValue(h, n, %q) = %v, want %v", tt.value, ok, tt.ok)
}
}
}
var parseExtensionTests = []struct {
value string
extensions []map[string]string
}{
{`foo`, []map[string]string{map[string]string{"": "foo"}}},
{`foo, bar; baz=2`, []map[string]string{
map[string]string{"": "foo"},
map[string]string{"": "bar", "baz": "2"}}},
{`foo; bar="b,a;z"`, []map[string]string{
map[string]string{"": "foo", "bar": "b,a;z"}}},
{`foo , bar; baz = 2`, []map[string]string{
map[string]string{"": "foo"},
map[string]string{"": "bar", "baz": "2"}}},
{`foo, bar; baz=2 junk`, []map[string]string{
map[string]string{"": "foo"}}},
{`foo junk, bar; baz=2 junk`, nil},
{`mux; max-channels=4; flow-control, deflate-stream`, []map[string]string{
map[string]string{"": "mux", "max-channels": "4", "flow-control": ""},
map[string]string{"": "deflate-stream"}}},
{`permessage-foo; x="10"`, []map[string]string{
map[string]string{"": "permessage-foo", "x": "10"}}},
{`permessage-foo; use_y, permessage-foo`, []map[string]string{
map[string]string{"": "permessage-foo", "use_y": ""},
map[string]string{"": "permessage-foo"}}},
{`permessage-deflate; client_max_window_bits; server_max_window_bits=10 , permessage-deflate; client_max_window_bits`, []map[string]string{
map[string]string{"": "permessage-deflate", "client_max_window_bits": "", "server_max_window_bits": "10"},
map[string]string{"": "permessage-deflate", "client_max_window_bits": ""}}},
}
func TestParseExtensions(t *testing.T) {
for _, tt := range parseExtensionTests {
h := http.Header{http.CanonicalHeaderKey("Sec-WebSocket-Extensions"): {tt.value}}
extensions := parseExtensions(h)
if !reflect.DeepEqual(extensions, tt.extensions) {
t.Errorf("parseExtensions(%q)\n = %v,\nwant %v", tt.value, extensions, tt.extensions)
}
}
}