cleanup record layer

This commit is contained in:
Vincent Hanquez 2013-07-27 08:32:27 +01:00
parent af3feba166
commit c252ed8f49
6 changed files with 32 additions and 25 deletions

View file

@ -68,8 +68,8 @@ data HandshakeState = HandshakeState
, hstClientCertSent :: !Bool -- ^ Set to true when a client certificate chain was sent
, hstCertReqSent :: !Bool -- ^ Set to true when a certificate request was sent
, hstClientCertChain :: !(Maybe CertificateChain)
, hstPendingTxState :: Maybe TransmissionState
, hstPendingRxState :: Maybe TransmissionState
, hstPendingTxState :: Maybe RecordState
, hstPendingRxState :: Maybe RecordState
, hstPendingCipher :: Maybe Cipher
, hstPendingCompression :: Compression
} deriving (Show)
@ -186,7 +186,7 @@ setMasterSecret ver role masterSecret = modify $ \hst ->
, hstPendingTxState = Just pendingTx
, hstPendingRxState = Just pendingRx }
computeKeyBlock :: HandshakeState -> Bytes -> Version -> Role -> (TransmissionState, TransmissionState)
computeKeyBlock :: HandshakeState -> Bytes -> Version -> Role -> (RecordState, RecordState)
computeKeyBlock hst masterSecret ver cc = (pendingTx, pendingRx)
where cipher = fromJust "cipher" $ hstPendingCipher hst
keyblockSize = cipherKeyBlockSize cipher
@ -211,13 +211,13 @@ computeKeyBlock hst masterSecret ver cc = (pendingTx, pendingRx)
msClient = MacState { msSequence = 0 }
msServer = MacState { msSequence = 0 }
pendingTx = TransmissionState
pendingTx = RecordState
{ stCryptState = if cc == ClientRole then cstClient else cstServer
, stMacState = if cc == ClientRole then msClient else msServer
, stCipher = Just cipher
, stCompression = hstPendingCompression hst
}
pendingRx = TransmissionState
pendingRx = RecordState
{ stCryptState = if cc == ClientRole then cstServer else cstClient
, stMacState = if cc == ClientRole then msServer else msClient
, stCipher = Just cipher

View file

@ -29,6 +29,11 @@ module Network.TLS.Record
, disengageRecord
-- * State tracking
, RecordM
, runRecordM
, RecordState(..)
, newRecordState
, getRecordVersion
, setRecordIV
) where
import Network.TLS.Record.Types

View file

@ -62,7 +62,7 @@ getCipherData (Record pt ver _) cdata = do
return $ cipherDataContent cdata
decryptData :: Version -> Record Ciphertext -> Bytes -> TransmissionState -> RecordM Bytes
decryptData :: Version -> Record Ciphertext -> Bytes -> RecordState -> RecordM Bytes
decryptData ver record econtent tst = decryptOf (bulkF bulk)
where cipher = fromJust "cipher" $ stCipher tst
bulk = cipherBulk cipher

View file

