move tx state into a mvar in the context.

This commit is contained in:
Vincent Hanquez 2013-08-01 08:05:03 +01:00
parent 49ff6e933c
commit 7994f4ba27
4 changed files with 58 additions and 45 deletions

View file

@ -44,6 +44,8 @@ module Network.TLS.Context
, ctxLogging
, ctxWithHooks
, ctxRxState
, ctxTxState
, ctxNeedEmptyPacket
, setEOF
, setEstablished
, contextFlush
@ -77,6 +79,7 @@ module Network.TLS.Context
, throwCore
, usingState
, usingState_
, runTxState
, runRxState
, usingHState
, getStateRNG
@ -312,9 +315,11 @@ data Context = Context
, ctxMeasurement :: IORef Measurement
, ctxEOF_ :: IORef Bool -- ^ has the handle EOFed or not.
, ctxEstablished_ :: IORef Bool -- ^ has the handshake been done and been successful.
, ctxNeedEmptyPacket :: IORef Bool -- ^ empty packet workaround for CBC guessability.
, ctxSSLv2ClientHello :: IORef Bool -- ^ enable the reception of compatibility SSLv2 client hello.
-- the flag will be set to false regardless of its initial value
-- after the first packet received.
, ctxTxState :: MVar RecordState -- ^ current tx state
, ctxRxState :: MVar RecordState -- ^ current rx state
, ctxHooks :: IORef Hooks -- ^ hooks for this context
, ctxLockWrite :: MVar () -- ^ lock to use for writing data (including updating the state)
@ -393,7 +398,9 @@ contextNew backend params rng = liftIO $ do
-- we enable the reception of SSLv2 ClientHello message only in the
-- server context, where we might be dealing with an old/compat client.
sslv2Compat <- newIORef (role == ServerRole)
needEmptyPacket <- newIORef False
hooks <- newIORef defaultHooks
tx <- newMVar newRecordState
rx <- newMVar newRecordState
lockWrite <- newMVar ()
lockRead <- newMVar ()
@ -402,11 +409,13 @@ contextNew backend params rng = liftIO $ do
{ ctxConnection = backend
, ctxParams = params
, ctxState = stvar
, ctxTxState = tx
, ctxRxState = rx
, ctxMeasurement = stats
, ctxEOF_ = eof
, ctxEstablished_ = established
, ctxSSLv2ClientHello = sslv2Compat
, ctxNeedEmptyPacket = needEmptyPacket
, ctxHooks = hooks
, ctxLockWrite = lockWrite
, ctxLockRead = lockRead
@ -446,6 +455,14 @@ usingState_ ctx f = do
usingHState :: MonadIO m => Context -> HandshakeM a -> m a
usingHState ctx f = usingState_ ctx $ withHandshakeM f
runTxState :: MonadIO m => Context -> RecordM a -> m (Either TLSError a)
runTxState ctx f = do
ver <- usingState_ ctx getVersion
liftIO $ modifyMVar (ctxTxState ctx) $ \st ->
case runRecordM f ver st of
Left err -> return (st, Left err)
Right (a, newSt) -> return (newSt, Right a)
runRxState :: MonadIO m => Context -> RecordM a -> m (Either TLSError a)
runRxState ctx f = do
ver <- usingState_ ctx getVersion

View file

@ -15,7 +15,6 @@ module Network.TLS.IO
) where
import Network.TLS.Context
import Network.TLS.State (needEmptyPacket)
import Network.TLS.Struct
import Network.TLS.Record
import Network.TLS.Packet
@ -25,6 +24,7 @@ import Data.Data
import qualified Data.ByteString as B
import Data.ByteString.Char8 ()
import Data.IORef
import Control.Monad.State
import Control.Exception (throwIO, Exception())
import System.IO.Error (mkIOError, eofErrorType)
@ -113,12 +113,15 @@ sendPacket ctx pkt = do
-- in ver <= TLS1.0, block ciphers using CBC are using CBC residue as IV, which can be guessed
-- by an attacker. Hence, an empty packet is sent before a normal data packet, to
-- prevent guessability.
withEmptyPacket <- usingState_ ctx needEmptyPacket
withEmptyPacket <- liftIO $ readIORef $ ctxNeedEmptyPacket ctx
when (isNonNullAppData pkt && withEmptyPacket) $ sendPacket ctx $ AppData B.empty
liftIO $ (loggingPacketSent $ ctxLogging ctx) (show pkt)
dataToSend <- usingState_ ctx $ writePacket pkt
liftIO $ (loggingIOSent $ ctxLogging ctx) dataToSend
liftIO $ contextSend ctx dataToSend
edataToSend <- liftIO (writePacket ctx pkt)
case edataToSend of
Left err -> throwCore err
Right dataToSend -> do
liftIO $ (loggingIOSent $ ctxLogging ctx) dataToSend
liftIO $ contextSend ctx dataToSend
where isNonNullAppData (AppData b) = not $ B.null b
isNonNullAppData _ = False

