diff --git a/core/Network/TLS/Handshake/Client.hs b/core/Network/TLS/Handshake/Client.hs index 2dc0006..cad9585 100644 --- a/core/Network/TLS/Handshake/Client.hs +++ b/core/Network/TLS/Handshake/Client.hs @@ -36,6 +36,7 @@ import Control.Exception (SomeException) import qualified Control.Exception as E import Network.TLS.Handshake.Common +import Network.TLS.Handshake.Process import Network.TLS.Handshake.Certificate import Network.TLS.Handshake.Signature import Network.TLS.Handshake.Key @@ -75,7 +76,7 @@ handshakeClient cparams ctx = do crand <- getStateRNG ctx 32 >>= return . ClientRandom let clientSession = Session . maybe Nothing (Just . fst) $ clientWantSessionResume cparams extensions <- getExtensions - usingState_ ctx (startHandshakeClient (pConnectVersion params) crand) + startHandshake ctx (pConnectVersion params) crand sendPacket ctx $ Handshake [ ClientHello (pConnectVersion params) crand clientSession (map cipherID ciphers) (map compressionID compressions) extensions Nothing diff --git a/core/Network/TLS/Handshake/Common.hs b/core/Network/TLS/Handshake/Common.hs index 8840399..b9fc40a 100644 --- a/core/Network/TLS/Handshake/Common.hs +++ b/core/Network/TLS/Handshake/Common.hs @@ -121,4 +121,3 @@ runRecvState ctx iniState = recvPacketHandshake ctx >>= loop iniState > processHandshake ctx x loop nstate xs loop _ _ = unexpected "spurious handshake" Nothing - diff --git a/core/Network/TLS/Handshake/Process.hs b/core/Network/TLS/Handshake/Process.hs index ea16584..7dd65fc 100644 --- a/core/Network/TLS/Handshake/Process.hs +++ b/core/Network/TLS/Handshake/Process.hs @@ -9,13 +9,15 @@ -- module Network.TLS.Handshake.Process ( processHandshake + , startHandshake ) where import Data.ByteString (ByteString) +import Data.Maybe (isNothing) import Control.Applicative import Control.Monad.Error -import Control.Monad.State (gets) +import Control.Monad.State (gets, modify) import Network.TLS.Types (Role(..), invertRole) import Network.TLS.Util @@ -23,6 +25,7 @@ import Network.TLS.Packet import Network.TLS.Struct import Network.TLS.State import Network.TLS.Context +import Network.TLS.Crypto import Network.TLS.Handshake.State import Network.TLS.Handshake.Key import Network.TLS.Extension @@ -34,7 +37,7 @@ processHandshake ctx hs = do case hs of ClientHello cver ran _ _ _ ex _ -> when (role == ServerRole) $ do mapM_ (usingState_ ctx . processClientExtension) ex - usingState_ ctx $ startHandshakeClient cver ran + startHandshake ctx cver ran Certificates certs -> processCertificates role certs ClientKeyXchg content -> when (role == ServerRole) $ do processClientKeyXchg ctx content @@ -90,3 +93,11 @@ processClientFinished ctx fdata = do usingState_ ctx $ updateVerifiedData ServerRole fdata return () +startHandshake :: MonadIO m => Context -> Version -> ClientRandom -> m () +startHandshake ctx ver crand = do + -- FIXME check if handshake is already not null + let initCtx = if ver < TLS12 then hashMD5SHA1 else hashSHA256 + usingState_ ctx $ do + chs <- gets stHandshake + when (isNothing chs) $ + modify (\st -> st { stHandshake = Just $ newEmptyHandshake ver crand initCtx }) diff --git a/core/Network/TLS/State.hs b/core/Network/TLS/State.hs index 7fba73c..88f0d48 100644 --- a/core/Network/TLS/State.hs +++ b/core/Network/TLS/State.hs @@ -44,7 +44,6 @@ module Network.TLS.State , getSessionData , isSessionResuming , isClientContext - , startHandshakeClient , getHandshakeDigest , endHandshake -- * random @@ -53,14 +52,11 @@ module Network.TLS.State ) where import Control.Applicative -import Data.Maybe (isNothing) import Network.TLS.Struct -import Network.TLS.Crypto -import Network.TLS.Handshake.State import Network.TLS.RNG +import Network.TLS.Handshake.State import Network.TLS.Types (Role(..)) import qualified Data.ByteString as B -import Control.Monad import Control.Monad.State import Control.Monad.Error import Crypto.Random.API @@ -211,14 +207,6 @@ getVerifiedData client = gets (if client == ClientRole then stClientVerifiedData isClientContext :: MonadState TLSState m => m Role isClientContext = gets stClientContext -startHandshakeClient :: MonadState TLSState m => Version -> ClientRandom -> m () -startHandshakeClient ver crand = do - -- FIXME check if handshake is already not null - let initCtx = if ver < TLS12 then hashMD5SHA1 else hashSHA256 - chs <- get >>= return . stHandshake - when (isNothing chs) $ - modify (\st -> st { stHandshake = Just $ newEmptyHandshake ver crand initCtx }) - withHandshakeM :: MonadState TLSState m => HandshakeM a -> m a withHandshakeM f = get >>= \st -> case stHandshake st of