Separate tx/rx state from a single RecordState

unroll a reader/state/error monad into a single simple monad,
and move back version and client context in state.
This commit is contained in:
Vincent Hanquez 2013-07-25 21:53:32 +01:00
parent e3b3483560
commit e2d5170af7
11 changed files with 110 additions and 128 deletions

View file

@ -218,7 +218,7 @@ sendClientData cparams ctx = sendCertificate >> sendClientKeyXchg >> sendCertifi
-- 4. Send it to the server.
--
sendCertificateVerify = do
usedVersion <- usingState_ ctx $ stVersion . stRecordState <$> get
usedVersion <- usingState_ ctx getVersion
-- Only send a certificate verify message when we
-- have sent a non-empty list of certificates.

View file

@ -183,7 +183,7 @@ handshakeServerWith sparams ctx clientHello@(ClientHello ver _ clientSession cip
-- certificates.
--
when (serverWantClientCert sparams) $ do
usedVersion <- usingState_ ctx $ getRecordState stVersion
usedVersion <- usingState_ ctx getVersion
let certTypes = [ CertificateType_RSA_Sign ]
hashSigs = if usedVersion < TLS12
then Nothing
@ -247,7 +247,7 @@ recvClientData sparams ctx = runRecvState ctx (RecvStateHandshake processClientC
-- Fetch all handshake messages up to now.
msgs <- usingHState ctx $ B.concat <$> getHandshakeMessages
usedVersion <- usingState_ ctx $ getRecordState stVersion
usedVersion <- usingState_ ctx getVersion
(signature, hsh) <- case usedVersion of
SSL3 -> do

View file

@ -15,7 +15,7 @@ module Network.TLS.IO
) where
import Network.TLS.Context
import Network.TLS.State (needEmptyPacket, runRecordStateSt)
import Network.TLS.State (needEmptyPacket, runRxState)
import Network.TLS.Struct
import Network.TLS.Record
import Network.TLS.Packet
@ -78,7 +78,7 @@ recvRecord compatSSLv2 ctx
maximumSizeExceeded = Error_Protocol ("record exceeding maximum size", True, RecordOverflow)
makeRecord header content = do
liftIO $ (loggingIORecv $ ctxLogging ctx) header content
usingState ctx $ runRecordStateSt $ disengageRecord $ rawToRecord header (fragmentCiphertext content)
usingState ctx $ runRxState $ disengageRecord $ rawToRecord header (fragmentCiphertext content)
-- | receive one packet from the context that contains 1 or
@ -109,7 +109,7 @@ sendPacket ctx pkt = do
-- in ver <= TLS1.0, block ciphers using CBC are using CBC residue as IV, which can be guessed
-- by an attacker. Hence, an empty packet is sent before a normal data packet, to
-- prevent guessability.
withEmptyPacket <- usingState_ ctx (runRecordStateSt needEmptyPacket)
withEmptyPacket <- usingState_ ctx needEmptyPacket
when (isNonNullAppData pkt && withEmptyPacket) $ sendPacket ctx $ AppData B.empty
liftIO $ (loggingPacketSent $ ctxLogging ctx) (show pkt)

View file

@ -99,7 +99,7 @@ processHandshake hs = do
decryptRSA :: ByteString -> TLSSt (Either KxError ByteString)
decryptRSA econtent = do
ver <- getRecordState stVersion
ver <- getVersion
rsapriv <- fromJust "rsa private key" . hstRSAPrivateKey . fromJust "handshake" . stHandshake <$> get
let cipher = if ver < TLS10 then econtent else B.drop 2 econtent
st <- get

View file

@ -28,9 +28,7 @@ module Network.TLS.Record
, engageRecord
, disengageRecord
-- * State tracking
, RecordState(..)
, RecordM
, newRecordState
) where
import Network.TLS.Record.Types

View file

