@@ -101,7 +101,7 @@ import Control.Concurrent.MVar
101
101
)
102
102
import Control.Category ((>>>) )
103
103
import Control.Applicative ((<$>) )
104
- import Control.Monad (when , unless , join , mplus )
104
+ import Control.Monad (when , unless , join , mplus , (<=<) )
105
105
import Control.Exception
106
106
( IOException
107
107
, SomeException
@@ -112,9 +112,12 @@ import Control.Exception
112
112
, try
113
113
, bracketOnError
114
114
, fromException
115
+ , finally
115
116
, catch
117
+ , bracket
118
+ , mask_
116
119
)
117
- import Data.IORef (IORef , newIORef , writeIORef , readIORef )
120
+ import Data.IORef (IORef , newIORef , writeIORef , readIORef , writeIORef )
118
121
import Data.ByteString (ByteString )
119
122
import qualified Data.ByteString as BS (concat )
120
123
import qualified Data.ByteString.Char8 as BSC (pack , unpack )
@@ -132,6 +135,7 @@ import qualified Data.Set as Set
132
135
)
133
136
import Data.Map (Map )
134
137
import qualified Data.Map as Map (empty )
138
+ import Data.Traversable (traverse )
135
139
import Data.Accessor (Accessor , accessor , (^.) , (^=) , (^:) )
136
140
import qualified Data.Accessor.Container as DAC (mapMaybe )
137
141
import Data.Foldable (forM_ , mapM_ )
@@ -617,26 +621,25 @@ apiConnect params ourEndPoint theirAddress _reliability hints =
617
621
-- | Close a connection
618
622
apiClose :: EndPointPair -> LightweightConnectionId -> IORef Bool -> IO ()
619
623
apiClose (ourEndPoint, theirEndPoint) connId connAlive =
620
- void . tryIO . asyncWhenCancelled return $ do
621
- mAct <- modifyMVar (remoteState theirEndPoint) $ \ st -> case st of
622
- RemoteEndPointValid vst -> do
623
- alive <- readIORef connAlive
624
- if alive
625
- then do
626
- writeIORef connAlive False
627
- act <- schedule theirEndPoint $
628
- sendOn vst [encodeInt32 CloseConnection , encodeInt32 connId]
629
- return ( RemoteEndPointValid
630
- . (remoteOutgoing ^: (\ x -> x - 1 ))
631
- $ vst
632
- , Just act
633
- )
634
- else
635
- return (RemoteEndPointValid vst, Nothing )
636
- _ ->
637
- return (st, Nothing )
638
- forM_ mAct $ runScheduledAction (ourEndPoint, theirEndPoint)
639
- closeIfUnused (ourEndPoint, theirEndPoint)
624
+ void . tryIO . asyncWhenCancelled return $ finally
625
+ (withScheduledAction ourEndPoint $ \ sched -> do
626
+ modifyMVar_ (remoteState theirEndPoint) $ \ st -> case st of
627
+ RemoteEndPointValid vst -> do
628
+ alive <- readIORef connAlive
629
+ if alive
630
+ then do
631
+ writeIORef connAlive False
632
+ sched theirEndPoint $
633
+ sendOn vst [encodeInt32 CloseConnection , encodeInt32 connId]
634
+ return ( RemoteEndPointValid
635
+ . (remoteOutgoing ^: (\ x -> x - 1 ))
636
+ $ vst
637
+ )
638
+ else
639
+ return (RemoteEndPointValid vst)
640
+ _ ->
641
+ return st)
642
+ (closeIfUnused (ourEndPoint, theirEndPoint))
640
643
641
644
642
645
-- | Send data across a connection
@@ -647,16 +650,16 @@ apiSend :: EndPointPair -- ^ Local and remote endpoint
647
650
-> IO (Either (TransportError SendErrorCode ) () )
648
651
apiSend (ourEndPoint, theirEndPoint) connId connAlive payload =
649
652
-- We don't need the overhead of asyncWhenCancelled here
650
- try . mapIOException sendFailed $ do
651
- act <- withMVar (remoteState theirEndPoint) $ \ st -> case st of
653
+ try . mapIOException sendFailed $ withScheduledAction ourEndPoint $ \ sched -> do
654
+ withMVar (remoteState theirEndPoint) $ \ st -> case st of
652
655
RemoteEndPointInvalid _ ->
653
656
relyViolation (ourEndPoint, theirEndPoint) " apiSend"
654
657
RemoteEndPointInit _ _ _ ->
655
658
relyViolation (ourEndPoint, theirEndPoint) " apiSend"
656
659
RemoteEndPointValid vst -> do
657
660
alive <- readIORef connAlive
658
661
if alive
659
- then schedule theirEndPoint $
662
+ then sched theirEndPoint $
660
663
sendOn vst (encodeInt32 connId : prependLength payload)
661
664
else throwIO $ TransportError SendClosed " Connection closed"
662
665
RemoteEndPointClosing _ _ -> do
@@ -674,7 +677,6 @@ apiSend (ourEndPoint, theirEndPoint) connId connAlive payload =
674
677
if alive
675
678
then throwIO $ TransportError SendFailed (show err)
676
679
else throwIO $ TransportError SendClosed " Connection closed"
677
- runScheduledAction (ourEndPoint, theirEndPoint) act
678
680
where
679
681
sendFailed = TransportError SendFailed . show
680
682
@@ -700,33 +702,32 @@ apiCloseEndPoint transport evs ourEndPoint =
700
702
where
701
703
-- Close the remote socket and return the set of all incoming connections
702
704
tryCloseRemoteSocket :: RemoteEndPoint -> IO ()
703
- tryCloseRemoteSocket theirEndPoint = do
705
+ tryCloseRemoteSocket theirEndPoint = withScheduledAction ourEndPoint $ \ sched -> do
704
706
-- We make an attempt to close the connection nicely
705
707
-- (by sending a CloseSocket first)
706
708
let closed = RemoteEndPointFailed . userError $ " apiCloseEndPoint"
707
- mAct <- modifyMVar (remoteState theirEndPoint) $ \ st ->
709
+ modifyMVar_ (remoteState theirEndPoint) $ \ st ->
708
710
case st of
709
711
RemoteEndPointInvalid _ ->
710
- return (st, Nothing )
712
+ return st
711
713
RemoteEndPointInit resolved _ _ -> do
712
714
putMVar resolved ()
713
- return ( closed, Nothing )
715
+ return closed
714
716
RemoteEndPointValid vst -> do
715
- act <- schedule theirEndPoint $ do
717
+ sched theirEndPoint $ do
716
718
tryIO $ sendOn vst [ encodeInt32 CloseSocket
717
719
, encodeInt32 (vst ^. remoteMaxIncoming)
718
720
]
719
721
tryCloseSocket (remoteSocket vst)
720
- return ( closed, Just act)
722
+ return closed
721
723
RemoteEndPointClosing resolved vst -> do
722
724
putMVar resolved ()
723
- act <- schedule theirEndPoint $ tryCloseSocket (remoteSocket vst)
724
- return ( closed, Just act)
725
+ sched theirEndPoint $ tryCloseSocket (remoteSocket vst)
726
+ return closed
725
727
RemoteEndPointClosed ->
726
- return (st, Nothing )
728
+ return st
727
729
RemoteEndPointFailed err ->
728
- return (RemoteEndPointFailed err, Nothing )
729
- forM_ mAct $ runScheduledAction (ourEndPoint, theirEndPoint)
730
+ return (RemoteEndPointFailed err)
730
731
731
732
732
733
--------------------------------------------------------------------------------
@@ -1486,6 +1487,15 @@ runScheduledAction (ourEndPoint, theirEndPoint) mvar = do
1486
1487
writeChan (localChannel ourEndPoint) $ ErrorEvent err
1487
1488
return (RemoteEndPointFailed ex)
1488
1489
1490
+ -- | Use 'schedule' action 'runScheduled' action in a safe way, it's assumed that
1491
+ -- callback is used only once, otherwise guarantees of runScheduledAction are not
1492
+ -- respected.
1493
+ withScheduledAction :: LocalEndPoint -> ((RemoteEndPoint -> IO a -> IO () ) -> IO () ) -> IO ()
1494
+ withScheduledAction ourEndPoint f =
1495
+ bracket (newIORef Nothing )
1496
+ (traverse (\ (tp, a) -> runScheduledAction (ourEndPoint, tp) a) <=< readIORef)
1497
+ (\ ref -> f (\ rp g -> mask_ $ schedule rp g >>= \ x -> writeIORef ref (Just (rp,x)) ))
1498
+
1489
1499
--------------------------------------------------------------------------------
1490
1500
-- "Stateless" (MVar free) functions --
1491
1501
--------------------------------------------------------------------------------
0 commit comments