Separate tx/rx state from a single RecordState
unroll a reader/state/error monad into a single simple monad, and move back version and client context in state.
This commit is contained in:
parent
e3b3483560
commit
e2d5170af7
11 changed files with 110 additions and 128 deletions
|
@ -218,7 +218,7 @@ sendClientData cparams ctx = sendCertificate >> sendClientKeyXchg >> sendCertifi
|
|||
-- 4. Send it to the server.
|
||||
--
|
||||
sendCertificateVerify = do
|
||||
usedVersion <- usingState_ ctx $ stVersion . stRecordState <$> get
|
||||
usedVersion <- usingState_ ctx getVersion
|
||||
|
||||
-- Only send a certificate verify message when we
|
||||
-- have sent a non-empty list of certificates.
|
||||
|
|
|
@ -183,7 +183,7 @@ handshakeServerWith sparams ctx clientHello@(ClientHello ver _ clientSession cip
|
|||
-- certificates.
|
||||
--
|
||||
when (serverWantClientCert sparams) $ do
|
||||
usedVersion <- usingState_ ctx $ getRecordState stVersion
|
||||
usedVersion <- usingState_ ctx getVersion
|
||||
let certTypes = [ CertificateType_RSA_Sign ]
|
||||
hashSigs = if usedVersion < TLS12
|
||||
then Nothing
|
||||
|
@ -247,7 +247,7 @@ recvClientData sparams ctx = runRecvState ctx (RecvStateHandshake processClientC
|
|||
-- Fetch all handshake messages up to now.
|
||||
msgs <- usingHState ctx $ B.concat <$> getHandshakeMessages
|
||||
|
||||
usedVersion <- usingState_ ctx $ getRecordState stVersion
|
||||
usedVersion <- usingState_ ctx getVersion
|
||||
|
||||
(signature, hsh) <- case usedVersion of
|
||||
SSL3 -> do
|
||||
|
|
|
@ -15,7 +15,7 @@ module Network.TLS.IO
|
|||
) where
|
||||
|
||||
import Network.TLS.Context
|
||||
import Network.TLS.State (needEmptyPacket, runRecordStateSt)
|
||||
import Network.TLS.State (needEmptyPacket, runRxState)
|
||||
import Network.TLS.Struct
|
||||
import Network.TLS.Record
|
||||
import Network.TLS.Packet
|
||||
|
@ -78,7 +78,7 @@ recvRecord compatSSLv2 ctx
|
|||
maximumSizeExceeded = Error_Protocol ("record exceeding maximum size", True, RecordOverflow)
|
||||
makeRecord header content = do
|
||||
liftIO $ (loggingIORecv $ ctxLogging ctx) header content
|
||||
usingState ctx $ runRecordStateSt $ disengageRecord $ rawToRecord header (fragmentCiphertext content)
|
||||
usingState ctx $ runRxState $ disengageRecord $ rawToRecord header (fragmentCiphertext content)
|
||||
|
||||
|
||||
-- | receive one packet from the context that contains 1 or
|
||||
|
@ -109,7 +109,7 @@ sendPacket ctx pkt = do
|
|||
-- in ver <= TLS1.0, block ciphers using CBC are using CBC residue as IV, which can be guessed
|
||||
-- by an attacker. Hence, an empty packet is sent before a normal data packet, to
|
||||
-- prevent guessability.
|
||||
withEmptyPacket <- usingState_ ctx (runRecordStateSt needEmptyPacket)
|
||||
withEmptyPacket <- usingState_ ctx needEmptyPacket
|
||||
when (isNonNullAppData pkt && withEmptyPacket) $ sendPacket ctx $ AppData B.empty
|
||||
|
||||
liftIO $ (loggingPacketSent $ ctxLogging ctx) (show pkt)
|
||||
|
|
|
@ -99,7 +99,7 @@ processHandshake hs = do
|
|||
|
||||
decryptRSA :: ByteString -> TLSSt (Either KxError ByteString)
|
||||
decryptRSA econtent = do
|
||||
ver <- getRecordState stVersion
|
||||
ver <- getVersion
|
||||
rsapriv <- fromJust "rsa private key" . hstRSAPrivateKey . fromJust "handshake" . stHandshake <$> get
|
||||
let cipher = if ver < TLS10 then econtent else B.drop 2 econtent
|
||||
st <- get
|
||||
|
|
|
@ -28,9 +28,7 @@ module Network.TLS.Record
|
|||
, engageRecord
|
||||
, disengageRecord
|
||||
-- * State tracking
|
||||
, RecordState(..)
|
||||
, RecordM
|
||||
, newRecordState
|
||||
) where
|
||||
|
||||
import Network.TLS.Record.Types
|
||||
|
|
|
@ -30,14 +30,14 @@ disengageRecord = decryptRecord >=> uncompressRecord
|
|||
|
||||
uncompressRecord :: Record Compressed -> RecordM (Record Plaintext)
|
||||
uncompressRecord record = onRecordFragment record $ fragmentUncompress $ \bytes ->
|
||||
withRxCompression $ compressionInflate bytes
|
||||
withCompression $ compressionInflate bytes
|
||||
|
||||
decryptRecord :: Record Ciphertext -> RecordM (Record Compressed)
|
||||
decryptRecord record = onRecordFragment record $ fragmentUncipher $ \e -> do
|
||||
st <- get
|
||||
case stCipher $ stRxState st of
|
||||
case stCipher st of
|
||||
Nothing -> return e
|
||||
_ -> decryptData record e st
|
||||
_ -> getRecordVersion >>= \ver -> decryptData ver record e st
|
||||
|
||||
getCipherData :: Record a -> CipherData -> RecordM ByteString
|
||||
getCipherData (Record pt ver _) cdata = do
|
||||
|
@ -46,14 +46,14 @@ getCipherData (Record pt ver _) cdata = do
|
|||
Nothing -> return True
|
||||
Just digest -> do
|
||||
let new_hdr = Header pt ver (fromIntegral $ B.length $ cipherDataContent cdata)
|
||||
expected_digest <- makeDigest False new_hdr $ cipherDataContent cdata
|
||||
expected_digest <- makeDigest new_hdr $ cipherDataContent cdata
|
||||
return (expected_digest `bytesEq` digest)
|
||||
|
||||
-- check if the padding is filled with the correct pattern if it exists
|
||||
paddingValid <- case cipherDataPadding cdata of
|
||||
Nothing -> return True
|
||||
Just pad -> do
|
||||
cver <- gets stVersion
|
||||
cver <- getRecordVersion
|
||||
let b = B.length pad - 1
|
||||
return (if cver < TLS10 then True else B.replicate (B.length pad) (fromIntegral b) `bytesEq` pad)
|
||||
|
||||
|
@ -62,10 +62,9 @@ getCipherData (Record pt ver _) cdata = do
|
|||
|
||||
return $ cipherDataContent cdata
|
||||
|
||||
decryptData :: Record Ciphertext -> Bytes -> RecordState -> RecordM Bytes
|
||||
decryptData record econtent st = decryptOf (bulkF bulk)
|
||||
where tst = stRxState st
|
||||
cipher = fromJust "cipher" $ stCipher tst
|
||||
decryptData :: Version -> Record Ciphertext -> Bytes -> TransmissionState -> RecordM Bytes
|
||||
decryptData ver record econtent tst = decryptOf (bulkF bulk)
|
||||
where cipher = fromJust "cipher" $ stCipher tst
|
||||
bulk = cipherBulk cipher
|
||||
cst = stCryptState tst
|
||||
macSize = hashSize $ cipherHash cipher
|
||||
|
@ -73,7 +72,7 @@ decryptData record econtent st = decryptOf (bulkF bulk)
|
|||
blockSize = bulkBlockSize bulk
|
||||
econtentLen = B.length econtent
|
||||
|
||||
explicitIV = hasExplicitBlockIV $ stVersion st
|
||||
explicitIV = hasExplicitBlockIV ver
|
||||
|
||||
sanityCheckError = throwError (Error_Packet "encrypted content too small for encryption parameters")
|
||||
|
||||
|
@ -86,7 +85,7 @@ decryptData record econtent st = decryptOf (bulkF bulk)
|
|||
then get2 econtent (bulkIVSize bulk, econtentLen - bulkIVSize bulk)
|
||||
else return (cstIV cst, econtent)
|
||||
let newiv = fromJust "new iv" $ takelast (bulkBlockSize bulk) econtent'
|
||||
modifyRxState_ $ \txs -> txs { stCryptState = cst { cstIV = newiv } }
|
||||
modify $ \txs -> txs { stCryptState = cst { cstIV = newiv } }
|
||||
|
||||
let content' = decryptF writekey iv econtent'
|
||||
let paddinglength = fromIntegral (B.last content') + 1
|
||||
|
@ -104,7 +103,7 @@ decryptData record econtent st = decryptOf (bulkF bulk)
|
|||
{- update Ctx -}
|
||||
let contentlen = B.length content' - macSize
|
||||
(content, mac) <- get2 content' (contentlen, macSize)
|
||||
modifyRxState_ $ \txs -> txs { stCryptState = cst { cstIV = newiv } }
|
||||
modify $ \txs -> txs { stCryptState = cst { cstIV = newiv } }
|
||||
getCipherData record $ CipherData
|
||||
{ cipherDataContent = content
|
||||
, cipherDataMAC = Just mac
|
||||
|
|
|
@ -29,7 +29,7 @@ engageRecord = compressRecord >=> encryptRecord
|
|||
compressRecord :: Record Plaintext -> RecordM (Record Compressed)
|
||||
compressRecord record =
|
||||
onRecordFragment record $ fragmentCompress $ \bytes -> do
|
||||
withTxCompression $ compressionDeflate bytes
|
||||
withCompression $ compressionDeflate bytes
|
||||
|
||||
{-
|
||||
- when Tx Encrypted is set, we pass the data through encryptContent, otherwise
|
||||
|
@ -38,20 +38,20 @@ compressRecord record =
|
|||
encryptRecord :: Record Compressed -> RecordM (Record Ciphertext)
|
||||
encryptRecord record = onRecordFragment record $ fragmentCipher $ \bytes -> do
|
||||
st <- get
|
||||
case stCipher $ stTxState st of
|
||||
case stCipher st of
|
||||
Nothing -> return bytes
|
||||
_ -> encryptContent record bytes
|
||||
|
||||
encryptContent :: Record Compressed -> ByteString -> RecordM ByteString
|
||||
encryptContent record content = do
|
||||
digest <- makeDigest True (recordToHeader record) content
|
||||
digest <- makeDigest (recordToHeader record) content
|
||||
encryptData $ B.concat [content, digest]
|
||||
|
||||
encryptData :: ByteString -> RecordM ByteString
|
||||
encryptData content = do
|
||||
st <- get
|
||||
tstate <- get
|
||||
ver <- getRecordVersion
|
||||
|
||||
let tstate = stTxState st
|
||||
let cipher = fromJust "cipher" $ stCipher tstate
|
||||
let bulk = cipherBulk cipher
|
||||
let cst = stCryptState tstate
|
||||
|
@ -71,14 +71,14 @@ encryptData content = do
|
|||
B.empty
|
||||
|
||||
let e = encrypt writekey (cstIV cst) (B.concat [ content, padding ])
|
||||
if hasExplicitBlockIV $ stVersion st
|
||||
if hasExplicitBlockIV ver
|
||||
then return $ B.concat [cstIV cst,e]
|
||||
else do
|
||||
let newiv = fromJust "new iv" $ takelast (bulkIVSize bulk) e
|
||||
modifyTxState_ $ \txs -> txs { stCryptState = cst { cstIV = newiv } }
|
||||
put $ tstate { stCryptState = cst { cstIV = newiv } }
|
||||
return e
|
||||
BulkStreamF initF encryptF _ -> do
|
||||
let iv = cstIV cst
|
||||
let (e, newiv) = encryptF (if iv /= B.empty then iv else initF writekey) content
|
||||
modifyTxState_ $ \txs -> txs { stCryptState = cst { cstIV = newiv } }
|
||||
put $ tstate { stCryptState = cst { cstIV = newiv } }
|
||||
return e
|
||||
|
|
|
@ -12,15 +12,11 @@ module Network.TLS.Record.State
|
|||
( CryptState(..)
|
||||
, MacState(..)
|
||||
, TransmissionState(..)
|
||||
, RecordState(..)
|
||||
, newRecordState
|
||||
, RecordM(..)
|
||||
, withTxCompression
|
||||
, withRxCompression
|
||||
, modifyTxState
|
||||
, modifyRxState
|
||||
, modifyTxState_
|
||||
, modifyRxState_
|
||||
, newTransmissionState
|
||||
, RecordM
|
||||
, runRecordM
|
||||
, getRecordVersion
|
||||
, withCompression
|
||||
, computeDigest
|
||||
, makeDigest
|
||||
) where
|
||||
|
@ -36,7 +32,7 @@ import Network.TLS.Wire
|
|||
import Network.TLS.Packet
|
||||
import Network.TLS.MAC
|
||||
import Network.TLS.Util
|
||||
import Network.TLS.Types (Role(..))
|
||||
import Network.TLS.Types (Direction(..))
|
||||
|
||||
import qualified Data.ByteString as B
|
||||
|
||||
|
@ -57,26 +53,40 @@ data TransmissionState = TransmissionState
|
|||
, stMacState :: !MacState
|
||||
} deriving (Show)
|
||||
|
||||
data RecordState = RecordState
|
||||
{ stClientContext :: Role
|
||||
, stVersion :: !Version
|
||||
, stTxState :: TransmissionState
|
||||
, stRxState :: TransmissionState
|
||||
} deriving (Show)
|
||||
newtype RecordM a = RecordM { runRecordM :: Version
|
||||
-> TransmissionState
|
||||
-> Either TLSError (a, TransmissionState) }
|
||||
|
||||
newtype RecordM a = RecordM { runRecordM :: ErrorT TLSError (State RecordState) a }
|
||||
deriving (Monad, MonadError TLSError)
|
||||
instance Monad RecordM where
|
||||
return a = RecordM $ \_ st -> Right (a, st)
|
||||
m1 >>= m2 = RecordM $ \ver st -> do
|
||||
case runRecordM m1 ver st of
|
||||
Left err -> Left err
|
||||
Right (a, st2) -> runRecordM (m2 a) ver st2
|
||||
|
||||
instance Functor RecordM where
|
||||
fmap f = RecordM . fmap f . runRecordM
|
||||
fmap f m = RecordM $ \ver st ->
|
||||
case runRecordM m ver st of
|
||||
Left err -> Left err
|
||||
Right (a, st2) -> Right (f a, st2)
|
||||
|
||||
instance MonadState RecordState RecordM where
|
||||
put x = RecordM (lift $ put x)
|
||||
get = RecordM (lift get)
|
||||
getRecordVersion :: RecordM Version
|
||||
getRecordVersion = RecordM $ \ver st -> Right (ver, st)
|
||||
|
||||
instance MonadState TransmissionState RecordM where
|
||||
put x = RecordM $ \_ _ -> Right ((), x)
|
||||
get = RecordM $ \_ st -> Right (st, st)
|
||||
#if MIN_VERSION_mtl(2,1,0)
|
||||
state f = RecordM (lift $ state f)
|
||||
state f = RecordM $ \_ st -> Right (f st)
|
||||
#endif
|
||||
|
||||
instance MonadError TLSError RecordM where
|
||||
throwError e = RecordM $ \_ _ -> Left e
|
||||
catchError m f = RecordM $ \ver st ->
|
||||
case runRecordM m ver st of
|
||||
Left err -> runRecordM (f err) ver st
|
||||
r -> r
|
||||
|
||||
newTransmissionState :: TransmissionState
|
||||
newTransmissionState = TransmissionState
|
||||
{ stCipher = Nothing
|
||||
|
@ -89,39 +99,12 @@ incrTransmissionState :: TransmissionState -> TransmissionState
|
|||
incrTransmissionState ts = ts { stMacState = MacState (ms + 1) }
|
||||
where (MacState ms) = stMacState ts
|
||||
|
||||
newRecordState :: Role -> RecordState
|
||||
newRecordState clientContext = RecordState
|
||||
{ stClientContext = clientContext
|
||||
, stVersion = TLS10
|
||||
, stTxState = newTransmissionState
|
||||
, stRxState = newTransmissionState
|
||||
}
|
||||
|
||||
modifyTxState :: (TransmissionState -> (TransmissionState, a)) -> RecordM a
|
||||
modifyTxState f =
|
||||
get >>= \st -> case f $ stTxState st of
|
||||
(nst, a) -> put (st { stTxState = nst }) >> return a
|
||||
|
||||
modifyTxState_ :: (TransmissionState -> TransmissionState) -> RecordM ()
|
||||
modifyTxState_ f = modifyTxState (\t -> (f t, ()))
|
||||
|
||||
modifyRxState :: (TransmissionState -> (TransmissionState, a)) -> RecordM a
|
||||
modifyRxState f =
|
||||
get >>= \st -> case f $ stRxState st of
|
||||
(nst, a) -> put (st { stRxState = nst }) >> return a
|
||||
|
||||
modifyRxState_ :: (TransmissionState -> TransmissionState) -> RecordM ()
|
||||
modifyRxState_ f = modifyRxState (\t -> (f t, ()))
|
||||
|
||||
modifyCompression :: TransmissionState -> (Compression -> (Compression, a)) -> (TransmissionState, a)
|
||||
modifyCompression tst f = case f (stCompression tst) of
|
||||
(nc, a) -> (tst { stCompression = nc }, a)
|
||||
|
||||
withTxCompression :: (Compression -> (Compression, a)) -> RecordM a
|
||||
withTxCompression f = modifyTxState $ \tst -> modifyCompression tst f
|
||||
|
||||
withRxCompression :: (Compression -> (Compression, a)) -> RecordM a
|
||||
withRxCompression f = modifyRxState $ \tst -> modifyCompression tst f
|
||||
withCompression :: (Compression -> (Compression, a)) -> RecordM a
|
||||
withCompression f = do
|
||||
st <- get
|
||||
let (nc, a) = f $ stCompression st
|
||||
put $ st { stCompression = nc }
|
||||
return a
|
||||
|
||||
computeDigest :: Version -> TransmissionState -> Header -> Bytes -> (Bytes, TransmissionState)
|
||||
computeDigest ver tstate hdr content = (digest, incrTransmissionState tstate)
|
||||
|
@ -135,12 +118,10 @@ computeDigest ver tstate hdr content = (digest, incrTransmissionState tstate)
|
|||
| ver < TLS10 = (macSSL hashf, B.concat [ encodedSeq, encodeHeaderNoVer hdr, content ])
|
||||
| otherwise = (hmac hashf 64, B.concat [ encodedSeq, encodeHeader hdr, content ])
|
||||
|
||||
makeDigest :: Bool -> Header -> Bytes -> RecordM Bytes
|
||||
makeDigest w hdr content = do
|
||||
makeDigest :: Header -> Bytes -> RecordM Bytes
|
||||
makeDigest hdr content = do
|
||||
ver <- getRecordVersion
|
||||
st <- get
|
||||
let (digest, nstate) = computeDigest (stVersion st)
|
||||
(if w then stTxState st else stRxState st) hdr content
|
||||
put $ if w
|
||||
then st { stTxState = nstate }
|
||||
else st { stRxState = nstate }
|
||||
let (digest, nstate) = computeDigest ver st hdr content
|
||||
put nstate
|
||||
return digest
|
||||
|
|
|
@ -32,7 +32,7 @@ import Network.TLS.Cipher
|
|||
-- this doesn't change any state
|
||||
makeRecord :: Packet -> RecordM (Record Plaintext)
|
||||
makeRecord pkt = do
|
||||
ver <- stVersion <$> get
|
||||
ver <- getRecordVersion
|
||||
return $ Record (packetType pkt) ver (fragmentPlaintext $ writePacketContent pkt)
|
||||
where writePacketContent (Handshake hss) = encodeHandshakes hss
|
||||
writePacketContent (Alert a) = encodeAlerts a
|
||||
|
@ -67,17 +67,14 @@ prepareRecord :: RecordM a -> TLSSt a
|
|||
prepareRecord f = do
|
||||
st <- get
|
||||
ver <- getVersion
|
||||
let sz = case stCipher $ stTxState $ stRecordState st of
|
||||
let sz = case stCipher $ stTxState st of
|
||||
Nothing -> 0
|
||||
Just cipher -> bulkIVSize $ cipherBulk cipher
|
||||
if hasExplicitBlockIV ver && sz > 0
|
||||
then do newIV <- genRandom sz
|
||||
runRecordStateSt $ modify $ \rts ->
|
||||
let ts = stTxState rts
|
||||
nts = ts { stCryptState = (stCryptState ts) { cstIV = newIV } }
|
||||
in rts { stTxState = nts }
|
||||
runRecordStateSt f
|
||||
else runRecordStateSt f
|
||||
runTxState (modify $ \ts -> ts { stCryptState = (stCryptState ts) { cstIV = newIV } })
|
||||
runTxState f
|
||||
else runTxState f
|
||||
|
||||
{------------------------------------------------------------------------------}
|
||||
{- SENDING Helpers -}
|
||||
|
|
|
@ -12,14 +12,13 @@
|
|||
module Network.TLS.State
|
||||
( TLSState(..)
|
||||
, TLSSt
|
||||
, RecordState(..)
|
||||
, getRecordState
|
||||
, runTLSState
|
||||
, runRecordStateSt
|
||||
, HandshakeState(..)
|
||||
, withHandshakeM
|
||||
, newTLSState
|
||||
, withTLSRNG
|
||||
, runTxState
|
||||
, runRxState
|
||||
, genRandom
|
||||
, assert -- FIXME move somewhere else (Internal.hs ?)
|
||||
, updateVerifiedData
|
||||
|
@ -77,7 +76,8 @@ data TLSState = TLSState
|
|||
{ stHandshake :: !(Maybe HandshakeState)
|
||||
, stSession :: Session
|
||||
, stSessionResuming :: Bool
|
||||
, stRecordState :: RecordState
|
||||
, stTxState :: TransmissionState
|
||||
, stRxState :: TransmissionState
|
||||
, stSecureRenegotiation :: Bool -- RFC 5746
|
||||
, stClientVerifiedData :: Bytes -- RFC 5746
|
||||
, stServerVerifiedData :: Bytes -- RFC 5746
|
||||
|
@ -86,6 +86,8 @@ data TLSState = TLSState
|
|||
, stServerNextProtocolSuggest :: Maybe [B.ByteString]
|
||||
, stClientCertificateChain :: Maybe CertificateChain
|
||||
, stRandomGen :: StateRNG
|
||||
, stVersion :: Version
|
||||
, stClientContext :: Role
|
||||
} deriving (Show)
|
||||
|
||||
newtype TLSSt a = TLSSt { runTLSSt :: ErrorT TLSError (State TLSState) a }
|
||||
|
@ -104,29 +106,27 @@ instance MonadState TLSState TLSSt where
|
|||
runTLSState :: TLSSt a -> TLSState -> (Either TLSError a, TLSState)
|
||||
runTLSState f st = runState (runErrorT (runTLSSt f)) st
|
||||
|
||||
getRecordState :: MonadState TLSState m => (RecordState -> a) -> m a
|
||||
getRecordState f = gets (f . stRecordState)
|
||||
|
||||
runRecordState :: RecordM a -> TLSState -> (Either TLSError a, TLSState)
|
||||
runRecordState f st =
|
||||
let (r, nrst) = runState (runErrorT (runRecordM f)) (stRecordState st)
|
||||
in case r of
|
||||
Left _ -> (r, st)
|
||||
Right _ -> (r, st { stRecordState = nrst })
|
||||
|
||||
runRecordStateSt :: RecordM a -> TLSSt a
|
||||
runRecordStateSt f = do
|
||||
runTxState :: RecordM a -> TLSSt a
|
||||
runTxState f = do
|
||||
st <- get
|
||||
case runRecordState f st of
|
||||
(Left e, _) -> throwError e
|
||||
(Right a, newSt) -> put newSt >> return a
|
||||
case runRecordM f (stVersion st) (stTxState st) of
|
||||
Left err -> throwError err
|
||||
Right (a, newSt) -> put (st { stTxState = newSt }) >> return a
|
||||
|
||||
runRxState :: RecordM a -> TLSSt a
|
||||
runRxState f = do
|
||||
st <- get
|
||||
case runRecordM f (stVersion st) (stRxState st) of
|
||||
Left err -> throwError err
|
||||
Right (a, newSt) -> put (st { stRxState = newSt }) >> return a
|
||||
|
||||
newTLSState :: CPRG g => g -> Role -> TLSState
|
||||
newTLSState rng clientContext = TLSState
|
||||
{ stHandshake = Nothing
|
||||
, stSession = Session Nothing
|
||||
, stSessionResuming = False
|
||||
, stRecordState = newRecordState clientContext
|
||||
, stTxState = newTransmissionState
|
||||
, stRxState = newTransmissionState
|
||||
, stSecureRenegotiation = False
|
||||
, stClientVerifiedData = B.empty
|
||||
, stServerVerifiedData = B.empty
|
||||
|
@ -135,6 +135,8 @@ newTLSState rng clientContext = TLSState
|
|||
, stServerNextProtocolSuggest = Nothing
|
||||
, stClientCertificateChain = Nothing
|
||||
, stRandomGen = StateRNG rng
|
||||
, stVersion = TLS10
|
||||
, stClientContext = clientContext
|
||||
}
|
||||
|
||||
updateVerifiedData :: MonadState TLSState m => Role -> Bytes -> m ()
|
||||
|
@ -179,17 +181,17 @@ certVerifyHandshakeMaterial = certVerifyHandshakeTypeMaterial . typeOfHandshake
|
|||
switchTxEncryption, switchRxEncryption :: TLSSt ()
|
||||
switchTxEncryption =
|
||||
withHandshakeM (gets hstPendingTxState)
|
||||
>>= \newTxState -> runRecordStateSt (modify $ \st -> st { stTxState = fromJust "pending-tx" newTxState })
|
||||
>>= \newTxState -> modify $ \st -> st { stTxState = fromJust "pending-tx" newTxState }
|
||||
switchRxEncryption =
|
||||
withHandshakeM (gets hstPendingRxState)
|
||||
>>= \newRxState -> runRecordStateSt (modify $ \st -> st { stRxState = fromJust "pending-rx" newRxState })
|
||||
>>= \newRxState -> modify $ \st -> st { stRxState = fromJust "pending-rx" newRxState }
|
||||
|
||||
getSessionData :: MonadState TLSState m => m (Maybe SessionData)
|
||||
getSessionData = get >>= \st -> return (stHandshake st >>= hstMasterSecret >>= wrapSessionData st)
|
||||
where wrapSessionData st masterSecret = do
|
||||
return $ SessionData
|
||||
{ sessionVersion = stVersion $ stRecordState st
|
||||
, sessionCipher = cipherID $ fromJust "cipher" $ stCipher $ stTxState $ stRecordState st
|
||||
{ sessionVersion = stVersion st
|
||||
, sessionCipher = cipherID $ fromJust "cipher" $ stCipher $ stTxState $ st
|
||||
, sessionSecret = masterSecret
|
||||
}
|
||||
|
||||
|
@ -202,17 +204,17 @@ getSession = gets stSession
|
|||
isSessionResuming :: MonadState TLSState m => m Bool
|
||||
isSessionResuming = gets stSessionResuming
|
||||
|
||||
needEmptyPacket :: MonadState RecordState m => m Bool
|
||||
needEmptyPacket :: MonadState TLSState m => m Bool
|
||||
needEmptyPacket = gets f
|
||||
where f st = (stVersion st <= TLS10)
|
||||
&& stClientContext st == ClientRole
|
||||
&& (maybe False (\c -> bulkBlockSize (cipherBulk c) > 0) (stCipher $ stTxState st))
|
||||
|
||||
setVersion :: MonadState TLSState m => Version -> m ()
|
||||
setVersion ver = modify (\st -> st { stRecordState = (stRecordState st) { stVersion = ver } })
|
||||
setVersion ver = modify (\st -> st { stVersion = ver })
|
||||
|
||||
getVersion :: MonadState TLSState m => m Version
|
||||
getVersion = gets (stVersion . stRecordState)
|
||||
getVersion = gets stVersion
|
||||
|
||||
setSecureRenegotiation :: MonadState TLSState m => Bool -> m ()
|
||||
setSecureRenegotiation b = modify (\st -> st { stSecureRenegotiation = b })
|
||||
|
@ -248,7 +250,7 @@ getVerifiedData :: MonadState TLSState m => Role -> m Bytes
|
|||
getVerifiedData client = gets (if client == ClientRole then stClientVerifiedData else stServerVerifiedData)
|
||||
|
||||
isClientContext :: MonadState TLSState m => m Role
|
||||
isClientContext = getRecordState stClientContext
|
||||
isClientContext = gets stClientContext
|
||||
|
||||
startHandshakeClient :: MonadState TLSState m => Version -> ClientRandom -> m ()
|
||||
startHandshakeClient ver crand = do
|
||||
|
|
|
@ -13,6 +13,7 @@ module Network.TLS.Types
|
|||
, CompressionID
|
||||
, Role(..)
|
||||
, invertRole
|
||||
, Direction(..)
|
||||
) where
|
||||
|
||||
import Data.ByteString (ByteString)
|
||||
|
@ -43,6 +44,10 @@ type CompressionID = Word8
|
|||
data Role = ClientRole | ServerRole
|
||||
deriving (Show,Eq)
|
||||
|
||||
-- | Direction
|
||||
data Direction = Tx | Rx
|
||||
deriving (Show,Eq)
|
||||
|
||||
invertRole :: Role -> Role
|
||||
invertRole ClientRole = ServerRole
|
||||
invertRole ServerRole = ClientRole
|
||||
|
|
Loading…
Reference in a new issue