move tx state into a mvar in the context.
This commit is contained in:
parent
49ff6e933c
commit
7994f4ba27
4 changed files with 58 additions and 45 deletions
|
@ -44,6 +44,8 @@ module Network.TLS.Context
|
|||
, ctxLogging
|
||||
, ctxWithHooks
|
||||
, ctxRxState
|
||||
, ctxTxState
|
||||
, ctxNeedEmptyPacket
|
||||
, setEOF
|
||||
, setEstablished
|
||||
, contextFlush
|
||||
|
@ -77,6 +79,7 @@ module Network.TLS.Context
|
|||
, throwCore
|
||||
, usingState
|
||||
, usingState_
|
||||
, runTxState
|
||||
, runRxState
|
||||
, usingHState
|
||||
, getStateRNG
|
||||
|
@ -312,9 +315,11 @@ data Context = Context
|
|||
, ctxMeasurement :: IORef Measurement
|
||||
, ctxEOF_ :: IORef Bool -- ^ has the handle EOFed or not.
|
||||
, ctxEstablished_ :: IORef Bool -- ^ has the handshake been done and been successful.
|
||||
, ctxNeedEmptyPacket :: IORef Bool -- ^ empty packet workaround for CBC guessability.
|
||||
, ctxSSLv2ClientHello :: IORef Bool -- ^ enable the reception of compatibility SSLv2 client hello.
|
||||
-- the flag will be set to false regardless of its initial value
|
||||
-- after the first packet received.
|
||||
, ctxTxState :: MVar RecordState -- ^ current tx state
|
||||
, ctxRxState :: MVar RecordState -- ^ current rx state
|
||||
, ctxHooks :: IORef Hooks -- ^ hooks for this context
|
||||
, ctxLockWrite :: MVar () -- ^ lock to use for writing data (including updating the state)
|
||||
|
@ -393,7 +398,9 @@ contextNew backend params rng = liftIO $ do
|
|||
-- we enable the reception of SSLv2 ClientHello message only in the
|
||||
-- server context, where we might be dealing with an old/compat client.
|
||||
sslv2Compat <- newIORef (role == ServerRole)
|
||||
needEmptyPacket <- newIORef False
|
||||
hooks <- newIORef defaultHooks
|
||||
tx <- newMVar newRecordState
|
||||
rx <- newMVar newRecordState
|
||||
lockWrite <- newMVar ()
|
||||
lockRead <- newMVar ()
|
||||
|
@ -402,11 +409,13 @@ contextNew backend params rng = liftIO $ do
|
|||
{ ctxConnection = backend
|
||||
, ctxParams = params
|
||||
, ctxState = stvar
|
||||
, ctxTxState = tx
|
||||
, ctxRxState = rx
|
||||
, ctxMeasurement = stats
|
||||
, ctxEOF_ = eof
|
||||
, ctxEstablished_ = established
|
||||
, ctxSSLv2ClientHello = sslv2Compat
|
||||
, ctxNeedEmptyPacket = needEmptyPacket
|
||||
, ctxHooks = hooks
|
||||
, ctxLockWrite = lockWrite
|
||||
, ctxLockRead = lockRead
|
||||
|
@ -446,6 +455,14 @@ usingState_ ctx f = do
|
|||
usingHState :: MonadIO m => Context -> HandshakeM a -> m a
|
||||
usingHState ctx f = usingState_ ctx $ withHandshakeM f
|
||||
|
||||
runTxState :: MonadIO m => Context -> RecordM a -> m (Either TLSError a)
|
||||
runTxState ctx f = do
|
||||
ver <- usingState_ ctx getVersion
|
||||
liftIO $ modifyMVar (ctxTxState ctx) $ \st ->
|
||||
case runRecordM f ver st of
|
||||
Left err -> return (st, Left err)
|
||||
Right (a, newSt) -> return (newSt, Right a)
|
||||
|
||||
runRxState :: MonadIO m => Context -> RecordM a -> m (Either TLSError a)
|
||||
runRxState ctx f = do
|
||||
ver <- usingState_ ctx getVersion
|
||||
|
|
|
@ -15,7 +15,6 @@ module Network.TLS.IO
|
|||
) where
|
||||
|
||||
import Network.TLS.Context
|
||||
import Network.TLS.State (needEmptyPacket)
|
||||
import Network.TLS.Struct
|
||||
import Network.TLS.Record
|
||||
import Network.TLS.Packet
|
||||
|
@ -25,6 +24,7 @@ import Data.Data
|
|||
import qualified Data.ByteString as B
|
||||
import Data.ByteString.Char8 ()
|
||||
|
||||
import Data.IORef
|
||||
import Control.Monad.State
|
||||
import Control.Exception (throwIO, Exception())
|
||||
import System.IO.Error (mkIOError, eofErrorType)
|
||||
|
@ -113,12 +113,15 @@ sendPacket ctx pkt = do
|
|||
-- in ver <= TLS1.0, block ciphers using CBC are using CBC residue as IV, which can be guessed
|
||||
-- by an attacker. Hence, an empty packet is sent before a normal data packet, to
|
||||
-- prevent guessability.
|
||||
withEmptyPacket <- usingState_ ctx needEmptyPacket
|
||||
withEmptyPacket <- liftIO $ readIORef $ ctxNeedEmptyPacket ctx
|
||||
when (isNonNullAppData pkt && withEmptyPacket) $ sendPacket ctx $ AppData B.empty
|
||||
|
||||
liftIO $ (loggingPacketSent $ ctxLogging ctx) (show pkt)
|
||||
dataToSend <- usingState_ ctx $ writePacket pkt
|
||||
liftIO $ (loggingIOSent $ ctxLogging ctx) dataToSend
|
||||
liftIO $ contextSend ctx dataToSend
|
||||
edataToSend <- liftIO (writePacket ctx pkt)
|
||||
case edataToSend of
|
||||
Left err -> throwCore err
|
||||
Right dataToSend -> do
|
||||
liftIO $ (loggingIOSent $ ctxLogging ctx) dataToSend
|
||||
liftIO $ contextSend ctx dataToSend
|
||||
where isNonNullAppData (AppData b) = not $ B.null b
|
||||
isNonNullAppData _ = False
|
||||
|
|
|
@ -10,7 +10,10 @@
|
|||
--
|
||||
module Network.TLS.Sending (writePacket) where
|
||||
|
||||
import Control.Applicative
|
||||
import Control.Monad.State
|
||||
import Control.Concurrent.MVar
|
||||
import Data.IORef
|
||||
|
||||
import Data.ByteString (ByteString)
|
||||
import qualified Data.ByteString as B
|
||||
|
@ -20,9 +23,11 @@ import Network.TLS.Cap
|
|||
import Network.TLS.Struct
|
||||
import Network.TLS.Record
|
||||
import Network.TLS.Packet
|
||||
import Network.TLS.Context
|
||||
import Network.TLS.State
|
||||
import Network.TLS.Handshake.State
|
||||
import Network.TLS.Cipher
|
||||
import Network.TLS.Util
|
||||
|
||||
-- | 'makePacketData' create a Header and a content bytestring related to a packet
|
||||
-- this doesn't change any state
|
||||
|
@ -42,32 +47,43 @@ encodeRecord record = return $ B.concat [ encodeHeader hdr, content ]
|
|||
|
||||
-- | writePacket transform a packet into marshalled data related to current state
|
||||
-- and updating state on the go
|
||||
writePacket :: Packet -> TLSSt ByteString
|
||||
writePacket pkt@(Handshake hss) = do
|
||||
forM_ hss $ \hs -> do
|
||||
writePacket :: Context -> Packet -> IO (Either TLSError ByteString)
|
||||
writePacket ctx pkt@(Handshake hss) = do
|
||||
usingState_ ctx $ forM_ hss $ \hs -> do
|
||||
case hs of
|
||||
Finished fdata -> updateVerifiedData ClientRole fdata
|
||||
_ -> return ()
|
||||
let encoded = encodeHandshake hs
|
||||
when (certVerifyHandshakeMaterial hs) $ withHandshakeM $ addHandshakeMessage encoded
|
||||
when (finishHandshakeTypeMaterial $ typeOfHandshake hs) $ withHandshakeM $ updateHandshakeDigest encoded
|
||||
prepareRecord (makeRecord pkt >>= engageRecord >>= encodeRecord)
|
||||
writePacket pkt = do
|
||||
d <- prepareRecord (makeRecord pkt >>= engageRecord >>= encodeRecord)
|
||||
when (pkt == ChangeCipherSpec) $ switchTxEncryption
|
||||
prepareRecord ctx (makeRecord pkt >>= engageRecord >>= encodeRecord)
|
||||
writePacket ctx pkt = do
|
||||
d <- prepareRecord ctx (makeRecord pkt >>= engageRecord >>= encodeRecord)
|
||||
when (pkt == ChangeCipherSpec) $ switchTxEncryption ctx
|
||||
return d
|
||||
|
||||
-- before TLS 1.1, the block cipher IV is made of the residual of the previous block,
|
||||
-- so we use cstIV as is, however in other case we generate an explicit IV
|
||||
prepareRecord :: RecordM a -> TLSSt a
|
||||
prepareRecord f = do
|
||||
st <- get
|
||||
ver <- getVersion
|
||||
let sz = case stCipher $ stTxState st of
|
||||
prepareRecord :: Context -> RecordM a -> IO (Either TLSError a)
|
||||
prepareRecord ctx f = do
|
||||
ver <- usingState_ ctx getVersion
|
||||
txState <- readMVar $ ctxTxState ctx
|
||||
let sz = case stCipher $ txState of
|
||||
Nothing -> 0
|
||||
Just cipher -> bulkIVSize $ cipherBulk cipher
|
||||
if hasExplicitBlockIV ver && sz > 0
|
||||
then do newIV <- genRandom sz
|
||||
runTxState (modify $ setRecordIV newIV)
|
||||
runTxState f
|
||||
else runTxState f
|
||||
then do newIV <- getStateRNG ctx sz
|
||||
runTxState ctx (modify (setRecordIV newIV) >> f)
|
||||
else runTxState ctx f
|
||||
|
||||
switchTxEncryption :: MonadIO m => Context -> m ()
|
||||
switchTxEncryption ctx = do
|
||||
tx <- usingHState ctx (fromJust "tx-state" <$> gets hstPendingTxState)
|
||||
(ver, cc) <- usingState_ ctx $ do v <- getVersion
|
||||
c <- isClientContext
|
||||
return (v, c)
|
||||
liftIO $ modifyMVar_ (ctxTxState ctx) (\_ -> return tx)
|
||||
-- set empty packet counter measure if condition are met
|
||||
when (ver <= TLS10 && cc == ClientRole && isCBC tx) $ liftIO $ writeIORef (ctxNeedEmptyPacket ctx) True
|
||||
where isCBC tx = maybe False (\c -> bulkBlockSize (cipherBulk c) > 0) (stCipher tx)
|
||||
|
||||
|
|
|
@ -17,7 +17,6 @@ module Network.TLS.State
|
|||
, withHandshakeM
|
||||
, newTLSState
|
||||
, withTLSRNG
|
||||
, runTxState
|
||||
, assert -- FIXME move somewhere else (Internal.hs ?)
|
||||
, updateVerifiedData
|
||||
, finishHandshakeTypeMaterial
|
||||
|
@ -41,8 +40,6 @@ module Network.TLS.State
|
|||
, getSession
|
||||
, getSessionData
|
||||
, isSessionResuming
|
||||
, needEmptyPacket
|
||||
, switchTxEncryption
|
||||
, isClientContext
|
||||
, startHandshakeClient
|
||||
, getHandshakeDigest
|
||||
|
@ -76,7 +73,6 @@ data TLSState = TLSState
|
|||
{ stHandshake :: !(Maybe HandshakeState)
|
||||
, stSession :: Session
|
||||
, stSessionResuming :: Bool
|
||||
, stTxState :: RecordState
|
||||
, stSecureRenegotiation :: Bool -- RFC 5746
|
||||
, stClientVerifiedData :: Bytes -- RFC 5746
|
||||
, stServerVerifiedData :: Bytes -- RFC 5746
|
||||
|
@ -105,19 +101,11 @@ instance MonadState TLSState TLSSt where
|
|||
runTLSState :: TLSSt a -> TLSState -> (Either TLSError a, TLSState)
|
||||
runTLSState f st = runState (runErrorT (runTLSSt f)) st
|
||||
|
||||
runTxState :: RecordM a -> TLSSt a
|
||||
runTxState f = do
|
||||
st <- get
|
||||
case runRecordM f (stVersion st) (stTxState st) of
|
||||
Left err -> throwError err
|
||||
Right (a, newSt) -> put (st { stTxState = newSt }) >> return a
|
||||
|
||||
newTLSState :: CPRG g => g -> Role -> TLSState
|
||||
newTLSState rng clientContext = TLSState
|
||||
{ stHandshake = Nothing
|
||||
, stSession = Session Nothing
|
||||
, stSessionResuming = False
|
||||
, stTxState = newRecordState
|
||||
, stSecureRenegotiation = False
|
||||
, stClientVerifiedData = B.empty
|
||||
, stServerVerifiedData = B.empty
|
||||
|
@ -169,17 +157,12 @@ certVerifyHandshakeTypeMaterial HandshakeType_NPN = False
|
|||
certVerifyHandshakeMaterial :: Handshake -> Bool
|
||||
certVerifyHandshakeMaterial = certVerifyHandshakeTypeMaterial . typeOfHandshake
|
||||
|
||||
switchTxEncryption :: TLSSt ()
|
||||
switchTxEncryption =
|
||||
withHandshakeM (gets hstPendingTxState)
|
||||
>>= \newTxState -> modify $ \st -> st { stTxState = fromJust "pending-tx" newTxState }
|
||||
|
||||
getSessionData :: MonadState TLSState m => m (Maybe SessionData)
|
||||
getSessionData = get >>= \st -> return (stHandshake st >>= hstMasterSecret >>= wrapSessionData st)
|
||||
where wrapSessionData st masterSecret = do
|
||||
return $ SessionData
|
||||
{ sessionVersion = stVersion st
|
||||
, sessionCipher = cipherID $ fromJust "cipher" $ stCipher $ stTxState $ st
|
||||
, sessionCipher = undefined -- cipherID $ fromJust "cipher" $ stCipher $ stTxState $ st
|
||||
, sessionSecret = masterSecret
|
||||
}
|
||||
|
||||
|
@ -192,12 +175,6 @@ getSession = gets stSession
|
|||
isSessionResuming :: MonadState TLSState m => m Bool
|
||||
isSessionResuming = gets stSessionResuming
|
||||
|
||||
needEmptyPacket :: MonadState TLSState m => m Bool
|
||||
needEmptyPacket = gets f
|
||||
where f st = (stVersion st <= TLS10)
|
||||
&& stClientContext st == ClientRole
|
||||
&& (maybe False (\c -> bulkBlockSize (cipherBulk c) > 0) (stCipher $ stTxState st))
|
||||
|
||||
setVersion :: MonadState TLSState m => Version -> m ()
|
||||
setVersion ver = modify (\st -> st { stVersion = ver })
|
||||
|
||||
|
|
Loading…
Reference in a new issue