Skip to content
This repository was archived by the owner on Sep 3, 2024. It is now read-only.

Commit 9ec9c1a

Browse files
committed
Make schedule - runScheduled pair exception safe.
Instead of using schedule/runScheduled actions separately, because they are exceptions unsafe, we introduce a `withScheduledAction` function that cares about the safety by providing correct finalizers. X-Bug-URL: https://cloud-haskell.atlassian.net/browse/DP-109
1 parent 2e9d6da commit 9ec9c1a

File tree

1 file changed

+47
-37
lines changed

1 file changed

+47
-37
lines changed

src/Network/Transport/TCP.hs

Lines changed: 47 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ import Control.Concurrent.MVar
101101
)
102102
import Control.Category ((>>>))
103103
import Control.Applicative ((<$>))
104-
import Control.Monad (when, unless, join, mplus)
104+
import Control.Monad (when, unless, join, mplus, (<=<))
105105
import Control.Exception
106106
( IOException
107107
, SomeException
@@ -112,9 +112,12 @@ import Control.Exception
112112
, try
113113
, bracketOnError
114114
, fromException
115+
, finally
115116
, catch
117+
, bracket
118+
, mask_
116119
)
117-
import Data.IORef (IORef, newIORef, writeIORef, readIORef)
120+
import Data.IORef (IORef, newIORef, writeIORef, readIORef, writeIORef)
118121
import Data.ByteString (ByteString)
119122
import qualified Data.ByteString as BS (concat)
120123
import qualified Data.ByteString.Char8 as BSC (pack, unpack)
@@ -132,6 +135,7 @@ import qualified Data.Set as Set
132135
)
133136
import Data.Map (Map)
134137
import qualified Data.Map as Map (empty)
138+
import Data.Traversable (traverse)
135139
import Data.Accessor (Accessor, accessor, (^.), (^=), (^:))
136140
import qualified Data.Accessor.Container as DAC (mapMaybe)
137141
import Data.Foldable (forM_, mapM_)
@@ -617,26 +621,25 @@ apiConnect params ourEndPoint theirAddress _reliability hints =
617621
-- | Close a connection
618622
apiClose :: EndPointPair -> LightweightConnectionId -> IORef Bool -> IO ()
619623
apiClose (ourEndPoint, theirEndPoint) connId connAlive =
620-
void . tryIO . asyncWhenCancelled return $ do
621-
mAct <- modifyMVar (remoteState theirEndPoint) $ \st -> case st of
622-
RemoteEndPointValid vst -> do
623-
alive <- readIORef connAlive
624-
if alive
625-
then do
626-
writeIORef connAlive False
627-
act <- schedule theirEndPoint $
628-
sendOn vst [encodeInt32 CloseConnection, encodeInt32 connId]
629-
return ( RemoteEndPointValid
630-
. (remoteOutgoing ^: (\x -> x - 1))
631-
$ vst
632-
, Just act
633-
)
634-
else
635-
return (RemoteEndPointValid vst, Nothing)
636-
_ ->
637-
return (st, Nothing)
638-
forM_ mAct $ runScheduledAction (ourEndPoint, theirEndPoint)
639-
closeIfUnused (ourEndPoint, theirEndPoint)
624+
void . tryIO . asyncWhenCancelled return $ finally
625+
(withScheduledAction ourEndPoint $ \sched -> do
626+
modifyMVar_ (remoteState theirEndPoint) $ \st -> case st of
627+
RemoteEndPointValid vst -> do
628+
alive <- readIORef connAlive
629+
if alive
630+
then do
631+
writeIORef connAlive False
632+
sched theirEndPoint $
633+
sendOn vst [encodeInt32 CloseConnection, encodeInt32 connId]
634+
return ( RemoteEndPointValid
635+
. (remoteOutgoing ^: (\x -> x - 1))
636+
$ vst
637+
)
638+
else
639+
return (RemoteEndPointValid vst)
640+
_ ->
641+
return st)
642+
(closeIfUnused (ourEndPoint, theirEndPoint))
640643

641644

