mirror of https://github.com/perkeep/perkeep.git
Add github.com/lib/pq third party dependency.
Change-Id: I0c2240c1615b463adb57b618ca39a626e858dcc7
This commit is contained in:
parent
5d556aac1c
commit
4c3b49b2cf
|
@ -0,0 +1,8 @@
|
|||
Copyright (c) 2011-2013, 'pq' Contributors
|
||||
Portions Copyright (C) 2011 Blake Mizerany
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
|
@ -0,0 +1,103 @@
|
|||
# pq - A pure Go postgres driver for Go's database/sql package
|
||||
|
||||
## Install
|
||||
|
||||
go get github.com/lib/pq
|
||||
|
||||
## Docs
|
||||
|
||||
<http://godoc.org/github.com/lib/pq>
|
||||
|
||||
## Use
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
_ "github.com/lib/pq"
|
||||
"database/sql"
|
||||
)
|
||||
|
||||
func main() {
|
||||
db, err := sql.Open("postgres", "user=pqgotest dbname=pqgotest sslmode=verify-full")
|
||||
// ...
|
||||
}
|
||||
|
||||
**Connection String Parameters**
|
||||
|
||||
These are a subset of the libpq connection parameters. In addition, a
|
||||
number of the [environment
|
||||
variables](http://www.postgresql.org/docs/9.1/static/libpq-envars.html)
|
||||
supported by libpq are also supported. Just like libpq, these have
|
||||
lower precedence than explicitly provided connection parameters.
|
||||
|
||||
See http://www.postgresql.org/docs/9.1/static/libpq-connect.html.
|
||||
|
||||
* `dbname` - The name of the database to connect to
|
||||
* `user` - The user to sign in as
|
||||
* `password` - The user's password
|
||||
* `host` - The host to connect to. Values that start with `/` are for unix domain sockets. (default is `localhost`)
|
||||
* `port` - The port to bind to. (default is `5432`)
|
||||
* `sslmode` - Whether or not to use SSL (default is `require`, this is not the default for libpq)
|
||||
Valid values are:
|
||||
* `disable` - No SSL
|
||||
* `require` - Always SSL (skip verification)
|
||||
* `verify-full` - Always SSL (require verification)
|
||||
|
||||
See http://golang.org/pkg/database/sql to learn how to use with `pq` through the `database/sql` package.
|
||||
|
||||
## Tests
|
||||
|
||||
`go test` is used for testing. A running PostgreSQL server is
|
||||
required, with the ability to log in. The default database to connect
|
||||
to test with is "pqgotest," but it can be overridden using environment
|
||||
variables.
|
||||
|
||||
Example:
|
||||
|
||||
PGHOST=/var/run/postgresql go test pq
|
||||
|
||||
## Features
|
||||
|
||||
* SSL
|
||||
* Handles bad connections for `database/sql`
|
||||
* Scan `time.Time` correctly (i.e. `timestamp[tz]`, `time[tz]`, `date`)
|
||||
* Scan binary blobs correctly (i.e. `bytea`)
|
||||
* pq.ParseURL for converting urls to connection strings for sql.Open.
|
||||
* Many libpq compatible environment variables
|
||||
* Unix socket support
|
||||
|
||||
## Future / Things you can help with
|
||||
|
||||
* Notifications: `LISTEN`/`NOTIFY`
|
||||
* `hstore` sugar (i.e. handling hstore in `rows.Scan`)
|
||||
|
||||
## Thank you (alphabetical)
|
||||
|
||||
Some of these contributors are from the original library `bmizerany/pq.go` whose
|
||||
code still exists in here.
|
||||
|
||||
* Andy Balholm (andybalholm)
|
||||
* Ben Berkert (benburkert)
|
||||
* Bill Mill (llimllib)
|
||||
* Bjørn Madsen (aeons)
|
||||
* Blake Gentry (bgentry)
|
||||
* Brad Fitzpatrick (bradfitz)
|
||||
* Chris Walsh (cwds)
|
||||
* Daniel Farina (fdr)
|
||||
* Everyone at The Go Team
|
||||
* Ewan Chou (coocood)
|
||||
* Federico Romero (federomero)
|
||||
* Gary Burd (garyburd)
|
||||
* Heroku (heroku)
|
||||
* Jason McVetta (jmcvetta)
|
||||
* Joakim Sernbrant (serbaut)
|
||||
* John Gallagher (jgallagher)
|
||||
* Kamil Kisiel (kisielk)
|
||||
* Keith Rarick (kr)
|
||||
* Maciek Sakrejda (deafbybeheading)
|
||||
* Marc Brinkmann (mbr)
|
||||
* Martin Olsen (martinolsen)
|
||||
* Mike Lewis (mikelikespie)
|
||||
* Ryan Smith (ryandotsmith)
|
||||
* Samuel Stauffer (samuel)
|
||||
* notedit (notedit)
|
|
@ -0,0 +1,81 @@
|
|||
package pq
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"camlistore.org/third_party/github.com/lib/pq/oid"
|
||||
)
|
||||
|
||||
type readBuf []byte
|
||||
|
||||
func (b *readBuf) int32() (n int) {
|
||||
n = int(int32(binary.BigEndian.Uint32(*b)))
|
||||
*b = (*b)[4:]
|
||||
return
|
||||
}
|
||||
|
||||
func (b *readBuf) oid() (n oid.Oid) {
|
||||
n = oid.Oid(binary.BigEndian.Uint32(*b))
|
||||
*b = (*b)[4:]
|
||||
return
|
||||
}
|
||||
|
||||
func (b *readBuf) int16() (n int) {
|
||||
n = int(binary.BigEndian.Uint16(*b))
|
||||
*b = (*b)[2:]
|
||||
return
|
||||
}
|
||||
|
||||
var stringTerm = []byte{0}
|
||||
|
||||
func (b *readBuf) string() string {
|
||||
i := bytes.Index(*b, stringTerm)
|
||||
if i < 0 {
|
||||
errorf("invalid message format; expected string terminator")
|
||||
}
|
||||
s := (*b)[:i]
|
||||
*b = (*b)[i+1:]
|
||||
return string(s)
|
||||
}
|
||||
|
||||
func (b *readBuf) next(n int) (v []byte) {
|
||||
v = (*b)[:n]
|
||||
*b = (*b)[n:]
|
||||
return
|
||||
}
|
||||
|
||||
func (b *readBuf) byte() byte {
|
||||
return b.next(1)[0]
|
||||
}
|
||||
|
||||
type writeBuf []byte
|
||||
|
||||
func newWriteBuf(c byte) *writeBuf {
|
||||
b := make(writeBuf, 5)
|
||||
b[0] = c
|
||||
return &b
|
||||
}
|
||||
|
||||
func (b *writeBuf) int32(n int) {
|
||||
x := make([]byte, 4)
|
||||
binary.BigEndian.PutUint32(x, uint32(n))
|
||||
*b = append(*b, x...)
|
||||
}
|
||||
|
||||
func (b *writeBuf) int16(n int) {
|
||||
x := make([]byte, 2)
|
||||
binary.BigEndian.PutUint16(x, uint16(n))
|
||||
*b = append(*b, x...)
|
||||
}
|
||||
|
||||
func (b *writeBuf) string(s string) {
|
||||
*b = append(*b, (s + "\000")...)
|
||||
}
|
||||
|
||||
func (b *writeBuf) byte(c byte) {
|
||||
*b = append(*b, c)
|
||||
}
|
||||
|
||||
func (b *writeBuf) bytes(v []byte) {
|
||||
*b = append(*b, v...)
|
||||
}
|
|
@ -0,0 +1,684 @@
|
|||
// Package pq is a pure Go Postgres driver for the database/sql package.
|
||||
package pq
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"crypto/md5"
|
||||
"crypto/tls"
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"camlistore.org/third_party/github.com/lib/pq/oid"
|
||||
"io"
|
||||
"net"
|
||||
"os"
|
||||
"path"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrSSLNotSupported = errors.New("pq: SSL is not enabled on the server")
|
||||
ErrNotSupported = errors.New("pq: invalid command")
|
||||
)
|
||||
|
||||
type drv struct{}
|
||||
|
||||
func (d *drv) Open(name string) (driver.Conn, error) {
|
||||
return Open(name)
|
||||
}
|
||||
|
||||
func init() {
|
||||
sql.Register("postgres", &drv{})
|
||||
}
|
||||
|
||||
type conn struct {
|
||||
c net.Conn
|
||||
buf *bufio.Reader
|
||||
namei int
|
||||
}
|
||||
|
||||
func Open(name string) (_ driver.Conn, err error) {
|
||||
defer errRecover(&err)
|
||||
defer errRecoverWithPGReason(&err)
|
||||
|
||||
o := make(Values)
|
||||
|
||||
// A number of defaults are applied here, in this order:
|
||||
//
|
||||
// * Very low precedence defaults applied in every situation
|
||||
// * Environment variables
|
||||
// * Explicitly passed connection information
|
||||
o.Set("host", "localhost")
|
||||
o.Set("port", "5432")
|
||||
|
||||
for k, v := range parseEnviron(os.Environ()) {
|
||||
o.Set(k, v)
|
||||
}
|
||||
|
||||
parseOpts(name, o)
|
||||
|
||||
// If a user is not provided by any other means, the last
|
||||
// resort is to use the current operating system provided user
|
||||
// name.
|
||||
if o.Get("user") == "" {
|
||||
u, err := userCurrent()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
} else {
|
||||
o.Set("user", u)
|
||||
}
|
||||
}
|
||||
|
||||
c, err := net.Dial(network(o))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
cn := &conn{c: c}
|
||||
cn.ssl(o)
|
||||
cn.buf = bufio.NewReader(cn.c)
|
||||
cn.startup(o)
|
||||
return cn, nil
|
||||
}
|
||||
|
||||
func network(o Values) (string, string) {
|
||||
host := o.Get("host")
|
||||
|
||||
if strings.HasPrefix(host, "/") {
|
||||
sockPath := path.Join(host, ".s.PGSQL."+o.Get("port"))
|
||||
return "unix", sockPath
|
||||
}
|
||||
|
||||
return "tcp", host + ":" + o.Get("port")
|
||||
}
|
||||
|
||||
type Values map[string]string
|
||||
|
||||
func (vs Values) Set(k, v string) {
|
||||
vs[k] = v
|
||||
}
|
||||
|
||||
func (vs Values) Get(k string) (v string) {
|
||||
v, _ = vs[k]
|
||||
return
|
||||
}
|
||||
|
||||
func parseOpts(name string, o Values) {
|
||||
if len(name) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
name = strings.TrimSpace(name)
|
||||
|
||||
ps := strings.Split(name, " ")
|
||||
for _, p := range ps {
|
||||
kv := strings.Split(p, "=")
|
||||
if len(kv) < 2 {
|
||||
errorf("invalid option: %q", p)
|
||||
}
|
||||
o.Set(kv[0], kv[1])
|
||||
}
|
||||
}
|
||||
|
||||
func (cn *conn) Begin() (driver.Tx, error) {
|
||||
_, err := cn.Exec("BEGIN", nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return cn, err
|
||||
}
|
||||
|
||||
func (cn *conn) Commit() error {
|
||||
_, err := cn.Exec("COMMIT", nil)
|
||||
return err
|
||||
}
|
||||
|
||||
func (cn *conn) Rollback() error {
|
||||
_, err := cn.Exec("ROLLBACK", nil)
|
||||
return err
|
||||
}
|
||||
|
||||
func (cn *conn) gname() string {
|
||||
cn.namei++
|
||||
return strconv.FormatInt(int64(cn.namei), 10)
|
||||
}
|
||||
|
||||
func (cn *conn) simpleQuery(q string) (res driver.Result, err error) {
|
||||
defer errRecover(&err)
|
||||
|
||||
b := newWriteBuf('Q')
|
||||
b.string(q)
|
||||
cn.send(b)
|
||||
|
||||
for {
|
||||
t, r := cn.recv1()
|
||||
switch t {
|
||||
case 'C':
|
||||
res = parseComplete(r.string())
|
||||
case 'Z':
|
||||
// done
|
||||
return
|
||||
case 'E':
|
||||
err = parseError(r)
|
||||
case 'T', 'N', 'S', 'D':
|
||||
// ignore
|
||||
default:
|
||||
errorf("unknown response for simple query: %q", t)
|
||||
}
|
||||
}
|
||||
panic("not reached")
|
||||
}
|
||||
|
||||
func (cn *conn) prepareTo(q, stmtName string) (_ driver.Stmt, err error) {
|
||||
defer errRecover(&err)
|
||||
|
||||
st := &stmt{cn: cn, name: stmtName, query: q}
|
||||
|
||||
b := newWriteBuf('P')
|
||||
b.string(st.name)
|
||||
b.string(q)
|
||||
b.int16(0)
|
||||
cn.send(b)
|
||||
|
||||
b = newWriteBuf('D')
|
||||
b.byte('S')
|
||||
b.string(st.name)
|
||||
cn.send(b)
|
||||
|
||||
cn.send(newWriteBuf('S'))
|
||||
|
||||
for {
|
||||
t, r := cn.recv1()
|
||||
switch t {
|
||||
case '1', '2', 'N':
|
||||
case 't':
|
||||
st.nparams = int(r.int16())
|
||||
st.paramTyps = make([]oid.Oid, st.nparams, st.nparams)
|
||||
|
||||
for i := 0; i < st.nparams; i += 1 {
|
||||
st.paramTyps[i] = r.oid()
|
||||
}
|
||||
case 'T':
|
||||
n := r.int16()
|
||||
st.cols = make([]string, n)
|
||||
st.rowTyps = make([]oid.Oid, n)
|
||||
for i := range st.cols {
|
||||
st.cols[i] = r.string()
|
||||
r.next(6)
|
||||
st.rowTyps[i] = r.oid()
|
||||
r.next(8)
|
||||
}
|
||||
case 'n':
|
||||
// no data
|
||||
case 'Z':
|
||||
return st, err
|
||||
case 'E':
|
||||
err = parseError(r)
|
||||
case 'C':
|
||||
// command complete
|
||||
return st, err
|
||||
default:
|
||||
errorf("unexpected describe rows response: %q", t)
|
||||
}
|
||||
}
|
||||
|
||||
panic("not reached")
|
||||
}
|
||||
|
||||
func (cn *conn) Prepare(q string) (driver.Stmt, error) {
|
||||
return cn.prepareTo(q, cn.gname())
|
||||
}
|
||||
|
||||
func (cn *conn) Close() (err error) {
|
||||
defer errRecover(&err)
|
||||
cn.send(newWriteBuf('X'))
|
||||
|
||||
return cn.c.Close()
|
||||
}
|
||||
|
||||
// Implement the optional "Execer" interface for one-shot queries
|
||||
func (cn *conn) Exec(query string, args []driver.Value) (_ driver.Result, err error) {
|
||||
defer errRecover(&err)
|
||||
|
||||
// Check to see if we can use the "simpleQuery" interface, which is
|
||||
// *much* faster than going through prepare/exec
|
||||
if len(args) == 0 {
|
||||
return cn.simpleQuery(query)
|
||||
}
|
||||
|
||||
// Use the unnamed statement to defer planning until bind
|
||||
// time, or else value-based selectivity estimates cannot be
|
||||
// used.
|
||||
st, err := cn.prepareTo(query, "")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
r, err := st.Exec(args)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return r, err
|
||||
}
|
||||
|
||||
// Assumes len(*m) is > 5
|
||||
func (cn *conn) send(m *writeBuf) {
|
||||
b := (*m)[1:]
|
||||
binary.BigEndian.PutUint32(b, uint32(len(b)))
|
||||
|
||||
if (*m)[0] == 0 {
|
||||
*m = b
|
||||
}
|
||||
|
||||
_, err := cn.c.Write(*m)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
func (cn *conn) recv() (t byte, r *readBuf) {
|
||||
for {
|
||||
t, r = cn.recv1()
|
||||
switch t {
|
||||
case 'E':
|
||||
panic(parseError(r))
|
||||
case 'N':
|
||||
// ignore
|
||||
default:
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
panic("not reached")
|
||||
}
|
||||
|
||||
func (cn *conn) recv1() (byte, *readBuf) {
|
||||
x := make([]byte, 5)
|
||||
_, err := io.ReadFull(cn.buf, x)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
b := readBuf(x[1:])
|
||||
y := make([]byte, b.int32()-4)
|
||||
_, err = io.ReadFull(cn.buf, y)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return x[0], (*readBuf)(&y)
|
||||
}
|
||||
|
||||
func (cn *conn) ssl(o Values) {
|
||||
tlsConf := tls.Config{}
|
||||
switch mode := o.Get("sslmode"); mode {
|
||||
case "require", "":
|
||||
tlsConf.InsecureSkipVerify = true
|
||||
case "verify-full":
|
||||
// fall out
|
||||
case "disable":
|
||||
return
|
||||
default:
|
||||
errorf(`unsupported sslmode %q; only "require" (default), "verify-full", and "disable" supported`, mode)
|
||||
}
|
||||
|
||||
w := newWriteBuf(0)
|
||||
w.int32(80877103)
|
||||
cn.send(w)
|
||||
|
||||
b := make([]byte, 1)
|
||||
_, err := io.ReadFull(cn.c, b)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
if b[0] != 'S' {
|
||||
panic(ErrSSLNotSupported)
|
||||
}
|
||||
|
||||
cn.c = tls.Client(cn.c, &tlsConf)
|
||||
}
|
||||
|
||||
func (cn *conn) startup(o Values) {
|
||||
w := newWriteBuf(0)
|
||||
w.int32(196608)
|
||||
w.string("user")
|
||||
w.string(o.Get("user"))
|
||||
w.string("database")
|
||||
w.string(o.Get("dbname"))
|
||||
w.string("")
|
||||
cn.send(w)
|
||||
|
||||
for {
|
||||
t, r := cn.recv()
|
||||
switch t {
|
||||
case 'K', 'S':
|
||||
case 'R':
|
||||
cn.auth(r, o)
|
||||
case 'Z':
|
||||
return
|
||||
default:
|
||||
errorf("unknown response for startup: %q", t)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (cn *conn) auth(r *readBuf, o Values) {
|
||||
switch code := r.int32(); code {
|
||||
case 0:
|
||||
// OK
|
||||
case 3:
|
||||
w := newWriteBuf('p')
|
||||
w.string(o.Get("password"))
|
||||
cn.send(w)
|
||||
|
||||
t, r := cn.recv()
|
||||
if t != 'R' {
|
||||
errorf("unexpected password response: %q", t)
|
||||
}
|
||||
|
||||
if r.int32() != 0 {
|
||||
errorf("unexpected authentication response: %q", t)
|
||||
}
|
||||
case 5:
|
||||
s := string(r.next(4))
|
||||
w := newWriteBuf('p')
|
||||
w.string("md5" + md5s(md5s(o.Get("password")+o.Get("user"))+s))
|
||||
cn.send(w)
|
||||
|
||||
t, r := cn.recv()
|
||||
if t != 'R' {
|
||||
errorf("unexpected password response: %q", t)
|
||||
}
|
||||
|
||||
if r.int32() != 0 {
|
||||
errorf("unexpected authentication response: %q", t)
|
||||
}
|
||||
default:
|
||||
errorf("unknown authentication response: %d", code)
|
||||
}
|
||||
}
|
||||
|
||||
type stmt struct {
|
||||
cn *conn
|
||||
name string
|
||||
query string
|
||||
cols []string
|
||||
nparams int
|
||||
rowTyps []oid.Oid
|
||||
paramTyps []oid.Oid
|
||||
closed bool
|
||||
}
|
||||
|
||||
func (st *stmt) Close() (err error) {
|
||||
if st.closed {
|
||||
return nil
|
||||
}
|
||||
|
||||
defer errRecover(&err)
|
||||
|
||||
w := newWriteBuf('C')
|
||||
w.byte('S')
|
||||
w.string(st.name)
|
||||
st.cn.send(w)
|
||||
|
||||
st.cn.send(newWriteBuf('S'))
|
||||
|
||||
t, _ := st.cn.recv()
|
||||
if t != '3' {
|
||||
errorf("unexpected close response: %q", t)
|
||||
}
|
||||
st.closed = true
|
||||
|
||||
t, _ = st.cn.recv()
|
||||
if t != 'Z' {
|
||||
errorf("expected ready for query, but got: %q", t)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (st *stmt) Query(v []driver.Value) (_ driver.Rows, err error) {
|
||||
defer errRecover(&err)
|
||||
st.exec(v)
|
||||
return &rows{st: st}, nil
|
||||
}
|
||||
|
||||
func (st *stmt) Exec(v []driver.Value) (res driver.Result, err error) {
|
||||
defer errRecover(&err)
|
||||
|
||||
if len(v) == 0 {
|
||||
return st.cn.simpleQuery(st.query)
|
||||
}
|
||||
st.exec(v)
|
||||
|
||||
for {
|
||||
t, r := st.cn.recv1()
|
||||
switch t {
|
||||
case 'E':
|
||||
err = parseError(r)
|
||||
case 'C':
|
||||
res = parseComplete(r.string())
|
||||
case 'Z':
|
||||
// done
|
||||
return
|
||||
case 'T', 'N', 'S', 'D':
|
||||
// Ignore
|
||||
default:
|
||||
errorf("unknown exec response: %q", t)
|
||||
}
|
||||
}
|
||||
|
||||
panic("not reached")
|
||||
}
|
||||
|
||||
func (st *stmt) exec(v []driver.Value) {
|
||||
w := newWriteBuf('B')
|
||||
w.string("")
|
||||
w.string(st.name)
|
||||
w.int16(0)
|
||||
w.int16(len(v))
|
||||
for i, x := range v {
|
||||
if x == nil {
|
||||
w.int32(-1)
|
||||
} else {
|
||||
b := encode(x, st.paramTyps[i])
|
||||
w.int32(len(b))
|
||||
w.bytes(b)
|
||||
}
|
||||
}
|
||||
w.int16(0)
|
||||
st.cn.send(w)
|
||||
|
||||
w = newWriteBuf('E')
|
||||
w.string("")
|
||||
w.int32(0)
|
||||
st.cn.send(w)
|
||||
|
||||
st.cn.send(newWriteBuf('S'))
|
||||
|
||||
var err error
|
||||
for {
|
||||
t, r := st.cn.recv1()
|
||||
switch t {
|
||||
case 'E':
|
||||
err = parseError(r)
|
||||
case '2':
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return
|
||||
case 'Z':
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return
|
||||
case 'N':
|
||||
// ignore
|
||||
default:
|
||||
errorf("unexpected bind response: %q", t)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (st *stmt) NumInput() int {
|
||||
return st.nparams
|
||||
}
|
||||
|
||||
type result int64
|
||||
|
||||
func (i result) RowsAffected() (int64, error) {
|
||||
return int64(i), nil
|
||||
}
|
||||
|
||||
func (i result) LastInsertId() (int64, error) {
|
||||
return 0, ErrNotSupported
|
||||
}
|
||||
|
||||
func parseComplete(s string) driver.Result {
|
||||
parts := strings.Split(s, " ")
|
||||
n, _ := strconv.ParseInt(parts[len(parts)-1], 10, 64)
|
||||
return result(n)
|
||||
}
|
||||
|
||||
type rows struct {
|
||||
st *stmt
|
||||
done bool
|
||||
}
|
||||
|
||||
func (rs *rows) Close() error {
|
||||
for {
|
||||
err := rs.Next(nil)
|
||||
switch err {
|
||||
case nil:
|
||||
case io.EOF:
|
||||
return nil
|
||||
default:
|
||||
return err
|
||||
}
|
||||
}
|
||||
panic("not reached")
|
||||
}
|
||||
|
||||
func (rs *rows) Columns() []string {
|
||||
return rs.st.cols
|
||||
}
|
||||
|
||||
func (rs *rows) Next(dest []driver.Value) (err error) {
|
||||
if rs.done {
|
||||
return io.EOF
|
||||
}
|
||||
|
||||
defer errRecover(&err)
|
||||
|
||||
for {
|
||||
t, r := rs.st.cn.recv1()
|
||||
switch t {
|
||||
case 'E':
|
||||
err = parseError(r)
|
||||
case 'C', 'S', 'N':
|
||||
continue
|
||||
case 'Z':
|
||||
rs.done = true
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return io.EOF
|
||||
case 'D':
|
||||
n := r.int16()
|
||||
for i := 0; i < len(dest) && i < n; i++ {
|
||||
l := r.int32()
|
||||
if l == -1 {
|
||||
dest[i] = nil
|
||||
continue
|
||||
}
|
||||
dest[i] = decode(r.next(l), rs.st.rowTyps[i])
|
||||
}
|
||||
return
|
||||
default:
|
||||
errorf("unexpected message after execute: %q", t)
|
||||
}
|
||||
}
|
||||
|
||||
panic("not reached")
|
||||
}
|
||||
|
||||
func md5s(s string) string {
|
||||
h := md5.New()
|
||||
h.Write([]byte(s))
|
||||
return fmt.Sprintf("%x", h.Sum(nil))
|
||||
}
|
||||
|
||||
// parseEnviron tries to mimic some of libpq's environment handling
|
||||
//
|
||||
// To ease testing, it does not directly reference os.Environ, but is
|
||||
// designed to accept its output.
|
||||
//
|
||||
// Environment-set connection information is intended to have a higher
|
||||
// precedence than a library default but lower than any explicitly
|
||||
// passed information (such as in the URL or connection string).
|
||||
func parseEnviron(env []string) (out map[string]string) {
|
||||
out = make(map[string]string)
|
||||
|
||||
for _, v := range env {
|
||||
parts := strings.SplitN(v, "=", 2)
|
||||
|
||||
accrue := func(keyname string) {
|
||||
out[keyname] = parts[1]
|
||||
}
|
||||
|
||||
// The order of these is the same as is seen in the
|
||||
// PostgreSQL 9.1 manual, with omissions briefly
|
||||
// noted.
|
||||
switch parts[0] {
|
||||
case "PGHOST":
|
||||
accrue("host")
|
||||
case "PGHOSTADDR":
|
||||
accrue("hostaddr")
|
||||
case "PGPORT":
|
||||
accrue("port")
|
||||
case "PGDATABASE":
|
||||
accrue("dbname")
|
||||
case "PGUSER":
|
||||
accrue("user")
|
||||
case "PGPASSWORD":
|
||||
accrue("password")
|
||||
// skip PGPASSFILE, PGSERVICE, PGSERVICEFILE,
|
||||
// PGREALM
|
||||
case "PGOPTIONS":
|
||||
accrue("options")
|
||||
case "PGAPPNAME":
|
||||
accrue("application_name")
|
||||
case "PGSSLMODE":
|
||||
accrue("sslmode")
|
||||
case "PGREQUIRESSL":
|
||||
accrue("requiressl")
|
||||
case "PGSSLCERT":
|
||||
accrue("sslcert")
|
||||
case "PGSSLKEY":
|
||||
accrue("sslkey")
|
||||
case "PGSSLROOTCERT":
|
||||
accrue("sslrootcert")
|
||||
case "PGSSLCRL":
|
||||
accrue("sslcrl")
|
||||
case "PGREQUIREPEER":
|
||||
accrue("requirepeer")
|
||||
case "PGKRBSRVNAME":
|
||||
accrue("krbsrvname")
|
||||
case "PGGSSLIB":
|
||||
accrue("gsslib")
|
||||
case "PGCONNECT_TIMEOUT":
|
||||
accrue("connect_timeout")
|
||||
case "PGCLIENTENCODING":
|
||||
accrue("client_encoding")
|
||||
// skip PGDATESTYLE, PGTZ, PGGEQO, PGSYSCONFDIR,
|
||||
// PGLOCALEDIR
|
||||
}
|
||||
}
|
||||
|
||||
return out
|
||||
}
|
|
@ -0,0 +1,528 @@
|
|||
package pq
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"io"
|
||||
"os"
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
type Fatalistic interface {
|
||||
Fatal(args ...interface{})
|
||||
}
|
||||
|
||||
func openTestConn(t Fatalistic) *sql.DB {
|
||||
datname := os.Getenv("PGDATABASE")
|
||||
sslmode := os.Getenv("PGSSLMODE")
|
||||
|
||||
if datname == "" {
|
||||
os.Setenv("PGDATABASE", "pqgotest")
|
||||
}
|
||||
|
||||
if sslmode == "" {
|
||||
os.Setenv("PGSSLMODE", "disable")
|
||||
}
|
||||
|
||||
conn, err := sql.Open("postgres", "")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
return conn
|
||||
}
|
||||
|
||||
func TestExec(t *testing.T) {
|
||||
db := openTestConn(t)
|
||||
defer db.Close()
|
||||
|
||||
_, err := db.Exec("CREATE TEMP TABLE temp (a int)")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
r, err := db.Exec("INSERT INTO temp VALUES (1)")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if n, _ := r.RowsAffected(); n != 1 {
|
||||
t.Fatalf("expected 1 row affected, not %d", n)
|
||||
}
|
||||
|
||||
r, err = db.Exec("INSERT INTO temp VALUES ($1), ($2), ($3)", 1, 2, 3)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if n, _ := r.RowsAffected(); n != 3 {
|
||||
t.Fatalf("expected 3 rows affected, not %d", n)
|
||||
}
|
||||
|
||||
r, err = db.Exec("SELECT g FROM generate_series(1, 2) g")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if n, _ := r.RowsAffected(); n != 2 {
|
||||
t.Fatalf("expected 2 rows affected, not %d", n)
|
||||
}
|
||||
|
||||
r, err = db.Exec("SELECT g FROM generate_series(1, $1) g", 3)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if n, _ := r.RowsAffected(); n != 3 {
|
||||
t.Fatalf("expected 3 rows affected, not %d", n)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestStatment(t *testing.T) {
|
||||
db := openTestConn(t)
|
||||
defer db.Close()
|
||||
|
||||
st, err := db.Prepare("SELECT 1")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
st1, err := db.Prepare("SELECT 2")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
r, err := st.Query()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer r.Close()
|
||||
|
||||
if !r.Next() {
|
||||
t.Fatal("expected row")
|
||||
}
|
||||
|
||||
var i int
|
||||
err = r.Scan(&i)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if i != 1 {
|
||||
t.Fatalf("expected 1, got %d", i)
|
||||
}
|
||||
|
||||
// st1
|
||||
|
||||
r1, err := st1.Query()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer r1.Close()
|
||||
|
||||
if !r1.Next() {
|
||||
if r.Err() != nil {
|
||||
t.Fatal(r1.Err())
|
||||
}
|
||||
t.Fatal("expected row")
|
||||
}
|
||||
|
||||
err = r1.Scan(&i)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if i != 2 {
|
||||
t.Fatalf("expected 2, got %d", i)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRowsCloseBeforeDone(t *testing.T) {
|
||||
db := openTestConn(t)
|
||||
defer db.Close()
|
||||
|
||||
r, err := db.Query("SELECT 1")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = r.Close()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if r.Next() {
|
||||
t.Fatal("unexpected row")
|
||||
}
|
||||
|
||||
if r.Err() != nil {
|
||||
t.Fatal(r.Err())
|
||||
}
|
||||
}
|
||||
|
||||
func TestEncodeDecode(t *testing.T) {
|
||||
db := openTestConn(t)
|
||||
defer db.Close()
|
||||
|
||||
q := `
|
||||
SELECT
|
||||
'\x000102'::bytea,
|
||||
'foobar'::text,
|
||||
NULL::integer,
|
||||
'2000-1-1 01:02:03.04-7'::timestamptz,
|
||||
0::boolean,
|
||||
123,
|
||||
3.14::float8
|
||||
WHERE
|
||||
'\x000102'::bytea = $1
|
||||
AND 'foobar'::text = $2
|
||||
AND $3::integer is NULL
|
||||
`
|
||||
// AND '2000-1-1 12:00:00.000000-7'::timestamp = $3
|
||||
|
||||
exp1 := []byte{0, 1, 2}
|
||||
exp2 := "foobar"
|
||||
|
||||
r, err := db.Query(q, exp1, exp2, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer r.Close()
|
||||
|
||||
if !r.Next() {
|
||||
if r.Err() != nil {
|
||||
t.Fatal(r.Err())
|
||||
}
|
||||
t.Fatal("expected row")
|
||||
}
|
||||
|
||||
var got1 []byte
|
||||
var got2 string
|
||||
var got3 = sql.NullInt64{Valid: true}
|
||||
var got4 time.Time
|
||||
var got5, got6, got7 interface{}
|
||||
|
||||
err = r.Scan(&got1, &got2, &got3, &got4, &got5, &got6, &got7)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(exp1, got1) {
|
||||
t.Errorf("expected %q byte: %q", exp1, got1)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(exp2, got2) {
|
||||
t.Errorf("expected %q byte: %q", exp2, got2)
|
||||
}
|
||||
|
||||
if got3.Valid {
|
||||
t.Fatal("expected invalid")
|
||||
}
|
||||
|
||||
if got4.Year() != 2000 {
|
||||
t.Fatal("wrong year")
|
||||
}
|
||||
|
||||
if got5 != false {
|
||||
t.Fatalf("expected false, got %q", got5)
|
||||
}
|
||||
|
||||
if got6 != int64(123) {
|
||||
t.Fatalf("expected 123, got %d", got6)
|
||||
}
|
||||
|
||||
if got7 != float64(3.14) {
|
||||
t.Fatalf("expected 3.14, got %f", got7)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNoData(t *testing.T) {
|
||||
db := openTestConn(t)
|
||||
defer db.Close()
|
||||
|
||||
st, err := db.Prepare("SELECT 1 WHERE true = false")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer st.Close()
|
||||
|
||||
r, err := st.Query()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer r.Close()
|
||||
|
||||
if r.Next() {
|
||||
if r.Err() != nil {
|
||||
t.Fatal(r.Err())
|
||||
}
|
||||
t.Fatal("unexpected row")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPGError(t *testing.T) {
|
||||
// Don't use the normal connection setup, this is intended to
|
||||
// blow up in the startup packet from a non-existent user.
|
||||
db, err := sql.Open("postgres", "user=thisuserreallydoesntexist")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
_, err = db.Begin()
|
||||
if err == nil {
|
||||
t.Fatal("expected error")
|
||||
}
|
||||
|
||||
if err, ok := err.(PGError); !ok {
|
||||
t.Fatalf("expected a PGError, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBadConn(t *testing.T) {
|
||||
var err error
|
||||
|
||||
func() {
|
||||
defer errRecover(&err)
|
||||
panic(io.EOF)
|
||||
}()
|
||||
|
||||
if err != driver.ErrBadConn {
|
||||
t.Fatalf("expected driver.ErrBadConn, got: %#v", err)
|
||||
}
|
||||
|
||||
func() {
|
||||
defer errRecover(&err)
|
||||
e := &pgError{c: make(map[byte]string)}
|
||||
e.c['S'] = Efatal
|
||||
panic(e)
|
||||
}()
|
||||
|
||||
if err != driver.ErrBadConn {
|
||||
t.Fatalf("expected driver.ErrBadConn, got: %#v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestErrorOnExec(t *testing.T) {
|
||||
db := openTestConn(t)
|
||||
defer db.Close()
|
||||
|
||||
sql := "DO $$BEGIN RAISE unique_violation USING MESSAGE='foo'; END; $$;"
|
||||
_, err := db.Exec(sql)
|
||||
_, ok := err.(PGError)
|
||||
if !ok {
|
||||
t.Fatalf("expected PGError, was: %#v", err)
|
||||
}
|
||||
|
||||
_, err = db.Exec("SELECT 1 WHERE true = false") // returns no rows
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestErrorOnQuery(t *testing.T) {
|
||||
db := openTestConn(t)
|
||||
defer db.Close()
|
||||
|
||||
sql := "DO $$BEGIN RAISE unique_violation USING MESSAGE='foo'; END; $$;"
|
||||
r, err := db.Query(sql)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer r.Close()
|
||||
|
||||
if r.Next() {
|
||||
t.Fatal("unexpected row, want error")
|
||||
}
|
||||
|
||||
_, ok := r.Err().(PGError)
|
||||
if !ok {
|
||||
t.Fatalf("expected PGError, was: %#v", r.Err())
|
||||
}
|
||||
|
||||
r, err = db.Query("SELECT 1 WHERE true = false") // returns no rows
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if r.Next() {
|
||||
t.Fatal("unexpected row")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBindError(t *testing.T) {
|
||||
db := openTestConn(t)
|
||||
defer db.Close()
|
||||
|
||||
_, err := db.Exec("create temp table test (i integer)")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
_, err = db.Query("select * from test where i=$1", "hhh")
|
||||
if err == nil {
|
||||
t.Fatal("expected an error")
|
||||
}
|
||||
|
||||
// Should not get error here
|
||||
r, err := db.Query("select * from test where i=$1", 1)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer r.Close()
|
||||
}
|
||||
|
||||
func TestParseEnviron(t *testing.T) {
|
||||
expected := map[string]string{"dbname": "hello", "user": "goodbye"}
|
||||
results := parseEnviron([]string{"PGDATABASE=hello", "PGUSER=goodbye"})
|
||||
if !reflect.DeepEqual(expected, results) {
|
||||
t.Fatalf("Expected: %#v Got: %#v", expected, results)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecerInterface(t *testing.T) {
|
||||
// Gin up a straw man private struct just for the type check
|
||||
cn := &conn{c: nil}
|
||||
var cni interface{} = cn
|
||||
|
||||
_, ok := cni.(driver.Execer)
|
||||
if !ok {
|
||||
t.Fatal("Driver doesn't implement Execer")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNullAfterNonNull(t *testing.T) {
|
||||
db := openTestConn(t)
|
||||
defer db.Close()
|
||||
|
||||
r, err := db.Query("SELECT 9::integer UNION SELECT NULL::integer")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
var n sql.NullInt64
|
||||
|
||||
if !r.Next() {
|
||||
if r.Err() != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Fatal("expected row")
|
||||
}
|
||||
|
||||
if err := r.Scan(&n); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if n.Int64 != 9 {
|
||||
t.Fatalf("expected 2, not %d", n.Int64)
|
||||
}
|
||||
|
||||
if !r.Next() {
|
||||
if r.Err() != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Fatal("expected row")
|
||||
}
|
||||
|
||||
if err := r.Scan(&n); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if n.Valid {
|
||||
t.Fatal("expected n to be invalid")
|
||||
}
|
||||
|
||||
if n.Int64 != 0 {
|
||||
t.Fatalf("expected n to 2, not %d", n.Int64)
|
||||
}
|
||||
}
|
||||
|
||||
// Stress test the performance of parsing results from the wire.
|
||||
func BenchmarkResultParsing(b *testing.B) {
|
||||
b.StopTimer()
|
||||
|
||||
db := openTestConn(b)
|
||||
defer db.Close()
|
||||
_, err := db.Exec("BEGIN")
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
|
||||
b.StartTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
res, err := db.Query("SELECT generate_series(1, 50000)")
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
res.Close()
|
||||
}
|
||||
}
|
||||
|
||||
func Test64BitErrorChecking(t *testing.T) {
|
||||
defer func() {
|
||||
if err := recover(); err != nil {
|
||||
t.Fatal("panic due to 0xFFFFFFFF != -1 " +
|
||||
"when int is 64 bits")
|
||||
}
|
||||
}()
|
||||
|
||||
db := openTestConn(t)
|
||||
defer db.Close()
|
||||
|
||||
r, err := db.Query(`SELECT *
|
||||
FROM (VALUES (0::integer, NULL::text), (1, 'test string')) AS t;`)
|
||||
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
defer r.Close()
|
||||
|
||||
for r.Next() {
|
||||
}
|
||||
}
|
||||
|
||||
// Open transaction, issue INSERT query inside transaction, rollback
|
||||
// transaction, issue SELECT query to same db used to create the tx. No rows
|
||||
// should be returned.
|
||||
func TestRollback(t *testing.T) {
|
||||
db := openTestConn(t)
|
||||
defer db.Close()
|
||||
|
||||
_, err := db.Exec("CREATE TEMP TABLE temp (a int)")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
sqlInsert := "INSERT INTO temp VALUES (1)"
|
||||
sqlSelect := "SELECT * FROM temp"
|
||||
tx, err := db.Begin()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
_, err = tx.Query(sqlInsert)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
err = tx.Rollback()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
r, err := db.Query(sqlSelect)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
// Next() returns false if query returned no rows.
|
||||
if r.Next() {
|
||||
t.Fatal("Transaction rollback failed")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConnTrailingSpace(t *testing.T) {
|
||||
o := make(Values)
|
||||
expected := Values{"dbname": "hello", "user": "goodbye"}
|
||||
parseOpts("dbname=hello user=goodbye ", o)
|
||||
if !reflect.DeepEqual(expected, o) {
|
||||
t.Fatalf("Expected: %#v Got: %#v", expected, o)
|
||||
}
|
||||
}
|
|
@ -0,0 +1,122 @@
|
|||
package pq
|
||||
|
||||
import (
|
||||
"database/sql/driver"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"camlistore.org/third_party/github.com/lib/pq/oid"
|
||||
"strconv"
|
||||
"time"
|
||||
)
|
||||
|
||||
func encode(x interface{}, pgtypOid oid.Oid) []byte {
|
||||
switch v := x.(type) {
|
||||
case int64:
|
||||
return []byte(fmt.Sprintf("%d", v))
|
||||
case float32, float64:
|
||||
return []byte(fmt.Sprintf("%f", v))
|
||||
case []byte:
|
||||
if pgtypOid == oid.T_bytea {
|
||||
return []byte(fmt.Sprintf("\\x%x", v))
|
||||
}
|
||||
|
||||
return v
|
||||
case string:
|
||||
if pgtypOid == oid.T_bytea {
|
||||
return []byte(fmt.Sprintf("\\x%x", v))
|
||||
}
|
||||
|
||||
return []byte(v)
|
||||
case bool:
|
||||
return []byte(fmt.Sprintf("%t", v))
|
||||
case time.Time:
|
||||
return []byte(v.Format(time.RFC3339Nano))
|
||||
default:
|
||||
errorf("encode: unknown type for %T", v)
|
||||
}
|
||||
|
||||
panic("not reached")
|
||||
}
|
||||
|
||||
func decode(s []byte, typ oid.Oid) interface{} {
|
||||
switch typ {
|
||||
case oid.T_bytea:
|
||||
s = s[2:] // trim off "\\x"
|
||||
d := make([]byte, hex.DecodedLen(len(s)))
|
||||
_, err := hex.Decode(d, s)
|
||||
if err != nil {
|
||||
errorf("%s", err)
|
||||
}
|
||||
return d
|
||||
case oid.T_timestamptz:
|
||||
return mustParse("2006-01-02 15:04:05-07", typ, s)
|
||||
case oid.T_timestamp:
|
||||
return mustParse("2006-01-02 15:04:05", typ, s)
|
||||
case oid.T_time:
|
||||
return mustParse("15:04:05", typ, s)
|
||||
case oid.T_timetz:
|
||||
return mustParse("15:04:05-07", typ, s)
|
||||
case oid.T_date:
|
||||
return mustParse("2006-01-02", typ, s)
|
||||
case oid.T_bool:
|
||||
return s[0] == 't'
|
||||
case oid.T_int8, oid.T_int2, oid.T_int4:
|
||||
i, err := strconv.ParseInt(string(s), 10, 64)
|
||||
if err != nil {
|
||||
errorf("%s", err)
|
||||
}
|
||||
return i
|
||||
case oid.T_float4, oid.T_float8:
|
||||
bits := 64
|
||||
if typ == oid.T_float4 {
|
||||
bits = 32
|
||||
}
|
||||
f, err := strconv.ParseFloat(string(s), bits)
|
||||
if err != nil {
|
||||
errorf("%s", err)
|
||||
}
|
||||
return f
|
||||
}
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
func mustParse(f string, typ oid.Oid, s []byte) time.Time {
|
||||
str := string(s)
|
||||
|
||||
// Special case until time.Parse bug is fixed:
|
||||
// http://code.google.com/p/go/issues/detail?id=3487
|
||||
if str[len(str)-2] == '.' {
|
||||
str += "0"
|
||||
}
|
||||
|
||||
// check for a 30-minute-offset timezone
|
||||
if (typ == oid.T_timestamptz || typ == oid.T_timetz) &&
|
||||
str[len(str)-3] == ':' {
|
||||
f += ":00"
|
||||
}
|
||||
t, err := time.Parse(f, str)
|
||||
if err != nil {
|
||||
errorf("decode: %s", err)
|
||||
}
|
||||
return t
|
||||
}
|
||||
|
||||
type NullTime struct {
|
||||
Time time.Time
|
||||
Valid bool // Valid is true if Time is not NULL
|
||||
}
|
||||
|
||||
// Scan implements the Scanner interface.
|
||||
func (nt *NullTime) Scan(value interface{}) error {
|
||||
nt.Time, nt.Valid = value.(time.Time)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Value implements the driver Valuer interface.
|
||||
func (nt NullTime) Value() (driver.Value, error) {
|
||||
if !nt.Valid {
|
||||
return nil, nil
|
||||
}
|
||||
return nt.Time, nil
|
||||
}
|
|
@ -0,0 +1,164 @@
|
|||
package pq
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestScanTimestamp(t *testing.T) {
|
||||
var nt NullTime
|
||||
tn := time.Now()
|
||||
(&nt).Scan(tn)
|
||||
if !nt.Valid {
|
||||
t.Errorf("Expected Valid=false")
|
||||
}
|
||||
if nt.Time != tn {
|
||||
t.Errorf("Time value mismatch")
|
||||
}
|
||||
}
|
||||
|
||||
func TestScanNilTimestamp(t *testing.T) {
|
||||
var nt NullTime
|
||||
(&nt).Scan(nil)
|
||||
if nt.Valid {
|
||||
t.Errorf("Expected Valid=false")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTimestampWithTimeZone(t *testing.T) {
|
||||
db := openTestConn(t)
|
||||
defer db.Close()
|
||||
|
||||
tx, err := db.Begin()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer tx.Rollback()
|
||||
|
||||
_, err = tx.Exec("create temp table test (t timestamp with time zone)")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// try several different locations, all included in Go's zoneinfo.zip
|
||||
for _, locName := range []string{
|
||||
"UTC",
|
||||
"America/Chicago",
|
||||
"America/New_York",
|
||||
"Australia/Darwin",
|
||||
"Australia/Perth",
|
||||
} {
|
||||
loc, err := time.LoadLocation(locName)
|
||||
if err != nil {
|
||||
t.Logf("Could not load time zone %s - skipping", locName)
|
||||
continue
|
||||
}
|
||||
|
||||
// Postgres timestamps have a resolution of 1 microsecond, so don't
|
||||
// use the full range of the Nanosecond argument
|
||||
refTime := time.Date(2012, 11, 6, 10, 23, 42, 123456000, loc)
|
||||
_, err = tx.Exec("insert into test(t) values($1)", refTime)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
for _, pgTimeZone := range []string{"US/Eastern", "Australia/Darwin"} {
|
||||
// Switch Postgres's timezone to test different output timestamp formats
|
||||
_, err = tx.Exec(fmt.Sprintf("set time zone '%s'", pgTimeZone))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
var gotTime time.Time
|
||||
row := tx.QueryRow("select t from test")
|
||||
err = row.Scan(&gotTime)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if !refTime.Equal(gotTime) {
|
||||
t.Errorf("timestamps not equal: %s != %s", refTime, gotTime)
|
||||
}
|
||||
}
|
||||
|
||||
_, err = tx.Exec("delete from test")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestTimestampWithOutTimezone(t *testing.T) {
|
||||
db := openTestConn(t)
|
||||
defer db.Close()
|
||||
|
||||
test := func(ts, pgts string) {
|
||||
r, err := db.Query("SELECT $1::timestamp", pgts)
|
||||
if err != nil {
|
||||
t.Fatalf("Could not run query: %v", err)
|
||||
}
|
||||
|
||||
n := r.Next()
|
||||
|
||||
if n != true {
|
||||
t.Fatal("Expected at least one row")
|
||||
}
|
||||
|
||||
var result time.Time
|
||||
err = r.Scan(&result)
|
||||
if err != nil {
|
||||
t.Fatalf("Did not expect error scanning row: %v", err)
|
||||
}
|
||||
|
||||
expected, err := time.Parse(time.RFC3339, ts)
|
||||
if err != nil {
|
||||
t.Fatalf("Could not parse test time literal: %v", err)
|
||||
}
|
||||
|
||||
if !result.Equal(expected) {
|
||||
t.Fatalf("Expected time to match %v: got mismatch %v",
|
||||
expected, result)
|
||||
}
|
||||
|
||||
n = r.Next()
|
||||
if n != false {
|
||||
t.Fatal("Expected only one row")
|
||||
}
|
||||
}
|
||||
|
||||
test("2000-01-01T00:00:00Z", "2000-01-01T00:00:00")
|
||||
|
||||
// Test higher precision time
|
||||
test("2013-01-04T20:14:58.80033Z", "2013-01-04 20:14:58.80033")
|
||||
}
|
||||
|
||||
func TestStringWithNul(t *testing.T) {
|
||||
db := openTestConn(t)
|
||||
defer db.Close()
|
||||
|
||||
hello0world := string("hello\x00world")
|
||||
_, err := db.Query("SELECT $1::text", &hello0world)
|
||||
if err == nil {
|
||||
t.Fatal("Postgres accepts a string with nul in it; " +
|
||||
"injection attacks may be plausible")
|
||||
}
|
||||
}
|
||||
|
||||
func TestByteToText(t *testing.T) {
|
||||
db := openTestConn(t)
|
||||
defer db.Close()
|
||||
|
||||
b := []byte("hello world")
|
||||
row := db.QueryRow("SELECT $1::text", b)
|
||||
|
||||
var result []byte
|
||||
err := row.Scan(&result)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if string(result) != string(b) {
|
||||
t.Fatalf("expected %v but got %v", b, result)
|
||||
}
|
||||
}
|
|
@ -0,0 +1,108 @@
|
|||
package pq
|
||||
|
||||
import (
|
||||
"database/sql/driver"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"runtime"
|
||||
)
|
||||
|
||||
const (
|
||||
Efatal = "FATAL"
|
||||
Epanic = "PANIC"
|
||||
Ewarning = "WARNING"
|
||||
Enotice = "NOTICE"
|
||||
Edebug = "DEBUG"
|
||||
Einfo = "INFO"
|
||||
Elog = "LOG"
|
||||
)
|
||||
|
||||
type Error error
|
||||
|
||||
type PGError interface {
|
||||
Error() string
|
||||
Fatal() bool
|
||||
Get(k byte) (v string)
|
||||
}
|
||||
type pgError struct {
|
||||
c map[byte]string
|
||||
}
|
||||
|
||||
func parseError(r *readBuf) *pgError {
|
||||
err := &pgError{make(map[byte]string)}
|
||||
for t := r.byte(); t != 0; t = r.byte() {
|
||||
err.c[t] = r.string()
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (err *pgError) Get(k byte) (v string) {
|
||||
v, _ = err.c[k]
|
||||
return
|
||||
}
|
||||
|
||||
func (err *pgError) Fatal() bool {
|
||||
return err.Get('S') == Efatal
|
||||
}
|
||||
|
||||
func (err *pgError) Error() string {
|
||||
var s string
|
||||
for k, v := range err.c {
|
||||
s += fmt.Sprintf(" %c:%q", k, v)
|
||||
}
|
||||
return "pq: " + s[1:]
|
||||
}
|
||||
|
||||
func errorf(s string, args ...interface{}) {
|
||||
panic(Error(fmt.Errorf("pq: %s", fmt.Sprintf(s, args...))))
|
||||
}
|
||||
|
||||
type SimplePGError struct {
|
||||
pgError
|
||||
}
|
||||
|
||||
func (err *SimplePGError) Error() string {
|
||||
return "pq: " + err.Get('M')
|
||||
}
|
||||
|
||||
func errRecoverWithPGReason(err *error) {
|
||||
e := recover()
|
||||
switch v := e.(type) {
|
||||
case nil:
|
||||
// Do nothing
|
||||
case *pgError:
|
||||
// Return a SimplePGError in place
|
||||
*err = &SimplePGError{*v}
|
||||
default:
|
||||
// Otherwise re-panic
|
||||
panic(e)
|
||||
}
|
||||
}
|
||||
|
||||
func errRecover(err *error) {
|
||||
e := recover()
|
||||
switch v := e.(type) {
|
||||
case nil:
|
||||
// Do nothing
|
||||
case runtime.Error:
|
||||
panic(v)
|
||||
case *pgError:
|
||||
if v.Fatal() {
|
||||
*err = driver.ErrBadConn
|
||||
} else {
|
||||
*err = v
|
||||
}
|
||||
case *net.OpError:
|
||||
*err = driver.ErrBadConn
|
||||
case error:
|
||||
if v == io.EOF || v.(error).Error() == "remote error: handshake failure" {
|
||||
*err = driver.ErrBadConn
|
||||
} else {
|
||||
*err = v
|
||||
}
|
||||
|
||||
default:
|
||||
panic(fmt.Sprintf("unknown error: %#v", e))
|
||||
}
|
||||
}
|
|
@ -0,0 +1,169 @@
|
|||
package oid
|
||||
|
||||
// Generated via massaging this catalog query:
|
||||
//
|
||||
// SELECT 'T_' || typname || ' = ' || oid
|
||||
// FROM pg_type WHERE oid < 10000
|
||||
// ORDER BY oid;
|
||||
//
|
||||
// This should probably be done one per release. Postgres does not
|
||||
// re-appropriate the system OID space below 10000 as a general rule.
|
||||
|
||||
type Oid uint32
|
||||
|
||||
const (
|
||||
T_bool Oid = 16
|
||||
T_bytea = 17
|
||||
T_char = 18
|
||||
T_name = 19
|
||||
T_int8 = 20
|
||||
T_int2 = 21
|
||||
T_int2vector = 22
|
||||
T_int4 = 23
|
||||
T_regproc = 24
|
||||
T_text = 25
|
||||
T_oid = 26
|
||||
T_tid = 27
|
||||
T_xid = 28
|
||||
T_cid = 29
|
||||
T_oidvector = 30
|
||||
T_pg_type = 71
|
||||
T_pg_attribute = 75
|
||||
T_pg_proc = 81
|
||||
T_pg_class = 83
|
||||
T_json = 114
|
||||
T_xml = 142
|
||||
T__xml = 143
|
||||
T_pg_node_tree = 194
|
||||
T__json = 199
|
||||
T_smgr = 210
|
||||
T_point = 600
|
||||
T_lseg = 601
|
||||
T_path = 602
|
||||
T_box = 603
|
||||
T_polygon = 604
|
||||
T_line = 628
|
||||
T__line = 629
|
||||
T_cidr = 650
|
||||
T__cidr = 651
|
||||
T_float4 = 700
|
||||
T_float8 = 701
|
||||
T_abstime = 702
|
||||
T_reltime = 703
|
||||
T_tinterval = 704
|
||||
T_unknown = 705
|
||||
T_circle = 718
|
||||
T__circle = 719
|
||||
T_money = 790
|
||||
T__money = 791
|
||||
T_macaddr = 829
|
||||
T_inet = 869
|
||||
T__bool = 1000
|
||||
T__bytea = 1001
|
||||
T__char = 1002
|
||||
T__name = 1003
|
||||
T__int2 = 1005
|
||||
T__int2vector = 1006
|
||||
T__int4 = 1007
|
||||
T__regproc = 1008
|
||||
T__text = 1009
|
||||
T__tid = 1010
|
||||
T__xid = 1011
|
||||
T__cid = 1012
|
||||
T__oidvector = 1013
|
||||
T__bpchar = 1014
|
||||
T__varchar = 1015
|
||||
T__int8 = 1016
|
||||
T__point = 1017
|
||||
T__lseg = 1018
|
||||
T__path = 1019
|
||||
T__box = 1020
|
||||
T__float4 = 1021
|
||||
T__float8 = 1022
|
||||
T__abstime = 1023
|
||||
T__reltime = 1024
|
||||
T__tinterval = 1025
|
||||
T__polygon = 1027
|
||||
T__oid = 1028
|
||||
T_aclitem = 1033
|
||||
T__aclitem = 1034
|
||||
T__macaddr = 1040
|
||||
T__inet = 1041
|
||||
T_bpchar = 1042
|
||||
T_varchar = 1043
|
||||
T_date = 1082
|
||||
T_time = 1083
|
||||
T_timestamp = 1114
|
||||
T__timestamp = 1115
|
||||
T__date = 1182
|
||||
T__time = 1183
|
||||
T_timestamptz = 1184
|
||||
T__timestamptz = 1185
|
||||
T_interval = 1186
|
||||
T__interval = 1187
|
||||
T__numeric = 1231
|
||||
T_pg_database = 1248
|
||||
T__cstring = 1263
|
||||
T_timetz = 1266
|
||||
T__timetz = 1270
|
||||
T_bit = 1560
|
||||
T__bit = 1561
|
||||
T_varbit = 1562
|
||||
T__varbit = 1563
|
||||
T_numeric = 1700
|
||||
T_refcursor = 1790
|
||||
T__refcursor = 2201
|
||||
T_regprocedure = 2202
|
||||
T_regoper = 2203
|
||||
T_regoperator = 2204
|
||||
T_regclass = 2205
|
||||
T_regtype = 2206
|
||||
T__regprocedure = 2207
|
||||
T__regoper = 2208
|
||||
T__regoperator = 2209
|
||||
T__regclass = 2210
|
||||
T__regtype = 2211
|
||||
T_record = 2249
|
||||
T_cstring = 2275
|
||||
T_any = 2276
|
||||
T_anyarray = 2277
|
||||
T_void = 2278
|
||||
T_trigger = 2279
|
||||
T_language_handler = 2280
|
||||
T_internal = 2281
|
||||
T_opaque = 2282
|
||||
T_anyelement = 2283
|
||||
T__record = 2287
|
||||
T_anynonarray = 2776
|
||||
T_pg_authid = 2842
|
||||
T_pg_auth_members = 2843
|
||||
T__txid_snapshot = 2949
|
||||
T_uuid = 2950
|
||||
T__uuid = 2951
|
||||
T_txid_snapshot = 2970
|
||||
T_fdw_handler = 3115
|
||||
T_anyenum = 3500
|
||||
T_tsvector = 3614
|
||||
T_tsquery = 3615
|
||||
T_gtsvector = 3642
|
||||
T__tsvector = 3643
|
||||
T__gtsvector = 3644
|
||||
T__tsquery = 3645
|
||||
T_regconfig = 3734
|
||||
T__regconfig = 3735
|
||||
T_regdictionary = 3769
|
||||
T__regdictionary = 3770
|
||||
T_anyrange = 3831
|
||||
T_int4range = 3904
|
||||
T__int4range = 3905
|
||||
T_numrange = 3906
|
||||
T__numrange = 3907
|
||||
T_tsrange = 3908
|
||||
T__tsrange = 3909
|
||||
T_tstzrange = 3910
|
||||
T__tstzrange = 3911
|
||||
T_daterange = 3912
|
||||
T__daterange = 3913
|
||||
T_int8range = 3926
|
||||
T__int8range = 3927
|
||||
)
|
|
@ -0,0 +1,68 @@
|
|||
package pq
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
nurl "net/url"
|
||||
"sort"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// ParseURL converts url to a connection string for driver.Open.
|
||||
// Example:
|
||||
//
|
||||
// "postgres://bob:secret@1.2.3.4:5432/mydb?sslmode=verify-full"
|
||||
//
|
||||
// converts to:
|
||||
//
|
||||
// "user=bob password=secret host=1.2.3.4 port=5432 dbname=mydb sslmode=verify-full"
|
||||
//
|
||||
// A minimal example:
|
||||
//
|
||||
// "postgres://"
|
||||
//
|
||||
// This will be blank, causing driver.Open to use all of the defaults
|
||||
func ParseURL(url string) (string, error) {
|
||||
u, err := nurl.Parse(url)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if u.Scheme != "postgres" {
|
||||
return "", fmt.Errorf("invalid connection protocol: %s", u.Scheme)
|
||||
}
|
||||
|
||||
var kvs []string
|
||||
accrue := func(k, v string) {
|
||||
if v != "" {
|
||||
kvs = append(kvs, k+"="+v)
|
||||
}
|
||||
}
|
||||
|
||||
if u.User != nil {
|
||||
v := u.User.Username()
|
||||
accrue("user", v)
|
||||
|
||||
v, _ = u.User.Password()
|
||||
accrue("password", v)
|
||||
}
|
||||
|
||||
i := strings.Index(u.Host, ":")
|
||||
if i < 0 {
|
||||
accrue("host", u.Host)
|
||||
} else {
|
||||
accrue("host", u.Host[:i])
|
||||
accrue("port", u.Host[i+1:])
|
||||
}
|
||||
|
||||
if u.Path != "" {
|
||||
accrue("dbname", u.Path[1:])
|
||||
}
|
||||
|
||||
q := u.Query()
|
||||
for k, _ := range q {
|
||||
accrue(k, q.Get(k))
|
||||
}
|
||||
|
||||
sort.Strings(kvs) // Makes testing easier (not a performance concern)
|
||||
return strings.Join(kvs, " "), nil
|
||||
}
|
|
@ -0,0 +1,54 @@
|
|||
package pq
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestSimpleParseURL(t *testing.T) {
|
||||
expected := "host=hostname.remote"
|
||||
str, err := ParseURL("postgres://hostname.remote")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if str != expected {
|
||||
t.Fatalf("unexpected result from ParseURL:\n+ %v\n- %v", str, expected)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFullParseURL(t *testing.T) {
|
||||
expected := "dbname=database host=hostname.remote password=secret port=1234 user=username"
|
||||
str, err := ParseURL("postgres://username:secret@hostname.remote:1234/database")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if str != expected {
|
||||
t.Fatalf("unexpected result from ParseURL:\n+ %s\n- %s", str, expected)
|
||||
}
|
||||
}
|
||||
|
||||
func TestInvalidProtocolParseURL(t *testing.T) {
|
||||
_, err := ParseURL("http://hostname.remote")
|
||||
switch err {
|
||||
case nil:
|
||||
t.Fatal("Expected an error from parsing invalid protocol")
|
||||
default:
|
||||
msg := "invalid connection protocol: http"
|
||||
if err.Error() != msg {
|
||||
t.Fatalf("Unexpected error message:\n+ %s\n- %s",
|
||||
err.Error(), msg)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestMinimalURL(t *testing.T) {
|
||||
cs, err := ParseURL("postgres://")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if cs != "" {
|
||||
t.Fatalf("expected blank connection string, got: %q", cs)
|
||||
}
|
||||
}
|
|
@ -0,0 +1,15 @@
|
|||
// Package pq is a pure Go Postgres driver for the database/sql package.
|
||||
|
||||
// +build darwin freebsd linux netbsd openbsd
|
||||
|
||||
package pq
|
||||
|
||||
import "os/user"
|
||||
|
||||
func userCurrent() (string, error) {
|
||||
u, err := user.Current()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return u.Username, nil
|
||||
}
|
|
@ -0,0 +1,27 @@
|
|||
// Package pq is a pure Go Postgres driver for the database/sql package.
|
||||
package pq
|
||||
|
||||
import (
|
||||
"path/filepath"
|
||||
"syscall"
|
||||
)
|
||||
|
||||
// Perform Windows user name lookup identically to libpq.
|
||||
//
|
||||
// The PostgreSQL code makes use of the legacy Win32 function
|
||||
// GetUserName, and that function has not been imported into stock Go.
|
||||
// GetUserNameEx is available though, the difference being that a
|
||||
// wider range of names are available. To get the output to be the
|
||||
// same as GetUserName, only the base (or last) component of the
|
||||
// result is returned.
|
||||
func userCurrent() (string, error) {
|
||||
pw_name := make([]uint16, 128)
|
||||
pwname_size := uint32(len(pw_name)) - 1
|
||||
err := syscall.GetUserNameEx(syscall.NameSamCompatible, &pw_name[0], &pwname_size)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
s := syscall.UTF16ToString(pw_name)
|
||||
u := filepath.Base(s)
|
||||
return u, nil
|
||||
}
|
Loading…
Reference in New Issue