more cleanup / separation with handshake state.
This commit is contained in:
parent
0fdfd3d104
commit
acc670e30e
5 changed files with 23 additions and 32 deletions
|
@ -90,14 +90,17 @@ handshakeClient cparams ctx = do
|
||||||
recvServerHello sentExts = runRecvState ctx (RecvStateHandshake $ onServerHello sentExts)
|
recvServerHello sentExts = runRecvState ctx (RecvStateHandshake $ onServerHello sentExts)
|
||||||
|
|
||||||
onServerHello :: MonadIO m => [ExtensionID] -> Handshake -> m (RecvState m)
|
onServerHello :: MonadIO m => [ExtensionID] -> Handshake -> m (RecvState m)
|
||||||
onServerHello sentExts sh@(ServerHello rver _ serverSession cipher _ exts) = do
|
onServerHello sentExts sh@(ServerHello rver _ serverSession cipher compression exts) = do
|
||||||
when (rver == SSL2) $ throwCore $ Error_Protocol ("ssl2 is not supported", True, ProtocolVersion)
|
when (rver == SSL2) $ throwCore $ Error_Protocol ("ssl2 is not supported", True, ProtocolVersion)
|
||||||
case find ((==) rver) allowedvers of
|
case find ((==) rver) allowedvers of
|
||||||
Nothing -> throwCore $ Error_Protocol ("version " ++ show rver ++ "is not supported", True, ProtocolVersion)
|
Nothing -> throwCore $ Error_Protocol ("version " ++ show rver ++ "is not supported", True, ProtocolVersion)
|
||||||
Just _ -> usingState_ ctx $ setVersion rver
|
Just _ -> usingState_ ctx $ setVersion rver
|
||||||
case find ((==) cipher . cipherID) ciphers of
|
-- find the compression and cipher methods that the server want to use.
|
||||||
Nothing -> throwCore $ Error_Protocol ("no cipher in common with the server", True, HandshakeFailure)
|
case (find ((==) cipher . cipherID) ciphers, find ((==) compression . compressionID) compressions) of
|
||||||
Just c -> usingHState ctx $ setCipher c
|
(Nothing,_) -> throwCore $ Error_Protocol ("no cipher in common with the server", True, HandshakeFailure)
|
||||||
|
(_,Nothing) -> throwCore $ Error_Protocol ("no compression in common with the server", True, HandshakeFailure)
|
||||||
|
(Just cipherAlg, Just compressAlg) ->
|
||||||
|
usingHState ctx $ setPendingAlgs cipherAlg compressAlg
|
||||||
|
|
||||||
-- intersect sent extensions in client and the received extensions from server.
|
-- intersect sent extensions in client and the received extensions from server.
|
||||||
-- if server returns extensions that we didn't request, fail.
|
-- if server returns extensions that we didn't request, fail.
|
||||||
|
|
|
@ -96,9 +96,7 @@ handshakeServerWith sparams ctx clientHello@(ClientHello ver _ clientSession cip
|
||||||
when (null commonCompressions) $
|
when (null commonCompressions) $
|
||||||
throwCore $ Error_Protocol ("no compression in common with the client", True, HandshakeFailure)
|
throwCore $ Error_Protocol ("no compression in common with the client", True, HandshakeFailure)
|
||||||
usingState_ ctx $ setVersion ver
|
usingState_ ctx $ setVersion ver
|
||||||
usingHState ctx $ do
|
usingHState ctx $ setPendingAlgs usedCipher usedCompression
|
||||||
setCipher usedCipher
|
|
||||||
modify (\hst -> hst { hstPendingCompression = usedCompression })
|
|
||||||
|
|
||||||
resumeSessionData <- case clientSession of
|
resumeSessionData <- case clientSession of
|
||||||
(Session (Just clientSessionId)) -> withSessionManager params (\s -> liftIO $ sessionResume s clientSessionId)
|
(Session (Just clientSessionId)) -> withSessionManager params (\s -> liftIO $ sessionResume s clientSessionId)
|
||||||
|
@ -165,7 +163,8 @@ handshakeServerWith sparams ctx clientHello@(ClientHello ver _ clientSession cip
|
||||||
, extensionEncode $ NextProtocolNegotiation protos) ]
|
, extensionEncode $ NextProtocolNegotiation protos) ]
|
||||||
Nothing -> return []
|
Nothing -> return []
|
||||||
let extensions = secRengExt ++ npnExt
|
let extensions = secRengExt ++ npnExt
|
||||||
usingState_ ctx (setVersion ver >> setServerRandom srand)
|
usingState_ ctx (setVersion ver)
|
||||||
|
usingHState ctx $ setServerRandom srand
|
||||||
return $ ServerHello ver srand session (cipherID usedCipher)
|
return $ ServerHello ver srand session (cipherID usedCipher)
|
||||||
(compressionID usedCompression) extensions
|
(compressionID usedCompression) extensions
|
||||||
|
|
||||||
|
|
|
@ -35,6 +35,9 @@ module Network.TLS.Handshake.State
|
||||||
-- * master secret
|
-- * master secret
|
||||||
, setMasterSecret
|
, setMasterSecret
|
||||||
, setMasterSecretFromPre
|
, setMasterSecretFromPre
|
||||||
|
-- * misc accessor
|
||||||
|
, setPendingAlgs
|
||||||
|
, setServerRandom
|
||||||
) where
|
) where
|
||||||
|
|
||||||
import Network.TLS.Util
|
import Network.TLS.Util
|
||||||
|
@ -211,3 +214,10 @@ computeKeyBlock hst masterSecret ver cc = (pendingTx, pendingRx)
|
||||||
, stCipher = Just cipher
|
, stCipher = Just cipher
|
||||||
, stCompression = hstPendingCompression hst
|
, stCompression = hstPendingCompression hst
|
||||||
}
|
}
|
||||||
|
|
||||||
|
setPendingAlgs :: Cipher -> Compression -> HandshakeM ()
|
||||||
|
setPendingAlgs cipher compression =
|
||||||
|
modify $ \hst -> hst { hstPendingCipher = Just cipher, hstPendingCompression = compression }
|
||||||
|
|
||||||
|
setServerRandom :: ServerRandom -> HandshakeM ()
|
||||||
|
setServerRandom ran = modify $ \hst -> hst { hstServerRandom = Just ran }
|
||||||
|
|
|
@ -120,7 +120,7 @@ processServerHello (ServerHello sver ran _ _ _ ex) = do
|
||||||
-- secreneg <- getSecureRenegotiation
|
-- secreneg <- getSecureRenegotiation
|
||||||
-- when (secreneg && (isNothing $ lookup 0xff01 ex)) $ ...
|
-- when (secreneg && (isNothing $ lookup 0xff01 ex)) $ ...
|
||||||
mapM_ processServerExtension ex
|
mapM_ processServerExtension ex
|
||||||
setServerRandom ran
|
withHandshakeM $ setServerRandom ran
|
||||||
setVersion sver
|
setVersion sver
|
||||||
where processServerExtension (0xff01, content) = do
|
where processServerExtension (0xff01, content) = do
|
||||||
cv <- getVerifiedData True
|
cv <- getVerifiedData True
|
||||||
|
|
|
@ -29,8 +29,6 @@ module Network.TLS.State
|
||||||
, certVerifyHandshakeMaterial
|
, certVerifyHandshakeMaterial
|
||||||
, setVersion
|
, setVersion
|
||||||
, getVersion
|
, getVersion
|
||||||
, setCipher
|
|
||||||
, setServerRandom
|
|
||||||
, setSecureRenegotiation
|
, setSecureRenegotiation
|
||||||
, getSecureRenegotiation
|
, getSecureRenegotiation
|
||||||
, setExtensionNPN
|
, setExtensionNPN
|
||||||
|
@ -49,7 +47,6 @@ module Network.TLS.State
|
||||||
, needEmptyPacket
|
, needEmptyPacket
|
||||||
, switchTxEncryption
|
, switchTxEncryption
|
||||||
, switchRxEncryption
|
, switchRxEncryption
|
||||||
, getCipherKeyExchangeType
|
|
||||||
, isClientContext
|
, isClientContext
|
||||||
, startHandshakeClient
|
, startHandshakeClient
|
||||||
, getHandshakeDigest
|
, getHandshakeDigest
|
||||||
|
@ -67,7 +64,6 @@ import Network.TLS.Handshake.State
|
||||||
import Network.TLS.RNG
|
import Network.TLS.RNG
|
||||||
import Network.TLS.Types (Role(..))
|
import Network.TLS.Types (Role(..))
|
||||||
import qualified Data.ByteString as B
|
import qualified Data.ByteString as B
|
||||||
import Control.Applicative ((<$>))
|
|
||||||
import Control.Monad
|
import Control.Monad
|
||||||
import Control.Monad.State
|
import Control.Monad.State
|
||||||
import Control.Monad.Error
|
import Control.Monad.Error
|
||||||
|
@ -187,9 +183,6 @@ switchRxEncryption =
|
||||||
withHandshakeM (gets hstPendingRxState)
|
withHandshakeM (gets hstPendingRxState)
|
||||||
>>= \newRxState -> runRecordStateSt (modify $ \st -> st { stRxState = fromJust "pending-rx" newRxState })
|
>>= \newRxState -> runRecordStateSt (modify $ \st -> st { stRxState = fromJust "pending-rx" newRxState })
|
||||||
|
|
||||||
setServerRandom :: MonadState TLSState m => ServerRandom -> m ()
|
|
||||||
setServerRandom ran = updateHandshake "srand" (\hst -> hst { hstServerRandom = Just ran })
|
|
||||||
|
|
||||||
getSessionData :: MonadState TLSState m => m (Maybe SessionData)
|
getSessionData :: MonadState TLSState m => m (Maybe SessionData)
|
||||||
getSessionData = get >>= \st -> return (stHandshake st >>= hstMasterSecret >>= wrapSessionData st)
|
getSessionData = get >>= \st -> return (stHandshake st >>= hstMasterSecret >>= wrapSessionData st)
|
||||||
where wrapSessionData st masterSecret = do
|
where wrapSessionData st masterSecret = do
|
||||||
|
@ -214,9 +207,6 @@ needEmptyPacket = gets f
|
||||||
&& stClientContext st == ClientRole
|
&& stClientContext st == ClientRole
|
||||||
&& (maybe False (\c -> bulkBlockSize (cipherBulk c) > 0) (stCipher $ stTxState st))
|
&& (maybe False (\c -> bulkBlockSize (cipherBulk c) > 0) (stCipher $ stTxState st))
|
||||||
|
|
||||||
setCipher :: Cipher -> HandshakeM ()
|
|
||||||
setCipher cipher = modify (\st -> st { hstPendingCipher = Just cipher })
|
|
||||||
|
|
||||||
setVersion :: MonadState TLSState m => Version -> m ()
|
setVersion :: MonadState TLSState m => Version -> m ()
|
||||||
setVersion ver = modify (\st -> st { stRecordState = (stRecordState st) { stVersion = ver } })
|
setVersion ver = modify (\st -> st { stRecordState = (stRecordState st) { stVersion = ver } })
|
||||||
|
|
||||||
|
@ -253,9 +243,6 @@ setClientCertificateChain s = modify (\st -> st { stClientCertificateChain = Jus
|
||||||
getClientCertificateChain :: MonadState TLSState m => m (Maybe CertificateChain)
|
getClientCertificateChain :: MonadState TLSState m => m (Maybe CertificateChain)
|
||||||
getClientCertificateChain = gets stClientCertificateChain
|
getClientCertificateChain = gets stClientCertificateChain
|
||||||
|
|
||||||
getCipherKeyExchangeType :: HandshakeM (Maybe CipherKeyExchangeType)
|
|
||||||
getCipherKeyExchangeType = gets (\st -> cipherKeyExchange <$> hstPendingCipher st)
|
|
||||||
|
|
||||||
getVerifiedData :: MonadState TLSState m => Bool -> m Bytes
|
getVerifiedData :: MonadState TLSState m => Bool -> m Bytes
|
||||||
getVerifiedData client = gets (if client then stClientVerifiedData else stServerVerifiedData)
|
getVerifiedData client = gets (if client then stClientVerifiedData else stServerVerifiedData)
|
||||||
|
|
||||||
|
@ -270,14 +257,6 @@ startHandshakeClient ver crand = do
|
||||||
when (isNothing chs) $
|
when (isNothing chs) $
|
||||||
modify (\st -> st { stHandshake = Just $ newEmptyHandshake ver crand initCtx })
|
modify (\st -> st { stHandshake = Just $ newEmptyHandshake ver crand initCtx })
|
||||||
|
|
||||||
hasValidHandshake :: MonadState TLSState m => String -> m ()
|
|
||||||
hasValidHandshake name = get >>= \st -> assert name [ ("valid handshake", isNothing $ stHandshake st) ]
|
|
||||||
|
|
||||||
updateHandshake :: MonadState TLSState m => String -> (HandshakeState -> HandshakeState) -> m ()
|
|
||||||
updateHandshake n f = do
|
|
||||||
hasValidHandshake n
|
|
||||||
modify (\st -> st { stHandshake = f <$> stHandshake st })
|
|
||||||
|
|
||||||
withHandshakeM :: MonadState TLSState m => HandshakeM a -> m a
|
withHandshakeM :: MonadState TLSState m => HandshakeM a -> m a
|
||||||
withHandshakeM f =
|
withHandshakeM f =
|
||||||
get >>= \st -> case stHandshake st of
|
get >>= \st -> case stHandshake st of
|
||||||
|
@ -287,12 +266,12 @@ withHandshakeM f =
|
||||||
return a
|
return a
|
||||||
|
|
||||||
getHandshakeDigest :: MonadState TLSState m => Bool -> m Bytes
|
getHandshakeDigest :: MonadState TLSState m => Bool -> m Bytes
|
||||||
getHandshakeDigest client = do
|
getHandshakeDigest roleClient = do
|
||||||
st <- get
|
st <- get
|
||||||
let hst = fromJust "handshake" $ stHandshake st
|
let hst = fromJust "handshake" $ stHandshake st
|
||||||
let hashctx = hstHandshakeDigest hst
|
let hashctx = hstHandshakeDigest hst
|
||||||
let msecret = fromJust "master secret" $ hstMasterSecret hst
|
let msecret = fromJust "master secret" $ hstMasterSecret hst
|
||||||
return $ (if client then generateClientFinished else generateServerFinished) (stVersion $ stRecordState st) msecret hashctx
|
return $ (if roleClient then generateClientFinished else generateServerFinished) (stVersion $ stRecordState st) msecret hashctx
|
||||||
|
|
||||||
endHandshake :: MonadState TLSState m => m ()
|
endHandshake :: MonadState TLSState m => m ()
|
||||||
endHandshake = modify (\st -> st { stHandshake = Nothing })
|
endHandshake = modify (\st -> st { stHandshake = Nothing })
|
||||||
|
|
Loading…
Reference in a new issue