diff --git a/core/Network/TLS/Handshake/Client.hs b/core/Network/TLS/Handshake/Client.hs index 6cba68f..3490556 100644 --- a/core/Network/TLS/Handshake/Client.hs +++ b/core/Network/TLS/Handshake/Client.hs @@ -37,6 +37,7 @@ import qualified Control.Exception as E import Network.TLS.Handshake.Common import Network.TLS.Handshake.Certificate import Network.TLS.Handshake.Signature +import Network.TLS.Handshake.State -- client part of handshake. send a bunch of handshake of client -- values intertwined with response from the server. @@ -142,7 +143,7 @@ handshakeClient cparams ctx = do -- certificate, we simply store the -- information for later. -- - usingState_ ctx $ setClientCertRequest (cTypes, sigAlgs, dNames) + usingHState ctx $ setClientCertRequest (cTypes, sigAlgs, dNames) return $ RecvStateHandshake processServerHelloDone processCertificateRequest p = processServerHelloDone p @@ -165,7 +166,7 @@ sendClientData cparams ctx = sendCertificate >> sendClientKeyXchg >> sendCertifi -- use. -- sendCertificate = do - certRequested <- usingState_ ctx getClientCertRequest + certRequested <- usingHState ctx getClientCertRequest case certRequested of Nothing -> return () @@ -174,7 +175,7 @@ sendClientData cparams ctx = sendCertificate >> sendClientKeyXchg >> sendCertifi certChain <- liftIO $ onCertificateRequest cparams req `E.catch` throwMiscErrorOnException "certificate request callback failed" - usingState_ ctx $ setClientCertSent False + usingHState ctx $ setClientCertSent False case certChain of Nothing -> sendPacket ctx $ Handshake [Certificates (CertificateChain [])] Just (CertificateChain [], _) -> sendPacket ctx $ Handshake [Certificates (CertificateChain [])] @@ -183,7 +184,7 @@ sendClientData cparams ctx = sendCertificate >> sendClientKeyXchg >> sendCertifi PubKeyRSA _ -> return () _ -> throwCore $ Error_Protocol ("no supported certificate type", True, HandshakeFailure) usingHState ctx $ setClientPrivateKey pk - usingState_ ctx $ setClientCertSent True + usingHState ctx $ setClientCertSent True sendPacket ctx $ Handshake [Certificates cc] sendClientKeyXchg = do @@ -217,9 +218,9 @@ sendClientData cparams ctx = sendCertificate >> sendClientKeyXchg >> sendCertifi -- Only send a certificate verify message when we -- have sent a non-empty list of certificates. -- - certSent <- usingState_ ctx $ getClientCertSent + certSent <- usingHState ctx $ getClientCertSent case certSent of - Just True -> do + True -> do -- Fetch all handshake messages up to now. msgs <- usingState_ ctx $ B.concat <$> getHandshakeMessages @@ -240,7 +241,7 @@ sendClientData cparams ctx = sendCertificate >> sendClientKeyXchg >> sendCertifi sendPacket ctx $ Handshake [CertVerify Nothing (CertVerifyData sigDig)] _ -> do - Just (_, Just hashSigs, _) <- usingState_ ctx $ getClientCertRequest + Just (_, Just hashSigs, _) <- usingHState ctx $ getClientCertRequest let suppHashSigs = pHashSignatures $ ctxParams ctx hashSigs' = filter (\ a -> a `elem` hashSigs) suppHashSigs liftIO $ putStrLn $ " supported hash sig algorithms: " ++ show hashSigs' diff --git a/core/Network/TLS/Handshake/Server.hs b/core/Network/TLS/Handshake/Server.hs index ae5447c..aa01244 100644 --- a/core/Network/TLS/Handshake/Server.hs +++ b/core/Network/TLS/Handshake/Server.hs @@ -21,6 +21,7 @@ import Network.TLS.Packet import Network.TLS.Extension import Network.TLS.IO import Network.TLS.State hiding (getNegotiatedProtocol) +import Network.TLS.Handshake.State import Network.TLS.Receiving import Network.TLS.Measurement import Data.Maybe @@ -192,7 +193,7 @@ handshakeServerWith sparams ctx clientHello@(ClientHello ver _ clientSession cip else Just (pHashSignatures $ ctxParams ctx) creq = CertRequest certTypes hashSigs (map extractCAname $ serverCACertificates sparams) - usingState_ ctx $ setCertReqSent True + usingHState ctx $ setCertReqSent True sendPacket ctx (Handshake [creq]) -- Send HelloDone @@ -225,7 +226,7 @@ recvClientData sparams ctx = runRecvState ctx (RecvStateHandshake processClientC -- Remember cert chain for later use. -- - usingState_ ctx $ setClientCertChain certs + usingHState ctx $ setClientCertChain certs -- FIXME: We should check whether the certificate -- matches our request and that we support @@ -275,7 +276,7 @@ recvClientData sparams ctx = runRecvState ctx (RecvStateHandshake processClientC -- When verification succeeds, commit the -- client certificate chain to the context. -- - Just certs <- usingState_ ctx $ getClientCertChain + Just certs <- usingHState ctx $ getClientCertChain usingState_ ctx $ setClientCertificateChain certs return () @@ -292,13 +293,13 @@ recvClientData sparams ctx = runRecvState ctx (RecvStateHandshake processClientC -- application callbacks accepts, we -- also commit the client certificate -- chain to the context. - Just certs <- usingState_ ctx $ getClientCertChain + Just certs <- usingHState ctx $ getClientCertChain usingState_ ctx $ setClientCertificateChain certs else throwCore $ Error_Protocol ("verification failed", True, BadCertificate) return $ RecvStateNext expectChangeCipher processCertificateVerify p = do - chain <- usingState_ ctx $ getClientCertChain + chain <- usingHState ctx $ getClientCertChain case chain of Just cc | isNullCertificateChain cc -> return () | otherwise -> throwCore $ Error_Protocol ("cert verify message missing", True, UnexpectedMessage) @@ -317,7 +318,7 @@ recvClientData sparams ctx = runRecvState ctx (RecvStateHandshake processClientC expectFinish p = unexpected (show p) (Just "Handshake Finished") checkValidClientCertChain msg = do - chain <- usingState_ ctx $ getClientCertChain + chain <- usingHState ctx $ getClientCertChain let throwerror = Error_Protocol (msg , True, UnexpectedMessage) case chain of Nothing -> throwCore throwerror diff --git a/core/Network/TLS/Handshake/State.hs b/core/Network/TLS/Handshake/State.hs index a4cc984..aaacfcd 100644 --- a/core/Network/TLS/Handshake/State.hs +++ b/core/Network/TLS/Handshake/State.hs @@ -14,11 +14,20 @@ module Network.TLS.Handshake.State , HandshakeM , newEmptyHandshake , runHandshake - -- * accessors + -- * key accessors , setPublicKey , setPrivateKey , setClientPublicKey , setClientPrivateKey + -- * cert accessors + , setClientCertSent + , getClientCertSent + , setCertReqSent + , getCertReqSent + , setClientCertChain + , getClientCertChain + , setClientCertRequest + , getClientCertRequest ) where import Network.TLS.Util @@ -29,7 +38,6 @@ import qualified Data.ByteString as B import Control.Applicative ((<$>)) import Control.Monad import Control.Monad.State -import Control.Monad.Error import Data.X509 (CertificateChain) data HandshakeState = HandshakeState @@ -97,3 +105,28 @@ setClientPublicKey pk = modify (\hst -> hst { hstRSAClientPublicKey = Just pk }) setClientPrivateKey :: PrivKey -> HandshakeM () setClientPrivateKey pk = modify (\hst -> hst { hstRSAClientPrivateKey = Just pk }) + +setCertReqSent :: Bool -> HandshakeM () +setCertReqSent b = modify (\hst -> hst { hstCertReqSent = b }) + +getCertReqSent :: HandshakeM Bool +getCertReqSent = gets hstCertReqSent + +setClientCertSent :: Bool -> HandshakeM () +setClientCertSent b = modify (\hst -> hst { hstClientCertSent = b }) + +getClientCertSent :: HandshakeM Bool +getClientCertSent = gets hstClientCertSent + +setClientCertChain :: CertificateChain -> HandshakeM () +setClientCertChain b = modify (\hst -> hst { hstClientCertChain = Just b }) + +getClientCertChain :: HandshakeM (Maybe CertificateChain) +getClientCertChain = gets hstClientCertChain + +setClientCertRequest :: ClientCertRequestData -> HandshakeM () +setClientCertRequest d = modify (\hst -> hst { hstClientCertRequest = Just d }) + +getClientCertRequest :: HandshakeM (Maybe ClientCertRequestData) +getClientCertRequest = gets hstClientCertRequest + diff --git a/core/Network/TLS/Receiving.hs b/core/Network/TLS/Receiving.hs index b025a20..fb5ed51 100644 --- a/core/Network/TLS/Receiving.hs +++ b/core/Network/TLS/Receiving.hs @@ -27,6 +27,7 @@ import Network.TLS.Struct import Network.TLS.Record import Network.TLS.Packet import Network.TLS.State +import Network.TLS.Handshake.State import Network.TLS.Cipher import Network.TLS.Crypto import Network.TLS.Extension diff --git a/core/Network/TLS/State.hs b/core/Network/TLS/State.hs index aafbec4..7d54d45 100644 --- a/core/Network/TLS/State.hs +++ b/core/Network/TLS/State.hs @@ -30,18 +30,6 @@ module Network.TLS.State , setMasterSecret , setMasterSecretFromPre , getMasterSecret - , setPublicKey - , setPrivateKey - , setClientPublicKey - , setClientPrivateKey - , setClientCertSent - , getClientCertSent - , setCertReqSent - , getCertReqSent - , setClientCertChain - , getClientCertChain - , setClientCertRequest - , getClientCertRequest , setKeyBlock , setVersion , getVersion @@ -241,30 +229,6 @@ setMasterSecretFromPre premasterSecret = do getMasterSecret :: MonadState TLSState m => m (Maybe Bytes) getMasterSecret = gets (stHandshake >=> hstMasterSecret) -setCertReqSent :: MonadState TLSState m => Bool -> m () -setCertReqSent b = updateHandshake "client cert req sent" (\hst -> hst { hstCertReqSent = b }) - -getCertReqSent :: MonadState TLSState m => m (Maybe Bool) -getCertReqSent = gets (stHandshake >=> Just . hstCertReqSent) - -setClientCertSent :: MonadState TLSState m => Bool -> m () -setClientCertSent b = updateHandshake "client cert sent" (\hst -> hst { hstClientCertSent = b }) - -getClientCertSent :: MonadState TLSState m => m (Maybe Bool) -getClientCertSent = gets (stHandshake >=> Just . hstClientCertSent) - -setClientCertChain :: MonadState TLSState m => CertificateChain -> m () -setClientCertChain b = updateHandshake "client certificate chain" (\hst -> hst { hstClientCertChain = Just b }) - -getClientCertChain :: MonadState TLSState m => m (Maybe CertificateChain) -getClientCertChain = gets (stHandshake >=> hstClientCertChain) - -setClientCertRequest :: MonadState TLSState m => ClientCertRequestData -> m () -setClientCertRequest d = updateHandshake "client cert data" (\hst -> hst { hstClientCertRequest = Just d }) - -getClientCertRequest :: MonadState TLSState m => m (Maybe ClientCertRequestData) -getClientCertRequest = gets (stHandshake >=> hstClientCertRequest) - getSessionData :: MonadState TLSState m => m (Maybe SessionData) getSessionData = get >>= \st -> return (stHandshake st >>= hstMasterSecret >>= wrapSessionData st) where wrapSessionData st masterSecret = do