remove Rx state from general state.

move RxState as a mutable mvar in the context directly.
This commit is contained in:
Vincent Hanquez 2013-07-30 08:58:58 +01:00
parent 6ff5e692d0
commit 49ff6e933c
5 changed files with 42 additions and 33 deletions

View file

@ -43,6 +43,7 @@ module Network.TLS.Context
, ctxEstablished
, ctxLogging
, ctxWithHooks
, ctxRxState
, setEOF
, setEstablished
, contextFlush
@ -76,6 +77,7 @@ module Network.TLS.Context
, throwCore
, usingState
, usingState_
, runRxState
, usingHState
, getStateRNG
) where
@ -90,6 +92,7 @@ import Network.TLS.Compression
import Network.TLS.Crypto
import Network.TLS.State
import Network.TLS.Handshake.State
import Network.TLS.Record.State
import Network.TLS.Measurement
import Network.TLS.X509
import Network.TLS.Types (Role(..))
@ -312,6 +315,7 @@ data Context = Context
, ctxSSLv2ClientHello :: IORef Bool -- ^ enable the reception of compatibility SSLv2 client hello.
-- the flag will be set to false regardless of its initial value
-- after the first packet received.
, ctxRxState :: MVar RecordState -- ^ current rx state
, ctxHooks :: IORef Hooks -- ^ hooks for this context
, ctxLockWrite :: MVar () -- ^ lock to use for writing data (including updating the state)
, ctxLockRead :: MVar () -- ^ lock to use for reading data (including updating the state)
@ -390,6 +394,7 @@ contextNew backend params rng = liftIO $ do
-- server context, where we might be dealing with an old/compat client.
sslv2Compat <- newIORef (role == ServerRole)
hooks <- newIORef defaultHooks
rx <- newMVar newRecordState
lockWrite <- newMVar ()
lockRead <- newMVar ()
lockState <- newMVar ()
@ -397,6 +402,7 @@ contextNew backend params rng = liftIO $ do
{ ctxConnection = backend
, ctxParams = params
, ctxState = stvar
, ctxRxState = rx
, ctxMeasurement = stats
, ctxEOF_ = eof
, ctxEstablished_ = established
@ -440,6 +446,14 @@ usingState_ ctx f = do
usingHState :: MonadIO m => Context -> HandshakeM a -> m a
usingHState ctx f = usingState_ ctx $ withHandshakeM f
runRxState :: MonadIO m => Context -> RecordM a -> m (Either TLSError a)
runRxState ctx f = do
ver <- usingState_ ctx getVersion
liftIO $ modifyMVar (ctxRxState ctx) $ \st ->
case runRecordM f ver st of
Left err -> return (st, Left err)
Right (a, newSt) -> return (newSt, Right a)
getStateRNG :: MonadIO m => Context -> Int -> m Bytes
getStateRNG ctx n = usingState_ ctx $ genRandom n

View file

@ -20,7 +20,6 @@ import Network.TLS.Session
import Network.TLS.Struct
import Network.TLS.IO
import Network.TLS.State hiding (getNegotiatedProtocol)
import Network.TLS.Receiving
import Network.TLS.Handshake.Process
import Network.TLS.Measurement
import Network.TLS.Types

View file

@ -15,7 +15,7 @@ module Network.TLS.IO
) where
import Network.TLS.Context
import Network.TLS.State (needEmptyPacket, runRxState)
import Network.TLS.State (needEmptyPacket)
import Network.TLS.Struct
import Network.TLS.Record
import Network.TLS.Packet
@ -68,7 +68,7 @@ recvRecord compatSSLv2 ctx
| otherwise = readExact ctx 5 >>= either (return . Left) recvLength . decodeHeader
where recvLength header@(Header _ _ readlen)
| readlen > 16384 + 2048 = return $ Left maximumSizeExceeded
| otherwise = readExact ctx (fromIntegral readlen) >>= makeRecord header
| otherwise = readExact ctx (fromIntegral readlen) >>= getRecord header
#ifdef SSLV2_COMPATIBLE
recvDeprecatedLength readlen
| readlen > 1024 * 4 = return $ Left maximumSizeExceeded
@ -76,12 +76,13 @@ recvRecord compatSSLv2 ctx
content <- readExact ctx (fromIntegral readlen)
case decodeDeprecatedHeader readlen content of
Left err -> return $ Left err
Right header -> makeRecord header content
Right header -> getRecord header content
#endif
maximumSizeExceeded = Error_Protocol ("record exceeding maximum size", True, RecordOverflow)
makeRecord header content = do
getRecord :: MonadIO m => Header -> Bytes -> m (Either TLSError (Record Plaintext))
getRecord header content = do
liftIO $ (loggingIORecv $ ctxLogging ctx) header content
usingState ctx $ runRxState $ disengageRecord $ rawToRecord header (fragmentCiphertext content)
runRxState ctx $ disengageRecord $ rawToRecord header (fragmentCiphertext content)
-- | receive one packet from the context that contains 1 or
@ -94,7 +95,7 @@ recvPacket ctx = do
case erecord of
Left err -> return $ Left err
Right record -> do
pktRecv <- usingState ctx $ processPacket record
pktRecv <- processPacket ctx record
pkt <- case pktRecv of
Right (Handshake hss) ->
ctxWithHooks ctx $ \hooks ->

View file

