Replace github.com/bmizerany/pq with github.com/lib/pq postgres driver.

Change-Id: If62fd5b489696171c8be6b42a988e1c7d0a634d0
This commit is contained in:
Josh Huckabee 2013-06-25 16:13:22 -07:00
parent dccd401ff0
commit 65f3079498
12 changed files with 3 additions and 1859 deletions

View File

@ -29,7 +29,7 @@ import (
"camlistore.org/pkg/index/postgres"
"camlistore.org/pkg/index/sqlite"
_ "camlistore.org/third_party/github.com/bmizerany/pq"
_ "camlistore.org/third_party/github.com/lib/pq"
_ "camlistore.org/third_party/github.com/ziutek/mymysql/godrv"
)

View File

@ -29,7 +29,7 @@ import (
"camlistore.org/pkg/index/postgres"
"camlistore.org/pkg/test/testdep"
_ "camlistore.org/third_party/github.com/bmizerany/pq"
_ "camlistore.org/third_party/github.com/lib/pq"
)
var (

View File

@ -27,7 +27,7 @@ import (
"camlistore.org/pkg/index/sqlindex"
"camlistore.org/pkg/jsonconfig"
_ "camlistore.org/third_party/github.com/bmizerany/pq"
_ "camlistore.org/third_party/github.com/lib/pq"
)
type myIndexStorage struct {

View File

@ -1,74 +0,0 @@
package pq
import (
"bytes"
"encoding/binary"
)
type readBuf []byte
func (b *readBuf) int32() (n int) {
n = int(binary.BigEndian.Uint32(*b))
*b = (*b)[4:]
return
}
func (b *readBuf) int16() (n int) {
n = int(binary.BigEndian.Uint16(*b))
*b = (*b)[2:]
return
}
var stringTerm = []byte{0}
func (b *readBuf) string() string {
i := bytes.Index(*b, stringTerm)
if i < 0 {
errorf("invalid message format; expected string terminator")
}
s := (*b)[:i]
*b = (*b)[i+1:]
return string(s)
}
func (b *readBuf) next(n int) (v []byte) {
v = (*b)[:n]
*b = (*b)[n:]
return
}
func (b *readBuf) byte() byte {
return b.next(1)[0]
}
type writeBuf []byte
func newWriteBuf(c byte) *writeBuf {
b := make(writeBuf, 5)
b[0] = c
return &b
}
func (b *writeBuf) int32(n int) {
x := make([]byte, 4)
binary.BigEndian.PutUint32(x, uint32(n))
*b = append(*b, x...)
}
func (b *writeBuf) int16(n int) {
x := make([]byte, 2)
binary.BigEndian.PutUint16(x, uint16(n))
*b = append(*b, x...)
}
func (b *writeBuf) string(s string) {
*b = append(*b, (s + "\000")...)
}
func (b *writeBuf) byte(c byte) {
*b = append(*b, c)
}
func (b *writeBuf) bytes(v []byte) {
*b = append(*b, v...)
}

View File

@ -1,659 +0,0 @@
package pq
import (
"bufio"
"crypto/md5"
"crypto/tls"
"database/sql"
"database/sql/driver"
"encoding/binary"
"errors"
"fmt"
"io"
"net"
"os"
"os/user"
"path"
"strconv"
"strings"
)
var (
ErrSSLNotSupported = errors.New("pq: SSL is not enabled on the server")
ErrNotSupported = errors.New("pq: invalid command")
)
type drv struct{}
func (d *drv) Open(name string) (driver.Conn, error) {
return Open(name)
}
func init() {
sql.Register("postgres", &drv{})
}
type conn struct {
c net.Conn
buf *bufio.Reader
namei int
}
func Open(name string) (_ driver.Conn, err error) {
defer errRecover(&err)
defer errRecoverWithPGReason(&err)
o := make(Values)
// A number of defaults are applied here, in this order:
//
// * Very low precedence defaults applied in every situation
// * Environment variables
// * Explicitly passed connection information
o.Set("host", "localhost")
o.Set("port", "5432")
// Default the username, but ignore errors, because a user
// passed in via environment variable or connection string
// would be okay. This can result in connections failing
// *sometimes* if the client relies on being able to determine
// the current username and there are intermittent problems.
u, err := user.Current()
if err == nil {
o.Set("user", u.Username)
}
for k, v := range parseEnviron(os.Environ()) {
o.Set(k, v)
}
parseOpts(name, o)
c, err := net.Dial(network(o))
if err != nil {
return nil, err
}
cn := &conn{c: c}
cn.ssl(o)
cn.buf = bufio.NewReader(cn.c)
cn.startup(o)
return cn, nil
}
func network(o Values) (string, string) {
host := o.Get("host")
if strings.HasPrefix(host, "/") {
sockPath := path.Join(host, ".s.PGSQL."+o.Get("port"))
return "unix", sockPath
}
return "tcp", host + ":" + o.Get("port")
}
type Values map[string]string
func (vs Values) Set(k, v string) {
vs[k] = v
}
func (vs Values) Get(k string) (v string) {
v, _ = vs[k]
return
}
func parseOpts(name string, o Values) {
if len(name) == 0 {
return
}
ps := strings.Split(name, " ")
for _, p := range ps {
kv := strings.Split(p, "=")
if len(kv) < 2 {
errorf("invalid option: %q", p)
}
o.Set(kv[0], kv[1])
}
}
func (cn *conn) Begin() (driver.Tx, error) {
_, err := cn.Exec("BEGIN", nil)
if err != nil {
return nil, err
}
return cn, err
}
func (cn *conn) Commit() error {
_, err := cn.Exec("COMMIT", nil)
return err
}
func (cn *conn) Rollback() error {
_, err := cn.Exec("ROLLBACK", nil)
return err
}
func (cn *conn) gname() string {
cn.namei++
return strconv.FormatInt(int64(cn.namei), 10)
}
func (cn *conn) simpleQuery(q string) (res driver.Result, err error) {
defer errRecover(&err)
b := newWriteBuf('Q')
b.string(q)
cn.send(b)
for {
t, r := cn.recv1()
switch t {
case 'C':
res = parseComplete(r.string())
case 'Z':
// done
return
case 'E':
err = parseError(r)
case 'T', 'N':
// ignore
default:
errorf("unknown response for simple query: %q", t)
}
}
panic("not reached")
}
func (cn *conn) prepareTo(q, stmtName string) (_ driver.Stmt, err error) {
defer errRecover(&err)
st := &stmt{cn: cn, name: stmtName, query: q}
b := newWriteBuf('P')
b.string(st.name)
b.string(q)
b.int16(0)
cn.send(b)
b = newWriteBuf('D')
b.byte('S')
b.string(st.name)
cn.send(b)
cn.send(newWriteBuf('S'))
for {
t, r := cn.recv1()
switch t {
case '1', '2', 'N':
case 't':
st.nparams = int(r.int16())
case 'T':
n := r.int16()
st.cols = make([]string, n)
st.ooid = make([]int, n)
for i := range st.cols {
st.cols[i] = r.string()
r.next(6)
st.ooid[i] = r.int32()
r.next(8)
}
case 'n':
// no data
case 'Z':
return st, err
case 'E':
err = parseError(r)
default:
errorf("unexpected describe rows response: %q", t)
}
}
panic("not reached")
}
func (cn *conn) Prepare(q string) (driver.Stmt, error) {
return cn.prepareTo(q, cn.gname())
}
func (cn *conn) Close() (err error) {
defer errRecover(&err)
cn.send(newWriteBuf('X'))
return cn.c.Close()
}
// Implement the optional "Execer" interface for one-shot queries
func (cn *conn) Exec(query string, args []driver.Value) (_ driver.Result, err error) {
defer errRecover(&err)
// Check to see if we can use the "simpleQuery" interface, which is
// *much* faster than going through prepare/exec
if len(args) == 0 {
return cn.simpleQuery(query)
}
// Use the unnamed statement to defer planning until bind
// time, or else value-based selectivity estimates cannot be
// used.
st, err := cn.prepareTo(query, "")
if err != nil {
panic(err)
}
r, err := st.Exec(args)
if err != nil {
panic(err)
}
return r, err
}
// Assumes len(*m) is > 5
func (cn *conn) send(m *writeBuf) {
b := (*m)[1:]
binary.BigEndian.PutUint32(b, uint32(len(b)))
if (*m)[0] == 0 {
*m = b
}
_, err := cn.c.Write(*m)
if err != nil {
panic(err)
}
}
func (cn *conn) recv() (t byte, r *readBuf) {
for {
t, r = cn.recv1()
switch t {
case 'E':
panic(parseError(r))
case 'N':
// ignore
default:
return
}
}
panic("not reached")
}
func (cn *conn) recv1() (byte, *readBuf) {
x := make([]byte, 5)
_, err := io.ReadFull(cn.buf, x)
if err != nil {
panic(err)
}
b := readBuf(x[1:])
y := make([]byte, b.int32()-4)
_, err = io.ReadFull(cn.buf, y)
if err != nil {
panic(err)
}
return x[0], (*readBuf)(&y)
}
func (cn *conn) ssl(o Values) {
tlsConf := tls.Config{}
switch mode := o.Get("sslmode"); mode {
case "require", "":
tlsConf.InsecureSkipVerify = true
case "verify-full":
// fall out
case "disable":
return
default:
errorf(`unsupported sslmode %q; only "require" (default), "verify-full", and "disable" supported`, mode)
}
w := newWriteBuf(0)
w.int32(80877103)
cn.send(w)
b := make([]byte, 1)
_, err := io.ReadFull(cn.c, b)
if err != nil {
panic(err)
}
if b[0] != 'S' {
panic(ErrSSLNotSupported)
}
cn.c = tls.Client(cn.c, &tlsConf)
}
func (cn *conn) startup(o Values) {
w := newWriteBuf(0)
w.int32(196608)
w.string("user")
w.string(o.Get("user"))
w.string("database")
w.string(o.Get("dbname"))
w.string("")
cn.send(w)
for {
t, r := cn.recv()
switch t {
case 'K', 'S':
case 'R':
cn.auth(r, o)
case 'Z':
return
default:
errorf("unknown response for startup: %q", t)
}
}
}
func (cn *conn) auth(r *readBuf, o Values) {
switch code := r.int32(); code {
case 0:
// OK
case 5:
s := string(r.next(4))
w := newWriteBuf('p')
w.string("md5" + md5s(md5s(o.Get("password")+o.Get("user"))+s))
cn.send(w)
t, r := cn.recv()
if t != 'R' {
errorf("unexpected password response: %q", t)
}
if r.int32() != 0 {
errorf("unexpected authentication resoonse: %q", t)
}
default:
errorf("unknown authentication response: %d", code)
}
}
type stmt struct {
cn *conn
name string
query string
cols []string
nparams int
ooid []int
closed bool
}
func (st *stmt) Close() (err error) {
if st.closed {
return nil
}
defer errRecover(&err)
w := newWriteBuf('C')
w.byte('S')
w.string(st.name)
st.cn.send(w)
st.cn.send(newWriteBuf('S'))
t, _ := st.cn.recv()
if t != '3' {
errorf("unexpected close response: %q", t)
}
st.closed = true
t, _ = st.cn.recv()
if t != 'Z' {
errorf("expected ready for query, but got: %q", t)
}
return nil
}
func (st *stmt) Query(v []driver.Value) (_ driver.Rows, err error) {
defer errRecover(&err)
st.exec(v)
return &rows{st: st}, nil
}
func (st *stmt) Exec(v []driver.Value) (res driver.Result, err error) {
defer errRecover(&err)
if len(v) == 0 {
return st.cn.simpleQuery(st.query)
}
st.exec(v)
for {
t, r := st.cn.recv1()
switch t {
case 'E':
err = parseError(r)
case 'C':
res = parseComplete(r.string())
case 'Z':
// done
return
case 'D':
errorf("unexpected data row returned in Exec; check your query")
case 'S', 'N':
// Ignore
default:
errorf("unknown exec response: %q", t)
}
}
panic("not reached")
}
func (st *stmt) exec(v []driver.Value) {
w := newWriteBuf('B')
w.string("")
w.string(st.name)
w.int16(0)
w.int16(len(v))
for _, x := range v {
if x == nil {
w.int32(-1)
} else {
b := encode(x)
w.int32(len(b))
w.bytes(b)
}
}
w.int16(0)
st.cn.send(w)
w = newWriteBuf('E')
w.string("")
w.int32(0)
st.cn.send(w)
st.cn.send(newWriteBuf('S'))
var err error
for {
t, r := st.cn.recv1()
switch t {
case 'E':
err = parseError(r)
case '2':
if err != nil {
panic(err)
}
return
case 'Z':
if err != nil {
panic(err)
}
return
case 'N':
// ignore
default:
errorf("unexpected bind response: %q", t)
}
}
}
func (st *stmt) NumInput() int {
return st.nparams
}
type result int64
func (i result) RowsAffected() (int64, error) {
return int64(i), nil
}
func (i result) LastInsertId() (int64, error) {
return 0, ErrNotSupported
}
func parseComplete(s string) driver.Result {
parts := strings.Split(s, " ")
n, _ := strconv.ParseInt(parts[len(parts)-1], 10, 64)
return result(n)
}
type rows struct {
st *stmt
done bool
}
func (rs *rows) Close() error {
for {
err := rs.Next(nil)
switch err {
case nil:
case io.EOF:
return nil
default:
return err
}
}
panic("not reached")
}
func (rs *rows) Columns() []string {
return rs.st.cols
}
func (rs *rows) Next(dest []driver.Value) (err error) {
if rs.done {
return io.EOF
}
defer errRecover(&err)
for {
t, r := rs.st.cn.recv1()
switch t {
case 'E':
err = parseError(r)
case 'C', 'S', 'N':
continue
case 'Z':
rs.done = true
if err != nil {
return err
}
return io.EOF
case 'D':
n := r.int16()
for i := 0; i < len(dest) && i < n; i++ {
l := r.int32()
if l == -1 {
dest[i] = nil
continue
}
dest[i] = decode(r.next(l), rs.st.ooid[i])
}
return
default:
errorf("unexpected message after execute: %q", t)
}
}
panic("not reached")
}
func md5s(s string) string {
h := md5.New()
h.Write([]byte(s))
return fmt.Sprintf("%x", h.Sum(nil))
}
// parseEnviron tries to mimic some of libpq's environment handling
//
// To ease testing, it does not directly reference os.Environ, but is
// designed to accept its output.
//
// Environment-set connection information is intended to have a higher
// precedence than a library default but lower than any explicitly
// passed information (such as in the URL or connection string).
func parseEnviron(env []string) (out map[string]string) {
out = make(map[string]string)
for _, v := range env {
parts := strings.SplitN(v, "=", 2)
accrue := func(keyname string) {
out[keyname] = parts[1]
}
// The order of these is the same as is seen in the
// PostgreSQL 9.1 manual, with omissions briefly
// noted.
switch parts[0] {
case "PGHOST":
accrue("host")
case "PGHOSTADDR":
accrue("hostaddr")
case "PGPORT":
accrue("port")
case "PGDATABASE":
accrue("dbname")
case "PGUSER":
accrue("user")
case "PGPASSWORD":
accrue("password")
// skip PGPASSFILE, PGSERVICE, PGSERVICEFILE,
// PGREALM
case "PGOPTIONS":
accrue("options")
case "PGAPPNAME":
accrue("application_name")
case "PGSSLMODE":
accrue("sslmode")
case "PGREQUIRESSL":
accrue("requiressl")
case "PGSSLCERT":
accrue("sslcert")
case "PGSSLKEY":
accrue("sslkey")
case "PGSSLROOTCERT":
accrue("sslrootcert")
case "PGSSLCRL":
accrue("sslcrl")
case "PGREQUIREPEER":
accrue("requirepeer")
case "PGKRBSRVNAME":
accrue("krbsrvname")
case "PGGSSLIB":
accrue("gsslib")
case "PGCONNECT_TIMEOUT":
accrue("connect_timeout")
case "PGCLIENTENCODING":
accrue("client_encoding")
// skip PGDATESTYLE, PGTZ, PGGEQO, PGSYSCONFDIR,
// PGLOCALEDIR
}
}
return out
}

View File

@ -1,443 +0,0 @@
package pq
import (
"database/sql"
"database/sql/driver"
"io"
"os"
"reflect"
"testing"
"time"
)
type Fatalistic interface {
Fatal(args ...interface{})
}
func openTestConn(t Fatalistic) *sql.DB {
datname := os.Getenv("PGDATABASE")
sslmode := os.Getenv("PGSSLMODE")
if datname == "" {
os.Setenv("PGDATABASE", "pqgotest")
}
if sslmode == "" {
os.Setenv("PGSSLMODE", "disable")
}
conn, err := sql.Open("postgres", "")
if err != nil {
t.Fatal(err)
}
return conn
}
func TestExec(t *testing.T) {
db := openTestConn(t)
defer db.Close()
_, err := db.Exec("CREATE TEMP TABLE temp (a int)")
if err != nil {
t.Fatal(err)
}
r, err := db.Exec("INSERT INTO temp VALUES (1)")
if err != nil {
t.Fatal(err)
}
if n, _ := r.RowsAffected(); n != 1 {
t.Fatalf("expected 1 row affected, not %d", n)
}
r, err = db.Exec("INSERT INTO temp VALUES ($1), ($2), ($3)", 1, 2, 3)
if err != nil {
t.Fatal(err)
}
if n, _ := r.RowsAffected(); n != 3 {
t.Fatalf("expected 3 row affected, not %d", n)
}
}
func TestStatment(t *testing.T) {
db := openTestConn(t)
defer db.Close()
st, err := db.Prepare("SELECT 1")
if err != nil {
t.Fatal(err)
}
st1, err := db.Prepare("SELECT 2")
if err != nil {
t.Fatal(err)
}
r, err := st.Query()
if err != nil {
t.Fatal(err)
}
defer r.Close()
if !r.Next() {
t.Fatal("expected row")
}
var i int
err = r.Scan(&i)
if err != nil {
t.Fatal(err)
}
if i != 1 {
t.Fatalf("expected 1, got %d", i)
}
// st1
r1, err := st1.Query()
if err != nil {
t.Fatal(err)
}
defer r1.Close()
if !r1.Next() {
if r.Err() != nil {
t.Fatal(r1.Err())
}
t.Fatal("expected row")
}
err = r1.Scan(&i)
if err != nil {
t.Fatal(err)
}
if i != 2 {
t.Fatalf("expected 2, got %d", i)
}
}
func TestRowsCloseBeforeDone(t *testing.T) {
db := openTestConn(t)
defer db.Close()
r, err := db.Query("SELECT 1")
if err != nil {
t.Fatal(err)
}
err = r.Close()
if err != nil {
t.Fatal(err)
}
if r.Next() {
t.Fatal("unexpected row")
}
if r.Err() != nil {
t.Fatal(r.Err())
}
}
func TestEncodeDecode(t *testing.T) {
db := openTestConn(t)
defer db.Close()
q := `
SELECT
'\x000102'::bytea,
'foobar'::text,
NULL::integer,
'2000-1-1 01:02:03.04-7'::timestamptz,
0::boolean,
123,
3.14::float8
WHERE
'\x000102'::bytea = $1
AND 'foobar'::text = $2
AND $3::integer is NULL
`
// AND '2000-1-1 12:00:00.000000-7'::timestamp = $3
exp1 := []byte{0, 1, 2}
exp2 := "foobar"
r, err := db.Query(q, exp1, exp2, nil)
if err != nil {
t.Fatal(err)
}
defer r.Close()
if !r.Next() {
if r.Err() != nil {
t.Fatal(r.Err())
}
t.Fatal("expected row")
}
var got1 []byte
var got2 string
var got3 = sql.NullInt64{Valid: true}
var got4 time.Time
var got5, got6, got7 interface{}
err = r.Scan(&got1, &got2, &got3, &got4, &got5, &got6, &got7)
if err != nil {
t.Fatal(err)
}
if !reflect.DeepEqual(exp1, got1) {
t.Errorf("expected %q byte: %q", exp1, got1)
}
if !reflect.DeepEqual(exp2, got2) {
t.Errorf("expected %q byte: %q", exp2, got2)
}
if got3.Valid {
t.Fatal("expected invalid")
}
if got4.Year() != 2000 {
t.Fatal("wrong year")
}
if got5 != false {
t.Fatalf("expected false, got %q", got5)
}
if got6 != int64(123) {
t.Fatalf("expected 123, got %d", got6)
}
if got7 != float64(3.14) {
t.Fatalf("expected 3.14, got %f", got7)
}
}
func TestNoData(t *testing.T) {
db := openTestConn(t)
defer db.Close()
st, err := db.Prepare("SELECT 1 WHERE true = false")
if err != nil {
t.Fatal(err)
}
defer st.Close()
r, err := st.Query()
if err != nil {
t.Fatal(err)
}
defer r.Close()
if r.Next() {
if r.Err() != nil {
t.Fatal(r.Err())
}
t.Fatal("unexpected row")
}
}
func TestPGError(t *testing.T) {
// Don't use the normal connection setup, this is intended to
// blow up in the startup packet from a non-existent user.
db, err := sql.Open("postgres", "user=thisuserreallydoesntexist")
if err != nil {
t.Fatal(err)
}
defer db.Close()
_, err = db.Begin()
if err == nil {
t.Fatal("expected error")
}
if err, ok := err.(PGError); !ok {
t.Fatalf("expected a PGError, got: %v", err)
}
}
func TestBadConn(t *testing.T) {
var err error
func() {
defer errRecover(&err)
panic(io.EOF)
}()
if err != driver.ErrBadConn {
t.Fatalf("expected driver.ErrBadConn, got: %#v", err)
}
func() {
defer errRecover(&err)
e := &pgError{c: make(map[byte]string)}
e.c['S'] = Efatal
panic(e)
}()
if err != driver.ErrBadConn {
t.Fatalf("expected driver.ErrBadConn, got: %#v", err)
}
}
func TestErrorOnExec(t *testing.T) {
db := openTestConn(t)
defer db.Close()
sql := "DO $$BEGIN RAISE unique_violation USING MESSAGE='foo'; END; $$;"
_, err := db.Exec(sql)
_, ok := err.(PGError)
if !ok {
t.Fatalf("expected PGError, was: %#v", err)
}
_, err = db.Exec("SELECT 1 WHERE true = false") // returns no rows
if err != nil {
t.Fatal(err)
}
}
func TestErrorOnQuery(t *testing.T) {
db := openTestConn(t)
defer db.Close()
sql := "DO $$BEGIN RAISE unique_violation USING MESSAGE='foo'; END; $$;"
r, err := db.Query(sql)
if err != nil {
t.Fatal(err)
}
defer r.Close()
if r.Next() {
t.Fatal("unexpected row, want error")
}
_, ok := r.Err().(PGError)
if !ok {
t.Fatalf("expected PGError, was: %#v", r.Err())
}
r, err = db.Query("SELECT 1 WHERE true = false") // returns no rows
if err != nil {
t.Fatal(err)
}
if r.Next() {
t.Fatal("unexpected row")
}
}
func TestBindError(t *testing.T) {
db := openTestConn(t)
defer db.Close()
_, err := db.Exec("create temp table test (i integer)")
if err != nil {
t.Fatal(err)
}
_, err = db.Query("select * from test where i=$1", "hhh")
if err == nil {
t.Fatal("expected an error")
}
// Should not get error here
r, err := db.Query("select * from test where i=$1", 1)
if err != nil {
t.Fatal(err)
}
defer r.Close()
}
func TestParseEnviron(t *testing.T) {
expected := map[string]string{"dbname": "hello", "user": "goodbye"}
results := parseEnviron([]string{"PGDATABASE=hello", "PGUSER=goodbye"})
if !reflect.DeepEqual(expected, results) {
t.Fatalf("Expected: %#v Got: %#v", expected, results)
}
}
func TestExecerInterface(t *testing.T) {
// Gin up a straw man private struct just for the type check
cn := &conn{c: nil}
var cni interface{} = cn
_, ok := cni.(driver.Execer)
if !ok {
t.Fatal("Driver doesn't implement Execer")
}
}
func TestNullAfterNonNull(t *testing.T) {
db := openTestConn(t)
defer db.Close()
r, err := db.Query("SELECT 9::integer UNION SELECT NULL::integer")
if err != nil {
t.Fatal(err)
}
var n sql.NullInt64
if !r.Next() {
if r.Err() != nil {
t.Fatal(err)
}
t.Fatal("expected row")
}
if err := r.Scan(&n); err != nil {
t.Fatal(err)
}
if n.Int64 != 9 {
t.Fatalf("expected 2, not %d", n.Int64)
}
if !r.Next() {
if r.Err() != nil {
t.Fatal(err)
}
t.Fatal("expected row")
}
if err := r.Scan(&n); err != nil {
t.Fatal(err)
}
if n.Valid {
t.Fatal("expected n to be invalid")
}
if n.Int64 != 0 {
t.Fatalf("expected n to 2, not %d", n.Int64)
}
}
// Stress test the performance of parsing results from the wire.
func BenchmarkResultParsing(b *testing.B) {
b.StopTimer()
db := openTestConn(b)
defer db.Close()
_, err := db.Exec("BEGIN")
if err != nil {
b.Fatal(err)
}
b.StartTimer()
for i := 0; i < b.N; i++ {
res, err := db.Query("SELECT generate_series(1, 50000)")
if err != nil {
b.Fatal(err)
}
res.Close()
}
}

View File

@ -1,108 +0,0 @@
package pq
import (
"database/sql/driver"
"encoding/hex"
"fmt"
"strconv"
"time"
)
func encode(x interface{}) []byte {
const timeFormat = "2006-01-02 15:04:05.0000-07"
switch v := x.(type) {
case int64:
return []byte(fmt.Sprintf("%d", v))
case float32, float64:
return []byte(fmt.Sprintf("%f", v))
case []byte:
return []byte(fmt.Sprintf("\\x%x", v))
case string:
return []byte(v)
case bool:
return []byte(fmt.Sprintf("%t", v))
case time.Time:
return []byte(v.Format(timeFormat))
default:
errorf("encode: unknown type for %T", v)
}
panic("not reached")
}
func decode(s []byte, typ int) interface{} {
switch typ {
case t_bytea:
s = s[2:] // trim off "\\x"
d := make([]byte, hex.DecodedLen(len(s)))
_, err := hex.Decode(d, s)
if err != nil {
errorf("%s", err)
}
return d
case t_timestamptz:
return mustParse("2006-01-02 15:04:05-07", s)
case t_timestamp:
return mustParse("2006-01-02 15:04:05", s)
case t_time:
return mustParse("15:04:05", s)
case t_timetz:
return mustParse("15:04:05-07", s)
case t_date:
return mustParse("2006-01-02", s)
case t_bool:
return s[0] == 't'
case t_int8, t_int2, t_int4:
i, err := strconv.ParseInt(string(s), 10, 64)
if err != nil {
errorf("%s", err)
}
return i
case t_float4, t_float8:
bits := 64
if typ == t_float4 {
bits = 32
}
f, err := strconv.ParseFloat(string(s), bits)
if err != nil {
errorf("%s", err)
}
return f
}
return s
}
func mustParse(f string, s []byte) time.Time {
str := string(s)
// Special case until time.Parse bug is fixed:
// http://code.google.com/p/go/issues/detail?id=3487
if str[len(str)-2] == '.' {
str += "0"
}
t, err := time.Parse(f, str)
if err != nil {
errorf("decode: %s", err)
}
return t
}
type NullTime struct {
Time time.Time
Valid bool // Valid is true if Time is not NULL
}
// Scan implements the Scanner interface.
func (nt *NullTime) Scan(value interface{}) error {
nt.Time, nt.Valid = value.(time.Time)
return nil
}
// Value implements the driver Valuer interface.
func (nt NullTime) Value() (driver.Value, error) {
if !nt.Valid {
return nil, nil
}
return nt.Time, nil
}

View File

@ -1,26 +0,0 @@
package pq
import (
"testing"
"time"
)
func TestScanTimestamp(t *testing.T) {
var nt NullTime
tn := time.Now()
(&nt).Scan(tn)
if !nt.Valid {
t.Errorf("Expected Valid=false")
}
if nt.Time != tn {
t.Errorf("Time value mismatch")
}
}
func TestScanNilTimestamp(t *testing.T) {
var nt NullTime
(&nt).Scan(nil)
if nt.Valid {
t.Errorf("Expected Valid=false")
}
}

View File

@ -1,108 +0,0 @@
package pq
import (
"database/sql/driver"
"fmt"
"io"
"net"
"runtime"
)
const (
Efatal = "FATAL"
Epanic = "PANIC"
Ewarning = "WARNING"
Enotice = "NOTICE"
Edebug = "DEBUG"
Einfo = "INFO"
Elog = "LOG"
)
type Error error
type PGError interface {
Error() string
Fatal() bool
Get(k byte) (v string)
}
type pgError struct {
c map[byte]string
}
func parseError(r *readBuf) *pgError {
err := &pgError{make(map[byte]string)}
for t := r.byte(); t != 0; t = r.byte() {
err.c[t] = r.string()
}
return err
}
func (err *pgError) Get(k byte) (v string) {
v, _ = err.c[k]
return
}
func (err *pgError) Fatal() bool {
return err.Get('S') == Efatal
}
func (err *pgError) Error() string {
var s string
for k, v := range err.c {
s += fmt.Sprintf(" %c:%q", k, v)
}
return "pq: " + s[1:]
}
func errorf(s string, args ...interface{}) {
panic(Error(fmt.Errorf("pq: %s", fmt.Sprintf(s, args...))))
}
type SimplePGError struct {
pgError
}
func (err *SimplePGError) Error() string {
return "pq: " + err.Get('M')
}
func errRecoverWithPGReason(err *error) {
e := recover()
switch v := e.(type) {
case nil:
// Do nothing
case *pgError:
// Return a SimplePGError in place
*err = &SimplePGError{*v}
default:
// Otherwise re-panic
panic(e)
}
}
func errRecover(err *error) {
e := recover()
switch v := e.(type) {
case nil:
// Do nothing
case runtime.Error:
panic(v)
case *pgError:
if v.Fatal() {
*err = driver.ErrBadConn
} else {
*err = v
}
case *net.OpError:
*err = driver.ErrBadConn
case error:
if v == io.EOF || v.(error).Error() == "remote error: handshake failure" {
*err = driver.ErrBadConn
} else {
*err = v
}
default:
panic(fmt.Sprintf("unknown error: %#v", e))
}
}

View File

@ -1,317 +0,0 @@
package pq
const (
t_bool = 16
t_bytea = 17
t_char = 18
t_name = 19
t_int8 = 20
t_int2 = 21
t_int2vector = 22
t_int4 = 23
t_regproc = 24
t_text = 25
t_oid = 26
t_tid = 27
t_xid = 28
t_cid = 29
t_oidvector = 30
t_pg_type = 71
t_pg_attribute = 75
t_pg_proc = 81
t_pg_class = 83
t_xml = 142
t__xml = 143
t_pg_node_tree = 194
t_smgr = 210
t_point = 600
t_lseg = 601
t_path = 602
t_box = 603
t_polygon = 604
t_line = 628
t__line = 629
t_float4 = 700
t_float8 = 701
t_abstime = 702
t_reltime = 703
t_tinterval = 704
t_unknown = 705
t_circle = 718
t__circle = 719
t_money = 790
t__money = 791
t_macaddr = 829
t_inet = 869
t_cidr = 650
t__bool = 1000
t__bytea = 1001
t__char = 1002
t__name = 1003
t__int2 = 1005
t__int2vector = 1006
t__int4 = 1007
t__regproc = 1008
t__text = 1009
t__oid = 1028
t__tid = 1010
t__xid = 1011
t__cid = 1012
t__oidvector = 1013
t__bpchar = 1014
t__varchar = 1015
t__int8 = 1016
t__point = 1017
t__lseg = 1018
t__path = 1019
t__box = 1020
t__float4 = 1021
t__float8 = 1022
t__abstime = 1023
t__reltime = 1024
t__tinterval = 1025
t__polygon = 1027
t_aclitem = 1033
t__aclitem = 1034
t__macaddr = 1040
t__inet = 1041
t__cidr = 651
t__cstring = 1263
t_bpchar = 1042
t_varchar = 1043
t_date = 1082
t_time = 1083
t_timestamp = 1114
t__timestamp = 1115
t__date = 1182
t__time = 1183
t_timestamptz = 1184
t__timestamptz = 1185
t_interval = 1186
t__interval = 1187
t__numeric = 1231
t_timetz = 1266
t__timetz = 1270
t_bit = 1560
t__bit = 1561
t_varbit = 1562
t__varbit = 1563
t_numeric = 1700
t_refcursor = 1790
t__refcursor = 2201
t_regprocedure = 2202
t_regoper = 2203
t_regoperator = 2204
t_regclass = 2205
t_regtype = 2206
t__regprocedure = 2207
t__regoper = 2208
t__regoperator = 2209
t__regclass = 2210
t__regtype = 2211
t_uuid = 2950
t__uuid = 2951
t_tsvector = 3614
t_gtsvector = 3642
t_tsquery = 3615
t_regconfig = 3734
t_regdictionary = 3769
t__tsvector = 3643
t__gtsvector = 3644
t__tsquery = 3645
t__regconfig = 3735
t__regdictionary = 3770
t_txid_snapshot = 2970
t__txid_snapshot = 2949
t_record = 2249
t__record = 2287
t_cstring = 2275
t_any = 2276
t_anyarray = 2277
t_void = 2278
t_trigger = 2279
t_language_handler = 2280
t_internal = 2281
t_opaque = 2282
t_anyelement = 2283
t_anynonarray = 2776
t_anyenum = 3500
t_fdw_handler = 3115
t_pg_attrdef = 10000
t_pg_constraint = 10001
t_pg_inherits = 10002
t_pg_index = 10003
t_pg_operator = 10004
t_pg_opfamily = 10005
t_pg_opclass = 10006
t_pg_am = 10117
t_pg_amop = 10118
t_pg_amproc = 10478
t_pg_language = 10731
t_pg_largeobject_metadata = 10732
t_pg_largeobject = 10733
t_pg_aggregate = 10734
t_pg_statistic = 10735
t_pg_rewrite = 10736
t_pg_trigger = 10737
t_pg_description = 10738
t_pg_cast = 10739
t_pg_enum = 10936
t_pg_namespace = 10937
t_pg_conversion = 10938
t_pg_depend = 10939
t_pg_database = 1248
t_pg_db_role_setting = 10940
t_pg_tablespace = 10941
t_pg_pltemplate = 10942
t_pg_authid = 2842
t_pg_auth_members = 2843
t_pg_shdepend = 10943
t_pg_shdescription = 10944
t_pg_ts_config = 10945
t_pg_ts_config_map = 10946
t_pg_ts_dict = 10947
t_pg_ts_parser = 10948
t_pg_ts_template = 10949
t_pg_extension = 10950
t_pg_foreign_data_wrapper = 10951
t_pg_foreign_server = 10952
t_pg_user_mapping = 10953
t_pg_foreign_table = 10954
t_pg_default_acl = 10955
t_pg_seclabel = 10956
t_pg_collation = 10957
t_pg_toast_2604 = 10958
t_pg_toast_2606 = 10959
t_pg_toast_2609 = 10960
t_pg_toast_1255 = 10961
t_pg_toast_2618 = 10962
t_pg_toast_3596 = 10963
t_pg_toast_2619 = 10964
t_pg_toast_2620 = 10965
t_pg_toast_1262 = 10966
t_pg_toast_2396 = 10967
t_pg_toast_2964 = 10968
t_pg_roles = 10970
t_pg_shadow = 10973
t_pg_group = 10976
t_pg_user = 10979
t_pg_rules = 10982
t_pg_views = 10986
t_pg_tables = 10989
t_pg_indexes = 10993
t_pg_stats = 10997
t_pg_locks = 11001
t_pg_cursors = 11004
t_pg_available_extensions = 11007
t_pg_available_extension_versions = 11010
t_pg_prepared_xacts = 11013
t_pg_prepared_statements = 11017
t_pg_seclabels = 11020
t_pg_settings = 11024
t_pg_timezone_abbrevs = 11029
t_pg_timezone_names = 11032
t_pg_stat_all_tables = 11035
t_pg_stat_xact_all_tables = 11039
t_pg_stat_sys_tables = 11043
t_pg_stat_xact_sys_tables = 11047
t_pg_stat_user_tables = 11050
t_pg_stat_xact_user_tables = 11054
t_pg_statio_all_tables = 11057
t_pg_statio_sys_tables = 11061
t_pg_statio_user_tables = 11064
t_pg_stat_all_indexes = 11067
t_pg_stat_sys_indexes = 11071
t_pg_stat_user_indexes = 11074
t_pg_statio_all_indexes = 11077
t_pg_statio_sys_indexes = 11081
t_pg_statio_user_indexes = 11084
t_pg_statio_all_sequences = 11087
t_pg_statio_sys_sequences = 11090
t_pg_statio_user_sequences = 11093
t_pg_stat_activity = 11096
t_pg_stat_replication = 11099
t_pg_stat_database = 11102
t_pg_stat_database_conflicts = 11105
t_pg_stat_user_functions = 11108
t_pg_stat_xact_user_functions = 11112
t_pg_stat_bgwriter = 11116
t_pg_user_mappings = 11119
t_cardinal_number = 11669
t_character_data = 11671
t_sql_identifier = 11672
t_information_schema_catalog_name = 11674
t_time_stamp = 11676
t_yes_or_no = 11677
t_applicable_roles = 11680
t_administrable_role_authorizations = 11684
t_attributes = 11687
t_character_sets = 11691
t_check_constraint_routine_usage = 11695
t_check_constraints = 11699
t_collations = 11703
t_collation_character_set_applicability = 11706
t_column_domain_usage = 11709
t_column_privileges = 11713
t_column_udt_usage = 11717
t_columns = 11721
t_constraint_column_usage = 11725
t_constraint_table_usage = 11729
t_domain_constraints = 11733
t_domain_udt_usage = 11737
t_domains = 11740
t_enabled_roles = 11744
t_key_column_usage = 11747
t_parameters = 11751
t_referential_constraints = 11755
t_role_column_grants = 11759
t_routine_privileges = 11762
t_role_routine_grants = 11766
t_routines = 11769
t_schemata = 11773
t_sequences = 11776
t_sql_features = 11780
t_pg_toast_11779 = 11782
t_sql_implementation_info = 11785
t_pg_toast_11784 = 11787
t_sql_languages = 11790
t_pg_toast_11789 = 11792
t_sql_packages = 11795
t_pg_toast_11794 = 11797
t_sql_parts = 11800
t_pg_toast_11799 = 11802
t_sql_sizing = 11805
t_pg_toast_11804 = 11807
t_sql_sizing_profiles = 11810
t_pg_toast_11809 = 11812
t_table_constraints = 11815
t_table_privileges = 11819
t_role_table_grants = 11823
t_tables = 11826
t_triggered_update_columns = 11830
t_triggers = 11834
t_usage_privileges = 11838
t_role_usage_grants = 11842
t_view_column_usage = 11845
t_view_routine_usage = 11849
t_view_table_usage = 11853
t_views = 11857
t_data_type_privileges = 11861
t_element_types = 11865
t__pg_foreign_data_wrappers = 11869
t_foreign_data_wrapper_options = 11872
t_foreign_data_wrappers = 11875
t__pg_foreign_servers = 11878
t_foreign_server_options = 11882
t_foreign_servers = 11885
t__pg_foreign_tables = 11888
t_foreign_table_options = 11892
t_foreign_tables = 11895
t__pg_user_mappings = 11898
t_user_mapping_options = 11901
t_user_mappings = 11905
t_t = 16806
t__t = 16805
t_temp = 16810
t__temp = 16809
)

View File

@ -1,68 +0,0 @@
package pq
import (
"fmt"
nurl "net/url"
"sort"
"strings"
)
// ParseURL converts url to a connection string for driver.Open.
// Example:
//
// "postgres://bob:secret@1.2.3.4:5432/mydb?sslmode=verify-full"
//
// converts to:
//
// "user=bob password=secret host=1.2.3.4 port=5432 dbname=mydb sslmode=verify-full"
//
// A minimal example:
//
// "postgres://"
//
// This will be blank, causing driver.Open to use all of the defaults
func ParseURL(url string) (string, error) {
u, err := nurl.Parse(url)
if err != nil {
return "", err
}
if u.Scheme != "postgres" {
return "", fmt.Errorf("invalid connection protocol: %s", u.Scheme)
}
var kvs []string
accrue := func(k, v string) {
if v != "" {
kvs = append(kvs, k+"="+v)
}
}
if u.User != nil {
v := u.User.Username()
accrue("user", v)
v, _ = u.User.Password()
accrue("password", v)
}
i := strings.Index(u.Host, ":")
if i < 0 {
accrue("host", u.Host)
} else {
accrue("host", u.Host[:i])
accrue("port", u.Host[i+1:])
}
if u.Path != "" {
accrue("dbname", u.Path[1:])
}
q := u.Query()
for k, _ := range q {
accrue(k, q.Get(k))
}
sort.Strings(kvs) // Makes testing easier (not a performance concern)
return strings.Join(kvs, " "), nil
}

View File

@ -1,53 +0,0 @@
package pq
import (
"testing"
)
func TestSimpleParseURL(t *testing.T) {
expected := "host=hostname.remote"
str, err := ParseURL("postgres://hostname.remote")
if err != nil {
t.Fatal(err)
}
if str != expected {
t.Fatalf("unexpected result from ParseURL:\n+ %v\n- %v", str, expected)
}
}
func TestFullParseURL(t *testing.T) {
expected := "dbname=database host=hostname.remote password=secret port=1234 user=username"
str, err := ParseURL("postgres://username:secret@hostname.remote:1234/database")
if err != nil {
t.Fatal(err)
}
if str != expected {
t.Fatalf("unexpected result from ParseURL:\n+ %s\n- %s", str, expected)
}
}
func TestInvalidProtocolParseURL(t *testing.T) {
_, err := ParseURL("http://hostname.remote")
switch err {
case nil:
t.Fatal("Expected an error from parsing invalid protocol")
default:
msg := "invalid connection protocol: http"
if err.Error() != msg {
t.Fatal("Unexpected error message:\n+ %s\n- %s", err.Error(), msg)
}
}
}
func TestMinimalURL(t *testing.T) {
cs, err := ParseURL("postgres://")
if err != nil {
t.Fatal(err)
}
if cs != "" {
t.Fatalf("expected blank connection string, got: %q", cs)
}
}