From 513d13029fa87b5e4dab72be32cbbb41bd1034e2 Mon Sep 17 00:00:00 2001 From: Vincent Hanquez Date: Tue, 30 Oct 2012 04:46:19 +0000 Subject: [PATCH] use gets where possible and make thing nicer --- core/Network/TLS/State.hs | 98 ++++++++++++++++----------------------- 1 file changed, 40 insertions(+), 58 deletions(-) diff --git a/core/Network/TLS/State.hs b/core/Network/TLS/State.hs index b2190f9..97e584b 100644 --- a/core/Network/TLS/State.hs +++ b/core/Network/TLS/State.hs @@ -316,9 +316,7 @@ setMasterSecretFromPre premasterSecret = do (fromJust "server random" $ hstServerRandom hst) getMasterSecret :: MonadState TLSState m => m (Maybe Bytes) -getMasterSecret = do - st <- get - return (stHandshake st >>= hstMasterSecret) +getMasterSecret = gets (stHandshake >=> hstMasterSecret) setPublicKey :: MonadState TLSState m => PublicKey -> m () setPublicKey pk = updateHandshake "publickey" (\hst -> hst { hstRSAPublicKey = Just pk }) @@ -336,38 +334,28 @@ setCertReqSent :: MonadState TLSState m => Bool -> m () setCertReqSent b = updateHandshake "client cert req sent" (\hst -> hst { hstCertReqSent = b }) getCertReqSent :: MonadState TLSState m => m (Maybe Bool) -getCertReqSent = do - st <- get - return (stHandshake st >>= Just . hstCertReqSent) +getCertReqSent = gets (stHandshake >=> Just . hstCertReqSent) setClientCertSent :: MonadState TLSState m => Bool -> m () setClientCertSent b = updateHandshake "client cert sent" (\hst -> hst { hstClientCertSent = b }) getClientCertSent :: MonadState TLSState m => m (Maybe Bool) -getClientCertSent = do - st <- get - return (stHandshake st >>= Just . hstClientCertSent) +getClientCertSent = gets (stHandshake >=> Just . hstClientCertSent) setClientCertChain :: MonadState TLSState m => [X509] -> m () setClientCertChain b = updateHandshake "client certificate chain" (\hst -> hst { hstClientCertChain = Just b }) getClientCertChain :: MonadState TLSState m => m (Maybe [X509]) -getClientCertChain = do - st <- get - return (stHandshake st >>= hstClientCertChain) +getClientCertChain = gets (stHandshake >=> hstClientCertChain) setClientCertRequest :: MonadState TLSState m => ClientCertRequestData -> m () setClientCertRequest d = updateHandshake "client cert data" (\hst -> hst { hstClientCertRequest = Just d }) getClientCertRequest :: MonadState TLSState m => m (Maybe ClientCertRequestData) -getClientCertRequest = do - st <- get - return (stHandshake st >>= hstClientCertRequest) +getClientCertRequest = gets (stHandshake >=> hstClientCertRequest) getSessionData :: MonadState TLSState m => m (Maybe SessionData) -getSessionData = do - st <- get - return (stHandshake st >>= hstMasterSecret >>= wrapSessionData st) +getSessionData = get >>= \st -> return (stHandshake st >>= hstMasterSecret >>= wrapSessionData st) where wrapSessionData st masterSecret = do return $ SessionData { sessionVersion = stVersion st @@ -390,42 +378,36 @@ needEmptyPacket = gets f && (maybe False (\c -> bulkBlockSize (cipherBulk c) > 0) (stCipher st)) setKeyBlock :: MonadState TLSState m => m () -setKeyBlock = do - st <- get +setKeyBlock = modify setPendingState where + setPendingState st = st { stPendingTxCryptState = Just $ if cc then cstClient else cstServer + , stPendingRxCryptState = Just $ if cc then cstServer else cstClient + , stPendingTxMacState = Just $ if cc then msClient else msServer + , stPendingRxMacState = Just $ if cc then msServer else msClient + } + where hst = fromJust "handshake" $ stHandshake st + cc = stClientContext st + cipher = fromJust "cipher" $ stCipher st + keyblockSize = cipherKeyBlockSize cipher - let hst = fromJust "handshake" $ stHandshake st + bulk = cipherBulk cipher + digestSize = hashSize $ cipherHash cipher + keySize = bulkKeySize bulk + ivSize = bulkIVSize bulk + kb = generateKeyBlock (stVersion st) (hstClientRandom hst) + (fromJust "server random" $ hstServerRandom hst) + (fromJust "master secret" $ hstMasterSecret hst) keyblockSize - let cc = stClientContext st - let cipher = fromJust "cipher" $ stCipher st - let keyblockSize = cipherKeyBlockSize cipher + (cMACSecret, sMACSecret, cWriteKey, sWriteKey, cWriteIV, sWriteIV) = + fromJust "p6" $ partition6 kb (digestSize, digestSize, keySize, keySize, ivSize, ivSize) - let bulk = cipherBulk cipher - let digestSize = hashSize $ cipherHash cipher - let keySize = bulkKeySize bulk - let ivSize = bulkIVSize bulk - let kb = generateKeyBlock (stVersion st) (hstClientRandom hst) - (fromJust "server random" $ hstServerRandom hst) - (fromJust "master secret" $ hstMasterSecret hst) keyblockSize - - let (cMACSecret, sMACSecret, cWriteKey, sWriteKey, cWriteIV, sWriteIV) = - fromJust "p6" $ partition6 kb (digestSize, digestSize, keySize, keySize, ivSize, ivSize) - - let cstClient = TLSCryptState - { cstKey = cWriteKey - , cstIV = cWriteIV - , cstMacSecret = cMACSecret } - let cstServer = TLSCryptState - { cstKey = sWriteKey - , cstIV = sWriteIV - , cstMacSecret = sMACSecret } - let msClient = TLSMacState { msSequence = 0 } - let msServer = TLSMacState { msSequence = 0 } - put $ st - { stPendingTxCryptState = Just $ if cc then cstClient else cstServer - , stPendingRxCryptState = Just $ if cc then cstServer else cstClient - , stPendingTxMacState = Just $ if cc then msClient else msServer - , stPendingRxMacState = Just $ if cc then msServer else msClient - } + cstClient = TLSCryptState { cstKey = cWriteKey + , cstIV = cWriteIV + , cstMacSecret = cMACSecret } + cstServer = TLSCryptState { cstKey = sWriteKey + , cstIV = sWriteIV + , cstMacSecret = sMACSecret } + msClient = TLSMacState { msSequence = 0 } + msServer = TLSMacState { msSequence = 0 } setCipher :: MonadState TLSState m => Cipher -> m () setCipher cipher = modify (\st -> st { stCipher = Just cipher }) @@ -437,19 +419,19 @@ setSecureRenegotiation :: MonadState TLSState m => Bool -> m () setSecureRenegotiation b = modify (\st -> st { stSecureRenegotiation = b }) getSecureRenegotiation :: MonadState TLSState m => m Bool -getSecureRenegotiation = get >>= return . stSecureRenegotiation +getSecureRenegotiation = gets stSecureRenegotiation setExtensionNPN :: MonadState TLSState m => Bool -> m () setExtensionNPN b = modify (\st -> st { stExtensionNPN = b }) getExtensionNPN :: MonadState TLSState m => m Bool -getExtensionNPN = get >>= return . stExtensionNPN +getExtensionNPN = gets stExtensionNPN setNegotiatedProtocol :: MonadState TLSState m => B.ByteString -> m () setNegotiatedProtocol s = modify (\st -> st { stNegotiatedProtocol = Just s }) getNegotiatedProtocol :: MonadState TLSState m => m (Maybe B.ByteString) -getNegotiatedProtocol = get >>= return . stNegotiatedProtocol +getNegotiatedProtocol = gets stNegotiatedProtocol setServerNextProtocolSuggest :: MonadState TLSState m => [B.ByteString] -> m () setServerNextProtocolSuggest ps = modify (\st -> st { stServerNextProtocolSuggest = Just ps}) @@ -461,16 +443,16 @@ setClientCertificateChain :: MonadState TLSState m => [X509] -> m () setClientCertificateChain s = modify (\st -> st { stClientCertificateChain = Just s }) getClientCertificateChain :: MonadState TLSState m => m (Maybe [X509]) -getClientCertificateChain = get >>= return . stClientCertificateChain +getClientCertificateChain = gets stClientCertificateChain getCipherKeyExchangeType :: MonadState TLSState m => m (Maybe CipherKeyExchangeType) -getCipherKeyExchangeType = get >>= return . (maybe Nothing (Just . cipherKeyExchange) . stCipher) +getCipherKeyExchangeType = gets (\st -> cipherKeyExchange <$> stCipher st) getVerifiedData :: MonadState TLSState m => Bool -> m Bytes -getVerifiedData client = get >>= return . (if client then stClientVerifiedData else stServerVerifiedData) +getVerifiedData client = gets (if client then stClientVerifiedData else stServerVerifiedData) isClientContext :: MonadState TLSState m => m Bool -isClientContext = get >>= return . stClientContext +isClientContext = gets stClientContext -- create a new empty handshake state newEmptyHandshake :: Version -> ClientRandom -> HashCtx -> TLSHandshakeState