View file

@ -10,7 +10,10 @@
--
module Network.TLS.Sending (writePacket) where
import Control.Applicative
import Control.Monad.State
import Control.Concurrent.MVar
import Data.IORef
import Data.ByteString (ByteString)
import qualified Data.ByteString as B
@ -20,9 +23,11 @@ import Network.TLS.Cap
import Network.TLS.Struct
import Network.TLS.Record
import Network.TLS.Packet
import Network.TLS.Context
import Network.TLS.State
import Network.TLS.Handshake.State
import Network.TLS.Cipher
import Network.TLS.Util
-- | 'makePacketData' create a Header and a content bytestring related to a packet
-- this doesn't change any state
@ -42,32 +47,43 @@ encodeRecord record = return $ B.concat [ encodeHeader hdr, content ]
-- | writePacket transform a packet into marshalled data related to current state
-- and updating state on the go
writePacket :: Packet -> TLSSt ByteString
writePacket pkt@(Handshake hss) = do
forM_ hss $ \hs -> do
writePacket :: Context -> Packet -> IO (Either TLSError ByteString)
writePacket ctx pkt@(Handshake hss) = do
usingState_ ctx $ forM_ hss $ \hs -> do
case hs of
Finished fdata -> updateVerifiedData ClientRole fdata
_ -> return ()
let encoded = encodeHandshake hs
when (certVerifyHandshakeMaterial hs) $ withHandshakeM $ addHandshakeMessage encoded
when (finishHandshakeTypeMaterial $ typeOfHandshake hs) $ withHandshakeM $ updateHandshakeDigest encoded
prepareRecord (makeRecord pkt >>= engageRecord >>= encodeRecord)
writePacket pkt = do
d <- prepareRecord (makeRecord pkt >>= engageRecord >>= encodeRecord)
when (pkt == ChangeCipherSpec) $ switchTxEncryption
prepareRecord ctx (makeRecord pkt >>= engageRecord >>= encodeRecord)
writePacket ctx pkt = do
d <- prepareRecord ctx (makeRecord pkt >>= engageRecord >>= encodeRecord)
when (pkt == ChangeCipherSpec) $ switchTxEncryption ctx
return d
-- before TLS 1.1, the block cipher IV is made of the residual of the previous block,
-- so we use cstIV as is, however in other case we generate an explicit IV
prepareRecord :: RecordM a -> TLSSt a
prepareRecord f = do
st <- get
ver <- getVersion
let sz = case stCipher $ stTxState st of
prepareRecord :: Context -> RecordM a -> IO (Either TLSError a)
prepareRecord ctx f = do
ver <- usingState_ ctx getVersion
txState <- readMVar $ ctxTxState ctx
let sz = case stCipher $ txState of
Nothing -> 0
Just cipher -> bulkIVSize $ cipherBulk cipher
if hasExplicitBlockIV ver && sz > 0
then do newIV <- genRandom sz
runTxState (modify $ setRecordIV newIV)
runTxState f
else runTxState f
then do newIV <- getStateRNG ctx sz
runTxState ctx (modify (setRecordIV newIV) >> f)
else runTxState ctx f
switchTxEncryption :: MonadIO m => Context -> m ()
switchTxEncryption ctx = do
tx <- usingHState ctx (fromJust "tx-state" <$> gets hstPendingTxState)
(ver, cc) <- usingState_ ctx $ do v <- getVersion
c <- isClientContext
return (v, c)
liftIO $ modifyMVar_ (ctxTxState ctx) (\_ -> return tx)
-- set empty packet counter measure if condition are met
when (ver <= TLS10 && cc == ClientRole && isCBC tx) $ liftIO $ writeIORef (ctxNeedEmptyPacket ctx) True
where isCBC tx = maybe False (\c -> bulkBlockSize (cipherBulk c) > 0) (stCipher tx)

