diff --git a/lib/go/camli/db/db.go b/lib/go/camli/db/db.go index d622d3d02..958c68c70 100644 --- a/lib/go/camli/db/db.go +++ b/lib/go/camli/db/db.go @@ -226,8 +226,26 @@ func (s *Stmt) Exec(args ...interface{}) os.Error { return err } defer s.db.putConn(ci) - // TODO(bradfitz): convert args from full set (package db) to - // restricted set (package dbimpl) + + if want := si.NumInput(); len(args) != want { + return fmt.Errorf("db: expected %d arguments, got %d", want, len(args)) + } + + // Convert args if the driver knows its own types. + if cc, ok := si.(dbimpl.ColumnConverter); ok { + for n, arg := range args { + args[n], err = cc.ColumnCoverter(n).ConvertValue(arg) + if err != nil { + return fmt.Errorf("db: converting Exec column index %d: %v", n, err) + } + } + } + + // Then convert everything into the restricted subset + // of types that the dbimpl package needs to know about. + // all integers -> int64, etc + // TODO(bradfitz): ^that + resi, err := si.Exec(args) if err != nil { return err diff --git a/lib/go/camli/db/db_test.go b/lib/go/camli/db/db_test.go index e4ece8fa1..4068d861f 100644 --- a/lib/go/camli/db/db_test.go +++ b/lib/go/camli/db/db_test.go @@ -49,12 +49,12 @@ func TestDb(t *testing.T) { {[]interface{}{7, 9}, ""}, // Invalid conversions: - //{[]interface{}{"Brad", int64(0xFFFFFFFF)}, "conversion"}, - //{[]interface{}{"Brad", "strconv fail"}, "conversion"}, + {[]interface{}{"Brad", int64(0xFFFFFFFF)}, "db: converting Exec column index 1: value 4294967295 overflows int32"}, + {[]interface{}{"Brad", "strconv fail"}, "db: converting Exec column index 1: value \"strconv fail\" can't be converted to int32"}, // Wrong number of args: - {[]interface{}{}, "fakedb: expected 2 arguments, got 0"}, - {[]interface{}{1, 2, 3}, "fakedb: expected 2 arguments, got 3"}, + {[]interface{}{}, "db: expected 2 arguments, got 0"}, + {[]interface{}{1, 2, 3}, "db: expected 2 arguments, got 3"}, } for n, et := range execTests { err := stmt.Exec(et.args...) diff --git a/lib/go/camli/db/dbimpl/dbimpl.go b/lib/go/camli/db/dbimpl/dbimpl.go index 324a744f0..dc607d152 100644 --- a/lib/go/camli/db/dbimpl/dbimpl.go +++ b/lib/go/camli/db/dbimpl/dbimpl.go @@ -66,6 +66,16 @@ type Stmt interface { Query(args []interface{}) (Rows, os.Error) } +// ColumnConverter may be optionally implemented by Stmt to signal +// to the db package to do type conversions. +type ColumnConverter interface { + ColumnCoverter(idx int) ValueConverter +} + +type ValueConverter interface { + ConvertValue(v interface{}) (interface{}, os.Error) +} + type Rows interface { Columns() []string Close() os.Error diff --git a/lib/go/camli/db/fakedb_test.go b/lib/go/camli/db/fakedb_test.go index 2059688e9..f83683288 100644 --- a/lib/go/camli/db/fakedb_test.go +++ b/lib/go/camli/db/fakedb_test.go @@ -62,6 +62,15 @@ type table struct { rows []*row } +func (t *table) columnIndex(name string) int { + for n, nname := range t.colname { + if name == nname { + return n + } + } + return -1 +} + type row struct { cols []interface{} // must be same size as its table colname + coltype } @@ -87,6 +96,8 @@ type fakeStmt struct { colType []string // used by CREATE colValue []string // used by INSERT (mix of strings and "?" for bound params) placeholders int // number of ? params + + placeholderConverter []dbimpl.ValueConverter // used by INSERT } var driver dbimpl.Driver = &fakeDriver{} @@ -236,9 +247,9 @@ func (c *fakeConn) Prepare(query string) (dbimpl.Stmt, os.Error) { // TODO(bradfitz): check that // pre-bound value type conversion is // valid for this column type - _ = ctype } else { stmt.placeholders++ + stmt.placeholderConverter = append(stmt.placeholderConverter, converterForType(ctype)) } stmt.colName = append(stmt.colName, column) stmt.colValue = append(stmt.colValue, value) @@ -249,6 +260,10 @@ func (c *fakeConn) Prepare(query string) (dbimpl.Stmt, os.Error) { return stmt, nil } +func (s *fakeStmt) ColumnCoverter(idx int) dbimpl.ValueConverter { + return s.placeholderConverter[idx] +} + func (s *fakeStmt) Close() os.Error { return nil } @@ -274,7 +289,7 @@ func (s *fakeStmt) Exec(args []interface{}) (dbimpl.Result, os.Error) { func (s *fakeStmt) execInsert(args []interface{}) (dbimpl.Result, os.Error) { db := s.c.db if len(args) != s.placeholders { - return nil, fmt.Errorf("fakedb: expected %d arguments, got %d", s.placeholders, len(args)) + panic("error in pkg db; should only get here if size is correct") } db.mu.Lock() t, ok := db.table(s.table) @@ -286,9 +301,29 @@ func (s *fakeStmt) execInsert(args []interface{}) (dbimpl.Result, os.Error) { t.mu.Lock() defer t.mu.Unlock() - // TODO(bradfitz): type check columns, fill in defaults, etc + cols := make([]interface{}, len(t.colname)) + argPos := 0 + for n, colname := range s.colName { + colidx := t.columnIndex(colname) + if colidx == -1 { + return nil, fmt.Errorf("fakedb: column %q doesn't exist or dropped since prepared statement was created", colname) + } + var val interface{} + if s.colValue[n] == "?" { + val = args[argPos] + argPos++ + } else { + val = s.colValue[n] + } + valType := fmt.Sprintf("%T", val) + if colType := t.coltype[colidx]; valType != colType { + return nil, fmt.Errorf("fakedb: column %q value of %v (%v) doesn't match column type of %q", + colname, val, valType, colType) + } + cols[colidx] = val + } - //t.rows = append(t.rows, &row{cols: vals}) + t.rows = append(t.rows, &row{cols: cols}) return dbimpl.RowsAffected(1), nil } @@ -299,7 +334,7 @@ func (s *fakeStmt) Query(args []interface{}) (dbimpl.Rows, os.Error) { } func (s *fakeStmt) NumInput() int { - return 0 + return s.placeholders } func (tx *fakeTx) Commit() os.Error { @@ -311,3 +346,15 @@ func (tx *fakeTx) Rollback() os.Error { tx.c.currTx = nil return nil } + +func converterForType(typ string) dbimpl.ValueConverter { + switch typ { + case "bool": + return dbimpl.Bool + case "int32": + return dbimpl.Int32 + case "string": + return dbimpl.String + } + panic("invalid fakedb column type of " + typ) +}