more cleanup / separation with handshake state.

This commit is contained in:
Vincent Hanquez 2013-07-23 07:30:13 +00:00
parent 0fdfd3d104
commit acc670e30e
5 changed files with 23 additions and 32 deletions

View file

@ -90,14 +90,17 @@ handshakeClient cparams ctx = do
recvServerHello sentExts = runRecvState ctx (RecvStateHandshake $ onServerHello sentExts)
onServerHello :: MonadIO m => [ExtensionID] -> Handshake -> m (RecvState m)
onServerHello sentExts sh@(ServerHello rver _ serverSession cipher _ exts) = do
onServerHello sentExts sh@(ServerHello rver _ serverSession cipher compression exts) = do
when (rver == SSL2) $ throwCore $ Error_Protocol ("ssl2 is not supported", True, ProtocolVersion)
case find ((==) rver) allowedvers of
Nothing -> throwCore $ Error_Protocol ("version " ++ show rver ++ "is not supported", True, ProtocolVersion)
Just _ -> usingState_ ctx $ setVersion rver
case find ((==) cipher . cipherID) ciphers of
Nothing -> throwCore $ Error_Protocol ("no cipher in common with the server", True, HandshakeFailure)
Just c -> usingHState ctx $ setCipher c
-- find the compression and cipher methods that the server want to use.
case (find ((==) cipher . cipherID) ciphers, find ((==) compression . compressionID) compressions) of
(Nothing,_) -> throwCore $ Error_Protocol ("no cipher in common with the server", True, HandshakeFailure)
(_,Nothing) -> throwCore $ Error_Protocol ("no compression in common with the server", True, HandshakeFailure)
(Just cipherAlg, Just compressAlg) ->
usingHState ctx $ setPendingAlgs cipherAlg compressAlg
-- intersect sent extensions in client and the received extensions from server.
-- if server returns extensions that we didn't request, fail.

View file

@ -96,9 +96,7 @@ handshakeServerWith sparams ctx clientHello@(ClientHello ver _ clientSession cip
when (null commonCompressions) $
throwCore $ Error_Protocol ("no compression in common with the client", True, HandshakeFailure)
usingState_ ctx $ setVersion ver
usingHState ctx $ do
setCipher usedCipher
modify (\hst -> hst { hstPendingCompression = usedCompression })
usingHState ctx $ setPendingAlgs usedCipher usedCompression
resumeSessionData <- case clientSession of
(Session (Just clientSessionId)) -> withSessionManager params (\s -> liftIO $ sessionResume s clientSessionId)
@ -165,7 +163,8 @@ handshakeServerWith sparams ctx clientHello@(ClientHello ver _ clientSession cip
, extensionEncode $ NextProtocolNegotiation protos) ]
Nothing -> return []
let extensions = secRengExt ++ npnExt
usingState_ ctx (setVersion ver >> setServerRandom srand)
usingState_ ctx (setVersion ver)
usingHState ctx $ setServerRandom srand
return $ ServerHello ver srand session (cipherID usedCipher)
(compressionID usedCompression) extensions

View file

@ -35,6 +35,9 @@ module Network.TLS.Handshake.State
-- * master secret
, setMasterSecret
, setMasterSecretFromPre
-- * misc accessor
, setPendingAlgs
, setServerRandom
) where
import Network.TLS.Util
@ -211,3 +214,10 @@ computeKeyBlock hst masterSecret ver cc = (pendingTx, pendingRx)
, stCipher = Just cipher
, stCompression = hstPendingCompression hst
}
setPendingAlgs :: Cipher -> Compression -> HandshakeM ()
setPendingAlgs cipher compression =
modify $ \hst -> hst { hstPendingCipher = Just cipher, hstPendingCompression = compression }
setServerRandom :: ServerRandom -> HandshakeM ()
setServerRandom ran = modify $ \hst -> hst { hstServerRandom = Just ran }

View file

@ -120,7 +120,7 @@ processServerHello (ServerHello sver ran _ _ _ ex) = do
-- secreneg <- getSecureRenegotiation
-- when (secreneg && (isNothing $ lookup 0xff01 ex)) $ ...
mapM_ processServerExtension ex
setServerRandom ran
withHandshakeM $ setServerRandom ran
setVersion sver
where processServerExtension (0xff01, content) = do
cv <- getVerifiedData True