View file

@ -17,7 +17,6 @@ module Network.TLS.State
, withHandshakeM
, newTLSState
, withTLSRNG
, runTxState
, assert -- FIXME move somewhere else (Internal.hs ?)
, updateVerifiedData
, finishHandshakeTypeMaterial
@ -41,8 +40,6 @@ module Network.TLS.State
, getSession
, getSessionData
, isSessionResuming
, needEmptyPacket
, switchTxEncryption
, isClientContext
, startHandshakeClient
, getHandshakeDigest
@ -76,7 +73,6 @@ data TLSState = TLSState
{ stHandshake :: !(Maybe HandshakeState)
, stSession :: Session
, stSessionResuming :: Bool
, stTxState :: RecordState
, stSecureRenegotiation :: Bool -- RFC 5746
, stClientVerifiedData :: Bytes -- RFC 5746
, stServerVerifiedData :: Bytes -- RFC 5746
@ -105,19 +101,11 @@ instance MonadState TLSState TLSSt where
runTLSState :: TLSSt a -> TLSState -> (Either TLSError a, TLSState)
runTLSState f st = runState (runErrorT (runTLSSt f)) st
runTxState :: RecordM a -> TLSSt a
runTxState f = do
st <- get
case runRecordM f (stVersion st) (stTxState st) of
Left err -> throwError err
Right (a, newSt) -> put (st { stTxState = newSt }) >> return a
newTLSState :: CPRG g => g -> Role -> TLSState
newTLSState rng clientContext = TLSState
{ stHandshake = Nothing
, stSession = Session Nothing
, stSessionResuming = False
, stTxState = newRecordState
, stSecureRenegotiation = False
, stClientVerifiedData = B.empty
, stServerVerifiedData = B.empty
@ -169,17 +157,12 @@ certVerifyHandshakeTypeMaterial HandshakeType_NPN = False
certVerifyHandshakeMaterial :: Handshake -> Bool
certVerifyHandshakeMaterial = certVerifyHandshakeTypeMaterial . typeOfHandshake
switchTxEncryption :: TLSSt ()
switchTxEncryption =
withHandshakeM (gets hstPendingTxState)
>>= \newTxState -> modify $ \st -> st { stTxState = fromJust "pending-tx" newTxState }
getSessionData :: MonadState TLSState m => m (Maybe SessionData)
getSessionData = get >>= \st -> return (stHandshake st >>= hstMasterSecret >>= wrapSessionData st)
where wrapSessionData st masterSecret = do
return $ SessionData
{ sessionVersion = stVersion st
, sessionCipher = cipherID $ fromJust "cipher" $ stCipher $ stTxState $ st
, sessionCipher = undefined -- cipherID $ fromJust "cipher" $ stCipher $ stTxState $ st
, sessionSecret = masterSecret
}
@ -192,12 +175,6 @@ getSession = gets stSession
isSessionResuming :: MonadState TLSState m => m Bool
isSessionResuming = gets stSessionResuming
needEmptyPacket :: MonadState TLSState m => m Bool
needEmptyPacket = gets f
where f st = (stVersion st <= TLS10)
&& stClientContext st == ClientRole
&& (maybe False (\c -> bulkBlockSize (cipherBulk c) > 0) (stCipher $ stTxState st))
setVersion :: MonadState TLSState m => Version -> m ()
setVersion ver = modify (\st -> st { stVersion = ver })