@ -11,11 +11,12 @@
module Network.TLS.Record.State
( CryptState(..)
, MacState(..)
, TransmissionState(..)
, newTransmissionState
, RecordState(..)
, newRecordState
, RecordM
, runRecordM
, getRecordVersion
, setRecordIV
, withCompression
, computeDigest
, makeDigest
@ -32,7 +33,6 @@ import Network.TLS.Wire
import Network.TLS.Packet
import Network.TLS.MAC
import Network.TLS.Util
import Network.TLS.Types (Direction(..))
import qualified Data.ByteString as B
@ -46,7 +46,7 @@ newtype MacState = MacState
{ msSequence :: Word64
} deriving (Show)
data TransmissionState = TransmissionState
data RecordState = RecordState
{ stCipher :: Maybe Cipher
, stCompression :: Compression
, stCryptState :: !CryptState
@ -54,8 +54,8 @@ data TransmissionState = TransmissionState
} deriving (Show)
newtype RecordM a = RecordM { runRecordM :: Version
-> TransmissionState
-> Either TLSError (a, TransmissionState) }
-> RecordState
-> Either TLSError (a, RecordState) }
instance Monad RecordM where
return a = RecordM $ \_ st -> Right (a, st)
@ -73,7 +73,7 @@ instance Functor RecordM where
getRecordVersion :: RecordM Version
getRecordVersion = RecordM $ \ver st -> Right (ver, st)
instance MonadState TransmissionState RecordM where
instance MonadState RecordState RecordM where
put x = RecordM $ \_ _ -> Right ((), x)
get = RecordM $ \_ st -> Right (st, st)
#if MIN_VERSION_mtl(2,1,0)
@ -87,18 +87,21 @@ instance MonadError TLSError RecordM where
Left err -> runRecordM (f err) ver st
r -> r
newTransmissionState :: TransmissionState
newTransmissionState = TransmissionState
newRecordState :: RecordState
newRecordState = RecordState
{ stCipher = Nothing
, stCompression = nullCompression
, stCryptState = CryptState B.empty B.empty B.empty
, stMacState = MacState 0
}
incrTransmissionState :: TransmissionState -> TransmissionState
incrTransmissionState ts = ts { stMacState = MacState (ms + 1) }
incrRecordState :: RecordState -> RecordState
incrRecordState ts = ts { stMacState = MacState (ms + 1) }
where (MacState ms) = stMacState ts
setRecordIV :: Bytes -> RecordState -> RecordState
setRecordIV iv st = st { stCryptState = (stCryptState st) { cstIV = iv } }
withCompression :: (Compression -> (Compression, a)) -> RecordM a
withCompression f = do
st <- get
@ -106,8 +109,8 @@ withCompression f = do
put $ st { stCompression = nc }
return a
computeDigest :: Version -> TransmissionState -> Header -> Bytes -> (Bytes, TransmissionState)
computeDigest ver tstate hdr content = (digest, incrTransmissionState tstate)
computeDigest :: Version -> RecordState -> Header -> Bytes -> (Bytes, RecordState)
computeDigest ver tstate hdr content = (digest, incrRecordState tstate)
where digest = macF (cstMacSecret cst) msg
cst = stCryptState tstate
cipher = fromJust "cipher" $ stCipher tstate

View file

@ -24,7 +24,6 @@ import Network.TLS.Record
import Network.TLS.Packet
import Network.TLS.State
import Network.TLS.Handshake.State
import Network.TLS.Record.State
import Network.TLS.Crypto
import Network.TLS.Cipher
@ -72,7 +71,7 @@ prepareRecord f = do
Just cipher -> bulkIVSize $ cipherBulk cipher
if hasExplicitBlockIV ver && sz > 0
then do newIV <- genRandom sz
runTxState (modify $ \ts -> ts { stCryptState = (stCryptState ts) { cstIV = newIV } })
runTxState (modify $ setRecordIV newIV)
runTxState f
else runTxState f

View file

@ -76,8 +76,8 @@ data TLSState = TLSState
{ stHandshake :: !(Maybe HandshakeState)
, stSession :: Session
, stSessionResuming :: Bool
, stTxState :: TransmissionState
, stRxState :: TransmissionState
, stTxState :: RecordState
, stRxState :: RecordState
, stSecureRenegotiation :: Bool -- RFC 5746
, stClientVerifiedData :: Bytes -- RFC 5746
, stServerVerifiedData :: Bytes -- RFC 5746
@ -125,8 +125,8 @@ newTLSState rng clientContext = TLSState
{ stHandshake = Nothing
, stSession = Session Nothing
, stSessionResuming = False
, stTxState = newTransmissionState
, stRxState = newTransmissionState
, stTxState = newRecordState
, stRxState = newRecordState
, stSecureRenegotiation = False
, stClientVerifiedData = B.empty
, stServerVerifiedData = B.empty