diff --git a/cluster.go b/cluster.go index 4ed29bea1..a0da8d3bb 100644 --- a/cluster.go +++ b/cluster.go @@ -33,9 +33,10 @@ import ( "sync" "time" - "gopkg.in/mgo.v2-unstable/bson" "strconv" "strings" + + "gopkg.in/mgo.v2-unstable/bson" ) // --------------------------------------------------------------------------- @@ -122,10 +123,10 @@ func (cluster *mongoCluster) removeServer(server *mongoServer) { other := cluster.servers.Remove(server) cluster.Unlock() if other != nil { - other.Close() + other.CloseIdle() log("Removed server ", server.Addr, " from cluster.") } - server.Close() + server.CloseIdle() } type isMasterResult struct { diff --git a/cluster_test.go b/cluster_test.go index 524acbc93..ea09feb3f 100644 --- a/cluster_test.go +++ b/cluster_test.go @@ -27,11 +27,13 @@ package mgo_test import ( + "errors" "fmt" "io" "net" "strings" "sync" + "sync/atomic" "time" . "gopkg.in/check.v1" @@ -1477,7 +1479,6 @@ func (s *S) TestSecondaryModeWithMongosInsert(c *C) { c.Assert(result.A, Equals, 1) } - func (s *S) TestRemovalOfClusterMember(c *C) { if *fast { c.Skip("-fast") @@ -2088,3 +2089,72 @@ func (s *S) TestDoNotFallbackToMonotonic(c *C) { c.Assert(q13b, Equals, q13a) } } + +func (s *S) TestConnectServerFailed(c *C) { + dials := int32(0) + maxDials := 50 + info := &mgo.DialInfo{ + Addrs: []string{"localhost:40001"}, + DialServer: func(addr *mgo.ServerAddr) (net.Conn, error) { + n := atomic.AddInt32(&dials, 1) + if n == int32(maxDials/2) { + return nil, errors.New("expected dial failed") + } + return net.Dial("tcp", addr.String()) + }, + } + + session, err := mgo.DialWithInfo(info) + c.Assert(err, IsNil) + defer session.Close() + + mgo.ResetStats() + + errs := make(chan error, 1) + var done int32 + var finished sync.WaitGroup + var starting sync.WaitGroup + defer func() { + atomic.StoreInt32(&done, 1) + finished.Wait() + }() + for i := 0; i < maxDials; i++ { + finished.Add(1) + starting.Add(1) + go func(s0 *mgo.Session) { + defer finished.Done() + for i := 0; ; i++ { + if atomic.LoadInt32(&done) == 1 { + break + } + err := func(s0 *mgo.Session) error { + s := s0.Copy() + defer s.Close() + coll := s.DB("mydb").C("mycoll") + + var ret []interface{} + return coll.Find(nil).All(&ret) + }(s0) + if err != nil { + select { + case errs <- err: + default: + } + } + if i == 0 { + starting.Done() + } + } + }(session) + time.Sleep(10 * time.Millisecond) + } + starting.Wait() + + // no errors expect. + var opErr error + select { + case opErr = <-errs: + default: + } + c.Assert(opErr, IsNil) +} diff --git a/server.go b/server.go index ba0480e58..8ae789aa5 100644 --- a/server.go +++ b/server.go @@ -187,6 +187,16 @@ func (server *mongoServer) Connect(timeout time.Duration) (*mongoSocket, error) // Close forces closing all sockets that are alive, whether // they're currently in use or not. func (server *mongoServer) Close() { + server.close(false) +} + +// CloseIdle closing all sockets that are idle, +// sockets currently in use will be closed after idle. +func (server *mongoServer) CloseIdle() { + server.close(true) +} + +func (server *mongoServer) close(waitForIdle bool) { server.Lock() server.closed = true liveSockets := server.liveSockets @@ -196,7 +206,11 @@ func (server *mongoServer) Close() { server.Unlock() logf("Connections to %s closing (%d live sockets).", server.Addr, len(liveSockets)) for i, s := range liveSockets { - s.Close() + if waitForIdle { + s.CloseAfterIdle() + } else { + s.Close() + } liveSockets[i] = nil } for i := range unusedSockets { diff --git a/socket.go b/socket.go index a2343354d..9783136e0 100644 --- a/socket.go +++ b/socket.go @@ -40,19 +40,20 @@ type replyFunc func(err error, reply *replyOp, docNum int, docData []byte) type mongoSocket struct { sync.Mutex - server *mongoServer // nil when cached - conn net.Conn - timeout time.Duration - addr string // For debugging only. - nextRequestId uint32 - replyFuncs map[uint32]replyFunc - references int - creds []Credential - logout []Credential - cachedNonce string - gotNonce sync.Cond - dead error - serverInfo *mongoServerInfo + server *mongoServer // nil when cached + conn net.Conn + timeout time.Duration + addr string // For debugging only. + nextRequestId uint32 + replyFuncs map[uint32]replyFunc + references int + creds []Credential + logout []Credential + cachedNonce string + gotNonce sync.Cond + dead error + serverInfo *mongoServerInfo + closeAfterIdle bool } type queryOpFlags uint32 @@ -264,10 +265,13 @@ func (socket *mongoSocket) Release() { if socket.references == 0 { stats.socketsInUse(-1) server := socket.server + closeAfterIdle := socket.closeAfterIdle socket.Unlock() socket.LogoutAll() - // If the socket is dead server is nil. - if server != nil { + if closeAfterIdle { + socket.Close() + } else if server != nil { + // If the socket is dead server is nil. server.RecycleSocket(socket) } } else { @@ -316,6 +320,21 @@ func (socket *mongoSocket) Close() { socket.kill(errors.New("Closed explicitly"), false) } +// CloseAfterIdle terminates an idle socket, which has a zero +// reference, or marks the socket to be terminate after idle. +func (socket *mongoSocket) CloseAfterIdle() { + socket.Lock() + if socket.references == 0 { + socket.Unlock() + socket.Close() + logf("Socket %p to %s: idle and close.", socket, socket.addr) + return + } + socket.closeAfterIdle = true + socket.Unlock() + logf("Socket %p to %s: close after idle.", socket, socket.addr) +} + func (socket *mongoSocket) kill(err error, abend bool) { socket.Lock() if socket.dead != nil {