@ -15,29 +15,33 @@ module Network.TLS.Receiving
import Control.Applicative ((<$>))
import Control.Monad.State
import Control.Monad.Error
import Control.Concurrent.MVar
import Network.TLS.Context
import Network.TLS.Struct
import Network.TLS.Record
import Network.TLS.Packet
import Network.TLS.State
import Network.TLS.Cipher
import Network.TLS.Util
returnEither :: Either TLSError a -> TLSSt a
returnEither (Left err) = throwError err
returnEither (Right a) = return a
processPacket :: Record Plaintext -> TLSSt Packet
processPacket :: MonadIO m => Context -> Record Plaintext -> m (Either TLSError Packet)
processPacket (Record ProtocolType_AppData _ fragment) = return $ AppData $ fragmentGetBytes fragment
processPacket _ (Record ProtocolType_AppData _ fragment) = return $ Right $ AppData $ fragmentGetBytes fragment
processPacket (Record ProtocolType_Alert _ fragment) = return . Alert =<< returnEither (decodeAlerts $ fragmentGetBytes fragment)
processPacket _ (Record ProtocolType_Alert _ fragment) = return (Alert `fmapEither` (decodeAlerts $ fragmentGetBytes fragment))
processPacket (Record ProtocolType_ChangeCipherSpec _ fragment) = do
returnEither $ decodeChangeCipherSpec $ fragmentGetBytes fragment
switchRxEncryption
return ChangeCipherSpec
processPacket ctx (Record ProtocolType_ChangeCipherSpec _ fragment) =
case decodeChangeCipherSpec $ fragmentGetBytes fragment of
Left err -> return $ Left err
Right _ -> do switchRxEncryption ctx
return $ Right ChangeCipherSpec
processPacket (Record ProtocolType_Handshake ver fragment) = do
processPacket ctx (Record ProtocolType_Handshake ver fragment) = usingState ctx $ do
keyxchg <- gets (\st -> case stHandshake st of
Nothing -> Nothing
Just hst -> cipherKeyExchange <$> hstPendingCipher hst)
@ -54,7 +58,12 @@ processPacket (Record ProtocolType_Handshake ver fragment) = do
Right hs -> return hs
return $ Handshake hss
processPacket (Record ProtocolType_DeprecatedHandshake _ fragment) =
processPacket _ (Record ProtocolType_DeprecatedHandshake _ fragment) =
case decodeDeprecatedHandshake $ fragmentGetBytes fragment of
Left err -> throwError err
Right hs -> return $ Handshake [hs]
Left err -> return $ Left err
Right hs -> return $ Right $ Handshake [hs]
switchRxEncryption :: MonadIO m => Context -> m ()
switchRxEncryption ctx =
usingHState ctx (gets hstPendingRxState) >>= \rx ->
liftIO $ modifyMVar_ (ctxRxState ctx) (\_ -> return $ fromJust "rx-state" rx)

View file

@ -18,7 +18,6 @@ module Network.TLS.State
, newTLSState
, withTLSRNG
, runTxState
, runRxState
, assert -- FIXME move somewhere else (Internal.hs ?)
, updateVerifiedData
, finishHandshakeTypeMaterial
@ -44,7 +43,6 @@ module Network.TLS.State
, isSessionResuming
, needEmptyPacket
, switchTxEncryption
, switchRxEncryption
, isClientContext
, startHandshakeClient
, getHandshakeDigest
@ -79,7 +77,6 @@ data TLSState = TLSState
, stSession :: Session
, stSessionResuming :: Bool
, stTxState :: RecordState
, stRxState :: RecordState
, stSecureRenegotiation :: Bool -- RFC 5746
, stClientVerifiedData :: Bytes -- RFC 5746
, stServerVerifiedData :: Bytes -- RFC 5746
@ -115,20 +112,12 @@ runTxState f = do
Left err -> throwError err
Right (a, newSt) -> put (st { stTxState = newSt }) >> return a
runRxState :: RecordM a -> TLSSt a
runRxState f = do
st <- get
case runRecordM f (stVersion st) (stRxState st) of
Left err -> throwError err
Right (a, newSt) -> put (st { stRxState = newSt }) >> return a
newTLSState :: CPRG g => g -> Role -> TLSState
newTLSState rng clientContext = TLSState
{ stHandshake = Nothing
, stSession = Session Nothing
, stSessionResuming = False
, stTxState = newRecordState
, stRxState = newRecordState
, stSecureRenegotiation = False
, stClientVerifiedData = B.empty
, stServerVerifiedData = B.empty
@ -180,13 +169,10 @@ certVerifyHandshakeTypeMaterial HandshakeType_NPN = False
certVerifyHandshakeMaterial :: Handshake -> Bool
certVerifyHandshakeMaterial = certVerifyHandshakeTypeMaterial . typeOfHandshake
switchTxEncryption, switchRxEncryption :: TLSSt ()
switchTxEncryption :: TLSSt ()
switchTxEncryption =
withHandshakeM (gets hstPendingTxState)
>>= \newTxState -> modify $ \st -> st { stTxState = fromJust "pending-tx" newTxState }
switchRxEncryption =
withHandshakeM (gets hstPendingRxState)
>>= \newRxState -> modify $ \st -> st { stRxState = fromJust "pending-rx" newRxState }
getSessionData :: MonadState TLSState m => m (Maybe SessionData)
getSessionData = get >>= \st -> return (stHandshake st >>= hstMasterSecret >>= wrapSessionData st)