hs-tls/core/Network/TLS/State.hs

364 lines
15 KiB
Haskell
Raw Normal View History

{-# LANGUAGE GeneralizedNewtypeDeriving, FlexibleContexts, MultiParamTypeClasses, ExistentialQuantification, RankNTypes, CPP #-}
2010-09-09 21:47:19 +00:00
-- |
-- Module : Network.TLS.State
-- License : BSD-style
-- Maintainer : Vincent Hanquez <vincent@snarc.org>
-- Stability : experimental
-- Portability : unknown
--
-- the State module contains calls related to state initialization/manipulation
-- which is use by the Receiving module and the Sending module.
--
module Network.TLS.State
2013-07-12 06:27:28 +00:00
( TLSState(..)
, TLSSt
, RecordState(..)
, getRecordState
, runTLSState
, runRecordStateSt
2013-07-18 06:19:05 +00:00
, HandshakeState(..)
, withHandshakeM
2013-07-12 06:27:28 +00:00
, newTLSState
, withTLSRNG
2013-07-13 07:03:25 +00:00
, genRandom
2013-07-12 06:27:28 +00:00
, assert -- FIXME move somewhere else (Internal.hs ?)
, updateVerifiedData
, finishHandshakeTypeMaterial
, finishHandshakeMaterial
, certVerifyHandshakeTypeMaterial
, certVerifyHandshakeMaterial
, setMasterSecret
, setMasterSecretFromPre
, getMasterSecret
, setKeyBlock
, setVersion
2013-07-13 07:03:25 +00:00
, getVersion
2013-07-12 06:27:28 +00:00
, setCipher
, setServerRandom
, setSecureRenegotiation
, getSecureRenegotiation
, setExtensionNPN
, getExtensionNPN
, setNegotiatedProtocol
, getNegotiatedProtocol
, setServerNextProtocolSuggest
, getServerNextProtocolSuggest
, getClientCertificateChain
, setClientCertificateChain
, getVerifiedData
, setSession
, getSession
, getSessionData
, isSessionResuming
, needEmptyPacket
, switchTxEncryption
, switchRxEncryption
, getCipherKeyExchangeType
, isClientContext
, startHandshakeClient
, getHandshakeDigest
, endHandshake
) where
2010-09-09 21:47:19 +00:00
import Data.Maybe (isNothing)
import Network.TLS.Util
2010-09-09 21:47:19 +00:00
import Network.TLS.Struct
import Network.TLS.Packet
import Network.TLS.Crypto
import Network.TLS.Cipher
2013-07-13 07:03:25 +00:00
import Network.TLS.Record.State
2013-07-18 06:19:05 +00:00
import Network.TLS.Handshake.State
2013-07-13 07:03:25 +00:00
import Network.TLS.RNG
import Network.TLS.Types (Role(..))
import qualified Data.ByteString as B
2011-07-07 21:21:23 +00:00
import Control.Applicative ((<$>))
2010-09-09 21:47:19 +00:00
import Control.Monad
import Control.Monad.State
import Control.Monad.Error
import Crypto.Random.API
2013-05-19 07:05:46 +00:00
import Data.X509 (CertificateChain)
2010-09-09 21:47:19 +00:00
assert :: Monad m => String -> [(String,Bool)] -> m ()
assert fctname list = forM_ list $ \ (name, assumption) -> do
2013-07-12 06:27:28 +00:00
when assumption $ fail (fctname ++ ": assumption about " ++ name ++ " failed")
2010-09-09 21:47:19 +00:00
data TLSState = TLSState
2013-07-19 06:05:37 +00:00
{ stHandshake :: !(Maybe HandshakeState)
2013-07-12 06:27:28 +00:00
, stSession :: Session
, stSessionResuming :: Bool
, stRecordState :: RecordState
, stSecureRenegotiation :: Bool -- RFC 5746
, stClientVerifiedData :: Bytes -- RFC 5746
, stServerVerifiedData :: Bytes -- RFC 5746
, stExtensionNPN :: Bool -- NPN draft extension
, stNegotiatedProtocol :: Maybe B.ByteString -- NPN protocol
, stServerNextProtocolSuggest :: Maybe [B.ByteString]
, stClientCertificateChain :: Maybe CertificateChain
} deriving (Show)
2010-09-09 21:47:19 +00:00
newtype TLSSt a = TLSSt { runTLSSt :: ErrorT TLSError (State TLSState) a }
2013-07-12 06:27:28 +00:00
deriving (Monad, MonadError TLSError)
instance Functor TLSSt where
2013-07-12 06:27:28 +00:00
fmap f = TLSSt . fmap f . runTLSSt
instance MonadState TLSState TLSSt where
2013-07-12 06:27:28 +00:00
put x = TLSSt (lift $ put x)
get = TLSSt (lift get)
#if MIN_VERSION_mtl(2,1,0)
2013-07-12 06:27:28 +00:00
state f = TLSSt (lift $ state f)
#endif
runTLSState :: TLSSt a -> TLSState -> (Either TLSError a, TLSState)
runTLSState f st = runState (runErrorT (runTLSSt f)) st
2010-09-09 21:47:19 +00:00
getRecordState :: MonadState TLSState m => (RecordState -> a) -> m a
getRecordState f = gets (f . stRecordState)
runRecordState :: RecordM a -> TLSState -> (Either TLSError a, TLSState)
runRecordState f st =
let (r, nrst) = runState (runErrorT (runRecordM f)) (stRecordState st)
in case r of
Left _ -> (r, st)
Right _ -> (r, st { stRecordState = nrst })
runRecordStateSt :: RecordM a -> TLSSt a
runRecordStateSt f = do
st <- get
case runRecordState f st of
(Left e, _) -> throwError e
(Right a, newSt) -> put newSt >> return a
newTLSState :: CPRG g => g -> Role -> TLSState
newTLSState rng clientContext = TLSState
2013-07-12 06:27:28 +00:00
{ stHandshake = Nothing
, stSession = Session Nothing
, stSessionResuming = False
, stRecordState = newRecordState rng clientContext
, stSecureRenegotiation = False
, stClientVerifiedData = B.empty
, stServerVerifiedData = B.empty
, stExtensionNPN = False
, stNegotiatedProtocol = Nothing
, stServerNextProtocolSuggest = Nothing
, stClientCertificateChain = Nothing
}
updateVerifiedData :: MonadState TLSState m => Role -> Bytes -> m ()
updateVerifiedData sending bs = do
2013-07-12 06:27:28 +00:00
cc <- isClientContext
if cc /= sending
then modify (\st -> st { stServerVerifiedData = bs })
else modify (\st -> st { stClientVerifiedData = bs })
2010-09-09 21:47:19 +00:00
finishHandshakeTypeMaterial :: HandshakeType -> Bool
finishHandshakeTypeMaterial HandshakeType_ClientHello = True
finishHandshakeTypeMaterial HandshakeType_ServerHello = True
finishHandshakeTypeMaterial HandshakeType_Certificate = True
finishHandshakeTypeMaterial HandshakeType_HelloRequest = False
finishHandshakeTypeMaterial HandshakeType_ServerHelloDone = True
finishHandshakeTypeMaterial HandshakeType_ClientKeyXchg = True
finishHandshakeTypeMaterial HandshakeType_ServerKeyXchg = True
finishHandshakeTypeMaterial HandshakeType_CertRequest = True
finishHandshakeTypeMaterial HandshakeType_CertVerify = True
2010-09-09 21:47:19 +00:00
finishHandshakeTypeMaterial HandshakeType_Finished = True
2012-02-07 21:24:30 +00:00
finishHandshakeTypeMaterial HandshakeType_NPN = True
2010-09-09 21:47:19 +00:00
finishHandshakeMaterial :: Handshake -> Bool
finishHandshakeMaterial = finishHandshakeTypeMaterial . typeOfHandshake
certVerifyHandshakeTypeMaterial :: HandshakeType -> Bool
certVerifyHandshakeTypeMaterial HandshakeType_ClientHello = True
certVerifyHandshakeTypeMaterial HandshakeType_ServerHello = True
certVerifyHandshakeTypeMaterial HandshakeType_Certificate = True
certVerifyHandshakeTypeMaterial HandshakeType_HelloRequest = False
certVerifyHandshakeTypeMaterial HandshakeType_ServerHelloDone = True
certVerifyHandshakeTypeMaterial HandshakeType_ClientKeyXchg = True
certVerifyHandshakeTypeMaterial HandshakeType_ServerKeyXchg = True
certVerifyHandshakeTypeMaterial HandshakeType_CertRequest = True
certVerifyHandshakeTypeMaterial HandshakeType_CertVerify = False
certVerifyHandshakeTypeMaterial HandshakeType_Finished = False
certVerifyHandshakeTypeMaterial HandshakeType_NPN = False
certVerifyHandshakeMaterial :: Handshake -> Bool
certVerifyHandshakeMaterial = certVerifyHandshakeTypeMaterial . typeOfHandshake
switchTxEncryption, switchRxEncryption :: TLSSt ()
switchTxEncryption =
withHandshakeM (gets hstPendingTxState)
>>= \newTxState -> runRecordStateSt (modify $ \st -> st { stTxState = fromJust "pending-tx" newTxState })
switchRxEncryption =
withHandshakeM (gets hstPendingRxState)
>>= \newRxState -> runRecordStateSt (modify $ \st -> st { stRxState = fromJust "pending-rx" newRxState })
2010-09-09 21:47:19 +00:00
setServerRandom :: MonadState TLSState m => ServerRandom -> m ()
2010-09-09 21:47:19 +00:00
setServerRandom ran = updateHandshake "srand" (\hst -> hst { hstServerRandom = Just ran })
setMasterSecret :: Version -> Role -> Bytes -> HandshakeM ()
setMasterSecret ver role masterSecret = do
modify (\hst -> hst { hstMasterSecret = Just masterSecret } )
setKeyBlock ver role
2013-07-12 06:27:28 +00:00
return ()
2010-09-09 21:47:19 +00:00
setMasterSecretFromPre :: Version -> Role -> Bytes -> HandshakeM ()
setMasterSecretFromPre ver role premasterSecret = do
secret <- genSecret <$> get
setMasterSecret ver role secret
where genSecret hst = generateMasterSecret ver
2013-07-12 06:27:28 +00:00
premasterSecret
(hstClientRandom hst)
(fromJust "server random" $ hstServerRandom hst)
2012-07-28 12:22:16 +00:00
getMasterSecret :: MonadState TLSState m => m (Maybe Bytes)
getMasterSecret = gets (stHandshake >=> hstMasterSecret)
2012-07-28 12:22:16 +00:00
2011-12-20 07:38:35 +00:00
getSessionData :: MonadState TLSState m => m (Maybe SessionData)
getSessionData = get >>= \st -> return (stHandshake st >>= hstMasterSecret >>= wrapSessionData st)
2013-07-12 06:27:28 +00:00
where wrapSessionData st masterSecret = do
return $ SessionData
{ sessionVersion = stVersion $ stRecordState st
, sessionCipher = cipherID $ fromJust "cipher" $ stCipher $ stTxState $ stRecordState st
2013-07-12 06:27:28 +00:00
, sessionSecret = masterSecret
}
2011-12-20 07:38:35 +00:00
setSession :: MonadState TLSState m => Session -> Bool -> m ()
setSession session resuming = modify (\st -> st { stSession = session, stSessionResuming = resuming })
getSession :: MonadState TLSState m => m Session
getSession = gets stSession
isSessionResuming :: MonadState TLSState m => m Bool
isSessionResuming = gets stSessionResuming
needEmptyPacket :: MonadState RecordState m => m Bool
needEmptyPacket = gets f
2013-07-12 06:27:28 +00:00
where f st = (stVersion st <= TLS10)
&& stClientContext st == ClientRole
&& (maybe False (\c -> bulkBlockSize (cipherBulk c) > 0) (stCipher $ stTxState st))
setKeyBlock :: Version -> Role -> HandshakeM ()
setKeyBlock ver cc = modify setPendingState
where
setPendingState hst = hst { hstPendingTxState = Just pendingTx, hstPendingRxState = Just pendingRx }
where cipher = fromJust "cipher" $ hstPendingCipher hst
keyblockSize = cipherKeyBlockSize cipher
bulk = cipherBulk cipher
digestSize = hashSize $ cipherHash cipher
keySize = bulkKeySize bulk
ivSize = bulkIVSize bulk
kb = generateKeyBlock ver (hstClientRandom hst)
(fromJust "server random" $ hstServerRandom hst)
(fromJust "master secret" $ hstMasterSecret hst) keyblockSize
(cMACSecret, sMACSecret, cWriteKey, sWriteKey, cWriteIV, sWriteIV) =
fromJust "p6" $ partition6 kb (digestSize, digestSize, keySize, keySize, ivSize, ivSize)
2013-07-18 06:32:08 +00:00
cstClient = CryptState { cstKey = cWriteKey
, cstIV = cWriteIV
, cstMacSecret = cMACSecret }
cstServer = CryptState { cstKey = sWriteKey
, cstIV = sWriteIV
, cstMacSecret = sMACSecret }
msClient = MacState { msSequence = 0 }
msServer = MacState { msSequence = 0 }
2010-09-09 21:47:19 +00:00
pendingTx = TransmissionState
{ stCryptState = if cc == ClientRole then cstClient else cstServer
, stMacState = if cc == ClientRole then msClient else msServer
, stCipher = Just cipher
, stCompression = hstPendingCompression hst
}
pendingRx = TransmissionState
{ stCryptState = if cc == ClientRole then cstServer else cstClient
, stMacState = if cc == ClientRole then msServer else msClient
, stCipher = Just cipher
, stCompression = hstPendingCompression hst
}
setCipher :: Cipher -> HandshakeM ()
setCipher cipher = modify (\st -> st { hstPendingCipher = Just cipher })
2010-09-09 21:47:19 +00:00
setVersion :: MonadState TLSState m => Version -> m ()
setVersion ver = modify (\st -> st { stRecordState = (stRecordState st) { stVersion = ver } })
2010-09-09 21:47:19 +00:00
2013-07-13 07:03:25 +00:00
getVersion :: MonadState TLSState m => m Version
getVersion = gets (stVersion . stRecordState)
setSecureRenegotiation :: MonadState TLSState m => Bool -> m ()
setSecureRenegotiation b = modify (\st -> st { stSecureRenegotiation = b })
getSecureRenegotiation :: MonadState TLSState m => m Bool
getSecureRenegotiation = gets stSecureRenegotiation
setExtensionNPN :: MonadState TLSState m => Bool -> m ()
setExtensionNPN b = modify (\st -> st { stExtensionNPN = b })
getExtensionNPN :: MonadState TLSState m => m Bool
getExtensionNPN = gets stExtensionNPN
setNegotiatedProtocol :: MonadState TLSState m => B.ByteString -> m ()
setNegotiatedProtocol s = modify (\st -> st { stNegotiatedProtocol = Just s })
getNegotiatedProtocol :: MonadState TLSState m => m (Maybe B.ByteString)
getNegotiatedProtocol = gets stNegotiatedProtocol
setServerNextProtocolSuggest :: MonadState TLSState m => [B.ByteString] -> m ()
setServerNextProtocolSuggest ps = modify (\st -> st { stServerNextProtocolSuggest = Just ps})
getServerNextProtocolSuggest :: MonadState TLSState m => m (Maybe [B.ByteString])
getServerNextProtocolSuggest = get >>= return . stServerNextProtocolSuggest
2013-05-19 07:05:46 +00:00
setClientCertificateChain :: MonadState TLSState m => CertificateChain -> m ()
setClientCertificateChain s = modify (\st -> st { stClientCertificateChain = Just s })
2013-05-19 07:05:46 +00:00
getClientCertificateChain :: MonadState TLSState m => m (Maybe CertificateChain)
getClientCertificateChain = gets stClientCertificateChain
getCipherKeyExchangeType :: HandshakeM (Maybe CipherKeyExchangeType)
getCipherKeyExchangeType = gets (\st -> cipherKeyExchange <$> hstPendingCipher st)
getVerifiedData :: MonadState TLSState m => Bool -> m Bytes
getVerifiedData client = gets (if client then stClientVerifiedData else stServerVerifiedData)
isClientContext :: MonadState TLSState m => m Role
isClientContext = getRecordState stClientContext
2010-09-09 21:47:19 +00:00
startHandshakeClient :: MonadState TLSState m => Version -> ClientRandom -> m ()
2010-09-09 21:47:19 +00:00
startHandshakeClient ver crand = do
2013-07-12 06:27:28 +00:00
-- FIXME check if handshake is already not null
let initCtx = if ver < TLS12 then hashMD5SHA1 else hashSHA256
chs <- get >>= return . stHandshake
when (isNothing chs) $
modify (\st -> st { stHandshake = Just $ newEmptyHandshake ver crand initCtx })
2010-09-09 21:47:19 +00:00
hasValidHandshake :: MonadState TLSState m => String -> m ()
hasValidHandshake name = get >>= \st -> assert name [ ("valid handshake", isNothing $ stHandshake st) ]
2010-09-09 21:47:19 +00:00
2013-07-19 06:05:37 +00:00
updateHandshake :: MonadState TLSState m => String -> (HandshakeState -> HandshakeState) -> m ()
2010-09-09 21:47:19 +00:00
updateHandshake n f = do
2013-07-12 06:27:28 +00:00
hasValidHandshake n
modify (\st -> st { stHandshake = f <$> stHandshake st })
2010-09-09 21:47:19 +00:00
withHandshakeM :: MonadState TLSState m => HandshakeM a -> m a
withHandshakeM f =
get >>= \st -> case stHandshake st of
Nothing -> fail "handshake missing"
Just hst -> do let (a, nhst) = runHandshake hst f
put (st { stHandshake = Just nhst })
return a
getHandshakeDigest :: MonadState TLSState m => Bool -> m Bytes
2010-09-09 21:47:19 +00:00
getHandshakeDigest client = do
2013-07-12 06:27:28 +00:00
st <- get
let hst = fromJust "handshake" $ stHandshake st
let hashctx = hstHandshakeDigest hst
let msecret = fromJust "master secret" $ hstMasterSecret hst
return $ (if client then generateClientFinished else generateServerFinished) (stVersion $ stRecordState st) msecret hashctx
2010-09-09 21:47:19 +00:00
endHandshake :: MonadState TLSState m => m ()
endHandshake = modify (\st -> st { stHandshake = Nothing })
2013-07-13 07:03:25 +00:00
genRandom :: Int -> TLSSt Bytes
genRandom n = runRecordStateSt (genTLSRandom n)