diff --git a/pkg/sorted/postgres/postgreskv.go b/pkg/sorted/postgres/postgreskv.go index d650d753f..6930f8b40 100644 --- a/pkg/sorted/postgres/postgreskv.go +++ b/pkg/sorted/postgres/postgreskv.go @@ -20,6 +20,7 @@ package postgres // import "perkeep.org/pkg/sorted/postgres" import ( "database/sql" + "errors" "fmt" "regexp" @@ -36,20 +37,39 @@ func init() { } func newKeyValueFromJSONConfig(cfg jsonconfig.Obj) (sorted.KeyValue, error) { - conninfo := fmt.Sprintf("user=%s dbname=%s host=%s password=%s sslmode=%s", - cfg.RequiredString("user"), - cfg.RequiredString("database"), - cfg.OptionalString("host", "localhost"), - cfg.OptionalString("password", ""), - cfg.OptionalString("sslmode", "require"), + var ( + user = cfg.RequiredString("user") + database = cfg.RequiredString("database") + host = cfg.OptionalString("host", "localhost") + password = cfg.OptionalString("password", "") + sslmode = cfg.OptionalString("sslmode", "require") ) if err := cfg.Validate(); err != nil { return nil, err } + + // connect without a database, it may not exist yet + conninfo := fmt.Sprintf("user=%s host=%s sslmode=%s", user, host, sslmode) + if password != "" { + conninfo += fmt.Sprintf(" password=%s", password) + } db, err := sql.Open("postgres", conninfo) if err != nil { return nil, err } + err = createDB(db, database) + db.Close() // ignoring error, if createDB failed db.Close() will likely also fail + if err != nil { + return nil, err + } + + // reconnect after database is created + conninfo += fmt.Sprintf(" dbname=%s", database) + db, err = sql.Open("postgres", conninfo) + if err != nil { + return nil, err + } + for _, tableSql := range SQLCreateTables() { if _, err := db.Exec(tableSql); err != nil { return nil, fmt.Errorf("error creating table with %q: %v", tableSql, err) @@ -141,3 +161,34 @@ func (kv *keyValue) SchemaVersion() (version int, err error) { err = kv.db.QueryRow("SELECT value FROM meta WHERE metakey='version'").Scan(&version) return } + +var validDatabaseRegex = regexp.MustCompile(`^[a-zA-Z0-9\-_]+$`) + +func validDatabaseName(database string) bool { + return validDatabaseRegex.MatchString(database) +} + +func createDB(db *sql.DB, database string) error { + if database == "" { + return errors.New("database name can't be empty") + } + + rows, err := db.Query(`SELECT 1 FROM pg_database WHERE datname = $1`, database) + if err != nil { + return err + } + defer rows.Close() + if rows.Next() { + return nil // database is already created + } + + // Verify database only has runes we expect + if !validDatabaseName(database) { + return fmt.Errorf("Invalid postgres database name: %q", database) + } + _, err = db.Exec(fmt.Sprintf("CREATE DATABASE %s", database)) + if err != nil { + return err + } + return err +} diff --git a/pkg/sorted/postgres/postgreskv_test.go b/pkg/sorted/postgres/postgreskv_test.go index c6342491f..b5d719070 100644 --- a/pkg/sorted/postgres/postgreskv_test.go +++ b/pkg/sorted/postgres/postgreskv_test.go @@ -45,3 +45,21 @@ func TestPostgreSQLKV(t *testing.T) { } kvtest.TestSorted(t, kv) } + +func TestPostgresDBNaming(t *testing.T) { + cases := []struct { + name string + valid bool + }{ + {"perkeep", true}, + {"perkeep_2", true}, + {"perkeep-2", true}, + {"'; drop tables;", false}, // validDatabaseName doesn't actually check for sql injection + } + for i := range cases { + res := validDatabaseName(cases[i].name) + if res != cases[i].valid { + t.Errorf("%q got %v expected %v", cases[i].name, res, cases[i].valid) + } + } +}