move more stuff in the HandshakeM
This commit is contained in:
parent
849f87c8ea
commit
7ecc341af6
5 changed files with 51 additions and 51 deletions
|
@ -37,6 +37,7 @@ import qualified Control.Exception as E
|
||||||
import Network.TLS.Handshake.Common
|
import Network.TLS.Handshake.Common
|
||||||
import Network.TLS.Handshake.Certificate
|
import Network.TLS.Handshake.Certificate
|
||||||
import Network.TLS.Handshake.Signature
|
import Network.TLS.Handshake.Signature
|
||||||
|
import Network.TLS.Handshake.State
|
||||||
|
|
||||||
-- client part of handshake. send a bunch of handshake of client
|
-- client part of handshake. send a bunch of handshake of client
|
||||||
-- values intertwined with response from the server.
|
-- values intertwined with response from the server.
|
||||||
|
@ -142,7 +143,7 @@ handshakeClient cparams ctx = do
|
||||||
-- certificate, we simply store the
|
-- certificate, we simply store the
|
||||||
-- information for later.
|
-- information for later.
|
||||||
--
|
--
|
||||||
usingState_ ctx $ setClientCertRequest (cTypes, sigAlgs, dNames)
|
usingHState ctx $ setClientCertRequest (cTypes, sigAlgs, dNames)
|
||||||
return $ RecvStateHandshake processServerHelloDone
|
return $ RecvStateHandshake processServerHelloDone
|
||||||
processCertificateRequest p = processServerHelloDone p
|
processCertificateRequest p = processServerHelloDone p
|
||||||
|
|
||||||
|
@ -165,7 +166,7 @@ sendClientData cparams ctx = sendCertificate >> sendClientKeyXchg >> sendCertifi
|
||||||
-- use.
|
-- use.
|
||||||
--
|
--
|
||||||
sendCertificate = do
|
sendCertificate = do
|
||||||
certRequested <- usingState_ ctx getClientCertRequest
|
certRequested <- usingHState ctx getClientCertRequest
|
||||||
case certRequested of
|
case certRequested of
|
||||||
Nothing ->
|
Nothing ->
|
||||||
return ()
|
return ()
|
||||||
|
@ -174,7 +175,7 @@ sendClientData cparams ctx = sendCertificate >> sendClientKeyXchg >> sendCertifi
|
||||||
certChain <- liftIO $ onCertificateRequest cparams req `E.catch`
|
certChain <- liftIO $ onCertificateRequest cparams req `E.catch`
|
||||||
throwMiscErrorOnException "certificate request callback failed"
|
throwMiscErrorOnException "certificate request callback failed"
|
||||||
|
|
||||||
usingState_ ctx $ setClientCertSent False
|
usingHState ctx $ setClientCertSent False
|
||||||
case certChain of
|
case certChain of
|
||||||
Nothing -> sendPacket ctx $ Handshake [Certificates (CertificateChain [])]
|
Nothing -> sendPacket ctx $ Handshake [Certificates (CertificateChain [])]
|
||||||
Just (CertificateChain [], _) -> sendPacket ctx $ Handshake [Certificates (CertificateChain [])]
|
Just (CertificateChain [], _) -> sendPacket ctx $ Handshake [Certificates (CertificateChain [])]
|
||||||
|
@ -183,7 +184,7 @@ sendClientData cparams ctx = sendCertificate >> sendClientKeyXchg >> sendCertifi
|
||||||
PubKeyRSA _ -> return ()
|
PubKeyRSA _ -> return ()
|
||||||
_ -> throwCore $ Error_Protocol ("no supported certificate type", True, HandshakeFailure)
|
_ -> throwCore $ Error_Protocol ("no supported certificate type", True, HandshakeFailure)
|
||||||
usingHState ctx $ setClientPrivateKey pk
|
usingHState ctx $ setClientPrivateKey pk
|
||||||
usingState_ ctx $ setClientCertSent True
|
usingHState ctx $ setClientCertSent True
|
||||||
sendPacket ctx $ Handshake [Certificates cc]
|
sendPacket ctx $ Handshake [Certificates cc]
|
||||||
|
|
||||||
sendClientKeyXchg = do
|
sendClientKeyXchg = do
|
||||||
|
@ -217,9 +218,9 @@ sendClientData cparams ctx = sendCertificate >> sendClientKeyXchg >> sendCertifi
|
||||||
-- Only send a certificate verify message when we
|
-- Only send a certificate verify message when we
|
||||||
-- have sent a non-empty list of certificates.
|
-- have sent a non-empty list of certificates.
|
||||||
--
|
--
|
||||||
certSent <- usingState_ ctx $ getClientCertSent
|
certSent <- usingHState ctx $ getClientCertSent
|
||||||
case certSent of
|
case certSent of
|
||||||
Just True -> do
|
True -> do
|
||||||
-- Fetch all handshake messages up to now.
|
-- Fetch all handshake messages up to now.
|
||||||
msgs <- usingState_ ctx $ B.concat <$> getHandshakeMessages
|
msgs <- usingState_ ctx $ B.concat <$> getHandshakeMessages
|
||||||
|
|
||||||
|
@ -240,7 +241,7 @@ sendClientData cparams ctx = sendCertificate >> sendClientKeyXchg >> sendCertifi
|
||||||
sendPacket ctx $ Handshake [CertVerify Nothing (CertVerifyData sigDig)]
|
sendPacket ctx $ Handshake [CertVerify Nothing (CertVerifyData sigDig)]
|
||||||
|
|
||||||
_ -> do
|
_ -> do
|
||||||
Just (_, Just hashSigs, _) <- usingState_ ctx $ getClientCertRequest
|
Just (_, Just hashSigs, _) <- usingHState ctx $ getClientCertRequest
|
||||||
let suppHashSigs = pHashSignatures $ ctxParams ctx
|
let suppHashSigs = pHashSignatures $ ctxParams ctx
|
||||||
hashSigs' = filter (\ a -> a `elem` hashSigs) suppHashSigs
|
hashSigs' = filter (\ a -> a `elem` hashSigs) suppHashSigs
|
||||||
liftIO $ putStrLn $ " supported hash sig algorithms: " ++ show hashSigs'
|
liftIO $ putStrLn $ " supported hash sig algorithms: " ++ show hashSigs'
|
||||||
|
|
|
@ -21,6 +21,7 @@ import Network.TLS.Packet
|
||||||
import Network.TLS.Extension
|
import Network.TLS.Extension
|
||||||
import Network.TLS.IO
|
import Network.TLS.IO
|
||||||
import Network.TLS.State hiding (getNegotiatedProtocol)
|
import Network.TLS.State hiding (getNegotiatedProtocol)
|
||||||
|
import Network.TLS.Handshake.State
|
||||||
import Network.TLS.Receiving
|
import Network.TLS.Receiving
|
||||||
import Network.TLS.Measurement
|
import Network.TLS.Measurement
|
||||||
import Data.Maybe
|
import Data.Maybe
|
||||||
|
@ -192,7 +193,7 @@ handshakeServerWith sparams ctx clientHello@(ClientHello ver _ clientSession cip
|
||||||
else Just (pHashSignatures $ ctxParams ctx)
|
else Just (pHashSignatures $ ctxParams ctx)
|
||||||
creq = CertRequest certTypes hashSigs
|
creq = CertRequest certTypes hashSigs
|
||||||
(map extractCAname $ serverCACertificates sparams)
|
(map extractCAname $ serverCACertificates sparams)
|
||||||
usingState_ ctx $ setCertReqSent True
|
usingHState ctx $ setCertReqSent True
|
||||||
sendPacket ctx (Handshake [creq])
|
sendPacket ctx (Handshake [creq])
|
||||||
|
|
||||||
-- Send HelloDone
|
-- Send HelloDone
|
||||||
|
@ -225,7 +226,7 @@ recvClientData sparams ctx = runRecvState ctx (RecvStateHandshake processClientC
|
||||||
|
|
||||||
-- Remember cert chain for later use.
|
-- Remember cert chain for later use.
|
||||||
--
|
--
|
||||||
usingState_ ctx $ setClientCertChain certs
|
usingHState ctx $ setClientCertChain certs
|
||||||
|
|
||||||
-- FIXME: We should check whether the certificate
|
-- FIXME: We should check whether the certificate
|
||||||
-- matches our request and that we support
|
-- matches our request and that we support
|
||||||
|
@ -275,7 +276,7 @@ recvClientData sparams ctx = runRecvState ctx (RecvStateHandshake processClientC
|
||||||
-- When verification succeeds, commit the
|
-- When verification succeeds, commit the
|
||||||
-- client certificate chain to the context.
|
-- client certificate chain to the context.
|
||||||
--
|
--
|
||||||
Just certs <- usingState_ ctx $ getClientCertChain
|
Just certs <- usingHState ctx $ getClientCertChain
|
||||||
usingState_ ctx $ setClientCertificateChain certs
|
usingState_ ctx $ setClientCertificateChain certs
|
||||||
return ()
|
return ()
|
||||||
|
|
||||||
|
@ -292,13 +293,13 @@ recvClientData sparams ctx = runRecvState ctx (RecvStateHandshake processClientC
|
||||||
-- application callbacks accepts, we
|
-- application callbacks accepts, we
|
||||||
-- also commit the client certificate
|
-- also commit the client certificate
|
||||||
-- chain to the context.
|
-- chain to the context.
|
||||||
Just certs <- usingState_ ctx $ getClientCertChain
|
Just certs <- usingHState ctx $ getClientCertChain
|
||||||
usingState_ ctx $ setClientCertificateChain certs
|
usingState_ ctx $ setClientCertificateChain certs
|
||||||
else throwCore $ Error_Protocol ("verification failed", True, BadCertificate)
|
else throwCore $ Error_Protocol ("verification failed", True, BadCertificate)
|
||||||
return $ RecvStateNext expectChangeCipher
|
return $ RecvStateNext expectChangeCipher
|
||||||
|
|
||||||
processCertificateVerify p = do
|
processCertificateVerify p = do
|
||||||
chain <- usingState_ ctx $ getClientCertChain
|
chain <- usingHState ctx $ getClientCertChain
|
||||||
case chain of
|
case chain of
|
||||||
Just cc | isNullCertificateChain cc -> return ()
|
Just cc | isNullCertificateChain cc -> return ()
|
||||||
| otherwise -> throwCore $ Error_Protocol ("cert verify message missing", True, UnexpectedMessage)
|
| otherwise -> throwCore $ Error_Protocol ("cert verify message missing", True, UnexpectedMessage)
|
||||||
|
@ -317,7 +318,7 @@ recvClientData sparams ctx = runRecvState ctx (RecvStateHandshake processClientC
|
||||||
expectFinish p = unexpected (show p) (Just "Handshake Finished")
|
expectFinish p = unexpected (show p) (Just "Handshake Finished")
|
||||||
|
|
||||||
checkValidClientCertChain msg = do
|
checkValidClientCertChain msg = do
|
||||||
chain <- usingState_ ctx $ getClientCertChain
|
chain <- usingHState ctx $ getClientCertChain
|
||||||
let throwerror = Error_Protocol (msg , True, UnexpectedMessage)
|
let throwerror = Error_Protocol (msg , True, UnexpectedMessage)
|
||||||
case chain of
|
case chain of
|
||||||
Nothing -> throwCore throwerror
|
Nothing -> throwCore throwerror
|
||||||
|
|
|
@ -14,11 +14,20 @@ module Network.TLS.Handshake.State
|
||||||
, HandshakeM
|
, HandshakeM
|
||||||
, newEmptyHandshake
|
, newEmptyHandshake
|
||||||
, runHandshake
|
, runHandshake
|
||||||
-- * accessors
|
-- * key accessors
|
||||||
, setPublicKey
|
, setPublicKey
|
||||||
, setPrivateKey
|
, setPrivateKey
|
||||||
, setClientPublicKey
|
, setClientPublicKey
|
||||||
, setClientPrivateKey
|
, setClientPrivateKey
|
||||||
|
-- * cert accessors
|
||||||
|
, setClientCertSent
|
||||||
|
, getClientCertSent
|
||||||
|
, setCertReqSent
|
||||||
|
, getCertReqSent
|
||||||
|
, setClientCertChain
|
||||||
|
, getClientCertChain
|
||||||
|
, setClientCertRequest
|
||||||
|
, getClientCertRequest
|
||||||
) where
|
) where
|
||||||
|
|
||||||
import Network.TLS.Util
|
import Network.TLS.Util
|
||||||
|
@ -29,7 +38,6 @@ import qualified Data.ByteString as B
|
||||||
import Control.Applicative ((<$>))
|
import Control.Applicative ((<$>))
|
||||||
import Control.Monad
|
import Control.Monad
|
||||||
import Control.Monad.State
|
import Control.Monad.State
|
||||||
import Control.Monad.Error
|
|
||||||
import Data.X509 (CertificateChain)
|
import Data.X509 (CertificateChain)
|
||||||
|
|
||||||
data HandshakeState = HandshakeState
|
data HandshakeState = HandshakeState
|
||||||
|
@ -97,3 +105,28 @@ setClientPublicKey pk = modify (\hst -> hst { hstRSAClientPublicKey = Just pk })
|
||||||
|
|
||||||
setClientPrivateKey :: PrivKey -> HandshakeM ()
|
setClientPrivateKey :: PrivKey -> HandshakeM ()
|
||||||
setClientPrivateKey pk = modify (\hst -> hst { hstRSAClientPrivateKey = Just pk })
|
setClientPrivateKey pk = modify (\hst -> hst { hstRSAClientPrivateKey = Just pk })
|
||||||
|
|
||||||
|
setCertReqSent :: Bool -> HandshakeM ()
|
||||||
|
setCertReqSent b = modify (\hst -> hst { hstCertReqSent = b })
|
||||||
|
|
||||||
|
getCertReqSent :: HandshakeM Bool
|
||||||
|
getCertReqSent = gets hstCertReqSent
|
||||||
|
|
||||||
|
setClientCertSent :: Bool -> HandshakeM ()
|
||||||
|
setClientCertSent b = modify (\hst -> hst { hstClientCertSent = b })
|
||||||
|
|
||||||
|
getClientCertSent :: HandshakeM Bool
|
||||||
|
getClientCertSent = gets hstClientCertSent
|
||||||
|
|
||||||
|
setClientCertChain :: CertificateChain -> HandshakeM ()
|
||||||
|
setClientCertChain b = modify (\hst -> hst { hstClientCertChain = Just b })
|
||||||
|
|
||||||
|
getClientCertChain :: HandshakeM (Maybe CertificateChain)
|
||||||
|
getClientCertChain = gets hstClientCertChain
|
||||||
|
|
||||||
|
setClientCertRequest :: ClientCertRequestData -> HandshakeM ()
|
||||||
|
setClientCertRequest d = modify (\hst -> hst { hstClientCertRequest = Just d })
|
||||||
|
|
||||||
|
getClientCertRequest :: HandshakeM (Maybe ClientCertRequestData)
|
||||||
|
getClientCertRequest = gets hstClientCertRequest
|
||||||
|
|
||||||
|
|
|
@ -27,6 +27,7 @@ import Network.TLS.Struct
|
||||||
import Network.TLS.Record
|
import Network.TLS.Record
|
||||||
import Network.TLS.Packet
|
import Network.TLS.Packet
|
||||||
import Network.TLS.State
|
import Network.TLS.State
|
||||||
|
import Network.TLS.Handshake.State
|
||||||
import Network.TLS.Cipher
|
import Network.TLS.Cipher
|
||||||
import Network.TLS.Crypto
|
import Network.TLS.Crypto
|
||||||
import Network.TLS.Extension
|
import Network.TLS.Extension
|
||||||
|
|
|
@ -30,18 +30,6 @@ module Network.TLS.State
|
||||||
, setMasterSecret
|
, setMasterSecret
|
||||||
, setMasterSecretFromPre
|
, setMasterSecretFromPre
|
||||||
, getMasterSecret
|
, getMasterSecret
|
||||||
, setPublicKey
|
|
||||||
, setPrivateKey
|
|
||||||
, setClientPublicKey
|
|
||||||
, setClientPrivateKey
|
|
||||||
, setClientCertSent
|
|
||||||
, getClientCertSent
|
|
||||||
, setCertReqSent
|
|
||||||
, getCertReqSent
|
|
||||||
, setClientCertChain
|
|
||||||
, getClientCertChain
|
|
||||||
, setClientCertRequest
|
|
||||||
, getClientCertRequest
|
|
||||||
, setKeyBlock
|
, setKeyBlock
|
||||||
, setVersion
|
, setVersion
|
||||||
, getVersion
|
, getVersion
|
||||||
|
@ -241,30 +229,6 @@ setMasterSecretFromPre premasterSecret = do
|
||||||
getMasterSecret :: MonadState TLSState m => m (Maybe Bytes)
|
getMasterSecret :: MonadState TLSState m => m (Maybe Bytes)
|
||||||
getMasterSecret = gets (stHandshake >=> hstMasterSecret)
|
getMasterSecret = gets (stHandshake >=> hstMasterSecret)
|
||||||
|
|
||||||
setCertReqSent :: MonadState TLSState m => Bool -> m ()
|
|
||||||
setCertReqSent b = updateHandshake "client cert req sent" (\hst -> hst { hstCertReqSent = b })
|
|
||||||
|
|
||||||
getCertReqSent :: MonadState TLSState m => m (Maybe Bool)
|
|
||||||
getCertReqSent = gets (stHandshake >=> Just . hstCertReqSent)
|
|
||||||
|
|
||||||
setClientCertSent :: MonadState TLSState m => Bool -> m ()
|
|
||||||
setClientCertSent b = updateHandshake "client cert sent" (\hst -> hst { hstClientCertSent = b })
|
|
||||||
|
|
||||||
getClientCertSent :: MonadState TLSState m => m (Maybe Bool)
|
|
||||||
getClientCertSent = gets (stHandshake >=> Just . hstClientCertSent)
|
|
||||||
|
|
||||||
setClientCertChain :: MonadState TLSState m => CertificateChain -> m ()
|
|
||||||
setClientCertChain b = updateHandshake "client certificate chain" (\hst -> hst { hstClientCertChain = Just b })
|
|
||||||
|
|
||||||
getClientCertChain :: MonadState TLSState m => m (Maybe CertificateChain)
|
|
||||||
getClientCertChain = gets (stHandshake >=> hstClientCertChain)
|
|
||||||
|
|
||||||
setClientCertRequest :: MonadState TLSState m => ClientCertRequestData -> m ()
|
|
||||||
setClientCertRequest d = updateHandshake "client cert data" (\hst -> hst { hstClientCertRequest = Just d })
|
|
||||||
|
|
||||||
getClientCertRequest :: MonadState TLSState m => m (Maybe ClientCertRequestData)
|
|
||||||
getClientCertRequest = gets (stHandshake >=> hstClientCertRequest)
|
|
||||||
|
|
||||||
getSessionData :: MonadState TLSState m => m (Maybe SessionData)
|
getSessionData :: MonadState TLSState m => m (Maybe SessionData)
|
||||||
getSessionData = get >>= \st -> return (stHandshake st >>= hstMasterSecret >>= wrapSessionData st)
|
getSessionData = get >>= \st -> return (stHandshake st >>= hstMasterSecret >>= wrapSessionData st)
|
||||||
where wrapSessionData st masterSecret = do
|
where wrapSessionData st masterSecret = do
|
||||||
|
|
Loading…
Reference in a new issue