diff --git a/third_party/labix.org/v2/mgo/LICENSE b/third_party/labix.org/v2/mgo/LICENSE index 83e319bb8..770c7672b 100644 --- a/third_party/labix.org/v2/mgo/LICENSE +++ b/third_party/labix.org/v2/mgo/LICENSE @@ -1,6 +1,6 @@ mgo - MongoDB driver for Go -Copyright (c) 2010-2012 - Gustavo Niemeyer +Copyright (c) 2010-2013 - Gustavo Niemeyer All rights reserved. diff --git a/third_party/labix.org/v2/mgo/auth.go b/third_party/labix.org/v2/mgo/auth.go index df19b6d2f..56cdfc295 100644 --- a/third_party/labix.org/v2/mgo/auth.go +++ b/third_party/labix.org/v2/mgo/auth.go @@ -1,18 +1,18 @@ // mgo - MongoDB driver for Go -// +// // Copyright (c) 2010-2012 - Gustavo Niemeyer -// +// // All rights reserved. // // Redistribution and use in source and binary forms, with or without -// modification, are permitted provided that the following conditions are met: -// +// modification, are permitted provided that the following conditions are met: +// // 1. Redistributions of source code must retain the above copyright notice, this -// list of conditions and the following disclaimer. +// list of conditions and the following disclaimer. // 2. Redistributions in binary form must reproduce the above copyright notice, // this list of conditions and the following disclaimer in the documentation -// and/or other materials provided with the distribution. -// +// and/or other materials provided with the distribution. +// // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND // ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED // WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE @@ -27,11 +27,11 @@ package mgo import ( + "camlistore.org/third_party/labix.org/v2/mgo/bson" "crypto/md5" "encoding/hex" "errors" "fmt" - "camlistore.org/third_party/labix.org/v2/mgo/bson" "sync" ) @@ -91,13 +91,13 @@ func (socket *mongoSocket) resetNonce() { op.limit = -1 op.replyFunc = func(err error, reply *replyOp, docNum int, docData []byte) { if err != nil { - socket.kill(errors.New("getNonce: " + err.Error()), true) + socket.kill(errors.New("getNonce: "+err.Error()), true) return } result := &getNonceResult{} err = bson.Unmarshal(docData, &result) if err != nil { - socket.kill(errors.New("Failed to unmarshal nonce: " + err.Error()), true) + socket.kill(errors.New("Failed to unmarshal nonce: "+err.Error()), true) return } debugf("Socket %p to %s: nonce unmarshalled: %#v", socket, socket.addr, result) @@ -125,7 +125,7 @@ func (socket *mongoSocket) resetNonce() { } err := socket.Query(op) if err != nil { - socket.kill(errors.New("resetNonce: " + err.Error()), true) + socket.kill(errors.New("resetNonce: "+err.Error()), true) } } diff --git a/third_party/labix.org/v2/mgo/auth_test.go b/third_party/labix.org/v2/mgo/auth_test.go index 1136fc96a..2b31454b9 100644 --- a/third_party/labix.org/v2/mgo/auth_test.go +++ b/third_party/labix.org/v2/mgo/auth_test.go @@ -1,18 +1,18 @@ // mgo - MongoDB driver for Go -// +// // Copyright (c) 2010-2012 - Gustavo Niemeyer -// +// // All rights reserved. // // Redistribution and use in source and binary forms, with or without -// modification, are permitted provided that the following conditions are met: -// +// modification, are permitted provided that the following conditions are met: +// // 1. Redistributions of source code must retain the above copyright notice, this -// list of conditions and the following disclaimer. +// list of conditions and the following disclaimer. // 2. Redistributions in binary form must reproduce the above copyright notice, // this list of conditions and the following disclaimer in the documentation -// and/or other materials provided with the distribution. -// +// and/or other materials provided with the distribution. +// // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND // ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED // WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE @@ -27,54 +27,62 @@ package mgo_test import ( - . "camlistore.org/third_party/launchpad.net/gocheck" "camlistore.org/third_party/labix.org/v2/mgo" + . "camlistore.org/third_party/launchpad.net/gocheck" + "fmt" "sync" + "time" ) func (s *S) TestAuthLogin(c *C) { - session, err := mgo.Dial("localhost:40002") - c.Assert(err, IsNil) - defer session.Close() + // Test both with a normal database and with an authenticated shard. + for _, addr := range []string{"localhost:40002", "localhost:40203"} { + session, err := mgo.Dial(addr) + c.Assert(err, IsNil) + defer session.Close() - coll := session.DB("mydb").C("mycoll") - err = coll.Insert(M{"n": 1}) - c.Assert(err, ErrorMatches, "unauthorized") + coll := session.DB("mydb").C("mycoll") + err = coll.Insert(M{"n": 1}) + c.Assert(err, ErrorMatches, "unauthorized|need to login|not authorized .*") - admindb := session.DB("admin") + admindb := session.DB("admin") - err = admindb.Login("root", "wrong") - c.Assert(err, ErrorMatches, "auth fails") + err = admindb.Login("root", "wrong") + c.Assert(err, ErrorMatches, "auth fails") - err = admindb.Login("root", "rapadura") - c.Assert(err, IsNil) + err = admindb.Login("root", "rapadura") + c.Assert(err, IsNil) - err = coll.Insert(M{"n": 1}) - c.Assert(err, IsNil) + err = coll.Insert(M{"n": 1}) + c.Assert(err, IsNil) + } } func (s *S) TestAuthLoginLogout(c *C) { - session, err := mgo.Dial("localhost:40002") - c.Assert(err, IsNil) - defer session.Close() + // Test both with a normal database and with an authenticated shard. + for _, addr := range []string{"localhost:40002", "localhost:40203"} { + session, err := mgo.Dial(addr) + c.Assert(err, IsNil) + defer session.Close() - admindb := session.DB("admin") - err = admindb.Login("root", "rapadura") - c.Assert(err, IsNil) + admindb := session.DB("admin") + err = admindb.Login("root", "rapadura") + c.Assert(err, IsNil) - admindb.Logout() + admindb.Logout() - coll := session.DB("mydb").C("mycoll") - err = coll.Insert(M{"n": 1}) - c.Assert(err, ErrorMatches, "unauthorized") + coll := session.DB("mydb").C("mycoll") + err = coll.Insert(M{"n": 1}) + c.Assert(err, ErrorMatches, "unauthorized|need to login|not authorized .*") - // Must have dropped auth from the session too. - session = session.Copy() - defer session.Close() + // Must have dropped auth from the session too. + session = session.Copy() + defer session.Close() - coll = session.DB("mydb").C("mycoll") - err = coll.Insert(M{"n": 1}) - c.Assert(err, ErrorMatches, "unauthorized") + coll = session.DB("mydb").C("mycoll") + err = coll.Insert(M{"n": 1}) + c.Assert(err, ErrorMatches, "unauthorized|need to login|not authorized .*") + } } func (s *S) TestAuthLoginLogoutAll(c *C) { @@ -90,7 +98,7 @@ func (s *S) TestAuthLoginLogoutAll(c *C) { coll := session.DB("mydb").C("mycoll") err = coll.Insert(M{"n": 1}) - c.Assert(err, ErrorMatches, "unauthorized") + c.Assert(err, ErrorMatches, "unauthorized|need to login|not authorized .*") // Must have dropped auth from the session too. session = session.Copy() @@ -98,7 +106,181 @@ func (s *S) TestAuthLoginLogoutAll(c *C) { coll = session.DB("mydb").C("mycoll") err = coll.Insert(M{"n": 1}) - c.Assert(err, ErrorMatches, "unauthorized") + c.Assert(err, ErrorMatches, "unauthorized|need to login|not authorized .*") +} + +func (s *S) TestAuthUpsertUserErrors(c *C) { + session, err := mgo.Dial("localhost:40002") + c.Assert(err, IsNil) + defer session.Close() + + admindb := session.DB("admin") + err = admindb.Login("root", "rapadura") + c.Assert(err, IsNil) + + mydb := session.DB("mydb") + + err = mydb.UpsertUser(&mgo.User{}) + c.Assert(err, ErrorMatches, "user has no Username") + + err = mydb.UpsertUser(&mgo.User{Username: "user", Password: "pass", UserSource: "source"}) + c.Assert(err, ErrorMatches, "user has both Password/PasswordHash and UserSource set") + + err = mydb.UpsertUser(&mgo.User{Username: "user", Password: "pass", OtherDBRoles: map[string][]mgo.Role{"db": nil}}) + c.Assert(err, ErrorMatches, "user with OtherDBRoles is only supported in admin database") +} + +func (s *S) TestAuthUpsertUser(c *C) { + if !s.versionAtLeast(2, 4) { + c.Skip("UpsertUser only works on 2.4+") + } + session, err := mgo.Dial("localhost:40002") + c.Assert(err, IsNil) + defer session.Close() + + admindb := session.DB("admin") + err = admindb.Login("root", "rapadura") + c.Assert(err, IsNil) + + mydb := session.DB("mydb") + myotherdb := session.DB("myotherdb") + + ruser := &mgo.User{ + Username: "myruser", + Password: "mypass", + Roles: []mgo.Role{mgo.RoleRead}, + } + rwuser := &mgo.User{ + Username: "myrwuser", + Password: "mypass", + Roles: []mgo.Role{mgo.RoleReadWrite}, + } + rwuserother := &mgo.User{ + Username: "myrwuser", + UserSource: "mydb", + Roles: []mgo.Role{mgo.RoleRead}, + } + + err = mydb.UpsertUser(ruser) + c.Assert(err, IsNil) + err = mydb.UpsertUser(rwuser) + c.Assert(err, IsNil) + err = myotherdb.UpsertUser(rwuserother) + c.Assert(err, IsNil) + + err = mydb.Login("myruser", "mypass") + c.Assert(err, IsNil) + + admindb.Logout() + + coll := session.DB("mydb").C("mycoll") + err = coll.Insert(M{"n": 1}) + c.Assert(err, ErrorMatches, "unauthorized|not authorized .*") + + err = mydb.Login("myrwuser", "mypass") + c.Assert(err, IsNil) + + err = coll.Insert(M{"n": 1}) + c.Assert(err, IsNil) + + // Test indirection via UserSource: we can't write to it, because + // the roles for myrwuser are different there. + othercoll := myotherdb.C("myothercoll") + err = othercoll.Insert(M{"n": 1}) + c.Assert(err, ErrorMatches, "unauthorized|not authorized .*") + + // Reading works, though. + err = othercoll.Find(nil).One(nil) + c.Assert(err, Equals, mgo.ErrNotFound) + + // Can't login directly into the database using UserSource, though. + err = myotherdb.Login("myrwuser", "mypass") + c.Assert(err, ErrorMatches, "auth fails") +} + +func (s *S) TestAuthUpserUserOtherDBRoles(c *C) { + if !s.versionAtLeast(2, 4) { + c.Skip("UpsertUser only works on 2.4+") + } + session, err := mgo.Dial("localhost:40002") + c.Assert(err, IsNil) + defer session.Close() + + admindb := session.DB("admin") + err = admindb.Login("root", "rapadura") + c.Assert(err, IsNil) + + ruser := &mgo.User{ + Username: "myruser", + Password: "mypass", + OtherDBRoles: map[string][]mgo.Role{"mydb": []mgo.Role{mgo.RoleRead}}, + } + + err = admindb.UpsertUser(ruser) + c.Assert(err, IsNil) + defer admindb.RemoveUser("myruser") + + admindb.Logout() + err = admindb.Login("myruser", "mypass") + + coll := session.DB("mydb").C("mycoll") + err = coll.Insert(M{"n": 1}) + c.Assert(err, ErrorMatches, "unauthorized|not authorized .*") + + err = coll.Find(nil).One(nil) + c.Assert(err, Equals, mgo.ErrNotFound) +} + +func (s *S) TestAuthUpserUserUnsetFields(c *C) { + if !s.versionAtLeast(2, 4) { + c.Skip("UpsertUser only works on 2.4+") + } + session, err := mgo.Dial("localhost:40002") + c.Assert(err, IsNil) + defer session.Close() + + admindb := session.DB("admin") + err = admindb.Login("root", "rapadura") + c.Assert(err, IsNil) + + // Insert a user with most fields set. + user := &mgo.User{ + Username: "myruser", + Password: "mypass", + Roles: []mgo.Role{mgo.RoleRead}, + OtherDBRoles: map[string][]mgo.Role{"mydb": []mgo.Role{mgo.RoleRead}}, + } + err = admindb.UpsertUser(user) + c.Assert(err, IsNil) + defer admindb.RemoveUser("myruser") + + // Now update the user with few things set. + user = &mgo.User{ + Username: "myruser", + UserSource: "mydb", + } + err = admindb.UpsertUser(user) + c.Assert(err, IsNil) + + // Everything that was unset must have been dropped. + var userm M + err = admindb.C("system.users").Find(M{"user": "myruser"}).One(&userm) + c.Assert(err, IsNil) + delete(userm, "_id") + c.Assert(userm, DeepEquals, M{"user": "myruser", "userSource": "mydb", "roles": []interface{}{}}) + + // Now set password again... + user = &mgo.User{ + Username: "myruser", + Password: "mypass", + } + err = admindb.UpsertUser(user) + c.Assert(err, IsNil) + + // ... and assert that userSource has been dropped. + err = admindb.C("system.users").Find(M{"user": "myruser"}).One(&userm) + _, found := userm["userSource"] + c.Assert(found, Equals, false) } func (s *S) TestAuthAddUser(c *C) { @@ -123,7 +305,7 @@ func (s *S) TestAuthAddUser(c *C) { coll := session.DB("mydb").C("mycoll") err = coll.Insert(M{"n": 1}) - c.Assert(err, ErrorMatches, "unauthorized") + c.Assert(err, ErrorMatches, "unauthorized|not authorized .*") err = mydb.Login("mywuser", "mypass") c.Assert(err, IsNil) @@ -156,7 +338,7 @@ func (s *S) TestAuthAddUserReplaces(c *C) { // ReadOnly flag was changed too. err = mydb.C("mycoll").Insert(M{"n": 1}) - c.Assert(err, ErrorMatches, "unauthorized") + c.Assert(err, ErrorMatches, "unauthorized|not authorized .*") } func (s *S) TestAuthRemoveUser(c *C) { @@ -233,7 +415,7 @@ func (s *S) TestAuthLoginSwitchUser(c *C) { // Can't write. err = coll.Insert(M{"n": 1}) - c.Assert(err, ErrorMatches, "unauthorized") + c.Assert(err, ErrorMatches, "unauthorized|not authorized .*") // But can read. result := struct{ N int }{} @@ -268,7 +450,7 @@ func (s *S) TestAuthLoginChangePassword(c *C) { // The second login must be in effect, which means read-only. err = mydb.C("mycoll").Insert(M{"n": 1}) - c.Assert(err, ErrorMatches, "unauthorized") + c.Assert(err, ErrorMatches, "unauthorized|not authorized .*") } func (s *S) TestAuthLoginCachingWithSessionRefresh(c *C) { @@ -335,7 +517,7 @@ func (s *S) TestAuthLoginCachingWithNewSession(c *C) { coll := session.DB("mydb").C("mycoll") err = coll.Insert(M{"n": 1}) - c.Assert(err, ErrorMatches, "unauthorized") + c.Assert(err, ErrorMatches, "unauthorized|need to login|not authorized for .*") } func (s *S) TestAuthLoginCachingAcrossPool(c *C) { @@ -432,7 +614,7 @@ func (s *S) TestAuthLoginCachingAcrossPoolWithLogout(c *C) { // Can't write, since root has been implicitly logged out // when the collection went into the pool, and not revalidated. err = other.DB("mydb").C("mycoll").Insert(M{"n": 1}) - c.Assert(err, ErrorMatches, "unauthorized") + c.Assert(err, ErrorMatches, "unauthorized|not authorized .*") // But can read due to the revalidated myuser login. result := struct{ N int }{} @@ -515,3 +697,82 @@ func (s *S) TestAuthURLWithNewSession(c *C) { err = session.DB("mydb").C("mycoll").Insert(M{"n": 1}) c.Assert(err, IsNil) } + +func (s *S) TestAuthURLWithDatabase(c *C) { + session, err := mgo.Dial("mongodb://root:rapadura@localhost:40002") + c.Assert(err, IsNil) + defer session.Close() + + mydb := session.DB("mydb") + err = mydb.AddUser("myruser", "mypass", true) + c.Assert(err, IsNil) + + usession, err := mgo.Dial("mongodb://myruser:mypass@localhost:40002/mydb") + c.Assert(err, IsNil) + defer usession.Close() + + ucoll := usession.DB("mydb").C("mycoll") + err = ucoll.FindId(0).One(nil) + c.Assert(err, Equals, mgo.ErrNotFound) + err = ucoll.Insert(M{"n": 1}) + c.Assert(err, ErrorMatches, "unauthorized|not authorized .*") +} + +func (s *S) TestDefaultDatabase(c *C) { + tests := []struct{ url, db string }{ + {"mongodb://root:rapadura@localhost:40002", "test"}, + {"mongodb://root:rapadura@localhost:40002/admin", "admin"}, + {"mongodb://localhost:40001", "test"}, + {"mongodb://localhost:40001/", "test"}, + {"mongodb://localhost:40001/mydb", "mydb"}, + } + + for _, test := range tests { + session, err := mgo.Dial(test.url) + c.Assert(err, IsNil) + defer session.Close() + + c.Logf("test: %#v", test) + c.Assert(session.DB("").Name, Equals, test.db) + + scopy := session.Copy() + c.Check(scopy.DB("").Name, Equals, test.db) + scopy.Close() + } +} + +func (s *S) TestAuthDirect(c *C) { + // Direct connections must work to the master and slaves. + for _, port := range []string{"40031", "40032", "40033"} { + url := fmt.Sprintf("mongodb://root:rapadura@localhost:%s/?connect=direct", port) + session, err := mgo.Dial(url) + c.Assert(err, IsNil) + defer session.Close() + + session.SetMode(mgo.Monotonic, true) + + var result struct{} + err = session.DB("mydb").C("mycoll").Find(nil).One(&result) + c.Assert(err, Equals, mgo.ErrNotFound) + } +} + +func (s *S) TestAuthDirectWithLogin(c *C) { + // Direct connections must work to the master and slaves. + for _, port := range []string{"40031", "40032", "40033"} { + url := fmt.Sprintf("mongodb://localhost:%s/?connect=direct", port) + session, err := mgo.Dial(url) + c.Assert(err, IsNil) + defer session.Close() + + session.SetMode(mgo.Monotonic, true) + session.SetSyncTimeout(3 * time.Second) + + err = session.DB("admin").Login("root", "rapadura") + c.Assert(err, IsNil) + + var result struct{} + err = session.DB("mydb").C("mycoll").Find(nil).One(&result) + c.Assert(err, Equals, mgo.ErrNotFound) + } +} diff --git a/third_party/labix.org/v2/mgo/bson/bson.go b/third_party/labix.org/v2/mgo/bson/bson.go index c33b49534..9671ace0f 100644 --- a/third_party/labix.org/v2/mgo/bson/bson.go +++ b/third_party/labix.org/v2/mgo/bson/bson.go @@ -1,18 +1,18 @@ // BSON library for Go -// +// // Copyright (c) 2010-2012 - Gustavo Niemeyer -// +// // All rights reserved. // // Redistribution and use in source and binary forms, with or without -// modification, are permitted provided that the following conditions are met: -// +// modification, are permitted provided that the following conditions are met: +// // 1. Redistributions of source code must retain the above copyright notice, this -// list of conditions and the following disclaimer. +// list of conditions and the following disclaimer. // 2. Redistributions in binary form must reproduce the above copyright notice, // this list of conditions and the following disclaimer in the documentation -// and/or other materials provided with the distribution. -// +// and/or other materials provided with the distribution. +// // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND // ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED // WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE @@ -24,14 +24,22 @@ // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS // SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +// Package bson is an implementation of the BSON specification for Go: +// +// http://bsonspec.org +// +// It was created as part of the mgo MongoDB driver for Go, but is standalone +// and may be used on its own without the driver. package bson import ( "crypto/md5" + "crypto/rand" "encoding/binary" "encoding/hex" "errors" "fmt" + "io" "os" "reflect" "runtime" @@ -58,11 +66,12 @@ type Getter interface { // value via the SetBSON method during unmarshaling, and the object // itself will not be changed as usual. // -// If setting the value works, the method should return nil. If it returns -// a bson.TypeError value, the BSON value will be omitted from a map or -// slice being decoded and the unmarshalling will continue. If it returns -// any other non-nil error, the unmarshalling procedure will stop and error -// out with the provided value. +// If setting the value works, the method should return nil or alternatively +// bson.SetZero to set the respective field to its zero value (nil for +// pointer types). If SetBSON returns a value of type bson.TypeError, the +// BSON value will be omitted from a map or slice being decoded and the +// unmarshalling will continue. If it returns any other non-nil error, the +// unmarshalling procedure will stop and error out with the provided value. // // This interface is generally useful in pointer receivers, since the method // will want to change the receiver. A type field that implements the Setter @@ -84,6 +93,11 @@ type Setter interface { SetBSON(raw Raw) error } +// SetZero may be returned from a SetBSON method to have the value set to +// its respective zero value. When used in pointer values, this will set the +// field to nil rather than to the pre-allocated value. +var SetZero = errors.New("set to zero") + // M is a convenient alias for a map[string]interface{} map, useful for // dealing with BSON in a native way. For instance: // @@ -94,23 +108,30 @@ type Setter interface { // undefined ordered. See also the bson.D type for an ordered alternative. type M map[string]interface{} -// D is a type for dealing with documents containing ordered elements in a -// native fashion. For instance: +// D represents a BSON document containing ordered elements. For example: // // bson.D{{"a", 1}, {"b", true}} // // In some situations, such as when creating indexes for MongoDB, the order in // which the elements are defined is important. If the order is not important, -// using a map is generally more comfortable. See the bson.M type and the -// Map() method for D. +// using a map is generally more comfortable. See bson.M and bson.RawD. type D []DocElem -// See the bson.D type. +// See the D type. type DocElem struct { Name string Value interface{} } +// Map returns a map out of the ordered element name/value pairs in d. +func (d D) Map() (m M) { + m = make(M, len(d)) + for _, item := range d { + m[item.Name] = item.Value + } + return m +} + // The Raw type represents raw unprocessed BSON documents and elements. // Kind is the kind of element as defined per the BSON specification, and // Data is the raw unprocessed data for the respective element. @@ -125,13 +146,16 @@ type Raw struct { Data []byte } -// Map returns a map out of the ordered element name/value pairs in d. -func (d D) Map() (m M) { - m = make(M, len(d)) - for _, item := range d { - m[item.Name] = item.Value - } - return m +// RawD represents a BSON document containing raw unprocessed elements. +// This low-level representation may be useful when lazily processing +// documents of uncertain content, or when manipulating the raw content +// documents in general. +type RawD []RawDocElem + +// See the RawD type. +type RawDocElem struct { + Name string + Value Raw } // ObjectId is a unique ID identifying a BSON value. It must be exactly 12 bytes @@ -143,7 +167,7 @@ type ObjectId string // ObjectIdHex returns an ObjectId from the provided hex representation. // Calling this function with an invalid hex representation will -// cause a runtime panic. +// cause a runtime panic. See the IsObjectIdHex function. func ObjectIdHex(s string) ObjectId { d, err := hex.DecodeString(s) if err != nil || len(d) != 12 { @@ -152,40 +176,52 @@ func ObjectIdHex(s string) ObjectId { return ObjectId(d) } +// IsObjectIdHex returns whether s is a valid hex representation of +// an ObjectId. See the ObjectIdHex function. +func IsObjectIdHex(s string) bool { + if len(s) != 24 { + return false + } + _, err := hex.DecodeString(s) + return err == nil +} + // objectIdCounter is atomically incremented when generating a new ObjectId // using NewObjectId() function. It's used as a counter part of an id. var objectIdCounter uint32 = 0 // machineId stores machine id generated once and used in subsequent calls // to NewObjectId function. -var machineId []byte +var machineId = readMachineId() // initMachineId generates machine id and puts it into the machineId global // variable. If this function fails to get the hostname, it will cause // a runtime error. -func initMachineId() { +func readMachineId() []byte { var sum [3]byte - hostname, err := os.Hostname() - if err != nil { - panic("Failed to get hostname: " + err.Error()) + id := sum[:] + hostname, err1 := os.Hostname() + if err1 != nil { + _, err2 := io.ReadFull(rand.Reader, id) + if err2 != nil { + panic(fmt.Errorf("cannot get hostname: %v; %v", err1, err2)) + } + return id } hw := md5.New() hw.Write([]byte(hostname)) - copy(sum[:3], hw.Sum(nil)) - machineId = sum[:] + copy(id, hw.Sum(nil)) + return id } // NewObjectId returns a new unique ObjectId. // This function causes a runtime error if it fails to get the hostname // of the current machine. func NewObjectId() ObjectId { - b := make([]byte, 12) + var b [12]byte // Timestamp, 4 bytes, big endian - binary.BigEndian.PutUint32(b, uint32(time.Now().Unix())) + binary.BigEndian.PutUint32(b[:], uint32(time.Now().Unix())) // Machine, first 3 bytes of md5(hostname) - if machineId == nil { - initMachineId() - } b[4] = machineId[0] b[5] = machineId[1] b[6] = machineId[2] @@ -198,7 +234,7 @@ func NewObjectId() ObjectId { b[9] = byte(i >> 16) b[10] = byte(i >> 8) b[11] = byte(i) - return ObjectId(b) + return ObjectId(b[:]) } // NewObjectIdWithTime returns a dummy ObjectId with the timestamp part filled @@ -294,7 +330,7 @@ type Symbol string // why this function exists. Using the time.Now function also works fine // otherwise. func Now() time.Time { - return time.Unix(0, time.Now().UnixNano() / 1e6 * 1e6) + return time.Unix(0, time.Now().UnixNano()/1e6*1e6) } // MongoTimestamp is a special internal type used by MongoDB that for some @@ -383,16 +419,16 @@ func handleErr(err *error) { // // The following flags are currently supported: // -// omitempty Only include the field if it's not set to the zero -// value for the type or to empty slices or maps. -// Does not apply to zero valued structs. +// omitempty Only include the field if it's not set to the zero +// value for the type or to empty slices or maps. // -// minsize Marshal an int64 value as an int32, if that's feasible -// while preserving the numeric value. +// minsize Marshal an int64 value as an int32, if that's feasible +// while preserving the numeric value. // -// inline Inline the field, which must be a struct, causing all -// of its fields to be processed as if they were part of -// the outer struct. +// inline Inline the field, which must be a struct or a map, +// causing all of its fields or keys to be processed as if +// they were part of the outer struct. For maps, keys must +// not conflict with the bson keys of other struct fields. // // Some examples: // @@ -404,7 +440,7 @@ func handleErr(err *error) { // E int64 ",minsize" // F int64 "myf,omitempty,minsize" // } -// +// func Marshal(in interface{}) (out []byte, err error) { defer handleErr(&err) e := &encoder{make([]byte, 0, initialBufferSize)} @@ -413,10 +449,24 @@ func Marshal(in interface{}) (out []byte, err error) { } // Unmarshal deserializes data from in into the out value. The out value -// must be a map or a pointer to a struct (or a pointer to a struct pointer). +// must be a map, a pointer to a struct, or a pointer to a bson.D value. // The lowercased field name is used as the key for each exported field, // but this behavior may be changed using the respective field tag. -// Uninitialized pointer values are properly initialized only when necessary. +// The tag may also contain flags to tweak the marshalling behavior for +// the field. The tag formats accepted are: +// +// "[][,[,]]" +// +// `(...) bson:"[][,[,]]" (...)` +// +// The following flags are currently supported during unmarshal (see the +// Marshal method for other flags): +// +// inline Inline the field, which must be a struct or a map. +// Inlined structs are handled as if its fields were part +// of the outer struct. An inlined map causes keys that do +// not match any other struct field to be inserted in the +// map rather than being discarded as usual. // // The target field or element types of out may not necessarily match // the BSON values of the provided data. The following conversions are @@ -428,14 +478,16 @@ func Marshal(in interface{}) (out []byte, err error) { // - Numeric types are converted to bools as true if not 0 or false otherwise // - Binary and string BSON data is converted to a string, array or byte slice // -// If the value would not fit the type and cannot be converted, it's silently -// skipped. +// If the value would not fit the type and cannot be converted, it's +// silently skipped. +// +// Pointer values are initialized when necessary. func Unmarshal(in []byte, out interface{}) (err error) { defer handleErr(&err) v := reflect.ValueOf(out) switch v.Kind() { case reflect.Map, reflect.Ptr: - d := &decoder{in: in} + d := newDecoder(in) d.readDocTo(v) case reflect.Struct: return errors.New("Unmarshal can't deal with struct values. Use a pointer.") @@ -458,7 +510,7 @@ func (raw Raw) Unmarshal(out interface{}) (err error) { v = v.Elem() fallthrough case reflect.Map: - d := &decoder{in: raw.Data} + d := newDecoder(raw.Data) good := d.readElemTo(v, raw.Kind) if !good { return &TypeError{v.Type(), raw.Kind} @@ -486,6 +538,7 @@ func (e *TypeError) Error() string { type structInfo struct { FieldsMap map[string]fieldInfo FieldsList []fieldInfo + InlineMap int Zero reflect.Value } @@ -516,6 +569,7 @@ func getStructInfo(st reflect.Type) (*structInfo, error) { n := st.NumField() fieldsMap := make(map[string]fieldInfo) fieldsList := make([]fieldInfo, 0, n) + inlineMap := -1 for i := 0; i != n; i++ { field := st.Field(i) if field.PkgPath != "" { @@ -570,25 +624,35 @@ func getStructInfo(st reflect.Type) (*structInfo, error) { } if inline { - if field.Type.Kind() != reflect.Struct { - panic("Option ,inline needs a struct value field") - } - sinfo, err := getStructInfo(field.Type) - if err != nil { - return nil, err - } - for _, finfo := range sinfo.FieldsList { - if _, found := fieldsMap[finfo.Key]; found { - msg := "Duplicated key '" + finfo.Key + "' in struct " + st.String() - return nil, errors.New(msg) + switch field.Type.Kind() { + case reflect.Map: + if inlineMap >= 0 { + return nil, errors.New("Multiple ,inline maps in struct " + st.String()) } - if finfo.Inline == nil { - finfo.Inline = []int{i, finfo.Num} - } else { - finfo.Inline = append([]int{i}, finfo.Inline...) + if field.Type.Key() != reflect.TypeOf("") { + return nil, errors.New("Option ,inline needs a map with string keys in struct " + st.String()) } - fieldsMap[finfo.Key] = finfo - fieldsList = append(fieldsList, finfo) + inlineMap = info.Num + case reflect.Struct: + sinfo, err := getStructInfo(field.Type) + if err != nil { + return nil, err + } + for _, finfo := range sinfo.FieldsList { + if _, found := fieldsMap[finfo.Key]; found { + msg := "Duplicated key '" + finfo.Key + "' in struct " + st.String() + return nil, errors.New(msg) + } + if finfo.Inline == nil { + finfo.Inline = []int{i, finfo.Num} + } else { + finfo.Inline = append([]int{i}, finfo.Inline...) + } + fieldsMap[finfo.Key] = finfo + fieldsList = append(fieldsList, finfo) + } + default: + panic("Option ,inline needs a struct value or map field") } continue } @@ -609,7 +673,8 @@ func getStructInfo(st reflect.Type) (*structInfo, error) { } sinfo = &structInfo{ fieldsMap, - fieldsList[:len(fieldsMap)], + fieldsList, + inlineMap, reflect.New(st).Elem(), } structMapMutex.Lock() diff --git a/third_party/labix.org/v2/mgo/bson/bson_test.go b/third_party/labix.org/v2/mgo/bson/bson_test.go index 21864cd6a..3db1b64d1 100644 --- a/third_party/labix.org/v2/mgo/bson/bson_test.go +++ b/third_party/labix.org/v2/mgo/bson/bson_test.go @@ -1,18 +1,18 @@ // BSON library for Go -// +// // Copyright (c) 2010-2012 - Gustavo Niemeyer -// +// // All rights reserved. // // Redistribution and use in source and binary forms, with or without -// modification, are permitted provided that the following conditions are met: -// +// modification, are permitted provided that the following conditions are met: +// // 1. Redistributions of source code must retain the above copyright notice, this -// list of conditions and the following disclaimer. +// list of conditions and the following disclaimer. // 2. Redistributions in binary form must reproduce the above copyright notice, // this list of conditions and the following disclaimer in the documentation -// and/or other materials provided with the distribution. -// +// and/or other materials provided with the distribution. +// // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND // ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED // WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE @@ -28,11 +28,11 @@ package bson_test import ( + "camlistore.org/third_party/labix.org/v2/mgo/bson" + . "camlistore.org/third_party/launchpad.net/gocheck" "encoding/binary" "encoding/json" "errors" - . "camlistore.org/third_party/launchpad.net/gocheck" - "camlistore.org/third_party/labix.org/v2/mgo/bson" "net/url" "reflect" "testing" @@ -59,12 +59,17 @@ func wrapInDoc(data string) string { func makeZeroDoc(value interface{}) (zero interface{}) { v := reflect.ValueOf(value) t := v.Type() - if t.Kind() == reflect.Map { + switch t.Kind() { + case reflect.Map: mv := reflect.MakeMap(t) zero = mv.Interface() - } else { + case reflect.Ptr: pv := reflect.New(v.Type().Elem()) zero = pv.Interface() + case reflect.Slice: + zero = reflect.New(t).Interface() + default: + panic("unsupported doc type") } return zero } @@ -225,6 +230,18 @@ func (s *S) TestUnmarshalZeroesMap(c *C) { c.Assert(m, DeepEquals, bson.M{"b": 2}) } +func (s *S) TestUnmarshalNonNilInterface(c *C) { + data, err := bson.Marshal(bson.M{"b": 2}) + c.Assert(err, IsNil) + m := bson.M{"a": 1} + var i interface{} + i = m + err = bson.Unmarshal(data, &i) + c.Assert(err, IsNil) + c.Assert(i, DeepEquals, bson.M{"b": 2}) + c.Assert(m, DeepEquals, bson.M{"a": 1}) +} + // -------------------------------------------------------------------------- // Some one way marshaling operations which would unmarshal differently. @@ -336,6 +353,20 @@ func (s *S) TestUnmarshalStructSampleItems(c *C) { } } +func (s *S) Test64bitInt(c *C) { + var i int64 = (1 << 31) + if int(i) > 0 { + data, err := bson.Marshal(bson.M{"i": int(i)}) + c.Assert(err, IsNil) + c.Assert(string(data), Equals, wrapInDoc("\x12i\x00\x00\x00\x00\x80\x00\x00\x00\x00")) + + var result struct{ I int } + err = bson.Unmarshal(data, &result) + c.Assert(err, IsNil) + c.Assert(int64(result.I), Equals, i) + } +} + // -------------------------------------------------------------------------- // Generic two-way struct marshaling tests. @@ -429,8 +460,18 @@ var marshalItems = []testItemType{ // Ordered document dump. Will unmarshal as a dictionary by default. {bson.D{{"a", nil}, {"c", nil}, {"b", nil}, {"d", nil}, {"f", nil}, {"e", true}}, "\x0Aa\x00\x0Ac\x00\x0Ab\x00\x0Ad\x00\x0Af\x00\x08e\x00\x01"}, + {MyD{{"a", nil}, {"c", nil}, {"b", nil}, {"d", nil}, {"f", nil}, {"e", true}}, + "\x0Aa\x00\x0Ac\x00\x0Ab\x00\x0Ad\x00\x0Af\x00\x08e\x00\x01"}, {&dOnIface{bson.D{{"a", nil}, {"c", nil}, {"b", nil}, {"d", true}}}, "\x03d\x00" + wrapInDoc("\x0Aa\x00\x0Ac\x00\x0Ab\x00\x08d\x00\x01")}, + + {bson.RawD{{"a", bson.Raw{0x0A, nil}}, {"c", bson.Raw{0x0A, nil}}, {"b", bson.Raw{0x08, []byte{0x01}}}}, + "\x0Aa\x00" + "\x0Ac\x00" + "\x08b\x00\x01"}, + {MyRawD{{"a", bson.Raw{0x0A, nil}}, {"c", bson.Raw{0x0A, nil}}, {"b", bson.Raw{0x08, []byte{0x01}}}}, + "\x0Aa\x00" + "\x0Ac\x00" + "\x08b\x00\x01"}, + {&dOnIface{bson.RawD{{"a", bson.Raw{0x0A, nil}}, {"c", bson.Raw{0x0A, nil}}, {"b", bson.Raw{0x08, []byte{0x01}}}}}, + "\x03d\x00" + wrapInDoc("\x0Aa\x00"+"\x0Ac\x00"+"\x08b\x00\x01")}, + {&ignoreField{"before", "ignore", "after"}, "\x02before\x00\a\x00\x00\x00before\x00\x02after\x00\x06\x00\x00\x00after\x00"}, @@ -491,9 +532,17 @@ var unmarshalItems = []testItemType{ {&bson.Raw{0x03, []byte(wrapInDoc("\x10byte\x00\x08\x00\x00\x00"))}, "\x10byte\x00\x08\x00\x00\x00"}, + // RawD document. + {&struct{ bson.RawD }{bson.RawD{{"a", bson.Raw{0x0A, []byte{}}}, {"c", bson.Raw{0x0A, []byte{}}}, {"b", bson.Raw{0x08, []byte{0x01}}}}}, + "\x03rawd\x00" + wrapInDoc("\x0Aa\x00\x0Ac\x00\x08b\x00\x01")}, + // Decode old binary. {bson.M{"_": []byte("old")}, "\x05_\x00\x07\x00\x00\x00\x02\x03\x00\x00\x00old"}, + + // Decode old binary without length. According to the spec, this shouldn't happen. + {bson.M{"_": []byte("old")}, + "\x05_\x00\x03\x00\x00\x00\x02old"}, } func (s *S) TestUnmarshalOneWayItems(c *C) { @@ -533,9 +582,15 @@ var marshalErrorItems = []testItemType{ {bson.Raw{0x0A, []byte{}}, "Attempted to unmarshal Raw kind 10 as a document"}, {&inlineCantPtr{&struct{ A, B int }{1, 2}}, - "Option ,inline needs a struct value field"}, + "Option ,inline needs a struct value or map field"}, {&inlineDupName{1, struct{ A, B int }{2, 3}}, "Duplicated key 'a' in struct bson_test.inlineDupName"}, + {&inlineDupMap{}, + "Multiple ,inline maps in struct bson_test.inlineDupMap"}, + {&inlineBadKeyMap{}, + "Option ,inline needs a map with string keys in struct bson_test.inlineBadKeyMap"}, + {&inlineMap{A: 1, M: map[string]interface{}{"a": 1}}, + `Can't have key "a" in inlined map; conflicts with struct field`}, } func (s *S) TestMarshalErrorItems(c *C) { @@ -753,9 +808,7 @@ func (s *S) TestUnmarshalSetterOmits(c *C) { func (s *S) TestUnmarshalSetterErrors(c *C) { boom := errors.New("BOOM") setterResult["2"] = boom - defer func() { - delete(setterResult, "2") - }() + defer delete(setterResult, "2") m := map[string]*setterType{} data := wrapInDoc("\x02abc\x00\x02\x00\x00\x001\x00" + @@ -775,6 +828,22 @@ func (s *S) TestDMap(c *C) { c.Assert(d.Map(), DeepEquals, bson.M{"a": 1, "b": 2}) } +func (s *S) TestUnmarshalSetterSetZero(c *C) { + setterResult["foo"] = bson.SetZero + defer delete(setterResult, "field") + + data, err := bson.Marshal(bson.M{"field": "foo"}) + c.Assert(err, IsNil) + + m := map[string]*setterType{} + err = bson.Unmarshal([]byte(data), m) + c.Assert(err, IsNil) + + value, ok := m["field"] + c.Assert(ok, Equals, true) + c.Assert(value, IsNil) +} + // -------------------------------------------------------------------------- // Getter test cases. @@ -869,6 +938,9 @@ type condInt struct { type condUInt struct { V uint ",omitempty" } +type condFloat struct { + V float64 ",omitempty" +} type condIface struct { V interface{} ",omitempty" } @@ -887,6 +959,9 @@ type namedCondStr struct { type condTime struct { V time.Time ",omitempty" } +type condStruct struct { + V struct{ A []int } ",omitempty" +} type shortInt struct { V int64 ",minsize" @@ -914,17 +989,44 @@ type inlineDupName struct { A int V struct{ A, B int } ",inline" } +type inlineMap struct { + A int + M map[string]interface{} ",inline" +} +type inlineMapInt struct { + A int + M map[string]int ",inline" +} +type inlineMapMyM struct { + A int + M MyM ",inline" +} +type inlineDupMap struct { + M1 map[string]interface{} ",inline" + M2 map[string]interface{} ",inline" +} +type inlineBadKeyMap struct { + M map[int]int ",inline" +} -type MyBytes []byte -type MyBool bool +type ( + MyString string + MyBytes []byte + MyBool bool + MyD []bson.DocElem + MyRawD []bson.RawDocElem + MyM map[string]interface{} +) -var truevar = true -var falsevar = false +var ( + truevar = true + falsevar = false -var int64var = int64(42) -var int64ptr = &int64var -var intvar = int(42) -var intptr = &intvar + int64var = int64(42) + int64ptr = &int64var + intvar = int(42) + intptr = &intvar +) func parseURL(s string) *url.URL { u, err := url.Parse(s) @@ -1040,6 +1142,7 @@ var twoWayCrossItems = []crossTypeItem{ {&condInt{}, map[string]int{}}, {&condUInt{1}, map[string]uint{"v": 1}}, {&condUInt{}, map[string]uint{}}, + {&condFloat{}, map[string]int{}}, {&condStr{"yo"}, map[string]string{"v": "yo"}}, {&condStr{}, map[string]string{}}, {&condStrNS{"yo"}, map[string]string{"v": "yo"}}, @@ -1058,6 +1161,9 @@ var twoWayCrossItems = []crossTypeItem{ {&condTime{time.Unix(123456789, 123e6)}, map[string]time.Time{"v": time.Unix(123456789, 123e6)}}, {&condTime{}, map[string]string{}}, + {&condStruct{struct{ A []int }{[]int{1}}}, bson.M{"v": bson.M{"a": []interface{}{1}}}}, + {&condStruct{struct{ A []int }{}}, bson.M{}}, + {&namedCondStr{"yo"}, map[string]string{"myv": "yo"}}, {&namedCondStr{}, map[string]string{}}, @@ -1074,6 +1180,11 @@ var twoWayCrossItems = []crossTypeItem{ {&shortNonEmptyInt{}, map[string]interface{}{}}, {&inlineInt{struct{ A, B int }{1, 2}}, map[string]interface{}{"a": 1, "b": 2}}, + {&inlineMap{A: 1, M: map[string]interface{}{"b": 2}}, map[string]interface{}{"a": 1, "b": 2}}, + {&inlineMap{A: 1, M: nil}, map[string]interface{}{"a": 1}}, + {&inlineMapInt{A: 1, M: map[string]int{"b": 2}}, map[string]int{"a": 1, "b": 2}}, + {&inlineMapInt{A: 1, M: nil}, map[string]int{"a": 1}}, + {&inlineMapMyM{A: 1, M: MyM{"b": MyM{"c": 3}}}, map[string]interface{}{"a": 1, "b": map[string]interface{}{"c": 3}}}, // []byte <=> MyBytes {&struct{ B MyBytes }{[]byte("abc")}, map[string]string{"b": "abc"}}, @@ -1096,13 +1207,30 @@ var twoWayCrossItems = []crossTypeItem{ // zero time + 1 second + 1 millisecond; overflows int64 as nanoseconds {&struct{ V time.Time }{time.Unix(-62135596799, 1e6).Local()}, map[string]interface{}{"v": time.Unix(-62135596799, 1e6).Local()}}, + + // bson.D <=> []DocElem + {&bson.D{{"a", bson.D{{"b", 1}, {"c", 2}}}}, &bson.D{{"a", bson.D{{"b", 1}, {"c", 2}}}}}, + {&bson.D{{"a", bson.D{{"b", 1}, {"c", 2}}}}, &MyD{{"a", MyD{{"b", 1}, {"c", 2}}}}}, + + // bson.RawD <=> []RawDocElem + {&bson.RawD{{"a", bson.Raw{0x08, []byte{0x01}}}}, &bson.RawD{{"a", bson.Raw{0x08, []byte{0x01}}}}}, + {&bson.RawD{{"a", bson.Raw{0x08, []byte{0x01}}}}, &MyRawD{{"a", bson.Raw{0x08, []byte{0x01}}}}}, + + // bson.M <=> map + {bson.M{"a": bson.M{"b": 1, "c": 2}}, MyM{"a": MyM{"b": 1, "c": 2}}}, + {bson.M{"a": bson.M{"b": 1, "c": 2}}, map[string]interface{}{"a": map[string]interface{}{"b": 1, "c": 2}}}, + + // bson.M <=> map[MyString] + {bson.M{"a": bson.M{"b": 1, "c": 2}}, map[MyString]interface{}{"a": map[MyString]interface{}{"b": 1, "c": 2}}}, } // Same thing, but only one way (obj1 => obj2). var oneWayCrossItems = []crossTypeItem{ // map <=> struct - {map[string]interface{}{"a": 1, "b": "2", "c": 3}, - map[string]int{"a": 1, "c": 3}}, + {map[string]interface{}{"a": 1, "b": "2", "c": 3}, map[string]int{"a": 1, "c": 3}}, + + // inline map elides badly typed values + {map[string]interface{}{"a": 1, "b": "2", "c": 3}, &inlineMapInt{A: 1, M: map[string]int{"c": 3}}}, // Can't decode int into struct. {bson.M{"a": bson.M{"b": 2}}, &struct{ A bool }{}}, @@ -1146,6 +1274,21 @@ func (s *S) TestObjectIdHex(c *C) { c.Assert(id.Hex(), Equals, "4d88e15b60f486e428412dc9") } +func (s *S) TestIsObjectIdHex(c *C) { + test := []struct { + id string + valid bool + }{ + {"4d88e15b60f486e428412dc9", true}, + {"4d88e15b60f486e428412dc", false}, + {"4d88e15b60f486e428412dc9e", false}, + {"4d88e15b60f486e428412dcx", false}, + } + for _, t := range test { + c.Assert(bson.IsObjectIdHex(t.id), Equals, t.valid) + } +} + // -------------------------------------------------------------------------- // ObjectId parts extraction tests. diff --git a/third_party/labix.org/v2/mgo/bson/decode.go b/third_party/labix.org/v2/mgo/bson/decode.go index ce65c7669..ef85735ec 100644 --- a/third_party/labix.org/v2/mgo/bson/decode.go +++ b/third_party/labix.org/v2/mgo/bson/decode.go @@ -1,18 +1,18 @@ // BSON library for Go -// +// // Copyright (c) 2010-2012 - Gustavo Niemeyer -// +// // All rights reserved. // // Redistribution and use in source and binary forms, with or without -// modification, are permitted provided that the following conditions are met: -// +// modification, are permitted provided that the following conditions are met: +// // 1. Redistributions of source code must retain the above copyright notice, this -// list of conditions and the following disclaimer. +// list of conditions and the following disclaimer. // 2. Redistributions in binary form must reproduce the above copyright notice, // this list of conditions and the following disclaimer in the documentation -// and/or other materials provided with the distribution. -// +// and/or other materials provided with the distribution. +// // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND // ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED // WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE @@ -37,8 +37,15 @@ import ( ) type decoder struct { - in []byte - i int + in []byte + i int + docType reflect.Type +} + +var typeM = reflect.TypeOf(M{}) + +func newDecoder(in []byte) *decoder { + return &decoder{in, 0, typeM} } // -------------------------------------------------------------------------- @@ -106,6 +113,13 @@ func getSetter(outt reflect.Type, out reflect.Value) Setter { return out.Interface().(Setter) } +func clearMap(m reflect.Value) { + var none reflect.Value + for _, k := range m.MapKeys() { + m.SetMapIndex(k, none) + } +} + func (d *decoder) readDocTo(out reflect.Value) { var elemType reflect.Type outt := out.Type() @@ -134,31 +148,44 @@ func (d *decoder) readDocTo(out reflect.Value) { } var fieldsMap map[string]fieldInfo + var inlineMap reflect.Value start := d.i - switch outk { - case reflect.Interface: - if !out.IsNil() { - panic("Found non-nil interface. Please contact the developers.") + origout := out + if outk == reflect.Interface { + if d.docType.Kind() == reflect.Map { + mv := reflect.MakeMap(d.docType) + out.Set(mv) + out = mv + } else { + dv := reflect.New(d.docType).Elem() + out.Set(dv) + out = dv } - mv := reflect.ValueOf(make(M)) - out.Set(mv) - out = mv outt = out.Type() outk = outt.Kind() - fallthrough + } + + docType := d.docType + keyType := typeString + convertKey := false + switch outk { case reflect.Map: - if outt.Key().Kind() != reflect.String { + keyType = outt.Key() + if keyType.Kind() != reflect.String { panic("BSON map must have string keys. Got: " + outt.String()) } + if keyType != typeString { + convertKey = true + } elemType = outt.Elem() + if elemType == typeIface { + d.docType = outt + } if out.IsNil() { out.Set(reflect.MakeMap(out.Type())) } else if out.Len() > 0 { - var none reflect.Value - for _, k := range out.MapKeys() { - out.SetMapIndex(k, none) - } + clearMap(out) } case reflect.Struct: if outt != typeRaw { @@ -168,12 +195,33 @@ func (d *decoder) readDocTo(out reflect.Value) { } fieldsMap = sinfo.FieldsMap out.Set(sinfo.Zero) + if sinfo.InlineMap != -1 { + inlineMap = out.Field(sinfo.InlineMap) + if !inlineMap.IsNil() && inlineMap.Len() > 0 { + clearMap(inlineMap) + } + elemType = inlineMap.Type().Elem() + if elemType == typeIface { + d.docType = inlineMap.Type() + } + } } + case reflect.Slice: + switch outt.Elem() { + case typeDocElem: + origout.Set(d.readDocElems(outt)) + return + case typeRawDocElem: + origout.Set(d.readRawDocElems(outt)) + return + } + fallthrough default: panic("Unsupported document type for unmarshalling: " + out.Type().String()) } - end := d.i - 4 + int(d.readInt32()) + end := int(d.readInt32()) + end += d.i - 4 if end <= d.i || end > len(d.in) || d.in[end-1] != '\x00' { corrupted() } @@ -188,11 +236,15 @@ func (d *decoder) readDocTo(out reflect.Value) { case reflect.Map: e := reflect.New(elemType).Elem() if d.readElemTo(e, kind) { - out.SetMapIndex(reflect.ValueOf(name), e) + k := reflect.ValueOf(name) + if convertKey { + k = k.Convert(keyType) + } + out.SetMapIndex(k, e) } case reflect.Struct: if outt == typeRaw { - d.readElemTo(blackHole, kind) + d.dropElem(kind) } else { if info, ok := fieldsMap[name]; ok { if info.Inline == nil { @@ -200,10 +252,19 @@ func (d *decoder) readDocTo(out reflect.Value) { } else { d.readElemTo(out.FieldByIndex(info.Inline), kind) } + } else if inlineMap.IsValid() { + if inlineMap.IsNil() { + inlineMap.Set(reflect.MakeMap(inlineMap.Type())) + } + e := reflect.New(elemType).Elem() + if d.readElemTo(e, kind) { + inlineMap.SetMapIndex(reflect.ValueOf(name), e) + } } else { d.dropElem(kind) } } + case reflect.Slice: } if d.i >= end { @@ -214,17 +275,16 @@ func (d *decoder) readDocTo(out reflect.Value) { if d.i != end { corrupted() } + d.docType = docType - switch outk { - case reflect.Struct: - if outt == typeRaw { - out.Set(reflect.ValueOf(Raw{0x03, d.in[start:d.i]})) - } + if outt == typeRaw { + out.Set(reflect.ValueOf(Raw{0x03, d.in[start:d.i]})) } } func (d *decoder) readArrayDocTo(out reflect.Value) { - end := d.i - 4 + int(d.readInt32()) + end := int(d.readInt32()) + end += d.i - 4 if end <= d.i || end > len(d.in) || d.in[end-1] != '\x00' { corrupted() } @@ -261,7 +321,8 @@ func (d *decoder) readSliceDoc(t reflect.Type) interface{} { tmp := make([]reflect.Value, 0, 8) elemType := t.Elem() - end := d.i - 4 + int(d.readInt32()) + end := int(d.readInt32()) + end += d.i - 4 if end <= d.i || end > len(d.in) || d.in[end-1] != '\x00' { corrupted() } @@ -295,11 +356,13 @@ func (d *decoder) readSliceDoc(t reflect.Type) interface{} { return slice.Interface() } -var typeD = reflect.TypeOf(D{}) var typeSlice = reflect.TypeOf([]interface{}{}) +var typeIface = typeSlice.Elem() -func (d *decoder) readDocD() interface{} { - slice := make(D, 0, 8) +func (d *decoder) readDocElems(typ reflect.Type) reflect.Value { + docType := d.docType + d.docType = typ + slice := make([]DocElem, 0, 8) d.readDocWith(func(kind byte, name string) { e := DocElem{Name: name} v := reflect.ValueOf(&e.Value) @@ -307,11 +370,32 @@ func (d *decoder) readDocD() interface{} { slice = append(slice, e) } }) - return slice + slicev := reflect.New(typ).Elem() + slicev.Set(reflect.ValueOf(slice)) + d.docType = docType + return slicev +} + +func (d *decoder) readRawDocElems(typ reflect.Type) reflect.Value { + docType := d.docType + d.docType = typ + slice := make([]RawDocElem, 0, 8) + d.readDocWith(func(kind byte, name string) { + e := RawDocElem{Name: name} + v := reflect.ValueOf(&e.Value) + if d.readElemTo(v.Elem(), kind) { + slice = append(slice, e) + } + }) + slicev := reflect.New(typ).Elem() + slicev.Set(reflect.ValueOf(slice)) + d.docType = docType + return slicev } func (d *decoder) readDocWith(f func(kind byte, name string)) { - end := d.i - 4 + int(d.readInt32()) + end := int(d.readInt32()) + end += d.i - 4 if end <= d.i || end > len(d.in) || d.in[end-1] != '\x00' { corrupted() } @@ -354,9 +438,12 @@ func (d *decoder) readElemTo(out reflect.Value, kind byte) (good bool) { case reflect.Interface, reflect.Ptr, reflect.Struct, reflect.Map: d.readDocTo(out) default: - if _, ok := out.Interface().(D); ok { - out.Set(reflect.ValueOf(d.readDocD())) - } else { + switch out.Interface().(type) { + case D: + out.Set(d.readDocElems(out.Type())) + case RawD: + out.Set(d.readRawDocElems(out.Type())) + default: d.readDocTo(blackHole) } } @@ -443,9 +530,14 @@ func (d *decoder) readElemTo(out reflect.Value, kind byte) (good bool) { if setter := getSetter(outt, out); setter != nil { err := setter.SetBSON(Raw{kind, d.in[start:d.i]}) + if err == SetZero { + out.Set(reflect.Zero(outt)) + return true + } if err == nil { return true - } else if _, ok := err.(*TypeError); !ok { + } + if _, ok := err.(*TypeError); !ok { panic(err) } return false @@ -621,7 +713,7 @@ func (d *decoder) readBinary() Binary { b := Binary{} b.Kind = d.readByte() b.Data = d.readBytes(l) - if b.Kind == 0x02 { + if b.Kind == 0x02 && len(b.Data) >= 4 { // Weird obsolete format with redundant length. b.Data = b.Data[4:] } diff --git a/third_party/labix.org/v2/mgo/bson/encode.go b/third_party/labix.org/v2/mgo/bson/encode.go index 6e37e35a4..e37868a45 100644 --- a/third_party/labix.org/v2/mgo/bson/encode.go +++ b/third_party/labix.org/v2/mgo/bson/encode.go @@ -1,18 +1,18 @@ // BSON library for Go -// +// // Copyright (c) 2010-2012 - Gustavo Niemeyer -// +// // All rights reserved. // // Redistribution and use in source and binary forms, with or without -// modification, are permitted provided that the following conditions are met: -// +// modification, are permitted provided that the following conditions are met: +// // 1. Redistributions of source code must retain the above copyright notice, this -// list of conditions and the following disclaimer. +// list of conditions and the following disclaimer. // 2. Redistributions in binary form must reproduce the above copyright notice, // this list of conditions and the following disclaimer in the documentation -// and/or other materials provided with the distribution. -// +// and/or other materials provided with the distribution. +// // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND // ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED // WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE @@ -28,6 +28,7 @@ package bson import ( + "fmt" "math" "net/url" "reflect" @@ -45,9 +46,11 @@ var ( typeMongoTimestamp = reflect.TypeOf(MongoTimestamp(0)) typeOrderKey = reflect.TypeOf(MinKey) typeDocElem = reflect.TypeOf(DocElem{}) + typeRawDocElem = reflect.TypeOf(RawDocElem{}) typeRaw = reflect.TypeOf(Raw{}) typeURL = reflect.TypeOf(url.URL{}) typeTime = reflect.TypeOf(time.Time{}) + typeString = reflect.TypeOf("") ) const itoaCacheSize = 32 @@ -130,6 +133,18 @@ func (e *encoder) addStruct(v reflect.Value) { panic(err) } var value reflect.Value + if sinfo.InlineMap >= 0 { + m := v.Field(sinfo.InlineMap) + if m.Len() > 0 { + for _, k := range m.MapKeys() { + ks := k.String() + if _, found := sinfo.FieldsMap[ks]; found { + panic(fmt.Sprintf("Can't have key %q in inlined map; conflicts with struct field", ks)) + } + e.addElem(ks, m.MapIndex(k), false) + } + } + } for _, info := range sinfo.FieldsList { if info.Inline == nil { value = v.Field(info.Num) @@ -157,25 +172,56 @@ func isZero(v reflect.Value) bool { return v.Int() == 0 case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: return v.Uint() == 0 + case reflect.Float32, reflect.Float64: + return v.Float() == 0 case reflect.Bool: return !v.Bool() case reflect.Struct: if v.Type() == typeTime { return v.Interface().(time.Time).IsZero() } + for i := v.NumField() - 1; i >= 0; i-- { + if !isZero(v.Field(i)) { + return false + } + } + return true } return false } func (e *encoder) addSlice(v reflect.Value) { - if d, ok := v.Interface().(D); ok { + vi := v.Interface() + if d, ok := vi.(D); ok { for _, elem := range d { e.addElem(elem.Name, reflect.ValueOf(elem.Value), false) } - } else { - for i := 0; i != v.Len(); i++ { - e.addElem(itoa(i), v.Index(i), false) + return + } + if d, ok := vi.(RawD); ok { + for _, elem := range d { + e.addElem(elem.Name, reflect.ValueOf(elem.Value), false) } + return + } + l := v.Len() + et := v.Type().Elem() + if et == typeDocElem { + for i := 0; i < l; i++ { + elem := v.Index(i).Interface().(DocElem) + e.addElem(elem.Name, reflect.ValueOf(elem.Value), false) + } + return + } + if et == typeRawDocElem { + for i := 0; i < l; i++ { + elem := v.Index(i).Interface().(RawDocElem) + e.addElem(elem.Name, reflect.ValueOf(elem.Value), false) + } + return + } + for i := 0; i < l; i++ { + e.addElem(itoa(i), v.Index(i), false) } } @@ -247,32 +293,27 @@ func (e *encoder) addElem(name string, v reflect.Value, minSize bool) { } case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - if v.Type().Kind() <= reflect.Int32 { - e.addElemName('\x10', name) - e.addInt32(int32(v.Int())) - } else { - switch v.Type() { - case typeMongoTimestamp: - e.addElemName('\x11', name) - e.addInt64(v.Int()) + switch v.Type() { + case typeMongoTimestamp: + e.addElemName('\x11', name) + e.addInt64(v.Int()) - case typeOrderKey: - if v.Int() == int64(MaxKey) { - e.addElemName('\x7F', name) - } else { - e.addElemName('\xFF', name) - } + case typeOrderKey: + if v.Int() == int64(MaxKey) { + e.addElemName('\x7F', name) + } else { + e.addElemName('\xFF', name) + } - default: - i := v.Int() - if minSize && i >= math.MinInt32 && i <= math.MaxInt32 { - // It fits into an int32, encode as such. - e.addElemName('\x10', name) - e.addInt32(int32(i)) - } else { - e.addElemName('\x12', name) - e.addInt64(i) - } + default: + i := v.Int() + if (minSize || v.Type().Kind() != reflect.Int64) && i >= math.MinInt32 && i <= math.MaxInt32 { + // It fits into an int32, encode as such. + e.addElemName('\x10', name) + e.addInt32(int32(i)) + } else { + e.addElemName('\x12', name) + e.addInt64(i) } } @@ -294,7 +335,7 @@ func (e *encoder) addElem(name string, v reflect.Value, minSize bool) { if et.Kind() == reflect.Uint8 { e.addElemName('\x05', name) e.addBinary('\x00', v.Bytes()) - } else if et == typeDocElem { + } else if et == typeDocElem || et == typeRawDocElem { e.addElemName('\x03', name) e.addDoc(v) } else { @@ -347,7 +388,7 @@ func (e *encoder) addElem(name string, v reflect.Value, minSize bool) { case time.Time: // MongoDB handles timestamps as milliseconds. e.addElemName('\x09', name) - e.addInt64(s.Unix() * 1000 + int64(s.Nanosecond() / 1e6)) + e.addInt64(s.Unix()*1000 + int64(s.Nanosecond()/1e6)) case url.URL: e.addElemName('\x02', name) diff --git a/third_party/labix.org/v2/mgo/cluster.go b/third_party/labix.org/v2/mgo/cluster.go index 2e9be0ef7..a9491e8b6 100644 --- a/third_party/labix.org/v2/mgo/cluster.go +++ b/third_party/labix.org/v2/mgo/cluster.go @@ -1,18 +1,18 @@ // mgo - MongoDB driver for Go -// +// // Copyright (c) 2010-2012 - Gustavo Niemeyer -// +// // All rights reserved. // // Redistribution and use in source and binary forms, with or without -// modification, are permitted provided that the following conditions are met: -// +// modification, are permitted provided that the following conditions are met: +// // 1. Redistributions of source code must retain the above copyright notice, this -// list of conditions and the following disclaimer. +// list of conditions and the following disclaimer. // 2. Redistributions in binary form must reproduce the above copyright notice, // this list of conditions and the following disclaimer in the documentation -// and/or other materials provided with the distribution. -// +// and/or other materials provided with the distribution. +// // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND // ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED // WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE @@ -27,7 +27,9 @@ package mgo import ( + "camlistore.org/third_party/labix.org/v2/mgo/bson" "errors" + "net" "sync" "time" ) @@ -51,12 +53,19 @@ type mongoCluster struct { direct bool cachedIndex map[string]bool sync chan bool + dial dialer } -func newCluster(userSeeds []string, direct bool) *mongoCluster { - cluster := &mongoCluster{userSeeds: userSeeds, references: 1, direct: direct} +func newCluster(userSeeds []string, direct bool, dial dialer) *mongoCluster { + cluster := &mongoCluster{ + userSeeds: userSeeds, + references: 1, + direct: direct, + dial: dial, + } cluster.serverSynced.L = cluster.RWMutex.RLocker() cluster.sync = make(chan bool, 1) + stats.cluster(+1) go cluster.syncServersLoop() return cluster } @@ -84,6 +93,7 @@ func (cluster *mongoCluster) Release() { } // Wake up the sync loop so it can die. cluster.syncServers() + stats.cluster(-1) } cluster.Unlock() } @@ -115,34 +125,55 @@ type isMasterResult struct { Primary string Hosts []string Passives []string + Tags bson.D + Msg string } -func (cluster *mongoCluster) syncServer(server *mongoServer) (hosts []string, err error) { +func (cluster *mongoCluster) isMaster(socket *mongoSocket, result *isMasterResult) error { + // Monotonic let's it talk to a slave and still hold the socket. + session := newSession(Monotonic, cluster, 10*time.Second) + session.setSocket(socket) + err := session.Run("ismaster", result) + session.Close() + return err +} + +type possibleTimeout interface { + Timeout() bool +} + +var syncSocketTimeout = 5 * time.Second + +func (cluster *mongoCluster) syncServer(server *mongoServer) (info *mongoServerInfo, hosts []string, err error) { addr := server.Addr log("SYNC Processing ", addr, "...") + // Retry a few times to avoid knocking a server down for a hiccup. var result isMasterResult var tryerr error for retry := 0; ; retry++ { if retry == 3 { - return nil, tryerr + return nil, nil, tryerr + } + if retry > 0 { + // Don't abuse the server needlessly if there's something actually wrong. + if err, ok := tryerr.(possibleTimeout); ok && err.Timeout() { + // Give a chance for waiters to timeout as well. + cluster.serverSynced.Broadcast() + } + time.Sleep(500 * time.Millisecond) } - socket, err := server.AcquireSocket(0) + // It's not clear what would be a good timeout here. Is it + // better to wait longer or to retry? + socket, _, err := server.AcquireSocket(0, syncSocketTimeout) if err != nil { tryerr = err logf("SYNC Failed to get socket to %s: %v", addr, err) continue } - - // Monotonic will let us talk to a slave and still hold the socket. - session := newSession(Monotonic, cluster, socket, 10 * time.Second) - - // session holds the socket now. + err = cluster.isMaster(socket, &result) socket.Release() - - err = session.Run("ismaster", &result) - session.Close() if err != nil { tryerr = err logf("SYNC Command 'ismaster' to %s failed: %v", addr, err) @@ -156,15 +187,22 @@ func (cluster *mongoCluster) syncServer(server *mongoServer) (hosts []string, er debugf("SYNC %s is a master.", addr) // Made an incorrect assumption above, so fix stats. stats.conn(-1, false) - server.SetMaster(true) stats.conn(+1, true) } else if result.Secondary { debugf("SYNC %s is a slave.", addr) + } else if cluster.direct { + logf("SYNC %s in unknown state. Pretending it's a slave due to direct connection.", addr) } else { logf("SYNC %s is neither a master nor a slave.", addr) // Made an incorrect assumption above, so fix stats. stats.conn(-1, false) - return nil, errors.New(addr + " is not a master nor slave") + return nil, nil, errors.New(addr + " is not a master nor slave") + } + + info = &mongoServerInfo{ + Master: result.IsMaster, + Mongos: result.Msg == "isdbgrid", + Tags: result.Tags, } hosts = make([]string, 0, 1+len(result.Hosts)+len(result.Passives)) @@ -176,33 +214,48 @@ func (cluster *mongoCluster) syncServer(server *mongoServer) (hosts []string, er hosts = append(hosts, result.Passives...) debugf("SYNC %s knows about the following peers: %#v", addr, hosts) - return hosts, nil + return info, hosts, nil } -func (cluster *mongoCluster) mergeServer(server *mongoServer) { +type syncKind bool + +const ( + completeSync syncKind = true + partialSync syncKind = false +) + +func (cluster *mongoCluster) addServer(server *mongoServer, info *mongoServerInfo, syncKind syncKind) { cluster.Lock() - previous := cluster.servers.Search(server) - isMaster := server.IsMaster() - if previous == nil { + current := cluster.servers.Search(server.ResolvedAddr) + if current == nil { + if syncKind == partialSync { + cluster.Unlock() + server.Close() + log("SYNC Discarding unknown server ", server.Addr, " due to partial sync.") + return + } cluster.servers.Add(server) - if isMaster { + if info.Master { cluster.masters.Add(server) log("SYNC Adding ", server.Addr, " to cluster as a master.") } else { log("SYNC Adding ", server.Addr, " to cluster as a slave.") } } else { - if isMaster != previous.IsMaster() { - if isMaster { + if server != current { + panic("addServer attempting to add duplicated server") + } + if server.Info().Master != info.Master { + if info.Master { log("SYNC Server ", server.Addr, " is now a master.") - cluster.masters.Add(previous) + cluster.masters.Add(server) } else { log("SYNC Server ", server.Addr, " is now a slave.") - cluster.masters.Remove(previous) + cluster.masters.Remove(server) } } - previous.Merge(server) } + server.SetInfo(info) debugf("SYNC Broadcasting availability of server %s", server.Addr) cluster.serverSynced.Broadcast() cluster.Unlock() @@ -246,7 +299,7 @@ func (cluster *mongoCluster) syncServers() { // How long to wait for a checkup of the cluster topology if nothing // else kicks a synchronization before that. -const syncServersDelay = 3 * time.Minute +const syncServersDelay = 30 * time.Second // syncServersLoop loops while the cluster is alive to keep its idea of // the server topology up-to-date. It must be called just once from @@ -281,7 +334,7 @@ func (cluster *mongoCluster) syncServersLoop() { // Hold off before allowing another sync. No point in // burning CPU looking for down servers. - time.Sleep(5e8) + time.Sleep(500 * time.Millisecond) cluster.Lock() if cluster.references == 0 { @@ -312,15 +365,42 @@ func (cluster *mongoCluster) syncServersLoop() { debugf("SYNC Cluster %p is stopping its sync loop.", cluster) } +func (cluster *mongoCluster) server(addr string, tcpaddr *net.TCPAddr) *mongoServer { + cluster.RLock() + server := cluster.servers.Search(tcpaddr.String()) + cluster.RUnlock() + if server != nil { + return server + } + return newServer(addr, tcpaddr, cluster.sync, cluster.dial) +} + +func resolveAddr(addr string) (*net.TCPAddr, error) { + tcpaddr, err := net.ResolveTCPAddr("tcp", addr) + if err != nil { + log("SYNC Failed to resolve ", addr, ": ", err.Error()) + return nil, err + } + if tcpaddr.String() != addr { + debug("SYNC Address ", addr, " resolved as ", tcpaddr.String()) + } + return tcpaddr, nil +} + +type pendingAdd struct { + server *mongoServer + info *mongoServerInfo +} + func (cluster *mongoCluster) syncServersIteration(direct bool) { log("SYNC Starting full topology synchronization...") var wg sync.WaitGroup var m sync.Mutex - mergePending := make(map[string]*mongoServer) - mergeRequested := make(map[string]bool) + notYetAdded := make(map[string]pendingAdd) + addIfFound := make(map[string]bool) seen := make(map[string]bool) - goodSync := false + syncKind := partialSync var spawnSync func(addr string, byMaster bool) spawnSync = func(addr string, byMaster bool) { @@ -328,66 +408,71 @@ func (cluster *mongoCluster) syncServersIteration(direct bool) { go func() { defer wg.Done() - server, err := newServer(addr, cluster.sync) + tcpaddr, err := resolveAddr(addr) if err != nil { log("SYNC Failed to start sync of ", addr, ": ", err.Error()) return } + resolvedAddr := tcpaddr.String() m.Lock() if byMaster { - if s, found := mergePending[server.ResolvedAddr]; found { - delete(mergePending, server.ResolvedAddr) + if pending, ok := notYetAdded[resolvedAddr]; ok { + delete(notYetAdded, resolvedAddr) m.Unlock() - cluster.mergeServer(s) + cluster.addServer(pending.server, pending.info, completeSync) return } - mergeRequested[server.ResolvedAddr] = true + addIfFound[resolvedAddr] = true } - if seen[server.ResolvedAddr] { + if seen[resolvedAddr] { m.Unlock() return } - seen[server.ResolvedAddr] = true + seen[resolvedAddr] = true m.Unlock() - hosts, err := cluster.syncServer(server) - if err == nil { - isMaster := server.IsMaster() - if !direct { - for _, addr := range hosts { - spawnSync(addr, isMaster) - } - } + server := cluster.server(addr, tcpaddr) + info, hosts, err := cluster.syncServer(server) + if err != nil { + cluster.removeServer(server) + return + } - m.Lock() - merge := direct || isMaster - if mergeRequested[server.ResolvedAddr] { - merge = true - } else if !merge { - mergePending[server.ResolvedAddr] = server - } - if merge { - goodSync = true - } - m.Unlock() - if merge { - cluster.mergeServer(server) + m.Lock() + add := direct || info.Master || addIfFound[resolvedAddr] + if add { + syncKind = completeSync + } else { + notYetAdded[resolvedAddr] = pendingAdd{server, info} + } + m.Unlock() + if add { + cluster.addServer(server, info, completeSync) + } + if !direct { + for _, addr := range hosts { + spawnSync(addr, info.Master) } } }() } - for _, addr := range cluster.getKnownAddrs() { + knownAddrs := cluster.getKnownAddrs() + for _, addr := range knownAddrs { spawnSync(addr, false) } wg.Wait() - for _, server := range mergePending { - if goodSync { - cluster.removeServer(server) - } else { - server.Close() + if syncKind == completeSync { + logf("SYNC Synchronization was complete (got data from primary).") + for _, pending := range notYetAdded { + cluster.removeServer(pending.server) + } + } else { + logf("SYNC Synchronization was partial (cannot talk to primary).") + for _, pending := range notYetAdded { + cluster.addServer(pending.server, pending.info, partialSync) } } @@ -397,7 +482,7 @@ func (cluster *mongoCluster) syncServersIteration(direct bool) { // Update dynamic seeds, but only if we have any good servers. Otherwise, // leave them alone for better chances of a successful sync in the future. - if goodSync { + if syncKind == completeSync { dynaSeeds := make([]string, cluster.servers.Len()) for i, server := range cluster.servers.Slice() { dynaSeeds[i] = server.Addr @@ -413,7 +498,7 @@ var socketsPerServer = 4096 // AcquireSocket returns a socket to a server in the cluster. If slaveOk is // true, it will attempt to return a socket to a slave server. If it is // false, the socket will necessarily be to a master server. -func (cluster *mongoCluster) AcquireSocket(slaveOk bool, syncTimeout time.Duration) (s *mongoSocket, err error) { +func (cluster *mongoCluster) AcquireSocket(slaveOk bool, syncTimeout time.Duration, socketTimeout time.Duration, serverTags []bson.D) (s *mongoSocket, err error) { var started time.Time warnedLimit := false for { @@ -440,13 +525,19 @@ func (cluster *mongoCluster) AcquireSocket(slaveOk bool, syncTimeout time.Durati var server *mongoServer if slaveOk { - server = cluster.servers.MostAvailable() + server = cluster.servers.BestFit(serverTags) } else { - server = cluster.masters.MostAvailable() + server = cluster.masters.BestFit(nil) } cluster.RUnlock() - s, err = server.AcquireSocket(socketsPerServer) + if server == nil { + // Must have failed the requested tags. Sleep to avoid spinning. + time.Sleep(1e8) + continue + } + + s, abended, err := server.AcquireSocket(socketsPerServer, socketTimeout) if err == errSocketLimit { if !warnedLimit { log("WARNING: Per-server connection limit reached.") @@ -459,6 +550,17 @@ func (cluster *mongoCluster) AcquireSocket(slaveOk bool, syncTimeout time.Durati cluster.syncServers() continue } + if abended && !slaveOk { + var result isMasterResult + err := cluster.isMaster(s, &result) + if err != nil || !result.IsMaster { + logf("Cannot confirm server %s as master (%v)", server.Addr, err) + s.Release() + cluster.syncServers() + time.Sleep(1e8) + continue + } + } return s, nil } panic("unreached") diff --git a/third_party/labix.org/v2/mgo/cluster_test.go b/third_party/labix.org/v2/mgo/cluster_test.go index b09617887..be9b75188 100644 --- a/third_party/labix.org/v2/mgo/cluster_test.go +++ b/third_party/labix.org/v2/mgo/cluster_test.go @@ -1,18 +1,18 @@ // mgo - MongoDB driver for Go -// +// // Copyright (c) 2010-2012 - Gustavo Niemeyer -// +// // All rights reserved. // // Redistribution and use in source and binary forms, with or without -// modification, are permitted provided that the following conditions are met: -// +// modification, are permitted provided that the following conditions are met: +// // 1. Redistributions of source code must retain the above copyright notice, this -// list of conditions and the following disclaimer. +// list of conditions and the following disclaimer. // 2. Redistributions in binary form must reproduce the above copyright notice, // this list of conditions and the following disclaimer in the documentation -// and/or other materials provided with the distribution. -// +// and/or other materials provided with the distribution. +// // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND // ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED // WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE @@ -27,11 +27,14 @@ package mgo_test import ( - "io" - . "camlistore.org/third_party/launchpad.net/gocheck" "camlistore.org/third_party/labix.org/v2/mgo" "camlistore.org/third_party/labix.org/v2/mgo/bson" + . "camlistore.org/third_party/launchpad.net/gocheck" + "fmt" + "io" + "net" "strings" + "sync" "time" ) @@ -81,7 +84,7 @@ func (s *S) TestNewSession(c *C) { m := M{} ok := iter.Next(m) c.Assert(ok, Equals, true) - err = iter.Err() + err = iter.Close() c.Assert(err, IsNil) // If Batch(-1) is in effect, a single document must have been received. @@ -146,7 +149,7 @@ func (s *S) TestCloneSession(c *C) { m := M{} ok := iter.Next(m) c.Assert(ok, Equals, true) - err = iter.Err() + err = iter.Close() c.Assert(err, IsNil) // If Batch(-1) is in effect, a single document must have been received. @@ -226,7 +229,7 @@ func (s *S) TestSetModeMonotonic(c *C) { stats := mgo.GetStats() c.Assert(stats.MasterConns, Equals, 1) c.Assert(stats.SlaveConns, Equals, 2) - c.Assert(stats.SocketsInUse, Equals, 1) + c.Assert(stats.SocketsInUse, Equals, 2) session.SetMode(mgo.Monotonic, true) @@ -307,6 +310,51 @@ func (s *S) TestSetModeStrongAfterMonotonic(c *C) { c.Assert(result["ismaster"], Equals, true) } +func (s *S) TestSetModeMonotonicWriteOnIteration(c *C) { + // Must necessarily connect to a slave, otherwise the + // master connection will be available first. + session, err := mgo.Dial("localhost:40012") + c.Assert(err, IsNil) + defer session.Close() + + session.SetMode(mgo.Monotonic, false) + + c.Assert(session.Mode(), Equals, mgo.Monotonic) + + coll1 := session.DB("mydb").C("mycoll1") + coll2 := session.DB("mydb").C("mycoll2") + + ns := []int{40, 41, 42, 43, 44, 45, 46} + for _, n := range ns { + err := coll1.Insert(M{"n": n}) + c.Assert(err, IsNil) + } + + // Release master so we can grab a slave again. + session.Refresh() + + // Wait until synchronization is done. + for { + n, err := coll1.Count() + c.Assert(err, IsNil) + if n == len(ns) { + break + } + } + + iter := coll1.Find(nil).Batch(2).Iter() + i := 0 + m := M{} + for iter.Next(&m) { + i++ + if i > 3 { + err := coll2.Insert(M{"n": 47 + i}) + c.Assert(err, IsNil) + } + } + c.Assert(i, Equals, len(ns)) +} + func (s *S) TestSetModeEventual(c *C) { // Must necessarily connect to a slave, otherwise the // master connection will be available first. @@ -417,6 +465,61 @@ func (s *S) TestPrimaryShutdownStrong(c *C) { err = session.Run("serverStatus", result) c.Assert(err, IsNil) c.Assert(result.Host, Not(Equals), host) + + // Insert some data to confirm it's indeed a master. + err = session.DB("mydb").C("mycoll").Insert(M{"n": 42}) + c.Assert(err, IsNil) +} + +func (s *S) TestPrimaryHiccup(c *C) { + if *fast { + c.Skip("-fast") + } + + session, err := mgo.Dial("localhost:40021") + c.Assert(err, IsNil) + defer session.Close() + + // With strong consistency, this will open a socket to the master. + result := &struct{ Host string }{} + err = session.Run("serverStatus", result) + c.Assert(err, IsNil) + + // Establish a few extra sessions to create spare sockets to + // the master. This increases a bit the chances of getting an + // incorrect cached socket. + var sessions []*mgo.Session + for i := 0; i < 20; i++ { + sessions = append(sessions, session.Copy()) + err = sessions[len(sessions)-1].Run("serverStatus", result) + c.Assert(err, IsNil) + } + for i := range sessions { + sessions[i].Close() + } + + // Kill the master, but bring it back immediatelly. + host := result.Host + s.Stop(host) + s.StartAll() + + // This must fail, since the connection was broken. + err = session.Run("serverStatus", result) + c.Assert(err, Equals, io.EOF) + + // With strong consistency, it fails again until reset. + err = session.Run("serverStatus", result) + c.Assert(err, Equals, io.EOF) + + session.Refresh() + + // Now we should be able to talk to the new master. + // Increase the timeout since this may take quite a while. + session.SetSyncTimeout(3 * time.Minute) + + // Insert some data to confirm it's indeed a master. + err = session.DB("mydb").C("mycoll").Insert(M{"n": 42}) + c.Assert(err, IsNil) } func (s *S) TestPrimaryShutdownMonotonic(c *C) { @@ -435,6 +538,9 @@ func (s *S) TestPrimaryShutdownMonotonic(c *C) { err = coll.Insert(M{"a": 1}) c.Assert(err, IsNil) + // Wait a bit for this to be synchronized to slaves. + time.Sleep(3 * time.Second) + result := &struct{ Host string }{} err = session.Run("serverStatus", result) c.Assert(err, IsNil) @@ -559,6 +665,9 @@ func (s *S) TestPrimaryShutdownEventual(c *C) { err = coll.Insert(M{"a": 1}) c.Assert(err, IsNil) + // Wait a bit for this to be synchronized to slaves. + time.Sleep(3 * time.Second) + // Kill the master. s.Stop(master) @@ -664,7 +773,7 @@ func (s *S) TestTopologySyncWithSlaveSeed(c *C) { c.Assert(result.Ok, Equals, true) // One connection to each during discovery. Master - // socket recycled for insert. + // socket recycled for insert. stats := mgo.GetStats() c.Assert(stats.MasterConns, Equals, 1) c.Assert(stats.SlaveConns, Equals, 2) @@ -721,6 +830,75 @@ func (s *S) TestDialWithTimeout(c *C) { c.Assert(started.After(time.Now().Add(-timeout*2)), Equals, true) } +func (s *S) TestSocketTimeout(c *C) { + if *fast { + c.Skip("-fast") + } + + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + s.Freeze("localhost:40001") + + timeout := 3 * time.Second + session.SetSocketTimeout(timeout) + started := time.Now() + + // Do something. + result := struct{ Ok bool }{} + err = session.Run("getLastError", &result) + c.Assert(err, ErrorMatches, ".*: i/o timeout") + c.Assert(started.Before(time.Now().Add(-timeout)), Equals, true) + c.Assert(started.After(time.Now().Add(-timeout*2)), Equals, true) +} + +func (s *S) TestSocketTimeoutOnDial(c *C) { + if *fast { + c.Skip("-fast") + } + + timeout := 1 * time.Second + + defer mgo.HackSyncSocketTimeout(timeout)() + + s.Freeze("localhost:40001") + + started := time.Now() + + session, err := mgo.DialWithTimeout("localhost:40001", timeout) + c.Assert(err, ErrorMatches, "no reachable servers") + c.Assert(session, IsNil) + + c.Assert(started.Before(time.Now().Add(-timeout)), Equals, true) + c.Assert(started.After(time.Now().Add(-20*time.Second)), Equals, true) +} + +func (s *S) TestSocketTimeoutOnInactiveSocket(c *C) { + if *fast { + c.Skip("-fast") + } + + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + timeout := 2 * time.Second + session.SetSocketTimeout(timeout) + + // Do something that relies on the timeout and works. + c.Assert(session.Ping(), IsNil) + + // Freeze and wait for the timeout to go by. + s.Freeze("localhost:40001") + time.Sleep(timeout + 500*time.Millisecond) + s.Thaw("localhost:40001") + + // Do something again. The timeout above should not have killed + // the socket as there was nothing to be done. + c.Assert(session.Ping(), IsNil) +} + func (s *S) TestDirect(c *C) { session, err := mgo.Dial("localhost:40012?connect=direct") c.Assert(err, IsNil) @@ -746,13 +924,51 @@ func (s *S) TestDirect(c *C) { err = coll.Insert(M{"test": 1}) c.Assert(err, ErrorMatches, "no reachable servers") - // Slave is still reachable. + // Writing to the local database is okay. + coll = session.DB("local").C("mycoll") + defer coll.RemoveAll(nil) + id := bson.NewObjectId() + err = coll.Insert(M{"_id": id}) + c.Assert(err, IsNil) + + // Data was stored in the right server. + n, err := coll.Find(M{"_id": id}).Count() + c.Assert(err, IsNil) + c.Assert(n, Equals, 1) + + // Server hasn't changed. result.Host = "" err = session.Run("serverStatus", result) c.Assert(err, IsNil) c.Assert(strings.HasSuffix(result.Host, ":40012"), Equals, true) } +func (s *S) TestDirectToUnknownStateMember(c *C) { + session, err := mgo.Dial("localhost:40041?connect=direct") + c.Assert(err, IsNil) + defer session.Close() + + session.SetMode(mgo.Monotonic, true) + + result := &struct{ Host string }{} + err = session.Run("serverStatus", result) + c.Assert(err, IsNil) + c.Assert(strings.HasSuffix(result.Host, ":40041"), Equals, true) + + // We've got no master, so it'll timeout. + session.SetSyncTimeout(5e8 * time.Nanosecond) + + coll := session.DB("mydb").C("mycoll") + err = coll.Insert(M{"test": 1}) + c.Assert(err, ErrorMatches, "no reachable servers") + + // Slave is still reachable. + result.Host = "" + err = session.Run("serverStatus", result) + c.Assert(err, IsNil) + c.Assert(strings.HasSuffix(result.Host, ":40041"), Equals, true) +} + type OpCounters struct { Insert int Query int @@ -862,7 +1078,10 @@ func (s *S) TestRemovalOfClusterMember(c *C) { c.Fatalf("Test started with bad cluster state: %v", master.LiveServers()) } - result := &struct{ IsMaster bool; Me string }{} + result := &struct { + IsMaster bool + Me string + }{} slave := master.Copy() slave.SetMode(mgo.Monotonic, true) // Monotonic can hold a non-master socket persistently. err = slave.Run("isMaster", result) @@ -875,10 +1094,6 @@ func (s *S) TestRemovalOfClusterMember(c *C) { master.Run(bson.D{{"$eval", `rs.add("` + slaveAddr + `")`}}, nil) master.Close() slave.Close() - - s.Stop(slaveAddr) - // For some reason it remains FATAL if we don't wait. - time.Sleep(3e9) }() c.Logf("========== Removing slave: %s ==========", slaveAddr) @@ -909,6 +1124,8 @@ func (s *S) TestRemovalOfClusterMember(c *C) { if len(live) != 2 { c.Errorf("Removed server still considered live: %#s", live) } + + c.Log("========== Test succeeded. ==========") } func (s *S) TestSocketLimit(c *C) { @@ -955,3 +1172,366 @@ func (s *S) TestSocketLimit(c *C) { c.Assert(delay > 3e9, Equals, true) c.Assert(delay < 6e9, Equals, true) } + +func (s *S) TestSetModeEventualIterBug(c *C) { + session1, err := mgo.Dial("localhost:40011") + c.Assert(err, IsNil) + defer session1.Close() + + session1.SetMode(mgo.Eventual, false) + + coll1 := session1.DB("mydb").C("mycoll") + + const N = 100 + for i := 0; i < N; i++ { + err = coll1.Insert(M{"_id": i}) + c.Assert(err, IsNil) + } + + c.Logf("Waiting until secondary syncs") + for { + n, err := coll1.Count() + c.Assert(err, IsNil) + if n == N { + c.Logf("Found all") + break + } + } + + session2, err := mgo.Dial("localhost:40011") + c.Assert(err, IsNil) + defer session2.Close() + + session2.SetMode(mgo.Eventual, false) + + coll2 := session2.DB("mydb").C("mycoll") + + i := 0 + iter := coll2.Find(nil).Batch(10).Iter() + var result struct{} + for iter.Next(&result) { + i++ + } + c.Assert(iter.Close(), Equals, nil) + c.Assert(i, Equals, N) +} + +func (s *S) TestCustomDialOld(c *C) { + dials := make(chan bool, 16) + dial := func(addr net.Addr) (net.Conn, error) { + tcpaddr, ok := addr.(*net.TCPAddr) + if !ok { + return nil, fmt.Errorf("unexpected address type: %T", addr) + } + dials <- true + return net.DialTCP("tcp", nil, tcpaddr) + } + info := mgo.DialInfo{ + Addrs: []string{"localhost:40012"}, + Dial: dial, + } + + // Use hostname here rather than IP, to make things trickier. + session, err := mgo.DialWithInfo(&info) + c.Assert(err, IsNil) + defer session.Close() + + const N = 3 + for i := 0; i < N; i++ { + select { + case <-dials: + case <-time.After(5 * time.Second): + c.Fatalf("expected %d dials, got %d", N, i) + } + } + select { + case <-dials: + c.Fatalf("got more dials than expected") + case <-time.After(100 * time.Millisecond): + } +} + +func (s *S) TestCustomDialNew(c *C) { + dials := make(chan bool, 16) + dial := func(addr *mgo.ServerAddr) (net.Conn, error) { + dials <- true + if addr.TCPAddr().Port == 40012 { + c.Check(addr.String(), Equals, "localhost:40012") + } + return net.DialTCP("tcp", nil, addr.TCPAddr()) + } + info := mgo.DialInfo{ + Addrs: []string{"localhost:40012"}, + DialServer: dial, + } + + // Use hostname here rather than IP, to make things trickier. + session, err := mgo.DialWithInfo(&info) + c.Assert(err, IsNil) + defer session.Close() + + const N = 3 + for i := 0; i < N; i++ { + select { + case <-dials: + case <-time.After(5 * time.Second): + c.Fatalf("expected %d dials, got %d", N, i) + } + } + select { + case <-dials: + c.Fatalf("got more dials than expected") + case <-time.After(100 * time.Millisecond): + } +} + +func (s *S) TestPrimaryShutdownOnAuthShard(c *C) { + if *fast { + c.Skip("-fast") + } + + // Dial the shard. + session, err := mgo.Dial("localhost:40203") + c.Assert(err, IsNil) + defer session.Close() + + // Login and insert something to make it more realistic. + session.DB("admin").Login("root", "rapadura") + coll := session.DB("mydb").C("mycoll") + err = coll.Insert(bson.M{"n": 1}) + c.Assert(err, IsNil) + + // Dial the replica set to figure the master out. + rs, err := mgo.Dial("root:rapadura@localhost:40031") + c.Assert(err, IsNil) + defer rs.Close() + + // With strong consistency, this will open a socket to the master. + result := &struct{ Host string }{} + err = rs.Run("serverStatus", result) + c.Assert(err, IsNil) + + // Kill the master. + host := result.Host + s.Stop(host) + + // This must fail, since the connection was broken. + err = rs.Run("serverStatus", result) + c.Assert(err, Equals, io.EOF) + + // This won't work because the master just died. + err = coll.Insert(bson.M{"n": 2}) + c.Assert(err, NotNil) + + // Refresh session and wait for re-election. + session.Refresh() + for i := 0; i < 60; i++ { + err = coll.Insert(bson.M{"n": 3}) + if err == nil { + break + } + c.Logf("Waiting for replica set to elect a new master. Last error: %v", err) + time.Sleep(500 * time.Millisecond) + } + c.Assert(err, IsNil) + + count, err := coll.Count() + c.Assert(count > 1, Equals, true) +} + +func (s *S) TestNearestSecondary(c *C) { + defer mgo.HackPingDelay(3 * time.Second)() + + rs1a := "127.0.0.1:40011" + rs1b := "127.0.0.1:40012" + rs1c := "127.0.0.1:40013" + s.Freeze(rs1b) + + session, err := mgo.Dial(rs1a) + c.Assert(err, IsNil) + defer session.Close() + + // Wait for the sync up to run through the first couple of servers. + for len(session.LiveServers()) != 2 { + c.Log("Waiting for two servers to be alive...") + time.Sleep(100 * time.Millisecond) + } + + // Extra delay to ensure the third server gets penalized. + time.Sleep(500 * time.Millisecond) + + // Release third server. + s.Thaw(rs1b) + + // Wait for it to come up. + for len(session.LiveServers()) != 3 { + c.Log("Waiting for all servers to be alive...") + time.Sleep(100 * time.Millisecond) + } + + session.SetMode(mgo.Monotonic, true) + var result struct{ Host string } + + // See which slave picks the line, several times to avoid chance. + for i := 0; i < 10; i++ { + session.Refresh() + err = session.Run("serverStatus", &result) + c.Assert(err, IsNil) + c.Assert(hostPort(result.Host), Equals, hostPort(rs1c)) + } + + if *fast { + // Don't hold back for several seconds. + return + } + + // Now hold the other server for long enough to penalize it. + s.Freeze(rs1c) + time.Sleep(5 * time.Second) + s.Thaw(rs1c) + + // Wait for the ping to be processed. + time.Sleep(500 * time.Millisecond) + + // Repeating the test should now pick the former server consistently. + for i := 0; i < 10; i++ { + session.Refresh() + err = session.Run("serverStatus", &result) + c.Assert(err, IsNil) + c.Assert(hostPort(result.Host), Equals, hostPort(rs1b)) + } +} + +func (s *S) TestConnectCloseConcurrency(c *C) { + restore := mgo.HackPingDelay(500 * time.Millisecond) + defer restore() + var wg sync.WaitGroup + const n = 500 + wg.Add(n) + for i := 0; i < n; i++ { + go func() { + defer wg.Done() + session, err := mgo.Dial("localhost:40001") + if err != nil { + c.Fatal(err) + } + time.Sleep(1) + session.Close() + }() + } + wg.Wait() +} + +func (s *S) TestSelectServers(c *C) { + if !s.versionAtLeast(2, 2) { + c.Skip("read preferences introduced in 2.2") + } + + session, err := mgo.Dial("localhost:40011") + c.Assert(err, IsNil) + defer session.Close() + + session.SetMode(mgo.Eventual, true) + + var result struct{ Host string } + + session.Refresh() + session.SelectServers(bson.D{{"rs1", "b"}}) + err = session.Run("serverStatus", &result) + c.Assert(err, IsNil) + c.Assert(hostPort(result.Host), Equals, "40012") + + session.Refresh() + session.SelectServers(bson.D{{"rs1", "c"}}) + err = session.Run("serverStatus", &result) + c.Assert(err, IsNil) + c.Assert(hostPort(result.Host), Equals, "40013") +} + +func (s *S) TestSelectServersWithMongos(c *C) { + if !s.versionAtLeast(2, 2) { + c.Skip("read preferences introduced in 2.2") + } + + session, err := mgo.Dial("localhost:40021") + c.Assert(err, IsNil) + defer session.Close() + + ssresult := &struct{ Host string }{} + imresult := &struct{ IsMaster bool }{} + + // Figure the master while still using the strong session. + err = session.Run("serverStatus", ssresult) + c.Assert(err, IsNil) + err = session.Run("isMaster", imresult) + c.Assert(err, IsNil) + master := ssresult.Host + c.Assert(imresult.IsMaster, Equals, true, Commentf("%s is not the master", master)) + + var slave1, slave2 string + switch hostPort(master) { + case "40021": + slave1, slave2 = "b", "c" + case "40022": + slave1, slave2 = "a", "c" + case "40023": + slave1, slave2 = "a", "b" + } + + // Collect op counters for everyone. + opc21a, err := getOpCounters("localhost:40021") + c.Assert(err, IsNil) + opc22a, err := getOpCounters("localhost:40022") + c.Assert(err, IsNil) + opc23a, err := getOpCounters("localhost:40023") + c.Assert(err, IsNil) + + // Do a SlaveOk query through MongoS + mongos, err := mgo.Dial("localhost:40202") + c.Assert(err, IsNil) + defer mongos.Close() + + mongos.SetMode(mgo.Monotonic, true) + + mongos.Refresh() + mongos.SelectServers(bson.D{{"rs2", slave1}}) + coll := mongos.DB("mydb").C("mycoll") + result := &struct{}{} + for i := 0; i != 5; i++ { + err := coll.Find(nil).One(result) + c.Assert(err, Equals, mgo.ErrNotFound) + } + + mongos.Refresh() + mongos.SelectServers(bson.D{{"rs2", slave2}}) + coll = mongos.DB("mydb").C("mycoll") + for i := 0; i != 7; i++ { + err := coll.Find(nil).One(result) + c.Assert(err, Equals, mgo.ErrNotFound) + } + + // Collect op counters for everyone again. + opc21b, err := getOpCounters("localhost:40021") + c.Assert(err, IsNil) + opc22b, err := getOpCounters("localhost:40022") + c.Assert(err, IsNil) + opc23b, err := getOpCounters("localhost:40023") + c.Assert(err, IsNil) + + switch hostPort(master) { + case "40021": + c.Check(opc21b.Query-opc21a.Query, Equals, 0) + c.Check(opc22b.Query-opc22a.Query, Equals, 5) + c.Check(opc23b.Query-opc23a.Query, Equals, 7) + case "40022": + c.Check(opc21b.Query-opc21a.Query, Equals, 5) + c.Check(opc22b.Query-opc22a.Query, Equals, 0) + c.Check(opc23b.Query-opc23a.Query, Equals, 7) + case "40023": + c.Check(opc21b.Query-opc21a.Query, Equals, 5) + c.Check(opc22b.Query-opc22a.Query, Equals, 7) + c.Check(opc23b.Query-opc23a.Query, Equals, 0) + default: + c.Fatal("Uh?") + } +} diff --git a/third_party/labix.org/v2/mgo/export_test.go b/third_party/labix.org/v2/mgo/export_test.go index d56e6af8c..12b2a59e1 100644 --- a/third_party/labix.org/v2/mgo/export_test.go +++ b/third_party/labix.org/v2/mgo/export_test.go @@ -1,5 +1,9 @@ package mgo +import ( + "time" +) + func HackSocketsPerServer(newLimit int) (restore func()) { oldLimit := newLimit restore = func() { @@ -8,3 +12,21 @@ func HackSocketsPerServer(newLimit int) (restore func()) { socketsPerServer = newLimit return } + +func HackPingDelay(newDelay time.Duration) (restore func()) { + oldDelay := pingDelay + restore = func() { + pingDelay = oldDelay + } + pingDelay = newDelay + return +} + +func HackSyncSocketTimeout(newTimeout time.Duration) (restore func()) { + oldTimeout := syncSocketTimeout + restore = func() { + syncSocketTimeout = oldTimeout + } + syncSocketTimeout = newTimeout + return +} diff --git a/third_party/labix.org/v2/mgo/gridfs.go b/third_party/labix.org/v2/mgo/gridfs.go index 5478fa33b..253b276dc 100644 --- a/third_party/labix.org/v2/mgo/gridfs.go +++ b/third_party/labix.org/v2/mgo/gridfs.go @@ -1,18 +1,18 @@ // mgo - MongoDB driver for Go -// +// // Copyright (c) 2010-2012 - Gustavo Niemeyer -// +// // All rights reserved. // // Redistribution and use in source and binary forms, with or without -// modification, are permitted provided that the following conditions are met: -// +// modification, are permitted provided that the following conditions are met: +// // 1. Redistributions of source code must retain the above copyright notice, this -// list of conditions and the following disclaimer. +// list of conditions and the following disclaimer. // 2. Redistributions in binary form must reproduce the above copyright notice, // this list of conditions and the following disclaimer in the documentation -// and/or other materials provided with the distribution. -// +// and/or other materials provided with the distribution. +// // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND // ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED // WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE @@ -27,12 +27,12 @@ package mgo import ( + "camlistore.org/third_party/labix.org/v2/mgo/bson" "crypto/md5" "encoding/hex" "errors" "hash" "io" - "camlistore.org/third_party/labix.org/v2/mgo/bson" "os" "sync" "time" @@ -72,10 +72,10 @@ type GridFile struct { } type gfsFile struct { - Id interface{} "_id" - ChunkSize int "chunkSize" - UploadDate time.Time "uploadDate" - Length int64 ",minsize" + Id interface{} "_id" + ChunkSize int "chunkSize" + UploadDate time.Time "uploadDate" + Length int64 ",minsize" MD5 string Filename string ",omitempty" ContentType string "contentType,omitempty" @@ -157,15 +157,15 @@ func (gfs *GridFS) Create(name string) (file *GridFile, err error) { return } -// OpenId returns a file with the provided id in case it exists or an error -// instead. If the file isn't found, err will be set to mgo.ErrNotFound. +// OpenId returns the file with the provided id, for reading. +// If the file isn't found, err will be set to mgo.ErrNotFound. // // It's important to Close files whether they are being written to // or read from, and to check the err result to ensure the operation // completed successfully. // // The following example will print the first 8192 bytes from the file: -// +// // func check(err os.Error) { // if err != nil { // panic(err.String()) @@ -205,15 +205,16 @@ func (gfs *GridFS) OpenId(id interface{}) (file *GridFile, err error) { return } -// Open returns the most recent uploaded file with the provided name, or an -// error instead. If the file isn't found, err will be set to mgo.ErrNotFound. +// Open returns the most recently uploaded file with the provided +// name, for reading. If the file isn't found, err will be set +// to mgo.ErrNotFound. // // It's important to Close files whether they are being written to // or read from, and to check the err result to ensure the operation // completed successfully. // // The following example will print the first 8192 bytes from the file: -// +// // file, err := db.GridFS("fs").Open("myfile.txt") // check(err) // b := make([]byte, 8192) @@ -248,17 +249,17 @@ func (gfs *GridFS) Open(name string) (file *GridFile, err error) { return } -// OpenNext opens the next file from iter, sets *file to it, and returns -// true on the success case. If no more documents are available on iter or -// an error occurred, *file is set to nil and the result is false. Errors -// will be available on iter.Err(). +// OpenNext opens the next file from iter for reading, sets *file to it, +// and returns true on the success case. If no more documents are available +// on iter or an error occurred, *file is set to nil and the result is false. +// Errors will be available via iter.Err(). // // The iter parameter must be an iterator on the GridFS files collection. // Using the GridFS.Find method is an easy way to obtain such an iterator, // but any iterator on the collection will work. // -// If the provided *file is non-nil, OpenNext will close it before -// iterating to the next element. This means that in a loop one only +// If the provided *file is non-nil, OpenNext will close it before attempting +// to iterate to the next element. This means that in a loop one only // has to worry about closing files when breaking out of the loop early // (break, return, or panic). // @@ -271,8 +272,8 @@ func (gfs *GridFS) Open(name string) (file *GridFile, err error) { // for gfs.OpenNext(iter, &f) { // fmt.Printf("Filename: %s\n", f.Name()) // } -// if iter.Err() != nil { -// panic(iter.Err()) +// if iter.Close() != nil { +// panic(iter.Close()) // } // func (gfs *GridFS) OpenNext(iter *Iter, file **GridFile) bool { @@ -280,7 +281,7 @@ func (gfs *GridFS) OpenNext(iter *Iter, file **GridFile) bool { // Ignoring the error here shouldn't be a big deal // as we're reading the file and the loop iteration // for this file is finished. - _ = file.Close() + _ = (*file).Close() } var doc gfsFile if !iter.Next(&doc) { @@ -306,7 +307,7 @@ func (gfs *GridFS) OpenNext(iter *Iter, file **GridFile) bool { // // files := db.C("fs" + ".files") // iter := files.Find(nil).Iter() -// +// func (gfs *GridFS) Find(query interface{}) *Query { return gfs.Files.Find(query) } @@ -335,7 +336,7 @@ func (gfs *GridFS) Remove(name string) (err error) { } } if err == nil { - err = iter.Err() + err = iter.Close() } return err } @@ -509,6 +510,8 @@ func (file *GridFile) Close() (err error) { // // The file will internally cache the data so that all but the last // chunk sent to the database have the size defined by SetChunkSize. +// This also means that errors may be deferred until a future call +// to Write or Close. // // The parameters and behavior of this function turn the file // into an io.Writer. diff --git a/third_party/labix.org/v2/mgo/gridfs_test.go b/third_party/labix.org/v2/mgo/gridfs_test.go index 11b6e2372..caa938d63 100644 --- a/third_party/labix.org/v2/mgo/gridfs_test.go +++ b/third_party/labix.org/v2/mgo/gridfs_test.go @@ -1,18 +1,18 @@ // mgo - MongoDB driver for Go -// +// // Copyright (c) 2010-2012 - Gustavo Niemeyer -// +// // All rights reserved. // // Redistribution and use in source and binary forms, with or without -// modification, are permitted provided that the following conditions are met: -// +// modification, are permitted provided that the following conditions are met: +// // 1. Redistributions of source code must retain the above copyright notice, this -// list of conditions and the following disclaimer. +// list of conditions and the following disclaimer. // 2. Redistributions in binary form must reproduce the above copyright notice, // this list of conditions and the following disclaimer in the documentation -// and/or other materials provided with the distribution. -// +// and/or other materials provided with the distribution. +// // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND // ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED // WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE @@ -27,10 +27,10 @@ package mgo_test import ( - "io" - . "camlistore.org/third_party/launchpad.net/gocheck" "camlistore.org/third_party/labix.org/v2/mgo" "camlistore.org/third_party/labix.org/v2/mgo/bson" + . "camlistore.org/third_party/launchpad.net/gocheck" + "io" "os" "time" ) @@ -161,7 +161,7 @@ func (s *S) TestGridFSFileDetails(c *C) { ud := file.UploadDate() now := time.Now() c.Assert(ud.Before(now), Equals, true) - c.Assert(ud.After(now.Add(-3 * time.Second)), Equals, true) + c.Assert(ud.After(now.Add(-3*time.Second)), Equals, true) result := M{} err = db.C("fs.files").Find(nil).One(result) @@ -177,7 +177,7 @@ func (s *S) TestGridFSFileDetails(c *C) { "md5": "1e50210a0202497fb79bc38b6ade6c34", "filename": "myfile2.txt", "contentType": "text/plain", - "metadata": bson.M{"any": "thing"}, + "metadata": M{"any": "thing"}, } c.Assert(result, DeepEquals, expected) } @@ -249,7 +249,7 @@ func (s *S) TestGridFSCreateWithChunking(c *C) { } break } - c.Assert(iter.Err(), IsNil) + c.Assert(iter.Close(), IsNil) result["_id"] = "" @@ -589,7 +589,7 @@ func (s *S) TestGridFSOpenNext(c *C) { ok = gfs.OpenNext(iter, &f) c.Assert(ok, Equals, false) - c.Assert(iter.Err(), IsNil) + c.Assert(iter.Close(), IsNil) c.Assert(f, IsNil) // Do it again with a more restrictive query to make sure @@ -602,6 +602,6 @@ func (s *S) TestGridFSOpenNext(c *C) { ok = gfs.OpenNext(iter, &f) c.Assert(ok, Equals, false) - c.Assert(iter.Err(), IsNil) + c.Assert(iter.Close(), IsNil) c.Assert(f, IsNil) } diff --git a/third_party/labix.org/v2/mgo/log.go b/third_party/labix.org/v2/mgo/log.go index f25546547..f1d74b30d 100644 --- a/third_party/labix.org/v2/mgo/log.go +++ b/third_party/labix.org/v2/mgo/log.go @@ -1,18 +1,18 @@ // mgo - MongoDB driver for Go -// +// // Copyright (c) 2010-2012 - Gustavo Niemeyer -// +// // All rights reserved. // // Redistribution and use in source and binary forms, with or without -// modification, are permitted provided that the following conditions are met: -// +// modification, are permitted provided that the following conditions are met: +// // 1. Redistributions of source code must retain the above copyright notice, this -// list of conditions and the following disclaimer. +// list of conditions and the following disclaimer. // 2. Redistributions in binary form must reproduce the above copyright notice, // this list of conditions and the following disclaimer in the documentation -// and/or other materials provided with the distribution. -// +// and/or other materials provided with the distribution. +// // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND // ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED // WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE diff --git a/third_party/labix.org/v2/mgo/queue.go b/third_party/labix.org/v2/mgo/queue.go index 94eedb555..e9245de70 100644 --- a/third_party/labix.org/v2/mgo/queue.go +++ b/third_party/labix.org/v2/mgo/queue.go @@ -1,18 +1,18 @@ // mgo - MongoDB driver for Go -// +// // Copyright (c) 2010-2012 - Gustavo Niemeyer -// +// // All rights reserved. // // Redistribution and use in source and binary forms, with or without -// modification, are permitted provided that the following conditions are met: -// +// modification, are permitted provided that the following conditions are met: +// // 1. Redistributions of source code must retain the above copyright notice, this -// list of conditions and the following disclaimer. +// list of conditions and the following disclaimer. // 2. Redistributions in binary form must reproduce the above copyright notice, // this list of conditions and the following disclaimer in the documentation -// and/or other materials provided with the distribution. -// +// and/or other materials provided with the distribution. +// // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND // ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED // WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE @@ -84,5 +84,8 @@ func (q *queue) expand() { copy(elems[newpopi:], q.elems[q.popi:]) q.popi = newpopi } + for i := range q.elems { + q.elems[i] = nil // Help GC. + } q.elems = elems } diff --git a/third_party/labix.org/v2/mgo/queue_test.go b/third_party/labix.org/v2/mgo/queue_test.go index fb36d68c8..a912aedfc 100644 --- a/third_party/labix.org/v2/mgo/queue_test.go +++ b/third_party/labix.org/v2/mgo/queue_test.go @@ -1,18 +1,18 @@ // mgo - MongoDB driver for Go -// +// // Copyright (c) 2010-2012 - Gustavo Niemeyer -// +// // All rights reserved. // // Redistribution and use in source and binary forms, with or without -// modification, are permitted provided that the following conditions are met: -// +// modification, are permitted provided that the following conditions are met: +// // 1. Redistributions of source code must retain the above copyright notice, this -// list of conditions and the following disclaimer. +// list of conditions and the following disclaimer. // 2. Redistributions in binary form must reproduce the above copyright notice, // this list of conditions and the following disclaimer in the documentation -// and/or other materials provided with the distribution. -// +// and/or other materials provided with the distribution. +// // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND // ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED // WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE diff --git a/third_party/labix.org/v2/mgo/server.go b/third_party/labix.org/v2/mgo/server.go index ffb593851..5569e4add 100644 --- a/third_party/labix.org/v2/mgo/server.go +++ b/third_party/labix.org/v2/mgo/server.go @@ -1,18 +1,18 @@ // mgo - MongoDB driver for Go -// +// // Copyright (c) 2010-2012 - Gustavo Niemeyer -// +// // All rights reserved. // // Redistribution and use in source and binary forms, with or without -// modification, are permitted provided that the following conditions are met: -// +// modification, are permitted provided that the following conditions are met: +// // 1. Redistributions of source code must retain the above copyright notice, this -// list of conditions and the following disclaimer. +// list of conditions and the following disclaimer. // 2. Redistributions in binary form must reproduce the above copyright notice, // this list of conditions and the following disclaimer in the documentation -// and/or other materials provided with the distribution. -// +// and/or other materials provided with the distribution. +// // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND // ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED // WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE @@ -27,10 +27,12 @@ package mgo import ( + "camlistore.org/third_party/labix.org/v2/mgo/bson" "errors" "net" "sort" "sync" + "time" ) // --------------------------------------------------------------------------- @@ -44,31 +46,50 @@ type mongoServer struct { unusedSockets []*mongoSocket liveSockets []*mongoSocket closed bool - master bool + abended bool sync chan bool + dial dialer + pingValue time.Duration + pingIndex int + pingCount uint32 + pingWindow [6]time.Duration + info *mongoServerInfo } -func newServer(addr string, sync chan bool) (server *mongoServer, err error) { - tcpaddr, err := net.ResolveTCPAddr("tcp", addr) - if err != nil { - log("Failed to resolve ", addr, ": ", err.Error()) - return nil, err - } +type dialer struct { + old func(addr net.Addr) (net.Conn, error) + new func(addr *ServerAddr) (net.Conn, error) +} - resolvedAddr := tcpaddr.String() - if resolvedAddr != addr { - debug("Address ", addr, " resolved as ", resolvedAddr) - } - server = &mongoServer{ +func (dial dialer) isSet() bool { + return dial.old != nil || dial.new != nil +} + +type mongoServerInfo struct { + Master bool + Mongos bool + Tags bson.D +} + +var defaultServerInfo mongoServerInfo + +func newServer(addr string, tcpaddr *net.TCPAddr, sync chan bool, dial dialer) *mongoServer { + server := &mongoServer{ Addr: addr, - ResolvedAddr: resolvedAddr, + ResolvedAddr: tcpaddr.String(), tcpaddr: tcpaddr, sync: sync, + dial: dial, + info: &defaultServerInfo, } - return + // Once so the server gets a ping value, then loop in background. + server.pinger(false) + go server.pinger(true) + return server } var errSocketLimit = errors.New("per-server connection limit reached") +var errServerClosed = errors.New("server was closed") // AcquireSocket returns a socket for communicating with the server. // This will attempt to reuse an old connection, if one is available. Otherwise, @@ -77,56 +98,82 @@ var errSocketLimit = errors.New("per-server connection limit reached") // the same number of times as AcquireSocket + Acquire were called for it. // If the limit argument is not zero, a socket will only be returned if the // number of sockets in use for this server is under the provided limit. -func (server *mongoServer) AcquireSocket(limit int) (socket *mongoSocket, err error) { +func (server *mongoServer) AcquireSocket(limit int, timeout time.Duration) (socket *mongoSocket, abended bool, err error) { for { server.Lock() + abended = server.abended + if server.closed { + server.Unlock() + return nil, abended, errServerClosed + } n := len(server.unusedSockets) if limit > 0 && len(server.liveSockets)-n >= limit { server.Unlock() - return nil, errSocketLimit + return nil, false, errSocketLimit } if n > 0 { socket = server.unusedSockets[n-1] server.unusedSockets[n-1] = nil // Help GC. server.unusedSockets = server.unusedSockets[:n-1] + info := server.info server.Unlock() - err = socket.InitialAcquire() + err = socket.InitialAcquire(info, timeout) if err != nil { continue } } else { server.Unlock() - socket, err = server.Connect() + socket, err = server.Connect(timeout) if err == nil { server.Lock() + // We've waited for the Connect, see if we got + // closed in the meantime + if server.closed { + server.Unlock() + socket.Release() + socket.Close() + return nil, abended, errServerClosed + } server.liveSockets = append(server.liveSockets, socket) server.Unlock() } } return } - panic("unreached") + panic("unreachable") } // Connect establishes a new connection to the server. This should // generally be done through server.AcquireSocket(). -func (server *mongoServer) Connect() (*mongoSocket, error) { +func (server *mongoServer) Connect(timeout time.Duration) (*mongoSocket, error) { server.RLock() - addr := server.Addr - tcpaddr := server.tcpaddr - master := server.master + master := server.info.Master + dial := server.dial server.RUnlock() - log("Establishing new connection to ", addr, "...") - conn, err := net.DialTCP("tcp", nil, tcpaddr) + logf("Establishing new connection to %s (timeout=%s)...", server.Addr, timeout) + var conn net.Conn + var err error + switch { + case !dial.isSet(): + // Cannot do this because it lacks timeout support. :-( + //conn, err = net.DialTCP("tcp", nil, server.tcpaddr) + conn, err = net.DialTimeout("tcp", server.ResolvedAddr, timeout) + case dial.old != nil: + conn, err = dial.old(server.tcpaddr) + case dial.new != nil: + conn, err = dial.new(&ServerAddr{server.Addr, server.tcpaddr}) + default: + panic("dialer is set, but both dial.old and dial.new are nil") + } if err != nil { - log("Connection to ", addr, " failed: ", err.Error()) + logf("Connection to %s failed: %v", server.Addr, err.Error()) return nil, err } - log("Connection to ", addr, " established.") + logf("Connection to %s established.", server.Addr) stats.conn(+1, master) - return newSocket(server, conn), nil + return newSocket(server, conn, timeout), nil } // Close forces closing all sockets that are alive, whether @@ -138,9 +185,8 @@ func (server *mongoServer) Close() { unusedSockets := server.unusedSockets server.liveSockets = nil server.unusedSockets = nil - addr := server.Addr server.Unlock() - logf("Connections to %s closing (%d live sockets).", addr, len(liveSockets)) + logf("Connections to %s closing (%d live sockets).", server.Addr, len(liveSockets)) for i, s := range liveSockets { s.Close() liveSockets[i] = nil @@ -176,6 +222,7 @@ func removeSocket(sockets []*mongoSocket, socket *mongoSocket) []*mongoSocket { // abnormally, and thus should be discarded rather than cached. func (server *mongoServer) AbendSocket(socket *mongoSocket) { server.Lock() + server.abended = true if server.closed { server.Unlock() return @@ -190,29 +237,83 @@ func (server *mongoServer) AbendSocket(socket *mongoSocket) { } } -// Merge other into server, which must both be communicating with -// the same server address. -func (server *mongoServer) Merge(other *mongoServer) { +func (server *mongoServer) SetInfo(info *mongoServerInfo) { server.Lock() - server.master = other.master - server.Unlock() - // Sockets of other are ignored for the moment. Merging them - // would mean a large number of sockets being cached on longer - // recovering situations. - other.Close() -} - -func (server *mongoServer) SetMaster(isMaster bool) { - server.Lock() - server.master = isMaster + server.info = info server.Unlock() } -func (server *mongoServer) IsMaster() bool { - server.RLock() - result := server.master - server.RUnlock() - return result +func (server *mongoServer) Info() *mongoServerInfo { + server.Lock() + info := server.info + server.Unlock() + return info +} + +func (server *mongoServer) hasTags(serverTags []bson.D) bool { +NextTagSet: + for _, tags := range serverTags { + NextReqTag: + for _, req := range tags { + for _, has := range server.info.Tags { + if req.Name == has.Name { + if req.Value == has.Value { + continue NextReqTag + } + continue NextTagSet + } + } + continue NextTagSet + } + return true + } + return false +} + +var pingDelay = 5 * time.Second + +func (server *mongoServer) pinger(loop bool) { + op := queryOp{ + collection: "admin.$cmd", + query: bson.D{{"ping", 1}}, + flags: flagSlaveOk, + limit: -1, + } + for { + if loop { + time.Sleep(pingDelay) + } + op := op + socket, _, err := server.AcquireSocket(0, 3*pingDelay) + if err == nil { + start := time.Now() + _, _ = socket.SimpleQuery(&op) + delay := time.Now().Sub(start) + + server.pingWindow[server.pingIndex] = delay + server.pingIndex = (server.pingIndex + 1) % len(server.pingWindow) + server.pingCount++ + var max time.Duration + for i := 0; i < len(server.pingWindow) && uint32(i) < server.pingCount; i++ { + if server.pingWindow[i] > max { + max = server.pingWindow[i] + } + } + socket.Release() + server.Lock() + if server.closed { + loop = false + } + server.pingValue = max + server.Unlock() + logf("Ping for %s is %d ms", server.Addr, max/time.Millisecond) + } else if err == errServerClosed { + return + } + if !loop { + return + } + } } type mongoServerSlice []*mongoServer @@ -233,8 +334,7 @@ func (s mongoServerSlice) Sort() { sort.Sort(s) } -func (s mongoServerSlice) Search(other *mongoServer) (i int, ok bool) { - resolvedAddr := other.ResolvedAddr +func (s mongoServerSlice) Search(resolvedAddr string) (i int, ok bool) { n := len(s) i = sort.Search(n, func(i int) bool { return s[i].ResolvedAddr >= resolvedAddr @@ -246,8 +346,8 @@ type mongoServers struct { slice mongoServerSlice } -func (servers *mongoServers) Search(other *mongoServer) (server *mongoServer) { - if i, ok := servers.slice.Search(other); ok { +func (servers *mongoServers) Search(resolvedAddr string) (server *mongoServer) { + if i, ok := servers.slice.Search(resolvedAddr); ok { return servers.slice[i] } return nil @@ -259,7 +359,7 @@ func (servers *mongoServers) Add(server *mongoServer) { } func (servers *mongoServers) Remove(other *mongoServer) (server *mongoServer) { - if i, found := servers.slice.Search(other); found { + if i, found := servers.slice.Search(other.ResolvedAddr); found { server = servers.slice[i] copy(servers.slice[i:], servers.slice[i+1:]) n := len(servers.slice) - 1 @@ -285,26 +385,31 @@ func (servers *mongoServers) Empty() bool { return len(servers.slice) == 0 } -// MostAvailable returns the best guess of what would be the -// most interesting server to perform operations on at this -// point in time. -func (servers *mongoServers) MostAvailable() *mongoServer { - if len(servers.slice) == 0 { - panic("MostAvailable: can't be used on empty server list") - } +// BestFit returns the best guess of what would be the most interesting +// server to perform operations on at this point in time. +func (servers *mongoServers) BestFit(serverTags []bson.D) *mongoServer { var best *mongoServer - for i, next := range servers.slice { - if i == 0 { + for _, next := range servers.slice { + if best == nil { best = next best.RLock() + if serverTags != nil && !next.info.Mongos && !best.hasTags(serverTags) { + best.RUnlock() + best = nil + } continue } next.RLock() swap := false switch { - case next.master != best.master: + case serverTags != nil && !next.info.Mongos && !next.hasTags(serverTags): + // Must have requested tags. + case next.info.Master != best.info.Master: // Prefer slaves. - swap = best.master + swap = best.info.Master + case absDuration(next.pingValue-best.pingValue) > 15*time.Millisecond: + // Prefer nearest server. + swap = next.pingValue < best.pingValue case len(next.liveSockets)-len(next.unusedSockets) < len(best.liveSockets)-len(best.unusedSockets): // Prefer servers with less connections. swap = true @@ -316,6 +421,15 @@ func (servers *mongoServers) MostAvailable() *mongoServer { next.RUnlock() } } - best.RUnlock() + if best != nil { + best.RUnlock() + } return best } + +func absDuration(d time.Duration) time.Duration { + if d < 0 { + return -d + } + return d +} diff --git a/third_party/labix.org/v2/mgo/session.go b/third_party/labix.org/v2/mgo/session.go index 5cf637e66..4364058a6 100644 --- a/third_party/labix.org/v2/mgo/session.go +++ b/third_party/labix.org/v2/mgo/session.go @@ -1,18 +1,18 @@ // mgo - MongoDB driver for Go -// +// // Copyright (c) 2010-2012 - Gustavo Niemeyer -// +// // All rights reserved. // // Redistribution and use in source and binary forms, with or without -// modification, are permitted provided that the following conditions are met: -// +// modification, are permitted provided that the following conditions are met: +// // 1. Redistributions of source code must retain the above copyright notice, this -// list of conditions and the following disclaimer. +// list of conditions and the following disclaimer. // 2. Redistributions in binary form must reproduce the above copyright notice, // this list of conditions and the following disclaimer in the documentation -// and/or other materials provided with the distribution. -// +// and/or other materials provided with the distribution. +// // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND // ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED // WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE @@ -27,14 +27,14 @@ package mgo import ( + "camlistore.org/third_party/labix.org/v2/mgo/bson" "crypto/md5" "encoding/hex" "errors" "fmt" - "camlistore.org/third_party/labix.org/v2/mgo/bson" "math" + "net" "reflect" - "runtime" "sort" "strconv" "strings" @@ -54,17 +54,19 @@ const ( // need to be updated too. type Session struct { - m sync.RWMutex - cluster_ *mongoCluster - socket *mongoSocket - socketIsMaster bool - slaveOk bool - consistency mode - queryConfig query - safeOp *queryOp - syncTimeout time.Duration - urlauth *authInfo - auth []authInfo + m sync.RWMutex + cluster_ *mongoCluster + slaveSocket *mongoSocket + masterSocket *mongoSocket + slaveOk bool + consistency mode + queryConfig query + safeOp *queryOp + syncTimeout time.Duration + sockTimeout time.Duration + defaultdb string + dialAuth *authInfo + auth []authInfo } type Database struct { @@ -102,12 +104,13 @@ type Iter struct { m sync.Mutex gotReply sync.Cond session *Session + server *mongoServer docData queue err error op getMoreOp prefetch float64 limit int32 - pendingDocs int + docsToReceive int docsBeforeMore int timeout time.Duration timedout bool @@ -124,8 +127,8 @@ const defaultPrefetch = 0.25 // // Dial will timeout after 10 seconds if a server isn't reached. The returned // session will timeout operations after one minute by default if servers -// aren't available. To customize the timeout, see DialWithTimeout -// and SetSyncTimeout. +// aren't available. To customize the timeout, see DialWithTimeout, +// SetSyncTimeout, and SetSocketTimeout. // // This method is generally called just once for a given cluster. Further // sessions to the same cluster are then established using the New or Copy @@ -169,12 +172,13 @@ const defaultPrefetch = 0.25 // // http://www.mongodb.org/display/DOCS/Connections // -func Dial(url string) (session *Session, err error) { - session, err = DialWithTimeout(url, 10*time.Second) +func Dial(url string) (*Session, error) { + session, err := DialWithTimeout(url, 10*time.Second) if err == nil { - session.SetSyncTimeout(time.Minute) + session.SetSyncTimeout(1 * time.Minute) + session.SetSocketTimeout(1 * time.Minute) } - return + return session, err } // DialWithTimeout works like Dial, but uses timeout as the amount of time to @@ -183,13 +187,13 @@ func Dial(url string) (session *Session, err error) { // forever waiting for a connection to be made. // // See SetSyncTimeout for customizing the timeout for the session. -func DialWithTimeout(url string, timeout time.Duration) (session *Session, err error) { - servers, auth, options, err := parseURL(url) +func DialWithTimeout(url string, timeout time.Duration) (*Session, error) { + uinfo, err := parseURL(url) if err != nil { return nil, err } direct := false - for k, v := range options { + for k, v := range uinfo.options { switch k { case "connect": if v == "direct" { @@ -201,15 +205,102 @@ func DialWithTimeout(url string, timeout time.Duration) (session *Session, err e } fallthrough default: - err = errors.New("Unsupported connection URL option: " + k + "=" + v) - return + return nil, errors.New("Unsupported connection URL option: " + k + "=" + v) } } - cluster := newCluster(servers, direct) - session = newSession(Eventual, cluster, nil, timeout) - if auth.user != "" { - session.urlauth = &auth - session.auth = []authInfo{auth} + info := DialInfo{ + Addrs: uinfo.addrs, + Direct: direct, + Timeout: timeout, + Username: uinfo.user, + Password: uinfo.pass, + Database: uinfo.db, + } + return DialWithInfo(&info) +} + +// DialInfo holds options for establishing a session with a MongoDB cluster. +// To use a URL, see the Dial function. +type DialInfo struct { + // Addrs holds the addresses for the seed servers. + Addrs []string + + // Direct informs whether to establish connections only with the + // specified seed servers, or to obtain information for the whole + // cluster and establish connections with further servers too. + Direct bool + + // Timeout is the amount of time to wait for a server to respond when + // first connecting and on follow up operations in the session. If + // timeout is zero, the call may block forever waiting for a connection + // to be established. + Timeout time.Duration + + // Database is the database name used during the initial authentication. + // If set, the value is also returned as the default result from the + // Session.DB method, in place of "test". + Database string + + // Username and Password inform the credentials for the initial + // authentication done against Database, if that is set, + // or the "admin" database otherwise. See the Session.Login method too. + Username string + Password string + + // Dial optionally specifies the dial function for establishing + // connections with the MongoDB servers. + Dial func(addr net.Addr) (net.Conn, error) + + // DialServer optionally specifies the dial function for establishing + // connections with the MongoDB servers. + // + // WARNING: This interface is experimental and may change. + DialServer func(addr *ServerAddr) (net.Conn, error) +} + +// ServerAddr represents the address for establishing a connection to an +// individual MongoDB server. +// +// WARNING: This interface is experimental and may change. +type ServerAddr struct { + str string + tcp *net.TCPAddr +} + +// String returns the address that was provided for the server before resolution. +func (addr *ServerAddr) String() string { + return addr.str +} + +// TCPAddr returns the resolved TCP address for the server. +func (addr *ServerAddr) TCPAddr() *net.TCPAddr { + return addr.tcp +} + +// DialWithInfo establishes a new session to the cluster identified by info. +func DialWithInfo(info *DialInfo) (*Session, error) { + addrs := make([]string, len(info.Addrs)) + for i, addr := range info.Addrs { + p := strings.LastIndexAny(addr, "]:") + if p == -1 || addr[p] != ':' { + // XXX This is untested. The test suite doesn't use the standard port. + addr += ":27017" + } + addrs[i] = addr + } + cluster := newCluster(addrs, info.Direct, dialer{info.Dial, info.DialServer}) + session := newSession(Eventual, cluster, info.Timeout) + session.defaultdb = info.Database + if session.defaultdb == "" { + session.defaultdb = "test" + } + if info.Username != "" { + db := info.Database + if db == "" { + db = "admin" + } + session.dialAuth = &authInfo{db, info.Username, info.Password} + session.auth = []authInfo{*session.dialAuth} } cluster.Release() @@ -229,105 +320,80 @@ func isOptSep(c rune) bool { return c == ';' || c == '&' } -func parseURL(url string) (servers []string, auth authInfo, options map[string]string, err error) { +type urlInfo struct { + addrs []string + user string + pass string + db string + options map[string]string +} + +func parseURL(url string) (*urlInfo, error) { if strings.HasPrefix(url, "mongodb://") { url = url[10:] } - options = make(map[string]string) + info := &urlInfo{options: make(map[string]string)} if c := strings.Index(url, "?"); c != -1 { for _, pair := range strings.FieldsFunc(url[c+1:], isOptSep) { l := strings.SplitN(pair, "=", 2) if len(l) != 2 || l[0] == "" || l[1] == "" { - err = errors.New("Connection option must be key=value: " + pair) - return + return nil, errors.New("Connection option must be key=value: " + pair) } - options[l[0]] = l[1] + info.options[l[0]] = l[1] } url = url[:c] } if c := strings.Index(url, "@"); c != -1 { pair := strings.SplitN(url[:c], ":", 2) if len(pair) != 2 || pair[0] == "" { - err = errors.New("Credentials must be provided as user:pass@host") - return + return nil, errors.New("Credentials must be provided as user:pass@host") } - auth.user = pair[0] - auth.pass = pair[1] + info.user = pair[0] + info.pass = pair[1] url = url[c+1:] - auth.db = "admin" } if c := strings.Index(url, "/"); c != -1 { - if c != len(url)-1 { - auth.db = url[c+1:] - } + info.db = url[c+1:] url = url[:c] } - if auth.user == "" { - if auth.db != "" { - err = errors.New("Database name only makes sense with credentials") - return - } - } else if auth.db == "" { - auth.db = "admin" - } - servers = strings.Split(url, ",") - // XXX This is untested. The test suite doesn't use the standard port. - for i, server := range servers { - p := strings.LastIndexAny(server, "]:") - if p == -1 || server[p] != ':' { - servers[i] = server + ":27017" - } - } - return + info.addrs = strings.Split(url, ",") + return info, nil } -func newSession(consistency mode, cluster *mongoCluster, socket *mongoSocket, syncTimeout time.Duration) (session *Session) { +func newSession(consistency mode, cluster *mongoCluster, timeout time.Duration) (session *Session) { cluster.Acquire() - session = &Session{cluster_: cluster, syncTimeout: syncTimeout} + session = &Session{cluster_: cluster, syncTimeout: timeout, sockTimeout: timeout} debugf("New session %p on cluster %p", session, cluster) session.SetMode(consistency, true) session.SetSafe(&Safe{}) - session.setSocket(socket) session.queryConfig.prefetch = defaultPrefetch - runtime.SetFinalizer(session, finalizeSession) return session } func copySession(session *Session, keepAuth bool) (s *Session) { cluster := session.cluster() cluster.Acquire() - if session.socket != nil { - session.socket.Acquire() + if session.masterSocket != nil { + session.masterSocket.Acquire() + } + if session.slaveSocket != nil { + session.slaveSocket.Acquire() } var auth []authInfo if keepAuth { auth = make([]authInfo, len(session.auth)) copy(auth, session.auth) - } else if session.urlauth != nil { - auth = []authInfo{*session.urlauth} - } - // Copy everything but the mutex. - s = &Session{ - cluster_: session.cluster_, - socket: session.socket, - socketIsMaster: session.socketIsMaster, - slaveOk: session.slaveOk, - consistency: session.consistency, - queryConfig: session.queryConfig, - safeOp: session.safeOp, - syncTimeout: session.syncTimeout, - urlauth: session.urlauth, - auth: auth, + } else if session.dialAuth != nil { + auth = []authInfo{*session.dialAuth} } + scopy := *session + scopy.m = sync.RWMutex{} + scopy.auth = auth + s = &scopy debugf("New session %p on cluster %p (copy from %p)", s, cluster, session) - runtime.SetFinalizer(s, finalizeSession) return s } -func finalizeSession(session *Session) { - session.Close() -} - // LiveServers returns a list of server addresses which are // currently known to be alive. func (s *Session) LiveServers() (addrs []string) { @@ -337,16 +403,24 @@ func (s *Session) LiveServers() (addrs []string) { return addrs } -// DB returns a value representing the named database. -// Creating this value is a very lightweight operation, and involves -// no network communication. +// DB returns a value representing the named database. If name +// is empty, the database name provided in the dialed URL is +// used instead. If that is also empty, "test" is used as a +// fallback in a way equivalent to the mongo shell. +// +// Creating this value is a very lightweight operation, and +// involves no network communication. func (s *Session) DB(name string) *Database { + if name == "" { + name = s.defaultdb + } return &Database{s, name} } // C returns a value representing the named collection. -// Creating this object is a very lightweight operation, and involves -// no network communication. +// +// Creating this value is a very lightweight operation, and +// involves no network communication. func (db *Database) C(name string) *Collection { return &Collection{db, name, db.Name + "." + name} } @@ -422,7 +496,7 @@ func (db *Database) Login(user, pass string) (err error) { session := db.Session dbname := db.Name - socket, err := session.acquireSocket(false) + socket, err := session.acquireSocket(true) if err != nil { return err } @@ -447,6 +521,15 @@ func (db *Database) Login(user, pass string) (err error) { return nil } +func (s *Session) socketLogin(socket *mongoSocket) error { + for _, a := range s.auth { + if err := socket.Login(a.db, a.user, a.pass); err != nil { + return err + } + } + return nil +} + // Logout removes any established authentication credentials for the database. func (db *Database) Logout() { session := db.Session @@ -461,8 +544,13 @@ func (db *Database) Logout() { break } } - if found && session.socket != nil { - session.socket.Logout(dbname) + if found { + if session.masterSocket != nil { + session.masterSocket.Logout(dbname) + } + if session.slaveSocket != nil { + session.slaveSocket.Logout(dbname) + } } session.m.Unlock() } @@ -471,14 +559,120 @@ func (db *Database) Logout() { func (s *Session) LogoutAll() { s.m.Lock() for _, a := range s.auth { - s.socket.Logout(a.db) + if s.masterSocket != nil { + s.masterSocket.Logout(a.db) + } + if s.slaveSocket != nil { + s.slaveSocket.Logout(a.db) + } } s.auth = s.auth[0:0] s.m.Unlock() } +// User represents a MongoDB user. +// +// Relevant documentation: +// +// http://docs.mongodb.org/manual/reference/privilege-documents/ +// http://docs.mongodb.org/manual/reference/user-privileges/ +// +type User struct { + // Username is how the user identifies itself to the system. + Username string `bson:"user"` + + // Password is the plaintext password for the user. If set, + // the UpsertUser method will hash it into PasswordHash and + // unset it before the user is added to the database. + Password string `bson:",omitempty"` + + // PasswordHash is the MD5 hash of Username+":mongo:"+Password. + PasswordHash string `bson:"pwd,omitempty"` + + // UserSource indicates where to look for this user's credentials. + // It may be set to a database name, or to "$external" for + // consulting an external resource such as Kerberos. UserSource + // must not be set if Password or PasswordHash are present. + UserSource string `bson:"userSource,omitempty"` + + // Roles indicates the set of roles the user will be provided. + // See the Role constants. + Roles []Role `bson:"roles"` + + // OtherDBRoles allows assigning roles in other databases from + // user documents inserted in the admin database. This field + // only works in the admin database. + OtherDBRoles map[string][]Role `bson:"otherDBRoles,omitempty"` +} + +type Role string + +const ( + // Relevant documentation: + // + // http://docs.mongodb.org/manual/reference/user-privileges/ + // + RoleRead Role = "read" + RoleReadAny Role = "readAnyDatabase" + RoleReadWrite Role = "readWrite" + RoleReadWriteAny Role = "readWriteAnyDatabase" + RoleDBAdmin Role = "dbAdmin" + RoleDBAdminAny Role = "dbAdminAnyDatabase" + RoleUserAdmin Role = "userAdmin" + RoleUserAdminAny Role = "UserAdminAnyDatabase" + RoleClusterAdmin Role = "clusterAdmin" +) + +// UpsertUser updates the authentication credentials and the roles for +// a MongoDB user within the db database. If the named user doesn't exist +// it will be created. +// +// This method should only be used from MongoDB 2.4 and on. For older +// MongoDB releases, use the obsolete AddUser method instead. +// +// Relevant documentation: +// +// http://docs.mongodb.org/manual/reference/user-privileges/ +// http://docs.mongodb.org/manual/reference/privilege-documents/ +// +func (db *Database) UpsertUser(user *User) error { + if user.Username == "" { + return fmt.Errorf("user has no Username") + } + if user.Password != "" { + psum := md5.New() + psum.Write([]byte(user.Username + ":mongo:" + user.Password)) + user.PasswordHash = hex.EncodeToString(psum.Sum(nil)) + user.Password = "" + } + if user.PasswordHash != "" && user.UserSource != "" { + return fmt.Errorf("user has both Password/PasswordHash and UserSource set") + } + if len(user.OtherDBRoles) > 0 && db.Name != "admin" { + return fmt.Errorf("user with OtherDBRoles is only supported in admin database") + } + var unset bson.D + if user.PasswordHash == "" { + unset = append(unset, bson.DocElem{"pwd", 1}) + } + if user.UserSource == "" { + unset = append(unset, bson.DocElem{"userSource", 1}) + } + // user.Roles is always sent, as it's the way MongoDB distinguishes + // old-style documents from new-style documents. + if len(user.OtherDBRoles) == 0 { + unset = append(unset, bson.DocElem{"otherDBRoles", 1}) + } + c := db.C("system.users") + _, err := c.Upsert(bson.D{{"user", user.Username}}, bson.D{{"$unset", unset}, {"$set", user}}) + return err +} + // AddUser creates or updates the authentication credentials of user within -// the database. +// the db database. +// +// This method is obsolete and should only be used with MongoDB 2.2 or +// earlier. For MongoDB 2.4 and on, use UpsertUser instead. func (db *Database) AddUser(user, pass string, readOnly bool) error { psum := md5.New() psum.Write([]byte(user + ":mongo:" + pass)) @@ -502,6 +696,7 @@ type indexSpec struct { Background bool ",omitempty" Sparse bool ",omitempty" Bits, Min, Max int ",omitempty" + ExpireAfter int "expireAfterSeconds,omitempty" } type Index struct { @@ -511,6 +706,8 @@ type Index struct { Background bool // Build index in background and return immediately Sparse bool // Only index documents containing the Key fields + ExpireAfter time.Duration // Periodically delete docs with indexed time.Time older than that. + Name string // Index name, computed by EnsureIndex Bits, Min, Max int // Properties for spatial indexes @@ -519,15 +716,30 @@ type Index struct { func parseIndexKey(key []string) (name string, realKey bson.D, err error) { var order interface{} for _, field := range key { + raw := field if name != "" { name += "_" } + var kind string if field != "" { + if field[0] == '$' { + if c := strings.Index(field, ":"); c > 1 && c < len(field)-1 { + kind = field[1:c] + field = field[c+1:] + name += field + "_" + kind + } + } switch field[0] { + case '$': + // Logic above failed. Reset and error. + field = "" case '@': order = "2d" field = field[1:] - name += field + "_" // Why don't they put 2d here? + // The shell used to render this field as key_ instead of key_2d, + // and mgo followed suit. This has been fixed in recent server + // releases, and mgo followed as well. + name += field + "_2d" case '-': order = -1 field = field[1:] @@ -536,12 +748,16 @@ func parseIndexKey(key []string) (name string, realKey bson.D, err error) { field = field[1:] fallthrough default: - order = 1 - name += field + "_1" + if kind == "" { + order = 1 + name += field + "_1" + } else { + order = kind + } } } - if field == "" { - return "", nil, errors.New("Invalid index key: empty field name") + if field == "" || kind != "" && order != kind { + return "", nil, fmt.Errorf(`Invalid index key: want "[$:][-]", got %q`, raw) } realKey = append(realKey, bson.DocElem{field, order}) } @@ -602,16 +818,21 @@ func (c *Collection) EnsureIndexKey(key ...string) error { // included in the index. When using a sparse index for sorting, only indexed // documents will be returned. // -// Spatial indexes are also supported through that API. Here is an example: +// If ExpireAfter is non-zero, the server will periodically scan the collection +// and remove documents containing an indexed time.Time field with a value +// older than ExpireAfter. See the documentation for details: +// +// http://docs.mongodb.org/manual/tutorial/expire-data +// +// Other kinds of indexes are also supported through that API. Here is an example: // // index := Index{ -// Key: []string{"@loc"}, +// Key: []string{"$2d:loc"}, // Bits: 26, // } // err := collection.EnsureIndex(index) // -// The "@" prefix in the field name will request the creation of a "2d" index -// for the given field. +// The example above requests the creation of a "2d" index for the "loc" field. // // The 2D index bounds may be changed using the Min and Max attributes of the // Index value. The default bound setting of (-180, 180) is suitable for @@ -642,16 +863,17 @@ func (c *Collection) EnsureIndex(index Index) error { } spec := indexSpec{ - Name: name, - NS: c.FullName, - Key: realKey, - Unique: index.Unique, - DropDups: index.DropDups, - Background: index.Background, - Sparse: index.Sparse, - Bits: index.Bits, - Min: index.Min, - Max: index.Max, + Name: name, + NS: c.FullName, + Key: realKey, + Unique: index.Unique, + DropDups: index.DropDups, + Background: index.Background, + Sparse: index.Sparse, + Bits: index.Bits, + Min: index.Min, + Max: index.Max, + ExpireAfter: int(index.ExpireAfter / time.Second), } session = session.Clone() @@ -714,12 +936,12 @@ func (c *Collection) DropIndex(key ...string) error { // // indexes, err := collection.Indexes() // if err != nil { -// panic(err) +// return err // } // for _, index := range indexes { // err = collection.DropIndex(index.Key...) // if err != nil { -// panic(err) +// return err // } // } // @@ -733,34 +955,38 @@ func (c *Collection) Indexes() (indexes []Index, err error) { break } index := Index{ - Name: spec.Name, - Key: simpleIndexKey(spec.Key), - Unique: spec.Unique, - DropDups: spec.DropDups, - Background: spec.Background, - Sparse: spec.Sparse, + Name: spec.Name, + Key: simpleIndexKey(spec.Key), + Unique: spec.Unique, + DropDups: spec.DropDups, + Background: spec.Background, + Sparse: spec.Sparse, + ExpireAfter: time.Duration(spec.ExpireAfter) * time.Second, } indexes = append(indexes, index) } - err = iter.Err() + err = iter.Close() return } func simpleIndexKey(realKey bson.D) (key []string) { for i := range realKey { field := realKey[i].Name - i, _ := realKey[i].Value.(int) - if i == 1 { + vi, ok := realKey[i].Value.(int) + if !ok { + vf, _ := realKey[i].Value.(float64) + vi = int(vf) + } + if vi == 1 { key = append(key, field) continue } - if i == -1 { + if vi == -1 { key = append(key, "-"+field) continue } - s, _ := realKey[i].Value.(string) - if s == "2d" { - key = append(key, "@"+field) + if vs, ok := realKey[i].Value.(string); ok { + key = append(key, "$"+vs+":"+field) continue } panic("Got unknown index key type for field " + field) @@ -776,7 +1002,7 @@ func (s *Session) ResetIndexCache() { // New creates a new session with the same parameters as the original // session, including consistency, batch size, prefetching, safety mode, -// etc. The returned session will use sockets from the poll, so there's +// etc. The returned session will use sockets from the pool, so there's // a chance that writes just performed in another session may not yet // be visible. // @@ -823,7 +1049,7 @@ func (s *Session) Close() { s.m.Lock() if s.cluster_ != nil { debugf("Closing session %p", s) - s.setSocket(nil) + s.unsetSocket() s.cluster_.Release() s.cluster_ = nil } @@ -842,7 +1068,7 @@ func (s *Session) cluster() *mongoCluster { func (s *Session) Refresh() { s.m.Lock() s.slaveOk = s.consistency != Strong - s.setSocket(nil) + s.unsetSocket() s.m.Unlock() } @@ -888,14 +1114,14 @@ func (s *Session) Refresh() { // connection is unsuitable (to a slave server in a Strong session). func (s *Session) SetMode(consistency mode, refresh bool) { s.m.Lock() - debugf("Session %p: setting mode %d with refresh=%v (socket=%p)", s, consistency, refresh, s.socket) + debugf("Session %p: setting mode %d with refresh=%v (master=%p, slave=%p)", s, consistency, refresh, s.masterSocket, s.slaveSocket) s.consistency = consistency if refresh { s.slaveOk = s.consistency != Strong - s.setSocket(nil) + s.unsetSocket() } else if s.consistency == Strong { s.slaveOk = false - } else if s.socket == nil { + } else if s.masterSocket == nil { s.slaveOk = true } s.m.Unlock() @@ -919,6 +1145,33 @@ func (s *Session) SetSyncTimeout(d time.Duration) { s.m.Unlock() } +// SetSocketTimeout sets the amount of time to wait for a non-responding +// socket to the database before it is forcefully closed. +func (s *Session) SetSocketTimeout(d time.Duration) { + s.m.Lock() + s.sockTimeout = d + if s.masterSocket != nil { + s.masterSocket.SetTimeout(d) + } + if s.slaveSocket != nil { + s.slaveSocket.SetTimeout(d) + } + s.m.Unlock() +} + +// SetCursorTimeout changes the standard timeout period that the server +// enforces on created cursors. The only supported value right now is +// 0, which disables the timeout. The standard server timeout is 10 minutes. +func (s *Session) SetCursorTimeout(d time.Duration) { + s.m.Lock() + if d == 0 { + s.queryConfig.op.flags |= flagNoCursorTimeout + } else { + panic("SetCursorTimeout: only 0 (disable timeout) supported for now") + } + s.m.Unlock() +} + // SetBatch sets the default batch size used when fetching documents from the // database. It's possible to change this setting on a per-query basis as // well, using the Query.Batch method. @@ -927,6 +1180,10 @@ func (s *Session) SetSyncTimeout(d time.Duration) { // writing, MongoDB will use an initial size of min(100 docs, 4MB) on the // first batch, and 4MB on remaining ones. func (s *Session) SetBatch(n int) { + if n == 1 { + // Server interprets 1 as -1 and closes the cursor (!?) + n = 2 + } s.m.Lock() s.queryConfig.op.limit = int32(n) s.m.Unlock() @@ -1146,6 +1403,30 @@ func (s *Session) Run(cmd interface{}, result interface{}) error { return s.DB("admin").Run(cmd, result) } +// SelectServers restricts communication to servers configured with the +// given tags. For example, the following statement restricts servers +// used for reading operations to those with both tag "disk" set to +// "ssd" and tag "rack" set to 1: +// +// session.SelectSlaves(bson.D{{"disk", "ssd"}, {"rack", 1}}) +// +// Multiple sets of tags may be provided, in which case the used server +// must match all tags within any one set. +// +// If a connection was previously assigned to the session due to the +// current session mode (see Session.SetMode), the tag selection will +// only be enforced after the session is refreshed. +// +// Relevant documentation: +// +// http://docs.mongodb.org/manual/tutorial/configure-replica-set-tag-sets +// +func (s *Session) SelectServers(tags ...bson.D) { + s.m.Lock() + s.queryConfig.op.serverTags = tags + s.m.Unlock() +} + // Ping runs a trivial ping command just to get in touch with the server. func (s *Session) Ping() error { return s.Run("ping", nil) @@ -1224,8 +1505,7 @@ func (c *Collection) Find(query interface{}) *Query { return q } -// FindId prepares a query to find a document by its _id field. -// It is a convenience helper equivalent to: +// FindId is a convenience helper equivalent to: // // query := collection.Find(bson.M{"_id": id}) // @@ -1234,13 +1514,81 @@ func (c *Collection) FindId(id interface{}) *Query { return c.Find(bson.D{{"_id", id}}) } +type Pipe struct { + session *Session + collection *Collection + pipeline interface{} +} + +// Pipe prepares a pipeline to aggregate. The pipeline document +// must be a slice built in terms of the aggregation framework language. +// +// For example: +// +// pipe := collection.Pipe([]bson.M{{"$match": bson.M{"name": "Otavio"}}}) +// iter := pipe.Iter() +// +// Relevant documentation: +// +// http://docs.mongodb.org/manual/reference/aggregation +// http://docs.mongodb.org/manual/applications/aggregation +// http://docs.mongodb.org/manual/tutorial/aggregation-examples +// +func (c *Collection) Pipe(pipeline interface{}) *Pipe { + session := c.Database.Session + return &Pipe{ + session: session, + collection: c, + pipeline: pipeline, + } +} + +// Iter executes the pipeline and returns an iterator capable of going +// over all the generated results. +func (p *Pipe) Iter() *Iter { + iter := &Iter{ + session: p.session, + timeout: -1, + } + iter.gotReply.L = &iter.m + var result struct{ Result []bson.Raw } + c := p.collection + iter.err = c.Database.Run(bson.D{{"aggregate", c.Name}, {"pipeline", p.pipeline}}, &result) + if iter.err != nil { + return iter + } + for i := range result.Result { + iter.docData.Push(result.Result[i].Data) + } + return iter +} + +// All works like Iter.All. +func (p *Pipe) All(result interface{}) error { + return p.Iter().All(result) +} + +// One executes the pipeline and unmarshals the first item from the +// result set into the result parameter. +// It returns ErrNotFound if no items are generated by the pipeline. +func (p *Pipe) One(result interface{}) error { + iter := p.Iter() + if iter.Next(result) { + return nil + } + if err := iter.Err(); err != nil { + return err + } + return ErrNotFound +} + type LastError struct { Err string Code, N, Waited int - FSyncFiles int "fsyncFiles" + FSyncFiles int `bson:"fsyncFiles"` WTimeout bool - UpdatedExisting bool "updatedExisting" - UpsertedId interface{} "upserted" + UpdatedExisting bool `bson:"updatedExisting"` + UpsertedId interface{} `bson:"upserted"` } func (err *LastError) Error() string { @@ -1266,12 +1614,27 @@ func (err *QueryError) Error() string { return err.Message } +// IsDup returns whether err informs of a duplicate key error because +// a primary key index or a secondary unique index already has an entry +// with the given value. +func IsDup(err error) bool { + // Besides being handy, helps with https://jira.mongodb.org/browse/SERVER-7164 + // What follows makes me sad. Hopefully conventions will be more clear over time. + switch e := err.(type) { + case *LastError: + return e.Code == 11000 || e.Code == 11001 || e.Code == 12582 + case *QueryError: + return e.Code == 11000 || e.Code == 11001 || e.Code == 12582 + } + return false +} + // Insert inserts one or more documents in the respective collection. In // case the session is in safe mode (see the SetSafe method) and an error // happens while inserting the provided documents, the returned error will // be of type *LastError. func (c *Collection) Insert(docs ...interface{}) error { - _, err := c.Database.Session.writeQuery(&insertOp{c.FullName, docs}) + _, err := c.writeQuery(&insertOp{c.FullName, docs}) return err } @@ -1287,14 +1650,22 @@ func (c *Collection) Insert(docs ...interface{}) error { // http://www.mongodb.org/display/DOCS/Atomic+Operations // func (c *Collection) Update(selector interface{}, change interface{}) error { - session := c.Database.Session - lerr, err := session.writeQuery(&updateOp{c.FullName, selector, change, 0}) + lerr, err := c.writeQuery(&updateOp{c.FullName, selector, change, 0}) if err == nil && lerr != nil && !lerr.UpdatedExisting { return ErrNotFound } return err } +// UpdateId is a convenience helper equivalent to: +// +// err := collection.Update(bson.M{"_id": id}, change) +// +// See the Update method for more details. +func (c *Collection) UpdateId(id interface{}, change interface{}) error { + return c.Update(bson.D{{"_id", id}}, change) +} + // ChangeInfo holds details about the outcome of a change operation. type ChangeInfo struct { Updated int // Number of existing documents updated @@ -1315,8 +1686,7 @@ type ChangeInfo struct { // http://www.mongodb.org/display/DOCS/Atomic+Operations // func (c *Collection) UpdateAll(selector interface{}, change interface{}) (info *ChangeInfo, err error) { - session := c.Database.Session - lerr, err := session.writeQuery(&updateOp{c.FullName, selector, change, 2}) + lerr, err := c.writeQuery(&updateOp{c.FullName, selector, change, 2}) if err == nil && lerr != nil { info = &ChangeInfo{Updated: lerr.N} } @@ -1342,8 +1712,7 @@ func (c *Collection) Upsert(selector interface{}, change interface{}) (info *Cha return nil, err } change = bson.Raw{0x03, data} - session := c.Database.Session - lerr, err := session.writeQuery(&updateOp{c.FullName, selector, change, 1}) + lerr, err := c.writeQuery(&updateOp{c.FullName, selector, change, 1}) if err == nil && lerr != nil { info = &ChangeInfo{} if lerr.UpdatedExisting { @@ -1355,6 +1724,15 @@ func (c *Collection) Upsert(selector interface{}, change interface{}) (info *Cha return info, err } +// UpsertId is a convenience helper equivalent to: +// +// info, err := collection.Upsert(bson.M{"_id": id}, change) +// +// See the Upsert method for more details. +func (c *Collection) UpsertId(id interface{}, change interface{}) (info *ChangeInfo, err error) { + return c.Upsert(bson.D{{"_id", id}}, change) +} + // Remove finds a single document matching the provided selector document // and removes it from the database. // If the session is in safe mode (see SetSafe) a ErrNotFound error is @@ -1366,14 +1744,22 @@ func (c *Collection) Upsert(selector interface{}, change interface{}) (info *Cha // http://www.mongodb.org/display/DOCS/Removing // func (c *Collection) Remove(selector interface{}) error { - session := c.Database.Session - lerr, err := session.writeQuery(&deleteOp{c.FullName, selector, 1}) + lerr, err := c.writeQuery(&deleteOp{c.FullName, selector, 1}) if err == nil && lerr != nil && lerr.N == 0 { return ErrNotFound } return err } +// RemoveId is a convenience helper equivalent to: +// +// err := collection.Remove(bson.M{"_id": id}) +// +// See the Remove method for more details. +func (c *Collection) RemoveId(id interface{}) error { + return c.Remove(bson.D{{"_id", id}}) +} + // RemoveAll finds all documents matching the provided selector document // and removes them from the database. In case the session is in safe mode // (see the SetSafe method) and an error happens when attempting the change, @@ -1384,8 +1770,7 @@ func (c *Collection) Remove(selector interface{}) error { // http://www.mongodb.org/display/DOCS/Removing // func (c *Collection) RemoveAll(selector interface{}) (info *ChangeInfo, err error) { - session := c.Database.Session - lerr, err := session.writeQuery(&deleteOp{c.FullName, selector, 0}) + lerr, err := c.writeQuery(&deleteOp{c.FullName, selector, 0}) if err == nil && lerr != nil { info = &ChangeInfo{Removed: lerr.N} } @@ -1469,6 +1854,10 @@ func (c *Collection) Create(info *CollectionInfo) error { // writing, MongoDB will use an initial size of min(100 docs, 4MB) on the // first batch, and 4MB on remaining ones. func (q *Query) Batch(n int) *Query { + if n == 1 { + // Server interprets 1 as -1 and closes the cursor (!?) + n = 2 + } q.m.Lock() q.op.limit = int32(n) q.m.Unlock() @@ -1542,28 +1931,6 @@ func (q *Query) Select(selector interface{}) *Query { return q } -type queryWrapper struct { - Query interface{} "$query" - OrderBy interface{} "$orderby,omitempty" - Hint interface{} "$hint,omitempty" - Explain bool "$explain,omitempty" - Snapshot bool "$snapshot,omitempty" -} - -func (q *Query) wrap() *queryWrapper { - w, ok := q.op.query.(*queryWrapper) - if !ok { - if q.op.query == nil { - var empty bson.D - w = &queryWrapper{Query: empty} - } else { - w = &queryWrapper{Query: q.op.query} - } - q.op.query = w - } - return w -} - // Sort asks the database to order returned documents according to the // provided field names. A field name may be prefixed by - (minus) for // it to be sorted in reverse order. @@ -1580,7 +1947,6 @@ func (q *Query) wrap() *queryWrapper { // func (q *Query) Sort(fields ...string) *Query { q.m.Lock() - w := q.wrap() var order bson.D for _, field := range fields { n := 1 @@ -1598,7 +1964,8 @@ func (q *Query) Sort(fields ...string) *Query { } order = append(order, bson.DocElem{field, n}) } - w.OrderBy = order + q.op.options.OrderBy = order + q.op.hasOptions = true q.m.Unlock() return q } @@ -1620,13 +1987,13 @@ func (q *Query) Sort(fields ...string) *Query { // // http://www.mongodb.org/display/DOCS/Optimization // http://www.mongodb.org/display/DOCS/Query+Optimizer -// +// func (q *Query) Explain(result interface{}) error { q.m.Lock() clone := &Query{session: q.session, query: q.query} q.m.Unlock() - w := clone.wrap() - w.Explain = true + clone.op.options.Explain = true + clone.op.hasOptions = true if clone.op.limit > 0 { clone.op.limit = -q.op.limit } @@ -1634,7 +2001,7 @@ func (q *Query) Explain(result interface{}) error { if iter.Next(result) { return nil } - return iter.Err() + return iter.Close() } // Hint will include an explicit "hint" in the query to force the server @@ -1656,8 +2023,8 @@ func (q *Query) Explain(result interface{}) error { func (q *Query) Hint(indexKey ...string) *Query { q.m.Lock() _, realKey, err := parseIndexKey(indexKey) - w := q.wrap() - w.Hint = realKey + q.op.options.Hint = realKey + q.op.hasOptions = true q.m.Unlock() if err != nil { panic(err) @@ -1690,8 +2057,19 @@ func (q *Query) Hint(indexKey ...string) *Query { // func (q *Query) Snapshot() *Query { q.m.Lock() - w := q.wrap() - w.Snapshot = true + q.op.options.Snapshot = true + q.op.hasOptions = true + q.m.Unlock() + return q +} + +// LogReplay enables an option that optimizes queries that are typically +// made against the MongoDB oplog for replaying it. This is an internal +// implementation aspect and most likely uninteresting for other uses. +// It has seen at least one use case, though, so it's exposed via the API. +func (q *Query) LogReplay() *Query { + q.m.Lock() + q.op.flags |= flagLogReplay q.m.Unlock() return q } @@ -1717,6 +2095,7 @@ func checkQueryError(fullname string, d []byte) error { Error: result := &queryError{} bson.Unmarshal(d, result) + logf("queryError: %#v\n", result) if result.LastError != nil { return result.LastError } @@ -1787,7 +2166,7 @@ func (q *Query) One(result interface{}) (err error) { // optionally a database name. // // See the FindRef methods on Session and on Database. -// +// // Relevant documentation: // // http://www.mongodb.org/display/DOCS/Database+References @@ -1807,7 +2186,7 @@ type DBRef struct { // See also the DBRef type and the FindRef method on Session. // // Relevant documentation: -// +// // http://www.mongodb.org/display/DOCS/Database+References // func (db *Database) FindRef(ref *DBRef) *Query { @@ -1827,7 +2206,7 @@ func (db *Database) FindRef(ref *DBRef) *Query { // See also the DBRef type and the FindRef method on Database. // // Relevant documentation: -// +// // http://www.mongodb.org/display/DOCS/Database+References // func (s *Session) FindRef(ref *DBRef) *Query { @@ -1841,17 +2220,17 @@ func (s *Session) FindRef(ref *DBRef) *Query { // CollectionNames returns the collection names present in database. func (db *Database) CollectionNames() (names []string, err error) { c := len(db.Name) + 1 + iter := db.C("system.namespaces").Find(nil).Iter() var result *struct{ Name string } - err = db.C("system.namespaces").Find(nil).For(&result, func() error { + for iter.Next(&result) { if strings.Index(result.Name, "$") < 0 || strings.Index(result.Name, ".oplog.$") >= 0 { names = append(names, result.Name[c:]) } - return nil - }) - if err != nil { + } + if err := iter.Close(); err != nil { return nil, err } - sort.StringSlice(names).Sort() + sort.Strings(names) return names, nil } @@ -1874,6 +2253,7 @@ func (s *Session) DatabaseNames() (names []string, err error) { names = append(names, db.Name) } } + sort.Strings(names) return names, nil } @@ -1889,12 +2269,17 @@ func (q *Query) Iter() *Iter { limit := q.limit q.m.Unlock() - iter := &Iter{session: session, prefetch: prefetch, limit: limit} + iter := &Iter{ + session: session, + prefetch: prefetch, + limit: limit, + timeout: -1, + } iter.gotReply.L = &iter.m iter.op.collection = op.collection iter.op.limit = op.limit iter.op.replyFunc = iter.replyFunc() - iter.pendingDocs++ + iter.docsToReceive++ op.replyFunc = iter.op.replyFunc op.flags |= session.slaveOkFlag() @@ -1903,6 +2288,7 @@ func (q *Query) Iter() *Iter { iter.err = err } else { iter.err = socket.Query(&op) + iter.server = socket.Server() socket.Release() } return iter @@ -1938,13 +2324,13 @@ func (q *Query) Iter() *Iter { // fmt.Println(result.Id) // lastId = result.Id // } -// if iter.Err() != nil { -// panic(err) +// if err := iter.Close(); err != nil { +// return err // } // if iter.Timeout() { // continue // } -// query := collection.Find(bson.M{"_id", bson.M{"$gt", lastId}}) +// query := collection.Find(bson.M{"_id": bson.M{"$gt": lastId}}) // iter = query.Sort("$natural").Tail(5 * time.Second) // } // @@ -1967,24 +2353,25 @@ func (q *Query) Tail(timeout time.Duration) *Iter { iter.op.collection = op.collection iter.op.limit = op.limit iter.op.replyFunc = iter.replyFunc() - iter.pendingDocs++ + iter.docsToReceive++ op.replyFunc = iter.op.replyFunc - op.flags |= 2 | 32 | session.slaveOkFlag() // Tailable | AwaitData [| SlaveOk] + op.flags |= flagTailable | flagAwaitData | session.slaveOkFlag() socket, err := session.acquireSocket(true) if err != nil { iter.err = err } else { iter.err = socket.Query(&op) + iter.server = socket.Server() socket.Release() } return iter } -func (s *Session) slaveOkFlag() (flag uint32) { +func (s *Session) slaveOkFlag() (flag queryOpFlags) { s.m.RLock() if s.slaveOk { - flag = 4 + flag = flagSlaveOk } s.m.RUnlock() return @@ -2006,6 +2393,49 @@ func (iter *Iter) Err() error { return err } +// Close kills the server cursor used by the iterator, if any, and returns +// nil if no errors happened during iteration, or the actual error otherwise. +// +// Server cursors are automatically closed at the end of an iteration, which +// means close will do nothing unless the iteration was interrupted before +// the server finished sending results to the driver. If Close is not called +// in such a situation, the cursor will remain available at the server until +// the default cursor timeout period is reached. No further problems arise. +// +// Close is idempotent. That means it can be called repeatedly and will +// return the same result every time. +// +// In case a resulting document included a field named $err or errmsg, which are +// standard ways for MongoDB to report an improper query, the returned value has +// a *QueryError type. +func (iter *Iter) Close() error { + iter.m.Lock() + iter.killCursor() + err := iter.err + iter.m.Unlock() + if err == ErrNotFound { + return nil + } + return err +} + +func (iter *Iter) killCursor() error { + if iter.op.cursorId != 0 { + socket, err := iter.acquireSocket() + if err == nil { + // TODO Batch kills. + err = socket.Query(&killCursorsOp{[]int64{iter.op.cursorId}}) + socket.Release() + } + if err != nil && (iter.err == nil || iter.err == ErrNotFound) { + iter.err = err + } + iter.op.cursorId = 0 + return err + } + return nil +} + // Timeout returns true if Next returned false due to a timeout of // a tailable cursor. In those cases, Next may be called again to continue // the iteration at the previous cursor position. @@ -2033,27 +2463,25 @@ func (iter *Iter) Timeout() bool { // for iter.Next(&result) { // fmt.Printf("Result: %v\n", result.Id) // } -// if iter.Err() != nil { -// panic(iter.Err()) +// if err := iter.Close(); err != nil { +// return err // } // func (iter *Iter) Next(result interface{}) bool { - timeouts := false - timeout := time.Time{} - if iter.timeout >= 0 { - timeouts = true - timeout = time.Now().Add(iter.timeout) - } - iter.m.Lock() iter.timedout = false - for iter.err == nil && iter.docData.Len() == 0 && (iter.pendingDocs > 0 || iter.op.cursorId != 0) { - if iter.pendingDocs == 0 && iter.op.cursorId != 0 { - // Tailable cursor exhausted. - if timeouts && time.Now().After(timeout) { - iter.timedout = true - iter.m.Unlock() - return false + timeout := time.Time{} + for iter.err == nil && iter.docData.Len() == 0 && (iter.docsToReceive > 0 || iter.op.cursorId != 0) { + if iter.docsToReceive == 0 { + if iter.timeout >= 0 { + if timeout.IsZero() { + timeout = time.Now().Add(iter.timeout) + } + if time.Now().After(timeout) { + iter.timedout = true + iter.m.Unlock() + return false + } } iter.getMore() } @@ -2062,16 +2490,23 @@ func (iter *Iter) Next(result interface{}) bool { // Exhaust available data before reporting any errors. if docData, ok := iter.docData.Pop().([]byte); ok { - iter.limit-- - if iter.limit == 0 { - // XXX Must kill the cursor here. - iter.err = ErrNotFound + if iter.limit > 0 { + iter.limit-- + if iter.limit == 0 { + if iter.docData.Len() > 0 { + panic(fmt.Errorf("data remains after limit exhausted: %d", iter.docData.Len())) + } + iter.err = ErrNotFound + if iter.killCursor() != nil { + return false + } + } } if iter.op.cursorId != 0 && iter.err == nil { - iter.docsBeforeMore-- if iter.docsBeforeMore == 0 { iter.getMore() } + iter.docsBeforeMore-- // Goes negative. } iter.m.Unlock() err := bson.Unmarshal(docData, result) @@ -2084,7 +2519,11 @@ func (iter *Iter) Next(result interface{}) bool { // XXX Only have to check first document for a query error? err = checkQueryError(iter.op.collection, docData) if err != nil { - iter.err = err + iter.m.Lock() + if iter.err == nil { + iter.err = err + } + iter.m.Unlock() return false } return true @@ -2102,7 +2541,8 @@ func (iter *Iter) Next(result interface{}) bool { panic("unreachable") } -// All retrieves all documents from the result set into the provided slice. +// All retrieves all documents from the result set into the provided slice +// and closes the iterator. // // The result argument must necessarily be the address for a slice. The slice // may be nil or previously allocated. @@ -2111,14 +2551,14 @@ func (iter *Iter) Next(result interface{}) bool { // potentially large, since it may consume all memory until the system // crashes. Consider building the query with a Limit clause to ensure the // result size is bounded. -// +// // For instance: // // var result []struct{ Value int } // iter := collection.Find(nil).Limit(100).Iter() // err := iter.All(&result) // if err != nil { -// panic(iter.Err()) +// return err // } // func (iter *Iter) All(result interface{}) error { @@ -2146,7 +2586,7 @@ func (iter *Iter) All(result interface{}) error { i++ } resultv.Elem().Set(slicev.Slice(0, i)) - return iter.Err() + return iter.Close() } // All works like Iter.All. @@ -2189,8 +2629,35 @@ func (iter *Iter) For(result interface{}, f func() error) (err error) { return iter.Err() } -func (iter *Iter) getMore() { +func (iter *Iter) acquireSocket() (*mongoSocket, error) { socket, err := iter.session.acquireSocket(true) + if err != nil { + return nil, err + } + if socket.Server() != iter.server { + // Socket server changed during iteration. This may happen + // with Eventual sessions, if a Refresh is done, or if a + // monotonic session gets a write and shifts from secondary + // to primary. Our cursor is in a specific server, though. + iter.session.m.Lock() + sockTimeout := iter.session.sockTimeout + iter.session.m.Unlock() + socket.Release() + socket, _, err = iter.server.AcquireSocket(0, sockTimeout) + if err != nil { + return nil, err + } + err := iter.session.socketLogin(socket) + if err != nil { + socket.Release() + return nil, err + } + } + return socket, nil +} + +func (iter *Iter) getMore() { + socket, err := iter.acquireSocket() if err != nil { iter.err = err return @@ -2198,17 +2665,16 @@ func (iter *Iter) getMore() { defer socket.Release() debugf("Iter %p requesting more documents", iter) - iter.pendingDocs++ - if iter.limit > 0 && iter.op.limit > iter.limit { - iter.op.limit = iter.limit + if iter.limit > 0 { + limit := iter.limit - int32(iter.docsToReceive) - int32(iter.docData.Len()) + if limit < iter.op.limit { + iter.op.limit = limit + } } - if iter.op.limit == 1 { - iter.op.limit = -1 - } - err = socket.Query(&iter.op) - if err != nil { + if err := socket.Query(&iter.op); err != nil { iter.err = err } + iter.docsToReceive++ } type countCmd struct { @@ -2234,13 +2700,8 @@ func (q *Query) Count() (n int, err error) { dbname := op.collection[:c] cname := op.collection[c+1:] - qdoc := op.query - if wrapper, ok := qdoc.(*queryWrapper); ok { - qdoc = wrapper.Query - } - result := struct{ N int }{} - err = session.DB(dbname).Run(countCmd{cname, qdoc, limit, op.skip}, &result) + err = session.DB(dbname).Run(countCmd{cname, op.query, limit, op.skip}, &result) return result.N, err } @@ -2282,13 +2743,8 @@ func (q *Query) Distinct(key string, result interface{}) error { dbname := op.collection[:c] cname := op.collection[c+1:] - qdoc := op.query - if wrapper, ok := qdoc.(*queryWrapper); ok { - qdoc = wrapper.Query - } - var doc struct{ Values bson.Raw } - err := session.DB(dbname).Run(distinctCmd{cname, key, qdoc}, &doc) + err := session.DB(dbname).Run(distinctCmd{cname, key, op.query}, &doc) if err != nil { return err } @@ -2390,12 +2846,12 @@ type MapReduceTime struct { // var result []struct { Id int "_id"; Value int } // _, err := collection.Find(nil).MapReduce(job, &result) // if err != nil { -// panic(err) +// return err // } // for _, item := range result { // fmt.Println(item.Value) // } -// +// // This function is compatible with MongoDB 1.7.4+. // // Relevant documentation: @@ -2417,23 +2873,16 @@ func (q *Query) MapReduce(job *MapReduce, result interface{}) (info *MapReduceIn dbname := op.collection[:c] cname := op.collection[c+1:] - qdoc := op.query - var sort interface{} - if wrapper, ok := qdoc.(*queryWrapper); ok { - qdoc = wrapper.Query - sort = wrapper.OrderBy - } - cmd := mapReduceCmd{ Collection: cname, Map: job.Map, Reduce: job.Reduce, Finalize: job.Finalize, - Out: job.Out, + Out: fixMROut(job.Out), Scope: job.Scope, Verbose: job.Verbose, - Query: qdoc, - Sort: sort, + Query: op.query, + Sort: op.options.OrderBy, Limit: limit, } @@ -2483,6 +2932,36 @@ func (q *Query) MapReduce(job *MapReduce, result interface{}) (info *MapReduceIn return info, nil } +// The "out" option in the MapReduce command must be ordered. This was +// found after the implementation was accepting maps for a long time, +// so rather than breaking the API, we'll fix the order if necessary. +// Details about the order requirement may be seen in MongoDB's code: +// +// http://goo.gl/L8jwJX +// +func fixMROut(out interface{}) interface{} { + outv := reflect.ValueOf(out) + if outv.Kind() != reflect.Map || outv.Type().Key() != reflect.TypeOf("") { + return out + } + outs := make(bson.D, outv.Len()) + + outTypeIndex := -1 + for i, k := range outv.MapKeys() { + ks := k.String() + outs[i].Name = ks + outs[i].Value = outv.MapIndex(k).Interface() + switch ks { + case "normal", "replace", "merge", "reduce", "inline": + outTypeIndex = i + } + } + if outTypeIndex > 0 { + outs[0], outs[outTypeIndex] = outs[outTypeIndex], outs[0] + } + return outs +} + type Change struct { Update interface{} // The change document Upsert bool // Whether to insert in case the document isn't found @@ -2520,6 +2999,8 @@ type valueResult struct { // info, err = col.Find(M{"_id": id}).Apply(change, &doc) // fmt.Println(doc.N) // +// This method depends on MongoDB >= 2.0 to work properly. +// // Relevant documentation: // // http://www.mongodb.org/display/DOCS/findAndModify+Command @@ -2540,24 +3021,21 @@ func (q *Query) Apply(change Change, result interface{}) (info *ChangeInfo, err dbname := op.collection[:c] cname := op.collection[c+1:] - qdoc := op.query - var sort interface{} - if wrapper, ok := qdoc.(*queryWrapper); ok { - qdoc = wrapper.Query - sort = wrapper.OrderBy - } - cmd := findModifyCmd{ Collection: cname, Update: change.Update, Upsert: change.Upsert, Remove: change.Remove, New: change.ReturnNew, - Query: qdoc, - Sort: sort, + Query: op.query, + Sort: op.options.OrderBy, Fields: op.selector, } + session = session.Clone() + defer session.Close() + session.SetMode(Strong, false) + var doc valueResult err = session.DB(dbname).Run(&cmd, &doc) if err != nil { @@ -2566,12 +3044,14 @@ func (q *Query) Apply(change Change, result interface{}) (info *ChangeInfo, err } return nil, err } - if doc.Value.Kind == 0x0A { + if doc.LastError.N == 0 { return nil, ErrNotFound } - err = doc.Value.Unmarshal(result) - if err != nil { - return nil, err + if doc.Value.Kind != 0x0A { + err = doc.Value.Unmarshal(result) + if err != nil { + return nil, err + } } info = &ChangeInfo{} lerr := &doc.LastError @@ -2624,51 +3104,54 @@ func (s *Session) BuildInfo() (info BuildInfo, err error) { func (s *Session) acquireSocket(slaveOk bool) (*mongoSocket, error) { - // Try to use a previously reserved socket, with a fast read-only lock. + // Read-only lock to check for previously reserved socket. s.m.RLock() - sock := s.socket - sockIsGood := sock != nil && (slaveOk && s.slaveOk || s.socketIsMaster) - s.m.RUnlock() - - if sockIsGood { - sock.Acquire() - return sock, nil + if s.masterSocket != nil { + socket := s.masterSocket + socket.Acquire() + s.m.RUnlock() + return socket, nil } + if s.slaveSocket != nil && s.slaveOk && slaveOk { + socket := s.slaveSocket + socket.Acquire() + s.m.RUnlock() + return socket, nil + } + s.m.RUnlock() // No go. We may have to request a new socket and change the session, // so try again but with an exclusive lock now. s.m.Lock() defer s.m.Unlock() - sock = s.socket - sockIsGood = sock != nil && (slaveOk && s.slaveOk || s.socketIsMaster) - - if sockIsGood { - sock.Acquire() - return sock, nil + if s.masterSocket != nil { + s.masterSocket.Acquire() + return s.masterSocket, nil + } + if s.slaveSocket != nil && s.slaveOk && slaveOk { + s.slaveSocket.Acquire() + return s.slaveSocket, nil } // Still not good. We need a new socket. - sock, err := s.cluster().AcquireSocket(slaveOk && s.slaveOk, s.syncTimeout) + sock, err := s.cluster().AcquireSocket(slaveOk && s.slaveOk, s.syncTimeout, s.sockTimeout, s.queryConfig.op.serverTags) if err != nil { return nil, err } // Authenticate the new socket. - for _, a := range s.auth { - err = sock.Login(a.db, a.user, a.pass) - if err != nil { - sock.Release() - return nil, err - } + if err = s.socketLogin(sock); err != nil { + sock.Release() + return nil, err } // Keep track of the new socket, if necessary. // Note that, as a special case, if the Eventual session was - // not refreshed (socket != nil), it means the developer asked - // to preserve an existing reserved socket, so we'll keep the - // master one around too before a Refresh happens. - if s.consistency != Eventual || s.socket != nil { + // not refreshed (s.slaveSocket != nil), it means the developer + // asked to preserve an existing reserved socket, so we'll + // keep a master one around too before a Refresh happens. + if s.consistency != Eventual || s.slaveSocket != nil { s.setSocket(sock) } @@ -2680,27 +3163,38 @@ func (s *Session) acquireSocket(slaveOk bool) (*mongoSocket, error) { return sock, nil } -// Set the socket bound to this session. With a bound socket, all operations -// with this session will use the given socket if possible. When not possible -// (e.g. attempting to write to a slave) acquireSocket will replace the -// current socket. Note that this method will properly refcount the socket up -// and down when setting/releasing. +// setSocket binds socket to this section. func (s *Session) setSocket(socket *mongoSocket) { - if socket != nil { - s.socketIsMaster = socket.Acquire() + info := socket.Acquire() + if info.Master { + if s.masterSocket != nil { + panic("setSocket(master) with existing master socket reserved") + } + s.masterSocket = socket } else { - s.socketIsMaster = false + if s.slaveSocket != nil { + panic("setSocket(slave) with existing slave socket reserved") + } + s.slaveSocket = socket } - if s.socket != nil { - s.socket.Release() +} + +// unsetSocket releases any slave and/or master sockets reserved. +func (s *Session) unsetSocket() { + if s.masterSocket != nil { + s.masterSocket.Release() } - s.socket = socket + if s.slaveSocket != nil { + s.slaveSocket.Release() + } + s.masterSocket = nil + s.slaveSocket = nil } func (iter *Iter) replyFunc() replyFunc { return func(err error, op *replyOp, docNum int, docData []byte) { iter.m.Lock() - iter.pendingDocs-- + iter.docsToReceive-- if err != nil { iter.err = err debugf("Iter %p received an error: %s", iter, err.Error()) @@ -2715,8 +3209,13 @@ func (iter *Iter) replyFunc() replyFunc { } else { rdocs := int(op.replyDocs) if docNum == 0 { - iter.pendingDocs += rdocs - 1 - iter.docsBeforeMore = rdocs - int(iter.prefetch*float64(rdocs)) + iter.docsToReceive += rdocs - 1 + docsToProcess := iter.docData.Len() + rdocs + if iter.limit == 0 || int32(docsToProcess) < iter.limit { + iter.docsBeforeMore = docsToProcess - int(iter.prefetch*float64(rdocs)) + } else { + iter.docsBeforeMore = -1 + } iter.op.cursorId = op.cursorId } // XXX Handle errors and flags. @@ -2732,8 +3231,10 @@ func (iter *Iter) replyFunc() replyFunc { // by a getLastError command in case the session is in safe mode. The // LastError result is made available in lerr, and if lerr.Err is set it // will also be returned as err. -func (s *Session) writeQuery(op interface{}) (lerr *LastError, err error) { - socket, err := s.acquireSocket(false) +func (c *Collection) writeQuery(op interface{}) (lerr *LastError, err error) { + s := c.Database.Session + dbname := c.Database.Name + socket, err := s.acquireSocket(dbname == "local") if err != nil { return nil, err } @@ -2751,6 +3252,7 @@ func (s *Session) writeQuery(op interface{}) (lerr *LastError, err error) { var replyErr error mutex.Lock() query := *safeOp // Copy the data. + query.collection = dbname + ".$cmd" query.replyFunc = func(err error, reply *replyOp, docNum int, docData []byte) { replyData = docData replyErr = err @@ -2764,6 +3266,13 @@ func (s *Session) writeQuery(op interface{}) (lerr *LastError, err error) { if replyErr != nil { return nil, replyErr // XXX TESTME } + if hasErrMsg(replyData) { + // Looks like getLastError itself failed. + err = checkQueryError(query.collection, replyData) + if err != nil { + return nil, err + } + } result := &LastError{} bson.Unmarshal(replyData, &result) debugf("Result from writing query: %#v", result) @@ -2774,3 +3283,13 @@ func (s *Session) writeQuery(op interface{}) (lerr *LastError, err error) { } panic("unreachable") } + +func hasErrMsg(d []byte) bool { + l := len(d) + for i := 0; i+8 < l; i++ { + if d[i] == '\x02' && d[i+1] == 'e' && d[i+2] == 'r' && d[i+3] == 'r' && d[i+4] == 'm' && d[i+5] == 's' && d[i+6] == 'g' && d[i+7] == '\x00' { + return true + } + } + return false +} diff --git a/third_party/labix.org/v2/mgo/session_test.go b/third_party/labix.org/v2/mgo/session_test.go index 96e4f522c..03b685346 100644 --- a/third_party/labix.org/v2/mgo/session_test.go +++ b/third_party/labix.org/v2/mgo/session_test.go @@ -1,18 +1,18 @@ // mgo - MongoDB driver for Go -// +// // Copyright (c) 2010-2012 - Gustavo Niemeyer -// +// // All rights reserved. // // Redistribution and use in source and binary forms, with or without -// modification, are permitted provided that the following conditions are met: -// +// modification, are permitted provided that the following conditions are met: +// // 1. Redistributions of source code must retain the above copyright notice, this -// list of conditions and the following disclaimer. +// list of conditions and the following disclaimer. // 2. Redistributions in binary form must reproduce the above copyright notice, // this list of conditions and the following disclaimer in the documentation -// and/or other materials provided with the distribution. -// +// and/or other materials provided with the distribution. +// // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND // ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED // WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE @@ -27,11 +27,14 @@ package mgo_test import ( - "errors" - . "camlistore.org/third_party/launchpad.net/gocheck" "camlistore.org/third_party/labix.org/v2/mgo" "camlistore.org/third_party/labix.org/v2/mgo/bson" + . "camlistore.org/third_party/launchpad.net/gocheck" + "flag" + "fmt" "math" + "reflect" + "runtime" "sort" "strconv" "strings" @@ -139,7 +142,7 @@ func (s *S) TestInsertFindOneNil(c *C) { coll := session.DB("mydb").C("mycoll") err = coll.Find(nil).One(nil) - c.Assert(err, ErrorMatches, "unauthorized.*") + c.Assert(err, ErrorMatches, "unauthorized.*|not authorized.*") } func (s *S) TestInsertFindOneMap(c *C) { @@ -261,7 +264,10 @@ func (s *S) TestDatabaseAndCollectionNames(c *C) { names, err := session.DatabaseNames() c.Assert(err, IsNil) - c.Assert(names, DeepEquals, []string{"db1", "db2"}) + if !reflect.DeepEqual(names, []string{"db1", "db2"}) { + // 2.4+ has "local" as well. + c.Assert(names, DeepEquals, []string{"db1", "db2", "local"}) + } names, err = db1.CollectionNames() c.Assert(err, IsNil) @@ -288,6 +294,37 @@ func (s *S) TestSelect(c *C) { c.Assert(result.B, Equals, 2) } +func (s *S) TestInlineMap(c *C) { + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + coll := session.DB("mydb").C("mycoll") + + var v, result1 struct { + A int + M map[string]int ",inline" + } + + v.A = 1 + v.M = map[string]int{"b": 2} + err = coll.Insert(v) + c.Assert(err, IsNil) + + noId := M{"_id": 0} + + err = coll.Find(nil).Select(noId).One(&result1) + c.Assert(err, IsNil) + c.Assert(result1.A, Equals, 1) + c.Assert(result1.M, DeepEquals, map[string]int{"b": 2}) + + var result2 M + err = coll.Find(nil).Select(noId).One(&result2) + c.Assert(err, IsNil) + c.Assert(result2, DeepEquals, M{"a": 1, "b": 2}) + +} + func (s *S) TestUpdate(c *C) { session, err := mgo.Dial("localhost:40001") c.Assert(err, IsNil) @@ -316,6 +353,34 @@ func (s *S) TestUpdate(c *C) { c.Assert(err, Equals, mgo.ErrNotFound) } +func (s *S) TestUpdateId(c *C) { + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + coll := session.DB("mydb").C("mycoll") + + ns := []int{40, 41, 42, 43, 44, 45, 46} + for _, n := range ns { + err := coll.Insert(M{"_id": n, "n": n}) + c.Assert(err, IsNil) + } + + err = coll.UpdateId(42, M{"$inc": M{"n": 1}}) + c.Assert(err, IsNil) + + result := make(M) + err = coll.FindId(42).One(result) + c.Assert(err, IsNil) + c.Assert(result["n"], Equals, 43) + + err = coll.UpdateId(47, M{"k": 47, "n": 47}) + c.Assert(err, Equals, mgo.ErrNotFound) + + err = coll.FindId(47).One(result) + c.Assert(err, Equals, mgo.ErrNotFound) +} + func (s *S) TestUpdateNil(c *C) { session, err := mgo.Dial("localhost:40001") c.Assert(err, IsNil) @@ -396,6 +461,39 @@ func (s *S) TestUpsert(c *C) { c.Assert(result["n"], Equals, 48) } +func (s *S) TestUpsertId(c *C) { + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + coll := session.DB("mydb").C("mycoll") + + ns := []int{40, 41, 42, 43, 44, 45, 46} + for _, n := range ns { + err := coll.Insert(M{"_id": n, "n": n}) + c.Assert(err, IsNil) + } + + info, err := coll.UpsertId(42, M{"n": 24}) + c.Assert(err, IsNil) + c.Assert(info.Updated, Equals, 1) + c.Assert(info.UpsertedId, IsNil) + + result := M{} + err = coll.FindId(42).One(result) + c.Assert(err, IsNil) + c.Assert(result["n"], Equals, 24) + + info, err = coll.UpsertId(47, M{"_id": 47, "n": 47}) + c.Assert(err, IsNil) + c.Assert(info.Updated, Equals, 0) + c.Assert(info.UpsertedId, IsNil) + + err = coll.FindId(47).One(result) + c.Assert(err, IsNil) + c.Assert(result["n"], Equals, 47) +} + func (s *S) TestUpdateAll(c *C) { session, err := mgo.Dial("localhost:40001") c.Assert(err, IsNil) @@ -460,6 +558,24 @@ func (s *S) TestRemove(c *C) { c.Assert(result.N, Equals, 44) } +func (s *S) TestRemoveId(c *C) { + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + coll := session.DB("mydb").C("mycoll") + + err = coll.Insert(M{"_id": 40}, M{"_id": 41}, M{"_id": 42}) + c.Assert(err, IsNil) + + err = coll.RemoveId(41) + c.Assert(err, IsNil) + + c.Assert(coll.FindId(40).One(nil), IsNil) + c.Assert(coll.FindId(41).One(nil), Equals, mgo.ErrNotFound) + c.Assert(coll.FindId(42).One(nil), IsNil) +} + func (s *S) TestRemoveAll(c *C) { session, err := mgo.Dial("localhost:40001") c.Assert(err, IsNil) @@ -507,14 +623,20 @@ func (s *S) TestDropDatabase(c *C) { names, err := session.DatabaseNames() c.Assert(err, IsNil) - c.Assert(names, DeepEquals, []string{"db2"}) + if !reflect.DeepEqual(names, []string{"db2"}) { + // 2.4+ has "local" as well. + c.Assert(names, DeepEquals, []string{"db2", "local"}) + } err = db2.DropDatabase() c.Assert(err, IsNil) names, err = session.DatabaseNames() c.Assert(err, IsNil) - c.Assert(names, DeepEquals, []string(nil)) + if !reflect.DeepEqual(names, []string(nil)) { + // 2.4+ has "local" as well. + c.Assert(names, DeepEquals, []string{"local"}) + } } func (s *S) TestDropCollection(c *C) { @@ -549,9 +671,9 @@ func (s *S) TestCreateCollectionCapped(c *C) { coll := session.DB("mydb").C("mycoll") info := &mgo.CollectionInfo{ - Capped: true, + Capped: true, MaxBytes: 1024, - MaxDocs: 3, + MaxDocs: 3, } err = coll.Create(info) c.Assert(err, IsNil) @@ -596,8 +718,8 @@ func (s *S) TestCreateCollectionForceIndex(c *C) { info := &mgo.CollectionInfo{ ForceIdIndex: true, - Capped: true, - MaxBytes: 1024, + Capped: true, + MaxBytes: 1024, } err = coll.Create(info) c.Assert(err, IsNil) @@ -609,15 +731,108 @@ func (s *S) TestCreateCollectionForceIndex(c *C) { c.Assert(indexes, HasLen, 1) } -func (s *S) TestFindAndModify(c *C) { +func (s *S) TestIsDupValues(c *C) { + c.Assert(mgo.IsDup(nil), Equals, false) + c.Assert(mgo.IsDup(&mgo.LastError{Code: 1}), Equals, false) + c.Assert(mgo.IsDup(&mgo.QueryError{Code: 1}), Equals, false) + c.Assert(mgo.IsDup(&mgo.LastError{Code: 11000}), Equals, true) + c.Assert(mgo.IsDup(&mgo.QueryError{Code: 11000}), Equals, true) + c.Assert(mgo.IsDup(&mgo.LastError{Code: 11001}), Equals, true) + c.Assert(mgo.IsDup(&mgo.QueryError{Code: 11001}), Equals, true) + c.Assert(mgo.IsDup(&mgo.LastError{Code: 12582}), Equals, true) + c.Assert(mgo.IsDup(&mgo.QueryError{Code: 12582}), Equals, true) +} + +func (s *S) TestIsDupPrimary(c *C) { session, err := mgo.Dial("localhost:40001") c.Assert(err, IsNil) defer session.Close() coll := session.DB("mydb").C("mycoll") + err = coll.Insert(M{"_id": 1}) + c.Assert(err, IsNil) + err = coll.Insert(M{"_id": 1}) + c.Assert(err, ErrorMatches, ".*duplicate key error.*") + c.Assert(mgo.IsDup(err), Equals, true) +} + +func (s *S) TestIsDupUnique(c *C) { + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + index := mgo.Index{ + Key: []string{"a", "b"}, + Unique: true, + } + + coll := session.DB("mydb").C("mycoll") + + err = coll.EnsureIndex(index) + c.Assert(err, IsNil) + + err = coll.Insert(M{"a": 1, "b": 1}) + c.Assert(err, IsNil) + err = coll.Insert(M{"a": 1, "b": 1}) + c.Assert(err, ErrorMatches, ".*duplicate key error.*") + c.Assert(mgo.IsDup(err), Equals, true) +} + +func (s *S) TestIsDupCapped(c *C) { + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + coll := session.DB("mydb").C("mycoll") + + info := &mgo.CollectionInfo{ + ForceIdIndex: true, + Capped: true, + MaxBytes: 1024, + } + err = coll.Create(info) + c.Assert(err, IsNil) + + err = coll.Insert(M{"_id": 1}) + c.Assert(err, IsNil) + err = coll.Insert(M{"_id": 1}) + // Quite unfortunate that the error is different for capped collections. + c.Assert(err, ErrorMatches, "duplicate key.*capped collection") + // The issue is reduced by using IsDup. + c.Assert(mgo.IsDup(err), Equals, true) +} + +func (s *S) TestIsDupFindAndModify(c *C) { + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + coll := session.DB("mydb").C("mycoll") + + err = coll.EnsureIndex(mgo.Index{Key: []string{"n"}, Unique: true}) + c.Assert(err, IsNil) + + err = coll.Insert(M{"n": 1}) + c.Assert(err, IsNil) + err = coll.Insert(M{"n": 2}) + c.Assert(err, IsNil) + _, err = coll.Find(M{"n": 1}).Apply(mgo.Change{Update: M{"$inc": M{"n": 1}}}, bson.M{}) + c.Assert(err, ErrorMatches, ".*duplicate key error.*") + c.Assert(mgo.IsDup(err), Equals, true) +} + +func (s *S) TestFindAndModify(c *C) { + session, err := mgo.Dial("localhost:40011") + c.Assert(err, IsNil) + defer session.Close() + + coll := session.DB("mydb").C("mycoll") + err = coll.Insert(M{"n": 42}) + session.SetMode(mgo.Monotonic, true) + result := M{} info, err := coll.Find(M{"n": 42}).Apply(mgo.Change{Update: M{"$inc": M{"n": 1}}}, result) c.Assert(err, IsNil) @@ -677,10 +892,16 @@ func (s *S) TestFindAndModifyBug997828(c *C) { result := make(M) _, err = coll.Find(M{"n": "not-a-number"}).Apply(mgo.Change{Update: M{"$inc": M{"n": 1}}}, result) - c.Assert(err, ErrorMatches, `Cannot apply \$inc modifier to non-number`) - lerr, _ := err.(*mgo.LastError) - c.Assert(lerr, NotNil) - c.Assert(lerr.Code, Equals, 10140) + c.Assert(err, ErrorMatches, `(exception: )?Cannot apply \$inc modifier to non-number`) + if s.versionAtLeast(2, 1) { + qerr, _ := err.(*mgo.QueryError) + c.Assert(qerr, NotNil, Commentf("err: %#v", err)) + c.Assert(qerr.Code, Equals, 10140) + } else { + lerr, _ := err.(*mgo.LastError) + c.Assert(lerr, NotNil, Commentf("err: %#v", err)) + c.Assert(lerr.Code, Equals, 10140) + } } func (s *S) TestCountCollection(c *C) { @@ -773,8 +994,8 @@ func (s *S) TestQueryExplain(c *C) { } m := M{} - query := coll.Find(nil).Batch(1).Limit(2) - err = query.Batch(2).Explain(m) + query := coll.Find(nil).Limit(2) + err = query.Explain(m) c.Assert(err, IsNil) c.Assert(m["cursor"], Equals, "BasicCursor") c.Assert(m["nscanned"], Equals, 2) @@ -782,11 +1003,11 @@ func (s *S) TestQueryExplain(c *C) { n := 0 var result M - err = query.For(&result, func() error { + iter := query.Iter() + for iter.Next(&result) { n++ - return nil - }) - c.Assert(err, IsNil) + } + c.Assert(iter.Close(), IsNil) c.Assert(n, Equals, 2) } @@ -802,7 +1023,7 @@ func (s *S) TestQueryHint(c *C) { err = coll.Find(nil).Hint("a").Explain(m) c.Assert(err, IsNil) c.Assert(m["indexBounds"], NotNil) - c.Assert(m["indexBounds"].(bson.M)["a"], NotNil) + c.Assert(m["indexBounds"].(M)["a"], NotNil) } func (s *S) TestFindOneNotFound(c *C) { @@ -869,8 +1090,7 @@ func (s *S) TestFindIterAll(c *C) { mgo.ResetStats() - query := coll.Find(M{"n": M{"$gte": 42}}).Sort("$natural").Prefetch(0).Batch(2) - iter := query.Iter() + iter := coll.Find(M{"n": M{"$gte": 42}}).Sort("$natural").Prefetch(0).Batch(2).Iter() result := struct{ N int }{} for i := 2; i < 7; i++ { ok := iter.Next(&result) @@ -884,7 +1104,7 @@ func (s *S) TestFindIterAll(c *C) { ok := iter.Next(&result) c.Assert(ok, Equals, false) - c.Assert(iter.Err(), IsNil) + c.Assert(iter.Close(), IsNil) session.Refresh() // Release socket. @@ -933,7 +1153,7 @@ func (s *S) TestFindIterWithoutResults(c *C) { result := struct{ N int }{} ok := iter.Next(&result) c.Assert(ok, Equals, false) - c.Assert(iter.Err(), IsNil) + c.Assert(iter.Close(), IsNil) c.Assert(result.N, Equals, 0) } @@ -965,17 +1185,62 @@ func (s *S) TestFindIterLimit(c *C) { ok := iter.Next(&result) c.Assert(ok, Equals, false) - c.Assert(iter.Err(), IsNil) + c.Assert(iter.Close(), IsNil) session.Refresh() // Release socket. stats := mgo.GetStats() - c.Assert(stats.SentOps, Equals, 1) // 1*QUERY_OP + c.Assert(stats.SentOps, Equals, 2) // 1*QUERY_OP + 1*KILL_CURSORS_OP c.Assert(stats.ReceivedOps, Equals, 1) // and its REPLY_OP c.Assert(stats.ReceivedDocs, Equals, 3) c.Assert(stats.SocketsInUse, Equals, 0) } +func (s *S) TestTooManyItemsLimitBug(c *C) { + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + defer runtime.GOMAXPROCS(runtime.GOMAXPROCS(runtime.NumCPU())) + + mgo.SetDebug(false) + coll := session.DB("mydb").C("mycoll") + words := strings.Split("foo bar baz", " ") + for i := 0; i < 5; i++ { + words = append(words, words...) + } + doc := bson.D{{"words", words}} + inserts := 10000 + limit := 5000 + iters := 0 + c.Assert(inserts > limit, Equals, true) + for i := 0; i < inserts; i++ { + err := coll.Insert(&doc) + c.Assert(err, IsNil) + } + iter := coll.Find(nil).Limit(limit).Iter() + for iter.Next(&doc) { + if iters%100 == 0 { + c.Logf("Seen %d docments", iters) + } + iters++ + } + c.Assert(iter.Close(), IsNil) + c.Assert(iters, Equals, limit) +} + +func serverCursorsOpen(session *mgo.Session) int { + var result struct { + Cursors struct { + TotalOpen int `bson:"totalOpen"` + TimedOut int `bson:"timedOut"` + } + } + err := session.Run("serverStatus", &result) + if err != nil { + panic(err) + } + return result.Cursors.TotalOpen +} func (s *S) TestFindIterLimitWithMore(c *C) { session, err := mgo.Dial("localhost:40001") @@ -1015,6 +1280,8 @@ func (s *S) TestFindIterLimitWithMore(c *C) { c.Fatalf("Bad result size with negative limit: %d", nresults) } + cursorsOpen := serverCursorsOpen(session) + // Try again, with a positive limit. Should reach the end now, // using multiple chunks. nresults = 0 @@ -1024,6 +1291,9 @@ func (s *S) TestFindIterLimitWithMore(c *C) { } c.Assert(nresults, Equals, total) + // Ensure the cursor used is properly killed. + c.Assert(serverCursorsOpen(session), Equals, cursorsOpen) + // Edge case, -MinInt == -MinInt. nresults = 0 iter = coll.Find(nil).Limit(math.MinInt32).Iter() @@ -1069,12 +1339,12 @@ func (s *S) TestFindIterLimitWithBatch(c *C) { ok := iter.Next(&result) c.Assert(ok, Equals, false) - c.Assert(iter.Err(), IsNil) + c.Assert(iter.Close(), IsNil) session.Refresh() // Release socket. stats := mgo.GetStats() - c.Assert(stats.SentOps, Equals, 2) // 1*QUERY_OP + 1*GET_MORE_OP + c.Assert(stats.SentOps, Equals, 3) // 1*QUERY_OP + 1*GET_MORE_OP + 1*KILL_CURSORS_OP c.Assert(stats.ReceivedOps, Equals, 2) // and its REPLY_OPs c.Assert(stats.ReceivedDocs, Equals, 3) c.Assert(stats.SocketsInUse, Equals, 0) @@ -1120,7 +1390,7 @@ func (s *S) TestFindIterSortWithBatch(c *C) { ok := iter.Next(&result) c.Assert(ok, Equals, false) - c.Assert(iter.Err(), IsNil) + c.Assert(iter.Close(), IsNil) session.Refresh() // Release socket. @@ -1224,7 +1494,7 @@ func (s *S) TestFindTailTimeoutWithSleep(c *C) { coll.Insert(M{"n": 48}) ok = iter.Next(&result) c.Assert(ok, Equals, true) - c.Assert(iter.Err(), IsNil) + c.Assert(iter.Close(), IsNil) c.Assert(iter.Timeout(), Equals, false) c.Assert(result.N, Equals, 48) } @@ -1317,7 +1587,7 @@ func (s *S) TestFindTailTimeoutNoSleep(c *C) { coll.Insert(M{"n": 48}) ok = iter.Next(&result) c.Assert(ok, Equals, true) - c.Assert(iter.Err(), IsNil) + c.Assert(iter.Close(), IsNil) c.Assert(iter.Timeout(), Equals, false) c.Assert(result.N, Equals, 48) } @@ -1423,6 +1693,76 @@ func (s *S) TestFindTailNoTimeout(c *C) { } } +func (s *S) TestIterNextResetsResult(c *C) { + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + coll := session.DB("mydb").C("mycoll") + + ns := []int{1, 2, 3} + for _, n := range ns { + coll.Insert(M{"n" + strconv.Itoa(n): n}) + } + + query := coll.Find(nil).Sort("$natural") + + i := 0 + var sresult *struct{ N1, N2, N3 int } + iter := query.Iter() + for iter.Next(&sresult) { + switch i { + case 0: + c.Assert(sresult.N1, Equals, 1) + c.Assert(sresult.N2+sresult.N3, Equals, 0) + case 1: + c.Assert(sresult.N2, Equals, 2) + c.Assert(sresult.N1+sresult.N3, Equals, 0) + case 2: + c.Assert(sresult.N3, Equals, 3) + c.Assert(sresult.N1+sresult.N2, Equals, 0) + } + i++ + } + c.Assert(iter.Close(), IsNil) + + i = 0 + var mresult M + iter = query.Iter() + for iter.Next(&mresult) { + delete(mresult, "_id") + switch i { + case 0: + c.Assert(mresult, DeepEquals, M{"n1": 1}) + case 1: + c.Assert(mresult, DeepEquals, M{"n2": 2}) + case 2: + c.Assert(mresult, DeepEquals, M{"n3": 3}) + } + i++ + } + c.Assert(iter.Close(), IsNil) + + i = 0 + var iresult interface{} + iter = query.Iter() + for iter.Next(&iresult) { + mresult, ok := iresult.(bson.M) + c.Assert(ok, Equals, true, Commentf("%#v", iresult)) + delete(mresult, "_id") + switch i { + case 0: + c.Assert(mresult, DeepEquals, bson.M{"n1": 1}) + case 1: + c.Assert(mresult, DeepEquals, bson.M{"n2": 2}) + case 2: + c.Assert(mresult, DeepEquals, bson.M{"n3": 3}) + } + i++ + } + c.Assert(iter.Close(), IsNil) +} + func (s *S) TestFindForOnIter(c *C) { session, err := mgo.Dial("localhost:40001") c.Assert(err, IsNil) @@ -1525,7 +1865,7 @@ func (s *S) TestFindForStopOnError(c *C) { c.Assert(i < 4, Equals, true) c.Assert(result.N, Equals, ns[i]) if i == 3 { - return errors.New("stop!") + return fmt.Errorf("stop!") } i++ return nil @@ -1627,7 +1967,9 @@ func (s *S) TestFindIterSnapshot(c *C) { iter := query.Iter() seen := map[int]bool{} - result := struct{ Id int "_id" }{} + result := struct { + Id int "_id" + }{} for iter.Next(&result) { if len(seen) == 2 { // Grow all entries so that they have to move. @@ -1642,7 +1984,7 @@ func (s *S) TestFindIterSnapshot(c *C) { } seen[result.Id] = true } - c.Assert(iter.Err(), IsNil) + c.Assert(iter.Close(), IsNil) } func (s *S) TestSort(c *C) { @@ -1703,56 +2045,72 @@ func (s *S) TestPrefetching(c *C) { coll := session.DB("mydb").C("mycoll") - docs := make([]interface{}, 200) - for i := 0; i != 200; i++ { - docs[i] = M{"n": i} + mgo.SetDebug(false) + docs := make([]interface{}, 800) + for i := 0; i != 600; i++ { + docs[i] = bson.D{{"n", i}} } coll.Insert(docs...) - // Same test three times. Once with prefetching via query, then with the - // default prefetching, and a third time tweaking the default settings in - // the session. - for testi := 0; testi != 3; testi++ { + for testi := 0; testi < 5; testi++ { mgo.ResetStats() var iter *mgo.Iter - var nextn int + var beforeMore int switch testi { - case 0: // First, using query methods. - iter = coll.Find(M{}).Prefetch(0.27).Batch(100).Iter() - nextn = 73 - - case 1: // Then, the default session value. + case 0: // The default session value. session.SetBatch(100) iter = coll.Find(M{}).Iter() - nextn = 75 + beforeMore = 75 - case 2: // Then, tweaking the session value. + case 2: // Changing the session value. session.SetBatch(100) session.SetPrefetch(0.27) iter = coll.Find(M{}).Iter() - nextn = 73 + beforeMore = 73 + + case 1: // Changing via query methods. + iter = coll.Find(M{}).Prefetch(0.27).Batch(100).Iter() + beforeMore = 73 + + case 3: // With prefetch on first document. + iter = coll.Find(M{}).Prefetch(1.0).Batch(100).Iter() + beforeMore = 0 + + case 4: // Without prefetch. + iter = coll.Find(M{}).Prefetch(0).Batch(100).Iter() + beforeMore = 100 } - result := struct{ N int }{} - for i := 0; i != nextn; i++ { + pings := 0 + for batchi := 0; batchi < len(docs)/100-1; batchi++ { + c.Logf("Iterating over %d documents on batch %d", beforeMore, batchi) + var result struct{ N int } + for i := 0; i < beforeMore; i++ { + ok := iter.Next(&result) + c.Assert(ok, Equals, true, Commentf("iter.Err: %v", iter.Err())) + } + beforeMore = 99 + c.Logf("Done iterating.") + + session.Run("ping", nil) // Roundtrip to settle down. + pings++ + + stats := mgo.GetStats() + c.Assert(stats.ReceivedDocs, Equals, (batchi+1)*100+pings) + + c.Logf("Iterating over one more document on batch %d", batchi) ok := iter.Next(&result) - c.Assert(ok, Equals, true) + c.Assert(ok, Equals, true, Commentf("iter.Err: %v", iter.Err())) + c.Logf("Done iterating.") + + session.Run("ping", nil) // Roundtrip to settle down. + pings++ + + stats = mgo.GetStats() + c.Assert(stats.ReceivedDocs, Equals, (batchi+2)*100+pings) } - - stats := mgo.GetStats() - c.Assert(stats.ReceivedDocs, Equals, 100) - - ok := iter.Next(&result) - c.Assert(ok, Equals, true) - - // Ping the database just to wait for the fetch above - // to get delivered. - session.Run("ping", M{}) // XXX Should support nil here. - - stats = mgo.GetStats() - c.Assert(stats.ReceivedDocs, Equals, 201) // 200 + the ping result } } @@ -1933,10 +2291,11 @@ func (s *S) TestQueryErrorNext(c *C) { ok := iter.Next(&result) c.Assert(ok, Equals, false) - err = iter.Err() + err = iter.Close() c.Assert(err, ErrorMatches, "Unsupported projection option: b") c.Assert(err.(*mgo.QueryError).Message, Matches, "Unsupported projection option: b") c.Assert(err.(*mgo.QueryError).Code, Equals, 13097) + c.Assert(iter.Err(), Equals, err) // The result should be properly unmarshalled with QueryError c.Assert(result.Err, Matches, "Unsupported projection option: b") @@ -1958,8 +2317,16 @@ func (s *S) TestEnsureIndex(c *C) { DropDups: true, } + // Obsolete: index3 := mgo.Index{ - Key: []string{"@loc"}, + Key: []string{"@loc_old"}, + Min: -500, + Max: 500, + Bits: 32, + } + + index4 := mgo.Index{ + Key: []string{"$2d:loc"}, Min: -500, Max: 500, Bits: 32, @@ -1967,14 +2334,10 @@ func (s *S) TestEnsureIndex(c *C) { coll := session.DB("mydb").C("mycoll") - err = coll.EnsureIndex(index1) - c.Assert(err, IsNil) - - err = coll.EnsureIndex(index2) - c.Assert(err, IsNil) - - err = coll.EnsureIndex(index3) - c.Assert(err, IsNil) + for _, index := range []mgo.Index{index1, index2, index3, index4} { + err = coll.EnsureIndex(index) + c.Assert(err, IsNil) + } sysidx := session.DB("mydb").C("system.indexes") @@ -1987,13 +2350,17 @@ func (s *S) TestEnsureIndex(c *C) { c.Assert(err, IsNil) result3 := M{} - err = sysidx.Find(M{"name": "loc_"}).One(result3) + err = sysidx.Find(M{"name": "loc_old_2d"}).One(result3) + c.Assert(err, IsNil) + + result4 := M{} + err = sysidx.Find(M{"name": "loc_2d"}).One(result4) c.Assert(err, IsNil) delete(result1, "v") expected1 := M{ "name": "a_1", - "key": bson.M{"a": 1}, + "key": M{"a": 1}, "ns": "mydb.mycoll", "background": true, } @@ -2002,7 +2369,7 @@ func (s *S) TestEnsureIndex(c *C) { delete(result2, "v") expected2 := M{ "name": "a_1_b_-1", - "key": bson.M{"a": 1, "b": -1}, + "key": M{"a": 1, "b": -1}, "ns": "mydb.mycoll", "unique": true, "dropDups": true, @@ -2011,8 +2378,8 @@ func (s *S) TestEnsureIndex(c *C) { delete(result3, "v") expected3 := M{ - "name": "loc_", - "key": bson.M{"loc": "2d"}, + "name": "loc_old_2d", + "key": M{"loc_old": "2d"}, "ns": "mydb.mycoll", "min": -500, "max": 500, @@ -2020,11 +2387,23 @@ func (s *S) TestEnsureIndex(c *C) { } c.Assert(result3, DeepEquals, expected3) + delete(result4, "v") + expected4 := M{ + "name": "loc_2d", + "key": M{"loc": "2d"}, + "ns": "mydb.mycoll", + "min": -500, + "max": 500, + "bits": 32, + } + c.Assert(result4, DeepEquals, expected4) + // Ensure the index actually works for real. err = coll.Insert(M{"a": 1, "b": 1}) c.Assert(err, IsNil) err = coll.Insert(M{"a": 1, "b": 1}) c.Assert(err, ErrorMatches, ".*duplicate key error.*") + c.Assert(mgo.IsDup(err), Equals, true) } func (s *S) TestEnsureIndexWithBadInfo(c *C) { @@ -2092,7 +2471,7 @@ func (s *S) TestEnsureIndexKey(c *C) { delete(result1, "v") expected1 := M{ "name": "a_1", - "key": bson.M{"a": 1}, + "key": M{"a": 1}, "ns": "mydb.mycoll", } c.Assert(result1, DeepEquals, expected1) @@ -2100,7 +2479,7 @@ func (s *S) TestEnsureIndexKey(c *C) { delete(result2, "v") expected2 := M{ "name": "a_1_b_-1", - "key": bson.M{"a": 1, "b": -1}, + "key": M{"a": 1, "b": -1}, "ns": "mydb.mycoll", } c.Assert(result2, DeepEquals, expected2) @@ -2195,18 +2574,105 @@ func (s *S) TestEnsureIndexGetIndexes(c *C) { err = coll.EnsureIndexKey("a") c.Assert(err, IsNil) + // Obsolete. err = coll.EnsureIndexKey("@c") c.Assert(err, IsNil) + err = coll.EnsureIndexKey("$2d:d") + c.Assert(err, IsNil) + indexes, err := coll.Indexes() + c.Assert(err, IsNil) c.Assert(indexes[0].Name, Equals, "_id_") c.Assert(indexes[1].Name, Equals, "a_1") c.Assert(indexes[1].Key, DeepEquals, []string{"a"}) c.Assert(indexes[2].Name, Equals, "b_-1") c.Assert(indexes[2].Key, DeepEquals, []string{"-b"}) - c.Assert(indexes[3].Name, Equals, "c_") - c.Assert(indexes[3].Key, DeepEquals, []string{"@c"}) + c.Assert(indexes[3].Name, Equals, "c_2d") + c.Assert(indexes[3].Key, DeepEquals, []string{"$2d:c"}) + c.Assert(indexes[4].Name, Equals, "d_2d") + c.Assert(indexes[4].Key, DeepEquals, []string{"$2d:d"}) +} + +func (s *S) TestEnsureIndexEvalGetIndexes(c *C) { + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + coll := session.DB("mydb").C("mycoll") + + err = session.Run(bson.D{{"eval", "db.getSiblingDB('mydb').mycoll.ensureIndex({b: -1})"}}, nil) + c.Assert(err, IsNil) + err = session.Run(bson.D{{"eval", "db.getSiblingDB('mydb').mycoll.ensureIndex({a: 1})"}}, nil) + c.Assert(err, IsNil) + err = session.Run(bson.D{{"eval", "db.getSiblingDB('mydb').mycoll.ensureIndex({c: '2d'})"}}, nil) + c.Assert(err, IsNil) + err = session.Run(bson.D{{"eval", "db.getSiblingDB('mydb').mycoll.ensureIndex({d: -1, e: 1})"}}, nil) + c.Assert(err, IsNil) + + indexes, err := coll.Indexes() + c.Assert(err, IsNil) + + c.Assert(indexes[0].Name, Equals, "_id_") + c.Assert(indexes[1].Name, Equals, "a_1") + c.Assert(indexes[1].Key, DeepEquals, []string{"a"}) + c.Assert(indexes[2].Name, Equals, "b_-1") + c.Assert(indexes[2].Key, DeepEquals, []string{"-b"}) + c.Assert(indexes[3].Name, Equals, "c_2d") + c.Assert(indexes[3].Key, DeepEquals, []string{"$2d:c"}) + c.Assert(indexes[4].Name, Equals, "d_-1_e_1") + c.Assert(indexes[4].Key, DeepEquals, []string{"-d", "e"}) +} + +var testTTL = flag.Bool("test-ttl", false, "test TTL collections (may take 1 minute)") + +func (s *S) TestEnsureIndexExpireAfter(c *C) { + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + session.SetSafe(nil) + + coll := session.DB("mydb").C("mycoll") + + err = coll.Insert(M{"n": 1, "t": time.Now().Add(-120 * time.Second)}) + c.Assert(err, IsNil) + err = coll.Insert(M{"n": 2, "t": time.Now()}) + c.Assert(err, IsNil) + + // Should fail since there are duplicated entries. + index := mgo.Index{ + Key: []string{"t"}, + ExpireAfter: 1 * time.Minute, + } + + err = coll.EnsureIndex(index) + c.Assert(err, IsNil) + + indexes, err := coll.Indexes() + c.Assert(err, IsNil) + c.Assert(indexes[1].Name, Equals, "t_1") + c.Assert(indexes[1].ExpireAfter, Equals, 1*time.Minute) + + if *testTTL { + worked := false + stop := time.Now().Add(70 * time.Second) + for time.Now().Before(stop) { + n, err := coll.Count() + c.Assert(err, IsNil) + if n == 1 { + worked = true + break + } + c.Assert(n, Equals, 2) + c.Logf("Still has 2 entries...") + time.Sleep(1 * time.Second) + } + if !worked { + c.Fatalf("TTL index didn't work") + } + } } func (s *S) TestDistinct(c *C) { @@ -2333,13 +2799,13 @@ func (s *S) TestMapReduceToCollection(c *C) { Value int } mr := session.DB("mydb").C("mr") - err = mr.Find(nil).For(&item, func() error { + iter := mr.Find(nil).Iter() + for iter.Next(&item) { c.Logf("Item: %#v", &item) c.Assert(item.Value, Equals, expected[item.Id]) expected[item.Id] = -1 - return nil - }) - c.Assert(err, IsNil) + } + c.Assert(iter.Close(), IsNil) } func (s *S) TestMapReduceToOtherDb(c *C) { @@ -2375,13 +2841,36 @@ func (s *S) TestMapReduceToOtherDb(c *C) { Value int } mr := session.DB("otherdb").C("mr") - err = mr.Find(nil).For(&item, func() error { + iter := mr.Find(nil).Iter() + for iter.Next(&item) { c.Logf("Item: %#v", &item) c.Assert(item.Value, Equals, expected[item.Id]) expected[item.Id] = -1 - return nil - }) + } + c.Assert(iter.Close(), IsNil) +} + +func (s *S) TestMapReduceOutOfOrder(c *C) { + session, err := mgo.Dial("localhost:40001") c.Assert(err, IsNil) + defer session.Close() + + coll := session.DB("mydb").C("mycoll") + + for _, i := range []int{1, 4, 6, 2, 2, 3, 4} { + coll.Insert(M{"n": i}) + } + + job := &mgo.MapReduce{ + Map: "function() { emit(this.n, 1); }", + Reduce: "function(key, values) { return Array.sum(values); }", + Out: bson.M{"a": "a", "z": "z", "replace": "mr", "db": "otherdb", "b": "b", "y": "y"}, + } + + info, err := coll.Find(nil).MapReduce(job, nil) + c.Assert(err, IsNil) + c.Assert(info.Collection, Equals, "mr") + c.Assert(info.Database, Equals, "otherdb") } func (s *S) TestMapReduceScope(c *C) { @@ -2458,10 +2947,16 @@ func (s *S) TestBuildInfo(c *C) { c.Assert(err, IsNil) var v []int - for _, a := range strings.Split(info.Version, ".") { - i, err := strconv.Atoi(a) + for i, a := range strings.Split(info.Version, ".") { + for _, token := range []string{"-rc", "-pre"} { + if i == 2 && strings.Contains(a, token) { + a = a[:strings.Index(a, token)] + info.VersionArray[len(info.VersionArray)-1] = 0 + } + } + n, err := strconv.Atoi(a) c.Assert(err, IsNil) - v = append(v, i) + v = append(v, n) } for len(v) < 4 { v = append(v, 0) @@ -2536,3 +3031,183 @@ func (s *S) TestFsync(c *C) { err = session.Fsync(true) c.Assert(err, IsNil) } + +func (s *S) TestPipeIter(c *C) { + if !s.versionAtLeast(2, 1) { + c.Skip("Pipe only works on 2.1+") + } + + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + coll := session.DB("mydb").C("mycoll") + + ns := []int{40, 41, 42, 43, 44, 45, 46} + for _, n := range ns { + coll.Insert(M{"n": n}) + } + + iter := coll.Pipe([]M{{"$match": M{"n": M{"$gte": 42}}}}).Iter() + result := struct{ N int }{} + for i := 2; i < 7; i++ { + ok := iter.Next(&result) + c.Assert(ok, Equals, true) + c.Assert(result.N, Equals, ns[i]) + } + + c.Assert(iter.Next(&result), Equals, false) + c.Assert(iter.Close(), IsNil) +} + +func (s *S) TestPipeAll(c *C) { + if !s.versionAtLeast(2, 1) { + c.Skip("Pipe only works on 2.1+") + } + + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + coll := session.DB("mydb").C("mycoll") + + ns := []int{40, 41, 42, 43, 44, 45, 46} + for _, n := range ns { + err := coll.Insert(M{"n": n}) + c.Assert(err, IsNil) + } + + var result []struct{ N int } + err = coll.Pipe([]M{{"$match": M{"n": M{"$gte": 42}}}}).All(&result) + c.Assert(err, IsNil) + for i := 2; i < 7; i++ { + c.Assert(result[i-2].N, Equals, ns[i]) + } +} + +func (s *S) TestPipeOne(c *C) { + if !s.versionAtLeast(2, 1) { + c.Skip("Pipe only works on 2.1+") + } + + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + coll := session.DB("mydb").C("mycoll") + coll.Insert(M{"a": 1, "b": 2}) + + result := struct{ A, B int }{} + + pipe := coll.Pipe([]M{{"$project": M{"a": 1, "b": M{"$add": []interface{}{"$b", 1}}}}}) + err = pipe.One(&result) + c.Assert(err, IsNil) + c.Assert(result.A, Equals, 1) + c.Assert(result.B, Equals, 3) + + pipe = coll.Pipe([]M{{"$match": M{"a": 2}}}) + err = pipe.One(&result) + c.Assert(err, Equals, mgo.ErrNotFound) +} + +func (s *S) TestBatch1Bug(c *C) { + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + coll := session.DB("mydb").C("mycoll") + + for i := 0; i < 3; i++ { + err := coll.Insert(M{"n": i}) + c.Assert(err, IsNil) + } + + var ns []struct{ N int } + err = coll.Find(nil).Batch(1).All(&ns) + c.Assert(err, IsNil) + c.Assert(len(ns), Equals, 3) + + session.SetBatch(1) + err = coll.Find(nil).All(&ns) + c.Assert(err, IsNil) + c.Assert(len(ns), Equals, 3) +} + +func (s *S) TestInterfaceIterBug(c *C) { + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + coll := session.DB("mydb").C("mycoll") + + for i := 0; i < 3; i++ { + err := coll.Insert(M{"n": i}) + c.Assert(err, IsNil) + } + + var result interface{} + + i := 0 + iter := coll.Find(nil).Sort("n").Iter() + for iter.Next(&result) { + c.Assert(result.(bson.M)["n"], Equals, i) + i++ + } + c.Assert(iter.Close(), IsNil) +} + +func (s *S) TestFindIterCloseKillsCursor(c *C) { + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + cursors := serverCursorsOpen(session) + + coll := session.DB("mydb").C("mycoll") + ns := []int{40, 41, 42, 43, 44, 45, 46} + for _, n := range ns { + err = coll.Insert(M{"n": n}) + c.Assert(err, IsNil) + } + + iter := coll.Find(nil).Batch(2).Iter() + c.Assert(iter.Next(bson.M{}), Equals, true) + + c.Assert(iter.Close(), IsNil) + c.Assert(serverCursorsOpen(session), Equals, cursors) +} + +func (s *S) TestLogReplay(c *C) { + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + coll := session.DB("mydb").C("mycoll") + for i := 0; i < 5; i++ { + err = coll.Insert(M{"ts": time.Now()}) + c.Assert(err, IsNil) + } + + iter := coll.Find(nil).LogReplay().Iter() + c.Assert(iter.Next(bson.M{}), Equals, false) + c.Assert(iter.Err(), ErrorMatches, "no ts field in query") +} + +func (s *S) TestSetCursorTimeout(c *C) { + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + coll := session.DB("mydb").C("mycoll") + err = coll.Insert(M{"n": 42}) + + // This is just a smoke test. Won't wait 10 minutes for an actual timeout. + + session.SetCursorTimeout(0) + + var result struct{ N int } + iter := coll.Find(nil).Iter() + c.Assert(iter.Next(&result), Equals, true) + c.Assert(result.N, Equals, 42) + c.Assert(iter.Next(&result), Equals, false) +} diff --git a/third_party/labix.org/v2/mgo/socket.go b/third_party/labix.org/v2/mgo/socket.go index 59a93a8bc..2fefa6d08 100644 --- a/third_party/labix.org/v2/mgo/socket.go +++ b/third_party/labix.org/v2/mgo/socket.go @@ -1,18 +1,18 @@ // mgo - MongoDB driver for Go -// +// // Copyright (c) 2010-2012 - Gustavo Niemeyer -// +// // All rights reserved. // // Redistribution and use in source and binary forms, with or without -// modification, are permitted provided that the following conditions are met: -// +// modification, are permitted provided that the following conditions are met: +// // 1. Redistributions of source code must retain the above copyright notice, this -// list of conditions and the following disclaimer. +// list of conditions and the following disclaimer. // 2. Redistributions in binary form must reproduce the above copyright notice, // this list of conditions and the following disclaimer in the documentation -// and/or other materials provided with the distribution. -// +// and/or other materials provided with the distribution. +// // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND // ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED // WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE @@ -27,10 +27,11 @@ package mgo import ( - "errors" "camlistore.org/third_party/labix.org/v2/mgo/bson" + "errors" "net" "sync" + "time" ) type replyFunc func(err error, reply *replyOp, docNum int, docData []byte) @@ -38,7 +39,8 @@ type replyFunc func(err error, reply *replyOp, docNum int, docData []byte) type mongoSocket struct { sync.Mutex server *mongoServer // nil when cached - conn *net.TCPConn + conn net.Conn + timeout time.Duration addr string // For debugging only. nextRequestId uint32 replyFuncs map[uint32]replyFunc @@ -48,16 +50,59 @@ type mongoSocket struct { cachedNonce string gotNonce sync.Cond dead error + serverInfo *mongoServerInfo } +type queryOpFlags uint32 + +const ( + _ queryOpFlags = 1 << iota + flagTailable + flagSlaveOk + flagLogReplay + flagNoCursorTimeout + flagAwaitData +) + type queryOp struct { collection string query interface{} skip int32 limit int32 selector interface{} - flags uint32 + flags queryOpFlags replyFunc replyFunc + + options queryWrapper + hasOptions bool + serverTags []bson.D +} + +type queryWrapper struct { + Query interface{} "$query" + OrderBy interface{} "$orderby,omitempty" + Hint interface{} "$hint,omitempty" + Explain bool "$explain,omitempty" + Snapshot bool "$snapshot,omitempty" + ReadPreference bson.D "$readPreference,omitempty" +} + +func (op *queryOp) finalQuery(socket *mongoSocket) interface{} { + if op.flags&flagSlaveOk != 0 && len(op.serverTags) > 0 && socket.ServerInfo().Mongos { + op.hasOptions = true + op.options.ReadPreference = bson.D{{"mode", "secondaryPreferred"}, {"tags", op.serverTags}} + } + if op.hasOptions { + if op.query == nil { + var empty bson.D + op.options.Query = empty + } else { + op.options.Query = op.query + } + debugf("final query is %#v\n", &op.options) + return &op.options + } + return op.query } type getMoreOp struct { @@ -92,17 +137,24 @@ type deleteOp struct { flags uint32 } +type killCursorsOp struct { + cursorIds []int64 +} + type requestInfo struct { bufferPos int replyFunc replyFunc } -func newSocket(server *mongoServer, conn *net.TCPConn) *mongoSocket { - socket := &mongoSocket{conn: conn, addr: server.Addr} +func newSocket(server *mongoServer, conn net.Conn, timeout time.Duration) *mongoSocket { + socket := &mongoSocket{ + conn: conn, + addr: server.Addr, + server: server, + replyFuncs: make(map[uint32]replyFunc), + } socket.gotNonce.L = &socket.Mutex - socket.replyFuncs = make(map[uint32]replyFunc) - socket.server = server - if err := socket.InitialAcquire(); err != nil { + if err := socket.InitialAcquire(server.Info(), timeout); err != nil { panic("newSocket: InitialAcquire returned error: " + err.Error()) } stats.socketsAlive(+1) @@ -112,10 +164,28 @@ func newSocket(server *mongoServer, conn *net.TCPConn) *mongoSocket { return socket } +// Server returns the server that the socket is associated with. +// It returns nil while the socket is cached in its respective server. +func (socket *mongoSocket) Server() *mongoServer { + socket.Lock() + server := socket.server + socket.Unlock() + return server +} + +// ServerInfo returns details for the server at the time the socket +// was initially acquired. +func (socket *mongoSocket) ServerInfo() *mongoServerInfo { + socket.Lock() + serverInfo := socket.serverInfo + socket.Unlock() + return serverInfo +} + // InitialAcquire obtains the first reference to the socket, either // right after the connection is made or once a recycled socket is // being put back in use. -func (socket *mongoSocket) InitialAcquire() error { +func (socket *mongoSocket) InitialAcquire(serverInfo *mongoServerInfo, timeout time.Duration) error { socket.Lock() if socket.references > 0 { panic("Socket acquired out of cache with references") @@ -125,6 +195,8 @@ func (socket *mongoSocket) InitialAcquire() error { return socket.dead } socket.references++ + socket.serverInfo = serverInfo + socket.timeout = timeout stats.socketsInUse(+1) stats.socketRefs(+1) socket.Unlock() @@ -134,20 +206,18 @@ func (socket *mongoSocket) InitialAcquire() error { // Acquire obtains an additional reference to the socket. // The socket will only be recycled when it's released as many // times as it's been acquired. -func (socket *mongoSocket) Acquire() (isMaster bool) { +func (socket *mongoSocket) Acquire() (info *mongoServerInfo) { socket.Lock() if socket.references == 0 { panic("Socket got non-initial acquire with references == 0") } - socket.references++ - stats.socketRefs(+1) // We'll track references to dead sockets as well. // Caller is still supposed to release the socket. - if socket.dead == nil { - isMaster = socket.server.IsMaster() - } + socket.references++ + stats.socketRefs(+1) + serverInfo := socket.serverInfo socket.Unlock() - return isMaster + return serverInfo } // Release decrements a socket reference. The socket will be @@ -173,6 +243,42 @@ func (socket *mongoSocket) Release() { } } +// SetTimeout changes the timeout used on socket operations. +func (socket *mongoSocket) SetTimeout(d time.Duration) { + socket.Lock() + socket.timeout = d + socket.Unlock() +} + +type deadlineType int + +const ( + readDeadline deadlineType = 1 + writeDeadline deadlineType = 2 +) + +func (socket *mongoSocket) updateDeadline(which deadlineType) { + var when time.Time + if socket.timeout > 0 { + when = time.Now().Add(socket.timeout) + } + whichstr := "" + switch which { + case readDeadline | writeDeadline: + whichstr = "read/write" + socket.conn.SetDeadline(when) + case readDeadline: + whichstr = "read" + socket.conn.SetReadDeadline(when) + case writeDeadline: + whichstr = "write" + socket.conn.SetWriteDeadline(when) + default: + panic("invalid parameter to updateDeadline") + } + debugf("Socket %p to %s: updated %s deadline to %s ahead (%s)", socket, socket.addr, whichstr, socket.timeout, when) +} + // Close terminates the socket use. func (socket *mongoSocket) Close() { socket.kill(errors.New("Closed explicitly"), false) @@ -279,7 +385,7 @@ func (socket *mongoSocket) Query(ops ...interface{}) (err error) { buf = addCString(buf, op.collection) buf = addInt32(buf, op.skip) buf = addInt32(buf, op.limit) - buf, err = addBSON(buf, op.query) + buf, err = addBSON(buf, op.finalQuery(socket)) if err != nil { return err } @@ -310,6 +416,14 @@ func (socket *mongoSocket) Query(ops ...interface{}) (err error) { return err } + case *killCursorsOp: + buf = addHeader(buf, 2007) + buf = addInt32(buf, 0) // Reserved + buf = addInt32(buf, int32(len(op.cursorIds))) + for _, cursorId := range op.cursorIds { + buf = addInt64(buf, cursorId) + } + default: panic("Internal error: unknown operation type") } @@ -329,7 +443,7 @@ func (socket *mongoSocket) Query(ops ...interface{}) (err error) { socket.Lock() if socket.dead != nil { socket.Unlock() - debug("Socket %p to %s: failing query, already closed: %s", socket, socket.addr, socket.dead.Error()) + debugf("Socket %p to %s: failing query, already closed: %s", socket, socket.addr, socket.dead.Error()) // XXX This seems necessary in case the session is closed concurrently // with a query being performed, but it's not yet tested: for i := 0; i != requestCount; i++ { @@ -341,6 +455,8 @@ func (socket *mongoSocket) Query(ops ...interface{}) (err error) { return socket.dead } + wasWaiting := len(socket.replyFuncs) > 0 + // Reserve id 0 for requests which should have no responses. requestId := socket.nextRequestId + 1 if requestId == 0 { @@ -357,12 +473,16 @@ func (socket *mongoSocket) Query(ops ...interface{}) (err error) { debugf("Socket %p to %s: sending %d op(s) (%d bytes)", socket, socket.addr, len(ops), len(buf)) stats.sentOps(len(ops)) + socket.updateDeadline(writeDeadline) _, err = socket.conn.Write(buf) + if !wasWaiting && requestCount > 0 { + socket.updateDeadline(readDeadline) + } socket.Unlock() return err } -func fill(r *net.TCPConn, b []byte) error { +func fill(r net.Conn, b []byte) error { l := len(b) n, err := r.Read(b) for n != l && err == nil { @@ -460,6 +580,12 @@ func (socket *mongoSocket) readLoop() { if replyFuncFound { delete(socket.replyFuncs, uint32(responseTo)) } + if len(socket.replyFuncs) == 0 { + // Nothing else to read for now. Disable deadline. + socket.conn.SetReadDeadline(time.Time{}) + } else { + socket.updateDeadline(readDeadline) + } socket.Unlock() // XXX Do bound checking against totalLen. diff --git a/third_party/labix.org/v2/mgo/stats.go b/third_party/labix.org/v2/mgo/stats.go index ab58263a9..59723e60c 100644 --- a/third_party/labix.org/v2/mgo/stats.go +++ b/third_party/labix.org/v2/mgo/stats.go @@ -1,18 +1,18 @@ // mgo - MongoDB driver for Go -// +// // Copyright (c) 2010-2012 - Gustavo Niemeyer -// +// // All rights reserved. // // Redistribution and use in source and binary forms, with or without -// modification, are permitted provided that the following conditions are met: -// +// modification, are permitted provided that the following conditions are met: +// // 1. Redistributions of source code must retain the above copyright notice, this -// list of conditions and the following disclaimer. +// list of conditions and the following disclaimer. // 2. Redistributions in binary form must reproduce the above copyright notice, // this list of conditions and the following disclaimer in the documentation -// and/or other materials provided with the distribution. -// +// and/or other materials provided with the distribution. +// // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND // ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED // WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE @@ -58,6 +58,7 @@ func ResetStats() { old := stats stats = &Stats{} // These are absolute values: + stats.Clusters = old.Clusters stats.SocketsInUse = old.SocketsInUse stats.SocketsAlive = old.SocketsAlive stats.SocketRefs = old.SocketRefs @@ -66,6 +67,7 @@ func ResetStats() { } type Stats struct { + Clusters int MasterConns int SlaveConns int SentOps int @@ -76,6 +78,14 @@ type Stats struct { SocketRefs int } +func (stats *Stats) cluster(delta int) { + if stats != nil { + statsMutex.Lock() + stats.Clusters += delta + statsMutex.Unlock() + } +} + func (stats *Stats) conn(delta int, master bool) { if stats != nil { statsMutex.Lock() diff --git a/third_party/labix.org/v2/mgo/suite_test.go b/third_party/labix.org/v2/mgo/suite_test.go index 19f52295a..205276b55 100644 --- a/third_party/labix.org/v2/mgo/suite_test.go +++ b/third_party/labix.org/v2/mgo/suite_test.go @@ -1,18 +1,18 @@ // mgo - MongoDB driver for Go -// +// // Copyright (c) 2010-2012 - Gustavo Niemeyer -// +// // All rights reserved. // // Redistribution and use in source and binary forms, with or without -// modification, are permitted provided that the following conditions are met: -// +// modification, are permitted provided that the following conditions are met: +// // 1. Redistributions of source code must retain the above copyright notice, this -// list of conditions and the following disclaimer. +// list of conditions and the following disclaimer. // 2. Redistributions in binary form must reproduce the above copyright notice, // this list of conditions and the following disclaimer in the documentation -// and/or other materials provided with the distribution. -// +// and/or other materials provided with the distribution. +// // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND // ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED // WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE @@ -27,15 +27,17 @@ package mgo_test import ( + "camlistore.org/third_party/labix.org/v2/mgo" + "camlistore.org/third_party/labix.org/v2/mgo/bson" + . "camlistore.org/third_party/launchpad.net/gocheck" "errors" "flag" "fmt" - . "camlistore.org/third_party/launchpad.net/gocheck" - "camlistore.org/third_party/labix.org/v2/mgo" - "camlistore.org/third_party/labix.org/v2/mgo/bson" + "net" "os/exec" + "strconv" + "syscall" - "strings" "testing" "time" ) @@ -60,6 +62,20 @@ func TestAll(t *testing.T) { type S struct { session *mgo.Session stopped bool + build mgo.BuildInfo + frozen []string +} + +func (s *S) versionAtLeast(v ...int) bool { + for i := range v { + if i == len(s.build.VersionArray) { + return false + } + if s.build.VersionArray[i] < v[i] { + return false + } + } + return true } var _ = Suite(&S{}) @@ -68,6 +84,12 @@ func (s *S) SetUpSuite(c *C) { mgo.SetDebug(true) mgo.SetStats(true) s.StartAll() + + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + s.build, err = session.BuildInfo() + c.Check(err, IsNil) + session.Close() } func (s *S) SetUpTest(c *C) { @@ -83,8 +105,14 @@ func (s *S) TearDownTest(c *C) { if s.stopped { s.StartAll() } + for _, host := range s.frozen { + if host != "" { + s.Thaw(host) + } + } + var stats mgo.Stats for i := 0; ; i++ { - stats := mgo.GetStats() + stats = mgo.GetStats() if stats.SocketsInUse == 0 && stats.SocketsAlive == 0 { break } @@ -92,24 +120,70 @@ func (s *S) TearDownTest(c *C) { c.Fatal("Test left sockets in a dirty state") } c.Logf("Waiting for sockets to die: %d in use, %d alive", stats.SocketsInUse, stats.SocketsAlive) - time.Sleep(5e8) + time.Sleep(500 * time.Millisecond) + } + for i := 0; ; i++ { + stats = mgo.GetStats() + if stats.Clusters == 0 { + break + } + if i == 60 { + c.Fatal("Test left clusters alive") + } + c.Logf("Waiting for clusters to die: %d alive", stats.Clusters) + time.Sleep(1 * time.Second) } } func (s *S) Stop(host string) { + // Give a moment for slaves to sync and avoid getting rollback issues. + time.Sleep(2 * time.Second) err := run("cd _testdb && supervisorctl stop " + supvName(host)) if err != nil { - panic(err.Error()) + panic(err) } s.stopped = true } +func (s *S) pid(host string) int { + output, err := exec.Command("lsof", "-iTCP:"+hostPort(host), "-sTCP:LISTEN", "-Fp").CombinedOutput() + if err != nil { + panic(err) + } + pidstr := string(output[1 : len(output)-1]) + pid, err := strconv.Atoi(pidstr) + if err != nil { + panic("cannot convert pid to int: " + pidstr) + } + return pid +} + +func (s *S) Freeze(host string) { + err := syscall.Kill(s.pid(host), syscall.SIGSTOP) + if err != nil { + panic(err) + } + s.frozen = append(s.frozen, host) +} + +func (s *S) Thaw(host string) { + err := syscall.Kill(s.pid(host), syscall.SIGCONT) + if err != nil { + panic(err) + } + for i, frozen := range s.frozen { + if frozen == host { + s.frozen[i] = "" + } + } +} + func (s *S) StartAll() { // Restart any stopped nodes. run("cd _testdb && supervisorctl start all") err := run("cd testdb && mongo --nodb wait.js") if err != nil { - panic(err.Error()) + panic(err) } s.stopped = false } @@ -123,29 +197,44 @@ func run(command string) error { return nil } +var supvNames = map[string]string{ + "40001": "db1", + "40002": "db2", + "40011": "rs1a", + "40012": "rs1b", + "40013": "rs1c", + "40021": "rs2a", + "40022": "rs2b", + "40023": "rs2c", + "40031": "rs3a", + "40032": "rs3b", + "40033": "rs3c", + "40041": "rs4a", + "40101": "cfg1", + "40102": "cfg2", + "40103": "cfg3", + "40201": "s1", + "40202": "s2", + "40203": "s3", +} + // supvName returns the supervisord name for the given host address. func supvName(host string) string { - switch { - case strings.HasSuffix(host, ":40001"): - return "db1" - case strings.HasSuffix(host, ":40011"): - return "rs1a" - case strings.HasSuffix(host, ":40012"): - return "rs1b" - case strings.HasSuffix(host, ":40013"): - return "rs1c" - case strings.HasSuffix(host, ":40021"): - return "rs2a" - case strings.HasSuffix(host, ":40022"): - return "rs2b" - case strings.HasSuffix(host, ":40023"): - return "rs2c" - case strings.HasSuffix(host, ":40101"): - return "cfg1" - case strings.HasSuffix(host, ":40201"): - return "s1" - case strings.HasSuffix(host, ":40202"): - return "s2" + host, port, err := net.SplitHostPort(host) + if err != nil { + panic(err) } - panic("Unknown host: " + host) + name, ok := supvNames[port] + if !ok { + panic("Unknown host: " + host) + } + return name +} + +func hostPort(host string) string { + _, port, err := net.SplitHostPort(host) + if err != nil { + panic(err) + } + return port }