cleanup record layer
This commit is contained in:
parent
af3feba166
commit
c252ed8f49
6 changed files with 32 additions and 25 deletions
|
@ -68,8 +68,8 @@ data HandshakeState = HandshakeState
|
|||
, hstClientCertSent :: !Bool -- ^ Set to true when a client certificate chain was sent
|
||||
, hstCertReqSent :: !Bool -- ^ Set to true when a certificate request was sent
|
||||
, hstClientCertChain :: !(Maybe CertificateChain)
|
||||
, hstPendingTxState :: Maybe TransmissionState
|
||||
, hstPendingRxState :: Maybe TransmissionState
|
||||
, hstPendingTxState :: Maybe RecordState
|
||||
, hstPendingRxState :: Maybe RecordState
|
||||
, hstPendingCipher :: Maybe Cipher
|
||||
, hstPendingCompression :: Compression
|
||||
} deriving (Show)
|
||||
|
@ -186,7 +186,7 @@ setMasterSecret ver role masterSecret = modify $ \hst ->
|
|||
, hstPendingTxState = Just pendingTx
|
||||
, hstPendingRxState = Just pendingRx }
|
||||
|
||||
computeKeyBlock :: HandshakeState -> Bytes -> Version -> Role -> (TransmissionState, TransmissionState)
|
||||
computeKeyBlock :: HandshakeState -> Bytes -> Version -> Role -> (RecordState, RecordState)
|
||||
computeKeyBlock hst masterSecret ver cc = (pendingTx, pendingRx)
|
||||
where cipher = fromJust "cipher" $ hstPendingCipher hst
|
||||
keyblockSize = cipherKeyBlockSize cipher
|
||||
|
@ -211,13 +211,13 @@ computeKeyBlock hst masterSecret ver cc = (pendingTx, pendingRx)
|
|||
msClient = MacState { msSequence = 0 }
|
||||
msServer = MacState { msSequence = 0 }
|
||||
|
||||
pendingTx = TransmissionState
|
||||
pendingTx = RecordState
|
||||
{ stCryptState = if cc == ClientRole then cstClient else cstServer
|
||||
, stMacState = if cc == ClientRole then msClient else msServer
|
||||
, stCipher = Just cipher
|
||||
, stCompression = hstPendingCompression hst
|
||||
}
|
||||
pendingRx = TransmissionState
|
||||
pendingRx = RecordState
|
||||
{ stCryptState = if cc == ClientRole then cstServer else cstClient
|
||||
, stMacState = if cc == ClientRole then msServer else msClient
|
||||
, stCipher = Just cipher
|
||||
|
|
|
@ -29,6 +29,11 @@ module Network.TLS.Record
|
|||
, disengageRecord
|
||||
-- * State tracking
|
||||
, RecordM
|
||||
, runRecordM
|
||||
, RecordState(..)
|
||||
, newRecordState
|
||||
, getRecordVersion
|
||||
, setRecordIV
|
||||
) where
|
||||
|
||||
import Network.TLS.Record.Types
|
||||
|
|
|
@ -62,7 +62,7 @@ getCipherData (Record pt ver _) cdata = do
|
|||
|
||||
return $ cipherDataContent cdata
|
||||
|
||||
decryptData :: Version -> Record Ciphertext -> Bytes -> TransmissionState -> RecordM Bytes
|
||||
decryptData :: Version -> Record Ciphertext -> Bytes -> RecordState -> RecordM Bytes
|
||||
decryptData ver record econtent tst = decryptOf (bulkF bulk)
|
||||
where cipher = fromJust "cipher" $ stCipher tst
|
||||
bulk = cipherBulk cipher
|
||||
|
|
|
@ -11,11 +11,12 @@
|
|||
module Network.TLS.Record.State
|
||||
( CryptState(..)
|
||||
, MacState(..)
|
||||
, TransmissionState(..)
|
||||
, newTransmissionState
|
||||
, RecordState(..)
|
||||
, newRecordState
|
||||
, RecordM
|
||||
, runRecordM
|
||||
, getRecordVersion
|
||||
, setRecordIV
|
||||
, withCompression
|
||||
, computeDigest
|
||||
, makeDigest
|
||||
|
@ -32,7 +33,6 @@ import Network.TLS.Wire
|
|||
import Network.TLS.Packet
|
||||
import Network.TLS.MAC
|
||||
import Network.TLS.Util
|
||||
import Network.TLS.Types (Direction(..))
|
||||
|
||||
import qualified Data.ByteString as B
|
||||
|
||||
|
@ -46,7 +46,7 @@ newtype MacState = MacState
|
|||
{ msSequence :: Word64
|
||||
} deriving (Show)
|
||||
|
||||
data TransmissionState = TransmissionState
|
||||
data RecordState = RecordState
|
||||
{ stCipher :: Maybe Cipher
|
||||
, stCompression :: Compression
|
||||
, stCryptState :: !CryptState
|
||||
|
@ -54,8 +54,8 @@ data TransmissionState = TransmissionState
|
|||
} deriving (Show)
|
||||
|
||||
newtype RecordM a = RecordM { runRecordM :: Version
|
||||
-> TransmissionState
|
||||
-> Either TLSError (a, TransmissionState) }
|
||||
-> RecordState
|
||||
-> Either TLSError (a, RecordState) }
|
||||
|
||||
instance Monad RecordM where
|
||||
return a = RecordM $ \_ st -> Right (a, st)
|
||||
|
@ -73,7 +73,7 @@ instance Functor RecordM where
|
|||
getRecordVersion :: RecordM Version
|
||||
getRecordVersion = RecordM $ \ver st -> Right (ver, st)
|
||||
|
||||
instance MonadState TransmissionState RecordM where
|
||||
instance MonadState RecordState RecordM where
|
||||
put x = RecordM $ \_ _ -> Right ((), x)
|
||||
get = RecordM $ \_ st -> Right (st, st)
|
||||
#if MIN_VERSION_mtl(2,1,0)
|
||||
|
@ -87,18 +87,21 @@ instance MonadError TLSError RecordM where
|
|||
Left err -> runRecordM (f err) ver st
|
||||
r -> r
|
||||
|
||||
newTransmissionState :: TransmissionState
|
||||
newTransmissionState = TransmissionState
|
||||
newRecordState :: RecordState
|
||||
newRecordState = RecordState
|
||||
{ stCipher = Nothing
|
||||
, stCompression = nullCompression
|
||||
, stCryptState = CryptState B.empty B.empty B.empty
|
||||
, stMacState = MacState 0
|
||||
}
|
||||
|
||||
incrTransmissionState :: TransmissionState -> TransmissionState
|
||||
incrTransmissionState ts = ts { stMacState = MacState (ms + 1) }
|
||||
incrRecordState :: RecordState -> RecordState
|
||||
incrRecordState ts = ts { stMacState = MacState (ms + 1) }
|
||||
where (MacState ms) = stMacState ts
|
||||
|
||||
setRecordIV :: Bytes -> RecordState -> RecordState
|
||||
setRecordIV iv st = st { stCryptState = (stCryptState st) { cstIV = iv } }
|
||||
|
||||
withCompression :: (Compression -> (Compression, a)) -> RecordM a
|
||||
withCompression f = do
|
||||
st <- get
|
||||
|
@ -106,8 +109,8 @@ withCompression f = do
|
|||
put $ st { stCompression = nc }
|
||||
return a
|
||||
|
||||
computeDigest :: Version -> TransmissionState -> Header -> Bytes -> (Bytes, TransmissionState)
|
||||
computeDigest ver tstate hdr content = (digest, incrTransmissionState tstate)
|
||||
computeDigest :: Version -> RecordState -> Header -> Bytes -> (Bytes, RecordState)
|
||||
computeDigest ver tstate hdr content = (digest, incrRecordState tstate)
|
||||
where digest = macF (cstMacSecret cst) msg
|
||||
cst = stCryptState tstate
|
||||
cipher = fromJust "cipher" $ stCipher tstate
|
||||
|
|
|
@ -24,7 +24,6 @@ import Network.TLS.Record
|
|||
import Network.TLS.Packet
|
||||
import Network.TLS.State
|
||||
import Network.TLS.Handshake.State
|
||||
import Network.TLS.Record.State
|
||||
import Network.TLS.Crypto
|
||||
import Network.TLS.Cipher
|
||||
|
||||
|
@ -72,7 +71,7 @@ prepareRecord f = do
|
|||
Just cipher -> bulkIVSize $ cipherBulk cipher
|
||||
if hasExplicitBlockIV ver && sz > 0
|
||||
then do newIV <- genRandom sz
|
||||
runTxState (modify $ \ts -> ts { stCryptState = (stCryptState ts) { cstIV = newIV } })
|
||||
runTxState (modify $ setRecordIV newIV)
|
||||
runTxState f
|
||||
else runTxState f
|
||||
|
||||
|
|
|
@ -76,8 +76,8 @@ data TLSState = TLSState
|
|||
{ stHandshake :: !(Maybe HandshakeState)
|
||||
, stSession :: Session
|
||||
, stSessionResuming :: Bool
|
||||
, stTxState :: TransmissionState
|
||||
, stRxState :: TransmissionState
|
||||
, stTxState :: RecordState
|
||||
, stRxState :: RecordState
|
||||
, stSecureRenegotiation :: Bool -- RFC 5746
|
||||
, stClientVerifiedData :: Bytes -- RFC 5746
|
||||
, stServerVerifiedData :: Bytes -- RFC 5746
|
||||
|
@ -125,8 +125,8 @@ newTLSState rng clientContext = TLSState
|
|||
{ stHandshake = Nothing
|
||||
, stSession = Session Nothing
|
||||
, stSessionResuming = False
|
||||
, stTxState = newTransmissionState
|
||||
, stRxState = newTransmissionState
|
||||
, stTxState = newRecordState
|
||||
, stRxState = newRecordState
|
||||
, stSecureRenegotiation = False
|
||||
, stClientVerifiedData = B.empty
|
||||
, stServerVerifiedData = B.empty
|
||||
|
|
Loading…
Reference in a new issue