diff --git a/core/Network/TLS/Handshake/Client.hs b/core/Network/TLS/Handshake/Client.hs index b9ef76e..2cfb7b1 100644 --- a/core/Network/TLS/Handshake/Client.hs +++ b/core/Network/TLS/Handshake/Client.hs @@ -218,7 +218,7 @@ sendClientData cparams ctx = sendCertificate >> sendClientKeyXchg >> sendCertifi -- 4. Send it to the server. -- sendCertificateVerify = do - usedVersion <- usingState_ ctx $ stVersion . stRecordState <$> get + usedVersion <- usingState_ ctx getVersion -- Only send a certificate verify message when we -- have sent a non-empty list of certificates. diff --git a/core/Network/TLS/Handshake/Server.hs b/core/Network/TLS/Handshake/Server.hs index 8847fc7..3abf2f6 100644 --- a/core/Network/TLS/Handshake/Server.hs +++ b/core/Network/TLS/Handshake/Server.hs @@ -183,7 +183,7 @@ handshakeServerWith sparams ctx clientHello@(ClientHello ver _ clientSession cip -- certificates. -- when (serverWantClientCert sparams) $ do - usedVersion <- usingState_ ctx $ getRecordState stVersion + usedVersion <- usingState_ ctx getVersion let certTypes = [ CertificateType_RSA_Sign ] hashSigs = if usedVersion < TLS12 then Nothing @@ -247,7 +247,7 @@ recvClientData sparams ctx = runRecvState ctx (RecvStateHandshake processClientC -- Fetch all handshake messages up to now. msgs <- usingHState ctx $ B.concat <$> getHandshakeMessages - usedVersion <- usingState_ ctx $ getRecordState stVersion + usedVersion <- usingState_ ctx getVersion (signature, hsh) <- case usedVersion of SSL3 -> do diff --git a/core/Network/TLS/IO.hs b/core/Network/TLS/IO.hs index aeade33..0a0587b 100644 --- a/core/Network/TLS/IO.hs +++ b/core/Network/TLS/IO.hs @@ -15,7 +15,7 @@ module Network.TLS.IO ) where import Network.TLS.Context -import Network.TLS.State (needEmptyPacket, runRecordStateSt) +import Network.TLS.State (needEmptyPacket, runRxState) import Network.TLS.Struct import Network.TLS.Record import Network.TLS.Packet @@ -78,7 +78,7 @@ recvRecord compatSSLv2 ctx maximumSizeExceeded = Error_Protocol ("record exceeding maximum size", True, RecordOverflow) makeRecord header content = do liftIO $ (loggingIORecv $ ctxLogging ctx) header content - usingState ctx $ runRecordStateSt $ disengageRecord $ rawToRecord header (fragmentCiphertext content) + usingState ctx $ runRxState $ disengageRecord $ rawToRecord header (fragmentCiphertext content) -- | receive one packet from the context that contains 1 or @@ -109,7 +109,7 @@ sendPacket ctx pkt = do -- in ver <= TLS1.0, block ciphers using CBC are using CBC residue as IV, which can be guessed -- by an attacker. Hence, an empty packet is sent before a normal data packet, to -- prevent guessability. - withEmptyPacket <- usingState_ ctx (runRecordStateSt needEmptyPacket) + withEmptyPacket <- usingState_ ctx needEmptyPacket when (isNonNullAppData pkt && withEmptyPacket) $ sendPacket ctx $ AppData B.empty liftIO $ (loggingPacketSent $ ctxLogging ctx) (show pkt) diff --git a/core/Network/TLS/Receiving.hs b/core/Network/TLS/Receiving.hs index 65f58b5..f0ae245 100644 --- a/core/Network/TLS/Receiving.hs +++ b/core/Network/TLS/Receiving.hs @@ -99,7 +99,7 @@ processHandshake hs = do decryptRSA :: ByteString -> TLSSt (Either KxError ByteString) decryptRSA econtent = do - ver <- getRecordState stVersion + ver <- getVersion rsapriv <- fromJust "rsa private key" . hstRSAPrivateKey . fromJust "handshake" . stHandshake <$> get let cipher = if ver < TLS10 then econtent else B.drop 2 econtent st <- get diff --git a/core/Network/TLS/Record.hs b/core/Network/TLS/Record.hs index 8b00b31..8b07349 100644 --- a/core/Network/TLS/Record.hs +++ b/core/Network/TLS/Record.hs @@ -28,9 +28,7 @@ module Network.TLS.Record , engageRecord , disengageRecord -- * State tracking - , RecordState(..) , RecordM - , newRecordState ) where import Network.TLS.Record.Types diff --git a/core/Network/TLS/Record/Disengage.hs b/core/Network/TLS/Record/Disengage.hs index 6c928f9..e33399e 100644 --- a/core/Network/TLS/Record/Disengage.hs +++ b/core/Network/TLS/Record/Disengage.hs @@ -30,14 +30,14 @@ disengageRecord = decryptRecord >=> uncompressRecord uncompressRecord :: Record Compressed -> RecordM (Record Plaintext) uncompressRecord record = onRecordFragment record $ fragmentUncompress $ \bytes -> - withRxCompression $ compressionInflate bytes + withCompression $ compressionInflate bytes decryptRecord :: Record Ciphertext -> RecordM (Record Compressed) decryptRecord record = onRecordFragment record $ fragmentUncipher $ \e -> do st <- get - case stCipher $ stRxState st of + case stCipher st of Nothing -> return e - _ -> decryptData record e st + _ -> getRecordVersion >>= \ver -> decryptData ver record e st getCipherData :: Record a -> CipherData -> RecordM ByteString getCipherData (Record pt ver _) cdata = do @@ -46,14 +46,14 @@ getCipherData (Record pt ver _) cdata = do Nothing -> return True Just digest -> do let new_hdr = Header pt ver (fromIntegral $ B.length $ cipherDataContent cdata) - expected_digest <- makeDigest False new_hdr $ cipherDataContent cdata + expected_digest <- makeDigest new_hdr $ cipherDataContent cdata return (expected_digest `bytesEq` digest) -- check if the padding is filled with the correct pattern if it exists paddingValid <- case cipherDataPadding cdata of Nothing -> return True Just pad -> do - cver <- gets stVersion + cver <- getRecordVersion let b = B.length pad - 1 return (if cver < TLS10 then True else B.replicate (B.length pad) (fromIntegral b) `bytesEq` pad) @@ -62,10 +62,9 @@ getCipherData (Record pt ver _) cdata = do return $ cipherDataContent cdata -decryptData :: Record Ciphertext -> Bytes -> RecordState -> RecordM Bytes -decryptData record econtent st = decryptOf (bulkF bulk) - where tst = stRxState st - cipher = fromJust "cipher" $ stCipher tst +decryptData :: Version -> Record Ciphertext -> Bytes -> TransmissionState -> RecordM Bytes +decryptData ver record econtent tst = decryptOf (bulkF bulk) + where cipher = fromJust "cipher" $ stCipher tst bulk = cipherBulk cipher cst = stCryptState tst macSize = hashSize $ cipherHash cipher @@ -73,7 +72,7 @@ decryptData record econtent st = decryptOf (bulkF bulk) blockSize = bulkBlockSize bulk econtentLen = B.length econtent - explicitIV = hasExplicitBlockIV $ stVersion st + explicitIV = hasExplicitBlockIV ver sanityCheckError = throwError (Error_Packet "encrypted content too small for encryption parameters") @@ -86,7 +85,7 @@ decryptData record econtent st = decryptOf (bulkF bulk) then get2 econtent (bulkIVSize bulk, econtentLen - bulkIVSize bulk) else return (cstIV cst, econtent) let newiv = fromJust "new iv" $ takelast (bulkBlockSize bulk) econtent' - modifyRxState_ $ \txs -> txs { stCryptState = cst { cstIV = newiv } } + modify $ \txs -> txs { stCryptState = cst { cstIV = newiv } } let content' = decryptF writekey iv econtent' let paddinglength = fromIntegral (B.last content') + 1 @@ -104,7 +103,7 @@ decryptData record econtent st = decryptOf (bulkF bulk) {- update Ctx -} let contentlen = B.length content' - macSize (content, mac) <- get2 content' (contentlen, macSize) - modifyRxState_ $ \txs -> txs { stCryptState = cst { cstIV = newiv } } + modify $ \txs -> txs { stCryptState = cst { cstIV = newiv } } getCipherData record $ CipherData { cipherDataContent = content , cipherDataMAC = Just mac diff --git a/core/Network/TLS/Record/Engage.hs b/core/Network/TLS/Record/Engage.hs index 5058699..1c56933 100644 --- a/core/Network/TLS/Record/Engage.hs +++ b/core/Network/TLS/Record/Engage.hs @@ -29,7 +29,7 @@ engageRecord = compressRecord >=> encryptRecord compressRecord :: Record Plaintext -> RecordM (Record Compressed) compressRecord record = onRecordFragment record $ fragmentCompress $ \bytes -> do - withTxCompression $ compressionDeflate bytes + withCompression $ compressionDeflate bytes {- - when Tx Encrypted is set, we pass the data through encryptContent, otherwise @@ -38,20 +38,20 @@ compressRecord record = encryptRecord :: Record Compressed -> RecordM (Record Ciphertext) encryptRecord record = onRecordFragment record $ fragmentCipher $ \bytes -> do st <- get - case stCipher $ stTxState st of + case stCipher st of Nothing -> return bytes _ -> encryptContent record bytes encryptContent :: Record Compressed -> ByteString -> RecordM ByteString encryptContent record content = do - digest <- makeDigest True (recordToHeader record) content + digest <- makeDigest (recordToHeader record) content encryptData $ B.concat [content, digest] encryptData :: ByteString -> RecordM ByteString encryptData content = do - st <- get + tstate <- get + ver <- getRecordVersion - let tstate = stTxState st let cipher = fromJust "cipher" $ stCipher tstate let bulk = cipherBulk cipher let cst = stCryptState tstate @@ -71,14 +71,14 @@ encryptData content = do B.empty let e = encrypt writekey (cstIV cst) (B.concat [ content, padding ]) - if hasExplicitBlockIV $ stVersion st + if hasExplicitBlockIV ver then return $ B.concat [cstIV cst,e] else do let newiv = fromJust "new iv" $ takelast (bulkIVSize bulk) e - modifyTxState_ $ \txs -> txs { stCryptState = cst { cstIV = newiv } } + put $ tstate { stCryptState = cst { cstIV = newiv } } return e BulkStreamF initF encryptF _ -> do let iv = cstIV cst let (e, newiv) = encryptF (if iv /= B.empty then iv else initF writekey) content - modifyTxState_ $ \txs -> txs { stCryptState = cst { cstIV = newiv } } + put $ tstate { stCryptState = cst { cstIV = newiv } } return e diff --git a/core/Network/TLS/Record/State.hs b/core/Network/TLS/Record/State.hs index 1f4098c..8599fe2 100644 --- a/core/Network/TLS/Record/State.hs +++ b/core/Network/TLS/Record/State.hs @@ -12,15 +12,11 @@ module Network.TLS.Record.State ( CryptState(..) , MacState(..) , TransmissionState(..) - , RecordState(..) - , newRecordState - , RecordM(..) - , withTxCompression - , withRxCompression - , modifyTxState - , modifyRxState - , modifyTxState_ - , modifyRxState_ + , newTransmissionState + , RecordM + , runRecordM + , getRecordVersion + , withCompression , computeDigest , makeDigest ) where @@ -36,7 +32,7 @@ import Network.TLS.Wire import Network.TLS.Packet import Network.TLS.MAC import Network.TLS.Util -import Network.TLS.Types (Role(..)) +import Network.TLS.Types (Direction(..)) import qualified Data.ByteString as B @@ -57,26 +53,40 @@ data TransmissionState = TransmissionState , stMacState :: !MacState } deriving (Show) -data RecordState = RecordState - { stClientContext :: Role - , stVersion :: !Version - , stTxState :: TransmissionState - , stRxState :: TransmissionState - } deriving (Show) +newtype RecordM a = RecordM { runRecordM :: Version + -> TransmissionState + -> Either TLSError (a, TransmissionState) } -newtype RecordM a = RecordM { runRecordM :: ErrorT TLSError (State RecordState) a } - deriving (Monad, MonadError TLSError) +instance Monad RecordM where + return a = RecordM $ \_ st -> Right (a, st) + m1 >>= m2 = RecordM $ \ver st -> do + case runRecordM m1 ver st of + Left err -> Left err + Right (a, st2) -> runRecordM (m2 a) ver st2 instance Functor RecordM where - fmap f = RecordM . fmap f . runRecordM + fmap f m = RecordM $ \ver st -> + case runRecordM m ver st of + Left err -> Left err + Right (a, st2) -> Right (f a, st2) -instance MonadState RecordState RecordM where - put x = RecordM (lift $ put x) - get = RecordM (lift get) +getRecordVersion :: RecordM Version +getRecordVersion = RecordM $ \ver st -> Right (ver, st) + +instance MonadState TransmissionState RecordM where + put x = RecordM $ \_ _ -> Right ((), x) + get = RecordM $ \_ st -> Right (st, st) #if MIN_VERSION_mtl(2,1,0) - state f = RecordM (lift $ state f) + state f = RecordM $ \_ st -> Right (f st) #endif +instance MonadError TLSError RecordM where + throwError e = RecordM $ \_ _ -> Left e + catchError m f = RecordM $ \ver st -> + case runRecordM m ver st of + Left err -> runRecordM (f err) ver st + r -> r + newTransmissionState :: TransmissionState newTransmissionState = TransmissionState { stCipher = Nothing @@ -89,39 +99,12 @@ incrTransmissionState :: TransmissionState -> TransmissionState incrTransmissionState ts = ts { stMacState = MacState (ms + 1) } where (MacState ms) = stMacState ts -newRecordState :: Role -> RecordState -newRecordState clientContext = RecordState - { stClientContext = clientContext - , stVersion = TLS10 - , stTxState = newTransmissionState - , stRxState = newTransmissionState - } - -modifyTxState :: (TransmissionState -> (TransmissionState, a)) -> RecordM a -modifyTxState f = - get >>= \st -> case f $ stTxState st of - (nst, a) -> put (st { stTxState = nst }) >> return a - -modifyTxState_ :: (TransmissionState -> TransmissionState) -> RecordM () -modifyTxState_ f = modifyTxState (\t -> (f t, ())) - -modifyRxState :: (TransmissionState -> (TransmissionState, a)) -> RecordM a -modifyRxState f = - get >>= \st -> case f $ stRxState st of - (nst, a) -> put (st { stRxState = nst }) >> return a - -modifyRxState_ :: (TransmissionState -> TransmissionState) -> RecordM () -modifyRxState_ f = modifyRxState (\t -> (f t, ())) - -modifyCompression :: TransmissionState -> (Compression -> (Compression, a)) -> (TransmissionState, a) -modifyCompression tst f = case f (stCompression tst) of - (nc, a) -> (tst { stCompression = nc }, a) - -withTxCompression :: (Compression -> (Compression, a)) -> RecordM a -withTxCompression f = modifyTxState $ \tst -> modifyCompression tst f - -withRxCompression :: (Compression -> (Compression, a)) -> RecordM a -withRxCompression f = modifyRxState $ \tst -> modifyCompression tst f +withCompression :: (Compression -> (Compression, a)) -> RecordM a +withCompression f = do + st <- get + let (nc, a) = f $ stCompression st + put $ st { stCompression = nc } + return a computeDigest :: Version -> TransmissionState -> Header -> Bytes -> (Bytes, TransmissionState) computeDigest ver tstate hdr content = (digest, incrTransmissionState tstate) @@ -135,12 +118,10 @@ computeDigest ver tstate hdr content = (digest, incrTransmissionState tstate) | ver < TLS10 = (macSSL hashf, B.concat [ encodedSeq, encodeHeaderNoVer hdr, content ]) | otherwise = (hmac hashf 64, B.concat [ encodedSeq, encodeHeader hdr, content ]) -makeDigest :: Bool -> Header -> Bytes -> RecordM Bytes -makeDigest w hdr content = do +makeDigest :: Header -> Bytes -> RecordM Bytes +makeDigest hdr content = do + ver <- getRecordVersion st <- get - let (digest, nstate) = computeDigest (stVersion st) - (if w then stTxState st else stRxState st) hdr content - put $ if w - then st { stTxState = nstate } - else st { stRxState = nstate } + let (digest, nstate) = computeDigest ver st hdr content + put nstate return digest diff --git a/core/Network/TLS/Sending.hs b/core/Network/TLS/Sending.hs index f4d8453..9905a1d 100644 --- a/core/Network/TLS/Sending.hs +++ b/core/Network/TLS/Sending.hs @@ -32,7 +32,7 @@ import Network.TLS.Cipher -- this doesn't change any state makeRecord :: Packet -> RecordM (Record Plaintext) makeRecord pkt = do - ver <- stVersion <$> get + ver <- getRecordVersion return $ Record (packetType pkt) ver (fragmentPlaintext $ writePacketContent pkt) where writePacketContent (Handshake hss) = encodeHandshakes hss writePacketContent (Alert a) = encodeAlerts a @@ -67,17 +67,14 @@ prepareRecord :: RecordM a -> TLSSt a prepareRecord f = do st <- get ver <- getVersion - let sz = case stCipher $ stTxState $ stRecordState st of + let sz = case stCipher $ stTxState st of Nothing -> 0 Just cipher -> bulkIVSize $ cipherBulk cipher if hasExplicitBlockIV ver && sz > 0 then do newIV <- genRandom sz - runRecordStateSt $ modify $ \rts -> - let ts = stTxState rts - nts = ts { stCryptState = (stCryptState ts) { cstIV = newIV } } - in rts { stTxState = nts } - runRecordStateSt f - else runRecordStateSt f + runTxState (modify $ \ts -> ts { stCryptState = (stCryptState ts) { cstIV = newIV } }) + runTxState f + else runTxState f {------------------------------------------------------------------------------} {- SENDING Helpers -} diff --git a/core/Network/TLS/State.hs b/core/Network/TLS/State.hs index 3ca546c..911d587 100644 --- a/core/Network/TLS/State.hs +++ b/core/Network/TLS/State.hs @@ -12,14 +12,13 @@ module Network.TLS.State ( TLSState(..) , TLSSt - , RecordState(..) - , getRecordState , runTLSState - , runRecordStateSt , HandshakeState(..) , withHandshakeM , newTLSState , withTLSRNG + , runTxState + , runRxState , genRandom , assert -- FIXME move somewhere else (Internal.hs ?) , updateVerifiedData @@ -77,7 +76,8 @@ data TLSState = TLSState { stHandshake :: !(Maybe HandshakeState) , stSession :: Session , stSessionResuming :: Bool - , stRecordState :: RecordState + , stTxState :: TransmissionState + , stRxState :: TransmissionState , stSecureRenegotiation :: Bool -- RFC 5746 , stClientVerifiedData :: Bytes -- RFC 5746 , stServerVerifiedData :: Bytes -- RFC 5746 @@ -86,6 +86,8 @@ data TLSState = TLSState , stServerNextProtocolSuggest :: Maybe [B.ByteString] , stClientCertificateChain :: Maybe CertificateChain , stRandomGen :: StateRNG + , stVersion :: Version + , stClientContext :: Role } deriving (Show) newtype TLSSt a = TLSSt { runTLSSt :: ErrorT TLSError (State TLSState) a } @@ -104,29 +106,27 @@ instance MonadState TLSState TLSSt where runTLSState :: TLSSt a -> TLSState -> (Either TLSError a, TLSState) runTLSState f st = runState (runErrorT (runTLSSt f)) st -getRecordState :: MonadState TLSState m => (RecordState -> a) -> m a -getRecordState f = gets (f . stRecordState) - -runRecordState :: RecordM a -> TLSState -> (Either TLSError a, TLSState) -runRecordState f st = - let (r, nrst) = runState (runErrorT (runRecordM f)) (stRecordState st) - in case r of - Left _ -> (r, st) - Right _ -> (r, st { stRecordState = nrst }) - -runRecordStateSt :: RecordM a -> TLSSt a -runRecordStateSt f = do +runTxState :: RecordM a -> TLSSt a +runTxState f = do st <- get - case runRecordState f st of - (Left e, _) -> throwError e - (Right a, newSt) -> put newSt >> return a + case runRecordM f (stVersion st) (stTxState st) of + Left err -> throwError err + Right (a, newSt) -> put (st { stTxState = newSt }) >> return a + +runRxState :: RecordM a -> TLSSt a +runRxState f = do + st <- get + case runRecordM f (stVersion st) (stRxState st) of + Left err -> throwError err + Right (a, newSt) -> put (st { stRxState = newSt }) >> return a newTLSState :: CPRG g => g -> Role -> TLSState newTLSState rng clientContext = TLSState { stHandshake = Nothing , stSession = Session Nothing , stSessionResuming = False - , stRecordState = newRecordState clientContext + , stTxState = newTransmissionState + , stRxState = newTransmissionState , stSecureRenegotiation = False , stClientVerifiedData = B.empty , stServerVerifiedData = B.empty @@ -135,6 +135,8 @@ newTLSState rng clientContext = TLSState , stServerNextProtocolSuggest = Nothing , stClientCertificateChain = Nothing , stRandomGen = StateRNG rng + , stVersion = TLS10 + , stClientContext = clientContext } updateVerifiedData :: MonadState TLSState m => Role -> Bytes -> m () @@ -179,17 +181,17 @@ certVerifyHandshakeMaterial = certVerifyHandshakeTypeMaterial . typeOfHandshake switchTxEncryption, switchRxEncryption :: TLSSt () switchTxEncryption = withHandshakeM (gets hstPendingTxState) - >>= \newTxState -> runRecordStateSt (modify $ \st -> st { stTxState = fromJust "pending-tx" newTxState }) + >>= \newTxState -> modify $ \st -> st { stTxState = fromJust "pending-tx" newTxState } switchRxEncryption = withHandshakeM (gets hstPendingRxState) - >>= \newRxState -> runRecordStateSt (modify $ \st -> st { stRxState = fromJust "pending-rx" newRxState }) + >>= \newRxState -> modify $ \st -> st { stRxState = fromJust "pending-rx" newRxState } getSessionData :: MonadState TLSState m => m (Maybe SessionData) getSessionData = get >>= \st -> return (stHandshake st >>= hstMasterSecret >>= wrapSessionData st) where wrapSessionData st masterSecret = do return $ SessionData - { sessionVersion = stVersion $ stRecordState st - , sessionCipher = cipherID $ fromJust "cipher" $ stCipher $ stTxState $ stRecordState st + { sessionVersion = stVersion st + , sessionCipher = cipherID $ fromJust "cipher" $ stCipher $ stTxState $ st , sessionSecret = masterSecret } @@ -202,17 +204,17 @@ getSession = gets stSession isSessionResuming :: MonadState TLSState m => m Bool isSessionResuming = gets stSessionResuming -needEmptyPacket :: MonadState RecordState m => m Bool +needEmptyPacket :: MonadState TLSState m => m Bool needEmptyPacket = gets f where f st = (stVersion st <= TLS10) && stClientContext st == ClientRole && (maybe False (\c -> bulkBlockSize (cipherBulk c) > 0) (stCipher $ stTxState st)) setVersion :: MonadState TLSState m => Version -> m () -setVersion ver = modify (\st -> st { stRecordState = (stRecordState st) { stVersion = ver } }) +setVersion ver = modify (\st -> st { stVersion = ver }) getVersion :: MonadState TLSState m => m Version -getVersion = gets (stVersion . stRecordState) +getVersion = gets stVersion setSecureRenegotiation :: MonadState TLSState m => Bool -> m () setSecureRenegotiation b = modify (\st -> st { stSecureRenegotiation = b }) @@ -248,7 +250,7 @@ getVerifiedData :: MonadState TLSState m => Role -> m Bytes getVerifiedData client = gets (if client == ClientRole then stClientVerifiedData else stServerVerifiedData) isClientContext :: MonadState TLSState m => m Role -isClientContext = getRecordState stClientContext +isClientContext = gets stClientContext startHandshakeClient :: MonadState TLSState m => Version -> ClientRandom -> m () startHandshakeClient ver crand = do diff --git a/core/Network/TLS/Types.hs b/core/Network/TLS/Types.hs index 0682f8b..31a6ca7 100644 --- a/core/Network/TLS/Types.hs +++ b/core/Network/TLS/Types.hs @@ -13,6 +13,7 @@ module Network.TLS.Types , CompressionID , Role(..) , invertRole + , Direction(..) ) where import Data.ByteString (ByteString) @@ -43,6 +44,10 @@ type CompressionID = Word8 data Role = ClientRole | ServerRole deriving (Show,Eq) +-- | Direction +data Direction = Tx | Rx + deriving (Show,Eq) + invertRole :: Role -> Role invertRole ClientRole = ServerRole invertRole ServerRole = ClientRole