From 49ff6e933caf79574d29f5eeff387dff313cafb6 Mon Sep 17 00:00:00 2001 From: Vincent Hanquez Date: Tue, 30 Jul 2013 08:58:58 +0100 Subject: [PATCH] remove Rx state from general state. move RxState as a mutable mvar in the context directly. --- core/Network/TLS/Context.hs | 14 +++++++++++++ core/Network/TLS/Handshake/Common.hs | 1 - core/Network/TLS/IO.hs | 13 ++++++------ core/Network/TLS/Receiving.hs | 31 ++++++++++++++++++---------- core/Network/TLS/State.hs | 16 +------------- 5 files changed, 42 insertions(+), 33 deletions(-) diff --git a/core/Network/TLS/Context.hs b/core/Network/TLS/Context.hs index 30ac4c9..80e5666 100644 --- a/core/Network/TLS/Context.hs +++ b/core/Network/TLS/Context.hs @@ -43,6 +43,7 @@ module Network.TLS.Context , ctxEstablished , ctxLogging , ctxWithHooks + , ctxRxState , setEOF , setEstablished , contextFlush @@ -76,6 +77,7 @@ module Network.TLS.Context , throwCore , usingState , usingState_ + , runRxState , usingHState , getStateRNG ) where @@ -90,6 +92,7 @@ import Network.TLS.Compression import Network.TLS.Crypto import Network.TLS.State import Network.TLS.Handshake.State +import Network.TLS.Record.State import Network.TLS.Measurement import Network.TLS.X509 import Network.TLS.Types (Role(..)) @@ -312,6 +315,7 @@ data Context = Context , ctxSSLv2ClientHello :: IORef Bool -- ^ enable the reception of compatibility SSLv2 client hello. -- the flag will be set to false regardless of its initial value -- after the first packet received. + , ctxRxState :: MVar RecordState -- ^ current rx state , ctxHooks :: IORef Hooks -- ^ hooks for this context , ctxLockWrite :: MVar () -- ^ lock to use for writing data (including updating the state) , ctxLockRead :: MVar () -- ^ lock to use for reading data (including updating the state) @@ -390,6 +394,7 @@ contextNew backend params rng = liftIO $ do -- server context, where we might be dealing with an old/compat client. sslv2Compat <- newIORef (role == ServerRole) hooks <- newIORef defaultHooks + rx <- newMVar newRecordState lockWrite <- newMVar () lockRead <- newMVar () lockState <- newMVar () @@ -397,6 +402,7 @@ contextNew backend params rng = liftIO $ do { ctxConnection = backend , ctxParams = params , ctxState = stvar + , ctxRxState = rx , ctxMeasurement = stats , ctxEOF_ = eof , ctxEstablished_ = established @@ -440,6 +446,14 @@ usingState_ ctx f = do usingHState :: MonadIO m => Context -> HandshakeM a -> m a usingHState ctx f = usingState_ ctx $ withHandshakeM f +runRxState :: MonadIO m => Context -> RecordM a -> m (Either TLSError a) +runRxState ctx f = do + ver <- usingState_ ctx getVersion + liftIO $ modifyMVar (ctxRxState ctx) $ \st -> + case runRecordM f ver st of + Left err -> return (st, Left err) + Right (a, newSt) -> return (newSt, Right a) + getStateRNG :: MonadIO m => Context -> Int -> m Bytes getStateRNG ctx n = usingState_ ctx $ genRandom n diff --git a/core/Network/TLS/Handshake/Common.hs b/core/Network/TLS/Handshake/Common.hs index fcfd91c..8840399 100644 --- a/core/Network/TLS/Handshake/Common.hs +++ b/core/Network/TLS/Handshake/Common.hs @@ -20,7 +20,6 @@ import Network.TLS.Session import Network.TLS.Struct import Network.TLS.IO import Network.TLS.State hiding (getNegotiatedProtocol) -import Network.TLS.Receiving import Network.TLS.Handshake.Process import Network.TLS.Measurement import Network.TLS.Types diff --git a/core/Network/TLS/IO.hs b/core/Network/TLS/IO.hs index bf5f6a5..f519314 100644 --- a/core/Network/TLS/IO.hs +++ b/core/Network/TLS/IO.hs @@ -15,7 +15,7 @@ module Network.TLS.IO ) where import Network.TLS.Context -import Network.TLS.State (needEmptyPacket, runRxState) +import Network.TLS.State (needEmptyPacket) import Network.TLS.Struct import Network.TLS.Record import Network.TLS.Packet @@ -68,7 +68,7 @@ recvRecord compatSSLv2 ctx | otherwise = readExact ctx 5 >>= either (return . Left) recvLength . decodeHeader where recvLength header@(Header _ _ readlen) | readlen > 16384 + 2048 = return $ Left maximumSizeExceeded - | otherwise = readExact ctx (fromIntegral readlen) >>= makeRecord header + | otherwise = readExact ctx (fromIntegral readlen) >>= getRecord header #ifdef SSLV2_COMPATIBLE recvDeprecatedLength readlen | readlen > 1024 * 4 = return $ Left maximumSizeExceeded @@ -76,12 +76,13 @@ recvRecord compatSSLv2 ctx content <- readExact ctx (fromIntegral readlen) case decodeDeprecatedHeader readlen content of Left err -> return $ Left err - Right header -> makeRecord header content + Right header -> getRecord header content #endif maximumSizeExceeded = Error_Protocol ("record exceeding maximum size", True, RecordOverflow) - makeRecord header content = do + getRecord :: MonadIO m => Header -> Bytes -> m (Either TLSError (Record Plaintext)) + getRecord header content = do liftIO $ (loggingIORecv $ ctxLogging ctx) header content - usingState ctx $ runRxState $ disengageRecord $ rawToRecord header (fragmentCiphertext content) + runRxState ctx $ disengageRecord $ rawToRecord header (fragmentCiphertext content) -- | receive one packet from the context that contains 1 or @@ -94,7 +95,7 @@ recvPacket ctx = do case erecord of Left err -> return $ Left err Right record -> do - pktRecv <- usingState ctx $ processPacket record + pktRecv <- processPacket ctx record pkt <- case pktRecv of Right (Handshake hss) -> ctxWithHooks ctx $ \hooks -> diff --git a/core/Network/TLS/Receiving.hs b/core/Network/TLS/Receiving.hs index d5dd2d8..9fd197e 100644 --- a/core/Network/TLS/Receiving.hs +++ b/core/Network/TLS/Receiving.hs @@ -15,29 +15,33 @@ module Network.TLS.Receiving import Control.Applicative ((<$>)) import Control.Monad.State import Control.Monad.Error +import Control.Concurrent.MVar +import Network.TLS.Context import Network.TLS.Struct import Network.TLS.Record import Network.TLS.Packet import Network.TLS.State import Network.TLS.Cipher +import Network.TLS.Util returnEither :: Either TLSError a -> TLSSt a returnEither (Left err) = throwError err returnEither (Right a) = return a -processPacket :: Record Plaintext -> TLSSt Packet +processPacket :: MonadIO m => Context -> Record Plaintext -> m (Either TLSError Packet) -processPacket (Record ProtocolType_AppData _ fragment) = return $ AppData $ fragmentGetBytes fragment +processPacket _ (Record ProtocolType_AppData _ fragment) = return $ Right $ AppData $ fragmentGetBytes fragment -processPacket (Record ProtocolType_Alert _ fragment) = return . Alert =<< returnEither (decodeAlerts $ fragmentGetBytes fragment) +processPacket _ (Record ProtocolType_Alert _ fragment) = return (Alert `fmapEither` (decodeAlerts $ fragmentGetBytes fragment)) -processPacket (Record ProtocolType_ChangeCipherSpec _ fragment) = do - returnEither $ decodeChangeCipherSpec $ fragmentGetBytes fragment - switchRxEncryption - return ChangeCipherSpec +processPacket ctx (Record ProtocolType_ChangeCipherSpec _ fragment) = + case decodeChangeCipherSpec $ fragmentGetBytes fragment of + Left err -> return $ Left err + Right _ -> do switchRxEncryption ctx + return $ Right ChangeCipherSpec -processPacket (Record ProtocolType_Handshake ver fragment) = do +processPacket ctx (Record ProtocolType_Handshake ver fragment) = usingState ctx $ do keyxchg <- gets (\st -> case stHandshake st of Nothing -> Nothing Just hst -> cipherKeyExchange <$> hstPendingCipher hst) @@ -54,7 +58,12 @@ processPacket (Record ProtocolType_Handshake ver fragment) = do Right hs -> return hs return $ Handshake hss -processPacket (Record ProtocolType_DeprecatedHandshake _ fragment) = +processPacket _ (Record ProtocolType_DeprecatedHandshake _ fragment) = case decodeDeprecatedHandshake $ fragmentGetBytes fragment of - Left err -> throwError err - Right hs -> return $ Handshake [hs] + Left err -> return $ Left err + Right hs -> return $ Right $ Handshake [hs] + +switchRxEncryption :: MonadIO m => Context -> m () +switchRxEncryption ctx = + usingHState ctx (gets hstPendingRxState) >>= \rx -> + liftIO $ modifyMVar_ (ctxRxState ctx) (\_ -> return $ fromJust "rx-state" rx) diff --git a/core/Network/TLS/State.hs b/core/Network/TLS/State.hs index 54724a5..95e6dc8 100644 --- a/core/Network/TLS/State.hs +++ b/core/Network/TLS/State.hs @@ -18,7 +18,6 @@ module Network.TLS.State , newTLSState , withTLSRNG , runTxState - , runRxState , assert -- FIXME move somewhere else (Internal.hs ?) , updateVerifiedData , finishHandshakeTypeMaterial @@ -44,7 +43,6 @@ module Network.TLS.State , isSessionResuming , needEmptyPacket , switchTxEncryption - , switchRxEncryption , isClientContext , startHandshakeClient , getHandshakeDigest @@ -79,7 +77,6 @@ data TLSState = TLSState , stSession :: Session , stSessionResuming :: Bool , stTxState :: RecordState - , stRxState :: RecordState , stSecureRenegotiation :: Bool -- RFC 5746 , stClientVerifiedData :: Bytes -- RFC 5746 , stServerVerifiedData :: Bytes -- RFC 5746 @@ -115,20 +112,12 @@ runTxState f = do Left err -> throwError err Right (a, newSt) -> put (st { stTxState = newSt }) >> return a -runRxState :: RecordM a -> TLSSt a -runRxState f = do - st <- get - case runRecordM f (stVersion st) (stRxState st) of - Left err -> throwError err - Right (a, newSt) -> put (st { stRxState = newSt }) >> return a - newTLSState :: CPRG g => g -> Role -> TLSState newTLSState rng clientContext = TLSState { stHandshake = Nothing , stSession = Session Nothing , stSessionResuming = False , stTxState = newRecordState - , stRxState = newRecordState , stSecureRenegotiation = False , stClientVerifiedData = B.empty , stServerVerifiedData = B.empty @@ -180,13 +169,10 @@ certVerifyHandshakeTypeMaterial HandshakeType_NPN = False certVerifyHandshakeMaterial :: Handshake -> Bool certVerifyHandshakeMaterial = certVerifyHandshakeTypeMaterial . typeOfHandshake -switchTxEncryption, switchRxEncryption :: TLSSt () +switchTxEncryption :: TLSSt () switchTxEncryption = withHandshakeM (gets hstPendingTxState) >>= \newTxState -> modify $ \st -> st { stTxState = fromJust "pending-tx" newTxState } -switchRxEncryption = - withHandshakeM (gets hstPendingRxState) - >>= \newRxState -> modify $ \st -> st { stRxState = fromJust "pending-rx" newRxState } getSessionData :: MonadState TLSState m => m (Maybe SessionData) getSessionData = get >>= \st -> return (stHandshake st >>= hstMasterSecret >>= wrapSessionData st)