View file

@ -29,8 +29,6 @@ module Network.TLS.State
, certVerifyHandshakeMaterial
, setVersion
, getVersion
, setCipher
, setServerRandom
, setSecureRenegotiation
, getSecureRenegotiation
, setExtensionNPN
@ -49,7 +47,6 @@ module Network.TLS.State
, needEmptyPacket
, switchTxEncryption
, switchRxEncryption
, getCipherKeyExchangeType
, isClientContext
, startHandshakeClient
, getHandshakeDigest
@ -67,7 +64,6 @@ import Network.TLS.Handshake.State
import Network.TLS.RNG
import Network.TLS.Types (Role(..))
import qualified Data.ByteString as B
import Control.Applicative ((<$>))
import Control.Monad
import Control.Monad.State
import Control.Monad.Error
@ -187,9 +183,6 @@ switchRxEncryption =
withHandshakeM (gets hstPendingRxState)
>>= \newRxState -> runRecordStateSt (modify $ \st -> st { stRxState = fromJust "pending-rx" newRxState })
setServerRandom :: MonadState TLSState m => ServerRandom -> m ()
setServerRandom ran = updateHandshake "srand" (\hst -> hst { hstServerRandom = Just ran })
getSessionData :: MonadState TLSState m => m (Maybe SessionData)
getSessionData = get >>= \st -> return (stHandshake st >>= hstMasterSecret >>= wrapSessionData st)
where wrapSessionData st masterSecret = do
@ -214,9 +207,6 @@ needEmptyPacket = gets f
&& stClientContext st == ClientRole
&& (maybe False (\c -> bulkBlockSize (cipherBulk c) > 0) (stCipher $ stTxState st))
setCipher :: Cipher -> HandshakeM ()
setCipher cipher = modify (\st -> st { hstPendingCipher = Just cipher })
setVersion :: MonadState TLSState m => Version -> m ()
setVersion ver = modify (\st -> st { stRecordState = (stRecordState st) { stVersion = ver } })
@ -253,9 +243,6 @@ setClientCertificateChain s = modify (\st -> st { stClientCertificateChain = Jus
getClientCertificateChain :: MonadState TLSState m => m (Maybe CertificateChain)
getClientCertificateChain = gets stClientCertificateChain
getCipherKeyExchangeType :: HandshakeM (Maybe CipherKeyExchangeType)
getCipherKeyExchangeType = gets (\st -> cipherKeyExchange <$> hstPendingCipher st)
getVerifiedData :: MonadState TLSState m => Bool -> m Bytes
getVerifiedData client = gets (if client then stClientVerifiedData else stServerVerifiedData)
@ -270,14 +257,6 @@ startHandshakeClient ver crand = do
when (isNothing chs) $
modify (\st -> st { stHandshake = Just $ newEmptyHandshake ver crand initCtx })
hasValidHandshake :: MonadState TLSState m => String -> m ()
hasValidHandshake name = get >>= \st -> assert name [ ("valid handshake", isNothing $ stHandshake st) ]
updateHandshake :: MonadState TLSState m => String -> (HandshakeState -> HandshakeState) -> m ()
updateHandshake n f = do
hasValidHandshake n
modify (\st -> st { stHandshake = f <$> stHandshake st })
withHandshakeM :: MonadState TLSState m => HandshakeM a -> m a
withHandshakeM f =
get >>= \st -> case stHandshake st of
@ -287,12 +266,12 @@ withHandshakeM f =
return a
getHandshakeDigest :: MonadState TLSState m => Bool -> m Bytes
getHandshakeDigest client = do
getHandshakeDigest roleClient = do
st <- get
let hst = fromJust "handshake" $ stHandshake st
let hashctx = hstHandshakeDigest hst
let msecret = fromJust "master secret" $ hstMasterSecret hst
return $ (if client then generateClientFinished else generateServerFinished) (stVersion $ stRecordState st) msecret hashctx
return $ (if roleClient then generateClientFinished else generateServerFinished) (stVersion $ stRecordState st) msecret hashctx
endHandshake :: MonadState TLSState m => m ()
endHandshake = modify (\st -> st { stHandshake = Nothing })