642645
-- | Send data across a connection
@@ -647,16 +650,16 @@ apiSend :: EndPointPair -- ^ Local and remote endpoint
647650
-> IO (Either (TransportError SendErrorCode) ())
648651
apiSend (ourEndPoint, theirEndPoint) connId connAlive payload =
649652
-- We don't need the overhead of asyncWhenCancelled here
650-
try . mapIOException sendFailed $ do
651-
act <- withMVar (remoteState theirEndPoint) $ \st -> case st of
653+
try . mapIOException sendFailed $ withScheduledAction ourEndPoint $ \sched -> do
654+
withMVar (remoteState theirEndPoint) $ \st -> case st of
652655
RemoteEndPointInvalid _ ->
653656
relyViolation (ourEndPoint, theirEndPoint) "apiSend"
654657
RemoteEndPointInit _ _ _ ->
655658
relyViolation (ourEndPoint, theirEndPoint) "apiSend"
656659
RemoteEndPointValid vst -> do
657660
alive <- readIORef connAlive
658661
if alive
659-
then schedule theirEndPoint $
662+
then sched theirEndPoint $
660663
sendOn vst (encodeInt32 connId : prependLength payload)
661664
else throwIO $ TransportError SendClosed "Connection closed"
662665
RemoteEndPointClosing _ _ -> do
@@ -674,7 +677,6 @@ apiSend (ourEndPoint, theirEndPoint) connId connAlive payload =
674677
if alive
675678
then throwIO $ TransportError SendFailed (show err)
676679
else throwIO $ TransportError SendClosed "Connection closed"
677-
runScheduledAction (ourEndPoint, theirEndPoint) act
678680
where
679681
sendFailed = TransportError SendFailed . show
680682

@@ -700,33 +702,32 @@ apiCloseEndPoint transport evs ourEndPoint =
700702
where
701703
-- Close the remote socket and return the set of all incoming connections
702704
tryCloseRemoteSocket :: RemoteEndPoint -> IO ()
703-
tryCloseRemoteSocket theirEndPoint = do
705+
tryCloseRemoteSocket theirEndPoint = withScheduledAction ourEndPoint $ \sched -> do
704706
-- We make an attempt to close the connection nicely
705707
-- (by sending a CloseSocket first)
706708
let closed = RemoteEndPointFailed . userError $ "apiCloseEndPoint"
707-
mAct <- modifyMVar (remoteState theirEndPoint) $ \st ->
709+
modifyMVar_ (remoteState theirEndPoint) $ \st ->
708710
case st of
709711
RemoteEndPointInvalid _ ->
710-
return (st, Nothing)
712+
return st
711713
RemoteEndPointInit resolved _ _ -> do
712714
putMVar resolved ()
713-
return (closed, Nothing)
715+
return closed
714716
RemoteEndPointValid vst -> do
715-
act <- schedule theirEndPoint $ do
717+
sched theirEndPoint $ do
716718
tryIO $ sendOn vst [ encodeInt32 CloseSocket
717719
, encodeInt32 (vst ^. remoteMaxIncoming)
718720
]
719721
tryCloseSocket (remoteSocket vst)
720-
return (closed, Just act)
722+
return closed
721723
RemoteEndPointClosing resolved vst -> do
722724
putMVar resolved ()
723-
act <- schedule theirEndPoint $ tryCloseSocket (remoteSocket vst)
724-
return (closed, Just act)
725+
sched theirEndPoint $ tryCloseSocket (remoteSocket vst)
726+
return closed
725727
RemoteEndPointClosed ->
726-
return (st, Nothing)
728+
return st
727729
RemoteEndPointFailed err ->
728-
return (RemoteEndPointFailed err, Nothing)
729-
forM_ mAct $ runScheduledAction (ourEndPoint, theirEndPoint)
730+
return (RemoteEndPointFailed err)
730731

731732

732733
--------------------------------------------------------------------------------
@@ -1486,6 +1487,15 @@ runScheduledAction (ourEndPoint, theirEndPoint) mvar = do
14861487
writeChan (localChannel ourEndPoint) $ ErrorEvent err
14871488
return (RemoteEndPointFailed ex)
14881489

1490+
-- | Use 'schedule' action 'runScheduled' action in a safe way, it's assumed that
1491+
-- callback is used only once, otherwise guarantees of runScheduledAction are not
1492+
-- respected.
1493+
withScheduledAction :: LocalEndPoint -> ((RemoteEndPoint -> IO a -> IO ()) -> IO ()) -> IO ()
1494+
withScheduledAction ourEndPoint f =
1495+
bracket (newIORef Nothing)
1496+
(traverse (\(tp, a) -> runScheduledAction (ourEndPoint, tp) a) <=< readIORef)
1497+
(\ref -> f (\rp g -> mask_ $ schedule rp g >>= \x -> writeIORef ref (Just (rp,x)) ))
1498+
14891499
--------------------------------------------------------------------------------
14901500
-- "Stateless" (MVar free) functions --
14911501
--------------------------------------------------------------------------------

0 commit comments

Comments
 (0)