@@ -89,9 +89,10 @@ type Server struct {
8989 sem * semaphore.Weighted
9090
9191 // goroutine management fields
92- done chan struct {}
93- checkNow chan struct {}
94- closewg sync.WaitGroup
92+ done chan struct {}
93+ checkNow chan struct {}
94+ disconnecting chan struct {}
95+ closewg sync.WaitGroup
9596
9697 // description related fields
9798 desc atomic.Value // holds a description.Server
@@ -139,8 +140,9 @@ func NewServer(addr address.Address, opts ...ServerOption) (*Server, error) {
139140
140141 sem : semaphore .NewWeighted (int64 (maxConns )),
141142
142- done : make (chan struct {}),
143- checkNow : make (chan struct {}, 1 ),
143+ done : make (chan struct {}),
144+ checkNow : make (chan struct {}, 1 ),
145+ disconnecting : make (chan struct {}),
144146
145147 subscribers : make (map [uint64 ]chan description.Server ),
146148 }
@@ -193,7 +195,14 @@ func (s *Server) Disconnect(ctx context.Context) error {
193195
194196 // For every call to Connect there must be at least 1 goroutine that is
195197 // waiting on the done channel.
196- s .done <- struct {}{}
198+ select {
199+ case <- ctx .Done ():
200+ // signal a disconnect and still wait for receiver of done
201+ // to finish.
202+ close (s .disconnecting )
203+ s .done <- struct {}{}
204+ case s .done <- struct {}{}:
205+ }
197206 err := s .pool .disconnect (ctx )
198207 if err != nil {
199208 return err
@@ -398,6 +407,13 @@ func (s *Server) update() {
398407 conn .nc .Close ()
399408 }
400409 for {
410+ select {
411+ case <- done :
412+ closeServer ()
413+ return
414+ default :
415+ }
416+
401417 select {
402418 case <- heartbeatTicker .C :
403419 case <- checkNow :
@@ -463,7 +479,15 @@ func (s *Server) heartbeat(conn *connection) (description.Server, *connection) {
463479 var desc description.Server
464480 var set bool
465481 var err error
466- ctx := context .Background ()
482+ ctx , cancel := context .WithCancel (context .Background ())
483+ defer cancel ()
484+ go func () {
485+ select {
486+ case <- ctx .Done ():
487+ case <- s .disconnecting :
488+ cancel ()
489+ }
490+ }()
467491
468492 for i := 1 ; i <= maxRetry ; i ++ {
469493 var now time.Time
@@ -499,7 +523,7 @@ func (s *Server) heartbeat(conn *connection) (description.Server, *connection) {
499523
500524 conn .connect (ctx )
501525
502- err : = conn .wait ()
526+ err = conn .wait ()
503527 if err == nil {
504528 descPtr = & conn .desc
505529 }
0 commit comments