add dynamic recv certificate hook and remove the static one.

This commit is contained in:
Vincent Hanquez 2014-01-26 06:37:17 +00:00
parent a880d4081e
commit 11575711bc
5 changed files with 17 additions and 5 deletions

View file

@ -45,6 +45,7 @@ module Network.TLS.Context
-- * Context hooks
, contextHookSetHandshakeRecv
, contextHookSetCertificateRecv
-- * Using context states
, throwCore
@ -69,6 +70,7 @@ import Network.TLS.Parameters
import Network.TLS.Measurement
import Network.TLS.Types (Role(..))
import Network.TLS.Handshake (handshakeClient, handshakeClientWith, handshakeServer, handshakeServerWith)
import Network.TLS.X509
import Data.Maybe (isJust)
import Crypto.Random
@ -207,3 +209,7 @@ contextNewOnSocket sock params st = contextNew sock params st
contextHookSetHandshakeRecv :: Context -> (Handshake -> IO Handshake) -> IO ()
contextHookSetHandshakeRecv context f =
liftIO $ modifyIORef (ctxHooks context) (\hooks -> hooks { hookRecvHandshake = f })
contextHookSetCertificateRecv :: Context -> (CertificateChain -> IO ()) -> IO ()
contextHookSetCertificateRecv context f =
liftIO $ modifyIORef (ctxHooks context) (\hooks -> hooks { hookRecvCertificates = f })

View file

@ -278,6 +278,9 @@ onServerHello _ _ _ p = unexpected (show p) (Just "server hello")
processCertificate :: ClientParams -> Context -> Handshake -> IO (RecvState IO)
processCertificate cparams ctx (Certificates certs) = do
-- run certificate recv hook
ctxWithHooks ctx (\hooks -> hookRecvCertificates hooks $ certs)
-- then run certificate validation
usage <- catchException (wrapCertificateChecks <$> checkCert) rejectOnException
case usage of
CertificateUsageAccept -> return ()

View file

@ -266,6 +266,8 @@ doHandshake sparams mcred ctx chosenVersion usedCipher usedCompression clientSes
recvClientData :: ServerParams -> Context -> IO ()
recvClientData sparams ctx = runRecvState ctx (RecvStateHandshake processClientCertificate)
where processClientCertificate (Certificates certs) = do
-- run certificate recv hook
ctxWithHooks ctx (\hooks -> hookRecvCertificates hooks $ certs)
-- Call application callback to see whether the
-- certificate chain is acceptable.
--

View file

@ -13,6 +13,7 @@ module Network.TLS.Hooks
import qualified Data.ByteString as B
import Network.TLS.Struct (Header, Handshake(..))
import Network.TLS.X509 (CertificateChain)
import Data.Default.Class
-- | Hooks for logging
@ -36,12 +37,14 @@ instance Default Logging where
-- | A collection of hooks actions.
data Hooks = Hooks
{ hookRecvHandshake :: Handshake -> IO Handshake
{ hookRecvHandshake :: Handshake -> IO Handshake
, hookRecvCertificates :: CertificateChain -> IO ()
}
defaultHooks :: Hooks
defaultHooks = Hooks
{ hookRecvHandshake = \hs -> return hs
, hookRecvCertificates = return . const ()
}
instance Default Hooks where

View file

@ -247,8 +247,7 @@ instance Default ServerHooks where
def = defaultServerHooks
data CommonHooks = CommonHooks
{ onCertificatesRecv :: CertificateChain -> IO CertificateUsage -- ^ callback to verify received cert chain.
, onHandshake :: Measurement -> IO Bool -- ^ callback on a beggining of handshake
{ onHandshake :: Measurement -> IO Bool -- ^ callback on a beggining of handshake
, logging :: Logging -- ^ callback for logging
}
@ -257,7 +256,6 @@ instance Show CommonHooks where
instance Default CommonHooks where
def = CommonHooks
{ onCertificatesRecv = \_ -> return CertificateUsageAccept
, logging = def
{ logging = def
, onHandshake = \_ -> return True
}