@ -30,14 +30,14 @@ disengageRecord = decryptRecord >=> uncompressRecord
uncompressRecord :: Record Compressed -> RecordM (Record Plaintext)
uncompressRecord record = onRecordFragment record $ fragmentUncompress $ \bytes ->
withRxCompression $ compressionInflate bytes
withCompression $ compressionInflate bytes
decryptRecord :: Record Ciphertext -> RecordM (Record Compressed)
decryptRecord record = onRecordFragment record $ fragmentUncipher $ \e -> do
st <- get
case stCipher $ stRxState st of
case stCipher st of
Nothing -> return e
_ -> decryptData record e st
_ -> getRecordVersion >>= \ver -> decryptData ver record e st
getCipherData :: Record a -> CipherData -> RecordM ByteString
getCipherData (Record pt ver _) cdata = do
@ -46,14 +46,14 @@ getCipherData (Record pt ver _) cdata = do
Nothing -> return True
Just digest -> do
let new_hdr = Header pt ver (fromIntegral $ B.length $ cipherDataContent cdata)
expected_digest <- makeDigest False new_hdr $ cipherDataContent cdata
expected_digest <- makeDigest new_hdr $ cipherDataContent cdata
return (expected_digest `bytesEq` digest)
-- check if the padding is filled with the correct pattern if it exists
paddingValid <- case cipherDataPadding cdata of
Nothing -> return True
Just pad -> do
cver <- gets stVersion
cver <- getRecordVersion
let b = B.length pad - 1
return (if cver < TLS10 then True else B.replicate (B.length pad) (fromIntegral b) `bytesEq` pad)
@ -62,10 +62,9 @@ getCipherData (Record pt ver _) cdata = do
return $ cipherDataContent cdata
decryptData :: Record Ciphertext -> Bytes -> RecordState -> RecordM Bytes
decryptData record econtent st = decryptOf (bulkF bulk)
where tst = stRxState st
cipher = fromJust "cipher" $ stCipher tst
decryptData :: Version -> Record Ciphertext -> Bytes -> TransmissionState -> RecordM Bytes
decryptData ver record econtent tst = decryptOf (bulkF bulk)
where cipher = fromJust "cipher" $ stCipher tst
bulk = cipherBulk cipher
cst = stCryptState tst
macSize = hashSize $ cipherHash cipher
@ -73,7 +72,7 @@ decryptData record econtent st = decryptOf (bulkF bulk)
blockSize = bulkBlockSize bulk
econtentLen = B.length econtent
explicitIV = hasExplicitBlockIV $ stVersion st
explicitIV = hasExplicitBlockIV ver
sanityCheckError = throwError (Error_Packet "encrypted content too small for encryption parameters")
@ -86,7 +85,7 @@ decryptData record econtent st = decryptOf (bulkF bulk)
then get2 econtent (bulkIVSize bulk, econtentLen - bulkIVSize bulk)
else return (cstIV cst, econtent)
let newiv = fromJust "new iv" $ takelast (bulkBlockSize bulk) econtent'
modifyRxState_ $ \txs -> txs { stCryptState = cst { cstIV = newiv } }
modify $ \txs -> txs { stCryptState = cst { cstIV = newiv } }
let content' = decryptF writekey iv econtent'
let paddinglength = fromIntegral (B.last content') + 1
@ -104,7 +103,7 @@ decryptData record econtent st = decryptOf (bulkF bulk)
{- update Ctx -}
let contentlen = B.length content' - macSize
(content, mac) <- get2 content' (contentlen, macSize)
modifyRxState_ $ \txs -> txs { stCryptState = cst { cstIV = newiv } }
modify $ \txs -> txs { stCryptState = cst { cstIV = newiv } }
getCipherData record $ CipherData
{ cipherDataContent = content
, cipherDataMAC = Just mac

View file

@ -29,7 +29,7 @@ engageRecord = compressRecord >=> encryptRecord
compressRecord :: Record Plaintext -> RecordM (Record Compressed)
compressRecord record =
onRecordFragment record $ fragmentCompress $ \bytes -> do
withTxCompression $ compressionDeflate bytes
withCompression $ compressionDeflate bytes
{-
- when Tx Encrypted is set, we pass the data through encryptContent, otherwise
@ -38,20 +38,20 @@ compressRecord record =
encryptRecord :: Record Compressed -> RecordM (Record Ciphertext)
encryptRecord record = onRecordFragment record $ fragmentCipher $ \bytes -> do
st <- get
case stCipher $ stTxState st of
case stCipher st of
Nothing -> return bytes
_ -> encryptContent record bytes
encryptContent :: Record Compressed -> ByteString -> RecordM ByteString
encryptContent record content = do
digest <- makeDigest True (recordToHeader record) content
digest <- makeDigest (recordToHeader record) content
encryptData $ B.concat [content, digest]
encryptData :: ByteString -> RecordM ByteString
encryptData content = do
st <- get
tstate <- get
ver <- getRecordVersion
let tstate = stTxState st
let cipher = fromJust "cipher" $ stCipher tstate
let bulk = cipherBulk cipher
let cst = stCryptState tstate
@ -71,14 +71,14 @@ encryptData content = do
B.empty
let e = encrypt writekey (cstIV cst) (B.concat [ content, padding ])
if hasExplicitBlockIV $ stVersion st
if hasExplicitBlockIV ver
then return $ B.concat [cstIV cst,e]
else do
let newiv = fromJust "new iv" $ takelast (bulkIVSize bulk) e
modifyTxState_ $ \txs -> txs { stCryptState = cst { cstIV = newiv } }
put $ tstate { stCryptState = cst { cstIV = newiv } }
return e
BulkStreamF initF encryptF _ -> do
let iv = cstIV cst
let (e, newiv) = encryptF (if iv /= B.empty then iv else initF writekey) content
modifyTxState_ $ \txs -> txs { stCryptState = cst { cstIV = newiv } }
put $ tstate { stCryptState = cst { cstIV = newiv } }
return e

View file

@ -12,15 +12,11 @@ module Network.TLS.Record.State
( CryptState(..)
, MacState(..)
, TransmissionState(..)
, RecordState(..)
, newRecordState
, RecordM(..)
, withTxCompression
, withRxCompression
, modifyTxState
, modifyRxState
, modifyTxState_
, modifyRxState_
, newTransmissionState
, RecordM
, runRecordM
, getRecordVersion
, withCompression
, computeDigest
, makeDigest
) where
@ -36,7 +32,7 @@ import Network.TLS.Wire
import Network.TLS.Packet
import Network.TLS.MAC
import Network.TLS.Util
import Network.TLS.Types (Role(..))
import Network.TLS.Types (Direction(..))
import qualified Data.ByteString as B
@ -57,26 +53,40 @@ data TransmissionState = TransmissionState
, stMacState :: !MacState
} deriving (Show)
data RecordState = RecordState
{ stClientContext :: Role
, stVersion :: !Version
, stTxState :: TransmissionState
, stRxState :: TransmissionState
} deriving (Show)
newtype RecordM a = RecordM { runRecordM :: Version
-> TransmissionState
-> Either TLSError (a, TransmissionState) }
newtype RecordM a = RecordM { runRecordM :: ErrorT TLSError (State RecordState) a }
deriving (Monad, MonadError TLSError)
instance Monad RecordM where
return a = RecordM $ \_ st -> Right (a, st)
m1 >>= m2 = RecordM $ \ver st -> do
case runRecordM m1 ver st of
Left err -> Left err
Right (a, st2) -> runRecordM (m2 a) ver st2
instance Functor RecordM where
fmap f = RecordM . fmap f . runRecordM
fmap f m = RecordM $ \ver st ->
case runRecordM m ver st of
Left err -> Left err
Right (a, st2) -> Right (f a, st2)
instance MonadState RecordState RecordM where
put x = RecordM (lift $ put x)
get = RecordM (lift get)
getRecordVersion :: RecordM Version
getRecordVersion = RecordM $ \ver st -> Right (ver, st)
instance MonadState TransmissionState RecordM where
put x = RecordM $ \_ _ -> Right ((), x)
get = RecordM $ \_ st -> Right (st, st)
#if MIN_VERSION_mtl(2,1,0)
state f = RecordM (lift $ state f)
state f = RecordM $ \_ st -> Right (f st)
#endif
instance MonadError TLSError RecordM where
throwError e = RecordM $ \_ _ -> Left e
catchError m f = RecordM $ \ver st ->
case runRecordM m ver st of
Left err -> runRecordM (f err) ver st
r -> r
newTransmissionState :: TransmissionState
newTransmissionState = TransmissionState
{ stCipher = Nothing
@ -89,39 +99,12 @@ incrTransmissionState :: TransmissionState -> TransmissionState
incrTransmissionState ts = ts { stMacState = MacState (ms + 1) }
where (MacState ms) = stMacState ts
newRecordState :: Role -> RecordState
newRecordState clientContext = RecordState
{ stClientContext = clientContext
, stVersion = TLS10
, stTxState = newTransmissionState
, stRxState = newTransmissionState
}
modifyTxState :: (TransmissionState -> (TransmissionState, a)) -> RecordM a
modifyTxState f =
get >>= \st -> case f $ stTxState st of
(nst, a) -> put (st { stTxState = nst }) >> return a
modifyTxState_ :: (TransmissionState -> TransmissionState) -> RecordM ()
modifyTxState_ f = modifyTxState (\t -> (f t, ()))
modifyRxState :: (TransmissionState -> (TransmissionState, a)) -> RecordM a
modifyRxState f =
get >>= \st -> case f $ stRxState st of
(nst, a) -> put (st { stRxState = nst }) >> return a
modifyRxState_ :: (TransmissionState -> TransmissionState) -> RecordM ()
modifyRxState_ f = modifyRxState (\t -> (f t, ()))
modifyCompression :: TransmissionState -> (Compression -> (Compression, a)) -> (TransmissionState, a)
modifyCompression tst f = case f (stCompression tst) of
(nc, a) -> (tst { stCompression = nc }, a)
withTxCompression :: (Compression -> (Compression, a)) -> RecordM a
withTxCompression f = modifyTxState $ \tst -> modifyCompression tst f
withRxCompression :: (Compression -> (Compression, a)) -> RecordM a
withRxCompression f = modifyRxState $ \tst -> modifyCompression tst f
withCompression :: (Compression -> (Compression, a)) -> RecordM a
withCompression f = do
st <- get
let (nc, a) = f $ stCompression st
put $ st { stCompression = nc }
return a
computeDigest :: Version -> TransmissionState -> Header -> Bytes -> (Bytes, TransmissionState)
computeDigest ver tstate hdr content = (digest, incrTransmissionState tstate)
@ -135,12 +118,10 @@ computeDigest ver tstate hdr content = (digest, incrTransmissionState tstate)
| ver < TLS10 = (macSSL hashf, B.concat [ encodedSeq, encodeHeaderNoVer hdr, content ])
| otherwise = (hmac hashf 64, B.concat [ encodedSeq, encodeHeader hdr, content ])
makeDigest :: Bool -> Header -> Bytes -> RecordM Bytes
makeDigest w hdr content = do
makeDigest :: Header -> Bytes -> RecordM Bytes
makeDigest hdr content = do
ver <- getRecordVersion
st <- get
let (digest, nstate) = computeDigest (stVersion st)
(if w then stTxState st else stRxState st) hdr content
put $ if w
then st { stTxState = nstate }
else st { stRxState = nstate }
let (digest, nstate) = computeDigest ver st hdr content
put nstate
return digest

View file

@ -32,7 +32,7 @@ import Network.TLS.Cipher
-- this doesn't change any state
makeRecord :: Packet -> RecordM (Record Plaintext)
makeRecord pkt = do
ver <- stVersion <$> get
ver <- getRecordVersion
return $ Record (packetType pkt) ver (fragmentPlaintext $ writePacketContent pkt)
where writePacketContent (Handshake hss) = encodeHandshakes hss
writePacketContent (Alert a) = encodeAlerts a
@ -67,17 +67,14 @@ prepareRecord :: RecordM a -> TLSSt a
prepareRecord f = do
st <- get
ver <- getVersion
let sz = case stCipher $ stTxState $ stRecordState st of
let sz = case stCipher $ stTxState st of
Nothing -> 0
Just cipher -> bulkIVSize $ cipherBulk cipher
if hasExplicitBlockIV ver && sz > 0
then do newIV <- genRandom sz
runRecordStateSt $ modify $ \rts ->
let ts = stTxState rts
nts = ts { stCryptState = (stCryptState ts) { cstIV = newIV } }
in rts { stTxState = nts }
runRecordStateSt f
else runRecordStateSt f
runTxState (modify $ \ts -> ts { stCryptState = (stCryptState ts) { cstIV = newIV } })
runTxState f
else runTxState f
{------------------------------------------------------------------------------}
{- SENDING Helpers -}

View file

@ -12,14 +12,13 @@
module Network.TLS.State
( TLSState(..)
, TLSSt
, RecordState(..)
, getRecordState
, runTLSState
, runRecordStateSt
, HandshakeState(..)
, withHandshakeM
, newTLSState
, withTLSRNG
, runTxState
, runRxState
, genRandom
, assert -- FIXME move somewhere else (Internal.hs ?)
, updateVerifiedData
@ -77,7 +76,8 @@ data TLSState = TLSState
{ stHandshake :: !(Maybe HandshakeState)
, stSession :: Session
, stSessionResuming :: Bool
, stRecordState :: RecordState
, stTxState :: TransmissionState
, stRxState :: TransmissionState
, stSecureRenegotiation :: Bool -- RFC 5746
, stClientVerifiedData :: Bytes -- RFC 5746
, stServerVerifiedData :: Bytes -- RFC 5746
@ -86,6 +86,8 @@ data TLSState = TLSState
, stServerNextProtocolSuggest :: Maybe [B.ByteString]
, stClientCertificateChain :: Maybe CertificateChain
, stRandomGen :: StateRNG
, stVersion :: Version
, stClientContext :: Role
} deriving (Show)
newtype TLSSt a = TLSSt { runTLSSt :: ErrorT TLSError (State TLSState) a }
@ -104,29 +106,27 @@ instance MonadState TLSState TLSSt where
runTLSState :: TLSSt a -> TLSState -> (Either TLSError a, TLSState)
runTLSState f st = runState (runErrorT (runTLSSt f)) st
getRecordState :: MonadState TLSState m => (RecordState -> a) -> m a
getRecordState f = gets (f . stRecordState)
runRecordState :: RecordM a -> TLSState -> (Either TLSError a, TLSState)
runRecordState f st =
let (r, nrst) = runState (runErrorT (runRecordM f)) (stRecordState st)
in case r of
Left _ -> (r, st)
Right _ -> (r, st { stRecordState = nrst })
runRecordStateSt :: RecordM a -> TLSSt a
runRecordStateSt f = do
runTxState :: RecordM a -> TLSSt a
runTxState f = do
st <- get
case runRecordState f st of
(Left e, _) -> throwError e
(Right a, newSt) -> put newSt >> return a
case runRecordM f (stVersion st) (stTxState st) of
Left err -> throwError err
Right (a, newSt) -> put (st { stTxState = newSt }) >> return a
runRxState :: RecordM a -> TLSSt a
runRxState f = do
st <- get
case runRecordM f (stVersion st) (stRxState st) of
Left err -> throwError err
Right (a, newSt) -> put (st { stRxState = newSt }) >> return a
newTLSState :: CPRG g => g -> Role -> TLSState
newTLSState rng clientContext = TLSState
{ stHandshake = Nothing
, stSession = Session Nothing
, stSessionResuming = False
, stRecordState = newRecordState clientContext
, stTxState = newTransmissionState
, stRxState = newTransmissionState
, stSecureRenegotiation = False
, stClientVerifiedData = B.empty
, stServerVerifiedData = B.empty
@ -135,6 +135,8 @@ newTLSState rng clientContext = TLSState
, stServerNextProtocolSuggest = Nothing
, stClientCertificateChain = Nothing
, stRandomGen = StateRNG rng
, stVersion = TLS10
, stClientContext = clientContext
}
updateVerifiedData :: MonadState TLSState m => Role -> Bytes -> m ()
@ -179,17 +181,17 @@ certVerifyHandshakeMaterial = certVerifyHandshakeTypeMaterial . typeOfHandshake
switchTxEncryption, switchRxEncryption :: TLSSt ()
switchTxEncryption =
withHandshakeM (gets hstPendingTxState)
>>= \newTxState -> runRecordStateSt (modify $ \st -> st { stTxState = fromJust "pending-tx" newTxState })
>>= \newTxState -> modify $ \st -> st { stTxState = fromJust "pending-tx" newTxState }
switchRxEncryption =
withHandshakeM (gets hstPendingRxState)
>>= \newRxState -> runRecordStateSt (modify $ \st -> st { stRxState = fromJust "pending-rx" newRxState })
>>= \newRxState -> modify $ \st -> st { stRxState = fromJust "pending-rx" newRxState }
getSessionData :: MonadState TLSState m => m (Maybe SessionData)
getSessionData = get >>= \st -> return (stHandshake st >>= hstMasterSecret >>= wrapSessionData st)
where wrapSessionData st masterSecret = do
return $ SessionData
{ sessionVersion = stVersion $ stRecordState st
, sessionCipher = cipherID $ fromJust "cipher" $ stCipher $ stTxState $ stRecordState st
{ sessionVersion = stVersion st
, sessionCipher = cipherID $ fromJust "cipher" $ stCipher $ stTxState $ st
, sessionSecret = masterSecret
}
@ -202,17 +204,17 @@ getSession = gets stSession
isSessionResuming :: MonadState TLSState m => m Bool
isSessionResuming = gets stSessionResuming
needEmptyPacket :: MonadState RecordState m => m Bool
needEmptyPacket :: MonadState TLSState m => m Bool
needEmptyPacket = gets f
where f st = (stVersion st <= TLS10)
&& stClientContext st == ClientRole
&& (maybe False (\c -> bulkBlockSize (cipherBulk c) > 0) (stCipher $ stTxState st))
setVersion :: MonadState TLSState m => Version -> m ()
setVersion ver = modify (\st -> st { stRecordState = (stRecordState st) { stVersion = ver } })
setVersion ver = modify (\st -> st { stVersion = ver })
getVersion :: MonadState TLSState m => m Version
getVersion = gets (stVersion . stRecordState)
getVersion = gets stVersion
setSecureRenegotiation :: MonadState TLSState m => Bool -> m ()
setSecureRenegotiation b = modify (\st -> st { stSecureRenegotiation = b })
@ -248,7 +250,7 @@ getVerifiedData :: MonadState TLSState m => Role -> m Bytes
getVerifiedData client = gets (if client == ClientRole then stClientVerifiedData else stServerVerifiedData)
isClientContext :: MonadState TLSState m => m Role
isClientContext = getRecordState stClientContext
isClientContext = gets stClientContext
startHandshakeClient :: MonadState TLSState m => Version -> ClientRandom -> m ()
startHandshakeClient ver crand = do

View file

@ -13,6 +13,7 @@ module Network.TLS.Types
, CompressionID
, Role(..)
, invertRole
, Direction(..)
) where
import Data.ByteString (ByteString)
@ -43,6 +44,10 @@ type CompressionID = Word8
data Role = ClientRole | ServerRole
deriving (Show,Eq)
-- | Direction
data Direction = Tx | Rx
deriving (Show,Eq)
invertRole :: Role -> Role
invertRole ClientRole = ServerRole
invertRole ServerRole = ClientRole