remove Rx state from general state.
move RxState as a mutable mvar in the context directly.
This commit is contained in:
parent
6ff5e692d0
commit
49ff6e933c
5 changed files with 42 additions and 33 deletions
|
@ -43,6 +43,7 @@ module Network.TLS.Context
|
|||
, ctxEstablished
|
||||
, ctxLogging
|
||||
, ctxWithHooks
|
||||
, ctxRxState
|
||||
, setEOF
|
||||
, setEstablished
|
||||
, contextFlush
|
||||
|
@ -76,6 +77,7 @@ module Network.TLS.Context
|
|||
, throwCore
|
||||
, usingState
|
||||
, usingState_
|
||||
, runRxState
|
||||
, usingHState
|
||||
, getStateRNG
|
||||
) where
|
||||
|
@ -90,6 +92,7 @@ import Network.TLS.Compression
|
|||
import Network.TLS.Crypto
|
||||
import Network.TLS.State
|
||||
import Network.TLS.Handshake.State
|
||||
import Network.TLS.Record.State
|
||||
import Network.TLS.Measurement
|
||||
import Network.TLS.X509
|
||||
import Network.TLS.Types (Role(..))
|
||||
|
@ -312,6 +315,7 @@ data Context = Context
|
|||
, ctxSSLv2ClientHello :: IORef Bool -- ^ enable the reception of compatibility SSLv2 client hello.
|
||||
-- the flag will be set to false regardless of its initial value
|
||||
-- after the first packet received.
|
||||
, ctxRxState :: MVar RecordState -- ^ current rx state
|
||||
, ctxHooks :: IORef Hooks -- ^ hooks for this context
|
||||
, ctxLockWrite :: MVar () -- ^ lock to use for writing data (including updating the state)
|
||||
, ctxLockRead :: MVar () -- ^ lock to use for reading data (including updating the state)
|
||||
|
@ -390,6 +394,7 @@ contextNew backend params rng = liftIO $ do
|
|||
-- server context, where we might be dealing with an old/compat client.
|
||||
sslv2Compat <- newIORef (role == ServerRole)
|
||||
hooks <- newIORef defaultHooks
|
||||
rx <- newMVar newRecordState
|
||||
lockWrite <- newMVar ()
|
||||
lockRead <- newMVar ()
|
||||
lockState <- newMVar ()
|
||||
|
@ -397,6 +402,7 @@ contextNew backend params rng = liftIO $ do
|
|||
{ ctxConnection = backend
|
||||
, ctxParams = params
|
||||
, ctxState = stvar
|
||||
, ctxRxState = rx
|
||||
, ctxMeasurement = stats
|
||||
, ctxEOF_ = eof
|
||||
, ctxEstablished_ = established
|
||||
|
@ -440,6 +446,14 @@ usingState_ ctx f = do
|
|||
usingHState :: MonadIO m => Context -> HandshakeM a -> m a
|
||||
usingHState ctx f = usingState_ ctx $ withHandshakeM f
|
||||
|
||||
runRxState :: MonadIO m => Context -> RecordM a -> m (Either TLSError a)
|
||||
runRxState ctx f = do
|
||||
ver <- usingState_ ctx getVersion
|
||||
liftIO $ modifyMVar (ctxRxState ctx) $ \st ->
|
||||
case runRecordM f ver st of
|
||||
Left err -> return (st, Left err)
|
||||
Right (a, newSt) -> return (newSt, Right a)
|
||||
|
||||
getStateRNG :: MonadIO m => Context -> Int -> m Bytes
|
||||
getStateRNG ctx n = usingState_ ctx $ genRandom n
|
||||
|
||||
|
|
|
@ -20,7 +20,6 @@ import Network.TLS.Session
|
|||
import Network.TLS.Struct
|
||||
import Network.TLS.IO
|
||||
import Network.TLS.State hiding (getNegotiatedProtocol)
|
||||
import Network.TLS.Receiving
|
||||
import Network.TLS.Handshake.Process
|
||||
import Network.TLS.Measurement
|
||||
import Network.TLS.Types
|
||||
|
|
|
@ -15,7 +15,7 @@ module Network.TLS.IO
|
|||
) where
|
||||
|
||||
import Network.TLS.Context
|
||||
import Network.TLS.State (needEmptyPacket, runRxState)
|
||||
import Network.TLS.State (needEmptyPacket)
|
||||
import Network.TLS.Struct
|
||||
import Network.TLS.Record
|
||||
import Network.TLS.Packet
|
||||
|
@ -68,7 +68,7 @@ recvRecord compatSSLv2 ctx
|
|||
| otherwise = readExact ctx 5 >>= either (return . Left) recvLength . decodeHeader
|
||||
where recvLength header@(Header _ _ readlen)
|
||||
| readlen > 16384 + 2048 = return $ Left maximumSizeExceeded
|
||||
| otherwise = readExact ctx (fromIntegral readlen) >>= makeRecord header
|
||||
| otherwise = readExact ctx (fromIntegral readlen) >>= getRecord header
|
||||
#ifdef SSLV2_COMPATIBLE
|
||||
recvDeprecatedLength readlen
|
||||
| readlen > 1024 * 4 = return $ Left maximumSizeExceeded
|
||||
|
@ -76,12 +76,13 @@ recvRecord compatSSLv2 ctx
|
|||
content <- readExact ctx (fromIntegral readlen)
|
||||
case decodeDeprecatedHeader readlen content of
|
||||
Left err -> return $ Left err
|
||||
Right header -> makeRecord header content
|
||||
Right header -> getRecord header content
|
||||
#endif
|
||||
maximumSizeExceeded = Error_Protocol ("record exceeding maximum size", True, RecordOverflow)
|
||||
makeRecord header content = do
|
||||
getRecord :: MonadIO m => Header -> Bytes -> m (Either TLSError (Record Plaintext))
|
||||
getRecord header content = do
|
||||
liftIO $ (loggingIORecv $ ctxLogging ctx) header content
|
||||
usingState ctx $ runRxState $ disengageRecord $ rawToRecord header (fragmentCiphertext content)
|
||||
runRxState ctx $ disengageRecord $ rawToRecord header (fragmentCiphertext content)
|
||||
|
||||
|
||||
-- | receive one packet from the context that contains 1 or
|
||||
|
@ -94,7 +95,7 @@ recvPacket ctx = do
|
|||
case erecord of
|
||||
Left err -> return $ Left err
|
||||
Right record -> do
|
||||
pktRecv <- usingState ctx $ processPacket record
|
||||
pktRecv <- processPacket ctx record
|
||||
pkt <- case pktRecv of
|
||||
Right (Handshake hss) ->
|
||||
ctxWithHooks ctx $ \hooks ->
|
||||
|
|
|
@ -15,29 +15,33 @@ module Network.TLS.Receiving
|
|||
import Control.Applicative ((<$>))
|
||||
import Control.Monad.State
|
||||
import Control.Monad.Error
|
||||
import Control.Concurrent.MVar
|
||||
|
||||
import Network.TLS.Context
|
||||
import Network.TLS.Struct
|
||||
import Network.TLS.Record
|
||||
import Network.TLS.Packet
|
||||
import Network.TLS.State
|
||||
import Network.TLS.Cipher
|
||||
import Network.TLS.Util
|
||||
|
||||
returnEither :: Either TLSError a -> TLSSt a
|
||||
returnEither (Left err) = throwError err
|
||||
returnEither (Right a) = return a
|
||||
|
||||
processPacket :: Record Plaintext -> TLSSt Packet
|
||||
processPacket :: MonadIO m => Context -> Record Plaintext -> m (Either TLSError Packet)
|
||||
|
||||
processPacket (Record ProtocolType_AppData _ fragment) = return $ AppData $ fragmentGetBytes fragment
|
||||
processPacket _ (Record ProtocolType_AppData _ fragment) = return $ Right $ AppData $ fragmentGetBytes fragment
|
||||
|
||||
processPacket (Record ProtocolType_Alert _ fragment) = return . Alert =<< returnEither (decodeAlerts $ fragmentGetBytes fragment)
|
||||
processPacket _ (Record ProtocolType_Alert _ fragment) = return (Alert `fmapEither` (decodeAlerts $ fragmentGetBytes fragment))
|
||||
|
||||
processPacket (Record ProtocolType_ChangeCipherSpec _ fragment) = do
|
||||
returnEither $ decodeChangeCipherSpec $ fragmentGetBytes fragment
|
||||
switchRxEncryption
|
||||
return ChangeCipherSpec
|
||||
processPacket ctx (Record ProtocolType_ChangeCipherSpec _ fragment) =
|
||||
case decodeChangeCipherSpec $ fragmentGetBytes fragment of
|
||||
Left err -> return $ Left err
|
||||
Right _ -> do switchRxEncryption ctx
|
||||
return $ Right ChangeCipherSpec
|
||||
|
||||
processPacket (Record ProtocolType_Handshake ver fragment) = do
|
||||
processPacket ctx (Record ProtocolType_Handshake ver fragment) = usingState ctx $ do
|
||||
keyxchg <- gets (\st -> case stHandshake st of
|
||||
Nothing -> Nothing
|
||||
Just hst -> cipherKeyExchange <$> hstPendingCipher hst)
|
||||
|
@ -54,7 +58,12 @@ processPacket (Record ProtocolType_Handshake ver fragment) = do
|
|||
Right hs -> return hs
|
||||
return $ Handshake hss
|
||||
|
||||
processPacket (Record ProtocolType_DeprecatedHandshake _ fragment) =
|
||||
processPacket _ (Record ProtocolType_DeprecatedHandshake _ fragment) =
|
||||
case decodeDeprecatedHandshake $ fragmentGetBytes fragment of
|
||||
Left err -> throwError err
|
||||
Right hs -> return $ Handshake [hs]
|
||||
Left err -> return $ Left err
|
||||
Right hs -> return $ Right $ Handshake [hs]
|
||||
|
||||
switchRxEncryption :: MonadIO m => Context -> m ()
|
||||
switchRxEncryption ctx =
|
||||
usingHState ctx (gets hstPendingRxState) >>= \rx ->
|
||||
liftIO $ modifyMVar_ (ctxRxState ctx) (\_ -> return $ fromJust "rx-state" rx)
|
||||
|
|
|
@ -18,7 +18,6 @@ module Network.TLS.State
|
|||
, newTLSState
|
||||
, withTLSRNG
|
||||
, runTxState
|
||||
, runRxState
|
||||
, assert -- FIXME move somewhere else (Internal.hs ?)
|
||||
, updateVerifiedData
|
||||
, finishHandshakeTypeMaterial
|
||||
|
@ -44,7 +43,6 @@ module Network.TLS.State
|
|||
, isSessionResuming
|
||||
, needEmptyPacket
|
||||
, switchTxEncryption
|
||||
, switchRxEncryption
|
||||
, isClientContext
|
||||
, startHandshakeClient
|
||||
, getHandshakeDigest
|
||||
|
@ -79,7 +77,6 @@ data TLSState = TLSState
|
|||
, stSession :: Session
|
||||
, stSessionResuming :: Bool
|
||||
, stTxState :: RecordState
|
||||
, stRxState :: RecordState
|
||||
, stSecureRenegotiation :: Bool -- RFC 5746
|
||||
, stClientVerifiedData :: Bytes -- RFC 5746
|
||||
, stServerVerifiedData :: Bytes -- RFC 5746
|
||||
|
@ -115,20 +112,12 @@ runTxState f = do
|
|||
Left err -> throwError err
|
||||
Right (a, newSt) -> put (st { stTxState = newSt }) >> return a
|
||||
|
||||
runRxState :: RecordM a -> TLSSt a
|
||||
runRxState f = do
|
||||
st <- get
|
||||
case runRecordM f (stVersion st) (stRxState st) of
|
||||
Left err -> throwError err
|
||||
Right (a, newSt) -> put (st { stRxState = newSt }) >> return a
|
||||
|
||||
newTLSState :: CPRG g => g -> Role -> TLSState
|
||||
newTLSState rng clientContext = TLSState
|
||||
{ stHandshake = Nothing
|
||||
, stSession = Session Nothing
|
||||
, stSessionResuming = False
|
||||
, stTxState = newRecordState
|
||||
, stRxState = newRecordState
|
||||
, stSecureRenegotiation = False
|
||||
, stClientVerifiedData = B.empty
|
||||
, stServerVerifiedData = B.empty
|
||||
|
@ -180,13 +169,10 @@ certVerifyHandshakeTypeMaterial HandshakeType_NPN = False
|
|||
certVerifyHandshakeMaterial :: Handshake -> Bool
|
||||
certVerifyHandshakeMaterial = certVerifyHandshakeTypeMaterial . typeOfHandshake
|
||||
|
||||
switchTxEncryption, switchRxEncryption :: TLSSt ()
|
||||
switchTxEncryption :: TLSSt ()
|
||||
switchTxEncryption =
|
||||
withHandshakeM (gets hstPendingTxState)
|
||||
>>= \newTxState -> modify $ \st -> st { stTxState = fromJust "pending-tx" newTxState }
|
||||
switchRxEncryption =
|
||||
withHandshakeM (gets hstPendingRxState)
|
||||
>>= \newRxState -> modify $ \st -> st { stRxState = fromJust "pending-rx" newRxState }
|
||||
|
||||
getSessionData :: MonadState TLSState m => m (Maybe SessionData)
|
||||
getSessionData = get >>= \st -> return (stHandshake st >>= hstMasterSecret >>= wrapSessionData st)
|
||||
|
|
Loading…
Reference in a new issue