From 11575711bcdfbcd32f1b2d9d4de9f32bea4582f8 Mon Sep 17 00:00:00 2001 From: Vincent Hanquez Date: Sun, 26 Jan 2014 06:37:17 +0000 Subject: [PATCH] add dynamic recv certificate hook and remove the static one. --- core/Network/TLS/Context.hs | 6 ++++++ core/Network/TLS/Handshake/Client.hs | 3 +++ core/Network/TLS/Handshake/Server.hs | 2 ++ core/Network/TLS/Hooks.hs | 5 ++++- core/Network/TLS/Parameters.hs | 6 ++---- 5 files changed, 17 insertions(+), 5 deletions(-) diff --git a/core/Network/TLS/Context.hs b/core/Network/TLS/Context.hs index 41971b7..299371f 100644 --- a/core/Network/TLS/Context.hs +++ b/core/Network/TLS/Context.hs @@ -45,6 +45,7 @@ module Network.TLS.Context -- * Context hooks , contextHookSetHandshakeRecv + , contextHookSetCertificateRecv -- * Using context states , throwCore @@ -69,6 +70,7 @@ import Network.TLS.Parameters import Network.TLS.Measurement import Network.TLS.Types (Role(..)) import Network.TLS.Handshake (handshakeClient, handshakeClientWith, handshakeServer, handshakeServerWith) +import Network.TLS.X509 import Data.Maybe (isJust) import Crypto.Random @@ -207,3 +209,7 @@ contextNewOnSocket sock params st = contextNew sock params st contextHookSetHandshakeRecv :: Context -> (Handshake -> IO Handshake) -> IO () contextHookSetHandshakeRecv context f = liftIO $ modifyIORef (ctxHooks context) (\hooks -> hooks { hookRecvHandshake = f }) + +contextHookSetCertificateRecv :: Context -> (CertificateChain -> IO ()) -> IO () +contextHookSetCertificateRecv context f = + liftIO $ modifyIORef (ctxHooks context) (\hooks -> hooks { hookRecvCertificates = f }) diff --git a/core/Network/TLS/Handshake/Client.hs b/core/Network/TLS/Handshake/Client.hs index 59e32d9..0c36170 100644 --- a/core/Network/TLS/Handshake/Client.hs +++ b/core/Network/TLS/Handshake/Client.hs @@ -278,6 +278,9 @@ onServerHello _ _ _ p = unexpected (show p) (Just "server hello") processCertificate :: ClientParams -> Context -> Handshake -> IO (RecvState IO) processCertificate cparams ctx (Certificates certs) = do + -- run certificate recv hook + ctxWithHooks ctx (\hooks -> hookRecvCertificates hooks $ certs) + -- then run certificate validation usage <- catchException (wrapCertificateChecks <$> checkCert) rejectOnException case usage of CertificateUsageAccept -> return () diff --git a/core/Network/TLS/Handshake/Server.hs b/core/Network/TLS/Handshake/Server.hs index b66279b..4fadf79 100644 --- a/core/Network/TLS/Handshake/Server.hs +++ b/core/Network/TLS/Handshake/Server.hs @@ -266,6 +266,8 @@ doHandshake sparams mcred ctx chosenVersion usedCipher usedCompression clientSes recvClientData :: ServerParams -> Context -> IO () recvClientData sparams ctx = runRecvState ctx (RecvStateHandshake processClientCertificate) where processClientCertificate (Certificates certs) = do + -- run certificate recv hook + ctxWithHooks ctx (\hooks -> hookRecvCertificates hooks $ certs) -- Call application callback to see whether the -- certificate chain is acceptable. -- diff --git a/core/Network/TLS/Hooks.hs b/core/Network/TLS/Hooks.hs index 4413967..7d98d7e 100644 --- a/core/Network/TLS/Hooks.hs +++ b/core/Network/TLS/Hooks.hs @@ -13,6 +13,7 @@ module Network.TLS.Hooks import qualified Data.ByteString as B import Network.TLS.Struct (Header, Handshake(..)) +import Network.TLS.X509 (CertificateChain) import Data.Default.Class -- | Hooks for logging @@ -36,12 +37,14 @@ instance Default Logging where -- | A collection of hooks actions. data Hooks = Hooks - { hookRecvHandshake :: Handshake -> IO Handshake + { hookRecvHandshake :: Handshake -> IO Handshake + , hookRecvCertificates :: CertificateChain -> IO () } defaultHooks :: Hooks defaultHooks = Hooks { hookRecvHandshake = \hs -> return hs + , hookRecvCertificates = return . const () } instance Default Hooks where diff --git a/core/Network/TLS/Parameters.hs b/core/Network/TLS/Parameters.hs index c1544ee..0823136 100644 --- a/core/Network/TLS/Parameters.hs +++ b/core/Network/TLS/Parameters.hs @@ -247,8 +247,7 @@ instance Default ServerHooks where def = defaultServerHooks data CommonHooks = CommonHooks - { onCertificatesRecv :: CertificateChain -> IO CertificateUsage -- ^ callback to verify received cert chain. - , onHandshake :: Measurement -> IO Bool -- ^ callback on a beggining of handshake + { onHandshake :: Measurement -> IO Bool -- ^ callback on a beggining of handshake , logging :: Logging -- ^ callback for logging } @@ -257,7 +256,6 @@ instance Show CommonHooks where instance Default CommonHooks where def = CommonHooks - { onCertificatesRecv = \_ -> return CertificateUsageAccept - , logging = def + { logging = def , onHandshake = \_ -> return True }