mirror of https://github.com/perkeep/perkeep.git
sorted/postgres: create databases on boot
Previously, `create table ...` sql migrations were being made without the database existing. This resulted in a panic and error like: pq: database "pk_3a94488d_blobpacked" does not exist There seems to be an upstream issue with our postgres library in which `CREATE DATABASE ...` queries are not prepared so we have to build the sql manually. For now I've added a regex to make sure we don't allow anything too crazy in. Fixes #1022 Change-Id: I0da16759e9219347bb11713b92337021546f9d57
This commit is contained in:
parent
22734c9d29
commit
db2355bc13
|
@ -20,6 +20,7 @@ package postgres // import "perkeep.org/pkg/sorted/postgres"
|
|||
|
||||
import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"regexp"
|
||||
|
||||
|
@ -36,20 +37,39 @@ func init() {
|
|||
}
|
||||
|
||||
func newKeyValueFromJSONConfig(cfg jsonconfig.Obj) (sorted.KeyValue, error) {
|
||||
conninfo := fmt.Sprintf("user=%s dbname=%s host=%s password=%s sslmode=%s",
|
||||
cfg.RequiredString("user"),
|
||||
cfg.RequiredString("database"),
|
||||
cfg.OptionalString("host", "localhost"),
|
||||
cfg.OptionalString("password", ""),
|
||||
cfg.OptionalString("sslmode", "require"),
|
||||
var (
|
||||
user = cfg.RequiredString("user")
|
||||
database = cfg.RequiredString("database")
|
||||
host = cfg.OptionalString("host", "localhost")
|
||||
password = cfg.OptionalString("password", "")
|
||||
sslmode = cfg.OptionalString("sslmode", "require")
|
||||
)
|
||||
if err := cfg.Validate(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// connect without a database, it may not exist yet
|
||||
conninfo := fmt.Sprintf("user=%s host=%s sslmode=%s", user, host, sslmode)
|
||||
if password != "" {
|
||||
conninfo += fmt.Sprintf(" password=%s", password)
|
||||
}
|
||||
db, err := sql.Open("postgres", conninfo)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
err = createDB(db, database)
|
||||
db.Close() // ignoring error, if createDB failed db.Close() will likely also fail
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// reconnect after database is created
|
||||
conninfo += fmt.Sprintf(" dbname=%s", database)
|
||||
db, err = sql.Open("postgres", conninfo)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, tableSql := range SQLCreateTables() {
|
||||
if _, err := db.Exec(tableSql); err != nil {
|
||||
return nil, fmt.Errorf("error creating table with %q: %v", tableSql, err)
|
||||
|
@ -141,3 +161,34 @@ func (kv *keyValue) SchemaVersion() (version int, err error) {
|
|||
err = kv.db.QueryRow("SELECT value FROM meta WHERE metakey='version'").Scan(&version)
|
||||
return
|
||||
}
|
||||
|
||||
var validDatabaseRegex = regexp.MustCompile(`^[a-zA-Z0-9\-_]+$`)
|
||||
|
||||
func validDatabaseName(database string) bool {
|
||||
return validDatabaseRegex.MatchString(database)
|
||||
}
|
||||
|
||||
func createDB(db *sql.DB, database string) error {
|
||||
if database == "" {
|
||||
return errors.New("database name can't be empty")
|
||||
}
|
||||
|
||||
rows, err := db.Query(`SELECT 1 FROM pg_database WHERE datname = $1`, database)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer rows.Close()
|
||||
if rows.Next() {
|
||||
return nil // database is already created
|
||||
}
|
||||
|
||||
// Verify database only has runes we expect
|
||||
if !validDatabaseName(database) {
|
||||
return fmt.Errorf("Invalid postgres database name: %q", database)
|
||||
}
|
||||
_, err = db.Exec(fmt.Sprintf("CREATE DATABASE %s", database))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
|
|
@ -45,3 +45,21 @@ func TestPostgreSQLKV(t *testing.T) {
|
|||
}
|
||||
kvtest.TestSorted(t, kv)
|
||||
}
|
||||
|
||||
func TestPostgresDBNaming(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
valid bool
|
||||
}{
|
||||
{"perkeep", true},
|
||||
{"perkeep_2", true},
|
||||
{"perkeep-2", true},
|
||||
{"'; drop tables;", false}, // validDatabaseName doesn't actually check for sql injection
|
||||
}
|
||||
for i := range cases {
|
||||
res := validDatabaseName(cases[i].name)
|
||||
if res != cases[i].valid {
|
||||
t.Errorf("%q got %v expected %v", cases[i].name, res, cases[i].valid)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue